diff options
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/openai.py')
-rw-r--r-- | autogpts/autogpt/autogpt/core/resource/model_providers/openai.py | 438 |
1 files changed, 261 insertions, 177 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index d242942ae..6d1eca5e5 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -1,21 +1,22 @@ import enum -import functools import logging import math -import time -from typing import Callable, Optional, ParamSpec, TypeVar +import os +from pathlib import Path +from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar -import openai +import tenacity import tiktoken -from openai.error import APIError, RateLimitError +import yaml +from openai._exceptions import APIStatusError, RateLimitError +from openai.types import CreateEmbeddingResponse +from openai.types.chat import ChatCompletion +from pydantic import SecretStr -from autogpt.core.configuration import ( - Configurable, - SystemConfiguration, - UserConfigurable, -) +from autogpt.core.configuration import Configurable, UserConfigurable from autogpt.core.resource.model_providers.schema import ( - AssistantChatMessageDict, + AssistantChatMessage, + AssistantToolCall, AssistantToolCallDict, ChatMessage, ChatModelInfo, @@ -27,6 +28,7 @@ from autogpt.core.resource.model_providers.schema import ( EmbeddingModelProvider, EmbeddingModelResponse, ModelProviderBudget, + ModelProviderConfiguration, ModelProviderCredentials, ModelProviderName, ModelProviderService, @@ -60,8 +62,10 @@ class OpenAIModelName(str, enum.Enum): GPT4_v2 = "gpt-4-0613" GPT4_v2_32k = "gpt-4-32k-0613" GPT4_v3 = "gpt-4-1106-preview" + GPT4_v4 = "gpt-4-0125-preview" GPT4_ROLLING = "gpt-4" GPT4_ROLLING_32k = "gpt-4-32k" + GPT4_TURBO = "gpt-4-turbo-preview" GPT4_VISION = "gpt-4-vision-preview" GPT4 = GPT4_ROLLING GPT4_32k = GPT4_ROLLING_32k @@ -128,7 +132,7 @@ OPEN_AI_CHAT_MODELS = { has_function_call_api=True, ), ChatModelInfo( - name=OpenAIModelName.GPT4_v3, + name=OpenAIModelName.GPT4_TURBO, service=ModelProviderService.CHAT, provider_name=ModelProviderName.OPENAI, prompt_token_cost=0.01 / 1000, @@ -147,6 +151,7 @@ chat_model_mapping = { OpenAIModelName.GPT4_v1_32k, OpenAIModelName.GPT4_v2_32k, ], + OpenAIModelName.GPT4_TURBO: [OpenAIModelName.GPT4_v3, OpenAIModelName.GPT4_v4], } for base, copies in chat_model_mapping.items(): for copy in copies: @@ -163,19 +168,86 @@ OPEN_AI_MODELS = { } -class OpenAIConfiguration(SystemConfiguration): - retries_per_request: int = UserConfigurable() +class OpenAIConfiguration(ModelProviderConfiguration): + fix_failed_parse_tries: int = UserConfigurable(3) + + +class OpenAICredentials(ModelProviderCredentials): + """Credentials for OpenAI.""" + + api_key: SecretStr = UserConfigurable(from_env="OPENAI_API_KEY") + api_base: Optional[SecretStr] = UserConfigurable( + default=None, from_env="OPENAI_API_BASE_URL" + ) + organization: Optional[SecretStr] = UserConfigurable(from_env="OPENAI_ORGANIZATION") + + api_type: str = UserConfigurable( + default="", + from_env=lambda: ( + "azure" + if os.getenv("USE_AZURE") == "True" + else os.getenv("OPENAI_API_TYPE") + ), + ) + api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION") + azure_endpoint: Optional[SecretStr] = None + azure_model_to_deploy_id_map: Optional[dict[str, str]] = None + + def get_api_access_kwargs(self) -> dict[str, str]: + kwargs = { + k: (v.get_secret_value() if type(v) is SecretStr else v) + for k, v in { + "api_key": self.api_key, + "base_url": self.api_base, + "organization": self.organization, + }.items() + if v is not None + } + if self.api_type == "azure": + kwargs["api_version"] = self.api_version + kwargs["azure_endpoint"] = self.azure_endpoint + return kwargs + + def get_model_access_kwargs(self, model: str) -> dict[str, str]: + kwargs = {"model": model} + if self.api_type == "azure" and model: + azure_kwargs = self._get_azure_access_kwargs(model) + kwargs.update(azure_kwargs) + return kwargs + + def load_azure_config(self, config_file: Path) -> None: + with open(config_file) as file: + config_params = yaml.load(file, Loader=yaml.FullLoader) or {} + + try: + assert config_params.get( + "azure_model_map", {} + ), "Azure model->deployment_id map is empty" + except AssertionError as e: + raise ValueError(*e.args) + self.api_type = config_params.get("azure_api_type", "azure") + self.api_version = config_params.get("azure_api_version", "") + self.azure_endpoint = config_params.get("azure_endpoint") + self.azure_model_to_deploy_id_map = config_params.get("azure_model_map") -class OpenAIModelProviderBudget(ModelProviderBudget): - graceful_shutdown_threshold: float = UserConfigurable() - warning_threshold: float = UserConfigurable() + def _get_azure_access_kwargs(self, model: str) -> dict[str, str]: + """Get the kwargs for the Azure API.""" + + if not self.azure_model_to_deploy_id_map: + raise ValueError("Azure model deployment map not configured") + + if model not in self.azure_model_to_deploy_id_map: + raise ValueError(f"No Azure deployment ID configured for model '{model}'") + deployment_id = self.azure_model_to_deploy_id_map[model] + + return {"model": deployment_id} class OpenAISettings(ModelProviderSettings): configuration: OpenAIConfiguration - credentials: ModelProviderCredentials - budget: OpenAIModelProviderBudget + credentials: Optional[OpenAICredentials] + budget: ModelProviderBudget class OpenAIProvider( @@ -187,8 +259,8 @@ class OpenAIProvider( configuration=OpenAIConfiguration( retries_per_request=10, ), - credentials=ModelProviderCredentials(), - budget=OpenAIModelProviderBudget( + credentials=None, + budget=ModelProviderBudget( total_budget=math.inf, total_cost=0.0, remaining_budget=math.inf, @@ -197,38 +269,41 @@ class OpenAIProvider( completion_tokens=0, total_tokens=0, ), - graceful_shutdown_threshold=0.005, - warning_threshold=0.01, ), ) + _budget: ModelProviderBudget + _configuration: OpenAIConfiguration + def __init__( self, settings: OpenAISettings, logger: logging.Logger, ): + self._settings = settings + + assert settings.credentials, "Cannot create OpenAIProvider without credentials" self._configuration = settings.configuration self._credentials = settings.credentials self._budget = settings.budget - self._logger = logger + if self._credentials.api_type == "azure": + from openai import AsyncAzureOpenAI - retry_handler = _OpenAIRetryHandler( - logger=self._logger, - num_retries=self._configuration.retries_per_request, - ) + # API key and org (if configured) are passed, the rest of the required + # credentials is loaded from the environment by the AzureOpenAI client. + self._client = AsyncAzureOpenAI(**self._credentials.get_api_access_kwargs()) + else: + from openai import AsyncOpenAI + + self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs()) - self._create_chat_completion = retry_handler(_create_chat_completion) - self._create_embedding = retry_handler(_create_embedding) + self._logger = logger def get_token_limit(self, model_name: str) -> int: """Get the token limit for a given model.""" return OPEN_AI_MODELS[model_name].max_tokens - def get_remaining_budget(self) -> float: - """Get the remaining budget.""" - return self._budget.remaining_budget - @classmethod def get_tokenizer(cls, model_name: OpenAIModelName) -> ModelTokenizer: return tiktoken.encoding_for_model(model_name) @@ -266,7 +341,7 @@ class OpenAIProvider( try: encoding = tiktoken.encoding_for_model(encoding_model) except KeyError: - cls._logger.warn( + logging.getLogger(__class__.__name__).warning( f"Model {model_name} not found. Defaulting to cl100k_base encoding." ) encoding = tiktoken.get_encoding("cl100k_base") @@ -285,7 +360,7 @@ class OpenAIProvider( self, model_prompt: list[ChatMessage], model_name: OpenAIModelName, - completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None, + completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, functions: Optional[list[CompletionModelFunction]] = None, **kwargs, ) -> ChatModelResponse[_T]: @@ -297,27 +372,68 @@ class OpenAIProvider( model_prompt += completion_kwargs["messages"] del completion_kwargs["messages"] - response = await self._create_chat_completion( - messages=model_prompt, - **completion_kwargs, - ) - response_args = { - "model_info": OPEN_AI_CHAT_MODELS[model_name], - "prompt_tokens_used": response.usage.prompt_tokens, - "completion_tokens_used": response.usage.completion_tokens, - } + cost = 0.0 + attempts = 0 + while True: + _response = await self._create_chat_completion( + messages=model_prompt, + **completion_kwargs, + ) - response_message = response.choices[0].message.to_dict_recursive() - if tool_calls_compat_mode: - response_message["tool_calls"] = _tool_calls_compat_extract_calls( - response_message["content"] + _assistant_msg = _response.choices[0].message + assistant_msg = AssistantChatMessage( + content=_assistant_msg.content, + tool_calls=( + [AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls] + if _assistant_msg.tool_calls + else None + ), ) - response = ChatModelResponse( - response=response_message, - parsed_result=completion_parser(response_message), - **response_args, - ) - self._budget.update_usage_and_cost(response) + response = ChatModelResponse( + response=assistant_msg, + model_info=OPEN_AI_CHAT_MODELS[model_name], + prompt_tokens_used=( + _response.usage.prompt_tokens if _response.usage else 0 + ), + completion_tokens_used=( + _response.usage.completion_tokens if _response.usage else 0 + ), + ) + cost += self._budget.update_usage_and_cost(response) + self._logger.debug( + f"Completion usage: {response.prompt_tokens_used} input, " + f"{response.completion_tokens_used} output - ${round(cost, 5)}" + ) + + # If parsing the response fails, append the error to the prompt, and let the + # LLM fix its mistake(s). + try: + attempts += 1 + + if ( + tool_calls_compat_mode + and assistant_msg.content + and not assistant_msg.tool_calls + ): + assistant_msg.tool_calls = list( + _tool_calls_compat_extract_calls(assistant_msg.content) + ) + + response.parsed_result = completion_parser(assistant_msg) + break + except Exception as e: + self._logger.warning(f"Parsing attempt #{attempts} failed: {e}") + self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''") + if attempts < self._configuration.fix_failed_parse_tries: + model_prompt.append( + ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}") + ) + else: + raise + + if attempts > 1: + self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}") + return response async def create_embedding( @@ -331,14 +447,11 @@ class OpenAIProvider( embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs) response = await self._create_embedding(text=text, **embedding_kwargs) - response_args = { - "model_info": OPEN_AI_EMBEDDING_MODELS[model_name], - "prompt_tokens_used": response.usage.prompt_tokens, - "completion_tokens_used": response.usage.completion_tokens, - } response = EmbeddingModelResponse( - **response_args, - embedding=embedding_parser(response.embeddings[0]), + embedding=embedding_parser(response.data[0].embedding), + model_info=OPEN_AI_EMBEDDING_MODELS[model_name], + prompt_tokens_used=response.usage.prompt_tokens, + completion_tokens_used=0, ) self._budget.update_usage_and_cost(response) return response @@ -359,28 +472,29 @@ class OpenAIProvider( The kwargs for the chat API call. """ - completion_kwargs = { - "model": model_name, - **kwargs, - **self._credentials.unmasked(), - } + kwargs.update(self._credentials.get_model_access_kwargs(model_name)) if functions: if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api: - completion_kwargs["tools"] = [ + kwargs["tools"] = [ {"type": "function", "function": f.schema} for f in functions ] if len(functions) == 1: # force the model to call the only specified function - completion_kwargs["tool_choice"] = { + kwargs["tool_choice"] = { "type": "function", "function": {"name": functions[0].name}, } else: # Provide compatibility with older models - _functions_compat_fix_kwargs(functions, completion_kwargs) + _functions_compat_fix_kwargs(functions, kwargs) + + if extra_headers := self._configuration.extra_request_headers: + kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update( + extra_headers.copy() + ) - return completion_kwargs + return kwargs def _get_embedding_kwargs( self, @@ -397,114 +511,82 @@ class OpenAIProvider( The kwargs for the embedding API call. """ - embedding_kwargs = { - "model": model_name, - **kwargs, - **self._credentials.unmasked(), - } + kwargs.update(self._credentials.get_model_access_kwargs(model_name)) - return embedding_kwargs - - def __repr__(self): - return "OpenAIProvider()" - - -async def _create_embedding(text: str, *_, **kwargs) -> openai.Embedding: - """Embed text using the OpenAI API. - - Args: - text str: The text to embed. - model str: The name of the model to use. - - Returns: - str: The embedding. - """ - return await openai.Embedding.acreate( - input=[text], - **kwargs, - ) - - -async def _create_chat_completion( - messages: list[ChatMessage], *_, **kwargs -) -> openai.Completion: - """Create a chat completion using the OpenAI API. - - Args: - messages: The prompt to use. - - Returns: - The completion. - """ - raw_messages = [ - message.dict(include={"role", "content", "tool_calls", "name"}) - for message in messages - ] - return await openai.ChatCompletion.acreate( - messages=raw_messages, - **kwargs, - ) - - -class _OpenAIRetryHandler: - """Retry Handler for OpenAI API call. - - Args: - num_retries int: Number of retries. Defaults to 10. - backoff_base float: Base for exponential backoff. Defaults to 2. - warn_user bool: Whether to warn the user. Defaults to True. - """ - - _retry_limit_msg = "Error: Reached rate limit, passing..." - _api_key_error_msg = ( - "Please double check that you have setup a PAID OpenAI API Account. You can " - "read more here: https://docs.agpt.co/setup/#getting-an-api-key" - ) - _backoff_msg = "Error: API Bad gateway. Waiting {backoff} seconds..." - - def __init__( - self, - logger: logging.Logger, - num_retries: int = 10, - backoff_base: float = 2.0, - warn_user: bool = True, - ): - self._logger = logger - self._num_retries = num_retries - self._backoff_base = backoff_base - self._warn_user = warn_user - - def _log_rate_limit_error(self) -> None: - self._logger.debug(self._retry_limit_msg) - if self._warn_user: - self._logger.warning(self._api_key_error_msg) - self._warn_user = False + if extra_headers := self._configuration.extra_request_headers: + kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update( + extra_headers.copy() + ) - def _backoff(self, attempt: int) -> None: - backoff = self._backoff_base ** (attempt + 2) - self._logger.debug(self._backoff_msg.format(backoff=backoff)) - time.sleep(backoff) + return kwargs + + def _create_chat_completion( + self, messages: list[ChatMessage], *_, **kwargs + ) -> Coroutine[None, None, ChatCompletion]: + """Create a chat completion using the OpenAI API with retry handling.""" + + @self._retry_api_request + async def _create_chat_completion_with_retry( + messages: list[ChatMessage], *_, **kwargs + ) -> ChatCompletion: + raw_messages = [ + message.dict(include={"role", "content", "tool_calls", "name"}) + for message in messages + ] + return await self._client.chat.completions.create( + messages=raw_messages, # type: ignore + **kwargs, + ) - def __call__(self, func: Callable[_P, _T]) -> Callable[_P, _T]: - @functools.wraps(func) - async def _wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T: - num_attempts = self._num_retries + 1 # +1 for the first attempt - for attempt in range(1, num_attempts + 1): - try: - return await func(*args, **kwargs) + return _create_chat_completion_with_retry(messages, *_, **kwargs) - except RateLimitError: - if attempt == num_attempts: - raise - self._log_rate_limit_error() + def _create_embedding( + self, text: str, *_, **kwargs + ) -> Coroutine[None, None, CreateEmbeddingResponse]: + """Create an embedding using the OpenAI API with retry handling.""" - except APIError as e: - if (e.http_status != 502) or (attempt == num_attempts): - raise + @self._retry_api_request + async def _create_embedding_with_retry( + text: str, *_, **kwargs + ) -> CreateEmbeddingResponse: + return await self._client.embeddings.create( + input=[text], + **kwargs, + ) - self._backoff(attempt) + return _create_embedding_with_retry(text, *_, **kwargs) + + def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]: + _log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG) + + def _log_on_fail(retry_state: tenacity.RetryCallState) -> None: + _log_retry_debug_message(retry_state) + + if ( + retry_state.attempt_number == 0 + and retry_state.outcome + and isinstance(retry_state.outcome.exception(), RateLimitError) + ): + self._logger.warning( + "Please double check that you have setup a PAID OpenAI API Account." + " You can read more here: " + "https://docs.agpt.co/setup/#getting-an-openai-api-key" + ) + + return tenacity.retry( + retry=( + tenacity.retry_if_exception_type(RateLimitError) + | tenacity.retry_if_exception( + lambda e: isinstance(e, APIStatusError) and e.status_code == 502 + ) + ), + wait=tenacity.wait_exponential(), + stop=tenacity.stop_after_attempt(self._configuration.retries_per_request), + after=_log_on_fail, + )(func) - return _wrapped + def __repr__(self): + return "OpenAIProvider()" def format_function_specs_as_typescript_ns( @@ -572,10 +654,12 @@ def count_openai_functions_tokens( ) -> int: """Returns the number of tokens taken up by a set of function definitions - Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 + Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 # noqa: E501 """ return count_tokens( - f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}" + "# Tools\n\n" + "## functions\n\n" + f"{format_function_specs_as_typescript_ns(functions)}" ) @@ -628,7 +712,7 @@ def _functions_compat_fix_kwargs( ] -def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDict]: +def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]: import json import re @@ -639,10 +723,10 @@ def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDic else: block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL) if not block: - raise ValueError("Could not find tool calls block in response") + raise ValueError("Could not find tool_calls block in response") tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1)) for t in tool_calls: t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK - return tool_calls + yield AssistantToolCall.parse_obj(t) |