Skip to content

Commit 4c98430

Browse files
authored
feat: Add agent memory (#4829)
1 parent d4bbde2 commit 4c98430

File tree

9 files changed

+397
-4
lines changed

9 files changed

+397
-4
lines changed

haystack/agents/memory/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from haystack.agents.memory.base import Memory
2+
from haystack.agents.memory.no_memory import NoMemory
3+
from haystack.agents.memory.conversation_memory import ConversationMemory
4+
from haystack.agents.memory.conversation_summary_memory import ConversationSummaryMemory

haystack/agents/memory/base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, Any, List, Optional
3+
4+
5+
class Memory(ABC):
6+
"""
7+
Abstract base class for memory management in an Agent.
8+
"""
9+
10+
@abstractmethod
11+
def load(self, keys: Optional[List[str]] = None, **kwargs) -> Any:
12+
"""
13+
Load the context of this model run from memory.
14+
15+
:param keys: Optional list of keys to specify the data to load.
16+
:return: The loaded data.
17+
"""
18+
19+
@abstractmethod
20+
def save(self, data: Dict[str, Any]) -> None:
21+
"""
22+
Save the context of this model run to memory.
23+
24+
:param data: A dictionary containing the data to save.
25+
"""
26+
27+
@abstractmethod
28+
def clear(self) -> None:
29+
"""
30+
Clear memory contents.
31+
"""
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import collections
2+
from typing import OrderedDict, List, Optional, Any, Dict
3+
4+
from haystack.agents.memory import Memory
5+
6+
7+
class ConversationMemory(Memory):
8+
"""
9+
A memory class that stores conversation history.
10+
"""
11+
12+
def __init__(self, input_key: str = "input", output_key: str = "output"):
13+
"""
14+
Initialize ConversationMemory with input and output keys.
15+
16+
:param input_key: The key to use for storing user input.
17+
:param output_key: The key to use for storing model output.
18+
"""
19+
self.list: List[OrderedDict] = []
20+
self.input_key = input_key
21+
self.output_key = output_key
22+
23+
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
24+
"""
25+
Load conversation history as a formatted string.
26+
27+
:param keys: Optional list of keys (ignored in this implementation).
28+
:param kwargs: Optional keyword arguments
29+
- window_size: integer specifying the number of most recent conversation snippets to load.
30+
:return: A formatted string containing the conversation history.
31+
"""
32+
chat_transcript = ""
33+
window_size = kwargs.get("window_size", None)
34+
35+
if window_size is not None:
36+
chat_list = self.list[-window_size:] # pylint: disable=invalid-unary-operand-type
37+
else:
38+
chat_list = self.list
39+
40+
for chat_snippet in chat_list:
41+
chat_transcript += f"Human: {chat_snippet['Human']}\n"
42+
chat_transcript += f"AI: {chat_snippet['AI']}\n"
43+
return chat_transcript
44+
45+
def save(self, data: Dict[str, Any]) -> None:
46+
"""
47+
Save a conversation snippet to memory.
48+
49+
:param data: A dictionary containing the conversation snippet to save.
50+
"""
51+
chat_snippet = collections.OrderedDict()
52+
chat_snippet["Human"] = data[self.input_key]
53+
chat_snippet["AI"] = data[self.output_key]
54+
self.list.append(chat_snippet)
55+
56+
def clear(self) -> None:
57+
"""
58+
Clear the conversation history.
59+
"""
60+
self.list = []
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from typing import Optional, Union, Dict, Any, List
2+
3+
from haystack.agents.memory import ConversationMemory
4+
from haystack.nodes import PromptTemplate, PromptNode
5+
6+
7+
class ConversationSummaryMemory(ConversationMemory):
8+
"""
9+
A memory class that stores conversation history and periodically generates summaries.
10+
"""
11+
12+
def __init__(
13+
self,
14+
prompt_node: PromptNode,
15+
prompt_template: Optional[Union[str, PromptTemplate]] = None,
16+
input_key: str = "input",
17+
output_key: str = "output",
18+
summary_frequency: int = 3,
19+
):
20+
"""
21+
Initialize ConversationSummaryMemory with a PromptNode, optional prompt_template,
22+
input and output keys, and a summary_frequency.
23+
24+
:param prompt_node: A PromptNode object for generating conversation summaries.
25+
:param prompt_template: Optional prompt template as a string or PromptTemplate object.
26+
:param input_key: input key, default is "input".
27+
:param output_key: output key, default is "output".
28+
:param summary_frequency: integer specifying how often to generate a summary (default is 3).
29+
"""
30+
super().__init__(input_key, output_key)
31+
self.save_count = 0
32+
self.prompt_node = prompt_node
33+
34+
template = (
35+
prompt_template
36+
if prompt_template is not None
37+
else prompt_node.default_prompt_template or "conversational-summary"
38+
)
39+
self.template = prompt_node.get_prompt_template(template)
40+
self.summary_frequency = summary_frequency
41+
self.summary = ""
42+
43+
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
44+
"""
45+
Load conversation history as a formatted string, including the latest summary.
46+
47+
:param keys: Optional list of keys (ignored in this implementation).
48+
:param kwargs: Optional keyword arguments
49+
- window_size: integer specifying the number of most recent conversation snippets to load.
50+
:return: A formatted string containing the conversation history with the latest summary.
51+
"""
52+
if self.has_unsummarized_snippets():
53+
unsummarized = super().load(keys=keys, window_size=self.unsummarized_snippets())
54+
return f"{self.summary}\n{unsummarized}"
55+
else:
56+
return self.summary
57+
58+
def summarize(self) -> str:
59+
"""
60+
Generate a summary of the conversation history and clear the history.
61+
62+
:return: A string containing the generated summary.
63+
"""
64+
most_recent_chat_snippets = self.load(window_size=self.summary_frequency)
65+
pn_response = self.prompt_node.prompt(self.template, chat_transcript=most_recent_chat_snippets)
66+
return pn_response[0]
67+
68+
def needs_summary(self) -> bool:
69+
"""
70+
Determine if a new summary should be generated.
71+
72+
:return: True if a new summary should be generated, otherwise False.
73+
"""
74+
return self.save_count % self.summary_frequency == 0
75+
76+
def unsummarized_snippets(self) -> int:
77+
"""
78+
Returns how many conversation snippets have not been summarized.
79+
:return: The number of conversation snippets that have not been summarized.
80+
"""
81+
return self.save_count % self.summary_frequency
82+
83+
def has_unsummarized_snippets(self) -> bool:
84+
"""
85+
Returns True if there are any conversation snippets that have not been summarized.
86+
:return: True if there are unsummarized snippets, otherwise False.
87+
"""
88+
return self.unsummarized_snippets() != 0
89+
90+
def save(self, data: Dict[str, Any]) -> None:
91+
"""
92+
Save a conversation snippet to memory and update the save count.
93+
Generate a summary if needed.
94+
95+
:param data: A dictionary containing the conversation snippet to save.
96+
"""
97+
super().save(data)
98+
self.save_count += 1
99+
if self.needs_summary():
100+
self.summary = self.summarize()
101+
102+
def clear(self) -> None:
103+
"""
104+
Clear the conversation history and the summary.
105+
"""
106+
super().clear()
107+
self.save_count = 0
108+
self.summary = ""

haystack/agents/memory/no_memory.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Optional, List, Any, Dict
2+
3+
from haystack.agents.memory import Memory
4+
5+
6+
class NoMemory(Memory):
7+
"""
8+
A memory class that doesn't store any data.
9+
"""
10+
11+
def load(self, keys: Optional[List[str]] = None, **kwargs) -> str:
12+
"""
13+
Load an empty dictionary.
14+
15+
:param keys: Optional list of keys (ignored in this implementation).
16+
:return: An empty str.
17+
"""
18+
return ""
19+
20+
def save(self, data: Dict[str, Any]) -> None:
21+
"""
22+
Save method that does nothing.
23+
24+
:param data: A dictionary containing the data to save (ignored in this implementation).
25+
"""
26+
pass
27+
28+
def clear(self) -> None:
29+
"""
30+
Clear method that does nothing.
31+
"""
32+
pass

haystack/nodes/prompt/prompt_template.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,4 +434,8 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
434434
"Question: {query}\n"
435435
"Thought: Let's think step-by-step, I first need to ",
436436
),
437+
PromptTemplate(
438+
name="conversational-summary",
439+
prompt_text="Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:",
440+
),
437441
]

