diff options
author | Reinier van der Leer <pwuts@agpt.co> | 2024-01-19 11:41:40 +0100 |
---|---|---|
committer | Reinier van der Leer <pwuts@agpt.co> | 2024-01-19 11:41:40 +0100 |
commit | b238abac52a4f945325603d433b7eade5bb92d2a (patch) | |
tree | 16aa453a7d7bd197a2c94bf12e6016d7db3099df | |
parent | refactor(benchmark): Interface & type consoledation, and arch change, to allo... (diff) | |
download | Auto-GPT-b238abac52a4f945325603d433b7eade5bb92d2a.tar.gz Auto-GPT-b238abac52a4f945325603d433b7eade5bb92d2a.tar.bz2 Auto-GPT-b238abac52a4f945325603d433b7eade5bb92d2a.zip |
feat(forge/db): Add `AgentDB.update_artifact` method
-rw-r--r-- | autogpts/forge/forge/sdk/db.py | 73 |
1 files changed, 50 insertions, 23 deletions
diff --git a/autogpts/forge/forge/sdk/db.py b/autogpts/forge/forge/sdk/db.py index 1af538c62..ce4d22f6f 100644 --- a/autogpts/forge/forge/sdk/db.py +++ b/autogpts/forge/forge/sdk/db.py @@ -7,7 +7,7 @@ IT IS NOT ADVISED TO USE THIS IN PRODUCTION! import datetime import math import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple from sqlalchemy import ( JSON, @@ -311,6 +311,29 @@ class AgentDB: LOG.error(f"Unexpected error while getting step: {e}") raise + async def get_artifact(self, artifact_id: str) -> Artifact: + if self.debug_enabled: + LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}") + try: + with self.Session() as session: + if ( + artifact_model := session.query(ArtifactModel) + .filter_by(artifact_id=artifact_id) + .first() + ): + return convert_to_artifact(artifact_model) + else: + LOG.error(f"Artifact not found with and artifact_id: {artifact_id}") + raise NotFoundError("Artifact not found") + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while getting artifact: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while getting artifact: {e}") + raise + async def update_step( self, task_id: str, @@ -353,28 +376,32 @@ class AgentDB: LOG.error(f"Unexpected error while getting step: {e}") raise - async def get_artifact(self, artifact_id: str) -> Artifact: - if self.debug_enabled: - LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}") - try: - with self.Session() as session: - if ( - artifact_model := session.query(ArtifactModel) - .filter_by(artifact_id=artifact_id) - .first() - ): - return convert_to_artifact(artifact_model) - else: - LOG.error(f"Artifact not found with and artifact_id: {artifact_id}") - raise NotFoundError("Artifact not found") - except SQLAlchemyError as e: - LOG.error(f"SQLAlchemy error while getting artifact: {e}") - raise - except NotFoundError as e: - raise - except Exception as e: - LOG.error(f"Unexpected error while getting artifact: {e}") - raise + async def update_artifact( + self, + artifact_id: str, + *, + file_name: str = "", + relative_path: str = "", + agent_created: Optional[Literal[True]] = None, + ) -> Artifact: + LOG.debug(f"Updating artifact with artifact_id: {artifact_id}") + with self.Session() as session: + if ( + artifact := session.query(ArtifactModel) + .filter_by(artifact_id=artifact_id) + .first() + ): + if file_name: + artifact.file_name = file_name + if relative_path: + artifact.relative_path = relative_path + if agent_created: + artifact.agent_created = agent_created + session.commit() + return await self.get_artifact(artifact_id) + else: + LOG.error(f"Artifact not found with artifact_id: {artifact_id}") + raise NotFoundError("Artifact not found") async def list_tasks( self, page: int = 1, per_page: int = 10 |