diff options
author | Reinier van der Leer <pwuts@agpt.co> | 2024-01-19 19:23:17 +0100 |
---|---|---|
committer | Reinier van der Leer <pwuts@agpt.co> | 2024-01-19 19:23:17 +0100 |
commit | fc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc (patch) | |
tree | 04adae4ddeb5768e949ef03ee334ddbb4ff5b087 | |
parent | fix(agent/serve): Fix task cost tracking persistence in `AgentProtocolServer` (diff) | |
download | Auto-GPT-fc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc.tar.gz Auto-GPT-fc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc.tar.bz2 Auto-GPT-fc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc.zip |
feat(agent/llm/openai): Include compatibility tool call extraction in LLM response parse-fix loop
-rw-r--r-- | autogpts/autogpt/autogpt/core/resource/model_providers/openai.py | 46 |
1 files changed, 22 insertions, 24 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index 506286949..af26ba961 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -377,29 +377,17 @@ class OpenAIProvider( **completion_kwargs, ) - _response_msg = _response.choices[0].message - if ( - tool_calls_compat_mode - and _response_msg.content - and not _response_msg.tool_calls - ): - tool_calls = list( - _tool_calls_compat_extract_calls(_response_msg.content) - ) - elif _response_msg.tool_calls: - tool_calls = [ - AssistantToolCall(**tc.dict()) for tc in _response_msg.tool_calls - ] - else: - tool_calls = None - - assistant_message = AssistantChatMessage( - content=_response_msg.content, - tool_calls=tool_calls, + _assistant_msg = _response.choices[0].message + assistant_msg = AssistantChatMessage( + content=_assistant_msg.content, + tool_calls=( + [AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls] + if _assistant_msg.tool_calls + else None + ), ) - response = ChatModelResponse( - response=assistant_message, + response=assistant_msg, model_info=OPEN_AI_CHAT_MODELS[model_name], prompt_tokens_used=( _response.usage.prompt_tokens if _response.usage else 0 @@ -418,11 +406,21 @@ class OpenAIProvider( # LLM fix its mistake(s). try: attempts += 1 - response.parsed_result = completion_parser(assistant_message) + + if ( + tool_calls_compat_mode + and assistant_msg.content + and not assistant_msg.tool_calls + ): + assistant_msg.tool_calls = list( + _tool_calls_compat_extract_calls(assistant_msg.content) + ) + + response.parsed_result = completion_parser(assistant_msg) break except Exception as e: self._logger.warning(f"Parsing attempt #{attempts} failed: {e}") - self._logger.debug(f"Parsing failed on response: '''{_response_msg}'''") + self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''") if attempts < self._configuration.fix_failed_parse_tries: model_prompt.append( ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}") @@ -722,7 +720,7 @@ def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCal else: block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL) if not block: - raise ValueError("Could not find tool calls block in response") + raise ValueError("Could not find tool_calls block in response") tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1)) for t in tool_calls: |