diff options
Diffstat (limited to 'autogpt/llm/base.py')
-rw-r--r-- | autogpt/llm/base.py | 195 |
1 files changed, 0 insertions, 195 deletions
diff --git a/autogpt/llm/base.py b/autogpt/llm/base.py deleted file mode 100644 index 14a146b3c..000000000 --- a/autogpt/llm/base.py +++ /dev/null @@ -1,195 +0,0 @@ -from __future__ import annotations - -from copy import deepcopy -from dataclasses import dataclass, field -from math import ceil, floor -from typing import TYPE_CHECKING, Literal, Optional, Type, TypedDict, TypeVar, overload - -if TYPE_CHECKING: - from autogpt.llm.providers.openai import OpenAIFunctionCall - -MessageRole = Literal["system", "user", "assistant", "function"] -MessageType = Literal["ai_response", "action_result"] - -TText = list[int] -"""Token array representing tokenized text""" - - -class MessageDict(TypedDict): - role: MessageRole - content: str - - -class ResponseMessageDict(TypedDict): - role: Literal["assistant"] - content: Optional[str] - function_call: Optional[FunctionCallDict] - - -class FunctionCallDict(TypedDict): - name: str - arguments: str - - -@dataclass -class Message: - """OpenAI Message object containing a role and the message content""" - - role: MessageRole - content: str - type: MessageType | None = None - - def raw(self) -> MessageDict: - return {"role": self.role, "content": self.content} - - -@dataclass -class ModelInfo: - """Struct for model information. - - Would be lovely to eventually get this directly from APIs, but needs to be scraped from - websites for now. - """ - - name: str - max_tokens: int - prompt_token_cost: float - - -@dataclass -class CompletionModelInfo(ModelInfo): - """Struct for generic completion model information.""" - - completion_token_cost: float - - -@dataclass -class ChatModelInfo(CompletionModelInfo): - """Struct for chat model information.""" - - -@dataclass -class TextModelInfo(CompletionModelInfo): - """Struct for text completion model information.""" - - -@dataclass -class EmbeddingModelInfo(ModelInfo): - """Struct for embedding model information.""" - - embedding_dimensions: int - - -# Can be replaced by Self in Python 3.11 -TChatSequence = TypeVar("TChatSequence", bound="ChatSequence") - - -@dataclass -class ChatSequence: - """Utility container for a chat sequence""" - - model: ChatModelInfo - messages: list[Message] = field(default_factory=list[Message]) - - @overload - def __getitem__(self, key: int) -> Message: - ... - - @overload - def __getitem__(self: TChatSequence, key: slice) -> TChatSequence: - ... - - def __getitem__(self: TChatSequence, key: int | slice) -> Message | TChatSequence: - if isinstance(key, slice): - copy = deepcopy(self) - copy.messages = self.messages[key] - return copy - return self.messages[key] - - def __iter__(self): - return iter(self.messages) - - def __len__(self): - return len(self.messages) - - def add( - self, - message_role: MessageRole, - content: str, - type: MessageType | None = None, - ) -> None: - self.append(Message(message_role, content, type)) - - def append(self, message: Message): - return self.messages.append(message) - - def extend(self, messages: list[Message] | ChatSequence): - return self.messages.extend(messages) - - def insert(self, index: int, *messages: Message): - for message in reversed(messages): - self.messages.insert(index, message) - - @classmethod - def for_model( - cls: Type[TChatSequence], - model_name: str, - messages: list[Message] | ChatSequence = [], - **kwargs, - ) -> TChatSequence: - from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS - - if not model_name in OPEN_AI_CHAT_MODELS: - raise ValueError(f"Unknown chat model '{model_name}'") - - return cls( - model=OPEN_AI_CHAT_MODELS[model_name], messages=list(messages), **kwargs - ) - - @property - def token_length(self) -> int: - from autogpt.llm.utils import count_message_tokens - - return count_message_tokens(self.messages, self.model.name) - - def raw(self) -> list[MessageDict]: - return [m.raw() for m in self.messages] - - def dump(self) -> str: - SEPARATOR_LENGTH = 42 - - def separator(text: str): - half_sep_len = (SEPARATOR_LENGTH - 2 - len(text)) / 2 - return f"{floor(half_sep_len)*'-'} {text.upper()} {ceil(half_sep_len)*'-'}" - - formatted_messages = "\n".join( - [f"{separator(m.role)}\n{m.content}" for m in self.messages] - ) - return f""" -============== {__class__.__name__} ============== -Length: {self.token_length} tokens; {len(self.messages)} messages -{formatted_messages} -========================================== -""" - - -@dataclass -class LLMResponse: - """Standard response struct for a response from an LLM model.""" - - model_info: ModelInfo - - -@dataclass -class EmbeddingModelResponse(LLMResponse): - """Standard response struct for a response from an embedding model.""" - - embedding: list[float] = field(default_factory=list) - - -@dataclass -class ChatModelResponse(LLMResponse): - """Standard response struct for a response from a chat LLM.""" - - content: Optional[str] - function_call: Optional[OpenAIFunctionCall] |