aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/tests/vcr/vcr_filter.py
blob: 81c269fc50d1f9aa6aafe4f7bc537737e98893c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import contextlib
import json
import re
from io import BytesIO
from typing import Any

from vcr.request import Request

HOSTNAMES_TO_CACHE: list[str] = [
    "api.openai.com",
    "localhost:50337",
    "duckduckgo.com",
]

IGNORE_REQUEST_HEADERS: set[str | re.Pattern] = {
    "Authorization",
    "Cookie",
    "OpenAI-Organization",
    "X-OpenAI-Client-User-Agent",
    "User-Agent",
    re.compile(r"X-Stainless-[\w\-]+", re.IGNORECASE),
}

LLM_MESSAGE_REPLACEMENTS: list[dict[str, str]] = [
    {
        "regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}",
        "replacement": "Tue Jan  1 00:00:00 2000",
    },
    {
        "regex": r"<selenium.webdriver.chrome.webdriver.WebDriver[^>]*>",
        "replacement": "",
    },
]

OPENAI_URL = "api.openai.com"


def before_record_request(request: Request) -> Request | None:
    if not should_cache_request(request):
        return None

    request = filter_request_headers(request)
    request = freeze_request(request)
    return request


def should_cache_request(request: Request) -> bool:
    return any(hostname in request.url for hostname in HOSTNAMES_TO_CACHE)


def filter_request_headers(request: Request) -> Request:
    for header_name in list(request.headers):
        if any(
            (
                (type(ignore) is str and ignore.lower() == header_name.lower())
                or (isinstance(ignore, re.Pattern) and ignore.match(header_name))
            )
            for ignore in IGNORE_REQUEST_HEADERS
        ):
            del request.headers[header_name]
    return request


def freeze_request(request: Request) -> Request:
    if not request or not request.body:
        return request

    with contextlib.suppress(ValueError):
        request.body = freeze_request_body(
            json.loads(
                request.body.getvalue()
                if isinstance(request.body, BytesIO)
                else request.body
            )
        )

    return request


def freeze_request_body(body: dict) -> bytes:
    """Remove any dynamic items from the request body"""

    if "messages" not in body:
        return json.dumps(body, sort_keys=True).encode()

    if "max_tokens" in body:
        del body["max_tokens"]

    for message in body["messages"]:
        if "content" in message and "role" in message:
            if message["role"] == "system":
                message["content"] = replace_message_content(
                    message["content"], LLM_MESSAGE_REPLACEMENTS
                )

    return json.dumps(body, sort_keys=True).encode()


def replace_message_content(content: str, replacements: list[dict[str, str]]) -> str:
    for replacement in replacements:
        pattern = re.compile(replacement["regex"])
        content = pattern.sub(replacement["replacement"], content)

    return content


def before_record_response(response: dict[str, Any]) -> dict[str, Any]:
    if "Transfer-Encoding" in response["headers"]:
        del response["headers"]["Transfer-Encoding"]
    return response