aboutsummaryrefslogtreecommitdiff
path: root/benchmark/agbenchmark/agent_api_interface.py
diff options
context:
space:
mode:
Diffstat (limited to 'benchmark/agbenchmark/agent_api_interface.py')
-rw-r--r--benchmark/agbenchmark/agent_api_interface.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/benchmark/agbenchmark/agent_api_interface.py b/benchmark/agbenchmark/agent_api_interface.py
new file mode 100644
index 000000000..6eadcc537
--- /dev/null
+++ b/benchmark/agbenchmark/agent_api_interface.py
@@ -0,0 +1,104 @@
+import logging
+import os
+import time
+from pathlib import Path
+from typing import AsyncIterator, Optional
+
+from agent_protocol_client import (
+ AgentApi,
+ ApiClient,
+ Configuration,
+ Step,
+ TaskRequestBody,
+)
+
+from agbenchmark.agent_interface import get_list_of_file_paths
+from agbenchmark.config import AgentBenchmarkConfig
+
+logger = logging.getLogger(__name__)
+
+
+async def run_api_agent(
+ task: str,
+ config: AgentBenchmarkConfig,
+ timeout: int,
+ artifacts_location: Optional[Path] = None,
+) -> AsyncIterator[Step]:
+ configuration = Configuration(host=config.host)
+ async with ApiClient(configuration) as api_client:
+ api_instance = AgentApi(api_client)
+ task_request_body = TaskRequestBody(input=task)
+
+ start_time = time.time()
+ response = await api_instance.create_agent_task(
+ task_request_body=task_request_body
+ )
+ task_id = response.task_id
+
+ if artifacts_location:
+ await upload_artifacts(
+ api_instance, artifacts_location, task_id, "artifacts_in"
+ )
+
+ while True:
+ step = await api_instance.execute_agent_task_step(task_id=task_id)
+ yield step
+
+ if time.time() - start_time > timeout:
+ raise TimeoutError("Time limit exceeded")
+ if not step or step.is_last:
+ break
+
+ if artifacts_location:
+ # In "mock" mode, we cheat by giving the correct artifacts to pass the test
+ if os.getenv("IS_MOCK"):
+ await upload_artifacts(
+ api_instance, artifacts_location, task_id, "artifacts_out"
+ )
+
+ await download_agent_artifacts_into_folder(
+ api_instance, task_id, config.temp_folder
+ )
+
+
+async def download_agent_artifacts_into_folder(
+ api_instance: AgentApi, task_id: str, folder: Path
+):
+ artifacts = await api_instance.list_agent_task_artifacts(task_id=task_id)
+
+ for artifact in artifacts.artifacts:
+ # current absolute path of the directory of the file
+ if artifact.relative_path:
+ path: str = (
+ artifact.relative_path
+ if not artifact.relative_path.startswith("/")
+ else artifact.relative_path[1:]
+ )
+ folder = (folder / path).parent
+
+ if not folder.exists():
+ folder.mkdir(parents=True)
+
+ file_path = folder / artifact.file_name
+ logger.debug(f"Downloading agent artifact {artifact.file_name} to {folder}")
+ with open(file_path, "wb") as f:
+ content = await api_instance.download_agent_task_artifact(
+ task_id=task_id, artifact_id=artifact.artifact_id
+ )
+
+ f.write(content)
+
+
+async def upload_artifacts(
+ api_instance: AgentApi, artifacts_location: Path, task_id: str, type: str
+) -> None:
+ for file_path in get_list_of_file_paths(artifacts_location, type):
+ relative_path: Optional[str] = "/".join(
+ str(file_path).split(f"{type}/", 1)[-1].split("/")[:-1]
+ )
+ if not relative_path:
+ relative_path = None
+
+ await api_instance.upload_agent_task_artifacts(
+ task_id=task_id, file=str(file_path), relative_path=relative_path
+ )