diff options
author | Reinier van der Leer <pwuts@agpt.co> | 2024-02-16 14:58:53 +0100 |
---|---|---|
committer | Reinier van der Leer <pwuts@agpt.co> | 2024-02-16 15:11:48 +0100 |
commit | 70e345b2cecddb85e6fe136d12ce263643830cb8 (patch) | |
tree | 95f847183fff3358c018baa6d56399b80e8bb83d /benchmark | |
parent | chore: Update `agbenchmark` dependency for agent and forge (diff) | |
download | Auto-GPT-70e345b2cecddb85e6fe136d12ce263643830cb8.tar.gz Auto-GPT-70e345b2cecddb85e6fe136d12ce263643830cb8.tar.bz2 Auto-GPT-70e345b2cecddb85e6fe136d12ce263643830cb8.zip |
refactor(benchmark): `load_webarena_challenges`
- Reduce duplicate and nested statements
- Add `skip_unavailable` parameter
Related changes:
- Add `available` and `unavailable_reason` attributes to `ChallengeInfo` and `WebArenaChallengeSpec`
- Add `pytest.skip` statement to `WebArenaChallenge.test_method` to make sure unavailable challenges are not run
Diffstat (limited to 'benchmark')
-rw-r--r-- | benchmark/agbenchmark/challenges/base.py | 3 | ||||
-rw-r--r-- | benchmark/agbenchmark/challenges/webarena.py | 62 |
2 files changed, 43 insertions, 22 deletions
diff --git a/benchmark/agbenchmark/challenges/base.py b/benchmark/agbenchmark/challenges/base.py index 4fe73a2d7..f77a08c65 100644 --- a/benchmark/agbenchmark/challenges/base.py +++ b/benchmark/agbenchmark/challenges/base.py @@ -28,6 +28,9 @@ class ChallengeInfo(BaseModel): source_uri: str """Internal reference indicating the source of the challenge specification""" + available: bool = True + unavailable_reason: str = "" + class BaseChallenge(ABC): """ diff --git a/benchmark/agbenchmark/challenges/webarena.py b/benchmark/agbenchmark/challenges/webarena.py index a11330c1d..09f801089 100644 --- a/benchmark/agbenchmark/challenges/webarena.py +++ b/benchmark/agbenchmark/challenges/webarena.py @@ -179,6 +179,9 @@ class WebArenaChallengeSpec(BaseModel): intent_template_id: int instantiation_dict: dict[str, str | list[str]] + available: bool = True + unavailable_reason: str = "" + class EvalSet(BaseModel): class StringMatchEvalSet(BaseModel): exact_match: str | None @@ -288,6 +291,8 @@ class WebArenaChallenge(BaseChallenge): ], # TODO: make categories more specific reference_answer=spec.eval.reference_answer_raw_annotation, source_uri=cls.SOURCE_URI_TEMPLATE.format(task_id=spec.task_id), + available=spec.available, + unavailable_reason=spec.unavailable_reason, ) return type( f"Test{challenge_info.name}", @@ -362,6 +367,9 @@ class WebArenaChallenge(BaseChallenge): request: pytest.FixtureRequest, i_attempt: int = 0, ) -> None: + if not self._spec.available: + pytest.skip(self._spec.unavailable_reason) + # if os.environ.get("HELICONE_API_KEY"): # from helicone.lock import HeliconeLockManager @@ -426,11 +434,13 @@ class WebArenaChallenge(BaseChallenge): ) + "\n".join(f"{repr(r[0])}\n -> {repr(r[1])}" for r in evals_results) -def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]: +def load_webarena_challenges( + skip_unavailable: bool = True +) -> Iterator[type[WebArenaChallenge]]: logger.info("Loading WebArena challenges...") for site, info in site_info_map.items(): - if not info.available: + if not info.available and skip_unavailable: logger.warning( f"JungleGym site '{site}' is not available: {info.unavailable_reason} " "Skipping all challenges which use this site." @@ -457,30 +467,38 @@ def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]: for entry in challenge_dicts: try: challenge_spec = WebArenaChallengeSpec.parse_obj(entry) - for site in challenge_spec.sites: - site_info = site_info_map.get(site) - if site_info is None: - logger.warning( - f"WebArena task {challenge_spec.task_id} requires unknown site " - f"'{site}'; skipping..." - ) - break - if not site_info.available: - logger.debug( - f"WebArena task {challenge_spec.task_id} requires unavailable " - f"site '{site}'; skipping..." - ) - break - else: - yield WebArenaChallenge.from_challenge_spec(challenge_spec) - loaded += 1 - continue - skipped += 1 except ValidationError as e: failed += 1 logger.warning(f"Error validating WebArena challenge entry: {entry}") logger.warning(f"Error details: {e}") + continue + + # Check all required sites for availability + for site in challenge_spec.sites: + site_info = site_info_map.get(site) + if site_info is None: + challenge_spec.available = False + challenge_spec.unavailable_reason = ( + f"WebArena task {challenge_spec.task_id} requires unknown site " + f"'{site}'" + ) + elif not site_info.available: + challenge_spec.available = False + challenge_spec.unavailable_reason = ( + f"WebArena task {challenge_spec.task_id} requires unavailable " + f"site '{site}'" + ) + + if not challenge_spec.available and skip_unavailable: + logger.debug(f"{challenge_spec.unavailable_reason}; skipping...") + skipped += 1 + continue + + yield WebArenaChallenge.from_challenge_spec(challenge_spec) + loaded += 1 + logger.info( "Loading WebArena challenges complete: " - f"loaded {loaded}, skipped {skipped}. {failed} challenge failed to load." + f"loaded {loaded}, skipped {skipped}." + + (f" {failed} challenges failed to load." if failed else "") ) |