test_utils.py
1 import pytest 2 from fastapi import HTTPException 3 4 from mlflow.exceptions import MlflowException 5 from mlflow.gateway.exceptions import AIGatewayException 6 from mlflow.gateway.utils import ( 7 SearchRoutesToken, 8 _is_valid_uri, 9 assemble_uri_path, 10 check_configuration_route_name_collisions, 11 get_gateway_uri, 12 handle_incomplete_chunks, 13 is_valid_endpoint_name, 14 parse_sse_lines, 15 resolve_route_url, 16 safe_stream, 17 set_gateway_uri, 18 stream_sse_data, 19 to_sse_error_chunk, 20 translate_http_exception, 21 ) 22 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 23 24 25 @pytest.mark.parametrize( 26 ("base_url", "route"), 27 [ 28 ("http://127.0.0.1:6000", "gateway/test/invocations"), 29 ("http://127.0.0.1:6000/", "/gateway/test/invocations"), 30 ("http://127.0.0.1:6000/gateway", "/test/invocations"), 31 ("http://127.0.0.1:6000/gateway/", "/test/invocations"), 32 ("http://127.0.0.1:6000/gateway", "test/invocations"), 33 ("http://127.0.0.1:6000/gateway/", "test/invocations"), 34 ], 35 ) 36 def test_resolve_route_url(base_url, route): 37 assert resolve_route_url(base_url, route) == "http://127.0.0.1:6000/gateway/test/invocations" 38 39 40 @pytest.mark.parametrize("base_url", ["databricks", "databricks://my.workspace"]) 41 def test_resolve_route_url_qualified_url_ignores_base(base_url): 42 route = "https://my.databricks.workspace/api/2.0/gateway/chat/invocations" 43 44 resolved = resolve_route_url(base_url, route) 45 46 assert resolved == route 47 48 49 @pytest.mark.parametrize( 50 ("name", "expected"), 51 [ 52 ("validName", True), 53 ("valid-name", True), 54 ("valid_name", True), 55 ("valid.name", True), 56 ("valid123", True), 57 ("invalid name", False), 58 ("invalid/name", False), 59 ("invalid?name", False), 60 ("", False), 61 ("日本語", False), # Japanese characters 62 ("naïve", False), # accented characters 63 ("名前", False), # Chinese characters 64 ], 65 ) 66 def test_is_valid_endpoint_name(name, expected): 67 assert is_valid_endpoint_name(name) == expected 68 69 70 def test_check_configuration_route_name_collisions(): 71 config = {"endpoints": [{"name": "name1"}, {"name": "name2"}, {"name": "name1"}]} 72 with pytest.raises( 73 MlflowException, match="Duplicate names found in endpoint / route configurations" 74 ): 75 check_configuration_route_name_collisions(config) 76 77 78 @pytest.mark.parametrize( 79 ("uri", "expected"), 80 [ 81 ("http://localhost", True), 82 ("databricks", True), 83 ("localhost", False), 84 ("http:/localhost", False), 85 ("", False), 86 ], 87 ) 88 def test__is_valid_uri(uri, expected): 89 assert _is_valid_uri(uri) == expected 90 91 92 @pytest.mark.parametrize( 93 ("paths", "expected"), 94 [ 95 (["path1", "path2", "path3"], "/path1/path2/path3"), 96 (["/path1/", "/path2/", "/path3/"], "/path1/path2/path3"), 97 (["/path1//", "/path2//", "/path3//"], "/path1/path2/path3"), 98 (["path1", "", "path3"], "/path1/path3"), 99 (["", "", ""], "/"), 100 ([], "/"), 101 ], 102 ) 103 def test_assemble_uri_path(paths, expected): 104 assert assemble_uri_path(paths) == expected 105 106 107 def test_set_gateway_uri(monkeypatch): 108 monkeypatch.setattr("mlflow.gateway.utils._gateway_uri", None) 109 110 valid_uri = "http://localhost" 111 set_gateway_uri(valid_uri) 112 assert get_gateway_uri() == valid_uri 113 114 invalid_uri = "localhost" 115 with pytest.raises(MlflowException, match="The gateway uri provided is missing required"): 116 set_gateway_uri(invalid_uri) 117 118 119 def test_get_gateway_uri(monkeypatch): 120 monkeypatch.setattr("mlflow.gateway.utils._gateway_uri", None) 121 monkeypatch.delenv("MLFLOW_GATEWAY_URI", raising=False) 122 123 with pytest.raises(MlflowException, match="No Gateway server uri has been set"): 124 get_gateway_uri() 125 126 valid_uri = "http://localhost" 127 monkeypatch.setattr("mlflow.gateway.utils._gateway_uri", valid_uri) 128 assert get_gateway_uri() == valid_uri 129 130 monkeypatch.delenv("MLFLOW_GATEWAY_URI", raising=False) 131 set_gateway_uri(valid_uri) 132 assert get_gateway_uri() == valid_uri 133 134 135 def test_search_routes_token_decodes_correctly(): 136 token = SearchRoutesToken(12345) 137 encoded_token = token.encode() 138 decoded_token = SearchRoutesToken.decode(encoded_token) 139 assert decoded_token.index == token.index 140 141 142 @pytest.mark.parametrize( 143 "index", 144 [ 145 "not an integer", 146 -1, 147 None, 148 [1, 2, 3], 149 {"key": "value"}, 150 ], 151 ) 152 def test_search_routes_token_with_invalid_token_values(index): 153 token = SearchRoutesToken(index) 154 encoded_token = token.encode() 155 with pytest.raises(MlflowException, match="Invalid SearchRoutes token"): 156 SearchRoutesToken.decode(encoded_token) 157 158 159 @pytest.mark.asyncio 160 async def test_translate_http_exception_handles_ai_gateway_exception(): 161 @translate_http_exception 162 async def raise_ai_gateway_exception(): 163 raise AIGatewayException(status_code=503, detail="AI Gateway error") 164 165 with pytest.raises(HTTPException, match="AI Gateway error") as exc_info: 166 await raise_ai_gateway_exception() 167 168 assert exc_info.value.status_code == 503 169 assert exc_info.value.detail == "AI Gateway error" 170 171 172 @pytest.mark.asyncio 173 async def test_translate_http_exception_handles_mlflow_exception(): 174 @translate_http_exception 175 async def raise_mlflow_exception(): 176 raise MlflowException("Invalid parameter", error_code=INVALID_PARAMETER_VALUE) 177 178 with pytest.raises(HTTPException, match="Invalid parameter") as exc_info: 179 await raise_mlflow_exception() 180 181 assert exc_info.value.status_code == 400 182 assert exc_info.value.detail == { 183 "error_code": "INVALID_PARAMETER_VALUE", 184 "message": "Invalid parameter", 185 } 186 187 188 @pytest.mark.asyncio 189 async def test_translate_http_exception_passes_through_other_exceptions(): 190 @translate_http_exception 191 async def raise_value_error(): 192 raise ValueError("Some value error") 193 194 with pytest.raises(ValueError, match="Some value error"): 195 await raise_value_error() 196 197 198 def test_parse_sse_lines_single_data_line(): 199 chunk = b'data: {"message": "hello"}\n' 200 result = list(parse_sse_lines(chunk)) 201 assert result == [{"message": "hello"}] 202 203 204 def test_parse_sse_lines_multiple_data_lines(): 205 chunk = b'data: {"id": 1}\ndata: {"id": 2}\n' 206 result = list(parse_sse_lines(chunk)) 207 assert result == [{"id": 1}, {"id": 2}] 208 209 210 def test_parse_sse_lines_with_event_lines(): 211 chunk = b'event: message\ndata: {"content": "test"}\n' 212 result = list(parse_sse_lines(chunk)) 213 assert result == [{"content": "test"}] 214 215 216 def test_parse_sse_lines_done_marker(): 217 chunk = b"data: [DONE]\n" 218 result = list(parse_sse_lines(chunk)) 219 assert result == [] 220 221 222 def test_parse_sse_lines_empty_data(): 223 chunk = b"data: \n" 224 result = list(parse_sse_lines(chunk)) 225 assert result == [] 226 227 228 def test_parse_sse_lines_empty_chunk(): 229 result = list(parse_sse_lines(b"")) 230 assert result == [] 231 232 233 def test_parse_sse_lines_string_input(): 234 chunk = 'data: {"key": "value"}\n' 235 result = list(parse_sse_lines(chunk)) 236 assert result == [{"key": "value"}] 237 238 239 def test_parse_sse_lines_invalid_json(): 240 chunk = b"data: {invalid json}\n" 241 result = list(parse_sse_lines(chunk)) 242 assert result == [] 243 244 245 def test_parse_sse_lines_mixed_valid_invalid(): 246 chunk = b'data: {"valid": true}\ndata: invalid\ndata: {"also": "valid"}\n' 247 result = list(parse_sse_lines(chunk)) 248 assert result == [{"valid": True}, {"also": "valid"}] 249 250 251 def test_parse_sse_lines_invalid_utf8(): 252 chunk = b"\xff\xfe" 253 result = list(parse_sse_lines(chunk)) 254 assert result == [] 255 256 257 def test_parse_sse_lines_non_data_lines_ignored(): 258 chunk = b'id: 123\nretry: 1000\ndata: {"message": "test"}\n' 259 result = list(parse_sse_lines(chunk)) 260 assert result == [{"message": "test"}] 261 262 263 @pytest.mark.asyncio 264 async def test_stream_sse_data_yields_parsed_json(): 265 async def mock_stream(): 266 yield b'data: {"chunk": 1}\n' 267 yield b'data: {"chunk": 2}\n' 268 269 results = [data async for data in stream_sse_data(mock_stream())] 270 assert results == [{"chunk": 1}, {"chunk": 2}] 271 272 273 @pytest.mark.asyncio 274 async def test_stream_sse_data_skips_done(): 275 async def mock_stream(): 276 yield b'data: {"chunk": 1}\n' 277 yield b"data: [DONE]\n" 278 279 results = [data async for data in stream_sse_data(mock_stream())] 280 assert results == [{"chunk": 1}] 281 282 283 @pytest.mark.asyncio 284 async def test_stream_sse_data_skips_empty_lines(): 285 async def mock_stream(): 286 yield b"" 287 yield b'data: {"chunk": 1}\n' 288 yield b" " 289 yield b'data: {"chunk": 2}\n' 290 291 results = [data async for data in stream_sse_data(mock_stream())] 292 assert results == [{"chunk": 1}, {"chunk": 2}] 293 294 295 @pytest.mark.asyncio 296 async def test_stream_sse_data_skips_invalid_json(): 297 async def mock_stream(): 298 yield b'data: {"valid": true}\n' 299 yield b"data: not json\n" 300 yield b'data: {"also_valid": true}\n' 301 302 results = [data async for data in stream_sse_data(mock_stream())] 303 assert results == [{"valid": True}, {"also_valid": True}] 304 305 306 @pytest.mark.asyncio 307 async def test_handle_incomplete_chunks_complete_lines(): 308 async def mock_stream(): 309 yield b"line1\nline2\n" 310 311 results = [chunk async for chunk in handle_incomplete_chunks(mock_stream())] 312 assert results == [b"line1", b"line2"] 313 314 315 @pytest.mark.asyncio 316 async def test_handle_incomplete_chunks_split_across_chunks(): 317 async def mock_stream(): 318 yield b"he" 319 yield b"llo\nwor" 320 yield b"ld\n" 321 322 results = [chunk async for chunk in handle_incomplete_chunks(mock_stream())] 323 assert results == [b"hello", b"world"] 324 325 326 @pytest.mark.asyncio 327 async def test_handle_incomplete_chunks_trailing_data(): 328 async def mock_stream(): 329 yield b"line1\nline2" 330 331 results = [chunk async for chunk in handle_incomplete_chunks(mock_stream())] 332 assert results == [b"line1", b"line2"] 333 334 335 @pytest.mark.asyncio 336 async def test_handle_incomplete_chunks_no_newline(): 337 async def mock_stream(): 338 yield b"no newline at all" 339 340 results = [chunk async for chunk in handle_incomplete_chunks(mock_stream())] 341 assert results == [b"no newline at all"] 342 343 344 def test_to_sse_error_chunk(): 345 error = ValueError("Something went wrong") 346 result = to_sse_error_chunk(error) 347 assert ( 348 result == 'data: {"error": {"message": "Something went wrong", "type": "ValueError"}}\n\n' 349 ) 350 351 352 def test_to_sse_error_chunk_with_custom_exception(): 353 class CustomError(Exception): 354 pass 355 356 error = CustomError("Custom error message") 357 result = to_sse_error_chunk(error) 358 assert ( 359 result == 'data: {"error": {"message": "Custom error message", "type": "CustomError"}}\n\n' 360 ) 361 362 363 @pytest.mark.asyncio 364 async def test_safe_stream_passes_through_chunks(): 365 async def mock_stream(): 366 yield "chunk1" 367 yield "chunk2" 368 yield "chunk3" 369 370 results = [chunk async for chunk in safe_stream(mock_stream())] 371 assert results == ["chunk1", "chunk2", "chunk3"] 372 373 374 @pytest.mark.asyncio 375 async def test_safe_stream_catches_exception_and_yields_error_chunk(): 376 async def mock_stream(): 377 yield "chunk1" 378 raise RuntimeError("Stream failed") 379 380 results = [chunk async for chunk in safe_stream(mock_stream())] 381 assert len(results) == 2 382 assert results[0] == "chunk1" 383 assert '"error"' in results[1] 384 assert '"message": "Stream failed"' in results[1] 385 assert '"type": "RuntimeError"' in results[1] 386 387 388 @pytest.mark.asyncio 389 async def test_safe_stream_as_bytes(): 390 async def mock_stream(): 391 yield b"chunk1" 392 raise ValueError("Bytes stream failed") 393 394 results = [chunk async for chunk in safe_stream(mock_stream(), as_bytes=True)] 395 assert len(results) == 2 396 assert results[0] == b"chunk1" 397 assert isinstance(results[1], bytes) 398 assert b'"error"' in results[1] 399 assert b'"message": "Bytes stream failed"' in results[1] 400 assert b'"type": "ValueError"' in results[1] 401 402 403 @pytest.mark.asyncio 404 async def test_safe_stream_no_exception(): 405 async def mock_stream(): 406 yield "data1" 407 yield "data2" 408 409 results = [chunk async for chunk in safe_stream(mock_stream())] 410 assert results == ["data1", "data2"]