diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/schema.py')
-rw-r--r-- | autogpts/autogpt/autogpt/core/resource/model_providers/schema.py | 38 |
1 files changed, 17 insertions, 21 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py index 327718c11..dd69b526e 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py @@ -1,6 +1,7 @@ import abc import enum import math +from collections import defaultdict from typing import ( Any, Callable, @@ -90,7 +91,7 @@ class AssistantToolCallDict(TypedDict): class AssistantChatMessage(ChatMessage): - role: Literal["assistant"] = "assistant" + role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT content: Optional[str] tool_calls: Optional[list[AssistantToolCall]] = None @@ -187,39 +188,34 @@ class ModelProviderUsage(ProviderUsage): completion_tokens: int = 0 prompt_tokens: int = 0 - total_tokens: int = 0 def update_usage( self, - model_response: ModelResponse, + input_tokens_used: int, + output_tokens_used: int = 0, ) -> None: - self.completion_tokens += model_response.completion_tokens_used - self.prompt_tokens += model_response.prompt_tokens_used - self.total_tokens += ( - model_response.completion_tokens_used + model_response.prompt_tokens_used - ) + self.prompt_tokens += input_tokens_used + self.completion_tokens += output_tokens_used class ModelProviderBudget(ProviderBudget): - total_budget: float = UserConfigurable() - total_cost: float - remaining_budget: float - usage: ModelProviderUsage + usage: defaultdict[str, ModelProviderUsage] = defaultdict(ModelProviderUsage) def update_usage_and_cost( self, - model_response: ModelResponse, + model_info: ModelInfo, + input_tokens_used: int, + output_tokens_used: int = 0, ) -> float: """Update the usage and cost of the provider. Returns: float: The (calculated) cost of the given model response. """ - model_info = model_response.model_info - self.usage.update_usage(model_response) + self.usage[model_info.name].update_usage(input_tokens_used, output_tokens_used) incurred_cost = ( - model_response.completion_tokens_used * model_info.completion_token_cost - + model_response.prompt_tokens_used * model_info.prompt_token_cost + output_tokens_used * model_info.completion_token_cost + + input_tokens_used * model_info.prompt_token_cost ) self.total_cost += incurred_cost self.remaining_budget -= incurred_cost @@ -230,7 +226,7 @@ class ModelProviderSettings(ProviderSettings): resource_type: ResourceType = ResourceType.MODEL configuration: ModelProviderConfiguration credentials: ModelProviderCredentials - budget: ModelProviderBudget + budget: Optional[ModelProviderBudget] = None class ModelProvider(abc.ABC): @@ -238,8 +234,8 @@ class ModelProvider(abc.ABC): default_settings: ClassVar[ModelProviderSettings] - _budget: Optional[ModelProviderBudget] _configuration: ModelProviderConfiguration + _budget: Optional[ModelProviderBudget] = None @abc.abstractmethod def count_tokens(self, text: str, model_name: str) -> int: @@ -284,7 +280,7 @@ class ModelTokenizer(Protocol): class EmbeddingModelInfo(ModelInfo): """Struct for embedding model information.""" - llm_service = ModelProviderService.EMBEDDING + service: Literal[ModelProviderService.EMBEDDING] = ModelProviderService.EMBEDDING max_tokens: int embedding_dimensions: int @@ -322,7 +318,7 @@ class EmbeddingModelProvider(ModelProvider): class ChatModelInfo(ModelInfo): """Struct for language model information.""" - llm_service = ModelProviderService.CHAT + service: Literal[ModelProviderService.CHAT] = ModelProviderService.CHAT max_tokens: int has_function_call_api: bool = False |