aboutsummaryrefslogtreecommitdiff
path: root/benchmark/agbenchmark/agent_api_interface.py
blob: 6eadcc537f5edbffcf4f962fb14866fa748ebaf3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import logging
import os
import time
from pathlib import Path
from typing import AsyncIterator, Optional

from agent_protocol_client import (
    AgentApi,
    ApiClient,
    Configuration,
    Step,
    TaskRequestBody,
)

from agbenchmark.agent_interface import get_list_of_file_paths
from agbenchmark.config import AgentBenchmarkConfig

logger = logging.getLogger(__name__)


async def run_api_agent(
    task: str,
    config: AgentBenchmarkConfig,
    timeout: int,
    artifacts_location: Optional[Path] = None,
) -> AsyncIterator[Step]:
    configuration = Configuration(host=config.host)
    async with ApiClient(configuration) as api_client:
        api_instance = AgentApi(api_client)
        task_request_body = TaskRequestBody(input=task)

        start_time = time.time()
        response = await api_instance.create_agent_task(
            task_request_body=task_request_body
        )
        task_id = response.task_id

        if artifacts_location:
            await upload_artifacts(
                api_instance, artifacts_location, task_id, "artifacts_in"
            )

        while True:
            step = await api_instance.execute_agent_task_step(task_id=task_id)
            yield step

            if time.time() - start_time > timeout:
                raise TimeoutError("Time limit exceeded")
            if not step or step.is_last:
                break

        if artifacts_location:
            # In "mock" mode, we cheat by giving the correct artifacts to pass the test
            if os.getenv("IS_MOCK"):
                await upload_artifacts(
                    api_instance, artifacts_location, task_id, "artifacts_out"
                )

            await download_agent_artifacts_into_folder(
                api_instance, task_id, config.temp_folder
            )


async def download_agent_artifacts_into_folder(
    api_instance: AgentApi, task_id: str, folder: Path
):
    artifacts = await api_instance.list_agent_task_artifacts(task_id=task_id)

    for artifact in artifacts.artifacts:
        # current absolute path of the directory of the file
        if artifact.relative_path:
            path: str = (
                artifact.relative_path
                if not artifact.relative_path.startswith("/")
                else artifact.relative_path[1:]
            )
            folder = (folder / path).parent

        if not folder.exists():
            folder.mkdir(parents=True)

        file_path = folder / artifact.file_name
        logger.debug(f"Downloading agent artifact {artifact.file_name} to {folder}")
        with open(file_path, "wb") as f:
            content = await api_instance.download_agent_task_artifact(
                task_id=task_id, artifact_id=artifact.artifact_id
            )

            f.write(content)


async def upload_artifacts(
    api_instance: AgentApi, artifacts_location: Path, task_id: str, type: str
) -> None:
    for file_path in get_list_of_file_paths(artifacts_location, type):
        relative_path: Optional[str] = "/".join(
            str(file_path).split(f"{type}/", 1)[-1].split("/")[:-1]
        )
        if not relative_path:
            relative_path = None

        await api_instance.upload_agent_task_artifacts(
            task_id=task_id, file=str(file_path), relative_path=relative_path
        )