diff options
Diffstat (limited to 'autogpts/forge/forge/sdk/db.py')
-rw-r--r-- | autogpts/forge/forge/sdk/db.py | 77 |
1 files changed, 52 insertions, 25 deletions
diff --git a/autogpts/forge/forge/sdk/db.py b/autogpts/forge/forge/sdk/db.py index 0b35139b9..ce4d22f6f 100644 --- a/autogpts/forge/forge/sdk/db.py +++ b/autogpts/forge/forge/sdk/db.py @@ -7,7 +7,7 @@ IT IS NOT ADVISED TO USE THIS IN PRODUCTION! import datetime import math import uuid -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple from sqlalchemy import ( JSON, @@ -23,7 +23,7 @@ from sqlalchemy.orm import DeclarativeBase, joinedload, relationship, sessionmak from .errors import NotFoundError from .forge_log import ForgeLogger -from .schema import Artifact, Pagination, Status, Step, StepRequestBody, Task +from .model import Artifact, Pagination, Status, Step, StepRequestBody, Task LOG = ForgeLogger(__name__) @@ -259,7 +259,7 @@ class AgentDB: LOG.error(f"Unexpected error while creating step: {e}") raise - async def get_task(self, task_id: int) -> Task: + async def get_task(self, task_id: str) -> Task: """Get a task by its id""" if self.debug_enabled: LOG.debug(f"Getting task with task_id: {task_id}") @@ -311,6 +311,29 @@ class AgentDB: LOG.error(f"Unexpected error while getting step: {e}") raise + async def get_artifact(self, artifact_id: str) -> Artifact: + if self.debug_enabled: + LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}") + try: + with self.Session() as session: + if ( + artifact_model := session.query(ArtifactModel) + .filter_by(artifact_id=artifact_id) + .first() + ): + return convert_to_artifact(artifact_model) + else: + LOG.error(f"Artifact not found with and artifact_id: {artifact_id}") + raise NotFoundError("Artifact not found") + except SQLAlchemyError as e: + LOG.error(f"SQLAlchemy error while getting artifact: {e}") + raise + except NotFoundError as e: + raise + except Exception as e: + LOG.error(f"Unexpected error while getting artifact: {e}") + raise + async def update_step( self, task_id: str, @@ -353,28 +376,32 @@ class AgentDB: LOG.error(f"Unexpected error while getting step: {e}") raise - async def get_artifact(self, artifact_id: str) -> Artifact: - if self.debug_enabled: - LOG.debug(f"Getting artifact with and artifact_id: {artifact_id}") - try: - with self.Session() as session: - if ( - artifact_model := session.query(ArtifactModel) - .filter_by(artifact_id=artifact_id) - .first() - ): - return convert_to_artifact(artifact_model) - else: - LOG.error(f"Artifact not found with and artifact_id: {artifact_id}") - raise NotFoundError("Artifact not found") - except SQLAlchemyError as e: - LOG.error(f"SQLAlchemy error while getting artifact: {e}") - raise - except NotFoundError as e: - raise - except Exception as e: - LOG.error(f"Unexpected error while getting artifact: {e}") - raise + async def update_artifact( + self, + artifact_id: str, + *, + file_name: str = "", + relative_path: str = "", + agent_created: Optional[Literal[True]] = None, + ) -> Artifact: + LOG.debug(f"Updating artifact with artifact_id: {artifact_id}") + with self.Session() as session: + if ( + artifact := session.query(ArtifactModel) + .filter_by(artifact_id=artifact_id) + .first() + ): + if file_name: + artifact.file_name = file_name + if relative_path: + artifact.relative_path = relative_path + if agent_created: + artifact.agent_created = agent_created + session.commit() + return await self.get_artifact(artifact_id) + else: + LOG.error(f"Artifact not found with artifact_id: {artifact_id}") + raise NotFoundError("Artifact not found") async def list_tasks( self, page: int = 1, per_page: int = 10 |