diff options
Diffstat (limited to 'autogpts/autogpt/tests/integration/test_image_gen.py')
-rw-r--r-- | autogpts/autogpt/tests/integration/test_image_gen.py | 45 |
1 files changed, 14 insertions, 31 deletions
diff --git a/autogpts/autogpt/tests/integration/test_image_gen.py b/autogpts/autogpt/tests/integration/test_image_gen.py index 450966486..14b90aec7 100644 --- a/autogpts/autogpt/tests/integration/test_image_gen.py +++ b/autogpts/autogpt/tests/integration/test_image_gen.py @@ -18,7 +18,7 @@ def image_size(request): @pytest.mark.requires_openai_api_key @pytest.mark.vcr -def test_dalle(agent: Agent, workspace, image_size, patched_api_requestor): +def test_dalle(agent: Agent, workspace, image_size, cached_openai_client): """Test DALL-E image generation.""" generate_and_validate( agent, @@ -29,7 +29,8 @@ def test_dalle(agent: Agent, workspace, image_size, patched_api_requestor): @pytest.mark.xfail( - reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution." + reason="The image is too big to be put in a cassette for a CI pipeline. " + "We're looking into a solution." ) @pytest.mark.requires_huggingface_api_key @pytest.mark.parametrize( @@ -69,12 +70,14 @@ def test_sd_webui_negative_prompt(agent: Agent, workspace, image_size): ) # Generate an image with a negative prompt - image_path = lst(gen_image(negative_prompt="horse", filename="negative.jpg")) + image_path = lst( + gen_image(negative_prompt="horse", output_file=Path("negative.jpg")) + ) with Image.open(image_path) as img: neg_image_hash = hashlib.md5(img.tobytes()).hexdigest() # Generate an image without a negative prompt - image_path = lst(gen_image(filename="positive.jpg")) + image_path = lst(gen_image(output_file=Path("positive.jpg"))) with Image.open(image_path) as img: image_hash = hashlib.md5(img.tobytes()).hexdigest() @@ -83,7 +86,7 @@ def test_sd_webui_negative_prompt(agent: Agent, workspace, image_size): def lst(txt): """Extract the file path from the output of `generate_image()`""" - return Path(txt.split(":", maxsplit=1)[1].strip()) + return Path(txt.split(": ", maxsplit=1)[1].strip()) def generate_and_validate( @@ -96,7 +99,8 @@ def generate_and_validate( ): """Generate an image and validate the output.""" agent.legacy_config.image_provider = image_provider - agent.legacy_config.huggingface_image_model = hugging_face_image_model + if hugging_face_image_model: + agent.legacy_config.huggingface_image_model = hugging_face_image_model prompt = "astronaut riding a horse" image_path = lst(generate_image(prompt, agent, image_size, **kwargs)) @@ -108,7 +112,8 @@ def generate_and_validate( @pytest.mark.parametrize( "return_text", [ - '{"error":"Model [model] is currently loading","estimated_time": [delay]}', # Delay + # Delay + '{"error":"Model [model] is currently loading","estimated_time": [delay]}', '{"error":"Model [model] is currently loading"}', # No delay '{"error:}', # Bad JSON "", # Bad Image @@ -139,6 +144,7 @@ def test_huggingface_fail_request_with_delay( mock_post.return_value.text = return_text agent.legacy_config.image_provider = "huggingface" + agent.legacy_config.huggingface_api_token = "mock-api-key" agent.legacy_config.huggingface_image_model = image_model prompt = "astronaut riding a horse" @@ -154,29 +160,6 @@ def test_huggingface_fail_request_with_delay( mock_sleep.assert_not_called() -def test_huggingface_fail_request_with_delay(mocker, agent: Agent): - agent.legacy_config.huggingface_api_token = "1" - - # Mock requests.post - mock_post = mocker.patch("requests.post") - mock_post.return_value.status_code = 500 - mock_post.return_value.ok = False - mock_post.return_value.text = '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading","estimated_time":0}' - - # Mock time.sleep - mock_sleep = mocker.patch("time.sleep") - - agent.legacy_config.image_provider = "huggingface" - agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" - - result = generate_image("astronaut riding a horse", agent, 512) - - assert result == "Error creating image." - - # Verify retry was called with delay. - mock_sleep.assert_called_with(0) - - def test_huggingface_fail_request_no_delay(mocker, agent: Agent): agent.legacy_config.huggingface_api_token = "1" @@ -245,7 +228,7 @@ def test_huggingface_fail_missing_api_token(mocker, agent: Agent): agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4" # Mock requests.post to raise ValueError - mock_post = mocker.patch("requests.post", side_effect=ValueError) + mocker.patch("requests.post", side_effect=ValueError) # Verify request raises an error. with pytest.raises(ValueError): |