diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/core/configuration/schema.py')
-rw-r--r-- | autogpts/autogpt/autogpt/core/configuration/schema.py | 285 |
1 files changed, 263 insertions, 22 deletions
diff --git a/autogpts/autogpt/autogpt/core/configuration/schema.py b/autogpts/autogpt/autogpt/core/configuration/schema.py index eebb7ba9d..5bc95ffac 100644 --- a/autogpts/autogpt/autogpt/core/configuration/schema.py +++ b/autogpts/autogpt/autogpt/core/configuration/schema.py @@ -1,24 +1,73 @@ import abc -import functools +import os import typing -from typing import Any, Generic, TypeVar +from typing import Any, Callable, Generic, Optional, Type, TypeVar, get_args -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError +from pydantic.fields import ModelField, Undefined, UndefinedType +from pydantic.main import ModelMetaclass +T = TypeVar("T") +M = TypeVar("M", bound=BaseModel) -@functools.wraps(Field) -def UserConfigurable(*args, **kwargs): - return Field(*args, **kwargs, user_configurable=True) + +def UserConfigurable( + default: T | UndefinedType = Undefined, + *args, + default_factory: Optional[Callable[[], T]] = None, + from_env: Optional[str | Callable[[], T | None]] = None, + description: str = "", + **kwargs, +) -> T: # TODO: use this to auto-generate docs for the application configuration + return Field( + default, + *args, + default_factory=default_factory, + from_env=from_env, + description=description, + **kwargs, + user_configurable=True, + ) class SystemConfiguration(BaseModel): def get_user_config(self) -> dict[str, Any]: - return _get_user_config_fields(self) + return _recurse_user_config_values(self) + + @classmethod + def from_env(cls): + """ + Initializes the config object from environment variables. + + Environment variables are mapped to UserConfigurable fields using the from_env + attribute that can be passed to UserConfigurable. + """ + + def infer_field_value(field: ModelField): + field_info = field.field_info + default_value = ( + field.default + if field.default not in (None, Undefined) + else (field.default_factory() if field.default_factory else Undefined) + ) + if from_env := field_info.extra.get("from_env"): + val_from_env = ( + os.getenv(from_env) if type(from_env) is str else from_env() + ) + if val_from_env is not None: + return val_from_env + return default_value + + return _recursive_init_model(cls, infer_field_value) class Config: extra = "forbid" use_enum_values = True + validate_assignment = True + + +SC = TypeVar("SC", bound=SystemConfiguration) class SystemSettings(BaseModel): @@ -30,6 +79,7 @@ class SystemSettings(BaseModel): class Config: extra = "forbid" use_enum_values = True + validate_assignment = True S = TypeVar("S", bound=SystemSettings) @@ -43,55 +93,238 @@ class Configurable(abc.ABC, Generic[S]): @classmethod def get_user_config(cls) -> dict[str, Any]: - return _get_user_config_fields(cls.default_settings) + return _recurse_user_config_values(cls.default_settings) @classmethod - def build_agent_configuration(cls, configuration: dict) -> S: + def build_agent_configuration(cls, overrides: dict = {}) -> S: """Process the configuration for this object.""" - defaults = cls.default_settings.dict() - final_configuration = deep_update(defaults, configuration) + base_config = _update_user_config_from_env(cls.default_settings) + final_configuration = deep_update(base_config, overrides) return cls.default_settings.__class__.parse_obj(final_configuration) -def _get_user_config_fields(instance: BaseModel) -> dict[str, Any]: +def _update_user_config_from_env(instance: BaseModel) -> dict[str, Any]: """ - Get the user config fields of a Pydantic model instance. + Update config fields of a Pydantic model instance from environment variables. + + Precedence: + 1. Non-default value already on the instance + 2. Value returned by `from_env()` + 3. Default value for the field - Args: + Params: instance: The Pydantic model instance. Returns: The user config fields of the instance. """ + + def infer_field_value(field: ModelField, value): + field_info = field.field_info + default_value = ( + field.default + if field.default not in (None, Undefined) + else (field.default_factory() if field.default_factory else None) + ) + if value == default_value and (from_env := field_info.extra.get("from_env")): + val_from_env = os.getenv(from_env) if type(from_env) is str else from_env() + if val_from_env is not None: + return val_from_env + return value + + def init_sub_config(model: Type[SC]) -> SC | None: + try: + return model.from_env() + except ValidationError as e: + # Gracefully handle missing fields + if all(e["type"] == "value_error.missing" for e in e.errors()): + return None + raise + + return _recurse_user_config_fields(instance, infer_field_value, init_sub_config) + + +def _recursive_init_model( + model: Type[M], + infer_field_value: Callable[[ModelField], Any], +) -> M: + """ + Recursively initialize the user configuration fields of a Pydantic model. + + Parameters: + model: The Pydantic model type. + infer_field_value: A callback function to infer the value of each field. + Parameters: + ModelField: The Pydantic ModelField object describing the field. + + Returns: + BaseModel: An instance of the model with the initialized configuration. + """ user_config_fields = {} + for name, field in model.__fields__.items(): + if "user_configurable" in field.field_info.extra: + user_config_fields[name] = infer_field_value(field) + elif type(field.outer_type_) is ModelMetaclass and issubclass( + field.outer_type_, SystemConfiguration + ): + try: + user_config_fields[name] = _recursive_init_model( + model=field.outer_type_, + infer_field_value=infer_field_value, + ) + except ValidationError as e: + # Gracefully handle missing fields + if all(e["type"] == "value_error.missing" for e in e.errors()): + user_config_fields[name] = None + raise - for name, value in instance.__dict__.items(): - field_info = instance.__fields__[name] - if "user_configurable" in field_info.field_info.extra: - user_config_fields[name] = value + user_config_fields = remove_none_items(user_config_fields) + + return model.parse_obj(user_config_fields) + + +def _recurse_user_config_fields( + model: BaseModel, + infer_field_value: Callable[[ModelField, Any], Any], + init_sub_config: Optional[ + Callable[[Type[SystemConfiguration]], SystemConfiguration | None] + ] = None, +) -> dict[str, Any]: + """ + Recursively process the user configuration fields of a Pydantic model instance. + + Params: + model: The Pydantic model to iterate over. + infer_field_value: A callback function to process each field. + Params: + ModelField: The Pydantic ModelField object describing the field. + Any: The current value of the field. + init_sub_config: An optional callback function to initialize a sub-config. + Params: + Type[SystemConfiguration]: The type of the sub-config to initialize. + + Returns: + dict[str, Any]: The processed user configuration fields of the instance. + """ + user_config_fields = {} + + for name, field in model.__fields__.items(): + value = getattr(model, name) + + # Handle individual field + if "user_configurable" in field.field_info.extra: + user_config_fields[name] = infer_field_value(field, value) + + # Recurse into nested config object elif isinstance(value, SystemConfiguration): - user_config_fields[name] = value.get_user_config() + user_config_fields[name] = _recurse_user_config_fields( + model=value, + infer_field_value=infer_field_value, + init_sub_config=init_sub_config, + ) + + # Recurse into optional nested config object + elif value is None and init_sub_config: + field_type = get_args(field.annotation)[0] # Optional[T] -> T + if type(field_type) is ModelMetaclass and issubclass( + field_type, SystemConfiguration + ): + sub_config = init_sub_config(field_type) + if sub_config: + user_config_fields[name] = _recurse_user_config_fields( + model=sub_config, + infer_field_value=infer_field_value, + init_sub_config=init_sub_config, + ) + elif isinstance(value, list) and all( isinstance(i, SystemConfiguration) for i in value ): - user_config_fields[name] = [i.get_user_config() for i in value] + user_config_fields[name] = [ + _recurse_user_config_fields(i, infer_field_value, init_sub_config) + for i in value + ] elif isinstance(value, dict) and all( isinstance(i, SystemConfiguration) for i in value.values() ): user_config_fields[name] = { - k: v.get_user_config() for k, v in value.items() + k: _recurse_user_config_fields(v, infer_field_value, init_sub_config) + for k, v in value.items() } return user_config_fields +def _recurse_user_config_values( + instance: BaseModel, + get_field_value: Callable[[ModelField, T], T] = lambda _, v: v, +) -> dict[str, Any]: + """ + This function recursively traverses the user configuration values in a Pydantic + model instance. + + Params: + instance: A Pydantic model instance. + get_field_value: A callback function to process each field. Parameters: + ModelField: The Pydantic ModelField object that describes the field. + Any: The current value of the field. + + Returns: + A dictionary containing the processed user configuration fields of the instance. + """ + user_config_values = {} + + for name, value in instance.__dict__.items(): + field = instance.__fields__[name] + if "user_configurable" in field.field_info.extra: + user_config_values[name] = get_field_value(field, value) + elif isinstance(value, SystemConfiguration): + user_config_values[name] = _recurse_user_config_values( + instance=value, get_field_value=get_field_value + ) + elif isinstance(value, list) and all( + isinstance(i, SystemConfiguration) for i in value + ): + user_config_values[name] = [ + _recurse_user_config_values(i, get_field_value) for i in value + ] + elif isinstance(value, dict) and all( + isinstance(i, SystemConfiguration) for i in value.values() + ): + user_config_values[name] = { + k: _recurse_user_config_values(v, get_field_value) + for k, v in value.items() + } + + return user_config_values + + +def _get_non_default_user_config_values(instance: BaseModel) -> dict[str, Any]: + """ + Get the non-default user config fields of a Pydantic model instance. + + Params: + instance: The Pydantic model instance. + + Returns: + dict[str, Any]: The non-default user config values on the instance. + """ + + def get_field_value(field: ModelField, value): + default = field.default_factory() if field.default_factory else field.default + if value != default: + return value + + return remove_none_items(_recurse_user_config_values(instance, get_field_value)) + + def deep_update(original_dict: dict, update_dict: dict) -> dict: """ Recursively update a dictionary. - Args: + Params: original_dict (dict): The dictionary to be updated. update_dict (dict): The dictionary to update with. @@ -108,3 +341,11 @@ def deep_update(original_dict: dict, update_dict: dict) -> dict: else: original_dict[key] = value return original_dict + + +def remove_none_items(d): + if isinstance(d, dict): + return { + k: remove_none_items(v) for k, v in d.items() if v not in (None, Undefined) + } + return d |