aboutsummaryrefslogtreecommitdiff
path: root/autogpt/memory/vector/providers/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'autogpt/memory/vector/providers/base.py')
-rw-r--r--autogpt/memory/vector/providers/base.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/autogpt/memory/vector/providers/base.py b/autogpt/memory/vector/providers/base.py
index 969d89347..dc4dbf3cc 100644
--- a/autogpt/memory/vector/providers/base.py
+++ b/autogpt/memory/vector/providers/base.py
@@ -17,25 +17,29 @@ class VectorMemoryProvider(MutableSet[MemoryItem], AbstractSingleton):
def __init__(self, config: Config):
pass
- def get(self, query: str) -> MemoryItemRelevance | None:
+ def get(self, query: str, config: Config) -> MemoryItemRelevance | None:
"""
Gets the data from the memory that is most relevant to the given query.
Args:
- data: The data to compare to.
+ query: The query used to retrieve information.
+ config: The config Object.
Returns: The most relevant Memory
"""
- result = self.get_relevant(query, 1)
+ result = self.get_relevant(query, 1, config)
return result[0] if result else None
- def get_relevant(self, query: str, k: int) -> Sequence[MemoryItemRelevance]:
+ def get_relevant(
+ self, query: str, k: int, config: Config
+ ) -> Sequence[MemoryItemRelevance]:
"""
Returns the top-k most relevant memories for the given query
Args:
query: the query to compare stored memories to
k: the number of relevant memories to fetch
+ config: The config Object.
Returns:
list[MemoryItemRelevance] containing the top [k] relevant memories
@@ -48,7 +52,7 @@ class VectorMemoryProvider(MutableSet[MemoryItem], AbstractSingleton):
f"{len(self)} memories in index"
)
- relevances = self.score_memories_for_relevance(query)
+ relevances = self.score_memories_for_relevance(query, config)
logger.debug(f"Memory relevance scores: {[str(r) for r in relevances]}")
# take last k items and reverse
@@ -57,13 +61,13 @@ class VectorMemoryProvider(MutableSet[MemoryItem], AbstractSingleton):
return [relevances[i] for i in top_k_indices]
def score_memories_for_relevance(
- self, for_query: str
+ self, for_query: str, config: Config
) -> Sequence[MemoryItemRelevance]:
"""
Returns MemoryItemRelevance for every memory in the index.
Implementations may override this function for performance purposes.
"""
- e_query: Embedding = get_embedding(for_query)
+ e_query: Embedding = get_embedding(for_query, config)
return [m.relevance_for(for_query, e_query) for m in self]
def get_stats(self) -> tuple[int, int]: