/ tests / run_agent / test_steer.py
test_steer.py
  1  """Tests for AIAgent.steer() — mid-run user message injection.
  2  
  3  /steer lets the user add a note to the agent's next tool result without
  4  interrupting the current tool call. The agent sees the note inline with
  5  tool output on its next iteration, preserving message-role alternation
  6  and prompt-cache integrity.
  7  """
  8  from __future__ import annotations
  9  
 10  import threading
 11  
 12  import pytest
 13  
 14  from run_agent import AIAgent
 15  
 16  
 17  def _bare_agent() -> AIAgent:
 18      """Build an AIAgent without running __init__, then install the steer
 19      state manually — matches the existing object.__new__ stub pattern
 20      used elsewhere in the test suite.
 21      """
 22      agent = object.__new__(AIAgent)
 23      agent._pending_steer = None
 24      agent._pending_steer_lock = threading.Lock()
 25      return agent
 26  
 27  
 28  class TestSteerAcceptance:
 29      def test_accepts_non_empty_text(self):
 30          agent = _bare_agent()
 31          assert agent.steer("go ahead and check the logs") is True
 32          assert agent._pending_steer == "go ahead and check the logs"
 33  
 34      def test_rejects_empty_string(self):
 35          agent = _bare_agent()
 36          assert agent.steer("") is False
 37          assert agent._pending_steer is None
 38  
 39      def test_rejects_whitespace_only(self):
 40          agent = _bare_agent()
 41          assert agent.steer("   \n\t  ") is False
 42          assert agent._pending_steer is None
 43  
 44      def test_rejects_none(self):
 45          agent = _bare_agent()
 46          assert agent.steer(None) is False  # type: ignore[arg-type]
 47          assert agent._pending_steer is None
 48  
 49      def test_strips_surrounding_whitespace(self):
 50          agent = _bare_agent()
 51          assert agent.steer("  hello world  \n") is True
 52          assert agent._pending_steer == "hello world"
 53  
 54      def test_concatenates_multiple_steers_with_newlines(self):
 55          agent = _bare_agent()
 56          agent.steer("first note")
 57          agent.steer("second note")
 58          agent.steer("third note")
 59          assert agent._pending_steer == "first note\nsecond note\nthird note"
 60  
 61  
 62  class TestSteerDrain:
 63      def test_drain_returns_and_clears(self):
 64          agent = _bare_agent()
 65          agent.steer("hello")
 66          assert agent._drain_pending_steer() == "hello"
 67          assert agent._pending_steer is None
 68  
 69      def test_drain_on_empty_returns_none(self):
 70          agent = _bare_agent()
 71          assert agent._drain_pending_steer() is None
 72  
 73  
 74  class TestSteerInjection:
 75      def test_appends_to_last_tool_result(self):
 76          agent = _bare_agent()
 77          agent.steer("please also check auth.log")
 78          messages = [
 79              {"role": "user", "content": "what's in /var/log?"},
 80              {"role": "assistant", "tool_calls": [{"id": "a"}, {"id": "b"}]},
 81              {"role": "tool", "content": "ls output A", "tool_call_id": "a"},
 82              {"role": "tool", "content": "ls output B", "tool_call_id": "b"},
 83          ]
 84          agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2)
 85          # The LAST tool result is modified; earlier ones are untouched.
 86          assert messages[2]["content"] == "ls output A"
 87          assert "ls output B" in messages[3]["content"]
 88          assert "User guidance:" in messages[3]["content"]
 89          assert "please also check auth.log" in messages[3]["content"]
 90          # And pending_steer is consumed.
 91          assert agent._pending_steer is None
 92  
 93      def test_no_op_when_no_steer_pending(self):
 94          agent = _bare_agent()
 95          messages = [
 96              {"role": "assistant", "tool_calls": [{"id": "a"}]},
 97              {"role": "tool", "content": "output", "tool_call_id": "a"},
 98          ]
 99          agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
