aboutsummaryrefslogtreecommitdiff
path: root/benchmark/agbenchmark/main.py
blob: 4cd97bd89f85e97c8f3b1877d98b07d83eff59d0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import logging
import os
from pathlib import Path
from typing import Optional, Sequence

from dotenv import load_dotenv

from agbenchmark.challenges import get_unique_categories
from agbenchmark.config import AgentBenchmarkConfig

load_dotenv()

logger = logging.getLogger(__name__)


def run_benchmark(
    config: AgentBenchmarkConfig,
    maintain: bool = False,
    improve: bool = False,
    explore: bool = False,
    tests: tuple[str] = tuple(),
    categories: tuple[str] = tuple(),
    skip_categories: tuple[str] = tuple(),
    attempts_per_challenge: int = 1,
    mock: bool = False,
    no_dep: bool = False,
    no_cutoff: bool = False,
    cutoff: Optional[int] = None,
    keep_answers: bool = False,
    server: bool = False,
) -> int:
    """
    Starts the benchmark. If a category flag is provided, only challenges with the
    corresponding mark will be run.
    """
    import pytest

    from agbenchmark.reports.ReportManager import SingletonReportManager

    validate_args(
        maintain=maintain,
        improve=improve,
        explore=explore,
        tests=tests,
        categories=categories,
        skip_categories=skip_categories,
        no_cutoff=no_cutoff,
        cutoff=cutoff,
    )

    SingletonReportManager()

    for key, value in vars(config).items():
        logger.debug(f"config.{key} = {repr(value)}")

    pytest_args = ["-vs"]

    if tests:
        logger.info(f"Running specific test(s): {' '.join(tests)}")
        pytest_args += [f"--test={t}" for t in tests]
    else:
        all_categories = get_unique_categories()

        if categories or skip_categories:
            categories_to_run = set(categories) or all_categories
            if skip_categories:
                categories_to_run = categories_to_run.difference(set(skip_categories))
            assert categories_to_run, "Error: You can't skip all categories"
            pytest_args += [f"--category={c}" for c in categories_to_run]
            logger.info(f"Running tests of category: {categories_to_run}")
        else:
            logger.info("Running all categories")

        if maintain:
            logger.info("Running only regression tests")
        elif improve:
            logger.info("Running only non-regression tests")
        elif explore:
            logger.info("Only attempt challenges that have never been beaten")

    if mock:
        # TODO: unhack
        os.environ[
            "IS_MOCK"
        ] = "True"  # ugly hack to make the mock work when calling from API

    # Pass through flags
    for flag, active in {
        "--maintain": maintain,
        "--improve": improve,
        "--explore": explore,
        "--no-dep": no_dep,
        "--mock": mock,
        "--nc": no_cutoff,
        "--keep-answers": keep_answers,
    }.items():
        if active:
            pytest_args.append(flag)

    if attempts_per_challenge > 1:
        pytest_args.append(f"--attempts={attempts_per_challenge}")

    if cutoff:
        pytest_args.append(f"--cutoff={cutoff}")
        logger.debug(f"Setting cuttoff override to {cutoff} seconds.")

    current_dir = Path(__file__).resolve().parent
    pytest_args.append(str(current_dir / "generate_test.py"))

    pytest_args.append("--cache-clear")
    logger.debug(f"Running Pytest with args: {pytest_args}")
    exit_code = pytest.main(pytest_args)

    SingletonReportManager.clear_instance()
    return exit_code


class InvalidInvocationError(ValueError):
    pass


def validate_args(
    maintain: bool,
    improve: bool,
    explore: bool,
    tests: Sequence[str],
    categories: Sequence[str],
    skip_categories: Sequence[str],
    no_cutoff: bool,
    cutoff: Optional[int],
) -> None:
    if categories:
        all_categories = get_unique_categories()
        invalid_categories = set(categories) - all_categories
        if invalid_categories:
            raise InvalidInvocationError(
                "One or more invalid categories were specified: "
                f"{', '.join(invalid_categories)}.\n"
                f"Valid categories are: {', '.join(all_categories)}."
            )

    if (maintain + improve + explore) > 1:
        raise InvalidInvocationError(
            "You can't use --maintain, --improve or --explore at the same time. "
            "Please choose one."
        )

    if tests and (categories or skip_categories or maintain or improve or explore):
        raise InvalidInvocationError(
            "If you're running a specific test make sure no other options are "
            "selected. Please just pass the --test."
        )

    if no_cutoff and cutoff:
        raise InvalidInvocationError(
            "You can't use both --nc and --cutoff at the same time. "
            "Please choose one."
        )