diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/schema.py')
-rw-r--r-- | autogpts/autogpt/autogpt/core/resource/model_providers/schema.py | 58 |
1 files changed, 57 insertions, 1 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py index 60df855f2..bb2e29490 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py @@ -1,8 +1,10 @@ import abc import enum +import logging import math from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, @@ -28,6 +30,9 @@ from autogpt.core.resource.schema import ( from autogpt.core.utils.json_schema import JSONSchema from autogpt.logs.utils import fmt_kwargs +if TYPE_CHECKING: + from jsonschema import ValidationError + class ModelProviderService(str, enum.Enum): """A ModelService describes what kind of service the model provides.""" @@ -39,6 +44,7 @@ class ModelProviderService(str, enum.Enum): class ModelProviderName(str, enum.Enum): OPENAI = "openai" + ANTHROPIC = "anthropic" class ChatMessage(BaseModel): @@ -100,6 +106,12 @@ class AssistantChatMessage(ChatMessage): tool_calls: Optional[list[AssistantToolCall]] = None +class ToolResultMessage(ChatMessage): + role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL + is_error: bool = False + tool_call_id: str + + class AssistantChatMessageDict(TypedDict, total=False): role: str content: str @@ -146,6 +158,30 @@ class CompletionModelFunction(BaseModel): ) return f"{self.name}: {self.description}. Params: ({params})" + def validate_call( + self, function_call: AssistantFunctionCall + ) -> tuple[bool, list["ValidationError"]]: + """ + Validates the given function call against the function's parameter specs + + Returns: + bool: Whether the given set of arguments is valid for this command + list[ValidationError]: Issues with the set of arguments (if any) + + Raises: + ValueError: If the function_call doesn't call this function + """ + if function_call.name != self.name: + raise ValueError( + f"Can't validate {function_call.name} call using {self.name} spec" + ) + + params_schema = JSONSchema( + type=JSONSchema.Type.OBJECT, + properties={name: spec for name, spec in self.parameters.items()}, + ) + return params_schema.validate_object(function_call.arguments) + class ModelInfo(BaseModel): """Struct for model information. @@ -229,7 +265,7 @@ class ModelProviderBudget(ProviderBudget): class ModelProviderSettings(ProviderSettings): resource_type: ResourceType = ResourceType.MODEL configuration: ModelProviderConfiguration - credentials: ModelProviderCredentials + credentials: Optional[ModelProviderCredentials] = None budget: Optional[ModelProviderBudget] = None @@ -238,9 +274,28 @@ class ModelProvider(abc.ABC): default_settings: ClassVar[ModelProviderSettings] + _settings: ModelProviderSettings _configuration: ModelProviderConfiguration + _credentials: Optional[ModelProviderCredentials] = None _budget: Optional[ModelProviderBudget] = None + _logger: logging.Logger + + def __init__( + self, + settings: Optional[ModelProviderSettings] = None, + logger: Optional[logging.Logger] = None, + ): + if not settings: + settings = self.default_settings.copy(deep=True) + + self._settings = settings + self._configuration = settings.configuration + self._credentials = settings.credentials + self._budget = settings.budget + + self._logger = logger or logging.getLogger(self.__module__) + @abc.abstractmethod def count_tokens(self, text: str, model_name: str) -> int: ... @@ -358,6 +413,7 @@ class ChatModelProvider(ModelProvider): completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, functions: Optional[list[CompletionModelFunction]] = None, max_output_tokens: Optional[int] = None, + prefill_response: str = "", **kwargs, ) -> ChatModelResponse[_T]: ... |