100          assert messages[-1]["content"] == "output"  # unchanged
101  
102      def test_no_op_when_num_tool_msgs_zero(self):
103          agent = _bare_agent()
104          agent.steer("steer")
105          messages = [{"role": "user", "content": "hi"}]
106          agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=0)
107          # Steer should remain pending (nothing to drain into)
108          assert agent._pending_steer == "steer"
109  
110      def test_marker_labels_text_as_user_guidance(self):
111          """The injection marker must label the appended text as user
112          guidance so the model attributes it to the user rather than
113          confusing it with tool output.  This is the cache-safe way to
114          signal provenance without violating message-role alternation.
115          """
116          agent = _bare_agent()
117          agent.steer("stop after next step")
118          messages = [{"role": "tool", "content": "x", "tool_call_id": "1"}]
119          agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
120          content = messages[-1]["content"]
121          assert "User guidance:" in content
122          assert "stop after next step" in content
123  
124      def test_multimodal_content_list_preserved(self):
125          """Anthropic-style list content should be preserved, with the steer
126          appended as a text block."""
127          agent = _bare_agent()
128          agent.steer("extra note")
129          original_blocks = [{"type": "text", "text": "existing output"}]
130          messages = [
131              {"role": "tool", "content": list(original_blocks), "tool_call_id": "1"}
132          ]
133          agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
134          new_content = messages[-1]["content"]
135          assert isinstance(new_content, list)
136          assert len(new_content) == 2
137          assert new_content[0] == {"type": "text", "text": "existing output"}
138          assert new_content[1]["type"] == "text"
139          assert "extra note" in new_content[1]["text"]
140  
141      def test_restashed_when_no_tool_result_in_batch(self):
142          """If the 'batch' contains no tool-role messages (e.g. all skipped
143          after an interrupt), the steer should be put back into the pending
144          slot so the caller's fallback path can deliver it."""
145          agent = _bare_agent()
146          agent.steer("ping")
147          messages = [
148              {"role": "user", "content": "x"},
149              {"role": "assistant", "content": "y"},
150          ]
151          # Claim there were N tool msgs, but the tail has none — simulates
152          # the interrupt-cancelled case.
153          agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2)
154          # Messages untouched
155          assert messages[-1]["content"] == "y"
156          # And the steer is back in pending so the fallback can grab it
157          assert agent._pending_steer == "ping"
158  
159  
160  class TestSteerThreadSafety:
161      def test_concurrent_steer_calls_preserve_all_text(self):
162          agent = _bare_agent()
163          N = 200
164  
165          def worker(idx: int) -> None:
166              agent.steer(f"note-{idx}")
167  
168          threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
169          for t in threads:
170              t.start()
171          for t in threads:
172              t.join()
173  
174          text = agent._drain_pending_steer()
175          assert text is not None
176          # Every single note must be preserved — none dropped by the lock.
177          lines = text.split("\n")
178          assert len(lines) == N
179          assert set(lines) == {f"note-{i}" for i in range(N)}
180  
181  
182  class TestSteerClearedOnInterrupt:
183      def test_clear_interrupt_drops_pending_steer(self):
184          """A hard interrupt supersedes any pending steer — the agent's
185          next tool iteration won't happen, so delivering the steer later
186          would be surprising."""
187          agent = _bare_agent()
188          # Minimal surface needed by clear_interrupt()
189          agent._interrupt_requested = True
190          agent._interrupt_message = None
191          agent._interrupt_thread_signal_pending = False
192          agent._execution_thread_id = None
193          agent._tool_worker_threads = None
194          agent._tool_worker_threads_lock = None
195  
196          agent.steer("will be dropped")
197          assert agent._pending_steer == "will be dropped"
198  
199          agent.clear_interrupt()
200          assert agent._pending_steer is None
201  
202  
203  class TestPreApiCallSteerDrain:
204      """Test that steers arriving during an API call are drained before the
205      next API call — not deferred until the next tool batch.  This is the
206      fix for the scenario where /steer sent during model thinking only lands
207      after the agent is completely done."""
208  
209      def test_pre_api_drain_injects_into_last_tool_result(self):
210          """If a steer is pending when the main loop starts building
211          api_messages, it should be injected into the last tool result
212          in the messages list."""
213          agent = _bare_agent()
214          # Simulate messages after a tool batch completed
215          messages = [
216              {"role": "user", "content": "do something"},
217              {"role": "assistant", "content": "ok", "tool_calls": [
218                  {"id": "tc1", "function": {"name": "terminal", "arguments": "{}"}}
219              ]},
220              {"role": "tool", "content": "output here", "tool_call_id": "tc1"},
221          ]
222          # Steer arrives during API call (set after tool execution)
223          agent.steer("focus on error handling")
224          # Simulate what the pre-API-call drain does:
225          _pre_api_steer = agent._drain_pending_steer()
226          assert _pre_api_steer == "focus on error handling"
227          # Inject into last tool msg (mirrors the new code in run_conversation)
228          for _si in range(len(messages) - 1, -1, -1):
229              if messages[_si].get("role") == "tool":
230                  messages[_si]["content"] += f"\n\nUser guidance: {_pre_api_steer}"
231                  break
232          assert "User guidance:" in messages[-1]["content"]
233          assert "focus on error handling" in messages[-1]["content"]
234          assert agent._pending_steer is None
235  
236      def test_pre_api_drain_restashes_when_no_tool_message(self):
237          """If there are no tool results yet (first iteration), the steer
238          should be put back into _pending_steer for the post-tool drain."""
239          agent = _bare_agent()
240          messages = [
241              {"role": "user", "content": "hello"},
242          ]
243          agent.steer("early steer")
244          _pre_api_steer = agent._drain_pending_steer()
245          assert _pre_api_steer == "early steer"
246          # No tool message found — put it back
247          found = False
248          for _si in range(len(messages) - 1, -1, -1):
249              if messages[_si].get("role") == "tool":
250                  found = True
251                  break
252          assert not found
253          # Restash
254          agent._pending_steer = _pre_api_steer
255          assert agent._pending_steer == "early steer"
256  
257      def test_pre_api_drain_finds_tool_msg_past_assistant(self):
258          """The pre-API drain should scan backwards past a non-tool message
259          (e.g., if an assistant message was somehow appended after tools)
260          and still find the tool result."""
261          agent = _bare_agent()
262          messages = [
263              {"role": "user", "content": "do something"},
264              {"role": "assistant", "content": "let me check", "tool_calls": [
265                  {"id": "tc1", "function": {"name": "web_search", "arguments": "{}"}}
266              ]},
267              {"role": "tool", "content": "search results", "tool_call_id": "tc1"},
268          ]
269          agent.steer("change approach")
270          _pre_api_steer = agent._drain_pending_steer()
271          assert _pre_api_steer is not None
272          for _si in range(len(messages) - 1, -1, -1):
273              if messages[_si].get("role") == "tool":
274                  messages[_si]["content"] += f"\n\nUser guidance: {_pre_api_steer}"
275                  break
276          assert "change approach" in messages[2]["content"]
277  
278  
279  class TestSteerCommandRegistry:
280      def test_steer_in_command_registry(self):
281          """The /steer slash command must be registered so it reaches all
282          platforms (CLI, gateway, TUI autocomplete, Telegram/Slack menus).
283          """
284          from hermes_cli.commands import resolve_command, ACTIVE_SESSION_BYPASS_COMMANDS
285  
286          cmd = resolve_command("steer")
287          assert cmd is not None
288          assert cmd.name == "steer"
289          assert cmd.category == "Session"
290          assert cmd.args_hint == "<prompt>"
291  
292      def test_steer_in_bypass_set(self):
293          """When the agent is running, /steer MUST bypass the Level-1
294          base-adapter queue so it reaches the gateway runner's /steer
295          handler. Otherwise it would be queued as user text and only
296          delivered at turn end — defeating the whole point.
297          """
298          from hermes_cli.commands import ACTIVE_SESSION_BYPASS_COMMANDS, should_bypass_active_session
299  
300          assert "steer" in ACTIVE_SESSION_BYPASS_COMMANDS
301          assert should_bypass_active_session("steer") is True
302  
303  
304  if __name__ == "__main__":  # pragma: no cover
305      pytest.main([__file__, "-v"])