test_handlers.py
1 import json 2 import uuid 3 from dataclasses import asdict 4 from datetime import datetime, timezone 5 from unittest import mock 6 7 import pytest 8 from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan 9 10 import mlflow 11 from mlflow.entities import ( 12 GatewayBudgetPolicy, 13 Issue, 14 IssueSeverity, 15 IssueStatus, 16 RunStatus, 17 ScorerVersion, 18 Span, 19 Trace, 20 TraceData, 21 TraceInfo, 22 TraceState, 23 ViewType, 24 ) 25 from mlflow.entities._job import Job as JobEntity 26 from mlflow.entities._job_status import JobStatus 27 from mlflow.entities.gateway_budget_policy import ( 28 BudgetAction, 29 BudgetDuration, 30 BudgetDurationUnit, 31 BudgetTargetScope, 32 BudgetUnit, 33 ) 34 from mlflow.entities.model_registry import ( 35 ModelVersion, 36 ModelVersionTag, 37 PromptVersion, 38 RegisteredModel, 39 RegisteredModelTag, 40 ) 41 from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY, PROMPT_TEXT_TAG_KEY 42 from mlflow.entities.presigned_download import PresignedDownloadUrlResponse 43 from mlflow.entities.presigned_upload import CreatePresignedUploadResponse 44 from mlflow.entities.trace_location import TraceLocation as EntityTraceLocation 45 from mlflow.entities.trace_metrics import ( 46 AggregationType, 47 MetricAggregation, 48 MetricDataPoint, 49 MetricViewType, 50 ) 51 from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES 52 from mlflow.exceptions import MlflowException, MlflowNotImplementedException 53 from mlflow.gateway.budget_tracker.in_memory import InMemoryBudgetTracker 54 from mlflow.genai.scorers.online.entities import OnlineScoringConfig 55 from mlflow.protos.databricks_pb2 import ( 56 INTERNAL_ERROR, 57 INVALID_PARAMETER_VALUE, 58 NOT_IMPLEMENTED, 59 RESOURCE_DOES_NOT_EXIST, 60 ErrorCode, 61 ) 62 from mlflow.protos.issues_pb2 import ( 63 CreateIssue, 64 SearchIssues, 65 UpdateIssue, 66 ) 67 from mlflow.protos.model_registry_pb2 import ( 68 CreateModelVersion, 69 CreateRegisteredModel, 70 DeleteModelVersion, 71 DeleteModelVersionTag, 72 DeleteRegisteredModel, 73 DeleteRegisteredModelAlias, 74 DeleteRegisteredModelTag, 75 GetLatestVersions, 76 GetModelVersion, 77 GetModelVersionByAlias, 78 GetModelVersionDownloadUri, 79 GetRegisteredModel, 80 RenameRegisteredModel, 81 SearchModelVersions, 82 SearchRegisteredModels, 83 SetModelVersionTag, 84 SetRegisteredModelAlias, 85 SetRegisteredModelTag, 86 TransitionModelVersionStage, 87 UpdateModelVersion, 88 UpdateRegisteredModel, 89 ) 90 from mlflow.protos.prompt_optimization_pb2 import ( 91 OPTIMIZER_TYPE_GEPA, 92 OPTIMIZER_TYPE_METAPROMPT, 93 OPTIMIZER_TYPE_UNSPECIFIED, 94 ) 95 from mlflow.protos.service_pb2 import ( 96 BatchGetTraceInfos, 97 BatchGetTraces, 98 CalculateTraceFilterCorrelation, 99 CreateExperiment, 100 DeleteScorer, 101 DeleteTraceTag, 102 DeleteTraceTagV3, 103 GatewayEndpoint, 104 GetGatewayEndpoint, 105 GetScorer, 106 GetTrace, 107 LinkPromptsToTrace, 108 ListScorers, 109 ListScorerVersions, 110 QueryTraceMetrics, 111 RegisterScorer, 112 SearchExperiments, 113 SearchLoggedModels, 114 SearchRuns, 115 SearchTraces, 116 SearchTracesV3, 117 SetTraceTag, 118 SetTraceTagV3, 119 TraceLocation, 120 ) 121 from mlflow.protos.webhooks_pb2 import ListWebhooks 122 from mlflow.server import ( 123 ARTIFACTS_DESTINATION_ENV_VAR, 124 BACKEND_STORE_URI_ENV_VAR, 125 SERVE_ARTIFACTS_ENV_VAR, 126 app, 127 ) 128 from mlflow.server.handlers import ( 129 ARTIFACT_STREAM_CHUNK_SIZE, 130 STATIC_PREFIX_ENV_VAR, 131 ModelRegistryStoreRegistryWrapper, 132 TrackingStoreRegistryWrapper, 133 _batch_get_trace_infos, 134 _batch_get_traces, 135 _calculate_trace_filter_correlation, 136 _cancel_prompt_optimization_job, 137 _convert_path_parameter_to_flask_format, 138 _create_dataset_handler, 139 _create_experiment, 140 _create_issue, 141 _create_model_version, 142 _create_presigned_upload_url, 143 _create_prompt_optimization_job, 144 _create_registered_model, 145 _delete_artifact_mlflow_artifacts, 146 _delete_dataset_handler, 147 _delete_dataset_tag_handler, 148 _delete_model_version, 149 _delete_model_version_tag, 150 _delete_registered_model, 151 _delete_registered_model_alias, 152 _delete_registered_model_tag, 153 _delete_scorer, 154 _delete_trace_tag, 155 _delete_trace_tag_v3, 156 _deprecated_search_traces_v2, 157 _download_artifact, 158 _get_ajax_path, 159 _get_dataset_experiment_ids_handler, 160 _get_dataset_handler, 161 _get_dataset_records_handler, 162 _get_gateway_endpoint, 163 _get_issue, 164 _get_latest_versions, 165 _get_model_version, 166 _get_model_version_by_alias, 167 _get_model_version_download_uri, 168 _get_presigned_download_url, 169 _get_registered_model, 170 _get_request_message, 171 _get_rest_path, 172 _get_scorer, 173 _get_trace, 174 _get_trace_artifact_repo, 175 _get_workspace_scoped_repo_path_if_enabled, 176 _link_prompts_to_trace, 177 _list_artifacts_for_proxied_run_artifact_root, 178 _list_scorer_versions, 179 _list_scorers, 180 _list_webhooks, 181 _log_batch, 182 _query_trace_metrics, 183 _register_scorer, 184 _rename_registered_model, 185 _search_evaluation_datasets_handler, 186 _search_experiments, 187 _search_issues, 188 _search_logged_models, 189 _search_model_versions, 190 _search_registered_models, 191 _search_runs, 192 _search_traces_v3, 193 _set_dataset_tags_handler, 194 _set_model_version_tag, 195 _set_registered_model_alias, 196 _set_registered_model_tag, 197 _set_trace_tag, 198 _set_trace_tag_v3, 199 _transition_stage, 200 _update_issue, 201 _update_model_version, 202 _update_registered_model, 203 _upsert_dataset_records_handler, 204 _validate_source_run, 205 catch_mlflow_exception, 206 get_artifact_handler, 207 get_endpoints, 208 get_logged_model_artifact_handler, 209 get_model_version_artifact_handler, 210 get_trace_artifact_handler, 211 get_ui_telemetry_handler, 212 post_ui_telemetry_handler, 213 upload_artifact_handler, 214 ) 215 from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore 216 from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository 217 from mlflow.store.artifact.local_artifact_repo import LocalArtifactRepository 218 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 219 from mlflow.store.entities.paged_list import PagedList 220 from mlflow.store.model_registry import ( 221 SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD, 222 SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, 223 ) 224 from mlflow.store.model_registry.rest_store import RestStore as ModelRegistryRestStore 225 from mlflow.store.tracking import MAX_RESULTS_QUERY_TRACE_METRICS 226 from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore 227 from mlflow.telemetry.schemas import Record, Status 228 from mlflow.tracing.analysis import TraceFilterCorrelationResult 229 from mlflow.tracing.utils import build_otel_context 230 from mlflow.utils.mlflow_tags import MLFLOW_ARTIFACT_LOCATION 231 from mlflow.utils.proto_json_utils import message_to_json 232 from mlflow.utils.validation import MAX_BATCH_LOG_REQUEST_SIZE 233 from mlflow.utils.workspace_context import WorkspaceContext 234 from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME 235 236 237 @pytest.fixture 238 def mock_get_request_message(): 239 with mock.patch("mlflow.server.handlers._get_request_message") as m: 240 yield m 241 242 243 @pytest.fixture 244 def mock_get_request_json(): 245 with mock.patch("mlflow.server.handlers._get_request_json") as m: 246 yield m 247 248 249 @pytest.fixture 250 def mock_tracking_store(): 251 with mock.patch("mlflow.server.handlers._get_tracking_store") as m: 252 mock_store = mock.MagicMock() 253 m.return_value = mock_store 254 yield mock_store 255 256 257 @pytest.fixture 258 def mock_model_registry_store(): 259 with mock.patch("mlflow.server.handlers._get_model_registry_store") as m: 260 mock_store = mock.MagicMock() 261 mock_store.list_webhooks_by_event.return_value = PagedList([], None) 262 m.return_value = mock_store 263 yield mock_store 264 265 266 @pytest.fixture 267 def enable_serve_artifacts(monkeypatch): 268 monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true") 269 270 271 @pytest.fixture 272 def mock_evaluation_dataset(): 273 from mlflow.protos.datasets_pb2 import Dataset as ProtoDataset 274 275 dataset = mock.MagicMock() 276 dataset.dataset_id = "d-1234567890abcdef1234567890abcdef" 277 dataset.name = "test_dataset" 278 dataset.digest = "abc123" 279 dataset.created_time = 1234567890 280 dataset.last_update_time = 1234567890 281 dataset.created_by = "test_user" 282 dataset.last_updated_by = "test_user" 283 dataset.tags = {"env": "test", "version": "1.0"} 284 dataset.experiment_ids = ["0", "1"] 285 dataset._records = [] 286 dataset.schema = json.dumps({ 287 "inputs": {"question": "string"}, 288 "expectations": {"accuracy": "float"}, 289 }) 290 dataset.profile = json.dumps({"record_count": 0}) 291 292 proto_dataset = ProtoDataset() 293 proto_dataset.dataset_id = dataset.dataset_id 294 proto_dataset.name = dataset.name 295 proto_dataset.digest = dataset.digest 296 proto_dataset.created_time = dataset.created_time 297 proto_dataset.last_update_time = dataset.last_update_time 298 proto_dataset.created_by = dataset.created_by 299 proto_dataset.last_updated_by = dataset.last_updated_by 300 proto_dataset.schema = dataset.schema 301 proto_dataset.profile = dataset.profile 302 303 dataset.to_proto = mock.MagicMock(return_value=proto_dataset) 304 305 return dataset 306 307 308 @pytest.fixture 309 def mock_telemetry_config_cache(): 310 with mock.patch("mlflow.server.handlers._telemetry_config_cache", {}) as m: 311 yield m 312 313 314 @pytest.fixture 315 def bypass_telemetry_env_check(monkeypatch): 316 monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_TESTING_TELEMETRY", False) 317 monkeypatch.setattr(mlflow.telemetry.utils, "_IS_IN_CI_ENV_OR_TESTING", False) 318 monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_DEV_VERSION", False) 319 320 321 @pytest.fixture 322 def mock_job_store(): 323 with mock.patch("mlflow.server.handlers._get_job_store") as m: 324 mock_store = mock.MagicMock() 325 m.return_value = mock_store 326 yield mock_store 327 328 329 def _create_mock_job( 330 job_id="job-123", 331 job_name="optimize_prompts", 332 status_name="PENDING", 333 params=None, 334 result=None, 335 creation_time=1234567890000, 336 status_details=None, 337 ): 338 from mlflow.entities._job import Job 339 from mlflow.entities._job_status import JobStatus 340 341 if params is None: 342 params = { 343 "experiment_id": "exp-123", 344 "prompt_uri": "prompts:/my-prompt/1", 345 "run_id": "run-456", 346 } 347 348 return Job( 349 job_id=job_id, 350 creation_time=creation_time, 351 job_name=job_name, 352 params=json.dumps(params), 353 timeout=None, 354 status=JobStatus.from_str(status_name), 355 result=json.dumps(result) if result and status_name == "SUCCEEDED" else result, 356 retry_count=0, 357 last_update_time=creation_time, 358 status_details=status_details, 359 ) 360 361 362 def _create_mock_run(run_id="run-456", params=None, metrics=None): 363 mock_run = mock.MagicMock() 364 mock_run.info.run_id = run_id 365 mock_run.data.params = params or {} 366 mock_run.data.metrics = metrics or {} 367 return mock_run 368 369 370 def test_health(): 371 with app.test_client() as c: 372 response = c.get("/health") 373 assert response.status_code == 200 374 assert response.get_data().decode() == "OK" 375 376 377 def test_version(): 378 with app.test_client() as c: 379 response = c.get("/version") 380 assert response.status_code == 200 381 assert response.get_data().decode() == mlflow.__version__ 382 383 384 def test_server_info(): 385 with app.test_client() as c: 386 response = c.get("/api/3.0/mlflow/server-info") 387 assert response.status_code == 200 388 data = response.get_json() 389 assert data["store_type"] == "SqlStore" 390 assert data["workspaces_enabled"] is False 391 392 393 def test_get_endpoints(): 394 endpoints = get_endpoints() 395 create_experiment_endpoint = [e for e in endpoints if e[1] == _create_experiment] 396 assert len(create_experiment_endpoint) == 2 397 398 399 def test_convert_path_parameter_to_flask_format(): 400 converted = _convert_path_parameter_to_flask_format("/mlflow/trace") 401 assert "/mlflow/trace" == converted 402 403 converted = _convert_path_parameter_to_flask_format("/mlflow/trace/{request_id}") 404 assert "/mlflow/trace/<request_id>" == converted 405 406 converted = _convert_path_parameter_to_flask_format("/mlflow/{foo}/{bar}/{baz}") 407 assert "/mlflow/<foo>/<bar>/<baz>" == converted 408 409 410 def test_all_model_registry_endpoints_available(): 411 endpoints = {handler: method for (path, handler, method) in get_endpoints()} 412 413 # Test that each of the handler is enabled as an endpoint with appropriate method. 414 expected_endpoints = { 415 "POST": [ 416 _create_registered_model, 417 _create_model_version, 418 _rename_registered_model, 419 _transition_stage, 420 ], 421 "PATCH": [_update_registered_model, _update_model_version], 422 "DELETE": [_delete_registered_model, _delete_registered_model], 423 "GET": [ 424 _search_model_versions, 425 _get_latest_versions, 426 _get_registered_model, 427 _get_model_version, 428 _get_model_version_download_uri, 429 ], 430 } 431 # TODO: efficient mechanism to test endpoint path 432 for method, handlers in expected_endpoints.items(): 433 for handler in handlers: 434 assert handler in endpoints 435 assert endpoints[handler] == [method] 436 437 438 def test_can_parse_json(): 439 request = mock.MagicMock() 440 request.method = "POST" 441 request.content_type = "application/json" 442 request.get_json = mock.MagicMock() 443 request.get_json.return_value = {"name": "hello"} 444 msg = _get_request_message(CreateExperiment(), flask_request=request) 445 assert msg.name == "hello" 446 447 448 def test_can_parse_post_json_with_unknown_fields(): 449 request = mock.MagicMock() 450 request.method = "POST" 451 request.content_type = "application/json" 452 request.get_json = mock.MagicMock() 453 request.get_json.return_value = {"name": "hello", "WHAT IS THIS FIELD EVEN": "DOING"} 454 msg = _get_request_message(CreateExperiment(), flask_request=request) 455 assert msg.name == "hello" 456 457 458 def test_can_parse_post_json_with_content_type_params(): 459 request = mock.MagicMock() 460 request.method = "POST" 461 request.content_type = "application/json; charset=utf-8" 462 request.get_json = mock.MagicMock() 463 request.get_json.return_value = {"name": "hello"} 464 msg = _get_request_message(CreateExperiment(), flask_request=request) 465 assert msg.name == "hello" 466 467 468 def test_can_parse_get_json_with_unknown_fields(): 469 request = mock.MagicMock() 470 request.method = "GET" 471 request.args = {"name": "hello", "superDuperUnknown": "field"} 472 msg = _get_request_message(CreateExperiment(), flask_request=request) 473 assert msg.name == "hello" 474 475 476 # Previous versions of the client sent a doubly string encoded JSON blob, 477 # so this test ensures continued compliance with such clients. 478 def test_can_parse_json_string(): 479 request = mock.MagicMock() 480 request.method = "POST" 481 request.content_type = "application/json" 482 request.get_json = mock.MagicMock() 483 request.get_json.return_value = '{"name": "hello2"}' 484 msg = _get_request_message(CreateExperiment(), flask_request=request) 485 assert msg.name == "hello2" 486 487 488 def test_can_block_post_request_with_invalid_content_type(): 489 request = mock.MagicMock() 490 request.method = "POST" 491 request.content_type = "text/plain" 492 request.get_json = mock.MagicMock() 493 request.get_json.return_value = {"name": "hello"} 494 with pytest.raises(MlflowException, match=r"Bad Request. Content-Type"): 495 _get_request_message(CreateExperiment(), flask_request=request) 496 497 498 def test_can_block_post_request_with_missing_content_type(): 499 request = mock.MagicMock() 500 request.method = "POST" 501 request.content_type = None 502 request.get_json = mock.MagicMock() 503 request.get_json.return_value = {"name": "hello"} 504 with pytest.raises(MlflowException, match=r"Bad Request. Content-Type"): 505 _get_request_message(CreateExperiment(), flask_request=request) 506 507 508 def test_search_runs_default_view_type(mock_get_request_message, mock_tracking_store): 509 """ 510 Search Runs default view type is filled in as ViewType.ACTIVE_ONLY 511 """ 512 mock_get_request_message.return_value = SearchRuns(experiment_ids=["0"]) 513 mock_tracking_store.search_runs.return_value = PagedList([], None) 514 _search_runs() 515 _, kwargs = mock_tracking_store.search_runs.call_args 516 assert kwargs["run_view_type"] == ViewType.ACTIVE_ONLY 517 518 519 def test_search_runs_empty_page_token(mock_get_request_message, mock_tracking_store): 520 """ 521 Test that empty page_token from protobuf is converted to None before calling store 522 """ 523 # Create proto without setting page_token 524 search_runs_proto = SearchRuns() 525 search_runs_proto.experiment_ids.append("0") 526 search_runs_proto.max_results = 10 527 # Verify protobuf returns empty string for unset field 528 assert search_runs_proto.page_token == "" 529 530 mock_get_request_message.return_value = search_runs_proto 531 mock_tracking_store.search_runs.return_value = PagedList([], None) 532 533 _search_runs() 534 535 # Verify store was called with None, not empty string 536 mock_tracking_store.search_runs.assert_called_once() 537 call_kwargs = mock_tracking_store.search_runs.call_args.kwargs 538 assert call_kwargs["page_token"] is None # page_token should be None, not "" 539 540 541 def test_log_batch_api_req(mock_get_request_json): 542 mock_get_request_json.return_value = "a" * (MAX_BATCH_LOG_REQUEST_SIZE + 1) 543 response = _log_batch() 544 assert response.status_code == 400 545 json_response = json.loads(response.get_data()) 546 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 547 assert ( 548 f"Batched logging API requests must be at most {MAX_BATCH_LOG_REQUEST_SIZE} bytes" 549 in json_response["message"] 550 ) 551 552 553 def test_catch_mlflow_exception(): 554 @catch_mlflow_exception 555 def test_handler(): 556 raise MlflowException("test error", error_code=INTERNAL_ERROR) 557 558 response = test_handler() 559 json_response = json.loads(response.get_data()) 560 assert response.status_code == 500 561 assert json_response["error_code"] == ErrorCode.Name(INTERNAL_ERROR) 562 assert json_response["message"] == "test error" 563 564 565 def test_mlflow_server_with_installed_plugin(tmp_path, monkeypatch): 566 pytest.skip("FileStore is no longer supported.") 567 from mlflow_test_plugin.file_store import PluginFileStore 568 569 monkeypatch.setenv(BACKEND_STORE_URI_ENV_VAR, f"file-plugin:{tmp_path}") 570 monkeypatch.setattr(mlflow.server.handlers, "_tracking_store", None) 571 plugin_file_store = mlflow.server.handlers._get_tracking_store() 572 assert isinstance(plugin_file_store, PluginFileStore) 573 assert plugin_file_store.is_plugin 574 575 576 def jsonify(obj): 577 def _jsonify(obj): 578 return json.loads(message_to_json(obj.to_proto())) 579 580 if isinstance(obj, list): 581 return [_jsonify(o) for o in obj] 582 else: 583 return _jsonify(obj) 584 585 586 # Tests for Model Registry handlers 587 def test_create_registered_model(mock_get_request_message, mock_model_registry_store): 588 tags = [ 589 RegisteredModelTag(key="key", value="value"), 590 RegisteredModelTag(key="anotherKey", value="some other value"), 591 ] 592 mock_get_request_message.return_value = CreateRegisteredModel( 593 name="model_1", tags=[tag.to_proto() for tag in tags] 594 ) 595 rm = RegisteredModel("model_1", tags=tags) 596 mock_model_registry_store.create_registered_model.return_value = rm 597 resp = _create_registered_model() 598 _, args = mock_model_registry_store.create_registered_model.call_args 599 assert args["name"] == "model_1" 600 assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags} 601 assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)} 602 603 604 def test_get_registered_model(mock_get_request_message, mock_model_registry_store): 605 name = "model1" 606 mock_get_request_message.return_value = GetRegisteredModel(name=name) 607 rmd = RegisteredModel( 608 name=name, 609 creation_timestamp=111, 610 last_updated_timestamp=222, 611 description="Test model", 612 latest_versions=[], 613 ) 614 mock_model_registry_store.get_registered_model.return_value = rmd 615 resp = _get_registered_model() 616 _, args = mock_model_registry_store.get_registered_model.call_args 617 assert args == {"name": name} 618 assert json.loads(resp.get_data()) == {"registered_model": jsonify(rmd)} 619 620 621 def test_update_registered_model(mock_get_request_message, mock_model_registry_store): 622 name = "model_1" 623 description = "Test model" 624 mock_get_request_message.return_value = UpdateRegisteredModel( 625 name=name, description=description 626 ) 627 rm2 = RegisteredModel(name, description=description) 628 mock_model_registry_store.update_registered_model.return_value = rm2 629 resp = _update_registered_model() 630 _, args = mock_model_registry_store.update_registered_model.call_args 631 assert args == {"name": name, "description": "Test model"} 632 assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)} 633 634 635 def test_rename_registered_model(mock_get_request_message, mock_model_registry_store): 636 name = "model_1" 637 new_name = "model_2" 638 mock_get_request_message.return_value = RenameRegisteredModel(name=name, new_name=new_name) 639 rm2 = RegisteredModel(new_name) 640 mock_model_registry_store.rename_registered_model.return_value = rm2 641 resp = _rename_registered_model() 642 _, args = mock_model_registry_store.rename_registered_model.call_args 643 assert args == {"name": name, "new_name": new_name} 644 assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)} 645 646 647 def test_delete_registered_model(mock_get_request_message, mock_model_registry_store): 648 name = "model_1" 649 mock_get_request_message.return_value = DeleteRegisteredModel(name=name) 650 _delete_registered_model() 651 _, args = mock_model_registry_store.delete_registered_model.call_args 652 assert args == {"name": name} 653 654 655 def test_search_registered_models(mock_get_request_message, mock_model_registry_store): 656 rmds = [ 657 RegisteredModel( 658 name="model_1", 659 creation_timestamp=111, 660 last_updated_timestamp=222, 661 description="Test model", 662 latest_versions=[], 663 ), 664 RegisteredModel( 665 name="model_2", 666 creation_timestamp=111, 667 last_updated_timestamp=333, 668 description="Another model", 669 latest_versions=[], 670 ), 671 ] 672 mock_get_request_message.return_value = SearchRegisteredModels() 673 mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, None) 674 resp = _search_registered_models() 675 _, args = mock_model_registry_store.search_registered_models.call_args 676 assert args == { 677 "filter_string": "", 678 "max_results": SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, 679 "order_by": [], 680 "page_token": None, 681 } 682 assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)} 683 684 mock_get_request_message.return_value = SearchRegisteredModels(filter="hello") 685 mock_model_registry_store.search_registered_models.return_value = PagedList(rmds[:1], "tok") 686 resp = _search_registered_models() 687 _, args = mock_model_registry_store.search_registered_models.call_args 688 assert args == { 689 "filter_string": "hello", 690 "max_results": SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, 691 "order_by": [], 692 "page_token": None, 693 } 694 assert json.loads(resp.get_data()) == { 695 "registered_models": jsonify(rmds[:1]), 696 "next_page_token": "tok", 697 } 698 699 mock_get_request_message.return_value = SearchRegisteredModels(filter="hi", max_results=5) 700 mock_model_registry_store.search_registered_models.return_value = PagedList([rmds[0]], "tik") 701 resp = _search_registered_models() 702 _, args = mock_model_registry_store.search_registered_models.call_args 703 assert args == {"filter_string": "hi", "max_results": 5, "order_by": [], "page_token": None} 704 assert json.loads(resp.get_data()) == { 705 "registered_models": jsonify([rmds[0]]), 706 "next_page_token": "tik", 707 } 708 709 mock_get_request_message.return_value = SearchRegisteredModels( 710 filter="hey", max_results=500, order_by=["a", "B desc"], page_token="prev" 711 ) 712 mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, "DONE") 713 resp = _search_registered_models() 714 _, args = mock_model_registry_store.search_registered_models.call_args 715 assert args == { 716 "filter_string": "hey", 717 "max_results": 500, 718 "order_by": ["a", "B desc"], 719 "page_token": "prev", 720 } 721 assert json.loads(resp.get_data()) == { 722 "registered_models": jsonify(rmds), 723 "next_page_token": "DONE", 724 } 725 726 727 def test_get_latest_versions(mock_get_request_message, mock_model_registry_store): 728 name = "model1" 729 mock_get_request_message.return_value = GetLatestVersions(name=name) 730 mvds = [ 731 ModelVersion( 732 name=name, 733 version="5", 734 creation_timestamp=1, 735 last_updated_timestamp=12, 736 description="v 5", 737 user_id="u1", 738 current_stage="Production", 739 source="A/B", 740 run_id=uuid.uuid4().hex, 741 status="READY", 742 status_message=None, 743 ), 744 ModelVersion( 745 name=name, 746 version="1", 747 creation_timestamp=1, 748 last_updated_timestamp=1200, 749 description="v 1", 750 user_id="u1", 751 current_stage="Archived", 752 source="A/B2", 753 run_id=uuid.uuid4().hex, 754 status="READY", 755 status_message=None, 756 ), 757 ModelVersion( 758 name=name, 759 version="12", 760 creation_timestamp=100, 761 last_updated_timestamp=None, 762 description="v 12", 763 user_id="u2", 764 current_stage="Staging", 765 source="A/B3", 766 run_id=uuid.uuid4().hex, 767 status="READY", 768 status_message=None, 769 ), 770 ] 771 mock_model_registry_store.get_latest_versions.return_value = mvds 772 resp = _get_latest_versions() 773 _, args = mock_model_registry_store.get_latest_versions.call_args 774 assert args == {"name": name, "stages": []} 775 assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)} 776 777 for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]: 778 mock_get_request_message.return_value = GetLatestVersions(name=name, stages=stages) 779 _get_latest_versions() 780 _, args = mock_model_registry_store.get_latest_versions.call_args 781 assert args == {"name": name, "stages": stages} 782 783 784 def test_create_model_version(mock_get_request_message, mock_model_registry_store): 785 run_id = uuid.uuid4().hex 786 tags = [ 787 ModelVersionTag(key="key", value="value"), 788 ModelVersionTag(key="anotherKey", value="some other value"), 789 ] 790 run_link = "localhost:5000/path/to/run" 791 mock_get_request_message.return_value = CreateModelVersion( 792 name="model_1", 793 source=f"runs:/{run_id}", 794 run_id=run_id, 795 run_link=run_link, 796 tags=[tag.to_proto() for tag in tags], 797 ) 798 mv = ModelVersion( 799 name="model_1", version="12", creation_timestamp=123, tags=tags, run_link=run_link 800 ) 801 mock_model_registry_store.create_model_version.return_value = mv 802 resp = _create_model_version() 803 _, args = mock_model_registry_store.create_model_version.call_args 804 assert args["name"] == "model_1" 805 assert args["source"] == f"runs:/{run_id}" 806 assert args["run_id"] == run_id 807 assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags} 808 assert args["run_link"] == run_link 809 assert json.loads(resp.get_data()) == {"model_version": jsonify(mv)} 810 811 812 @pytest.mark.parametrize( 813 "source", 814 [ 815 "file:///etc/passwd", 816 "file:///", 817 "/etc/passwd", 818 "file:///proc/self/environ", 819 "file://remote-host/etc/passwd", 820 "file://remote-host/", 821 ], 822 ) 823 def test_create_model_version_rejects_local_source_for_prompts( 824 mock_get_request_message, mock_model_registry_store, source 825 ): 826 mock_get_request_message.return_value = CreateModelVersion( 827 name="model_1", 828 source=source, 829 tags=[ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true").to_proto()], 830 ) 831 resp = _create_model_version() 832 assert resp.status_code == 400 833 assert "Invalid prompt source" in resp.get_json()["message"] 834 835 836 @pytest.mark.parametrize( 837 "source", 838 [ 839 "https://example.com/../../etc/passwd", 840 "http://example.com/path/..%2f..%2fsecret", 841 ], 842 ) 843 def test_create_model_version_rejects_traversal_source_for_prompts( 844 mock_get_request_message, mock_model_registry_store, source 845 ): 846 mock_get_request_message.return_value = CreateModelVersion( 847 name="model_1", 848 source=source, 849 tags=[ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true").to_proto()], 850 ) 851 resp = _create_model_version() 852 assert resp.status_code == 400 853 assert "Invalid model version source" in resp.get_json()["message"] 854 855 856 def test_set_registered_model_tag(mock_get_request_message, mock_model_registry_store): 857 name = "model1" 858 tag = RegisteredModelTag(key="some weird key", value="some value") 859 mock_get_request_message.return_value = SetRegisteredModelTag( 860 name=name, key=tag.key, value=tag.value 861 ) 862 _set_registered_model_tag() 863 _, args = mock_model_registry_store.set_registered_model_tag.call_args 864 assert args == {"name": name, "tag": tag} 865 866 867 def test_delete_registered_model_tag(mock_get_request_message, mock_model_registry_store): 868 name = "model1" 869 key = "some weird key" 870 mock_get_request_message.return_value = DeleteRegisteredModelTag(name=name, key=key) 871 _delete_registered_model_tag() 872 _, args = mock_model_registry_store.delete_registered_model_tag.call_args 873 assert args == {"name": name, "key": key} 874 875 876 def test_get_model_version_details(mock_get_request_message, mock_model_registry_store): 877 mock_get_request_message.return_value = GetModelVersion(name="model1", version="32") 878 mvd = ModelVersion( 879 name="model1", 880 version="5", 881 creation_timestamp=1, 882 last_updated_timestamp=12, 883 description="v 5", 884 user_id="u1", 885 current_stage="Production", 886 source="A/B", 887 run_id=uuid.uuid4().hex, 888 status="READY", 889 status_message=None, 890 ) 891 mock_model_registry_store.get_model_version.return_value = mvd 892 resp = _get_model_version() 893 _, args = mock_model_registry_store.get_model_version.call_args 894 assert args == {"name": "model1", "version": "32"} 895 assert json.loads(resp.get_data()) == {"model_version": jsonify(mvd)} 896 897 898 def test_update_model_version(mock_get_request_message, mock_model_registry_store): 899 name = "model1" 900 version = "32" 901 description = "Great model!" 902 mock_get_request_message.return_value = UpdateModelVersion( 903 name=name, version=version, description=description 904 ) 905 906 mv = ModelVersion(name=name, version=version, creation_timestamp=123, description=description) 907 mock_model_registry_store.update_model_version.return_value = mv 908 _update_model_version() 909 _, args = mock_model_registry_store.update_model_version.call_args 910 assert args == {"name": name, "version": version, "description": description} 911 912 913 def test_transition_model_version_stage(mock_get_request_message, mock_model_registry_store): 914 name = "model1" 915 version = "32" 916 stage = "Production" 917 mock_get_request_message.return_value = TransitionModelVersionStage( 918 name=name, version=version, stage=stage 919 ) 920 mv = ModelVersion(name=name, version=version, creation_timestamp=123, current_stage=stage) 921 mock_model_registry_store.transition_model_version_stage.return_value = mv 922 _transition_stage() 923 _, args = mock_model_registry_store.transition_model_version_stage.call_args 924 assert args == { 925 "name": name, 926 "version": version, 927 "stage": stage, 928 "archive_existing_versions": False, 929 } 930 931 932 def test_delete_model_version(mock_get_request_message, mock_model_registry_store): 933 name = "model1" 934 version = "32" 935 mock_get_request_message.return_value = DeleteModelVersion(name=name, version=version) 936 _delete_model_version() 937 _, args = mock_model_registry_store.delete_model_version.call_args 938 assert args == {"name": name, "version": version} 939 940 941 def test_get_model_version_download_uri(mock_get_request_message, mock_model_registry_store): 942 name = "model1" 943 version = "32" 944 mock_get_request_message.return_value = GetModelVersionDownloadUri(name=name, version=version) 945 mock_model_registry_store.get_model_version_download_uri.return_value = "some/download/path" 946 resp = _get_model_version_download_uri() 947 _, args = mock_model_registry_store.get_model_version_download_uri.call_args 948 assert args == {"name": name, "version": version} 949 assert json.loads(resp.get_data()) == {"artifact_uri": "some/download/path"} 950 951 952 def test_search_model_versions(mock_get_request_message, mock_model_registry_store): 953 mvds = [ 954 ModelVersion( 955 name="model_1", 956 version="5", 957 creation_timestamp=100, 958 last_updated_timestamp=3200, 959 description="v 5", 960 user_id="u1", 961 current_stage="Production", 962 source="A/B/CD", 963 run_id=uuid.uuid4().hex, 964 status="READY", 965 status_message=None, 966 ), 967 ModelVersion( 968 name="model_1", 969 version="12", 970 creation_timestamp=110, 971 last_updated_timestamp=2000, 972 description="v 12", 973 user_id="u2", 974 current_stage="Production", 975 source="A/B/CD", 976 run_id=uuid.uuid4().hex, 977 status="READY", 978 status_message=None, 979 ), 980 ModelVersion( 981 name="ads_model", 982 version="8", 983 creation_timestamp=200, 984 last_updated_timestamp=1000, 985 description="v 8", 986 user_id="u1", 987 current_stage="Staging", 988 source="A/B/CD", 989 run_id=uuid.uuid4().hex, 990 status="READY", 991 status_message=None, 992 ), 993 ModelVersion( 994 name="fraud_detection_model", 995 version="345", 996 creation_timestamp=1000, 997 last_updated_timestamp=999, 998 description="newest version", 999 user_id="u12", 1000 current_stage="None", 1001 source="A/B/CD", 1002 run_id=uuid.uuid4().hex, 1003 status="READY", 1004 status_message=None, 1005 ), 1006 ] 1007 mock_get_request_message.return_value = SearchModelVersions(filter="source_path = 'A/B/CD'") 1008 mock_model_registry_store.search_model_versions.return_value = PagedList(mvds, None) 1009 resp = _search_model_versions() 1010 mock_model_registry_store.search_model_versions.assert_called_with( 1011 filter_string="source_path = 'A/B/CD'", 1012 max_results=SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD, 1013 order_by=[], 1014 page_token=None, 1015 ) 1016 assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)} 1017 1018 mock_get_request_message.return_value = SearchModelVersions(filter="name='model_1'") 1019 mock_model_registry_store.search_model_versions.return_value = PagedList(mvds[:1], "tok") 1020 resp = _search_model_versions() 1021 mock_model_registry_store.search_model_versions.assert_called_with( 1022 filter_string="name='model_1'", 1023 max_results=SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD, 1024 order_by=[], 1025 page_token=None, 1026 ) 1027 assert json.loads(resp.get_data()) == { 1028 "model_versions": jsonify(mvds[:1]), 1029 "next_page_token": "tok", 1030 } 1031 1032 mock_get_request_message.return_value = SearchModelVersions(filter="version<=12", max_results=2) 1033 mock_model_registry_store.search_model_versions.return_value = PagedList( 1034 [mvds[0], mvds[2]], "next" 1035 ) 1036 resp = _search_model_versions() 1037 mock_model_registry_store.search_model_versions.assert_called_with( 1038 filter_string="version<=12", max_results=2, order_by=[], page_token=None 1039 ) 1040 assert json.loads(resp.get_data()) == { 1041 "model_versions": jsonify([mvds[0], mvds[2]]), 1042 "next_page_token": "next", 1043 } 1044 1045 mock_get_request_message.return_value = SearchModelVersions( 1046 filter="version<=12", max_results=2, order_by=["version DESC"], page_token="prev" 1047 ) 1048 mock_model_registry_store.search_model_versions.return_value = PagedList(mvds[1:3], "next") 1049 resp = _search_model_versions() 1050 mock_model_registry_store.search_model_versions.assert_called_with( 1051 filter_string="version<=12", max_results=2, order_by=["version DESC"], page_token="prev" 1052 ) 1053 assert json.loads(resp.get_data()) == { 1054 "model_versions": jsonify(mvds[1:3]), 1055 "next_page_token": "next", 1056 } 1057 1058 1059 def test_set_model_version_tag(mock_get_request_message, mock_model_registry_store): 1060 name = "model1" 1061 version = "1" 1062 tag = ModelVersionTag(key="some weird key", value="some value") 1063 mock_get_request_message.return_value = SetModelVersionTag( 1064 name=name, version=version, key=tag.key, value=tag.value 1065 ) 1066 _set_model_version_tag() 1067 _, args = mock_model_registry_store.set_model_version_tag.call_args 1068 assert args == {"name": name, "version": version, "tag": tag} 1069 1070 1071 def test_delete_model_version_tag(mock_get_request_message, mock_model_registry_store): 1072 name = "model1" 1073 version = "1" 1074 key = "some weird key" 1075 mock_get_request_message.return_value = DeleteModelVersionTag( 1076 name=name, version=version, key=key 1077 ) 1078 _delete_model_version_tag() 1079 _, args = mock_model_registry_store.delete_model_version_tag.call_args 1080 assert args == {"name": name, "version": version, "key": key} 1081 1082 1083 def test_set_registered_model_alias(mock_get_request_message, mock_model_registry_store): 1084 name = "model1" 1085 alias = "test_alias" 1086 version = "1" 1087 mock_get_request_message.return_value = SetRegisteredModelAlias( 1088 name=name, alias=alias, version=version 1089 ) 1090 _set_registered_model_alias() 1091 _, args = mock_model_registry_store.set_registered_model_alias.call_args 1092 assert args == {"name": name, "alias": alias, "version": version} 1093 1094 1095 def test_delete_registered_model_alias(mock_get_request_message, mock_model_registry_store): 1096 name = "model1" 1097 alias = "test_alias" 1098 mock_get_request_message.return_value = DeleteRegisteredModelAlias(name=name, alias=alias) 1099 _delete_registered_model_alias() 1100 _, args = mock_model_registry_store.delete_registered_model_alias.call_args 1101 assert args == {"name": name, "alias": alias} 1102 1103 1104 def test_get_model_version_by_alias(mock_get_request_message, mock_model_registry_store): 1105 name = "model1" 1106 alias = "test_alias" 1107 mock_get_request_message.return_value = GetModelVersionByAlias(name=name, alias=alias) 1108 mvd = ModelVersion( 1109 name="model1", 1110 version="5", 1111 creation_timestamp=1, 1112 last_updated_timestamp=12, 1113 description="v 5", 1114 user_id="u1", 1115 current_stage="Production", 1116 source="A/B", 1117 run_id=uuid.uuid4().hex, 1118 status="READY", 1119 status_message=None, 1120 aliases=["test_alias"], 1121 ) 1122 mock_model_registry_store.get_model_version_by_alias.return_value = mvd 1123 resp = _get_model_version_by_alias() 1124 _, args = mock_model_registry_store.get_model_version_by_alias.call_args 1125 assert args == {"name": name, "alias": alias} 1126 assert json.loads(resp.get_data()) == {"model_version": jsonify(mvd)} 1127 1128 1129 @pytest.mark.parametrize( 1130 "path", 1131 [ 1132 "/path", 1133 "path/../to/file", 1134 "/etc/passwd", 1135 "/etc/passwd%00.jpg", 1136 "/file://etc/passwd", 1137 "%2E%2E%2F%2E%2E%2Fpath", 1138 ], 1139 ) 1140 def test_delete_artifact_mlflow_artifacts_throws_for_malicious_path(enable_serve_artifacts, path): 1141 response = _delete_artifact_mlflow_artifacts(path) 1142 assert response.status_code == 400 1143 json_response = json.loads(response.get_data()) 1144 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1145 assert json_response["message"] == "Invalid path" 1146 1147 1148 def test_get_presigned_download_url_success(enable_serve_artifacts): 1149 from mlflow.store.artifact.artifact_repo import MultipartDownloadMixin 1150 1151 class MockMultipartDownloadRepo(MultipartDownloadMixin): 1152 def get_download_presigned_url(self, artifact_path, expiration=300): 1153 return PresignedDownloadUrlResponse( 1154 url="https://storage.example.com/presigned?token=abc", 1155 headers={"x-custom-header": "value"}, 1156 file_size=1024, 1157 ) 1158 1159 artifact_path = "run_id/artifacts/model.pkl" 1160 with ( 1161 app.test_request_context(method="GET"), 1162 mock.patch( 1163 "mlflow.server.handlers._get_artifact_repo_mlflow_artifacts", 1164 return_value=MockMultipartDownloadRepo(), 1165 ), 1166 ): 1167 response = _get_presigned_download_url(artifact_path) 1168 1169 assert response.status_code == 200 1170 data = json.loads(response.get_data()) 1171 assert data["url"] == "https://storage.example.com/presigned?token=abc" 1172 assert data["headers"] == {"x-custom-header": "value"} 1173 assert data["file_size"] == 1024 1174 1175 1176 @pytest.mark.parametrize( 1177 "path", 1178 [ 1179 "/path", 1180 "path/../to/file", 1181 "/etc/passwd", 1182 "/etc/passwd%00.jpg", 1183 "/file://etc/passwd", 1184 "%2E%2E%2F%2E%2E%2Fpath", 1185 ], 1186 ) 1187 def test_get_presigned_download_url_throws_for_malicious_path(enable_serve_artifacts, path): 1188 response = _get_presigned_download_url(path) 1189 assert response.status_code == 400 1190 json_response = json.loads(response.get_data()) 1191 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1192 assert json_response["message"] == "Invalid path" 1193 1194 1195 def test_get_presigned_download_url_unsupported_repo(enable_serve_artifacts, tmp_path): 1196 with ( 1197 app.test_request_context(method="GET"), 1198 mock.patch( 1199 "mlflow.server.handlers._get_artifact_repo_mlflow_artifacts", 1200 return_value=LocalArtifactRepository(str(tmp_path)), 1201 ), 1202 ): 1203 response = _get_presigned_download_url("some/artifact/path") 1204 1205 assert response.status_code == 501 1206 json_response = json.loads(response.get_data()) 1207 assert json_response["error_code"] == ErrorCode.Name(NOT_IMPLEMENTED) 1208 assert "multipart" in json_response["message"].lower() 1209 1210 1211 # --- Presigned upload URL handler tests --- 1212 1213 1214 def test_create_presigned_upload_url_success(): 1215 from mlflow.store.artifact.artifact_repo import PresignedUploadMixin 1216 1217 class MockPresignedUploadRepo(PresignedUploadMixin): 1218 def create_presigned_upload_url(self, artifact_path, expiration=900): 1219 return CreatePresignedUploadResponse( 1220 presigned_url="https://s3.amazonaws.com/bucket/artifacts/model.pkl?X-Amz-Signature=abc", 1221 headers={"Content-Type": "application/octet-stream"}, 1222 ) 1223 1224 mock_run = mock.MagicMock() 1225 mock_run.info.artifact_uri = "s3://bucket/0/abc123/artifacts" 1226 1227 from mlflow.protos.service_pb2 import CreatePresignedUploadUrl 1228 1229 request_proto = CreatePresignedUploadUrl() 1230 request_proto.run_id = "abc123" 1231 request_proto.path = "model.pkl" 1232 1233 with ( 1234 app.test_request_context(method="POST", content_type="application/json"), 1235 mock.patch( 1236 "mlflow.server.handlers._get_request_message", 1237 return_value=request_proto, 1238 ), 1239 mock.patch( 1240 "mlflow.server.handlers._get_tracking_store", 1241 ) as mock_store, 1242 mock.patch( 1243 "mlflow.server.handlers._get_artifact_repo", 1244 return_value=MockPresignedUploadRepo(), 1245 ), 1246 ): 1247 mock_store.return_value.get_run.return_value = mock_run 1248 response = _create_presigned_upload_url() 1249 1250 assert response.status_code == 200 1251 data = json.loads(response.get_data()) 1252 assert "presigned_url" in data 1253 assert "X-Amz-Signature" in data["presigned_url"] 1254 assert data["headers"] == {"Content-Type": "application/octet-stream"} 1255 1256 1257 def test_create_presigned_upload_url_unsupported_repo(): 1258 mock_run = mock.MagicMock() 1259 mock_run.info.artifact_uri = "file:///tmp/artifacts" 1260 1261 from mlflow.protos.service_pb2 import CreatePresignedUploadUrl 1262 1263 request_proto = CreatePresignedUploadUrl() 1264 request_proto.run_id = "abc123" 1265 request_proto.path = "model.pkl" 1266 1267 with ( 1268 app.test_request_context(method="POST", content_type="application/json"), 1269 mock.patch( 1270 "mlflow.server.handlers._get_request_message", 1271 return_value=request_proto, 1272 ), 1273 mock.patch( 1274 "mlflow.server.handlers._get_tracking_store", 1275 ) as mock_store, 1276 mock.patch( 1277 "mlflow.server.handlers._get_artifact_repo", 1278 return_value=LocalArtifactRepository("/tmp/artifacts"), 1279 ), 1280 ): 1281 mock_store.return_value.get_run.return_value = mock_run 1282 response = _create_presigned_upload_url() 1283 1284 assert response.status_code == 501 1285 json_response = json.loads(response.get_data()) 1286 assert json_response["error_code"] == ErrorCode.Name(NOT_IMPLEMENTED) 1287 assert "presigned upload" in json_response["message"].lower() 1288 1289 1290 @pytest.mark.parametrize( 1291 "artifact_uri", 1292 [ 1293 "mlflow-artifacts:/0/abc123/artifacts", 1294 "http://mlflow-server:5000/api/2.0/mlflow-artifacts/artifacts", 1295 "https://mlflow-server/api/2.0/mlflow-artifacts/artifacts", 1296 ], 1297 ) 1298 def test_create_presigned_upload_url_rejects_proxy_artifact_uri(artifact_uri): 1299 mock_run = mock.MagicMock() 1300 mock_run.info.artifact_uri = artifact_uri 1301 1302 from mlflow.protos.service_pb2 import CreatePresignedUploadUrl 1303 1304 request_proto = CreatePresignedUploadUrl() 1305 request_proto.run_id = "abc123" 1306 request_proto.path = "model.pkl" 1307 1308 with ( 1309 app.test_request_context(method="POST", content_type="application/json"), 1310 mock.patch( 1311 "mlflow.server.handlers._get_request_message", 1312 return_value=request_proto, 1313 ), 1314 mock.patch( 1315 "mlflow.server.handlers._get_tracking_store", 1316 ) as mock_store, 1317 ): 1318 mock_store.return_value.get_run.return_value = mock_run 1319 response = _create_presigned_upload_url() 1320 1321 assert response.status_code == 400 1322 json_response = json.loads(response.get_data()) 1323 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1324 assert "proxied" in json_response["message"].lower() 1325 1326 1327 def test_create_presigned_upload_url_invalid_run_id(): 1328 from mlflow.protos.service_pb2 import CreatePresignedUploadUrl 1329 1330 request_proto = CreatePresignedUploadUrl() 1331 request_proto.run_id = "nonexistent_run" 1332 request_proto.path = "model.pkl" 1333 1334 with ( 1335 app.test_request_context(method="POST", content_type="application/json"), 1336 mock.patch( 1337 "mlflow.server.handlers._get_request_message", 1338 return_value=request_proto, 1339 ), 1340 mock.patch( 1341 "mlflow.server.handlers._get_tracking_store", 1342 ) as mock_store, 1343 ): 1344 mock_store.return_value.get_run.side_effect = MlflowException( 1345 "Run 'nonexistent_run' not found", 1346 error_code=RESOURCE_DOES_NOT_EXIST, 1347 ) 1348 response = _create_presigned_upload_url() 1349 1350 assert response.status_code == 404 1351 json_response = json.loads(response.get_data()) 1352 assert json_response["error_code"] == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 1353 1354 1355 @pytest.mark.parametrize( 1356 "path", 1357 [ 1358 "../../../etc/passwd", 1359 "path/../to/file", 1360 "/etc/passwd", 1361 "/etc/passwd%00.jpg", 1362 "%2E%2E%2F%2E%2E%2Fpath", 1363 ], 1364 ) 1365 def test_create_presigned_upload_url_rejects_path_traversal(path): 1366 from mlflow.protos.service_pb2 import CreatePresignedUploadUrl 1367 1368 request_proto = CreatePresignedUploadUrl() 1369 request_proto.run_id = "abc123" 1370 request_proto.path = path 1371 1372 with ( 1373 app.test_request_context(method="POST", content_type="application/json"), 1374 mock.patch( 1375 "mlflow.server.handlers._get_request_message", 1376 return_value=request_proto, 1377 ), 1378 ): 1379 response = _create_presigned_upload_url() 1380 1381 assert response.status_code == 400 1382 json_response = json.loads(response.get_data()) 1383 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1384 1385 1386 def test_create_presigned_upload_url_with_custom_expiration(): 1387 from mlflow.store.artifact.artifact_repo import PresignedUploadMixin 1388 1389 captured_expiration = {} 1390 1391 class MockPresignedUploadRepo(PresignedUploadMixin): 1392 def create_presigned_upload_url(self, artifact_path, expiration=900): 1393 captured_expiration["value"] = expiration 1394 return CreatePresignedUploadResponse( 1395 presigned_url="https://example.com/presigned", 1396 headers={}, 1397 ) 1398 1399 mock_run = mock.MagicMock() 1400 mock_run.info.artifact_uri = "s3://bucket/0/abc123/artifacts" 1401 1402 from mlflow.protos.service_pb2 import CreatePresignedUploadUrl 1403 1404 request_proto = CreatePresignedUploadUrl() 1405 request_proto.run_id = "abc123" 1406 request_proto.path = "model.pkl" 1407 request_proto.expiration = 60 1408 1409 with ( 1410 app.test_request_context(method="POST", content_type="application/json"), 1411 mock.patch( 1412 "mlflow.server.handlers._get_request_message", 1413 return_value=request_proto, 1414 ), 1415 mock.patch( 1416 "mlflow.server.handlers._get_tracking_store", 1417 ) as mock_store, 1418 mock.patch( 1419 "mlflow.server.handlers._get_artifact_repo", 1420 return_value=MockPresignedUploadRepo(), 1421 ), 1422 ): 1423 mock_store.return_value.get_run.return_value = mock_run 1424 response = _create_presigned_upload_url() 1425 1426 assert response.status_code == 200 1427 assert captured_expiration["value"] == 60 1428 1429 1430 def test_create_presigned_upload_url_default_expiration(): 1431 from mlflow.store.artifact.artifact_repo import PresignedUploadMixin 1432 1433 captured_expiration = {} 1434 1435 class MockPresignedUploadRepo(PresignedUploadMixin): 1436 def create_presigned_upload_url(self, artifact_path, expiration=900): 1437 captured_expiration["value"] = expiration 1438 return CreatePresignedUploadResponse( 1439 presigned_url="https://example.com/presigned", 1440 headers={}, 1441 ) 1442 1443 mock_run = mock.MagicMock() 1444 mock_run.info.artifact_uri = "s3://bucket/0/abc123/artifacts" 1445 1446 from mlflow.protos.service_pb2 import CreatePresignedUploadUrl 1447 1448 # Don't set expiration - should default to 900 1449 request_proto = CreatePresignedUploadUrl() 1450 request_proto.run_id = "abc123" 1451 request_proto.path = "model.pkl" 1452 1453 with ( 1454 app.test_request_context(method="POST", content_type="application/json"), 1455 mock.patch( 1456 "mlflow.server.handlers._get_request_message", 1457 return_value=request_proto, 1458 ), 1459 mock.patch( 1460 "mlflow.server.handlers._get_tracking_store", 1461 ) as mock_store, 1462 mock.patch( 1463 "mlflow.server.handlers._get_artifact_repo", 1464 return_value=MockPresignedUploadRepo(), 1465 ), 1466 ): 1467 mock_store.return_value.get_run.return_value = mock_run 1468 response = _create_presigned_upload_url() 1469 1470 assert response.status_code == 200 1471 assert captured_expiration["value"] == 900 1472 1473 1474 def test_create_presigned_upload_url_blocked_in_artifacts_only_mode(monkeypatch): 1475 from mlflow.server import ARTIFACTS_ONLY_ENV_VAR 1476 1477 monkeypatch.setenv(ARTIFACTS_ONLY_ENV_VAR, "true") 1478 1479 with app.test_request_context(method="POST", content_type="application/json"): 1480 response = _create_presigned_upload_url() 1481 1482 assert response.status_code == 503 1483 assert "artifacts-only" in response.get_data(as_text=True).lower() 1484 1485 1486 @pytest.mark.parametrize( 1487 "uri", 1488 [ 1489 "http://host#/abc/etc/", 1490 "http://host/;..%2F..%2Fetc", 1491 ], 1492 ) 1493 def test_local_file_read_write_by_pass_vulnerability(uri): 1494 request = mock.MagicMock() 1495 request.method = "POST" 1496 request.content_type = "application/json; charset=utf-8" 1497 request.get_json = mock.MagicMock() 1498 request.get_json.return_value = { 1499 "name": "hello", 1500 "artifact_location": uri, 1501 } 1502 msg = _get_request_message(CreateExperiment(), flask_request=request) 1503 with mock.patch("mlflow.server.handlers._get_request_message", return_value=msg): 1504 response = _create_experiment() 1505 json_response = json.loads(response.get_data()) 1506 assert ( 1507 json_response["message"] == "'artifact_location' URL can't include fragments or params." 1508 ) 1509 1510 # Test if source is a local filesystem path, `_validate_source` validates that the run 1511 # artifact_uri is also a local filesystem path. 1512 run_id = uuid.uuid4().hex 1513 with mock.patch("mlflow.server.handlers._get_tracking_store") as mock_get_tracking_store: 1514 mock_get_tracking_store().get_run( 1515 run_id 1516 ).info.artifact_uri = f"http://host/{run_id}/artifacts/abc" 1517 1518 with pytest.raises( 1519 MlflowException, 1520 match=( 1521 "the run_id request parameter has to be specified and the local " 1522 "path has to be contained within the artifact directory of the " 1523 "run specified by the run_id" 1524 ), 1525 ): 1526 _validate_source_run("/local/path/xyz", run_id) 1527 1528 1529 @pytest.mark.parametrize( 1530 ("location", "expected_class", "expected_uri"), 1531 [ 1532 ("file:///0/traces/123", LocalArtifactRepository, "file:///0/traces/123"), 1533 ("s3://bucket/0/traces/123", S3ArtifactRepository, "s3://bucket/0/traces/123"), 1534 ( 1535 "wasbs://container@account.blob.core.windows.net/bucket/1/traces/123", 1536 AzureBlobArtifactRepository, 1537 "wasbs://container@account.blob.core.windows.net/bucket/1/traces/123", 1538 ), 1539 # Proxy URI must be resolved to the actual storage URI 1540 ( 1541 "https://127.0.0.1/api/2.0/mlflow-artifacts/artifacts/2/traces/123", 1542 S3ArtifactRepository, 1543 "s3://bucket/2/traces/123", 1544 ), 1545 ("mlflow-artifacts:/1/traces/123", S3ArtifactRepository, "s3://bucket/1/traces/123"), 1546 ], 1547 ) 1548 def test_get_trace_artifact_repo(location, expected_class, expected_uri, monkeypatch): 1549 monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true") 1550 monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket") 1551 trace_info = TraceInfo( 1552 trace_id="123", 1553 trace_location=EntityTraceLocation.from_experiment_id("0"), 1554 request_time=0, 1555 execution_duration=1, 1556 state=TraceState.OK, 1557 tags={MLFLOW_ARTIFACT_LOCATION: location}, 1558 ) 1559 repo = _get_trace_artifact_repo(trace_info) 1560 assert isinstance(repo, expected_class) 1561 assert repo.artifact_uri == expected_uri 1562 1563 1564 ### Prompt Registry Tests ### 1565 def test_create_prompt_as_registered_model(mock_get_request_message, mock_model_registry_store): 1566 tags = [RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")] 1567 mock_get_request_message.return_value = CreateRegisteredModel( 1568 name="model_1", tags=[tag.to_proto() for tag in tags] 1569 ) 1570 rm = RegisteredModel("model_1", tags=tags) 1571 mock_model_registry_store.create_registered_model.return_value = rm 1572 resp = _create_registered_model() 1573 _, args = mock_model_registry_store.create_registered_model.call_args 1574 assert args["name"] == "model_1" 1575 assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags} 1576 assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)} 1577 1578 1579 def test_create_prompt_as_model_version(mock_get_request_message, mock_model_registry_store): 1580 tags = [ 1581 ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"), 1582 ModelVersionTag(key=PROMPT_TEXT_TAG_KEY, value="some prompt text"), 1583 ] 1584 mock_get_request_message.return_value = CreateModelVersion( 1585 name="model_1", 1586 tags=[tag.to_proto() for tag in tags], 1587 source=None, 1588 run_id=None, 1589 run_link=None, 1590 ) 1591 mv = ModelVersion( 1592 name="prompt_1", version="12", creation_timestamp=123, tags=tags, run_link=None 1593 ) 1594 mock_model_registry_store.create_model_version.return_value = mv 1595 resp = _create_model_version() 1596 _, args = mock_model_registry_store.create_model_version.call_args 1597 assert args["name"] == "model_1" 1598 assert args["source"] == "" 1599 assert args["run_id"] == "" 1600 assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags} 1601 assert args["run_link"] == "" 1602 assert json.loads(resp.get_data()) == {"model_version": jsonify(mv)} 1603 1604 1605 def test_create_evaluation_dataset(mock_tracking_store, mock_evaluation_dataset): 1606 mock_tracking_store.create_dataset.return_value = mock_evaluation_dataset 1607 1608 with app.test_request_context( 1609 method="POST", 1610 json={ 1611 "name": "test_dataset", 1612 "experiment_ids": ["0", "1"], 1613 "tags": json.dumps({"env": "test"}), 1614 }, 1615 ): 1616 _create_dataset_handler() 1617 1618 mock_tracking_store.create_dataset.assert_called_once_with( 1619 name="test_dataset", 1620 experiment_ids=["0", "1"], 1621 tags={"env": "test"}, 1622 ) 1623 1624 1625 def test_get_evaluation_dataset(mock_tracking_store, mock_evaluation_dataset): 1626 mock_tracking_store.get_dataset.return_value = mock_evaluation_dataset 1627 1628 dataset_id = "d-1234567890abcdef1234567890abcdef" 1629 with app.test_request_context(method="GET"): 1630 _get_dataset_handler(dataset_id) 1631 1632 mock_tracking_store.get_dataset.assert_called_once_with(dataset_id) 1633 1634 1635 def test_delete_evaluation_dataset(mock_tracking_store): 1636 dataset_id = "d-1234567890abcdef1234567890abcdef" 1637 with app.test_request_context(method="DELETE"): 1638 _delete_dataset_handler(dataset_id) 1639 1640 mock_tracking_store.delete_dataset.assert_called_once_with(dataset_id) 1641 1642 1643 def test_search_datasets(mock_tracking_store): 1644 from mlflow.protos.datasets_pb2 import Dataset as ProtoDataset 1645 1646 datasets = [] 1647 for i in range(2): 1648 ds = mock.MagicMock() 1649 ds.name = f"dataset_{i}" 1650 proto = ProtoDataset() 1651 proto.dataset_id = f"d-{i:032d}" 1652 proto.name = ds.name 1653 ds.to_proto.return_value = proto 1654 datasets.append(ds) 1655 1656 paged_list = PagedList(datasets, "next_token") 1657 mock_tracking_store.search_datasets.return_value = paged_list 1658 1659 with app.test_request_context( 1660 method="POST", 1661 json={ 1662 "experiment_ids": ["0", "1"], 1663 "filter_string": "name = 'dataset_1'", 1664 "max_results": 10, 1665 "order_by": ["name DESC"], 1666 "page_token": "token123", 1667 }, 1668 ): 1669 _search_evaluation_datasets_handler() 1670 1671 mock_tracking_store.search_datasets.assert_called_once_with( 1672 experiment_ids=["0", "1"], 1673 filter_string="name = 'dataset_1'", 1674 max_results=10, 1675 order_by=["name DESC"], 1676 page_token="token123", 1677 ) 1678 1679 1680 def test_set_dataset_tags(mock_tracking_store): 1681 dataset_id = "d-1234567890abcdef1234567890abcdef" 1682 with app.test_request_context( 1683 method="POST", 1684 json={ 1685 "tags": json.dumps({"env": "production", "version": "2.0"}), 1686 }, 1687 ): 1688 _set_dataset_tags_handler(dataset_id) 1689 1690 mock_tracking_store.set_dataset_tags.assert_called_once_with( 1691 dataset_id=dataset_id, 1692 tags={"env": "production", "version": "2.0"}, 1693 ) 1694 1695 1696 def test_delete_dataset_tag(mock_tracking_store): 1697 dataset_id = "d-1234567890abcdef1234567890abcdef" 1698 key = "deprecated_tag" 1699 with app.test_request_context(method="DELETE"): 1700 _delete_dataset_tag_handler(dataset_id, key) 1701 1702 mock_tracking_store.delete_dataset_tag.assert_called_once_with( 1703 dataset_id=dataset_id, 1704 key=key, 1705 ) 1706 1707 1708 def test_upsert_dataset_records(mock_tracking_store): 1709 mock_tracking_store.upsert_dataset_records.return_value = { 1710 "inserted": 2, 1711 "updated": 0, 1712 } 1713 1714 dataset_id = "d-1234567890abcdef1234567890abcdef" 1715 records = [ 1716 {"inputs": {"q": "test1"}, "expectations": {"score": 0.9}}, 1717 {"inputs": {"q": "test2"}, "expectations": {"score": 0.8}}, 1718 ] 1719 1720 with app.test_request_context( 1721 method="POST", 1722 json={ 1723 "records": json.dumps(records), 1724 }, 1725 ): 1726 resp = _upsert_dataset_records_handler(dataset_id) 1727 1728 mock_tracking_store.upsert_dataset_records.assert_called_once_with( 1729 dataset_id=dataset_id, 1730 records=records, 1731 ) 1732 1733 response_data = json.loads(resp.get_data()) 1734 assert response_data["inserted_count"] == 2 1735 assert response_data["updated_count"] == 0 1736 1737 1738 def test_get_dataset_experiment_ids(mock_tracking_store): 1739 mock_tracking_store.get_dataset_experiment_ids.return_value = [ 1740 "exp1", 1741 "exp2", 1742 "exp3", 1743 ] 1744 1745 dataset_id = "d-1234567890abcdef1234567890abcdef" 1746 with app.test_request_context(method="GET"): 1747 resp = _get_dataset_experiment_ids_handler(dataset_id) 1748 1749 mock_tracking_store.get_dataset_experiment_ids.assert_called_once_with(dataset_id=dataset_id) 1750 1751 response_data = json.loads(resp.get_data()) 1752 assert response_data["experiment_ids"] == ["exp1", "exp2", "exp3"] 1753 1754 1755 def test_get_dataset_records(mock_tracking_store): 1756 records = [] 1757 for i in range(3): 1758 record = mock.MagicMock() 1759 record.dataset_id = "d-1234567890abcdef1234567890abcdef" 1760 record.dataset_record_id = f"r-00{i}" 1761 record.inputs = {"question": f"test{i}"} 1762 record.expectations = {"score": 0.9 - i * 0.1} 1763 record.tags = {} 1764 record.created_time = 1234567890 + i 1765 record.last_update_time = 1234567890 + i 1766 record.to_dict.return_value = { 1767 "dataset_id": record.dataset_id, 1768 "dataset_record_id": record.dataset_record_id, 1769 "inputs": record.inputs, 1770 "expectations": record.expectations, 1771 "tags": record.tags, 1772 "created_time": record.created_time, 1773 "last_update_time": record.last_update_time, 1774 } 1775 records.append(record) 1776 1777 mock_tracking_store._load_dataset_records.return_value = (records, None) 1778 1779 dataset_id = "d-1234567890abcdef1234567890abcdef" 1780 with app.test_request_context(method="GET"): 1781 resp = _get_dataset_records_handler(dataset_id) 1782 1783 mock_tracking_store._load_dataset_records.assert_called_with( 1784 dataset_id, max_results=1000, page_token=None 1785 ) 1786 1787 response_data = json.loads(resp.get_data()) 1788 records_data = json.loads(response_data["records"]) 1789 assert len(records_data) == 3 1790 assert records_data[0]["dataset_record_id"] == "r-000" 1791 1792 mock_tracking_store._load_dataset_records.return_value = (records[:2], "token_page2") 1793 1794 with app.test_request_context( 1795 method="GET", 1796 json={ 1797 "max_results": 2, 1798 "page_token": None, 1799 }, 1800 ): 1801 resp = _get_dataset_records_handler(dataset_id) 1802 1803 mock_tracking_store._load_dataset_records.assert_called_with( 1804 dataset_id, max_results=2, page_token=None 1805 ) 1806 1807 response_data = json.loads(resp.get_data()) 1808 records_data = json.loads(response_data["records"]) 1809 assert len(records_data) == 2 1810 assert response_data["next_page_token"] == "token_page2" 1811 1812 mock_tracking_store._load_dataset_records.return_value = (records[2:], None) 1813 1814 with app.test_request_context( 1815 method="GET", 1816 json={ 1817 "max_results": 2, 1818 "page_token": "token_page2", 1819 }, 1820 ): 1821 resp = _get_dataset_records_handler(dataset_id) 1822 1823 mock_tracking_store._load_dataset_records.assert_called_with( 1824 dataset_id, max_results=2, page_token="token_page2" 1825 ) 1826 1827 response_data = json.loads(resp.get_data()) 1828 records_data = json.loads(response_data["records"]) 1829 assert len(records_data) == 1 1830 assert "next_page_token" not in response_data or response_data["next_page_token"] == "" 1831 1832 1833 def test_get_dataset_records_empty(mock_tracking_store): 1834 mock_tracking_store._load_dataset_records.return_value = ([], None) 1835 1836 dataset_id = "d-1234567890abcdef1234567890abcdef" 1837 with app.test_request_context(method="GET"): 1838 resp = _get_dataset_records_handler(dataset_id) 1839 1840 response_data = json.loads(resp.get_data()) 1841 records_data = json.loads(response_data["records"]) 1842 assert len(records_data) == 0 1843 assert "next_page_token" not in response_data or response_data["next_page_token"] == "" 1844 1845 1846 def test_get_dataset_records_pagination(mock_tracking_store): 1847 dataset_id = "d-1234567890abcdef1234567890abcdef" 1848 all_records = [] 1849 for i in range(50): 1850 record = mock.Mock() 1851 record.dataset_record_id = f"r-{i:03d}" 1852 record.inputs = {"q": f"Question {i}"} 1853 record.expectations = {"a": f"Answer {i}"} 1854 record.tags = {} 1855 record.source_type = "TRACE" 1856 record.source_id = f"trace-{i}" 1857 record.created_time = 1609459200 + i 1858 record.to_dict.return_value = { 1859 "dataset_record_id": f"r-{i:03d}", 1860 "inputs": {"q": f"Question {i}"}, 1861 "expectations": {"a": f"Answer {i}"}, 1862 "tags": {}, 1863 "source_type": "TRACE", 1864 "source_id": f"trace-{i}", 1865 "created_time": 1609459200 + i, 1866 } 1867 all_records.append(record) 1868 mock_tracking_store._load_dataset_records.return_value = (all_records[:20], "token_20") 1869 1870 with app.test_request_context( 1871 method="GET", 1872 json={"max_results": 20}, 1873 ): 1874 resp = _get_dataset_records_handler(dataset_id) 1875 1876 mock_tracking_store._load_dataset_records.assert_called_with( 1877 dataset_id, max_results=20, page_token=None 1878 ) 1879 1880 response_data = json.loads(resp.get_data()) 1881 records_data = json.loads(response_data["records"]) 1882 assert len(records_data) == 20 1883 assert response_data["next_page_token"] == "token_20" 1884 assert records_data[0]["dataset_record_id"] == "r-000" 1885 assert records_data[19]["dataset_record_id"] == "r-019" 1886 mock_tracking_store._load_dataset_records.return_value = (all_records[20:40], "token_40") 1887 1888 with app.test_request_context( 1889 method="GET", 1890 json={"max_results": 20, "page_token": "token_20"}, 1891 ): 1892 resp = _get_dataset_records_handler(dataset_id) 1893 1894 mock_tracking_store._load_dataset_records.assert_called_with( 1895 dataset_id, max_results=20, page_token="token_20" 1896 ) 1897 1898 response_data = json.loads(resp.get_data()) 1899 records_data = json.loads(response_data["records"]) 1900 assert len(records_data) == 20 1901 assert response_data["next_page_token"] == "token_40" 1902 assert records_data[0]["dataset_record_id"] == "r-020" 1903 mock_tracking_store._load_dataset_records.return_value = (all_records[40:], None) 1904 1905 with app.test_request_context( 1906 method="GET", 1907 json={"max_results": 20, "page_token": "token_40"}, 1908 ): 1909 resp = _get_dataset_records_handler(dataset_id) 1910 1911 response_data = json.loads(resp.get_data()) 1912 records_data = json.loads(response_data["records"]) 1913 assert len(records_data) == 10 1914 assert "next_page_token" not in response_data or response_data["next_page_token"] == "" 1915 assert records_data[0]["dataset_record_id"] == "r-040" 1916 assert records_data[9]["dataset_record_id"] == "r-049" 1917 1918 1919 def test_register_scorer(mock_get_request_message, mock_tracking_store): 1920 experiment_id = "123" 1921 name = "accuracy_scorer" 1922 serialized_scorer = '{"name": "accuracy_scorer"}' 1923 1924 mock_get_request_message.return_value = RegisterScorer( 1925 experiment_id=experiment_id, name=name, serialized_scorer=serialized_scorer 1926 ) 1927 1928 mock_scorer_version = ScorerVersion( 1929 experiment_id=experiment_id, 1930 scorer_name=name, 1931 scorer_version=1, 1932 serialized_scorer=serialized_scorer, 1933 creation_time=1234567890, 1934 scorer_id="test-scorer-id", 1935 ) 1936 mock_tracking_store.register_scorer.return_value = mock_scorer_version 1937 1938 resp = _register_scorer() 1939 1940 mock_tracking_store.register_scorer.assert_called_once_with( 1941 experiment_id, name, serialized_scorer 1942 ) 1943 1944 response_data = json.loads(resp.get_data()) 1945 assert response_data == { 1946 "version": 1, 1947 "scorer_id": "test-scorer-id", 1948 "experiment_id": experiment_id, 1949 "name": name, 1950 "serialized_scorer": serialized_scorer, 1951 "creation_time": 1234567890, 1952 } 1953 1954 1955 def test_register_scorer_rejects_decorator_scorer(mock_get_request_message, mock_tracking_store): 1956 from mlflow.genai.scorers.scorer_utils import DECORATOR_SCORER_REGISTRATION_NOT_SUPPORTED_ERROR 1957 1958 serialized_scorer = json.dumps({"name": "my_scorer", "call_source": " return 1.0\n"}) 1959 mock_get_request_message.return_value = RegisterScorer( 1960 experiment_id="123", name="my_scorer", serialized_scorer=serialized_scorer 1961 ) 1962 resp = _register_scorer() 1963 assert resp.status_code == 400 1964 assert DECORATOR_SCORER_REGISTRATION_NOT_SUPPORTED_ERROR in resp.get_json()["message"] 1965 mock_tracking_store.register_scorer.assert_not_called() 1966 1967 1968 def test_list_scorers(mock_get_request_message, mock_tracking_store): 1969 experiment_id = "123" 1970 1971 mock_get_request_message.return_value = ListScorers(experiment_id=experiment_id) 1972 1973 # Create mock scorers 1974 scorers = [ 1975 ScorerVersion( 1976 experiment_id=123, 1977 scorer_name="accuracy_scorer", 1978 scorer_version=1, 1979 serialized_scorer="serialized_accuracy_scorer", 1980 creation_time=12345, 1981 ), 1982 ScorerVersion( 1983 experiment_id=123, 1984 scorer_name="safety_scorer", 1985 scorer_version=2, 1986 serialized_scorer="serialized_safety_scorer", 1987 creation_time=12345, 1988 ), 1989 ] 1990 1991 mock_tracking_store.list_scorers.return_value = scorers 1992 1993 resp = _list_scorers() 1994 1995 # Verify the tracking store was called with correct arguments 1996 mock_tracking_store.list_scorers.assert_called_once_with(experiment_id) 1997 1998 # Verify the response 1999 response_data = json.loads(resp.get_data()) 2000 assert len(response_data["scorers"]) == 2 2001 assert response_data["scorers"][0]["scorer_name"] == "accuracy_scorer" 2002 assert response_data["scorers"][0]["scorer_version"] == 1 2003 assert response_data["scorers"][0]["serialized_scorer"] == "serialized_accuracy_scorer" 2004 assert response_data["scorers"][1]["scorer_name"] == "safety_scorer" 2005 assert response_data["scorers"][1]["scorer_version"] == 2 2006 assert response_data["scorers"][1]["serialized_scorer"] == "serialized_safety_scorer" 2007 2008 2009 def test_list_scorer_versions(mock_get_request_message, mock_tracking_store): 2010 experiment_id = "123" 2011 name = "accuracy_scorer" 2012 2013 mock_get_request_message.return_value = ListScorerVersions( 2014 experiment_id=experiment_id, name=name 2015 ) 2016 2017 # Create mock scorers with multiple versions 2018 scorers = [ 2019 ScorerVersion( 2020 experiment_id=123, 2021 scorer_name="accuracy_scorer", 2022 scorer_version=1, 2023 serialized_scorer="serialized_accuracy_scorer_v1", 2024 creation_time=12345, 2025 ), 2026 ScorerVersion( 2027 experiment_id=123, 2028 scorer_name="accuracy_scorer", 2029 scorer_version=2, 2030 serialized_scorer="serialized_accuracy_scorer_v2", 2031 creation_time=12345, 2032 ), 2033 ] 2034 2035 mock_tracking_store.list_scorer_versions.return_value = scorers 2036 2037 resp = _list_scorer_versions() 2038 2039 # Verify the tracking store was called with correct arguments 2040 mock_tracking_store.list_scorer_versions.assert_called_once_with(experiment_id, name) 2041 2042 # Verify the response 2043 response_data = json.loads(resp.get_data()) 2044 assert len(response_data["scorers"]) == 2 2045 assert response_data["scorers"][0]["scorer_version"] == 1 2046 assert response_data["scorers"][0]["serialized_scorer"] == "serialized_accuracy_scorer_v1" 2047 assert response_data["scorers"][1]["scorer_version"] == 2 2048 assert response_data["scorers"][1]["serialized_scorer"] == "serialized_accuracy_scorer_v2" 2049 2050 2051 def test_get_scorer_with_version(mock_get_request_message, mock_tracking_store): 2052 experiment_id = "123" 2053 name = "accuracy_scorer" 2054 version = 2 2055 2056 mock_get_request_message.return_value = GetScorer( 2057 experiment_id=experiment_id, name=name, version=version 2058 ) 2059 2060 # Mock the return value as a ScorerVersion entity 2061 mock_scorer_version = ScorerVersion( 2062 experiment_id=123, 2063 scorer_name="accuracy_scorer", 2064 scorer_version=2, 2065 serialized_scorer="serialized_accuracy_scorer_v2", 2066 creation_time=1640995200000, 2067 ) 2068 mock_tracking_store.get_scorer.return_value = mock_scorer_version 2069 2070 resp = _get_scorer() 2071 2072 # Verify the tracking store was called with correct arguments (positional) 2073 mock_tracking_store.get_scorer.assert_called_once_with(experiment_id, name, version) 2074 2075 # Verify the response 2076 response_data = json.loads(resp.get_data()) 2077 assert response_data["scorer"]["experiment_id"] == 123 2078 assert response_data["scorer"]["scorer_name"] == "accuracy_scorer" 2079 assert response_data["scorer"]["scorer_version"] == 2 2080 assert response_data["scorer"]["serialized_scorer"] == "serialized_accuracy_scorer_v2" 2081 assert response_data["scorer"]["creation_time"] == 1640995200000 2082 2083 2084 def test_get_scorer_without_version(mock_get_request_message, mock_tracking_store): 2085 experiment_id = "123" 2086 name = "accuracy_scorer" 2087 2088 mock_get_request_message.return_value = GetScorer(experiment_id=experiment_id, name=name) 2089 2090 # Mock the return value as a ScorerVersion entity 2091 mock_scorer_version = ScorerVersion( 2092 experiment_id=123, 2093 scorer_name="accuracy_scorer", 2094 scorer_version=3, 2095 serialized_scorer="serialized_accuracy_scorer_latest", 2096 creation_time=1640995200000, 2097 ) 2098 mock_tracking_store.get_scorer.return_value = mock_scorer_version 2099 2100 resp = _get_scorer() 2101 2102 # Verify the tracking store was called with correct arguments (positional, version=None) 2103 mock_tracking_store.get_scorer.assert_called_once_with(experiment_id, name, None) 2104 2105 # Verify the response 2106 response_data = json.loads(resp.get_data()) 2107 assert response_data["scorer"]["experiment_id"] == 123 2108 assert response_data["scorer"]["scorer_name"] == "accuracy_scorer" 2109 assert response_data["scorer"]["scorer_version"] == 3 2110 assert response_data["scorer"]["serialized_scorer"] == "serialized_accuracy_scorer_latest" 2111 assert response_data["scorer"]["creation_time"] == 1640995200000 2112 2113 2114 def test_delete_scorer_with_version(mock_get_request_message, mock_tracking_store): 2115 experiment_id = "123" 2116 name = "accuracy_scorer" 2117 version = 2 2118 2119 mock_get_request_message.return_value = DeleteScorer( 2120 experiment_id=experiment_id, name=name, version=version 2121 ) 2122 2123 resp = _delete_scorer() 2124 2125 # Verify the tracking store was called with correct arguments (positional) 2126 mock_tracking_store.delete_scorer.assert_called_once_with(experiment_id, name, version) 2127 2128 # Verify the response (should be empty for delete operations) 2129 response_data = json.loads(resp.get_data()) 2130 assert response_data == {} 2131 2132 2133 def test_delete_scorer_without_version(mock_get_request_message, mock_tracking_store): 2134 experiment_id = "123" 2135 name = "accuracy_scorer" 2136 2137 mock_get_request_message.return_value = DeleteScorer(experiment_id=experiment_id, name=name) 2138 2139 resp = _delete_scorer() 2140 2141 # Verify the tracking store was called with correct arguments (positional, version=None) 2142 mock_tracking_store.delete_scorer.assert_called_once_with(experiment_id, name, None) 2143 2144 # Verify the response (should be empty for delete operations) 2145 response_data = json.loads(resp.get_data()) 2146 assert response_data == {} 2147 2148 2149 def test_get_online_scoring_configs_batch(mock_tracking_store): 2150 mock_configs = [ 2151 OnlineScoringConfig( 2152 online_scoring_config_id="cfg-1", 2153 scorer_id="scorer-1", 2154 sample_rate=0.5, 2155 filter_string="status = 'OK'", 2156 experiment_id="exp1", 2157 ), 2158 OnlineScoringConfig( 2159 online_scoring_config_id="cfg-2", 2160 scorer_id="scorer-2", 2161 sample_rate=0.8, 2162 experiment_id="exp1", 2163 ), 2164 ] 2165 mock_tracking_store.get_online_scoring_configs.return_value = mock_configs 2166 2167 with app.test_client() as c: 2168 resp = c.get( 2169 "/ajax-api/3.0/mlflow/scorers/online-configs", 2170 query_string=[("scorer_ids", "scorer-1"), ("scorer_ids", "scorer-2")], 2171 ) 2172 assert resp.status_code == 200 2173 data = resp.get_json() 2174 assert "configs" in data 2175 assert isinstance(data["configs"], list) 2176 assert len(data["configs"]) == 2 2177 configs_by_id = {c["scorer_id"]: c for c in data["configs"]} 2178 assert configs_by_id["scorer-1"]["sample_rate"] == 0.5 2179 assert configs_by_id["scorer-1"]["filter_string"] == "status = 'OK'" 2180 assert configs_by_id["scorer-2"]["sample_rate"] == 0.8 2181 assert configs_by_id["scorer-2"].get("filter_string") is None 2182 2183 mock_tracking_store.get_online_scoring_configs.assert_called_once_with(["scorer-1", "scorer-2"]) 2184 2185 2186 def test_get_online_scoring_configs_missing_param(mock_tracking_store): 2187 with app.test_client() as c: 2188 resp = c.get( 2189 "/ajax-api/3.0/mlflow/scorers/online-configs", 2190 ) 2191 assert resp.status_code == 400 2192 data = resp.get_json() 2193 assert "scorer_ids" in data["message"] 2194 2195 2196 def test_calculate_trace_filter_correlation(mock_get_request_message, mock_tracking_store): 2197 experiment_ids = ["123", "456"] 2198 filter_string1 = "span.type = 'LLM'" 2199 filter_string2 = "feedback.quality > 0.8" 2200 base_filter = "request_time > 1000" 2201 2202 mock_request = CalculateTraceFilterCorrelation( 2203 experiment_ids=experiment_ids, 2204 filter_string1=filter_string1, 2205 filter_string2=filter_string2, 2206 base_filter=base_filter, 2207 ) 2208 mock_get_request_message.return_value = mock_request 2209 2210 mock_result = TraceFilterCorrelationResult( 2211 npmi=0.456, 2212 npmi_smoothed=0.445, 2213 filter1_count=100, 2214 filter2_count=80, 2215 joint_count=50, 2216 total_count=200, 2217 ) 2218 mock_tracking_store.calculate_trace_filter_correlation.return_value = mock_result 2219 2220 resp = _calculate_trace_filter_correlation() 2221 2222 mock_tracking_store.calculate_trace_filter_correlation.assert_called_once_with( 2223 experiment_ids=experiment_ids, 2224 filter_string1=filter_string1, 2225 filter_string2=filter_string2, 2226 base_filter=base_filter, 2227 ) 2228 2229 response_data = json.loads(resp.get_data()) 2230 assert response_data["npmi"] == 0.456 2231 assert response_data["npmi_smoothed"] == 0.445 2232 assert response_data["filter1_count"] == 100 2233 assert response_data["filter2_count"] == 80 2234 assert response_data["joint_count"] == 50 2235 assert response_data["total_count"] == 200 2236 2237 2238 def test_calculate_trace_filter_correlation_without_base_filter( 2239 mock_get_request_message, mock_tracking_store 2240 ): 2241 experiment_ids = ["123"] 2242 filter_string1 = "span.type = 'LLM'" 2243 filter_string2 = "feedback.quality > 0.8" 2244 2245 mock_request = CalculateTraceFilterCorrelation( 2246 experiment_ids=experiment_ids, 2247 filter_string1=filter_string1, 2248 filter_string2=filter_string2, 2249 ) 2250 mock_get_request_message.return_value = mock_request 2251 2252 mock_result = TraceFilterCorrelationResult( 2253 npmi=0.789, 2254 npmi_smoothed=0.775, 2255 filter1_count=50, 2256 filter2_count=40, 2257 joint_count=30, 2258 total_count=100, 2259 ) 2260 mock_tracking_store.calculate_trace_filter_correlation.return_value = mock_result 2261 2262 resp = _calculate_trace_filter_correlation() 2263 2264 mock_tracking_store.calculate_trace_filter_correlation.assert_called_once_with( 2265 experiment_ids=experiment_ids, 2266 filter_string1=filter_string1, 2267 filter_string2=filter_string2, 2268 base_filter=None, 2269 ) 2270 2271 response_data = json.loads(resp.get_data()) 2272 assert response_data["npmi"] == 0.789 2273 assert response_data["npmi_smoothed"] == 0.775 2274 assert response_data["filter1_count"] == 50 2275 assert response_data["filter2_count"] == 40 2276 assert response_data["joint_count"] == 30 2277 assert response_data["total_count"] == 100 2278 2279 2280 def test_calculate_trace_filter_correlation_with_nan_npmi( 2281 mock_get_request_message, mock_tracking_store 2282 ): 2283 experiment_ids = ["123"] 2284 filter_string1 = "span.type = 'LLM'" 2285 filter_string2 = "feedback.quality > 0.8" 2286 2287 mock_request = CalculateTraceFilterCorrelation( 2288 experiment_ids=experiment_ids, 2289 filter_string1=filter_string1, 2290 filter_string2=filter_string2, 2291 ) 2292 mock_get_request_message.return_value = mock_request 2293 2294 mock_result = TraceFilterCorrelationResult( 2295 npmi=float("nan"), 2296 npmi_smoothed=None, 2297 filter1_count=0, 2298 filter2_count=0, 2299 joint_count=0, 2300 total_count=100, 2301 ) 2302 mock_tracking_store.calculate_trace_filter_correlation.return_value = mock_result 2303 2304 resp = _calculate_trace_filter_correlation() 2305 2306 mock_tracking_store.calculate_trace_filter_correlation.assert_called_once_with( 2307 experiment_ids=experiment_ids, 2308 filter_string1=filter_string1, 2309 filter_string2=filter_string2, 2310 base_filter=None, 2311 ) 2312 2313 response_data = json.loads(resp.get_data()) 2314 assert "npmi" not in response_data 2315 assert "npmi_smoothed" not in response_data 2316 assert response_data["filter1_count"] == 0 2317 assert response_data["filter2_count"] == 0 2318 assert response_data["joint_count"] == 0 2319 assert response_data["total_count"] == 100 2320 2321 2322 def test_databricks_tracking_store_registration(): 2323 registry = TrackingStoreRegistryWrapper() 2324 2325 # Test that the correct store type is returned for databricks scheme 2326 store = registry.get_store("databricks", artifact_uri=None) 2327 assert isinstance(store, DatabricksTracingRestStore) 2328 2329 # Verify that the store was created with the right get_host_creds function 2330 # The RestStore should have a get_host_creds attribute that is a partial function 2331 assert hasattr(store, "get_host_creds") 2332 assert store.get_host_creds.func.__name__ == "get_databricks_host_creds" 2333 assert store.get_host_creds.args == ("databricks",) 2334 2335 2336 def test_databricks_model_registry_store_registration(): 2337 registry = ModelRegistryStoreRegistryWrapper() 2338 2339 # Test that the correct store type is returned for databricks 2340 store = registry.get_store("databricks") 2341 assert isinstance(store, ModelRegistryRestStore) 2342 2343 # Verify that the store was created with the right get_host_creds function 2344 assert hasattr(store, "get_host_creds") 2345 assert store.get_host_creds.func.__name__ == "get_databricks_host_creds" 2346 assert store.get_host_creds.args == ("databricks",) 2347 2348 # Test that the correct store type is returned for databricks-uc 2349 uc_store = registry.get_store("databricks-uc") 2350 assert isinstance(uc_store, UcModelRegistryStore) 2351 2352 # Verify that the UC store was created with the right get_host_creds function 2353 # Note: UcModelRegistryStore uses get_databricks_host_creds internally, 2354 # not get_databricks_uc_host_creds 2355 assert hasattr(uc_store, "get_host_creds") 2356 assert uc_store.get_host_creds.func.__name__ == "get_databricks_host_creds" 2357 assert uc_store.get_host_creds.args == ("databricks-uc",) 2358 2359 # Also verify it has tracking_uri set 2360 assert hasattr(uc_store, "tracking_uri") 2361 # The tracking_uri will be set based on environment/test config 2362 # In test environment, it may be set to a test sqlite database 2363 assert uc_store.tracking_uri is not None 2364 2365 2366 def test_search_experiments_empty_page_token(mock_get_request_message, mock_tracking_store): 2367 # Create proto without setting page_token - it defaults to empty string 2368 search_experiments_proto = SearchExperiments() 2369 search_experiments_proto.max_results = 10 2370 2371 # Verify that proto's default page_token is empty string 2372 assert search_experiments_proto.page_token == "" 2373 2374 mock_get_request_message.return_value = search_experiments_proto 2375 mock_tracking_store.search_experiments.return_value = PagedList([], None) 2376 2377 _search_experiments() 2378 2379 # Verify that search_experiments was called with page_token=None (not empty string) 2380 mock_tracking_store.search_experiments.assert_called_once() 2381 call_kwargs = mock_tracking_store.search_experiments.call_args.kwargs 2382 assert call_kwargs.get("page_token") is None 2383 assert call_kwargs.get("max_results") == 10 2384 2385 2386 def test_search_registered_models_empty_page_token( 2387 mock_get_request_message, mock_model_registry_store 2388 ): 2389 # Create proto without setting page_token - it defaults to empty string 2390 search_registered_models_proto = SearchRegisteredModels() 2391 search_registered_models_proto.max_results = 10 2392 2393 # Verify that proto's default page_token is empty string 2394 assert search_registered_models_proto.page_token == "" 2395 2396 mock_get_request_message.return_value = search_registered_models_proto 2397 mock_model_registry_store.search_registered_models.return_value = PagedList([], None) 2398 2399 _search_registered_models() 2400 2401 # Verify that search_registered_models was called with page_token=None (not empty string) 2402 mock_model_registry_store.search_registered_models.assert_called_once() 2403 call_kwargs = mock_model_registry_store.search_registered_models.call_args.kwargs 2404 assert call_kwargs.get("page_token") is None 2405 assert call_kwargs.get("max_results") == 10 2406 2407 2408 def test_search_model_versions_empty_page_token( 2409 mock_get_request_message, mock_model_registry_store 2410 ): 2411 # Create proto without setting page_token - it defaults to empty string 2412 search_model_versions_proto = SearchModelVersions() 2413 search_model_versions_proto.max_results = 10 2414 2415 # Verify that proto's default page_token is empty string 2416 assert search_model_versions_proto.page_token == "" 2417 2418 mock_get_request_message.return_value = search_model_versions_proto 2419 mock_model_registry_store.search_model_versions.return_value = PagedList([], None) 2420 2421 _search_model_versions() 2422 2423 # Verify that search_model_versions was called with page_token=None (not empty string) 2424 mock_model_registry_store.search_model_versions.assert_called_once() 2425 call_kwargs = mock_model_registry_store.search_model_versions.call_args.kwargs 2426 assert call_kwargs.get("page_token") is None 2427 assert call_kwargs.get("max_results") == 10 2428 2429 2430 def test_search_traces_v3_empty_page_token(mock_get_request_message, mock_tracking_store): 2431 # Create proto without setting page_token - it defaults to empty string 2432 # SearchTracesV3 requires locations field 2433 search_traces_proto = SearchTracesV3() 2434 location = TraceLocation() 2435 location.mlflow_experiment.experiment_id = "1" 2436 search_traces_proto.locations.append(location) 2437 search_traces_proto.max_results = 10 2438 2439 # Verify that proto's default page_token is empty string 2440 assert search_traces_proto.page_token == "" 2441 2442 mock_get_request_message.return_value = search_traces_proto 2443 mock_tracking_store.search_traces.return_value = ([], None) 2444 2445 _search_traces_v3() 2446 2447 # Verify that search_traces was called with page_token=None (not empty string) 2448 mock_tracking_store.search_traces.assert_called_once() 2449 call_kwargs = mock_tracking_store.search_traces.call_args.kwargs 2450 assert call_kwargs.get("page_token") is None 2451 assert call_kwargs.get("max_results") == 10 2452 2453 2454 def test_deprecated_search_traces_v2_empty_page_token( 2455 mock_get_request_message, mock_tracking_store 2456 ): 2457 # Create proto without setting page_token - it defaults to empty string 2458 search_traces_proto = SearchTraces() 2459 search_traces_proto.max_results = 10 2460 2461 # Verify that proto's default page_token is empty string 2462 assert search_traces_proto.page_token == "" 2463 2464 mock_get_request_message.return_value = search_traces_proto 2465 mock_tracking_store.search_traces.return_value = ([], None) 2466 2467 _deprecated_search_traces_v2() 2468 2469 # Verify that search_traces was called with page_token=None (not empty string) 2470 mock_tracking_store.search_traces.assert_called_once() 2471 call_kwargs = mock_tracking_store.search_traces.call_args.kwargs 2472 assert call_kwargs.get("page_token") is None 2473 assert call_kwargs.get("max_results") == 10 2474 2475 2476 def test_search_logged_models_empty_page_token(mock_get_request_message, mock_tracking_store): 2477 # Create proto without setting page_token - it defaults to empty string 2478 search_logged_models_proto = SearchLoggedModels() 2479 search_logged_models_proto.max_results = 10 2480 2481 # Verify that proto's default page_token is empty string 2482 assert search_logged_models_proto.page_token == "" 2483 2484 mock_get_request_message.return_value = search_logged_models_proto 2485 mock_tracking_store.search_logged_models.return_value = PagedList([], None) 2486 2487 _search_logged_models() 2488 2489 # Verify that search_logged_models was called with page_token=None (not empty string) 2490 mock_tracking_store.search_logged_models.assert_called_once() 2491 call_kwargs = mock_tracking_store.search_logged_models.call_args.kwargs 2492 assert call_kwargs.get("page_token") is None 2493 assert call_kwargs.get("max_results") == 10 2494 2495 2496 def test_list_webhooks_empty_page_token(mock_get_request_message, mock_model_registry_store): 2497 # Create proto without setting page_token - it defaults to empty string 2498 list_webhooks_proto = ListWebhooks() 2499 list_webhooks_proto.max_results = 10 2500 2501 # Verify that proto's default page_token is empty string 2502 assert list_webhooks_proto.page_token == "" 2503 2504 mock_get_request_message.return_value = list_webhooks_proto 2505 mock_model_registry_store.list_webhooks.return_value = PagedList([], None) 2506 2507 _list_webhooks() 2508 2509 # Verify that list_webhooks was called with page_token=None (not empty string) 2510 mock_model_registry_store.list_webhooks.assert_called_once() 2511 call_kwargs = mock_model_registry_store.list_webhooks.call_args.kwargs 2512 assert call_kwargs.get("page_token") is None 2513 assert call_kwargs.get("max_results") == 10 2514 2515 2516 def test_batch_get_traces_handler(mock_get_request_message, mock_tracking_store): 2517 trace_id_1 = "test-trace-123" 2518 trace_id_2 = "test-trace-456" 2519 2520 get_traces_proto = BatchGetTraces(trace_ids=[trace_id_1, trace_id_2]) 2521 2522 mock_get_request_message.return_value = get_traces_proto 2523 2524 otel_span = OTelReadableSpan( 2525 name="test", 2526 context=build_otel_context(123, 234), 2527 parent=None, 2528 start_time=100, 2529 end_time=200, 2530 attributes={ 2531 "mlflow.spanInputs": json.dumps("inputs"), 2532 "mlflow.spanOutputs": json.dumps("outputs"), 2533 "mlflow.spanType": json.dumps("span_type"), 2534 }, 2535 ) 2536 mock_span = Span(otel_span) 2537 2538 # Create mock traces to return 2539 mock_trace_1 = Trace( 2540 info=TraceInfo( 2541 trace_id=trace_id_1, 2542 trace_location=EntityTraceLocation.from_experiment_id("1"), 2543 request_time=1234567890, 2544 execution_duration=5000, 2545 state=TraceState.OK, 2546 ), 2547 data=TraceData(spans=[mock_span]), 2548 ) 2549 2550 mock_trace_2 = Trace( 2551 info=TraceInfo( 2552 trace_id=trace_id_2, 2553 trace_location=EntityTraceLocation.from_experiment_id("1"), 2554 request_time=1234567890, 2555 execution_duration=3000, 2556 state=TraceState.OK, 2557 ), 2558 data=TraceData(spans=[mock_span]), 2559 ) 2560 2561 mock_tracking_store.batch_get_traces.return_value = [mock_trace_1, mock_trace_2] 2562 2563 # Call the handler 2564 response = _batch_get_traces() 2565 2566 # Verify the store was called with the correct trace IDs 2567 mock_tracking_store.batch_get_traces.assert_called_once_with([trace_id_1, trace_id_2], None) 2568 2569 # Verify response was created 2570 assert response is not None 2571 assert response.status_code == 200 2572 traces = json.loads(response.get_data())["traces"] 2573 assert len(traces) == 2 2574 assert len(traces[0]["spans"]) == 1 2575 assert len(traces[1]["spans"]) == 1 2576 2577 2578 def test_batch_get_traces_handler_empty_list(mock_get_request_message, mock_tracking_store): 2579 get_traces_proto = BatchGetTraces() 2580 2581 mock_get_request_message.return_value = get_traces_proto 2582 mock_tracking_store.batch_get_traces.return_value = [] 2583 2584 response = _batch_get_traces() 2585 2586 mock_tracking_store.batch_get_traces.assert_called_once_with([], None) 2587 2588 # Verify response was created 2589 assert response is not None 2590 assert response.status_code == 200 2591 2592 2593 def test_batch_get_trace_infos_handler(mock_get_request_message, mock_tracking_store): 2594 trace_id_1 = "test-trace-123" 2595 trace_id_2 = "test-trace-456" 2596 2597 mock_get_request_message.return_value = BatchGetTraceInfos(trace_ids=[trace_id_1, trace_id_2]) 2598 2599 mock_trace_info_1 = TraceInfo( 2600 trace_id=trace_id_1, 2601 trace_location=EntityTraceLocation.from_experiment_id("1"), 2602 request_time=1234567890, 2603 execution_duration=5000, 2604 state=TraceState.OK, 2605 ) 2606 mock_trace_info_2 = TraceInfo( 2607 trace_id=trace_id_2, 2608 trace_location=EntityTraceLocation.from_experiment_id("1"), 2609 request_time=1234567890, 2610 execution_duration=3000, 2611 state=TraceState.OK, 2612 ) 2613 2614 mock_tracking_store.batch_get_trace_infos.return_value = [ 2615 mock_trace_info_1, 2616 mock_trace_info_2, 2617 ] 2618 2619 response = _batch_get_trace_infos() 2620 2621 mock_tracking_store.batch_get_trace_infos.assert_called_once_with([trace_id_1, trace_id_2]) 2622 2623 assert response is not None 2624 assert response.status_code == 200 2625 trace_infos = json.loads(response.get_data())["trace_infos"] 2626 assert len(trace_infos) == 2 2627 assert trace_infos[0]["trace_id"] == trace_id_1 2628 assert trace_infos[1]["trace_id"] == trace_id_2 2629 2630 2631 def test_get_trace_handler(mock_get_request_message, mock_tracking_store): 2632 trace_id = "test-trace-123" 2633 2634 get_trace_proto = GetTrace(trace_id=trace_id, allow_partial=True) 2635 mock_get_request_message.return_value = get_trace_proto 2636 2637 otel_span = OTelReadableSpan( 2638 name="test", 2639 context=build_otel_context(123, 234), 2640 parent=None, 2641 start_time=100, 2642 end_time=200, 2643 attributes={ 2644 "mlflow.spanInputs": json.dumps("inputs"), 2645 "mlflow.spanOutputs": json.dumps("outputs"), 2646 "mlflow.spanType": json.dumps("span_type"), 2647 }, 2648 ) 2649 mock_span = Span(otel_span) 2650 2651 mock_trace = Trace( 2652 info=TraceInfo( 2653 trace_id=trace_id, 2654 trace_location=EntityTraceLocation.from_experiment_id("1"), 2655 request_time=1234567890, 2656 execution_duration=5000, 2657 state=TraceState.OK, 2658 ), 2659 data=TraceData(spans=[mock_span]), 2660 ) 2661 2662 mock_tracking_store.get_trace.return_value = mock_trace 2663 2664 response = _get_trace() 2665 2666 mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True) 2667 2668 assert response is not None 2669 assert response.status_code == 200 2670 response_data = json.loads(response.get_data()) 2671 assert "trace" in response_data 2672 trace = response_data["trace"] 2673 assert trace["trace_info"]["trace_id"] == trace_id 2674 assert len(trace["spans"]) == 1 2675 2676 2677 def test_get_trace_handler_with_allow_partial_false(mock_get_request_message, mock_tracking_store): 2678 trace_id = "test-trace-456" 2679 2680 get_trace_proto = GetTrace(trace_id=trace_id, allow_partial=False) 2681 mock_get_request_message.return_value = get_trace_proto 2682 2683 otel_span = OTelReadableSpan( 2684 name="test", 2685 context=build_otel_context(123, 234), 2686 parent=None, 2687 start_time=100, 2688 end_time=200, 2689 attributes={}, 2690 ) 2691 mock_span = Span(otel_span) 2692 2693 mock_trace = Trace( 2694 info=TraceInfo( 2695 trace_id=trace_id, 2696 trace_location=EntityTraceLocation.from_experiment_id("1"), 2697 request_time=1234567890, 2698 execution_duration=5000, 2699 state=TraceState.OK, 2700 ), 2701 data=TraceData(spans=[mock_span]), 2702 ) 2703 2704 mock_tracking_store.get_trace.return_value = mock_trace 2705 2706 response = _get_trace() 2707 2708 mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=False) 2709 2710 assert response is not None 2711 assert response.status_code == 200 2712 response_data = json.loads(response.get_data()) 2713 assert "trace" in response_data 2714 2715 2716 def test_get_trace_handler_not_found(mock_get_request_message, mock_tracking_store): 2717 trace_id = "non-existent-trace" 2718 2719 get_trace_proto = GetTrace(trace_id=trace_id) 2720 mock_get_request_message.return_value = get_trace_proto 2721 2722 mock_tracking_store.get_trace.side_effect = MlflowException( 2723 f"Trace with ID {trace_id} is not found.", 2724 error_code=RESOURCE_DOES_NOT_EXIST, 2725 ) 2726 2727 response = _get_trace() 2728 2729 mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=False) 2730 2731 assert response is not None 2732 assert response.status_code == 404 2733 response_data = json.loads(response.get_data()) 2734 assert "error_code" in response_data 2735 assert response_data["error_code"] == "RESOURCE_DOES_NOT_EXIST" 2736 2737 2738 def test_get_trace_artifact_handler(mock_tracking_store): 2739 trace_id = "test-trace-artifact-123" 2740 2741 otel_span = OTelReadableSpan( 2742 name="test_span", 2743 context=build_otel_context(123, 234), 2744 parent=None, 2745 start_time=100, 2746 end_time=200, 2747 attributes={ 2748 "mlflow.spanInputs": json.dumps({"input": "test_input"}), 2749 "mlflow.spanOutputs": json.dumps({"output": "test_output"}), 2750 }, 2751 ) 2752 mock_span = Span(otel_span) 2753 2754 mock_trace = Trace( 2755 info=TraceInfo( 2756 trace_id=trace_id, 2757 trace_location=EntityTraceLocation.from_experiment_id("1"), 2758 request_time=1234567890, 2759 execution_duration=5000, 2760 state=TraceState.OK, 2761 ), 2762 data=TraceData(spans=[mock_span]), 2763 ) 2764 2765 mock_tracking_store.get_trace.return_value = mock_trace 2766 mock_tracking_store.batch_get_traces.return_value = [mock_trace] 2767 2768 with app.test_request_context(method="GET", query_string={"request_id": trace_id}): 2769 response = get_trace_artifact_handler() 2770 2771 # Verify the store was called correctly 2772 mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True) 2773 2774 # Verify response headers and status 2775 assert response is not None 2776 assert response.status_code == 200 2777 assert response.headers["Content-Disposition"] == "attachment; filename=traces.json" 2778 2779 2780 def test_get_trace_artifact_handler_missing_request_id(mock_tracking_store): 2781 with app.test_request_context(method="GET"): 2782 response = get_trace_artifact_handler() 2783 2784 assert response.status_code == 400 2785 response_data = json.loads(response.get_data()) 2786 assert "error_code" in response_data 2787 assert response_data["error_code"] == "BAD_REQUEST" 2788 assert 'must include the "request_id" query parameter' in response_data["message"] 2789 2790 2791 def test_get_trace_artifact_handler_trace_not_found(mock_tracking_store): 2792 trace_id = "non-existent-trace" 2793 mock_tracking_store.get_trace.side_effect = MlflowException( 2794 f"Trace with ID {trace_id} is not found.", 2795 error_code=RESOURCE_DOES_NOT_EXIST, 2796 ) 2797 2798 with app.test_request_context(method="GET", query_string={"request_id": trace_id}): 2799 response = get_trace_artifact_handler() 2800 2801 mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True) 2802 2803 assert response.status_code == 404 2804 response_data = json.loads(response.get_data()) 2805 assert "error_code" in response_data 2806 assert response_data["error_code"] == "RESOURCE_DOES_NOT_EXIST" 2807 assert f"Trace with ID {trace_id} is not found" in response_data["message"] 2808 2809 2810 def test_get_trace_artifact_handler_fallback_to_batch_get_traces(mock_tracking_store): 2811 trace_id = "test-trace-batch-123" 2812 2813 otel_span = OTelReadableSpan( 2814 name="test_span_batch", 2815 context=build_otel_context(456, 789), 2816 parent=None, 2817 start_time=100, 2818 end_time=200, 2819 attributes={ 2820 "mlflow.spanInputs": json.dumps({"input": "batch_input"}), 2821 "mlflow.spanOutputs": json.dumps({"output": "batch_output"}), 2822 }, 2823 ) 2824 mock_span = Span(otel_span) 2825 2826 mock_trace = Trace( 2827 info=TraceInfo( 2828 trace_id=trace_id, 2829 trace_location=EntityTraceLocation.from_experiment_id("2"), 2830 request_time=1234567890, 2831 execution_duration=3000, 2832 state=TraceState.OK, 2833 ), 2834 data=TraceData(spans=[mock_span]), 2835 ) 2836 2837 # Simulate get_trace not being implemented 2838 mock_tracking_store.get_trace.side_effect = MlflowNotImplementedException( 2839 "get_trace is not implemented" 2840 ) 2841 mock_tracking_store.batch_get_traces.return_value = [mock_trace] 2842 2843 with app.test_request_context(method="GET", query_string={"request_id": trace_id}): 2844 response = get_trace_artifact_handler() 2845 2846 # Verify both methods were called 2847 mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True) 2848 mock_tracking_store.batch_get_traces.assert_called_once_with([trace_id], None) 2849 2850 # Verify successful response 2851 assert response is not None 2852 assert response.status_code == 200 2853 assert response.headers["Content-Disposition"] == "attachment; filename=traces.json" 2854 2855 2856 def test_get_trace_artifact_handler_batch_get_traces_not_found(mock_tracking_store): 2857 trace_id = "non-existent-batch-trace" 2858 2859 # Simulate get_trace not being implemented 2860 mock_tracking_store.get_trace.side_effect = MlflowNotImplementedException( 2861 "get_trace is not implemented" 2862 ) 2863 # batch_get_traces returns empty list (trace not found) 2864 mock_tracking_store.batch_get_traces.return_value = [] 2865 2866 with app.test_request_context(method="GET", query_string={"request_id": trace_id}): 2867 response = get_trace_artifact_handler() 2868 2869 # Verify both methods were called 2870 mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True) 2871 mock_tracking_store.batch_get_traces.assert_called_once_with([trace_id], None) 2872 2873 # Verify 404 response 2874 assert response.status_code == 404 2875 response_data = json.loads(response.get_data()) 2876 assert "error_code" in response_data 2877 assert response_data["error_code"] == "RESOURCE_DOES_NOT_EXIST" 2878 assert f"Trace with id={trace_id} not found" in response_data["message"] 2879 2880 2881 def test_get_trace_artifact_handler_fallback_to_artifact_repo(mock_tracking_store): 2882 trace_id = "test-trace-artifact-repo-123" 2883 2884 trace_info = TraceInfo( 2885 trace_id=trace_id, 2886 trace_location=EntityTraceLocation.from_experiment_id("3"), 2887 request_time=1234567890, 2888 execution_duration=4000, 2889 state=TraceState.OK, 2890 ) 2891 2892 trace_data = { 2893 "spans": [ 2894 { 2895 "name": "artifact_span", 2896 "context": {"trace_id": trace_id, "span_id": "123"}, 2897 "parent_id": None, 2898 "start_time": 100, 2899 "end_time": 200, 2900 "status_code": "OK", 2901 "status_message": "", 2902 "attributes": {}, 2903 "events": [], 2904 } 2905 ] 2906 } 2907 2908 # Simulate batch_get_traces not being implemented 2909 mock_tracking_store.get_trace.side_effect = MlflowNotImplementedException( 2910 "get_trace is not implemented" 2911 ) 2912 mock_tracking_store.batch_get_traces.side_effect = MlflowNotImplementedException( 2913 "batch_get_traces is not implemented" 2914 ) 2915 mock_tracking_store.get_trace_info.return_value = trace_info 2916 2917 # Mock the artifact repo 2918 mock_artifact_repo = mock.MagicMock() 2919 mock_artifact_repo.download_trace_data.return_value = trace_data 2920 2921 with mock.patch( 2922 "mlflow.server.handlers._get_trace_artifact_repo", return_value=mock_artifact_repo 2923 ): 2924 with app.test_request_context(method="GET", query_string={"request_id": trace_id}): 2925 response = get_trace_artifact_handler() 2926 2927 # Verify the fallback path was taken 2928 mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True) 2929 mock_tracking_store.batch_get_traces.assert_called_once_with([trace_id], None) 2930 mock_tracking_store.get_trace_info.assert_called_once_with(trace_id) 2931 mock_artifact_repo.download_trace_data.assert_called_once() 2932 2933 # Verify successful response 2934 assert response is not None 2935 assert response.status_code == 200 2936 assert response.headers["Content-Disposition"] == "attachment; filename=traces.json" 2937 2938 2939 def test_get_trace_artifact_handler_with_attachment_path(mock_tracking_store): 2940 trace_id = "tr-test-attachment-123" 2941 attachment_id = "a1b2c3d4-e5f6-4890-abcd-ef1234567890" 2942 2943 trace_info = TraceInfo( 2944 trace_id=trace_id, 2945 trace_location=EntityTraceLocation.from_experiment_id("3"), 2946 request_time=1234567890, 2947 execution_duration=4000, 2948 state=TraceState.OK, 2949 ) 2950 2951 mock_tracking_store.get_trace_info.return_value = trace_info 2952 2953 mock_artifact_repo = mock.MagicMock() 2954 mock_artifact_repo.download_trace_attachment.return_value = b"\x89PNG fake image" 2955 2956 with mock.patch( 2957 "mlflow.server.handlers._get_trace_artifact_repo", return_value=mock_artifact_repo 2958 ): 2959 query = {"request_id": trace_id, "path": attachment_id} 2960 with app.test_request_context(method="GET", query_string=query): 2961 response = get_trace_artifact_handler() 2962 2963 mock_tracking_store.get_trace_info.assert_called_once_with(trace_id) 2964 mock_artifact_repo.download_trace_attachment.assert_called_once_with(attachment_id) 2965 assert response.status_code == 200 2966 assert response.headers["Content-Type"] == "application/octet-stream" 2967 assert response.headers["Content-Disposition"] == f"attachment; filename={attachment_id}" 2968 assert response.headers["X-Content-Type-Options"] == "nosniff" 2969 2970 2971 def test_get_trace_artifact_handler_attachment_missing_request_id(): 2972 query = {"path": "a1b2c3d4-e5f6-4890-abcd-ef1234567890"} 2973 with app.test_request_context(method="GET", query_string=query): 2974 response = get_trace_artifact_handler() 2975 assert response.status_code == 400 2976 2977 2978 def test_get_trace_artifact_handler_attachment_trace_not_found(mock_tracking_store): 2979 mock_tracking_store.get_trace_info.return_value = None 2980 2981 query = {"request_id": "tr-nonexistent", "path": "a1b2c3d4-e5f6-4890-abcd-ef1234567890"} 2982 with app.test_request_context(method="GET", query_string=query): 2983 response = get_trace_artifact_handler() 2984 assert response.status_code == 404 2985 2986 2987 def test_delete_trace_tag_v2_handler(mock_get_request_message, mock_tracking_store): 2988 """Test v2 delete_trace_tag handler with request_id parameter. 2989 2990 Verifies that when the Flask route uses request_id path parameter, 2991 the _delete_trace_tag handler is called and invokes store.delete_trace_tag(). 2992 """ 2993 2994 request_id = "tr-123v2" 2995 tag_key = "tk" 2996 2997 # Create the request message 2998 request_msg = DeleteTraceTag(key=tag_key) 2999 mock_get_request_message.return_value = request_msg 3000 3001 # Call the v2 handler with request_id parameter 3002 response = _delete_trace_tag(request_id=request_id) 3003 3004 # Verify the store method was called with correct parameters 3005 mock_tracking_store.delete_trace_tag.assert_called_once_with(request_id, tag_key) 3006 3007 assert response is not None 3008 assert response.status_code == 200 3009 3010 3011 def test_delete_trace_tag_v3_handler(mock_get_request_message, mock_tracking_store): 3012 """Test v3 delete_trace_tag handler with trace_id parameter. 3013 3014 Verifies that when the Flask route uses trace_id path parameter, 3015 the _delete_trace_tag_v3 handler is called and invokes store.delete_trace_tag(). 3016 This is similar to v2 but uses the v3 proto message and route parameter naming. 3017 """ 3018 3019 trace_id = "tr-v3-456" 3020 tag_key = "tk" 3021 3022 # Create the request message with V3 3023 request_msg = DeleteTraceTagV3(key=tag_key) 3024 mock_get_request_message.return_value = request_msg 3025 3026 # Call the v3 handler with trace_id parameter 3027 response = _delete_trace_tag_v3(trace_id=trace_id) 3028 3029 # Verify the store method was called with correct parameters 3030 # Both v2 and v3 call the same store method 3031 mock_tracking_store.delete_trace_tag.assert_called_once_with(trace_id, tag_key) 3032 3033 assert response is not None 3034 assert response.status_code == 200 3035 3036 3037 def test_set_trace_tag_v2_handler(mock_get_request_message, mock_tracking_store): 3038 """Test v2 set_trace_tag handler with request_id parameter. 3039 3040 Verifies that when the Flask route uses request_id path parameter, 3041 the _set_trace_tag handler is called and invokes store.set_trace_tag(). 3042 """ 3043 trace_id = "tr-test-v2-123" 3044 tag_key = "tk" 3045 tag_value = "tv" 3046 3047 # Create the request message 3048 request_msg = SetTraceTag(key=tag_key, value=tag_value) 3049 mock_get_request_message.return_value = request_msg 3050 3051 # Call the v2 handler with request_id parameter 3052 response = _set_trace_tag(request_id=trace_id) 3053 3054 # Verify the store method was called with correct parameters 3055 mock_tracking_store.set_trace_tag.assert_called_once_with(trace_id, tag_key, tag_value) 3056 3057 # Verify response was created (200 status) 3058 assert response is not None 3059 assert response.status_code == 200 3060 3061 3062 def test_set_trace_tag_v3_handler(mock_get_request_message, mock_tracking_store): 3063 """Test v3 set_trace_tag handler with trace_id parameter. 3064 3065 Verifies that when the Flask route uses trace_id path parameter, 3066 the _set_trace_tag_v3 handler is called and invokes store.set_trace_tag(). 3067 This is similar to v2 but uses the v3 proto message and route parameter naming. 3068 """ 3069 trace_id = "tr-test-v3-456" 3070 tag_key = "tk" 3071 tag_value = "tv" 3072 3073 # Create the request message (v3 version) 3074 request_msg = SetTraceTagV3(key=tag_key, value=tag_value) 3075 mock_get_request_message.return_value = request_msg 3076 3077 # Call the v3 handler with trace_id parameter 3078 response = _set_trace_tag_v3(trace_id=trace_id) 3079 3080 # Verify the store method was called with correct parameters 3081 # Note: Both handlers call the same store method 3082 mock_tracking_store.set_trace_tag.assert_called_once_with(trace_id, tag_key, tag_value) 3083 3084 # Verify response was created (200 status) 3085 assert response is not None 3086 assert response.status_code == 200 3087 3088 3089 def test_link_prompts_to_trace_handler(mock_get_request_message, mock_tracking_store): 3090 """Test link_prompts_to_trace handler. 3091 3092 Verifies that the handler correctly parses the request and calls 3093 store.link_prompts_to_trace() with the appropriate PromptVersion objects. 3094 """ 3095 trace_id = "tr-test-123" 3096 prompt_versions_refs = [ 3097 LinkPromptsToTrace.PromptVersionRef(name="prompt1", version="1"), 3098 LinkPromptsToTrace.PromptVersionRef(name="prompt2", version="2"), 3099 ] 3100 3101 # Create the request message 3102 request_msg = LinkPromptsToTrace(trace_id=trace_id, prompt_versions=prompt_versions_refs) 3103 mock_get_request_message.return_value = request_msg 3104 3105 # Call the handler 3106 response = _link_prompts_to_trace() 3107 3108 # Verify the store method was called with correct parameters 3109 # The handler should convert PromptVersionRef to PromptVersion objects 3110 call_args = mock_tracking_store.link_prompts_to_trace.call_args 3111 assert call_args[1]["trace_id"] == trace_id 3112 3113 prompt_versions = call_args[1]["prompt_versions"] 3114 assert len(prompt_versions) == 2 3115 assert isinstance(prompt_versions[0], PromptVersion) 3116 assert prompt_versions[0].name == "prompt1" 3117 assert prompt_versions[0].version == 1 3118 assert isinstance(prompt_versions[1], PromptVersion) 3119 assert prompt_versions[1].name == "prompt2" 3120 assert prompt_versions[1].version == 2 3121 3122 # Verify response was created (200 status) 3123 assert response is not None 3124 assert response.status_code == 200 3125 3126 3127 def test_list_providers(): 3128 with app.test_client() as c: 3129 response = c.get("/ajax-api/3.0/mlflow/gateway/supported-providers") 3130 assert response.status_code == 200 3131 data = response.get_json() 3132 assert "providers" in data 3133 assert isinstance(data["providers"], list) 3134 assert len(data["providers"]) > 0 3135 assert "openai" in data["providers"] 3136 3137 3138 def test_list_providers_with_allowed_filter(monkeypatch): 3139 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai,anthropic") 3140 with app.test_client() as c: 3141 response = c.get("/ajax-api/3.0/mlflow/gateway/supported-providers") 3142 assert response.status_code == 200 3143 data = response.get_json() 3144 assert "openai" in data["providers"] 3145 assert "anthropic" in data["providers"] 3146 assert "gemini" not in data["providers"] 3147 assert "bedrock" not in data["providers"] 3148 3149 3150 def test_list_models(): 3151 with app.test_client() as c: 3152 response = c.get("/ajax-api/3.0/mlflow/gateway/supported-models?provider=openai") 3153 assert response.status_code == 200 3154 data = response.get_json() 3155 assert "models" in data 3156 assert isinstance(data["models"], list) 3157 assert len(data["models"]) > 0 3158 3159 3160 def test_list_models_all_providers(): 3161 with app.test_client() as c: 3162 response = c.get("/ajax-api/3.0/mlflow/gateway/supported-models") 3163 assert response.status_code == 200 3164 data = response.get_json() 3165 assert "models" in data 3166 assert isinstance(data["models"], list) 3167 assert len(data["models"]) > 0 3168 3169 3170 def test_get_provider_config(): 3171 with app.test_client() as c: 3172 response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config?provider=openai") 3173 assert response.status_code == 200 3174 data = response.get_json() 3175 assert "auth_modes" in data 3176 assert "default_mode" in data 3177 assert data["default_mode"] == "api_key" 3178 assert len(data["auth_modes"]) >= 1 3179 api_key_mode = data["auth_modes"][0] 3180 assert api_key_mode["mode"] == "api_key" 3181 3182 3183 def test_get_provider_config_with_multiple_auth_modes(): 3184 with app.test_client() as c: 3185 response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config?provider=bedrock") 3186 assert response.status_code == 200 3187 data = response.get_json() 3188 3189 assert "auth_modes" in data 3190 assert data["default_mode"] == "api_key" 3191 assert len(data["auth_modes"]) >= 2 3192 3193 access_keys_mode = next(m for m in data["auth_modes"] if m["mode"] == "access_keys") 3194 assert len(access_keys_mode["secret_fields"]) == 2 3195 assert any(f["name"] == "aws_secret_access_key" for f in access_keys_mode["secret_fields"]) 3196 assert any(f["name"] == "aws_region_name" for f in access_keys_mode["config_fields"]) 3197 3198 iam_role_mode = next(m for m in data["auth_modes"] if m["mode"] == "iam_role") 3199 assert any(f["name"] == "aws_role_name" for f in iam_role_mode["config_fields"]) 3200 3201 3202 def test_get_provider_config_missing_provider(): 3203 with app.test_client() as c: 3204 response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config") 3205 assert response.status_code == 400 3206 3207 3208 def test_get_provider_config_with_allowed_filter(monkeypatch): 3209 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "anthropic") 3210 with app.test_client() as c: 3211 response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config?provider=openai") 3212 assert response.status_code == 400 3213 data = response.get_json() 3214 assert "not allowed" in data["message"] 3215 3216 response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config?provider=anthropic") 3217 assert response.status_code == 200 3218 3219 3220 @pytest.mark.parametrize( 3221 "invalid_name", 3222 [ 3223 "invalid name", # space 3224 "invalid/name", # slash 3225 "invalid?name", # question mark 3226 "invalid&name", # ampersand 3227 "invalid#name", # hash 3228 "invalid@name", # at sign 3229 "invalid:name", # colon 3230 "日本語", # unicode (Japanese) 3231 "naïve", # unicode (accented) 3232 ], 3233 ) 3234 def test_create_gateway_endpoint_rejects_invalid_name(mock_get_request_message, invalid_name): 3235 from mlflow.protos.service_pb2 import CreateGatewayEndpoint 3236 from mlflow.server.handlers import _create_gateway_endpoint 3237 3238 request_msg = CreateGatewayEndpoint() 3239 request_msg.name = invalid_name 3240 mock_get_request_message.return_value = request_msg 3241 3242 response = _create_gateway_endpoint() 3243 3244 assert response.status_code == 400 3245 response_data = json.loads(response.get_data()) 3246 assert "Invalid endpoint name" in response_data["message"] 3247 assert response_data["error_code"] == "INVALID_PARAMETER_VALUE" 3248 3249 3250 @pytest.mark.parametrize( 3251 "invalid_name", 3252 [ 3253 "invalid name", # space 3254 "invalid/name", # slash 3255 "invalid?name", # question mark 3256 "invalid&name", # ampersand 3257 ], 3258 ) 3259 def test_update_gateway_endpoint_rejects_invalid_name(mock_get_request_message, invalid_name): 3260 from mlflow.protos.service_pb2 import UpdateGatewayEndpoint 3261 from mlflow.server.handlers import _update_gateway_endpoint 3262 3263 request_msg = UpdateGatewayEndpoint() 3264 request_msg.endpoint_id = "test-endpoint-id" 3265 request_msg.name = invalid_name 3266 mock_get_request_message.return_value = request_msg 3267 3268 response = _update_gateway_endpoint() 3269 3270 assert response.status_code == 400 3271 response_data = json.loads(response.get_data()) 3272 assert "Invalid endpoint name" in response_data["message"] 3273 assert response_data["error_code"] == "INVALID_PARAMETER_VALUE" 3274 3275 3276 def test_get_gateway_endpoint_by_endpoint_id(mock_get_request_message, mock_tracking_store): 3277 request_msg = GetGatewayEndpoint() 3278 request_msg.endpoint_id = "ep-123" 3279 mock_get_request_message.return_value = request_msg 3280 3281 mock_endpoint = mock.MagicMock() 3282 mock_endpoint.to_proto.return_value = GatewayEndpoint(endpoint_id="ep-123") 3283 mock_tracking_store.get_gateway_endpoint.return_value = mock_endpoint 3284 3285 response = _get_gateway_endpoint() 3286 3287 mock_tracking_store.get_gateway_endpoint.assert_called_once_with( 3288 endpoint_id="ep-123", name=None 3289 ) 3290 assert response.status_code == 200 3291 3292 3293 def test_get_gateway_endpoint_by_name(mock_get_request_message, mock_tracking_store): 3294 3295 request_msg = GetGatewayEndpoint() 3296 request_msg.name = "my-endpoint" 3297 mock_get_request_message.return_value = request_msg 3298 3299 mock_endpoint = mock.MagicMock() 3300 mock_endpoint.to_proto.return_value = GatewayEndpoint(endpoint_id="ep-456", name="my-endpoint") 3301 mock_tracking_store.get_gateway_endpoint.return_value = mock_endpoint 3302 3303 response = _get_gateway_endpoint() 3304 3305 mock_tracking_store.get_gateway_endpoint.assert_called_once_with( 3306 endpoint_id=None, name="my-endpoint" 3307 ) 3308 assert response.status_code == 200 3309 3310 3311 def test_query_trace_metrics_handler(mock_get_request_message, mock_tracking_store): 3312 experiment_ids = ["exp1", "exp2"] 3313 metric_name = "latency" 3314 3315 # Create aggregation protos 3316 aggregations_proto = [ 3317 MetricAggregation(aggregation_type=AggregationType.AVG).to_proto(), 3318 MetricAggregation( 3319 aggregation_type=AggregationType.PERCENTILE, percentile_value=95.0 3320 ).to_proto(), 3321 ] 3322 3323 # Create the request message 3324 request_msg = QueryTraceMetrics( 3325 experiment_ids=experiment_ids, 3326 view_type=MetricViewType.TRACES.to_proto(), 3327 metric_name=metric_name, 3328 aggregations=aggregations_proto, 3329 dimensions=["status", "model"], 3330 filters=["status = 'OK'"], 3331 time_interval_seconds=3600, 3332 start_time_ms=1000000, 3333 end_time_ms=2000000, 3334 max_results=100, 3335 page_token="token123", 3336 ) 3337 mock_get_request_message.return_value = request_msg 3338 3339 # Create mock result 3340 mock_data_points = [ 3341 MetricDataPoint( 3342 metric_name="latency", 3343 dimensions={"status": "OK", "model": "gpt-4"}, 3344 values={"AVG": 150.5, "P95.0": 200.0}, 3345 ), 3346 MetricDataPoint( 3347 metric_name="latency", 3348 dimensions={"status": "ERROR", "model": "gpt-4"}, 3349 values={"AVG": 50.0, "P95.0": 75.0}, 3350 ), 3351 ] 3352 3353 # Create a mock result object with next_page_token attribute 3354 mock_result = mock.MagicMock() 3355 mock_result.__iter__ = mock.MagicMock(return_value=iter(mock_data_points)) 3356 mock_result.token = "next_token" 3357 mock_tracking_store.query_trace_metrics.return_value = mock_result 3358 3359 # Call the handler 3360 response = _query_trace_metrics() 3361 3362 mock_tracking_store.query_trace_metrics.assert_called_once_with( 3363 experiment_ids=experiment_ids, 3364 view_type=MetricViewType.TRACES, 3365 metric_name=metric_name, 3366 aggregations=[ 3367 MetricAggregation(aggregation_type=AggregationType.AVG), 3368 MetricAggregation(aggregation_type=AggregationType.PERCENTILE, percentile_value=95.0), 3369 ], 3370 dimensions=["status", "model"], 3371 filters=["status = 'OK'"], 3372 time_interval_seconds=3600, 3373 start_time_ms=1000000, 3374 end_time_ms=2000000, 3375 max_results=100, 3376 page_token="token123", 3377 ) 3378 3379 assert response is not None 3380 assert response.status_code == 200 3381 response_data = json.loads(response.get_data()) 3382 assert "data_points" in response_data 3383 assert len(response_data["data_points"]) == 2 3384 assert response_data["data_points"][0] == asdict(mock_data_points[0]) 3385 assert response_data["data_points"][1] == asdict(mock_data_points[1]) 3386 assert response_data["next_page_token"] == "next_token" 3387 3388 3389 def test_query_trace_metrics_handler_empty_result(mock_get_request_message, mock_tracking_store): 3390 request_msg = QueryTraceMetrics( 3391 experiment_ids=["exp1"], 3392 view_type=MetricViewType.TRACES.to_proto(), 3393 metric_name="latency", 3394 aggregations=[MetricAggregation(aggregation_type=AggregationType.AVG).to_proto()], 3395 ) 3396 mock_get_request_message.return_value = request_msg 3397 3398 mock_result = mock.MagicMock() 3399 mock_result.__iter__ = mock.MagicMock(return_value=iter([])) 3400 mock_result.token = None 3401 mock_tracking_store.query_trace_metrics.return_value = mock_result 3402 3403 response = _query_trace_metrics() 3404 3405 mock_tracking_store.query_trace_metrics.assert_called_once_with( 3406 experiment_ids=["exp1"], 3407 view_type=MetricViewType.TRACES, 3408 metric_name="latency", 3409 aggregations=[MetricAggregation(aggregation_type=AggregationType.AVG)], 3410 dimensions=None, 3411 filters=None, 3412 time_interval_seconds=None, 3413 start_time_ms=None, 3414 end_time_ms=None, 3415 max_results=MAX_RESULTS_QUERY_TRACE_METRICS, 3416 page_token=None, 3417 ) 3418 3419 assert response is not None 3420 assert response.status_code == 200 3421 response_data = json.loads(response.get_data()) 3422 assert response_data == {} 3423 3424 3425 def test_invoke_scorer_missing_experiment_id(): 3426 with app.test_client() as c: 3427 response = c.post( 3428 "/ajax-api/3.0/mlflow/scorer/invoke", 3429 json={"serialized_scorer": "test", "trace_ids": ["trace1"]}, 3430 ) 3431 assert response.status_code == 400 3432 data = response.get_json() 3433 assert "experiment_id" in data["message"] 3434 3435 3436 def test_invoke_scorer_missing_serialized_scorer(): 3437 with app.test_client() as c: 3438 response = c.post( 3439 "/ajax-api/3.0/mlflow/scorer/invoke", 3440 json={"experiment_id": "123", "trace_ids": ["trace1"]}, 3441 ) 3442 assert response.status_code == 400 3443 data = response.get_json() 3444 assert "serialized_scorer" in data["message"] 3445 3446 3447 def test_invoke_scorer_missing_trace_ids(): 3448 with app.test_client() as c: 3449 response = c.post( 3450 "/ajax-api/3.0/mlflow/scorer/invoke", 3451 json={"experiment_id": "123", "serialized_scorer": "test"}, 3452 ) 3453 assert response.status_code == 400 3454 data = response.get_json() 3455 assert "Please select at least one trace to evaluate" in data["message"] 3456 3457 3458 def test_invoke_scorer_submits_jobs(mock_tracking_store): 3459 serialized_scorer = json.dumps({ 3460 "name": "test_judge", 3461 "aggregations": [], 3462 "description": None, 3463 "is_session_level_scorer": False, 3464 "mlflow_version": mlflow.__version__, 3465 "serialization_version": 1, 3466 "builtin_scorer_class": None, 3467 "builtin_scorer_pydantic_data": None, 3468 "call_source": None, 3469 "call_signature": None, 3470 "original_func_name": None, 3471 "instructions_judge_pydantic_data": { 3472 "instructions": "Test: {{ inputs }}", 3473 "model": "openai:/gpt-4", 3474 "feedback_value_type": { 3475 "enum": ["Yes", "No"], 3476 "title": "Result", 3477 "type": "string", 3478 }, 3479 }, 3480 }) 3481 3482 with mock.patch("mlflow.server.jobs.submit_job") as mock_submit: 3483 mock_job = mock.MagicMock() 3484 mock_job.job_id = "test-job-123" 3485 mock_submit.return_value = mock_job 3486 3487 with app.test_client() as c: 3488 response = c.post( 3489 "/ajax-api/3.0/mlflow/scorer/invoke", 3490 json={ 3491 "experiment_id": "exp-123", 3492 "serialized_scorer": serialized_scorer, 3493 "trace_ids": ["trace1", "trace2"], 3494 }, 3495 ) 3496 assert response.status_code == 200 3497 data = response.get_json() 3498 assert "jobs" in data 3499 assert len(data["jobs"]) == 1 3500 assert data["jobs"][0]["job_id"] == "test-job-123" 3501 assert data["jobs"][0]["trace_ids"] == ["trace1", "trace2"] 3502 3503 mock_submit.assert_called_once() 3504 3505 3506 def test_get_ui_telemetry_handler( 3507 test_app_context, mock_telemetry_config_cache, bypass_telemetry_env_check 3508 ): 3509 config = { 3510 "disable_telemetry": False, 3511 "disable_ui_telemetry": False, 3512 "disable_ui_events": ["event1", "event2"], 3513 "ui_rollout_percentage": 50, 3514 } 3515 3516 with mock.patch( 3517 "mlflow.server.handlers.fetch_ui_telemetry_config", return_value=config 3518 ) as mock_fetch: 3519 response = get_ui_telemetry_handler() 3520 3521 assert response is not None 3522 assert response.status_code == 200 3523 3524 response_data = json.loads(response.get_data()) 3525 3526 assert response_data["disable_ui_telemetry"] is False 3527 assert response_data["disable_ui_events"] == ["event1", "event2"] 3528 # rollout percent gets converted to a float as that is the proto definition 3529 assert response_data["ui_rollout_percentage"] == 50.0 3530 assert "config" in mock_telemetry_config_cache 3531 assert mock_fetch.call_count == 1 3532 mock_fetch.reset_mock() 3533 3534 # subsequent call should hit cache 3535 response = get_ui_telemetry_handler() 3536 mock_fetch.assert_not_called() 3537 assert response_data["disable_ui_telemetry"] is False 3538 assert response_data["disable_ui_events"] == ["event1", "event2"] 3539 assert response_data["ui_rollout_percentage"] == 50.0 3540 3541 3542 def test_get_ui_telemetry_handler_disabled_by_config( 3543 test_app_context, mock_telemetry_config_cache, bypass_telemetry_env_check 3544 ): 3545 config = { 3546 "disable_telemetry": True, 3547 "disable_ui_telemetry": False, 3548 "disable_ui_events": [], 3549 "ui_rollout_percentage": 0, 3550 } 3551 3552 with mock.patch( 3553 "mlflow.server.handlers.fetch_ui_telemetry_config", return_value=config 3554 ) as mock_fetch: 3555 response = get_ui_telemetry_handler() 3556 assert response is not None 3557 assert response.status_code == 200 3558 response_data = json.loads(response.get_data()) 3559 3560 # if disable_telemetry is True, the server should always report 3561 # that UI telemetry is disabled regardless of disable_ui_telemetry 3562 assert response_data["disable_ui_telemetry"] is True 3563 assert response_data["ui_rollout_percentage"] == 0.0 3564 assert response_data["disable_ui_events"] == [] 3565 assert mock_fetch.call_count == 1 3566 3567 3568 def test_get_ui_telemetry_handler_disabled_by_env( 3569 test_app_context, mock_telemetry_config_cache, bypass_telemetry_env_check, monkeypatch 3570 ): 3571 monkeypatch.setenv("DO_NOT_TRACK", "true") 3572 with mock.patch("mlflow.server.handlers.fetch_ui_telemetry_config") as mock_fetch: 3573 response = get_ui_telemetry_handler() 3574 assert response is not None 3575 assert response.status_code == 200 3576 response_data = json.loads(response.get_data()) 3577 3578 # if telemetry is disabled by env var, the server should always report 3579 # that UI telemetry is disabled, and no config fetch should happen 3580 mock_fetch.assert_not_called() 3581 assert response_data["disable_ui_telemetry"] is True 3582 assert response_data["ui_rollout_percentage"] == 0.0 3583 assert response_data["disable_ui_events"] == [] 3584 3585 3586 def test_get_ui_telemetry_handler_fallback_values( 3587 test_app_context, mock_telemetry_config_cache, bypass_telemetry_env_check 3588 ): 3589 config_without_ui_fields = { 3590 "disable_telemetry": False, 3591 "rollout_percentage": 100, 3592 } 3593 3594 # test fallback values if we forget to define UI config fields 3595 with mock.patch("requests.get", return_value=config_without_ui_fields): 3596 response = get_ui_telemetry_handler() 3597 3598 assert response is not None 3599 assert response.status_code == 200 3600 3601 response_data = json.loads(response.get_data()) 3602 3603 assert response_data["disable_ui_telemetry"] is True 3604 assert response_data["ui_rollout_percentage"] == 0 3605 assert response_data["disable_ui_events"] == [] 3606 3607 # test fallback values if we fail to fetch the config 3608 with mock.patch("requests.get", return_value=mock.Mock(status_code=404)): 3609 response = get_ui_telemetry_handler() 3610 3611 assert response.status_code == 200 3612 3613 response_data = json.loads(response.get_data()) 3614 assert response_data["disable_ui_telemetry"] is True 3615 assert response_data["ui_rollout_percentage"] == 0 3616 assert response_data["disable_ui_events"] == [] 3617 3618 3619 def test_post_ui_telemetry_handler_success( 3620 test_app, mock_telemetry_config_cache, bypass_telemetry_env_check 3621 ): 3622 event1 = { 3623 "event_name": "test_event_1", 3624 "timestamp_ns": 1234567890000000, 3625 "params": {"key1": "value1"}, 3626 "installation_id": "install-123", 3627 "session_id": "session-456", 3628 } 3629 3630 event2 = { 3631 "event_name": "test_event_2", 3632 "timestamp_ns": 1234567890000001, 3633 "params": {"key2": "value2"}, 3634 "installation_id": "install-123", 3635 "session_id": "session-456", 3636 } 3637 request = json.dumps({"records": [event1, event2]}) 3638 config = {"disable_ui_telemetry": False, "disable_telemetry": False} 3639 mock_client = mock.MagicMock() 3640 3641 server_install_id = "server-install-789" 3642 with ( 3643 test_app.test_request_context( 3644 "/ui-telemetry", method="POST", data=request, content_type="application/json" 3645 ), 3646 mock.patch("mlflow.server.handlers.fetch_ui_telemetry_config", return_value=config), 3647 mock.patch("mlflow.server.handlers.get_telemetry_client", return_value=mock_client), 3648 mock.patch( 3649 "mlflow.server.handlers.get_or_create_installation_id", 3650 return_value=server_install_id, 3651 ), 3652 ): 3653 response = post_ui_telemetry_handler() 3654 3655 assert response is not None 3656 assert response.status_code == 200 3657 3658 response_data = json.loads(response.get_data()) 3659 3660 assert response_data["status"] == "success" 3661 assert mock_client.add_records.call_count == 1 3662 assert mock_client.add_records.call_args[0][0] == [ 3663 Record( 3664 **event1, 3665 duration_ms=0, 3666 status=Status.SUCCESS, 3667 server_installation_id=server_install_id, 3668 ), 3669 Record( 3670 **event2, 3671 duration_ms=0, 3672 status=Status.SUCCESS, 3673 server_installation_id=server_install_id, 3674 ), 3675 ] 3676 3677 3678 def test_post_ui_telemetry_handler_telemetry_disabled_by_config( 3679 test_app, mock_telemetry_config_cache, bypass_telemetry_env_check 3680 ): 3681 event = { 3682 "event_name": "test_event_1", 3683 "timestamp_ns": 1234567890000000, 3684 "params": {"key1": "value1"}, 3685 "installation_id": "install-123", 3686 "session_id": "session-456", 3687 } 3688 3689 request = json.dumps({"records": [event]}) 3690 3691 config = {"disable_ui_telemetry": True} 3692 3693 mock_client = mock.MagicMock() 3694 3695 with ( 3696 test_app.test_request_context( 3697 "/ui-telemetry", method="POST", data=request, content_type="application/json" 3698 ), 3699 mock.patch("mlflow.server.handlers.fetch_ui_telemetry_config", return_value=config), 3700 mock.patch("mlflow.server.handlers.get_telemetry_client", return_value=mock_client), 3701 ): 3702 response = post_ui_telemetry_handler() 3703 3704 assert response is not None 3705 assert response.status_code == 200 3706 3707 response_data = json.loads(response.get_data()) 3708 3709 assert response_data["status"] == "disabled" 3710 mock_client.add_record.assert_not_called() 3711 3712 3713 def test_post_ui_telemetry_handler_telemetry_disabled_by_env( 3714 test_app, mock_telemetry_config_cache, bypass_telemetry_env_check, monkeypatch 3715 ): 3716 monkeypatch.setenv("DO_NOT_TRACK", "true") 3717 request = json.dumps({"records": []}) 3718 with ( 3719 test_app.test_request_context( 3720 "/ui-telemetry", method="POST", data=request, content_type="application/json" 3721 ), 3722 mock.patch("mlflow.server.handlers.fetch_ui_telemetry_config") as mock_fetch, 3723 mock.patch("mlflow.server.handlers.get_telemetry_client") as mock_get_client, 3724 ): 3725 response = post_ui_telemetry_handler() 3726 3727 assert response is not None 3728 assert response.status_code == 200 3729 3730 response_data = json.loads(response.get_data()) 3731 3732 assert response_data["status"] == "disabled" 3733 3734 # assert that no fetch happens and no client is retrieved 3735 mock_fetch.assert_not_called() 3736 mock_get_client.assert_not_called() 3737 3738 3739 def test_download_artifact_streams_in_chunks(enable_serve_artifacts, tmp_path): 3740 # Create a test file with binary data larger than the chunk size (2MB + 1000 bytes) 3741 test_file_size = ARTIFACT_STREAM_CHUNK_SIZE * 2 + 1000 3742 test_data = b"x" * test_file_size 3743 3744 artifact_path = "test_model/model.pkl" 3745 test_file = tmp_path / "model.pkl" 3746 test_file.write_bytes(test_data) 3747 3748 with ( 3749 app.test_request_context(method="GET"), 3750 mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo, 3751 mock.patch("mlflow.server.handlers.tempfile.TemporaryDirectory") as mock_tmp_dir, 3752 ): 3753 # Setup mocks 3754 mock_tmp_dir_instance = mock.MagicMock() 3755 mock_tmp_dir_instance.name = str(tmp_path) 3756 mock_tmp_dir.return_value = mock_tmp_dir_instance 3757 3758 mock_artifact_repo = mock.MagicMock() 3759 mock_artifact_repo.download_artifacts.return_value = str(test_file) 3760 mock_repo.return_value = mock_artifact_repo 3761 3762 # Call the function and capture the response 3763 response = _download_artifact(artifact_path) 3764 3765 # Extract chunks from the response by iterating over its data 3766 response_chunks = list(response.response) 3767 3768 # Verify that data was streamed in chunks, not line by line 3769 # For a 2MB+ binary file, line-by-line would produce many small chunks 3770 # Chunk-based streaming should produce exactly 3 chunks (2*1MB + 1000 bytes) 3771 assert len(response_chunks) == 3, f"Expected 3 chunks, got {len(response_chunks)}" 3772 3773 # Verify chunk sizes 3774 assert len(response_chunks[0]) == ARTIFACT_STREAM_CHUNK_SIZE 3775 assert len(response_chunks[1]) == ARTIFACT_STREAM_CHUNK_SIZE 3776 assert len(response_chunks[2]) == 1000 3777 3778 # Verify that all data is correctly streamed 3779 streamed_data = b"".join(response_chunks) 3780 assert streamed_data == test_data 3781 3782 3783 def test_create_prompt_optimization_job(mock_tracking_store): 3784 mock_job_entity = JobEntity( 3785 job_id="job-123", 3786 creation_time=1234567890, 3787 job_name="optimize_prompts", 3788 params='{"run_id": "run-456"}', 3789 timeout=None, 3790 status=JobStatus.PENDING, 3791 result=None, 3792 retry_count=0, 3793 last_update_time=1234567890, 3794 status_details=None, 3795 ) 3796 3797 mock_run = mock.MagicMock() 3798 mock_run.info.run_id = "run-456" 3799 mock_tracking_store.create_run.return_value = mock_run 3800 3801 mock_dataset = mock.MagicMock() 3802 mock_dataset._to_mlflow_entity.return_value = mock.MagicMock() 3803 3804 with ( 3805 mock.patch("mlflow.server.jobs.submit_job", return_value=mock_job_entity), 3806 mock.patch("mlflow.server.handlers._get_user", return_value="test_user"), 3807 mock.patch( 3808 "mlflow.genai.datasets.get_dataset", return_value=mock_dataset 3809 ) as mock_get_dataset, 3810 ): 3811 with app.test_request_context( 3812 method="POST", 3813 json={ 3814 "experiment_id": "exp-123", 3815 "source_prompt_uri": "prompts:/my-prompt/1", 3816 "config": { 3817 "optimizer_type": OPTIMIZER_TYPE_GEPA, 3818 "dataset_id": "dataset-123", 3819 "scorers": ["Correctness", "Safety"], 3820 "optimizer_config_json": '{"reflection_model": "openai:/gpt-4"}', 3821 }, 3822 "tags": [{"key": "env", "value": "test"}], 3823 }, 3824 ): 3825 response = _create_prompt_optimization_job() 3826 3827 mock_get_dataset.assert_called_once_with(dataset_id="dataset-123") 3828 3829 mock_tracking_store.create_run.assert_called_once() 3830 call_kwargs = mock_tracking_store.create_run.call_args[1] 3831 assert call_kwargs["experiment_id"] == "exp-123" 3832 assert call_kwargs["user_id"] == "test_user" 3833 3834 mock_tracking_store.log_batch.assert_called_once() 3835 logged_params = mock_tracking_store.log_batch.call_args[1]["params"] 3836 param_dict = {p.key: p.value for p in logged_params} 3837 assert param_dict["source_prompt_uri"] == "prompts:/my-prompt/1" 3838 assert param_dict["optimizer_type"] == "gepa" 3839 assert param_dict["dataset_id"] == "dataset-123" 3840 assert param_dict["scorer_names"] == '["Correctness", "Safety"]' 3841 3842 response_data = json.loads(response.get_data()) 3843 assert response_data["job"]["job_id"] == "job-123" 3844 assert response_data["job"]["run_id"] == "run-456" 3845 assert response_data["job"]["state"]["status"] == "JOB_STATUS_PENDING" 3846 assert response_data["job"]["experiment_id"] == "exp-123" 3847 assert response_data["job"]["source_prompt_uri"] == "prompts:/my-prompt/1" 3848 3849 3850 def test_create_prompt_optimization_job_zero_shot(mock_tracking_store): 3851 mock_job_entity = JobEntity( 3852 job_id="job-999", 3853 creation_time=1234567890, 3854 job_name="optimize_prompts", 3855 params='{"run_id": "run-999"}', 3856 timeout=None, 3857 status=JobStatus.PENDING, 3858 result=None, 3859 retry_count=0, 3860 last_update_time=1234567890, 3861 status_details=None, 3862 ) 3863 3864 mock_run = mock.MagicMock() 3865 mock_run.info.run_id = "run-999" 3866 mock_tracking_store.create_run.return_value = mock_run 3867 3868 with ( 3869 mock.patch("mlflow.server.jobs.submit_job", return_value=mock_job_entity), 3870 mock.patch("mlflow.server.handlers._get_user", return_value="test_user"), 3871 ): 3872 with app.test_request_context( 3873 method="POST", 3874 json={ 3875 "experiment_id": "exp-123", 3876 "source_prompt_uri": "prompts:/my-prompt/1", 3877 "config": { 3878 "optimizer_type": OPTIMIZER_TYPE_METAPROMPT, 3879 "scorers": [], # Empty scorers for zero-shot 3880 # No dataset_id - zero-shot optimization 3881 }, 3882 }, 3883 ): 3884 response = _create_prompt_optimization_job() 3885 3886 response_data = json.loads(response.get_data()) 3887 assert response_data["job"]["job_id"] == "job-999" 3888 assert response_data["job"]["run_id"] == "run-999" 3889 assert response_data["job"]["state"]["status"] == "JOB_STATUS_PENDING" 3890 3891 mock_tracking_store.log_batch.assert_called_once() 3892 logged_params = mock_tracking_store.log_batch.call_args[1]["params"] 3893 param_dict = {p.key: p.value for p in logged_params} 3894 assert param_dict["dataset_id"] == "" # Empty string for None 3895 assert param_dict["scorer_names"] == "[]" # Empty list 3896 3897 3898 def test_create_prompt_optimization_job_missing_prompt_uri(mock_tracking_store): 3899 with app.test_request_context( 3900 method="POST", 3901 json={ 3902 "experiment_id": "exp-123", 3903 "config": { 3904 "optimizer_type": 1, 3905 "dataset_id": "dataset-123", 3906 "scorers": ["Correctness"], 3907 }, 3908 }, 3909 ): 3910 response = _create_prompt_optimization_job() 3911 assert response.status_code == 400 3912 json_response = json.loads(response.get_data()) 3913 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 3914 assert "source_prompt_uri" in json_response["message"] 3915 3916 3917 def test_create_prompt_optimization_job_unspecified_optimizer_type(mock_tracking_store): 3918 with app.test_request_context( 3919 method="POST", 3920 json={ 3921 "experiment_id": "exp-123", 3922 "source_prompt_uri": "prompts:/my-prompt/1", 3923 "config": { 3924 "optimizer_type": OPTIMIZER_TYPE_UNSPECIFIED, 3925 "dataset_id": "dataset-123", 3926 "scorers": ["Correctness"], 3927 }, 3928 }, 3929 ): 3930 response = _create_prompt_optimization_job() 3931 assert response.status_code == 400 3932 json_response = json.loads(response.get_data()) 3933 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 3934 assert "optimizer_type is required" in json_response["message"] 3935 3936 3937 def test_create_prompt_optimization_job_invalid_optimizer_config_json(mock_tracking_store): 3938 with app.test_request_context( 3939 method="POST", 3940 json={ 3941 "experiment_id": "exp-123", 3942 "source_prompt_uri": "prompts:/my-prompt/1", 3943 "config": { 3944 "optimizer_type": 1, 3945 "dataset_id": "dataset-123", 3946 "scorers": ["Correctness"], 3947 "optimizer_config_json": "invalid json {", 3948 }, 3949 }, 3950 ): 3951 response = _create_prompt_optimization_job() 3952 assert response.status_code == 400 3953 json_response = json.loads(response.get_data()) 3954 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 3955 assert "Invalid JSON in optimizer_config_json" in json_response["message"] 3956 3957 3958 def test_create_prompt_optimization_job_missing_experiment_id(mock_tracking_store): 3959 with app.test_request_context( 3960 method="POST", 3961 json={ 3962 "experiment_id": "", # Empty experiment_id 3963 "source_prompt_uri": "prompts:/my-prompt/1", 3964 "config": { 3965 "optimizer_type": 1, 3966 "dataset_id": "dataset-123", 3967 "scorers": ["Correctness"], 3968 }, 3969 }, 3970 ): 3971 response = _create_prompt_optimization_job() 3972 assert response.status_code == 400 3973 json_response = json.loads(response.get_data()) 3974 assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE) 3975 assert "experiment_id is required" in json_response["message"] 3976 3977 3978 def test_cancel_prompt_optimization_job(): 3979 mock_job_entity = JobEntity( 3980 job_id="job-123", 3981 creation_time=1234567890, 3982 job_name="optimize_prompts", 3983 params=( 3984 '{"experiment_id": "exp-123", "prompt_uri": "prompts:/my-prompt/1", ' 3985 '"run_id": "run-456"}' 3986 ), 3987 timeout=None, 3988 status=JobStatus.CANCELED, 3989 result=None, 3990 retry_count=0, 3991 last_update_time=1234567890, 3992 status_details=None, 3993 ) 3994 3995 with ( 3996 mock.patch("mlflow.server.jobs.cancel_job", return_value=mock_job_entity), 3997 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 3998 ): 3999 mock_tracking_store = mock.Mock() 4000 mock_store.return_value = mock_tracking_store 4001 with app.test_request_context(method="POST"): 4002 response = _cancel_prompt_optimization_job("job-123") 4003 4004 # Verify that the underlying run was terminated 4005 mock_tracking_store.update_run_info.assert_called_once() 4006 call_args = mock_tracking_store.update_run_info.call_args 4007 assert call_args.kwargs["run_id"] == "run-456" 4008 assert call_args.kwargs["run_status"] == RunStatus.KILLED 4009 assert call_args.kwargs["run_name"] is None 4010 assert "end_time" in call_args.kwargs 4011 4012 response_data = json.loads(response.get_data()) 4013 assert response_data["job"]["job_id"] == "job-123" 4014 assert response_data["job"]["state"]["status"] == "JOB_STATUS_CANCELED" 4015 assert response_data["job"]["experiment_id"] == "exp-123" 4016 assert response_data["job"]["source_prompt_uri"] == "prompts:/my-prompt/1" 4017 assert response_data["job"]["run_id"] == "run-456" 4018 4019 4020 def test_get_prompt_optimization_job_pending(mock_tracking_store): 4021 mock_job = _create_mock_job(status_name="PENDING") 4022 4023 mock_run = _create_mock_run( 4024 params={ 4025 "source_prompt_uri": "prompts:/my-prompt/1", 4026 "optimizer_type": "gepa", 4027 "dataset_id": "dataset-789", 4028 "scorer_names": '["Correctness"]', 4029 } 4030 ) 4031 mock_tracking_store.get_run.return_value = mock_run 4032 4033 with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job): 4034 with app.test_client() as c: 4035 response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4036 assert response.status_code == 200 4037 4038 data = response.get_json() 4039 assert "job" in data 4040 job = data["job"] 4041 assert job["job_id"] == "job-123" 4042 assert job["run_id"] == "run-456" 4043 assert job["experiment_id"] == "exp-123" 4044 assert job["source_prompt_uri"] == "prompts:/my-prompt/1" 4045 assert job["state"]["status"] == "JOB_STATUS_PENDING" 4046 4047 4048 def test_get_prompt_optimization_job_succeeded_with_result(mock_tracking_store): 4049 mock_job = _create_mock_job( 4050 status_name="SUCCEEDED", 4051 result={"optimized_prompt_uri": "prompts:/my-prompt/2"}, 4052 ) 4053 4054 mock_run = _create_mock_run( 4055 params={ 4056 "source_prompt_uri": "prompts:/my-prompt/1", 4057 "optimizer_type": "gepa", 4058 "dataset_id": "dataset-789", 4059 "scorer_names": '["Correctness", "Safety"]', 4060 }, 4061 metrics={ 4062 "initial_eval_score.Correctness": 0.65, 4063 "initial_eval_score.Safety": 0.80, 4064 "final_eval_score.Correctness": 0.89, 4065 "final_eval_score.Safety": 0.95, 4066 }, 4067 ) 4068 mock_tracking_store.get_run.return_value = mock_run 4069 4070 with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job): 4071 with app.test_client() as c: 4072 response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4073 assert response.status_code == 200 4074 4075 data = response.get_json() 4076 job = data["job"] 4077 assert job["state"]["status"] == "JOB_STATUS_COMPLETED" 4078 assert job["optimized_prompt_uri"] == "prompts:/my-prompt/2" 4079 # Verify metrics are populated from the run 4080 assert job["initial_eval_scores"]["Correctness"] == 0.65 4081 assert job["initial_eval_scores"]["Safety"] == 0.80 4082 assert job["final_eval_scores"]["Correctness"] == 0.89 4083 assert job["final_eval_scores"]["Safety"] == 0.95 4084 4085 4086 def test_get_prompt_optimization_job_succeeded_run_fetch_fails(mock_tracking_store): 4087 mock_job = _create_mock_job( 4088 status_name="SUCCEEDED", 4089 result={"optimized_prompt_uri": "prompts:/my-prompt/2"}, 4090 ) 4091 4092 # Simulate run fetch failing (e.g., run not found) 4093 mock_tracking_store.get_run.side_effect = Exception("Run not found") 4094 4095 with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job): 4096 with app.test_client() as c: 4097 response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4098 assert response.status_code == 200 4099 4100 data = response.get_json() 4101 job = data["job"] 4102 assert job["state"]["status"] == "JOB_STATUS_COMPLETED" 4103 assert job["optimized_prompt_uri"] == "prompts:/my-prompt/2" 4104 # Metrics should not be present when run fetch fails 4105 assert "initial_eval_scores" not in job or job["initial_eval_scores"] == {} 4106 4107 4108 def test_get_prompt_optimization_job_failed_with_error(mock_tracking_store): 4109 mock_job = _create_mock_job( 4110 status_name="FAILED", 4111 result="Optimization failed: Invalid scorer", 4112 ) 4113 4114 mock_run = _create_mock_run() 4115 mock_tracking_store.get_run.return_value = mock_run 4116 4117 with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job): 4118 with app.test_client() as c: 4119 response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4120 assert response.status_code == 200 4121 4122 data = response.get_json() 4123 job = data["job"] 4124 assert job["state"]["status"] == "JOB_STATUS_FAILED" 4125 assert "Optimization failed" in job["state"]["error_message"] 4126 4127 4128 def test_get_prompt_optimization_job_without_run_id(mock_tracking_store): 4129 mock_job = _create_mock_job( 4130 params={"experiment_id": "exp-123", "prompt_uri": "prompts:/my-prompt/1"} 4131 ) 4132 4133 with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job): 4134 with app.test_client() as c: 4135 response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4136 assert response.status_code == 200 4137 data = response.get_json() 4138 job = data["job"] 4139 assert job["job_id"] == "job-123" 4140 assert job["experiment_id"] == "exp-123" 4141 assert "run_id" not in job # run_id is not set 4142 4143 4144 def test_get_prompt_optimization_job_with_progress(mock_tracking_store): 4145 mock_job = _create_mock_job( 4146 status_name="RUNNING", 4147 params={ 4148 "experiment_id": "exp-123", 4149 "prompt_uri": "prompts:/my-prompt/1", 4150 "run_id": "run-456", 4151 "optimizer_config": {"max_metric_calls": 200, "reflection_model": "openai:/gpt-4o"}, 4152 }, 4153 ) 4154 4155 mock_run = _create_mock_run( 4156 params={ 4157 "source_prompt_uri": "prompts:/my-prompt/1", 4158 "optimizer_type": "gepa", 4159 }, 4160 metrics={ 4161 "total_metric_calls": 86, 4162 "eval_score": 0.75, 4163 }, 4164 ) 4165 mock_tracking_store.get_run.return_value = mock_run 4166 4167 with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job): 4168 with app.test_client() as c: 4169 response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4170 assert response.status_code == 200 4171 4172 data = response.get_json() 4173 job = data["job"] 4174 assert job["state"]["status"] == "JOB_STATUS_IN_PROGRESS" 4175 # Progress should be 86 / 200 = 0.43 4176 assert job["state"]["metadata"]["progress"] == "0.43" 4177 4178 4179 def test_get_prompt_optimization_job_progress_capped_at_one(mock_tracking_store): 4180 mock_job = _create_mock_job( 4181 status_name="RUNNING", 4182 params={ 4183 "experiment_id": "exp-123", 4184 "prompt_uri": "prompts:/my-prompt/1", 4185 "run_id": "run-456", 4186 "optimizer_config": {"max_metric_calls": 100, "reflection_model": "openai:/gpt-4o"}, 4187 }, 4188 ) 4189 4190 mock_run = _create_mock_run( 4191 metrics={ 4192 "total_metric_calls": 150, # Exceeds max_metric_calls 4193 }, 4194 ) 4195 mock_tracking_store.get_run.return_value = mock_run 4196 4197 with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job): 4198 with app.test_client() as c: 4199 response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4200 assert response.status_code == 200 4201 4202 data = response.get_json() 4203 job = data["job"] 4204 # Progress should be capped at 1.0, not 1.5 4205 assert job["state"]["metadata"]["progress"] == "1.0" 4206 4207 4208 def test_get_prompt_optimization_job_no_progress_without_max_metric_calls(mock_tracking_store): 4209 mock_job = _create_mock_job( 4210 status_name="RUNNING", 4211 params={ 4212 "experiment_id": "exp-123", 4213 "prompt_uri": "prompts:/my-prompt/1", 4214 "run_id": "run-456", 4215 "optimizer_config": {"reflection_model": "openai:/gpt-4o"}, 4216 }, 4217 ) 4218 4219 mock_run = _create_mock_run( 4220 metrics={ 4221 "total_metric_calls": 50, 4222 }, 4223 ) 4224 mock_tracking_store.get_run.return_value = mock_run 4225 4226 with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job): 4227 with app.test_client() as c: 4228 response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4229 assert response.status_code == 200 4230 4231 data = response.get_json() 4232 job = data["job"] 4233 # Progress should NOT be set when max_metric_calls is not configured 4234 assert "progress" not in job["state"].get("status_details", {}) 4235 4236 4237 def test_search_prompt_optimization_jobs_returns_multiple_jobs(mock_job_store): 4238 mock_jobs = [ 4239 _create_mock_job( 4240 job_id="job-1", 4241 status_name="SUCCEEDED", 4242 result={"optimized_prompt_uri": "prompts:/opt/1"}, 4243 ), 4244 _create_mock_job(job_id="job-2", status_name="RUNNING"), 4245 _create_mock_job(job_id="job-3", status_name="PENDING"), 4246 ] 4247 mock_job_store.list_jobs.return_value = iter(mock_jobs) 4248 4249 with app.test_client() as c: 4250 response = c.post( 4251 "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search", 4252 json={"experiment_id": "exp-123"}, 4253 ) 4254 assert response.status_code == 200 4255 4256 data = response.get_json() 4257 assert "jobs" in data 4258 assert len(data["jobs"]) == 3 4259 4260 job_ids = [job["job_id"] for job in data["jobs"]] 4261 assert "job-1" in job_ids 4262 assert "job-2" in job_ids 4263 assert "job-3" in job_ids 4264 4265 mock_job_store.list_jobs.assert_called_once_with( 4266 job_name="optimize_prompts", 4267 params={"experiment_id": "exp-123"}, 4268 ) 4269 4270 4271 def test_search_prompt_optimization_jobs_returns_empty_list(mock_job_store): 4272 mock_job_store.list_jobs.return_value = iter([]) 4273 4274 with app.test_client() as c: 4275 response = c.post( 4276 "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search", 4277 json={"experiment_id": "exp-456"}, 4278 ) 4279 assert response.status_code == 200 4280 4281 data = response.get_json() 4282 assert data.get("jobs", []) == [] 4283 4284 4285 def test_search_prompt_optimization_jobs_missing_experiment_id(): 4286 with app.test_client() as c: 4287 response = c.post( 4288 "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search", 4289 json={}, 4290 ) 4291 assert response.status_code == 400 4292 4293 4294 def test_search_prompt_optimization_jobs_includes_succeeded_job_result(mock_job_store): 4295 mock_job = _create_mock_job( 4296 job_id="job-1", 4297 status_name="SUCCEEDED", 4298 result={"optimized_prompt_uri": "prompts:/optimized/1"}, 4299 ) 4300 mock_job_store.list_jobs.return_value = iter([mock_job]) 4301 4302 with app.test_client() as c: 4303 response = c.post( 4304 "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search", 4305 json={"experiment_id": "exp-123"}, 4306 ) 4307 assert response.status_code == 200 4308 4309 data = response.get_json() 4310 assert len(data["jobs"]) == 1 4311 assert data["jobs"][0]["optimized_prompt_uri"] == "prompts:/optimized/1" 4312 4313 4314 def test_search_prompt_optimization_jobs_includes_failed_job_error(mock_job_store): 4315 mock_job = _create_mock_job( 4316 job_id="job-1", 4317 status_name="FAILED", 4318 result="Some error occurred", 4319 ) 4320 mock_job_store.list_jobs.return_value = iter([mock_job]) 4321 4322 with app.test_client() as c: 4323 response = c.post( 4324 "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search", 4325 json={"experiment_id": "exp-123"}, 4326 ) 4327 assert response.status_code == 200 4328 4329 data = response.get_json() 4330 assert len(data["jobs"]) == 1 4331 assert "Some error occurred" in data["jobs"][0]["state"]["error_message"] 4332 4333 4334 def test_delete_prompt_optimization_job_success(mock_job_store, mock_tracking_store): 4335 mock_job = _create_mock_job( 4336 status_name="SUCCEEDED", 4337 result={"optimized_prompt_uri": "prompts:/optimized/1"}, 4338 ) 4339 mock_job_store.get_job.return_value = mock_job 4340 mock_tracking_store.get_run.return_value = mock.MagicMock() # Run exists 4341 4342 with app.test_client() as c: 4343 response = c.delete("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4344 assert response.status_code == 200 4345 4346 mock_job_store.delete_jobs.assert_called_once_with(job_ids=["job-123"]) 4347 mock_tracking_store.get_run.assert_called_once_with("run-456") 4348 mock_tracking_store.delete_run.assert_called_once_with("run-456") 4349 4350 4351 def test_delete_prompt_optimization_job_without_run_id(mock_job_store, mock_tracking_store): 4352 mock_job = _create_mock_job( 4353 params={"experiment_id": "exp-123", "prompt_uri": "prompts:/my-prompt/1"} 4354 ) 4355 mock_job_store.get_job.return_value = mock_job 4356 4357 with app.test_client() as c: 4358 response = c.delete("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4359 assert response.status_code == 200 4360 4361 mock_job_store.delete_jobs.assert_called_once_with(job_ids=["job-123"]) 4362 mock_tracking_store.delete_run.assert_not_called() 4363 4364 4365 def test_delete_prompt_optimization_job_skips_run_deletion_when_run_not_found( 4366 mock_job_store, mock_tracking_store 4367 ): 4368 mock_job = _create_mock_job( 4369 status_name="SUCCEEDED", 4370 result={"optimized_prompt_uri": "prompts:/optimized/1"}, 4371 ) 4372 mock_job_store.get_job.return_value = mock_job 4373 mock_tracking_store.get_run.side_effect = MlflowException("Run not found") 4374 4375 with app.test_client() as c: 4376 response = c.delete("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123") 4377 assert response.status_code == 200 4378 4379 mock_job_store.delete_jobs.assert_called_once_with(job_ids=["job-123"]) 4380 # delete_run should not be called since run doesn't exist 4381 mock_tracking_store.delete_run.assert_not_called() 4382 4383 4384 def test_get_workspace_scoped_repo_path_if_enabled_allows_legacy_default_artifacts(monkeypatch): 4385 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4386 with WorkspaceContext(DEFAULT_WORKSPACE_NAME): 4387 assert ( 4388 _get_workspace_scoped_repo_path_if_enabled("1/legacy/artifact") == "1/legacy/artifact" 4389 ) 4390 4391 4392 def test_get_workspace_scoped_repo_path_if_enabled_still_scopes_non_default(monkeypatch): 4393 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4394 with WorkspaceContext("team-blue"): 4395 scoped = _get_workspace_scoped_repo_path_if_enabled("2/new/artifact") 4396 assert scoped.startswith("workspaces/team-blue/2/new/artifact") 4397 4398 4399 def test_get_workspace_scoped_repo_path_if_enabled_prevents_cross_workspace_access(monkeypatch): 4400 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4401 4402 with WorkspaceContext("team-a"): 4403 with pytest.raises(MlflowException, match="targets workspace 'team-b'"): 4404 _get_workspace_scoped_repo_path_if_enabled("workspaces/team-b/secret.txt") 4405 4406 with pytest.raises(MlflowException, match="targets workspace 'other'"): 4407 _get_workspace_scoped_repo_path_if_enabled("workspaces/other/data/model.pkl") 4408 4409 4410 def test_get_workspace_scoped_repo_path_if_enabled_rejects_empty_workspace_in_path(monkeypatch): 4411 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4412 4413 with WorkspaceContext("team-a"): 4414 with pytest.raises(MlflowException, match="must include a workspace name"): 4415 _get_workspace_scoped_repo_path_if_enabled("workspaces/") 4416 4417 with pytest.raises(MlflowException, match="must include a workspace name"): 4418 _get_workspace_scoped_repo_path_if_enabled("workspaces//data.txt") 4419 4420 4421 def test_get_workspace_scoped_repo_path_if_enabled_allows_matching_workspace_prefix(monkeypatch): 4422 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4423 4424 with WorkspaceContext("team-a"): 4425 result = _get_workspace_scoped_repo_path_if_enabled("workspaces/team-a/data.txt") 4426 assert result == "workspaces/team-a/data.txt" 4427 4428 result = _get_workspace_scoped_repo_path_if_enabled("/workspaces/team-a/nested/path") 4429 assert result == "workspaces/team-a/nested/path" 4430 4431 4432 def test_get_workspace_scoped_repo_path_if_enabled_default_workspace_cross_access_blocked( 4433 monkeypatch, 4434 ): 4435 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4436 4437 with WorkspaceContext(DEFAULT_WORKSPACE_NAME): 4438 result = _get_workspace_scoped_repo_path_if_enabled("legacy/artifact.txt") 4439 assert result == "legacy/artifact.txt" 4440 4441 with pytest.raises(MlflowException, match="targets workspace 'team-b'"): 4442 _get_workspace_scoped_repo_path_if_enabled("workspaces/team-b/data.txt") 4443 4444 4445 def test_get_workspace_scoped_repo_path_if_enabled_requires_active_workspace(monkeypatch): 4446 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4447 4448 with pytest.raises(MlflowException, match="Active workspace is required"): 4449 _get_workspace_scoped_repo_path_if_enabled("some/path") 4450 4451 4452 def test_get_artifact_handler_applies_workspace_scoping(monkeypatch): 4453 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4454 monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true") 4455 monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket") 4456 4457 mock_run = mock.MagicMock() 4458 mock_run.info.artifact_uri = "mlflow-artifacts:/exp1/run1/artifacts" 4459 4460 mock_artifact_repo = mock.MagicMock() 4461 mock_artifact_repo.download_artifacts.return_value = "/tmp/artifact.txt" 4462 4463 with ( 4464 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4465 mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo, 4466 mock.patch("mlflow.server.handlers._send_artifact") as mock_send, 4467 ): 4468 mock_store.return_value.get_run.return_value = mock_run 4469 mock_repo.return_value = mock_artifact_repo 4470 4471 with WorkspaceContext("team-blue"): 4472 with app.test_request_context( 4473 method="GET", query_string={"run_id": "run1", "path": "model/weights.bin"} 4474 ): 4475 get_artifact_handler() 4476 4477 mock_send.assert_called_once() 4478 artifact_path = mock_send.call_args[0][1] 4479 assert artifact_path.startswith("workspaces/team-blue/") 4480 4481 4482 def test_get_artifact_handler_no_scoping_when_workspaces_disabled(monkeypatch): 4483 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false") 4484 monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true") 4485 monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket") 4486 4487 mock_run = mock.MagicMock() 4488 mock_run.info.artifact_uri = "mlflow-artifacts:/exp1/run1/artifacts" 4489 4490 mock_artifact_repo = mock.MagicMock() 4491 4492 with ( 4493 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4494 mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo, 4495 mock.patch("mlflow.server.handlers._send_artifact") as mock_send, 4496 ): 4497 mock_store.return_value.get_run.return_value = mock_run 4498 mock_repo.return_value = mock_artifact_repo 4499 4500 with app.test_request_context( 4501 method="GET", query_string={"run_id": "run1", "path": "model/weights.bin"} 4502 ): 4503 get_artifact_handler() 4504 4505 mock_send.assert_called_once() 4506 artifact_path = mock_send.call_args[0][1] 4507 assert not artifact_path.startswith("workspaces/") 4508 4509 4510 def test_get_model_version_artifact_handler_applies_workspace_scoping(monkeypatch): 4511 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4512 monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true") 4513 monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket") 4514 4515 mock_artifact_repo = mock.MagicMock() 4516 4517 with ( 4518 mock.patch("mlflow.server.handlers._get_model_registry_store") as mock_store, 4519 mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo, 4520 mock.patch("mlflow.server.handlers._send_artifact") as mock_send, 4521 ): 4522 mock_store.return_value.get_model_version_download_uri.return_value = ( 4523 "mlflow-artifacts:/models/MyModel/1" 4524 ) 4525 mock_repo.return_value = mock_artifact_repo 4526 4527 with WorkspaceContext("team-red"): 4528 with app.test_request_context( 4529 method="GET", query_string={"name": "MyModel", "version": "1", "path": "model.pkl"} 4530 ): 4531 get_model_version_artifact_handler() 4532 4533 mock_send.assert_called_once() 4534 artifact_path = mock_send.call_args[0][1] 4535 assert artifact_path.startswith("workspaces/team-red/") 4536 4537 4538 def test_get_logged_model_artifact_handler_applies_workspace_scoping(monkeypatch): 4539 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4540 monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true") 4541 monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket") 4542 4543 mock_logged_model = mock.MagicMock() 4544 mock_logged_model.artifact_location = "mlflow-artifacts:/exp1/run1/artifacts/model" 4545 4546 mock_artifact_repo = mock.MagicMock() 4547 4548 with ( 4549 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4550 mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo, 4551 mock.patch("mlflow.server.handlers._send_artifact") as mock_send, 4552 ): 4553 mock_store.return_value.get_logged_model.return_value = mock_logged_model 4554 mock_repo.return_value = mock_artifact_repo 4555 4556 with WorkspaceContext("team-green"): 4557 with app.test_request_context( 4558 method="GET", query_string={"artifact_file_path": "MLmodel"} 4559 ): 4560 get_logged_model_artifact_handler("model123") 4561 4562 mock_send.assert_called_once() 4563 artifact_path = mock_send.call_args[0][1] 4564 assert artifact_path.startswith("workspaces/team-green/") 4565 4566 4567 def test_upload_artifact_handler_applies_workspace_scoping(monkeypatch): 4568 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4569 monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true") 4570 monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket") 4571 4572 mock_run = mock.MagicMock() 4573 mock_run.info.artifact_uri = "mlflow-artifacts:/exp1/run1/artifacts" 4574 mock_run.info.experiment_id = "exp1" 4575 mock_run.info.run_id = "run1" 4576 4577 mock_artifact_repo = mock.MagicMock() 4578 4579 with ( 4580 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4581 mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo, 4582 ): 4583 mock_store.return_value.get_run.return_value = mock_run 4584 mock_repo.return_value = mock_artifact_repo 4585 4586 with WorkspaceContext("team-purple"): 4587 with app.test_request_context( 4588 method="POST", 4589 query_string={"run_uuid": "run1", "path": "output.txt"}, 4590 data=b"test data", 4591 ): 4592 upload_artifact_handler() 4593 4594 mock_artifact_repo.log_artifact.assert_called_once() 4595 logged_path = mock_artifact_repo.log_artifact.call_args[0][1] 4596 assert logged_path.startswith("workspaces/team-purple/") 4597 4598 4599 def test_list_artifacts_for_proxied_run_artifact_root_applies_workspace_scoping(monkeypatch): 4600 from mlflow.store.artifact.artifact_repo import ArtifactRepository 4601 4602 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 4603 monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true") 4604 monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket") 4605 4606 mock_artifact_repo = mock.MagicMock(spec=ArtifactRepository) 4607 mock_artifact_repo.list_artifacts.return_value = [] 4608 4609 with ( 4610 mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo, 4611 WorkspaceContext("team-orange"), 4612 ): 4613 mock_repo.return_value = mock_artifact_repo 4614 4615 _list_artifacts_for_proxied_run_artifact_root( 4616 proxied_artifact_root="mlflow-artifacts:/exp1/run1/artifacts", 4617 relative_path="model", 4618 ) 4619 4620 mock_artifact_repo.list_artifacts.assert_called_once() 4621 listed_path = mock_artifact_repo.list_artifacts.call_args[0][0] 4622 assert listed_path.startswith("workspaces/team-orange/") 4623 4624 4625 # ==================== Budget Window Tests ==================== 4626 4627 4628 def _make_budget_policy( 4629 budget_policy_id="bp-test", 4630 budget_amount=100.0, 4631 duration=None, 4632 ): 4633 return GatewayBudgetPolicy( 4634 budget_policy_id=budget_policy_id, 4635 budget_unit=BudgetUnit.USD, 4636 budget_amount=budget_amount, 4637 duration=duration or BudgetDuration(unit=BudgetDurationUnit.DAYS, value=1), 4638 target_scope=BudgetTargetScope.GLOBAL, 4639 budget_action=BudgetAction.ALERT, 4640 created_at=0, 4641 last_updated_at=0, 4642 ) 4643 4644 4645 def test_list_budget_windows_empty(): 4646 with ( 4647 app.test_client() as c, 4648 mock.patch("mlflow.server.handlers.get_budget_tracker") as mock_tracker, 4649 mock.patch("mlflow.server.handlers.maybe_refresh_budget_policies"), 4650 ): 4651 mock_tracker.return_value.get_all_windows.return_value = [] 4652 response = c.get("/ajax-api/3.0/mlflow/gateway/budgets/windows") 4653 4654 assert response.status_code == 200 4655 assert response.json.get("windows", []) == [] 4656 4657 4658 def test_list_budget_windows_returns_window_data(): 4659 tracker = InMemoryBudgetTracker() 4660 policy = _make_budget_policy(budget_policy_id="bp-1", budget_amount=50.0) 4661 tracker.refresh_policies([policy]) 4662 tracker.record_cost(12.5) 4663 4664 with ( 4665 app.test_client() as c, 4666 mock.patch("mlflow.server.handlers.get_budget_tracker", return_value=tracker), 4667 mock.patch("mlflow.server.handlers.maybe_refresh_budget_policies"), 4668 ): 4669 response = c.get("/ajax-api/3.0/mlflow/gateway/budgets/windows") 4670 4671 assert response.status_code == 200 4672 data = response.json 4673 assert len(data["windows"]) == 1 4674 window = data["windows"][0] 4675 assert window["budget_policy_id"] == "bp-1" 4676 assert window["current_spend"] == 12.5 4677 min_ms = int(datetime(2000, 1, 1, tzinfo=timezone.utc).timestamp() * 1000) 4678 assert window["window_start_ms"] >= min_ms 4679 assert window["window_end_ms"] > window["window_start_ms"] 4680 # Policy uses duration_unit=DAYS, duration_value=1 → exactly 1 day 4681 assert window["window_end_ms"] - window["window_start_ms"] == 86_400_000 4682 4683 4684 def test_list_budget_windows_multiple_policies(): 4685 tracker = InMemoryBudgetTracker() 4686 policy1 = _make_budget_policy(budget_policy_id="bp-1", budget_amount=100.0) 4687 policy2 = _make_budget_policy(budget_policy_id="bp-2", budget_amount=200.0) 4688 tracker.refresh_policies([policy1, policy2]) 4689 tracker.record_cost(30.0) 4690 4691 with ( 4692 app.test_client() as c, 4693 mock.patch("mlflow.server.handlers.get_budget_tracker", return_value=tracker), 4694 mock.patch("mlflow.server.handlers.maybe_refresh_budget_policies"), 4695 ): 4696 response = c.get("/ajax-api/3.0/mlflow/gateway/budgets/windows") 4697 4698 assert response.status_code == 200 4699 data = response.json 4700 policy_ids = {w["budget_policy_id"] for w in data["windows"]} 4701 assert policy_ids == {"bp-1", "bp-2"} 4702 windows_by_id = {w["budget_policy_id"]: w for w in data["windows"]} 4703 assert windows_by_id["bp-1"]["current_spend"] == 30.0 4704 assert windows_by_id["bp-2"]["current_spend"] == 30.0 4705 4706 4707 def test_list_budget_windows_zero_spend(): 4708 tracker = InMemoryBudgetTracker() 4709 policy = _make_budget_policy(budget_amount=100.0) 4710 tracker.refresh_policies([policy]) 4711 4712 with ( 4713 app.test_client() as c, 4714 mock.patch("mlflow.server.handlers.get_budget_tracker", return_value=tracker), 4715 mock.patch("mlflow.server.handlers.maybe_refresh_budget_policies"), 4716 ): 4717 response = c.get("/ajax-api/3.0/mlflow/gateway/budgets/windows") 4718 4719 assert response.status_code == 200 4720 window = response.json["windows"][0] 4721 assert window["budget_policy_id"] == "bp-test" 4722 assert window["current_spend"] == 0.0 4723 4724 4725 def test_create_issue_with_all_fields(): 4726 request_message = CreateIssue() 4727 request_message.experiment_id = "exp-123" 4728 request_message.name = "High latency" 4729 request_message.description = "API calls are taking too long" 4730 request_message.status = "pending" 4731 request_message.source_run_id = "run-123" 4732 request_message.root_causes.extend(["Database query inefficiency", "Network latency"]) 4733 request_message.categories.extend(["performance", "database"]) 4734 request_message.severity = IssueSeverity.HIGH.value 4735 request_message.created_by = "user@example.com" 4736 4737 issue = Issue( 4738 issue_id="iss-123", 4739 experiment_id="exp-123", 4740 name="High latency", 4741 description="API calls are taking too long", 4742 status=IssueStatus.PENDING, 4743 source_run_id="run-123", 4744 root_causes=["Database query inefficiency", "Network latency"], 4745 categories=["performance", "database"], 4746 severity=IssueSeverity.HIGH, 4747 created_timestamp=1234567890, 4748 last_updated_timestamp=1234567890, 4749 created_by="user@example.com", 4750 ) 4751 4752 with ( 4753 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4754 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 4755 ): 4756 mock_store.return_value.create_issue.return_value = issue 4757 4758 response = _create_issue() 4759 4760 mock_store.return_value.create_issue.assert_called_once() 4761 call_kwargs = mock_store.return_value.create_issue.call_args[1] 4762 assert call_kwargs["experiment_id"] == "exp-123" 4763 assert call_kwargs["name"] == "High latency" 4764 assert call_kwargs["description"] == "API calls are taking too long" 4765 assert call_kwargs["status"] == IssueStatus.PENDING 4766 assert call_kwargs["source_run_id"] == "run-123" 4767 assert call_kwargs["root_causes"] == ["Database query inefficiency", "Network latency"] 4768 assert call_kwargs["categories"] == ["performance", "database"] 4769 assert call_kwargs["severity"] == IssueSeverity.HIGH.value 4770 assert call_kwargs["created_by"] == "user@example.com" 4771 4772 json_response = json.loads(response.get_data()) 4773 assert json_response["issue"]["issue_id"] == "iss-123" 4774 assert json_response["issue"]["root_causes"] == [ 4775 "Database query inefficiency", 4776 "Network latency", 4777 ] 4778 assert json_response["issue"]["categories"] == ["performance", "database"] 4779 4780 4781 def test_create_issue_without_optional_fields(): 4782 request_message = CreateIssue() 4783 request_message.experiment_id = "exp-456" 4784 request_message.name = "Error handling issue" 4785 request_message.description = "Errors are not being caught properly" 4786 4787 issue = Issue( 4788 issue_id="iss-456", 4789 experiment_id="exp-456", 4790 name="Error handling issue", 4791 description="Errors are not being caught properly", 4792 status=IssueStatus.PENDING, 4793 created_timestamp=1234567890, 4794 last_updated_timestamp=1234567890, 4795 ) 4796 4797 with ( 4798 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4799 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 4800 ): 4801 mock_store.return_value.create_issue.return_value = issue 4802 4803 response = _create_issue() 4804 4805 mock_store.return_value.create_issue.assert_called_once() 4806 call_kwargs = mock_store.return_value.create_issue.call_args[1] 4807 assert call_kwargs["source_run_id"] is None 4808 assert call_kwargs["root_causes"] is None 4809 assert "severity" not in call_kwargs 4810 4811 json_response = json.loads(response.get_data()) 4812 assert json_response["issue"]["issue_id"] == "iss-456" 4813 4814 4815 def test_create_issue_with_default_status(): 4816 request_message = CreateIssue() 4817 request_message.experiment_id = "exp-789" 4818 request_message.name = "Test issue" 4819 request_message.description = "Test description" 4820 4821 issue = Issue( 4822 issue_id="iss-789", 4823 experiment_id="exp-789", 4824 name="Test issue", 4825 description="Test description", 4826 status=IssueStatus.PENDING, 4827 created_timestamp=1234567890, 4828 last_updated_timestamp=1234567890, 4829 ) 4830 4831 with ( 4832 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4833 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 4834 ): 4835 mock_store.return_value.create_issue.return_value = issue 4836 4837 _create_issue() 4838 4839 call_kwargs = mock_store.return_value.create_issue.call_args[1] 4840 # Status should not be in kwargs when not provided (store uses default) 4841 assert "status" not in call_kwargs 4842 4843 4844 def test_get_issue(): 4845 issue = Issue( 4846 issue_id="iss-get-123", 4847 experiment_id="exp-123", 4848 name="Test issue", 4849 description="Test description", 4850 status=IssueStatus.RESOLVED, 4851 severity=IssueSeverity.HIGH, 4852 root_causes=["Root cause 1"], 4853 created_timestamp=1234567890, 4854 last_updated_timestamp=1234567890, 4855 ) 4856 4857 with mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store: 4858 mock_store.return_value.get_issue.return_value = issue 4859 4860 with app.test_request_context(): 4861 response = _get_issue("iss-get-123") 4862 4863 mock_store.return_value.get_issue.assert_called_once_with("iss-get-123") 4864 4865 json_response = json.loads(response.get_data()) 4866 assert json_response["issue"]["issue_id"] == "iss-get-123" 4867 assert json_response["issue"]["name"] == "Test issue" 4868 assert json_response["issue"]["severity"] == "high" 4869 assert json_response["issue"]["root_causes"] == ["Root cause 1"] 4870 4871 4872 def test_get_issue_not_found(): 4873 with mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store: 4874 mock_store.return_value.get_issue.side_effect = MlflowException( 4875 "Issue not found", error_code=RESOURCE_DOES_NOT_EXIST 4876 ) 4877 4878 with app.test_request_context(): 4879 response = _get_issue("nonexistent-id") 4880 4881 # The @catch_mlflow_exception decorator catches and returns error as JSON 4882 assert response.status_code == 404 4883 json_response = json.loads(response.get_data()) 4884 assert json_response["error_code"] == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 4885 assert "Issue not found" in json_response["message"] 4886 4887 4888 def test_update_issue(): 4889 request_message = UpdateIssue() 4890 request_message.issue_id = "iss-update-123" 4891 request_message.name = "Updated issue name" 4892 request_message.description = "Updated description" 4893 request_message.status = "resolved" 4894 request_message.severity = "medium" 4895 4896 updated_issue = Issue( 4897 issue_id="iss-update-123", 4898 experiment_id="exp-123", 4899 name="Updated issue name", 4900 description="Updated description", 4901 status=IssueStatus.RESOLVED, 4902 severity=IssueSeverity.MEDIUM, 4903 created_timestamp=1234567890, 4904 last_updated_timestamp=1234567900, 4905 ) 4906 4907 with ( 4908 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4909 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 4910 ): 4911 mock_store.return_value.update_issue.return_value = updated_issue 4912 4913 response = _update_issue("iss-update-123") 4914 4915 mock_store.return_value.update_issue.assert_called_once() 4916 call_kwargs = mock_store.return_value.update_issue.call_args[1] 4917 assert call_kwargs["issue_id"] == "iss-update-123" 4918 assert call_kwargs["name"] == "Updated issue name" 4919 assert call_kwargs["description"] == "Updated description" 4920 assert call_kwargs["status"] == IssueStatus.RESOLVED 4921 assert call_kwargs["severity"] == IssueSeverity.MEDIUM.value 4922 4923 json_response = json.loads(response.get_data()) 4924 assert json_response["issue"]["issue_id"] == "iss-update-123" 4925 assert json_response["issue"]["name"] == "Updated issue name" 4926 assert json_response["issue"]["severity"] == "medium" 4927 4928 4929 def test_search_issues_all(): 4930 request_message = SearchIssues() 4931 4932 issues = [ 4933 Issue( 4934 issue_id="iss-1", 4935 experiment_id="exp-1", 4936 name="Issue 1", 4937 description="Description 1", 4938 status=IssueStatus.PENDING, 4939 created_timestamp=1234567890, 4940 last_updated_timestamp=1234567890, 4941 ), 4942 Issue( 4943 issue_id="iss-2", 4944 experiment_id="exp-1", 4945 name="Issue 2", 4946 description="Description 2", 4947 status=IssueStatus.RESOLVED, 4948 created_timestamp=1234567891, 4949 last_updated_timestamp=1234567891, 4950 ), 4951 ] 4952 4953 with ( 4954 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4955 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 4956 ): 4957 mock_store.return_value.search_issues.return_value = PagedList(issues, token="next-token") 4958 4959 response = _search_issues() 4960 4961 mock_store.return_value.search_issues.assert_called_once() 4962 call_kwargs = mock_store.return_value.search_issues.call_args[1] 4963 # max_results not specified in request, so it's not passed to store 4964 # The store will use its own default parameter value (SEARCH_ISSUES_DEFAULT_MAX_RESULTS) 4965 assert "max_results" not in call_kwargs 4966 assert call_kwargs["experiment_id"] is None 4967 assert call_kwargs["filter_string"] is None 4968 4969 json_response = json.loads(response.get_data()) 4970 assert len(json_response["issues"]) == 2 4971 assert json_response["issues"][0]["issue_id"] == "iss-1" 4972 assert json_response["issues"][1]["issue_id"] == "iss-2" 4973 assert json_response["next_page_token"] == "next-token" 4974 4975 4976 def test_search_issues_with_filters(): 4977 request_message = SearchIssues() 4978 request_message.experiment_id = "exp-specific" 4979 request_message.filter_string = "status = 'resolved' AND source_run_id = 'run-specific'" 4980 request_message.max_results = 50 4981 4982 issues = [ 4983 Issue( 4984 issue_id="iss-filtered", 4985 experiment_id="exp-specific", 4986 name="Filtered issue", 4987 description="Description", 4988 status=IssueStatus.RESOLVED, 4989 source_run_id="run-specific", 4990 created_timestamp=1234567890, 4991 last_updated_timestamp=1234567890, 4992 ), 4993 ] 4994 4995 with ( 4996 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 4997 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 4998 ): 4999 mock_store.return_value.search_issues.return_value = PagedList(issues, token=None) 5000 5001 response = _search_issues() 5002 5003 call_kwargs = mock_store.return_value.search_issues.call_args[1] 5004 assert call_kwargs["experiment_id"] == "exp-specific" 5005 assert ( 5006 call_kwargs["filter_string"] == "status = 'resolved' AND source_run_id = 'run-specific'" 5007 ) 5008 assert call_kwargs["max_results"] == 50 5009 5010 json_response = json.loads(response.get_data()) 5011 assert len(json_response["issues"]) == 1 5012 assert json_response["issues"][0]["issue_id"] == "iss-filtered" 5013 assert json_response["next_page_token"] == "" 5014 5015 5016 def test_search_issues_with_pagination(): 5017 request_message = SearchIssues() 5018 request_message.max_results = 10 5019 request_message.page_token = "token-123" 5020 5021 issues = [ 5022 Issue( 5023 issue_id=f"iss-{i}", 5024 experiment_id="exp-1", 5025 name=f"Issue {i}", 5026 description=f"Description {i}", 5027 status=IssueStatus.PENDING, 5028 created_timestamp=1234567890 + i, 5029 last_updated_timestamp=1234567890 + i, 5030 ) 5031 for i in range(10) 5032 ] 5033 5034 with ( 5035 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 5036 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 5037 ): 5038 mock_store.return_value.search_issues.return_value = PagedList(issues, token="token-456") 5039 5040 response = _search_issues() 5041 5042 call_kwargs = mock_store.return_value.search_issues.call_args[1] 5043 assert call_kwargs["max_results"] == 10 5044 assert call_kwargs["page_token"] == "token-123" 5045 5046 json_response = json.loads(response.get_data()) 5047 assert len(json_response["issues"]) == 10 5048 assert json_response["next_page_token"] == "token-456" 5049 5050 5051 def test_search_issues_empty_results(): 5052 request_message = SearchIssues() 5053 5054 with ( 5055 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 5056 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 5057 ): 5058 mock_store.return_value.search_issues.return_value = PagedList([], token=None) 5059 5060 response = _search_issues() 5061 5062 json_response = json.loads(response.get_data()) 5063 assert len(json_response.get("issues", [])) == 0 5064 assert json_response["next_page_token"] == "" 5065 5066 5067 def test_search_issues_with_trace_count(): 5068 request_message = SearchIssues() 5069 request_message.include_trace_count = True 5070 5071 issues = [ 5072 Issue( 5073 issue_id="iss-1", 5074 experiment_id="exp-1", 5075 name="Issue with traces", 5076 description="Has 2 traces", 5077 status=IssueStatus.PENDING, 5078 created_timestamp=1234567890, 5079 last_updated_timestamp=1234567890, 5080 trace_count=2, 5081 ), 5082 Issue( 5083 issue_id="iss-2", 5084 experiment_id="exp-1", 5085 name="Issue without traces", 5086 description="Has no traces", 5087 status=IssueStatus.PENDING, 5088 created_timestamp=1234567891, 5089 last_updated_timestamp=1234567891, 5090 trace_count=0, 5091 ), 5092 ] 5093 5094 with ( 5095 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 5096 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 5097 ): 5098 mock_store.return_value.search_issues.return_value = PagedList(issues, token=None) 5099 5100 response = _search_issues() 5101 5102 call_kwargs = mock_store.return_value.search_issues.call_args[1] 5103 assert call_kwargs["include_trace_count"] is True 5104 5105 json_response = json.loads(response.get_data()) 5106 assert len(json_response["issues"]) == 2 5107 assert json_response["issues"][0]["trace_count"] == 2 5108 assert json_response["issues"][1]["trace_count"] == 0 5109 5110 5111 def test_create_issue_with_empty_lists(): 5112 request_message = CreateIssue() 5113 request_message.experiment_id = "exp-123" 5114 request_message.name = "Test issue" 5115 request_message.description = "Test description" 5116 5117 issue = Issue( 5118 issue_id="iss-empty-lists", 5119 experiment_id="exp-123", 5120 name="Test issue", 5121 description="Test description", 5122 status=IssueStatus.PENDING, 5123 created_timestamp=1234567890, 5124 last_updated_timestamp=1234567890, 5125 ) 5126 5127 with ( 5128 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 5129 mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message), 5130 ): 5131 mock_store.return_value.create_issue.return_value = issue 5132 5133 _create_issue() 5134 5135 call_kwargs = mock_store.return_value.create_issue.call_args[1] 5136 # Empty lists should be passed as None 5137 assert call_kwargs["root_causes"] is None 5138 5139 5140 def test_invoke_issue_detection_handler_success(monkeypatch): 5141 monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "true") 5142 5143 mock_job = JobEntity( 5144 job_id="job-123", 5145 creation_time=1234567890000, 5146 job_name="invoke_issue_detection", 5147 params='{"experiment_id": "exp-123"}', 5148 timeout=None, 5149 status=JobStatus.PENDING, 5150 result=None, 5151 retry_count=0, 5152 last_update_time=1234567890000, 5153 status_details=None, 5154 ) 5155 5156 mock_run_info = mock.MagicMock() 5157 mock_run_info.run_id = "run-123" 5158 mock_run = mock.MagicMock() 5159 mock_run.info = mock_run_info 5160 5161 request_json = { 5162 "experiment_id": "exp-123", 5163 "trace_ids": ["trace-1", "trace-2"], 5164 "categories": ["correctness", "safety"], 5165 "provider": "openai", 5166 "model": "gpt-4o", 5167 "secret_id": "secret-123", 5168 } 5169 5170 with ( 5171 mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store, 5172 mock.patch( 5173 "mlflow.genai.discovery.job._fetch_provider_credentials", 5174 return_value={"OPENAI_API_KEY": "test-key"}, 5175 ) as mock_fetch_creds, 5176 mock.patch("mlflow.server.jobs.submit_job", return_value=mock_job) as mock_submit_job, 5177 mock.patch("mlflow.start_run", return_value=mock_run), 5178 mock.patch("mlflow.set_tag"), 5179 mock.patch("mlflow.end_run"), 5180 app.test_client() as c, 5181 ): 5182 resp = c.post( 5183 "/ajax-api/3.0/mlflow/issues/invoke", 5184 json=request_json, 5185 ) 5186 assert resp.status_code == 200 5187 json_response = resp.get_json() 5188 5189 assert json_response["job_id"] == "job-123" 5190 assert json_response["run_id"] == "run-123" 5191 5192 mock_fetch_creds.assert_called_once_with(mock_store.return_value, "openai", "secret-123") 5193 mock_submit_job.assert_called_once() 5194 call_kwargs = mock_submit_job.call_args.kwargs 5195 assert call_kwargs["params"]["experiment_id"] == "exp-123" 5196 assert call_kwargs["params"]["trace_ids"] == ["trace-1", "trace-2"] 5197 assert call_kwargs["params"]["categories"] == ["correctness", "safety"] 5198 assert call_kwargs["params"]["model"] == "openai:/gpt-4o" 5199 assert call_kwargs["extra_envs"] == {"OPENAI_API_KEY": "test-key"} 5200 5201 5202 def test_invoke_issue_detection_handler_with_endpoint(monkeypatch): 5203 monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "true") 5204 5205 mock_job = JobEntity( 5206 job_id="job-456", 5207 creation_time=1234567890000, 5208 job_name="invoke_issue_detection", 5209 params='{"experiment_id": "exp-123"}', 5210 timeout=None, 5211 status=JobStatus.PENDING, 5212 result=None, 5213 retry_count=0, 5214 last_update_time=1234567890000, 5215 status_details=None, 5216 ) 5217 5218 mock_run_info = mock.MagicMock() 5219 mock_run_info.run_id = "run-456" 5220 mock_run = mock.MagicMock() 5221 mock_run.info = mock_run_info 5222 5223 request_json = { 5224 "experiment_id": "exp-123", 5225 "trace_ids": ["trace-1"], 5226 "categories": ["correctness"], 5227 "provider": "openai", 5228 "endpoint_name": "my-endpoint", 5229 "secret_id": "secret-123", 5230 } 5231 5232 with ( 5233 mock.patch( 5234 "mlflow.genai.discovery.job._fetch_provider_credentials", 5235 return_value={"OPENAI_API_KEY": "test-key"}, 5236 ), 5237 mock.patch("mlflow.server.jobs.submit_job", return_value=mock_job) as mock_submit_job, 5238 mock.patch("mlflow.start_run", return_value=mock_run), 5239 mock.patch("mlflow.set_tag"), 5240 mock.patch("mlflow.end_run"), 5241 app.test_client() as c, 5242 ): 5243 resp = c.post( 5244 "/ajax-api/3.0/mlflow/issues/invoke", 5245 json=request_json, 5246 ) 5247 assert resp.status_code == 200 5248 json_response = resp.get_json() 5249 5250 assert json_response["job_id"] == "job-456" 5251 assert json_response["run_id"] == "run-456" 5252 5253 call_kwargs = mock_submit_job.call_args.kwargs 5254 assert call_kwargs["params"]["model"] == "gateway:/my-endpoint" 5255 5256 5257 def test_invoke_issue_detection_handler_missing_required_params(monkeypatch): 5258 monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "true") 5259 5260 request_json = { 5261 "experiment_id": "exp-123", 5262 "trace_ids": ["trace-1"], 5263 "categories": ["correctness"], 5264 "provider": "openai", 5265 # Missing both 'model' and 'endpoint_name' 5266 "secret_id": "secret-123", 5267 } 5268 5269 with ( 5270 mock.patch( 5271 "mlflow.genai.discovery.job._fetch_provider_credentials", 5272 return_value={"OPENAI_API_KEY": "test-key"}, 5273 ), 5274 app.test_client() as c, 5275 ): 5276 resp = c.post( 5277 "/ajax-api/3.0/mlflow/issues/invoke", 5278 json=request_json, 5279 ) 5280 assert resp.status_code == 500 5281 json_response = resp.get_json() 5282 assert ( 5283 "Either 'endpoint_name' or both 'provider' and 'model' must be provided" 5284 in json_response["message"] 5285 ) 5286 5287 5288 def test_get_job_success(mock_job_store): 5289 mock_job = JobEntity( 5290 job_id="job-123", 5291 creation_time=1234567890000, 5292 job_name="invoke_issue_detection", 5293 params='{"experiment_id": "exp-123"}', 5294 timeout=None, 5295 status=JobStatus.SUCCEEDED, 5296 result='{"summary": "Found 3 issues", "issues": 3, "total_traces_analyzed": 10}', 5297 retry_count=0, 5298 last_update_time=1234567900000, 5299 status_details=None, 5300 ) 5301 5302 with ( 5303 mock.patch("mlflow.server.jobs.get_job", return_value=mock_job), 5304 app.test_client() as c, 5305 ): 5306 resp = c.get("/ajax-api/3.0/mlflow/jobs/job-123") 5307 assert resp.status_code == 200 5308 json_response = resp.get_json() 5309 5310 assert json_response["status"] == "SUCCEEDED" 5311 assert json_response["result"]["summary"] == "Found 3 issues" 5312 assert json_response["result"]["issues"] == 3 5313 assert json_response["result"]["total_traces_analyzed"] == 10 5314 assert json_response["status_details"] is None 5315 5316 5317 def test_get_job_pending(mock_job_store): 5318 mock_job = JobEntity( 5319 job_id="job-pending", 5320 creation_time=1234567890000, 5321 job_name="invoke_issue_detection", 5322 params='{"experiment_id": "exp-123"}', 5323 timeout=None, 5324 status=JobStatus.PENDING, 5325 result=None, 5326 retry_count=0, 5327 last_update_time=1234567890000, 5328 status_details=None, 5329 ) 5330 5331 with ( 5332 mock.patch("mlflow.server.jobs.get_job", return_value=mock_job), 5333 app.test_client() as c, 5334 ): 5335 resp = c.get("/ajax-api/3.0/mlflow/jobs/job-pending") 5336 assert resp.status_code == 200 5337 json_response = resp.get_json() 5338 5339 assert json_response["status"] == "PENDING" 5340 assert json_response["result"] is None 5341 assert json_response["status_details"] is None 5342 5343 5344 def test_cancel_job_success(mock_job_store): 5345 mock_job = JobEntity( 5346 job_id="job-123", 5347 creation_time=1234567890000, 5348 job_name="invoke_issue_detection", 5349 params='{"experiment_id": "exp-123"}', 5350 timeout=None, 5351 status=JobStatus.CANCELED, 5352 result=None, 5353 retry_count=0, 5354 last_update_time=1234567900000, 5355 status_details=None, 5356 ) 5357 5358 with ( 5359 mock.patch("mlflow.server.jobs.cancel_job", return_value=mock_job) as mock_cancel, 5360 app.test_client() as c, 5361 ): 5362 resp = c.patch("/ajax-api/3.0/mlflow/jobs/cancel/job-123") 5363 assert resp.status_code == 200 5364 json_response = resp.get_json() 5365 5366 assert json_response["status"] == "CANCELED" 5367 mock_cancel.assert_called_once_with("job-123") 5368 5369 5370 def test_get_rest_path_respects_static_prefix(monkeypatch): 5371 # Without prefix, both return bare paths 5372 assert _get_rest_path("/mlflow/experiments/search") == "/api/2.0/mlflow/experiments/search" 5373 assert _get_ajax_path("/mlflow/experiments/search") == "/ajax-api/2.0/mlflow/experiments/search" 5374 5375 # With prefix, both should include the prefix 5376 monkeypatch.setenv(STATIC_PREFIX_ENV_VAR, "/myapp") 5377 assert ( 5378 _get_rest_path("/mlflow/experiments/search") == "/myapp/api/2.0/mlflow/experiments/search" 5379 ) 5380 assert ( 5381 _get_ajax_path("/mlflow/experiments/search") 5382 == "/myapp/ajax-api/2.0/mlflow/experiments/search" 5383 )