aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/agents/features/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/agents/features/context.py')
-rw-r--r--autogpts/autogpt/autogpt/agents/features/context.py82
1 files changed, 82 insertions, 0 deletions
diff --git a/autogpts/autogpt/autogpt/agents/features/context.py b/autogpts/autogpt/autogpt/agents/features/context.py
new file mode 100644
index 000000000..748635ddc
--- /dev/null
+++ b/autogpts/autogpt/autogpt/agents/features/context.py
@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Optional
+
+if TYPE_CHECKING:
+ from autogpt.core.prompting import ChatPrompt
+ from autogpt.models.context_item import ContextItem
+
+ from ..base import BaseAgent
+
+from autogpt.core.resource.model_providers import ChatMessage
+
+
+class AgentContext:
+ items: list[ContextItem]
+
+ def __init__(self, items: Optional[list[ContextItem]] = None):
+ self.items = items or []
+
+ def __bool__(self) -> bool:
+ return len(self.items) > 0
+
+ def __contains__(self, item: ContextItem) -> bool:
+ return any([i.source == item.source for i in self.items])
+
+ def add(self, item: ContextItem) -> None:
+ self.items.append(item)
+
+ def close(self, index: int) -> None:
+ self.items.pop(index - 1)
+
+ def clear(self) -> None:
+ self.items.clear()
+
+ def format_numbered(self) -> str:
+ return "\n\n".join([f"{i}. {c.fmt()}" for i, c in enumerate(self.items, 1)])
+
+
+class ContextMixin:
+ """Mixin that adds context support to a BaseAgent subclass"""
+
+ context: AgentContext
+
+ def __init__(self, **kwargs: Any):
+ self.context = AgentContext()
+
+ super(ContextMixin, self).__init__(**kwargs)
+
+ def build_prompt(
+ self,
+ *args: Any,
+ extra_messages: Optional[list[ChatMessage]] = None,
+ **kwargs: Any,
+ ) -> ChatPrompt:
+ if not extra_messages:
+ extra_messages = []
+
+ # Add context section to prompt
+ if self.context:
+ extra_messages.insert(
+ 0,
+ ChatMessage.system(
+ "## Context\n"
+ f"{self.context.format_numbered()}\n\n"
+ "When a context item is no longer needed and you are not done yet, "
+ "you can hide the item by specifying its number in the list above "
+ "to `hide_context_item`.",
+ ),
+ )
+
+ return super(ContextMixin, self).build_prompt(
+ *args,
+ extra_messages=extra_messages,
+ **kwargs,
+ ) # type: ignore
+
+
+def get_agent_context(agent: BaseAgent) -> AgentContext | None:
+ if isinstance(agent, ContextMixin):
+ return agent.context
+
+ return None