aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-04-16 10:38:49 +0200
committerGravatar GitHub <noreply@github.com> 2024-04-16 10:38:49 +0200
commit7082e63b115d72440ee2dfe3f545fa3dcba490d5 (patch)
treee82f8153b9b1ab021af0be80d96ab584350a46df
parentfeat(agent): Improve feedback in `create_chat_completion` parse-fix mechanism (diff)
downloadAuto-GPT-7082e63b115d72440ee2dfe3f545fa3dcba490d5.tar.gz
Auto-GPT-7082e63b115d72440ee2dfe3f545fa3dcba490d5.tar.bz2
Auto-GPT-7082e63b115d72440ee2dfe3f545fa3dcba490d5.zip
refactor(agent): Refactor & improve `create_chat_completion` (#7082)
* refactor(agent/core): Rearrange and split up `OpenAIProvider.create_chat_completion` - Rearrange to reduce complexity, improve separation/abstraction of concerns, and allow multiple points of failure during parsing - Move conversion from `ChatMessage` to `openai.types.ChatCompletionMessageParam` to `_get_chat_completion_args` - Move token usage and cost tracking boilerplate code to `_create_chat_completion` - Move tool call conversion/parsing to `_parse_assistant_tool_calls` (new) * fix(agent/core): Handle decoding of function call arguments in `create_chat_completion` - Amend `model_providers.schema`: change type of `arguments` from `str` to `dict[str, Any]` on `AssistantFunctionCall` and `AssistantFunctionCallDict` - Implement robust and transparent parsing in `OpenAIProvider._parse_assistant_tool_calls` - Remove now unnecessary `json_loads` calls throughout codebase * feat(agent/utils): Improve conditions and errors in `json_loads` - Include all decoding errors when raising a ValueError on decode failure - Use errors returned by `return_errors` instead of an error buffer - Fix check for decode failure
-rw-r--r--autogpts/autogpt/autogpt/agent_factory/profile_generator.py5
-rw-r--r--autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py4
-rw-r--r--autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py5
-rw-r--r--autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py5
-rw-r--r--autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py5
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/openai.py273
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/schema.py5
-rw-r--r--autogpts/autogpt/autogpt/core/utils/json_utils.py19
8 files changed, 208 insertions, 113 deletions
diff --git a/autogpts/autogpt/autogpt/agent_factory/profile_generator.py b/autogpts/autogpt/autogpt/agent_factory/profile_generator.py
index 889b7f2d4..78afbe51a 100644
--- a/autogpts/autogpt/autogpt/agent_factory/profile_generator.py
+++ b/autogpts/autogpt/autogpt/agent_factory/profile_generator.py
@@ -15,7 +15,6 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
-from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
@@ -203,9 +202,7 @@ class AgentProfileGenerator(PromptStrategy):
f"LLM did not call {self._create_agent_function.name} function; "
"agent profile creation failed"
)
- arguments: object = json_loads(
- response_content.tool_calls[0].function.arguments
- )
+ arguments: object = response_content.tool_calls[0].function.arguments
ai_profile = AIProfile(
ai_name=arguments.get("name"),
ai_role=arguments.get("description"),
diff --git a/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py b/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
index 0234c59a5..994df6181 100644
--- a/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
+++ b/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
@@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
-from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads
+from autogpt.core.utils.json_utils import extract_dict_from_json
from autogpt.prompts.utils import format_numbered_list, indent
@@ -436,7 +436,7 @@ def extract_command(
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
assistant_reply_json["command"] = {
"name": assistant_reply.tool_calls[0].function.name,
- "args": json_loads(assistant_reply.tool_calls[0].function.arguments),
+ "args": assistant_reply.tool_calls[0].function.arguments,
}
try:
if not isinstance(assistant_reply_json, dict):
diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py
index d26d86fd6..ae137a985 100644
--- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py
+++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py
@@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
-from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
@@ -195,9 +194,7 @@ class InitialPlan(PromptStrategy):
f"LLM did not call {self._create_plan_function.name} function; "
"plan creation failed"
)
- parsed_response: object = json_loads(
- response_content.tool_calls[0].function.arguments
- )
+ parsed_response: object = response_content.tool_calls[0].function.arguments
parsed_response["task_list"] = [
Task.parse_obj(task) for task in parsed_response["task_list"]
]
diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py
index d030c05e1..133b4590d 100644
--- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py
+++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py
@@ -9,7 +9,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
-from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
@@ -141,9 +140,7 @@ class NameAndGoals(PromptStrategy):
f"LLM did not call {self._create_agent_function} function; "
"agent profile creation failed"
)
- parsed_response = json_loads(
- response_content.tool_calls[0].function.arguments
- )
+ parsed_response = response_content.tool_calls[0].function.arguments
except KeyError:
logger.debug(f"Failed to parse this response content: {response_content}")
raise
diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py
index dec67c295..0d6daad2e 100644
--- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py
+++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py
@@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
-from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
@@ -188,9 +187,7 @@ class NextAbility(PromptStrategy):
raise ValueError("LLM did not call any function")
function_name = response_content.tool_calls[0].function.name
- function_arguments = json_loads(
- response_content.tool_calls[0].function.arguments
- )
+ function_arguments = response_content.tool_calls[0].function.arguments
parsed_response = {
"motivation": function_arguments.pop("motivation"),
"self_criticism": function_arguments.pop("self_criticism"),
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
index d68254a9c..cd01b496a 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
@@ -3,7 +3,7 @@ import logging
import math
import os
from pathlib import Path
-from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
+from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
import sentry_sdk
import tenacity
@@ -11,12 +11,17 @@ import tiktoken
import yaml
from openai._exceptions import APIStatusError, RateLimitError
from openai.types import CreateEmbeddingResponse
-from openai.types.chat import ChatCompletion
+from openai.types.chat import (
+ ChatCompletion,
+ ChatCompletionMessage,
+ ChatCompletionMessageParam,
+)
from pydantic import SecretStr
from autogpt.core.configuration import Configurable, UserConfigurable
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessage,
+ AssistantFunctionCall,
AssistantToolCall,
AssistantToolCallDict,
ChatMessage,
@@ -406,83 +411,90 @@ class OpenAIProvider(
) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API."""
- completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
- tool_calls_compat_mode = functions and "tools" not in completion_kwargs
- if "messages" in completion_kwargs:
- model_prompt += completion_kwargs["messages"]
- del completion_kwargs["messages"]
+ openai_messages, completion_kwargs = self._get_chat_completion_args(
+ model_prompt, model_name, functions, **kwargs
+ )
+ tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)
- cost = 0.0
+ total_cost = 0.0
attempts = 0
while True:
- _response = await self._create_chat_completion(
- messages=model_prompt,
+ _response, _cost, t_input, t_output = await self._create_chat_completion(
+ messages=openai_messages,
**completion_kwargs,
)
+ total_cost += _cost
+
+ # If parsing the response fails, append the error to the prompt, and let the
+ # LLM fix its mistake(s).
+ attempts += 1
+ parse_errors: list[Exception] = []
_assistant_msg = _response.choices[0].message
+
+ tool_calls, _errors = self._parse_assistant_tool_calls(
+ _assistant_msg, tool_calls_compat_mode
+ )
+ parse_errors += _errors
+
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
- tool_calls=(
- [AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
- if _assistant_msg.tool_calls
- else None
- ),
- )
- response = ChatModelResponse(
- response=assistant_msg,
- model_info=OPEN_AI_CHAT_MODELS[model_name],
- prompt_tokens_used=(
- _response.usage.prompt_tokens if _response.usage else 0
- ),
- completion_tokens_used=(
- _response.usage.completion_tokens if _response.usage else 0
- ),
- )
- cost += self._budget.update_usage_and_cost(response)
- self._logger.debug(
- f"Completion usage: {response.prompt_tokens_used} input, "
- f"{response.completion_tokens_used} output - ${round(cost, 5)}"
+ tool_calls=tool_calls or None,
)
- # If parsing the response fails, append the error to the prompt, and let the
- # LLM fix its mistake(s).
- try:
- attempts += 1
-
- if (
- tool_calls_compat_mode
- and assistant_msg.content
- and not assistant_msg.tool_calls
- ):
- assistant_msg.tool_calls = list(
- _tool_calls_compat_extract_calls(assistant_msg.content)
+ parsed_result: _T = None # type: ignore
+ if not parse_errors:
+ try:
+ parsed_result = completion_parser(assistant_msg)
+ except Exception as e:
+ parse_errors.append(e)
+
+ if not parse_errors:
+ if attempts > 1:
+ self._logger.debug(
+ f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
)
- response.parsed_result = completion_parser(assistant_msg)
- break
- except Exception as e:
- self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
- self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''")
- sentry_sdk.capture_exception(
- error=e,
- extras={"assistant_msg": assistant_msg, "i_attempt": attempts},
+ return ChatModelResponse(
+ response=AssistantChatMessage(
+ content=_assistant_msg.content,
+ tool_calls=tool_calls or None,
+ ),
+ parsed_result=parsed_result,
+ model_info=OPEN_AI_CHAT_MODELS[model_name],
+ prompt_tokens_used=t_input,
+ completion_tokens_used=t_output,
)
+
+ else:
+ self._logger.debug(
+ f"Parsing failed on response: '''{_assistant_msg}'''"
+ )
+ self._logger.warning(
+ f"Parsing attempt #{attempts} failed: {parse_errors}"
+ )
+ for e in parse_errors:
+ sentry_sdk.capture_exception(
+ error=e,
+ extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
+ )
+
if attempts < self._configuration.fix_failed_parse_tries:
- model_prompt.append(assistant_msg)
- model_prompt.append(
- ChatMessage.system(
- "ERROR PARSING YOUR RESPONSE:\n\n"
- f"{e.__class__.__name__}: {e}"
- )
+ openai_messages.append(_assistant_msg.dict(exclude_none=True))
+ openai_messages.append(
+ {
+ "role": "system",
+ "content": (
+ "ERROR PARSING YOUR RESPONSE:\n\n"
+ + "\n\n".join(
+ f"{e.__class__.__name__}: {e}" for e in parse_errors
+ )
+ ),
+ }
)
+ continue
else:
- raise
-
- if attempts > 1:
- self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}")
-
- return response
+ raise parse_errors[0]
async def create_embedding(
self,
@@ -504,21 +516,24 @@ class OpenAIProvider(
self._budget.update_usage_and_cost(response)
return response
- def _get_completion_kwargs(
+ def _get_chat_completion_args(
self,
+ model_prompt: list[ChatMessage],
model_name: OpenAIModelName,
functions: Optional[list[CompletionModelFunction]] = None,
**kwargs,
- ) -> dict:
- """Get kwargs for completion API call.
+ ) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]:
+ """Prepare chat completion arguments and keyword arguments for API call.
Args:
- model: The model to use.
- kwargs: Keyword arguments to override the default values.
+ model_prompt: List of ChatMessages.
+ model_name: The model to use.
+ functions: Optional list of functions available to the LLM.
+ kwargs: Additional keyword arguments.
Returns:
- The kwargs for the chat API call.
-
+ list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
+ dict[str, Any]: Any other kwargs for the OpenAI call
"""
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
@@ -541,7 +556,19 @@ class OpenAIProvider(
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
kwargs["extra_headers"].update(extra_headers.copy())
- return kwargs
+ if "messages" in kwargs:
+ model_prompt += kwargs["messages"]
+ del kwargs["messages"]
+
+ openai_messages: list[ChatCompletionMessageParam] = [
+ message.dict(
+ include={"role", "content", "tool_calls", "name"},
+ exclude_none=True,
+ )
+ for message in model_prompt
+ ]
+
+ return openai_messages, kwargs
def _get_embedding_kwargs(
self,
@@ -566,28 +593,106 @@ class OpenAIProvider(
return kwargs
- def _create_chat_completion(
- self, messages: list[ChatMessage], *_, **kwargs
- ) -> Coroutine[None, None, ChatCompletion]:
- """Create a chat completion using the OpenAI API with retry handling."""
+ async def _create_chat_completion(
+ self,
+ messages: list[ChatCompletionMessageParam],
+ model: OpenAIModelName,
+ *_,
+ **kwargs,
+ ) -> tuple[ChatCompletion, float, int, int]:
+ """
+ Create a chat completion using the OpenAI API with retry handling.
+
+ Params:
+ openai_messages: List of OpenAI-consumable message dict objects
+ model: The model to use for the completion
+
+ Returns:
+ ChatCompletion: The chat completion response object
+ float: The cost ($) of this completion
+ int: Number of prompt tokens used
+ int: Number of completion tokens used
+ """
@self._retry_api_request
async def _create_chat_completion_with_retry(
- messages: list[ChatMessage], *_, **kwargs
+ messages: list[ChatCompletionMessageParam], **kwargs
) -> ChatCompletion:
- raw_messages = [
- message.dict(
- include={"role", "content", "tool_calls", "name"},
- exclude_none=True,
- )
- for message in messages
- ]
return await self._client.chat.completions.create(
- messages=raw_messages, # type: ignore
+ messages=messages, # type: ignore
**kwargs,
)
- return _create_chat_completion_with_retry(messages, *_, **kwargs)
+ completion = await _create_chat_completion_with_retry(
+ messages, model=model, **kwargs
+ )
+
+ if completion.usage:
+ prompt_tokens_used = completion.usage.prompt_tokens
+ completion_tokens_used = completion.usage.completion_tokens
+ else:
+ prompt_tokens_used = completion_tokens_used = 0
+
+ cost = self._budget.update_usage_and_cost(
+ ChatModelResponse(
+ response=AssistantChatMessage(content=None),
+ model_info=OPEN_AI_CHAT_MODELS[model],
+ prompt_tokens_used=prompt_tokens_used,
+ completion_tokens_used=completion_tokens_used,
+ )
+ )
+ self._logger.debug(
+ f"Completion usage: {prompt_tokens_used} input, "
+ f"{completion_tokens_used} output - ${round(cost, 5)}"
+ )
+ return completion, cost, prompt_tokens_used, completion_tokens_used
+
+ def _parse_assistant_tool_calls(
+ self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
+ ):
+ tool_calls: list[AssistantToolCall] = []
+ parse_errors: list[Exception] = []
+
+ if assistant_message.tool_calls:
+ for _tc in assistant_message.tool_calls:
+ try:
+ parsed_arguments = json_loads(_tc.function.arguments)
+ except Exception as e:
+ err_message = (
+ f"Decoding arguments for {_tc.function.name} failed: "
+ + str(e.args[0])
+ )
+ parse_errors.append(
+ type(e)(err_message, *e.args[1:]).with_traceback(
+ e.__traceback__
+ )
+ )
+ continue
+
+ tool_calls.append(
+ AssistantToolCall(
+ id=_tc.id,
+ type=_tc.type,
+ function=AssistantFunctionCall(
+ name=_tc.function.name,
+ arguments=parsed_arguments,
+ ),
+ )
+ )
+
+ # If parsing of all tool calls succeeds in the end, we ignore any issues
+ if len(tool_calls) == len(assistant_message.tool_calls):
+ parse_errors = []
+
+ elif compat_mode and assistant_message.content:
+ try:
+ tool_calls = list(
+ _tool_calls_compat_extract_calls(assistant_message.content)
+ )
+ except Exception as e:
+ parse_errors.append(e)
+
+ return tool_calls, parse_errors
def _create_embedding(
self, text: str, *_, **kwargs
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
index 43d4bd296..cc0030995 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py
@@ -2,6 +2,7 @@ import abc
import enum
import math
from typing import (
+ Any,
Callable,
ClassVar,
Generic,
@@ -68,12 +69,12 @@ class ChatMessageDict(TypedDict):
class AssistantFunctionCall(BaseModel):
name: str
- arguments: str
+ arguments: dict[str, Any]
class AssistantFunctionCallDict(TypedDict):
name: str
- arguments: str
+ arguments: dict[str, Any]
class AssistantToolCall(BaseModel):
diff --git a/autogpts/autogpt/autogpt/core/utils/json_utils.py b/autogpts/autogpt/autogpt/core/utils/json_utils.py
index 664cb87c1..0374a85c1 100644
--- a/autogpts/autogpt/autogpt/core/utils/json_utils.py
+++ b/autogpts/autogpt/autogpt/core/utils/json_utils.py
@@ -1,4 +1,3 @@
-import io
import logging
import re
from typing import Any
@@ -32,16 +31,18 @@ def json_loads(json_str: str) -> Any:
if match:
json_str = match.group(1).strip()
- error_buffer = io.StringIO()
- json_result = demjson3.decode(
- json_str, return_errors=True, write_errors=error_buffer
- )
+ json_result = demjson3.decode(json_str, return_errors=True)
+ assert json_result is not None # by virtue of return_errors=True
- if error_buffer.getvalue():
- logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}")
+ if json_result.errors:
+ logger.debug(
+ "JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
+ )
- if json_result is None:
- raise ValueError(f"Failed to parse JSON string: {json_str}")
+ if json_result.object is demjson3.undefined:
+ raise ValueError(
+ f"Failed to parse JSON string: {json_str}", *json_result.errors
+ )
return json_result.object