aboutsummaryrefslogtreecommitdiff
path: root/benchmark/agbenchmark/reports/processing/report_types.py
blob: e2fb1bc6235400e74129a5be43db09399e5d9455 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Any, Dict, List, Union

from pydantic import BaseModel, Field

datetime_format = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00$"
from pydantic import BaseModel, constr


class ForbidOptionalMeta(type(BaseModel)):  # metaclass to forbid optional fields
    def __new__(cls, name: str, bases: tuple, dct: Dict[str, Any]) -> Any:
        for attr_name, attr_value in dct.items():
            if (
                getattr(attr_value, "__origin__", None) == Union
                and type(None) in attr_value.__args__
            ):
                raise TypeError(
                    f"Optional fields are forbidden, but found in {attr_name}"
                )

        return super().__new__(cls, name, bases, dct)


class BaseModelBenchmark(BaseModel, metaclass=ForbidOptionalMeta):
    class Config:
        extra = "forbid"


class Metrics(BaseModelBenchmark):
    difficulty: str
    success: bool
    success_percentage: float = Field(..., alias="success_%")
    run_time: str
    fail_reason: str | None
    attempted: bool
    cost: float | None


class MetricsOverall(BaseModelBenchmark):
    run_time: str
    highest_difficulty: str
    percentage: float | None
    total_cost: float | None


class Test(BaseModelBenchmark):
    data_path: str
    is_regression: bool
    answer: str
    description: str
    metrics: Metrics
    category: List[str]
    task: str
    reached_cutoff: bool


class ReportBase(BaseModelBenchmark):
    command: str
    completion_time: str | None
    benchmark_start_time: constr(regex=datetime_format)
    metrics: MetricsOverall
    config: Dict[str, str | dict[str, str]]
    agent_git_commit_sha: str | None
    benchmark_git_commit_sha: str | None
    repo_url: str | None


class Report(ReportBase):
    tests: Dict[str, Test]


class ReportV2(Test, ReportBase):
    test_name: str
    run_id: str | None
    team_name: str | None