aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/tests/vcr/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/tests/vcr/__init__.py')
-rw-r--r--autogpts/autogpt/tests/vcr/__init__.py54
1 files changed, 25 insertions, 29 deletions
diff --git a/autogpts/autogpt/tests/vcr/__init__.py b/autogpts/autogpt/tests/vcr/__init__.py
index 4d45aafc9..faef46a68 100644
--- a/autogpts/autogpt/tests/vcr/__init__.py
+++ b/autogpts/autogpt/tests/vcr/__init__.py
@@ -2,8 +2,11 @@ import logging
import os
from hashlib import sha256
-import openai.api_requestor
import pytest
+from openai import OpenAI
+from openai._models import FinalRequestOptions
+from openai._types import Omit
+from openai._utils import is_given
from pytest_mock import MockerFixture
from .vcr_filter import (
@@ -21,6 +24,7 @@ BASE_VCR_CONFIG = {
"Authorization",
"AGENT-MODE",
"AGENT-TYPE",
+ "Cookie",
"OpenAI-Organization",
"X-OpenAI-Client-User-Agent",
"User-Agent",
@@ -51,30 +55,26 @@ def vcr_cassette_dir(request):
return os.path.join("tests/vcr_cassettes", test_name)
-def patch_api_base(requestor: openai.api_requestor.APIRequestor):
- new_api_base = f"{PROXY}/v1"
- requestor.api_base = new_api_base
- return requestor
-
-
@pytest.fixture
-def patched_api_requestor(mocker: MockerFixture):
- init_requestor = openai.api_requestor.APIRequestor.__init__
- prepare_request = openai.api_requestor.APIRequestor._prepare_request_raw
+def cached_openai_client(mocker: MockerFixture) -> OpenAI:
+ client = OpenAI()
+ _prepare_options = client._prepare_options
- def patched_init_requestor(requestor, *args, **kwargs):
- init_requestor(requestor, *args, **kwargs)
- patch_api_base(requestor)
+ def _patched_prepare_options(self, options: FinalRequestOptions):
+ _prepare_options(options)
- def patched_prepare_request(self, *args, **kwargs):
- url, headers, data = prepare_request(self, *args, **kwargs)
+ headers: dict[str, str | Omit] = (
+ {**options.headers} if is_given(options.headers) else {}
+ )
+ options.headers = headers
+ data: dict = options.json_data
if PROXY:
- headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
- headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
+ headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit())
+ headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit())
- logging.getLogger("patched_api_requestor").debug(
- f"Outgoing API request: {headers}\n{data.decode() if data else None}"
+ logging.getLogger("cached_openai_client").debug(
+ f"Outgoing API request: {headers}\n{data if data else None}"
)
# Add hash header for cheap & fast matching on cassette playback
@@ -82,16 +82,12 @@ def patched_api_requestor(mocker: MockerFixture):
freeze_request_body(data), usedforsecurity=False
).hexdigest()
- return url, headers, data
-
if PROXY:
- mocker.patch.object(
- openai.api_requestor.APIRequestor,
- "__init__",
- new=patched_init_requestor,
- )
+ client.base_url = f"{PROXY}/v1"
mocker.patch.object(
- openai.api_requestor.APIRequestor,
- "_prepare_request_raw",
- new=patched_prepare_request,
+ client,
+ "_prepare_options",
+ new=_patched_prepare_options,
)
+
+ return client