diff options
Diffstat (limited to 'autogpts/forge/forge/db.py')
-rw-r--r-- | autogpts/forge/forge/db.py | 38 |
1 files changed, 18 insertions, 20 deletions
diff --git a/autogpts/forge/forge/db.py b/autogpts/forge/forge/db.py index 10e78dc5a..cff096379 100644 --- a/autogpts/forge/forge/db.py +++ b/autogpts/forge/forge/db.py @@ -1,16 +1,14 @@ -from .sdk import AgentDB, ForgeLogger, NotFoundError, Base -from sqlalchemy.exc import SQLAlchemyError - import datetime -from sqlalchemy import ( - Column, - DateTime, - String, -) import uuid +from sqlalchemy import Column, DateTime, String +from sqlalchemy.exc import SQLAlchemyError + +from .sdk import AgentDB, Base, ForgeLogger, NotFoundError + LOG = ForgeLogger(__name__) + class ChatModel(Base): __tablename__ = "chat" msg_id = Column(String, primary_key=True, index=True) @@ -22,6 +20,7 @@ class ChatModel(Base): DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow ) + class ActionModel(Base): __tablename__ = "action" action_id = Column(String, primary_key=True, index=True) @@ -35,10 +34,9 @@ class ActionModel(Base): class ForgeDatabase(AgentDB): - async def add_chat_history(self, task_id, messages): for message in messages: - await self.add_chat_message(task_id, message['role'], message['content']) + await self.add_chat_message(task_id, message["role"], message["content"]) async def add_chat_message(self, task_id, role, content): if self.debug_enabled: @@ -55,7 +53,9 @@ class ForgeDatabase(AgentDB): session.commit() session.refresh(mew_msg) if self.debug_enabled: - LOG.debug(f"Created new Chat message with task_id: {mew_msg.msg_id}") + LOG.debug( + f"Created new Chat message with task_id: {mew_msg.msg_id}" + ) return mew_msg except SQLAlchemyError as e: LOG.error(f"SQLAlchemy error while creating task: {e}") @@ -65,7 +65,7 @@ class ForgeDatabase(AgentDB): except Exception as e: LOG.error(f"Unexpected error while creating task: {e}") raise - + async def get_chat_history(self, task_id): if self.debug_enabled: LOG.debug(f"Getting chat history with task_id: {task_id}") @@ -80,9 +80,7 @@ class ForgeDatabase(AgentDB): return [{"role": m.role, "content": m.content} for m in messages] else: - LOG.error( - f"Chat history not found with task_id: {task_id}" - ) + LOG.error(f"Chat history not found with task_id: {task_id}") raise NotFoundError("Chat history not found") except SQLAlchemyError as e: LOG.error(f"SQLAlchemy error while getting chat history: {e}") @@ -92,7 +90,7 @@ class ForgeDatabase(AgentDB): except Exception as e: LOG.error(f"Unexpected error while getting chat history: {e}") raise - + async def create_action(self, task_id, name, args): try: with self.Session() as session: @@ -106,7 +104,9 @@ class ForgeDatabase(AgentDB): session.commit() session.refresh(new_action) if self.debug_enabled: - LOG.debug(f"Created new Action with task_id: {new_action.action_id}") + LOG.debug( + f"Created new Action with task_id: {new_action.action_id}" + ) return new_action except SQLAlchemyError as e: LOG.error(f"SQLAlchemy error while creating action: {e}") @@ -131,9 +131,7 @@ class ForgeDatabase(AgentDB): return [{"name": a.name, "args": a.args} for a in actions] else: - LOG.error( - f"Action history not found with task_id: {task_id}" - ) + LOG.error(f"Action history not found with task_id: {task_id}") raise NotFoundError("Action history not found") except SQLAlchemyError as e: LOG.error(f"SQLAlchemy error while getting action history: {e}") |