/ server / tests / test_routes_proxy.py
test_routes_proxy.py
  1  # Copyright 2026 Alibaba Group Holding Ltd.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  
 15  import asyncio
 16  from typing import Any, cast
 17  
 18  import httpx
 19  from fastapi.testclient import TestClient
 20  from websockets.typing import Origin
 21  
 22  import opensandbox_server.api.proxy as proxy_api
 23  from opensandbox_server.api import lifecycle
 24  from opensandbox_server.api.schema import Endpoint
 25  from opensandbox_server.middleware.auth import SANDBOX_API_KEY_HEADER
 26  from opensandbox_server.services.constants import OPEN_SANDBOX_EGRESS_AUTH_HEADER, OPEN_SANDBOX_INGRESS_HEADER
 27  from opensandbox_server.services.constants import OPEN_SANDBOX_SECURE_ACCESS_HEADER
 28  
 29  
 30  class _FakeStreamingResponse:
 31      def __init__(
 32          self, status_code: int = 200, headers: dict | None = None, chunks: list[bytes] | None = None
 33      ):
 34          self.status_code = status_code
 35          self.headers = httpx.Headers(headers or {})
 36          self._chunks = chunks or []
 37          self.aclose_called = False
 38  
 39      async def aiter_bytes(self):
 40          for chunk in self._chunks:
 41              yield chunk
 42  
 43      async def aclose(self):
 44          self.aclose_called = True
 45  
 46  
 47  class _FakeAsyncClient:
 48      def __init__(self):
 49          self.built = None
 50          self.response = _FakeStreamingResponse()
 51          self.raise_connect_error = False
 52          self.raise_generic_error = False
 53  
 54      def build_request(
 55          self,
 56          method: str,
 57          url: str,
 58          headers: dict,
 59          content,
 60          params: str | None = None,
 61      ):
 62          self.built = {
 63              "method": method,
 64              "url": url,
 65              "params": params,
 66              "headers": headers,
 67              "content": content,
 68          }
 69          return self.built
 70  
 71      async def send(self, req, stream: bool = True):
 72          if self.raise_connect_error:
 73              raise httpx.ConnectError("connection refused")
 74          if self.raise_generic_error:
 75              raise RuntimeError("unexpected proxy error")
 76          return self.response
 77  
 78  
 79  def _set_http_client(client: TestClient, fake_client: _FakeAsyncClient) -> None:
 80      cast(Any, client.app).state.http_client = fake_client
 81  
 82  
 83  class _FakeBackendWebSocket:
 84      def __init__(self, message: str = "backend-ready", subprotocol: str | None = "claw.v1"):
 85          self.message = message
 86          self.subprotocol = subprotocol
 87          self.sent: list[str | bytes] = []
 88          self.close_calls: list[tuple[int, str]] = []
 89          self._delivered = False
 90  
 91      async def send(self, payload: str | bytes) -> None:
 92          self.sent.append(payload)
 93  
 94      async def recv(self) -> str:
 95          if not self._delivered:
 96              self._delivered = True
 97              return self.message
 98          await asyncio.Future()
 99          raise AssertionError("unreachable")
