aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/tests/unit/test_commands.py
blob: a939ec4d207e1bff9a9af2aa0120daa79cd5d2b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from __future__ import annotations

import os
import shutil
import sys
from pathlib import Path
from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
    from autogpt.agents import Agent, BaseAgent

from autogpt.core.utils.json_schema import JSONSchema
from autogpt.models.command import Command, CommandParameter
from autogpt.models.command_registry import CommandRegistry

PARAMETERS = [
    CommandParameter(
        "arg1",
        spec=JSONSchema(
            type=JSONSchema.Type.INTEGER,
            description="Argument 1",
            required=True,
        ),
    ),
    CommandParameter(
        "arg2",
        spec=JSONSchema(
            type=JSONSchema.Type.STRING,
            description="Argument 2",
            required=False,
        ),
    ),
]


def example_command_method(arg1: int, arg2: str, agent: BaseAgent) -> str:
    """Example function for testing the Command class."""
    # This function is static because it is not used by any other test cases.
    return f"{arg1} - {arg2}"


def test_command_creation():
    """Test that a Command object can be created with the correct attributes."""
    cmd = Command(
        name="example",
        description="Example command",
        method=example_command_method,
        parameters=PARAMETERS,
    )

    assert cmd.name == "example"
    assert cmd.description == "Example command"
    assert cmd.method == example_command_method
    assert (
        str(cmd)
        == "example: Example command. Params: (arg1: integer, arg2: Optional[string])"
    )


@pytest.fixture
def example_command():
    yield Command(
        name="example",
        description="Example command",
        method=example_command_method,
        parameters=PARAMETERS,
    )


def test_command_call(example_command: Command, agent: Agent):
    """Test that Command(*args) calls and returns the result of method(*args)."""
    result = example_command(arg1=1, arg2="test", agent=agent)
    assert result == "1 - test"


def test_command_call_with_invalid_arguments(example_command: Command, agent: Agent):
    """Test that calling a Command object with invalid arguments raises a TypeError."""
    with pytest.raises(TypeError):
        example_command(arg1="invalid", does_not_exist="test", agent=agent)


def test_register_command(example_command: Command):
    """Test that a command can be registered to the registry."""
    registry = CommandRegistry()

    registry.register(example_command)

    assert registry.get_command(example_command.name) == example_command
    assert len(registry.commands) == 1


def test_unregister_command(example_command: Command):
    """Test that a command can be unregistered from the registry."""
    registry = CommandRegistry()

    registry.register(example_command)
    registry.unregister(example_command)

    assert len(registry.commands) == 0
    assert example_command.name not in registry


@pytest.fixture
def example_command_with_aliases(example_command: Command):
    example_command.aliases = ["example_alias", "example_alias_2"]
    return example_command


def test_register_command_aliases(example_command_with_aliases: Command):
    """Test that a command can be registered to the registry."""
    registry = CommandRegistry()
    command = example_command_with_aliases

    registry.register(command)

    assert command.name in registry
    assert registry.get_command(command.name) == command
    for alias in command.aliases:
        assert registry.get_command(alias) == command
    assert len(registry.commands) == 1


def test_unregister_command_aliases(example_command_with_aliases: Command):
    """Test that a command can be unregistered from the registry."""
    registry = CommandRegistry()
    command = example_command_with_aliases

    registry.register(command)
    registry.unregister(command)

    assert len(registry.commands) == 0
    assert command.name not in registry
    for alias in command.aliases:
        assert alias not in registry


def test_command_in_registry(example_command_with_aliases: Command):
    """Test that `command_name in registry` works."""
    registry = CommandRegistry()
    command = example_command_with_aliases

    assert command.name not in registry
    assert "nonexistent_command" not in registry

    registry.register(command)

    assert command.name in registry
    assert "nonexistent_command" not in registry
    for alias in command.aliases:
        assert alias in registry


def test_get_command(example_command: Command):
    """Test that a command can be retrieved from the registry."""
    registry = CommandRegistry()

    registry.register(example_command)
    retrieved_cmd = registry.get_command(example_command.name)

    assert retrieved_cmd == example_command


def test_get_nonexistent_command():
    """Test that attempting to get a nonexistent command raises a KeyError."""
    registry = CommandRegistry()

    assert registry.get_command("nonexistent_command") is None
    assert "nonexistent_command" not in registry


def test_call_command(agent: Agent):
    """Test that a command can be called through the registry."""
    registry = CommandRegistry()
    cmd = Command(
        name="example",
        description="Example command",
        method=example_command_method,
        parameters=PARAMETERS,
    )

    registry.register(cmd)
    result = registry.call("example", arg1=1, arg2="test", agent=agent)

    assert result == "1 - test"


def test_call_nonexistent_command(agent: Agent):
    """Test that attempting to call a nonexistent command raises a KeyError."""
    registry = CommandRegistry()

    with pytest.raises(KeyError):
        registry.call("nonexistent_command", arg1=1, arg2="test", agent=agent)


def test_import_mock_commands_module():
    """Test that the registry can import a module with mock command plugins."""
    registry = CommandRegistry()
    mock_commands_module = "tests.mocks.mock_commands"

    registry.import_command_module(mock_commands_module)

    assert "function_based_cmd" in registry
    assert registry.commands["function_based_cmd"].name == "function_based_cmd"
    assert (
        registry.commands["function_based_cmd"].description
        == "Function-based test command"
    )


def test_import_temp_command_file_module(tmp_path: Path):
    """
    Test that the registry can import a command plugins module from a temp file.
    Args:
        tmp_path (pathlib.Path): Path to a temporary directory.
    """
    registry = CommandRegistry()

    # Create a temp command file
    src = Path(os.getcwd()) / "tests/mocks/mock_commands.py"
    temp_commands_file = tmp_path / "mock_commands.py"
    shutil.copyfile(src, temp_commands_file)

    # Add the temp directory to sys.path to make the module importable
    sys.path.append(str(tmp_path))

    temp_commands_module = "mock_commands"
    registry.import_command_module(temp_commands_module)

    # Remove the temp directory from sys.path
    sys.path.remove(str(tmp_path))

    assert "function_based_cmd" in registry
    assert registry.commands["function_based_cmd"].name == "function_based_cmd"
    assert (
        registry.commands["function_based_cmd"].description
        == "Function-based test command"
    )