aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2023-12-13 22:41:55 +0100
committerGravatar Reinier van der Leer <pwuts@agpt.co> 2023-12-13 22:48:07 +0100
commitacf4df9f874855ec4b7dda80c5d18856561547b1 (patch)
treea3610b111e47d32bc84bbda06bd4f5532ad84979
parentci: Reset cassettes for test_dalle (diff)
downloadAuto-GPT-acf4df9f874855ec4b7dda80c5d18856561547b1.tar.gz
Auto-GPT-acf4df9f874855ec4b7dda80c5d18856561547b1.tar.bz2
Auto-GPT-acf4df9f874855ec4b7dda80c5d18856561547b1.zip
fix: Implement self-correction for invalid LLM responses
- Fix the parsing of invalid LLM responses by appending an error message to the prompt and allowing the LLM to fix its mistakes. - Update the `OpenAIProvider` to handle the self-correction process and limit the number of attempts to fix parsing errors. - Update the `BaseAgent` to profit from the new pasing and parse-fixing mechanism. This change ensures that the system can handle and recover from errors in parsing LLM responses. Hopefully this fixes #1407 once and for all.
-rw-r--r--autogpts/autogpt/autogpt/agents/agent.py19
-rw-r--r--autogpts/autogpt/autogpt/agents/base.py18
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/openai.py52
-rw-r--r--autogpts/autogpt/autogpt/core/utils/json_schema.py10
-rw-r--r--autogpts/autogpt/autogpt/json_utils/utilities.py14
5 files changed, 67 insertions, 46 deletions
diff --git a/autogpts/autogpt/autogpt/agents/agent.py b/autogpts/autogpt/autogpt/agents/agent.py
index fa387e79d..2b4dcd5ab 100644
--- a/autogpts/autogpt/autogpt/agents/agent.py
+++ b/autogpts/autogpt/autogpt/agents/agent.py
@@ -15,9 +15,9 @@ from pydantic import Field
from autogpt.core.configuration import Configurable
from autogpt.core.prompting import ChatPrompt
from autogpt.core.resource.model_providers import (
+ AssistantChatMessageDict,
ChatMessage,
ChatModelProvider,
- ChatModelResponse,
)
from autogpt.llm.api_manager import ApiManager
from autogpt.logs.log_cycle import (
@@ -44,7 +44,12 @@ from .prompt_strategies.one_shot import (
OneShotAgentPromptConfiguration,
OneShotAgentPromptStrategy,
)
-from .utils.exceptions import AgentException, CommandExecutionError, UnknownCommandError
+from .utils.exceptions import (
+ AgentException,
+ AgentTerminated,
+ CommandExecutionError,
+ UnknownCommandError,
+)
logger = logging.getLogger(__name__)
@@ -76,6 +81,8 @@ class Agent(
description=__doc__,
)
+ prompt_strategy: OneShotAgentPromptStrategy
+
def __init__(
self,
settings: AgentSettings,
@@ -164,20 +171,20 @@ class Agent(
return prompt
def parse_and_process_response(
- self, llm_response: ChatModelResponse, *args, **kwargs
+ self, llm_response: AssistantChatMessageDict, *args, **kwargs
) -> Agent.ThoughtProcessOutput:
for plugin in self.config.plugins:
if not plugin.can_handle_post_planning():
continue
- llm_response.response["content"] = plugin.post_planning(
- llm_response.response.get("content", "")
+ llm_response["content"] = plugin.post_planning(
+ llm_response.get("content", "")
)
(
command_name,
arguments,
assistant_reply_dict,
- ) = self.prompt_strategy.parse_response_content(llm_response.response)
+ ) = self.prompt_strategy.parse_response_content(llm_response)
self.log_cycle_handler.log_cycle(
self.ai_profile.ai_name,
diff --git a/autogpts/autogpt/autogpt/agents/base.py b/autogpts/autogpt/autogpt/agents/base.py
index 7c34d40ad..7bd118e03 100644
--- a/autogpts/autogpt/autogpt/agents/base.py
+++ b/autogpts/autogpt/autogpt/agents/base.py
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
from autogpt.config import Config
from autogpt.core.prompting.base import PromptStrategy
from autogpt.core.resource.model_providers.schema import (
+ AssistantChatMessageDict,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
@@ -247,7 +248,7 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
prompt = self.on_before_think(prompt, scratchpad=self._prompt_scratchpad)
logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
- raw_response = await self.llm_provider.create_chat_completion(
+ response = await self.llm_provider.create_chat_completion(
prompt.messages,
functions=get_openai_command_specs(
self.command_registry.list_available_commands(self)
@@ -256,11 +257,16 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
if self.config.use_functions_api
else [],
model_name=self.llm.name,
+ completion_parser=lambda r: self.parse_and_process_response(
+ r,
+ prompt,
+ scratchpad=self._prompt_scratchpad,
+ ),
)
self.config.cycle_count += 1
return self.on_response(
- llm_response=raw_response,
+ llm_response=response,
prompt=prompt,
scratchpad=self._prompt_scratchpad,
)
@@ -397,18 +403,14 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
The parsed command name and command args, if any, and the agent thoughts.
"""
- return self.parse_and_process_response(
- llm_response,
- prompt,
- scratchpad=scratchpad,
- )
+ return llm_response.parsed_result
# TODO: update memory/context
@abstractmethod
def parse_and_process_response(
self,
- llm_response: ChatModelResponse,
+ llm_response: AssistantChatMessageDict,
prompt: ChatPrompt,
scratchpad: PromptScratchpad,
) -> ThoughtProcessOutput:
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
index 3aad03fb1..bac2c9393 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
@@ -165,6 +165,7 @@ OPEN_AI_MODELS = {
class OpenAIConfiguration(ModelProviderConfiguration):
+ fix_failed_parse_tries: int = UserConfigurable(3)
pass
@@ -363,24 +364,45 @@ class OpenAIProvider(
model_prompt += completion_kwargs["messages"]
del completion_kwargs["messages"]
- response = await self._create_chat_completion(
- messages=model_prompt,
- **completion_kwargs,
- )
- response_args = {
- "model_info": OPEN_AI_CHAT_MODELS[model_name],
- "prompt_tokens_used": response.usage.prompt_tokens,
- "completion_tokens_used": response.usage.completion_tokens,
- }
-
- response_message = response.choices[0].message.to_dict_recursive()
- if tool_calls_compat_mode:
- response_message["tool_calls"] = _tool_calls_compat_extract_calls(
- response_message["content"]
+ attempts = 0
+ while True:
+ response = await self._create_chat_completion(
+ messages=model_prompt,
+ **completion_kwargs,
)
+ response_args = {
+ "model_info": OPEN_AI_CHAT_MODELS[model_name],
+ "prompt_tokens_used": response.usage.prompt_tokens,
+ "completion_tokens_used": response.usage.completion_tokens,
+ }
+
+ response_message = response.choices[0].message.to_dict_recursive()
+ if tool_calls_compat_mode:
+ response_message["tool_calls"] = _tool_calls_compat_extract_calls(
+ response_message["content"]
+ )
+
+ # If parsing the response fails, append the error to the prompt, and let the
+ # LLM fix its mistake(s).
+ try:
+ attempts += 1
+ parsed_response = completion_parser(response_message)
+ break
+ except Exception as e:
+ self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
+ self._logger.debug(
+ f"Parsing failed on response: '''{response_message}'''"
+ )
+ if attempts < self._configuration.fix_failed_parse_tries:
+ model_prompt.append(
+ ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}")
+ )
+ else:
+ raise
+
response = ChatModelResponse(
response=response_message,
- parsed_result=completion_parser(response_message),
+ parsed_result=parsed_response,
**response_args,
)
self._budget.update_usage_and_cost(response)
diff --git a/autogpts/autogpt/autogpt/core/utils/json_schema.py b/autogpts/autogpt/autogpt/core/utils/json_schema.py
index a28286f17..2ee4b67e9 100644
--- a/autogpts/autogpt/autogpt/core/utils/json_schema.py
+++ b/autogpts/autogpt/autogpt/core/utils/json_schema.py
@@ -103,18 +103,8 @@ class JSONSchema(BaseModel):
validator = Draft7Validator(self.to_dict())
if errors := sorted(validator.iter_errors(object), key=lambda e: e.path):
- for error in errors:
- logger.debug(f"JSON Validation Error: {error}")
-
- logger.error(json.dumps(object, indent=4))
- logger.error("The following issues were found:")
-
- for error in errors:
- logger.error(f"Error: {error.message}")
return False, errors
- logger.debug("The JSON object is valid.")
-
return True, None
def to_typescript_object_interface(self, interface_name: str = "") -> str:
diff --git a/autogpts/autogpt/autogpt/json_utils/utilities.py b/autogpts/autogpt/autogpt/json_utils/utilities.py
index 80ef8cee3..fe203b290 100644
--- a/autogpts/autogpt/autogpt/json_utils/utilities.py
+++ b/autogpts/autogpt/autogpt/json_utils/utilities.py
@@ -26,10 +26,10 @@ def extract_dict_from_response(response_content: str) -> dict[str, Any]:
# Response content comes from OpenAI as a Python `str(content_dict)`.
# `literal_eval` does the reverse of `str(dict)`.
- try:
- return ast.literal_eval(response_content)
- except BaseException as e:
- logger.info(f"Error parsing JSON response with literal_eval {e}")
- logger.debug(f"Invalid JSON received in response:\n{response_content}")
- # TODO: How to raise an error here without causing the program to exit?
- return {}
+ result = ast.literal_eval(response_content)
+ if not isinstance(result, dict):
+ raise ValueError(
+ f"Response '''{response_content}''' evaluated to "
+ f"non-dict value {repr(result)}"
+ )
+ return result