From 6dd76afad5c161b70afaa092b9ba3a7e35e2b3f1 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 22 Mar 2024 13:08:15 +0100 Subject: test(agent): Fix VCRpy request header filter for cross-platform cassette reuse (#7040) - Move filtering logic from tests/vcr/__init__.py to tests/vcr/vcr_filter.py - Ignore all `X-Stainless-*` headers for cassette matching, e.g. `X-Stainless-OS` and `X-Stainless-Runtime-Version` - Remove deprecated OpenAI proxy logic - Reorder methods in vcr_filter.py for readability --- autogpts/autogpt/tests/vcr/__init__.py | 16 ---- autogpts/autogpt/tests/vcr/vcr_filter.py | 124 +++++++++++++++---------------- 2 files changed, 58 insertions(+), 82 deletions(-) diff --git a/autogpts/autogpt/tests/vcr/__init__.py b/autogpts/autogpt/tests/vcr/__init__.py index faef46a68..8d477cfe2 100644 --- a/autogpts/autogpt/tests/vcr/__init__.py +++ b/autogpts/autogpt/tests/vcr/__init__.py @@ -10,7 +10,6 @@ from openai._utils import is_given from pytest_mock import MockerFixture from .vcr_filter import ( - PROXY, before_record_request, before_record_response, freeze_request_body, @@ -20,15 +19,6 @@ DEFAULT_RECORD_MODE = "new_episodes" BASE_VCR_CONFIG = { "before_record_request": before_record_request, "before_record_response": before_record_response, - "filter_headers": [ - "Authorization", - "AGENT-MODE", - "AGENT-TYPE", - "Cookie", - "OpenAI-Organization", - "X-OpenAI-Client-User-Agent", - "User-Agent", - ], "match_on": ["method", "headers"], } @@ -69,10 +59,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI: options.headers = headers data: dict = options.json_data - if PROXY: - headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit()) - headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit()) - logging.getLogger("cached_openai_client").debug( f"Outgoing API request: {headers}\n{data if data else None}" ) @@ -82,8 +68,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI: freeze_request_body(data), usedforsecurity=False ).hexdigest() - if PROXY: - client.base_url = f"{PROXY}/v1" mocker.patch.object( client, "_prepare_options", diff --git a/autogpts/autogpt/tests/vcr/vcr_filter.py b/autogpts/autogpt/tests/vcr/vcr_filter.py index 77b65d207..81c269fc5 100644 --- a/autogpts/autogpt/tests/vcr/vcr_filter.py +++ b/autogpts/autogpt/tests/vcr/vcr_filter.py @@ -1,16 +1,27 @@ import contextlib import json -import os import re from io import BytesIO -from typing import Any, Dict, List -from urllib.parse import urlparse, urlunparse +from typing import Any from vcr.request import Request -PROXY = os.environ.get("PROXY") +HOSTNAMES_TO_CACHE: list[str] = [ + "api.openai.com", + "localhost:50337", + "duckduckgo.com", +] -REPLACEMENTS: List[Dict[str, str]] = [ +IGNORE_REQUEST_HEADERS: set[str | re.Pattern] = { + "Authorization", + "Cookie", + "OpenAI-Organization", + "X-OpenAI-Client-User-Agent", + "User-Agent", + re.compile(r"X-Stainless-[\w\-]+", re.IGNORECASE), +} + +LLM_MESSAGE_REPLACEMENTS: list[dict[str, str]] = [ { "regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}", "replacement": "Tue Jan 1 00:00:00 2000", @@ -21,46 +32,33 @@ REPLACEMENTS: List[Dict[str, str]] = [ }, ] -ALLOWED_HOSTNAMES: List[str] = [ - "api.openai.com", - "localhost:50337", - "duckduckgo.com", -] - -if PROXY: - ALLOWED_HOSTNAMES.append(PROXY) - ORIGINAL_URL = PROXY -else: - ORIGINAL_URL = "no_ci" - -NEW_URL = "api.openai.com" - - -def replace_message_content(content: str, replacements: List[Dict[str, str]]) -> str: - for replacement in replacements: - pattern = re.compile(replacement["regex"]) - content = pattern.sub(replacement["replacement"], content) +OPENAI_URL = "api.openai.com" - return content +def before_record_request(request: Request) -> Request | None: + if not should_cache_request(request): + return None -def freeze_request_body(body: dict) -> bytes: - """Remove any dynamic items from the request body""" + request = filter_request_headers(request) + request = freeze_request(request) + return request - if "messages" not in body: - return json.dumps(body, sort_keys=True).encode() - if "max_tokens" in body: - del body["max_tokens"] +def should_cache_request(request: Request) -> bool: + return any(hostname in request.url for hostname in HOSTNAMES_TO_CACHE) - for message in body["messages"]: - if "content" in message and "role" in message: - if message["role"] == "system": - message["content"] = replace_message_content( - message["content"], REPLACEMENTS - ) - return json.dumps(body, sort_keys=True).encode() +def filter_request_headers(request: Request) -> Request: + for header_name in list(request.headers): + if any( + ( + (type(ignore) is str and ignore.lower() == header_name.lower()) + or (isinstance(ignore, re.Pattern) and ignore.match(header_name)) + ) + for ignore in IGNORE_REQUEST_HEADERS + ): + del request.headers[header_name] + return request def freeze_request(request: Request) -> Request: @@ -79,40 +77,34 @@ def freeze_request(request: Request) -> Request: return request -def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]: - if "Transfer-Encoding" in response["headers"]: - del response["headers"]["Transfer-Encoding"] - return response - +def freeze_request_body(body: dict) -> bytes: + """Remove any dynamic items from the request body""" -def before_record_request(request: Request) -> Request | None: - request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL) + if "messages" not in body: + return json.dumps(body, sort_keys=True).encode() - filtered_request = filter_hostnames(request) - if not filtered_request: - return None + if "max_tokens" in body: + del body["max_tokens"] - filtered_request_without_dynamic_data = freeze_request(filtered_request) - return filtered_request_without_dynamic_data + for message in body["messages"]: + if "content" in message and "role" in message: + if message["role"] == "system": + message["content"] = replace_message_content( + message["content"], LLM_MESSAGE_REPLACEMENTS + ) + return json.dumps(body, sort_keys=True).encode() -def replace_request_hostname( - request: Request, original_url: str, new_hostname: str -) -> Request: - parsed_url = urlparse(request.uri) - if parsed_url.hostname in original_url: - new_path = parsed_url.path.replace("/proxy_function", "") - request.uri = urlunparse( - parsed_url._replace(netloc=new_hostname, path=new_path, scheme="https") - ) +def replace_message_content(content: str, replacements: list[dict[str, str]]) -> str: + for replacement in replacements: + pattern = re.compile(replacement["regex"]) + content = pattern.sub(replacement["replacement"], content) - return request + return content -def filter_hostnames(request: Request) -> Request | None: - # Add your implementation here for filtering hostnames - if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES): - return request - else: - return None +def before_record_response(response: dict[str, Any]) -> dict[str, Any]: + if "Transfer-Encoding" in response["headers"]: + del response["headers"]["Transfer-Encoding"] + return response -- cgit v1.2.3