aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/schema.py')
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/schema.py38
1 files changed, 17 insertions, 21 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
index 327718c11..dd69b526e 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
@@ -1,6 +1,7 @@
import abc
import enum
import math
+from collections import defaultdict
from typing import (
Any,
Callable,
@@ -90,7 +91,7 @@ class AssistantToolCallDict(TypedDict):
class AssistantChatMessage(ChatMessage):
- role: Literal["assistant"] = "assistant"
+ role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT
content: Optional[str]
tool_calls: Optional[list[AssistantToolCall]] = None
@@ -187,39 +188,34 @@ class ModelProviderUsage(ProviderUsage):
completion_tokens: int = 0
prompt_tokens: int = 0
- total_tokens: int = 0
def update_usage(
self,
- model_response: ModelResponse,
+ input_tokens_used: int,
+ output_tokens_used: int = 0,
) -> None:
- self.completion_tokens += model_response.completion_tokens_used
- self.prompt_tokens += model_response.prompt_tokens_used
- self.total_tokens += (
- model_response.completion_tokens_used + model_response.prompt_tokens_used
- )
+ self.prompt_tokens += input_tokens_used
+ self.completion_tokens += output_tokens_used
class ModelProviderBudget(ProviderBudget):
- total_budget: float = UserConfigurable()
- total_cost: float
- remaining_budget: float
- usage: ModelProviderUsage
+ usage: defaultdict[str, ModelProviderUsage] = defaultdict(ModelProviderUsage)
def update_usage_and_cost(
self,
- model_response: ModelResponse,
+ model_info: ModelInfo,
+ input_tokens_used: int,
+ output_tokens_used: int = 0,
) -> float:
"""Update the usage and cost of the provider.
Returns:
float: The (calculated) cost of the given model response.
"""
- model_info = model_response.model_info
- self.usage.update_usage(model_response)
+ self.usage[model_info.name].update_usage(input_tokens_used, output_tokens_used)
incurred_cost = (
- model_response.completion_tokens_used * model_info.completion_token_cost
- + model_response.prompt_tokens_used * model_info.prompt_token_cost
+ output_tokens_used * model_info.completion_token_cost
+ + input_tokens_used * model_info.prompt_token_cost
)
self.total_cost += incurred_cost
self.remaining_budget -= incurred_cost
@@ -230,7 +226,7 @@ class ModelProviderSettings(ProviderSettings):
resource_type: ResourceType = ResourceType.MODEL
configuration: ModelProviderConfiguration
credentials: ModelProviderCredentials
- budget: ModelProviderBudget
+ budget: Optional[ModelProviderBudget] = None
class ModelProvider(abc.ABC):
@@ -238,8 +234,8 @@ class ModelProvider(abc.ABC):
default_settings: ClassVar[ModelProviderSettings]
- _budget: Optional[ModelProviderBudget]
_configuration: ModelProviderConfiguration
+ _budget: Optional[ModelProviderBudget] = None
@abc.abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
@@ -284,7 +280,7 @@ class ModelTokenizer(Protocol):
class EmbeddingModelInfo(ModelInfo):
"""Struct for embedding model information."""
- llm_service = ModelProviderService.EMBEDDING
+ service: Literal[ModelProviderService.EMBEDDING] = ModelProviderService.EMBEDDING
max_tokens: int
embedding_dimensions: int
@@ -322,7 +318,7 @@ class EmbeddingModelProvider(ModelProvider):
class ChatModelInfo(ModelInfo):
"""Struct for language model information."""
- llm_service = ModelProviderService.CHAT
+ service: Literal[ModelProviderService.CHAT] = ModelProviderService.CHAT
max_tokens: int
has_function_call_api: bool = False