/ tests / test_agent2_memory.py
test_agent2_memory.py
  1  """Tests for agent2 session memory (in-memory backend)."""
  2  import asyncio
  3  from collections import OrderedDict
  4  from types import SimpleNamespace
  5  from unittest.mock import patch
  6  
  7  from restai.agent2.memory import (
  8      LOCAL_SESSION_CAP,
  9      get_session,
 10      save_session,
 11  )
 12  from restai.agent2.types import (
 13      AgentSession,
 14      Message,
 15      TextBlock,
 16      ToolUseBlock,
 17  )
 18  
 19  
 20  def _make_brain():
 21      """Create a minimal mock brain with no Redis configured."""
 22      brain = SimpleNamespace(_agent2_sessions=OrderedDict(), _agent2_redis=None, _agent2_redis_url=None)
 23      return brain
 24  
 25  
 26  @patch("restai.agent2.memory.config.build_redis_url", return_value=None)
 27  def test_get_session_returns_empty_for_unknown_chat_id(_mock):
 28      brain = _make_brain()
 29      session = asyncio.run(get_session(brain, "nonexistent-id"))
 30      assert isinstance(session, AgentSession)
 31      assert session.messages == []
 32  
 33  
 34  @patch("restai.agent2.memory.config.build_redis_url", return_value=None)
 35  def test_save_and_get_session_round_trip(_mock):
 36      brain = _make_brain()
 37  
 38      msg = Message(role="user", content=[TextBlock(text="hello")])
 39      session = AgentSession(messages=[msg])
 40  
 41      asyncio.run(save_session(brain, "chat-1", session))
 42      loaded = asyncio.run(get_session(brain, "chat-1"))
 43  
 44      assert len(loaded.messages) == 1
 45      assert loaded.messages[0].role == "user"
 46      assert len(loaded.messages[0].content) == 1
 47      assert isinstance(loaded.messages[0].content[0], TextBlock)
 48      assert loaded.messages[0].content[0].text == "hello"
 49  
 50  
 51  @patch("restai.agent2.memory.config.build_redis_url", return_value=None)
 52  def test_lru_eviction(_mock):
 53      brain = _make_brain()
 54  
 55      # Save LOCAL_SESSION_CAP + 1 sessions
 56      async def fill():
 57          for i in range(LOCAL_SESSION_CAP + 1):
 58              msg = Message(role="user", content=[TextBlock(text=f"msg-{i}")])
 59              session = AgentSession(messages=[msg])
 60              await save_session(brain, f"chat-{i}", session)
 61  
 62      asyncio.run(fill())
 63  
 64      # The oldest session (chat-0) should have been evicted
 65      oldest = asyncio.run(get_session(brain, "chat-0"))
 66      assert oldest.messages == [], "Oldest session should have been evicted"
 67  
 68      # The newest session should still be present
 69      newest = asyncio.run(get_session(brain, f"chat-{LOCAL_SESSION_CAP}"))
 70      assert len(newest.messages) == 1
 71      assert newest.messages[0].content[0].text == f"msg-{LOCAL_SESSION_CAP}"
 72  
 73  
 74  @patch("restai.agent2.memory.config.build_redis_url", return_value=None)
 75  def test_message_serialization_round_trip(_mock):
 76      """TextBlock and ToolUseBlock survive save/load through JSON serialization."""
 77      brain = _make_brain()
 78  
 79      messages = [
 80          Message(role="user", content=[TextBlock(text="What is 2+2?")]),
 81          Message(
 82              role="assistant",
 83              content=[
 84                  TextBlock(text="Let me calculate that."),
 85                  ToolUseBlock(id="call_1", name="calculator", input={"expr": "2+2"}),
 86              ],
 87          ),
 88      ]
 89      session = AgentSession(messages=messages)
 90  
 91      async def roundtrip():
 92          await save_session(brain, "chat-ser", session)
 93          return await get_session(brain, "chat-ser")
 94  
 95      loaded = asyncio.run(roundtrip())
 96      assert len(loaded.messages) == 2
 97  
 98      m0 = loaded.messages[0]
 99      assert m0.role == "user"
100      assert isinstance(m0.content[0], TextBlock)
101      assert m0.content[0].text == "What is 2+2?"
102  
103      m1 = loaded.messages[1]
104      assert m1.role == "assistant"
105      assert len(m1.content) == 2
106      assert isinstance(m1.content[0], TextBlock)
107      assert isinstance(m1.content[1], ToolUseBlock)
108      assert m1.content[1].name == "calculator"
109      assert m1.content[1].input == {"expr": "2+2"}