test/agents/test_memory.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytest
2+
from typing import Dict, Any
3+
from haystack.agents.memory import NoMemory, ConversationMemory
4+
5+
6+
@pytest.mark.unit
7+
def test_no_memory():
8+
no_mem = NoMemory()
9+
assert no_mem.load() == ""
10+
no_mem.save({"key": "value"})
11+
no_mem.clear()
12+
13+
14+
@pytest.mark.unit
15+
def test_conversation_memory():
16+
conv_mem = ConversationMemory()
17+
assert conv_mem.load() == ""
18+
data: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
19+
conv_mem.save(data)
20+
assert conv_mem.load() == "Human: Hello\nAI: Hi there\n"
21+
22+
data: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
23+
conv_mem.save(data)
24+
assert conv_mem.load() == "Human: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
25+
assert conv_mem.load(window_size=1) == "Human: How are you?\nAI: I'm doing well, thanks.\n"
26+
27+
conv_mem.clear()
28+
assert conv_mem.load() == ""
29+
30+
31+
@pytest.mark.unit
32+
def test_conversation_memory_window_size():
33+
conv_mem = ConversationMemory()
34+
assert conv_mem.load() == ""
35+
data: Dict[str, Any] = {"input": "Hello", "output": "Hi there"}
36+
conv_mem.save(data)
37+
data: Dict[str, Any] = {"input": "How are you?", "output": "I'm doing well, thanks."}
38+
conv_mem.save(data)
39+
assert conv_mem.load() == "Human: Hello\nAI: Hi there\nHuman: How are you?\nAI: I'm doing well, thanks.\n"
40+
assert conv_mem.load(window_size=1) == "Human: How are you?\nAI: I'm doing well, thanks.\n"
41+
42+
# clear the memory
43+
conv_mem.clear()
44+
assert conv_mem.load() == ""
45+
assert conv_mem.load(window_size=1) == ""

0 commit comments

Comments
 (0)