aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/autogpt/memory/vector/providers/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpts/autogpt/autogpt/memory/vector/providers/base.py')
-rw-r--r--autogpts/autogpt/autogpt/memory/vector/providers/base.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/autogpts/autogpt/autogpt/memory/vector/providers/base.py b/autogpts/autogpt/autogpt/memory/vector/providers/base.py
new file mode 100644
index 000000000..8883f1346
--- /dev/null
+++ b/autogpts/autogpt/autogpt/memory/vector/providers/base.py
@@ -0,0 +1,79 @@
+import abc
+import functools
+import logging
+from typing import MutableSet, Sequence
+
+import numpy as np
+
+from autogpt.config.config import Config
+
+from .. import MemoryItem, MemoryItemRelevance
+from ..utils import Embedding, get_embedding
+
+logger = logging.getLogger(__name__)
+
+
+class VectorMemoryProvider(MutableSet[MemoryItem]):
+ @abc.abstractmethod
+ def __init__(self, config: Config):
+ pass
+
+ def get(self, query: str, config: Config) -> MemoryItemRelevance | None:
+ """
+ Gets the data from the memory that is most relevant to the given query.
+
+ Args:
+ query: The query used to retrieve information.
+ config: The config Object.
+
+ Returns: The most relevant Memory
+ """
+ result = self.get_relevant(query, 1, config)
+ return result[0] if result else None
+
+ def get_relevant(
+ self, query: str, k: int, config: Config
+ ) -> Sequence[MemoryItemRelevance]:
+ """
+ Returns the top-k most relevant memories for the given query
+
+ Args:
+ query: the query to compare stored memories to
+ k: the number of relevant memories to fetch
+ config: The config Object.
+
+ Returns:
+ list[MemoryItemRelevance] containing the top [k] relevant memories
+ """
+ if len(self) < 1:
+ return []
+
+ logger.debug(
+ f"Searching for {k} relevant memories for query '{query}'; "
+ f"{len(self)} memories in index"
+ )
+
+ relevances = self.score_memories_for_relevance(query, config)
+ logger.debug(f"Memory relevance scores: {[str(r) for r in relevances]}")
+
+ # take last k items and reverse
+ top_k_indices = np.argsort([r.score for r in relevances])[-k:][::-1]
+
+ return [relevances[i] for i in top_k_indices]
+
+ def score_memories_for_relevance(
+ self, for_query: str, config: Config
+ ) -> Sequence[MemoryItemRelevance]:
+ """
+ Returns MemoryItemRelevance for every memory in the index.
+ Implementations may override this function for performance purposes.
+ """
+ e_query: Embedding = get_embedding(for_query, config)
+ return [m.relevance_for(for_query, e_query) for m in self]
+
+ def get_stats(self) -> tuple[int, int]:
+ """
+ Returns:
+ tuple (n_memories: int, n_chunks: int): the stats of the memory index
+ """
+ return len(self), functools.reduce(lambda t, m: t + len(m.e_chunks), self, 0)