diff options
Diffstat (limited to 'benchmark/agbenchmark/challenges/webarena.py')
-rw-r--r-- | benchmark/agbenchmark/challenges/webarena.py | 478 |
1 files changed, 478 insertions, 0 deletions
diff --git a/benchmark/agbenchmark/challenges/webarena.py b/benchmark/agbenchmark/challenges/webarena.py new file mode 100644 index 000000000..24f569327 --- /dev/null +++ b/benchmark/agbenchmark/challenges/webarena.py @@ -0,0 +1,478 @@ +import logging +import os +from abc import ABC, abstractmethod +from typing import ClassVar, Iterator, Literal + +import pytest +import requests +from agent_protocol_client import AgentApi, Step +from pydantic import BaseModel, validator, ValidationError + +from agbenchmark.config import AgentBenchmarkConfig +from agbenchmark.utils.data_types import Category, EvalResult + +from .base import BaseChallenge, ChallengeInfo + +logger = logging.getLogger(__name__) + + +EvalType = Literal["string_match", "url_match", "program_html"] +WebArenaSite = Literal[ + "gitlab", "map", "reddit", "shopping", "shopping_admin", "wikipedia" +] +ReferenceAnswerType = Literal["exact_match", "fuzzy_match", "must_include"] + + +class WebArenaSiteInfo(BaseModel): + base_url: str + available: bool = True + additional_info: str = "" + unavailable_reason: str = "" + + +_git_user, _git_password = os.getenv("WEBARENA_GIT_CREDENTIALS", ":").split(":") + +site_info_map: dict[WebArenaSite, WebArenaSiteInfo] = { + "gitlab": WebArenaSiteInfo( + base_url="http://git.junglegym.ai", + available=bool(_git_user and _git_password), + additional_info=( + f"To log in, use the username '{_git_user}' and password '{_git_password}'." + ), + unavailable_reason=( + "WEBARENA_GIT_CREDENTIALS not set (correctly): " + f"'{os.getenv('WEBARENA_GIT_CREDENTIALS', '')}', " + "should be USERNAME:PASSWORD." + ), + ), + "map": WebArenaSiteInfo( + base_url="http://ec2-3-131-244-37.us-east-2.compute.amazonaws.com:3000/" + ), + "reddit": WebArenaSiteInfo(base_url="http://forum.junglegym.ai"), + "shopping": WebArenaSiteInfo(base_url="http://shop.junglegym.ai"), + "shopping_admin": WebArenaSiteInfo( + base_url="http://cms.junglegym.ai/admin", + additional_info="To log in, use the username 'admin' and password 'admin1234'.", + ), + "wikipedia": WebArenaSiteInfo(base_url="http://wiki.junglegym.ai"), +} + + +def get_site_url(site: WebArenaSite) -> str: + if site not in site_info_map: + raise ValueError(f"JungleGym site '{site}' unknown, cannot resolve URL") + return site_info_map[site].base_url + + +def resolve_uri(uri: str) -> str: + """ + Resolves URIs with mock hosts, like `__WIKI__/wiki/Octopus`, with the corresponding + JungleGym site mirror host. + """ + segments = uri.split("__") + if len(segments) > 2 and (site := segments[1]).lower() in site_info_map: + return uri.replace(f"__{site}__", get_site_url(site.lower())) # type: ignore + return uri + + +class Eval(ABC): + @abstractmethod + def evaluate(self, string: str) -> bool: + ... + + @property + @abstractmethod + def description(self) -> str: + ... + + +class StringEval(BaseModel, Eval): + type: ReferenceAnswerType + + +class ExactStringMatchEval(StringEval): + type: Literal["exact_match"] = "exact_match" + reference_answer: str + + @property + def description(self) -> str: + return f"Answer must be '{self.reference_answer}'" + + def evaluate(self, string: str) -> bool: + return string == self.reference_answer + + +class FuzzyStringMatchEval(StringEval): + type: Literal["fuzzy_match"] = "fuzzy_match" + reference_answer: str + + @property + def description(self) -> str: + return f"Answer must contain something like '{self.reference_answer}'" + + def evaluate(self, string: str) -> bool: + # TODO: use LLM for matching (or something else that's flexible/robust) + return self.reference_answer.lower() in string.lower() + + +class MustIncludeStringEval(StringEval): + type: Literal["must_include"] = "must_include" + reference_answer: str + + @property + def description(self) -> str: + return f"Answer must include '{self.reference_answer}'" + + def evaluate(self, string: str) -> bool: + return self.reference_answer.lower() in string.lower() + + +class UrlMatchEval(BaseModel, Eval): + url: str + """Example: `"__WIKI__/wiki/Octopus"`""" + + @property + def description(self) -> str: + return f"Agent must navigate to '{self.url}'" + + def evaluate(self, url: str) -> bool: + return url == resolve_uri(self.url) + + +class ProgramHtmlEval(BaseModel): + url: str + locator: str + """JavaScript code that returns the value to check""" + required_contents: str + + @property + def description(self) -> str: + return ( + f"On the webpage {self.url}, " + f"`{self.locator}` should contain '{self.required_contents}'" + ) + + def evaluate(self, selenium_instance) -> bool: + result = selenium_instance.execute_script( + self.locator or "return document.body.innerHTML;" + ) + return self.required_contents in result + + +_Eval = StringEval | UrlMatchEval | ProgramHtmlEval + + +class WebArenaChallengeSpec(BaseModel): + task_id: int + sites: list[WebArenaSite] + """The sites needed to complete the task""" + start_url: str + """The full URL at which to start""" + start_url_junglegym: str + """The JungleGym site (base URL) at which to start""" + require_login: bool + require_reset: bool + storage_state: str | None + + intent: str + intent_template: str + intent_template_id: int + instantiation_dict: dict[str, str | list[str]] + + class EvalSet(BaseModel): + class StringMatchEvalSet(BaseModel): + exact_match: str | None + fuzzy_match: list[str] | None + must_include: list[str] | None + + reference_answers: StringMatchEvalSet | None + """For string_match eval, a set of criteria to judge the final answer""" + reference_answer_raw_annotation: str | None + string_note: str | None + annotation_note: str | None + + reference_url: str | None + """For url_match eval, the last URL that should be visited""" + url_note: str | None + + program_html: list[ProgramHtmlEval] + """For program_html eval, a list of criteria to judge the site state by""" + + eval_types: list[EvalType] + + @validator("eval_types") + def check_eval_parameters(cls, v: list[EvalType], values): + if "string_match" in v and not values.get("reference_answers"): + raise ValueError("'string_match' eval_type requires reference_answers") + if "url_match" in v and not values.get("reference_url"): + raise ValueError("'url_match' eval_type requires reference_url") + if "program_html" in v and not values.get("program_html"): + raise ValueError( + "'program_html' eval_type requires at least one program_html eval" + ) + return v + + @property + def evaluators(self) -> list[_Eval]: + evaluators: list[_Eval] = [] + if self.reference_answers: + if self.reference_answers.exact_match: + evaluators.append( + ExactStringMatchEval( + reference_answer=self.reference_answers.exact_match + ) + ) + if self.reference_answers.fuzzy_match: + evaluators.extend( + FuzzyStringMatchEval(reference_answer=a) + for a in self.reference_answers.fuzzy_match + ) + if self.reference_answers.must_include: + evaluators.extend( + MustIncludeStringEval(reference_answer=a) + for a in self.reference_answers.must_include + ) + if self.reference_url: + evaluators.append(UrlMatchEval(url=self.reference_url)) + evaluators.extend(self.program_html) + return evaluators + + eval: EvalSet + """Evaluation criteria by which to judge the agent's performance""" + + @property + def assignment_for_agent(self): + sites = [get_site_url(s) for s in self.sites] + nav_constraint = ( + f"You are ONLY allowed to access URLs in {' and '.join(sites)}." + ) + + return ( + f"First of all, go to {self.start_url}. " + f"{self.intent.rstrip('.')}.\n" + f"{nav_constraint}" + ) + + +class WebArenaChallenge(BaseChallenge): + _spec: ClassVar[WebArenaChallengeSpec] + + SOURCE_URI_PREFIX = "__JUNGLEGYM__/webarena/tasks/" + SOURCE_URI_TEMPLATE = f"{SOURCE_URI_PREFIX}{{task_id}}" + + @classmethod + def from_source_uri(cls, source_uri: str) -> type["WebArenaChallenge"]: + if not source_uri.startswith(cls.SOURCE_URI_PREFIX): + raise ValueError(f"Invalid source_uri for WebArenaChallenge: {source_uri}") + + source_url = source_uri.replace( + cls.SOURCE_URI_PREFIX, + "https://api.junglegym.ai/get_webarena_by_task_id?task_id=", + ) + results = requests.get(source_url).json()["data"] + if not results: + raise ValueError(f"Could not fetch challenge {source_uri}") + return cls.from_challenge_spec(WebArenaChallengeSpec.parse_obj(results[0])) + + @classmethod + def from_challenge_spec( + cls, spec: WebArenaChallengeSpec + ) -> type["WebArenaChallenge"]: + challenge_info = ChallengeInfo( + eval_id=f"junglegym-webarena-{spec.task_id}", + name=f"WebArenaTask_{spec.task_id}", + task=spec.assignment_for_agent, + category=[ + Category.GENERALIST, + Category.WEB, + ], # TODO: make categories more specific + reference_answer=spec.eval.reference_answer_raw_annotation, + source_uri=cls.SOURCE_URI_TEMPLATE.format(task_id=spec.task_id), + ) + return type( + f"Test{challenge_info.name}", + (WebArenaChallenge,), + { + "info": challenge_info, + "_spec": spec, + }, + ) + + @classmethod + def evaluate_answer(cls, answer: str) -> list[tuple[_Eval, EvalResult]]: + results: list[tuple[_Eval, EvalResult]] = [] + for evaluator in cls._spec.eval.evaluators: + if isinstance(evaluator, StringEval): # string_match + results.append( + ( + evaluator, + EvalResult( + result=answer, + result_source="step_output", + score=evaluator.evaluate(answer), + passed=evaluator.evaluate(answer), + ), + ) + ) + return results + + @classmethod + def evaluate_step_result(cls, step: Step) -> list[tuple[_Eval, EvalResult]]: + assert step.output + eval_results = cls.evaluate_answer(step.output) + for eval in cls._spec.eval.evaluators: + if isinstance(eval, UrlMatchEval): + passed = resolve_uri(eval.url) in step.output # HACK: url_match bodge + eval_results.append( + ( + eval, + EvalResult( + result=step.output, + result_source="step_output", + score=1.0 if passed else 0.0, + passed=passed, + ), + ) + ) + # TODO: add support for program_html evals + return eval_results + + @classmethod + async def evaluate_task_state( + cls, agent: AgentApi, task_id: str + ) -> list[EvalResult]: + steps: list[Step] = (await agent.list_agent_task_steps(task_id)).steps + + eval_results_per_step = [cls.evaluate_step_result(step) for step in steps] + # Get the column aggregate (highest scored EvalResult for each Eval) + # from the matrix of EvalResults per step. + return [ + max(step_results_for_eval, key=lambda r: r[1].score)[1] + for step_results_for_eval in zip(*eval_results_per_step) + ] + + @pytest.mark.asyncio + async def test_method( + self, + config: AgentBenchmarkConfig, + request: pytest.FixtureRequest, + i_attempt: int, + ) -> None: + if os.environ.get("HELICONE_API_KEY"): + from helicone.lock import HeliconeLockManager + + HeliconeLockManager.write_custom_property("challenge", self.info.name) + + timeout = 120 + if request.config.getoption("--nc"): + timeout = 100000 + elif cutoff := request.config.getoption("--cutoff"): + timeout = int(cutoff) + + timed_out = None + eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = [] + try: + async for step in self.run_challenge(config, timeout): + if not step.output: + logger.warn(f"Step has no output: {step}") + continue + step_eval_results = self.evaluate_step_result(step) + logger.debug(f"Intermediary results: {step_eval_results}") + eval_results_per_step.append(step_eval_results) + if step.is_last: + request.node.user_properties.append( + ( + "answers", + step.output + if request.config.getoption("--keep-answers") + else None, + ) + ) + timed_out = False + except TimeoutError: + timed_out = True + request.node.user_properties.append(("timed_out", timed_out)) + + # Get the column aggregate (highest score for each Eval) + # from the matrix of EvalResults per step. + evals_results = [ + max(step_results_for_eval, key=lambda r: r[1].score) + for step_results_for_eval in zip(*eval_results_per_step) + ] + + if not evals_results: + if timed_out: + raise TimeoutError("Timed out, no results to evaluate") + else: + raise ValueError("No results to evaluate") + + request.node.user_properties.append( + ("scores", [r[1].score for r in evals_results]) + ) + + # FIXME: arbitrary threshold + assert all(r[1].score > 0.9 for r in evals_results), ( + "Scores insufficient:\n\n" + if not timed_out + else "Timed out; scores insufficient:\n\n" + ) + "\n".join(f"{repr(r[0])}\n -> {repr(r[1])}" for r in evals_results) + + +def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]: + logger.info("Loading WebArena challenges...") + + for site, info in site_info_map.items(): + if not info.available: + logger.warning( + f"JungleGym site '{site}' is not available: {info.unavailable_reason} " + "Skipping all challenges which use this site." + ) + + # response = requests.get("https://api.junglegym.ai/get_full_webarena_dataset") + # challenge_dicts = response.json()["data"] + + # Until the full WebArena challenge set is supported, use a hand-picked selection + import json + from pathlib import Path + + challenge_dicts = json.loads( + (Path(__file__).parent / "webarena_selection.json").read_bytes() + ) + + logger.debug( + "Fetched WebArena dataset. " + f"Constructing {len(challenge_dicts)} WebArenaChallenges..." + ) + loaded = 0 + failed = 0 + skipped = 0 + for entry in challenge_dicts: + try: + challenge_spec = WebArenaChallengeSpec.parse_obj(entry) + for site in challenge_spec.sites: + site_info = site_info_map.get(site) + if site_info is None: + logger.warning( + f"WebArena task {challenge_spec.task_id} requires unknown site " + f"'{site}'; skipping..." + ) + break + if not site_info.available: + logger.debug( + f"WebArena task {challenge_spec.task_id} requires unavailable " + f"site '{site}'; skipping..." + ) + break + else: + yield WebArenaChallenge.from_challenge_spec(challenge_spec) + loaded += 1 + continue + skipped += 1 + except ValidationError as e: + failed += 1 + logger.warning(f"Error validating WebArena challenge entry: {entry}") + logger.warning(f"Error details: {e}") + logger.info( + "Loading WebArena challenges complete: " + f"loaded {loaded}, skipped {skipped}. {failed} challenge failed to load." + ) |