test_rest_tracking.py
1 """ 2 Integration test which starts a local Tracking Server on an ephemeral port, 3 and ensures we can use the tracking API to communicate with it. 4 """ 5 6 import json 7 import logging 8 import math 9 import os 10 import pathlib 11 import posixpath 12 import subprocess 13 import sys 14 import time 15 import urllib.parse 16 from dataclasses import asdict 17 from io import StringIO 18 from pathlib import Path 19 from unittest import mock 20 21 import flask 22 import pandas as pd 23 import pytest 24 import requests 25 from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan 26 27 import mlflow.experiments 28 import mlflow.pyfunc 29 from mlflow import MlflowClient 30 from mlflow.artifacts import download_artifacts 31 from mlflow.data.pandas_dataset import from_pandas 32 from mlflow.entities import ( 33 Dataset, 34 DatasetInput, 35 FallbackConfig, 36 FallbackStrategy, 37 GatewayEndpointModelConfig, 38 GatewayModelLinkageType, 39 GatewayResourceType, 40 InputTag, 41 IssueSeverity, 42 IssueStatus, 43 Metric, 44 Param, 45 RoutingStrategy, 46 RunInputs, 47 RunTag, 48 Span, 49 SpanEvent, 50 SpanStatusCode, 51 ViewType, 52 ) 53 from mlflow.entities.logged_model_input import LoggedModelInput 54 from mlflow.entities.logged_model_output import LoggedModelOutput 55 from mlflow.entities.logged_model_status import LoggedModelStatus 56 from mlflow.entities.span import SpanAttributeKey 57 from mlflow.entities.trace_data import TraceData 58 from mlflow.entities.trace_info import TraceInfo 59 from mlflow.entities.trace_location import TraceLocation 60 from mlflow.entities.trace_metrics import ( 61 AggregationType, 62 MetricAggregation, 63 MetricViewType, 64 ) 65 from mlflow.entities.trace_state import TraceState 66 from mlflow.entities.trace_status import TraceStatus 67 from mlflow.environment_variables import ( 68 _MLFLOW_GO_STORE_TESTING, 69 MLFLOW_SERVER_GRAPHQL_MAX_ALIASES, 70 MLFLOW_SERVER_GRAPHQL_MAX_ROOT_FIELDS, 71 MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT, 72 ) 73 from mlflow.exceptions import MlflowException, RestException 74 from mlflow.genai.datasets import ( 75 add_dataset_to_experiments, 76 create_dataset, 77 remove_dataset_from_experiments, 78 ) 79 from mlflow.models import Model 80 from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST, ErrorCode 81 from mlflow.server import handlers 82 from mlflow.server.fastapi_app import app 83 from mlflow.server.handlers import initialize_backend_stores 84 from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore 85 from mlflow.tracing.analysis import TraceFilterCorrelationResult 86 from mlflow.tracing.client import TracingClient 87 from mlflow.tracing.constant import ( 88 TRACE_SCHEMA_VERSION_KEY, 89 TraceMetricDimensionKey, 90 TraceMetricKey, 91 ) 92 from mlflow.tracing.utils import build_otel_context 93 from mlflow.utils import mlflow_tags 94 from mlflow.utils.file_utils import TempDir, path_to_local_file_uri 95 from mlflow.utils.mlflow_tags import ( 96 MLFLOW_DATASET_CONTEXT, 97 MLFLOW_GIT_COMMIT, 98 MLFLOW_PARENT_RUN_ID, 99 MLFLOW_PROJECT_ENTRY_POINT, 100 MLFLOW_SOURCE_NAME, 101 MLFLOW_SOURCE_TYPE, 102 MLFLOW_USER, 103 ) 104 from mlflow.utils.os import is_windows 105 from mlflow.utils.proto_json_utils import message_to_json 106 from mlflow.utils.time import get_current_time_millis 107 108 from tests.helper_functions import get_safe_port 109 from tests.integration.utils import invoke_cli_runner 110 from tests.tracking.integration_test_utils import ( 111 ServerThread, 112 _init_server, 113 _send_rest_tracking_post_request, 114 ) 115 116 _logger = logging.getLogger(__name__) 117 118 119 @pytest.fixture(params=["file", "sqlalchemy"]) 120 def store_type(request): 121 """Provides the store type for parameterized tests.""" 122 if request.param == "file": 123 pytest.skip("FileStore is no longer supported.") 124 return request.param 125 126 127 @pytest.fixture 128 def mlflow_client(store_type: str, tmp_path: Path, db_uri: str, monkeypatch): 129 """Provides an MLflow Tracking API client pointed at the local tracking server.""" 130 # Set passphrase for secrets management (required for encryption) 131 monkeypatch.setenv( 132 "MLFLOW_CRYPTO_KEK_PASSPHRASE", "test-passphrase-at-least-32-characters-long" 133 ) 134 135 if store_type == "file": 136 backend_uri = tmp_path.joinpath("file").as_uri() 137 elif store_type == "sqlalchemy": 138 backend_uri = db_uri 139 140 # Force-reset backend stores before each test. 141 handlers._tracking_store = None 142 handlers._model_registry_store = None 143 initialize_backend_stores(backend_uri, default_artifact_root=tmp_path.as_uri()) 144 145 with ServerThread(app, get_safe_port()) as url: 146 yield MlflowClient(url) 147 148 149 @pytest.fixture 150 def mlflow_client_with_secrets(tmp_path: Path, monkeypatch): 151 """Provides an MLflow Tracking API client with fresh database for secrets management. 152 153 Creates a fresh SQLite database for each test to avoid encryption state pollution. 154 This is necessary because the KEK encryption state can persist across tests when 155 using a shared cached database. 156 """ 157 # Set passphrase for secrets management (required for encryption) 158 monkeypatch.setenv( 159 "MLFLOW_CRYPTO_KEK_PASSPHRASE", "test-passphrase-at-least-32-characters-long" 160 ) 161 162 # Create fresh database for this test (not using cached_db) 163 backend_uri = f"sqlite:///{tmp_path}/mlflow.db" 164 artifact_uri = (tmp_path / "artifacts").as_uri() 165 166 # Initialize the store (which creates tables) 167 store = SqlAlchemyStore(backend_uri, artifact_uri) 168 store.engine.dispose() 169 170 # Force-reset backend stores before each test 171 handlers._tracking_store = None 172 handlers._model_registry_store = None 173 initialize_backend_stores(backend_uri, default_artifact_root=artifact_uri) 174 175 with ServerThread(app, get_safe_port()) as url: 176 yield MlflowClient(url) 177 178 179 @pytest.fixture 180 def cli_env(mlflow_client): 181 """Provides an environment for the MLflow CLI pointed at the local tracking server.""" 182 return { 183 "LC_ALL": "en_US.UTF-8", 184 "LANG": "en_US.UTF-8", 185 "MLFLOW_TRACKING_URI": mlflow_client.tracking_uri, 186 } 187 188 189 def create_experiments(client, names): 190 return [client.create_experiment(n) for n in names] 191 192 193 def test_create_get_search_experiment(mlflow_client): 194 experiment_id = mlflow_client.create_experiment( 195 "My Experiment", 196 artifact_location="my_location", 197 tags={"key1": "val1", "key2": "val2"}, 198 ) 199 exp = mlflow_client.get_experiment(experiment_id) 200 assert exp.name == "My Experiment" 201 if is_windows(): 202 assert exp.artifact_location == pathlib.Path.cwd().joinpath("my_location").as_uri() 203 else: 204 assert exp.artifact_location == str(pathlib.Path.cwd().joinpath("my_location")) 205 assert len(exp.tags) == 2 206 assert exp.tags["key1"] == "val1" 207 assert exp.tags["key2"] == "val2" 208 209 experiments = mlflow_client.search_experiments() 210 assert {e.name for e in experiments} == {"My Experiment", "Default"} 211 mlflow_client.delete_experiment(experiment_id) 212 assert {e.name for e in mlflow_client.search_experiments()} == {"Default"} 213 assert {e.name for e in mlflow_client.search_experiments(view_type=ViewType.ACTIVE_ONLY)} == { 214 "Default" 215 } 216 assert {e.name for e in mlflow_client.search_experiments(view_type=ViewType.DELETED_ONLY)} == { 217 "My Experiment" 218 } 219 assert {e.name for e in mlflow_client.search_experiments(view_type=ViewType.ALL)} == { 220 "My Experiment", 221 "Default", 222 } 223 active_exps_paginated = mlflow_client.search_experiments(max_results=1) 224 assert {e.name for e in active_exps_paginated} == {"Default"} 225 assert active_exps_paginated.token is None 226 227 all_exps_paginated = mlflow_client.search_experiments(max_results=1, view_type=ViewType.ALL) 228 first_page_names = {e.name for e in all_exps_paginated} 229 all_exps_second_page = mlflow_client.search_experiments( 230 max_results=1, view_type=ViewType.ALL, page_token=all_exps_paginated.token 231 ) 232 second_page_names = {e.name for e in all_exps_second_page} 233 assert len(first_page_names) == 1 234 assert len(second_page_names) == 1 235 assert first_page_names.union(second_page_names) == {"Default", "My Experiment"} 236 237 238 def test_create_experiment_validation(mlflow_client): 239 def assert_bad_request(payload, expected_error_message): 240 response = _send_rest_tracking_post_request( 241 mlflow_client.tracking_uri, 242 "/api/2.0/mlflow/experiments/create", 243 payload, 244 ) 245 assert response.status_code == 400 246 assert expected_error_message in response.text 247 248 assert_bad_request( 249 { 250 "name": 123, 251 }, 252 "Invalid value 123 for parameter 'name'", 253 ) 254 assert_bad_request({}, "Missing value for required parameter 'name'.") 255 assert_bad_request( 256 { 257 "name": "experiment name", 258 "artifact_location": 9.0, 259 "tags": [{"key": "key", "value": "value"}], 260 }, 261 "Invalid value 9.0 for parameter 'artifact_location'", 262 ) 263 assert_bad_request( 264 { 265 "name": "experiment name", 266 "artifact_location": "my_location", 267 "tags": "5", 268 }, 269 "Invalid value \\\"5\\\" for parameter 'tags'", 270 ) 271 272 273 def test_delete_restore_experiment(mlflow_client): 274 experiment_id = mlflow_client.create_experiment("Deleterious") 275 assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "active" 276 mlflow_client.delete_experiment(experiment_id) 277 assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "deleted" 278 mlflow_client.restore_experiment(experiment_id) 279 assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "active" 280 281 282 def test_delete_restore_experiment_cli(mlflow_client, cli_env): 283 experiment_name = "DeleteriousCLI" 284 invoke_cli_runner( 285 mlflow.experiments.commands, 286 ["create", "--experiment-name", experiment_name], 287 env=cli_env, 288 ) 289 experiment_id = mlflow_client.get_experiment_by_name(experiment_name).experiment_id 290 assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "active" 291 invoke_cli_runner( 292 mlflow.experiments.commands, ["delete", "-x", str(experiment_id)], env=cli_env 293 ) 294 assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "deleted" 295 invoke_cli_runner( 296 mlflow.experiments.commands, ["restore", "-x", str(experiment_id)], env=cli_env 297 ) 298 assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "active" 299 300 301 def test_rename_experiment(mlflow_client): 302 experiment_id = mlflow_client.create_experiment("BadName") 303 assert mlflow_client.get_experiment(experiment_id).name == "BadName" 304 mlflow_client.rename_experiment(experiment_id, "GoodName") 305 assert mlflow_client.get_experiment(experiment_id).name == "GoodName" 306 307 308 def test_rename_experiment_cli(mlflow_client, cli_env): 309 bad_experiment_name = "CLIBadName" 310 good_experiment_name = "CLIGoodName" 311 312 invoke_cli_runner( 313 mlflow.experiments.commands, ["create", "-n", bad_experiment_name], env=cli_env 314 ) 315 experiment_id = mlflow_client.get_experiment_by_name(bad_experiment_name).experiment_id 316 assert mlflow_client.get_experiment(experiment_id).name == bad_experiment_name 317 invoke_cli_runner( 318 mlflow.experiments.commands, 319 [ 320 "rename", 321 "--experiment-id", 322 str(experiment_id), 323 "--new-name", 324 good_experiment_name, 325 ], 326 env=cli_env, 327 ) 328 assert mlflow_client.get_experiment(experiment_id).name == good_experiment_name 329 330 331 @pytest.mark.parametrize("parent_run_id_kwarg", [None, "my-parent-id"]) 332 def test_create_run_all_args(mlflow_client, parent_run_id_kwarg): 333 user = "username" 334 source_name = "Hello" 335 entry_point = "entry" 336 source_version = "abc" 337 create_run_kwargs = { 338 "start_time": 456, 339 "run_name": "my name", 340 "tags": { 341 MLFLOW_USER: user, 342 MLFLOW_SOURCE_TYPE: "LOCAL", 343 MLFLOW_SOURCE_NAME: source_name, 344 MLFLOW_PROJECT_ENTRY_POINT: entry_point, 345 MLFLOW_GIT_COMMIT: source_version, 346 MLFLOW_PARENT_RUN_ID: "7", 347 "my": "tag", 348 "other": "tag", 349 }, 350 } 351 experiment_id = mlflow_client.create_experiment( 352 f"Run A Lot (parent_run_id={parent_run_id_kwarg})" 353 ) 354 created_run = mlflow_client.create_run(experiment_id, **create_run_kwargs) 355 run_id = created_run.info.run_id 356 _logger.info(f"Run id={run_id}") 357 fetched_run = mlflow_client.get_run(run_id) 358 for run in [created_run, fetched_run]: 359 assert run.info.run_id == run_id 360 assert run.info.experiment_id == experiment_id 361 assert run.info.user_id == user 362 assert run.info.start_time == create_run_kwargs["start_time"] 363 assert run.info.run_name == "my name" 364 for tag in create_run_kwargs["tags"]: 365 assert tag in run.data.tags 366 assert run.data.tags.get(MLFLOW_USER) == user 367 assert run.data.tags.get(MLFLOW_PARENT_RUN_ID) == parent_run_id_kwarg or "7" 368 assert [run.info for run in mlflow_client.search_runs([experiment_id])] == [run.info] 369 370 371 def test_create_run_defaults(mlflow_client): 372 experiment_id = mlflow_client.create_experiment("Run A Little") 373 created_run = mlflow_client.create_run(experiment_id) 374 run_id = created_run.info.run_id 375 run = mlflow_client.get_run(run_id) 376 assert run.info.run_id == run_id 377 assert run.info.experiment_id == experiment_id 378 assert run.info.user_id == "unknown" 379 380 381 def test_log_metrics_params_tags(mlflow_client): 382 experiment_id = mlflow_client.create_experiment("Oh My") 383 created_run = mlflow_client.create_run(experiment_id) 384 run_id = created_run.info.run_id 385 mlflow_client.log_metric(run_id, key="metric", value=123.456, timestamp=789, step=2) 386 mlflow_client.log_metric(run_id, key="nan_metric", value=float("nan")) 387 mlflow_client.log_metric(run_id, key="inf_metric", value=float("inf")) 388 mlflow_client.log_metric(run_id, key="-inf_metric", value=-float("inf")) 389 mlflow_client.log_metric(run_id, key="stepless-metric", value=987.654, timestamp=321) 390 mlflow_client.log_param(run_id, "param", "value") 391 mlflow_client.set_tag(run_id, "taggity", "do-dah") 392 run = mlflow_client.get_run(run_id) 393 assert run.data.metrics.get("metric") == 123.456 394 assert math.isnan(run.data.metrics.get("nan_metric")) 395 assert run.data.metrics.get("inf_metric") >= 1.7976931348623157e308 396 assert run.data.metrics.get("-inf_metric") <= -1.7976931348623157e308 397 assert run.data.metrics.get("stepless-metric") == 987.654 398 assert run.data.params.get("param") == "value" 399 assert run.data.tags.get("taggity") == "do-dah" 400 metric_history0 = mlflow_client.get_metric_history(run_id, "metric") 401 assert len(metric_history0) == 1 402 metric0 = metric_history0[0] 403 assert metric0.key == "metric" 404 assert metric0.value == 123.456 405 assert metric0.timestamp == 789 406 assert metric0.step == 2 407 metric_history1 = mlflow_client.get_metric_history(run_id, "stepless-metric") 408 assert len(metric_history1) == 1 409 metric1 = metric_history1[0] 410 assert metric1.key == "stepless-metric" 411 assert metric1.value == 987.654 412 assert metric1.timestamp == 321 413 assert metric1.step == 0 414 415 metric_history = mlflow_client.get_metric_history(run_id, "a_test_accuracy") 416 assert metric_history == [] 417 418 419 def test_log_metric_validation(mlflow_client): 420 experiment_id = mlflow_client.create_experiment("metrics validation") 421 created_run = mlflow_client.create_run(experiment_id) 422 run_id = created_run.info.run_id 423 424 def assert_bad_request(payload, expected_error_message): 425 response = _send_rest_tracking_post_request( 426 mlflow_client.tracking_uri, 427 "/api/2.0/mlflow/runs/log-metric", 428 payload, 429 ) 430 assert response.status_code == 400 431 assert expected_error_message in response.text 432 433 assert_bad_request( 434 { 435 "run_id": 31, 436 "key": "metric", 437 "value": 41, 438 "timestamp": 59, 439 "step": 26, 440 }, 441 "Invalid value 31 for parameter 'run_id' supplied", 442 ) 443 assert_bad_request( 444 { 445 "run_id": run_id, 446 "key": 31, 447 "value": 41, 448 "timestamp": 59, 449 "step": 26, 450 }, 451 "Invalid value 31 for parameter 'key' supplied", 452 ) 453 assert_bad_request( 454 { 455 "run_id": run_id, 456 "key": "foo", 457 "value": 31, 458 "timestamp": 59, 459 "step": "foo", 460 }, 461 "Invalid value \\\"foo\\\" for parameter 'step' supplied", 462 ) 463 assert_bad_request( 464 { 465 "run_id": run_id, 466 "key": "foo", 467 "value": 31, 468 "timestamp": "foo", 469 "step": 41, 470 }, 471 "Invalid value \\\"foo\\\" for parameter 'timestamp' supplied", 472 ) 473 assert_bad_request( 474 { 475 "run_id": None, 476 "key": "foo", 477 "value": 31, 478 "timestamp": 59, 479 "step": 41, 480 }, 481 "Missing value for required parameter 'run_id'", 482 ) 483 assert_bad_request( 484 { 485 "run_id": run_id, 486 # Missing key 487 "value": 31, 488 "timestamp": 59, 489 "step": 41, 490 }, 491 "Missing value for required parameter 'key'", 492 ) 493 assert_bad_request( 494 { 495 "run_id": run_id, 496 "key": None, 497 "value": 31, 498 "timestamp": 59, 499 "step": 41, 500 }, 501 "Missing value for required parameter 'key'", 502 ) 503 504 505 def test_log_metric_model(mlflow_client: MlflowClient): 506 experiment_id = mlflow_client.create_experiment("metrics validation") 507 run = mlflow_client.create_run(experiment_id) 508 model = mlflow_client.create_logged_model(experiment_id) 509 mlflow_client.log_metric( 510 run.info.run_id, 511 key="metric", 512 value=0.5, 513 timestamp=123456789, 514 step=1, 515 dataset_name="name", 516 dataset_digest="digest", 517 model_id=model.model_id, 518 ) 519 520 model = mlflow_client.get_logged_model(model.model_id) 521 assert model.metrics == [ 522 Metric( 523 key="metric", 524 value=0.5, 525 timestamp=123456789, 526 step=1, 527 model_id=model.model_id, 528 dataset_name="name", 529 dataset_digest="digest", 530 run_id=run.info.run_id, 531 ) 532 ] 533 534 535 def test_log_param_validation(mlflow_client): 536 experiment_id = mlflow_client.create_experiment("params validation") 537 created_run = mlflow_client.create_run(experiment_id) 538 run_id = created_run.info.run_id 539 540 def assert_bad_request(payload, expected_error_message): 541 response = _send_rest_tracking_post_request( 542 mlflow_client.tracking_uri, 543 "/api/2.0/mlflow/runs/log-parameter", 544 payload, 545 ) 546 assert response.status_code == 400 547 assert expected_error_message in response.text 548 549 assert_bad_request( 550 { 551 "run_id": 31, 552 "key": "param", 553 "value": 41, 554 }, 555 "Invalid value 31 for parameter 'run_id' supplied", 556 ) 557 assert_bad_request( 558 { 559 "run_id": run_id, 560 "key": 31, 561 "value": 41, 562 }, 563 "Invalid value 31 for parameter 'key' supplied", 564 ) 565 566 567 def test_log_param_with_empty_string_as_value(mlflow_client): 568 experiment_id = mlflow_client.create_experiment( 569 test_log_param_with_empty_string_as_value.__name__ 570 ) 571 created_run = mlflow_client.create_run(experiment_id) 572 run_id = created_run.info.run_id 573 574 mlflow_client.log_param(run_id, "param_key", "") 575 assert {"param_key": ""}.items() <= mlflow_client.get_run(run_id).data.params.items() 576 577 578 def test_set_tag_with_empty_string_as_value(mlflow_client): 579 experiment_id = mlflow_client.create_experiment( 580 test_set_tag_with_empty_string_as_value.__name__ 581 ) 582 created_run = mlflow_client.create_run(experiment_id) 583 run_id = created_run.info.run_id 584 585 mlflow_client.set_tag(run_id, "tag_key", "") 586 assert {"tag_key": ""}.items() <= mlflow_client.get_run(run_id).data.tags.items() 587 588 589 def test_log_batch_containing_params_and_tags_with_empty_string_values(mlflow_client): 590 experiment_id = mlflow_client.create_experiment( 591 test_log_batch_containing_params_and_tags_with_empty_string_values.__name__ 592 ) 593 created_run = mlflow_client.create_run(experiment_id) 594 run_id = created_run.info.run_id 595 596 mlflow_client.log_batch( 597 run_id=run_id, 598 params=[Param("param_key", "")], 599 tags=[RunTag("tag_key", "")], 600 ) 601 assert {"param_key": ""}.items() <= mlflow_client.get_run(run_id).data.params.items() 602 assert {"tag_key": ""}.items() <= mlflow_client.get_run(run_id).data.tags.items() 603 604 605 def test_set_tag_validation(mlflow_client): 606 experiment_id = mlflow_client.create_experiment("tags validation") 607 created_run = mlflow_client.create_run(experiment_id) 608 run_id = created_run.info.run_id 609 610 def assert_bad_request(payload, expected_error_message): 611 response = _send_rest_tracking_post_request( 612 mlflow_client.tracking_uri, 613 "/api/2.0/mlflow/runs/set-tag", 614 payload, 615 ) 616 assert response.status_code == 400 617 assert expected_error_message in response.text 618 619 assert_bad_request( 620 { 621 "run_id": 31, 622 "key": "tag", 623 "value": 41, 624 }, 625 "Invalid value 31 for parameter 'run_id' supplied", 626 ) 627 assert_bad_request( 628 { 629 "run_id": run_id, 630 "key": "param", 631 "value": 41, 632 }, 633 "Invalid value 41 for parameter 'value' supplied", 634 ) 635 assert_bad_request( 636 { 637 "run_id": run_id, 638 # Missing key 639 "value": "value", 640 }, 641 "Missing value for required parameter 'key'", 642 ) 643 644 response = _send_rest_tracking_post_request( 645 mlflow_client.tracking_uri, 646 "/api/2.0/mlflow/runs/set-tag", 647 { 648 "run_uuid": run_id, 649 "key": "key", 650 "value": "value", 651 }, 652 ) 653 assert response.status_code == 200 654 655 656 def test_path_validation(mlflow_client): 657 experiment_id = mlflow_client.create_experiment("tags validation") 658 created_run = mlflow_client.create_run(experiment_id) 659 run_id = created_run.info.run_id 660 invalid_path = "../path" 661 662 def assert_response(resp): 663 assert resp.status_code == 400 664 body = response.json() 665 assert body["error_code"] == "INVALID_PARAMETER_VALUE" 666 assert body["message"] == "Invalid path" 667 668 response = requests.get( 669 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/artifacts/list", 670 params={"run_id": run_id, "path": invalid_path}, 671 ) 672 assert_response(response) 673 674 response = requests.get( 675 f"{mlflow_client.tracking_uri}/get-artifact", 676 params={"run_id": run_id, "path": invalid_path}, 677 ) 678 assert_response(response) 679 680 response = requests.get( 681 f"{mlflow_client.tracking_uri}//model-versions/get-artifact", 682 params={"name": "model", "version": 1, "path": invalid_path}, 683 ) 684 assert_response(response) 685 686 687 def test_set_experiment_tag(mlflow_client): 688 experiment_id = mlflow_client.create_experiment("SetExperimentTagTest") 689 mlflow_client.set_experiment_tag(experiment_id, "dataset", "imagenet1K") 690 experiment = mlflow_client.get_experiment(experiment_id) 691 assert "dataset" in experiment.tags 692 assert experiment.tags["dataset"] == "imagenet1K" 693 # test that updating a tag works 694 mlflow_client.set_experiment_tag(experiment_id, "dataset", "birdbike") 695 experiment = mlflow_client.get_experiment(experiment_id) 696 assert "dataset" in experiment.tags 697 assert experiment.tags["dataset"] == "birdbike" 698 # test that setting a tag on 1 experiment does not impact another experiment. 699 experiment_id_2 = mlflow_client.create_experiment("SetExperimentTagTest2") 700 experiment2 = mlflow_client.get_experiment(experiment_id_2) 701 assert len(experiment2.tags) == 0 702 # test that setting a tag on different experiments maintain different values across experiments 703 mlflow_client.set_experiment_tag(experiment_id_2, "dataset", "birds200") 704 experiment = mlflow_client.get_experiment(experiment_id) 705 experiment2 = mlflow_client.get_experiment(experiment_id_2) 706 assert "dataset" in experiment.tags 707 assert experiment.tags["dataset"] == "birdbike" 708 assert "dataset" in experiment2.tags 709 assert experiment2.tags["dataset"] == "birds200" 710 # test can set multi-line tags 711 mlflow_client.set_experiment_tag(experiment_id, "multiline tag", "value2\nvalue2\nvalue2") 712 experiment = mlflow_client.get_experiment(experiment_id) 713 assert "multiline tag" in experiment.tags 714 assert experiment.tags["multiline tag"] == "value2\nvalue2\nvalue2" 715 716 717 def test_set_experiment_tag_with_empty_string_as_value(mlflow_client): 718 experiment_id = mlflow_client.create_experiment( 719 test_set_experiment_tag_with_empty_string_as_value.__name__ 720 ) 721 mlflow_client.set_experiment_tag(experiment_id, "tag_key", "") 722 assert {"tag_key": ""}.items() <= mlflow_client.get_experiment(experiment_id).tags.items() 723 724 725 def test_delete_experiment_tag(mlflow_client): 726 experiment_id = mlflow_client.create_experiment("DeleteExperimentTagTest") 727 mlflow_client.set_experiment_tag(experiment_id, "dataset", "imagenet1K") 728 experiment = mlflow_client.get_experiment(experiment_id) 729 assert experiment.tags["dataset"] == "imagenet1K" 730 # test that deleting a tag works 731 mlflow_client.delete_experiment_tag(experiment_id, "dataset") 732 experiment = mlflow_client.get_experiment(experiment_id) 733 assert "dataset" not in experiment.tags 734 735 736 def test_delete_tag(mlflow_client): 737 experiment_id = mlflow_client.create_experiment("DeleteTagExperiment") 738 created_run = mlflow_client.create_run(experiment_id) 739 run_id = created_run.info.run_id 740 mlflow_client.log_metric(run_id, key="metric", value=123.456, timestamp=789, step=2) 741 mlflow_client.log_metric(run_id, key="stepless-metric", value=987.654, timestamp=321) 742 mlflow_client.log_param(run_id, "param", "value") 743 mlflow_client.set_tag(run_id, "taggity", "do-dah") 744 run = mlflow_client.get_run(run_id) 745 assert "taggity" in run.data.tags 746 assert run.data.tags["taggity"] == "do-dah" 747 mlflow_client.delete_tag(run_id, "taggity") 748 run = mlflow_client.get_run(run_id) 749 assert "taggity" not in run.data.tags 750 with pytest.raises(MlflowException, match=r"Run .+ not found"): 751 mlflow_client.delete_tag("fake_run_id", "taggity") 752 with pytest.raises(MlflowException, match="No tag with name: fakeTag"): 753 mlflow_client.delete_tag(run_id, "fakeTag") 754 mlflow_client.delete_run(run_id) 755 with pytest.raises(MlflowException, match=f"The run {run_id} must be in"): 756 mlflow_client.delete_tag(run_id, "taggity") 757 758 759 def test_log_batch(mlflow_client): 760 experiment_id = mlflow_client.create_experiment("Batch em up") 761 created_run = mlflow_client.create_run(experiment_id) 762 run_id = created_run.info.run_id 763 mlflow_client.log_batch( 764 run_id=run_id, 765 metrics=[Metric("metric", 123.456, 789, 3)], 766 params=[Param("param", "value")], 767 tags=[RunTag("taggity", "do-dah")], 768 ) 769 run = mlflow_client.get_run(run_id) 770 assert run.data.metrics.get("metric") == 123.456 771 assert run.data.params.get("param") == "value" 772 assert run.data.tags.get("taggity") == "do-dah" 773 metric_history = mlflow_client.get_metric_history(run_id, "metric") 774 assert len(metric_history) == 1 775 metric = metric_history[0] 776 assert metric.key == "metric" 777 assert metric.value == 123.456 778 assert metric.timestamp == 789 779 assert metric.step == 3 780 781 782 def test_log_batch_validation(mlflow_client): 783 experiment_id = mlflow_client.create_experiment("log_batch validation") 784 created_run = mlflow_client.create_run(experiment_id) 785 run_id = created_run.info.run_id 786 787 def assert_bad_request(payload, expected_error_message): 788 response = _send_rest_tracking_post_request( 789 mlflow_client.tracking_uri, 790 "/api/2.0/mlflow/runs/log-batch", 791 payload, 792 ) 793 assert response.status_code == 400 794 assert expected_error_message in response.text 795 796 for request_parameter in ["metrics", "params", "tags"]: 797 assert_bad_request( 798 { 799 "run_id": run_id, 800 request_parameter: "foo", 801 }, 802 f"Invalid value \\\"foo\\\" for parameter '{request_parameter}' supplied", 803 ) 804 805 ## Should 400 if missing timestamp 806 assert_bad_request( 807 {"run_id": run_id, "metrics": [{"key": "mae", "value": 2.5}]}, 808 "Missing value for required parameter 'metrics[0].timestamp'", 809 ) 810 811 ## Should 200 if timestamp provided but step is not 812 response = _send_rest_tracking_post_request( 813 mlflow_client.tracking_uri, 814 "/api/2.0/mlflow/runs/log-batch", 815 { 816 "run_id": run_id, 817 "metrics": [{"key": "mae", "value": 2.5, "timestamp": 123456789}], 818 }, 819 ) 820 821 assert response.status_code == 200 822 823 824 @pytest.mark.xfail(reason="Tracking server does not support logged-model endpoints yet") 825 @pytest.mark.allow_infer_pip_requirements_fallback 826 def test_log_model(mlflow_client): 827 experiment_id = mlflow_client.create_experiment("Log models") 828 with TempDir(chdr=True): 829 model_paths = [f"model/path/{i}" for i in range(3)] 830 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 831 with mlflow.start_run(experiment_id=experiment_id) as run: 832 for i, m in enumerate(model_paths): 833 mlflow.pyfunc.log_model(name=m, loader_module="mlflow.pyfunc") 834 mlflow.pyfunc.save_model( 835 m, 836 mlflow_model=Model(artifact_path=m, run_id=run.info.run_id), 837 loader_module="mlflow.pyfunc", 838 ) 839 model = Model.load(os.path.join(m, "MLmodel")) 840 run = mlflow.get_run(run.info.run_id) 841 tag = run.data.tags["mlflow.log-model.history"] 842 models = json.loads(tag) 843 model.utc_time_created = models[i]["utc_time_created"] 844 845 history_model_meta = models[i].copy() 846 original_model_uuid = history_model_meta.pop("model_uuid") 847 model_meta = model.get_tags_dict().copy() 848 new_model_uuid = model_meta.pop("model_uuid") 849 assert history_model_meta == model_meta 850 assert original_model_uuid != new_model_uuid 851 assert len(models) == i + 1 852 for j in range(0, i + 1): 853 assert models[j]["artifact_path"] == model_paths[j] 854 855 856 def test_set_terminated_defaults(mlflow_client): 857 experiment_id = mlflow_client.create_experiment("Terminator 1") 858 created_run = mlflow_client.create_run(experiment_id) 859 run_id = created_run.info.run_id 860 assert mlflow_client.get_run(run_id).info.status == "RUNNING" 861 assert mlflow_client.get_run(run_id).info.end_time is None 862 mlflow_client.set_terminated(run_id) 863 assert mlflow_client.get_run(run_id).info.status == "FINISHED" 864 assert mlflow_client.get_run(run_id).info.end_time <= get_current_time_millis() 865 866 867 def test_set_terminated_status(mlflow_client): 868 experiment_id = mlflow_client.create_experiment("Terminator 2") 869 created_run = mlflow_client.create_run(experiment_id) 870 run_id = created_run.info.run_id 871 assert mlflow_client.get_run(run_id).info.status == "RUNNING" 872 assert mlflow_client.get_run(run_id).info.end_time is None 873 mlflow_client.set_terminated(run_id, "FAILED") 874 assert mlflow_client.get_run(run_id).info.status == "FAILED" 875 assert mlflow_client.get_run(run_id).info.end_time <= get_current_time_millis() 876 877 878 def test_artifacts(mlflow_client, tmp_path): 879 experiment_id = mlflow_client.create_experiment("Art In Fact") 880 experiment_info = mlflow_client.get_experiment(experiment_id) 881 assert experiment_info.artifact_location.startswith(path_to_local_file_uri(str(tmp_path))) 882 artifact_path = urllib.parse.urlparse(experiment_info.artifact_location).path 883 assert posixpath.split(artifact_path)[-1] == experiment_id 884 885 created_run = mlflow_client.create_run(experiment_id) 886 assert created_run.info.artifact_uri.startswith(experiment_info.artifact_location) 887 run_id = created_run.info.run_id 888 src_dir = tmp_path.joinpath("test_artifacts_src") 889 src_dir.mkdir() 890 src_file = os.path.join(src_dir, "my.file") 891 with open(src_file, "w") as f: 892 f.write("Hello, World!") 893 mlflow_client.log_artifact(run_id, src_file, None) 894 mlflow_client.log_artifacts(run_id, src_dir, "dir") 895 896 root_artifacts_list = mlflow_client.list_artifacts(run_id) 897 assert {a.path for a in root_artifacts_list} == {"my.file", "dir"} 898 899 dir_artifacts_list = mlflow_client.list_artifacts(run_id, "dir") 900 assert {a.path for a in dir_artifacts_list} == {"dir/my.file"} 901 902 all_artifacts = download_artifacts( 903 run_id=run_id, artifact_path=".", tracking_uri=mlflow_client.tracking_uri 904 ) 905 with open(f"{all_artifacts}/my.file") as f: 906 assert f.read() == "Hello, World!" 907 with open(f"{all_artifacts}/dir/my.file") as f: 908 assert f.read() == "Hello, World!" 909 910 dir_artifacts = download_artifacts( 911 run_id=run_id, artifact_path="dir", tracking_uri=mlflow_client.tracking_uri 912 ) 913 with open(f"{dir_artifacts}/my.file") as f: 914 assert f.read() == "Hello, World!" 915 916 917 def test_search_pagination(mlflow_client): 918 experiment_id = mlflow_client.create_experiment("search_pagination") 919 runs = [mlflow_client.create_run(experiment_id, start_time=1).info.run_id for _ in range(0, 10)] 920 runs = sorted(runs) 921 result = mlflow_client.search_runs([experiment_id], max_results=4, page_token=None) 922 assert [r.info.run_id for r in result] == runs[0:4] 923 assert result.token is not None 924 result = mlflow_client.search_runs([experiment_id], max_results=4, page_token=result.token) 925 assert [r.info.run_id for r in result] == runs[4:8] 926 assert result.token is not None 927 result = mlflow_client.search_runs([experiment_id], max_results=4, page_token=result.token) 928 assert [r.info.run_id for r in result] == runs[8:] 929 assert result.token is None 930 931 932 def test_search_validation(mlflow_client): 933 experiment_id = mlflow_client.create_experiment("search_validation") 934 with pytest.raises( 935 MlflowException, 936 match=r"Invalid value 123456789 for parameter 'max_results' supplied", 937 ): 938 mlflow_client.search_runs([experiment_id], max_results=123456789) 939 940 941 def test_get_experiment_by_name(mlflow_client): 942 name = "test_get_experiment_by_name" 943 experiment_id = mlflow_client.create_experiment(name) 944 res = mlflow_client.get_experiment_by_name(name) 945 assert res.experiment_id == experiment_id 946 assert res.name == name 947 assert mlflow_client.get_experiment_by_name("idontexist") is None 948 949 950 def test_get_experiment(mlflow_client): 951 name = "test_get_experiment" 952 experiment_id = mlflow_client.create_experiment(name) 953 res = mlflow_client.get_experiment(experiment_id) 954 assert res.experiment_id == experiment_id 955 assert res.name == name 956 957 958 def test_search_experiments(mlflow_client): 959 # To ensure the default experiment and non-default experiments have different creation_time 960 # for deterministic search results, send a request to the server and initialize the tracking 961 # store. 962 assert mlflow_client.search_experiments()[0].name == "Default" 963 964 experiments = [ 965 ("a", {"key": "value"}), 966 ("ab", {"key": "vaLue"}), 967 ("Abc", None), 968 ] 969 experiment_ids = [] 970 for name, tags in experiments: 971 # sleep for windows file system current_time precision in Python to enforce 972 # deterministic ordering based on last_update_time (creation_time due to no 973 # mutation of experiment state) 974 time.sleep(0.001) 975 experiment_ids.append(mlflow_client.create_experiment(name, tags=tags)) 976 977 # filter_string 978 experiments = mlflow_client.search_experiments(filter_string="attribute.name = 'a'") 979 assert [e.name for e in experiments] == ["a"] 980 experiments = mlflow_client.search_experiments(filter_string="attribute.name != 'a'") 981 assert [e.name for e in experiments] == ["Abc", "ab", "Default"] 982 experiments = mlflow_client.search_experiments(filter_string="name LIKE 'a%'") 983 assert [e.name for e in experiments] == ["ab", "a"] 984 experiments = mlflow_client.search_experiments(filter_string="tag.key = 'value'") 985 assert [e.name for e in experiments] == ["a"] 986 experiments = mlflow_client.search_experiments(filter_string="tag.key != 'value'") 987 assert [e.name for e in experiments] == ["ab"] 988 experiments = mlflow_client.search_experiments(filter_string="tag.key ILIKE '%alu%'") 989 assert [e.name for e in experiments] == ["ab", "a"] 990 991 # order_by 992 experiments = mlflow_client.search_experiments(order_by=["name DESC"]) 993 assert [e.name for e in experiments] == ["ab", "a", "Default", "Abc"] 994 995 # max_results 996 experiments = mlflow_client.search_experiments(max_results=2) 997 assert [e.name for e in experiments] == ["Abc", "ab"] 998 # page_token 999 experiments = mlflow_client.search_experiments(page_token=experiments.token) 1000 assert [e.name for e in experiments] == ["a", "Default"] 1001 1002 # view_type 1003 time.sleep(0.001) 1004 mlflow_client.delete_experiment(experiment_ids[1]) 1005 experiments = mlflow_client.search_experiments(view_type=ViewType.ACTIVE_ONLY) 1006 assert [e.name for e in experiments] == ["Abc", "a", "Default"] 1007 experiments = mlflow_client.search_experiments(view_type=ViewType.DELETED_ONLY) 1008 assert [e.name for e in experiments] == ["ab"] 1009 experiments = mlflow_client.search_experiments(view_type=ViewType.ALL) 1010 assert [e.name for e in experiments] == ["Abc", "ab", "a", "Default"] 1011 1012 1013 def test_get_metric_history_bulk_rejects_invalid_requests(mlflow_client): 1014 def assert_response(resp, message_part): 1015 assert resp.status_code == 400 1016 response_json = resp.json() 1017 assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE" 1018 assert message_part in response_json.get("message", "") 1019 1020 response_no_run_ids_field = requests.get( 1021 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1022 params={"metric_key": "key"}, 1023 ) 1024 assert_response( 1025 response_no_run_ids_field, 1026 "GetMetricHistoryBulk request must specify at least one run_id", 1027 ) 1028 1029 response_empty_run_ids = requests.get( 1030 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1031 params={"run_id": [], "metric_key": "key"}, 1032 ) 1033 assert_response( 1034 response_empty_run_ids, 1035 "GetMetricHistoryBulk request must specify at least one run_id", 1036 ) 1037 1038 response_too_many_run_ids = requests.get( 1039 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1040 params={"run_id": [f"id_{i}" for i in range(1000)], "metric_key": "key"}, 1041 ) 1042 assert_response( 1043 response_too_many_run_ids, 1044 "GetMetricHistoryBulk request cannot specify more than", 1045 ) 1046 1047 response_no_metric_key_field = requests.get( 1048 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1049 params={"run_id": ["123"]}, 1050 ) 1051 assert_response( 1052 response_no_metric_key_field, 1053 "GetMetricHistoryBulk request must specify a metric_key", 1054 ) 1055 1056 1057 def test_get_metric_history_bulk_returns_expected_metrics_in_expected_order( 1058 mlflow_client, 1059 ): 1060 experiment_id = mlflow_client.create_experiment("get metric history bulk") 1061 created_run1 = mlflow_client.create_run(experiment_id) 1062 run_id1 = created_run1.info.run_id 1063 created_run2 = mlflow_client.create_run(experiment_id) 1064 run_id2 = created_run2.info.run_id 1065 created_run3 = mlflow_client.create_run(experiment_id) 1066 run_id3 = created_run3.info.run_id 1067 1068 metricA_history = [ 1069 {"key": "metricA", "timestamp": 1, "step": 2, "value": 10.0}, 1070 {"key": "metricA", "timestamp": 1, "step": 3, "value": 11.0}, 1071 {"key": "metricA", "timestamp": 1, "step": 3, "value": 12.0}, 1072 {"key": "metricA", "timestamp": 2, "step": 3, "value": 12.0}, 1073 ] 1074 for metric in metricA_history: 1075 mlflow_client.log_metric(run_id1, **metric) 1076 metric_for_run2 = dict(metric) 1077 metric_for_run2["value"] += 1.0 1078 mlflow_client.log_metric(run_id2, **metric_for_run2) 1079 1080 metricB_history = [ 1081 {"key": "metricB", "timestamp": 7, "step": -2, "value": -100.0}, 1082 {"key": "metricB", "timestamp": 8, "step": 0, "value": 0.0}, 1083 {"key": "metricB", "timestamp": 8, "step": 0, "value": 1.0}, 1084 {"key": "metricB", "timestamp": 9, "step": 1, "value": 12.0}, 1085 ] 1086 for metric in metricB_history: 1087 mlflow_client.log_metric(run_id1, **metric) 1088 metric_for_run2 = dict(metric) 1089 metric_for_run2["value"] += 1.0 1090 mlflow_client.log_metric(run_id2, **metric_for_run2) 1091 1092 response_run1_metricA = requests.get( 1093 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1094 params={"run_id": [run_id1], "metric_key": "metricA"}, 1095 ) 1096 assert response_run1_metricA.status_code == 200 1097 assert response_run1_metricA.json().get("metrics") == [ 1098 {**metric, "run_id": run_id1} for metric in metricA_history 1099 ] 1100 1101 response_run2_metricB = requests.get( 1102 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1103 params={"run_id": [run_id2], "metric_key": "metricB"}, 1104 ) 1105 assert response_run2_metricB.status_code == 200 1106 assert response_run2_metricB.json().get("metrics") == [ 1107 {**metric, "run_id": run_id2, "value": metric["value"] + 1.0} for metric in metricB_history 1108 ] 1109 1110 response_run1_run2_metricA = requests.get( 1111 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1112 params={"run_id": [run_id1, run_id2], "metric_key": "metricA"}, 1113 ) 1114 assert response_run1_run2_metricA.status_code == 200 1115 assert response_run1_run2_metricA.json().get("metrics") == sorted( 1116 [{**metric, "run_id": run_id1} for metric in metricA_history] 1117 + [ 1118 {**metric, "run_id": run_id2, "value": metric["value"] + 1.0} 1119 for metric in metricA_history 1120 ], 1121 key=lambda metric: metric["run_id"], 1122 ) 1123 1124 response_run1_run2_run_3_metricB = requests.get( 1125 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1126 params={"run_id": [run_id1, run_id2, run_id3], "metric_key": "metricB"}, 1127 ) 1128 assert response_run1_run2_run_3_metricB.status_code == 200 1129 assert response_run1_run2_run_3_metricB.json().get("metrics") == sorted( 1130 [{**metric, "run_id": run_id1} for metric in metricB_history] 1131 + [ 1132 {**metric, "run_id": run_id2, "value": metric["value"] + 1.0} 1133 for metric in metricB_history 1134 ], 1135 key=lambda metric: metric["run_id"], 1136 ) 1137 1138 1139 def test_get_metric_history_bulk_respects_max_results(mlflow_client): 1140 experiment_id = mlflow_client.create_experiment("get metric history bulk") 1141 run_id = mlflow_client.create_run(experiment_id).info.run_id 1142 max_results = 2 1143 1144 metricA_history = [ 1145 {"key": "metricA", "timestamp": 1, "step": 2, "value": 10.0}, 1146 {"key": "metricA", "timestamp": 1, "step": 3, "value": 11.0}, 1147 {"key": "metricA", "timestamp": 1, "step": 3, "value": 12.0}, 1148 {"key": "metricA", "timestamp": 2, "step": 3, "value": 12.0}, 1149 ] 1150 for metric in metricA_history: 1151 mlflow_client.log_metric(run_id, **metric) 1152 1153 response_limited = requests.get( 1154 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk", 1155 params={ 1156 "run_id": [run_id], 1157 "metric_key": "metricA", 1158 "max_results": max_results, 1159 }, 1160 ) 1161 assert response_limited.status_code == 200 1162 assert response_limited.json().get("metrics") == [ 1163 {**metric, "run_id": run_id} for metric in metricA_history[:max_results] 1164 ] 1165 1166 1167 def test_get_metric_history_bulk_calls_optimized_impl_when_expected(tmp_path): 1168 from mlflow.server.handlers import get_metric_history_bulk_handler 1169 1170 path = path_to_local_file_uri(str(tmp_path.joinpath("sqlalchemy.db"))) 1171 uri = ("sqlite://" if sys.platform == "win32" else "sqlite:////") + path[len("file://") :] 1172 mock_store = mock.Mock(wraps=SqlAlchemyStore(uri, str(tmp_path))) 1173 1174 flask_app = flask.Flask("test_flask_app") 1175 1176 class MockRequestArgs: 1177 def __init__(self, args_dict): 1178 self.args_dict = args_dict 1179 1180 def to_dict( 1181 self, 1182 flat, 1183 ): 1184 return self.args_dict 1185 1186 def get(self, key, default=None): 1187 return self.args_dict.get(key, default) 1188 1189 with ( 1190 mock.patch("mlflow.server.handlers._get_tracking_store", return_value=mock_store), 1191 flask_app.test_request_context() as mock_context, 1192 ): 1193 run_ids = [str(i) for i in range(10)] 1194 mock_context.request.args = MockRequestArgs({ 1195 "run_id": run_ids, 1196 "metric_key": "mock_key", 1197 }) 1198 1199 get_metric_history_bulk_handler() 1200 1201 mock_store.get_metric_history_bulk.assert_called_once_with( 1202 run_ids=run_ids, 1203 metric_key="mock_key", 1204 max_results=25000, 1205 ) 1206 1207 1208 def test_get_metric_history_respects_max_results(mlflow_client): 1209 experiment_id = mlflow_client.create_experiment("test max_results") 1210 run = mlflow_client.create_run(experiment_id) 1211 run_id = run.info.run_id 1212 1213 metric_history = [ 1214 {"key": "test_metric", "value": float(i), "step": i, "timestamp": 1000 + i} 1215 for i in range(5) 1216 ] 1217 for metric in metric_history: 1218 mlflow_client.log_metric(run_id, **metric) 1219 1220 # Test without max_results - should return all metrics 1221 all_metrics = mlflow_client.get_metric_history(run_id, "test_metric") 1222 assert len(all_metrics) == 5 1223 1224 # Test with max_results=3 - should return only 3 metrics 1225 response = requests.get( 1226 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history", 1227 params={"run_id": run_id, "metric_key": "test_metric", "max_results": 3}, 1228 ) 1229 assert response.status_code == 200 1230 response_data = response.json() 1231 assert len(response_data["metrics"]) == 3 1232 1233 returned_metrics = response_data["metrics"] 1234 for i, metric in enumerate(returned_metrics): 1235 assert metric["key"] == "test_metric" 1236 assert metric["value"] == float(i) 1237 if _MLFLOW_GO_STORE_TESTING.get(): 1238 assert int(metric["step"]) == i 1239 else: 1240 assert metric["step"] == i 1241 1242 1243 def test_get_metric_history_with_page_token(mlflow_client): 1244 experiment_id = mlflow_client.create_experiment("test page_token") 1245 run = mlflow_client.create_run(experiment_id) 1246 run_id = run.info.run_id 1247 1248 metric_history = [ 1249 {"key": "test_metric", "value": float(i), "step": i, "timestamp": 1000 + i} 1250 for i in range(10) 1251 ] 1252 for metric in metric_history: 1253 mlflow_client.log_metric(run_id, **metric) 1254 1255 page_size = 4 1256 1257 first_response = requests.get( 1258 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history", 1259 params={ 1260 "run_id": run_id, 1261 "metric_key": "test_metric", 1262 "max_results": page_size, 1263 }, 1264 ) 1265 assert first_response.status_code == 200 1266 first_data = first_response.json() 1267 first_metrics = first_data["metrics"] 1268 first_token = first_data.get("next_page_token") 1269 1270 assert first_token is not None 1271 assert len(first_metrics) == 4 1272 1273 second_response = requests.get( 1274 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history", 1275 params={ 1276 "run_id": run_id, 1277 "metric_key": "test_metric", 1278 "max_results": page_size, 1279 "page_token": first_token, 1280 }, 1281 ) 1282 assert second_response.status_code == 200 1283 second_data = second_response.json() 1284 second_metrics = second_data["metrics"] 1285 second_token = second_data.get("next_page_token") 1286 1287 assert second_token is not None 1288 assert len(second_metrics) == 4 1289 1290 third_response = requests.get( 1291 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history", 1292 params={ 1293 "run_id": run_id, 1294 "metric_key": "test_metric", 1295 "max_results": page_size, 1296 "page_token": second_token, 1297 }, 1298 ) 1299 assert third_response.status_code == 200 1300 third_data = third_response.json() 1301 third_metrics = third_data["metrics"] 1302 third_token = third_data.get("next_page_token") 1303 1304 assert third_token is None 1305 assert len(third_metrics) == 2 1306 1307 all_paginated_metrics = first_metrics + second_metrics + third_metrics 1308 assert len(all_paginated_metrics) == 10 1309 1310 for i, metric in enumerate(all_paginated_metrics): 1311 assert metric["key"] == "test_metric" 1312 assert metric["value"] == float(i) 1313 if _MLFLOW_GO_STORE_TESTING.get(): 1314 assert int(metric["step"]) == i 1315 else: 1316 assert metric["step"] == i 1317 if _MLFLOW_GO_STORE_TESTING.get(): 1318 assert int(metric["timestamp"]) == 1000 + i 1319 else: 1320 assert metric["timestamp"] == 1000 + i 1321 1322 # Test with invalid page_token 1323 response = requests.get( 1324 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history", 1325 params={ 1326 "run_id": run_id, 1327 "metric_key": "test_metric", 1328 "page_token": "invalid_token", 1329 }, 1330 ) 1331 assert response.status_code == 400 1332 response_data = response.json() 1333 assert "INVALID_PARAMETER_VALUE" in response_data.get("error_code", "") 1334 1335 1336 def test_get_metric_history_bulk_interval_rejects_invalid_requests(mlflow_client): 1337 def assert_response(resp, message_part): 1338 assert resp.status_code == 400 1339 response_json = resp.json() 1340 assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE" 1341 assert message_part in response_json.get("message", "") 1342 1343 url = f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk-interval" 1344 1345 assert_response( 1346 requests.get(url, params={"metric_key": "key"}), 1347 "Missing value for required parameter 'run_ids'.", 1348 ) 1349 1350 assert_response( 1351 requests.get(url, params={"run_ids": [], "metric_key": "key"}), 1352 "Missing value for required parameter 'run_ids'.", 1353 ) 1354 1355 assert_response( 1356 requests.get( 1357 url, 1358 params={"run_ids": [f"id_{i}" for i in range(1000)], "metric_key": "key"}, 1359 ), 1360 "GetMetricHistoryBulkInterval request must specify at most 100 run_ids.", 1361 ) 1362 1363 assert_response( 1364 requests.get(url, params={"run_ids": ["123"], "metric_key": "key", "max_results": 0}), 1365 "max_results must be between 1 and 2500", 1366 ) 1367 1368 assert_response( 1369 requests.get(url, params={"run_ids": ["123"], "metric_key": ""}), 1370 "Missing value for required parameter 'metric_key'", 1371 ) 1372 1373 assert_response( 1374 requests.get(url, params={"run_ids": ["123"], "max_results": 5}), 1375 "Missing value for required parameter 'metric_key'", 1376 ) 1377 1378 assert_response( 1379 requests.get( 1380 url, 1381 params={ 1382 "run_ids": ["123"], 1383 "metric_key": "key", 1384 "start_step": 1, 1385 "end_step": 0, 1386 "max_results": 5, 1387 }, 1388 ), 1389 "end_step must be greater than start_step. ", 1390 ) 1391 1392 assert_response( 1393 requests.get( 1394 url, 1395 params={ 1396 "run_ids": ["123"], 1397 "metric_key": "key", 1398 "start_step": 1, 1399 "max_results": 5, 1400 }, 1401 ), 1402 "If either start step or end step are specified, both must be specified.", 1403 ) 1404 1405 1406 def test_get_metric_history_bulk_interval_respects_max_results(mlflow_client): 1407 experiment_id = mlflow_client.create_experiment("get metric history bulk") 1408 run_id1 = mlflow_client.create_run(experiment_id).info.run_id 1409 metric_history = [ 1410 {"key": "metricA", "timestamp": 1, "step": i, "value": 10.0} for i in range(10) 1411 ] 1412 for metric in metric_history: 1413 mlflow_client.log_metric(run_id1, **metric) 1414 1415 url = f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk-interval" 1416 response_limited = requests.get( 1417 url, 1418 params={"run_ids": [run_id1], "metric_key": "metricA", "max_results": 5}, 1419 ) 1420 assert response_limited.status_code == 200 1421 expected_steps = [0, 2, 4, 6, 8, 9] 1422 expected_metrics = [ 1423 {**metric, "run_id": run_id1} 1424 for metric in metric_history 1425 if metric["step"] in expected_steps 1426 ] 1427 assert response_limited.json().get("metrics") == expected_metrics 1428 1429 # with start_step and end_step 1430 response_limited = requests.get( 1431 url, 1432 params={ 1433 "run_ids": [run_id1], 1434 "metric_key": "metricA", 1435 "start_step": 0, 1436 "end_step": 4, 1437 "max_results": 5, 1438 }, 1439 ) 1440 assert response_limited.status_code == 200 1441 assert response_limited.json().get("metrics") == [ 1442 {**metric, "run_id": run_id1} for metric in metric_history[:5] 1443 ] 1444 1445 # multiple runs 1446 run_id2 = mlflow_client.create_run(experiment_id).info.run_id 1447 metric_history2 = [ 1448 {"key": "metricA", "timestamp": 1, "step": i, "value": 10.0} for i in range(20) 1449 ] 1450 for metric in metric_history2: 1451 mlflow_client.log_metric(run_id2, **metric) 1452 response_limited = requests.get( 1453 url, 1454 params={ 1455 "run_ids": [run_id1, run_id2], 1456 "metric_key": "metricA", 1457 "max_results": 5, 1458 }, 1459 ) 1460 expected_steps = [0, 4, 8, 9, 12, 16, 19] 1461 expected_metrics = [] 1462 for run_id, metric_history in [ 1463 (run_id1, metric_history), 1464 (run_id2, metric_history2), 1465 ]: 1466 expected_metrics.extend([ 1467 {**metric, "run_id": run_id} 1468 for metric in metric_history 1469 if metric["step"] in expected_steps 1470 ]) 1471 assert response_limited.json().get("metrics") == expected_metrics 1472 1473 # test metrics with same steps 1474 metric_history_timestamp2 = [ 1475 {"key": "metricA", "timestamp": 2, "step": i, "value": 10.0} for i in range(10) 1476 ] 1477 for metric in metric_history_timestamp2: 1478 mlflow_client.log_metric(run_id1, **metric) 1479 1480 response_limited = requests.get( 1481 url, 1482 params={"run_ids": [run_id1], "metric_key": "metricA", "max_results": 5}, 1483 ) 1484 assert response_limited.status_code == 200 1485 expected_steps = [0, 2, 4, 6, 8, 9] 1486 expected_metrics = [ 1487 {"key": "metricA", "timestamp": j, "step": i, "value": 10.0, "run_id": run_id1} 1488 for i in expected_steps 1489 for j in [1, 2] 1490 ] 1491 assert response_limited.json().get("metrics") == expected_metrics 1492 1493 1494 def test_search_dataset_handler_rejects_invalid_requests(mlflow_client): 1495 def assert_response(resp, message_part): 1496 assert resp.status_code == 400 1497 response_json = resp.json() 1498 assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE" 1499 assert message_part in response_json.get("message", "") 1500 1501 response_no_experiment_id_field = requests.post( 1502 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/experiments/search-datasets", 1503 json={}, 1504 ) 1505 assert_response( 1506 response_no_experiment_id_field, 1507 "SearchDatasets request must specify at least one experiment_id.", 1508 ) 1509 1510 response_empty_experiment_id_field = requests.post( 1511 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/experiments/search-datasets", 1512 json={"experiment_ids": []}, 1513 ) 1514 assert_response( 1515 response_empty_experiment_id_field, 1516 "SearchDatasets request must specify at least one experiment_id.", 1517 ) 1518 1519 response_too_many_experiment_ids = requests.post( 1520 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/experiments/search-datasets", 1521 json={"experiment_ids": [f"id_{i}" for i in range(1000)]}, 1522 ) 1523 assert_response( 1524 response_too_many_experiment_ids, 1525 "SearchDatasets request cannot specify more than", 1526 ) 1527 1528 1529 def test_search_dataset_handler_returns_expected_results(mlflow_client): 1530 experiment_id = mlflow_client.create_experiment("log inputs test") 1531 created_run = mlflow_client.create_run(experiment_id) 1532 run_id = created_run.info.run_id 1533 1534 dataset1 = Dataset( 1535 name="name1", 1536 digest="digest1", 1537 source_type="source_type1", 1538 source="source1", 1539 ) 1540 dataset_inputs1 = [ 1541 DatasetInput( 1542 dataset=dataset1, 1543 tags=[InputTag(key=MLFLOW_DATASET_CONTEXT, value="training")], 1544 ) 1545 ] 1546 mlflow_client.log_inputs(run_id, dataset_inputs1) 1547 1548 response = requests.post( 1549 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/experiments/search-datasets", 1550 json={"experiment_ids": [experiment_id]}, 1551 ) 1552 expected = { 1553 "experiment_id": experiment_id, 1554 "name": "name1", 1555 "digest": "digest1", 1556 "context": "training", 1557 } 1558 1559 assert response.status_code == 200 1560 assert response.json().get("dataset_summaries") == [expected] 1561 1562 1563 def test_create_model_version_with_path_source(mlflow_client): 1564 name = "model" 1565 mlflow_client.create_registered_model(name) 1566 exp_id = mlflow_client.create_experiment("test") 1567 run = mlflow_client.create_run(experiment_id=exp_id) 1568 1569 response = requests.post( 1570 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1571 json={ 1572 "name": name, 1573 "source": run.info.artifact_uri[len("file://") :], 1574 "run_id": run.info.run_id, 1575 }, 1576 ) 1577 assert response.status_code == 200 1578 1579 # run_id is not specified 1580 response = requests.post( 1581 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1582 json={ 1583 "name": name, 1584 "source": run.info.artifact_uri[len("file://") :], 1585 }, 1586 ) 1587 assert response.status_code == 400 1588 assert "To use a local path as a model version" in response.json()["message"] 1589 1590 # run_id is specified but source is not in the run's artifact directory 1591 response = requests.post( 1592 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1593 json={ 1594 "name": name, 1595 "source": "/tmp", 1596 "run_id": run.info.run_id, 1597 }, 1598 ) 1599 assert response.status_code == 400 1600 assert "To use a local path as a model version" in response.json()["message"] 1601 1602 1603 def test_create_model_version_with_non_local_source(mlflow_client): 1604 name = "model" 1605 mlflow_client.create_registered_model(name) 1606 exp_id = mlflow_client.create_experiment("test") 1607 run = mlflow_client.create_run(experiment_id=exp_id) 1608 1609 response = requests.post( 1610 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1611 json={ 1612 "name": name, 1613 "source": run.info.artifact_uri[len("file://") :], 1614 "run_id": run.info.run_id, 1615 }, 1616 ) 1617 assert response.status_code == 200 1618 1619 # Test that remote uri's supplied as a source with absolute paths work fine 1620 response = requests.post( 1621 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1622 json={ 1623 "name": name, 1624 "source": "mlflow-artifacts:/models", 1625 "run_id": run.info.run_id, 1626 }, 1627 ) 1628 assert response.status_code == 200 1629 1630 # A single trailing slash 1631 response = requests.post( 1632 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1633 json={ 1634 "name": name, 1635 "source": "mlflow-artifacts:/models/", 1636 "run_id": run.info.run_id, 1637 }, 1638 ) 1639 assert response.status_code == 200 1640 1641 # Multiple trailing slashes 1642 response = requests.post( 1643 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1644 json={ 1645 "name": name, 1646 "source": "mlflow-artifacts:/models///", 1647 "run_id": run.info.run_id, 1648 }, 1649 ) 1650 assert response.status_code == 200 1651 1652 # Multiple slashes 1653 response = requests.post( 1654 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1655 json={ 1656 "name": name, 1657 "source": "mlflow-artifacts:/models/foo///bar", 1658 "run_id": run.info.run_id, 1659 }, 1660 ) 1661 assert response.status_code == 200 1662 1663 response = requests.post( 1664 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1665 json={ 1666 "name": name, 1667 "source": "mlflow-artifacts://host:9000/models", 1668 "run_id": run.info.run_id, 1669 }, 1670 ) 1671 assert response.status_code == 200 1672 1673 # Multiple dots 1674 response = requests.post( 1675 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1676 json={ 1677 "name": name, 1678 "source": "mlflow-artifacts://host:9000/models/artifact/..../", 1679 "run_id": run.info.run_id, 1680 }, 1681 ) 1682 assert response.status_code == 200 1683 1684 # Test that invalid remote uri's cannot be created 1685 response = requests.post( 1686 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1687 json={ 1688 "name": name, 1689 "source": "mlflow-artifacts://host:9000/models/../../../", 1690 "run_id": run.info.run_id, 1691 }, 1692 ) 1693 assert response.status_code == 400 1694 assert "If supplying a source as an http, https," in response.json()["message"] 1695 1696 response = requests.post( 1697 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1698 json={ 1699 "name": name, 1700 "source": "http://host:9000/models/../../../", 1701 "run_id": run.info.run_id, 1702 }, 1703 ) 1704 assert response.status_code == 400 1705 assert "If supplying a source as an http, https," in response.json()["message"] 1706 1707 response = requests.post( 1708 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1709 json={ 1710 "name": name, 1711 "source": "https://host/api/2.0/mlflow-artifacts/artifacts/../../../", 1712 "run_id": run.info.run_id, 1713 }, 1714 ) 1715 assert response.status_code == 400 1716 assert "If supplying a source as an http, https," in response.json()["message"] 1717 1718 response = requests.post( 1719 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1720 json={ 1721 "name": name, 1722 "source": "s3a://my_bucket/api/2.0/mlflow-artifacts/artifacts/../../../", 1723 "run_id": run.info.run_id, 1724 }, 1725 ) 1726 assert response.status_code == 400 1727 assert "If supplying a source as an http, https," in response.json()["message"] 1728 1729 response = requests.post( 1730 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1731 json={ 1732 "name": name, 1733 "source": "ftp://host:8888/api/2.0/mlflow-artifacts/artifacts/../../../", 1734 "run_id": run.info.run_id, 1735 }, 1736 ) 1737 assert response.status_code == 400 1738 assert "If supplying a source as an http, https," in response.json()["message"] 1739 1740 response = requests.post( 1741 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1742 json={ 1743 "name": name, 1744 "source": "mlflow-artifacts://host:9000/models/..%2f..%2fartifacts", 1745 "run_id": run.info.run_id, 1746 }, 1747 ) 1748 assert response.status_code == 400 1749 assert "If supplying a source as an http, https," in response.json()["message"] 1750 1751 response = requests.post( 1752 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1753 json={ 1754 "name": name, 1755 "source": "mlflow-artifacts://host:9000/models/artifact%00", 1756 "run_id": run.info.run_id, 1757 }, 1758 ) 1759 assert response.status_code == 400 1760 assert "If supplying a source as an http, https," in response.json()["message"] 1761 1762 response = requests.post( 1763 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1764 json={ 1765 "name": name, 1766 "source": f"dbfs:/{run.info.run_id}/artifacts/a%3f/../../../../../../../../../../", 1767 "run_id": run.info.run_id, 1768 }, 1769 ) 1770 assert response.status_code == 400 1771 assert "Invalid model version source" in response.json()["message"] 1772 1773 model = mlflow_client.create_logged_model(experiment_id=exp_id) 1774 response = requests.post( 1775 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1776 json={ 1777 "name": name, 1778 "source": model.artifact_location, 1779 "model_id": model.model_id, 1780 }, 1781 ) 1782 assert response.status_code == 200 1783 1784 response = requests.post( 1785 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1786 json={ 1787 "name": name, 1788 "source": model.model_uri, 1789 "model_id": model.model_id, 1790 }, 1791 ) 1792 assert response.status_code == 200 1793 1794 response = requests.post( 1795 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1796 json={ 1797 "name": name, 1798 "source": "file:///path/to/model", 1799 "model_id": model.model_id, 1800 }, 1801 ) 1802 assert response.status_code == 400 1803 1804 1805 def test_create_model_version_with_file_uri(mlflow_client): 1806 name = "test" 1807 mlflow_client.create_registered_model(name) 1808 exp_id = mlflow_client.create_experiment("test") 1809 run = mlflow_client.create_run(experiment_id=exp_id) 1810 assert run.info.artifact_uri.startswith("file://") 1811 response = requests.post( 1812 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1813 json={ 1814 "name": name, 1815 "source": run.info.artifact_uri, 1816 "run_id": run.info.run_id, 1817 }, 1818 ) 1819 assert response.status_code == 200 1820 1821 response = requests.post( 1822 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1823 json={ 1824 "name": name, 1825 "source": f"{run.info.artifact_uri}/model", 1826 "run_id": run.info.run_id, 1827 }, 1828 ) 1829 assert response.status_code == 200 1830 1831 response = requests.post( 1832 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1833 json={ 1834 "name": name, 1835 "source": f"{run.info.artifact_uri}/.", 1836 "run_id": run.info.run_id, 1837 }, 1838 ) 1839 assert response.status_code == 200 1840 1841 response = requests.post( 1842 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1843 json={ 1844 "name": name, 1845 "source": f"{run.info.artifact_uri}/model/..", 1846 "run_id": run.info.run_id, 1847 }, 1848 ) 1849 assert response.status_code == 200 1850 1851 # run_id is not specified 1852 response = requests.post( 1853 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1854 json={ 1855 "name": name, 1856 "source": run.info.artifact_uri, 1857 }, 1858 ) 1859 assert response.status_code == 400 1860 assert "To use a local path as a model version" in response.json()["message"] 1861 1862 # run_id is specified but source is not in the run's artifact directory 1863 response = requests.post( 1864 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1865 json={ 1866 "name": name, 1867 "source": "file:///tmp", 1868 }, 1869 ) 1870 assert response.status_code == 400 1871 assert "To use a local path as a model version" in response.json()["message"] 1872 1873 response = requests.post( 1874 f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", 1875 json={ 1876 "name": name, 1877 "source": "file://123.456.789.123/path/to/source", 1878 "run_id": run.info.run_id, 1879 }, 1880 ) 1881 assert response.status_code == 500, response.json() 1882 assert "is not a valid remote uri" in response.json()["message"] 1883 1884 1885 def test_create_model_version_with_validation_regex(db_uri: str): 1886 port = get_safe_port() 1887 with subprocess.Popen( 1888 [ 1889 sys.executable, 1890 "-m", 1891 "mlflow", 1892 "server", 1893 "--port", 1894 str(port), 1895 "--backend-store-uri", 1896 db_uri, 1897 ], 1898 env=( 1899 os.environ.copy() 1900 | { 1901 "MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX": r"^mlflow-artifacts:/.*$", 1902 "MLFLOW_SERVER_ENABLE_JOB_EXECUTION": "false", 1903 } 1904 ), 1905 ) as proc: 1906 try: 1907 # Wait for the server to start 1908 for _ in range(10): 1909 try: 1910 if requests.get(f"http://localhost:{port}/health").ok: 1911 break 1912 except requests.ConnectionError: 1913 time.sleep(1) 1914 else: 1915 raise RuntimeError("Failed to connect to the MLflow server") 1916 1917 # Test that the validation regex works as expected 1918 client = MlflowClient(f"http://localhost:{port}") 1919 name = "test" 1920 client.create_registered_model(name) 1921 # Invalid source 1922 with pytest.raises(MlflowException, match="Invalid model version source"): 1923 client.create_model_version(name, source="s3://path/to/model") 1924 # Valid source 1925 experiment_id = client.create_experiment("test") 1926 run = client.create_run(experiment_id=experiment_id) 1927 assert run.info.artifact_uri.startswith("mlflow-artifacts:/") 1928 client.create_model_version( 1929 name, source=f"{run.info.artifact_uri}/model", run_id=run.info.run_id 1930 ) 1931 finally: 1932 proc.terminate() 1933 proc.wait() 1934 1935 1936 @pytest.mark.xfail(reason="Tracking server does not support logged-model endpoints yet") 1937 def test_logging_model_with_local_artifact_uri(mlflow_client): 1938 from sklearn.linear_model import LogisticRegression 1939 1940 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 1941 with mlflow.start_run() as run: 1942 assert run.info.artifact_uri.startswith("file://") 1943 mlflow.sklearn.log_model(LogisticRegression(), name="model", registered_model_name="rmn") 1944 mlflow.pyfunc.load_model("models:/rmn/1") 1945 1946 1947 def test_log_input(mlflow_client, tmp_path): 1948 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 1949 path = tmp_path / "temp.csv" 1950 df.to_csv(path) 1951 dataset = from_pandas(df, source=path) 1952 1953 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 1954 1955 with mlflow.start_run() as run: 1956 mlflow.log_input(dataset, "train", {"foo": "baz"}) 1957 1958 dataset_inputs = mlflow_client.get_run(run.info.run_id).inputs.dataset_inputs 1959 1960 assert len(dataset_inputs) == 1 1961 assert dataset_inputs[0].dataset.name == "dataset" 1962 assert dataset_inputs[0].dataset.digest == "f0f3e026" 1963 assert dataset_inputs[0].dataset.source_type == "local" 1964 assert json.loads(dataset_inputs[0].dataset.source) == {"uri": str(path)} 1965 assert json.loads(dataset_inputs[0].dataset.schema) == { 1966 "mlflow_colspec": [ 1967 {"name": "a", "type": "long", "required": True}, 1968 {"name": "b", "type": "long", "required": True}, 1969 {"name": "c", "type": "long", "required": True}, 1970 ] 1971 } 1972 assert json.loads(dataset_inputs[0].dataset.profile) == { 1973 "num_rows": 2, 1974 "num_elements": 6, 1975 } 1976 1977 assert len(dataset_inputs[0].tags) == 2 1978 assert dataset_inputs[0].tags[0].key == "foo" 1979 assert dataset_inputs[0].tags[0].value == "baz" 1980 assert dataset_inputs[0].tags[1].key == mlflow_tags.MLFLOW_DATASET_CONTEXT 1981 assert dataset_inputs[0].tags[1].value == "train" 1982 1983 1984 def test_create_model_version_model_id(mlflow_client): 1985 name = "model" 1986 mlflow_client.create_registered_model(name) 1987 exp_id = mlflow_client.create_experiment("test") 1988 model = mlflow_client.create_logged_model(experiment_id=exp_id) 1989 mlflow_client.create_model_version( 1990 name=name, 1991 source=model.artifact_location, 1992 model_id=model.model_id, 1993 ) 1994 model = mlflow_client.get_logged_model(model.model_id) 1995 assert model.tags["mlflow.modelVersions"] == '[{"name": "model", "version": 1}]' 1996 mlflow_client.create_model_version( 1997 name=name, 1998 source=model.artifact_location, 1999 model_id=model.model_id, 2000 ) 2001 model = mlflow_client.get_logged_model(model.model_id) 2002 assert ( 2003 model.tags["mlflow.modelVersions"] 2004 == '[{"name": "model", "version": 1}, {"name": "model", "version": 2}]' 2005 ) 2006 2007 2008 def test_log_inputs(mlflow_client): 2009 experiment_id = mlflow_client.create_experiment("log inputs test") 2010 created_run = mlflow_client.create_run(experiment_id) 2011 run_id = created_run.info.run_id 2012 2013 dataset1 = Dataset( 2014 name="name1", 2015 digest="digest1", 2016 source_type="source_type1", 2017 source="source1", 2018 ) 2019 dataset_inputs1 = [DatasetInput(dataset=dataset1, tags=[InputTag(key="tag1", value="value1")])] 2020 2021 mlflow_client.log_inputs(run_id, dataset_inputs1) 2022 run = mlflow_client.get_run(run_id) 2023 assert len(run.inputs.dataset_inputs) == 1 2024 2025 assert isinstance(run.inputs, RunInputs) 2026 assert isinstance(run.inputs.dataset_inputs[0], DatasetInput) 2027 assert isinstance(run.inputs.dataset_inputs[0].dataset, Dataset) 2028 assert run.inputs.dataset_inputs[0].dataset.name == "name1" 2029 assert run.inputs.dataset_inputs[0].dataset.digest == "digest1" 2030 assert run.inputs.dataset_inputs[0].dataset.source_type == "source_type1" 2031 assert run.inputs.dataset_inputs[0].dataset.source == "source1" 2032 assert len(run.inputs.dataset_inputs[0].tags) == 1 2033 assert run.inputs.dataset_inputs[0].tags[0].key == "tag1" 2034 assert run.inputs.dataset_inputs[0].tags[0].value == "value1" 2035 2036 2037 def test_log_inputs_validation(mlflow_client): 2038 def assert_bad_request(payload, expected_error_message): 2039 response = _send_rest_tracking_post_request( 2040 mlflow_client.tracking_uri, 2041 "/api/2.0/mlflow/runs/log-inputs", 2042 payload, 2043 ) 2044 assert response.status_code == 400 2045 assert expected_error_message in response.text 2046 2047 dataset = Dataset( 2048 name="name1", 2049 digest="digest1", 2050 source_type="source_type1", 2051 source="source1", 2052 ) 2053 tags = [InputTag(key="tag1", value="value1")] 2054 dataset_inputs = [ 2055 json.loads(message_to_json(DatasetInput(dataset=dataset, tags=tags).to_proto())) 2056 ] 2057 assert_bad_request( 2058 { 2059 "datasets": dataset_inputs, 2060 }, 2061 "Missing value for required parameter 'run_id'", 2062 ) 2063 2064 2065 def test_log_inputs_model(mlflow_client): 2066 experiment_id = mlflow_client.create_experiment("log inputs test") 2067 run = mlflow_client.create_run(experiment_id) 2068 model = mlflow_client.create_logged_model(experiment_id=experiment_id) 2069 dataset = Dataset( 2070 name="name1", 2071 digest="digest1", 2072 source_type="source_type1", 2073 source="source1", 2074 ) 2075 dataset_inputs = [ 2076 DatasetInput( 2077 dataset=dataset, 2078 tags=[InputTag(key=MLFLOW_DATASET_CONTEXT, value="training")], 2079 ) 2080 ] 2081 mlflow_client.log_inputs( 2082 run.info.run_id, 2083 models=[LoggedModelInput(model_id=model.model_id)], 2084 datasets=dataset_inputs, 2085 ) 2086 run = mlflow_client.get_run(run.info.run_id) 2087 assert len(run.inputs.model_inputs) == 1 2088 2089 2090 def test_update_run_name_without_changing_status(mlflow_client): 2091 experiment_id = mlflow_client.create_experiment("update run name") 2092 created_run = mlflow_client.create_run(experiment_id) 2093 mlflow_client.set_terminated(created_run.info.run_id, "FINISHED") 2094 2095 mlflow_client.update_run(created_run.info.run_id, name="name_abc") 2096 updated_run_info = mlflow_client.get_run(created_run.info.run_id).info 2097 assert updated_run_info.run_name == "name_abc" 2098 assert updated_run_info.status == "FINISHED" 2099 2100 2101 def test_create_promptlab_run_handler_rejects_invalid_requests(mlflow_client): 2102 def assert_response(resp, message_part): 2103 assert resp.status_code == 400 2104 response_json = resp.json() 2105 assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE" 2106 assert message_part in response_json.get("message", "") 2107 2108 response = requests.post( 2109 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run", 2110 json={}, 2111 ) 2112 assert_response( 2113 response, 2114 "CreatePromptlabRun request must specify experiment_id.", 2115 ) 2116 2117 response = requests.post( 2118 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run", 2119 json={"experiment_id": "123"}, 2120 ) 2121 assert_response( 2122 response, 2123 "CreatePromptlabRun request must specify prompt_template.", 2124 ) 2125 2126 response = requests.post( 2127 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run", 2128 json={"experiment_id": "123", "prompt_template": "my_prompt_template"}, 2129 ) 2130 assert_response( 2131 response, 2132 "CreatePromptlabRun request must specify prompt_parameters.", 2133 ) 2134 2135 response = requests.post( 2136 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run", 2137 json={ 2138 "experiment_id": "123", 2139 "prompt_template": "my_prompt_template", 2140 "prompt_parameters": [{"key": "my_key", "value": "my_value"}], 2141 }, 2142 ) 2143 assert_response( 2144 response, 2145 "CreatePromptlabRun request must specify model_route.", 2146 ) 2147 2148 response = requests.post( 2149 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run", 2150 json={ 2151 "experiment_id": "123", 2152 "prompt_template": "my_prompt_template", 2153 "prompt_parameters": [{"key": "my_key", "value": "my_value"}], 2154 "model_route": "my_route", 2155 }, 2156 ) 2157 assert_response( 2158 response, 2159 "CreatePromptlabRun request must specify model_input.", 2160 ) 2161 2162 response = requests.post( 2163 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run", 2164 json={ 2165 "experiment_id": "123", 2166 "prompt_template": "my_prompt_template", 2167 "prompt_parameters": [{"key": "my_key", "value": "my_value"}], 2168 "model_route": "my_route", 2169 "model_input": "my_input", 2170 }, 2171 ) 2172 assert_response( 2173 response, 2174 "CreatePromptlabRun request must specify mlflow_version.", 2175 ) 2176 2177 response = requests.post( 2178 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run", 2179 json={ 2180 "experiment_id": "123", 2181 "prompt_template": "my_prompt_template", 2182 "prompt_parameters": [{"key": "my_key", "value": "my_value"}], 2183 "model_route": "my_route", 2184 "model_input": "my_input", 2185 "mlflow_version": "1.0.0", 2186 }, 2187 ) 2188 2189 2190 def test_create_promptlab_run_handler_returns_expected_results(mlflow_client): 2191 experiment_id = mlflow_client.create_experiment("log inputs test") 2192 2193 response = requests.post( 2194 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run", 2195 json={ 2196 "experiment_id": experiment_id, 2197 "run_name": "my_run_name", 2198 "prompt_template": "my_prompt_template", 2199 "prompt_parameters": [{"key": "my_key", "value": "my_value"}], 2200 "model_route": "my_route", 2201 "model_parameters": [{"key": "temperature", "value": "0.1"}], 2202 "model_input": "my_input", 2203 "model_output": "my_output", 2204 "model_output_parameters": [{"key": "latency", "value": "100"}], 2205 "mlflow_version": "1.0.0", 2206 "user_id": "username", 2207 "start_time": 456, 2208 }, 2209 ) 2210 assert response.status_code == 200 2211 run_json = response.json() 2212 assert run_json["run"]["info"]["run_name"] == "my_run_name" 2213 assert run_json["run"]["info"]["experiment_id"] == experiment_id 2214 assert run_json["run"]["info"]["user_id"] == "username" 2215 assert run_json["run"]["info"]["status"] == "FINISHED" 2216 assert run_json["run"]["info"]["start_time"] == 456 2217 2218 assert {"key": "model_route", "value": "my_route"} in run_json["run"]["data"]["params"] 2219 assert {"key": "prompt_template", "value": "my_prompt_template"} in run_json["run"]["data"][ 2220 "params" 2221 ] 2222 assert {"key": "temperature", "value": "0.1"} in run_json["run"]["data"]["params"] 2223 2224 assert { 2225 "key": "mlflow.loggedArtifacts", 2226 "value": '[{"path": "eval_results_table.json", "type": "table"}]', 2227 } in run_json["run"]["data"]["tags"] 2228 assert {"key": "mlflow.runSourceType", "value": "PROMPT_ENGINEERING"} in run_json["run"][ 2229 "data" 2230 ]["tags"] 2231 2232 2233 def test_gateway_proxy_handler_rejects_invalid_requests(mlflow_client): 2234 def assert_response(resp, message_part): 2235 assert resp.status_code == 400 2236 response_json = resp.json() 2237 assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE" 2238 assert message_part in response_json.get("message", "") 2239 2240 with _init_server( 2241 backend_uri=mlflow_client.tracking_uri, 2242 root_artifact_uri=mlflow_client.tracking_uri, 2243 extra_env={"MLFLOW_DEPLOYMENTS_TARGET": "http://localhost:5001"}, 2244 server_type="flask", 2245 ) as url: 2246 patched_client = MlflowClient(url) 2247 2248 response = requests.post( 2249 f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy", 2250 json={}, 2251 ) 2252 assert_response( 2253 response, 2254 "Deployments proxy request must specify a gateway_path.", 2255 ) 2256 2257 response = requests.post( 2258 f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy", 2259 json={"gateway_path": "foo/bar"}, 2260 ) 2261 assert_response( 2262 response, 2263 "Invalid gateway_path: foo/bar for method: POST", 2264 ) 2265 2266 response = requests.post( 2267 f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy", 2268 json={"gateway_path": "foo/bar/baz"}, 2269 ) 2270 assert_response( 2271 response, 2272 "Invalid gateway_path: foo/bar/baz for method: POST", 2273 ) 2274 2275 response = requests.get( 2276 f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy", 2277 params={"gateway_path": "hello/world"}, 2278 ) 2279 assert_response( 2280 response, 2281 "Invalid gateway_path: hello/world for method: GET", 2282 ) 2283 2284 # Unsupported method 2285 response = requests.delete( 2286 f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy", 2287 ) 2288 assert response.status_code == 405 2289 2290 2291 def test_upload_artifact_handler_rejects_invalid_requests(mlflow_client): 2292 def assert_response(resp, message_part): 2293 assert resp.status_code == 400 2294 response_json = resp.json() 2295 assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE" 2296 assert message_part in response_json.get("message", "") 2297 2298 experiment_id = mlflow_client.create_experiment("upload_artifacts_test") 2299 created_run = mlflow_client.create_run(experiment_id) 2300 2301 response = requests.post( 2302 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact", params={} 2303 ) 2304 assert_response(response, "Request must specify run_uuid.") 2305 2306 response = requests.post( 2307 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact", 2308 params={ 2309 "run_uuid": created_run.info.run_id, 2310 }, 2311 ) 2312 assert_response(response, "Request must specify path.") 2313 2314 response = requests.post( 2315 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact", 2316 params={"run_uuid": created_run.info.run_id, "path": ""}, 2317 ) 2318 assert_response(response, "Request must specify path.") 2319 2320 response = requests.post( 2321 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact", 2322 params={"run_uuid": created_run.info.run_id, "path": "../test.txt"}, 2323 ) 2324 assert_response(response, "Invalid path") 2325 2326 response = requests.post( 2327 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact", 2328 params={ 2329 "run_uuid": created_run.info.run_id, 2330 "path": "test.txt", 2331 }, 2332 ) 2333 assert_response(response, "Request must specify data.") 2334 2335 2336 def test_upload_artifact_handler(mlflow_client): 2337 experiment_id = mlflow_client.create_experiment("upload_artifacts_test") 2338 created_run = mlflow_client.create_run(experiment_id) 2339 2340 response = requests.post( 2341 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact", 2342 params={ 2343 "run_uuid": created_run.info.run_id, 2344 "path": "test.txt", 2345 }, 2346 data="hello world", 2347 ) 2348 assert response.status_code == 200 2349 2350 response = requests.get( 2351 f"{mlflow_client.tracking_uri}/get-artifact", 2352 params={ 2353 "run_uuid": created_run.info.run_id, 2354 "path": "test.txt", 2355 }, 2356 ) 2357 assert response.status_code == 200 2358 assert response.text == "hello world" 2359 2360 2361 def test_graphql_handler(mlflow_client): 2362 response = requests.post( 2363 f"{mlflow_client.tracking_uri}/graphql", 2364 json={ 2365 "query": 'query testQuery {test(inputString: "abc") { output }}', 2366 "operationName": "testQuery", 2367 }, 2368 headers={"content-type": "application/json; charset=utf-8"}, 2369 ) 2370 assert response.status_code == 200 2371 2372 2373 def test_graphql_handler_batching_raise_error(mlflow_client): 2374 # Test max root fields limit 2375 batch_query = ( 2376 "query testQuery {" 2377 + " ".join([ 2378 f"key_{i}: " + 'test(inputString: "abc") { output }' 2379 for i in range(int(MLFLOW_SERVER_GRAPHQL_MAX_ROOT_FIELDS.get()) + 2) 2380 ]) 2381 + "}" 2382 ) 2383 response = requests.post( 2384 f"{mlflow_client.tracking_uri}/graphql", 2385 json={ 2386 "query": batch_query, 2387 "operationName": "testQuery", 2388 }, 2389 headers={"content-type": "application/json; charset=utf-8"}, 2390 ) 2391 assert response.status_code == 200 2392 assert ( 2393 f"GraphQL queries should have at most {MLFLOW_SERVER_GRAPHQL_MAX_ROOT_FIELDS.get()}" 2394 in response.json()["errors"][0] 2395 ) 2396 2397 # Test max aliases limit 2398 batch_query = ( 2399 'query testQuery {mlflowGetExperiment(input: {experimentId: "123"}) {' 2400 + " ".join( 2401 f"experiment_{i}: " + "experiment { name }" 2402 for i in range(int(MLFLOW_SERVER_GRAPHQL_MAX_ALIASES.get()) + 2) 2403 ) 2404 + "}}" 2405 ) 2406 response = requests.post( 2407 f"{mlflow_client.tracking_uri}/graphql", 2408 json={ 2409 "query": batch_query, 2410 "operationName": "testQuery", 2411 }, 2412 ) 2413 assert response.status_code == 200 2414 assert ( 2415 f"queries should have at most {MLFLOW_SERVER_GRAPHQL_MAX_ALIASES.get()} aliases" 2416 in response.json()["errors"][0] 2417 ) 2418 2419 # Test max depth limit 2420 inner = "name" 2421 for _ in range(12): 2422 inner = f"name {{ {inner} }}" 2423 deep_query = ( 2424 'query testQuery { mlflowGetExperiment(input: {experimentId: "123"}) { experiment { ' 2425 + inner 2426 + " } } }" 2427 ) 2428 response = requests.post( 2429 f"{mlflow_client.tracking_uri}/graphql", 2430 json={ 2431 "query": deep_query, 2432 "operationName": "testQuery", 2433 }, 2434 ) 2435 assert response.status_code == 200 2436 assert "Query exceeds maximum depth of 10" in response.json()["errors"][0] 2437 2438 # Test max selections limit 2439 # Exceed the 1000 selection limit 2440 selections = [f"field_{i} {{ name }}" for i in range(1002)] 2441 selections_query = ( 2442 'query testQuery { mlflowGetExperiment(input: {experimentId: "123"}) { experiment { ' 2443 + " ".join(selections) 2444 + " } } }" 2445 ) 2446 response = requests.post( 2447 f"{mlflow_client.tracking_uri}/graphql", 2448 json={ 2449 "query": selections_query, 2450 "operationName": "testQuery", 2451 }, 2452 ) 2453 assert response.status_code == 200 2454 assert "Query exceeds maximum total selections of 1000" in response.json()["errors"][0] 2455 2456 2457 def test_get_experiment_graphql(mlflow_client): 2458 experiment_id = mlflow_client.create_experiment("GraphqlTest") 2459 response = requests.post( 2460 f"{mlflow_client.tracking_uri}/graphql", 2461 json={ 2462 "query": 'query testQuery {mlflowGetExperiment(input: {experimentId: "' 2463 + experiment_id 2464 + '"}) { experiment { name } }}', 2465 "operationName": "testQuery", 2466 }, 2467 headers={"content-type": "application/json; charset=utf-8"}, 2468 ) 2469 assert response.status_code == 200 2470 json = response.json() 2471 assert json["data"]["mlflowGetExperiment"]["experiment"]["name"] == "GraphqlTest" 2472 2473 2474 def test_get_run_and_experiment_graphql(mlflow_client): 2475 name = "GraphqlTest" 2476 mlflow_client.create_registered_model(name) 2477 experiment_id = mlflow_client.create_experiment(name) 2478 created_run = mlflow_client.create_run(experiment_id) 2479 run_id = created_run.info.run_id 2480 mlflow_client.create_model_version("GraphqlTest", "runs:/graphql_test/model", run_id) 2481 response = requests.post( 2482 f"{mlflow_client.tracking_uri}/graphql", 2483 json={ 2484 "query": f""" 2485 query testQuery @component(name: "Test") {{ 2486 mlflowGetRun(input: {{runId: "{run_id}"}}) {{ 2487 run {{ 2488 info {{ 2489 status 2490 }} 2491 experiment {{ 2492 name 2493 }} 2494 modelVersions {{ 2495 name 2496 }} 2497 }} 2498 }} 2499 }} 2500 """, 2501 "operationName": "testQuery", 2502 }, 2503 headers={"content-type": "application/json; charset=utf-8"}, 2504 ) 2505 assert response.status_code == 200 2506 json = response.json() 2507 assert json["errors"] is None 2508 assert json["data"]["mlflowGetRun"]["run"]["info"]["status"] == created_run.info.status 2509 assert json["data"]["mlflowGetRun"]["run"]["experiment"]["name"] == name 2510 assert json["data"]["mlflowGetRun"]["run"]["modelVersions"][0]["name"] == name 2511 2512 2513 def test_legacy_start_and_end_trace_v2(mlflow_client): 2514 experiment_id = mlflow_client.create_experiment("start end trace") 2515 2516 # Trace CRUD APIs are not directly exposed as public API of MlflowClient, 2517 # so we use the underlying tracking client to test them. 2518 store = mlflow_client._tracing_client.store 2519 2520 # Helper function to remove auto-added system tags (mlflow.xxx) from testing 2521 def _exclude_system_tags(tags: dict[str, str]): 2522 return {k: v for k, v in tags.items() if not k.startswith("mlflow.")} 2523 2524 trace_info = store.deprecated_start_trace_v2( 2525 experiment_id=experiment_id, 2526 timestamp_ms=1000, 2527 request_metadata={ 2528 "meta1": "apple", 2529 "meta2": "grape", 2530 }, 2531 tags={ 2532 "tag1": "football", 2533 "tag2": "basketball", 2534 }, 2535 ) 2536 assert trace_info.request_id is not None 2537 assert trace_info.experiment_id == experiment_id 2538 assert trace_info.timestamp_ms == 1000 2539 assert trace_info.execution_time_ms == 0 2540 assert trace_info.status == TraceStatus.IN_PROGRESS 2541 assert trace_info.request_metadata == { 2542 "meta1": "apple", 2543 "meta2": "grape", 2544 } 2545 assert _exclude_system_tags(trace_info.tags) == { 2546 "tag1": "football", 2547 "tag2": "basketball", 2548 } 2549 2550 trace_info = store.deprecated_end_trace_v2( 2551 request_id=trace_info.request_id, 2552 timestamp_ms=3000, 2553 status=TraceStatus.OK, 2554 request_metadata={ 2555 "meta1": "orange", 2556 "meta3": "banana", 2557 }, 2558 tags={ 2559 "tag1": "soccer", 2560 "tag3": "tennis", 2561 }, 2562 ) 2563 assert trace_info.request_id is not None 2564 assert trace_info.experiment_id == experiment_id 2565 assert trace_info.timestamp_ms == 1000 2566 assert trace_info.execution_time_ms == 2000 2567 assert trace_info.status == TraceStatus.OK 2568 assert trace_info.request_metadata == { 2569 "meta1": "orange", 2570 "meta2": "grape", 2571 "meta3": "banana", 2572 } 2573 assert _exclude_system_tags(trace_info.tags) == { 2574 "tag1": "soccer", 2575 "tag2": "basketball", 2576 "tag3": "tennis", 2577 } 2578 2579 2580 def test_start_trace(mlflow_client): 2581 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2582 experiment_id = mlflow.set_experiment("start end trace").experiment_id 2583 2584 # Helper function to remove auto-added system tags (mlflow.xxx) from testing 2585 def _exclude_system_keys(d: dict[str, str]): 2586 return {k: v for k, v in d.items() if not k.startswith("mlflow.")} 2587 2588 with mock.patch("mlflow.tracing.export.mlflow_v3._logger.warning") as mock_warning: 2589 with mlflow.start_span(name="test") as span: 2590 mlflow.update_current_trace( 2591 tags={ 2592 "tag1": "football", 2593 "tag2": "basketball", 2594 }, 2595 metadata={ 2596 "meta1": "apple", 2597 "meta2": "grape", 2598 }, 2599 ) 2600 2601 trace = mlflow_client.get_trace(span.trace_id, flush=True) 2602 assert trace.info.trace_id == span.trace_id 2603 assert trace.info.experiment_id == experiment_id 2604 assert trace.info.request_time > 0 2605 assert trace.info.execution_duration is not None 2606 assert trace.info.state == TraceState.OK 2607 assert _exclude_system_keys(trace.info.trace_metadata) == { 2608 "meta1": "apple", 2609 "meta2": "grape", 2610 } 2611 assert trace.info.trace_metadata[TRACE_SCHEMA_VERSION_KEY] == "3" 2612 assert _exclude_system_keys(trace.info.tags) == { 2613 "tag1": "football", 2614 "tag2": "basketball", 2615 } 2616 2617 # No "Failed to log span to MLflow backend" warning should be issued 2618 for call in mock_warning.call_args_list: 2619 assert "Failed to log span to MLflow backend" not in str(call) 2620 2621 2622 def test_get_trace(mlflow_client): 2623 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2624 experiment_id = mlflow_client.create_experiment("get trace") 2625 span = mlflow_client.start_trace(name="test", experiment_id=experiment_id) 2626 mlflow_client.end_trace(request_id=span.request_id, status=TraceStatus.OK) 2627 trace = mlflow_client.get_trace(span.request_id, flush=True) 2628 assert trace is not None 2629 assert trace.info.request_id == span.request_id 2630 assert trace.info.experiment_id == experiment_id 2631 assert trace.info.state == TraceState.OK 2632 assert len(trace.data.spans) == 1 2633 assert trace.data.spans[0].name == "test" 2634 assert trace.data.spans[0].status.status_code == SpanStatusCode.OK 2635 assert trace.data.spans[0].status.description == "" 2636 2637 2638 def test_search_traces(mlflow_client): 2639 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2640 experiment_id = mlflow_client.create_experiment("search traces") 2641 2642 # Create test traces 2643 def _create_trace(name, status): 2644 span = mlflow_client.start_trace(name=name, experiment_id=experiment_id) 2645 mlflow_client.end_trace(request_id=span.request_id, status=status) 2646 return span.request_id 2647 2648 # Flush between creations to ensure distinct timestamps. Without this, all three traces 2649 # can land in the same millisecond on a fast local server, making max_results ordering 2650 # non-deterministic. 2651 request_id_1 = _create_trace(name="trace1", status=TraceStatus.OK) 2652 mlflow.flush_trace_async_logging() 2653 request_id_2 = _create_trace(name="trace2", status=TraceStatus.OK) 2654 mlflow.flush_trace_async_logging() 2655 request_id_3 = _create_trace(name="trace3", status=TraceStatus.ERROR) 2656 mlflow.flush_trace_async_logging() 2657 2658 def _get_request_ids(traces): 2659 return [t.info.request_id for t in traces] 2660 2661 # Validate search 2662 traces = mlflow_client.search_traces(locations=[experiment_id]) 2663 assert set(_get_request_ids(traces)) == {request_id_3, request_id_2, request_id_1} 2664 assert traces.token is None 2665 2666 traces = mlflow_client.search_traces( 2667 locations=[experiment_id], 2668 filter_string="status = 'OK'", 2669 order_by=["timestamp ASC"], 2670 ) 2671 assert set(_get_request_ids(traces)) == {request_id_1, request_id_2} 2672 assert traces.token is None 2673 2674 traces = mlflow_client.search_traces( 2675 locations=[experiment_id], 2676 max_results=2, 2677 ) 2678 assert set(_get_request_ids(traces)) == {request_id_3, request_id_2} 2679 assert traces.token is not None 2680 traces = mlflow_client.search_traces( 2681 locations=[experiment_id], 2682 page_token=traces.token, 2683 ) 2684 assert _get_request_ids(traces) == [request_id_1] 2685 assert traces.token is None 2686 2687 2688 def test_search_traces_parameter_validation(mlflow_client): 2689 with pytest.raises( 2690 MlflowException, 2691 match="Locations must be a list of experiment IDs", 2692 ): 2693 mlflow_client.search_traces(locations=["catalog.schema"]) 2694 2695 2696 def test_search_traces_match_text(mlflow_client, store_type): 2697 if store_type == "file": 2698 pytest.skip("File store doesn't support full text search") 2699 2700 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2701 experiment_id = mlflow_client.create_experiment("search traces full text") 2702 2703 # Create test traces 2704 def _create_trace(name, attributes): 2705 span = mlflow_client.start_trace(name=name, experiment_id=experiment_id) 2706 span.set_attributes(attributes) 2707 mlflow_client.end_trace(request_id=span.trace_id, status=TraceStatus.OK) 2708 return span.trace_id 2709 2710 trace_id_1 = _create_trace(name="trace1", attributes={"test": "value1"}) 2711 trace_id_2 = _create_trace(name="trace2", attributes={"test": "value2"}) 2712 trace_id_3 = _create_trace(name="trace3", attributes={"test3": "I like it"}) 2713 2714 traces = mlflow_client.search_traces(locations=[experiment_id], flush=True) 2715 assert len([t.info.trace_id for t in traces]) == 3 2716 assert traces.token is None 2717 2718 traces = mlflow_client.search_traces( 2719 locations=[experiment_id], filter_string="trace.text LIKE '%trace%'" 2720 ) 2721 assert len([t.info.trace_id for t in traces]) == 3 2722 assert traces.token is None 2723 2724 traces = mlflow_client.search_traces( 2725 locations=[experiment_id], filter_string="trace.text LIKE '%value%'" 2726 ) 2727 assert {t.info.trace_id for t in traces} == {trace_id_1, trace_id_2} 2728 2729 traces = mlflow_client.search_traces( 2730 locations=[experiment_id], filter_string="trace.text LIKE '%I like it%'" 2731 ) 2732 assert [t.info.trace_id for t in traces] == [trace_id_3] 2733 2734 2735 def test_delete_traces(mlflow_client): 2736 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2737 experiment_id = mlflow_client.create_experiment("delete traces") 2738 2739 def _create_trace(name, status): 2740 span = mlflow_client.start_trace(name=name, experiment_id=experiment_id) 2741 mlflow_client.end_trace(request_id=span.request_id, status=status) 2742 return span.request_id 2743 2744 def _is_trace_exists(request_id): 2745 try: 2746 trace_info = mlflow_client._tracing_client.get_trace_info(request_id) 2747 return trace_info is not None 2748 except RestException as e: 2749 if e.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST): 2750 return False 2751 raise 2752 2753 # Case 1: Delete all traces under experiment ID 2754 request_id_1 = _create_trace(name="trace1", status=TraceStatus.OK) 2755 request_id_2 = _create_trace(name="trace2", status=TraceStatus.OK) 2756 mlflow.flush_trace_async_logging() 2757 assert _is_trace_exists(request_id_1) 2758 assert _is_trace_exists(request_id_2) 2759 2760 deleted_count = mlflow_client.delete_traces(experiment_id, max_timestamp_millis=int(1e15)) 2761 assert deleted_count == 2 2762 assert not _is_trace_exists(request_id_1) 2763 assert not _is_trace_exists(request_id_2) 2764 2765 # Case 2: Delete with max_traces limit 2766 request_id_1 = _create_trace(name="trace1", status=TraceStatus.OK) 2767 time.sleep(0.1) # Add some time gap to avoid timestamp collision 2768 request_id_2 = _create_trace(name="trace2", status=TraceStatus.OK) 2769 mlflow.flush_trace_async_logging() 2770 2771 deleted_count = mlflow_client.delete_traces( 2772 experiment_id, max_traces=1, max_timestamp_millis=int(1e15) 2773 ) 2774 assert deleted_count == 1 2775 # TODO: Currently the deletion order in the file store is random (based on 2776 # the order of the trace files in the directory), so we don't validate which 2777 # one is deleted. Uncomment the following lines once the deletion order is fixed. 2778 # assert not _is_trace_exists(request_id_1) # Old created trace should be deleted 2779 # assert _is_trace_exists(request_id_2) 2780 2781 # Case 3: Delete with explicit request ID 2782 request_id_1 = _create_trace(name="trace1", status=TraceStatus.OK) 2783 request_id_2 = _create_trace(name="trace2", status=TraceStatus.OK) 2784 mlflow.flush_trace_async_logging() 2785 2786 deleted_count = mlflow_client.delete_traces(experiment_id, trace_ids=[request_id_1]) 2787 assert deleted_count == 1 2788 assert not _is_trace_exists(request_id_1) 2789 assert _is_trace_exists(request_id_2) 2790 2791 2792 def test_calculate_trace_filter_correlation(mlflow_client, store_type): 2793 if store_type == "file": 2794 pytest.skip("File store doesn't support calculate_trace_filter_correlation") 2795 2796 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2797 experiment_id = mlflow_client.create_experiment("correlation test") 2798 2799 def _create_trace(name, tags): 2800 span = mlflow_client.start_trace(name=name, experiment_id=experiment_id, tags=tags) 2801 mlflow_client.end_trace(request_id=span.request_id, status=TraceStatus.OK) 2802 return span.request_id 2803 2804 for i in range(6): 2805 _create_trace(f"trace-prod-tool-{i}", {"env": "prod", "span_type": "TOOL"}) 2806 2807 for i in range(4): 2808 _create_trace(f"trace-dev-{i}", {"env": "dev", "span_type": "LLM" if i >= 1 else "TOOL"}) 2809 2810 client = TracingClient(tracking_uri=mlflow_client.tracking_uri) 2811 2812 mlflow.flush_trace_async_logging() 2813 2814 result = client.calculate_trace_filter_correlation( 2815 experiment_ids=[experiment_id], 2816 filter_string1="tags.env = 'prod'", 2817 filter_string2="tags.span_type = 'TOOL'", 2818 ) 2819 2820 assert isinstance(result, TraceFilterCorrelationResult) 2821 assert result.total_count == 10 2822 assert result.filter1_count == 6 2823 assert result.filter2_count == 7 2824 assert result.joint_count == 6 2825 assert 0.6 < result.npmi < 0.8 2826 assert result.npmi_smoothed is not None 2827 2828 result2 = client.calculate_trace_filter_correlation( 2829 experiment_ids=[experiment_id], 2830 filter_string1="tags.env = 'dev'", 2831 filter_string2="tags.span_type = 'LLM'", 2832 ) 2833 2834 assert result2.total_count == 10 2835 assert result2.filter1_count == 4 2836 assert result2.filter2_count == 3 2837 assert result2.joint_count == 3 2838 assert result2.npmi > 0.5 2839 2840 result3 = client.calculate_trace_filter_correlation( 2841 experiment_ids=[experiment_id], 2842 filter_string1="tags.env = 'staging'", 2843 filter_string2="tags.span_type = 'TOOL'", 2844 ) 2845 2846 assert result3.total_count == 10 2847 assert result3.filter1_count == 0 2848 assert result3.filter2_count == 7 2849 assert result3.joint_count == 0 2850 assert math.isnan(result3.npmi) 2851 2852 with pytest.raises(MlflowException, match="Invalid"): 2853 client.calculate_trace_filter_correlation( 2854 experiment_ids=[experiment_id], 2855 filter_string1="invalid.filter = 'test'", 2856 filter_string2="tags.span_type = 'TOOL'", 2857 ) 2858 2859 2860 def test_set_and_delete_trace_tag(mlflow_client): 2861 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2862 experiment_id = mlflow_client.create_experiment("set delete tag") 2863 2864 # Create test trace 2865 trace_info = mlflow_client._tracing_client.start_trace( 2866 TraceInfo( 2867 trace_id="tr-1234", 2868 trace_location=TraceLocation.from_experiment_id(experiment_id), 2869 request_time=1000, 2870 execution_duration=2000, 2871 state=TraceState.OK, 2872 tags={ 2873 "tag1": "red", 2874 "tag2": "blue", 2875 }, 2876 ) 2877 ) 2878 2879 # Validate set tag 2880 mlflow_client.set_trace_tag(trace_info.request_id, "tag1", "green") 2881 trace_info = mlflow_client._tracing_client.get_trace_info(trace_info.request_id) 2882 assert trace_info.tags["tag1"] == "green" 2883 2884 # Validate delete tag 2885 mlflow_client.delete_trace_tag(trace_info.request_id, "tag2") 2886 trace_info = mlflow_client._tracing_client.get_trace_info(trace_info.request_id) 2887 assert "tag2" not in trace_info.tags 2888 2889 2890 def test_query_trace_metrics(mlflow_client, store_type): 2891 if store_type == "file": 2892 pytest.skip("File store doesn't support query trace metrics") 2893 2894 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2895 experiment_id = mlflow_client.create_experiment("query trace metrics") 2896 2897 # Create test traces 2898 def _create_trace(name, status): 2899 span = mlflow_client.start_trace(name=name, experiment_id=experiment_id) 2900 mlflow_client.end_trace(request_id=span.request_id, status=status) 2901 return span.request_id 2902 2903 _create_trace(name="trace1", status=TraceStatus.OK) 2904 _create_trace(name="trace2", status=TraceStatus.OK) 2905 _create_trace(name="trace3", status=TraceStatus.ERROR) 2906 2907 mlflow.flush_trace_async_logging() 2908 2909 metrics = mlflow_client._tracing_client.store.query_trace_metrics( 2910 experiment_ids=[experiment_id], 2911 view_type=MetricViewType.TRACES, 2912 metric_name=TraceMetricKey.TRACE_COUNT, 2913 aggregations=[MetricAggregation(aggregation_type=AggregationType.COUNT)], 2914 dimensions=[TraceMetricDimensionKey.TRACE_STATUS], 2915 ) 2916 assert len(metrics) == 2 2917 assert asdict(metrics[0]) == { 2918 "metric_name": TraceMetricKey.TRACE_COUNT, 2919 "dimensions": {TraceMetricDimensionKey.TRACE_STATUS: "ERROR"}, 2920 "values": {"COUNT": 1}, 2921 } 2922 2923 assert asdict(metrics[1]) == { 2924 "metric_name": TraceMetricKey.TRACE_COUNT, 2925 "dimensions": {TraceMetricDimensionKey.TRACE_STATUS: "OK"}, 2926 "values": {"COUNT": 2}, 2927 } 2928 2929 2930 @pytest.mark.parametrize("allow_partial", [True, False]) 2931 def test_get_trace_handler(mlflow_client, allow_partial: bool, store_type): 2932 if store_type == "file": 2933 pytest.skip("File store doesn't support get trace handler") 2934 2935 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2936 2937 with mlflow.start_span(name="test") as span: 2938 span.set_attributes({"fruit": "apple"}) 2939 2940 mlflow.flush_trace_async_logging() 2941 2942 response = requests.get( 2943 f"{mlflow_client.tracking_uri}/ajax-api/3.0/mlflow/traces/get", 2944 params={"trace_id": span.trace_id, "allow_partial": allow_partial}, 2945 ) 2946 2947 assert response.status_code == 200 2948 2949 trace = response.json()["trace"] 2950 assert trace["trace_info"]["trace_id"] == span.trace_id 2951 assert len(trace["spans"]) == 1 2952 assert trace["spans"][0]["name"] == "test" 2953 attributes = trace["spans"][0]["attributes"] 2954 assert {"key": "fruit", "value": {"string_value": "apple"}} in attributes 2955 2956 2957 def test_get_trace_artifact_handler(mlflow_client): 2958 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2959 2960 with mlflow.start_span(name="test") as span: 2961 span.set_attributes({"fruit": "apple"}) 2962 span.add_event(SpanEvent("test_event", timestamp=99999, attributes={"foo": "bar"})) 2963 2964 mlflow.flush_trace_async_logging() 2965 2966 response = requests.get( 2967 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/get-trace-artifact", 2968 params={"request_id": span.trace_id}, 2969 ) 2970 assert response.status_code == 200 2971 assert response.headers["Content-Disposition"] == "attachment; filename=traces.json" 2972 2973 # Validate content 2974 trace_data = TraceData.from_dict(json.loads(response.text)) 2975 assert trace_data.spans[0].to_dict() == span.to_dict() 2976 2977 2978 def test_link_traces_to_run_and_search_traces(mlflow_client, store_type): 2979 # Skip file store because it doesn't support linking traces to runs 2980 if store_type == "file": 2981 pytest.skip("File store doesn't support linking traces to runs") 2982 2983 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 2984 experiment_id = mlflow.set_experiment("link traces to run test").experiment_id 2985 2986 run = mlflow_client.create_run(experiment_id) 2987 run_id = run.info.run_id 2988 2989 # 1. Trace created under a run 2990 with mlflow.start_run(run_id=run_id): 2991 with mlflow.start_span(name="trace1") as span1: 2992 span1.set_attributes({"test": "value1"}) 2993 trace_id_1 = span1.trace_id 2994 2995 # 2. Trace associated with a run 2996 with mlflow.start_span(name="trace2") as span2: 2997 span2.set_attributes({"test": "value2"}) 2998 trace_id_2 = span2.trace_id 2999 mlflow_client.link_traces_to_run(trace_ids=[trace_id_2], run_id=run_id) 3000 3001 # 3. Trace not associated with a run 3002 with mlflow.start_span(name="trace3") as span3: 3003 span3.set_attributes({"test": "value3"}) 3004 trace_id_3 = span3.trace_id 3005 3006 # Search traces without run_id filter - should return all traces in experiment 3007 all_traces = mlflow_client.search_traces(locations=[experiment_id], flush=True) 3008 assert {t.info.trace_id for t in all_traces} == {trace_id_1, trace_id_2, trace_id_3} 3009 3010 # Search traces with run_id filter - should return only linked traces 3011 linked_traces = mlflow_client.search_traces( 3012 locations=[experiment_id], filter_string=f"attribute.run_id = '{run_id}'" 3013 ) 3014 linked_trace_ids = [t.info.trace_id for t in linked_traces] 3015 assert len(linked_trace_ids) == 2 3016 assert set(linked_trace_ids) == {trace_id_1, trace_id_2} 3017 3018 3019 def test_get_metric_history_bulk_interval_graphql(mlflow_client): 3020 name = "GraphqlTest" 3021 mlflow_client.create_registered_model(name) 3022 experiment_id = mlflow_client.create_experiment(name) 3023 created_run = mlflow_client.create_run(experiment_id) 3024 3025 metric_name = "metric_0" 3026 for i in range(10): 3027 mlflow_client.log_metric(created_run.info.run_id, metric_name, i, step=i) 3028 3029 response = requests.post( 3030 f"{mlflow_client.tracking_uri}/graphql", 3031 json={ 3032 "query": f""" 3033 query testQuery {{ 3034 mlflowGetMetricHistoryBulkInterval(input: {{ 3035 runIds: ["{created_run.info.run_id}"], 3036 metricKey: "{metric_name}", 3037 }}) {{ 3038 metrics {{ 3039 key 3040 timestamp 3041 value 3042 }} 3043 }} 3044 }} 3045 """, 3046 "operationName": "testQuery", 3047 }, 3048 headers={"content-type": "application/json; charset=utf-8"}, 3049 ) 3050 3051 assert response.status_code == 200 3052 json = response.json() 3053 expected = [{"key": metric_name, "timestamp": mock.ANY, "value": i} for i in range(10)] 3054 assert json["data"]["mlflowGetMetricHistoryBulkInterval"]["metrics"] == expected 3055 3056 3057 def test_search_runs_graphql(mlflow_client): 3058 name = "GraphqlTest" 3059 mlflow_client.create_registered_model(name) 3060 experiment_id = mlflow_client.create_experiment(name) 3061 created_run_1 = mlflow_client.create_run(experiment_id) 3062 created_run_2 = mlflow_client.create_run(experiment_id) 3063 3064 response = requests.post( 3065 f"{mlflow_client.tracking_uri}/graphql", 3066 json={ 3067 "query": f""" 3068 mutation testMutation {{ 3069 mlflowSearchRuns(input: {{ experimentIds: ["{experiment_id}"] }}) {{ 3070 runs {{ 3071 info {{ 3072 runId 3073 }} 3074 }} 3075 }} 3076 }} 3077 """, 3078 "operationName": "testMutation", 3079 }, 3080 headers={"content-type": "application/json; charset=utf-8"}, 3081 ) 3082 3083 assert response.status_code == 200 3084 json = response.json() 3085 expected = [ 3086 {"info": {"runId": created_run_2.info.run_id}}, 3087 {"info": {"runId": created_run_1.info.run_id}}, 3088 ] 3089 assert json["data"]["mlflowSearchRuns"]["runs"] == expected 3090 3091 3092 def test_list_artifacts_graphql(mlflow_client, tmp_path): 3093 name = "GraphqlTest" 3094 experiment_id = mlflow_client.create_experiment(name) 3095 created_run_id = mlflow_client.create_run(experiment_id).info.run_id 3096 file_path = tmp_path / "test.txt" 3097 file_path.write_text("hello world") 3098 mlflow_client.log_artifact(created_run_id, file_path.absolute().as_posix()) 3099 mlflow_client.log_artifact(created_run_id, file_path.absolute().as_posix(), "testDir") 3100 3101 response = requests.post( 3102 f"{mlflow_client.tracking_uri}/graphql", 3103 json={ 3104 "query": f""" 3105 query testQuery {{ 3106 files: mlflowListArtifacts(input: {{ 3107 runId: "{created_run_id}", 3108 }}) {{ 3109 files {{ 3110 path 3111 isDir 3112 fileSize 3113 }} 3114 }} 3115 }} 3116 """, 3117 "operationName": "testQuery", 3118 }, 3119 headers={"content-type": "application/json; charset=utf-8"}, 3120 ) 3121 3122 assert response.status_code == 200 3123 json = response.json() 3124 file_expected = [ 3125 {"path": "test.txt", "isDir": False, "fileSize": "11"}, 3126 {"path": "testDir", "isDir": True, "fileSize": "0"}, 3127 ] 3128 assert json["data"]["files"]["files"] == file_expected 3129 3130 response = requests.post( 3131 f"{mlflow_client.tracking_uri}/graphql", 3132 json={ 3133 "query": f""" 3134 query testQuery {{ 3135 subdir: mlflowListArtifacts(input: {{ 3136 runId: "{created_run_id}", 3137 path: "testDir", 3138 }}) {{ 3139 files {{ 3140 path 3141 isDir 3142 fileSize 3143 }} 3144 }} 3145 }} 3146 """, 3147 "operationName": "testQuery", 3148 }, 3149 headers={"content-type": "application/json; charset=utf-8"}, 3150 ) 3151 3152 assert response.status_code == 200 3153 json = response.json() 3154 subdir_expected = [ 3155 {"path": "testDir/test.txt", "isDir": False, "fileSize": "11"}, 3156 ] 3157 assert json["data"]["subdir"]["files"] == subdir_expected 3158 3159 3160 def test_search_datasets_graphql(mlflow_client): 3161 name = "GraphqlTest" 3162 experiment_id = mlflow_client.create_experiment(name) 3163 created_run_id = mlflow_client.create_run(experiment_id).info.run_id 3164 dataset1 = Dataset( 3165 name="test-dataset-1", 3166 digest="12345", 3167 source_type="script", 3168 source="test", 3169 ) 3170 dataset_input1 = DatasetInput(dataset=dataset1, tags=[]) 3171 dataset2 = Dataset( 3172 name="test-dataset-2", 3173 digest="12346", 3174 source_type="script", 3175 source="test", 3176 ) 3177 dataset_input2 = DatasetInput( 3178 dataset=dataset2, tags=[InputTag(key=MLFLOW_DATASET_CONTEXT, value="training")] 3179 ) 3180 mlflow_client.log_inputs(created_run_id, [dataset_input1, dataset_input2]) 3181 3182 response = requests.post( 3183 f"{mlflow_client.tracking_uri}/graphql", 3184 json={ 3185 "query": f""" 3186 mutation testMutation {{ 3187 mlflowSearchDatasets(input:{{experimentIds: ["{experiment_id}"]}}) {{ 3188 datasetSummaries {{ 3189 experimentId 3190 name 3191 digest 3192 context 3193 }} 3194 }} 3195 }} 3196 """, 3197 "operationName": "testMutation", 3198 }, 3199 headers={"content-type": "application/json; charset=utf-8"}, 3200 ) 3201 3202 assert response.status_code == 200 3203 json = response.json() 3204 3205 def sort_dataset_summaries(l1): 3206 return sorted(l1, key=lambda x: x["digest"]) 3207 3208 expected = sort_dataset_summaries([ 3209 { 3210 "experimentId": experiment_id, 3211 "name": "test-dataset-2", 3212 "digest": "12346", 3213 "context": "training", 3214 }, 3215 { 3216 "experimentId": experiment_id, 3217 "name": "test-dataset-1", 3218 "digest": "12345", 3219 "context": "", 3220 }, 3221 ]) 3222 assert ( 3223 sort_dataset_summaries(json["data"]["mlflowSearchDatasets"]["datasetSummaries"]) == expected 3224 ) 3225 3226 3227 def test_create_logged_model(mlflow_client: MlflowClient): 3228 exp_id = mlflow_client.create_experiment("create_logged_model") 3229 model = mlflow_client.create_logged_model(exp_id) 3230 loaded_model = mlflow_client.get_logged_model(model.model_id) 3231 assert model.model_id == loaded_model.model_id 3232 3233 model = mlflow_client.create_logged_model(exp_id, name="my_model") 3234 loaded_model = mlflow_client.get_logged_model(model.model_id) 3235 assert model.name == "my_model" 3236 3237 model = mlflow_client.create_logged_model(exp_id, model_type="LLM") 3238 loaded_model = mlflow_client.get_logged_model(model.model_id) 3239 assert model.model_type == "LLM" 3240 3241 model = mlflow_client.create_logged_model(exp_id, source_run_id="123") 3242 loaded_model = mlflow_client.get_logged_model(model.model_id) 3243 assert model.source_run_id == "123" 3244 3245 model = mlflow_client.create_logged_model(exp_id, params={"param": "value"}) 3246 loaded_model = mlflow_client.get_logged_model(model.model_id) 3247 assert model.params == {"param": "value"} 3248 3249 model = mlflow_client.create_logged_model(exp_id, tags={"tag": "value"}) 3250 loaded_model = mlflow_client.get_logged_model(model.model_id) 3251 assert model.tags == {"tag": "value"} 3252 3253 3254 def test_log_logged_model_params(mlflow_client: MlflowClient): 3255 exp_id = mlflow_client.create_experiment("create_logged_model") 3256 model = mlflow_client.create_logged_model(exp_id) 3257 mlflow_client.log_model_params(model.model_id, {"param": "value"}) 3258 loaded_model = mlflow_client.get_logged_model(model.model_id) 3259 assert loaded_model.params == {"param": "value"} 3260 3261 3262 def test_finalize_logged_model(mlflow_client: MlflowClient): 3263 exp_id = mlflow_client.create_experiment("create_logged_model") 3264 model = mlflow_client.create_logged_model(exp_id) 3265 finalized_model = mlflow_client.finalize_logged_model(model.model_id, LoggedModelStatus.READY) 3266 assert finalized_model.status == LoggedModelStatus.READY 3267 3268 finalized_model = mlflow_client.finalize_logged_model(model.model_id, LoggedModelStatus.FAILED) 3269 assert finalized_model.status == LoggedModelStatus.FAILED 3270 3271 3272 def test_delete_logged_model(mlflow_client: MlflowClient): 3273 exp_id = mlflow_client.create_experiment("delete_logged_model") 3274 model = mlflow_client.create_logged_model(experiment_id=exp_id) 3275 mlflow_client.delete_logged_model(model.model_id) 3276 with pytest.raises(MlflowException, match="not found"): 3277 mlflow_client.get_logged_model(model.model_id) 3278 3279 models = mlflow_client.search_logged_models(experiment_ids=[exp_id]) 3280 assert len(models) == 0 3281 3282 3283 def test_set_logged_model_tags(mlflow_client: MlflowClient): 3284 exp_id = mlflow_client.create_experiment("create_logged_model") 3285 model = mlflow_client.create_logged_model(exp_id) 3286 mlflow_client.set_logged_model_tags(model.model_id, {"tag1": "value1", "tag2": "value2"}) 3287 loaded_model = mlflow_client.get_logged_model(model.model_id) 3288 assert loaded_model.tags == {"tag1": "value1", "tag2": "value2"} 3289 3290 mlflow_client.set_logged_model_tags(model.model_id, {"tag1": "value3"}) 3291 loaded_model = mlflow_client.get_logged_model(model.model_id) 3292 assert loaded_model.tags == {"tag1": "value3", "tag2": "value2"} 3293 3294 3295 def test_delete_logged_model_tag(mlflow_client: MlflowClient): 3296 exp_id = mlflow_client.create_experiment("create_logged_model") 3297 model = mlflow_client.create_logged_model(exp_id) 3298 mlflow_client.set_logged_model_tags(model.model_id, {"tag1": "value1", "tag2": "value2"}) 3299 mlflow_client.delete_logged_model_tag(model.model_id, "tag1") 3300 loaded_model = mlflow_client.get_logged_model(model.model_id) 3301 assert loaded_model.tags == {"tag2": "value2"} 3302 3303 with pytest.raises(MlflowException, match="No tag with key"): 3304 mlflow_client.delete_logged_model_tag(model.model_id, "tag1") 3305 3306 3307 def test_search_logged_models(mlflow_client: MlflowClient): 3308 exp_id = mlflow_client.create_experiment("create_logged_model") 3309 model_1 = mlflow_client.create_logged_model(exp_id) 3310 time.sleep(0.001) # to ensure different created time 3311 models = mlflow_client.search_logged_models(experiment_ids=[exp_id]) 3312 assert [m.name for m in models] == [model_1.name] 3313 3314 # max_results 3315 model_2 = mlflow_client.create_logged_model(exp_id) 3316 page_1 = mlflow_client.search_logged_models(experiment_ids=[exp_id], max_results=1) 3317 assert [m.name for m in page_1] == [model_2.name] 3318 assert page_1.token is not None 3319 3320 # pagination 3321 page_2 = mlflow_client.search_logged_models( 3322 experiment_ids=[exp_id], max_results=1, page_token=page_1.token 3323 ) 3324 assert [m.name for m in page_2] == [model_1.name] 3325 assert page_2.token is None 3326 3327 # filter_string 3328 models = mlflow_client.search_logged_models( 3329 experiment_ids=[exp_id], filter_string=f"name = {model_1.name!r}" 3330 ) 3331 assert [m.name for m in models] == [model_1.name] 3332 3333 # datasets 3334 run_1 = mlflow_client.create_run(exp_id) 3335 mlflow_client.log_metric( 3336 run_1.info.run_id, 3337 key="metric", 3338 value=1, 3339 dataset_name="dataset", 3340 dataset_digest="123", 3341 model_id=model_1.model_id, 3342 ) 3343 models = mlflow_client.search_logged_models( 3344 experiment_ids=[exp_id], 3345 datasets=[{"dataset_name": "dataset", "dataset_digest": "123"}], 3346 ) 3347 3348 assert [m.name for m in models] == [model_1.name] 3349 3350 # order_by 3351 models = mlflow_client.search_logged_models( 3352 experiment_ids=[exp_id], 3353 order_by=[{"field_name": "creation_timestamp", "ascending": False}], 3354 ) 3355 assert [m.name for m in models] == [model_2.name, model_1.name] 3356 3357 3358 def test_log_outputs(mlflow_client: MlflowClient): 3359 exp_id = mlflow_client.create_experiment("log_outputs") 3360 run = mlflow_client.create_run(experiment_id=exp_id) 3361 model = mlflow_client.create_logged_model(experiment_id=exp_id) 3362 model_outputs = [LoggedModelOutput(model.model_id, 1)] 3363 mlflow_client.log_outputs(run.info.run_id, model_outputs) 3364 run = mlflow_client.get_run(run.info.run_id) 3365 assert run.outputs.model_outputs == model_outputs 3366 3367 3368 def test_list_logged_model_artifacts(mlflow_client: MlflowClient): 3369 class Model(mlflow.pyfunc.PythonModel): 3370 def predict(self, context, model_input): 3371 return model_input 3372 3373 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 3374 model_info = mlflow.pyfunc.log_model(name="model", python_model=Model()) 3375 resp = requests.get( 3376 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/logged-models/{model_info.model_id}/artifacts/directories" 3377 ) 3378 assert resp.status_code == 200 3379 data = resp.json() 3380 paths = [f["path"] for f in data["files"]] 3381 assert "MLmodel" in paths 3382 3383 3384 def test_get_logged_model_artifact(mlflow_client: MlflowClient): 3385 class Model(mlflow.pyfunc.PythonModel): 3386 def predict(self, context, model_input): 3387 return model_input 3388 3389 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 3390 model_info = mlflow.pyfunc.log_model(name="model", python_model=Model()) 3391 resp = requests.get( 3392 f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/logged-models/{model_info.model_id}/artifacts/files", 3393 params={"artifact_file_path": "MLmodel"}, 3394 ) 3395 assert resp.status_code == 200 3396 assert model_info.model_id in resp.text 3397 3398 3399 def test_suppress_url_printing(mlflow_client: MlflowClient, monkeypatch): 3400 monkeypatch.setenv(MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT.name, "true") 3401 exp_id = mlflow_client.create_experiment("test_suppress_url_printing") 3402 run = mlflow_client.create_run(experiment_id=exp_id) 3403 captured_output = StringIO() 3404 monkeypatch.setattr(sys, "stdout", captured_output) 3405 mlflow_client._tracking_client._log_url(run.info.run_id) 3406 assert captured_output.getvalue() == "" 3407 3408 3409 def test_log_url_includes_workspace_when_set(mlflow_client: MlflowClient, monkeypatch): 3410 exp_id = mlflow_client.create_experiment("test_log_url_workspace") 3411 run = mlflow_client.create_run(experiment_id=exp_id) 3412 captured_output = StringIO() 3413 monkeypatch.setattr(sys, "stdout", captured_output) 3414 monkeypatch.setattr( 3415 "mlflow.tracking._tracking_service.client.get_workspace_url", lambda: "http://localhost" 3416 ) 3417 monkeypatch.setattr( 3418 "mlflow.tracking._tracking_service.client.get_request_workspace", lambda: "team-space" 3419 ) 3420 3421 mlflow_client._tracking_client._log_url(run.info.run_id) 3422 3423 out = captured_output.getvalue() 3424 expected_fragment = f"/#/experiments/{exp_id}/runs/{run.info.run_id}?workspace=team-space" 3425 assert expected_fragment in out 3426 3427 3428 def test_assessments_end_to_end(mlflow_client): 3429 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 3430 3431 # Set up experiment and trace 3432 experiment_id = mlflow_client.create_experiment("assessment_crud_test") 3433 trace_info = mlflow_client.start_trace(name="test_trace", experiment_id=experiment_id) 3434 mlflow_client.end_trace(request_id=trace_info.request_id) 3435 mlflow.flush_trace_async_logging() 3436 3437 # CREATE initial feedback assessment 3438 feedback_payload = { 3439 "assessment": { 3440 "assessment_name": "quality_score", 3441 "feedback": {"value": {"rating": 4, "comments": "Good response"}}, 3442 "source": {"source_type": "HUMAN", "source_id": "evaluator@company.com"}, 3443 "rationale": "Response was accurate and helpful", 3444 "metadata": {"model": "gpt-4", "version": "1.0"}, 3445 } 3446 } 3447 3448 # CREATE assessment 3449 create_response = requests.post( 3450 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments", 3451 json=feedback_payload, 3452 ) 3453 assert create_response.status_code == 200 3454 assessment = create_response.json()["assessment"] 3455 assessment_id = assessment["assessment_id"] 3456 3457 # Verify creation 3458 assert assessment["assessment_name"] == "quality_score" 3459 assert assessment["feedback"]["value"]["rating"] == 4 3460 assert assessment["source"]["source_type"] == "HUMAN" 3461 assert assessment["valid"] is True 3462 3463 # GET assessment 3464 get_response = requests.get( 3465 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{assessment_id}" 3466 ) 3467 assert get_response.status_code == 200 3468 retrieved = get_response.json()["assessment"] 3469 assert retrieved["assessment_id"] == assessment_id 3470 assert retrieved["feedback"]["value"]["rating"] == 4 3471 3472 # UPDATE assessment 3473 update_payload = { 3474 "assessment": { 3475 "assessment_id": assessment_id, 3476 "trace_id": trace_info.request_id, 3477 "assessment_name": "updated_quality_score", 3478 "feedback": {"value": {"rating": 5, "comments": "Excellent response"}}, 3479 "rationale": "Actually, the response was excellent", 3480 "metadata": {"model": "gpt-4", "version": "2.0"}, 3481 }, 3482 "update_mask": "assessmentName,feedback,rationale,metadata", 3483 } 3484 3485 update_response = requests.patch( 3486 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{assessment_id}", 3487 json=update_payload, 3488 ) 3489 assert update_response.status_code == 200 3490 updated = update_response.json()["assessment"] 3491 assert updated["assessment_name"] == "updated_quality_score" 3492 assert updated["feedback"]["value"]["rating"] == 5 3493 assert updated["rationale"] == "Actually, the response was excellent" 3494 3495 # CREATE override assessment 3496 override_payload = { 3497 "assessment": { 3498 "assessment_name": "corrected_quality_score", 3499 "feedback": {"value": {"rating": 3, "comments": "Actually needs improvement"}}, 3500 "source": {"source_type": "HUMAN", "source_id": "senior_evaluator@company.com"}, 3501 "overrides": assessment_id, 3502 } 3503 } 3504 3505 override_response = requests.post( 3506 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments", 3507 json=override_payload, 3508 ) 3509 assert override_response.status_code == 200 3510 override_assessment = override_response.json()["assessment"] 3511 override_id = override_assessment["assessment_id"] 3512 3513 # Verify original is now invalid 3514 get_original = requests.get( 3515 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{assessment_id}" 3516 ) 3517 assert get_original.status_code == 200 3518 assert get_original.json()["assessment"]["valid"] is False 3519 3520 # Verify override is valid 3521 get_override = requests.get( 3522 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{override_id}" 3523 ) 3524 assert get_override.status_code == 200 3525 assert get_override.json()["assessment"]["valid"] is True 3526 assert get_override.json()["assessment"]["overrides"] == assessment_id 3527 3528 # DELETE override assessment (should restore original) 3529 delete_response = requests.delete( 3530 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{override_id}" 3531 ) 3532 assert delete_response.status_code == 200 3533 3534 # Verify override is deleted 3535 get_deleted = requests.get( 3536 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{override_id}" 3537 ) 3538 assert get_deleted.status_code == 404 3539 3540 # Verify original is restored to valid 3541 get_restored = requests.get( 3542 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{assessment_id}" 3543 ) 3544 assert get_restored.status_code == 200 3545 assert get_restored.json()["assessment"]["valid"] is True 3546 3547 # CREATE expectation assessment to test different type 3548 expectation_payload = { 3549 "assessment": { 3550 "assessment_name": "response_time_check", 3551 "expectation": {"value": {"threshold_ms": 1000, "actual_ms": 750, "passed": True}}, 3552 "source": {"source_type": "CODE", "source_id": "automated_test"}, 3553 } 3554 } 3555 3556 expectation_response = requests.post( 3557 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments", 3558 json=expectation_payload, 3559 ) 3560 assert expectation_response.status_code == 200 3561 expectation = expectation_response.json()["assessment"] 3562 expectation_id = expectation["assessment_id"] 3563 3564 # Verify expectation was created correctly 3565 expectation_value = json.loads(expectation["expectation"]["serialized_value"]["value"]) 3566 assert expectation_value["passed"] is True 3567 assert expectation_value["threshold_ms"] == 1000 3568 assert expectation_value["actual_ms"] == 750 3569 assert expectation["source"]["source_type"] == "CODE" 3570 3571 # Clean up - delete remaining assessments 3572 for aid in [assessment_id, expectation_id]: 3573 delete_resp = requests.delete( 3574 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{aid}" 3575 ) 3576 assert delete_resp.status_code == 200 3577 3578 3579 def test_graphql_nan_metric_handling(mlflow_client): 3580 experiment_id = mlflow_client.create_experiment("test_graphql_nan_metrics") 3581 created_run = mlflow_client.create_run(experiment_id) 3582 run_id = created_run.info.run_id 3583 3584 # Log a normal metric and a NaN metric 3585 mlflow_client.log_metric(run_id, key="normal_metric", value=123, timestamp=1, step=1) 3586 mlflow_client.log_metric(run_id, key="nan_metric", value=math.nan, timestamp=2, step=2) 3587 3588 response = requests.post( 3589 f"{mlflow_client.tracking_uri}/graphql", 3590 json={ 3591 "query": f""" 3592 query testQuery {{ 3593 mlflowGetRun(input: {{runId: "{run_id}"}}) {{ 3594 run {{ 3595 data {{ 3596 metrics {{ 3597 key 3598 value 3599 timestamp 3600 step 3601 }} 3602 }} 3603 }} 3604 }} 3605 }} 3606 """, 3607 "operationName": "testQuery", 3608 }, 3609 headers={"content-type": "application/json; charset=utf-8"}, 3610 ) 3611 3612 assert response.status_code == 200 3613 json_response = response.json() 3614 assert json_response["errors"] is None 3615 3616 metrics = json_response["data"]["mlflowGetRun"]["run"]["data"]["metrics"] 3617 3618 # Find the normal metric and nan metric 3619 normal_metric = None 3620 nan_metric = None 3621 for metric in metrics: 3622 if metric["key"] == "normal_metric": 3623 normal_metric = metric 3624 elif metric["key"] == "nan_metric": 3625 nan_metric = metric 3626 3627 # Verify normal metric has a numeric value 3628 assert normal_metric is not None 3629 assert normal_metric["key"] == "normal_metric" 3630 assert normal_metric["value"] == 123 3631 assert normal_metric["timestamp"] == "1" 3632 assert normal_metric["step"] == "1" 3633 3634 # Verify NaN metric has null value 3635 assert nan_metric is not None 3636 assert nan_metric["key"] == "nan_metric" 3637 assert nan_metric["value"] is None 3638 assert nan_metric["timestamp"] == "2" 3639 assert nan_metric["step"] == "2" 3640 3641 3642 def test_create_and_get_evaluation_dataset(mlflow_client, store_type): 3643 if store_type == "file": 3644 pytest.skip("Evaluation datasets not supported for FileStore") 3645 3646 experiment_id = mlflow_client.create_experiment("eval_dataset_test") 3647 3648 dataset = mlflow_client.create_dataset( 3649 name="test_eval_dataset", 3650 experiment_id=experiment_id, 3651 tags={"environment": "test", "version": "1.0"}, 3652 ) 3653 3654 assert dataset.name == "test_eval_dataset" 3655 assert dataset.experiment_ids == [experiment_id] 3656 assert dataset.tags["environment"] == "test" 3657 assert dataset.tags["version"] == "1.0" 3658 assert dataset.dataset_id is not None 3659 3660 retrieved = mlflow_client.get_dataset(dataset.dataset_id) 3661 assert retrieved.name == dataset.name 3662 assert retrieved.dataset_id == dataset.dataset_id 3663 assert retrieved.tags == dataset.tags 3664 3665 3666 def test_search_evaluation_datasets(mlflow_client, store_type): 3667 if store_type == "file": 3668 pytest.skip("Evaluation datasets not supported for FileStore") 3669 3670 exp1 = mlflow_client.create_experiment("eval_search_exp1") 3671 exp2 = mlflow_client.create_experiment("eval_search_exp2") 3672 3673 mlflow_client.create_dataset( 3674 name="search_dataset_1", experiment_id=exp1, tags={"team": "ml", "status": "active"} 3675 ) 3676 3677 mlflow_client.create_dataset( 3678 name="search_dataset_2", 3679 experiment_id=[exp1, exp2], 3680 tags={"team": "data", "status": "active"}, 3681 ) 3682 3683 mlflow_client.create_dataset( 3684 name="search_dataset_3", experiment_id=exp2, tags={"team": "ml", "status": "archived"} 3685 ) 3686 3687 all_datasets = mlflow_client.search_datasets() 3688 assert len(all_datasets) >= 3 3689 3690 exp1_datasets = mlflow_client.search_datasets(experiment_ids=exp1) 3691 dataset_names = [d.name for d in exp1_datasets] 3692 assert "search_dataset_1" in dataset_names 3693 assert "search_dataset_2" in dataset_names 3694 3695 ml_datasets = mlflow_client.search_datasets(filter_string="tags.team = 'ml'") 3696 ml_names = [d.name for d in ml_datasets] 3697 assert "search_dataset_1" in ml_names 3698 assert "search_dataset_3" in ml_names 3699 assert "search_dataset_2" not in ml_names 3700 3701 ordered_datasets = mlflow_client.search_datasets(order_by=["name ASC"]) 3702 names = [d.name for d in ordered_datasets] 3703 assert names == sorted(names) 3704 3705 3706 def test_evaluation_dataset_tag_operations(mlflow_client, store_type): 3707 if store_type == "file": 3708 pytest.skip("Evaluation datasets not supported for FileStore") 3709 3710 experiment_id = mlflow_client.create_experiment("eval_tags_test") 3711 3712 dataset = mlflow_client.create_dataset( 3713 name="tag_test_dataset", 3714 experiment_id=experiment_id, 3715 tags={"initial": "value", "env": "dev"}, 3716 ) 3717 3718 mlflow_client.set_dataset_tags(dataset.dataset_id, {"env": "staging", "new_tag": "new_value"}) 3719 3720 updated = mlflow_client.get_dataset(dataset.dataset_id) 3721 assert updated.tags["initial"] == "value" # Original tag preserved 3722 assert updated.tags["env"] == "staging" # Updated tag 3723 assert updated.tags["new_tag"] == "new_value" # New tag added 3724 3725 mlflow_client.delete_dataset_tag(dataset.dataset_id, "new_tag") 3726 3727 final = mlflow_client.get_dataset(dataset.dataset_id) 3728 assert "new_tag" not in final.tags 3729 assert final.tags["env"] == "staging" # Other tags preserved 3730 3731 3732 def test_evaluation_dataset_delete(mlflow_client, store_type): 3733 if store_type == "file": 3734 pytest.skip("Evaluation datasets not supported for FileStore") 3735 3736 experiment_id = mlflow_client.create_experiment("eval_delete_test") 3737 3738 dataset = mlflow_client.create_dataset( 3739 name="delete_test_dataset", experiment_id=experiment_id, tags={"to_delete": "yes"} 3740 ) 3741 3742 retrieved = mlflow_client.get_dataset(dataset.dataset_id) 3743 assert retrieved.name == "delete_test_dataset" 3744 3745 mlflow_client.delete_dataset(dataset.dataset_id) 3746 3747 with pytest.raises(MlflowException, match="not found"): 3748 mlflow_client.get_dataset(dataset.dataset_id) 3749 3750 3751 def test_evaluation_dataset_upsert_records(mlflow_client, store_type): 3752 if store_type == "file": 3753 pytest.skip("Evaluation datasets not supported for FileStore") 3754 3755 experiment_id = mlflow_client.create_experiment("upsert_records_test") 3756 3757 dataset = mlflow_client.create_dataset( 3758 name="test_upsert_dataset", 3759 experiment_id=experiment_id, 3760 tags={"test": "upsert"}, 3761 ) 3762 3763 initial_records = [ 3764 { 3765 "inputs": {"question": "What is MLflow?"}, 3766 "expectations": {"answer": "MLflow is an ML platform"}, 3767 "tags": {"difficulty": "easy"}, 3768 }, 3769 { 3770 "inputs": {"question": "What is Python?"}, 3771 "expectations": {"answer": "Python is a programming language"}, 3772 "tags": {"difficulty": "easy"}, 3773 }, 3774 ] 3775 3776 # NB: MlflowClient doesn't have upsert_dataset_records method - merge_records() calls 3777 # the store directly. We make HTTP requests here to test the REST API handler end-to-end. 3778 response = requests.post( 3779 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/datasets/{dataset.dataset_id}/records", 3780 json={"records": json.dumps(initial_records)}, 3781 ) 3782 assert response.status_code == 200 3783 result = response.json() 3784 assert result["inserted_count"] == 2 3785 assert result["updated_count"] == 0 3786 3787 update_records = [ 3788 { 3789 "inputs": {"question": "What is MLflow?"}, 3790 "expectations": {"answer": "MLflow is an open-source ML platform"}, 3791 "tags": {"difficulty": "easy", "updated": "true"}, 3792 }, 3793 { 3794 "inputs": {"question": "What is Docker?"}, 3795 "expectations": {"answer": "Docker is a containerization platform"}, 3796 "tags": {"difficulty": "medium"}, 3797 }, 3798 ] 3799 3800 response = requests.post( 3801 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/datasets/{dataset.dataset_id}/records", 3802 json={"records": json.dumps(update_records)}, 3803 ) 3804 assert response.status_code == 200 3805 result = response.json() 3806 assert result["inserted_count"] == 1 3807 assert result["updated_count"] == 1 3808 3809 response = requests.post( 3810 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/datasets/invalid-id/records", 3811 json={"records": json.dumps(initial_records)}, 3812 ) 3813 assert response.status_code != 200 3814 3815 3816 def test_add_dataset_to_experiments_rest_tracking(mlflow_client, store_type): 3817 if store_type == "file": 3818 pytest.skip("File store doesn't support dataset operations") 3819 exp1 = mlflow_client.create_experiment("dataset_exp_1") 3820 exp2 = mlflow_client.create_experiment("dataset_exp_2") 3821 exp3 = mlflow_client.create_experiment("dataset_exp_3") 3822 3823 dataset = create_dataset( 3824 name="test_multi_exp_dataset", 3825 experiment_id=[exp1], 3826 tags={"test": "multi_exp"}, 3827 ) 3828 3829 assert len(dataset.experiment_ids) == 1 3830 assert exp1 in dataset.experiment_ids 3831 3832 updated_dataset = add_dataset_to_experiments( 3833 dataset_id=dataset.dataset_id, 3834 experiment_ids=[exp2, exp3], 3835 ) 3836 3837 assert len(updated_dataset.experiment_ids) == 3 3838 assert exp1 in updated_dataset.experiment_ids 3839 assert exp2 in updated_dataset.experiment_ids 3840 assert exp3 in updated_dataset.experiment_ids 3841 3842 retrieved = mlflow_client.get_dataset(dataset.dataset_id) 3843 assert len(retrieved.experiment_ids) == 3 3844 assert exp1 in retrieved.experiment_ids 3845 assert exp2 in retrieved.experiment_ids 3846 assert exp3 in retrieved.experiment_ids 3847 3848 3849 def test_remove_dataset_from_experiments_rest_tracking(mlflow_client, store_type): 3850 if store_type == "file": 3851 pytest.skip("File store doesn't support dataset operations") 3852 exp1 = mlflow_client.create_experiment("dataset_remove_exp_1") 3853 exp2 = mlflow_client.create_experiment("dataset_remove_exp_2") 3854 exp3 = mlflow_client.create_experiment("dataset_remove_exp_3") 3855 3856 dataset = create_dataset( 3857 name="test_remove_exp_dataset", 3858 experiment_id=[exp1, exp2, exp3], 3859 tags={"test": "remove_exp"}, 3860 ) 3861 3862 assert len(dataset.experiment_ids) == 3 3863 3864 updated_dataset = remove_dataset_from_experiments( 3865 dataset_id=dataset.dataset_id, 3866 experiment_ids=[exp2], 3867 ) 3868 3869 assert len(updated_dataset.experiment_ids) == 2 3870 assert exp1 in updated_dataset.experiment_ids 3871 assert exp2 not in updated_dataset.experiment_ids 3872 assert exp3 in updated_dataset.experiment_ids 3873 3874 retrieved = mlflow_client.get_dataset(dataset.dataset_id) 3875 assert len(retrieved.experiment_ids) == 2 3876 3877 updated_dataset = remove_dataset_from_experiments( 3878 dataset_id=dataset.dataset_id, 3879 experiment_ids=[exp1, exp3], 3880 ) 3881 3882 assert len(updated_dataset.experiment_ids) == 0 3883 3884 retrieved = mlflow_client.get_dataset(dataset.dataset_id) 3885 assert len(retrieved.experiment_ids) == 0 3886 3887 3888 def test_add_multiple_experiments_at_once_rest_tracking(mlflow_client, store_type): 3889 if store_type == "file": 3890 pytest.skip("File store doesn't support dataset operations") 3891 exps = [mlflow_client.create_experiment(f"bulk_add_exp_{i}") for i in range(5)] 3892 3893 dataset = create_dataset( 3894 name="test_bulk_add_dataset", 3895 experiment_id=[exps[0]], 3896 tags={"test": "bulk_add"}, 3897 ) 3898 3899 updated_dataset = add_dataset_to_experiments( 3900 dataset_id=dataset.dataset_id, 3901 experiment_ids=exps[1:], 3902 ) 3903 3904 assert len(updated_dataset.experiment_ids) == 5 3905 for exp in exps: 3906 assert exp in updated_dataset.experiment_ids 3907 3908 3909 def test_dataset_experiment_association_error_cases_rest_tracking(mlflow_client, store_type): 3910 if store_type == "file": 3911 pytest.skip("File store doesn't support dataset operations") 3912 exp1 = mlflow_client.create_experiment("error_test_exp") 3913 3914 with pytest.raises(MlflowException, match="not found"): 3915 add_dataset_to_experiments( 3916 dataset_id="d-nonexistent1234567890abcdef1234", 3917 experiment_ids=[exp1], 3918 ) 3919 3920 with pytest.raises(MlflowException, match="not found"): 3921 remove_dataset_from_experiments( 3922 dataset_id="d-nonexistent1234567890abcdef1234", 3923 experiment_ids=[exp1], 3924 ) 3925 3926 3927 def test_idempotent_add_experiments_rest_tracking(mlflow_client, store_type): 3928 if store_type == "file": 3929 pytest.skip("File store doesn't support dataset operations") 3930 exp1 = mlflow_client.create_experiment("idempotent_test_exp_1") 3931 exp2 = mlflow_client.create_experiment("idempotent_test_exp_2") 3932 3933 dataset = create_dataset( 3934 name="test_idempotent_dataset", 3935 experiment_id=[exp1, exp2], 3936 tags={"test": "idempotent"}, 3937 ) 3938 3939 assert len(dataset.experiment_ids) == 2 3940 3941 updated_dataset = add_dataset_to_experiments( 3942 dataset_id=dataset.dataset_id, 3943 experiment_ids=[exp1], 3944 ) 3945 3946 assert len(updated_dataset.experiment_ids) == 2 3947 assert exp1 in updated_dataset.experiment_ids 3948 assert exp2 in updated_dataset.experiment_ids 3949 3950 3951 def test_idempotent_remove_experiments_rest_tracking(mlflow_client, store_type): 3952 if store_type == "file": 3953 pytest.skip("File store doesn't support dataset operations") 3954 exp1 = mlflow_client.create_experiment("remove_idempotent_test_exp_1") 3955 exp2 = mlflow_client.create_experiment("remove_idempotent_test_exp_2") 3956 3957 dataset = create_dataset( 3958 name="test_remove_idempotent_dataset", 3959 experiment_id=[exp1], 3960 tags={"test": "remove_idempotent"}, 3961 ) 3962 3963 assert len(dataset.experiment_ids) == 1 3964 3965 updated_dataset = remove_dataset_from_experiments( 3966 dataset_id=dataset.dataset_id, 3967 experiment_ids=[exp2], 3968 ) 3969 3970 assert len(updated_dataset.experiment_ids) == 1 3971 assert exp1 in updated_dataset.experiment_ids 3972 3973 3974 def test_client_api_add_remove_experiments_rest_tracking(mlflow_client, store_type): 3975 if store_type == "file": 3976 pytest.skip("File store doesn't support dataset operations") 3977 exp1 = mlflow_client.create_experiment("client_api_exp_1") 3978 exp2 = mlflow_client.create_experiment("client_api_exp_2") 3979 exp3 = mlflow_client.create_experiment("client_api_exp_3") 3980 3981 dataset = mlflow_client.create_dataset( 3982 name="test_client_api_dataset", 3983 experiment_id=[exp1], 3984 tags={"test": "client_api"}, 3985 ) 3986 3987 updated_dataset = mlflow_client.add_dataset_to_experiments( 3988 dataset_id=dataset.dataset_id, 3989 experiment_ids=[exp2, exp3], 3990 ) 3991 3992 assert len(updated_dataset.experiment_ids) == 3 3993 3994 updated_dataset = mlflow_client.remove_dataset_from_experiments( 3995 dataset_id=dataset.dataset_id, 3996 experiment_ids=[exp2], 3997 ) 3998 3999 assert len(updated_dataset.experiment_ids) == 2 4000 assert exp1 in updated_dataset.experiment_ids 4001 assert exp2 not in updated_dataset.experiment_ids 4002 assert exp3 in updated_dataset.experiment_ids 4003 4004 4005 def test_scorer_CRUD(mlflow_client, store_type): 4006 if store_type == "file": 4007 pytest.skip("File store doesn't support scorer CRUD operations") 4008 experiment_id = mlflow_client.create_experiment("test_scorer_api_experiment") 4009 4010 # Get the RestStore object directly 4011 store = mlflow_client._tracking_client.store 4012 4013 # Test register scorer 4014 scorer_data = {"name": "test_scorer", "original_func_name": "test_func"} 4015 serialized_scorer = json.dumps(scorer_data) 4016 4017 version = store.register_scorer(experiment_id, "test_scorer", serialized_scorer) 4018 assert version.scorer_version == 1 4019 4020 # Test list scorers 4021 scorers = store.list_scorers(experiment_id) 4022 assert len(scorers) == 1 4023 assert scorers[0].scorer_name == "test_scorer" 4024 assert scorers[0].scorer_version == 1 4025 4026 # Test list scorer versions 4027 versions = store.list_scorer_versions(str(experiment_id), "test_scorer") 4028 assert len(versions) == 1 4029 assert versions[0].scorer_name == "test_scorer" 4030 assert versions[0].scorer_version == 1 4031 4032 # Test get scorer (latest version) 4033 scorer = store.get_scorer(str(experiment_id), "test_scorer") 4034 assert scorer.scorer_name == "test_scorer" 4035 assert scorer.scorer_version == 1 4036 4037 # Test get scorer (specific version) 4038 scorer_v1 = store.get_scorer(str(experiment_id), "test_scorer", version=1) 4039 assert scorer_v1.scorer_name == "test_scorer" 4040 assert scorer_v1.scorer_version == 1 4041 4042 # Test register second version 4043 scorer_data_v2 = { 4044 "name": "test_scorer_v2", 4045 "original_func_name": "test_func_v2", 4046 } 4047 serialized_scorer_v2 = json.dumps(scorer_data_v2) 4048 4049 version_v2 = store.register_scorer(str(experiment_id), "test_scorer", serialized_scorer_v2) 4050 assert version_v2.scorer_version == 2 4051 4052 # Verify list scorers returns latest version 4053 scorers_after_v2 = store.list_scorers(str(experiment_id)) 4054 assert len(scorers_after_v2) == 1 4055 assert scorers_after_v2[0].scorer_version == 2 4056 4057 # Verify list versions returns both versions 4058 versions_after_v2 = store.list_scorer_versions(str(experiment_id), "test_scorer") 4059 assert len(versions_after_v2) == 2 4060 4061 # Test delete specific version 4062 store.delete_scorer(str(experiment_id), "test_scorer", version=1) 4063 4064 # Verify version 1 is deleted 4065 versions_after_delete = store.list_scorer_versions(str(experiment_id), "test_scorer") 4066 assert len(versions_after_delete) == 1 4067 assert versions_after_delete[0].scorer_version == 2 4068 4069 # Test delete all versions 4070 store.delete_scorer(str(experiment_id), "test_scorer") 4071 4072 # Verify all versions are deleted 4073 scorers_after_delete_all = store.list_scorers(str(experiment_id)) 4074 assert len(scorers_after_delete_all) == 0 4075 4076 # Clean up 4077 mlflow_client.delete_experiment(experiment_id) 4078 4079 4080 @pytest.mark.parametrize( 4081 "filter_string", 4082 [ 4083 "status = 'OK'", 4084 None, 4085 ], 4086 ) 4087 def test_online_scoring_config(mlflow_client_with_secrets, filter_string): 4088 """ 4089 Smoke test for online scoring configuration REST APIs. 4090 Tests upsert_online_scoring_config and get_online_scoring_configs with both 4091 string and None filter values (None is sent by UI when filter field is blank). 4092 """ 4093 experiment_id = mlflow_client_with_secrets.create_experiment("test_online_scoring") 4094 store = mlflow_client_with_secrets._tracking_client.store 4095 4096 secret = store.create_gateway_secret( 4097 secret_name="test-secret", secret_value={"api_key": "sk-test"}, provider="openai" 4098 ) 4099 model_def = store.create_gateway_model_definition( 4100 name="test-model", secret_id=secret.secret_id, provider="openai", model_name="gpt-4" 4101 ) 4102 endpoint = store.create_gateway_endpoint( 4103 name="test-endpoint", 4104 model_configs=[ 4105 GatewayEndpointModelConfig( 4106 model_definition_id=model_def.model_definition_id, 4107 linkage_type=GatewayModelLinkageType.PRIMARY, 4108 ) 4109 ], 4110 ) 4111 4112 scorer_data = {"instructions_judge_pydantic_data": {"model": f"gateway:/{endpoint.name}"}} 4113 serialized_scorer = json.dumps(scorer_data) 4114 scorer_version = store.register_scorer(experiment_id, "my_scorer", serialized_scorer) 4115 scorer_id = scorer_version.scorer_id 4116 4117 config = store.upsert_online_scoring_config( 4118 experiment_id=experiment_id, 4119 scorer_name="my_scorer", 4120 sample_rate=0.5, 4121 filter_string=filter_string, 4122 ) 4123 assert config.scorer_id == scorer_id 4124 assert config.sample_rate == 0.5 4125 assert config.filter_string == filter_string 4126 assert config.experiment_id == experiment_id 4127 4128 configs = store.get_online_scoring_configs([scorer_id]) 4129 assert len(configs) == 1 4130 assert configs[0].scorer_id == scorer_id 4131 assert configs[0].sample_rate == 0.5 4132 assert configs[0].filter_string == filter_string 4133 4134 # Update with different filter string to test update functionality 4135 updated_filter = "status = 'COMPLETED'" 4136 updated_config = store.upsert_online_scoring_config( 4137 experiment_id=experiment_id, 4138 scorer_name="my_scorer", 4139 sample_rate=0.8, 4140 filter_string=updated_filter, 4141 ) 4142 assert updated_config.scorer_id == scorer_id 4143 assert updated_config.sample_rate == 0.8 4144 assert updated_config.filter_string == updated_filter 4145 4146 configs_after_update = store.get_online_scoring_configs([scorer_id]) 4147 assert len(configs_after_update) == 1 4148 assert configs_after_update[0].sample_rate == 0.8 4149 assert configs_after_update[0].filter_string == updated_filter 4150 4151 4152 @pytest.mark.parametrize("use_async", [False, True]) 4153 @pytest.mark.asyncio 4154 async def test_rest_store_logs_spans_via_otel_endpoint(mlflow_client, store_type, use_async): 4155 """ 4156 End-to-end test that verifies RestStore can log spans to a running server via OTLP endpoint. 4157 4158 This test: 4159 1. Creates spans using MLflow's span entities 4160 2. Uses RestStore.log_spans or log_spans_async to send them via OTLP protocol 4161 3. Verifies the spans were stored and can be retrieved 4162 """ 4163 if store_type == "file": 4164 pytest.skip("FileStore does not support OTLP span logging") 4165 4166 experiment_id = mlflow_client.create_experiment(f"rest_store_otel_test_{use_async}") 4167 root_span = mlflow_client.start_trace( 4168 f"rest_store_otel_trace_{use_async}", experiment_id=experiment_id 4169 ) 4170 otel_span = OTelReadableSpan( 4171 name=f"test-rest-store-span-{use_async}", 4172 context=build_otel_context( 4173 trace_id=int(root_span.trace_id[3:], 16), # Remove 'tr-' prefix and convert to int 4174 span_id=0x1234567890ABCDEF, 4175 ), 4176 parent=None, 4177 start_time=1000000000, 4178 end_time=2000000000, 4179 attributes={ 4180 SpanAttributeKey.REQUEST_ID: root_span.trace_id, 4181 "test.attribute": json.dumps(f"test-value-{use_async}"), # JSON-encoded string value 4182 }, 4183 resource=None, 4184 ) 4185 mlflow_span_to_log = Span(otel_span) 4186 # Call either sync or async version based on parametrization 4187 if use_async: 4188 # Use await to execute the async method 4189 result_spans = await mlflow_client._tracking_client.store.log_spans_async( 4190 location=experiment_id, spans=[mlflow_span_to_log] 4191 ) 4192 else: 4193 result_spans = mlflow_client._tracking_client.store.log_spans( 4194 location=experiment_id, spans=[mlflow_span_to_log] 4195 ) 4196 4197 # Verify the spans were returned (indicates successful logging) 4198 assert len(result_spans) == 1 4199 assert result_spans[0].name == f"test-rest-store-span-{use_async}" 4200 4201 4202 # ============================================================================= 4203 # Secrets and Endpoints E2E Tests 4204 # ============================================================================= 4205 4206 4207 def test_create_and_get_secret(mlflow_client_with_secrets): 4208 store = mlflow_client_with_secrets._tracking_client.store 4209 4210 secret = store.create_gateway_secret( 4211 secret_name="test-api-key", 4212 secret_value={"api_key": "sk-test-12345"}, 4213 provider="openai", 4214 ) 4215 4216 assert secret.secret_name == "test-api-key" 4217 assert secret.provider == "openai" 4218 assert secret.secret_id is not None 4219 4220 fetched = store.get_secret_info(secret.secret_id) 4221 assert fetched.secret_name == "test-api-key" 4222 assert fetched.provider == "openai" 4223 assert fetched.secret_id == secret.secret_id 4224 4225 4226 def test_update_secret(mlflow_client_with_secrets): 4227 store = mlflow_client_with_secrets._tracking_client.store 4228 4229 secret = store.create_gateway_secret( 4230 secret_name="test-key", 4231 secret_value={"api_key": "initial-value"}, 4232 provider="anthropic", 4233 ) 4234 4235 updated = store.update_gateway_secret( 4236 secret_id=secret.secret_id, 4237 secret_value={"api_key": "updated-value"}, 4238 ) 4239 4240 assert updated.secret_id == secret.secret_id 4241 assert updated.secret_name == "test-key" 4242 4243 4244 def test_list_secret_infos(mlflow_client_with_secrets): 4245 store = mlflow_client_with_secrets._tracking_client.store 4246 4247 secret1 = store.create_gateway_secret( 4248 secret_name="openai-key", 4249 secret_value={"api_key": "sk-openai"}, 4250 provider="openai", 4251 ) 4252 store.create_gateway_secret( 4253 secret_name="anthropic-key", 4254 secret_value={"api_key": "sk-ant"}, 4255 provider="anthropic", 4256 ) 4257 4258 all_secrets = store.list_secret_infos() 4259 assert len(all_secrets) >= 2 4260 4261 openai_secrets = store.list_secret_infos(provider="openai") 4262 assert len(openai_secrets) >= 1 4263 assert any(s.secret_id == secret1.secret_id for s in openai_secrets) 4264 4265 4266 def test_delete_secret(mlflow_client_with_secrets): 4267 store = mlflow_client_with_secrets._tracking_client.store 4268 4269 secret = store.create_gateway_secret( 4270 secret_name="temp-key", 4271 secret_value={"api_key": "temp-value"}, 4272 ) 4273 4274 store.delete_gateway_secret(secret.secret_id) 4275 4276 all_secrets = store.list_secret_infos() 4277 assert not any(s.secret_id == secret.secret_id for s in all_secrets) 4278 4279 4280 def test_create_secret_with_dict_value(mlflow_client_with_secrets): 4281 store = mlflow_client_with_secrets._tracking_client.store 4282 4283 secret = store.create_gateway_secret( 4284 secret_name="aws-creds", 4285 secret_value={"aws_access_key_id": "AKIATEST1234", "aws_secret_access_key": "secret123abc"}, 4286 provider="bedrock", 4287 ) 4288 4289 assert secret.secret_name == "aws-creds" 4290 assert secret.provider == "bedrock" 4291 assert secret.secret_id is not None 4292 assert isinstance(secret.masked_values, dict) 4293 assert secret.masked_values == { 4294 "aws_access_key_id": "AKI...1234", 4295 "aws_secret_access_key": "sec...3abc", 4296 } 4297 4298 4299 def test_update_secret_with_dict_value(mlflow_client_with_secrets): 4300 store = mlflow_client_with_secrets._tracking_client.store 4301 4302 secret = store.create_gateway_secret( 4303 secret_name="aws-creds-update", 4304 secret_value={"api_key": "initial-value-1234"}, 4305 provider="bedrock", 4306 ) 4307 4308 assert isinstance(secret.masked_values, dict) 4309 assert secret.masked_values == {"api_key": "ini...1234"} 4310 4311 updated = store.update_gateway_secret( 4312 secret_id=secret.secret_id, 4313 secret_value={ 4314 "aws_access_key_id": "NEWKEY123456", 4315 "aws_secret_access_key": "newsecret1234", 4316 }, 4317 ) 4318 4319 assert updated.secret_id == secret.secret_id 4320 assert updated.secret_name == "aws-creds-update" 4321 assert isinstance(updated.masked_values, dict) 4322 assert updated.masked_values == { 4323 "aws_access_key_id": "NEW...3456", 4324 "aws_secret_access_key": "new...1234", 4325 } 4326 4327 4328 def test_create_and_update_compound_secret_via_rest(mlflow_client_with_secrets): 4329 store = mlflow_client_with_secrets._tracking_client.store 4330 4331 secret = store.create_gateway_secret( 4332 secret_name="bedrock-aws-creds", 4333 secret_value={ 4334 "aws_access_key_id": "AKIAORIGINAL1234", 4335 "aws_secret_access_key": "original-secret-key-1234", 4336 }, 4337 provider="bedrock", 4338 auth_config={"auth_mode": "access_keys", "aws_region_name": "us-east-1"}, 4339 ) 4340 4341 assert secret.secret_name == "bedrock-aws-creds" 4342 assert secret.provider == "bedrock" 4343 assert isinstance(secret.masked_values, dict) 4344 assert secret.masked_values == { 4345 "aws_access_key_id": "AKI...1234", 4346 "aws_secret_access_key": "ori...1234", 4347 } 4348 4349 fetched = store.get_secret_info(secret_id=secret.secret_id) 4350 assert fetched.secret_id == secret.secret_id 4351 assert isinstance(fetched.masked_values, dict) 4352 assert fetched.masked_values == secret.masked_values 4353 4354 updated = store.update_gateway_secret( 4355 secret_id=secret.secret_id, 4356 secret_value={ 4357 "aws_access_key_id": "AKIAROTATED5678", 4358 "aws_secret_access_key": "rotated-secret-key-5678", 4359 }, 4360 ) 4361 4362 assert updated.secret_id == secret.secret_id 4363 assert updated.last_updated_at > secret.created_at 4364 assert isinstance(updated.masked_values, dict) 4365 assert updated.masked_values == { 4366 "aws_access_key_id": "AKI...5678", 4367 "aws_secret_access_key": "rot...5678", 4368 } 4369 4370 4371 def test_create_and_get_endpoint(mlflow_client_with_secrets): 4372 store = mlflow_client_with_secrets._tracking_client.store 4373 4374 secret = store.create_gateway_secret( 4375 secret_name="test-api-key", 4376 secret_value={"api_key": "sk-test-12345"}, 4377 provider="openai", 4378 ) 4379 secret2 = store.create_gateway_secret( 4380 secret_name="test-api-key-fallback", 4381 secret_value={"api_key": "sk-test-67890"}, 4382 provider="anthropic", 4383 ) 4384 4385 model_def = store.create_gateway_model_definition( 4386 name="test-model-def", 4387 secret_id=secret.secret_id, 4388 provider="openai", 4389 model_name="gpt-4", 4390 ) 4391 model_def_fallback = store.create_gateway_model_definition( 4392 name="test-model-def-fallback", 4393 secret_id=secret2.secret_id, 4394 provider="anthropic", 4395 model_name="claude-3-5-sonnet", 4396 ) 4397 4398 endpoint = store.create_gateway_endpoint( 4399 name="test-endpoint", 4400 model_configs=[ 4401 GatewayEndpointModelConfig( 4402 model_definition_id=model_def.model_definition_id, 4403 linkage_type=GatewayModelLinkageType.PRIMARY, 4404 weight=1.0, 4405 ), 4406 GatewayEndpointModelConfig( 4407 model_definition_id=model_def_fallback.model_definition_id, 4408 linkage_type=GatewayModelLinkageType.FALLBACK, 4409 weight=1.0, 4410 fallback_order=0, 4411 ), 4412 ], 4413 routing_strategy=RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT, 4414 fallback_config=FallbackConfig( 4415 strategy=FallbackStrategy.SEQUENTIAL, 4416 max_attempts=2, 4417 ), 4418 ) 4419 4420 assert endpoint.name == "test-endpoint" 4421 assert endpoint.endpoint_id is not None 4422 assert len(endpoint.model_mappings) == 2 4423 assert endpoint.model_mappings[0].model_definition.model_name == "gpt-4" 4424 assert endpoint.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT 4425 assert endpoint.fallback_config is not None 4426 assert endpoint.fallback_config.strategy == FallbackStrategy.SEQUENTIAL 4427 assert endpoint.fallback_config.max_attempts == 2 4428 4429 fetched = store.get_gateway_endpoint(endpoint.endpoint_id) 4430 assert fetched.name == "test-endpoint" 4431 assert fetched.endpoint_id == endpoint.endpoint_id 4432 assert len(fetched.model_mappings) == 2 4433 assert fetched.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT 4434 assert fetched.fallback_config is not None 4435 assert fetched.fallback_config.strategy == FallbackStrategy.SEQUENTIAL 4436 assert fetched.fallback_config.max_attempts == 2 4437 4438 4439 def test_create_endpoint_with_usage_tracking(mlflow_client_with_secrets): 4440 store = mlflow_client_with_secrets._tracking_client.store 4441 4442 secret = store.create_gateway_secret( 4443 secret_name="usage-tracking-test-key", 4444 secret_value={"api_key": "sk-usage-tracking-test"}, 4445 provider="openai", 4446 ) 4447 4448 model_def = store.create_gateway_model_definition( 4449 name="usage-tracking-model-def", 4450 secret_id=secret.secret_id, 4451 provider="openai", 4452 model_name="gpt-4", 4453 ) 4454 4455 endpoint = store.create_gateway_endpoint( 4456 name="usage-tracking-endpoint", 4457 model_configs=[ 4458 GatewayEndpointModelConfig( 4459 model_definition_id=model_def.model_definition_id, 4460 linkage_type=GatewayModelLinkageType.PRIMARY, 4461 weight=1.0, 4462 ) 4463 ], 4464 usage_tracking=True, 4465 ) 4466 4467 assert endpoint.usage_tracking is True 4468 experiment_id = endpoint.experiment_id 4469 4470 # Experiment is automatically created with usage tracking enabled 4471 experiment = mlflow_client_with_secrets.get_experiment(experiment_id) 4472 assert experiment.name == "gateway/usage-tracking-endpoint" 4473 4474 4475 def test_update_endpoint(mlflow_client_with_secrets): 4476 store = mlflow_client_with_secrets._tracking_client.store 4477 4478 secret = store.create_gateway_secret( 4479 secret_name="test-api-key-2", 4480 secret_value={"api_key": "sk-test-67890"}, 4481 provider="anthropic", 4482 ) 4483 secret2 = store.create_gateway_secret( 4484 secret_name="test-api-key-2-fallback", 4485 secret_value={"api_key": "sk-test-99999"}, 4486 provider="openai", 4487 ) 4488 4489 model_def = store.create_gateway_model_definition( 4490 name="test-model-def-2", 4491 secret_id=secret.secret_id, 4492 provider="anthropic", 4493 model_name="claude-3-5-sonnet", 4494 ) 4495 model_def_fallback = store.create_gateway_model_definition( 4496 name="test-model-def-2-fallback", 4497 secret_id=secret2.secret_id, 4498 provider="openai", 4499 model_name="gpt-4", 4500 ) 4501 4502 endpoint = store.create_gateway_endpoint( 4503 name="initial-name", 4504 model_configs=[ 4505 GatewayEndpointModelConfig( 4506 model_definition_id=model_def.model_definition_id, 4507 linkage_type=GatewayModelLinkageType.PRIMARY, 4508 weight=1.0, 4509 ), 4510 ], 4511 ) 4512 4513 updated = store.update_gateway_endpoint( 4514 endpoint_id=endpoint.endpoint_id, 4515 name="updated-name", 4516 model_configs=[ 4517 GatewayEndpointModelConfig( 4518 model_definition_id=model_def.model_definition_id, 4519 linkage_type=GatewayModelLinkageType.PRIMARY, 4520 weight=1.0, 4521 ), 4522 GatewayEndpointModelConfig( 4523 model_definition_id=model_def_fallback.model_definition_id, 4524 linkage_type=GatewayModelLinkageType.FALLBACK, 4525 weight=1.0, 4526 fallback_order=0, 4527 ), 4528 ], 4529 routing_strategy=RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT, 4530 fallback_config=FallbackConfig( 4531 strategy=FallbackStrategy.SEQUENTIAL, 4532 max_attempts=3, 4533 ), 4534 ) 4535 4536 assert updated.endpoint_id == endpoint.endpoint_id 4537 assert updated.name == "updated-name" 4538 assert updated.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT 4539 assert updated.fallback_config is not None 4540 assert updated.fallback_config.strategy == FallbackStrategy.SEQUENTIAL 4541 assert updated.fallback_config.max_attempts == 3 4542 assert len(updated.model_mappings) == 2 4543 4544 4545 def test_list_endpoints(mlflow_client_with_secrets): 4546 store = mlflow_client_with_secrets._tracking_client.store 4547 4548 secret1 = store.create_gateway_secret( 4549 secret_name="test-api-key-3", 4550 secret_value={"api_key": "sk-test-11111"}, 4551 provider="openai", 4552 ) 4553 secret2 = store.create_gateway_secret( 4554 secret_name="test-api-key-4", 4555 secret_value={"api_key": "sk-test-22222"}, 4556 provider="openai", 4557 ) 4558 secret3 = store.create_gateway_secret( 4559 secret_name="test-api-key-fallback-3", 4560 secret_value={"api_key": "sk-test-44444"}, 4561 provider="anthropic", 4562 ) 4563 4564 model_def1 = store.create_gateway_model_definition( 4565 name="test-model-def-3", 4566 secret_id=secret1.secret_id, 4567 provider="openai", 4568 model_name="gpt-4", 4569 ) 4570 model_def2 = store.create_gateway_model_definition( 4571 name="test-model-def-4", 4572 secret_id=secret2.secret_id, 4573 provider="openai", 4574 model_name="gpt-3.5-turbo", 4575 ) 4576 model_def3 = store.create_gateway_model_definition( 4577 name="test-model-def-fallback-3", 4578 secret_id=secret3.secret_id, 4579 provider="anthropic", 4580 model_name="claude-3-5-sonnet", 4581 ) 4582 4583 # Create endpoint without fallback 4584 endpoint1 = store.create_gateway_endpoint( 4585 name="endpoint-1", 4586 model_configs=[ 4587 GatewayEndpointModelConfig( 4588 model_definition_id=model_def1.model_definition_id, 4589 linkage_type=GatewayModelLinkageType.PRIMARY, 4590 weight=1.0, 4591 ), 4592 ], 4593 ) 4594 # Create endpoint with fallback 4595 endpoint2 = store.create_gateway_endpoint( 4596 name="endpoint-2", 4597 model_configs=[ 4598 GatewayEndpointModelConfig( 4599 model_definition_id=model_def2.model_definition_id, 4600 linkage_type=GatewayModelLinkageType.PRIMARY, 4601 weight=1.0, 4602 ), 4603 GatewayEndpointModelConfig( 4604 model_definition_id=model_def3.model_definition_id, 4605 linkage_type=GatewayModelLinkageType.FALLBACK, 4606 weight=1.0, 4607 fallback_order=0, 4608 ), 4609 ], 4610 routing_strategy=RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT, 4611 fallback_config=FallbackConfig( 4612 strategy=FallbackStrategy.SEQUENTIAL, 4613 max_attempts=2, 4614 ), 4615 ) 4616 4617 all_endpoints = store.list_gateway_endpoints() 4618 assert len(all_endpoints) >= 2 4619 endpoint_ids = {e.endpoint_id for e in all_endpoints} 4620 assert endpoint1.endpoint_id in endpoint_ids 4621 assert endpoint2.endpoint_id in endpoint_ids 4622 4623 # Find and verify endpoints 4624 found_ep1 = next(e for e in all_endpoints if e.endpoint_id == endpoint1.endpoint_id) 4625 found_ep2 = next(e for e in all_endpoints if e.endpoint_id == endpoint2.endpoint_id) 4626 4627 assert found_ep1.routing_strategy is None 4628 assert found_ep1.fallback_config is None 4629 4630 assert found_ep2.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT 4631 assert found_ep2.fallback_config is not None 4632 assert found_ep2.fallback_config.strategy == FallbackStrategy.SEQUENTIAL 4633 assert found_ep2.fallback_config.max_attempts == 2 4634 4635 4636 def test_delete_endpoint(mlflow_client_with_secrets): 4637 store = mlflow_client_with_secrets._tracking_client.store 4638 4639 secret = store.create_gateway_secret( 4640 secret_name="test-api-key-5", 4641 secret_value={"api_key": "sk-test-33333"}, 4642 provider="openai", 4643 ) 4644 4645 model_def = store.create_gateway_model_definition( 4646 name="test-model-def-5", 4647 secret_id=secret.secret_id, 4648 provider="openai", 4649 model_name="gpt-4", 4650 ) 4651 4652 endpoint = store.create_gateway_endpoint( 4653 name="temp-endpoint", 4654 model_configs=[ 4655 GatewayEndpointModelConfig( 4656 model_definition_id=model_def.model_definition_id, 4657 linkage_type=GatewayModelLinkageType.PRIMARY, 4658 weight=1.0, 4659 ), 4660 ], 4661 ) 4662 4663 store.delete_gateway_endpoint(endpoint.endpoint_id) 4664 4665 all_endpoints = store.list_gateway_endpoints() 4666 assert not any(e.endpoint_id == endpoint.endpoint_id for e in all_endpoints) 4667 4668 4669 def test_model_definitions(mlflow_client_with_secrets): 4670 store = mlflow_client_with_secrets._tracking_client.store 4671 4672 secret = store.create_gateway_secret( 4673 secret_name="model-secret", 4674 secret_value={"api_key": "sk-test"}, 4675 provider="openai", 4676 ) 4677 4678 model_def = store.create_gateway_model_definition( 4679 name="test-model-def", 4680 secret_id=secret.secret_id, 4681 provider="openai", 4682 model_name="gpt-4", 4683 ) 4684 4685 assert model_def.name == "test-model-def" 4686 assert model_def.secret_id == secret.secret_id 4687 assert model_def.provider == "openai" 4688 assert model_def.model_name == "gpt-4" 4689 assert model_def.model_definition_id is not None 4690 4691 fetched = store.get_gateway_model_definition(model_def.model_definition_id) 4692 assert fetched.model_definition_id == model_def.model_definition_id 4693 assert fetched.name == "test-model-def" 4694 4695 updated = store.update_gateway_model_definition( 4696 model_definition_id=model_def.model_definition_id, 4697 model_name="gpt-4-turbo", 4698 ) 4699 assert updated.model_definition_id == model_def.model_definition_id 4700 assert updated.model_name == "gpt-4-turbo" 4701 4702 all_defs = store.list_gateway_model_definitions() 4703 assert any(d.model_definition_id == model_def.model_definition_id for d in all_defs) 4704 4705 store.delete_gateway_model_definition(model_def.model_definition_id) 4706 4707 all_defs_after = store.list_gateway_model_definitions() 4708 assert not any(d.model_definition_id == model_def.model_definition_id for d in all_defs_after) 4709 4710 4711 def test_attach_detach_model_to_endpoint(mlflow_client_with_secrets): 4712 store = mlflow_client_with_secrets._tracking_client.store 4713 4714 secret = store.create_gateway_secret( 4715 secret_name="attach-detach-secret", 4716 secret_value={"api_key": "sk-test-attach"}, 4717 provider="openai", 4718 ) 4719 4720 model_def1 = store.create_gateway_model_definition( 4721 name="attach-model-def-1", 4722 secret_id=secret.secret_id, 4723 provider="openai", 4724 model_name="gpt-4", 4725 ) 4726 4727 model_def2 = store.create_gateway_model_definition( 4728 name="attach-model-def-2", 4729 secret_id=secret.secret_id, 4730 provider="openai", 4731 model_name="gpt-3.5-turbo", 4732 ) 4733 4734 endpoint = store.create_gateway_endpoint( 4735 name="attach-test-endpoint", 4736 model_configs=[ 4737 GatewayEndpointModelConfig( 4738 model_definition_id=model_def1.model_definition_id, 4739 linkage_type=GatewayModelLinkageType.PRIMARY, 4740 weight=1.0, 4741 ), 4742 ], 4743 ) 4744 4745 assert len(endpoint.model_mappings) == 1 4746 assert endpoint.model_mappings[0].model_definition.model_name == "gpt-4" 4747 4748 mapping = store.attach_model_to_endpoint( 4749 endpoint_id=endpoint.endpoint_id, 4750 model_config=GatewayEndpointModelConfig( 4751 model_definition_id=model_def2.model_definition_id, 4752 linkage_type=GatewayModelLinkageType.PRIMARY, 4753 weight=1.0, 4754 ), 4755 ) 4756 4757 assert mapping.endpoint_id == endpoint.endpoint_id 4758 assert mapping.model_definition_id == model_def2.model_definition_id 4759 4760 fetched_endpoint = store.get_gateway_endpoint(endpoint.endpoint_id) 4761 assert len(fetched_endpoint.model_mappings) == 2 4762 4763 store.detach_model_from_endpoint( 4764 endpoint_id=endpoint.endpoint_id, 4765 model_definition_id=model_def2.model_definition_id, 4766 ) 4767 4768 fetched_endpoint_after = store.get_gateway_endpoint(endpoint.endpoint_id) 4769 assert len(fetched_endpoint_after.model_mappings) == 1 4770 4771 4772 def test_endpoint_bindings(mlflow_client_with_secrets): 4773 store = mlflow_client_with_secrets._tracking_client.store 4774 4775 secret = store.create_gateway_secret( 4776 secret_name="binding-secret", 4777 secret_value={"api_key": "sk-test-44444"}, 4778 provider="openai", 4779 ) 4780 4781 model_def1 = store.create_gateway_model_definition( 4782 name="binding-model-def-1", 4783 secret_id=secret.secret_id, 4784 provider="openai", 4785 model_name="gpt-4", 4786 ) 4787 4788 model_def2 = store.create_gateway_model_definition( 4789 name="binding-model-def-2", 4790 secret_id=secret.secret_id, 4791 provider="openai", 4792 model_name="gpt-3.5-turbo", 4793 ) 4794 4795 endpoint1 = store.create_gateway_endpoint( 4796 name="binding-test-endpoint-1", 4797 model_configs=[ 4798 GatewayEndpointModelConfig( 4799 model_definition_id=model_def1.model_definition_id, 4800 linkage_type=GatewayModelLinkageType.PRIMARY, 4801 weight=1.0, 4802 ), 4803 ], 4804 ) 4805 4806 endpoint2 = store.create_gateway_endpoint( 4807 name="binding-test-endpoint-2", 4808 model_configs=[ 4809 GatewayEndpointModelConfig( 4810 model_definition_id=model_def2.model_definition_id, 4811 linkage_type=GatewayModelLinkageType.PRIMARY, 4812 weight=1.0, 4813 ), 4814 ], 4815 ) 4816 4817 binding1 = store.create_endpoint_binding( 4818 endpoint_id=endpoint1.endpoint_id, 4819 resource_type=GatewayResourceType.SCORER, 4820 resource_id="job-123", 4821 ) 4822 4823 binding2 = store.create_endpoint_binding( 4824 endpoint_id=endpoint1.endpoint_id, 4825 resource_type=GatewayResourceType.SCORER, 4826 resource_id="job-456", 4827 ) 4828 4829 binding3 = store.create_endpoint_binding( 4830 endpoint_id=endpoint2.endpoint_id, 4831 resource_type=GatewayResourceType.SCORER, 4832 resource_id="job-789", 4833 ) 4834 4835 assert binding1.endpoint_id == endpoint1.endpoint_id 4836 assert binding1.resource_type == GatewayResourceType.SCORER 4837 assert binding1.resource_id == "job-123" 4838 4839 bindings_endpoint1 = store.list_endpoint_bindings(endpoint_id=endpoint1.endpoint_id) 4840 assert len(bindings_endpoint1) == 2 4841 resource_ids = {b.resource_id for b in bindings_endpoint1} 4842 assert binding1.resource_id in resource_ids 4843 assert binding2.resource_id in resource_ids 4844 assert binding3.resource_id not in resource_ids 4845 4846 bindings_by_type = store.list_endpoint_bindings(resource_type=GatewayResourceType.SCORER) 4847 assert len(bindings_by_type) >= 3 4848 4849 bindings_by_resource = store.list_endpoint_bindings(resource_id="job-123") 4850 assert len(bindings_by_resource) == 1 4851 assert bindings_by_resource[0].resource_id == binding1.resource_id 4852 4853 bindings_multi = store.list_endpoint_bindings( 4854 endpoint_id=endpoint1.endpoint_id, 4855 resource_type=GatewayResourceType.SCORER, 4856 ) 4857 assert len(bindings_multi) == 2 4858 4859 store.delete_endpoint_binding( 4860 endpoint_id=binding1.endpoint_id, 4861 resource_type=binding1.resource_type.value, 4862 resource_id=binding1.resource_id, 4863 ) 4864 4865 bindings_after = store.list_endpoint_bindings(endpoint_id=endpoint1.endpoint_id) 4866 assert len(bindings_after) == 1 4867 assert not any(b.resource_id == binding1.resource_id for b in bindings_after) 4868 4869 4870 def test_secrets_and_endpoints_integration(mlflow_client_with_secrets): 4871 store = mlflow_client_with_secrets._tracking_client.store 4872 4873 secret = store.create_gateway_secret( 4874 secret_name="integration-test-key", 4875 secret_value={"api_key": "sk-integration-test"}, 4876 provider="openai", 4877 ) 4878 4879 model_def1 = store.create_gateway_model_definition( 4880 name="integration-model-def-1", 4881 secret_id=secret.secret_id, 4882 provider="openai", 4883 model_name="gpt-3.5-turbo", 4884 ) 4885 4886 model_def2 = store.create_gateway_model_definition( 4887 name="integration-model-def-2", 4888 secret_id=secret.secret_id, 4889 provider="openai", 4890 model_name="gpt-4", 4891 ) 4892 4893 endpoint = store.create_gateway_endpoint( 4894 name="integration-endpoint", 4895 model_configs=[ 4896 GatewayEndpointModelConfig( 4897 model_definition_id=model_def1.model_definition_id, 4898 linkage_type=GatewayModelLinkageType.PRIMARY, 4899 weight=1.0, 4900 ), 4901 ], 4902 ) 4903 4904 mapping = store.attach_model_to_endpoint( 4905 endpoint_id=endpoint.endpoint_id, 4906 model_config=GatewayEndpointModelConfig( 4907 model_definition_id=model_def2.model_definition_id, 4908 linkage_type=GatewayModelLinkageType.PRIMARY, 4909 weight=1.0, 4910 ), 4911 ) 4912 4913 binding = store.create_endpoint_binding( 4914 endpoint_id=endpoint.endpoint_id, 4915 resource_type=GatewayResourceType.SCORER, 4916 resource_id="integration-job", 4917 ) 4918 4919 fetched_endpoint = store.get_gateway_endpoint(endpoint.endpoint_id) 4920 assert len(fetched_endpoint.model_mappings) == 2 4921 mapping_ids = {m.mapping_id for m in fetched_endpoint.model_mappings} 4922 assert mapping.mapping_id in mapping_ids 4923 4924 bindings = store.list_endpoint_bindings(resource_id="integration-job") 4925 assert len(bindings) == 1 4926 assert bindings[0].resource_id == binding.resource_id 4927 4928 store.delete_endpoint_binding( 4929 endpoint_id=binding.endpoint_id, 4930 resource_type=binding.resource_type.value, 4931 resource_id=binding.resource_id, 4932 ) 4933 store.detach_model_from_endpoint( 4934 endpoint_id=endpoint.endpoint_id, 4935 model_definition_id=model_def2.model_definition_id, 4936 ) 4937 store.delete_gateway_endpoint(endpoint.endpoint_id) 4938 store.delete_gateway_model_definition(model_def1.model_definition_id) 4939 store.delete_gateway_model_definition(model_def2.model_definition_id) 4940 store.delete_gateway_secret(secret.secret_id) 4941 4942 4943 def test_list_providers(mlflow_client_with_secrets): 4944 import requests 4945 4946 base_url = mlflow_client_with_secrets._tracking_client.tracking_uri 4947 response = requests.get(f"{base_url}/ajax-api/3.0/mlflow/gateway/supported-providers") 4948 assert response.status_code == 200 4949 data = response.json() 4950 assert "providers" in data 4951 assert isinstance(data["providers"], list) 4952 assert len(data["providers"]) > 0 4953 assert "openai" in data["providers"] 4954 4955 4956 def test_list_models(mlflow_client_with_secrets): 4957 import requests 4958 4959 base_url = mlflow_client_with_secrets._tracking_client.tracking_uri 4960 response = requests.get(f"{base_url}/ajax-api/3.0/mlflow/gateway/supported-models") 4961 assert response.status_code == 200 4962 data = response.json() 4963 assert "models" in data 4964 assert isinstance(data["models"], list) 4965 assert len(data["models"]) > 0 4966 4967 model = data["models"][0] 4968 assert "model" in model 4969 assert "provider" in model 4970 assert "mode" in model 4971 assert all(not m["model"].startswith("ft:") for m in data["models"]) 4972 4973 response = requests.get( 4974 f"{base_url}/ajax-api/3.0/mlflow/gateway/supported-models", params={"provider": "openai"} 4975 ) 4976 assert response.status_code == 200 4977 filtered_data = response.json() 4978 assert all(m["provider"] == "openai" for m in filtered_data["models"]) 4979 4980 4981 def test_get_provider_config(mlflow_client_with_secrets): 4982 import requests 4983 4984 base_url = mlflow_client_with_secrets._tracking_client.tracking_uri 4985 4986 # Test simple provider (openai) - should have single api_key auth mode 4987 response = requests.get( 4988 f"{base_url}/ajax-api/3.0/mlflow/gateway/provider-config", 4989 params={"provider": "openai"}, 4990 ) 4991 assert response.status_code == 200 4992 data = response.json() 4993 assert "auth_modes" in data 4994 assert "default_mode" in data 4995 assert data["default_mode"] == "api_key" 4996 assert len(data["auth_modes"]) >= 1 4997 api_key_mode = data["auth_modes"][0] 4998 assert api_key_mode["mode"] == "api_key" 4999 5000 # Test multi-mode provider (bedrock) - should have multiple auth modes 5001 response = requests.get( 5002 f"{base_url}/ajax-api/3.0/mlflow/gateway/provider-config", 5003 params={"provider": "bedrock"}, 5004 ) 5005 assert response.status_code == 200 5006 data = response.json() 5007 assert "auth_modes" in data 5008 assert data["default_mode"] == "api_key" 5009 assert len(data["auth_modes"]) >= 2 # api_key, access_keys, iam_role 5010 5011 # Check access_keys mode structure 5012 access_keys_mode = next(m for m in data["auth_modes"] if m["mode"] == "access_keys") 5013 assert len(access_keys_mode["secret_fields"]) == 2 # access_key_id, secret_access_key 5014 assert any(f["name"] == "aws_secret_access_key" for f in access_keys_mode["secret_fields"]) 5015 assert any(f["name"] == "aws_region_name" for f in access_keys_mode["config_fields"]) 5016 5017 # Check iam_role mode exists 5018 iam_role_mode = next(m for m in data["auth_modes"] if m["mode"] == "iam_role") 5019 assert any(f["name"] == "aws_role_name" for f in iam_role_mode["config_fields"]) 5020 5021 # Unknown providers get a generic fallback 5022 response = requests.get( 5023 f"{base_url}/ajax-api/3.0/mlflow/gateway/provider-config", 5024 params={"provider": "unknown_provider"}, 5025 ) 5026 assert response.status_code == 200 5027 data = response.json() 5028 assert data["default_mode"] == "api_key" 5029 assert data["auth_modes"][0]["mode"] == "api_key" 5030 assert data["auth_modes"][0]["config_fields"][0]["name"] == "api_base" 5031 5032 # Missing provider parameter returns 400 5033 response = requests.get(f"{base_url}/ajax-api/3.0/mlflow/gateway/provider-config") 5034 assert response.status_code == 400 5035 5036 5037 def test_get_secrets_config_with_custom_passphrase(mlflow_client_with_secrets): 5038 base_url = mlflow_client_with_secrets._tracking_client.tracking_uri 5039 5040 response = requests.get(f"{base_url}/ajax-api/3.0/mlflow/gateway/secrets/config") 5041 assert response.status_code == 200 5042 data = response.json() 5043 assert data["secrets_available"] is True 5044 assert data["using_default_passphrase"] is False 5045 5046 5047 def test_get_secrets_config_with_default_passphrase(tmp_path: Path, monkeypatch): 5048 from tests.tracking.integration_test_utils import ServerThread, get_safe_port 5049 5050 monkeypatch.delenv("MLFLOW_CRYPTO_KEK_PASSPHRASE", raising=False) 5051 5052 backend_uri = f"sqlite:///{tmp_path}/mlflow.db" 5053 artifact_uri = (tmp_path / "artifacts").as_uri() 5054 5055 store = SqlAlchemyStore(backend_uri, artifact_uri) 5056 store.engine.dispose() 5057 5058 handlers._tracking_store = None 5059 handlers._model_registry_store = None 5060 initialize_backend_stores(backend_uri, default_artifact_root=artifact_uri) 5061 5062 with ServerThread(app, get_safe_port()) as url: 5063 response = requests.get(f"{url}/ajax-api/3.0/mlflow/gateway/secrets/config") 5064 assert response.status_code == 200 5065 data = response.json() 5066 assert data["secrets_available"] is True 5067 assert data["using_default_passphrase"] is True 5068 5069 5070 def test_endpoint_with_orphaned_model_definition(mlflow_client_with_secrets): 5071 store = mlflow_client_with_secrets._tracking_client.store 5072 5073 secret = store.create_gateway_secret( 5074 secret_name="orphan-test-key", 5075 secret_value={"api_key": "sk-orphan-test"}, 5076 provider="openai", 5077 ) 5078 5079 model_def = store.create_gateway_model_definition( 5080 name="orphan-model-def", 5081 secret_id=secret.secret_id, 5082 provider="openai", 5083 model_name="gpt-4", 5084 ) 5085 5086 endpoint = store.create_gateway_endpoint( 5087 name="orphan-test-endpoint", 5088 model_configs=[ 5089 GatewayEndpointModelConfig( 5090 model_definition_id=model_def.model_definition_id, 5091 linkage_type=GatewayModelLinkageType.PRIMARY, 5092 weight=1.0, 5093 ), 5094 ], 5095 ) 5096 5097 assert len(endpoint.model_mappings) == 1 5098 assert endpoint.model_mappings[0].model_definition.secret_id == secret.secret_id 5099 assert endpoint.model_mappings[0].model_definition.secret_name == "orphan-test-key" 5100 5101 store.delete_gateway_secret(secret.secret_id) 5102 5103 fetched_endpoint = store.get_gateway_endpoint(endpoint.endpoint_id) 5104 assert len(fetched_endpoint.model_mappings) == 1 5105 assert fetched_endpoint.model_mappings[0].model_definition.secret_id is None 5106 assert fetched_endpoint.model_mappings[0].model_definition.secret_name is None 5107 5108 5109 def test_update_model_definition_provider(mlflow_client_with_secrets): 5110 store = mlflow_client_with_secrets._tracking_client.store 5111 5112 secret = store.create_gateway_secret( 5113 secret_name="provider-update-secret", 5114 secret_value={"api_key": "sk-provider-test"}, 5115 provider="openai", 5116 ) 5117 5118 model_def = store.create_gateway_model_definition( 5119 name="provider-update-model-def", 5120 secret_id=secret.secret_id, 5121 provider="openai", 5122 model_name="gpt-4", 5123 ) 5124 5125 assert model_def.provider == "openai" 5126 assert model_def.model_name == "gpt-4" 5127 5128 updated = store.update_gateway_model_definition( 5129 model_definition_id=model_def.model_definition_id, 5130 provider="anthropic", 5131 model_name="claude-3-5-haiku-latest", 5132 ) 5133 5134 assert updated.provider == "anthropic" 5135 assert updated.model_name == "claude-3-5-haiku-latest" 5136 5137 fetched = store.get_gateway_model_definition(model_def.model_definition_id) 5138 assert fetched.provider == "anthropic" 5139 assert fetched.model_name == "claude-3-5-haiku-latest" 5140 5141 store.delete_gateway_model_definition(model_def.model_definition_id) 5142 store.delete_gateway_secret(secret.secret_id) 5143 5144 5145 def test_create_issue_with_all_fields(mlflow_client, store_type): 5146 if store_type == "file": 5147 pytest.skip("Issues are only supported in SqlAlchemyStore") 5148 5149 mlflow.set_tracking_uri(mlflow_client.tracking_uri) 5150 experiment_id = mlflow_client.create_experiment("Issue Test") 5151 run = mlflow_client.create_run(experiment_id) 5152 5153 response = requests.post( 5154 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5155 json={ 5156 "experiment_id": experiment_id, 5157 "name": "High latency issue", 5158 "description": "API calls are taking too long", 5159 "status": IssueStatus.PENDING.value, 5160 "source_run_id": run.info.run_id, 5161 "root_causes": ["Database query inefficiency", "Network latency"], 5162 "severity": IssueSeverity.HIGH.value, 5163 "created_by": "test-user", 5164 }, 5165 ) 5166 assert response.status_code == 200 5167 data = response.json() 5168 assert "issue" in data 5169 issue = data["issue"] 5170 assert issue["experiment_id"] == experiment_id 5171 assert issue["name"] == "High latency issue" 5172 assert issue["description"] == "API calls are taking too long" 5173 assert issue["status"] == IssueStatus.PENDING.value 5174 assert issue["source_run_id"] == run.info.run_id 5175 assert issue["root_causes"] == ["Database query inefficiency", "Network latency"] 5176 assert issue["severity"] == IssueSeverity.HIGH.value 5177 assert issue["created_by"] == "test-user" 5178 assert "issue_id" in issue 5179 assert "created_timestamp" in issue 5180 assert "last_updated_timestamp" in issue 5181 5182 5183 def test_create_issue_minimal_fields(mlflow_client, store_type): 5184 if store_type == "file": 5185 pytest.skip("Issues are only supported in SqlAlchemyStore") 5186 experiment_id = mlflow_client.create_experiment("Issue Test Minimal") 5187 5188 response = requests.post( 5189 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5190 json={ 5191 "experiment_id": experiment_id, 5192 "name": "Test issue", 5193 "description": "Test description", 5194 }, 5195 ) 5196 assert response.status_code == 200 5197 data = response.json() 5198 issue = data["issue"] 5199 assert issue["experiment_id"] == experiment_id 5200 assert issue["name"] == "Test issue" 5201 assert issue["description"] == "Test description" 5202 assert issue["status"] == IssueStatus.PENDING.value 5203 assert "issue_id" in issue 5204 5205 5206 def test_create_issue_with_required_fields(mlflow_client, store_type): 5207 if store_type == "file": 5208 pytest.skip("Issues are only supported in SqlAlchemyStore") 5209 experiment_id = mlflow_client.create_experiment("Issue Test Required Fields") 5210 5211 response = requests.post( 5212 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5213 json={ 5214 "experiment_id": experiment_id, 5215 "name": "Issue with required fields only", 5216 "description": "Testing issue creation with required fields", 5217 "status": IssueStatus.RESOLVED.value, 5218 }, 5219 ) 5220 assert response.status_code == 200 5221 data = response.json() 5222 issue = data["issue"] 5223 assert issue["status"] == IssueStatus.RESOLVED.value 5224 assert "issue_id" in issue 5225 assert "created_timestamp" in issue 5226 assert "last_updated_timestamp" in issue 5227 5228 5229 def test_create_issue_invalid_experiment(mlflow_client, store_type): 5230 if store_type == "file": 5231 pytest.skip("Issues are only supported in SqlAlchemyStore") 5232 response = requests.post( 5233 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5234 json={ 5235 "experiment_id": "999999", 5236 "name": "Test issue", 5237 "description": "Test description", 5238 }, 5239 ) 5240 assert response.status_code == 404 5241 data = response.json() 5242 assert data["error_code"] == "RESOURCE_DOES_NOT_EXIST" 5243 5244 5245 def test_get_issue(mlflow_client, store_type): 5246 if store_type == "file": 5247 pytest.skip("Issues are only supported in SqlAlchemyStore") 5248 experiment_id = mlflow_client.create_experiment("Issue Test Get") 5249 5250 create_response = requests.post( 5251 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5252 json={ 5253 "experiment_id": experiment_id, 5254 "name": "Test issue", 5255 "description": "Test description", 5256 "severity": IssueSeverity.MEDIUM.value, 5257 }, 5258 ) 5259 issue_id = create_response.json()["issue"]["issue_id"] 5260 5261 get_response = requests.get(f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/{issue_id}") 5262 assert get_response.status_code == 200 5263 data = get_response.json() 5264 issue = data["issue"] 5265 assert issue["issue_id"] == issue_id 5266 assert issue["name"] == "Test issue" 5267 assert issue["severity"] == IssueSeverity.MEDIUM.value 5268 5269 5270 def test_get_issue_not_found(mlflow_client, store_type): 5271 if store_type == "file": 5272 pytest.skip("Issues are only supported in SqlAlchemyStore") 5273 response = requests.get(f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/nonexistent-issue") 5274 assert response.status_code == 404 5275 data = response.json() 5276 assert data["error_code"] == "RESOURCE_DOES_NOT_EXIST" 5277 5278 5279 def test_update_issue(mlflow_client, store_type): 5280 if store_type == "file": 5281 pytest.skip("Issues are only supported in SqlAlchemyStore") 5282 experiment_id = mlflow_client.create_experiment("Issue Test Update") 5283 5284 create_response = requests.post( 5285 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5286 json={ 5287 "experiment_id": experiment_id, 5288 "name": "Original name", 5289 "description": "Original description", 5290 "status": IssueStatus.PENDING.value, 5291 }, 5292 ) 5293 issue_id = create_response.json()["issue"]["issue_id"] 5294 5295 update_response = requests.patch( 5296 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/{issue_id}", 5297 json={ 5298 "issue_id": issue_id, 5299 "name": "Updated name", 5300 "description": "Updated description", 5301 "status": IssueStatus.RESOLVED.value, 5302 "severity": IssueSeverity.HIGH.value, 5303 }, 5304 ) 5305 assert update_response.status_code == 200 5306 data = update_response.json() 5307 issue = data["issue"] 5308 assert issue["issue_id"] == issue_id 5309 assert issue["name"] == "Updated name" 5310 assert issue["description"] == "Updated description" 5311 assert issue["status"] == IssueStatus.RESOLVED.value 5312 assert issue["severity"] == IssueSeverity.HIGH.value 5313 5314 5315 def test_search_issues_no_filters(mlflow_client, store_type): 5316 if store_type == "file": 5317 pytest.skip("Issues are only supported in SqlAlchemyStore") 5318 experiment_id = mlflow_client.create_experiment("Issue Test Search") 5319 5320 for i in range(3): 5321 requests.post( 5322 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5323 json={ 5324 "experiment_id": experiment_id, 5325 "name": f"Issue {i}", 5326 "description": f"Description {i}", 5327 }, 5328 ) 5329 5330 search_response = requests.post( 5331 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search", json={} 5332 ) 5333 assert search_response.status_code == 200 5334 data = search_response.json() 5335 assert "issues" in data 5336 assert len(data["issues"]) == 3 5337 assert {issue["name"] for issue in data["issues"]} == {"Issue 0", "Issue 1", "Issue 2"} 5338 assert {issue["status"] for issue in data["issues"]} == {IssueStatus.PENDING.value} 5339 5340 5341 def test_search_issues_by_experiment(mlflow_client, store_type): 5342 if store_type == "file": 5343 pytest.skip("Issues are only supported in SqlAlchemyStore") 5344 exp1 = mlflow_client.create_experiment("Issue Test Search Exp1") 5345 exp2 = mlflow_client.create_experiment("Issue Test Search Exp2") 5346 5347 requests.post( 5348 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5349 json={ 5350 "experiment_id": exp1, 5351 "name": "Issue in exp1", 5352 "description": "Description", 5353 }, 5354 ) 5355 5356 requests.post( 5357 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5358 json={ 5359 "experiment_id": exp2, 5360 "name": "Issue in exp2", 5361 "description": "Description", 5362 }, 5363 ) 5364 5365 search_response = requests.post( 5366 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search", 5367 json={"experiment_id": exp1}, 5368 ) 5369 assert search_response.status_code == 200 5370 data = search_response.json() 5371 issues = data["issues"] 5372 assert len(issues) == 1 5373 assert issues[0]["experiment_id"] == exp1 5374 assert issues[0]["name"] == "Issue in exp1" 5375 5376 5377 def test_search_issues_by_status(mlflow_client, store_type): 5378 if store_type == "file": 5379 pytest.skip("Issues are only supported in SqlAlchemyStore") 5380 experiment_id = mlflow_client.create_experiment("Issue Test Search Status") 5381 5382 requests.post( 5383 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5384 json={ 5385 "experiment_id": experiment_id, 5386 "name": "Draft issue", 5387 "description": "Description", 5388 "status": IssueStatus.PENDING.value, 5389 }, 5390 ) 5391 5392 requests.post( 5393 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5394 json={ 5395 "experiment_id": experiment_id, 5396 "name": "Confirmed issue", 5397 "description": "Description", 5398 "status": IssueStatus.RESOLVED.value, 5399 }, 5400 ) 5401 5402 search_response = requests.post( 5403 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search", 5404 json={"experiment_id": experiment_id, "filter_string": "status = 'resolved'"}, 5405 ) 5406 assert search_response.status_code == 200 5407 data = search_response.json() 5408 issues = data["issues"] 5409 assert all(issue["status"] == IssueStatus.RESOLVED.value for issue in issues) 5410 assert any(issue["name"] == "Confirmed issue" for issue in issues) 5411 5412 5413 def test_search_issues_with_pagination(mlflow_client, store_type): 5414 if store_type == "file": 5415 pytest.skip("Issues are only supported in SqlAlchemyStore") 5416 experiment_id = mlflow_client.create_experiment("Issue Test Pagination") 5417 5418 for i in range(15): 5419 requests.post( 5420 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5421 json={ 5422 "experiment_id": experiment_id, 5423 "name": f"Issue {i}", 5424 "description": f"Description {i}", 5425 }, 5426 ) 5427 5428 first_page = requests.post( 5429 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search", 5430 json={"experiment_id": experiment_id, "max_results": 10}, 5431 ) 5432 assert first_page.status_code == 200 5433 first_data = first_page.json() 5434 assert len(first_data["issues"]) == 10 5435 assert "next_page_token" in first_data 5436 assert first_data["next_page_token"] != "" 5437 5438 second_page = requests.post( 5439 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search", 5440 json={ 5441 "experiment_id": experiment_id, 5442 "max_results": 10, 5443 "page_token": first_data["next_page_token"], 5444 }, 5445 ) 5446 assert second_page.status_code == 200 5447 second_data = second_page.json() 5448 assert len(second_data["issues"]) == 5 5449 assert second_data["next_page_token"] == "" 5450 5451 5452 def test_search_issues_sorted_by_timestamp(mlflow_client, store_type): 5453 if store_type == "file": 5454 pytest.skip("Issues are only supported in SqlAlchemyStore") 5455 experiment_id = mlflow_client.create_experiment("Issue Test Sort") 5456 5457 # Create issues with slight delays to ensure different timestamps 5458 issue_ids = [] 5459 for i in range(3): 5460 response = requests.post( 5461 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues", 5462 json={ 5463 "experiment_id": experiment_id, 5464 "name": f"Issue {i}", 5465 "description": f"Description {i}", 5466 }, 5467 ) 5468 issue_ids.append(response.json()["issue"]["issue_id"]) 5469 time.sleep(0.01) # Small delay to ensure different timestamps 5470 5471 search_response = requests.post( 5472 f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search", 5473 json={"experiment_id": experiment_id}, 5474 ) 5475 assert search_response.status_code == 200 5476 data = search_response.json() 5477 issues = data["issues"] 5478 assert len(issues) == 3 5479 # Issues should be returned (default order is by created_timestamp descending) 5480 assert {issue["issue_id"] for issue in issues} == set(issue_ids)