aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/tests/integration/test_image_gen.py
blob: 4509664869a3892030b3c2fcc7cb77329f9540bd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import functools
import hashlib
from pathlib import Path
from unittest.mock import patch

import pytest
from PIL import Image

from autogpt.agents.agent import Agent
from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui


@pytest.fixture(params=[256, 512, 1024])
def image_size(request):
    """Parametrize image size."""
    return request.param


@pytest.mark.requires_openai_api_key
@pytest.mark.vcr
def test_dalle(agent: Agent, workspace, image_size, patched_api_requestor):
    """Test DALL-E image generation."""
    generate_and_validate(
        agent,
        workspace,
        image_provider="dalle",
        image_size=image_size,
    )


@pytest.mark.xfail(
    reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution."
)
@pytest.mark.requires_huggingface_api_key
@pytest.mark.parametrize(
    "image_model",
    ["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
)
def test_huggingface(agent: Agent, workspace, image_size, image_model):
    """Test HuggingFace image generation."""
    generate_and_validate(
        agent,
        workspace,
        image_provider="huggingface",
        image_size=image_size,
        hugging_face_image_model=image_model,
    )


@pytest.mark.xfail(reason="SD WebUI call does not work.")
def test_sd_webui(agent: Agent, workspace, image_size):
    """Test SD WebUI image generation."""
    generate_and_validate(
        agent,
        workspace,
        image_provider="sd_webui",
        image_size=image_size,
    )


@pytest.mark.xfail(reason="SD WebUI call does not work.")
def test_sd_webui_negative_prompt(agent: Agent, workspace, image_size):
    gen_image = functools.partial(
        generate_image_with_sd_webui,
        prompt="astronaut riding a horse",
        agent=agent,
        size=image_size,
        extra={"seed": 123},
    )

    # Generate an image with a negative prompt
    image_path = lst(gen_image(negative_prompt="horse", filename="negative.jpg"))
    with Image.open(image_path) as img:
        neg_image_hash = hashlib.md5(img.tobytes()).hexdigest()

    # Generate an image without a negative prompt
    image_path = lst(gen_image(filename="positive.jpg"))
    with Image.open(image_path) as img:
        image_hash = hashlib.md5(img.tobytes()).hexdigest()

    assert image_hash != neg_image_hash


def lst(txt):
    """Extract the file path from the output of `generate_image()`"""
    return Path(txt.split(":", maxsplit=1)[1].strip())


def generate_and_validate(
    agent: Agent,
    workspace,
    image_size,
    image_provider,
    hugging_face_image_model=None,
    **kwargs,
):
    """Generate an image and validate the output."""
    agent.legacy_config.image_provider = image_provider
    agent.legacy_config.huggingface_image_model = hugging_face_image_model
    prompt = "astronaut riding a horse"

    image_path = lst(generate_image(prompt, agent, image_size, **kwargs))
    assert image_path.exists()
    with Image.open(image_path) as img:
        assert img.size == (image_size, image_size)


@pytest.mark.parametrize(
    "return_text",
    [
        '{"error":"Model [model] is currently loading","estimated_time": [delay]}',  # Delay
        '{"error":"Model [model] is currently loading"}',  # No delay
        '{"error:}',  # Bad JSON
        "",  # Bad Image
    ],
)
@pytest.mark.parametrize(
    "image_model",
    ["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
)
@pytest.mark.parametrize("delay", [10, 0])
def test_huggingface_fail_request_with_delay(
    agent: Agent, workspace, image_size, image_model, return_text, delay
):
    return_text = return_text.replace("[model]", image_model).replace(
        "[delay]", str(delay)
    )

    with patch("requests.post") as mock_post:
        if return_text == "":
            # Test bad image
            mock_post.return_value.status_code = 200
            mock_post.return_value.ok = True
            mock_post.return_value.content = b"bad image"
        else:
            # Test delay and bad json
            mock_post.return_value.status_code = 500
            mock_post.return_value.ok = False
            mock_post.return_value.text = return_text

        agent.legacy_config.image_provider = "huggingface"
        agent.legacy_config.huggingface_image_model = image_model
        prompt = "astronaut riding a horse"

        with patch("time.sleep") as mock_sleep:
            # Verify request fails.
            result = generate_image(prompt, agent, image_size)
            assert result == "Error creating image."

            # Verify retry was called with delay if delay is in return_text
            if "estimated_time" in return_text:
                mock_sleep.assert_called_with(delay)
            else:
                mock_sleep.assert_not_called()


def test_huggingface_fail_request_with_delay(mocker, agent: Agent):
    agent.legacy_config.huggingface_api_token = "1"

    # Mock requests.post
    mock_post = mocker.patch("requests.post")
    mock_post.return_value.status_code = 500
    mock_post.return_value.ok = False
    mock_post.return_value.text = '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading","estimated_time":0}'

    # Mock time.sleep
    mock_sleep = mocker.patch("time.sleep")

    agent.legacy_config.image_provider = "huggingface"
    agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

    result = generate_image("astronaut riding a horse", agent, 512)

    assert result == "Error creating image."

    # Verify retry was called with delay.
    mock_sleep.assert_called_with(0)


def test_huggingface_fail_request_no_delay(mocker, agent: Agent):
    agent.legacy_config.huggingface_api_token = "1"

    # Mock requests.post
    mock_post = mocker.patch("requests.post")
    mock_post.return_value.status_code = 500
    mock_post.return_value.ok = False
    mock_post.return_value.text = (
        '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading"}'
    )

    # Mock time.sleep
    mock_sleep = mocker.patch("time.sleep")

    agent.legacy_config.image_provider = "huggingface"
    agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

    result = generate_image("astronaut riding a horse", agent, 512)

    assert result == "Error creating image."

    # Verify retry was not called.
    mock_sleep.assert_not_called()


def test_huggingface_fail_request_bad_json(mocker, agent: Agent):
    agent.legacy_config.huggingface_api_token = "1"

    # Mock requests.post
    mock_post = mocker.patch("requests.post")
    mock_post.return_value.status_code = 500
    mock_post.return_value.ok = False
    mock_post.return_value.text = '{"error:}'

    # Mock time.sleep
    mock_sleep = mocker.patch("time.sleep")

    agent.legacy_config.image_provider = "huggingface"
    agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

    result = generate_image("astronaut riding a horse", agent, 512)

    assert result == "Error creating image."

    # Verify retry was not called.
    mock_sleep.assert_not_called()


def test_huggingface_fail_request_bad_image(mocker, agent: Agent):
    agent.legacy_config.huggingface_api_token = "1"

    # Mock requests.post
    mock_post = mocker.patch("requests.post")
    mock_post.return_value.status_code = 200

    agent.legacy_config.image_provider = "huggingface"
    agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

    result = generate_image("astronaut riding a horse", agent, 512)

    assert result == "Error creating image."


def test_huggingface_fail_missing_api_token(mocker, agent: Agent):
    agent.legacy_config.image_provider = "huggingface"
    agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

    # Mock requests.post to raise ValueError
    mock_post = mocker.patch("requests.post", side_effect=ValueError)

    # Verify request raises an error.
    with pytest.raises(ValueError):
        generate_image("astronaut riding a horse", agent, 512)