diff options
Diffstat (limited to 'benchmark/agbenchmark/challenges/base.py')
-rw-r--r-- | benchmark/agbenchmark/challenges/base.py | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/benchmark/agbenchmark/challenges/base.py b/benchmark/agbenchmark/challenges/base.py new file mode 100644 index 000000000..64ead3a9d --- /dev/null +++ b/benchmark/agbenchmark/challenges/base.py @@ -0,0 +1,104 @@ +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import AsyncIterator, ClassVar, Optional + +import pytest +from agent_protocol_client import AgentApi, Step +from colorama import Fore, Style +from pydantic import BaseModel, Field + +from agbenchmark.config import AgentBenchmarkConfig +from agbenchmark.utils.data_types import Category, DifficultyLevel, EvalResult + +logger = logging.getLogger(__name__) + + +class ChallengeInfo(BaseModel): + eval_id: str = "" + name: str + task: str + task_artifacts_dir: Optional[Path] = None + category: list[Category] + difficulty: Optional[DifficultyLevel] = None + description: Optional[str] = None + dependencies: list[str] = Field(default_factory=list) + reference_answer: Optional[str] + + source_uri: str + """Internal reference indicating the source of the challenge specification""" + + +class BaseChallenge(ABC): + """ + The base class and shared interface for all specific challenge implementations. + """ + + info: ClassVar[ChallengeInfo] + + @classmethod + @abstractmethod + def from_source_uri(cls, source_uri: str) -> type["BaseChallenge"]: + """ + Construct an individual challenge subclass from a suitable `source_uri` (as in + `ChallengeInfo.source_uri`). + """ + ... + + @abstractmethod + def test_method( + self, + config: AgentBenchmarkConfig, + request: pytest.FixtureRequest, + i_attempt: int, + ) -> None: + """ + Test method for use by Pytest-based benchmark sessions. Should return normally + if the challenge passes, and raise a (preferably descriptive) error otherwise. + """ + ... + + @classmethod + async def run_challenge( + cls, config: AgentBenchmarkConfig, timeout: int + ) -> AsyncIterator[Step]: + """ + Runs the challenge on the subject agent with the specified timeout. + Also prints basic challenge and status info to STDOUT. + + Params: + config: The subject agent's benchmark config. + timeout: Timeout (seconds) after which to stop the run if not finished. + + Yields: + Step: The steps generated by the agent for the challenge task. + """ + # avoid circular import + from agbenchmark.agent_api_interface import run_api_agent + + print() + print( + f"{Fore.MAGENTA + Style.BRIGHT}{'='*24} " + f"Starting {cls.info.name} challenge" + f" {'='*24}{Style.RESET_ALL}" + ) + print(f"{Fore.CYAN}Timeout:{Fore.RESET} {timeout} seconds") + print(f"{Fore.CYAN}Task:{Fore.RESET} {cls.info.task}") + + print() + logger.debug(f"Starting {cls.info.name} challenge run") + i = 0 + async for step in run_api_agent( + cls.info.task, config, timeout, cls.info.task_artifacts_dir + ): + i += 1 + print(f"[{cls.info.name}] - step {step.name} ({i}. request)") + yield step + logger.debug(f"Finished {cls.info.name} challenge run") + + @classmethod + @abstractmethod + async def evaluate_task_state( + cls, agent: AgentApi, task_id: str + ) -> list[EvalResult]: + ... |