diff options
Diffstat (limited to 'benchmark/agbenchmark/generate_test.py')
-rw-r--r-- | benchmark/agbenchmark/generate_test.py | 235 |
1 files changed, 18 insertions, 217 deletions
diff --git a/benchmark/agbenchmark/generate_test.py b/benchmark/agbenchmark/generate_test.py index 544e09ee7..5bc41971e 100644 --- a/benchmark/agbenchmark/generate_test.py +++ b/benchmark/agbenchmark/generate_test.py @@ -1,225 +1,26 @@ -import glob -import importlib -import json -import os -import sys -import types -from collections import deque -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import pytest - -from agbenchmark.__main__ import CHALLENGES_ALREADY_BEATEN -from agbenchmark.agent_api_interface import append_updates_file -from agbenchmark.agent_protocol_client.models.step import Step -from agbenchmark.utils.challenge import Challenge -from agbenchmark.utils.data_types import AgentBenchmarkConfig, ChallengeData - -DATA_CATEGORY = {} - - -def create_single_test( - data: Dict[str, Any] | ChallengeData, - challenge_location: str, - file_datum: Optional[list[dict[str, Any]]] = None, -) -> None: - challenge_data = None - artifacts_location = None - if isinstance(data, ChallengeData): - challenge_data = data - data = data.get_data() - - DATA_CATEGORY[data["name"]] = data["category"][0] - - # Define test class dynamically - challenge_class = types.new_class(f"Test{data['name']}", (Challenge,)) - print(f"challenge_class: {challenge_class}") - # clean_challenge_location = get_test_path(challenge_location) - setattr(challenge_class, "CHALLENGE_LOCATION", challenge_location) - - setattr( - challenge_class, - "ARTIFACTS_LOCATION", - artifacts_location or str(Path(challenge_location).resolve().parent), - ) - - # Define test method within the dynamically created class - @pytest.mark.asyncio - async def test_method(self, config: Dict[str, Any], request) -> None: # type: ignore - # create a random number between 0 and 1 - test_name = self.data.name - - try: - with open(CHALLENGES_ALREADY_BEATEN, "r") as f: - challenges_beaten_in_the_past = json.load(f) - except: - challenges_beaten_in_the_past = {} +""" +AGBenchmark's test discovery endpoint for Pytest. - if request.config.getoption("--explore") and challenges_beaten_in_the_past.get( - test_name, False - ): - return None +This module is picked up by Pytest's *_test.py file matching pattern, and all challenge +classes in the module that conform to the `Test*` pattern are collected. +""" - # skip optional categories - self.skip_optional_categories(config) - - from helicone.lock import HeliconeLockManager - - if os.environ.get("HELICONE_API_KEY"): - HeliconeLockManager.write_custom_property("challenge", self.data.name) - - cutoff = self.data.cutoff or 60 - - timeout = cutoff - if "--nc" in sys.argv: - timeout = 100000 - if "--cutoff" in sys.argv: - timeout = int(sys.argv[sys.argv.index("--cutoff") + 1]) - - await self.setup_challenge(config, timeout) - - scores = self.get_scores(config) - request.node.answers = ( - scores["answers"] if "--keep-answers" in sys.argv else None - ) - del scores["answers"] # remove answers from scores - request.node.scores = scores # store scores in request.node - is_score_100 = 1 in scores["values"] - - evaluation = "Correct!" if is_score_100 else "Incorrect." - eval_step = Step( - input=evaluation, - additional_input=None, - task_id="irrelevant, this step is a hack", - step_id="irrelevant, this step is a hack", - name="", - status="created", - output=None, - additional_output=None, - artifacts=[], - is_last=True, - ) - await append_updates_file(eval_step) +import importlib +import logging +from itertools import chain - assert is_score_100 +from agbenchmark.challenges.builtin import load_builtin_challenges +from agbenchmark.challenges.webarena import load_webarena_challenges - # Parametrize the method here - test_method = pytest.mark.parametrize( - "challenge_data", - [data], - indirect=True, - )(test_method) +logger = logging.getLogger(__name__) - setattr(challenge_class, "test_method", test_method) - print(f"Challenge Class {challenge_class}") +DATA_CATEGORY = {} - # Attach the new class to a module so it can be discovered by pytest +# Load challenges and attach them to this module +for challenge in chain(load_builtin_challenges(), load_webarena_challenges()): + # Attach the Challenge class to this module so it can be discovered by pytest module = importlib.import_module(__name__) - setattr(module, f"Test{data['name']}", challenge_class) - return challenge_class - - -def create_single_suite_challenge(challenge_data: ChallengeData, path: Path) -> None: - create_single_test(challenge_data, str(path)) - - -def create_challenge( - data: Dict[str, Any], - json_file: str, - json_files: deque, -) -> Union[deque, Any]: - path = Path(json_file).resolve() - print("Creating challenge for", path) - - challenge_class = create_single_test(data, str(path)) - print("Creation complete for", path) - - return json_files, challenge_class - - -def generate_tests() -> None: # sourcery skip: invert-any-all - print("Generating tests...") - - challenges_path = os.path.join(os.path.dirname(__file__), "challenges") - print(f"Looking for challenges in {challenges_path}...") - - json_files = deque( - glob.glob( - f"{challenges_path}/**/data.json", - recursive=True, - ) - ) - - print(f"Found {len(json_files)} challenges.") - print(f"Sample path: {json_files[0]}") - - agent_benchmark_config_path = str(Path.cwd() / "agbenchmark_config" / "config.json") - try: - with open(agent_benchmark_config_path, "r") as f: - agent_benchmark_config = AgentBenchmarkConfig(**json.load(f)) - agent_benchmark_config.agent_benchmark_config_path = ( - agent_benchmark_config_path - ) - except json.JSONDecodeError: - print("Error: benchmark_config.json is not a valid JSON file.") - raise - - regression_reports_path = agent_benchmark_config.get_regression_reports_path() - if regression_reports_path and os.path.exists(regression_reports_path): - with open(regression_reports_path, "r") as f: - regression_tests = json.load(f) - else: - regression_tests = {} - - while json_files: - json_file = ( - json_files.popleft() - ) # Take and remove the first element from json_files - if challenge_should_be_ignored(json_file): - continue - - data = ChallengeData.get_json_from_path(json_file) - - commands = sys.argv - # --by flag - if "--category" in commands: - categories = data.get("category", []) - commands_set = set(commands) - - # Convert the combined list to a set - categories_set = set(categories) - - # If there's no overlap with commands - if not categories_set.intersection(commands_set): - continue - - # --test flag, only run the test if it's the exact one specified - tests = [] - for command in commands: - if command.startswith("--test="): - tests.append(command.split("=")[1]) - - if tests and data["name"] not in tests: - continue - - # --maintain and --improve flag - in_regression = regression_tests.get(data["name"], None) - improve_flag = in_regression and "--improve" in commands - maintain_flag = not in_regression and "--maintain" in commands - if "--maintain" in commands and maintain_flag: - continue - elif "--improve" in commands and improve_flag: - continue - json_files, challenge_class = create_challenge(data, json_file, json_files) - - print(f"Generated test for {data['name']}.") - print(f"- {data}") - print("Test generation complete.") - - -def challenge_should_be_ignored(json_file): - return "challenges/deprecated" in json_file or "challenges/library" in json_file - + setattr(module, challenge.__name__, challenge) -generate_tests() + # Build a map of challenge names and their primary category + DATA_CATEGORY[challenge.info.name] = challenge.info.category[0].value |