test_step_callback_compat.py
1 """Tests for step_callback backward compatibility. 2 3 Verifies that the gateway's step_callback normalization keeps 4 ``tool_names`` as a list of strings for backward-compatible hooks, 5 while also providing the enriched ``tools`` list with results. 6 """ 7 8 import asyncio 9 from unittest.mock import AsyncMock, MagicMock, patch 10 11 import pytest 12 13 14 class TestStepCallbackNormalization: 15 """The gateway's _step_callback_sync normalizes prev_tools from run_agent.""" 16 17 def _extract_step_callback(self): 18 """Build a minimal _step_callback_sync using the same logic as gateway/run.py. 19 20 We replicate the closure so we can test normalisation in isolation 21 without spinning up the full gateway. 22 """ 23 captured_events = [] 24 25 class FakeHooks: 26 async def emit(self, event_type, data): 27 captured_events.append((event_type, data)) 28 29 hooks_ref = FakeHooks() 30 loop = asyncio.new_event_loop() 31 32 def _step_callback_sync(iteration: int, prev_tools: list) -> None: 33 _names: list[str] = [] 34 for _t in (prev_tools or []): 35 if isinstance(_t, dict): 36 _names.append(_t.get("name") or "") 37 else: 38 _names.append(str(_t)) 39 asyncio.run_coroutine_threadsafe( 40 hooks_ref.emit("agent:step", { 41 "iteration": iteration, 42 "tool_names": _names, 43 "tools": prev_tools, 44 }), 45 loop, 46 ) 47 48 return _step_callback_sync, captured_events, loop 49 50 def test_dict_prev_tools_produce_string_tool_names(self): 51 """When prev_tools is list[dict], tool_names should be list[str].""" 52 cb, events, loop = self._extract_step_callback() 53 54 # Simulate the enriched format from run_agent.py 55 prev_tools = [ 56 {"name": "terminal", "result": '{"output": "hello"}'}, 57 {"name": "read_file", "result": '{"content": "..."}'}, 58 ] 59 60 try: 61 loop.run_until_complete(asyncio.sleep(0)) # prime the loop 62 import threading 63 t = threading.Thread(target=cb, args=(1, prev_tools)) 64 t.start() 65 t.join(timeout=2) 66 loop.run_until_complete(asyncio.sleep(0.1)) 67 finally: 68 loop.close() 69 70 assert len(events) == 1 71 _, data = events[0] 72 # tool_names must be strings for backward compat 73 assert data["tool_names"] == ["terminal", "read_file"] 74 assert all(isinstance(n, str) for n in data["tool_names"]) 75 # tools should be the enriched dicts 76 assert data["tools"] == prev_tools 77 78 def test_string_prev_tools_still_work(self): 79 """When prev_tools is list[str] (legacy), tool_names should pass through.""" 80 cb, events, loop = self._extract_step_callback() 81 82 prev_tools = ["terminal", "read_file"] 83 84 try: 85 loop.run_until_complete(asyncio.sleep(0)) 86 import threading 87 t = threading.Thread(target=cb, args=(2, prev_tools)) 88 t.start() 89 t.join(timeout=2) 90 loop.run_until_complete(asyncio.sleep(0.1)) 91 finally: 92 loop.close() 93 94 assert len(events) == 1 95 _, data = events[0] 96 assert data["tool_names"] == ["terminal", "read_file"] 97 98 def test_empty_prev_tools(self): 99 """Empty or None prev_tools should produce empty tool_names.""" 100 cb, events, loop = self._extract_step_callback() 101 102 try: 103 loop.run_until_complete(asyncio.sleep(0)) 104 import threading 105 t = threading.Thread(target=cb, args=(1, [])) 106 t.start() 107 t.join(timeout=2) 108 loop.run_until_complete(asyncio.sleep(0.1)) 109 finally: 110 loop.close() 111 112 assert len(events) == 1 113 _, data = events[0] 114 assert data["tool_names"] == [] 115 116 def test_joinable_for_hook_example(self): 117 """The documented hook example: ', '.join(tool_names) should work.""" 118 # This is the exact pattern from the docs 119 prev_tools = [ 120 {"name": "terminal", "result": "ok"}, 121 {"name": "web_search", "result": None}, 122 ] 123 124 _names = [] 125 for _t in prev_tools: 126 if isinstance(_t, dict): 127 _names.append(_t.get("name") or "") 128 else: 129 _names.append(str(_t)) 130 131 # This must not raise — documented hook pattern 132 result = ", ".join(_names) 133 assert result == "terminal, web_search"