aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/schema.py')
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/schema.py51
1 files changed, 29 insertions, 22 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
index ccf3255b4..2ed667725 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
@@ -1,5 +1,6 @@
import abc
import enum
+import math
from typing import (
Callable,
ClassVar,
@@ -13,7 +14,7 @@ from typing import (
from pydantic import BaseModel, Field, SecretStr, validator
-from autogpt.core.configuration import UserConfigurable
+from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.resource.schema import (
Embedding,
ProviderBudget,
@@ -90,7 +91,7 @@ class AssistantToolCallDict(TypedDict):
class AssistantChatMessage(ChatMessage):
- role: Literal["assistant"]
+ role: Literal["assistant"] = "assistant"
content: Optional[str]
tool_calls: Optional[list[AssistantToolCall]]
@@ -163,6 +164,11 @@ class ModelResponse(BaseModel):
model_info: ModelInfo
+class ModelProviderConfiguration(SystemConfiguration):
+ retries_per_request: int = UserConfigurable()
+ extra_request_headers: dict[str, str] = Field(default_factory=dict)
+
+
class ModelProviderCredentials(ProviderCredentials):
"""Credentials for a model provider."""
@@ -172,24 +178,10 @@ class ModelProviderCredentials(ProviderCredentials):
api_version: SecretStr | None = UserConfigurable(default=None)
deployment_id: SecretStr | None = UserConfigurable(default=None)
- def unmasked(self) -> dict:
- return unmask(self)
-
class Config:
extra = "ignore"
-def unmask(model: BaseModel):
- unmasked_fields = {}
- for field_name, field in model.__fields__.items():
- value = getattr(model, field_name)
- if isinstance(value, SecretStr):
- unmasked_fields[field_name] = value.get_secret_value()
- else:
- unmasked_fields[field_name] = value
- return unmasked_fields
-
-
class ModelProviderUsage(ProviderUsage):
"""Usage for a particular model from a model provider."""
@@ -217,8 +209,12 @@ class ModelProviderBudget(ProviderBudget):
def update_usage_and_cost(
self,
model_response: ModelResponse,
- ) -> None:
- """Update the usage and cost of the provider."""
+ ) -> float:
+ """Update the usage and cost of the provider.
+
+ Returns:
+ float: The (calculated) cost of the given model response.
+ """
model_info = model_response.model_info
self.usage.update_usage(model_response)
incurred_cost = (
@@ -227,10 +223,12 @@ class ModelProviderBudget(ProviderBudget):
)
self.total_cost += incurred_cost
self.remaining_budget -= incurred_cost
+ return incurred_cost
class ModelProviderSettings(ProviderSettings):
resource_type: ResourceType = ResourceType.MODEL
+ configuration: ModelProviderConfiguration
credentials: ModelProviderCredentials
budget: ModelProviderBudget
@@ -240,6 +238,9 @@ class ModelProvider(abc.ABC):
default_settings: ClassVar[ModelProviderSettings]
+ _budget: Optional[ModelProviderBudget]
+ _configuration: ModelProviderConfiguration
+
@abc.abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
...
@@ -252,9 +253,15 @@ class ModelProvider(abc.ABC):
def get_token_limit(self, model_name: str) -> int:
...
- @abc.abstractmethod
+ def get_incurred_cost(self) -> float:
+ if self._budget:
+ return self._budget.total_cost
+ return 0
+
def get_remaining_budget(self) -> float:
- ...
+ if self._budget:
+ return self._budget.remaining_budget
+ return math.inf
class ModelTokenizer(Protocol):
@@ -326,7 +333,7 @@ _T = TypeVar("_T")
class ChatModelResponse(ModelResponse, Generic[_T]):
"""Standard response struct for a response from a language model."""
- response: AssistantChatMessageDict
+ response: AssistantChatMessage
parsed_result: _T = None
@@ -344,7 +351,7 @@ class ChatModelProvider(ModelProvider):
self,
model_prompt: list[ChatMessage],
model_name: str,
- completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
+ completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
**kwargs,
) -> ChatModelResponse[_T]: