aboutsummaryrefslogtreecommitdiff
path: root/benchmark/agbenchmark/challenges/base.py
blob: 64ead3a9df75b1e03db8a632c98933941ed4e1c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import AsyncIterator, ClassVar, Optional

import pytest
from agent_protocol_client import AgentApi, Step
from colorama import Fore, Style
from pydantic import BaseModel, Field

from agbenchmark.config import AgentBenchmarkConfig
from agbenchmark.utils.data_types import Category, DifficultyLevel, EvalResult

logger = logging.getLogger(__name__)


class ChallengeInfo(BaseModel):
    eval_id: str = ""
    name: str
    task: str
    task_artifacts_dir: Optional[Path] = None
    category: list[Category]
    difficulty: Optional[DifficultyLevel] = None
    description: Optional[str] = None
    dependencies: list[str] = Field(default_factory=list)
    reference_answer: Optional[str]

    source_uri: str
    """Internal reference indicating the source of the challenge specification"""


class BaseChallenge(ABC):
    """
    The base class and shared interface for all specific challenge implementations.
    """

    info: ClassVar[ChallengeInfo]

    @classmethod
    @abstractmethod
    def from_source_uri(cls, source_uri: str) -> type["BaseChallenge"]:
        """
        Construct an individual challenge subclass from a suitable `source_uri` (as in
        `ChallengeInfo.source_uri`).
        """
        ...

    @abstractmethod
    def test_method(
        self,
        config: AgentBenchmarkConfig,
        request: pytest.FixtureRequest,
        i_attempt: int,
    ) -> None:
        """
        Test method for use by Pytest-based benchmark sessions. Should return normally
        if the challenge passes, and raise a (preferably descriptive) error otherwise.
        """
        ...

    @classmethod
    async def run_challenge(
        cls, config: AgentBenchmarkConfig, timeout: int
    ) -> AsyncIterator[Step]:
        """
        Runs the challenge on the subject agent with the specified timeout.
        Also prints basic challenge and status info to STDOUT.

        Params:
            config: The subject agent's benchmark config.
            timeout: Timeout (seconds) after which to stop the run if not finished.

        Yields:
            Step: The steps generated by the agent for the challenge task.
        """
        # avoid circular import
        from agbenchmark.agent_api_interface import run_api_agent

        print()
        print(
            f"{Fore.MAGENTA + Style.BRIGHT}{'='*24} "
            f"Starting {cls.info.name} challenge"
            f" {'='*24}{Style.RESET_ALL}"
        )
        print(f"{Fore.CYAN}Timeout:{Fore.RESET} {timeout} seconds")
        print(f"{Fore.CYAN}Task:{Fore.RESET} {cls.info.task}")

        print()
        logger.debug(f"Starting {cls.info.name} challenge run")
        i = 0
        async for step in run_api_agent(
            cls.info.task, config, timeout, cls.info.task_artifacts_dir
        ):
            i += 1
            print(f"[{cls.info.name}] - step {step.name} ({i}. request)")
            yield step
        logger.debug(f"Finished {cls.info.name} challenge run")

    @classmethod
    @abstractmethod
    async def evaluate_task_state(
        cls, agent: AgentApi, task_id: str
    ) -> list[EvalResult]:
        ...