aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/models/command_registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/models/command_registry.py')
-rw-r--r--autogpts/autogpt/autogpt/models/command_registry.py212
1 files changed, 212 insertions, 0 deletions
diff --git a/autogpts/autogpt/autogpt/models/command_registry.py b/autogpts/autogpt/autogpt/models/command_registry.py
new file mode 100644
index 000000000..ec372c9f5
--- /dev/null
+++ b/autogpts/autogpt/autogpt/models/command_registry.py
@@ -0,0 +1,212 @@
+from __future__ import annotations
+
+import importlib
+import inspect
+import logging
+from dataclasses import dataclass, field
+from types import ModuleType
+from typing import TYPE_CHECKING, Any, Iterator
+
+if TYPE_CHECKING:
+ from autogpt.agents.base import BaseAgent
+ from autogpt.config import Config
+
+
+from autogpt.command_decorator import AUTO_GPT_COMMAND_IDENTIFIER
+from autogpt.models.command import Command
+
+logger = logging.getLogger(__name__)
+
+
+class CommandRegistry:
+ """
+ The CommandRegistry class is a manager for a collection of Command objects.
+ It allows the registration, modification, and retrieval of Command objects,
+ as well as the scanning and loading of command plugins from a specified
+ directory.
+ """
+
+ commands: dict[str, Command]
+ commands_aliases: dict[str, Command]
+
+ # Alternative way to structure the registry; currently redundant with self.commands
+ categories: dict[str, CommandCategory]
+
+ @dataclass
+ class CommandCategory:
+ name: str
+ title: str
+ description: str
+ commands: list[Command] = field(default_factory=list[Command])
+ modules: list[ModuleType] = field(default_factory=list[ModuleType])
+
+ def __init__(self):
+ self.commands = {}
+ self.commands_aliases = {}
+ self.categories = {}
+
+ def __contains__(self, command_name: str):
+ return command_name in self.commands or command_name in self.commands_aliases
+
+ def _import_module(self, module_name: str) -> Any:
+ return importlib.import_module(module_name)
+
+ def _reload_module(self, module: Any) -> Any:
+ return importlib.reload(module)
+
+ def register(self, cmd: Command) -> None:
+ if cmd.name in self.commands:
+ logger.warning(
+ f"Command '{cmd.name}' already registered and will be overwritten!"
+ )
+ self.commands[cmd.name] = cmd
+
+ if cmd.name in self.commands_aliases:
+ logger.warning(
+ f"Command '{cmd.name}' will overwrite alias with the same name of "
+ f"'{self.commands_aliases[cmd.name]}'!"
+ )
+ for alias in cmd.aliases:
+ self.commands_aliases[alias] = cmd
+
+ def unregister(self, command: Command) -> None:
+ if command.name in self.commands:
+ del self.commands[command.name]
+ for alias in command.aliases:
+ del self.commands_aliases[alias]
+ else:
+ raise KeyError(f"Command '{command.name}' not found in registry.")
+
+ def reload_commands(self) -> None:
+ """Reloads all loaded command plugins."""
+ for cmd_name in self.commands:
+ cmd = self.commands[cmd_name]
+ module = self._import_module(cmd.__module__)
+ reloaded_module = self._reload_module(module)
+ if hasattr(reloaded_module, "register"):
+ reloaded_module.register(self)
+
+ def get_command(self, name: str) -> Command | None:
+ if name in self.commands:
+ return self.commands[name]
+
+ if name in self.commands_aliases:
+ return self.commands_aliases[name]
+
+ def call(self, command_name: str, agent: BaseAgent, **kwargs) -> Any:
+ if command := self.get_command(command_name):
+ return command(**kwargs, agent=agent)
+ raise KeyError(f"Command '{command_name}' not found in registry")
+
+ def list_available_commands(self, agent: BaseAgent) -> Iterator[Command]:
+ """Iterates over all registered commands and yields those that are available.
+
+ Params:
+ agent (BaseAgent): The agent that the commands will be checked against.
+
+ Yields:
+ Command: The next available command.
+ """
+
+ for cmd in self.commands.values():
+ available = cmd.available
+ if callable(cmd.available):
+ available = cmd.available(agent)
+ if available:
+ yield cmd
+
+ # def command_specs(self) -> str:
+ # """
+ # Returns a technical declaration of all commands in the registry,
+ # for use in a prompt.
+ # """
+ #
+ # Declaring functions or commands should be done in a model-specific way to
+ # achieve optimal results. For this reason, it should NOT be implemented here,
+ # but in an LLM provider module.
+ # MUST take command AVAILABILITY into account.
+
+ @staticmethod
+ def with_command_modules(modules: list[str], config: Config) -> CommandRegistry:
+ new_registry = CommandRegistry()
+
+ logger.debug(
+ "The following command categories are disabled: "
+ f"{config.disabled_command_categories}"
+ )
+ enabled_command_modules = [
+ x for x in modules if x not in config.disabled_command_categories
+ ]
+
+ logger.debug(
+ f"The following command categories are enabled: {enabled_command_modules}"
+ )
+
+ for command_module in enabled_command_modules:
+ new_registry.import_command_module(command_module)
+
+ # Unregister commands that are incompatible with the current config
+ for command in [c for c in new_registry.commands.values()]:
+ if callable(command.enabled) and not command.enabled(config):
+ new_registry.unregister(command)
+ logger.debug(
+ f"Unregistering incompatible command '{command.name}':"
+ f" \"{command.disabled_reason or 'Disabled by current config.'}\""
+ )
+
+ return new_registry
+
+ def import_command_module(self, module_name: str) -> None:
+ """
+ Imports the specified Python module containing command plugins.
+
+ This method imports the associated module and registers any functions or
+ classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
+ as `Command` objects. The registered `Command` objects are then added to the
+ `commands` dictionary of the `CommandRegistry` object.
+
+ Args:
+ module_name (str): The name of the module to import for command plugins.
+ """
+
+ module = importlib.import_module(module_name)
+
+ category = self.register_module_category(module)
+
+ for attr_name in dir(module):
+ attr = getattr(module, attr_name)
+
+ command = None
+
+ # Register decorated functions
+ if getattr(attr, AUTO_GPT_COMMAND_IDENTIFIER, False):
+ command = attr.command
+
+ # Register command classes
+ elif (
+ inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
+ ):
+ command = attr()
+
+ if command:
+ self.register(command)
+ category.commands.append(command)
+
+ def register_module_category(self, module: ModuleType) -> CommandCategory:
+ if not (category_name := getattr(module, "COMMAND_CATEGORY", None)):
+ raise ValueError(f"Cannot import invalid command module {module.__name__}")
+
+ if category_name not in self.categories:
+ self.categories[category_name] = CommandRegistry.CommandCategory(
+ name=category_name,
+ title=getattr(
+ module, "COMMAND_CATEGORY_TITLE", category_name.capitalize()
+ ),
+ description=getattr(module, "__doc__", ""),
+ )
+
+ category = self.categories[category_name]
+ if module not in category.modules:
+ category.modules.append(module)
+
+ return category