aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/app/agent_protocol_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/app/agent_protocol_server.py')
-rw-r--r--autogpts/autogpt/autogpt/app/agent_protocol_server.py32
1 files changed, 26 insertions, 6 deletions
diff --git a/autogpts/autogpt/autogpt/app/agent_protocol_server.py b/autogpts/autogpt/autogpt/app/agent_protocol_server.py
index aeb8d8f46..7bebca930 100644
--- a/autogpts/autogpt/autogpt/app/agent_protocol_server.py
+++ b/autogpts/autogpt/autogpt/app/agent_protocol_server.py
@@ -213,12 +213,8 @@ class AgentProtocolServer:
# Execute previously proposed action
if execute_command:
assert execute_command_args is not None
- agent.workspace.on_write_file = lambda path: self.db.create_artifact(
- task_id=step.task_id,
- step_id=step.step_id,
- file_name=path.parts[-1],
- agent_created=True,
- relative_path=str(path),
+ agent.workspace.on_write_file = lambda path: self._on_agent_write_file(
+ task=task, step=step, relative_path=path
)
if step.is_last and execute_command == finish.__name__:
@@ -317,6 +313,30 @@ class AgentProtocolServer:
agent.state.save_to_json_file(agent.file_manager.state_file_path)
return step
+ async def _on_agent_write_file(
+ self, task: Task, step: Step, relative_path: pathlib.Path
+ ) -> None:
+ """
+ Creates an Artifact for the written file, or updates the Artifact if it exists.
+ """
+ if relative_path.is_absolute():
+ raise ValueError(f"File path '{relative_path}' is not relative")
+ for a in task.artifacts or []:
+ if a.relative_path == str(relative_path):
+ logger.debug(f"Updating Artifact after writing to existing file: {a}")
+ if not a.agent_created:
+ await self.db.update_artifact(a.artifact_id, agent_created=True)
+ break
+ else:
+ logger.debug(f"Creating Artifact for new file '{relative_path}'")
+ await self.db.create_artifact(
+ task_id=step.task_id,
+ step_id=step.step_id,
+ file_name=relative_path.parts[-1],
+ agent_created=True,
+ relative_path=str(relative_path),
+ )
+
async def get_step(self, task_id: str, step_id: str) -> Step:
"""
Get a step by ID.