aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-02-16 17:53:19 +0100
committerGravatar Reinier van der Leer <pwuts@agpt.co> 2024-02-16 17:53:19 +0100
commit752bac099bb977a0e9e106a1b92ee4d4141d525f (patch)
tree7dd8def6696a525434dd8f5923562348fa841033
parentci(benchmark): Add nightly benchmark workflow (diff)
downloadAuto-GPT-752bac099bb977a0e9e106a1b92ee4d4141d525f.tar.gz
Auto-GPT-752bac099bb977a0e9e106a1b92ee4d4141d525f.tar.bz2
Auto-GPT-752bac099bb977a0e9e106a1b92ee4d4141d525f.zip
feat(benchmark/report): Add and record `TestResult.n_steps`
- Added `n_steps` attribute to `TestResult` type - Added logic to record the number of steps to `BuiltinChallenge.test_method`, `WebArenaChallenge.test_method`, and `.reports.add_test_result_to_report`
-rw-r--r--benchmark/agbenchmark/challenges/builtin.py3
-rw-r--r--benchmark/agbenchmark/challenges/webarena.py3
-rw-r--r--benchmark/agbenchmark/reports/processing/report_types.py2
-rw-r--r--benchmark/agbenchmark/reports/reports.py1
4 files changed, 9 insertions, 0 deletions
diff --git a/benchmark/agbenchmark/challenges/builtin.py b/benchmark/agbenchmark/challenges/builtin.py
index fd28dc3ee..5b616e449 100644
--- a/benchmark/agbenchmark/challenges/builtin.py
+++ b/benchmark/agbenchmark/challenges/builtin.py
@@ -173,6 +173,7 @@ class BuiltinChallenge(BaseChallenge):
timeout = int(cutoff) # type: ignore
task_id = ""
+ n_steps = 0
timed_out = None
try:
async for step in self.run_challenge(
@@ -180,9 +181,11 @@ class BuiltinChallenge(BaseChallenge):
):
if not task_id:
task_id = step.task_id
+ n_steps += 1
timed_out = False
except TimeoutError:
timed_out = True
+ request.node.user_properties.append(("n_steps", n_steps))
request.node.user_properties.append(("timed_out", timed_out))
agent_client_config = ClientConfig(host=config.host)
diff --git a/benchmark/agbenchmark/challenges/webarena.py b/benchmark/agbenchmark/challenges/webarena.py
index 395b5a6ee..3cec1f956 100644
--- a/benchmark/agbenchmark/challenges/webarena.py
+++ b/benchmark/agbenchmark/challenges/webarena.py
@@ -393,6 +393,7 @@ class WebArenaChallenge(BaseChallenge):
elif cutoff := request.config.getoption("--cutoff"):
timeout = int(cutoff)
+ n_steps = 0
timed_out = None
eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = []
try:
@@ -402,6 +403,7 @@ class WebArenaChallenge(BaseChallenge):
if not step.output:
logger.warn(f"Step has no output: {step}")
continue
+ n_steps += 1
step_eval_results = self.evaluate_step_result(
step, mock=request.config.getoption("--mock")
)
@@ -419,6 +421,7 @@ class WebArenaChallenge(BaseChallenge):
timed_out = False
except TimeoutError:
timed_out = True
+ request.node.user_properties.append(("n_steps", n_steps))
request.node.user_properties.append(("timed_out", timed_out))
# Get the column aggregate (highest score for each Eval)
diff --git a/benchmark/agbenchmark/reports/processing/report_types.py b/benchmark/agbenchmark/reports/processing/report_types.py
index b6deef021..2ed4acf3b 100644
--- a/benchmark/agbenchmark/reports/processing/report_types.py
+++ b/benchmark/agbenchmark/reports/processing/report_types.py
@@ -20,6 +20,8 @@ class TestResult(BaseModel):
"""If applicable, the reason why the run was not successful"""
reached_cutoff: bool | None = None # None if in progress
"""Whether the run had to be stopped due to reaching the timeout"""
+ n_steps: int | None = None
+ """The number of steps executed by the agent"""
cost: float | None = None
"""The (known) cost incurred by the run, e.g. from using paid LLM APIs"""
diff --git a/benchmark/agbenchmark/reports/reports.py b/benchmark/agbenchmark/reports/reports.py
index 728d19fd9..4844f5bfe 100644
--- a/benchmark/agbenchmark/reports/reports.py
+++ b/benchmark/agbenchmark/reports/reports.py
@@ -92,6 +92,7 @@ def add_test_result_to_report(
run_time=f"{str(round(call.duration, 3))} seconds",
fail_reason=str(call.excinfo.value) if call.excinfo else None,
reached_cutoff=user_properties.get("timed_out", False),
+ n_steps=user_properties.get("n_steps"),
)
)
test_report.metrics.success_percentage = (