diff options
Diffstat (limited to 'autogpts/autogpt/tests/vcr/__init__.py')
-rw-r--r-- | autogpts/autogpt/tests/vcr/__init__.py | 54 |
1 files changed, 25 insertions, 29 deletions
diff --git a/autogpts/autogpt/tests/vcr/__init__.py b/autogpts/autogpt/tests/vcr/__init__.py index 4d45aafc9..faef46a68 100644 --- a/autogpts/autogpt/tests/vcr/__init__.py +++ b/autogpts/autogpt/tests/vcr/__init__.py @@ -2,8 +2,11 @@ import logging import os from hashlib import sha256 -import openai.api_requestor import pytest +from openai import OpenAI +from openai._models import FinalRequestOptions +from openai._types import Omit +from openai._utils import is_given from pytest_mock import MockerFixture from .vcr_filter import ( @@ -21,6 +24,7 @@ BASE_VCR_CONFIG = { "Authorization", "AGENT-MODE", "AGENT-TYPE", + "Cookie", "OpenAI-Organization", "X-OpenAI-Client-User-Agent", "User-Agent", @@ -51,30 +55,26 @@ def vcr_cassette_dir(request): return os.path.join("tests/vcr_cassettes", test_name) -def patch_api_base(requestor: openai.api_requestor.APIRequestor): - new_api_base = f"{PROXY}/v1" - requestor.api_base = new_api_base - return requestor - - @pytest.fixture -def patched_api_requestor(mocker: MockerFixture): - init_requestor = openai.api_requestor.APIRequestor.__init__ - prepare_request = openai.api_requestor.APIRequestor._prepare_request_raw +def cached_openai_client(mocker: MockerFixture) -> OpenAI: + client = OpenAI() + _prepare_options = client._prepare_options - def patched_init_requestor(requestor, *args, **kwargs): - init_requestor(requestor, *args, **kwargs) - patch_api_base(requestor) + def _patched_prepare_options(self, options: FinalRequestOptions): + _prepare_options(options) - def patched_prepare_request(self, *args, **kwargs): - url, headers, data = prepare_request(self, *args, **kwargs) + headers: dict[str, str | Omit] = ( + {**options.headers} if is_given(options.headers) else {} + ) + options.headers = headers + data: dict = options.json_data if PROXY: - headers["AGENT-MODE"] = os.environ.get("AGENT_MODE") - headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE") + headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit()) + headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit()) - logging.getLogger("patched_api_requestor").debug( - f"Outgoing API request: {headers}\n{data.decode() if data else None}" + logging.getLogger("cached_openai_client").debug( + f"Outgoing API request: {headers}\n{data if data else None}" ) # Add hash header for cheap & fast matching on cassette playback @@ -82,16 +82,12 @@ def patched_api_requestor(mocker: MockerFixture): freeze_request_body(data), usedforsecurity=False ).hexdigest() - return url, headers, data - if PROXY: - mocker.patch.object( - openai.api_requestor.APIRequestor, - "__init__", - new=patched_init_requestor, - ) + client.base_url = f"{PROXY}/v1" mocker.patch.object( - openai.api_requestor.APIRequestor, - "_prepare_request_raw", - new=patched_prepare_request, + client, + "_prepare_options", + new=_patched_prepare_options, ) + + return client |