aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-03-22 13:08:15 +0100
committerGravatar GitHub <noreply@github.com> 2024-03-22 13:08:15 +0100
commit6dd76afad5c161b70afaa092b9ba3a7e35e2b3f1 (patch)
tree8787a2f280607392481a72db68357f3022e9568a
parentci(agent): Fix Docker CI for PR runs from forks (vol. 2) (diff)
downloadAuto-GPT-6dd76afad5c161b70afaa092b9ba3a7e35e2b3f1.tar.gz
Auto-GPT-6dd76afad5c161b70afaa092b9ba3a7e35e2b3f1.tar.bz2
Auto-GPT-6dd76afad5c161b70afaa092b9ba3a7e35e2b3f1.zip
test(agent): Fix VCRpy request header filter for cross-platform cassette reuse (#7040)
- Move filtering logic from tests/vcr/__init__.py to tests/vcr/vcr_filter.py - Ignore all `X-Stainless-*` headers for cassette matching, e.g. `X-Stainless-OS` and `X-Stainless-Runtime-Version` - Remove deprecated OpenAI proxy logic - Reorder methods in vcr_filter.py for readability
-rw-r--r--autogpts/autogpt/tests/vcr/__init__.py16
-rw-r--r--autogpts/autogpt/tests/vcr/vcr_filter.py124
2 files changed, 58 insertions, 82 deletions
diff --git a/autogpts/autogpt/tests/vcr/__init__.py b/autogpts/autogpt/tests/vcr/__init__.py
index faef46a68..8d477cfe2 100644
--- a/autogpts/autogpt/tests/vcr/__init__.py
+++ b/autogpts/autogpt/tests/vcr/__init__.py
@@ -10,7 +10,6 @@ from openai._utils import is_given
from pytest_mock import MockerFixture
from .vcr_filter import (
- PROXY,
before_record_request,
before_record_response,
freeze_request_body,
@@ -20,15 +19,6 @@ DEFAULT_RECORD_MODE = "new_episodes"
BASE_VCR_CONFIG = {
"before_record_request": before_record_request,
"before_record_response": before_record_response,
- "filter_headers": [
- "Authorization",
- "AGENT-MODE",
- "AGENT-TYPE",
- "Cookie",
- "OpenAI-Organization",
- "X-OpenAI-Client-User-Agent",
- "User-Agent",
- ],
"match_on": ["method", "headers"],
}
@@ -69,10 +59,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
options.headers = headers
data: dict = options.json_data
- if PROXY:
- headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit())
- headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit())
-
logging.getLogger("cached_openai_client").debug(
f"Outgoing API request: {headers}\n{data if data else None}"
)
@@ -82,8 +68,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
freeze_request_body(data), usedforsecurity=False
).hexdigest()
- if PROXY:
- client.base_url = f"{PROXY}/v1"
mocker.patch.object(
client,
"_prepare_options",
diff --git a/autogpts/autogpt/tests/vcr/vcr_filter.py b/autogpts/autogpt/tests/vcr/vcr_filter.py
index 77b65d207..81c269fc5 100644
--- a/autogpts/autogpt/tests/vcr/vcr_filter.py
+++ b/autogpts/autogpt/tests/vcr/vcr_filter.py
@@ -1,16 +1,27 @@
import contextlib
import json
-import os
import re
from io import BytesIO
-from typing import Any, Dict, List
-from urllib.parse import urlparse, urlunparse
+from typing import Any
from vcr.request import Request
-PROXY = os.environ.get("PROXY")
+HOSTNAMES_TO_CACHE: list[str] = [
+ "api.openai.com",
+ "localhost:50337",
+ "duckduckgo.com",
+]
-REPLACEMENTS: List[Dict[str, str]] = [
+IGNORE_REQUEST_HEADERS: set[str | re.Pattern] = {
+ "Authorization",
+ "Cookie",
+ "OpenAI-Organization",
+ "X-OpenAI-Client-User-Agent",
+ "User-Agent",
+ re.compile(r"X-Stainless-[\w\-]+", re.IGNORECASE),
+}
+
+LLM_MESSAGE_REPLACEMENTS: list[dict[str, str]] = [
{
"regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}",
"replacement": "Tue Jan 1 00:00:00 2000",
@@ -21,46 +32,33 @@ REPLACEMENTS: List[Dict[str, str]] = [
},
]
-ALLOWED_HOSTNAMES: List[str] = [
- "api.openai.com",
- "localhost:50337",
- "duckduckgo.com",
-]
-
-if PROXY:
- ALLOWED_HOSTNAMES.append(PROXY)
- ORIGINAL_URL = PROXY
-else:
- ORIGINAL_URL = "no_ci"
-
-NEW_URL = "api.openai.com"
-
-
-def replace_message_content(content: str, replacements: List[Dict[str, str]]) -> str:
- for replacement in replacements:
- pattern = re.compile(replacement["regex"])
- content = pattern.sub(replacement["replacement"], content)
+OPENAI_URL = "api.openai.com"
- return content
+def before_record_request(request: Request) -> Request | None:
+ if not should_cache_request(request):
+ return None
-def freeze_request_body(body: dict) -> bytes:
- """Remove any dynamic items from the request body"""
+ request = filter_request_headers(request)
+ request = freeze_request(request)
+ return request
- if "messages" not in body:
- return json.dumps(body, sort_keys=True).encode()
- if "max_tokens" in body:
- del body["max_tokens"]
+def should_cache_request(request: Request) -> bool:
+ return any(hostname in request.url for hostname in HOSTNAMES_TO_CACHE)
- for message in body["messages"]:
- if "content" in message and "role" in message:
- if message["role"] == "system":
- message["content"] = replace_message_content(
- message["content"], REPLACEMENTS
- )
- return json.dumps(body, sort_keys=True).encode()
+def filter_request_headers(request: Request) -> Request:
+ for header_name in list(request.headers):
+ if any(
+ (
+ (type(ignore) is str and ignore.lower() == header_name.lower())
+ or (isinstance(ignore, re.Pattern) and ignore.match(header_name))
+ )
+ for ignore in IGNORE_REQUEST_HEADERS
+ ):
+ del request.headers[header_name]
+ return request
def freeze_request(request: Request) -> Request:
@@ -79,40 +77,34 @@ def freeze_request(request: Request) -> Request:
return request
-def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]:
- if "Transfer-Encoding" in response["headers"]:
- del response["headers"]["Transfer-Encoding"]
- return response
-
+def freeze_request_body(body: dict) -> bytes:
+ """Remove any dynamic items from the request body"""
-def before_record_request(request: Request) -> Request | None:
- request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL)
+ if "messages" not in body:
+ return json.dumps(body, sort_keys=True).encode()
- filtered_request = filter_hostnames(request)
- if not filtered_request:
- return None
+ if "max_tokens" in body:
+ del body["max_tokens"]
- filtered_request_without_dynamic_data = freeze_request(filtered_request)
- return filtered_request_without_dynamic_data
+ for message in body["messages"]:
+ if "content" in message and "role" in message:
+ if message["role"] == "system":
+ message["content"] = replace_message_content(
+ message["content"], LLM_MESSAGE_REPLACEMENTS
+ )
+ return json.dumps(body, sort_keys=True).encode()
-def replace_request_hostname(
- request: Request, original_url: str, new_hostname: str
-) -> Request:
- parsed_url = urlparse(request.uri)
- if parsed_url.hostname in original_url:
- new_path = parsed_url.path.replace("/proxy_function", "")
- request.uri = urlunparse(
- parsed_url._replace(netloc=new_hostname, path=new_path, scheme="https")
- )
+def replace_message_content(content: str, replacements: list[dict[str, str]]) -> str:
+ for replacement in replacements:
+ pattern = re.compile(replacement["regex"])
+ content = pattern.sub(replacement["replacement"], content)
- return request
+ return content
-def filter_hostnames(request: Request) -> Request | None:
- # Add your implementation here for filtering hostnames
- if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES):
- return request
- else:
- return None
+def before_record_response(response: dict[str, Any]) -> dict[str, Any]:
+ if "Transfer-Encoding" in response["headers"]:
+ del response["headers"]["Transfer-Encoding"]
+ return response