/ tests / gateway / test_proxy_mode.py
test_proxy_mode.py
  1  """Tests for gateway proxy mode — forwarding messages to a remote API server."""
  2  
  3  import asyncio
  4  import json
  5  import os
  6  from unittest.mock import AsyncMock, MagicMock, patch
  7  
  8  import pytest
  9  
 10  from gateway.config import Platform, StreamingConfig
 11  from gateway.platforms.base import resolve_proxy_url
 12  from gateway.run import GatewayRunner
 13  from gateway.session import SessionSource
 14  
 15  
 16  def _make_runner(proxy_url=None):
 17      """Create a minimal GatewayRunner for proxy tests."""
 18      runner = object.__new__(GatewayRunner)
 19      runner.adapters = {}
 20      runner.config = MagicMock()
 21      runner.config.streaming = StreamingConfig()
 22      runner._running_agents = {}
 23      runner._session_run_generation = {}
 24      runner._session_model_overrides = {}
 25      runner._agent_cache = {}
 26      runner._agent_cache_lock = None
 27      return runner
 28  
 29  
 30  def _make_source(platform=Platform.MATRIX):
 31      return SessionSource(
 32          platform=platform,
 33          chat_id="!room:server.org",
 34          chat_name="Test Room",
 35          chat_type="group",
 36          user_id="@user:server.org",
 37          user_name="testuser",
 38          thread_id=None,
 39      )
 40  
 41  
 42  class _FakeSSEResponse:
 43      """Simulates an aiohttp response with SSE streaming."""
 44  
 45      def __init__(self, status=200, sse_chunks=None, error_text=""):
 46          self.status = status
 47          self._sse_chunks = sse_chunks or []
 48          self._error_text = error_text
 49          self.content = self
 50  
 51      async def text(self):
 52          return self._error_text
 53  
 54      async def iter_any(self):
 55          for chunk in self._sse_chunks:
 56              if isinstance(chunk, str):
 57                  chunk = chunk.encode("utf-8")
 58              yield chunk
 59  
 60      async def __aenter__(self):
 61          return self
 62  
 63      async def __aexit__(self, *args):
 64          pass
 65  
 66  
 67  class _FakeSession:
 68      """Simulates an aiohttp.ClientSession with captured request args."""
 69  
 70      def __init__(self, response):
 71          self._response = response
 72          self.captured_url = None
 73          self.captured_json = None
 74          self.captured_headers = None
 75  
 76      def post(self, url, json=None, headers=None, **kwargs):
 77          self.captured_url = url
 78          self.captured_json = json
 79          self.captured_headers = headers
 80          return self._response
 81  
 82      async def __aenter__(self):
 83          return self
 84  
 85      async def __aexit__(self, *args):
 86          pass
 87  
 88  
 89  def _patch_aiohttp(session):
 90      """Patch aiohttp.ClientSession to return our fake session."""
 91      return patch(
 92          "aiohttp.ClientSession",
 93          return_value=session,
 94      )
 95  
 96  
 97  class TestGetProxyUrl:
 98      """Test _get_proxy_url() config resolution."""
 99  
