aboutsummaryrefslogtreecommitdiff
path: root/autogpts/autogpt/tests/integration/memory/_test_json_file_memory.py
blob: 94bf0d1bdee8697529125fe657c1d8b7d612f0a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# sourcery skip: snake-case-functions
"""Tests for JSONFileMemory class"""
import orjson
import pytest

from autogpt.config import Config
from autogpt.file_storage import FileStorage
from autogpt.memory.vector import JSONFileMemory, MemoryItem


def test_json_memory_init_without_backing_file(config: Config, storage: FileStorage):
    index_file = storage.root / f"{config.memory_index}.json"

    assert not index_file.exists()
    JSONFileMemory(config)
    assert index_file.exists()
    assert index_file.read_text() == "[]"


def test_json_memory_init_with_backing_empty_file(config: Config, storage: FileStorage):
    index_file = storage.root / f"{config.memory_index}.json"
    index_file.touch()

    assert index_file.exists()
    JSONFileMemory(config)
    assert index_file.exists()
    assert index_file.read_text() == "[]"


def test_json_memory_init_with_backing_invalid_file(
    config: Config, storage: FileStorage
):
    index_file = storage.root / f"{config.memory_index}.json"
    index_file.touch()

    raw_data = {"texts": ["test"]}
    data = orjson.dumps(raw_data, option=JSONFileMemory.SAVE_OPTIONS)
    with index_file.open("wb") as f:
        f.write(data)

    assert index_file.exists()
    JSONFileMemory(config)
    assert index_file.exists()
    assert index_file.read_text() == "[]"


def test_json_memory_add(config: Config, memory_item: MemoryItem):
    index = JSONFileMemory(config)
    index.add(memory_item)
    assert index.memories[0] == memory_item


def test_json_memory_clear(config: Config, memory_item: MemoryItem):
    index = JSONFileMemory(config)
    assert index.memories == []

    index.add(memory_item)
    assert index.memories[0] == memory_item, "Cannot test clear() because add() fails"

    index.clear()
    assert index.memories == []


def test_json_memory_get(config: Config, memory_item: MemoryItem, mock_get_embedding):
    index = JSONFileMemory(config)
    assert (
        index.get("test", config) is None
    ), "Cannot test get() because initial index is not empty"

    index.add(memory_item)
    retrieved = index.get("test", config)
    assert retrieved is not None
    assert retrieved.memory_item == memory_item


def test_json_memory_load_index(config: Config, memory_item: MemoryItem):
    index = JSONFileMemory(config)
    index.add(memory_item)

    try:
        assert index.file_path.exists(), "index was not saved to file"
        assert len(index) == 1, f"index contains {len(index)} items instead of 1"
        assert index.memories[0] == memory_item, "item in index != added mock item"
    except AssertionError as e:
        raise ValueError(f"Setting up for load_index test failed: {e}")

    index.memories = []
    index.load_index()

    assert len(index) == 1
    assert index.memories[0] == memory_item


@pytest.mark.vcr
@pytest.mark.requires_openai_api_key
def test_json_memory_get_relevant(config: Config, cached_openai_client: None) -> None:
    index = JSONFileMemory(config)
    mem1 = MemoryItem.from_text_file("Sample text", "sample.txt", config)
    mem2 = MemoryItem.from_text_file(
        "Grocery list:\n- Pancake mix", "groceries.txt", config
    )
    mem3 = MemoryItem.from_text_file(
        "What is your favorite color?", "color.txt", config
    )
    lipsum = "Lorem ipsum dolor sit amet"
    mem4 = MemoryItem.from_text_file(" ".join([lipsum] * 100), "lipsum.txt", config)
    index.add(mem1)
    index.add(mem2)
    index.add(mem3)
    index.add(mem4)

    assert index.get_relevant(mem1.raw_content, 1, config)[0].memory_item == mem1
    assert index.get_relevant(mem2.raw_content, 1, config)[0].memory_item == mem2
    assert index.get_relevant(mem3.raw_content, 1, config)[0].memory_item == mem3
    assert [mr.memory_item for mr in index.get_relevant(lipsum, 2, config)] == [
        mem4,
        mem1,
    ]


def test_json_memory_get_stats(config: Config, memory_item: MemoryItem) -> None:
    index = JSONFileMemory(config)
    index.add(memory_item)
    n_memories, n_chunks = index.get_stats()
    assert n_memories == 1
    assert n_chunks == 1