aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-02-16 15:17:11 +0100
committerGravatar Reinier van der Leer <pwuts@agpt.co> 2024-02-16 15:17:11 +0100
commit23d58a3cc0b7473b1ff2362acf8392eeae73a1d8 (patch)
treef98f3e0d05cc91344c0b5e7a8103c6cbbcaec2e2
parentrefactor(benchmark): `load_webarena_challenges` (diff)
downloadAuto-GPT-23d58a3cc0b7473b1ff2362acf8392eeae73a1d8.tar.gz
Auto-GPT-23d58a3cc0b7473b1ff2362acf8392eeae73a1d8.tar.bz2
Auto-GPT-23d58a3cc0b7473b1ff2362acf8392eeae73a1d8.zip
feat(benchmark/cli): Add `challenge list`, `challenge info` subcommands
- Add `challenge list` command with options `--all`, `--names`, `--json` - Add `tabular` dependency - Add `.utils.utils.sorted_by_enum_index` function to easily sort lists by an enum value/property based on the order of the enum's definition - Add `challenge info [name]` command with option `--json` - Add `.utils.utils.pretty_print_model` routine to pretty-print Pydantic models - Refactor `config` subcommand to use `pretty_print_model`
-rw-r--r--benchmark/agbenchmark/__main__.py127
-rw-r--r--benchmark/agbenchmark/utils/data_types.py2
-rw-r--r--benchmark/agbenchmark/utils/utils.py79
-rw-r--r--benchmark/poetry.lock16
-rw-r--r--benchmark/pyproject.toml1
5 files changed, 219 insertions, 6 deletions
diff --git a/benchmark/agbenchmark/__main__.py b/benchmark/agbenchmark/__main__.py
index 9fff53523..571f19f35 100644
--- a/benchmark/agbenchmark/__main__.py
+++ b/benchmark/agbenchmark/__main__.py
@@ -202,15 +202,136 @@ def serve(port: Optional[int] = None):
@cli.command()
def config():
"""Displays info regarding the present AGBenchmark config."""
+ from .utils.utils import pretty_print_model
+
try:
config = AgentBenchmarkConfig.load()
except FileNotFoundError as e:
click.echo(e, err=True)
return 1
- k_col_width = max(len(k) for k in config.dict().keys())
- for k, v in config.dict().items():
- click.echo(f"{k: <{k_col_width}} = {v}")
+ pretty_print_model(config, include_header=False)
+
+
+@cli.group()
+def challenge():
+ logging.getLogger().setLevel(logging.WARNING)
+ pass
+
+
+@challenge.command("list")
+@click.option(
+ "--all", "include_unavailable", is_flag=True, help="Include unavailable challenges."
+)
+@click.option(
+ "--names", "only_names", is_flag=True, help="List only the challenge names."
+)
+@click.option("--json", "output_json", is_flag=True)
+def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool):
+ """Lists [available|all] challenges."""
+ import json
+
+ from tabulate import tabulate
+
+ from .challenges.builtin import load_builtin_challenges
+ from .challenges.webarena import load_webarena_challenges
+ from .utils.data_types import Category, DifficultyLevel
+ from .utils.utils import sorted_by_enum_index
+
+ DIFFICULTY_COLORS = {
+ difficulty: color
+ for difficulty, color in zip(
+ DifficultyLevel,
+ ["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"],
+ )
+ }
+ CATEGORY_COLORS = {
+ category: f"bright_{color}"
+ for category, color in zip(
+ Category,
+ ["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"],
+ )
+ }
+
+ # Load challenges
+ challenges = filter(
+ lambda c: c.info.available or include_unavailable,
+ [
+ *load_builtin_challenges(),
+ *load_webarena_challenges(skip_unavailable=False),
+ ],
+ )
+ challenges = sorted_by_enum_index(
+ challenges, DifficultyLevel, key=lambda c: c.info.difficulty
+ )
+
+ if only_names:
+ if output_json:
+ click.echo(json.dumps([c.info.name for c in challenges]))
+ return
+
+ for c in challenges:
+ click.echo(
+ click.style(c.info.name, fg=None if c.info.available else "black")
+ )
+ return
+
+ if output_json:
+ click.echo(json.dumps([json.loads(c.info.json()) for c in challenges]))
+ return
+
+ headers = tuple(
+ click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories")
+ )
+ table = [
+ tuple(
+ v if challenge.info.available else click.style(v, fg="black")
+ for v in (
+ challenge.info.name,
+ (
+ click.style(
+ challenge.info.difficulty.value,
+ fg=DIFFICULTY_COLORS[challenge.info.difficulty],
+ )
+ if challenge.info.difficulty
+ else click.style("-", fg="black")
+ ),
+ " ".join(
+ click.style(cat.value, fg=CATEGORY_COLORS[cat])
+ for cat in sorted_by_enum_index(challenge.info.category, Category)
+ ),
+ )
+ )
+ for challenge in challenges
+ ]
+ click.echo(tabulate(table, headers=headers))
+
+
+@challenge.command()
+@click.option("--json", is_flag=True)
+@click.argument("name")
+def info(name: str, json: bool):
+ from itertools import chain
+
+ from .challenges.builtin import load_builtin_challenges
+ from .challenges.webarena import load_webarena_challenges
+ from .utils.utils import pretty_print_model
+
+ for challenge in chain(
+ load_builtin_challenges(),
+ load_webarena_challenges(skip_unavailable=False),
+ ):
+ if challenge.info.name != name:
+ continue
+
+ if json:
+ click.echo(challenge.info.json())
+ break
+
+ pretty_print_model(challenge.info)
+ break
+ else:
+ click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True)
@cli.command()
diff --git a/benchmark/agbenchmark/utils/data_types.py b/benchmark/agbenchmark/utils/data_types.py
index 688209682..ac7444921 100644
--- a/benchmark/agbenchmark/utils/data_types.py
+++ b/benchmark/agbenchmark/utils/data_types.py
@@ -29,8 +29,8 @@ STRING_DIFFICULTY_MAP = {e.value: DIFFICULTY_MAP[e] for e in DifficultyLevel}
class Category(str, Enum):
- DATA = "data"
GENERALIST = "general"
+ DATA = "data"
CODING = "coding"
SCRAPE_SYNTHESIZE = "scrape_synthesize"
WEB = "web"
diff --git a/benchmark/agbenchmark/utils/utils.py b/benchmark/agbenchmark/utils/utils.py
index 93724de85..0f0ad56d9 100644
--- a/benchmark/agbenchmark/utils/utils.py
+++ b/benchmark/agbenchmark/utils/utils.py
@@ -3,10 +3,13 @@ import json
import logging
import os
import re
+from enum import Enum
from pathlib import Path
-from typing import Any, Optional
+from typing import Any, Callable, Iterable, Optional, TypeVar, overload
+import click
from dotenv import load_dotenv
+from pydantic import BaseModel
from agbenchmark.reports.processing.report_types import Test
from agbenchmark.utils.data_types import DIFFICULTY_MAP, DifficultyLevel
@@ -17,6 +20,9 @@ AGENT_NAME = os.getenv("AGENT_NAME")
logger = logging.getLogger(__name__)
+T = TypeVar("T")
+E = TypeVar("E", bound=Enum)
+
def replace_backslash(value: Any) -> Any:
if isinstance(value, str):
@@ -124,6 +130,42 @@ def write_pretty_json(data, json_file):
f.write("\n")
+def pretty_print_model(model: BaseModel, include_header: bool = True) -> None:
+ indent = ""
+ if include_header:
+ # Try to find the ID and/or name attribute of the model
+ id, name = None, None
+ for attr, value in model.dict().items():
+ if attr == "id" or attr.endswith("_id"):
+ id = value
+ if attr.endswith("name"):
+ name = value
+ if id and name:
+ break
+ identifiers = [v for v in [name, id] if v]
+ click.echo(
+ f"{model.__repr_name__()}{repr(identifiers) if identifiers else ''}:"
+ )
+ indent = " " * 2
+
+ k_col_width = max(len(k) for k in model.dict().keys())
+ for k, v in model.dict().items():
+ v_fmt = repr(v)
+ if v is None or v == "":
+ v_fmt = click.style(v_fmt, fg="black")
+ elif type(v) is bool:
+ v_fmt = click.style(v_fmt, fg="green" if v else "red")
+ elif type(v) is str and "\n" in v:
+ v_fmt = f"\n{v}".replace(
+ "\n", f"\n{indent} {click.style('|', fg='black')} "
+ )
+ if isinstance(v, Enum):
+ v_fmt = click.style(v.value, fg="blue")
+ elif type(v) is list and len(v) > 0 and isinstance(v[0], Enum):
+ v_fmt = ", ".join(click.style(lv.value, fg="blue") for lv in v)
+ click.echo(f"{indent}{k: <{k_col_width}} = {v_fmt}")
+
+
def deep_sort(obj):
"""
Recursively sort the keys in JSON object
@@ -133,3 +175,38 @@ def deep_sort(obj):
if isinstance(obj, list):
return [deep_sort(elem) for elem in obj]
return obj
+
+
+@overload
+def sorted_by_enum_index(
+ sortable: Iterable[E],
+ enum: type[E],
+ *,
+ reverse: bool = False,
+) -> list[E]:
+ ...
+
+
+@overload
+def sorted_by_enum_index(
+ sortable: Iterable[T],
+ enum: type[Enum],
+ *,
+ key: Callable[[T], Enum | None],
+ reverse: bool = False,
+) -> list[T]:
+ ...
+
+
+def sorted_by_enum_index(
+ sortable: Iterable[T],
+ enum: type[Enum],
+ *,
+ key: Callable[[T], Enum | None] = lambda x: x, # type: ignore
+ reverse: bool = False,
+) -> list[T]:
+ return sorted(
+ sortable,
+ key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3,
+ reverse=reverse,
+ )
diff --git a/benchmark/poetry.lock b/benchmark/poetry.lock
index 005086565..70bef01f6 100644
--- a/benchmark/poetry.lock
+++ b/benchmark/poetry.lock
@@ -2432,6 +2432,20 @@ anyio = ">=3.4.0,<5"
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
[[package]]
+name = "tabulate"
+version = "0.9.0"
+description = "Pretty-print tabular data"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
+ {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
+]
+
+[package.extras]
+widechars = ["wcwidth"]
+
+[[package]]
name = "toml"
version = "0.10.2"
description = "Python Library for Tom's Obvious, Minimal Language"
@@ -2760,4 +2774,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "d7893a88906b5a8eda566e13e6a9492d012c910ded0da1b1ef12b69a14f8e047"
+content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9"
diff --git a/benchmark/pyproject.toml b/benchmark/pyproject.toml
index c659dcc8b..6c3976743 100644
--- a/benchmark/pyproject.toml
+++ b/benchmark/pyproject.toml
@@ -34,6 +34,7 @@ toml = "^0.10.2"
httpx = "^0.24.0"
agent-protocol-client = "^1.1.0"
click-default-group = "^1.2.4"
+tabulate = "^0.9.0"
[tool.poetry.group.dev.dependencies]
flake8 = "^3.9.2"