aboutsummaryrefslogtreecommitdiff
path: root/benchmark/agbenchmark/app.py
diff options
context:
space:
mode:
Diffstat (limited to 'benchmark/agbenchmark/app.py')
-rw-r--r--benchmark/agbenchmark/app.py334
1 files changed, 334 insertions, 0 deletions
diff --git a/benchmark/agbenchmark/app.py b/benchmark/agbenchmark/app.py
new file mode 100644
index 000000000..40fee14b6
--- /dev/null
+++ b/benchmark/agbenchmark/app.py
@@ -0,0 +1,334 @@
+import datetime
+import glob
+import json
+import logging
+import sys
+import time
+import uuid
+from collections import deque
+from multiprocessing import Process
+from pathlib import Path
+from typing import Optional
+
+import httpx
+import psutil
+from agent_protocol_client import AgentApi, ApiClient, ApiException, Configuration
+from agent_protocol_client.models import Task, TaskRequestBody
+from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel, Extra, ValidationError
+
+from agbenchmark.challenges import ChallengeInfo
+from agbenchmark.config import AgentBenchmarkConfig
+from agbenchmark.reports.processing.report_types_v2 import (
+ BenchmarkRun,
+ Metrics,
+ RepositoryInfo,
+ RunDetails,
+ TaskInfo,
+)
+from agbenchmark.schema import TaskEvalRequestBody
+from agbenchmark.utils.utils import write_pretty_json
+
+sys.path.append(str(Path(__file__).parent.parent))
+
+logger = logging.getLogger(__name__)
+
+CHALLENGES: dict[str, ChallengeInfo] = {}
+challenges_path = Path(__file__).parent / "challenges"
+challenge_spec_files = deque(
+ glob.glob(
+ f"{challenges_path}/**/data.json",
+ recursive=True,
+ )
+)
+
+logger.debug("Loading challenges...")
+while challenge_spec_files:
+ challenge_spec_file = Path(challenge_spec_files.popleft())
+ challenge_relpath = challenge_spec_file.relative_to(challenges_path.parent)
+ if challenge_relpath.is_relative_to("challenges/deprecated"):
+ continue
+
+ logger.debug(f"Loading {challenge_relpath}...")
+ try:
+ challenge_info = ChallengeInfo.parse_file(challenge_spec_file)
+ except ValidationError as e:
+ if logging.getLogger().level == logging.DEBUG:
+ logger.warning(f"Spec file {challenge_relpath} failed to load:\n{e}")
+ logger.debug(f"Invalid challenge spec: {challenge_spec_file.read_text()}")
+ continue
+ challenge_info.spec_file = challenge_spec_file
+
+ if not challenge_info.eval_id:
+ challenge_info.eval_id = str(uuid.uuid4())
+ # this will sort all the keys of the JSON systematically
+ # so that the order is always the same
+ write_pretty_json(challenge_info.dict(), challenge_spec_file)
+
+ CHALLENGES[challenge_info.eval_id] = challenge_info
+
+
+class BenchmarkTaskInfo(BaseModel):
+ task_id: str
+ start_time: datetime.datetime
+ challenge_info: ChallengeInfo
+
+
+task_informations: dict[str, BenchmarkTaskInfo] = {}
+
+
+def find_agbenchmark_without_uvicorn():
+ pids = []
+ for process in psutil.process_iter(
+ attrs=[
+ "pid",
+ "cmdline",
+ "name",
+ "username",
+ "status",
+ "cpu_percent",
+ "memory_info",
+ "create_time",
+ "cwd",
+ "connections",
+ ]
+ ):
+ try:
+ # Convert the process.info dictionary values to strings and concatenate them
+ full_info = " ".join([str(v) for k, v in process.as_dict().items()])
+
+ if "agbenchmark" in full_info and "uvicorn" not in full_info:
+ pids.append(process.pid)
+ except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
+ pass
+ return pids
+
+
+class CreateReportRequest(BaseModel):
+ test: str = None
+ test_run_id: str = None
+ # category: Optional[str] = []
+ mock: Optional[bool] = False
+
+ class Config:
+ extra = Extra.forbid # this will forbid any extra fields
+
+
+updates_list = []
+
+origins = [
+ "http://localhost:8000",
+ "http://localhost:8080",
+ "http://127.0.0.1:5000",
+ "http://localhost:5000",
+]
+
+
+def stream_output(pipe):
+ for line in pipe:
+ print(line, end="")
+
+
+def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
+ from agbenchmark.agent_api_interface import upload_artifacts
+ from agbenchmark.challenges import get_challenge_from_source_uri
+ from agbenchmark.main import run_benchmark
+
+ configuration = Configuration(
+ host=agbenchmark_config.host or "http://localhost:8000"
+ )
+ app = FastAPI()
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+ router = APIRouter()
+
+ @router.post("/reports")
+ def run_single_test(body: CreateReportRequest) -> dict:
+ pids = find_agbenchmark_without_uvicorn()
+ logger.info(f"pids already running with agbenchmark: {pids}")
+
+ logger.debug(f"Request to /reports: {body.dict()}")
+
+ # Start the benchmark in a separate thread
+ benchmark_process = Process(
+ target=lambda: run_benchmark(
+ config=agbenchmark_config,
+ tests=(body.test,),
+ mock=body.mock or False,
+ )
+ )
+ benchmark_process.start()
+
+ # Wait for the benchmark to finish, with a timeout of 200 seconds
+ timeout = 200
+ start_time = time.time()
+ while benchmark_process.is_alive():
+ if time.time() - start_time > timeout:
+ logger.warning(f"Benchmark run timed out after {timeout} seconds")
+ benchmark_process.terminate()
+ break
+ time.sleep(1)
+ else:
+ logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
+
+ # List all folders in the current working directory
+ path_reports = agbenchmark_config.reports_folder
+ folders = [folder for folder in path_reports.iterdir() if folder.is_dir()]
+
+ # Sort the folders based on their names
+ sorted_folders = sorted(folders, key=lambda x: x.name)
+
+ # Get the last folder
+ latest_folder = sorted_folders[-1] if sorted_folders else None
+
+ # Read report.json from this folder
+ if latest_folder:
+ report_path = latest_folder / "report.json"
+ logger.debug(f"Getting latest report from {report_path}")
+ if report_path.exists():
+ with report_path.open() as file:
+ data = json.load(file)
+ logger.debug(f"Report data: {data}")
+ else:
+ logger.error(
+ "Could not get result after running benchmark: "
+ f"'report.json' does not exist in '{latest_folder}'"
+ )
+ else:
+ logger.error(
+ "Could not get result after running benchmark: no reports found"
+ )
+
+ return data
+
+ @router.post("/agent/tasks", tags=["agent"])
+ async def create_agent_task(task_eval_request: TaskEvalRequestBody) -> Task:
+ """
+ Creates a new task using the provided TaskEvalRequestBody and returns a Task.
+
+ Args:
+ task_eval_request: `TaskRequestBody` including an eval_id.
+
+ Returns:
+ Task: A new task with task_id, input, additional_input,
+ and empty lists for artifacts and steps.
+
+ Example:
+ Request (TaskEvalRequestBody defined in schema.py):
+ {
+ ...,
+ "eval_id": "50da533e-3904-4401-8a07-c49adf88b5eb"
+ }
+
+ Response (Task defined in `agent_protocol_client.models`):
+ {
+ "task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
+ "input": "Write the word 'Washington' to a .txt file",
+ "artifacts": []
+ }
+ """
+ try:
+ challenge_info = CHALLENGES[task_eval_request.eval_id]
+ async with ApiClient(configuration) as api_client:
+ api_instance = AgentApi(api_client)
+ task_input = challenge_info.task
+
+ task_request_body = TaskRequestBody(input=task_input)
+ task_response = await api_instance.create_agent_task(
+ task_request_body=task_request_body
+ )
+ task_info = BenchmarkTaskInfo(
+ task_id=task_response.task_id,
+ start_time=datetime.datetime.now(datetime.timezone.utc),
+ challenge_info=challenge_info,
+ )
+ task_informations[task_info.task_id] = task_info
+
+ if input_artifacts_dir := challenge_info.task_artifacts_dir:
+ await upload_artifacts(
+ api_instance,
+ input_artifacts_dir,
+ task_response.task_id,
+ "artifacts_in",
+ )
+ return task_response
+ except ApiException as e:
+ logger.error(f"Error whilst trying to create a task:\n{e}")
+ logger.error(
+ "The above error was caused while processing request: "
+ f"{task_eval_request}"
+ )
+ raise HTTPException(500)
+
+ @router.post("/agent/tasks/{task_id}/steps")
+ async def proxy(request: Request, task_id: str):
+ timeout = httpx.Timeout(300.0, read=300.0) # 5 minutes
+ async with httpx.AsyncClient(timeout=timeout) as client:
+ # Construct the new URL
+ new_url = f"{configuration.host}/ap/v1/agent/tasks/{task_id}/steps"
+
+ # Forward the request
+ response = await client.post(
+ new_url,
+ data=await request.body(),
+ headers=dict(request.headers),
+ )
+
+ # Return the response from the forwarded request
+ return Response(content=response.content, status_code=response.status_code)
+
+ @router.post("/agent/tasks/{task_id}/evaluations")
+ async def create_evaluation(task_id: str) -> BenchmarkRun:
+ task_info = task_informations[task_id]
+ challenge = get_challenge_from_source_uri(task_info.challenge_info.source_uri)
+ try:
+ async with ApiClient(configuration) as api_client:
+ api_instance = AgentApi(api_client)
+ eval_results = await challenge.evaluate_task_state(
+ api_instance, task_id
+ )
+
+ eval_info = BenchmarkRun(
+ repository_info=RepositoryInfo(),
+ run_details=RunDetails(
+ command=f"agbenchmark --test={challenge.info.name}",
+ benchmark_start_time=(
+ task_info.start_time.strftime("%Y-%m-%dT%H:%M:%S+00:00")
+ ),
+ test_name=challenge.info.name,
+ ),
+ task_info=TaskInfo(
+ data_path=challenge.info.source_uri,
+ is_regression=None,
+ category=[c.value for c in challenge.info.category],
+ task=challenge.info.task,
+ answer=challenge.info.reference_answer or "",
+ description=challenge.info.description or "",
+ ),
+ metrics=Metrics(
+ success=all(e.passed for e in eval_results),
+ success_percentage=(
+ 100 * sum(e.score for e in eval_results) / len(eval_results)
+ if eval_results # avoid division by 0
+ else 0
+ ),
+ attempted=True,
+ ),
+ config={},
+ )
+
+ logger.debug(f"Returning evaluation data:\n{eval_info.json(indent=4)}")
+ return eval_info
+ except ApiException as e:
+ logger.error(f"Error {e} whilst trying to evaluate task: {task_id}")
+ raise HTTPException(500)
+
+ app.include_router(router, prefix="/ap/v1")
+
+ return app