From 028d2c319f3dcca6aa57fc4fdcd2e78a01926e3f Mon Sep 17 00:00:00 2001 From: Krzysztof Czerwinski <34861343+kcze@users.noreply.github.com> Date: Fri, 22 Mar 2024 12:55:40 +0100 Subject: feat(autogpt): Handle OpenAI API key exceptions gracefully (#6992) * Better handle no API keys or invalid ones * Handle exception and exit when invalid key is provided * Handle any APIError exception when trying to get OpenAI models and exit --------- Co-authored-by: Reinier van der Leer --- autogpts/autogpt/autogpt/app/main.py | 25 ++++++------ autogpts/autogpt/autogpt/config/config.py | 59 +++++++++++++++++++---------- autogpts/autogpt/autogpt/llm/api_manager.py | 10 ++++- 3 files changed, 60 insertions(+), 34 deletions(-) diff --git a/autogpts/autogpt/autogpt/app/main.py b/autogpts/autogpt/autogpt/app/main.py index 976ecfabf..30ab1120f 100644 --- a/autogpts/autogpt/autogpt/app/main.py +++ b/autogpts/autogpt/autogpt/app/main.py @@ -94,6 +94,12 @@ async def run_auto_gpt( ) file_storage.initialize() + # Set up logging module + configure_logging( + **config.logging.dict(), + tts_config=config.tts_config, + ) + # TODO: fill in llm values here assert_config_has_openai_api_key(config) @@ -116,12 +122,6 @@ async def run_auto_gpt( skip_news=skip_news, ) - # Set up logging module - configure_logging( - **config.logging.dict(), - tts_config=config.tts_config, - ) - llm_provider = _configure_openai_provider(config) logger = logging.getLogger(__name__) @@ -373,7 +373,6 @@ async def run_auto_gpt_server( from .agent_protocol_server import AgentProtocolServer config = ConfigBuilder.build_config_from_env() - # Storage local = config.file_storage_backend == FileStorageBackendName.LOCAL restrict_to_root = not local or config.restrict_to_workspace @@ -382,6 +381,12 @@ async def run_auto_gpt_server( ) file_storage.initialize() + # Set up logging module + configure_logging( + **config.logging.dict(), + tts_config=config.tts_config, + ) + # TODO: fill in llm values here assert_config_has_openai_api_key(config) @@ -398,12 +403,6 @@ async def run_auto_gpt_server( allow_downloads=allow_downloads, ) - # Set up logging module - configure_logging( - **config.logging.dict(), - tts_config=config.tts_config, - ) - llm_provider = _configure_openai_provider(config) if install_plugin_deps: diff --git a/autogpts/autogpt/autogpt/config/config.py b/autogpts/autogpt/autogpt/config/config.py index eed9eea34..ed1e5f78c 100644 --- a/autogpts/autogpt/autogpt/config/config.py +++ b/autogpts/autogpt/autogpt/config/config.py @@ -1,6 +1,7 @@ """Configuration class to store the state of bools for different scripts access.""" from __future__ import annotations +import logging import os import re from pathlib import Path @@ -11,6 +12,7 @@ from colorama import Fore from pydantic import Field, SecretStr, validator import autogpt +from autogpt.app.utils import clean_input from autogpt.core.configuration.schema import ( Configurable, SystemSettings, @@ -25,6 +27,8 @@ from autogpt.logs.config import LoggingConfig from autogpt.plugins.plugins_config import PluginsConfig from autogpt.speech import TTSConfig +logger = logging.getLogger(__name__) + PROJECT_ROOT = Path(autogpt.__file__).parent.parent AI_SETTINGS_FILE = Path("ai_settings.yaml") AZURE_CONFIG_FILE = Path("azure.yaml") @@ -299,34 +303,51 @@ class ConfigBuilder(Configurable[Config]): def assert_config_has_openai_api_key(config: Config) -> None: """Check if the OpenAI API key is set in config.py or as an environment variable.""" - if not config.openai_credentials: - print( - Fore.RED - + "Please set your OpenAI API key in .env or as an environment variable." - + Fore.RESET + key_pattern = r"^sk-\w{48}" + openai_api_key = ( + config.openai_credentials.api_key.get_secret_value() + if config.openai_credentials + else "" + ) + + # If there's no credentials or empty API key, prompt the user to set it + if not openai_api_key: + logger.error( + "Please set your OpenAI API key in .env or as an environment variable." + ) + logger.info( + "You can get your key from https://platform.openai.com/account/api-keys" ) - print("You can get your key from https://platform.openai.com/account/api-keys") - openai_api_key = input( - "If you do have the key, please enter your OpenAI API key now:\n" + openai_api_key = clean_input( + config, "Please enter your OpenAI API key if you have it:" ) - key_pattern = r"^sk-\w{48}" openai_api_key = openai_api_key.strip() if re.search(key_pattern, openai_api_key): os.environ["OPENAI_API_KEY"] = openai_api_key - config.openai_credentials = OpenAICredentials( - api_key=SecretStr(openai_api_key) - ) + if config.openai_credentials: + config.openai_credentials.api_key = SecretStr(openai_api_key) + else: + config.openai_credentials = OpenAICredentials( + api_key=SecretStr(openai_api_key) + ) + print("OpenAI API key successfully set!") print( - Fore.GREEN - + "OpenAI API key successfully set!\n" - + Fore.YELLOW - + "NOTE: The API key you've set is only temporary.\n" - + "For longer sessions, please set it in .env file" - + Fore.RESET + f"{Fore.YELLOW}NOTE: The API key you've set is only temporary. " + f"For longer sessions, please set it in the .env file{Fore.RESET}" ) else: - print("Invalid OpenAI API key!") + print(f"{Fore.RED}Invalid OpenAI API key{Fore.RESET}") exit(1) + # If key is set, but it looks invalid + elif not re.search(key_pattern, openai_api_key): + logger.error( + "Invalid OpenAI API key! " + "Please set your OpenAI API key in .env or as an environment variable." + ) + logger.info( + "You can get your key from https://platform.openai.com/account/api-keys" + ) + exit(1) def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]: diff --git a/autogpts/autogpt/autogpt/llm/api_manager.py b/autogpts/autogpt/autogpt/llm/api_manager.py index 3ce1cd831..1cfcdd755 100644 --- a/autogpts/autogpt/autogpt/llm/api_manager.py +++ b/autogpts/autogpt/autogpt/llm/api_manager.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from typing import List, Optional -from openai import AzureOpenAI, OpenAI +from openai import APIError, AzureOpenAI, OpenAI from openai.types import Model from autogpt.core.resource.model_providers.openai import ( @@ -106,7 +106,10 @@ class ApiManager(metaclass=Singleton): Returns: list[Model]: List of available GPT models. """ - if self.models is None: + if self.models is not None: + return self.models + + try: if openai_credentials.api_type == "azure": all_models = ( AzureOpenAI(**openai_credentials.get_api_access_kwargs()) @@ -120,5 +123,8 @@ class ApiManager(metaclass=Singleton): .data ) self.models = [model for model in all_models if "gpt" in model.id] + except APIError as e: + logger.error(e.message) + exit(1) return self.models -- cgit v1.2.3