aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/commands/image_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/commands/image_gen.py')
-rw-r--r--autogpts/autogpt/autogpt/commands/image_gen.py51
1 files changed, 28 insertions, 23 deletions
diff --git a/autogpts/autogpt/autogpt/commands/image_gen.py b/autogpts/autogpt/autogpt/commands/image_gen.py
index ba771635f..957c2ac5a 100644
--- a/autogpts/autogpt/autogpt/commands/image_gen.py
+++ b/autogpts/autogpt/autogpt/commands/image_gen.py
@@ -1,23 +1,25 @@
"""Commands to generate images based on text input"""
-COMMAND_CATEGORY = "text_to_image"
-COMMAND_CATEGORY_TITLE = "Text to Image"
-
import io
import json
import logging
import time
import uuid
from base64 import b64decode
+from pathlib import Path
-import openai
import requests
+from openai import OpenAI
from PIL import Image
from autogpt.agents.agent import Agent
from autogpt.command_decorator import command
from autogpt.core.utils.json_schema import JSONSchema
+COMMAND_CATEGORY = "text_to_image"
+COMMAND_CATEGORY_TITLE = "Text to Image"
+
+
logger = logging.getLogger(__name__)
@@ -39,7 +41,8 @@ def generate_image(prompt: str, agent: Agent, size: int = 256) -> str:
Args:
prompt (str): The prompt to use
- size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
+ size (int, optional): The size of the image. Defaults to 256.
+ Not supported by HuggingFace.
Returns:
str: The filename of the image
@@ -58,17 +61,17 @@ def generate_image(prompt: str, agent: Agent, size: int = 256) -> str:
return "No Image Provider Set"
-def generate_image_with_hf(prompt: str, filename: str, agent: Agent) -> str:
+def generate_image_with_hf(prompt: str, output_file: Path, agent: Agent) -> str:
"""Generate an image with HuggingFace's API.
Args:
prompt (str): The prompt to use
- filename (str): The filename to save the image to
+ filename (Path): The filename to save the image to
Returns:
str: The filename of the image
"""
- API_URL = f"https://api-inference.huggingface.co/models/{agent.legacy_config.huggingface_image_model}"
+ API_URL = f"https://api-inference.huggingface.co/models/{agent.legacy_config.huggingface_image_model}" # noqa: E501
if agent.legacy_config.huggingface_api_token is None:
raise ValueError(
"You need to set your Hugging Face API token in the config file."
@@ -92,8 +95,8 @@ def generate_image_with_hf(prompt: str, filename: str, agent: Agent) -> str:
try:
image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")
- image.save(filename)
- return f"Saved to disk:{filename}"
+ image.save(output_file)
+ return f"Saved to disk: {output_file}"
except Exception as e:
logger.error(e)
break
@@ -113,17 +116,17 @@ def generate_image_with_hf(prompt: str, filename: str, agent: Agent) -> str:
retry_count += 1
- return f"Error creating image."
+ return "Error creating image."
def generate_image_with_dalle(
- prompt: str, filename: str, size: int, agent: Agent
+ prompt: str, output_file: Path, size: int, agent: Agent
) -> str:
"""Generate an image with DALL-E.
Args:
prompt (str): The prompt to use
- filename (str): The filename to save the image to
+ filename (Path): The filename to save the image to
size (int): The size of the image
Returns:
@@ -134,31 +137,33 @@ def generate_image_with_dalle(
if size not in [256, 512, 1024]:
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
logger.info(
- f"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. Setting to {closest}, was {size}."
+ "DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. "
+ f"Setting to {closest}, was {size}."
)
size = closest
- response = openai.Image.create(
+ response = OpenAI(
+ api_key=agent.legacy_config.openai_credentials.api_key.get_secret_value()
+ ).images.generate(
prompt=prompt,
n=1,
size=f"{size}x{size}",
response_format="b64_json",
- api_key=agent.legacy_config.openai_api_key,
)
logger.info(f"Image Generated for prompt:{prompt}")
- image_data = b64decode(response["data"][0]["b64_json"])
+ image_data = b64decode(response.data[0].b64_json)
- with open(filename, mode="wb") as png:
+ with open(output_file, mode="wb") as png:
png.write(image_data)
- return f"Saved to disk:{filename}"
+ return f"Saved to disk: {output_file}"
def generate_image_with_sd_webui(
prompt: str,
- filename: str,
+ output_file: Path,
agent: Agent,
size: int = 512,
negative_prompt: str = "",
@@ -196,12 +201,12 @@ def generate_image_with_sd_webui(
},
)
- logger.info(f"Image Generated for prompt:{prompt}")
+ logger.info(f"Image Generated for prompt: '{prompt}'")
# Save the image to disk
response = response.json()
b64 = b64decode(response["images"][0].split(",", 1)[0])
image = Image.open(io.BytesIO(b64))
- image.save(filename)
+ image.save(output_file)
- return f"Saved to disk:{filename}"
+ return f"Saved to disk: {output_file}"