Skip to content

Commit ee557a2

Browse files
committed
Add support for agent memory
1 parent ffc7f9a commit ee557a2

File tree

3 files changed

+111
-2
lines changed

3 files changed

+111
-2
lines changed

coagent/agents/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
StructuredOutput,
99
type_to_response_format_param,
1010
)
11+
from .memory import Memory, InMemMemory
1112
from .model import Model
1213
from .parallel import Aggregator, AggregationResult, Parallel
1314
from .sequential import Sequential

coagent/agents/chat_agent.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MCPTool,
2626
NamedMCPServer,
2727
)
28+
from .memory import Memory, NoMemory
2829
from .messages import ChatMessage, ChatHistory, StructuredOutput
2930
from .model import default_model, Model
3031
from .util import is_user_confirmed
@@ -222,6 +223,7 @@ def __init__(
222223
mcp_servers: list[NamedMCPServer] | None = None,
223224
mcp_server_agent_type: str = "mcp_server",
224225
model: Model = default_model,
226+
memory: Memory | None = None,
225227
timeout: float = 300,
226228
):
227229
super().__init__(timeout=timeout)
@@ -236,6 +238,7 @@ def __init__(
236238
self._swarm_client: Swarm = Swarm(self.model)
237239
self._swarm_agent: SwarmAgent | None = None
238240

241+
self._memory: Memory = memory or NoMemory()
239242
self._history: ChatHistory = ChatHistory(messages=[])
240243

241244
@property
@@ -267,6 +270,10 @@ def mcp_server_agent_type(self) -> str:
267270
def model(self) -> Model:
268271
return self._model
269272

273+
@property
274+
def memory(self) -> Memory:
275+
return self._memory
276+
270277
async def stopped(self) -> None:
271278
for server in self.mcp_servers:
272279
if server.connect is not None:
@@ -286,7 +293,7 @@ def get_swarm_client(self, extensions: dict) -> Swarm:
286293
if model_id:
287294
# We assume that non-empty model ID indicates the use of a dynamic model client.
288295
model = Model(
289-
model=model_id,
296+
id=model_id,
290297
base_url=extensions.get("model_base_url", ""),
291298
api_key=extensions.get("model_api_key", ""),
292299
api_version=extensions.get("model_api_version", ""),
@@ -319,13 +326,15 @@ async def get_swarm_agent(self) -> SwarmAgent:
319326

320327
async def agent(self, agent_type: str) -> AsyncIterator[ChatMessage]:
321328
"""The candidate agent to delegate the conversation to."""
329+
# TODO: Handle memory?
322330
async for chunk in Delegate(self, agent_type).handle(self._history):
323331
yield chunk
324332

325333
@handler
326334
async def handle_history(
327335
self, msg: ChatHistory, ctx: Context
328336
) -> AsyncIterator[ChatMessage]:
337+
# TODO: Handle memory?
329338
response = self._handle_history(msg)
330339
async for resp in response:
331340
yield resp
@@ -334,15 +343,25 @@ async def handle_history(
334343
async def handle_message(
335344
self, msg: ChatMessage, ctx: Context
336345
) -> AsyncIterator[ChatMessage]:
337-
history = ChatHistory(messages=[msg])
346+
existing = await self.memory.get_items()
347+
history = ChatHistory(messages=existing + [msg])
348+
338349
response = self._handle_history(history)
350+
full_content = ""
339351
async for resp in response:
340352
yield resp
353+
full_content += resp.content
354+
355+
await self.memory.add_items(
356+
msg, # input item
357+
ChatMessage(role="assistant", content=full_content), # output item
358+
)
341359

342360
@handler
343361
async def handle_structured_output(
344362
self, msg: StructuredOutput, ctx: Context
345363
) -> AsyncIterator[ChatMessage]:
364+
# TODO: Handle memory?
346365
match msg.input:
347366
case ChatMessage():
348367
history = ChatHistory(messages=[msg.input])

coagent/agents/memory.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Protocol, runtime_checkable
2+
3+
from .messages import ChatMessage
4+
5+
6+
@runtime_checkable
7+
class Memory(Protocol):
8+
"""Protocol for memory implementations.
9+
10+
Memory stores conversation history for a specific agent, allowing
11+
agents to maintain context without requiring explicit manual memory management.
12+
"""
13+
14+
async def get_items(self, limit: int | None = None) -> list[ChatMessage]:
15+
"""Retrieve the conversation history for this memory.
16+
17+
Args:
18+
limit: Maximum number of items to retrieve. If None, retrieves all items.
19+
When specified, returns the latest N items in chronological order.
20+
21+
Returns:
22+
List of input items representing the conversation history
23+
"""
24+
...
25+
26+
async def add_items(self, *items: ChatMessage) -> None:
27+
"""Add new items to the conversation history.
28+
29+
Args:
30+
items: List of input items to add to the history
31+
"""
32+
...
33+
34+
async def pop_item(self) -> ChatMessage | None:
35+
"""Remove and return the most recent item from the memory.
36+
37+
Returns:
38+
The most recent item if it exists, None if the memory is empty
39+
"""
40+
...
41+
42+
async def clear_items(self) -> None:
43+
"""Clear all items for this memory."""
44+
...
45+
46+
47+
class NoMemory(list):
48+
"""Built-in memory implementation that stores no conversation history."""
49+
50+
async def get_items(self, limit: int | None = None) -> list[ChatMessage]:
51+
"""Retrieve the conversation history for this memory."""
52+
return []
53+
54+
async def add_items(self, *items: ChatMessage) -> None:
55+
"""Add new items to the conversation history."""
56+
return
57+
58+
async def pop_item(self) -> ChatMessage | None:
59+
"""Remove and return the most recent item from the memory."""
60+
return
61+
62+
async def clear_items(self) -> None:
63+
"""Clear all items for this memory."""
64+
return
65+
66+
67+
class InMemMemory(list):
68+
"""Built-in memory implementation that stores conversation history in memory."""
69+
70+
async def get_items(self, limit: int | None = None) -> list[ChatMessage]:
71+
"""Retrieve the conversation history for this memory."""
72+
if limit is None:
73+
return self
74+
# Return the latest limit number of items.
75+
return self[-limit:]
76+
77+
async def add_items(self, *items: ChatMessage) -> None:
78+
"""Add new items to the conversation history."""
79+
self.extend(items)
80+
81+
async def pop_item(self) -> ChatMessage | None:
82+
"""Remove and return the most recent item from the memory."""
83+
if not self:
84+
return None
85+
return self.pop()
86+
87+
async def clear_items(self) -> None:
88+
"""Clear all items for this memory."""
89+
self.clear()

0 commit comments

Comments
 (0)