diff options
Diffstat (limited to 'autogpts/forge/forge/sdk/db_test.py')
-rw-r--r-- | autogpts/forge/forge/sdk/db_test.py | 331 |
1 files changed, 331 insertions, 0 deletions
diff --git a/autogpts/forge/forge/sdk/db_test.py b/autogpts/forge/forge/sdk/db_test.py new file mode 100644 index 000000000..14330435e --- /dev/null +++ b/autogpts/forge/forge/sdk/db_test.py @@ -0,0 +1,331 @@ +import os +import sqlite3 +from datetime import datetime + +import pytest + +from forge.sdk.db import ( + AgentDB, + ArtifactModel, + StepModel, + TaskModel, + convert_to_artifact, + convert_to_step, + convert_to_task, +) +from forge.sdk.errors import NotFoundError as DataNotFoundError +from forge.sdk.model import ( + Artifact, + Status, + Step, + StepRequestBody, + Task, +) + + +@pytest.mark.asyncio +def test_table_creation(): + db_name = "sqlite:///test_db.sqlite3" + agent_db = AgentDB(db_name) + + conn = sqlite3.connect("test_db.sqlite3") + cursor = conn.cursor() + + # Test for tasks table existence + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'") + assert cursor.fetchone() is not None + + # Test for steps table existence + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='steps'") + assert cursor.fetchone() is not None + + # Test for artifacts table existence + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='artifacts'" + ) + assert cursor.fetchone() is not None + + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_task_schema(): + now = datetime.now() + task = Task( + task_id="50da533e-3904-4401-8a07-c49adf88b5eb", + input="Write the words you receive to the file 'output.txt'.", + created_at=now, + modified_at=now, + artifacts=[ + Artifact( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + agent_created=True, + file_name="main.py", + relative_path="python/code/", + created_at=now, + modified_at=now, + ) + ], + ) + assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb" + assert task.input == "Write the words you receive to the file 'output.txt'." + assert len(task.artifacts) == 1 + assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + + +@pytest.mark.asyncio +async def test_step_schema(): + now = datetime.now() + step = Step( + task_id="50da533e-3904-4401-8a07-c49adf88b5eb", + step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e", + created_at=now, + modified_at=now, + name="Write to file", + input="Write the words you receive to the file 'output.txt'.", + status=Status.created, + output="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>", + artifacts=[ + Artifact( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + file_name="main.py", + relative_path="python/code/", + created_at=now, + modified_at=now, + agent_created=True, + ) + ], + is_last=False, + ) + assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb" + assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e" + assert step.name == "Write to file" + assert step.status == Status.created + assert ( + step.output + == "I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')>" + ) + assert len(step.artifacts) == 1 + assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + assert step.is_last == False + + +@pytest.mark.asyncio +async def test_convert_to_task(): + now = datetime.now() + task_model = TaskModel( + task_id="50da533e-3904-4401-8a07-c49adf88b5eb", + created_at=now, + modified_at=now, + input="Write the words you receive to the file 'output.txt'.", + artifacts=[ + ArtifactModel( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + created_at=now, + modified_at=now, + relative_path="file:///path/to/main.py", + agent_created=True, + file_name="main.py", + ) + ], + ) + task = convert_to_task(task_model) + assert task.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb" + assert task.input == "Write the words you receive to the file 'output.txt'." + assert len(task.artifacts) == 1 + assert task.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + + +@pytest.mark.asyncio +async def test_convert_to_step(): + now = datetime.now() + step_model = StepModel( + task_id="50da533e-3904-4401-8a07-c49adf88b5eb", + step_id="6bb1801a-fd80-45e8-899a-4dd723cc602e", + created_at=now, + modified_at=now, + name="Write to file", + status="created", + input="Write the words you receive to the file 'output.txt'.", + artifacts=[ + ArtifactModel( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + created_at=now, + modified_at=now, + relative_path="file:///path/to/main.py", + agent_created=True, + file_name="main.py", + ) + ], + is_last=False, + ) + step = convert_to_step(step_model) + assert step.task_id == "50da533e-3904-4401-8a07-c49adf88b5eb" + assert step.step_id == "6bb1801a-fd80-45e8-899a-4dd723cc602e" + assert step.name == "Write to file" + assert step.status == Status.created + assert len(step.artifacts) == 1 + assert step.artifacts[0].artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + assert step.is_last == False + + +@pytest.mark.asyncio +async def test_convert_to_artifact(): + now = datetime.now() + artifact_model = ArtifactModel( + artifact_id="b225e278-8b4c-4f99-a696-8facf19f0e56", + created_at=now, + modified_at=now, + relative_path="file:///path/to/main.py", + agent_created=True, + file_name="main.py", + ) + artifact = convert_to_artifact(artifact_model) + assert artifact.artifact_id == "b225e278-8b4c-4f99-a696-8facf19f0e56" + assert artifact.relative_path == "file:///path/to/main.py" + assert artifact.agent_created == True + + +@pytest.mark.asyncio +async def test_create_task(): + # Having issues with pytest fixture so added setup and teardown in each test as a rapid workaround + # TODO: Fix this! + db_name = "sqlite:///test_db.sqlite3" + agent_db = AgentDB(db_name) + + task = await agent_db.create_task("task_input") + assert task.input == "task_input" + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_create_and_get_task(): + db_name = "sqlite:///test_db.sqlite3" + agent_db = AgentDB(db_name) + task = await agent_db.create_task("test_input") + fetched_task = await agent_db.get_task(task.task_id) + assert fetched_task.input == "test_input" + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_get_task_not_found(): + db_name = "sqlite:///test_db.sqlite3" + agent_db = AgentDB(db_name) + with pytest.raises(DataNotFoundError): + await agent_db.get_task(9999) + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_create_and_get_step(): + db_name = "sqlite:///test_db.sqlite3" + agent_db = AgentDB(db_name) + task = await agent_db.create_task("task_input") + step_input = StepInput(type="python/code") + request = StepRequestBody(input="test_input debug", additional_input=step_input) + step = await agent_db.create_step(task.task_id, request) + step = await agent_db.get_step(task.task_id, step.step_id) + assert step.input == "test_input debug" + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_updating_step(): + db_name = "sqlite:///test_db.sqlite3" + agent_db = AgentDB(db_name) + created_task = await agent_db.create_task("task_input") + step_input = StepInput(type="python/code") + request = StepRequestBody(input="test_input debug", additional_input=step_input) + created_step = await agent_db.create_step(created_task.task_id, request) + await agent_db.update_step(created_task.task_id, created_step.step_id, "completed") + + step = await agent_db.get_step(created_task.task_id, created_step.step_id) + assert step.status.value == "completed" + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_get_step_not_found(): + db_name = "sqlite:///test_db.sqlite3" + agent_db = AgentDB(db_name) + with pytest.raises(DataNotFoundError): + await agent_db.get_step(9999, 9999) + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_get_artifact(): + db_name = "sqlite:///test_db.sqlite3" + db = AgentDB(db_name) + + # Given: A task and its corresponding artifact + task = await db.create_task("test_input debug") + step_input = StepInput(type="python/code") + requst = StepRequestBody(input="test_input debug", additional_input=step_input) + + step = await db.create_step(task.task_id, requst) + + # Create an artifact + artifact = await db.create_artifact( + task_id=task.task_id, + file_name="test_get_artifact_sample_file.txt", + relative_path="file:///path/to/test_get_artifact_sample_file.txt", + agent_created=True, + step_id=step.step_id, + ) + + # When: The artifact is fetched by its ID + fetched_artifact = await db.get_artifact(artifact.artifact_id) + + # Then: The fetched artifact matches the original + assert fetched_artifact.artifact_id == artifact.artifact_id + assert ( + fetched_artifact.relative_path + == "file:///path/to/test_get_artifact_sample_file.txt" + ) + + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_list_tasks(): + db_name = "sqlite:///test_db.sqlite3" + db = AgentDB(db_name) + + # Given: Multiple tasks in the database + task1 = await db.create_task("test_input_1") + task2 = await db.create_task("test_input_2") + + # When: All tasks are fetched + fetched_tasks, pagination = await db.list_tasks() + + # Then: The fetched tasks list includes the created tasks + task_ids = [task.task_id for task in fetched_tasks] + assert task1.task_id in task_ids + assert task2.task_id in task_ids + os.remove(db_name.split("///")[1]) + + +@pytest.mark.asyncio +async def test_list_steps(): + db_name = "sqlite:///test_db.sqlite3" + db = AgentDB(db_name) + + step_input = StepInput(type="python/code") + requst = StepRequestBody(input="test_input debug", additional_input=step_input) + + # Given: A task and multiple steps for that task + task = await db.create_task("test_input") + step1 = await db.create_step(task.task_id, requst) + requst = StepRequestBody(input="step two", additional_input=step_input) + step2 = await db.create_step(task.task_id, requst) + + # When: All steps for the task are fetched + fetched_steps, pagination = await db.list_steps(task.task_id) + + # Then: The fetched steps list includes the created steps + step_ids = [step.step_id for step in fetched_steps] + assert step1.step_id in step_ids + assert step2.step_id in step_ids + os.remove(db_name.split("///")[1]) |