diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/commands/image_gen.py')
-rw-r--r-- | autogpts/autogpt/autogpt/commands/image_gen.py | 212 |
1 files changed, 212 insertions, 0 deletions
diff --git a/autogpts/autogpt/autogpt/commands/image_gen.py b/autogpts/autogpt/autogpt/commands/image_gen.py new file mode 100644 index 000000000..957c2ac5a --- /dev/null +++ b/autogpts/autogpt/autogpt/commands/image_gen.py @@ -0,0 +1,212 @@ +"""Commands to generate images based on text input""" + +import io +import json +import logging +import time +import uuid +from base64 import b64decode +from pathlib import Path + +import requests +from openai import OpenAI +from PIL import Image + +from autogpt.agents.agent import Agent +from autogpt.command_decorator import command +from autogpt.core.utils.json_schema import JSONSchema + +COMMAND_CATEGORY = "text_to_image" +COMMAND_CATEGORY_TITLE = "Text to Image" + + +logger = logging.getLogger(__name__) + + +@command( + "generate_image", + "Generates an Image", + { + "prompt": JSONSchema( + type=JSONSchema.Type.STRING, + description="The prompt used to generate the image", + required=True, + ), + }, + lambda config: bool(config.image_provider), + "Requires a image provider to be set.", +) +def generate_image(prompt: str, agent: Agent, size: int = 256) -> str: + """Generate an image from a prompt. + + Args: + prompt (str): The prompt to use + size (int, optional): The size of the image. Defaults to 256. + Not supported by HuggingFace. + + Returns: + str: The filename of the image + """ + filename = agent.workspace.root / f"{str(uuid.uuid4())}.jpg" + + # DALL-E + if agent.legacy_config.image_provider == "dalle": + return generate_image_with_dalle(prompt, filename, size, agent) + # HuggingFace + elif agent.legacy_config.image_provider == "huggingface": + return generate_image_with_hf(prompt, filename, agent) + # SD WebUI + elif agent.legacy_config.image_provider == "sdwebui": + return generate_image_with_sd_webui(prompt, filename, agent, size) + return "No Image Provider Set" + + +def generate_image_with_hf(prompt: str, output_file: Path, agent: Agent) -> str: + """Generate an image with HuggingFace's API. + + Args: + prompt (str): The prompt to use + filename (Path): The filename to save the image to + + Returns: + str: The filename of the image + """ + API_URL = f"https://api-inference.huggingface.co/models/{agent.legacy_config.huggingface_image_model}" # noqa: E501 + if agent.legacy_config.huggingface_api_token is None: + raise ValueError( + "You need to set your Hugging Face API token in the config file." + ) + headers = { + "Authorization": f"Bearer {agent.legacy_config.huggingface_api_token}", + "X-Use-Cache": "false", + } + + retry_count = 0 + while retry_count < 10: + response = requests.post( + API_URL, + headers=headers, + json={ + "inputs": prompt, + }, + ) + + if response.ok: + try: + image = Image.open(io.BytesIO(response.content)) + logger.info(f"Image Generated for prompt:{prompt}") + image.save(output_file) + return f"Saved to disk: {output_file}" + except Exception as e: + logger.error(e) + break + else: + try: + error = json.loads(response.text) + if "estimated_time" in error: + delay = error["estimated_time"] + logger.debug(response.text) + logger.info("Retrying in", delay) + time.sleep(delay) + else: + break + except Exception as e: + logger.error(e) + break + + retry_count += 1 + + return "Error creating image." + + +def generate_image_with_dalle( + prompt: str, output_file: Path, size: int, agent: Agent +) -> str: + """Generate an image with DALL-E. + + Args: + prompt (str): The prompt to use + filename (Path): The filename to save the image to + size (int): The size of the image + + Returns: + str: The filename of the image + """ + + # Check for supported image sizes + if size not in [256, 512, 1024]: + closest = min([256, 512, 1024], key=lambda x: abs(x - size)) + logger.info( + "DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. " + f"Setting to {closest}, was {size}." + ) + size = closest + + response = OpenAI( + api_key=agent.legacy_config.openai_credentials.api_key.get_secret_value() + ).images.generate( + prompt=prompt, + n=1, + size=f"{size}x{size}", + response_format="b64_json", + ) + + logger.info(f"Image Generated for prompt:{prompt}") + + image_data = b64decode(response.data[0].b64_json) + + with open(output_file, mode="wb") as png: + png.write(image_data) + + return f"Saved to disk: {output_file}" + + +def generate_image_with_sd_webui( + prompt: str, + output_file: Path, + agent: Agent, + size: int = 512, + negative_prompt: str = "", + extra: dict = {}, +) -> str: + """Generate an image with Stable Diffusion webui. + Args: + prompt (str): The prompt to use + filename (str): The filename to save the image to + size (int, optional): The size of the image. Defaults to 256. + negative_prompt (str, optional): The negative prompt to use. Defaults to "". + extra (dict, optional): Extra parameters to pass to the API. Defaults to {}. + Returns: + str: The filename of the image + """ + # Create a session and set the basic auth if needed + s = requests.Session() + if agent.legacy_config.sd_webui_auth: + username, password = agent.legacy_config.sd_webui_auth.split(":") + s.auth = (username, password or "") + + # Generate the images + response = requests.post( + f"{agent.legacy_config.sd_webui_url}/sdapi/v1/txt2img", + json={ + "prompt": prompt, + "negative_prompt": negative_prompt, + "sampler_index": "DDIM", + "steps": 20, + "config_scale": 7.0, + "width": size, + "height": size, + "n_iter": 1, + **extra, + }, + ) + + logger.info(f"Image Generated for prompt: '{prompt}'") + + # Save the image to disk + response = response.json() + b64 = b64decode(response["images"][0].split(",", 1)[0]) + image = Image.open(io.BytesIO(b64)) + image.save(output_file) + + return f"Saved to disk: {output_file}" |