aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-01-19 19:23:17 +0100
committerGravatar Reinier van der Leer <pwuts@agpt.co> 2024-01-19 19:23:17 +0100
commitfc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc (patch)
tree04adae4ddeb5768e949ef03ee334ddbb4ff5b087
parentfix(agent/serve): Fix task cost tracking persistence in `AgentProtocolServer` (diff)
downloadAuto-GPT-fc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc.tar.gz
Auto-GPT-fc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc.tar.bz2
Auto-GPT-fc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc.zip
feat(agent/llm/openai): Include compatibility tool call extraction in LLM response parse-fix loop
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/openai.py46
1 files changed, 22 insertions, 24 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
index 506286949..af26ba961 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
@@ -377,29 +377,17 @@ class OpenAIProvider(
**completion_kwargs,
)
- _response_msg = _response.choices[0].message
- if (
- tool_calls_compat_mode
- and _response_msg.content
- and not _response_msg.tool_calls
- ):
- tool_calls = list(
- _tool_calls_compat_extract_calls(_response_msg.content)
- )
- elif _response_msg.tool_calls:
- tool_calls = [
- AssistantToolCall(**tc.dict()) for tc in _response_msg.tool_calls
- ]
- else:
- tool_calls = None
-
- assistant_message = AssistantChatMessage(
- content=_response_msg.content,
- tool_calls=tool_calls,
+ _assistant_msg = _response.choices[0].message
+ assistant_msg = AssistantChatMessage(
+ content=_assistant_msg.content,
+ tool_calls=(
+ [AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
+ if _assistant_msg.tool_calls
+ else None
+ ),
)
-
response = ChatModelResponse(
- response=assistant_message,
+ response=assistant_msg,
model_info=OPEN_AI_CHAT_MODELS[model_name],
prompt_tokens_used=(
_response.usage.prompt_tokens if _response.usage else 0
@@ -418,11 +406,21 @@ class OpenAIProvider(
# LLM fix its mistake(s).
try:
attempts += 1
- response.parsed_result = completion_parser(assistant_message)
+
+ if (
+ tool_calls_compat_mode
+ and assistant_msg.content
+ and not assistant_msg.tool_calls
+ ):
+ assistant_msg.tool_calls = list(
+ _tool_calls_compat_extract_calls(assistant_msg.content)
+ )
+
+ response.parsed_result = completion_parser(assistant_msg)
break
except Exception as e:
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
- self._logger.debug(f"Parsing failed on response: '''{_response_msg}'''")
+ self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''")
if attempts < self._configuration.fix_failed_parse_tries:
model_prompt.append(
ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}")
@@ -722,7 +720,7 @@ def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCal
else:
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
- raise ValueError("Could not find tool calls block in response")
+ raise ValueError("Could not find tool_calls block in response")
tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
for t in tool_calls: