diff options
Diffstat (limited to 'autogpts/autogpt/tests/integration/test_image_gen.py')
-rw-r--r-- | autogpts/autogpt/tests/integration/test_image_gen.py | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/autogpts/autogpt/tests/integration/test_image_gen.py b/autogpts/autogpt/tests/integration/test_image_gen.py new file mode 100644 index 000000000..14b90aec7 --- /dev/null +++ b/autogpts/autogpt/tests/integration/test_image_gen.py @@ -0,0 +1,235 @@ +import functools +import hashlib +from pathlib import Path +from unittest.mock import patch + +import pytest +from PIL import Image + +from autogpt.agents.agent import Agent +from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui + + +@pytest.fixture(params=[256, 512, 1024]) +def image_size(request): + """Parametrize image size.""" + return request.param + + +@pytest.mark.requires_openai_api_key +@pytest.mark.vcr +def test_dalle(agent: Agent, workspace, image_size, cached_openai_client): + """Test DALL-E image generation.""" + generate_and_validate( + agent, + workspace, + image_provider="dalle", + image_size=image_size, + ) + + +@pytest.mark.xfail( + reason="The image is too big to be put in a cassette for a CI pipeline. " + "We're looking into a solution." +) +@pytest.mark.requires_huggingface_api_key +@pytest.mark.parametrize( + "image_model", + ["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"], +) +def test_huggingface(agent: Agent, workspace, image_size, image_model): + """Test HuggingFace image generation.""" + generate_and_validate( + agent, + workspace, + image_provider="huggingface", + image_size=image_size, + hugging_face_image_model=image_model, + ) + + +@pytest.mark.xfail(reason="SD WebUI call does not work.") +def test_sd_webui(agent: Agent, workspace, image_size): + """Test SD WebUI image generation.""" + generate_and_validate( + agent, + workspace, + image_provider="sd_webui", + image_size=image_size, + ) + + +@pytest.mark.xfail(reason="SD WebUI call does not work.") +def test_sd_webui_negative_prompt(agent: Agent, workspace, image_size): + gen_image = functools.partial( + generate_image_with_sd_webui, + prompt="astronaut riding a horse", + agent=agent, + size=image_size, + extra={"seed": 123}, + ) + + # Generate an image with a negative prompt + image_path = lst( + gen_image(negative_prompt="horse", output_file=Path("negative.jpg")) + ) + with Image.open(image_path) as img: + neg_image_hash = hashlib.md5(img.tobytes()).hexdigest() + + # Generate an image without a negative prompt + image_path = lst(gen_image(output_file=Path("positive.jpg"))) + with Image.open(image_path) as img: + image_hash = hashlib.md5(img.tobytes()).hexdigest() + + assert image_hash != neg_image_hash + + +def lst(txt): + """Extract the file path from the output of `generate_image()`""" + return Path(txt.split(": ", maxsplit=1)[1].strip()) + + +def generate_and_validate( + agent: Agent, + workspace, + image_size, + image_provider, + hugging_face_image_model=None, + **kwargs, +): + """Generate an image and validate the output.""" + agent.legacy_config.image_provider = image_provider + if hugging_face_image_model: + agent.legacy_config.huggingface_image_model = hugging_face_image_model + prompt = "astronaut riding a horse" + + image_path = lst(generate_image(prompt, agent, image_size, **kwargs)) + assert image_path.exists() + with Image.open(image_path) as img: + assert img.size == (image_size, image_size) + + +@pytest.mark.parametrize( + "return_text", + [ + # Delay + '{"error":"Model [model] is currently loading","estimated_time": [delay]}', + '{"error":"Model [model] is currently loading"}', # No delay + '{"error:}', # Bad JSON + "", # Bad Image + ], +) +@pytest.mark.parametrize( + "image_model", + ["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"], +) +@pytest.mark.parametrize("delay", [10, 0]) +def test_huggingface_fail_request_with_delay( + agent: Agent, workspace, image_size, image_model, return_text, delay +): + return_text = return_text.replace("[model]", image_model).replace( + "[delay]", str(delay) + ) + + with patch("requests.post") as mock_post: + if return_text == "": + # Test bad image + mock_post.return_value.status_code = 200 + mock_post.return_value.ok = True + mock_post.return_value.content = b"bad image" + else: + # Test delay and bad json + mock_post.return_value.status_code = 500 + mock_post.return_value.ok = False + mock_post.return_value.text = return_text + + agent.legacy_config.image_provider = "huggingface" + agent.legacy_config.huggingface_api_token = "mock-api-key" + agent.legacy_config.huggingface_image_model = image_model + prompt = "astronaut riding a horse" + + with patch("time.sleep") as mock_sleep: + # Verify request fails. + result = generate_image(prompt, agent, image_size) + assert result == "Error creating image." + + # Verify retry was called with delay if delay is in return_text + if "estimated_time" in return_text: + mock_sleep.assert_called_with(delay) + else: + mock_sleep.assert_not_called() + + +def test_huggingface_fail_request_no_delay(mocker, agent: Agent): + agent.legacy_config.huggingface_api_token = "1" + + # Mock requests.post + mock_post = mocker.patch("requests.post") + mock_post.return_value.status_code = 500 + mock_post.return_value.ok = False + mock_post.return_value.text = ( + '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading"}' + ) + + # Mock time.sleep + mock_sleep = mocker.patch("time.sleep") + + agent.legacy_config.image_provider = "huggingface" + agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + result = generate_image("astronaut riding a horse", agent, 512) + + assert result == "Error creating image." + + # Verify retry was not called. + mock_sleep.assert_not_called() + + +def test_huggingface_fail_request_bad_json(mocker, agent: Agent): + agent.legacy_config.huggingface_api_token = "1" + + # Mock requests.post + mock_post = mocker.patch("requests.post") + mock_post.return_value.status_code = 500 + mock_post.return_value.ok = False + mock_post.return_value.text = '{"error:}' + + # Mock time.sleep + mock_sleep = mocker.patch("time.sleep") + + agent.legacy_config.image_provider = "huggingface" + agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + result = generate_image("astronaut riding a horse", agent, 512) + + assert result == "Error creating image." + + # Verify retry was not called. + mock_sleep.assert_not_called() + + +def test_huggingface_fail_request_bad_image(mocker, agent: Agent): + agent.legacy_config.huggingface_api_token = "1" + + # Mock requests.post + mock_post = mocker.patch("requests.post") + mock_post.return_value.status_code = 200 + + agent.legacy_config.image_provider = "huggingface" + agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + result = generate_image("astronaut riding a horse", agent, 512) + + assert result == "Error creating image." + + +def test_huggingface_fail_missing_api_token(mocker, agent: Agent): + agent.legacy_config.image_provider = "huggingface" + agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" + + # Mock requests.post to raise ValueError + mocker.patch("requests.post", side_effect=ValueError) + + # Verify request raises an error. + with pytest.raises(ValueError): + generate_image("astronaut riding a horse", agent, 512) |