diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/openai.py')
-rw-r--r-- | autogpts/autogpt/autogpt/core/resource/model_providers/openai.py | 633 |
1 files changed, 390 insertions, 243 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index 3aad03fb1..cc6acd7df 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -1,21 +1,27 @@ import enum -import functools import logging -import math import os -import time from pathlib import Path -from typing import Callable, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar -import openai +import sentry_sdk +import tenacity import tiktoken import yaml -from openai.error import APIError, RateLimitError +from openai._exceptions import APIStatusError, RateLimitError +from openai.types import CreateEmbeddingResponse +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageParam, +) from pydantic import SecretStr from autogpt.core.configuration import Configurable, UserConfigurable from autogpt.core.resource.model_providers.schema import ( - AssistantChatMessageDict, + AssistantChatMessage, + AssistantFunctionCall, + AssistantToolCall, AssistantToolCallDict, ChatMessage, ChatModelInfo, @@ -30,27 +36,28 @@ from autogpt.core.resource.model_providers.schema import ( ModelProviderConfiguration, ModelProviderCredentials, ModelProviderName, - ModelProviderService, ModelProviderSettings, - ModelProviderUsage, ModelTokenizer, ) from autogpt.core.utils.json_schema import JSONSchema +from autogpt.core.utils.json_utils import json_loads _T = TypeVar("_T") _P = ParamSpec("_P") OpenAIEmbeddingParser = Callable[[Embedding], Embedding] -OpenAIChatParser = Callable[[str], dict] class OpenAIModelName(str, enum.Enum): - ADA = "text-embedding-ada-002" + EMBEDDING_v2 = "text-embedding-ada-002" + EMBEDDING_v3_S = "text-embedding-3-small" + EMBEDDING_v3_L = "text-embedding-3-large" GPT3_v1 = "gpt-3.5-turbo-0301" GPT3_v2 = "gpt-3.5-turbo-0613" GPT3_v2_16k = "gpt-3.5-turbo-16k-0613" GPT3_v3 = "gpt-3.5-turbo-1106" + GPT3_v4 = "gpt-3.5-turbo-0125" GPT3_ROLLING = "gpt-3.5-turbo" GPT3_ROLLING_16k = "gpt-3.5-turbo-16k" GPT3 = GPT3_ROLLING @@ -61,22 +68,41 @@ class OpenAIModelName(str, enum.Enum): GPT4_v2 = "gpt-4-0613" GPT4_v2_32k = "gpt-4-32k-0613" GPT4_v3 = "gpt-4-1106-preview" + GPT4_v3_VISION = "gpt-4-1106-vision-preview" + GPT4_v4 = "gpt-4-0125-preview" GPT4_ROLLING = "gpt-4" GPT4_ROLLING_32k = "gpt-4-32k" + GPT4_TURBO = "gpt-4-turbo-preview" GPT4_VISION = "gpt-4-vision-preview" GPT4 = GPT4_ROLLING GPT4_32k = GPT4_ROLLING_32k OPEN_AI_EMBEDDING_MODELS = { - OpenAIModelName.ADA: EmbeddingModelInfo( - name=OpenAIModelName.ADA, - service=ModelProviderService.EMBEDDING, - provider_name=ModelProviderName.OPENAI, - prompt_token_cost=0.0001 / 1000, - max_tokens=8191, - embedding_dimensions=1536, - ), + info.name: info + for info in [ + EmbeddingModelInfo( + name=OpenAIModelName.EMBEDDING_v2, + provider_name=ModelProviderName.OPENAI, + prompt_token_cost=0.0001 / 1000, + max_tokens=8191, + embedding_dimensions=1536, + ), + EmbeddingModelInfo( + name=OpenAIModelName.EMBEDDING_v3_S, + provider_name=ModelProviderName.OPENAI, + prompt_token_cost=0.00002 / 1000, + max_tokens=8191, + embedding_dimensions=1536, + ), + EmbeddingModelInfo( + name=OpenAIModelName.EMBEDDING_v3_L, + provider_name=ModelProviderName.OPENAI, + prompt_token_cost=0.00013 / 1000, + max_tokens=8191, + embedding_dimensions=3072, + ), + ] } @@ -84,8 +110,7 @@ OPEN_AI_CHAT_MODELS = { info.name: info for info in [ ChatModelInfo( - name=OpenAIModelName.GPT3, - service=ModelProviderService.CHAT, + name=OpenAIModelName.GPT3_v1, provider_name=ModelProviderName.OPENAI, prompt_token_cost=0.0015 / 1000, completion_token_cost=0.002 / 1000, @@ -93,8 +118,7 @@ OPEN_AI_CHAT_MODELS = { has_function_call_api=True, ), ChatModelInfo( - name=OpenAIModelName.GPT3_16k, - service=ModelProviderService.CHAT, + name=OpenAIModelName.GPT3_v2_16k, provider_name=ModelProviderName.OPENAI, prompt_token_cost=0.003 / 1000, completion_token_cost=0.004 / 1000, @@ -103,7 +127,6 @@ OPEN_AI_CHAT_MODELS = { ), ChatModelInfo( name=OpenAIModelName.GPT3_v3, - service=ModelProviderService.CHAT, provider_name=ModelProviderName.OPENAI, prompt_token_cost=0.001 / 1000, completion_token_cost=0.002 / 1000, @@ -111,8 +134,15 @@ OPEN_AI_CHAT_MODELS = { has_function_call_api=True, ), ChatModelInfo( - name=OpenAIModelName.GPT4, - service=ModelProviderService.CHAT, + name=OpenAIModelName.GPT3_v4, + provider_name=ModelProviderName.OPENAI, + prompt_token_cost=0.0005 / 1000, + completion_token_cost=0.0015 / 1000, + max_tokens=16384, + has_function_call_api=True, + ), + ChatModelInfo( + name=OpenAIModelName.GPT4_v1, provider_name=ModelProviderName.OPENAI, prompt_token_cost=0.03 / 1000, completion_token_cost=0.06 / 1000, @@ -120,8 +150,7 @@ OPEN_AI_CHAT_MODELS = { has_function_call_api=True, ), ChatModelInfo( - name=OpenAIModelName.GPT4_32k, - service=ModelProviderService.CHAT, + name=OpenAIModelName.GPT4_v1_32k, provider_name=ModelProviderName.OPENAI, prompt_token_cost=0.06 / 1000, completion_token_cost=0.12 / 1000, @@ -129,8 +158,7 @@ OPEN_AI_CHAT_MODELS = { has_function_call_api=True, ), ChatModelInfo( - name=OpenAIModelName.GPT4_v3, - service=ModelProviderService.CHAT, + name=OpenAIModelName.GPT4_TURBO, provider_name=ModelProviderName.OPENAI, prompt_token_cost=0.01 / 1000, completion_token_cost=0.03 / 1000, @@ -141,18 +169,24 @@ OPEN_AI_CHAT_MODELS = { } # Copy entries for models with equivalent specs chat_model_mapping = { - OpenAIModelName.GPT3: [OpenAIModelName.GPT3_v1, OpenAIModelName.GPT3_v2], - OpenAIModelName.GPT3_16k: [OpenAIModelName.GPT3_v2_16k], - OpenAIModelName.GPT4: [OpenAIModelName.GPT4_v1, OpenAIModelName.GPT4_v2], - OpenAIModelName.GPT4_32k: [ - OpenAIModelName.GPT4_v1_32k, + OpenAIModelName.GPT3_v1: [OpenAIModelName.GPT3_v2], + OpenAIModelName.GPT3_v2_16k: [OpenAIModelName.GPT3_16k], + OpenAIModelName.GPT3_v4: [OpenAIModelName.GPT3_ROLLING], + OpenAIModelName.GPT4_v1: [OpenAIModelName.GPT4_v2, OpenAIModelName.GPT4_ROLLING], + OpenAIModelName.GPT4_v1_32k: [ OpenAIModelName.GPT4_v2_32k, + OpenAIModelName.GPT4_32k, + ], + OpenAIModelName.GPT4_TURBO: [ + OpenAIModelName.GPT4_v3, + OpenAIModelName.GPT4_v3_VISION, + OpenAIModelName.GPT4_v4, + OpenAIModelName.GPT4_VISION, ], } for base, copies in chat_model_mapping.items(): for copy in copies: - copy_info = ChatModelInfo(**OPEN_AI_CHAT_MODELS[base].__dict__) - copy_info.name = copy + copy_info = OPEN_AI_CHAT_MODELS[base].copy(update={"name": copy}) OPEN_AI_CHAT_MODELS[copy] = copy_info if copy.endswith(("-0301", "-0314")): copy_info.has_function_call_api = False @@ -165,7 +199,7 @@ OPEN_AI_MODELS = { class OpenAIConfiguration(ModelProviderConfiguration): - pass + fix_failed_parse_tries: int = UserConfigurable(3) class OpenAICredentials(ModelProviderCredentials): @@ -186,32 +220,46 @@ class OpenAICredentials(ModelProviderCredentials): ), ) api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION") + azure_endpoint: Optional[SecretStr] = None azure_model_to_deploy_id_map: Optional[dict[str, str]] = None - def get_api_access_kwargs(self, model: str = "") -> dict[str, str]: - credentials = {k: v for k, v in self.unmasked().items() if type(v) is str} + def get_api_access_kwargs(self) -> dict[str, str]: + kwargs = { + k: (v.get_secret_value() if type(v) is SecretStr else v) + for k, v in { + "api_key": self.api_key, + "base_url": self.api_base, + "organization": self.organization, + }.items() + if v is not None + } + if self.api_type == "azure": + kwargs["api_version"] = self.api_version + assert self.azure_endpoint, "Azure endpoint not configured" + kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value() + return kwargs + + def get_model_access_kwargs(self, model: str) -> dict[str, str]: + kwargs = {"model": model} if self.api_type == "azure" and model: - azure_credentials = self._get_azure_access_kwargs(model) - credentials.update(azure_credentials) - return credentials + azure_kwargs = self._get_azure_access_kwargs(model) + kwargs.update(azure_kwargs) + return kwargs def load_azure_config(self, config_file: Path) -> None: with open(config_file) as file: - config_params = yaml.load(file, Loader=yaml.FullLoader) or {} + config_params = yaml.load(file, Loader=yaml.SafeLoader) or {} try: - assert ( - azure_api_base := config_params.get("azure_api_base", "") - ) != "", "Azure API base URL not set" assert config_params.get( "azure_model_map", {} ), "Azure model->deployment_id map is empty" except AssertionError as e: raise ValueError(*e.args) - self.api_base = SecretStr(azure_api_base) self.api_type = config_params.get("azure_api_type", "azure") self.api_version = config_params.get("azure_api_version", "") + self.azure_endpoint = config_params.get("azure_endpoint") self.azure_model_to_deploy_id_map = config_params.get("azure_model_map") def _get_azure_access_kwargs(self, model: str) -> dict[str, str]: @@ -224,21 +272,13 @@ class OpenAICredentials(ModelProviderCredentials): raise ValueError(f"No Azure deployment ID configured for model '{model}'") deployment_id = self.azure_model_to_deploy_id_map[model] - if model in OPEN_AI_EMBEDDING_MODELS: - return {"engine": deployment_id} - else: - return {"deployment_id": deployment_id} - - -class OpenAIModelProviderBudget(ModelProviderBudget): - graceful_shutdown_threshold: float = UserConfigurable() - warning_threshold: float = UserConfigurable() + return {"model": deployment_id} class OpenAISettings(ModelProviderSettings): configuration: OpenAIConfiguration credentials: Optional[OpenAICredentials] - budget: OpenAIModelProviderBudget + budget: ModelProviderBudget class OpenAIProvider( @@ -248,53 +288,53 @@ class OpenAIProvider( name="openai_provider", description="Provides access to OpenAI's API.", configuration=OpenAIConfiguration( - retries_per_request=10, + retries_per_request=7, ), credentials=None, - budget=OpenAIModelProviderBudget( - total_budget=math.inf, - total_cost=0.0, - remaining_budget=math.inf, - usage=ModelProviderUsage( - prompt_tokens=0, - completion_tokens=0, - total_tokens=0, - ), - graceful_shutdown_threshold=0.005, - warning_threshold=0.01, - ), + budget=ModelProviderBudget(), ) _configuration: OpenAIConfiguration + _credentials: OpenAICredentials + _budget: ModelProviderBudget def __init__( self, - settings: OpenAISettings, - logger: logging.Logger, + settings: Optional[OpenAISettings] = None, + logger: Optional[logging.Logger] = None, ): - assert settings.credentials, "Cannot create OpenAIProvider without credentials" + if not settings: + settings = self.default_settings.copy(deep=True) + if not settings.credentials: + settings.credentials = OpenAICredentials.from_env() + + self._settings = settings + self._configuration = settings.configuration self._credentials = settings.credentials self._budget = settings.budget - self._logger = logger + if self._credentials.api_type == "azure": + from openai import AsyncAzureOpenAI - retry_handler = _OpenAIRetryHandler( - logger=self._logger, - num_retries=self._configuration.retries_per_request, - ) + # API key and org (if configured) are passed, the rest of the required + # credentials is loaded from the environment by the AzureOpenAI client. + self._client = AsyncAzureOpenAI(**self._credentials.get_api_access_kwargs()) + else: + from openai import AsyncOpenAI + + self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs()) + + self._logger = logger or logging.getLogger(__name__) - self._create_chat_completion = retry_handler(_create_chat_completion) - self._create_embedding = retry_handler(_create_embedding) + async def get_available_models(self) -> list[ChatModelInfo]: + _models = (await self._client.models.list()).data + return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS] def get_token_limit(self, model_name: str) -> int: """Get the token limit for a given model.""" return OPEN_AI_MODELS[model_name].max_tokens - def get_remaining_budget(self) -> float: - """Get the remaining budget.""" - return self._budget.remaining_budget - @classmethod def get_tokenizer(cls, model_name: OpenAIModelName) -> ModelTokenizer: return tiktoken.encoding_for_model(model_name) @@ -332,7 +372,7 @@ class OpenAIProvider( try: encoding = tiktoken.encoding_for_model(encoding_model) except KeyError: - cls._logger.warning( + logging.getLogger(__class__.__name__).warning( f"Model {model_name} not found. Defaulting to cl100k_base encoding." ) encoding = tiktoken.get_encoding("cl100k_base") @@ -351,40 +391,101 @@ class OpenAIProvider( self, model_prompt: list[ChatMessage], model_name: OpenAIModelName, - completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None, + completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, functions: Optional[list[CompletionModelFunction]] = None, + max_output_tokens: Optional[int] = None, **kwargs, ) -> ChatModelResponse[_T]: """Create a completion using the OpenAI API.""" - completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs) - tool_calls_compat_mode = functions and "tools" not in completion_kwargs - if "messages" in completion_kwargs: - model_prompt += completion_kwargs["messages"] - del completion_kwargs["messages"] - - response = await self._create_chat_completion( - messages=model_prompt, - **completion_kwargs, + openai_messages, completion_kwargs = self._get_chat_completion_args( + model_prompt=model_prompt, + model_name=model_name, + functions=functions, + max_tokens=max_output_tokens, + **kwargs, ) - response_args = { - "model_info": OPEN_AI_CHAT_MODELS[model_name], - "prompt_tokens_used": response.usage.prompt_tokens, - "completion_tokens_used": response.usage.completion_tokens, - } + tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs) + + total_cost = 0.0 + attempts = 0 + while True: + _response, _cost, t_input, t_output = await self._create_chat_completion( + messages=openai_messages, + **completion_kwargs, + ) + total_cost += _cost + + # If parsing the response fails, append the error to the prompt, and let the + # LLM fix its mistake(s). + attempts += 1 + parse_errors: list[Exception] = [] - response_message = response.choices[0].message.to_dict_recursive() - if tool_calls_compat_mode: - response_message["tool_calls"] = _tool_calls_compat_extract_calls( - response_message["content"] + _assistant_msg = _response.choices[0].message + + tool_calls, _errors = self._parse_assistant_tool_calls( + _assistant_msg, tool_calls_compat_mode + ) + parse_errors += _errors + + assistant_msg = AssistantChatMessage( + content=_assistant_msg.content, + tool_calls=tool_calls or None, ) - response = ChatModelResponse( - response=response_message, - parsed_result=completion_parser(response_message), - **response_args, - ) - self._budget.update_usage_and_cost(response) - return response + + parsed_result: _T = None # type: ignore + if not parse_errors: + try: + parsed_result = completion_parser(assistant_msg) + except Exception as e: + parse_errors.append(e) + + if not parse_errors: + if attempts > 1: + self._logger.debug( + f"Total cost for {attempts} attempts: ${round(total_cost, 5)}" + ) + + return ChatModelResponse( + response=AssistantChatMessage( + content=_assistant_msg.content, + tool_calls=tool_calls or None, + ), + parsed_result=parsed_result, + model_info=OPEN_AI_CHAT_MODELS[model_name], + prompt_tokens_used=t_input, + completion_tokens_used=t_output, + ) + + else: + self._logger.debug( + f"Parsing failed on response: '''{_assistant_msg}'''" + ) + self._logger.warning( + f"Parsing attempt #{attempts} failed: {parse_errors}" + ) + for e in parse_errors: + sentry_sdk.capture_exception( + error=e, + extras={"assistant_msg": _assistant_msg, "i_attempt": attempts}, + ) + + if attempts < self._configuration.fix_failed_parse_tries: + openai_messages.append(_assistant_msg.dict(exclude_none=True)) + openai_messages.append( + { + "role": "system", + "content": ( + "ERROR PARSING YOUR RESPONSE:\n\n" + + "\n\n".join( + f"{e.__class__.__name__}: {e}" for e in parse_errors + ) + ), + } + ) + continue + else: + raise parse_errors[0] async def create_embedding( self, @@ -397,62 +498,68 @@ class OpenAIProvider( embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs) response = await self._create_embedding(text=text, **embedding_kwargs) - response_args = { - "model_info": OPEN_AI_EMBEDDING_MODELS[model_name], - "prompt_tokens_used": response.usage.prompt_tokens, - "completion_tokens_used": response.usage.completion_tokens, - } response = EmbeddingModelResponse( - **response_args, - embedding=embedding_parser(response.embeddings[0]), + embedding=embedding_parser(response.data[0].embedding), + model_info=OPEN_AI_EMBEDDING_MODELS[model_name], + prompt_tokens_used=response.usage.prompt_tokens, + completion_tokens_used=0, ) self._budget.update_usage_and_cost(response) return response - def _get_completion_kwargs( + def _get_chat_completion_args( self, + model_prompt: list[ChatMessage], model_name: OpenAIModelName, functions: Optional[list[CompletionModelFunction]] = None, **kwargs, - ) -> dict: - """Get kwargs for completion API call. + ) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]: + """Prepare chat completion arguments and keyword arguments for API call. Args: - model: The model to use. - kwargs: Keyword arguments to override the default values. + model_prompt: List of ChatMessages. + model_name: The model to use. + functions: Optional list of functions available to the LLM. + kwargs: Additional keyword arguments. Returns: - The kwargs for the chat API call. - + list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call + dict[str, Any]: Any other kwargs for the OpenAI call """ - completion_kwargs = { - "model": model_name, - **kwargs, - **self._credentials.get_api_access_kwargs(model_name), - } + kwargs.update(self._credentials.get_model_access_kwargs(model_name)) if functions: if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api: - completion_kwargs["tools"] = [ + kwargs["tools"] = [ {"type": "function", "function": f.schema} for f in functions ] if len(functions) == 1: # force the model to call the only specified function - completion_kwargs["tool_choice"] = { + kwargs["tool_choice"] = { "type": "function", "function": {"name": functions[0].name}, } else: # Provide compatibility with older models - _functions_compat_fix_kwargs(functions, completion_kwargs) + _functions_compat_fix_kwargs(functions, kwargs) if extra_headers := self._configuration.extra_request_headers: - if completion_kwargs.get("headers"): - completion_kwargs["headers"].update(extra_headers) - else: - completion_kwargs["headers"] = extra_headers.copy() + kwargs["extra_headers"] = kwargs.get("extra_headers", {}) + kwargs["extra_headers"].update(extra_headers.copy()) + + if "messages" in kwargs: + model_prompt += kwargs["messages"] + del kwargs["messages"] + + openai_messages: list[ChatCompletionMessageParam] = [ + message.dict( + include={"role", "content", "tool_calls", "name"}, + exclude_none=True, + ) + for message in model_prompt + ] - return completion_kwargs + return openai_messages, kwargs def _get_embedding_kwargs( self, @@ -469,120 +576,159 @@ class OpenAIProvider( The kwargs for the embedding API call. """ - embedding_kwargs = { - "model": model_name, - **kwargs, - **self._credentials.unmasked(), - } + kwargs.update(self._credentials.get_model_access_kwargs(model_name)) if extra_headers := self._configuration.extra_request_headers: - if embedding_kwargs.get("headers"): - embedding_kwargs["headers"].update(extra_headers) - else: - embedding_kwargs["headers"] = extra_headers.copy() - - return embedding_kwargs - - def __repr__(self): - return "OpenAIProvider()" + kwargs["extra_headers"] = kwargs.get("extra_headers", {}) + kwargs["extra_headers"].update(extra_headers.copy()) + return kwargs -async def _create_embedding(text: str, *_, **kwargs) -> openai.Embedding: - """Embed text using the OpenAI API. - - Args: - text str: The text to embed. - model str: The name of the model to use. - - Returns: - str: The embedding. - """ - return await openai.Embedding.acreate( - input=[text], + async def _create_chat_completion( + self, + messages: list[ChatCompletionMessageParam], + model: OpenAIModelName, + *_, **kwargs, - ) - - -async def _create_chat_completion( - messages: list[ChatMessage], *_, **kwargs -) -> openai.Completion: - """Create a chat completion using the OpenAI API. + ) -> tuple[ChatCompletion, float, int, int]: + """ + Create a chat completion using the OpenAI API with retry handling. - Args: - messages: The prompt to use. + Params: + openai_messages: List of OpenAI-consumable message dict objects + model: The model to use for the completion - Returns: - The completion. - """ - raw_messages = [ - message.dict(include={"role", "content", "tool_calls", "name"}) - for message in messages - ] - return await openai.ChatCompletion.acreate( - messages=raw_messages, - **kwargs, - ) + Returns: + ChatCompletion: The chat completion response object + float: The cost ($) of this completion + int: Number of prompt tokens used + int: Number of completion tokens used + """ + @self._retry_api_request + async def _create_chat_completion_with_retry( + messages: list[ChatCompletionMessageParam], **kwargs + ) -> ChatCompletion: + return await self._client.chat.completions.create( + messages=messages, # type: ignore + **kwargs, + ) -class _OpenAIRetryHandler: - """Retry Handler for OpenAI API call. + completion = await _create_chat_completion_with_retry( + messages, model=model, **kwargs + ) - Args: - num_retries int: Number of retries. Defaults to 10. - backoff_base float: Base for exponential backoff. Defaults to 2. - warn_user bool: Whether to warn the user. Defaults to True. - """ + if completion.usage: + prompt_tokens_used = completion.usage.prompt_tokens + completion_tokens_used = completion.usage.completion_tokens + else: + prompt_tokens_used = completion_tokens_used = 0 - _retry_limit_msg = "Error: Reached rate limit, passing..." - _api_key_error_msg = ( - "Please double check that you have setup a PAID OpenAI API Account. You can " - "read more here: https://docs.agpt.co/setup/#getting-an-openai-api-key" - ) - _backoff_msg = "Error: API Bad gateway. Waiting {backoff} seconds..." + cost = self._budget.update_usage_and_cost( + model_info=OPEN_AI_CHAT_MODELS[model], + input_tokens_used=prompt_tokens_used, + output_tokens_used=completion_tokens_used, + ) + self._logger.debug( + f"Completion usage: {prompt_tokens_used} input, " + f"{completion_tokens_used} output - ${round(cost, 5)}" + ) + return completion, cost, prompt_tokens_used, completion_tokens_used - def __init__( - self, - logger: logging.Logger, - num_retries: int = 10, - backoff_base: float = 2.0, - warn_user: bool = True, + def _parse_assistant_tool_calls( + self, assistant_message: ChatCompletionMessage, compat_mode: bool = False ): - self._logger = logger - self._num_retries = num_retries - self._backoff_base = backoff_base - self._warn_user = warn_user - - def _log_rate_limit_error(self) -> None: - self._logger.debug(self._retry_limit_msg) - if self._warn_user: - self._logger.warning(self._api_key_error_msg) - self._warn_user = False - - def _backoff(self, attempt: int) -> None: - backoff = self._backoff_base ** (attempt + 2) - self._logger.debug(self._backoff_msg.format(backoff=backoff)) - time.sleep(backoff) - - def __call__(self, func: Callable[_P, _T]) -> Callable[_P, _T]: - @functools.wraps(func) - async def _wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T: - num_attempts = self._num_retries + 1 # +1 for the first attempt - for attempt in range(1, num_attempts + 1): - try: - return await func(*args, **kwargs) + tool_calls: list[AssistantToolCall] = [] + parse_errors: list[Exception] = [] - except RateLimitError: - if attempt == num_attempts: - raise - self._log_rate_limit_error() - - except APIError as e: - if (e.http_status != 502) or (attempt == num_attempts): - raise + if assistant_message.tool_calls: + for _tc in assistant_message.tool_calls: + try: + parsed_arguments = json_loads(_tc.function.arguments) + except Exception as e: + err_message = ( + f"Decoding arguments for {_tc.function.name} failed: " + + str(e.args[0]) + ) + parse_errors.append( + type(e)(err_message, *e.args[1:]).with_traceback( + e.__traceback__ + ) + ) + continue + + tool_calls.append( + AssistantToolCall( + id=_tc.id, + type=_tc.type, + function=AssistantFunctionCall( + name=_tc.function.name, + arguments=parsed_arguments, + ), + ) + ) + + # If parsing of all tool calls succeeds in the end, we ignore any issues + if len(tool_calls) == len(assistant_message.tool_calls): + parse_errors = [] + + elif compat_mode and assistant_message.content: + try: + tool_calls = list( + _tool_calls_compat_extract_calls(assistant_message.content) + ) + except Exception as e: + parse_errors.append(e) + + return tool_calls, parse_errors + + def _create_embedding( + self, text: str, *_, **kwargs + ) -> Coroutine[None, None, CreateEmbeddingResponse]: + """Create an embedding using the OpenAI API with retry handling.""" + + @self._retry_api_request + async def _create_embedding_with_retry( + text: str, *_, **kwargs + ) -> CreateEmbeddingResponse: + return await self._client.embeddings.create( + input=[text], + **kwargs, + ) - self._backoff(attempt) + return _create_embedding_with_retry(text, *_, **kwargs) + + def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]: + _log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG) + + def _log_on_fail(retry_state: tenacity.RetryCallState) -> None: + _log_retry_debug_message(retry_state) + + if ( + retry_state.attempt_number == 0 + and retry_state.outcome + and isinstance(retry_state.outcome.exception(), RateLimitError) + ): + self._logger.warning( + "Please double check that you have setup a PAID OpenAI API Account." + " You can read more here: " + "https://docs.agpt.co/setup/#getting-an-openai-api-key" + ) + + return tenacity.retry( + retry=( + tenacity.retry_if_exception_type(RateLimitError) + | tenacity.retry_if_exception( + lambda e: isinstance(e, APIStatusError) and e.status_code == 502 + ) + ), + wait=tenacity.wait_exponential(), + stop=tenacity.stop_after_attempt(self._configuration.retries_per_request), + after=_log_on_fail, + )(func) - return _wrapped + def __repr__(self): + return "OpenAIProvider()" def format_function_specs_as_typescript_ns( @@ -708,21 +854,22 @@ def _functions_compat_fix_kwargs( ] -def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDict]: - import json +def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]: import re + import uuid logging.debug(f"Trying to extract tool calls from response:\n{response}") if response[0] == "[": - tool_calls: list[AssistantToolCallDict] = json.loads(response) + tool_calls: list[AssistantToolCallDict] = json_loads(response) else: block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL) if not block: - raise ValueError("Could not find tool calls block in response") - tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1)) + raise ValueError("Could not find tool_calls block in response") + tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1)) for t in tool_calls: + t["id"] = str(uuid.uuid4()) t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK - return tool_calls + yield AssistantToolCall.parse_obj(t) |