aboutsummaryrefslogtreecommitdiff
path: root/autogpt/memory/message_history.py
blob: 30dbbb809edc3e2761905988d1e9591a1fabc816 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from __future__ import annotations

import copy
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
    from autogpt.agent import Agent

from autogpt.config import Config
from autogpt.json_utils.utilities import extract_json_from_response
from autogpt.llm.base import ChatSequence, Message
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
from autogpt.llm.utils import (
    count_message_tokens,
    count_string_tokens,
    create_chat_completion,
)
from autogpt.logs import PROMPT_SUMMARY_FILE_NAME, SUMMARY_FILE_NAME, logger


@dataclass
class MessageHistory(ChatSequence):
    max_summary_tlength: int = 500
    agent: Optional[Agent] = None
    summary: str = "I was created"
    last_trimmed_index: int = 0

    SUMMARIZATION_PROMPT = '''Your task is to create a concise running summary of actions and information results in the provided text, focusing on key and potentially important information to remember.

You will receive the current summary and your latest actions. Combine them, adding relevant key information from the latest development in 1st person past tense and keeping the summary concise.

Summary So Far:
"""
{summary}
"""

Latest Development:
"""
{new_events}
"""
'''

    def trim_messages(
        self, current_message_chain: list[Message], config: Config
    ) -> tuple[Message, list[Message]]:
        """
        Returns a list of trimmed messages: messages which are in the message history
        but not in current_message_chain.

        Args:
            current_message_chain (list[Message]): The messages currently in the context.
            config (Config): The config to use.

        Returns:
            Message: A message with the new running summary after adding the trimmed messages.
            list[Message]: A list of messages that are in full_message_history with an index higher than last_trimmed_index and absent from current_message_chain.
        """
        # Select messages in full_message_history with an index higher than last_trimmed_index
        new_messages = [
            msg for i, msg in enumerate(self) if i > self.last_trimmed_index
        ]

        # Remove messages that are already present in current_message_chain
        new_messages_not_in_chain = [
            msg for msg in new_messages if msg not in current_message_chain
        ]

        if not new_messages_not_in_chain:
            return self.summary_message(), []

        new_summary_message = self.update_running_summary(
            new_events=new_messages_not_in_chain, config=config
        )

        # Find the index of the last message processed
        last_message = new_messages_not_in_chain[-1]
        self.last_trimmed_index = self.messages.index(last_message)

        return new_summary_message, new_messages_not_in_chain

    def per_cycle(self, messages: list[Message] | None = None):
        """
        Yields:
            Message: a message containing user input
            Message: a message from the AI containing a proposed action
            Message: the message containing the result of the AI's proposed action
        """
        messages = messages or self.messages
        for i in range(0, len(messages) - 1):
            ai_message = messages[i]
            if ai_message.type != "ai_response":
                continue
            user_message = (
                messages[i - 1] if i > 0 and messages[i - 1].role == "user" else None
            )
            result_message = messages[i + 1]
            try:
                assert (
                    extract_json_from_response(ai_message.content) != {}
                ), "AI response is not a valid JSON object"
                assert result_message.type == "action_result"

                yield user_message, ai_message, result_message
            except AssertionError as err:
                logger.debug(
                    f"Invalid item in message history: {err}; Messages: {messages[i-1:i+2]}"
                )

    def summary_message(self) -> Message:
        return Message(
            "system",
            f"This reminds you of these events from your past: \n{self.summary}",
        )

    def update_running_summary(
        self,
        new_events: list[Message],
        config: Config,
        max_summary_length: Optional[int] = None,
    ) -> Message:
        """
        This function takes a list of Message objects and updates the running summary
        to include the events they describe. The updated summary is returned
        in a Message formatted in the 1st person past tense.

        Args:
            new_events: A list of Messages containing the latest events to be added to the summary.

        Returns:
            Message: a Message containing the updated running summary.

        Example:
            ```py
            new_events = [{"event": "entered the kitchen."}, {"event": "found a scrawled note with the number 7"}]
            update_running_summary(new_events)
            # Returns: "This reminds you of these events from your past: \nI entered the kitchen and found a scrawled note saying 7."
            ```
        """
        if not new_events:
            return self.summary_message()
        if not max_summary_length:
            max_summary_length = self.max_summary_tlength

        # Create a copy of the new_events list to prevent modifying the original list
        new_events = copy.deepcopy(new_events)

        # Replace "assistant" with "you". This produces much better first person past tense results.
        for event in new_events:
            if event.role.lower() == "assistant":
                event.role = "you"

                # Remove "thoughts" dictionary from "content"
                try:
                    content_dict = extract_json_from_response(event.content)
                    if "thoughts" in content_dict:
                        del content_dict["thoughts"]
                    event.content = json.dumps(content_dict)
                except json.JSONDecodeError as e:
                    logger.error(f"Error: Invalid JSON: {e}")
                    if config.debug_mode:
                        logger.error(f"{event.content}")

            elif event.role.lower() == "system":
                event.role = "your computer"

            # Delete all user messages
            elif event.role == "user":
                new_events.remove(event)

        summ_model = OPEN_AI_CHAT_MODELS[config.fast_llm]

        # Determine token lengths for use in batching
        prompt_template_length = len(
            MessageHistory.SUMMARIZATION_PROMPT.format(summary="", new_events="")
        )
        max_input_tokens = summ_model.max_tokens - max_summary_length
        summary_tlength = count_string_tokens(self.summary, summ_model.name)
        batch = []
        batch_tlength = 0

        # TODO: Put a cap on length of total new events and drop some previous events to
        # save API cost. Need to think thru more how to do it without losing the context.
        for event in new_events:
            event_tlength = count_message_tokens(event, summ_model.name)

            if (
                batch_tlength + event_tlength
                > max_input_tokens - prompt_template_length - summary_tlength
            ):
                # The batch is full. Summarize it and start a new one.
                self.summarize_batch(batch, config, max_summary_length)
                summary_tlength = count_string_tokens(self.summary, summ_model.name)
                batch = [event]
                batch_tlength = event_tlength
            else:
                batch.append(event)
                batch_tlength += event_tlength

        if batch:
            # There's an unprocessed batch. Summarize it.
            self.summarize_batch(batch, config, max_summary_length)

        return self.summary_message()

    def summarize_batch(
        self, new_events_batch: list[Message], config: Config, max_output_length: int
    ):
        prompt = MessageHistory.SUMMARIZATION_PROMPT.format(
            summary=self.summary, new_events=new_events_batch
        )

        prompt = ChatSequence.for_model(config.fast_llm, [Message("user", prompt)])
        if self.agent:
            self.agent.log_cycle_handler.log_cycle(
                self.agent.ai_config.ai_name,
                self.agent.created_at,
                self.agent.cycle_count,
                prompt.raw(),
                PROMPT_SUMMARY_FILE_NAME,
            )

        self.summary = create_chat_completion(
            prompt, config, max_tokens=max_output_length
        ).content

        if self.agent:
            self.agent.log_cycle_handler.log_cycle(
                self.agent.ai_config.ai_name,
                self.agent.created_at,
                self.agent.cycle_count,
                self.summary,
                SUMMARY_FILE_NAME,
            )