100      def test_returns_none_when_not_configured(self, monkeypatch):
101          monkeypatch.delenv("GATEWAY_PROXY_URL", raising=False)
102          runner = _make_runner()
103          with patch("gateway.run._load_gateway_config", return_value={}):
104              assert runner._get_proxy_url() is None
105  
106      def test_reads_from_env_var(self, monkeypatch):
107          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://192.168.1.100:8642")
108          runner = _make_runner()
109          assert runner._get_proxy_url() == "http://192.168.1.100:8642"
110  
111      def test_strips_trailing_slash(self, monkeypatch):
112          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642/")
113          runner = _make_runner()
114          assert runner._get_proxy_url() == "http://host:8642"
115  
116      def test_reads_from_config_yaml(self, monkeypatch):
117          monkeypatch.delenv("GATEWAY_PROXY_URL", raising=False)
118          runner = _make_runner()
119          cfg = {"gateway": {"proxy_url": "http://10.0.0.1:8642"}}
120          with patch("gateway.run._load_gateway_config", return_value=cfg):
121              assert runner._get_proxy_url() == "http://10.0.0.1:8642"
122  
123      def test_env_var_overrides_config(self, monkeypatch):
124          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://env-host:8642")
125          runner = _make_runner()
126          cfg = {"gateway": {"proxy_url": "http://config-host:8642"}}
127          with patch("gateway.run._load_gateway_config", return_value=cfg):
128              assert runner._get_proxy_url() == "http://env-host:8642"
129  
130      def test_empty_string_treated_as_unset(self, monkeypatch):
131          monkeypatch.setenv("GATEWAY_PROXY_URL", "  ")
132          runner = _make_runner()
133          with patch("gateway.run._load_gateway_config", return_value={}):
134              assert runner._get_proxy_url() is None
135  
136  
137  class TestResolveProxyUrl:
138      def test_normalizes_socks_alias_from_all_proxy(self, monkeypatch):
139          for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
140                      "https_proxy", "http_proxy", "all_proxy", "NO_PROXY", "no_proxy"):
141              monkeypatch.delenv(key, raising=False)
142          monkeypatch.setenv("ALL_PROXY", "socks://127.0.0.1:1080/")
143          assert resolve_proxy_url() == "socks5://127.0.0.1:1080/"
144  
145      def test_no_proxy_bypasses_matching_host(self, monkeypatch):
146          for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
147                      "https_proxy", "http_proxy", "all_proxy", "NO_PROXY", "no_proxy"):
148              monkeypatch.delenv(key, raising=False)
149          monkeypatch.setenv("HTTPS_PROXY", "http://proxy.example:8080")
150          monkeypatch.setenv("NO_PROXY", "api.telegram.org")
151  
152          assert resolve_proxy_url(target_hosts="api.telegram.org") is None
153  
154      def test_no_proxy_bypasses_cidr_target(self, monkeypatch):
155          for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
156                      "https_proxy", "http_proxy", "all_proxy", "NO_PROXY", "no_proxy"):
157              monkeypatch.delenv(key, raising=False)
158          monkeypatch.setenv("HTTPS_PROXY", "http://proxy.example:8080")
159          monkeypatch.setenv("NO_PROXY", "149.154.160.0/20")
160  
161          assert resolve_proxy_url(target_hosts=["149.154.167.220"]) is None
162  
163      def test_no_proxy_ignored_without_target(self, monkeypatch):
164          for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
165                      "https_proxy", "http_proxy", "all_proxy", "NO_PROXY", "no_proxy"):
166              monkeypatch.delenv(key, raising=False)
167          monkeypatch.setenv("HTTPS_PROXY", "http://proxy.example:8080")
168          monkeypatch.setenv("NO_PROXY", "*")
169  
170          assert resolve_proxy_url() == "http://proxy.example:8080"
171  
172  
173  class TestRunAgentProxyDispatch:
174      """Test that _run_agent() delegates to proxy when configured."""
175  
176      @pytest.mark.asyncio
177      async def test_run_agent_delegates_to_proxy(self, monkeypatch):
178          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
179          runner = _make_runner()
180          source = _make_source()
181  
182          expected_result = {
183              "final_response": "Hello from remote!",
184              "messages": [
185                  {"role": "user", "content": "hi"},
186                  {"role": "assistant", "content": "Hello from remote!"},
187              ],
188              "api_calls": 1,
189              "tools": [],
190          }
191  
192          runner._run_agent_via_proxy = AsyncMock(return_value=expected_result)
193  
194          result = await runner._run_agent(
195              message="hi",
196              context_prompt="",
197              history=[],
198              source=source,
199              session_id="test-session-123",
200              session_key="test-key",
201              run_generation=7,
202          )
203  
204          assert result["final_response"] == "Hello from remote!"
205          runner._run_agent_via_proxy.assert_called_once()
206          assert runner._run_agent_via_proxy.call_args.kwargs["run_generation"] == 7
207  
208      @pytest.mark.asyncio
209      async def test_run_agent_skips_proxy_when_not_configured(self, monkeypatch):
210          monkeypatch.delenv("GATEWAY_PROXY_URL", raising=False)
211          runner = _make_runner()
212  
213          runner._run_agent_via_proxy = AsyncMock()
214  
215          with patch("gateway.run._load_gateway_config", return_value={}):
216              try:
217                  await runner._run_agent(
218                      message="hi",
219                      context_prompt="",
220                      history=[],
221                      source=_make_source(),
222                      session_id="test-session",
223                  )
224              except Exception:
225                  pass  # Expected — bare runner can't create a real agent
226  
227          runner._run_agent_via_proxy.assert_not_called()
228  
229  
230  class TestRunAgentViaProxy:
231      """Test the actual proxy HTTP forwarding logic."""
232  
233      @pytest.mark.asyncio
234      async def test_builds_correct_request(self, monkeypatch):
235          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
236          monkeypatch.setenv("GATEWAY_PROXY_KEY", "test-key-123")
237          runner = _make_runner()
238          source = _make_source()
239  
240          resp = _FakeSSEResponse(
241              status=200,
242              sse_chunks=[
243                  'data: {"choices":[{"delta":{"content":"Hello"}}]}\n\n'
244                  'data: {"choices":[{"delta":{"content":" world"}}]}\n\n'
245                  "data: [DONE]\n\n"
246              ],
247          )
248          session = _FakeSession(resp)
249  
250          with patch("gateway.run._load_gateway_config", return_value={}):
251              with _patch_aiohttp(session):
252                  with patch("aiohttp.ClientTimeout"):
253                      result = await runner._run_agent_via_proxy(
254                          message="How are you?",
255                          context_prompt="You are helpful.",
256                          history=[
257                              {"role": "user", "content": "Hello"},
258                              {"role": "assistant", "content": "Hi there!"},
259                          ],
260                          source=source,
261                          session_id="session-abc",
262                      )
263  
264          # Verify request URL
265          assert session.captured_url == "http://host:8642/v1/chat/completions"
266  
267          # Verify auth header
268          assert session.captured_headers["Authorization"] == "Bearer test-key-123"
269  
270          # Verify session ID header
271          assert session.captured_headers["X-Hermes-Session-Id"] == "session-abc"
272  
273          # Verify messages include system, history, and current message
274          messages = session.captured_json["messages"]
275          assert messages[0] == {"role": "system", "content": "You are helpful."}
276          assert messages[1] == {"role": "user", "content": "Hello"}
277          assert messages[2] == {"role": "assistant", "content": "Hi there!"}
278          assert messages[3] == {"role": "user", "content": "How are you?"}
279  
280          # Verify streaming is requested
281          assert session.captured_json["stream"] is True
282  
283          # Verify response was assembled
284          assert result["final_response"] == "Hello world"
285  
286      @pytest.mark.asyncio
287      async def test_handles_http_error(self, monkeypatch):
288          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
289          monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
290          runner = _make_runner()
291          source = _make_source()
292  
293          resp = _FakeSSEResponse(status=401, error_text="Unauthorized: invalid API key")
294          session = _FakeSession(resp)
295  
296          with patch("gateway.run._load_gateway_config", return_value={}):
297              with _patch_aiohttp(session):
298                  with patch("aiohttp.ClientTimeout"):
299                      result = await runner._run_agent_via_proxy(
300                          message="hi",
301                          context_prompt="",
302                          history=[],
303                          source=source,
304                          session_id="test",
305                      )
306  
307          assert "Proxy error (401)" in result["final_response"]
308          assert result["api_calls"] == 0
309  
310      @pytest.mark.asyncio
311      async def test_handles_connection_error(self, monkeypatch):
312          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://unreachable:8642")
313          monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
314          runner = _make_runner()
315          source = _make_source()
316  
317          class _ErrorSession:
318              def post(self, *args, **kwargs):
319                  raise ConnectionError("Connection refused")
320  
321              async def __aenter__(self):
322                  return self
323  
324              async def __aexit__(self, *args):
325                  pass
326  
327          with patch("gateway.run._load_gateway_config", return_value={}):
328              with patch("aiohttp.ClientSession", return_value=_ErrorSession()):
329                  with patch("aiohttp.ClientTimeout"):
330                      result = await runner._run_agent_via_proxy(
331                          message="hi",
332                          context_prompt="",
333                          history=[],
334                          source=source,
335                          session_id="test",
336                      )
337  
338          assert "Proxy connection error" in result["final_response"]
339  
340      @pytest.mark.asyncio
341      async def test_skips_tool_messages_in_history(self, monkeypatch):
342          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
343          monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
344          runner = _make_runner()
345          source = _make_source()
346  
347          resp = _FakeSSEResponse(
348              status=200,
349              sse_chunks=[b'data: {"choices":[{"delta":{"content":"ok"}}]}\n\ndata: [DONE]\n\n'],
350          )
351          session = _FakeSession(resp)
352  
353          history = [
354              {"role": "user", "content": "search for X"},
355              {"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]},
356              {"role": "tool", "content": "search results...", "tool_call_id": "tc1"},
357              {"role": "assistant", "content": "Found results."},
358          ]
359  
360          with patch("gateway.run._load_gateway_config", return_value={}):
361              with _patch_aiohttp(session):
362                  with patch("aiohttp.ClientTimeout"):
363                      await runner._run_agent_via_proxy(
364                          message="tell me more",
365                          context_prompt="",
366                          history=history,
367                          source=source,
368                          session_id="test",
369                      )
370  
371          # Only user and assistant with content should be forwarded
372          messages = session.captured_json["messages"]
373          roles = [m["role"] for m in messages]
374          assert "tool" not in roles
375          # assistant with None content should be skipped
376          assert all(m.get("content") for m in messages)
377  
378      @pytest.mark.asyncio
379      async def test_result_shape_matches_run_agent(self, monkeypatch):
380          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
381          monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
382          runner = _make_runner()
383          source = _make_source()
384  
385          resp = _FakeSSEResponse(
386              status=200,
387              sse_chunks=[b'data: {"choices":[{"delta":{"content":"answer"}}]}\n\ndata: [DONE]\n\n'],
388          )
389          session = _FakeSession(resp)
390  
391          with patch("gateway.run._load_gateway_config", return_value={}):
392              with _patch_aiohttp(session):
393                  with patch("aiohttp.ClientTimeout"):
394                      result = await runner._run_agent_via_proxy(
395                          message="hi",
396                          context_prompt="",
397                          history=[{"role": "user", "content": "prev"}, {"role": "assistant", "content": "ok"}],
398                          source=source,
399                          session_id="sess-123",
400                      )
401  
402          # Required keys that callers depend on
403          assert "final_response" in result
404          assert result["final_response"] == "answer"
405          assert "messages" in result
406          assert "api_calls" in result
407          assert "tools" in result
408          assert "history_offset" in result
409          assert result["history_offset"] == 2  # len(history)
410          assert "session_id" in result
411          assert result["session_id"] == "sess-123"
412  
413      @pytest.mark.asyncio
414      async def test_proxy_stale_generation_returns_empty_result(self, monkeypatch):
415          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
416          monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
417          runner = _make_runner()
418          source = _make_source()
419          runner._session_run_generation["test-key"] = 2
420  
421          resp = _FakeSSEResponse(
422              status=200,
423              sse_chunks=[
424                  'data: {"choices":[{"delta":{"content":"stale"}}]}\n\n',
425                  "data: [DONE]\n\n",
426              ],
427          )
428          session = _FakeSession(resp)
429  
430          with patch("gateway.run._load_gateway_config", return_value={}):
431              with _patch_aiohttp(session):
432                  with patch("aiohttp.ClientTimeout"):
433                      result = await runner._run_agent_via_proxy(
434                          message="hi",
435                          context_prompt="",
436                          history=[],
437                          source=source,
438                          session_id="sess-123",
439                          session_key="test-key",
440                          run_generation=1,
441                      )
442  
443          assert result["final_response"] == ""
444          assert result["messages"] == []
445          assert result["api_calls"] == 0
446  
447      @pytest.mark.asyncio
448      async def test_no_auth_header_without_key(self, monkeypatch):
449          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
450          monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
451          runner = _make_runner()
452          source = _make_source()
453  
454          resp = _FakeSSEResponse(
455              status=200,
456              sse_chunks=[b'data: {"choices":[{"delta":{"content":"ok"}}]}\n\ndata: [DONE]\n\n'],
457          )
458          session = _FakeSession(resp)
459  
460          with patch("gateway.run._load_gateway_config", return_value={}):
461              with _patch_aiohttp(session):
462                  with patch("aiohttp.ClientTimeout"):
463                      await runner._run_agent_via_proxy(
464                          message="hi",
465                          context_prompt="",
466                          history=[],
467                          source=source,
468                          session_id="test",
469                      )
470  
471          assert "Authorization" not in session.captured_headers
472  
473      @pytest.mark.asyncio
474      async def test_no_system_message_when_context_empty(self, monkeypatch):
475          monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
476          monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
477          runner = _make_runner()
478          source = _make_source()
479  
480          resp = _FakeSSEResponse(
481              status=200,
482              sse_chunks=[b'data: {"choices":[{"delta":{"content":"ok"}}]}\n\ndata: [DONE]\n\n'],
483          )
484          session = _FakeSession(resp)
485  
486          with patch("gateway.run._load_gateway_config", return_value={}):
487              with _patch_aiohttp(session):
488                  with patch("aiohttp.ClientTimeout"):
489                      await runner._run_agent_via_proxy(
490                          message="hello",
491                          context_prompt="",
492                          history=[],
493                          source=source,
494                          session_id="test",
495                      )
496  
497          # No system message should appear when context_prompt is empty
498          messages = session.captured_json["messages"]
499          assert len(messages) == 1
500          assert messages[0]["role"] == "user"
501          assert messages[0]["content"] == "hello"
502  
503  
504  class TestEnvVarRegistration:
505      """Verify GATEWAY_PROXY_URL and GATEWAY_PROXY_KEY are registered."""
506  
507      def test_proxy_url_in_optional_env_vars(self):
508          from hermes_cli.config import OPTIONAL_ENV_VARS
509          assert "GATEWAY_PROXY_URL" in OPTIONAL_ENV_VARS
510          info = OPTIONAL_ENV_VARS["GATEWAY_PROXY_URL"]
511          assert info["category"] == "messaging"
512          assert info["password"] is False
513  
514      def test_proxy_key_in_optional_env_vars(self):
515          from hermes_cli.config import OPTIONAL_ENV_VARS
516          assert "GATEWAY_PROXY_KEY" in OPTIONAL_ENV_VARS
517          info = OPTIONAL_ENV_VARS["GATEWAY_PROXY_KEY"]
518          assert info["category"] == "messaging"
519          assert info["password"] is True