aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/llm/api_manager.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/llm/api_manager.py')
-rw-r--r--autogpts/autogpt/autogpt/llm/api_manager.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/autogpts/autogpt/autogpt/llm/api_manager.py b/autogpts/autogpt/autogpt/llm/api_manager.py
new file mode 100644
index 000000000..35d28d632
--- /dev/null
+++ b/autogpts/autogpt/autogpt/llm/api_manager.py
@@ -0,0 +1,115 @@
+from __future__ import annotations
+
+import logging
+from typing import List, Optional
+
+from openai import OpenAI
+from openai.types import Model
+
+from autogpt.core.resource.model_providers.openai import (
+ OPEN_AI_MODELS,
+ OpenAICredentials,
+)
+from autogpt.core.resource.model_providers.schema import ChatModelInfo
+from autogpt.singleton import Singleton
+
+logger = logging.getLogger(__name__)
+
+
+class ApiManager(metaclass=Singleton):
+ def __init__(self):
+ self.total_prompt_tokens = 0
+ self.total_completion_tokens = 0
+ self.total_cost = 0
+ self.total_budget = 0
+ self.models: Optional[list[Model]] = None
+
+ def reset(self):
+ self.total_prompt_tokens = 0
+ self.total_completion_tokens = 0
+ self.total_cost = 0
+ self.total_budget = 0.0
+ self.models = None
+
+ def update_cost(self, prompt_tokens, completion_tokens, model):
+ """
+ Update the total cost, prompt tokens, and completion tokens.
+
+ Args:
+ prompt_tokens (int): The number of tokens used in the prompt.
+ completion_tokens (int): The number of tokens used in the completion.
+ model (str): The model used for the API call.
+ """
+ # the .model property in API responses can contain version suffixes like -v2
+ model = model[:-3] if model.endswith("-v2") else model
+ model_info = OPEN_AI_MODELS[model]
+
+ self.total_prompt_tokens += prompt_tokens
+ self.total_completion_tokens += completion_tokens
+ self.total_cost += prompt_tokens * model_info.prompt_token_cost / 1000
+ if isinstance(model_info, ChatModelInfo):
+ self.total_cost += (
+ completion_tokens * model_info.completion_token_cost / 1000
+ )
+
+ logger.debug(f"Total running cost: ${self.total_cost:.3f}")
+
+ def set_total_budget(self, total_budget):
+ """
+ Sets the total user-defined budget for API calls.
+
+ Args:
+ total_budget (float): The total budget for API calls.
+ """
+ self.total_budget = total_budget
+
+ def get_total_prompt_tokens(self):
+ """
+ Get the total number of prompt tokens.
+
+ Returns:
+ int: The total number of prompt tokens.
+ """
+ return self.total_prompt_tokens
+
+ def get_total_completion_tokens(self):
+ """
+ Get the total number of completion tokens.
+
+ Returns:
+ int: The total number of completion tokens.
+ """
+ return self.total_completion_tokens
+
+ def get_total_cost(self):
+ """
+ Get the total cost of API calls.
+
+ Returns:
+ float: The total cost of API calls.
+ """
+ return self.total_cost
+
+ def get_total_budget(self):
+ """
+ Get the total user-defined budget for API calls.
+
+ Returns:
+ float: The total budget for API calls.
+ """
+ return self.total_budget
+
+ def get_models(self, openai_credentials: OpenAICredentials) -> List[Model]:
+ """
+ Get list of available GPT models.
+
+ Returns:
+ list[Model]: List of available GPT models.
+ """
+ if self.models is None:
+ all_models = (
+ OpenAI(**openai_credentials.get_api_access_kwargs()).models.list().data
+ )
+ self.models = [model for model in all_models if "gpt" in model.id]
+
+ return self.models