aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/openai.py')
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/openai.py633
1 files changed, 390 insertions, 243 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
index 3aad03fb1..cc6acd7df 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
@@ -1,21 +1,27 @@
import enum
-import functools
import logging
-import math
import os
-import time
from pathlib import Path
-from typing import Callable, Optional, ParamSpec, TypeVar
+from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
-import openai
+import sentry_sdk
+import tenacity
import tiktoken
import yaml
-from openai.error import APIError, RateLimitError
+from openai._exceptions import APIStatusError, RateLimitError
+from openai.types import CreateEmbeddingResponse
+from openai.types.chat import (
+ ChatCompletion,
+ ChatCompletionMessage,
+ ChatCompletionMessageParam,
+)
from pydantic import SecretStr
from autogpt.core.configuration import Configurable, UserConfigurable
from autogpt.core.resource.model_providers.schema import (
- AssistantChatMessageDict,
+ AssistantChatMessage,
+ AssistantFunctionCall,
+ AssistantToolCall,
AssistantToolCallDict,
ChatMessage,
ChatModelInfo,
@@ -30,27 +36,28 @@ from autogpt.core.resource.model_providers.schema import (
ModelProviderConfiguration,
ModelProviderCredentials,
ModelProviderName,
- ModelProviderService,
ModelProviderSettings,
- ModelProviderUsage,
ModelTokenizer,
)
from autogpt.core.utils.json_schema import JSONSchema
+from autogpt.core.utils.json_utils import json_loads
_T = TypeVar("_T")
_P = ParamSpec("_P")
OpenAIEmbeddingParser = Callable[[Embedding], Embedding]
-OpenAIChatParser = Callable[[str], dict]
class OpenAIModelName(str, enum.Enum):
- ADA = "text-embedding-ada-002"
+ EMBEDDING_v2 = "text-embedding-ada-002"
+ EMBEDDING_v3_S = "text-embedding-3-small"
+ EMBEDDING_v3_L = "text-embedding-3-large"
GPT3_v1 = "gpt-3.5-turbo-0301"
GPT3_v2 = "gpt-3.5-turbo-0613"
GPT3_v2_16k = "gpt-3.5-turbo-16k-0613"
GPT3_v3 = "gpt-3.5-turbo-1106"
+ GPT3_v4 = "gpt-3.5-turbo-0125"
GPT3_ROLLING = "gpt-3.5-turbo"
GPT3_ROLLING_16k = "gpt-3.5-turbo-16k"
GPT3 = GPT3_ROLLING
@@ -61,22 +68,41 @@ class OpenAIModelName(str, enum.Enum):
GPT4_v2 = "gpt-4-0613"
GPT4_v2_32k = "gpt-4-32k-0613"
GPT4_v3 = "gpt-4-1106-preview"
+ GPT4_v3_VISION = "gpt-4-1106-vision-preview"
+ GPT4_v4 = "gpt-4-0125-preview"
GPT4_ROLLING = "gpt-4"
GPT4_ROLLING_32k = "gpt-4-32k"
+ GPT4_TURBO = "gpt-4-turbo-preview"
GPT4_VISION = "gpt-4-vision-preview"
GPT4 = GPT4_ROLLING
GPT4_32k = GPT4_ROLLING_32k
OPEN_AI_EMBEDDING_MODELS = {
- OpenAIModelName.ADA: EmbeddingModelInfo(
- name=OpenAIModelName.ADA,
- service=ModelProviderService.EMBEDDING,
- provider_name=ModelProviderName.OPENAI,
- prompt_token_cost=0.0001 / 1000,
- max_tokens=8191,
- embedding_dimensions=1536,
- ),
+ info.name: info
+ for info in [
+ EmbeddingModelInfo(
+ name=OpenAIModelName.EMBEDDING_v2,
+ provider_name=ModelProviderName.OPENAI,
+ prompt_token_cost=0.0001 / 1000,
+ max_tokens=8191,
+ embedding_dimensions=1536,
+ ),
+ EmbeddingModelInfo(
+ name=OpenAIModelName.EMBEDDING_v3_S,
+ provider_name=ModelProviderName.OPENAI,
+ prompt_token_cost=0.00002 / 1000,
+ max_tokens=8191,
+ embedding_dimensions=1536,
+ ),
+ EmbeddingModelInfo(
+ name=OpenAIModelName.EMBEDDING_v3_L,
+ provider_name=ModelProviderName.OPENAI,
+ prompt_token_cost=0.00013 / 1000,
+ max_tokens=8191,
+ embedding_dimensions=3072,
+ ),
+ ]
}
@@ -84,8 +110,7 @@ OPEN_AI_CHAT_MODELS = {
info.name: info
for info in [
ChatModelInfo(
- name=OpenAIModelName.GPT3,
- service=ModelProviderService.CHAT,
+ name=OpenAIModelName.GPT3_v1,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0015 / 1000,
completion_token_cost=0.002 / 1000,
@@ -93,8 +118,7 @@ OPEN_AI_CHAT_MODELS = {
has_function_call_api=True,
),
ChatModelInfo(
- name=OpenAIModelName.GPT3_16k,
- service=ModelProviderService.CHAT,
+ name=OpenAIModelName.GPT3_v2_16k,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.003 / 1000,
completion_token_cost=0.004 / 1000,
@@ -103,7 +127,6 @@ OPEN_AI_CHAT_MODELS = {
),
ChatModelInfo(
name=OpenAIModelName.GPT3_v3,
- service=ModelProviderService.CHAT,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.001 / 1000,
completion_token_cost=0.002 / 1000,
@@ -111,8 +134,15 @@ OPEN_AI_CHAT_MODELS = {
has_function_call_api=True,
),
ChatModelInfo(
- name=OpenAIModelName.GPT4,
- service=ModelProviderService.CHAT,
+ name=OpenAIModelName.GPT3_v4,
+ provider_name=ModelProviderName.OPENAI,
+ prompt_token_cost=0.0005 / 1000,
+ completion_token_cost=0.0015 / 1000,
+ max_tokens=16384,
+ has_function_call_api=True,
+ ),
+ ChatModelInfo(
+ name=OpenAIModelName.GPT4_v1,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.03 / 1000,
completion_token_cost=0.06 / 1000,
@@ -120,8 +150,7 @@ OPEN_AI_CHAT_MODELS = {
has_function_call_api=True,
),
ChatModelInfo(
- name=OpenAIModelName.GPT4_32k,
- service=ModelProviderService.CHAT,
+ name=OpenAIModelName.GPT4_v1_32k,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.06 / 1000,
completion_token_cost=0.12 / 1000,
@@ -129,8 +158,7 @@ OPEN_AI_CHAT_MODELS = {
has_function_call_api=True,
),
ChatModelInfo(
- name=OpenAIModelName.GPT4_v3,
- service=ModelProviderService.CHAT,
+ name=OpenAIModelName.GPT4_TURBO,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.01 / 1000,
completion_token_cost=0.03 / 1000,
@@ -141,18 +169,24 @@ OPEN_AI_CHAT_MODELS = {
}
# Copy entries for models with equivalent specs
chat_model_mapping = {
- OpenAIModelName.GPT3: [OpenAIModelName.GPT3_v1, OpenAIModelName.GPT3_v2],
- OpenAIModelName.GPT3_16k: [OpenAIModelName.GPT3_v2_16k],
- OpenAIModelName.GPT4: [OpenAIModelName.GPT4_v1, OpenAIModelName.GPT4_v2],
- OpenAIModelName.GPT4_32k: [
- OpenAIModelName.GPT4_v1_32k,
+ OpenAIModelName.GPT3_v1: [OpenAIModelName.GPT3_v2],
+ OpenAIModelName.GPT3_v2_16k: [OpenAIModelName.GPT3_16k],
+ OpenAIModelName.GPT3_v4: [OpenAIModelName.GPT3_ROLLING],
+ OpenAIModelName.GPT4_v1: [OpenAIModelName.GPT4_v2, OpenAIModelName.GPT4_ROLLING],
+ OpenAIModelName.GPT4_v1_32k: [
OpenAIModelName.GPT4_v2_32k,
+ OpenAIModelName.GPT4_32k,
+ ],
+ OpenAIModelName.GPT4_TURBO: [
+ OpenAIModelName.GPT4_v3,
+ OpenAIModelName.GPT4_v3_VISION,
+ OpenAIModelName.GPT4_v4,
+ OpenAIModelName.GPT4_VISION,
],
}
for base, copies in chat_model_mapping.items():
for copy in copies:
- copy_info = ChatModelInfo(**OPEN_AI_CHAT_MODELS[base].__dict__)
- copy_info.name = copy
+ copy_info = OPEN_AI_CHAT_MODELS[base].copy(update={"name": copy})
OPEN_AI_CHAT_MODELS[copy] = copy_info
if copy.endswith(("-0301", "-0314")):
copy_info.has_function_call_api = False
@@ -165,7 +199,7 @@ OPEN_AI_MODELS = {
class OpenAIConfiguration(ModelProviderConfiguration):
- pass
+ fix_failed_parse_tries: int = UserConfigurable(3)
class OpenAICredentials(ModelProviderCredentials):
@@ -186,32 +220,46 @@ class OpenAICredentials(ModelProviderCredentials):
),
)
api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION")
+ azure_endpoint: Optional[SecretStr] = None
azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
- def get_api_access_kwargs(self, model: str = "") -> dict[str, str]:
- credentials = {k: v for k, v in self.unmasked().items() if type(v) is str}
+ def get_api_access_kwargs(self) -> dict[str, str]:
+ kwargs = {
+ k: (v.get_secret_value() if type(v) is SecretStr else v)
+ for k, v in {
+ "api_key": self.api_key,
+ "base_url": self.api_base,
+ "organization": self.organization,
+ }.items()
+ if v is not None
+ }
+ if self.api_type == "azure":
+ kwargs["api_version"] = self.api_version
+ assert self.azure_endpoint, "Azure endpoint not configured"
+ kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value()
+ return kwargs
+
+ def get_model_access_kwargs(self, model: str) -> dict[str, str]:
+ kwargs = {"model": model}
if self.api_type == "azure" and model:
- azure_credentials = self._get_azure_access_kwargs(model)
- credentials.update(azure_credentials)
- return credentials
+ azure_kwargs = self._get_azure_access_kwargs(model)
+ kwargs.update(azure_kwargs)
+ return kwargs
def load_azure_config(self, config_file: Path) -> None:
with open(config_file) as file:
- config_params = yaml.load(file, Loader=yaml.FullLoader) or {}
+ config_params = yaml.load(file, Loader=yaml.SafeLoader) or {}
try:
- assert (
- azure_api_base := config_params.get("azure_api_base", "")
- ) != "", "Azure API base URL not set"
assert config_params.get(
"azure_model_map", {}
), "Azure model->deployment_id map is empty"
except AssertionError as e:
raise ValueError(*e.args)
- self.api_base = SecretStr(azure_api_base)
self.api_type = config_params.get("azure_api_type", "azure")
self.api_version = config_params.get("azure_api_version", "")
+ self.azure_endpoint = config_params.get("azure_endpoint")
self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
def _get_azure_access_kwargs(self, model: str) -> dict[str, str]:
@@ -224,21 +272,13 @@ class OpenAICredentials(ModelProviderCredentials):
raise ValueError(f"No Azure deployment ID configured for model '{model}'")
deployment_id = self.azure_model_to_deploy_id_map[model]
- if model in OPEN_AI_EMBEDDING_MODELS:
- return {"engine": deployment_id}
- else:
- return {"deployment_id": deployment_id}
-
-
-class OpenAIModelProviderBudget(ModelProviderBudget):
- graceful_shutdown_threshold: float = UserConfigurable()
- warning_threshold: float = UserConfigurable()
+ return {"model": deployment_id}
class OpenAISettings(ModelProviderSettings):
configuration: OpenAIConfiguration
credentials: Optional[OpenAICredentials]
- budget: OpenAIModelProviderBudget
+ budget: ModelProviderBudget
class OpenAIProvider(
@@ -248,53 +288,53 @@ class OpenAIProvider(
name="openai_provider",
description="Provides access to OpenAI's API.",
configuration=OpenAIConfiguration(
- retries_per_request=10,
+ retries_per_request=7,
),
credentials=None,
- budget=OpenAIModelProviderBudget(
- total_budget=math.inf,
- total_cost=0.0,
- remaining_budget=math.inf,
- usage=ModelProviderUsage(
- prompt_tokens=0,
- completion_tokens=0,
- total_tokens=0,
- ),
- graceful_shutdown_threshold=0.005,
- warning_threshold=0.01,
- ),
+ budget=ModelProviderBudget(),
)
_configuration: OpenAIConfiguration
+ _credentials: OpenAICredentials
+ _budget: ModelProviderBudget
def __init__(
self,
- settings: OpenAISettings,
- logger: logging.Logger,
+ settings: Optional[OpenAISettings] = None,
+ logger: Optional[logging.Logger] = None,
):
- assert settings.credentials, "Cannot create OpenAIProvider without credentials"
+ if not settings:
+ settings = self.default_settings.copy(deep=True)
+ if not settings.credentials:
+ settings.credentials = OpenAICredentials.from_env()
+
+ self._settings = settings
+
self._configuration = settings.configuration
self._credentials = settings.credentials
self._budget = settings.budget
- self._logger = logger
+ if self._credentials.api_type == "azure":
+ from openai import AsyncAzureOpenAI
- retry_handler = _OpenAIRetryHandler(
- logger=self._logger,
- num_retries=self._configuration.retries_per_request,
- )
+ # API key and org (if configured) are passed, the rest of the required
+ # credentials is loaded from the environment by the AzureOpenAI client.
+ self._client = AsyncAzureOpenAI(**self._credentials.get_api_access_kwargs())
+ else:
+ from openai import AsyncOpenAI
+
+ self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
+
+ self._logger = logger or logging.getLogger(__name__)
- self._create_chat_completion = retry_handler(_create_chat_completion)
- self._create_embedding = retry_handler(_create_embedding)
+ async def get_available_models(self) -> list[ChatModelInfo]:
+ _models = (await self._client.models.list()).data
+ return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS]
def get_token_limit(self, model_name: str) -> int:
"""Get the token limit for a given model."""
return OPEN_AI_MODELS[model_name].max_tokens
- def get_remaining_budget(self) -> float:
- """Get the remaining budget."""
- return self._budget.remaining_budget
-
@classmethod
def get_tokenizer(cls, model_name: OpenAIModelName) -> ModelTokenizer:
return tiktoken.encoding_for_model(model_name)
@@ -332,7 +372,7 @@ class OpenAIProvider(
try:
encoding = tiktoken.encoding_for_model(encoding_model)
except KeyError:
- cls._logger.warning(
+ logging.getLogger(__class__.__name__).warning(
f"Model {model_name} not found. Defaulting to cl100k_base encoding."
)
encoding = tiktoken.get_encoding("cl100k_base")
@@ -351,40 +391,101 @@ class OpenAIProvider(
self,
model_prompt: list[ChatMessage],
model_name: OpenAIModelName,
- completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
+ completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
+ max_output_tokens: Optional[int] = None,
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API."""
- completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
- tool_calls_compat_mode = functions and "tools" not in completion_kwargs
- if "messages" in completion_kwargs:
- model_prompt += completion_kwargs["messages"]
- del completion_kwargs["messages"]
-
- response = await self._create_chat_completion(
- messages=model_prompt,
- **completion_kwargs,
+ openai_messages, completion_kwargs = self._get_chat_completion_args(
+ model_prompt=model_prompt,
+ model_name=model_name,
+ functions=functions,
+ max_tokens=max_output_tokens,
+ **kwargs,
)
- response_args = {
- "model_info": OPEN_AI_CHAT_MODELS[model_name],
- "prompt_tokens_used": response.usage.prompt_tokens,
- "completion_tokens_used": response.usage.completion_tokens,
- }
+ tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)
+
+ total_cost = 0.0
+ attempts = 0
+ while True:
+ _response, _cost, t_input, t_output = await self._create_chat_completion(
+ messages=openai_messages,
+ **completion_kwargs,
+ )
+ total_cost += _cost
+
+ # If parsing the response fails, append the error to the prompt, and let the
+ # LLM fix its mistake(s).
+ attempts += 1
+ parse_errors: list[Exception] = []
- response_message = response.choices[0].message.to_dict_recursive()
- if tool_calls_compat_mode:
- response_message["tool_calls"] = _tool_calls_compat_extract_calls(
- response_message["content"]
+ _assistant_msg = _response.choices[0].message
+
+ tool_calls, _errors = self._parse_assistant_tool_calls(
+ _assistant_msg, tool_calls_compat_mode
+ )
+ parse_errors += _errors
+
+ assistant_msg = AssistantChatMessage(
+ content=_assistant_msg.content,
+ tool_calls=tool_calls or None,
)
- response = ChatModelResponse(
- response=response_message,
- parsed_result=completion_parser(response_message),
- **response_args,
- )
- self._budget.update_usage_and_cost(response)
- return response
+
+ parsed_result: _T = None # type: ignore
+ if not parse_errors:
+ try:
+ parsed_result = completion_parser(assistant_msg)
+ except Exception as e:
+ parse_errors.append(e)
+
+ if not parse_errors:
+ if attempts > 1:
+ self._logger.debug(
+ f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
+ )
+
+ return ChatModelResponse(
+ response=AssistantChatMessage(
+ content=_assistant_msg.content,
+ tool_calls=tool_calls or None,
+ ),
+ parsed_result=parsed_result,
+ model_info=OPEN_AI_CHAT_MODELS[model_name],
+ prompt_tokens_used=t_input,
+ completion_tokens_used=t_output,
+ )
+
+ else:
+ self._logger.debug(
+ f"Parsing failed on response: '''{_assistant_msg}'''"
+ )
+ self._logger.warning(
+ f"Parsing attempt #{attempts} failed: {parse_errors}"
+ )
+ for e in parse_errors:
+ sentry_sdk.capture_exception(
+ error=e,
+ extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
+ )
+
+ if attempts < self._configuration.fix_failed_parse_tries:
+ openai_messages.append(_assistant_msg.dict(exclude_none=True))
+ openai_messages.append(
+ {
+ "role": "system",
+ "content": (
+ "ERROR PARSING YOUR RESPONSE:\n\n"
+ + "\n\n".join(
+ f"{e.__class__.__name__}: {e}" for e in parse_errors
+ )
+ ),
+ }
+ )
+ continue
+ else:
+ raise parse_errors[0]
async def create_embedding(
self,
@@ -397,62 +498,68 @@ class OpenAIProvider(
embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs)
response = await self._create_embedding(text=text, **embedding_kwargs)
- response_args = {
- "model_info": OPEN_AI_EMBEDDING_MODELS[model_name],
- "prompt_tokens_used": response.usage.prompt_tokens,
- "completion_tokens_used": response.usage.completion_tokens,
- }
response = EmbeddingModelResponse(
- **response_args,
- embedding=embedding_parser(response.embeddings[0]),
+ embedding=embedding_parser(response.data[0].embedding),
+ model_info=OPEN_AI_EMBEDDING_MODELS[model_name],
+ prompt_tokens_used=response.usage.prompt_tokens,
+ completion_tokens_used=0,
)
self._budget.update_usage_and_cost(response)
return response
- def _get_completion_kwargs(
+ def _get_chat_completion_args(
self,
+ model_prompt: list[ChatMessage],
model_name: OpenAIModelName,
functions: Optional[list[CompletionModelFunction]] = None,
**kwargs,
- ) -> dict:
- """Get kwargs for completion API call.
+ ) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]:
+ """Prepare chat completion arguments and keyword arguments for API call.
Args:
- model: The model to use.
- kwargs: Keyword arguments to override the default values.
+ model_prompt: List of ChatMessages.
+ model_name: The model to use.
+ functions: Optional list of functions available to the LLM.
+ kwargs: Additional keyword arguments.
Returns:
- The kwargs for the chat API call.
-
+ list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
+ dict[str, Any]: Any other kwargs for the OpenAI call
"""
- completion_kwargs = {
- "model": model_name,
- **kwargs,
- **self._credentials.get_api_access_kwargs(model_name),
- }
+ kwargs.update(self._credentials.get_model_access_kwargs(model_name))
if functions:
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
- completion_kwargs["tools"] = [
+ kwargs["tools"] = [
{"type": "function", "function": f.schema} for f in functions
]
if len(functions) == 1:
# force the model to call the only specified function
- completion_kwargs["tool_choice"] = {
+ kwargs["tool_choice"] = {
"type": "function",
"function": {"name": functions[0].name},
}
else:
# Provide compatibility with older models
- _functions_compat_fix_kwargs(functions, completion_kwargs)
+ _functions_compat_fix_kwargs(functions, kwargs)
if extra_headers := self._configuration.extra_request_headers:
- if completion_kwargs.get("headers"):
- completion_kwargs["headers"].update(extra_headers)
- else:
- completion_kwargs["headers"] = extra_headers.copy()
+ kwargs["extra_headers"] = kwargs.get("extra_headers", {})
+ kwargs["extra_headers"].update(extra_headers.copy())
+
+ if "messages" in kwargs:
+ model_prompt += kwargs["messages"]
+ del kwargs["messages"]
+
+ openai_messages: list[ChatCompletionMessageParam] = [
+ message.dict(
+ include={"role", "content", "tool_calls", "name"},
+ exclude_none=True,
+ )
+ for message in model_prompt
+ ]
- return completion_kwargs
+ return openai_messages, kwargs
def _get_embedding_kwargs(
self,
@@ -469,120 +576,159 @@ class OpenAIProvider(
The kwargs for the embedding API call.
"""
- embedding_kwargs = {
- "model": model_name,
- **kwargs,
- **self._credentials.unmasked(),
- }
+ kwargs.update(self._credentials.get_model_access_kwargs(model_name))
if extra_headers := self._configuration.extra_request_headers:
- if embedding_kwargs.get("headers"):
- embedding_kwargs["headers"].update(extra_headers)
- else:
- embedding_kwargs["headers"] = extra_headers.copy()
-
- return embedding_kwargs
-
- def __repr__(self):
- return "OpenAIProvider()"
+ kwargs["extra_headers"] = kwargs.get("extra_headers", {})
+ kwargs["extra_headers"].update(extra_headers.copy())
+ return kwargs
-async def _create_embedding(text: str, *_, **kwargs) -> openai.Embedding:
- """Embed text using the OpenAI API.
-
- Args:
- text str: The text to embed.
- model str: The name of the model to use.
-
- Returns:
- str: The embedding.
- """
- return await openai.Embedding.acreate(
- input=[text],
+ async def _create_chat_completion(
+ self,
+ messages: list[ChatCompletionMessageParam],
+ model: OpenAIModelName,
+ *_,
**kwargs,
- )
-
-
-async def _create_chat_completion(
- messages: list[ChatMessage], *_, **kwargs
-) -> openai.Completion:
- """Create a chat completion using the OpenAI API.
+ ) -> tuple[ChatCompletion, float, int, int]:
+ """
+ Create a chat completion using the OpenAI API with retry handling.
- Args:
- messages: The prompt to use.
+ Params:
+ openai_messages: List of OpenAI-consumable message dict objects
+ model: The model to use for the completion
- Returns:
- The completion.
- """
- raw_messages = [
- message.dict(include={"role", "content", "tool_calls", "name"})
- for message in messages
- ]
- return await openai.ChatCompletion.acreate(
- messages=raw_messages,
- **kwargs,
- )
+ Returns:
+ ChatCompletion: The chat completion response object
+ float: The cost ($) of this completion
+ int: Number of prompt tokens used
+ int: Number of completion tokens used
+ """
+ @self._retry_api_request
+ async def _create_chat_completion_with_retry(
+ messages: list[ChatCompletionMessageParam], **kwargs
+ ) -> ChatCompletion:
+ return await self._client.chat.completions.create(
+ messages=messages, # type: ignore
+ **kwargs,
+ )
-class _OpenAIRetryHandler:
- """Retry Handler for OpenAI API call.
+ completion = await _create_chat_completion_with_retry(
+ messages, model=model, **kwargs
+ )
- Args:
- num_retries int: Number of retries. Defaults to 10.
- backoff_base float: Base for exponential backoff. Defaults to 2.
- warn_user bool: Whether to warn the user. Defaults to True.
- """
+ if completion.usage:
+ prompt_tokens_used = completion.usage.prompt_tokens
+ completion_tokens_used = completion.usage.completion_tokens
+ else:
+ prompt_tokens_used = completion_tokens_used = 0
- _retry_limit_msg = "Error: Reached rate limit, passing..."
- _api_key_error_msg = (
- "Please double check that you have setup a PAID OpenAI API Account. You can "
- "read more here: https://docs.agpt.co/setup/#getting-an-openai-api-key"
- )
- _backoff_msg = "Error: API Bad gateway. Waiting {backoff} seconds..."
+ cost = self._budget.update_usage_and_cost(
+ model_info=OPEN_AI_CHAT_MODELS[model],
+ input_tokens_used=prompt_tokens_used,
+ output_tokens_used=completion_tokens_used,
+ )
+ self._logger.debug(
+ f"Completion usage: {prompt_tokens_used} input, "
+ f"{completion_tokens_used} output - ${round(cost, 5)}"
+ )
+ return completion, cost, prompt_tokens_used, completion_tokens_used
- def __init__(
- self,
- logger: logging.Logger,
- num_retries: int = 10,
- backoff_base: float = 2.0,
- warn_user: bool = True,
+ def _parse_assistant_tool_calls(
+ self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
):
- self._logger = logger
- self._num_retries = num_retries
- self._backoff_base = backoff_base
- self._warn_user = warn_user
-
- def _log_rate_limit_error(self) -> None:
- self._logger.debug(self._retry_limit_msg)
- if self._warn_user:
- self._logger.warning(self._api_key_error_msg)
- self._warn_user = False
-
- def _backoff(self, attempt: int) -> None:
- backoff = self._backoff_base ** (attempt + 2)
- self._logger.debug(self._backoff_msg.format(backoff=backoff))
- time.sleep(backoff)
-
- def __call__(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
- @functools.wraps(func)
- async def _wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
- num_attempts = self._num_retries + 1 # +1 for the first attempt
- for attempt in range(1, num_attempts + 1):
- try:
- return await func(*args, **kwargs)
+ tool_calls: list[AssistantToolCall] = []
+ parse_errors: list[Exception] = []
- except RateLimitError:
- if attempt == num_attempts:
- raise
- self._log_rate_limit_error()
-
- except APIError as e:
- if (e.http_status != 502) or (attempt == num_attempts):
- raise
+ if assistant_message.tool_calls:
+ for _tc in assistant_message.tool_calls:
+ try:
+ parsed_arguments = json_loads(_tc.function.arguments)
+ except Exception as e:
+ err_message = (
+ f"Decoding arguments for {_tc.function.name} failed: "
+ + str(e.args[0])
+ )
+ parse_errors.append(
+ type(e)(err_message, *e.args[1:]).with_traceback(
+ e.__traceback__
+ )
+ )
+ continue
+
+ tool_calls.append(
+ AssistantToolCall(
+ id=_tc.id,
+ type=_tc.type,
+ function=AssistantFunctionCall(
+ name=_tc.function.name,
+ arguments=parsed_arguments,
+ ),
+ )
+ )
+
+ # If parsing of all tool calls succeeds in the end, we ignore any issues
+ if len(tool_calls) == len(assistant_message.tool_calls):
+ parse_errors = []
+
+ elif compat_mode and assistant_message.content:
+ try:
+ tool_calls = list(
+ _tool_calls_compat_extract_calls(assistant_message.content)
+ )
+ except Exception as e:
+ parse_errors.append(e)
+
+ return tool_calls, parse_errors
+
+ def _create_embedding(
+ self, text: str, *_, **kwargs
+ ) -> Coroutine[None, None, CreateEmbeddingResponse]:
+ """Create an embedding using the OpenAI API with retry handling."""
+
+ @self._retry_api_request
+ async def _create_embedding_with_retry(
+ text: str, *_, **kwargs
+ ) -> CreateEmbeddingResponse:
+ return await self._client.embeddings.create(
+ input=[text],
+ **kwargs,
+ )
- self._backoff(attempt)
+ return _create_embedding_with_retry(text, *_, **kwargs)
+
+ def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
+ _log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG)
+
+ def _log_on_fail(retry_state: tenacity.RetryCallState) -> None:
+ _log_retry_debug_message(retry_state)
+
+ if (
+ retry_state.attempt_number == 0
+ and retry_state.outcome
+ and isinstance(retry_state.outcome.exception(), RateLimitError)
+ ):
+ self._logger.warning(
+ "Please double check that you have setup a PAID OpenAI API Account."
+ " You can read more here: "
+ "https://docs.agpt.co/setup/#getting-an-openai-api-key"
+ )
+
+ return tenacity.retry(
+ retry=(
+ tenacity.retry_if_exception_type(RateLimitError)
+ | tenacity.retry_if_exception(
+ lambda e: isinstance(e, APIStatusError) and e.status_code == 502
+ )
+ ),
+ wait=tenacity.wait_exponential(),
+ stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
+ after=_log_on_fail,
+ )(func)
- return _wrapped
+ def __repr__(self):
+ return "OpenAIProvider()"
def format_function_specs_as_typescript_ns(
@@ -708,21 +854,22 @@ def _functions_compat_fix_kwargs(
]
-def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDict]:
- import json
+def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
import re
+ import uuid
logging.debug(f"Trying to extract tool calls from response:\n{response}")
if response[0] == "[":
- tool_calls: list[AssistantToolCallDict] = json.loads(response)
+ tool_calls: list[AssistantToolCallDict] = json_loads(response)
else:
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
- raise ValueError("Could not find tool calls block in response")
- tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
+ raise ValueError("Could not find tool_calls block in response")
+ tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1))
for t in tool_calls:
+ t["id"] = str(uuid.uuid4())
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
- return tool_calls
+ yield AssistantToolCall.parse_obj(t)