aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Reinier van der Leer <pwuts@agpt.co> 2024-01-29 11:33:42 +0100
committerGravatar Reinier van der Leer <pwuts@agpt.co> 2024-01-29 11:33:42 +0100
commit8b0579a87c08305e01e0fae9a3a2f79c45870a3d (patch)
tree861ddc8c2e0a2bd48756e29c6f1ccd58efea1efe
parentfeat(agent/llm): Add support for `gpt-4-0125-preview` (diff)
downloadAuto-GPT-benchmark/concurrency.tar.gz
Auto-GPT-benchmark/concurrency.tar.bz2
Auto-GPT-benchmark/concurrency.zip
feat(benchmark): Add `-P`, `--parallel-tasks` option to allow running multiple tasks concurrentlybenchmark/concurrency
* Add dependency `pytest-parallel` and indirect dependency `py` (pylib) * Make `SingletonReportManager` thread safe
-rw-r--r--benchmark/agbenchmark/__main__.py6
-rw-r--r--benchmark/agbenchmark/main.py4
-rw-r--r--benchmark/agbenchmark/reports/ReportManager.py57
-rw-r--r--benchmark/poetry.lock39
-rw-r--r--benchmark/pyproject.toml2
5 files changed, 82 insertions, 26 deletions
diff --git a/benchmark/agbenchmark/__main__.py b/benchmark/agbenchmark/__main__.py
index 9fff53523..b531b4518 100644
--- a/benchmark/agbenchmark/__main__.py
+++ b/benchmark/agbenchmark/__main__.py
@@ -63,6 +63,9 @@ def start():
"-N", "--attempts", default=1, help="Number of times to run each challenge."
)
@click.option(
+ "-P", "--parallel-tasks", default=1, help="Number of challenges to run in parallel."
+)
+@click.option(
"-c",
"--category",
multiple=True,
@@ -111,6 +114,7 @@ def run(
category: tuple[str],
skip_category: tuple[str],
attempts: int,
+ parallel_tasks: int,
cutoff: Optional[int] = None,
backend: Optional[bool] = False,
# agent_path: Optional[Path] = None,
@@ -158,6 +162,7 @@ def run(
categories=category,
skip_categories=skip_category,
attempts_per_challenge=attempts,
+ concurrent_tasks=parallel_tasks,
cutoff=cutoff,
)
@@ -177,6 +182,7 @@ def run(
categories=category,
skip_categories=skip_category,
attempts_per_challenge=attempts,
+ concurrent_tasks=parallel_tasks,
cutoff=cutoff,
)
diff --git a/benchmark/agbenchmark/main.py b/benchmark/agbenchmark/main.py
index 4cd97bd89..4128a0a26 100644
--- a/benchmark/agbenchmark/main.py
+++ b/benchmark/agbenchmark/main.py
@@ -22,6 +22,7 @@ def run_benchmark(
categories: tuple[str] = tuple(),
skip_categories: tuple[str] = tuple(),
attempts_per_challenge: int = 1,
+ concurrent_tasks: int = 1,
mock: bool = False,
no_dep: bool = False,
no_cutoff: bool = False,
@@ -100,6 +101,9 @@ def run_benchmark(
if attempts_per_challenge > 1:
pytest_args.append(f"--attempts={attempts_per_challenge}")
+ if concurrent_tasks > 1:
+ pytest_args.append(f"--tests-per-worker={concurrent_tasks}")
+
if cutoff:
pytest_args.append(f"--cutoff={cutoff}")
logger.debug(f"Setting cuttoff override to {cutoff} seconds.")
diff --git a/benchmark/agbenchmark/reports/ReportManager.py b/benchmark/agbenchmark/reports/ReportManager.py
index d04beee43..5d1392bbc 100644
--- a/benchmark/agbenchmark/reports/ReportManager.py
+++ b/benchmark/agbenchmark/reports/ReportManager.py
@@ -3,10 +3,11 @@ import json
import logging
import os
import sys
+import threading
import time
from datetime import datetime, timezone
from pathlib import Path
-from typing import Any
+from typing import Any, ClassVar
from agbenchmark.config import AgentBenchmarkConfig
from agbenchmark.reports.processing.graphs import save_single_radar_chart
@@ -20,39 +21,39 @@ logger = logging.getLogger(__name__)
class SingletonReportManager:
- instance = None
+ _instance = None
+ _lock: ClassVar[threading.Lock] = threading.Lock()
INFO_MANAGER: "SessionReportManager"
REGRESSION_MANAGER: "RegressionTestsTracker"
SUCCESS_RATE_TRACKER: "SuccessRatesTracker"
def __new__(cls):
- if not cls.instance:
- cls.instance = super(SingletonReportManager, cls).__new__(cls)
-
- agent_benchmark_config = AgentBenchmarkConfig.load()
- benchmark_start_time_dt = datetime.now(
- timezone.utc
- ) # or any logic to fetch the datetime
-
- # Make the Managers class attributes
- cls.INFO_MANAGER = SessionReportManager(
- agent_benchmark_config.get_report_dir(benchmark_start_time_dt)
- / "report.json",
- benchmark_start_time_dt,
- )
- cls.REGRESSION_MANAGER = RegressionTestsTracker(
- agent_benchmark_config.regression_tests_file
- )
- cls.SUCCESS_RATE_TRACKER = SuccessRatesTracker(
- agent_benchmark_config.success_rate_file
- )
-
- return cls.instance
+ with cls._lock:
+ if not cls._instance:
+ cls._instance = super(SingletonReportManager, cls).__new__(cls)
+
+ agent_benchmark_config = AgentBenchmarkConfig.load()
+ benchmark_start_time_dt = datetime.now(timezone.utc)
+
+ # Make the Managers class attributes
+ cls.INFO_MANAGER = SessionReportManager(
+ agent_benchmark_config.get_report_dir(benchmark_start_time_dt)
+ / "report.json",
+ benchmark_start_time_dt,
+ )
+ cls.REGRESSION_MANAGER = RegressionTestsTracker(
+ agent_benchmark_config.regression_tests_file
+ )
+ cls.SUCCESS_RATE_TRACKER = SuccessRatesTracker(
+ agent_benchmark_config.success_rate_file
+ )
+
+ return cls._instance
@classmethod
def clear_instance(cls):
- cls.instance = None
+ cls._instance = None
cls.INFO_MANAGER = None
cls.REGRESSION_MANAGER = None
cls.SUCCESS_RATE_TRACKER = None
@@ -131,6 +132,12 @@ class SessionReportManager(BaseReportManager):
self.save()
+ def get_test_report(self, test_name: str) -> Test | None:
+ if isinstance(self.tests, Report):
+ return self.tests.tests.get(test_name)
+ else:
+ return self.tests.get(test_name)
+
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
command = " ".join(sys.argv)
diff --git a/benchmark/poetry.lock b/benchmark/poetry.lock
index 057b89aa4..5bdfbe089 100644
--- a/benchmark/poetry.lock
+++ b/benchmark/poetry.lock
@@ -1947,6 +1947,17 @@ files = [
tests = ["pytest"]
[[package]]
+name = "py"
+version = "1.11.0"
+description = "library with cross-python path, ini-parsing, io, code, log facilities"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
+ {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
+]
+
+[[package]]
name = "pyasn1"
version = "0.5.1"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
@@ -2138,6 +2149,21 @@ docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
+name = "pytest-parallel"
+version = "0.1.1"
+description = "a pytest plugin for parallel and concurrent testing"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytest-parallel-0.1.1.tar.gz", hash = "sha256:9aac3fc199a168c0a8559b60249d9eb254de7af58c12cee0310b54d4affdbfab"},
+ {file = "pytest_parallel-0.1.1-py3-none-any.whl", hash = "sha256:9e3703015b0eda52be9e07d2ba3498f09340a56d5c79a39b50f22fc5c38212fe"},
+]
+
+[package.dependencies]
+pytest = ">=3.0.0"
+tblib = "*"
+
+[[package]]
name = "python-dateutil"
version = "2.8.2"
description = "Extensions to the standard Python datetime module"
@@ -2432,6 +2458,17 @@ anyio = ">=3.4.0,<5"
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"]
[[package]]
+name = "tblib"
+version = "3.0.0"
+description = "Traceback serialization library."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"},
+ {file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"},
+]
+
+[[package]]
name = "toml"
version = "0.10.2"
description = "Python Library for Tom's Obvious, Minimal Language"
@@ -2760,4 +2797,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "e0d1f991958a5d630287c7bb668e7fdc6183630e06196cf6f507a086be10baec"
+content-hash = "4a4e53f252c8996b172bbb35a730197c07c53d7b50bf1d21964d3b2237495066"
diff --git a/benchmark/pyproject.toml b/benchmark/pyproject.toml
index 6740004b4..b8bf8ccab 100644
--- a/benchmark/pyproject.toml
+++ b/benchmark/pyproject.toml
@@ -25,7 +25,9 @@ networkx = "^3.1"
colorama = "^0.4.6"
pyvis = "^0.3.2"
selenium = "^4.11.2"
+py = "^1.11.0" # needed for pytest-parallel
pytest-asyncio = "^0.21.1"
+pytest-parallel = "^0.1.1"
uvicorn = "^0.23.2"
fastapi = "^0.99.0"
python-multipart = "^0.0.6"