diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/schema.py')
-rw-r--r-- | autogpts/autogpt/autogpt/core/resource/model_providers/schema.py | 51 |
1 files changed, 29 insertions, 22 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py index ccf3255b4..2ed667725 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py @@ -1,5 +1,6 @@ import abc import enum +import math from typing import ( Callable, ClassVar, @@ -13,7 +14,7 @@ from typing import ( from pydantic import BaseModel, Field, SecretStr, validator -from autogpt.core.configuration import UserConfigurable +from autogpt.core.configuration import SystemConfiguration, UserConfigurable from autogpt.core.resource.schema import ( Embedding, ProviderBudget, @@ -90,7 +91,7 @@ class AssistantToolCallDict(TypedDict): class AssistantChatMessage(ChatMessage): - role: Literal["assistant"] + role: Literal["assistant"] = "assistant" content: Optional[str] tool_calls: Optional[list[AssistantToolCall]] @@ -163,6 +164,11 @@ class ModelResponse(BaseModel): model_info: ModelInfo +class ModelProviderConfiguration(SystemConfiguration): + retries_per_request: int = UserConfigurable() + extra_request_headers: dict[str, str] = Field(default_factory=dict) + + class ModelProviderCredentials(ProviderCredentials): """Credentials for a model provider.""" @@ -172,24 +178,10 @@ class ModelProviderCredentials(ProviderCredentials): api_version: SecretStr | None = UserConfigurable(default=None) deployment_id: SecretStr | None = UserConfigurable(default=None) - def unmasked(self) -> dict: - return unmask(self) - class Config: extra = "ignore" -def unmask(model: BaseModel): - unmasked_fields = {} - for field_name, field in model.__fields__.items(): - value = getattr(model, field_name) - if isinstance(value, SecretStr): - unmasked_fields[field_name] = value.get_secret_value() - else: - unmasked_fields[field_name] = value - return unmasked_fields - - class ModelProviderUsage(ProviderUsage): """Usage for a particular model from a model provider.""" @@ -217,8 +209,12 @@ class ModelProviderBudget(ProviderBudget): def update_usage_and_cost( self, model_response: ModelResponse, - ) -> None: - """Update the usage and cost of the provider.""" + ) -> float: + """Update the usage and cost of the provider. + + Returns: + float: The (calculated) cost of the given model response. + """ model_info = model_response.model_info self.usage.update_usage(model_response) incurred_cost = ( @@ -227,10 +223,12 @@ class ModelProviderBudget(ProviderBudget): ) self.total_cost += incurred_cost self.remaining_budget -= incurred_cost + return incurred_cost class ModelProviderSettings(ProviderSettings): resource_type: ResourceType = ResourceType.MODEL + configuration: ModelProviderConfiguration credentials: ModelProviderCredentials budget: ModelProviderBudget @@ -240,6 +238,9 @@ class ModelProvider(abc.ABC): default_settings: ClassVar[ModelProviderSettings] + _budget: Optional[ModelProviderBudget] + _configuration: ModelProviderConfiguration + @abc.abstractmethod def count_tokens(self, text: str, model_name: str) -> int: ... @@ -252,9 +253,15 @@ class ModelProvider(abc.ABC): def get_token_limit(self, model_name: str) -> int: ... - @abc.abstractmethod + def get_incurred_cost(self) -> float: + if self._budget: + return self._budget.total_cost + return 0 + def get_remaining_budget(self) -> float: - ... + if self._budget: + return self._budget.remaining_budget + return math.inf class ModelTokenizer(Protocol): @@ -326,7 +333,7 @@ _T = TypeVar("_T") class ChatModelResponse(ModelResponse, Generic[_T]): """Standard response struct for a response from a language model.""" - response: AssistantChatMessageDict + response: AssistantChatMessage parsed_result: _T = None @@ -344,7 +351,7 @@ class ChatModelProvider(ModelProvider): self, model_prompt: list[ChatMessage], model_name: str, - completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None, + completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, functions: Optional[list[CompletionModelFunction]] = None, **kwargs, ) -> ChatModelResponse[_T]: |