aboutsummaryrefslogtreecommitdiff
path: root/autogpts/forge/forge/db.py
blob: 10e78dc5a1e7d8df951b32f742ccbbc63dcdf941 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from .sdk import AgentDB, ForgeLogger, NotFoundError, Base
from sqlalchemy.exc import SQLAlchemyError

import datetime
from sqlalchemy import (
    Column,
    DateTime,
    String,
)
import uuid

LOG = ForgeLogger(__name__)

class ChatModel(Base):
    __tablename__ = "chat"
    msg_id = Column(String, primary_key=True, index=True)
    task_id = Column(String)
    role = Column(String)
    content = Column(String)
    created_at = Column(DateTime, default=datetime.datetime.utcnow)
    modified_at = Column(
        DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
    )

class ActionModel(Base):
    __tablename__ = "action"
    action_id = Column(String, primary_key=True, index=True)
    task_id = Column(String)
    name = Column(String)
    args = Column(String)
    created_at = Column(DateTime, default=datetime.datetime.utcnow)
    modified_at = Column(
        DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
    )


class ForgeDatabase(AgentDB):

    async def add_chat_history(self, task_id, messages):
        for message in messages:
            await self.add_chat_message(task_id, message['role'], message['content'])

    async def add_chat_message(self, task_id, role, content):
        if self.debug_enabled:
            LOG.debug("Creating new task")
        try:
            with self.Session() as session:
                mew_msg = ChatModel(
                    msg_id=str(uuid.uuid4()),
                    task_id=task_id,
                    role=role,
                    content=content,
                )
                session.add(mew_msg)
                session.commit()
                session.refresh(mew_msg)
                if self.debug_enabled:
                    LOG.debug(f"Created new Chat message with task_id: {mew_msg.msg_id}")
                return mew_msg
        except SQLAlchemyError as e:
            LOG.error(f"SQLAlchemy error while creating task: {e}")
            raise
        except NotFoundError as e:
            raise
        except Exception as e:
            LOG.error(f"Unexpected error while creating task: {e}")
            raise
       
    async def get_chat_history(self, task_id):
        if self.debug_enabled:
            LOG.debug(f"Getting chat history with task_id: {task_id}")
        try:
            with self.Session() as session:
                if messages := (
                    session.query(ChatModel)
                    .filter(ChatModel.task_id == task_id)
                    .order_by(ChatModel.created_at)
                    .all()
                ):
                    return [{"role": m.role, "content": m.content} for m in messages]

                else:
                    LOG.error(
                        f"Chat history not found with task_id: {task_id}"
                    )
                    raise NotFoundError("Chat history not found")
        except SQLAlchemyError as e:
            LOG.error(f"SQLAlchemy error while getting chat history: {e}")
            raise
        except NotFoundError as e:
            raise
        except Exception as e:
            LOG.error(f"Unexpected error while getting chat history: {e}")
            raise
    
    async def create_action(self, task_id, name, args):
        try:
            with self.Session() as session:
                new_action = ActionModel(
                    action_id=str(uuid.uuid4()),
                    task_id=task_id,
                    name=name,
                    args=str(args),
                )
                session.add(new_action)
                session.commit()
                session.refresh(new_action)
                if self.debug_enabled:
                    LOG.debug(f"Created new Action with task_id: {new_action.action_id}")
                return new_action
        except SQLAlchemyError as e:
            LOG.error(f"SQLAlchemy error while creating action: {e}")
            raise
        except NotFoundError as e:
            raise
        except Exception as e:
            LOG.error(f"Unexpected error while creating action: {e}")
            raise

    async def get_action_history(self, task_id):
        if self.debug_enabled:
            LOG.debug(f"Getting action history with task_id: {task_id}")
        try:
            with self.Session() as session:
                if actions := (
                    session.query(ActionModel)
                    .filter(ActionModel.task_id == task_id)
                    .order_by(ActionModel.created_at)
                    .all()
                ):
                    return [{"name": a.name, "args": a.args} for a in actions]

                else:
                    LOG.error(
                        f"Action history not found with task_id: {task_id}"
                    )
                    raise NotFoundError("Action history not found")
        except SQLAlchemyError as e:
            LOG.error(f"SQLAlchemy error while getting action history: {e}")
            raise
        except NotFoundError as e:
            raise
        except Exception as e:
            LOG.error(f"Unexpected error while getting action history: {e}")
            raise