diff options
Diffstat (limited to 'autogpt/memory/vector/providers/base.py')
-rw-r--r-- | autogpt/memory/vector/providers/base.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/autogpt/memory/vector/providers/base.py b/autogpt/memory/vector/providers/base.py index 969d89347..dc4dbf3cc 100644 --- a/autogpt/memory/vector/providers/base.py +++ b/autogpt/memory/vector/providers/base.py @@ -17,25 +17,29 @@ class VectorMemoryProvider(MutableSet[MemoryItem], AbstractSingleton): def __init__(self, config: Config): pass - def get(self, query: str) -> MemoryItemRelevance | None: + def get(self, query: str, config: Config) -> MemoryItemRelevance | None: """ Gets the data from the memory that is most relevant to the given query. Args: - data: The data to compare to. + query: The query used to retrieve information. + config: The config Object. Returns: The most relevant Memory """ - result = self.get_relevant(query, 1) + result = self.get_relevant(query, 1, config) return result[0] if result else None - def get_relevant(self, query: str, k: int) -> Sequence[MemoryItemRelevance]: + def get_relevant( + self, query: str, k: int, config: Config + ) -> Sequence[MemoryItemRelevance]: """ Returns the top-k most relevant memories for the given query Args: query: the query to compare stored memories to k: the number of relevant memories to fetch + config: The config Object. Returns: list[MemoryItemRelevance] containing the top [k] relevant memories @@ -48,7 +52,7 @@ class VectorMemoryProvider(MutableSet[MemoryItem], AbstractSingleton): f"{len(self)} memories in index" ) - relevances = self.score_memories_for_relevance(query) + relevances = self.score_memories_for_relevance(query, config) logger.debug(f"Memory relevance scores: {[str(r) for r in relevances]}") # take last k items and reverse @@ -57,13 +61,13 @@ class VectorMemoryProvider(MutableSet[MemoryItem], AbstractSingleton): return [relevances[i] for i in top_k_indices] def score_memories_for_relevance( - self, for_query: str + self, for_query: str, config: Config ) -> Sequence[MemoryItemRelevance]: """ Returns MemoryItemRelevance for every memory in the index. Implementations may override this function for performance purposes. """ - e_query: Embedding = get_embedding(for_query) + e_query: Embedding = get_embedding(for_query, config) return [m.relevance_for(for_query, e_query) for m in self] def get_stats(self) -> tuple[int, int]: |