/ tests / gateway / test_step_callback_compat.py
test_step_callback_compat.py
  1  """Tests for step_callback backward compatibility.
  2  
  3  Verifies that the gateway's step_callback normalization keeps
  4  ``tool_names`` as a list of strings for backward-compatible hooks,
  5  while also providing the enriched ``tools`` list with results.
  6  """
  7  
  8  import asyncio
  9  from unittest.mock import AsyncMock, MagicMock, patch
 10  
 11  import pytest
 12  
 13  
 14  class TestStepCallbackNormalization:
 15      """The gateway's _step_callback_sync normalizes prev_tools from run_agent."""
 16  
 17      def _extract_step_callback(self):
 18          """Build a minimal _step_callback_sync using the same logic as gateway/run.py.
 19  
 20          We replicate the closure so we can test normalisation in isolation
 21          without spinning up the full gateway.
 22          """
 23          captured_events = []
 24  
 25          class FakeHooks:
 26              async def emit(self, event_type, data):
 27                  captured_events.append((event_type, data))
 28  
 29          hooks_ref = FakeHooks()
 30          loop = asyncio.new_event_loop()
 31  
 32          def _step_callback_sync(iteration: int, prev_tools: list) -> None:
 33              _names: list[str] = []
 34              for _t in (prev_tools or []):
 35                  if isinstance(_t, dict):
 36                      _names.append(_t.get("name") or "")
 37                  else:
 38                      _names.append(str(_t))
 39              asyncio.run_coroutine_threadsafe(
 40                  hooks_ref.emit("agent:step", {
 41                      "iteration": iteration,
 42                      "tool_names": _names,
 43                      "tools": prev_tools,
 44                  }),
 45                  loop,
 46              )
 47  
 48          return _step_callback_sync, captured_events, loop
 49  
 50      def test_dict_prev_tools_produce_string_tool_names(self):
 51          """When prev_tools is list[dict], tool_names should be list[str]."""
 52          cb, events, loop = self._extract_step_callback()
 53  
 54          # Simulate the enriched format from run_agent.py
 55          prev_tools = [
 56              {"name": "terminal", "result": '{"output": "hello"}'},
 57              {"name": "read_file", "result": '{"content": "..."}'},
 58          ]
 59  
 60          try:
 61              loop.run_until_complete(asyncio.sleep(0))  # prime the loop
 62              import threading
 63              t = threading.Thread(target=cb, args=(1, prev_tools))
 64              t.start()
 65              t.join(timeout=2)
 66              loop.run_until_complete(asyncio.sleep(0.1))
 67          finally:
 68              loop.close()
 69  
 70          assert len(events) == 1
 71          _, data = events[0]
 72          # tool_names must be strings for backward compat
 73          assert data["tool_names"] == ["terminal", "read_file"]
 74          assert all(isinstance(n, str) for n in data["tool_names"])
 75          # tools should be the enriched dicts
 76          assert data["tools"] == prev_tools
 77  
 78      def test_string_prev_tools_still_work(self):
 79          """When prev_tools is list[str] (legacy), tool_names should pass through."""
 80          cb, events, loop = self._extract_step_callback()
 81  
 82          prev_tools = ["terminal", "read_file"]
 83  
 84          try:
 85              loop.run_until_complete(asyncio.sleep(0))
 86              import threading
 87              t = threading.Thread(target=cb, args=(2, prev_tools))
 88              t.start()
 89              t.join(timeout=2)
 90              loop.run_until_complete(asyncio.sleep(0.1))
 91          finally:
 92              loop.close()
 93  
 94          assert len(events) == 1
 95          _, data = events[0]
 96          assert data["tool_names"] == ["terminal", "read_file"]
 97  
 98      def test_empty_prev_tools(self):
 99          """Empty or None prev_tools should produce empty tool_names."""
100          cb, events, loop = self._extract_step_callback()
101  
102          try:
103              loop.run_until_complete(asyncio.sleep(0))
104              import threading
105              t = threading.Thread(target=cb, args=(1, []))
106              t.start()
107              t.join(timeout=2)
108              loop.run_until_complete(asyncio.sleep(0.1))
109          finally:
110              loop.close()
111  
112          assert len(events) == 1
113          _, data = events[0]
114          assert data["tool_names"] == []
115  
116      def test_joinable_for_hook_example(self):
117          """The documented hook example: ', '.join(tool_names) should work."""
118          # This is the exact pattern from the docs
119          prev_tools = [
120              {"name": "terminal", "result": "ok"},
121              {"name": "web_search", "result": None},
122          ]
123  
124          _names = []
125          for _t in prev_tools:
126              if isinstance(_t, dict):
127                  _names.append(_t.get("name") or "")
128              else:
129                  _names.append(str(_t))
130  
131          # This must not raise — documented hook pattern
132          result = ", ".join(_names)
133          assert result == "terminal, web_search"