100  
101      async def close(self, code: int = 1000, reason: str = "") -> None:
102          self.close_calls.append((code, reason))
103  
104  
105  class _FakeWebSocketConnector:
106      def __init__(self, backend: _FakeBackendWebSocket):
107          self.backend = backend
108          self.calls: list[dict] = []
109  
110      def __call__(self, uri: str, **kwargs):
111          self.calls.append({"uri": uri, **kwargs})
112          backend = self.backend
113  
114          class _ContextManager:
115              async def __aenter__(self):
116                  return backend
117  
118              async def __aexit__(self, exc_type, exc, tb):
119                  return False
120  
121          return _ContextManager()
122  
123  
124  def test_proxy_forwards_filtered_headers_and_query(
125      client: TestClient,
126      auth_headers: dict,
127      monkeypatch,
128  ) -> None:
129      class StubService:
130          @staticmethod
131          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
132              assert sandbox_id == "sbx-123"
133              assert port == 44772
134              assert resolve_internal is True
135              return Endpoint(endpoint="10.57.1.91:40109")
136  
137      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
138  
139      fake_client = _FakeAsyncClient()
140      fake_client.response = _FakeStreamingResponse(
141          status_code=201,
142          headers={"x-backend": "yes"},
143          chunks=[b"proxy-ok"],
144      )
145      _set_http_client(client, fake_client)
146  
147      headers = {
148          **auth_headers,
149          "Authorization": "Bearer top-secret",
150          "Cookie": "sid=secret",
151          "Connection": "keep-alive, X-Hop-Temp",
152          "Upgrade": "h2c",
153          "Trailer": "X-Checksum",
154          "X-Hop-Temp": "drop-me",
155          "X-Trace": "trace-1",
156      }
157  
158      response = client.post(
159          "/v1/sandboxes/sbx-123/proxy/44772/api/run",
160          params={"q": "search"},
161          headers=headers,
162          content=b'{"hello":"world"}',
163      )
164  
165      assert response.status_code == 201
166      assert response.content == b"proxy-ok"
167      assert response.headers.get("x-backend") == "yes"
168  
169      assert fake_client.built is not None
170      assert fake_client.built["method"] == "POST"
171      assert fake_client.built["url"] == "http://10.57.1.91:40109/api/run"
172      assert fake_client.built["params"] == "q=search"
173      forwarded_headers = fake_client.built["headers"]
174      lowered_headers = {k.lower(): v for k, v in forwarded_headers.items()}
175      assert "host" not in lowered_headers
176      assert "connection" not in lowered_headers
177      assert "upgrade" not in lowered_headers
178      assert "trailer" not in lowered_headers
179      assert "authorization" not in lowered_headers
180      assert "cookie" not in lowered_headers
181      assert SANDBOX_API_KEY_HEADER.lower() not in lowered_headers
182      assert "x-hop-temp" not in lowered_headers
183      assert lowered_headers.get("x-trace") == "trace-1"
184      assert fake_client.response.aclose_called is True
185  
186  
187  def test_proxy_root_path_forwards_endpoint_headers_and_query(
188      client: TestClient,
189      auth_headers: dict,
190      monkeypatch,
191  ) -> None:
192      class StubService:
193          @staticmethod
194          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
195              assert sandbox_id == "sbx-123"
196              assert port == 44772
197              assert resolve_internal is True
198              return Endpoint(
199                  endpoint="10.57.1.91:40109/base",
200                  headers={OPEN_SANDBOX_INGRESS_HEADER: "sbx-123-44772"},
201              )
202  
203      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
204  
205      fake_client = _FakeAsyncClient()
206      fake_client.response = _FakeStreamingResponse(chunks=[b"root-ok"])
207      _set_http_client(client, fake_client)
208  
209      response = client.get(
210          "/v1/sandboxes/sbx-123/proxy/44772",
211          params={"q": "search"},
212          headers={**auth_headers, "X-Trace": "trace-root"},
213      )
214  
215      assert response.status_code == 200
216      assert response.content == b"root-ok"
217      assert fake_client.built is not None
218      assert fake_client.built["url"] == "http://10.57.1.91:40109/base"
219      assert fake_client.built["params"] == "q=search"
220      lowered_headers = {
221          key.lower(): value for key, value in fake_client.built["headers"].items()
222      }
223      assert lowered_headers["opensandbox-ingress-to"] == "sbx-123-44772"
224      assert lowered_headers["x-trace"] == "trace-root"
225  
226  
227  def test_proxy_does_not_auto_inject_secure_access_header(
228      client: TestClient,
229      auth_headers: dict,
230      monkeypatch,
231  ) -> None:
232      class StubService:
233          @staticmethod
234          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
235              assert sandbox_id == "sbx-123"
236              assert port == 44772
237              assert resolve_internal is True
238              return Endpoint(
239                  endpoint="10.57.1.91:40109/base",
240                  headers={
241                      OPEN_SANDBOX_INGRESS_HEADER: "sbx-123-44772",
242                      OPEN_SANDBOX_SECURE_ACCESS_HEADER: "secure-token",
243                  },
244              )
245  
246      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
247  
248      fake_client = _FakeAsyncClient()
249      fake_client.response = _FakeStreamingResponse(chunks=[b"root-ok"])
250      _set_http_client(client, fake_client)
251  
252      response = client.get(
253          "/v1/sandboxes/sbx-123/proxy/44772",
254          params={"q": "search"},
255          headers={**auth_headers, "X-Trace": "trace-root"},
256      )
257  
258      assert response.status_code == 200
259      lowered_headers = {
260          key.lower(): value for key, value in fake_client.built["headers"].items()
261      }
262      assert lowered_headers["opensandbox-ingress-to"] == "sbx-123-44772"
263      assert OPEN_SANDBOX_SECURE_ACCESS_HEADER.lower() not in lowered_headers
264  
265  
266  def test_proxy_forwards_client_supplied_secure_access_header(
267      client: TestClient,
268      auth_headers: dict,
269      monkeypatch,
270  ) -> None:
271      class StubService:
272          @staticmethod
273          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
274              assert sandbox_id == "sbx-123"
275              assert port == 44772
276              assert resolve_internal is True
277              return Endpoint(
278                  endpoint="10.57.1.91:40109/base",
279                  headers={
280                      OPEN_SANDBOX_INGRESS_HEADER: "sbx-123-44772",
281                      OPEN_SANDBOX_SECURE_ACCESS_HEADER: "server-side-token",
282                  },
283              )
284  
285      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
286  
287      fake_client = _FakeAsyncClient()
288      fake_client.response = _FakeStreamingResponse(chunks=[b"root-ok"])
289      _set_http_client(client, fake_client)
290  
291      response = client.get(
292          "/v1/sandboxes/sbx-123/proxy/44772",
293          headers={
294              **auth_headers,
295              OPEN_SANDBOX_SECURE_ACCESS_HEADER: "client-token",
296          },
297      )
298  
299      assert response.status_code == 200
300      lowered_headers = {
301          key.lower(): value for key, value in fake_client.built["headers"].items()
302      }
303      assert lowered_headers[OPEN_SANDBOX_SECURE_ACCESS_HEADER.lower()] == "client-token"
304  
305  
306  def test_proxy_forwards_get_request_with_query_params(
307      client: TestClient,
308      auth_headers: dict,
309      monkeypatch,
310  ) -> None:
311      """Test that GET requests with query parameters are forwarded correctly.
312  
313      This test verifies the fix for issue #484 where GET requests with query
314      parameters were failing with 400 MISSING_QUERY when using use_server_proxy.
315      The query string should be passed via httpx params, not embedded in URL.
316      """
317      class StubService:
318          @staticmethod
319          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
320              assert sandbox_id == "sbx-123"
321              assert port == 44772
322              assert resolve_internal is True
323              return Endpoint(endpoint="10.57.1.91:40109")
324  
325      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
326  
327      fake_client = _FakeAsyncClient()
328      fake_client.response = _FakeStreamingResponse(
329          status_code=200,
330          headers={"content-type": "application/json"},
331          chunks=[b'[{"name":"file.txt","size":100}]'],
332      )
333      _set_http_client(client, fake_client)
334  
335      response = client.get(
336          "/v1/sandboxes/sbx-123/proxy/44772/files/search",
337          params={"path": "/workspace"},
338          headers=auth_headers,
339      )
340  
341      assert response.status_code == 200
342      assert fake_client.built is not None
343      assert fake_client.built["method"] == "GET"
344      assert fake_client.built["url"] == "http://10.57.1.91:40109/files/search"
345      assert fake_client.built["params"] == "path=%2Fworkspace"
346      assert fake_client.built["content"] is None
347  
348  
349  def test_proxy_forwards_delete_request_with_body(
350      client: TestClient,
351      auth_headers: dict,
352      monkeypatch,
353  ) -> None:
354      """Test that DELETE requests with body payload are forwarded correctly.
355  
356      This verifies that DELETE requests with JSON/body payload are not
357      incorrectly stripped when proxying.
358      """
359      class StubService:
360          @staticmethod
361          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
362              return Endpoint(endpoint="10.57.1.91:40109")
363  
364      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
365  
366      fake_client = _FakeAsyncClient()
367      fake_client.response = _FakeStreamingResponse(
368          status_code=200,
369          headers={"content-type": "application/json"},
370          chunks=[b'{"deleted":true}'],
371      )
372      _set_http_client(client, fake_client)
373  
374      response = client.request(
375          "DELETE",
376          "/v1/sandboxes/sbx-123/proxy/44772/resources",
377          headers=auth_headers,
378          content=b'{"id": "resource-123"}',
379      )
380  
381      assert response.status_code == 200
382      assert fake_client.built is not None
383      assert fake_client.built["method"] == "DELETE"
384      assert fake_client.built["content"] is not None
385  
386  
387  def test_proxy_filters_response_hop_by_hop_headers(
388      client: TestClient,
389      auth_headers: dict,
390      monkeypatch,
391  ) -> None:
392      class StubService:
393          @staticmethod
394          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
395              assert resolve_internal is True
396              return Endpoint(endpoint="10.57.1.91:40109")
397  
398      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
399  
400      fake_client = _FakeAsyncClient()
401      fake_client.response = _FakeStreamingResponse(
402          status_code=200,
403          headers={
404              "x-backend": "yes",
405              "Connection": "keep-alive, X-Hop-Temp",
406              "Keep-Alive": "timeout=5",
407              "Trailer": "X-Checksum",
408              "X-Hop-Temp": "drop-me",
409          },
410          chunks=[b"proxy-ok"],
411      )
412      _set_http_client(client, fake_client)
413  
414      response = client.get(
415          "/v1/sandboxes/sbx-123/proxy/44772/healthz",
416          headers=auth_headers,
417      )
418  
419      assert response.status_code == 200
420      assert response.content == b"proxy-ok"
421      assert response.headers.get("x-backend") == "yes"
422      assert response.headers.get("connection") is None
423      assert response.headers.get("keep-alive") is None
424      assert response.headers.get("trailer") is None
425      assert response.headers.get("x-hop-temp") is None
426  
427  
428  def test_proxy_rejects_websocket_upgrade(
429      client: TestClient,
430      auth_headers: dict,
431      monkeypatch,
432  ) -> None:
433      class StubService:
434          @staticmethod
435          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
436              return Endpoint(endpoint="10.57.1.91:40109")
437  
438      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
439      _set_http_client(client, _FakeAsyncClient())
440  
441      response = client.get(
442          "/v1/sandboxes/sbx-123/proxy/44772/ws",
443          headers={**auth_headers, "Upgrade": "websocket"},
444      )
445  
446      assert response.status_code == 400
447      assert response.json()["message"] == "Websocket upgrade is not supported yet"
448  
449  
450  def test_proxy_rejects_websocket_upgrade_for_post_and_mixed_case_header(
451      client: TestClient,
452      auth_headers: dict,
453      monkeypatch,
454  ) -> None:
455      class StubService:
456          @staticmethod
457          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
458              return Endpoint(endpoint="10.57.1.91:40109")
459  
460      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
461      _set_http_client(client, _FakeAsyncClient())
462  
463      response = client.post(
464          "/v1/sandboxes/sbx-123/proxy/44772/ws",
465          headers={**auth_headers, "Upgrade": "WebSocket"},
466          content=b"{}",
467      )
468  
469      assert response.status_code == 400
470      assert response.json()["message"] == "Websocket upgrade is not supported yet"
471  
472  
473  def test_proxy_websocket_relays_messages_and_forwards_safe_headers(
474      client: TestClient,
475      auth_headers: dict,
476      monkeypatch,
477  ) -> None:
478      class StubService:
479          @staticmethod
480          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
481              assert sandbox_id == "sbx-123"
482              assert port == 44772
483              assert resolve_internal is True
484              return Endpoint(
485                  endpoint="10.57.1.91:40109/proxy/44772",
486                  headers={OPEN_SANDBOX_INGRESS_HEADER: "sbx-123-44772"},
487              )
488  
489      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
490      backend = _FakeBackendWebSocket()
491      connector = _FakeWebSocketConnector(backend)
492      monkeypatch.setattr(proxy_api.websockets, "connect", connector)
493  
494      with client.websocket_connect(
495          "/v1/sandboxes/sbx-123/proxy/44772/ws?token=abc",
496          headers={
497              **auth_headers,
498              "Authorization": "Bearer top-secret",
499              "Cookie": "sid=secret",
500              "Origin": "https://ui.example.com",
501              "X-Trace": "trace-ws",
502          },
503          subprotocols=["claw.v1"],
504      ) as websocket:
505          assert websocket.receive_text() == "backend-ready"
506          websocket.send_text("client-ready")
507  
508      assert backend.sent == ["client-ready"]
509      assert backend.close_calls[0][0] == 1000
510  
511      call = connector.calls[0]
512      assert call["uri"] == "ws://10.57.1.91:40109/proxy/44772/ws?token=abc"
513      assert call["origin"] == Origin("https://ui.example.com")
514      assert call["subprotocols"] == ["claw.v1"]
515      lowered_headers = {
516          key.lower(): value for key, value in (call["additional_headers"] or {}).items()
517      }
518      assert "authorization" not in lowered_headers
519      assert "cookie" not in lowered_headers
520      assert "origin" not in lowered_headers
521      assert lowered_headers["opensandbox-ingress-to"] == "sbx-123-44772"
522      assert lowered_headers["x-trace"] == "trace-ws"
523  
524  
525  def test_proxy_maps_connect_error_to_502(
526      client: TestClient,
527      auth_headers: dict,
528      monkeypatch,
529  ) -> None:
530      class StubService:
531          @staticmethod
532          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
533              return Endpoint(endpoint="10.57.1.91:40109")
534  
535      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
536      fake_client = _FakeAsyncClient()
537      fake_client.raise_connect_error = True
538      _set_http_client(client, fake_client)
539  
540      response = client.get(
541          "/v1/sandboxes/sbx-123/proxy/44772/healthz",
542          headers=auth_headers,
543      )
544  
545      assert response.status_code == 502
546      assert "Could not connect to the backend sandbox" in response.json()["message"]
547  
548  
549  def test_proxy_maps_unexpected_error_to_500(
550      client: TestClient,
551      auth_headers: dict,
552      monkeypatch,
553  ) -> None:
554      class StubService:
555          @staticmethod
556          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
557              return Endpoint(endpoint="10.57.1.91:40109")
558  
559      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
560      fake_client = _FakeAsyncClient()
561      fake_client.raise_generic_error = True
562      _set_http_client(client, fake_client)
563  
564      response = client.get(
565          "/v1/sandboxes/sbx-123/proxy/44772/healthz",
566          headers=auth_headers,
567      )
568  
569      assert response.status_code == 500
570      assert "An internal error occurred in the proxy" in response.json()["message"]
571  
572  
573  def test_proxy_forwards_18080_without_server_side_egress_auth_check(
574      client: TestClient,
575      auth_headers: dict,
576      monkeypatch,
577  ) -> None:
578      class StubService:
579          @staticmethod
580          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
581              assert port == 18080
582              assert resolve_internal is True
583              return Endpoint(endpoint="10.57.1.91:18080")
584  
585      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
586      fake_client = _FakeAsyncClient()
587      fake_client.response = _FakeStreamingResponse(
588          status_code=401,
589          headers={"content-type": "application/json"},
590          chunks=[b'{"code":"UNAUTHORIZED"}'],
591      )
592      _set_http_client(client, fake_client)
593  
594      response = client.get(
595          "/v1/sandboxes/sbx-123/proxy/18080/policy",
596          headers=auth_headers,
597      )
598  
599      assert response.status_code == 401
600      assert response.json()["code"] == "UNAUTHORIZED"
601      assert fake_client.built is not None
602      assert fake_client.built["url"] == "http://10.57.1.91:18080/policy"
603  
604  
605  def test_proxy_forwards_egress_auth_header_for_18080(
606      client: TestClient,
607      auth_headers: dict,
608      monkeypatch,
609  ) -> None:
610      class StubService:
611          @staticmethod
612          def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint:
613              assert port == 18080
614              assert resolve_internal is True
615              return Endpoint(endpoint="10.57.1.91:18080")
616  
617      monkeypatch.setattr(lifecycle, "sandbox_service", StubService())
618  
619      fake_client = _FakeAsyncClient()
620      fake_client.response = _FakeStreamingResponse(
621          status_code=200,
622          headers={"content-type": "application/json"},
623          chunks=[b'{"status":"ok"}'],
624      )
625      _set_http_client(client, fake_client)
626  
627      response = client.get(
628          "/v1/sandboxes/sbx-123/proxy/18080/policy",
629          headers={**auth_headers, OPEN_SANDBOX_EGRESS_AUTH_HEADER: "egress-token"},
630      )
631  
632      assert response.status_code == 200
633      assert fake_client.built is not None
634      lowered_headers = {k.lower(): v for k, v in fake_client.built["headers"].items()}
635      assert lowered_headers[OPEN_SANDBOX_EGRESS_AUTH_HEADER.lower()] == "egress-token"