aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/tests/vcr/__init__.py
blob: 4d45aafc9976d3ac0f48e4852eb309501cf31679 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
import os
from hashlib import sha256

import openai.api_requestor
import pytest
from pytest_mock import MockerFixture

from .vcr_filter import (
    PROXY,
    before_record_request,
    before_record_response,
    freeze_request_body,
)

DEFAULT_RECORD_MODE = "new_episodes"
BASE_VCR_CONFIG = {
    "before_record_request": before_record_request,
    "before_record_response": before_record_response,
    "filter_headers": [
        "Authorization",
        "AGENT-MODE",
        "AGENT-TYPE",
        "OpenAI-Organization",
        "X-OpenAI-Client-User-Agent",
        "User-Agent",
    ],
    "match_on": ["method", "headers"],
}


@pytest.fixture(scope="session")
def vcr_config(get_base_vcr_config):
    return get_base_vcr_config


@pytest.fixture(scope="session")
def get_base_vcr_config(request):
    record_mode = request.config.getoption("--record-mode", default="new_episodes")
    config = BASE_VCR_CONFIG

    if record_mode is None:
        config["record_mode"] = DEFAULT_RECORD_MODE

    return config


@pytest.fixture()
def vcr_cassette_dir(request):
    test_name = os.path.splitext(request.node.name)[0]
    return os.path.join("tests/vcr_cassettes", test_name)


def patch_api_base(requestor: openai.api_requestor.APIRequestor):
    new_api_base = f"{PROXY}/v1"
    requestor.api_base = new_api_base
    return requestor


@pytest.fixture
def patched_api_requestor(mocker: MockerFixture):
    init_requestor = openai.api_requestor.APIRequestor.__init__
    prepare_request = openai.api_requestor.APIRequestor._prepare_request_raw

    def patched_init_requestor(requestor, *args, **kwargs):
        init_requestor(requestor, *args, **kwargs)
        patch_api_base(requestor)

    def patched_prepare_request(self, *args, **kwargs):
        url, headers, data = prepare_request(self, *args, **kwargs)

        if PROXY:
            headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
            headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")

        logging.getLogger("patched_api_requestor").debug(
            f"Outgoing API request: {headers}\n{data.decode() if data else None}"
        )

        # Add hash header for cheap & fast matching on cassette playback
        headers["X-Content-Hash"] = sha256(
            freeze_request_body(data), usedforsecurity=False
        ).hexdigest()

        return url, headers, data

    if PROXY:
        mocker.patch.object(
            openai.api_requestor.APIRequestor,
            "__init__",
            new=patched_init_requestor,
        )
    mocker.patch.object(
        openai.api_requestor.APIRequestor,
        "_prepare_request_raw",
        new=patched_prepare_request,
    )