aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/models/command_registry.py
blob: ec372c9f53c10d470c9074df64a6b7bc85e14d28 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from __future__ import annotations

import importlib
import inspect
import logging
from dataclasses import dataclass, field
from types import ModuleType
from typing import TYPE_CHECKING, Any, Iterator

if TYPE_CHECKING:
    from autogpt.agents.base import BaseAgent
    from autogpt.config import Config


from autogpt.command_decorator import AUTO_GPT_COMMAND_IDENTIFIER
from autogpt.models.command import Command

logger = logging.getLogger(__name__)


class CommandRegistry:
    """
    The CommandRegistry class is a manager for a collection of Command objects.
    It allows the registration, modification, and retrieval of Command objects,
    as well as the scanning and loading of command plugins from a specified
    directory.
    """

    commands: dict[str, Command]
    commands_aliases: dict[str, Command]

    # Alternative way to structure the registry; currently redundant with self.commands
    categories: dict[str, CommandCategory]

    @dataclass
    class CommandCategory:
        name: str
        title: str
        description: str
        commands: list[Command] = field(default_factory=list[Command])
        modules: list[ModuleType] = field(default_factory=list[ModuleType])

    def __init__(self):
        self.commands = {}
        self.commands_aliases = {}
        self.categories = {}

    def __contains__(self, command_name: str):
        return command_name in self.commands or command_name in self.commands_aliases

    def _import_module(self, module_name: str) -> Any:
        return importlib.import_module(module_name)

    def _reload_module(self, module: Any) -> Any:
        return importlib.reload(module)

    def register(self, cmd: Command) -> None:
        if cmd.name in self.commands:
            logger.warning(
                f"Command '{cmd.name}' already registered and will be overwritten!"
            )
        self.commands[cmd.name] = cmd

        if cmd.name in self.commands_aliases:
            logger.warning(
                f"Command '{cmd.name}' will overwrite alias with the same name of "
                f"'{self.commands_aliases[cmd.name]}'!"
            )
        for alias in cmd.aliases:
            self.commands_aliases[alias] = cmd

    def unregister(self, command: Command) -> None:
        if command.name in self.commands:
            del self.commands[command.name]
            for alias in command.aliases:
                del self.commands_aliases[alias]
        else:
            raise KeyError(f"Command '{command.name}' not found in registry.")

    def reload_commands(self) -> None:
        """Reloads all loaded command plugins."""
        for cmd_name in self.commands:
            cmd = self.commands[cmd_name]
            module = self._import_module(cmd.__module__)
            reloaded_module = self._reload_module(module)
            if hasattr(reloaded_module, "register"):
                reloaded_module.register(self)

    def get_command(self, name: str) -> Command | None:
        if name in self.commands:
            return self.commands[name]

        if name in self.commands_aliases:
            return self.commands_aliases[name]

    def call(self, command_name: str, agent: BaseAgent, **kwargs) -> Any:
        if command := self.get_command(command_name):
            return command(**kwargs, agent=agent)
        raise KeyError(f"Command '{command_name}' not found in registry")

    def list_available_commands(self, agent: BaseAgent) -> Iterator[Command]:
        """Iterates over all registered commands and yields those that are available.

        Params:
            agent (BaseAgent): The agent that the commands will be checked against.

        Yields:
            Command: The next available command.
        """

        for cmd in self.commands.values():
            available = cmd.available
            if callable(cmd.available):
                available = cmd.available(agent)
            if available:
                yield cmd

    # def command_specs(self) -> str:
    #     """
    #     Returns a technical declaration of all commands in the registry,
    #     for use in a prompt.
    #     """
    #
    #     Declaring functions or commands should be done in a model-specific way to
    #     achieve optimal results. For this reason, it should NOT be implemented here,
    #     but in an LLM provider module.
    #     MUST take command AVAILABILITY into account.

    @staticmethod
    def with_command_modules(modules: list[str], config: Config) -> CommandRegistry:
        new_registry = CommandRegistry()

        logger.debug(
            "The following command categories are disabled: "
            f"{config.disabled_command_categories}"
        )
        enabled_command_modules = [
            x for x in modules if x not in config.disabled_command_categories
        ]

        logger.debug(
            f"The following command categories are enabled: {enabled_command_modules}"
        )

        for command_module in enabled_command_modules:
            new_registry.import_command_module(command_module)

        # Unregister commands that are incompatible with the current config
        for command in [c for c in new_registry.commands.values()]:
            if callable(command.enabled) and not command.enabled(config):
                new_registry.unregister(command)
                logger.debug(
                    f"Unregistering incompatible command '{command.name}':"
                    f" \"{command.disabled_reason or 'Disabled by current config.'}\""
                )

        return new_registry

    def import_command_module(self, module_name: str) -> None:
        """
        Imports the specified Python module containing command plugins.

        This method imports the associated module and registers any functions or
        classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
        as `Command` objects. The registered `Command` objects are then added to the
        `commands` dictionary of the `CommandRegistry` object.

        Args:
            module_name (str): The name of the module to import for command plugins.
        """

        module = importlib.import_module(module_name)

        category = self.register_module_category(module)

        for attr_name in dir(module):
            attr = getattr(module, attr_name)

            command = None

            # Register decorated functions
            if getattr(attr, AUTO_GPT_COMMAND_IDENTIFIER, False):
                command = attr.command

            # Register command classes
            elif (
                inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
            ):
                command = attr()

            if command:
                self.register(command)
                category.commands.append(command)

    def register_module_category(self, module: ModuleType) -> CommandCategory:
        if not (category_name := getattr(module, "COMMAND_CATEGORY", None)):
            raise ValueError(f"Cannot import invalid command module {module.__name__}")

        if category_name not in self.categories:
            self.categories[category_name] = CommandRegistry.CommandCategory(
                name=category_name,
                title=getattr(
                    module, "COMMAND_CATEGORY_TITLE", category_name.capitalize()
                ),
                description=getattr(module, "__doc__", ""),
            )

        category = self.categories[category_name]
        if module not in category.modules:
            category.modules.append(module)

        return category