diff options
Diffstat (limited to 'benchmark/agbenchmark/reports/processing/report_types.py')
-rw-r--r-- | benchmark/agbenchmark/reports/processing/report_types.py | 109 |
1 files changed, 61 insertions, 48 deletions
diff --git a/benchmark/agbenchmark/reports/processing/report_types.py b/benchmark/agbenchmark/reports/processing/report_types.py index e2fb1bc62..b6deef021 100644 --- a/benchmark/agbenchmark/reports/processing/report_types.py +++ b/benchmark/agbenchmark/reports/processing/report_types.py @@ -1,74 +1,87 @@ -from typing import Any, Dict, List, Union +""" +Model definitions used internally and for reports generated during command-line runs. +""" -from pydantic import BaseModel, Field +from typing import Any, Dict, List -datetime_format = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00$" -from pydantic import BaseModel, constr - - -class ForbidOptionalMeta(type(BaseModel)): # metaclass to forbid optional fields - def __new__(cls, name: str, bases: tuple, dct: Dict[str, Any]) -> Any: - for attr_name, attr_value in dct.items(): - if ( - getattr(attr_value, "__origin__", None) == Union - and type(None) in attr_value.__args__ - ): - raise TypeError( - f"Optional fields are forbidden, but found in {attr_name}" - ) - - return super().__new__(cls, name, bases, dct) +from pydantic import BaseModel, Field, constr, validator +datetime_format = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00$" -class BaseModelBenchmark(BaseModel, metaclass=ForbidOptionalMeta): - class Config: - extra = "forbid" +class TestResult(BaseModel): + """Result details for a single run of a test/challenge.""" + + success: bool | None = None + """Whether the run was successful""" + run_time: str | None = None + """The (formatted) duration of the run""" + fail_reason: str | None = None + """If applicable, the reason why the run was not successful""" + reached_cutoff: bool | None = None # None if in progress + """Whether the run had to be stopped due to reaching the timeout""" + cost: float | None = None + """The (known) cost incurred by the run, e.g. from using paid LLM APIs""" + + @validator("fail_reason") + def success_xor_fail_reason(cls, v: str | None, values: dict[str, Any]): + if v: + success = values["success"] + assert not success, "fail_reason must only be specified if success=False" + else: + assert values["success"], "fail_reason is required if success=False" + return v + + +class TestMetrics(BaseModel): + """ + Result metrics for a set of runs for a test/challenge. Should be an aggregate of all + results for the same test/challenge within a benchmarking session. + """ -class Metrics(BaseModelBenchmark): - difficulty: str - success: bool - success_percentage: float = Field(..., alias="success_%") - run_time: str - fail_reason: str | None attempted: bool - cost: float | None + """Whether the challenge was attempted during this session""" + is_regression: bool + """Whether the challenge was considered a regression test at the time of running""" + success_percentage: float | None = Field(default=None, alias="success_%") + """Success rate (0-100) for this challenge within the session""" + +class MetricsOverall(BaseModel): + """Global metrics concerning a benchmarking session""" -class MetricsOverall(BaseModelBenchmark): run_time: str + """Duration from beginning to end of the session""" highest_difficulty: str - percentage: float | None - total_cost: float | None + """ + Difficulty of the most difficult challenge that succeeded at least once this session + """ + total_cost: float | None = None + """Total known cost of the session""" -class Test(BaseModelBenchmark): +class Test(BaseModel): + category: List[str] + difficulty: str | None data_path: str - is_regression: bool - answer: str description: str - metrics: Metrics - category: List[str] task: str - reached_cutoff: bool + answer: str + metrics: TestMetrics + results: list[TestResult] + metadata: dict[str, Any] | None = Field(default_factory=dict) -class ReportBase(BaseModelBenchmark): +class ReportBase(BaseModel): command: str - completion_time: str | None + completion_time: str | None = None benchmark_start_time: constr(regex=datetime_format) metrics: MetricsOverall config: Dict[str, str | dict[str, str]] - agent_git_commit_sha: str | None - benchmark_git_commit_sha: str | None - repo_url: str | None + agent_git_commit_sha: str | None = None + benchmark_git_commit_sha: str | None = None + repo_url: str | None = None class Report(ReportBase): tests: Dict[str, Test] - - -class ReportV2(Test, ReportBase): - test_name: str - run_id: str | None - team_name: str | None |