aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/app/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/app/main.py')
-rw-r--r--autogpts/autogpt/autogpt/app/main.py30
1 files changed, 9 insertions, 21 deletions
diff --git a/autogpts/autogpt/autogpt/app/main.py b/autogpts/autogpt/autogpt/app/main.py
index aaab5fe48..04354fb10 100644
--- a/autogpts/autogpt/autogpt/app/main.py
+++ b/autogpts/autogpt/autogpt/app/main.py
@@ -35,7 +35,7 @@ from autogpt.config import (
ConfigBuilder,
assert_config_has_openai_api_key,
)
-from autogpt.core.resource.model_providers.openai import OpenAIProvider
+from autogpt.core.resource.model_providers import MultiProvider
from autogpt.core.runner.client_lib.utils import coroutine
from autogpt.file_storage import FileStorageBackendName, get_storage
from autogpt.logs.config import configure_logging
@@ -123,7 +123,7 @@ async def run_auto_gpt(
skip_news=skip_news,
)
- llm_provider = _configure_openai_provider(config)
+ llm_provider = _configure_llm_provider(config)
logger = logging.getLogger(__name__)
@@ -399,7 +399,7 @@ async def run_auto_gpt_server(
allow_downloads=allow_downloads,
)
- llm_provider = _configure_openai_provider(config)
+ llm_provider = _configure_llm_provider(config)
# Set up & start server
database = AgentDB(
@@ -421,24 +421,12 @@ async def run_auto_gpt_server(
)
-def _configure_openai_provider(config: Config) -> OpenAIProvider:
- """Create a configured OpenAIProvider object.
-
- Args:
- config: The program's configuration.
-
- Returns:
- A configured OpenAIProvider object.
- """
- if config.openai_credentials is None:
- raise RuntimeError("OpenAI key is not configured")
-
- openai_settings = OpenAIProvider.default_settings.copy(deep=True)
- openai_settings.credentials = config.openai_credentials
- return OpenAIProvider(
- settings=openai_settings,
- logger=logging.getLogger("OpenAIProvider"),
- )
+def _configure_llm_provider(config: Config) -> MultiProvider:
+ multi_provider = MultiProvider()
+ for model in [config.smart_llm, config.fast_llm]:
+ # Ensure model providers for configured LLMs are available
+ multi_provider.get_model_provider(model)
+ return multi_provider
def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float: