aboutsummaryrefslogtreecommitdiff
path: root/benchmark
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-02-16 14:58:53 +0100
committerGravatar Reinier van der Leer <pwuts@agpt.co> 2024-02-16 15:11:48 +0100
commit70e345b2cecddb85e6fe136d12ce263643830cb8 (patch)
tree95f847183fff3358c018baa6d56399b80e8bb83d /benchmark
parentchore: Update `agbenchmark` dependency for agent and forge (diff)
downloadAuto-GPT-70e345b2cecddb85e6fe136d12ce263643830cb8.tar.gz
Auto-GPT-70e345b2cecddb85e6fe136d12ce263643830cb8.tar.bz2
Auto-GPT-70e345b2cecddb85e6fe136d12ce263643830cb8.zip
refactor(benchmark): `load_webarena_challenges`
- Reduce duplicate and nested statements - Add `skip_unavailable` parameter Related changes: - Add `available` and `unavailable_reason` attributes to `ChallengeInfo` and `WebArenaChallengeSpec` - Add `pytest.skip` statement to `WebArenaChallenge.test_method` to make sure unavailable challenges are not run
Diffstat (limited to 'benchmark')
-rw-r--r--benchmark/agbenchmark/challenges/base.py3
-rw-r--r--benchmark/agbenchmark/challenges/webarena.py62
2 files changed, 43 insertions, 22 deletions
diff --git a/benchmark/agbenchmark/challenges/base.py b/benchmark/agbenchmark/challenges/base.py
index 4fe73a2d7..f77a08c65 100644
--- a/benchmark/agbenchmark/challenges/base.py
+++ b/benchmark/agbenchmark/challenges/base.py
@@ -28,6 +28,9 @@ class ChallengeInfo(BaseModel):
source_uri: str
"""Internal reference indicating the source of the challenge specification"""
+ available: bool = True
+ unavailable_reason: str = ""
+
class BaseChallenge(ABC):
"""
diff --git a/benchmark/agbenchmark/challenges/webarena.py b/benchmark/agbenchmark/challenges/webarena.py
index a11330c1d..09f801089 100644
--- a/benchmark/agbenchmark/challenges/webarena.py
+++ b/benchmark/agbenchmark/challenges/webarena.py
@@ -179,6 +179,9 @@ class WebArenaChallengeSpec(BaseModel):
intent_template_id: int
instantiation_dict: dict[str, str | list[str]]
+ available: bool = True
+ unavailable_reason: str = ""
+
class EvalSet(BaseModel):
class StringMatchEvalSet(BaseModel):
exact_match: str | None
@@ -288,6 +291,8 @@ class WebArenaChallenge(BaseChallenge):
], # TODO: make categories more specific
reference_answer=spec.eval.reference_answer_raw_annotation,
source_uri=cls.SOURCE_URI_TEMPLATE.format(task_id=spec.task_id),
+ available=spec.available,
+ unavailable_reason=spec.unavailable_reason,
)
return type(
f"Test{challenge_info.name}",
@@ -362,6 +367,9 @@ class WebArenaChallenge(BaseChallenge):
request: pytest.FixtureRequest,
i_attempt: int = 0,
) -> None:
+ if not self._spec.available:
+ pytest.skip(self._spec.unavailable_reason)
+
# if os.environ.get("HELICONE_API_KEY"):
# from helicone.lock import HeliconeLockManager
@@ -426,11 +434,13 @@ class WebArenaChallenge(BaseChallenge):
) + "\n".join(f"{repr(r[0])}\n -> {repr(r[1])}" for r in evals_results)
-def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]:
+def load_webarena_challenges(
+ skip_unavailable: bool = True
+) -> Iterator[type[WebArenaChallenge]]:
logger.info("Loading WebArena challenges...")
for site, info in site_info_map.items():
- if not info.available:
+ if not info.available and skip_unavailable:
logger.warning(
f"JungleGym site '{site}' is not available: {info.unavailable_reason} "
"Skipping all challenges which use this site."
@@ -457,30 +467,38 @@ def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]:
for entry in challenge_dicts:
try:
challenge_spec = WebArenaChallengeSpec.parse_obj(entry)
- for site in challenge_spec.sites:
- site_info = site_info_map.get(site)
- if site_info is None:
- logger.warning(
- f"WebArena task {challenge_spec.task_id} requires unknown site "
- f"'{site}'; skipping..."
- )
- break
- if not site_info.available:
- logger.debug(
- f"WebArena task {challenge_spec.task_id} requires unavailable "
- f"site '{site}'; skipping..."
- )
- break
- else:
- yield WebArenaChallenge.from_challenge_spec(challenge_spec)
- loaded += 1
- continue
- skipped += 1
except ValidationError as e:
failed += 1
logger.warning(f"Error validating WebArena challenge entry: {entry}")
logger.warning(f"Error details: {e}")
+ continue
+
+ # Check all required sites for availability
+ for site in challenge_spec.sites:
+ site_info = site_info_map.get(site)
+ if site_info is None:
+ challenge_spec.available = False
+ challenge_spec.unavailable_reason = (
+ f"WebArena task {challenge_spec.task_id} requires unknown site "
+ f"'{site}'"
+ )
+ elif not site_info.available:
+ challenge_spec.available = False
+ challenge_spec.unavailable_reason = (
+ f"WebArena task {challenge_spec.task_id} requires unavailable "
+ f"site '{site}'"
+ )
+
+ if not challenge_spec.available and skip_unavailable:
+ logger.debug(f"{challenge_spec.unavailable_reason}; skipping...")
+ skipped += 1
+ continue
+
+ yield WebArenaChallenge.from_challenge_spec(challenge_spec)
+ loaded += 1
+
logger.info(
"Loading WebArena challenges complete: "
- f"loaded {loaded}, skipped {skipped}. {failed} challenge failed to load."
+ f"loaded {loaded}, skipped {skipped}."
+ + (f" {failed} challenges failed to load." if failed else "")
)