aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/schema.py')
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/schema.py58
1 files changed, 57 insertions, 1 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
index 60df855f2..bb2e29490 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
@@ -1,8 +1,10 @@
import abc
import enum
+import logging
import math
from collections import defaultdict
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
ClassVar,
@@ -28,6 +30,9 @@ from autogpt.core.resource.schema import (
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.logs.utils import fmt_kwargs
+if TYPE_CHECKING:
+ from jsonschema import ValidationError
+
class ModelProviderService(str, enum.Enum):
"""A ModelService describes what kind of service the model provides."""
@@ -39,6 +44,7 @@ class ModelProviderService(str, enum.Enum):
class ModelProviderName(str, enum.Enum):
OPENAI = "openai"
+ ANTHROPIC = "anthropic"
class ChatMessage(BaseModel):
@@ -100,6 +106,12 @@ class AssistantChatMessage(ChatMessage):
tool_calls: Optional[list[AssistantToolCall]] = None
+class ToolResultMessage(ChatMessage):
+ role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL
+ is_error: bool = False
+ tool_call_id: str
+
+
class AssistantChatMessageDict(TypedDict, total=False):
role: str
content: str
@@ -146,6 +158,30 @@ class CompletionModelFunction(BaseModel):
)
return f"{self.name}: {self.description}. Params: ({params})"
+ def validate_call(
+ self, function_call: AssistantFunctionCall
+ ) -> tuple[bool, list["ValidationError"]]:
+ """
+ Validates the given function call against the function's parameter specs
+
+ Returns:
+ bool: Whether the given set of arguments is valid for this command
+ list[ValidationError]: Issues with the set of arguments (if any)
+
+ Raises:
+ ValueError: If the function_call doesn't call this function
+ """
+ if function_call.name != self.name:
+ raise ValueError(
+ f"Can't validate {function_call.name} call using {self.name} spec"
+ )
+
+ params_schema = JSONSchema(
+ type=JSONSchema.Type.OBJECT,
+ properties={name: spec for name, spec in self.parameters.items()},
+ )
+ return params_schema.validate_object(function_call.arguments)
+
class ModelInfo(BaseModel):
"""Struct for model information.
@@ -229,7 +265,7 @@ class ModelProviderBudget(ProviderBudget):
class ModelProviderSettings(ProviderSettings):
resource_type: ResourceType = ResourceType.MODEL
configuration: ModelProviderConfiguration
- credentials: ModelProviderCredentials
+ credentials: Optional[ModelProviderCredentials] = None
budget: Optional[ModelProviderBudget] = None
@@ -238,9 +274,28 @@ class ModelProvider(abc.ABC):
default_settings: ClassVar[ModelProviderSettings]
+ _settings: ModelProviderSettings
_configuration: ModelProviderConfiguration
+ _credentials: Optional[ModelProviderCredentials] = None
_budget: Optional[ModelProviderBudget] = None
+ _logger: logging.Logger
+
+ def __init__(
+ self,
+ settings: Optional[ModelProviderSettings] = None,
+ logger: Optional[logging.Logger] = None,
+ ):
+ if not settings:
+ settings = self.default_settings.copy(deep=True)
+
+ self._settings = settings
+ self._configuration = settings.configuration
+ self._credentials = settings.credentials
+ self._budget = settings.budget
+
+ self._logger = logger or logging.getLogger(self.__module__)
+
@abc.abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
...
@@ -358,6 +413,7 @@ class ChatModelProvider(ModelProvider):
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
+ prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
...