diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/llm/api_manager.py')
-rw-r--r-- | autogpts/autogpt/autogpt/llm/api_manager.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/autogpts/autogpt/autogpt/llm/api_manager.py b/autogpts/autogpt/autogpt/llm/api_manager.py index 146263626..35d28d632 100644 --- a/autogpts/autogpt/autogpt/llm/api_manager.py +++ b/autogpts/autogpt/autogpt/llm/api_manager.py @@ -3,10 +3,13 @@ from __future__ import annotations import logging from typing import List, Optional -import openai -from openai import Model +from openai import OpenAI +from openai.types import Model -from autogpt.core.resource.model_providers.openai import OPEN_AI_MODELS +from autogpt.core.resource.model_providers.openai import ( + OPEN_AI_MODELS, + OpenAICredentials, +) from autogpt.core.resource.model_providers.schema import ChatModelInfo from autogpt.singleton import Singleton @@ -96,16 +99,17 @@ class ApiManager(metaclass=Singleton): """ return self.total_budget - def get_models(self, **openai_credentials) -> List[Model]: + def get_models(self, openai_credentials: OpenAICredentials) -> List[Model]: """ Get list of available GPT models. Returns: - list: List of available GPT models. - + list[Model]: List of available GPT models. """ if self.models is None: - all_models = openai.Model.list(**openai_credentials)["data"] - self.models = [model for model in all_models if "gpt" in model["id"]] + all_models = ( + OpenAI(**openai_credentials.get_api_access_kwargs()).models.list().data + ) + self.models = [model for model in all_models if "gpt" in model.id] return self.models |