/ tests / gateway / test_utils.py
test_utils.py
  1  import pytest
  2  from fastapi import HTTPException
  3  
  4  from mlflow.exceptions import MlflowException
  5  from mlflow.gateway.exceptions import AIGatewayException
  6  from mlflow.gateway.utils import (
  7      SearchRoutesToken,
  8      _is_valid_uri,
  9      assemble_uri_path,
 10      check_configuration_route_name_collisions,
 11      get_gateway_uri,
 12      handle_incomplete_chunks,
 13      is_valid_endpoint_name,
 14      parse_sse_lines,
 15      resolve_route_url,
 16      safe_stream,
 17      set_gateway_uri,
 18      stream_sse_data,
 19      to_sse_error_chunk,
 20      translate_http_exception,
 21  )
 22  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
 23  
 24  
 25  @pytest.mark.parametrize(
 26      ("base_url", "route"),
 27      [
 28          ("http://127.0.0.1:6000", "gateway/test/invocations"),
 29          ("http://127.0.0.1:6000/", "/gateway/test/invocations"),
 30          ("http://127.0.0.1:6000/gateway", "/test/invocations"),
 31          ("http://127.0.0.1:6000/gateway/", "/test/invocations"),
 32          ("http://127.0.0.1:6000/gateway", "test/invocations"),
 33          ("http://127.0.0.1:6000/gateway/", "test/invocations"),
 34      ],
 35  )
 36  def test_resolve_route_url(base_url, route):
 37      assert resolve_route_url(base_url, route) == "http://127.0.0.1:6000/gateway/test/invocations"
 38  
 39  
 40  @pytest.mark.parametrize("base_url", ["databricks", "databricks://my.workspace"])
 41  def test_resolve_route_url_qualified_url_ignores_base(base_url):
 42      route = "https://my.databricks.workspace/api/2.0/gateway/chat/invocations"
 43  
 44      resolved = resolve_route_url(base_url, route)
 45  
 46      assert resolved == route
 47  
 48  
 49  @pytest.mark.parametrize(
 50      ("name", "expected"),
 51      [
 52          ("validName", True),
 53          ("valid-name", True),
 54          ("valid_name", True),
 55          ("valid.name", True),
 56          ("valid123", True),
 57          ("invalid name", False),
 58          ("invalid/name", False),
 59          ("invalid?name", False),
 60          ("", False),
 61          ("日本語", False),  # Japanese characters
 62          ("naïve", False),  # accented characters
 63          ("名前", False),  # Chinese characters
 64      ],
 65  )
 66  def test_is_valid_endpoint_name(name, expected):
 67      assert is_valid_endpoint_name(name) == expected
 68  
 69  
 70  def test_check_configuration_route_name_collisions():
 71      config = {"endpoints": [{"name": "name1"}, {"name": "name2"}, {"name": "name1"}]}
 72      with pytest.raises(
 73          MlflowException, match="Duplicate names found in endpoint / route configurations"
 74      ):
 75          check_configuration_route_name_collisions(config)
 76  
 77  
 78  @pytest.mark.parametrize(
 79      ("uri", "expected"),
 80      [
 81          ("http://localhost", True),
 82          ("databricks", True),
 83          ("localhost", False),
 84          ("http:/localhost", False),
 85          ("", False),
 86      ],
 87  )
 88  def test__is_valid_uri(uri, expected):
 89      assert _is_valid_uri(uri) == expected
 90  
 91  
 92  @pytest.mark.parametrize(
 93      ("paths", "expected"),
 94      [
 95          (["path1", "path2", "path3"], "/path1/path2/path3"),
 96          (["/path1/", "/path2/", "/path3/"], "/path1/path2/path3"),
 97          (["/path1//", "/path2//", "/path3//"], "/path1/path2/path3"),
 98          (["path1", "", "path3"], "/path1/path3"),
 99          (["", "", ""], "/"),
100          ([], "/"),
101      ],
102  )
103  def test_assemble_uri_path(paths, expected):
104      assert assemble_uri_path(paths) == expected
105  
106  
107  def test_set_gateway_uri(monkeypatch):
108      monkeypatch.setattr("mlflow.gateway.utils._gateway_uri", None)
109  
110      valid_uri = "http://localhost"
111      set_gateway_uri(valid_uri)
112      assert get_gateway_uri() == valid_uri
113  
114      invalid_uri = "localhost"
115      with pytest.raises(MlflowException, match="The gateway uri provided is missing required"):
116          set_gateway_uri(invalid_uri)
117  
118  
119  def test_get_gateway_uri(monkeypatch):
120      monkeypatch.setattr("mlflow.gateway.utils._gateway_uri", None)
121      monkeypatch.delenv("MLFLOW_GATEWAY_URI", raising=False)
122  
123      with pytest.raises(MlflowException, match="No Gateway server uri has been set"):
124          get_gateway_uri()
125  
126      valid_uri = "http://localhost"
127      monkeypatch.setattr("mlflow.gateway.utils._gateway_uri", valid_uri)
128      assert get_gateway_uri() == valid_uri
129  
130      monkeypatch.delenv("MLFLOW_GATEWAY_URI", raising=False)
131      set_gateway_uri(valid_uri)
132      assert get_gateway_uri() == valid_uri
133  
134  
135  def test_search_routes_token_decodes_correctly():
136      token = SearchRoutesToken(12345)
137      encoded_token = token.encode()
138      decoded_token = SearchRoutesToken.decode(encoded_token)
139      assert decoded_token.index == token.index
140  
141  
142  @pytest.mark.parametrize(
143      "index",
144      [
145          "not an integer",
146          -1,
147          None,
148          [1, 2, 3],
149          {"key": "value"},
150      ],
151  )
152  def test_search_routes_token_with_invalid_token_values(index):
153      token = SearchRoutesToken(index)
154      encoded_token = token.encode()
155      with pytest.raises(MlflowException, match="Invalid SearchRoutes token"):
156          SearchRoutesToken.decode(encoded_token)
157  
158  
159  @pytest.mark.asyncio
160  async def test_translate_http_exception_handles_ai_gateway_exception():
161      @translate_http_exception
162      async def raise_ai_gateway_exception():
163          raise AIGatewayException(status_code=503, detail="AI Gateway error")
164  
165      with pytest.raises(HTTPException, match="AI Gateway error") as exc_info:
166          await raise_ai_gateway_exception()
167  
168      assert exc_info.value.status_code == 503
169      assert exc_info.value.detail == "AI Gateway error"
170  
171  
172  @pytest.mark.asyncio
173  async def test_translate_http_exception_handles_mlflow_exception():
174      @translate_http_exception
175      async def raise_mlflow_exception():
176          raise MlflowException("Invalid parameter", error_code=INVALID_PARAMETER_VALUE)
177  
178      with pytest.raises(HTTPException, match="Invalid parameter") as exc_info:
179          await raise_mlflow_exception()
180  
181      assert exc_info.value.status_code == 400
182      assert exc_info.value.detail == {
183          "error_code": "INVALID_PARAMETER_VALUE",
184          "message": "Invalid parameter",
185      }
186  
187  
188  @pytest.mark.asyncio
189  async def test_translate_http_exception_passes_through_other_exceptions():
190      @translate_http_exception
191      async def raise_value_error():
192          raise ValueError("Some value error")
193  
194      with pytest.raises(ValueError, match="Some value error"):
195          await raise_value_error()
196  
197  
198  def test_parse_sse_lines_single_data_line():
199      chunk = b'data: {"message": "hello"}\n'
200      result = list(parse_sse_lines(chunk))
201      assert result == [{"message": "hello"}]
202  
203  
204  def test_parse_sse_lines_multiple_data_lines():
205      chunk = b'data: {"id": 1}\ndata: {"id": 2}\n'
206      result = list(parse_sse_lines(chunk))
207      assert result == [{"id": 1}, {"id": 2}]
208  
209  
210  def test_parse_sse_lines_with_event_lines():
211      chunk = b'event: message\ndata: {"content": "test"}\n'
212      result = list(parse_sse_lines(chunk))
213      assert result == [{"content": "test"}]
214  
215  
216  def test_parse_sse_lines_done_marker():
217      chunk = b"data: [DONE]\n"
218      result = list(parse_sse_lines(chunk))
219      assert result == []
220  
221  
222  def test_parse_sse_lines_empty_data():
223      chunk = b"data: \n"
224      result = list(parse_sse_lines(chunk))
225      assert result == []
226  
227  
228  def test_parse_sse_lines_empty_chunk():
229      result = list(parse_sse_lines(b""))
230      assert result == []
231  
232  
233  def test_parse_sse_lines_string_input():
234      chunk = 'data: {"key": "value"}\n'
235      result = list(parse_sse_lines(chunk))
236      assert result == [{"key": "value"}]
237  
238  
239  def test_parse_sse_lines_invalid_json():
240      chunk = b"data: {invalid json}\n"
241      result = list(parse_sse_lines(chunk))
242      assert result == []
243  
244  
245  def test_parse_sse_lines_mixed_valid_invalid():
246      chunk = b'data: {"valid": true}\ndata: invalid\ndata: {"also": "valid"}\n'
247      result = list(parse_sse_lines(chunk))
248      assert result == [{"valid": True}, {"also": "valid"}]
249  
250  
251  def test_parse_sse_lines_invalid_utf8():
252      chunk = b"\xff\xfe"
253      result = list(parse_sse_lines(chunk))
254      assert result == []
255  
256  
257  def test_parse_sse_lines_non_data_lines_ignored():
258      chunk = b'id: 123\nretry: 1000\ndata: {"message": "test"}\n'
259      result = list(parse_sse_lines(chunk))
260      assert result == [{"message": "test"}]
261  
262  
263  @pytest.mark.asyncio
264  async def test_stream_sse_data_yields_parsed_json():
265      async def mock_stream():
266          yield b'data: {"chunk": 1}\n'
267          yield b'data: {"chunk": 2}\n'
268  
269      results = [data async for data in stream_sse_data(mock_stream())]
270      assert results == [{"chunk": 1}, {"chunk": 2}]
271  
272  
273  @pytest.mark.asyncio
274  async def test_stream_sse_data_skips_done():
275      async def mock_stream():
276          yield b'data: {"chunk": 1}\n'
277          yield b"data: [DONE]\n"
278  
279      results = [data async for data in stream_sse_data(mock_stream())]
280      assert results == [{"chunk": 1}]
281  
282  
283  @pytest.mark.asyncio
284  async def test_stream_sse_data_skips_empty_lines():
285      async def mock_stream():
286          yield b""
287          yield b'data: {"chunk": 1}\n'
288          yield b"   "
289          yield b'data: {"chunk": 2}\n'
290  
291      results = [data async for data in stream_sse_data(mock_stream())]
292      assert results == [{"chunk": 1}, {"chunk": 2}]
293  
294  
295  @pytest.mark.asyncio
296  async def test_stream_sse_data_skips_invalid_json():
297      async def mock_stream():
298          yield b'data: {"valid": true}\n'
299          yield b"data: not json\n"
300          yield b'data: {"also_valid": true}\n'
301  
302      results = [data async for data in stream_sse_data(mock_stream())]
303      assert results == [{"valid": True}, {"also_valid": True}]
304  
305  
306  @pytest.mark.asyncio
307  async def test_handle_incomplete_chunks_complete_lines():
308      async def mock_stream():
309          yield b"line1\nline2\n"
310  
311      results = [chunk async for chunk in handle_incomplete_chunks(mock_stream())]
312      assert results == [b"line1", b"line2"]
313  
314  
315  @pytest.mark.asyncio
316  async def test_handle_incomplete_chunks_split_across_chunks():
317      async def mock_stream():
318          yield b"he"
319          yield b"llo\nwor"
320          yield b"ld\n"
321  
322      results = [chunk async for chunk in handle_incomplete_chunks(mock_stream())]
323      assert results == [b"hello", b"world"]
324  
325  
326  @pytest.mark.asyncio
327  async def test_handle_incomplete_chunks_trailing_data():
328      async def mock_stream():
329          yield b"line1\nline2"
330  
331      results = [chunk async for chunk in handle_incomplete_chunks(mock_stream())]
332      assert results == [b"line1", b"line2"]
333  
334  
335  @pytest.mark.asyncio
336  async def test_handle_incomplete_chunks_no_newline():
337      async def mock_stream():
338          yield b"no newline at all"
339  
340      results = [chunk async for chunk in handle_incomplete_chunks(mock_stream())]
341      assert results == [b"no newline at all"]
342  
343  
344  def test_to_sse_error_chunk():
345      error = ValueError("Something went wrong")
346      result = to_sse_error_chunk(error)
347      assert (
348          result == 'data: {"error": {"message": "Something went wrong", "type": "ValueError"}}\n\n'
349      )
350  
351  
352  def test_to_sse_error_chunk_with_custom_exception():
353      class CustomError(Exception):
354          pass
355  
356      error = CustomError("Custom error message")
357      result = to_sse_error_chunk(error)
358      assert (
359          result == 'data: {"error": {"message": "Custom error message", "type": "CustomError"}}\n\n'
360      )
361  
362  
363  @pytest.mark.asyncio
364  async def test_safe_stream_passes_through_chunks():
365      async def mock_stream():
366          yield "chunk1"
367          yield "chunk2"
368          yield "chunk3"
369  
370      results = [chunk async for chunk in safe_stream(mock_stream())]
371      assert results == ["chunk1", "chunk2", "chunk3"]
372  
373  
374  @pytest.mark.asyncio
375  async def test_safe_stream_catches_exception_and_yields_error_chunk():
376      async def mock_stream():
377          yield "chunk1"
378          raise RuntimeError("Stream failed")
379  
380      results = [chunk async for chunk in safe_stream(mock_stream())]
381      assert len(results) == 2
382      assert results[0] == "chunk1"
383      assert '"error"' in results[1]
384      assert '"message": "Stream failed"' in results[1]
385      assert '"type": "RuntimeError"' in results[1]
386  
387  
388  @pytest.mark.asyncio
389  async def test_safe_stream_as_bytes():
390      async def mock_stream():
391          yield b"chunk1"
392          raise ValueError("Bytes stream failed")
393  
394      results = [chunk async for chunk in safe_stream(mock_stream(), as_bytes=True)]
395      assert len(results) == 2
396      assert results[0] == b"chunk1"
397      assert isinstance(results[1], bytes)
398      assert b'"error"' in results[1]
399      assert b'"message": "Bytes stream failed"' in results[1]
400      assert b'"type": "ValueError"' in results[1]
401  
402  
403  @pytest.mark.asyncio
404  async def test_safe_stream_no_exception():
405      async def mock_stream():
406          yield "data1"
407          yield "data2"
408  
409      results = [chunk async for chunk in safe_stream(mock_stream())]
410      assert results == ["data1", "data2"]