diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/app/main.py')
-rw-r--r-- | autogpts/autogpt/autogpt/app/main.py | 30 |
1 files changed, 9 insertions, 21 deletions
diff --git a/autogpts/autogpt/autogpt/app/main.py b/autogpts/autogpt/autogpt/app/main.py index aaab5fe48..04354fb10 100644 --- a/autogpts/autogpt/autogpt/app/main.py +++ b/autogpts/autogpt/autogpt/app/main.py @@ -35,7 +35,7 @@ from autogpt.config import ( ConfigBuilder, assert_config_has_openai_api_key, ) -from autogpt.core.resource.model_providers.openai import OpenAIProvider +from autogpt.core.resource.model_providers import MultiProvider from autogpt.core.runner.client_lib.utils import coroutine from autogpt.file_storage import FileStorageBackendName, get_storage from autogpt.logs.config import configure_logging @@ -123,7 +123,7 @@ async def run_auto_gpt( skip_news=skip_news, ) - llm_provider = _configure_openai_provider(config) + llm_provider = _configure_llm_provider(config) logger = logging.getLogger(__name__) @@ -399,7 +399,7 @@ async def run_auto_gpt_server( allow_downloads=allow_downloads, ) - llm_provider = _configure_openai_provider(config) + llm_provider = _configure_llm_provider(config) # Set up & start server database = AgentDB( @@ -421,24 +421,12 @@ async def run_auto_gpt_server( ) -def _configure_openai_provider(config: Config) -> OpenAIProvider: - """Create a configured OpenAIProvider object. - - Args: - config: The program's configuration. - - Returns: - A configured OpenAIProvider object. - """ - if config.openai_credentials is None: - raise RuntimeError("OpenAI key is not configured") - - openai_settings = OpenAIProvider.default_settings.copy(deep=True) - openai_settings.credentials = config.openai_credentials - return OpenAIProvider( - settings=openai_settings, - logger=logging.getLogger("OpenAIProvider"), - ) +def _configure_llm_provider(config: Config) -> MultiProvider: + multi_provider = MultiProvider() + for model in [config.smart_llm, config.fast_llm]: + # Ensure model providers for configured LLMs are available + multi_provider.get_model_provider(model) + return multi_provider def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float: |