diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/config/config.py')
-rw-r--r-- | autogpts/autogpt/autogpt/config/config.py | 451 |
1 files changed, 186 insertions, 265 deletions
diff --git a/autogpts/autogpt/autogpt/config/config.py b/autogpts/autogpt/autogpt/config/config.py index 871479e88..ff0053a76 100644 --- a/autogpts/autogpt/autogpt/config/config.py +++ b/autogpts/autogpt/autogpt/config/config.py @@ -1,20 +1,27 @@ """Configuration class to store the state of bools for different scripts access.""" from __future__ import annotations -import contextlib import os import re from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union -import yaml from auto_gpt_plugin_template import AutoGPTPluginTemplate from colorama import Fore -from pydantic import Field, validator +from pydantic import Field, SecretStr, validator import autogpt -from autogpt.core.configuration.schema import Configurable, SystemSettings -from autogpt.core.resource.model_providers.openai import OPEN_AI_CHAT_MODELS +from autogpt.core.configuration.schema import ( + Configurable, + SystemSettings, + UserConfigurable, +) +from autogpt.core.resource.model_providers.openai import ( + OPEN_AI_CHAT_MODELS, + OpenAICredentials, +) +from autogpt.file_workspace import FileWorkspaceBackendName +from autogpt.logs.config import LoggingConfig from autogpt.plugins.plugins_config import PluginsConfig from autogpt.speech import TTSConfig @@ -31,6 +38,7 @@ GPT_3_MODEL = "gpt-3.5-turbo" class Config(SystemSettings, arbitrary_types_allowed=True): name: str = "Auto-GPT configuration" description: str = "Default configuration for the Auto-GPT application." + ######################## # Application Settings # ######################## @@ -38,28 +46,61 @@ class Config(SystemSettings, arbitrary_types_allowed=True): app_data_dir: Path = project_root / "data" skip_news: bool = False skip_reprompt: bool = False - authorise_key: str = "y" - exit_key: str = "n" - debug_mode: bool = False - plain_output: bool = False + authorise_key: str = UserConfigurable(default="y", from_env="AUTHORISE_COMMAND_KEY") + exit_key: str = UserConfigurable(default="n", from_env="EXIT_KEY") noninteractive_mode: bool = False - chat_messages_enabled: bool = True + chat_messages_enabled: bool = UserConfigurable( + default=True, from_env=lambda: os.getenv("CHAT_MESSAGES_ENABLED") == "True" + ) + # TTS configuration tts_config: TTSConfig = TTSConfig() + logging: LoggingConfig = LoggingConfig() + + # Workspace + workspace_backend: FileWorkspaceBackendName = UserConfigurable( + default=FileWorkspaceBackendName.LOCAL, + from_env=lambda: FileWorkspaceBackendName(v) + if (v := os.getenv("WORKSPACE_BACKEND")) + else None, + ) ########################## # Agent Control Settings # ########################## # Paths - ai_settings_file: Path = project_root / AI_SETTINGS_FILE - prompt_settings_file: Path = project_root / PROMPT_SETTINGS_FILE + ai_settings_file: Path = UserConfigurable( + default=AI_SETTINGS_FILE, + from_env=lambda: Path(f) if (f := os.getenv("AI_SETTINGS_FILE")) else None, + ) + prompt_settings_file: Path = UserConfigurable( + default=PROMPT_SETTINGS_FILE, + from_env=lambda: Path(f) if (f := os.getenv("PROMPT_SETTINGS_FILE")) else None, + ) + # Model configuration - fast_llm: str = "gpt-3.5-turbo-16k" - smart_llm: str = "gpt-4" - temperature: float = 0 - openai_functions: bool = False - embedding_model: str = "text-embedding-ada-002" - browse_spacy_language_model: str = "en_core_web_sm" + fast_llm: str = UserConfigurable( + default="gpt-3.5-turbo-16k", + from_env=lambda: os.getenv("FAST_LLM"), + ) + smart_llm: str = UserConfigurable( + default="gpt-4", + from_env=lambda: os.getenv("SMART_LLM"), + ) + temperature: float = UserConfigurable( + default=0, + from_env=lambda: float(v) if (v := os.getenv("TEMPERATURE")) else None, + ) + openai_functions: bool = UserConfigurable( + default=False, from_env=lambda: os.getenv("OPENAI_FUNCTIONS", "False") == "True" + ) + embedding_model: str = UserConfigurable( + default="text-embedding-ada-002", from_env="EMBEDDING_MODEL" + ) + browse_spacy_language_model: str = UserConfigurable( + default="en_core_web_sm", from_env="BROWSE_SPACY_LANGUAGE_MODEL" + ) + # Run loop configuration continuous_mode: bool = False continuous_limit: int = 0 @@ -67,74 +108,138 @@ class Config(SystemSettings, arbitrary_types_allowed=True): ########## # Memory # ########## - memory_backend: str = "json_file" - memory_index: str = "auto-gpt-memory" - redis_host: str = "localhost" - redis_port: int = 6379 - redis_password: str = "" - wipe_redis_on_start: bool = True + memory_backend: str = UserConfigurable("json_file", from_env="MEMORY_BACKEND") + memory_index: str = UserConfigurable("auto-gpt-memory", from_env="MEMORY_INDEX") + redis_host: str = UserConfigurable("localhost", from_env="REDIS_HOST") + redis_port: int = UserConfigurable( + default=6379, + from_env=lambda: int(v) if (v := os.getenv("REDIS_PORT")) else None, + ) + redis_password: str = UserConfigurable("", from_env="REDIS_PASSWORD") + wipe_redis_on_start: bool = UserConfigurable( + default=True, + from_env=lambda: os.getenv("WIPE_REDIS_ON_START", "True") == "True", + ) ############ # Commands # ############ # General - disabled_command_categories: list[str] = Field(default_factory=list) + disabled_command_categories: list[str] = UserConfigurable( + default_factory=list, + from_env=lambda: _safe_split(os.getenv("DISABLED_COMMAND_CATEGORIES")), + ) + # File ops - restrict_to_workspace: bool = True + restrict_to_workspace: bool = UserConfigurable( + default=True, + from_env=lambda: os.getenv("RESTRICT_TO_WORKSPACE", "True") == "True", + ) allow_downloads: bool = False + # Shell commands - shell_command_control: str = "denylist" - execute_local_commands: bool = False - shell_denylist: list[str] = Field(default_factory=lambda: ["sudo", "su"]) - shell_allowlist: list[str] = Field(default_factory=list) + shell_command_control: str = UserConfigurable( + default="denylist", from_env="SHELL_COMMAND_CONTROL" + ) + execute_local_commands: bool = UserConfigurable( + default=False, + from_env=lambda: os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True", + ) + shell_denylist: list[str] = UserConfigurable( + default_factory=lambda: ["sudo", "su"], + from_env=lambda: _safe_split( + os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS")) + ), + ) + shell_allowlist: list[str] = UserConfigurable( + default_factory=list, + from_env=lambda: _safe_split( + os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS")) + ), + ) + # Text to image - image_provider: Optional[str] = None - huggingface_image_model: str = "CompVis/stable-diffusion-v1-4" - sd_webui_url: Optional[str] = "http://localhost:7860" - image_size: int = 256 + image_provider: Optional[str] = UserConfigurable(from_env="IMAGE_PROVIDER") + huggingface_image_model: str = UserConfigurable( + default="CompVis/stable-diffusion-v1-4", from_env="HUGGINGFACE_IMAGE_MODEL" + ) + sd_webui_url: Optional[str] = UserConfigurable( + default="http://localhost:7860", from_env="SD_WEBUI_URL" + ) + image_size: int = UserConfigurable( + default=256, + from_env=lambda: int(v) if (v := os.getenv("IMAGE_SIZE")) else None, + ) + # Audio to text - audio_to_text_provider: str = "huggingface" - huggingface_audio_to_text_model: Optional[str] = None + audio_to_text_provider: str = UserConfigurable( + default="huggingface", from_env="AUDIO_TO_TEXT_PROVIDER" + ) + huggingface_audio_to_text_model: Optional[str] = UserConfigurable( + from_env="HUGGINGFACE_AUDIO_TO_TEXT_MODEL" + ) + # Web browsing - selenium_web_browser: str = "chrome" - selenium_headless: bool = True - user_agent: str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36" + selenium_web_browser: str = UserConfigurable("chrome", from_env="USE_WEB_BROWSER") + selenium_headless: bool = UserConfigurable( + default=True, from_env=lambda: os.getenv("HEADLESS_BROWSER", "True") == "True" + ) + user_agent: str = UserConfigurable( + default="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36", # noqa: E501 + from_env="USER_AGENT", + ) ################### # Plugin Settings # ################### - plugins_dir: str = "plugins" - plugins_config_file: Path = project_root / PLUGINS_CONFIG_FILE + plugins_dir: str = UserConfigurable("plugins", from_env="PLUGINS_DIR") + plugins_config_file: Path = UserConfigurable( + default=PLUGINS_CONFIG_FILE, + from_env=lambda: Path(f) if (f := os.getenv("PLUGINS_CONFIG_FILE")) else None, + ) plugins_config: PluginsConfig = Field( default_factory=lambda: PluginsConfig(plugins={}) ) plugins: list[AutoGPTPluginTemplate] = Field(default_factory=list, exclude=True) - plugins_allowlist: list[str] = Field(default_factory=list) - plugins_denylist: list[str] = Field(default_factory=list) - plugins_openai: list[str] = Field(default_factory=list) + plugins_allowlist: list[str] = UserConfigurable( + default_factory=list, + from_env=lambda: _safe_split(os.getenv("ALLOWLISTED_PLUGINS")), + ) + plugins_denylist: list[str] = UserConfigurable( + default_factory=list, + from_env=lambda: _safe_split(os.getenv("DENYLISTED_PLUGINS")), + ) + plugins_openai: list[str] = UserConfigurable( + default_factory=list, from_env=lambda: _safe_split(os.getenv("OPENAI_PLUGINS")) + ) ############### # Credentials # ############### # OpenAI - openai_api_key: Optional[str] = None - openai_api_type: Optional[str] = None - openai_api_base: Optional[str] = None - openai_api_version: Optional[str] = None - openai_organization: Optional[str] = None - use_azure: bool = False - azure_config_file: Optional[Path] = project_root / AZURE_CONFIG_FILE - azure_model_to_deployment_id_map: Optional[Dict[str, str]] = None + openai_credentials: Optional[OpenAICredentials] = None + azure_config_file: Optional[Path] = UserConfigurable( + default=AZURE_CONFIG_FILE, + from_env=lambda: Path(f) if (f := os.getenv("AZURE_CONFIG_FILE")) else None, + ) + # Github - github_api_key: Optional[str] = None - github_username: Optional[str] = None + github_api_key: Optional[str] = UserConfigurable(from_env="GITHUB_API_KEY") + github_username: Optional[str] = UserConfigurable(from_env="GITHUB_USERNAME") + # Google - google_api_key: Optional[str] = None - google_custom_search_engine_id: Optional[str] = None + google_api_key: Optional[str] = UserConfigurable(from_env="GOOGLE_API_KEY") + google_custom_search_engine_id: Optional[str] = UserConfigurable( + from_env=lambda: os.getenv("GOOGLE_CUSTOM_SEARCH_ENGINE_ID"), + ) + # Huggingface - huggingface_api_token: Optional[str] = None + huggingface_api_token: Optional[str] = UserConfigurable( + from_env="HUGGINGFACE_API_TOKEN" + ) + # Stable Diffusion - sd_webui_auth: Optional[str] = None + sd_webui_auth: Optional[str] = UserConfigurable(from_env="SD_WEBUI_AUTH") @validator("plugins", each_item=True) def validate_plugins(cls, p: AutoGPTPluginTemplate | Any): @@ -156,67 +261,6 @@ class Config(SystemSettings, arbitrary_types_allowed=True): ) return v - def get_openai_credentials(self, model: str) -> dict[str, str]: - credentials = { - "api_key": self.openai_api_key, - "api_base": self.openai_api_base, - "organization": self.openai_organization, - } - if self.use_azure: - azure_credentials = self.get_azure_credentials(model) - credentials.update(azure_credentials) - return credentials - - def get_azure_credentials(self, model: str) -> dict[str, str]: - """Get the kwargs for the Azure API.""" - - # Fix --gpt3only and --gpt4only in combination with Azure - fast_llm = ( - self.fast_llm - if not ( - self.fast_llm == self.smart_llm - and self.fast_llm.startswith(GPT_4_MODEL) - ) - else f"not_{self.fast_llm}" - ) - smart_llm = ( - self.smart_llm - if not ( - self.smart_llm == self.fast_llm - and self.smart_llm.startswith(GPT_3_MODEL) - ) - else f"not_{self.smart_llm}" - ) - - deployment_id = { - fast_llm: self.azure_model_to_deployment_id_map.get( - "fast_llm_deployment_id", - self.azure_model_to_deployment_id_map.get( - "fast_llm_model_deployment_id" # backwards compatibility - ), - ), - smart_llm: self.azure_model_to_deployment_id_map.get( - "smart_llm_deployment_id", - self.azure_model_to_deployment_id_map.get( - "smart_llm_model_deployment_id" # backwards compatibility - ), - ), - self.embedding_model: self.azure_model_to_deployment_id_map.get( - "embedding_model_deployment_id" - ), - }.get(model, None) - - kwargs = { - "api_type": self.openai_api_type, - "api_base": self.openai_api_base, - "api_version": self.openai_api_version, - } - if model == self.embedding_model: - kwargs["engine"] = deployment_id - else: - kwargs["deployment_id"] = deployment_id - return kwargs - class ConfigBuilder(Configurable[Config]): default_settings = Config() @@ -224,124 +268,25 @@ class ConfigBuilder(Configurable[Config]): @classmethod def build_config_from_env(cls, project_root: Path = PROJECT_ROOT) -> Config: """Initialize the Config class""" - config_dict = { - "project_root": project_root, - "authorise_key": os.getenv("AUTHORISE_COMMAND_KEY"), - "exit_key": os.getenv("EXIT_KEY"), - "plain_output": os.getenv("PLAIN_OUTPUT", "False") == "True", - "shell_command_control": os.getenv("SHELL_COMMAND_CONTROL"), - "ai_settings_file": project_root - / Path(os.getenv("AI_SETTINGS_FILE", AI_SETTINGS_FILE)), - "prompt_settings_file": project_root - / Path(os.getenv("PROMPT_SETTINGS_FILE", PROMPT_SETTINGS_FILE)), - "fast_llm": os.getenv("FAST_LLM", os.getenv("FAST_LLM_MODEL")), - "smart_llm": os.getenv("SMART_LLM", os.getenv("SMART_LLM_MODEL")), - "embedding_model": os.getenv("EMBEDDING_MODEL"), - "browse_spacy_language_model": os.getenv("BROWSE_SPACY_LANGUAGE_MODEL"), - "openai_api_key": os.getenv("OPENAI_API_KEY"), - "use_azure": os.getenv("USE_AZURE") == "True", - "azure_config_file": project_root - / Path(os.getenv("AZURE_CONFIG_FILE", AZURE_CONFIG_FILE)), - "execute_local_commands": os.getenv("EXECUTE_LOCAL_COMMANDS", "False") - == "True", - "restrict_to_workspace": os.getenv("RESTRICT_TO_WORKSPACE", "True") - == "True", - "openai_functions": os.getenv("OPENAI_FUNCTIONS", "False") == "True", - "tts_config": { - "provider": os.getenv("TEXT_TO_SPEECH_PROVIDER"), - }, - "github_api_key": os.getenv("GITHUB_API_KEY"), - "github_username": os.getenv("GITHUB_USERNAME"), - "google_api_key": os.getenv("GOOGLE_API_KEY"), - "image_provider": os.getenv("IMAGE_PROVIDER"), - "huggingface_api_token": os.getenv("HUGGINGFACE_API_TOKEN"), - "huggingface_image_model": os.getenv("HUGGINGFACE_IMAGE_MODEL"), - "audio_to_text_provider": os.getenv("AUDIO_TO_TEXT_PROVIDER"), - "huggingface_audio_to_text_model": os.getenv( - "HUGGINGFACE_AUDIO_TO_TEXT_MODEL" - ), - "sd_webui_url": os.getenv("SD_WEBUI_URL"), - "sd_webui_auth": os.getenv("SD_WEBUI_AUTH"), - "selenium_web_browser": os.getenv("USE_WEB_BROWSER"), - "selenium_headless": os.getenv("HEADLESS_BROWSER", "True") == "True", - "user_agent": os.getenv("USER_AGENT"), - "memory_backend": os.getenv("MEMORY_BACKEND"), - "memory_index": os.getenv("MEMORY_INDEX"), - "redis_host": os.getenv("REDIS_HOST"), - "redis_password": os.getenv("REDIS_PASSWORD"), - "wipe_redis_on_start": os.getenv("WIPE_REDIS_ON_START", "True") == "True", - "plugins_dir": os.getenv("PLUGINS_DIR"), - "plugins_config_file": project_root - / Path(os.getenv("PLUGINS_CONFIG_FILE", PLUGINS_CONFIG_FILE)), - "chat_messages_enabled": os.getenv("CHAT_MESSAGES_ENABLED") == "True", - } - - config_dict["disabled_command_categories"] = _safe_split( - os.getenv("DISABLED_COMMAND_CATEGORIES") - ) - config_dict["shell_denylist"] = _safe_split( - os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS")) - ) - config_dict["shell_allowlist"] = _safe_split( - os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS")) - ) - - config_dict["google_custom_search_engine_id"] = os.getenv( - "GOOGLE_CUSTOM_SEARCH_ENGINE_ID", os.getenv("CUSTOM_SEARCH_ENGINE_ID") - ) - - if os.getenv("ELEVENLABS_API_KEY"): - config_dict["tts_config"]["elevenlabs"] = { - "api_key": os.getenv("ELEVENLABS_API_KEY"), - "voice_id": os.getenv("ELEVENLABS_VOICE_ID", ""), - } - if os.getenv("STREAMELEMENTS_VOICE"): - config_dict["tts_config"]["streamelements"] = { - "voice": os.getenv("STREAMELEMENTS_VOICE"), - } - - if not config_dict["tts_config"]["provider"]: - if os.getenv("USE_MAC_OS_TTS"): - default_tts_provider = "macos" - elif "elevenlabs" in config_dict["tts_config"]: - default_tts_provider = "elevenlabs" - elif os.getenv("USE_BRIAN_TTS"): - default_tts_provider = "streamelements" - else: - default_tts_provider = "gtts" - config_dict["tts_config"]["provider"] = default_tts_provider - - config_dict["plugins_allowlist"] = _safe_split(os.getenv("ALLOWLISTED_PLUGINS")) - config_dict["plugins_denylist"] = _safe_split(os.getenv("DENYLISTED_PLUGINS")) - - with contextlib.suppress(TypeError): - config_dict["image_size"] = int(os.getenv("IMAGE_SIZE")) - with contextlib.suppress(TypeError): - config_dict["redis_port"] = int(os.getenv("REDIS_PORT")) - with contextlib.suppress(TypeError): - config_dict["temperature"] = float(os.getenv("TEMPERATURE")) - - if config_dict["use_azure"]: - azure_config = cls.load_azure_config( - project_root / config_dict["azure_config_file"] - ) - config_dict.update(azure_config) - - elif os.getenv("OPENAI_API_BASE_URL"): - config_dict["openai_api_base"] = os.getenv("OPENAI_API_BASE_URL") - - openai_organization = os.getenv("OPENAI_ORGANIZATION") - if openai_organization is not None: - config_dict["openai_organization"] = openai_organization - - config_dict_without_none_values = { - k: v for k, v in config_dict.items() if v is not None - } - - config = cls.build_agent_configuration(config_dict_without_none_values) - - # Set secondary config variables (that depend on other config variables) + config = cls.build_agent_configuration() + config.project_root = project_root + + # Make relative paths absolute + for k in { + "ai_settings_file", # TODO: deprecate or repurpose + "prompt_settings_file", # TODO: deprecate or repurpose + "plugins_config_file", # TODO: move from project root + "azure_config_file", # TODO: move from project root + }: + setattr(config, k, project_root / getattr(config, k)) + + if ( + config.openai_credentials + and config.openai_credentials.api_type == "azure" + and (config_file := config.azure_config_file) + ): + config.openai_credentials.load_azure_config(config_file) config.plugins_config = PluginsConfig.load_config( config.plugins_config_file, @@ -351,36 +296,10 @@ class ConfigBuilder(Configurable[Config]): return config - @classmethod - def load_azure_config(cls, config_file: Path) -> Dict[str, str]: - """ - Loads the configuration parameters for Azure hosting from the specified file - path as a yaml file. - - Parameters: - config_file (Path): The path to the config yaml file. - - Returns: - Dict - """ - with open(config_file) as file: - config_params = yaml.load(file, Loader=yaml.FullLoader) or {} - - return { - "openai_api_type": config_params.get("azure_api_type", "azure"), - "openai_api_base": config_params.get("azure_api_base", ""), - "openai_api_version": config_params.get( - "azure_api_version", "2023-03-15-preview" - ), - "azure_model_to_deployment_id_map": config_params.get( - "azure_model_map", {} - ), - } - def assert_config_has_openai_api_key(config: Config) -> None: """Check if the OpenAI API key is set in config.py or as an environment variable.""" - if not config.openai_api_key: + if not config.openai_credentials: print( Fore.RED + "Please set your OpenAI API key in .env or as an environment variable." @@ -394,7 +313,9 @@ def assert_config_has_openai_api_key(config: Config) -> None: openai_api_key = openai_api_key.strip() if re.search(key_pattern, openai_api_key): os.environ["OPENAI_API_KEY"] = openai_api_key - config.openai_api_key = openai_api_key + config.openai_credentials = OpenAICredentials( + api_key=SecretStr(openai_api_key) + ) print( Fore.GREEN + "OpenAI API key successfully set!\n" |