diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/core/configuration/schema.py')
-rw-r--r-- | autogpts/autogpt/autogpt/core/configuration/schema.py | 351 |
1 files changed, 351 insertions, 0 deletions
diff --git a/autogpts/autogpt/autogpt/core/configuration/schema.py b/autogpts/autogpt/autogpt/core/configuration/schema.py new file mode 100644 index 000000000..5bc95ffac --- /dev/null +++ b/autogpts/autogpt/autogpt/core/configuration/schema.py @@ -0,0 +1,351 @@ +import abc +import os +import typing +from typing import Any, Callable, Generic, Optional, Type, TypeVar, get_args + +from pydantic import BaseModel, Field, ValidationError +from pydantic.fields import ModelField, Undefined, UndefinedType +from pydantic.main import ModelMetaclass + +T = TypeVar("T") +M = TypeVar("M", bound=BaseModel) + + +def UserConfigurable( + default: T | UndefinedType = Undefined, + *args, + default_factory: Optional[Callable[[], T]] = None, + from_env: Optional[str | Callable[[], T | None]] = None, + description: str = "", + **kwargs, +) -> T: + # TODO: use this to auto-generate docs for the application configuration + return Field( + default, + *args, + default_factory=default_factory, + from_env=from_env, + description=description, + **kwargs, + user_configurable=True, + ) + + +class SystemConfiguration(BaseModel): + def get_user_config(self) -> dict[str, Any]: + return _recurse_user_config_values(self) + + @classmethod + def from_env(cls): + """ + Initializes the config object from environment variables. + + Environment variables are mapped to UserConfigurable fields using the from_env + attribute that can be passed to UserConfigurable. + """ + + def infer_field_value(field: ModelField): + field_info = field.field_info + default_value = ( + field.default + if field.default not in (None, Undefined) + else (field.default_factory() if field.default_factory else Undefined) + ) + if from_env := field_info.extra.get("from_env"): + val_from_env = ( + os.getenv(from_env) if type(from_env) is str else from_env() + ) + if val_from_env is not None: + return val_from_env + return default_value + + return _recursive_init_model(cls, infer_field_value) + + class Config: + extra = "forbid" + use_enum_values = True + validate_assignment = True + + +SC = TypeVar("SC", bound=SystemConfiguration) + + +class SystemSettings(BaseModel): + """A base class for all system settings.""" + + name: str + description: str + + class Config: + extra = "forbid" + use_enum_values = True + validate_assignment = True + + +S = TypeVar("S", bound=SystemSettings) + + +class Configurable(abc.ABC, Generic[S]): + """A base class for all configurable objects.""" + + prefix: str = "" + default_settings: typing.ClassVar[S] + + @classmethod + def get_user_config(cls) -> dict[str, Any]: + return _recurse_user_config_values(cls.default_settings) + + @classmethod + def build_agent_configuration(cls, overrides: dict = {}) -> S: + """Process the configuration for this object.""" + + base_config = _update_user_config_from_env(cls.default_settings) + final_configuration = deep_update(base_config, overrides) + + return cls.default_settings.__class__.parse_obj(final_configuration) + + +def _update_user_config_from_env(instance: BaseModel) -> dict[str, Any]: + """ + Update config fields of a Pydantic model instance from environment variables. + + Precedence: + 1. Non-default value already on the instance + 2. Value returned by `from_env()` + 3. Default value for the field + + Params: + instance: The Pydantic model instance. + + Returns: + The user config fields of the instance. + """ + + def infer_field_value(field: ModelField, value): + field_info = field.field_info + default_value = ( + field.default + if field.default not in (None, Undefined) + else (field.default_factory() if field.default_factory else None) + ) + if value == default_value and (from_env := field_info.extra.get("from_env")): + val_from_env = os.getenv(from_env) if type(from_env) is str else from_env() + if val_from_env is not None: + return val_from_env + return value + + def init_sub_config(model: Type[SC]) -> SC | None: + try: + return model.from_env() + except ValidationError as e: + # Gracefully handle missing fields + if all(e["type"] == "value_error.missing" for e in e.errors()): + return None + raise + + return _recurse_user_config_fields(instance, infer_field_value, init_sub_config) + + +def _recursive_init_model( + model: Type[M], + infer_field_value: Callable[[ModelField], Any], +) -> M: + """ + Recursively initialize the user configuration fields of a Pydantic model. + + Parameters: + model: The Pydantic model type. + infer_field_value: A callback function to infer the value of each field. + Parameters: + ModelField: The Pydantic ModelField object describing the field. + + Returns: + BaseModel: An instance of the model with the initialized configuration. + """ + user_config_fields = {} + for name, field in model.__fields__.items(): + if "user_configurable" in field.field_info.extra: + user_config_fields[name] = infer_field_value(field) + elif type(field.outer_type_) is ModelMetaclass and issubclass( + field.outer_type_, SystemConfiguration + ): + try: + user_config_fields[name] = _recursive_init_model( + model=field.outer_type_, + infer_field_value=infer_field_value, + ) + except ValidationError as e: + # Gracefully handle missing fields + if all(e["type"] == "value_error.missing" for e in e.errors()): + user_config_fields[name] = None + raise + + user_config_fields = remove_none_items(user_config_fields) + + return model.parse_obj(user_config_fields) + + +def _recurse_user_config_fields( + model: BaseModel, + infer_field_value: Callable[[ModelField, Any], Any], + init_sub_config: Optional[ + Callable[[Type[SystemConfiguration]], SystemConfiguration | None] + ] = None, +) -> dict[str, Any]: + """ + Recursively process the user configuration fields of a Pydantic model instance. + + Params: + model: The Pydantic model to iterate over. + infer_field_value: A callback function to process each field. + Params: + ModelField: The Pydantic ModelField object describing the field. + Any: The current value of the field. + init_sub_config: An optional callback function to initialize a sub-config. + Params: + Type[SystemConfiguration]: The type of the sub-config to initialize. + + Returns: + dict[str, Any]: The processed user configuration fields of the instance. + """ + user_config_fields = {} + + for name, field in model.__fields__.items(): + value = getattr(model, name) + + # Handle individual field + if "user_configurable" in field.field_info.extra: + user_config_fields[name] = infer_field_value(field, value) + + # Recurse into nested config object + elif isinstance(value, SystemConfiguration): + user_config_fields[name] = _recurse_user_config_fields( + model=value, + infer_field_value=infer_field_value, + init_sub_config=init_sub_config, + ) + + # Recurse into optional nested config object + elif value is None and init_sub_config: + field_type = get_args(field.annotation)[0] # Optional[T] -> T + if type(field_type) is ModelMetaclass and issubclass( + field_type, SystemConfiguration + ): + sub_config = init_sub_config(field_type) + if sub_config: + user_config_fields[name] = _recurse_user_config_fields( + model=sub_config, + infer_field_value=infer_field_value, + init_sub_config=init_sub_config, + ) + + elif isinstance(value, list) and all( + isinstance(i, SystemConfiguration) for i in value + ): + user_config_fields[name] = [ + _recurse_user_config_fields(i, infer_field_value, init_sub_config) + for i in value + ] + elif isinstance(value, dict) and all( + isinstance(i, SystemConfiguration) for i in value.values() + ): + user_config_fields[name] = { + k: _recurse_user_config_fields(v, infer_field_value, init_sub_config) + for k, v in value.items() + } + + return user_config_fields + + +def _recurse_user_config_values( + instance: BaseModel, + get_field_value: Callable[[ModelField, T], T] = lambda _, v: v, +) -> dict[str, Any]: + """ + This function recursively traverses the user configuration values in a Pydantic + model instance. + + Params: + instance: A Pydantic model instance. + get_field_value: A callback function to process each field. Parameters: + ModelField: The Pydantic ModelField object that describes the field. + Any: The current value of the field. + + Returns: + A dictionary containing the processed user configuration fields of the instance. + """ + user_config_values = {} + + for name, value in instance.__dict__.items(): + field = instance.__fields__[name] + if "user_configurable" in field.field_info.extra: + user_config_values[name] = get_field_value(field, value) + elif isinstance(value, SystemConfiguration): + user_config_values[name] = _recurse_user_config_values( + instance=value, get_field_value=get_field_value + ) + elif isinstance(value, list) and all( + isinstance(i, SystemConfiguration) for i in value + ): + user_config_values[name] = [ + _recurse_user_config_values(i, get_field_value) for i in value + ] + elif isinstance(value, dict) and all( + isinstance(i, SystemConfiguration) for i in value.values() + ): + user_config_values[name] = { + k: _recurse_user_config_values(v, get_field_value) + for k, v in value.items() + } + + return user_config_values + + +def _get_non_default_user_config_values(instance: BaseModel) -> dict[str, Any]: + """ + Get the non-default user config fields of a Pydantic model instance. + + Params: + instance: The Pydantic model instance. + + Returns: + dict[str, Any]: The non-default user config values on the instance. + """ + + def get_field_value(field: ModelField, value): + default = field.default_factory() if field.default_factory else field.default + if value != default: + return value + + return remove_none_items(_recurse_user_config_values(instance, get_field_value)) + + +def deep_update(original_dict: dict, update_dict: dict) -> dict: + """ + Recursively update a dictionary. + + Params: + original_dict (dict): The dictionary to be updated. + update_dict (dict): The dictionary to update with. + + Returns: + dict: The updated dictionary. + """ + for key, value in update_dict.items(): + if ( + key in original_dict + and isinstance(original_dict[key], dict) + and isinstance(value, dict) + ): + original_dict[key] = deep_update(original_dict[key], value) + else: + original_dict[key] = value + return original_dict + + +def remove_none_items(d): + if isinstance(d, dict): + return { + k: remove_none_items(v) for k, v in d.items() if v not in (None, Undefined) + } + return d |