aboutsummaryrefslogtreecommitdiff
path: root/benchmark/agbenchmark/generate_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'benchmark/agbenchmark/generate_test.py')
-rw-r--r--benchmark/agbenchmark/generate_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/benchmark/agbenchmark/generate_test.py b/benchmark/agbenchmark/generate_test.py
new file mode 100644
index 000000000..5bc41971e
--- /dev/null
+++ b/benchmark/agbenchmark/generate_test.py
@@ -0,0 +1,26 @@
+"""
+AGBenchmark's test discovery endpoint for Pytest.
+
+This module is picked up by Pytest's *_test.py file matching pattern, and all challenge
+classes in the module that conform to the `Test*` pattern are collected.
+"""
+
+import importlib
+import logging
+from itertools import chain
+
+from agbenchmark.challenges.builtin import load_builtin_challenges
+from agbenchmark.challenges.webarena import load_webarena_challenges
+
+logger = logging.getLogger(__name__)
+
+DATA_CATEGORY = {}
+
+# Load challenges and attach them to this module
+for challenge in chain(load_builtin_challenges(), load_webarena_challenges()):
+ # Attach the Challenge class to this module so it can be discovered by pytest
+ module = importlib.import_module(__name__)
+ setattr(module, challenge.__name__, challenge)
+
+ # Build a map of challenge names and their primary category
+ DATA_CATEGORY[challenge.info.name] = challenge.info.category[0].value