/ tests / tools / test_session_search.py
test_session_search.py
  1  """Tests for tools/session_search_tool.py — helper functions and search dispatcher."""
  2  
  3  import asyncio
  4  import json
  5  import time
  6  import pytest
  7  
  8  from tools.session_search_tool import (
  9      _format_timestamp,
 10      _format_conversation,
 11      _truncate_around_matches,
 12      _get_session_search_max_concurrency,
 13      _list_recent_sessions,
 14      _HIDDEN_SESSION_SOURCES,
 15      MAX_SESSION_CHARS,
 16      SESSION_SEARCH_SCHEMA,
 17  )
 18  
 19  
 20  # =========================================================================
 21  # Tool schema guidance
 22  # =========================================================================
 23  
 24  class TestHiddenSessionSources:
 25      """Verify the _HIDDEN_SESSION_SOURCES constant used for third-party isolation."""
 26  
 27      def test_tool_source_is_hidden(self):
 28          assert "tool" in _HIDDEN_SESSION_SOURCES
 29  
 30      def test_standard_sources_not_hidden(self):
 31          for src in ("cli", "telegram", "discord", "slack", "cron"):
 32              assert src not in _HIDDEN_SESSION_SOURCES
 33  
 34  
 35  class TestSessionSearchSchema:
 36      def test_keeps_cross_session_recall_guidance_without_current_session_nudge(self):
 37          description = SESSION_SEARCH_SCHEMA["description"]
 38          assert "past conversations" in description
 39          assert "recent turns of the current session" not in description
 40  
 41  
 42  # =========================================================================
 43  # _format_timestamp
 44  # =========================================================================
 45  
 46  class TestFormatTimestamp:
 47      def test_unix_float(self):
 48          ts = 1700000000.0  # Nov 14, 2023
 49          result = _format_timestamp(ts)
 50          assert "2023" in result or "November" in result
 51  
 52      def test_unix_int(self):
 53          result = _format_timestamp(1700000000)
 54          assert isinstance(result, str)
 55          assert len(result) > 5
 56  
 57      def test_iso_string(self):
 58          result = _format_timestamp("2024-01-15T10:30:00")
 59          assert isinstance(result, str)
 60  
 61      def test_none_returns_unknown(self):
 62          assert _format_timestamp(None) == "unknown"
 63  
 64      def test_numeric_string(self):
 65          result = _format_timestamp("1700000000.0")
 66          assert isinstance(result, str)
 67          assert "unknown" not in result.lower()
 68  
 69  
 70  # =========================================================================
 71  # _format_conversation
 72  # =========================================================================
 73  
 74  class TestFormatConversation:
 75      def test_basic_messages(self):
 76          msgs = [
 77              {"role": "user", "content": "Hello"},
 78              {"role": "assistant", "content": "Hi there!"},
 79          ]
 80          result = _format_conversation(msgs)
 81          assert "[USER]: Hello" in result
 82          assert "[ASSISTANT]: Hi there!" in result
 83  
 84      def test_tool_message(self):
 85          msgs = [
 86              {"role": "tool", "content": "search results", "tool_name": "web_search"},
 87          ]
 88          result = _format_conversation(msgs)
 89          assert "[TOOL:web_search]" in result
 90  
 91      def test_long_tool_output_truncated(self):
 92          msgs = [
 93              {"role": "tool", "content": "x" * 1000, "tool_name": "terminal"},
 94          ]
 95          result = _format_conversation(msgs)
 96          assert "[truncated]" in result
 97  
 98      def test_assistant_with_tool_calls(self):
 99          msgs = [
100              {
101                  "role": "assistant",
102                  "content": "",
103                  "tool_calls": [
104                      {"function": {"name": "web_search"}},
105                      {"function": {"name": "terminal"}},
106                  ],
107              },
108          ]
109          result = _format_conversation(msgs)
110          assert "web_search" in result
111          assert "terminal" in result
112  
113      def test_empty_messages(self):
114          result = _format_conversation([])
115          assert result == ""
116  
117  
118  # =========================================================================
119  # _truncate_around_matches
120  # =========================================================================
121  
122  class TestTruncateAroundMatches:
123      def test_short_text_unchanged(self):
124          text = "Short text about docker"
125          result = _truncate_around_matches(text, "docker")
126          assert result == text
127  
128      def test_long_text_truncated(self):
129          # Create text longer than MAX_SESSION_CHARS with query term in middle
130          padding = "x" * (MAX_SESSION_CHARS + 5000)
131          text = padding + " KEYWORD_HERE " + padding
132          result = _truncate_around_matches(text, "KEYWORD_HERE")
133          assert len(result) <= MAX_SESSION_CHARS + 100  # +100 for prefix/suffix markers
134          assert "KEYWORD_HERE" in result
135  
136      def test_truncation_adds_markers(self):
137          text = "a" * 50000 + " target " + "b" * (MAX_SESSION_CHARS + 5000)
138          result = _truncate_around_matches(text, "target")
139          assert "truncated" in result.lower()
140  
141      def test_no_match_takes_from_start(self):
142          text = "x" * (MAX_SESSION_CHARS + 5000)
143          result = _truncate_around_matches(text, "nonexistent")
144          # Should take from the beginning
145          assert result.startswith("x")
146  
147      def test_match_at_beginning(self):
148          text = "KEYWORD " + "x" * (MAX_SESSION_CHARS + 5000)
149          result = _truncate_around_matches(text, "KEYWORD")
150          assert "KEYWORD" in result
151  
152      def test_multiword_phrase_match_beats_individual_term(self):
153          """Full phrase deep in text should be found even when a single term
154          appears much earlier in boilerplate."""
155          boilerplate = "The project setup is complex. " * 500  # ~15K, has 'project' early
156          filler = "x" * (MAX_SESSION_CHARS + 20000)
157          target = "We reviewed the keystone project roadmap in detail."
158          text = boilerplate + filler + target + filler
159          result = _truncate_around_matches(text, "keystone project")
160          assert "keystone project" in result.lower()
161  
162      def test_multiword_proximity_cooccurrence(self):
163          """When exact phrase is absent, terms co-occurring within proximity
164          should be preferred over a lone early term."""
165          early = "project " + "a" * (MAX_SESSION_CHARS + 20000)
166          # Place 'keystone' and 'project' near each other (but not as exact phrase)
167          cooccur = "this keystone initiative for the project was pivotal"
168          tail = "b" * (MAX_SESSION_CHARS + 20000)
169          text = early + cooccur + tail
170          result = _truncate_around_matches(text, "keystone project")
171          assert "keystone" in result.lower()
172          assert "project" in result.lower()
173  
174      def test_multiword_window_maximises_coverage(self):
175          """Sliding window should capture as many match clusters as possible."""
176          # Place two phrase matches: one at ~50K, one at ~60K, both should fit
177          pre = "z" * 50000
178          match1 = " alpha beta "
179          gap = "z" * 10000
180          match2 = " alpha beta "
181          post = "z" * (MAX_SESSION_CHARS + 40000)
182          text = pre + match1 + gap + match2 + post
183          result = _truncate_around_matches(text, "alpha beta")
184          assert result.lower().count("alpha beta") == 2
185  
186  
187  class TestSessionSearchConcurrency:
188      def test_defaults_to_three(self):
189          assert _get_session_search_max_concurrency() == 3
190  
191      def test_reads_and_clamps_configured_value(self, monkeypatch):
192          monkeypatch.setattr(
193              "hermes_cli.config.load_config",
194              lambda: {"auxiliary": {"session_search": {"max_concurrency": 9}}},
195          )
196          assert _get_session_search_max_concurrency() == 5
197  
198      def test_session_search_respects_configured_concurrency_limit(self, monkeypatch):
199          from unittest.mock import MagicMock
200          from tools.session_search_tool import session_search
201  
202          monkeypatch.setattr(
203              "hermes_cli.config.load_config",
204              lambda: {"auxiliary": {"session_search": {"max_concurrency": 1}}},
205          )
206  
207          max_seen = {"value": 0}
208          active = {"value": 0}
209  
210          async def fake_summarize(_text, _query, _meta):
211              active["value"] += 1
212              max_seen["value"] = max(max_seen["value"], active["value"])
213              await asyncio.sleep(0.01)
214              active["value"] -= 1
215              return "summary"
216  
217          monkeypatch.setattr("tools.session_search_tool._summarize_session", fake_summarize)
218          monkeypatch.setattr("model_tools._run_async", lambda coro: asyncio.run(coro))
219  
220          mock_db = MagicMock()
221          mock_db.search_messages.return_value = [
222              {"session_id": "s1", "source": "cli", "session_started": 1709500000, "model": "test"},
223              {"session_id": "s2", "source": "cli", "session_started": 1709500001, "model": "test"},
224              {"session_id": "s3", "source": "cli", "session_started": 1709500002, "model": "test"},
225          ]
226          mock_db.get_session.side_effect = lambda sid: {
227              "id": sid,
228              "parent_session_id": None,
229              "source": "cli",
230              "started_at": 1709500000,
231          }
232          mock_db.get_messages_as_conversation.side_effect = lambda sid: [
233              {"role": "user", "content": f"message from {sid}"},
234              {"role": "assistant", "content": "response"},
235          ]
236  
237          result = json.loads(session_search(query="message", db=mock_db, limit=3))
238  
239          assert result["success"] is True
240          assert result["count"] == 3
241          assert max_seen["value"] == 1
242  
243  
244  class TestRecentSessionListing:
245      def test_recent_mode_requests_last_active_ordering(self):
246          from unittest.mock import MagicMock
247  
248          mock_db = MagicMock()
249          mock_db.list_sessions_rich.return_value = []
250  
251          result = json.loads(_list_recent_sessions(mock_db, limit=5))
252  
253          assert result["success"] is True
254          mock_db.list_sessions_rich.assert_called_once_with(
255              limit=10,
256              exclude_sources=["tool"],
257              order_by_last_active=True,
258          )
259  
260      def test_current_child_session_excludes_root_lineage_even_when_child_id_is_longer(self):
261          from unittest.mock import MagicMock
262  
263          mock_db = MagicMock()
264          mock_db.list_sessions_rich.return_value = [
265              {
266                  "id": "root",
267                  "title": "Current conversation",
268                  "source": "cli",
269                  "started_at": 1709500000,
270                  "last_active": 1709500100,
271                  "message_count": 4,
272                  "preview": "current root",
273                  "parent_session_id": None,
274              },
275              {
276                  "id": "other_session",
277                  "title": "Other conversation",
278                  "source": "cli",
279                  "started_at": 1709400000,
280                  "last_active": 1709400100,
281                  "message_count": 3,
282                  "preview": "other root",
283                  "parent_session_id": None,
284              },
285          ]
286  
287          def _get_session(session_id):
288              if session_id == "child_session_id_that_is_definitely_longer":
289                  return {"parent_session_id": "root"}
290              if session_id == "root":
291                  return {"parent_session_id": None}
292              return None
293  
294          mock_db.get_session.side_effect = _get_session
295  
296          result = json.loads(_list_recent_sessions(
297              mock_db,
298              limit=5,
299              current_session_id="child_session_id_that_is_definitely_longer",
300          ))
301  
302          assert result["success"] is True
303          assert [item["session_id"] for item in result["results"]] == ["other_session"]
304          assert all(item["session_id"] != "root" for item in result["results"])
305  
306  
307  # =========================================================================
308  # session_search (dispatcher)
309  # =========================================================================
310  
311  class TestSessionSearch:
312      def test_no_db_returns_error(self):
313          from tools.session_search_tool import session_search
314          result = json.loads(session_search(query="test"))
315          assert result["success"] is False
316          assert "not available" in result["error"].lower()
317  
318      def test_empty_query_returns_error(self):
319          from tools.session_search_tool import session_search
320          mock_db = object()
321          result = json.loads(session_search(query="", db=mock_db))
322          assert result["success"] is False
323  
324      def test_whitespace_query_returns_error(self):
325          from tools.session_search_tool import session_search
326          mock_db = object()
327          result = json.loads(session_search(query="   ", db=mock_db))
328          assert result["success"] is False
329  
330      def test_current_session_excluded(self):
331          """session_search should never return the current session."""
332          from unittest.mock import MagicMock
333          from tools.session_search_tool import session_search
334  
335          mock_db = MagicMock()
336          current_sid = "20260304_120000_abc123"
337  
338          # Simulate FTS5 returning matches only from the current session
339          mock_db.search_messages.return_value = [
340              {"session_id": current_sid, "content": "test match", "source": "cli",
341               "session_started": 1709500000, "model": "test"},
342          ]
343          mock_db.get_session.return_value = {"parent_session_id": None}
344  
345          result = json.loads(session_search(
346              query="test", db=mock_db, current_session_id=current_sid,
347          ))
348          assert result["success"] is True
349          assert result["count"] == 0
350          assert result["results"] == []
351  
352      def test_current_session_excluded_keeps_others(self):
353          """Other sessions should still be returned when current is excluded."""
354          from unittest.mock import MagicMock
355          from tools.session_search_tool import session_search
356  
357          mock_db = MagicMock()
358          current_sid = "20260304_120000_abc123"
359          other_sid = "20260303_100000_def456"
360  
361          mock_db.search_messages.return_value = [
362              {"session_id": current_sid, "content": "match 1", "source": "cli",
363               "session_started": 1709500000, "model": "test"},
364              {"session_id": other_sid, "content": "match 2", "source": "telegram",
365               "session_started": 1709400000, "model": "test"},
366          ]
367          mock_db.get_session.return_value = {"parent_session_id": None}
368          mock_db.get_messages_as_conversation.return_value = [
369              {"role": "user", "content": "hello"},
370              {"role": "assistant", "content": "hi there"},
371          ]
372  
373          # Mock async_call_llm to raise RuntimeError → summarizer returns None
374          from unittest.mock import AsyncMock, patch as _patch
375          with _patch("tools.session_search_tool.async_call_llm",
376                       new_callable=AsyncMock,
377                       side_effect=RuntimeError("no provider")):
378              result = json.loads(session_search(
379                  query="test", db=mock_db, current_session_id=current_sid,
380              ))
381  
382          assert result["success"] is True
383          # Current session should be skipped, only other_sid should appear
384          assert result["sessions_searched"] == 1
385          assert current_sid not in [r.get("session_id") for r in result.get("results", [])]
386  
387      def test_current_child_session_excludes_parent_lineage(self):
388          """Compression/delegation parents should be excluded for the active child session."""
389          from unittest.mock import MagicMock
390          from tools.session_search_tool import session_search
391  
392          mock_db = MagicMock()
393          mock_db.search_messages.return_value = [
394              {"session_id": "parent_sid", "content": "match", "source": "cli",
395               "session_started": 1709500000, "model": "test"},
396          ]
397  
398          def _get_session(session_id):
399              if session_id == "child_sid":
400                  return {"parent_session_id": "parent_sid"}
401              if session_id == "parent_sid":
402                  return {"parent_session_id": None}
403              return None
404  
405          mock_db.get_session.side_effect = _get_session
406  
407          result = json.loads(session_search(
408              query="test", db=mock_db, current_session_id="child_sid",
409          ))
410  
411          assert result["success"] is True
412          assert result["count"] == 0
413          assert result["results"] == []
414          assert result["sessions_searched"] == 0
415  
416      def test_limit_none_coerced_to_default(self):
417          """Model sends limit=null → should fall back to 3, not TypeError."""
418          from unittest.mock import MagicMock
419          from tools.session_search_tool import session_search
420  
421          mock_db = MagicMock()
422          mock_db.search_messages.return_value = []
423  
424          result = json.loads(session_search(
425              query="test", db=mock_db, limit=None,
426          ))
427          assert result["success"] is True
428  
429      def test_limit_type_object_coerced_to_default(self):
430          """Model sends limit as a type object → should fall back to 3, not TypeError."""
431          from unittest.mock import MagicMock
432          from tools.session_search_tool import session_search
433  
434          mock_db = MagicMock()
435          mock_db.search_messages.return_value = []
436  
437          result = json.loads(session_search(
438              query="test", db=mock_db, limit=int,
439          ))
440          assert result["success"] is True
441  
442      def test_limit_string_coerced(self):
443          """Model sends limit as string '2' → should coerce to int."""
444          from unittest.mock import MagicMock
445          from tools.session_search_tool import session_search
446  
447          mock_db = MagicMock()
448          mock_db.search_messages.return_value = []
449  
450          result = json.loads(session_search(
451              query="test", db=mock_db, limit="2",
452          ))
453          assert result["success"] is True
454  
455      def test_limit_clamped_to_range(self):
456          """Negative or zero limit should be clamped to 1."""
457          from unittest.mock import MagicMock
458          from tools.session_search_tool import session_search
459  
460          mock_db = MagicMock()
461          mock_db.search_messages.return_value = []
462  
463          result = json.loads(session_search(
464              query="test", db=mock_db, limit=-5,
465          ))
466          assert result["success"] is True
467  
468          result = json.loads(session_search(
469              query="test", db=mock_db, limit=0,
470          ))
471          assert result["success"] is True
472  
473      def test_current_root_session_excludes_child_lineage(self):
474          """Delegation child hits should be excluded when they resolve to the current root session."""
475          from unittest.mock import MagicMock
476          from tools.session_search_tool import session_search
477  
478          mock_db = MagicMock()
479          mock_db.search_messages.return_value = [
480              {"session_id": "child_sid", "content": "match", "source": "cli",
481               "session_started": 1709500000, "model": "test"},
482          ]
483  
484          def _get_session(session_id):
485              if session_id == "root_sid":
486                  return {"parent_session_id": None}
487              if session_id == "child_sid":
488                  return {"parent_session_id": "root_sid"}
489              return None
490  
491          mock_db.get_session.side_effect = _get_session
492  
493          result = json.loads(session_search(
494              query="test", db=mock_db, current_session_id="root_sid",
495          ))
496  
497          assert result["success"] is True
498          assert result["count"] == 0
499          assert result["results"] == []
500          assert result["sessions_searched"] == 0
501  
502      def test_source_from_resolved_parent_not_fts5_child(self):
503          """source in output must reflect the resolved parent session, not the child that matched FTS5.
504  
505          Regression test for #15909: when a delegation child session (source='telegram')
506          resolves to a parent (source='api_server'), the result entry must report
507          'api_server', not 'telegram'.
508          """
509          from unittest.mock import MagicMock, AsyncMock, patch as _patch
510          from tools.session_search_tool import session_search
511  
512          mock_db = MagicMock()
513          # FTS5 hit is in the child delegation session which carries source='telegram'
514          mock_db.search_messages.return_value = [
515              {
516                  "session_id": "child_sid",
517                  "content": "hello world",
518                  "source": "telegram",       # child session source — wrong value to surface
519                  "session_started": 1709400000,
520                  "model": "gpt-4o-mini",
521              },
522          ]
523  
524          def _get_session(session_id):
525              if session_id == "child_sid":
526                  return {
527                      "id": "child_sid",
528                      "parent_session_id": "parent_sid",
529                      "source": "telegram",
530                      "started_at": 1709400000,
531                      "model": "gpt-4o-mini",
532                  }
533              if session_id == "parent_sid":
534                  return {
535                      "id": "parent_sid",
536                      "parent_session_id": None,
537                      "source": "api_server",  # correct parent source
538                      "started_at": 1709300000,
539                      "model": "gpt-4o-mini",
540                  }
541              return None
542  
543          mock_db.get_session.side_effect = _get_session
544          mock_db.get_messages_as_conversation.return_value = [
545              {"role": "user", "content": "hello world"},
546              {"role": "assistant", "content": "hi there"},
547          ]
548  
549          with _patch(
550              "tools.session_search_tool.async_call_llm",
551              new_callable=AsyncMock,
552              side_effect=RuntimeError("no provider"),
553          ):
554              result = json.loads(session_search(query="hello world", db=mock_db))
555  
556          assert result["success"] is True
557          assert result["count"] == 1
558          entry = result["results"][0]
559          assert entry["session_id"] == "parent_sid", "should report resolved parent session ID"
560          assert entry["source"] == "api_server", (
561              f"source should be parent's 'api_server', got {entry['source']!r}"
562          )