aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/core/resource/model_providers/openai.py')
-rw-r--r--autogpts/autogpt/autogpt/core/resource/model_providers/openai.py438
1 files changed, 261 insertions, 177 deletions
diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
index d242942ae..6d1eca5e5 100644
--- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
+++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py
@@ -1,21 +1,22 @@
import enum
-import functools
import logging
import math
-import time
-from typing import Callable, Optional, ParamSpec, TypeVar
+import os
+from pathlib import Path
+from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
-import openai
+import tenacity
import tiktoken
-from openai.error import APIError, RateLimitError
+import yaml
+from openai._exceptions import APIStatusError, RateLimitError
+from openai.types import CreateEmbeddingResponse
+from openai.types.chat import ChatCompletion
+from pydantic import SecretStr
-from autogpt.core.configuration import (
- Configurable,
- SystemConfiguration,
- UserConfigurable,
-)
+from autogpt.core.configuration import Configurable, UserConfigurable
from autogpt.core.resource.model_providers.schema import (
- AssistantChatMessageDict,
+ AssistantChatMessage,
+ AssistantToolCall,
AssistantToolCallDict,
ChatMessage,
ChatModelInfo,
@@ -27,6 +28,7 @@ from autogpt.core.resource.model_providers.schema import (
EmbeddingModelProvider,
EmbeddingModelResponse,
ModelProviderBudget,
+ ModelProviderConfiguration,
ModelProviderCredentials,
ModelProviderName,
ModelProviderService,
@@ -60,8 +62,10 @@ class OpenAIModelName(str, enum.Enum):
GPT4_v2 = "gpt-4-0613"
GPT4_v2_32k = "gpt-4-32k-0613"
GPT4_v3 = "gpt-4-1106-preview"
+ GPT4_v4 = "gpt-4-0125-preview"
GPT4_ROLLING = "gpt-4"
GPT4_ROLLING_32k = "gpt-4-32k"
+ GPT4_TURBO = "gpt-4-turbo-preview"
GPT4_VISION = "gpt-4-vision-preview"
GPT4 = GPT4_ROLLING
GPT4_32k = GPT4_ROLLING_32k
@@ -128,7 +132,7 @@ OPEN_AI_CHAT_MODELS = {
has_function_call_api=True,
),
ChatModelInfo(
- name=OpenAIModelName.GPT4_v3,
+ name=OpenAIModelName.GPT4_TURBO,
service=ModelProviderService.CHAT,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.01 / 1000,
@@ -147,6 +151,7 @@ chat_model_mapping = {
OpenAIModelName.GPT4_v1_32k,
OpenAIModelName.GPT4_v2_32k,
],
+ OpenAIModelName.GPT4_TURBO: [OpenAIModelName.GPT4_v3, OpenAIModelName.GPT4_v4],
}
for base, copies in chat_model_mapping.items():
for copy in copies:
@@ -163,19 +168,86 @@ OPEN_AI_MODELS = {
}
-class OpenAIConfiguration(SystemConfiguration):
- retries_per_request: int = UserConfigurable()
+class OpenAIConfiguration(ModelProviderConfiguration):
+ fix_failed_parse_tries: int = UserConfigurable(3)
+
+
+class OpenAICredentials(ModelProviderCredentials):
+ """Credentials for OpenAI."""
+
+ api_key: SecretStr = UserConfigurable(from_env="OPENAI_API_KEY")
+ api_base: Optional[SecretStr] = UserConfigurable(
+ default=None, from_env="OPENAI_API_BASE_URL"
+ )
+ organization: Optional[SecretStr] = UserConfigurable(from_env="OPENAI_ORGANIZATION")
+
+ api_type: str = UserConfigurable(
+ default="",
+ from_env=lambda: (
+ "azure"
+ if os.getenv("USE_AZURE") == "True"
+ else os.getenv("OPENAI_API_TYPE")
+ ),
+ )
+ api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION")
+ azure_endpoint: Optional[SecretStr] = None
+ azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
+
+ def get_api_access_kwargs(self) -> dict[str, str]:
+ kwargs = {
+ k: (v.get_secret_value() if type(v) is SecretStr else v)
+ for k, v in {
+ "api_key": self.api_key,
+ "base_url": self.api_base,
+ "organization": self.organization,
+ }.items()
+ if v is not None
+ }
+ if self.api_type == "azure":
+ kwargs["api_version"] = self.api_version
+ kwargs["azure_endpoint"] = self.azure_endpoint
+ return kwargs
+
+ def get_model_access_kwargs(self, model: str) -> dict[str, str]:
+ kwargs = {"model": model}
+ if self.api_type == "azure" and model:
+ azure_kwargs = self._get_azure_access_kwargs(model)
+ kwargs.update(azure_kwargs)
+ return kwargs
+
+ def load_azure_config(self, config_file: Path) -> None:
+ with open(config_file) as file:
+ config_params = yaml.load(file, Loader=yaml.FullLoader) or {}
+
+ try:
+ assert config_params.get(
+ "azure_model_map", {}
+ ), "Azure model->deployment_id map is empty"
+ except AssertionError as e:
+ raise ValueError(*e.args)
+ self.api_type = config_params.get("azure_api_type", "azure")
+ self.api_version = config_params.get("azure_api_version", "")
+ self.azure_endpoint = config_params.get("azure_endpoint")
+ self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
-class OpenAIModelProviderBudget(ModelProviderBudget):
- graceful_shutdown_threshold: float = UserConfigurable()
- warning_threshold: float = UserConfigurable()
+ def _get_azure_access_kwargs(self, model: str) -> dict[str, str]:
+ """Get the kwargs for the Azure API."""
+
+ if not self.azure_model_to_deploy_id_map:
+ raise ValueError("Azure model deployment map not configured")
+
+ if model not in self.azure_model_to_deploy_id_map:
+ raise ValueError(f"No Azure deployment ID configured for model '{model}'")
+ deployment_id = self.azure_model_to_deploy_id_map[model]
+
+ return {"model": deployment_id}
class OpenAISettings(ModelProviderSettings):
configuration: OpenAIConfiguration
- credentials: ModelProviderCredentials
- budget: OpenAIModelProviderBudget
+ credentials: Optional[OpenAICredentials]
+ budget: ModelProviderBudget
class OpenAIProvider(
@@ -187,8 +259,8 @@ class OpenAIProvider(
configuration=OpenAIConfiguration(
retries_per_request=10,
),
- credentials=ModelProviderCredentials(),
- budget=OpenAIModelProviderBudget(
+ credentials=None,
+ budget=ModelProviderBudget(
total_budget=math.inf,
total_cost=0.0,
remaining_budget=math.inf,
@@ -197,38 +269,41 @@ class OpenAIProvider(
completion_tokens=0,
total_tokens=0,
),
- graceful_shutdown_threshold=0.005,
- warning_threshold=0.01,
),
)
+ _budget: ModelProviderBudget
+ _configuration: OpenAIConfiguration
+
def __init__(
self,
settings: OpenAISettings,
logger: logging.Logger,
):
+ self._settings = settings
+
+ assert settings.credentials, "Cannot create OpenAIProvider without credentials"
self._configuration = settings.configuration
self._credentials = settings.credentials
self._budget = settings.budget
- self._logger = logger
+ if self._credentials.api_type == "azure":
+ from openai import AsyncAzureOpenAI
- retry_handler = _OpenAIRetryHandler(
- logger=self._logger,
- num_retries=self._configuration.retries_per_request,
- )
+ # API key and org (if configured) are passed, the rest of the required
+ # credentials is loaded from the environment by the AzureOpenAI client.
+ self._client = AsyncAzureOpenAI(**self._credentials.get_api_access_kwargs())
+ else:
+ from openai import AsyncOpenAI
+
+ self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
- self._create_chat_completion = retry_handler(_create_chat_completion)
- self._create_embedding = retry_handler(_create_embedding)
+ self._logger = logger
def get_token_limit(self, model_name: str) -> int:
"""Get the token limit for a given model."""
return OPEN_AI_MODELS[model_name].max_tokens
- def get_remaining_budget(self) -> float:
- """Get the remaining budget."""
- return self._budget.remaining_budget
-
@classmethod
def get_tokenizer(cls, model_name: OpenAIModelName) -> ModelTokenizer:
return tiktoken.encoding_for_model(model_name)
@@ -266,7 +341,7 @@ class OpenAIProvider(
try:
encoding = tiktoken.encoding_for_model(encoding_model)
except KeyError:
- cls._logger.warn(
+ logging.getLogger(__class__.__name__).warning(
f"Model {model_name} not found. Defaulting to cl100k_base encoding."
)
encoding = tiktoken.get_encoding("cl100k_base")
@@ -285,7 +360,7 @@ class OpenAIProvider(
self,
model_prompt: list[ChatMessage],
model_name: OpenAIModelName,
- completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
+ completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
**kwargs,
) -> ChatModelResponse[_T]:
@@ -297,27 +372,68 @@ class OpenAIProvider(
model_prompt += completion_kwargs["messages"]
del completion_kwargs["messages"]
- response = await self._create_chat_completion(
- messages=model_prompt,
- **completion_kwargs,
- )
- response_args = {
- "model_info": OPEN_AI_CHAT_MODELS[model_name],
- "prompt_tokens_used": response.usage.prompt_tokens,
- "completion_tokens_used": response.usage.completion_tokens,
- }
+ cost = 0.0
+ attempts = 0
+ while True:
+ _response = await self._create_chat_completion(
+ messages=model_prompt,
+ **completion_kwargs,
+ )
- response_message = response.choices[0].message.to_dict_recursive()
- if tool_calls_compat_mode:
- response_message["tool_calls"] = _tool_calls_compat_extract_calls(
- response_message["content"]
+ _assistant_msg = _response.choices[0].message
+ assistant_msg = AssistantChatMessage(
+ content=_assistant_msg.content,
+ tool_calls=(
+ [AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
+ if _assistant_msg.tool_calls
+ else None
+ ),
)
- response = ChatModelResponse(
- response=response_message,
- parsed_result=completion_parser(response_message),
- **response_args,
- )
- self._budget.update_usage_and_cost(response)
+ response = ChatModelResponse(
+ response=assistant_msg,
+ model_info=OPEN_AI_CHAT_MODELS[model_name],
+ prompt_tokens_used=(
+ _response.usage.prompt_tokens if _response.usage else 0
+ ),
+ completion_tokens_used=(
+ _response.usage.completion_tokens if _response.usage else 0
+ ),
+ )
+ cost += self._budget.update_usage_and_cost(response)
+ self._logger.debug(
+ f"Completion usage: {response.prompt_tokens_used} input, "
+ f"{response.completion_tokens_used} output - ${round(cost, 5)}"
+ )
+
+ # If parsing the response fails, append the error to the prompt, and let the
+ # LLM fix its mistake(s).
+ try:
+ attempts += 1
+
+ if (
+ tool_calls_compat_mode
+ and assistant_msg.content
+ and not assistant_msg.tool_calls
+ ):
+ assistant_msg.tool_calls = list(
+ _tool_calls_compat_extract_calls(assistant_msg.content)
+ )
+
+ response.parsed_result = completion_parser(assistant_msg)
+ break
+ except Exception as e:
+ self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
+ self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''")
+ if attempts < self._configuration.fix_failed_parse_tries:
+ model_prompt.append(
+ ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}")
+ )
+ else:
+ raise
+
+ if attempts > 1:
+ self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}")
+
return response
async def create_embedding(
@@ -331,14 +447,11 @@ class OpenAIProvider(
embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs)
response = await self._create_embedding(text=text, **embedding_kwargs)
- response_args = {
- "model_info": OPEN_AI_EMBEDDING_MODELS[model_name],
- "prompt_tokens_used": response.usage.prompt_tokens,
- "completion_tokens_used": response.usage.completion_tokens,
- }
response = EmbeddingModelResponse(
- **response_args,
- embedding=embedding_parser(response.embeddings[0]),
+ embedding=embedding_parser(response.data[0].embedding),
+ model_info=OPEN_AI_EMBEDDING_MODELS[model_name],
+ prompt_tokens_used=response.usage.prompt_tokens,
+ completion_tokens_used=0,
)
self._budget.update_usage_and_cost(response)
return response
@@ -359,28 +472,29 @@ class OpenAIProvider(
The kwargs for the chat API call.
"""
- completion_kwargs = {
- "model": model_name,
- **kwargs,
- **self._credentials.unmasked(),
- }
+ kwargs.update(self._credentials.get_model_access_kwargs(model_name))
if functions:
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
- completion_kwargs["tools"] = [
+ kwargs["tools"] = [
{"type": "function", "function": f.schema} for f in functions
]
if len(functions) == 1:
# force the model to call the only specified function
- completion_kwargs["tool_choice"] = {
+ kwargs["tool_choice"] = {
"type": "function",
"function": {"name": functions[0].name},
}
else:
# Provide compatibility with older models
- _functions_compat_fix_kwargs(functions, completion_kwargs)
+ _functions_compat_fix_kwargs(functions, kwargs)
+
+ if extra_headers := self._configuration.extra_request_headers:
+ kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update(
+ extra_headers.copy()
+ )
- return completion_kwargs
+ return kwargs
def _get_embedding_kwargs(
self,
@@ -397,114 +511,82 @@ class OpenAIProvider(
The kwargs for the embedding API call.
"""
- embedding_kwargs = {
- "model": model_name,
- **kwargs,
- **self._credentials.unmasked(),
- }
+ kwargs.update(self._credentials.get_model_access_kwargs(model_name))
- return embedding_kwargs
-
- def __repr__(self):
- return "OpenAIProvider()"
-
-
-async def _create_embedding(text: str, *_, **kwargs) -> openai.Embedding:
- """Embed text using the OpenAI API.
-
- Args:
- text str: The text to embed.
- model str: The name of the model to use.
-
- Returns:
- str: The embedding.
- """
- return await openai.Embedding.acreate(
- input=[text],
- **kwargs,
- )
-
-
-async def _create_chat_completion(
- messages: list[ChatMessage], *_, **kwargs
-) -> openai.Completion:
- """Create a chat completion using the OpenAI API.
-
- Args:
- messages: The prompt to use.
-
- Returns:
- The completion.
- """
- raw_messages = [
- message.dict(include={"role", "content", "tool_calls", "name"})
- for message in messages
- ]
- return await openai.ChatCompletion.acreate(
- messages=raw_messages,
- **kwargs,
- )
-
-
-class _OpenAIRetryHandler:
- """Retry Handler for OpenAI API call.
-
- Args:
- num_retries int: Number of retries. Defaults to 10.
- backoff_base float: Base for exponential backoff. Defaults to 2.
- warn_user bool: Whether to warn the user. Defaults to True.
- """
-
- _retry_limit_msg = "Error: Reached rate limit, passing..."
- _api_key_error_msg = (
- "Please double check that you have setup a PAID OpenAI API Account. You can "
- "read more here: https://docs.agpt.co/setup/#getting-an-api-key"
- )
- _backoff_msg = "Error: API Bad gateway. Waiting {backoff} seconds..."
-
- def __init__(
- self,
- logger: logging.Logger,
- num_retries: int = 10,
- backoff_base: float = 2.0,
- warn_user: bool = True,
- ):
- self._logger = logger
- self._num_retries = num_retries
- self._backoff_base = backoff_base
- self._warn_user = warn_user
-
- def _log_rate_limit_error(self) -> None:
- self._logger.debug(self._retry_limit_msg)
- if self._warn_user:
- self._logger.warning(self._api_key_error_msg)
- self._warn_user = False
+ if extra_headers := self._configuration.extra_request_headers:
+ kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update(
+ extra_headers.copy()
+ )
- def _backoff(self, attempt: int) -> None:
- backoff = self._backoff_base ** (attempt + 2)
- self._logger.debug(self._backoff_msg.format(backoff=backoff))
- time.sleep(backoff)
+ return kwargs
+
+ def _create_chat_completion(
+ self, messages: list[ChatMessage], *_, **kwargs
+ ) -> Coroutine[None, None, ChatCompletion]:
+ """Create a chat completion using the OpenAI API with retry handling."""
+
+ @self._retry_api_request
+ async def _create_chat_completion_with_retry(
+ messages: list[ChatMessage], *_, **kwargs
+ ) -> ChatCompletion:
+ raw_messages = [
+ message.dict(include={"role", "content", "tool_calls", "name"})
+ for message in messages
+ ]
+ return await self._client.chat.completions.create(
+ messages=raw_messages, # type: ignore
+ **kwargs,
+ )
- def __call__(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
- @functools.wraps(func)
- async def _wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
- num_attempts = self._num_retries + 1 # +1 for the first attempt
- for attempt in range(1, num_attempts + 1):
- try:
- return await func(*args, **kwargs)
+ return _create_chat_completion_with_retry(messages, *_, **kwargs)
- except RateLimitError:
- if attempt == num_attempts:
- raise
- self._log_rate_limit_error()
+ def _create_embedding(
+ self, text: str, *_, **kwargs
+ ) -> Coroutine[None, None, CreateEmbeddingResponse]:
+ """Create an embedding using the OpenAI API with retry handling."""
- except APIError as e:
- if (e.http_status != 502) or (attempt == num_attempts):
- raise
+ @self._retry_api_request
+ async def _create_embedding_with_retry(
+ text: str, *_, **kwargs
+ ) -> CreateEmbeddingResponse:
+ return await self._client.embeddings.create(
+ input=[text],
+ **kwargs,
+ )
- self._backoff(attempt)
+ return _create_embedding_with_retry(text, *_, **kwargs)
+
+ def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
+ _log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG)
+
+ def _log_on_fail(retry_state: tenacity.RetryCallState) -> None:
+ _log_retry_debug_message(retry_state)
+
+ if (
+ retry_state.attempt_number == 0
+ and retry_state.outcome
+ and isinstance(retry_state.outcome.exception(), RateLimitError)
+ ):
+ self._logger.warning(
+ "Please double check that you have setup a PAID OpenAI API Account."
+ " You can read more here: "
+ "https://docs.agpt.co/setup/#getting-an-openai-api-key"
+ )
+
+ return tenacity.retry(
+ retry=(
+ tenacity.retry_if_exception_type(RateLimitError)
+ | tenacity.retry_if_exception(
+ lambda e: isinstance(e, APIStatusError) and e.status_code == 502
+ )
+ ),
+ wait=tenacity.wait_exponential(),
+ stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
+ after=_log_on_fail,
+ )(func)
- return _wrapped
+ def __repr__(self):
+ return "OpenAIProvider()"
def format_function_specs_as_typescript_ns(
@@ -572,10 +654,12 @@ def count_openai_functions_tokens(
) -> int:
"""Returns the number of tokens taken up by a set of function definitions
- Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
+ Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 # noqa: E501
"""
return count_tokens(
- f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}"
+ "# Tools\n\n"
+ "## functions\n\n"
+ f"{format_function_specs_as_typescript_ns(functions)}"
)
@@ -628,7 +712,7 @@ def _functions_compat_fix_kwargs(
]
-def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDict]:
+def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
import json
import re
@@ -639,10 +723,10 @@ def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDic
else:
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
- raise ValueError("Could not find tool calls block in response")
+ raise ValueError("Could not find tool_calls block in response")
tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
for t in tool_calls:
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
- return tool_calls
+ yield AssistantToolCall.parse_obj(t)