2525 MCPTool ,
2626 NamedMCPServer ,
2727)
28+ from .memory import Memory , NoMemory
2829from .messages import ChatMessage , ChatHistory , StructuredOutput
2930from .model import default_model , Model
3031from .util import is_user_confirmed
@@ -222,6 +223,7 @@ def __init__(
222223 mcp_servers : list [NamedMCPServer ] | None = None ,
223224 mcp_server_agent_type : str = "mcp_server" ,
224225 model : Model = default_model ,
226+ memory : Memory | None = None ,
225227 timeout : float = 300 ,
226228 ):
227229 super ().__init__ (timeout = timeout )
@@ -236,6 +238,7 @@ def __init__(
236238 self ._swarm_client : Swarm = Swarm (self .model )
237239 self ._swarm_agent : SwarmAgent | None = None
238240
241+ self ._memory : Memory = memory or NoMemory ()
239242 self ._history : ChatHistory = ChatHistory (messages = [])
240243
241244 @property
@@ -267,6 +270,10 @@ def mcp_server_agent_type(self) -> str:
267270 def model (self ) -> Model :
268271 return self ._model
269272
273+ @property
274+ def memory (self ) -> Memory :
275+ return self ._memory
276+
270277 async def stopped (self ) -> None :
271278 for server in self .mcp_servers :
272279 if server .connect is not None :
@@ -286,7 +293,7 @@ def get_swarm_client(self, extensions: dict) -> Swarm:
286293 if model_id :
287294 # We assume that non-empty model ID indicates the use of a dynamic model client.
288295 model = Model (
289- model = model_id ,
296+ id = model_id ,
290297 base_url = extensions .get ("model_base_url" , "" ),
291298 api_key = extensions .get ("model_api_key" , "" ),
292299 api_version = extensions .get ("model_api_version" , "" ),
@@ -319,13 +326,15 @@ async def get_swarm_agent(self) -> SwarmAgent:
319326
320327 async def agent (self , agent_type : str ) -> AsyncIterator [ChatMessage ]:
321328 """The candidate agent to delegate the conversation to."""
329+ # TODO: Handle memory?
322330 async for chunk in Delegate (self , agent_type ).handle (self ._history ):
323331 yield chunk
324332
325333 @handler
326334 async def handle_history (
327335 self , msg : ChatHistory , ctx : Context
328336 ) -> AsyncIterator [ChatMessage ]:
337+ # TODO: Handle memory?
329338 response = self ._handle_history (msg )
330339 async for resp in response :
331340 yield resp
@@ -334,15 +343,25 @@ async def handle_history(
334343 async def handle_message (
335344 self , msg : ChatMessage , ctx : Context
336345 ) -> AsyncIterator [ChatMessage ]:
337- history = ChatHistory (messages = [msg ])
346+ existing = await self .memory .get_items ()
347+ history = ChatHistory (messages = existing + [msg ])
348+
338349 response = self ._handle_history (history )
350+ full_content = ""
339351 async for resp in response :
340352 yield resp
353+ full_content += resp .content
354+
355+ await self .memory .add_items (
356+ msg , # input item
357+ ChatMessage (role = "assistant" , content = full_content ), # output item
358+ )
341359
342360 @handler
343361 async def handle_structured_output (
344362 self , msg : StructuredOutput , ctx : Context
345363 ) -> AsyncIterator [ChatMessage ]:
364+ # TODO: Handle memory?
346365 match msg .input :
347366 case ChatMessage ():
348367 history = ChatHistory (messages = [msg .input ])
0 commit comments