test_steer.py
1 """Tests for AIAgent.steer() — mid-run user message injection. 2 3 /steer lets the user add a note to the agent's next tool result without 4 interrupting the current tool call. The agent sees the note inline with 5 tool output on its next iteration, preserving message-role alternation 6 and prompt-cache integrity. 7 """ 8 from __future__ import annotations 9 10 import threading 11 12 import pytest 13 14 from run_agent import AIAgent 15 16 17 def _bare_agent() -> AIAgent: 18 """Build an AIAgent without running __init__, then install the steer 19 state manually — matches the existing object.__new__ stub pattern 20 used elsewhere in the test suite. 21 """ 22 agent = object.__new__(AIAgent) 23 agent._pending_steer = None 24 agent._pending_steer_lock = threading.Lock() 25 return agent 26 27 28 class TestSteerAcceptance: 29 def test_accepts_non_empty_text(self): 30 agent = _bare_agent() 31 assert agent.steer("go ahead and check the logs") is True 32 assert agent._pending_steer == "go ahead and check the logs" 33 34 def test_rejects_empty_string(self): 35 agent = _bare_agent() 36 assert agent.steer("") is False 37 assert agent._pending_steer is None 38 39 def test_rejects_whitespace_only(self): 40 agent = _bare_agent() 41 assert agent.steer(" \n\t ") is False 42 assert agent._pending_steer is None 43 44 def test_rejects_none(self): 45 agent = _bare_agent() 46 assert agent.steer(None) is False # type: ignore[arg-type] 47 assert agent._pending_steer is None 48 49 def test_strips_surrounding_whitespace(self): 50 agent = _bare_agent() 51 assert agent.steer(" hello world \n") is True 52 assert agent._pending_steer == "hello world" 53 54 def test_concatenates_multiple_steers_with_newlines(self): 55 agent = _bare_agent() 56 agent.steer("first note") 57 agent.steer("second note") 58 agent.steer("third note") 59 assert agent._pending_steer == "first note\nsecond note\nthird note" 60 61 62 class TestSteerDrain: 63 def test_drain_returns_and_clears(self): 64 agent = _bare_agent() 65 agent.steer("hello") 66 assert agent._drain_pending_steer() == "hello" 67 assert agent._pending_steer is None 68 69 def test_drain_on_empty_returns_none(self): 70 agent = _bare_agent() 71 assert agent._drain_pending_steer() is None 72 73 74 class TestSteerInjection: 75 def test_appends_to_last_tool_result(self): 76 agent = _bare_agent() 77 agent.steer("please also check auth.log") 78 messages = [ 79 {"role": "user", "content": "what's in /var/log?"}, 80 {"role": "assistant", "tool_calls": [{"id": "a"}, {"id": "b"}]}, 81 {"role": "tool", "content": "ls output A", "tool_call_id": "a"}, 82 {"role": "tool", "content": "ls output B", "tool_call_id": "b"}, 83 ] 84 agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2) 85 # The LAST tool result is modified; earlier ones are untouched. 86 assert messages[2]["content"] == "ls output A" 87 assert "ls output B" in messages[3]["content"] 88 assert "User guidance:" in messages[3]["content"] 89 assert "please also check auth.log" in messages[3]["content"] 90 # And pending_steer is consumed. 91 assert agent._pending_steer is None 92 93 def test_no_op_when_no_steer_pending(self): 94 agent = _bare_agent() 95 messages = [ 96 {"role": "assistant", "tool_calls": [{"id": "a"}]}, 97 {"role": "tool", "content": "output", "tool_call_id": "a"}, 98 ] 99 agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1) 100 assert messages[-1]["content"] == "output" # unchanged 101 102 def test_no_op_when_num_tool_msgs_zero(self): 103 agent = _bare_agent() 104 agent.steer("steer") 105 messages = [{"role": "user", "content": "hi"}] 106 agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=0) 107 # Steer should remain pending (nothing to drain into) 108 assert agent._pending_steer == "steer" 109 110 def test_marker_labels_text_as_user_guidance(self): 111 """The injection marker must label the appended text as user 112 guidance so the model attributes it to the user rather than 113 confusing it with tool output. This is the cache-safe way to 114 signal provenance without violating message-role alternation. 115 """ 116 agent = _bare_agent() 117 agent.steer("stop after next step") 118 messages = [{"role": "tool", "content": "x", "tool_call_id": "1"}] 119 agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1) 120 content = messages[-1]["content"] 121 assert "User guidance:" in content 122 assert "stop after next step" in content 123 124 def test_multimodal_content_list_preserved(self): 125 """Anthropic-style list content should be preserved, with the steer 126 appended as a text block.""" 127 agent = _bare_agent() 128 agent.steer("extra note") 129 original_blocks = [{"type": "text", "text": "existing output"}] 130 messages = [ 131 {"role": "tool", "content": list(original_blocks), "tool_call_id": "1"} 132 ] 133 agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1) 134 new_content = messages[-1]["content"] 135 assert isinstance(new_content, list) 136 assert len(new_content) == 2 137 assert new_content[0] == {"type": "text", "text": "existing output"} 138 assert new_content[1]["type"] == "text" 139 assert "extra note" in new_content[1]["text"] 140 141 def test_restashed_when_no_tool_result_in_batch(self): 142 """If the 'batch' contains no tool-role messages (e.g. all skipped 143 after an interrupt), the steer should be put back into the pending 144 slot so the caller's fallback path can deliver it.""" 145 agent = _bare_agent() 146 agent.steer("ping") 147 messages = [ 148 {"role": "user", "content": "x"}, 149 {"role": "assistant", "content": "y"}, 150 ] 151 # Claim there were N tool msgs, but the tail has none — simulates 152 # the interrupt-cancelled case. 153 agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2) 154 # Messages untouched 155 assert messages[-1]["content"] == "y" 156 # And the steer is back in pending so the fallback can grab it 157 assert agent._pending_steer == "ping" 158 159 160 class TestSteerThreadSafety: 161 def test_concurrent_steer_calls_preserve_all_text(self): 162 agent = _bare_agent() 163 N = 200 164 165 def worker(idx: int) -> None: 166 agent.steer(f"note-{idx}") 167 168 threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)] 169 for t in threads: 170 t.start() 171 for t in threads: 172 t.join() 173 174 text = agent._drain_pending_steer() 175 assert text is not None 176 # Every single note must be preserved — none dropped by the lock. 177 lines = text.split("\n") 178 assert len(lines) == N 179 assert set(lines) == {f"note-{i}" for i in range(N)} 180 181 182 class TestSteerClearedOnInterrupt: 183 def test_clear_interrupt_drops_pending_steer(self): 184 """A hard interrupt supersedes any pending steer — the agent's 185 next tool iteration won't happen, so delivering the steer later 186 would be surprising.""" 187 agent = _bare_agent() 188 # Minimal surface needed by clear_interrupt() 189 agent._interrupt_requested = True 190 agent._interrupt_message = None 191 agent._interrupt_thread_signal_pending = False 192 agent._execution_thread_id = None 193 agent._tool_worker_threads = None 194 agent._tool_worker_threads_lock = None 195 196 agent.steer("will be dropped") 197 assert agent._pending_steer == "will be dropped" 198 199 agent.clear_interrupt() 200 assert agent._pending_steer is None 201 202 203 class TestPreApiCallSteerDrain: 204 """Test that steers arriving during an API call are drained before the 205 next API call — not deferred until the next tool batch. This is the 206 fix for the scenario where /steer sent during model thinking only lands 207 after the agent is completely done.""" 208 209 def test_pre_api_drain_injects_into_last_tool_result(self): 210 """If a steer is pending when the main loop starts building 211 api_messages, it should be injected into the last tool result 212 in the messages list.""" 213 agent = _bare_agent() 214 # Simulate messages after a tool batch completed 215 messages = [ 216 {"role": "user", "content": "do something"}, 217 {"role": "assistant", "content": "ok", "tool_calls": [ 218 {"id": "tc1", "function": {"name": "terminal", "arguments": "{}"}} 219 ]}, 220 {"role": "tool", "content": "output here", "tool_call_id": "tc1"}, 221 ] 222 # Steer arrives during API call (set after tool execution) 223 agent.steer("focus on error handling") 224 # Simulate what the pre-API-call drain does: 225 _pre_api_steer = agent._drain_pending_steer() 226 assert _pre_api_steer == "focus on error handling" 227 # Inject into last tool msg (mirrors the new code in run_conversation) 228 for _si in range(len(messages) - 1, -1, -1): 229 if messages[_si].get("role") == "tool": 230 messages[_si]["content"] += f"\n\nUser guidance: {_pre_api_steer}" 231 break 232 assert "User guidance:" in messages[-1]["content"] 233 assert "focus on error handling" in messages[-1]["content"] 234 assert agent._pending_steer is None 235 236 def test_pre_api_drain_restashes_when_no_tool_message(self): 237 """If there are no tool results yet (first iteration), the steer 238 should be put back into _pending_steer for the post-tool drain.""" 239 agent = _bare_agent() 240 messages = [ 241 {"role": "user", "content": "hello"}, 242 ] 243 agent.steer("early steer") 244 _pre_api_steer = agent._drain_pending_steer() 245 assert _pre_api_steer == "early steer" 246 # No tool message found — put it back 247 found = False 248 for _si in range(len(messages) - 1, -1, -1): 249 if messages[_si].get("role") == "tool": 250 found = True 251 break 252 assert not found 253 # Restash 254 agent._pending_steer = _pre_api_steer 255 assert agent._pending_steer == "early steer" 256 257 def test_pre_api_drain_finds_tool_msg_past_assistant(self): 258 """The pre-API drain should scan backwards past a non-tool message 259 (e.g., if an assistant message was somehow appended after tools) 260 and still find the tool result.""" 261 agent = _bare_agent() 262 messages = [ 263 {"role": "user", "content": "do something"}, 264 {"role": "assistant", "content": "let me check", "tool_calls": [ 265 {"id": "tc1", "function": {"name": "web_search", "arguments": "{}"}} 266 ]}, 267 {"role": "tool", "content": "search results", "tool_call_id": "tc1"}, 268 ] 269 agent.steer("change approach") 270 _pre_api_steer = agent._drain_pending_steer() 271 assert _pre_api_steer is not None 272 for _si in range(len(messages) - 1, -1, -1): 273 if messages[_si].get("role") == "tool": 274 messages[_si]["content"] += f"\n\nUser guidance: {_pre_api_steer}" 275 break 276 assert "change approach" in messages[2]["content"] 277 278 279 class TestSteerCommandRegistry: 280 def test_steer_in_command_registry(self): 281 """The /steer slash command must be registered so it reaches all 282 platforms (CLI, gateway, TUI autocomplete, Telegram/Slack menus). 283 """ 284 from hermes_cli.commands import resolve_command, ACTIVE_SESSION_BYPASS_COMMANDS 285 286 cmd = resolve_command("steer") 287 assert cmd is not None 288 assert cmd.name == "steer" 289 assert cmd.category == "Session" 290 assert cmd.args_hint == "<prompt>" 291 292 def test_steer_in_bypass_set(self): 293 """When the agent is running, /steer MUST bypass the Level-1 294 base-adapter queue so it reaches the gateway runner's /steer 295 handler. Otherwise it would be queued as user text and only 296 delivered at turn end — defeating the whole point. 297 """ 298 from hermes_cli.commands import ACTIVE_SESSION_BYPASS_COMMANDS, should_bypass_active_session 299 300 assert "steer" in ACTIVE_SESSION_BYPASS_COMMANDS 301 assert should_bypass_active_session("steer") is True 302 303 304 if __name__ == "__main__": # pragma: no cover 305 pytest.main([__file__, "-v"])