fluent.py
1 """ 2 Internal module implementing the fluent API, allowing management of an active 3 MLflow run. This module is exposed to users at the top-level :py:mod:`mlflow` module. 4 """ 5 6 import atexit 7 import contextlib 8 import importlib 9 import inspect 10 import io 11 import logging 12 import os 13 import threading 14 from copy import deepcopy 15 from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Union, overload 16 17 import mlflow 18 from mlflow.entities import Dataset as DatasetEntity 19 from mlflow.entities import ( 20 DatasetInput, 21 Experiment, 22 InputTag, 23 LoggedModel, 24 LoggedModelInput, 25 LoggedModelOutput, 26 LoggedModelStatus, 27 Metric, 28 Param, 29 Run, 30 RunInputs, 31 RunStatus, 32 RunTag, 33 ViewType, 34 ) 35 from mlflow.entities.lifecycle_stage import LifecycleStage 36 from mlflow.entities.trace_location import UnityCatalog 37 from mlflow.environment_variables import ( 38 _MLFLOW_ACTIVE_MODEL_ID, 39 _MLFLOW_ENABLE_SGC_RUN_RESUMPTION_FOR_DATABRICKS_JOBS, 40 MLFLOW_ACTIVE_MODEL_ID, 41 MLFLOW_ENABLE_ASYNC_LOGGING, 42 MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING, 43 MLFLOW_EXPERIMENT_ID, 44 MLFLOW_EXPERIMENT_NAME, 45 MLFLOW_RUN_ID, 46 MLFLOW_TRACING_SQL_WAREHOUSE_ID, 47 ) 48 from mlflow.exceptions import MlflowException 49 from mlflow.protos.databricks_pb2 import ( 50 INVALID_PARAMETER_VALUE, 51 RESOURCE_DOES_NOT_EXIST, 52 ) 53 from mlflow.store.tracking import SEARCH_MAX_RESULTS_DEFAULT 54 from mlflow.telemetry.events import AutologgingEvent 55 from mlflow.telemetry.track import _record_event 56 from mlflow.tracing.provider import ( 57 _get_trace_exporter, 58 ) 59 from mlflow.tracking._tracking_service.client import TrackingServiceClient 60 from mlflow.tracking._tracking_service.utils import _resolve_tracking_uri 61 from mlflow.utils import get_results_from_paginated_fn 62 from mlflow.utils.annotations import experimental 63 from mlflow.utils.async_logging.run_operations import RunOperations 64 from mlflow.utils.autologging_utils import ( 65 AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED, 66 AUTOLOGGING_INTEGRATIONS, 67 autologging_conf_lock, 68 autologging_integration, 69 autologging_is_disabled, 70 is_testing, 71 ) 72 from mlflow.utils.databricks_utils import ( 73 get_sgc_job_run_id, 74 is_in_databricks_model_serving_environment, 75 is_in_databricks_runtime, 76 ) 77 from mlflow.utils.file_utils import TempDir 78 from mlflow.utils.import_hooks import register_post_import_hook 79 from mlflow.utils.mlflow_tags import ( 80 MLFLOW_DATABRICKS_SGC_RESUME_RUN_JOB_RUN_ID_PREFIX, 81 MLFLOW_DATASET_CONTEXT, 82 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH, 83 MLFLOW_EXPERIMENT_PRIMARY_METRIC_GREATER_IS_BETTER, 84 MLFLOW_EXPERIMENT_PRIMARY_METRIC_NAME, 85 MLFLOW_MODEL_IS_EXTERNAL, 86 MLFLOW_PARENT_RUN_ID, 87 MLFLOW_RUN_NAME, 88 MLFLOW_RUN_NOTE, 89 ) 90 from mlflow.utils.thread_utils import ThreadLocalVariable 91 from mlflow.utils.time import get_current_time_millis 92 from mlflow.utils.uri import is_databricks_uri 93 from mlflow.utils.validation import ( 94 _validate_experiment_id_type, 95 _validate_logged_model_name, 96 _validate_run_id, 97 ) 98 from mlflow.version import IS_TRACING_SDK_ONLY 99 100 if not IS_TRACING_SDK_ONLY: 101 from mlflow.data.dataset import Dataset 102 from mlflow.tracking import _get_artifact_repo, _get_store, artifact_utils 103 from mlflow.tracking.client import MlflowClient 104 from mlflow.tracking.context import registry as context_registry 105 from mlflow.tracking.default_experiment import registry as default_experiment_registry 106 107 108 if TYPE_CHECKING: 109 import matplotlib.figure 110 import numpy 111 import pandas 112 import PIL 113 import plotly 114 115 116 _active_experiment_id = None 117 118 SEARCH_MAX_RESULTS_PANDAS = 100000 119 NUM_RUNS_PER_PAGE_PANDAS = 10000 120 121 _logger = logging.getLogger(__name__) 122 123 124 run_id_to_system_metrics_monitor = {} 125 126 127 _active_run_stack = ThreadLocalVariable(default_factory=lambda: []) 128 129 _last_active_run_id = ThreadLocalVariable(default_factory=lambda: None) 130 _last_logged_model_id = ThreadLocalVariable(default_factory=lambda: None) 131 132 133 def _reset_last_logged_model_id() -> None: 134 """ 135 Should be called only for testing purposes. 136 """ 137 _last_logged_model_id.set(None) 138 139 140 _experiment_lock = threading.Lock() 141 142 143 def set_experiment( 144 experiment_name: str | None = None, 145 experiment_id: str | None = None, 146 trace_location: UnityCatalog | None = None, 147 ) -> Experiment: 148 """ 149 Set the given experiment as the active experiment. The experiment must either be specified by 150 name via `experiment_name` or by ID via `experiment_id`. The experiment name and ID cannot 151 both be specified. 152 153 .. note:: 154 If the experiment being set by name does not exist, a new experiment will be 155 created with the given name. After the experiment has been created, it will be set 156 as the active experiment. On certain platforms, such as Databricks, the experiment name 157 must be an absolute path, e.g. ``"/Users/<username>/my-experiment"``. 158 159 Args: 160 experiment_name: Case sensitive name of the experiment to be activated. 161 experiment_id: ID of the experiment to be activated. If an experiment with this ID 162 does not exist, an exception is thrown. 163 trace_location: Optional UC trace location used to configure the experiment-derived 164 tracing destination. Must be an instance of 165 ``mlflow.entities.trace_location.UnityCatalog(...)``. 166 167 Returns: 168 An instance of :py:class:`mlflow.entities.Experiment` representing the new active 169 experiment. 170 171 .. code-block:: python 172 :test: 173 :caption: Example 174 175 import mlflow 176 177 # Set an experiment name, which must be unique and case-sensitive. 178 experiment = mlflow.set_experiment("Social NLP Experiments") 179 # Get Experiment Details 180 print(f"Experiment_id: {experiment.experiment_id}") 181 print(f"Artifact Location: {experiment.artifact_location}") 182 print(f"Tags: {experiment.tags}") 183 print(f"Lifecycle_stage: {experiment.lifecycle_stage}") 184 185 .. code-block:: text 186 :caption: Output 187 188 Experiment_id: 1 189 Artifact Location: file:///.../mlruns/1 190 Tags: {} 191 Lifecycle_stage: active 192 """ 193 if (experiment_name is not None and experiment_id is not None) or ( 194 experiment_name is None and experiment_id is None 195 ): 196 raise MlflowException( 197 message="Must specify exactly one of: `experiment_id` or `experiment_name`.", 198 error_code=INVALID_PARAMETER_VALUE, 199 ) 200 201 client = TrackingServiceClient(_resolve_tracking_uri()) 202 203 is_newly_created = False 204 205 with _experiment_lock: 206 if experiment_id is None: 207 experiment = client.get_experiment_by_name(experiment_name) 208 if not experiment: 209 _logger.info( 210 "Experiment with name '%s' does not exist. Creating a new experiment.", 211 experiment_name, 212 ) 213 try: 214 experiment_id = client.create_experiment(experiment_name) 215 except MlflowException as e: 216 if e.error_code == "RESOURCE_ALREADY_EXISTS": 217 # NB: If two simultaneous processes attempt to set the same experiment 218 # simultaneously, a race condition may be encountered here wherein 219 # experiment creation fails 220 return client.get_experiment_by_name(experiment_name) 221 raise 222 223 experiment = client.get_experiment(experiment_id) 224 is_newly_created = True 225 else: 226 experiment = client.get_experiment(experiment_id) 227 if experiment is None: 228 raise MlflowException( 229 message=f"Experiment with ID '{experiment_id}' does not exist.", 230 error_code=RESOURCE_DOES_NOT_EXIST, 231 ) 232 233 if experiment.lifecycle_stage != LifecycleStage.ACTIVE: 234 raise MlflowException( 235 message=( 236 f"Cannot set a deleted experiment {experiment.name!r} as the active" 237 " experiment. " 238 "You can restore the experiment, or permanently delete the " 239 "experiment to create a new one." 240 ), 241 error_code=INVALID_PARAMETER_VALUE, 242 ) 243 244 if trace_location is not None and trace_location.table_prefix is None: 245 trace_location = UnityCatalog( 246 catalog_name=trace_location.catalog_name, 247 schema_name=trace_location.schema_name, 248 table_prefix=experiment.experiment_id, 249 ) 250 251 try: 252 resolved_location = _resolve_experiment_to_trace_location( 253 experiment=experiment, 254 trace_location=trace_location, 255 ) 256 except MlflowException as e: 257 if is_newly_created and trace_location is not None: 258 raise MlflowException.invalid_parameter_value( 259 f"Experiment '{experiment.name}' (ID: {experiment.experiment_id}) was created " 260 f"but linking to trace location '{trace_location.full_table_prefix}' failed: " 261 f"{e.message} Please fix the issue and call set_experiment again to retry." 262 ) from e 263 raise 264 265 global _active_experiment_id 266 _active_experiment_id = experiment.experiment_id 267 268 # Set 'MLFLOW_EXPERIMENT_ID' environment variable 269 # so that subprocess can inherit it. 270 MLFLOW_EXPERIMENT_ID.set(_active_experiment_id) 271 if resolved_location is not None: 272 experiment.trace_location = resolved_location 273 274 _sync_trace_destination_and_provider(resolved_location) 275 276 return experiment 277 278 279 def _sync_trace_destination_and_provider( 280 resolved_location: UnityCatalog | None, 281 ) -> None: 282 from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION, provider 283 284 # If the tracer provider has already been initialized, reset it so the 285 # next trace re-derives the correct processor chain from the new experiment. 286 if provider.once._done: 287 provider.reset() 288 289 _MLFLOW_TRACE_USER_DESTINATION.set(resolved_location) 290 291 292 def _resolve_experiment_to_trace_location( 293 experiment: Experiment, 294 trace_location: UnityCatalog | None, 295 ) -> UnityCatalog | None: 296 """Resolve the trace destination for an experiment without mutating state. 297 298 All validation and network calls happen here. The caller is responsible 299 for committing the result (setting experiment-derived destination, etc.). 300 301 Returns: 302 The resolved UnityCatalog location if one was configured, or None. 303 """ 304 if trace_location is None: 305 return None 306 if not isinstance(trace_location, UnityCatalog): 307 raise MlflowException.invalid_parameter_value( 308 "`trace_location` must be an instance of `mlflow.entities.trace_location.UnityCatalog`." 309 ) 310 311 if not is_databricks_uri(_resolve_tracking_uri()): 312 raise MlflowException.invalid_parameter_value( 313 "`trace_location` is only supported with a Databricks tracking URI." 314 ) 315 316 # Check if experiment is already linked via the destination path tag (no backend call). 317 if destination_path := experiment.tags.get(MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH): 318 if destination_path == trace_location.full_table_prefix: 319 return experiment.trace_location 320 raise MlflowException.invalid_parameter_value( 321 f"Experiment '{experiment.name}' is already linked to a different " 322 f"trace location '{destination_path}'." 323 ) 324 325 # No existing link — register and link via backend. 326 from mlflow.tracing.client import TracingClient 327 328 tracing_client = TracingClient() 329 resolved = tracing_client._create_or_get_trace_location( 330 trace_location, 331 MLFLOW_TRACING_SQL_WAREHOUSE_ID.get(), 332 ) 333 tracing_client._link_trace_location( 334 experiment_id=experiment.experiment_id, 335 location=resolved, 336 ) 337 return resolved 338 339 340 def _set_experiment_primary_metric( 341 experiment_id: str, primary_metric: str, greater_is_better: bool 342 ): 343 client = MlflowClient() 344 client.set_experiment_tag(experiment_id, MLFLOW_EXPERIMENT_PRIMARY_METRIC_NAME, primary_metric) 345 client.set_experiment_tag( 346 experiment_id, MLFLOW_EXPERIMENT_PRIMARY_METRIC_GREATER_IS_BETTER, str(greater_is_better) 347 ) 348 349 350 class ActiveRun(Run): 351 """Wrapper around :py:class:`mlflow.entities.Run` to enable using Python ``with`` syntax.""" 352 353 def __init__(self, run): 354 Run.__init__(self, run.info, run.data) 355 356 def __enter__(self): 357 return self 358 359 def __exit__(self, exc_type, exc_val, exc_tb): 360 active_run_stack = _active_run_stack.get() 361 362 # Check if the run is still active. We check based on ID instead of 363 # using referential equality, because some tools (e.g. AutoML) may 364 # stop a run and start it again with the same ID to restore session state 365 if any(r.info.run_id == self.info.run_id for r in active_run_stack): 366 status = RunStatus.FINISHED if exc_type is None else RunStatus.FAILED 367 end_run(RunStatus.to_string(status)) 368 369 return exc_type is None 370 371 372 def _get_sgc_job_run_id_tag_key() -> str | None: 373 """ 374 Get the SGC job run ID tag key for run resumption if enabled and available. 375 376 Returns: 377 str or None: The experiment tag key for SGC resumption, or None if not applicable. 378 """ 379 if not _MLFLOW_ENABLE_SGC_RUN_RESUMPTION_FOR_DATABRICKS_JOBS.get(): 380 return None 381 382 if sgc_job_run_id := get_sgc_job_run_id(): 383 return f"{MLFLOW_DATABRICKS_SGC_RESUME_RUN_JOB_RUN_ID_PREFIX}.{sgc_job_run_id}" 384 385 return None 386 387 388 def _get_sgc_mlflow_run_id_for_resumption( 389 client, experiment_id: str | None, sgc_job_run_id_tag_key: str | None 390 ) -> str | None: 391 """ 392 Retrieves the MLflow run ID associated with a specific SGC job run ID tag key 393 for potential run resumption. 394 395 This function searches the experiment (specified by `experiment_id`, or the 396 default if None) for an experiment tag named `sgc_job_run_id_tag_key`. If the 397 tag exists, its value (the run ID to resume) is returned; otherwise, returns None. 398 399 Args: 400 client: MlflowClient instance used to query experiment information. 401 experiment_id: The experiment ID to search, or None to use the default. 402 sgc_job_run_id_tag_key: The experiment tag key that maps the SGC job run ID 403 to an MLflow run ID. 404 405 Returns: 406 str or None: The MLflow run ID to resume, if found; otherwise None. 407 """ 408 search_exp_id = experiment_id or _get_experiment_id() 409 410 try: 411 exp = client.get_experiment(search_exp_id) 412 # Check if experiment has the tag for resumption 413 if prev_mlflow_run_id := exp.tags.get(sgc_job_run_id_tag_key): 414 _logger.info( 415 f"Resuming MLflow run: {prev_mlflow_run_id} " 416 f"using SGC tag key: {sgc_job_run_id_tag_key}" 417 ) 418 return prev_mlflow_run_id 419 except Exception as e: 420 _logger.debug(f"Failed to retrieve SGC run ID: {e}", exc_info=True) 421 422 return None 423 424 425 def start_run( 426 run_id: str | None = None, 427 experiment_id: str | None = None, 428 run_name: str | None = None, 429 nested: bool = False, 430 parent_run_id: str | None = None, 431 tags: dict[str, Any] | None = None, 432 description: str | None = None, 433 log_system_metrics: bool | None = None, 434 ) -> ActiveRun: 435 """ 436 Start a new MLflow run, setting it as the active run under which metrics and parameters 437 will be logged. The return value can be used as a context manager within a ``with`` block; 438 otherwise, you must call ``end_run()`` to terminate the current run. 439 440 If you pass a ``run_id`` or the ``MLFLOW_RUN_ID`` environment variable is set, 441 ``start_run`` attempts to resume a run with the specified run ID and 442 other parameters are ignored. ``run_id`` takes precedence over ``MLFLOW_RUN_ID``. 443 444 If resuming an existing run, the run status is set to ``RunStatus.RUNNING``. 445 446 MLflow sets a variety of default tags on the run, as defined in 447 `MLflow system tags <../../tracking/tracking-api.html#system_tags>`_. 448 449 Args: 450 run_id: If specified, get the run with the specified UUID and log parameters 451 and metrics under that run. The run's end time is unset and its status 452 is set to running, but the run's other attributes (``source_version``, 453 ``source_type``, etc.) are not changed. 454 experiment_id: ID of the experiment under which to create the current run (applicable 455 only when ``run_id`` is not specified). If ``experiment_id`` argument 456 is unspecified, will look for valid experiment in the following order: 457 activated using ``set_experiment``, ``MLFLOW_EXPERIMENT_NAME`` 458 environment variable, ``MLFLOW_EXPERIMENT_ID`` environment variable, 459 or the default experiment as defined by the tracking server. 460 run_name: Name of new run, should be a non-empty string. Used only when ``run_id`` is 461 unspecified. If a new run is created and ``run_name`` is not specified, 462 a random name will be generated for the run. 463 nested: Controls whether run is nested in parent run. ``True`` creates a nested run. 464 parent_run_id: If specified, the current run will be nested under the the run with 465 the specified UUID. The parent run must be in the ACTIVE state. 466 tags: An optional dictionary of string keys and values to set as tags on the run. 467 If a run is being resumed, these tags are set on the resumed run. If a new run is 468 being created, these tags are set on the new run. 469 description: An optional string that populates the description box of the run. 470 If a run is being resumed, the description is set on the resumed run. 471 If a new run is being created, the description is set on the new run. 472 log_system_metrics: bool, defaults to None. If True, system metrics will be logged 473 to MLflow, e.g., cpu/gpu utilization. If None, we will check environment variable 474 `MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING` to determine whether to log system metrics. 475 System metrics logging is an experimental feature in MLflow 2.8 and subject to change. 476 477 Returns: 478 :py:class:`mlflow.ActiveRun` object that acts as a context manager wrapping the 479 run's state. 480 481 .. code-block:: python 482 :test: 483 :caption: Example 484 485 import mlflow 486 487 # Create nested runs 488 experiment_id = mlflow.create_experiment("experiment1") 489 with mlflow.start_run( 490 run_name="PARENT_RUN", 491 experiment_id=experiment_id, 492 tags={"version": "v1", "priority": "P1"}, 493 description="parent", 494 ) as parent_run: 495 mlflow.log_param("parent", "yes") 496 with mlflow.start_run( 497 run_name="CHILD_RUN", 498 experiment_id=experiment_id, 499 description="child", 500 nested=True, 501 ) as child_run: 502 mlflow.log_param("child", "yes") 503 print("parent run:") 504 print(f"run_id: {parent_run.info.run_id}") 505 print("description: {}".format(parent_run.data.tags.get("mlflow.note.content"))) 506 print("version tag value: {}".format(parent_run.data.tags.get("version"))) 507 print("priority tag value: {}".format(parent_run.data.tags.get("priority"))) 508 print("--") 509 510 # Search all child runs with a parent id 511 query = f"tags.mlflow.parentRunId = '{parent_run.info.run_id}'" 512 results = mlflow.search_runs(experiment_ids=[experiment_id], filter_string=query) 513 print("child runs:") 514 print(results[["run_id", "params.child", "tags.mlflow.runName"]]) 515 516 # Create a nested run under the existing parent run 517 with mlflow.start_run( 518 run_name="NEW_CHILD_RUN", 519 experiment_id=experiment_id, 520 description="new child", 521 parent_run_id=parent_run.info.run_id, 522 ) as child_run: 523 mlflow.log_param("new-child", "yes") 524 525 .. code-block:: text 526 :caption: Output 527 528 parent run: 529 run_id: 8979459433a24a52ab3be87a229a9cdf 530 description: starting a parent for experiment 7 531 version tag value: v1 532 priority tag value: P1 533 -- 534 child runs: 535 run_id params.child tags.mlflow.runName 536 0 7d175204675e40328e46d9a6a5a7ee6a yes CHILD_RUN 537 """ 538 active_run_stack = _active_run_stack.get() 539 _validate_experiment_id_type(experiment_id) 540 # back compat for int experiment_id 541 experiment_id = str(experiment_id) if isinstance(experiment_id, int) else experiment_id 542 if len(active_run_stack) > 0 and not nested: 543 raise Exception( 544 ( 545 "Run with UUID {} is already active. To start a new run, first end the " 546 + "current run with mlflow.end_run(). To start a nested " 547 + "run, call start_run with nested=True" 548 ).format(active_run_stack[0].info.run_id) 549 ) 550 client = MlflowClient() 551 sgc_job_run_id_tag_key: str | None = None 552 if run_id: 553 existing_run_id = run_id 554 elif run_id := MLFLOW_RUN_ID.get(): 555 existing_run_id = run_id 556 del os.environ[MLFLOW_RUN_ID.name] 557 # Get SGC job run ID tag key for run resumption if applicable 558 elif sgc_job_run_id_tag_key := _get_sgc_job_run_id_tag_key(): 559 existing_run_id = _get_sgc_mlflow_run_id_for_resumption( 560 client, experiment_id, sgc_job_run_id_tag_key 561 ) 562 else: 563 existing_run_id = None 564 if existing_run_id: 565 _validate_run_id(existing_run_id) 566 active_run_obj = client.get_run(existing_run_id) 567 # Check to see if experiment_id from environment matches experiment_id from set_experiment() 568 if ( 569 _active_experiment_id is not None 570 and _active_experiment_id != active_run_obj.info.experiment_id 571 ): 572 raise MlflowException( 573 f"Cannot start run with ID {existing_run_id} because active experiment ID " 574 "does not match environment run ID. Make sure --experiment-name " 575 "or --experiment-id matches experiment set with " 576 "set_experiment(), or just use command-line arguments" 577 ) 578 # Check if the current run has been deleted. 579 if active_run_obj.info.lifecycle_stage == LifecycleStage.DELETED: 580 raise MlflowException( 581 f"Cannot start run with ID {existing_run_id} because it is in the deleted state." 582 ) 583 # Use previous `end_time` because a value is required for `update_run_info`. 584 end_time = active_run_obj.info.end_time 585 _get_store().update_run_info( 586 existing_run_id, run_status=RunStatus.RUNNING, end_time=end_time, run_name=run_name 587 ) 588 tags = tags or {} 589 if description: 590 if MLFLOW_RUN_NOTE in tags: 591 raise MlflowException( 592 f"Description is already set via the tag {MLFLOW_RUN_NOTE} in tags." 593 f"Remove the key {MLFLOW_RUN_NOTE} from the tags or omit the description.", 594 error_code=INVALID_PARAMETER_VALUE, 595 ) 596 tags[MLFLOW_RUN_NOTE] = description 597 598 if tags: 599 client.log_batch( 600 run_id=existing_run_id, 601 tags=[RunTag(key, str(value)) for key, value in tags.items()], 602 ) 603 active_run_obj = client.get_run(existing_run_id) 604 else: 605 if parent_run_id: 606 _validate_run_id(parent_run_id) 607 # Make sure parent_run_id matches the current run id, if there is an active run 608 if len(active_run_stack) > 0 and parent_run_id != active_run_stack[-1].info.run_id: 609 current_run_id = active_run_stack[-1].info.run_id 610 raise MlflowException( 611 f"Current run with UUID {current_run_id} does not match the specified " 612 f"parent_run_id {parent_run_id}. To start a new nested run under " 613 f"the parent run with UUID {current_run_id}, first end the current run " 614 "with mlflow.end_run()." 615 ) 616 parent_run_obj = client.get_run(parent_run_id) 617 # Check if the specified parent_run has been deleted. 618 if parent_run_obj.info.lifecycle_stage == LifecycleStage.DELETED: 619 raise MlflowException( 620 f"Cannot start run under parent run with ID {parent_run_id} " 621 f"because it is in the deleted state." 622 ) 623 else: 624 parent_run_id = active_run_stack[-1].info.run_id if len(active_run_stack) > 0 else None 625 626 exp_id_for_run = experiment_id if experiment_id is not None else _get_experiment_id() 627 628 user_specified_tags = deepcopy(tags) or {} 629 if description: 630 if MLFLOW_RUN_NOTE in user_specified_tags: 631 raise MlflowException( 632 f"Description is already set via the tag {MLFLOW_RUN_NOTE} in tags." 633 f"Remove the key {MLFLOW_RUN_NOTE} from the tags or omit the description.", 634 error_code=INVALID_PARAMETER_VALUE, 635 ) 636 user_specified_tags[MLFLOW_RUN_NOTE] = description 637 if parent_run_id is not None: 638 user_specified_tags[MLFLOW_PARENT_RUN_ID] = parent_run_id 639 if run_name: 640 user_specified_tags[MLFLOW_RUN_NAME] = run_name 641 642 resolved_tags = context_registry.resolve_tags(user_specified_tags) 643 644 active_run_obj = client.create_run( 645 experiment_id=exp_id_for_run, 646 tags=resolved_tags, 647 run_name=run_name, 648 ) 649 650 # If SGC run resumption is enabled, set the experiment tag mapping 651 # SGC job_run_id to this run_id for future run resumption 652 if sgc_job_run_id_tag_key: 653 try: 654 client.set_experiment_tag( 655 exp_id_for_run, sgc_job_run_id_tag_key, active_run_obj.info.run_id 656 ) 657 _logger.info( 658 f"Set experiment tag {sgc_job_run_id_tag_key} = {active_run_obj.info.run_id} " 659 f"for SGC run resumption" 660 ) 661 except Exception as e: 662 _logger.debug( 663 f"Failed to set experiment tag for SGC resumption: {e}", exc_info=True 664 ) 665 666 if log_system_metrics is None: 667 # If `log_system_metrics` is not specified, we will check environment variable. 668 log_system_metrics = MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING.get() 669 670 if log_system_metrics: 671 if importlib.util.find_spec("psutil") is None: 672 raise MlflowException( 673 "Failed to start system metrics monitoring as package `psutil` is not installed. " 674 "Please run `pip install psutil` to resolve the issue, otherwise you can disable " 675 "system metrics logging by passing `log_system_metrics=False` to " 676 "`mlflow.start_run()` or calling `mlflow.disable_system_metrics_logging`." 677 ) 678 try: 679 from mlflow.system_metrics.system_metrics_monitor import SystemMetricsMonitor 680 681 system_monitor = SystemMetricsMonitor( 682 active_run_obj.info.run_id, 683 resume_logging=existing_run_id is not None, 684 ) 685 run_id_to_system_metrics_monitor[active_run_obj.info.run_id] = system_monitor 686 system_monitor.start() 687 except Exception as e: 688 _logger.error(f"Failed to start system metrics monitoring: {e}.") 689 690 active_run_stack.append(ActiveRun(active_run_obj)) 691 return active_run_stack[-1] 692 693 694 def end_run(status: str = RunStatus.to_string(RunStatus.FINISHED)) -> None: 695 """ 696 End an active MLflow run (if there is one). 697 698 .. code-block:: python 699 :test: 700 :caption: Example 701 702 import mlflow 703 704 # Start run and get status 705 mlflow.start_run() 706 run = mlflow.active_run() 707 print(f"run_id: {run.info.run_id}; status: {run.info.status}") 708 709 # End run and get status 710 mlflow.end_run() 711 run = mlflow.get_run(run.info.run_id) 712 print(f"run_id: {run.info.run_id}; status: {run.info.status}") 713 print("--") 714 715 # Check for any active runs 716 print(f"Active run: {mlflow.active_run()}") 717 718 .. code-block:: text 719 :caption: Output 720 721 run_id: b47ee4563368419880b44ad8535f6371; status: RUNNING 722 run_id: b47ee4563368419880b44ad8535f6371; status: FINISHED 723 -- 724 Active run: None 725 """ 726 active_run_stack = _active_run_stack.get() 727 if len(active_run_stack) > 0: 728 # Clear out the global existing run environment variable as well. 729 MLFLOW_RUN_ID.unset() 730 run = active_run_stack.pop() 731 last_active_run_id = run.info.run_id 732 _last_active_run_id.set(last_active_run_id) 733 MlflowClient().set_terminated(last_active_run_id, status) 734 if last_active_run_id in run_id_to_system_metrics_monitor: 735 system_metrics_monitor = run_id_to_system_metrics_monitor.pop(last_active_run_id) 736 system_metrics_monitor.finish() 737 738 739 def _safe_end_run(): 740 with contextlib.suppress(Exception): 741 end_run() 742 743 744 atexit.register(_safe_end_run) 745 746 747 def active_run() -> ActiveRun | None: 748 """ 749 Get the currently active ``Run``, or None if no such run exists. 750 751 .. attention:: 752 This API is **thread-local** and returns only the active run in the current thread. 753 If your application is multi-threaded and a run is started in a different thread, 754 this API will not retrieve that run. 755 756 **Note**: You cannot access currently-active run attributes 757 (parameters, metrics, etc.) through the run returned by ``mlflow.active_run``. In order 758 to access such attributes, use the :py:class:`mlflow.client.MlflowClient` as follows: 759 760 .. code-block:: python 761 :test: 762 :caption: Example 763 764 import mlflow 765 766 mlflow.start_run() 767 run = mlflow.active_run() 768 print(f"Active run_id: {run.info.run_id}") 769 mlflow.end_run() 770 771 .. code-block:: text 772 :caption: Output 773 774 Active run_id: 6f252757005748708cd3aad75d1ff462 775 """ 776 active_run_stack = _active_run_stack.get() 777 return active_run_stack[-1] if len(active_run_stack) > 0 else None 778 779 780 def last_active_run() -> Run | None: 781 """Gets the most recent active run. 782 783 Examples: 784 785 .. code-block:: python 786 :test: 787 :caption: To retrieve the most recent autologged run: 788 789 import mlflow 790 791 from sklearn.model_selection import train_test_split 792 from sklearn.datasets import load_diabetes 793 from sklearn.ensemble import RandomForestRegressor 794 795 mlflow.autolog() 796 797 db = load_diabetes() 798 X_train, X_test, y_train, y_test = train_test_split(db.data, db.target) 799 800 # Create and train models. 801 rf = RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3) 802 rf.fit(X_train, y_train) 803 804 # Use the model to make predictions on the test dataset. 805 predictions = rf.predict(X_test) 806 autolog_run = mlflow.last_active_run() 807 808 .. code-block:: python 809 :test: 810 :caption: To get the most recently active run that ended: 811 812 import mlflow 813 814 mlflow.start_run() 815 mlflow.end_run() 816 run = mlflow.last_active_run() 817 818 .. code-block:: python 819 :test: 820 :caption: To retrieve the currently active run: 821 822 import mlflow 823 824 mlflow.start_run() 825 run = mlflow.last_active_run() 826 mlflow.end_run() 827 828 Returns: 829 The active run (this is equivalent to ``mlflow.active_run()``) if one exists. 830 Otherwise, the last run started from the current Python process that reached 831 a terminal status (i.e. FINISHED, FAILED, or KILLED). 832 """ 833 _active_run = active_run() 834 if _active_run is not None: 835 return _active_run 836 837 last_active_run_id = _last_active_run_id.get() 838 if last_active_run_id is None: 839 return None 840 return get_run(last_active_run_id) 841 842 843 def _get_latest_active_run(): 844 """ 845 Get active run from global context by checking all threads. The `mlflow.active_run` API 846 only returns active run from current thread. This API is useful for the case where one 847 needs to get a run started from a separate thread. 848 """ 849 all_active_runs = [ 850 run for run_stack in _active_run_stack.get_all_thread_values().values() for run in run_stack 851 ] 852 if all_active_runs: 853 return max(all_active_runs, key=lambda run: run.info.start_time) 854 return None 855 856 857 def get_run(run_id: str) -> Run: 858 """ 859 Fetch the run from backend store. The resulting Run contains a collection of run metadata -- 860 RunInfo as well as a collection of run parameters, tags, and metrics -- RunData. It also 861 contains a collection of run inputs (experimental), including information about datasets used by 862 the run -- RunInputs. In the case where multiple metrics with the same key are logged for the 863 run, the RunData contains the most recently logged value at the largest step for each metric. 864 865 Args: 866 run_id: Unique identifier for the run. 867 868 Returns: 869 A single Run object, if the run exists. Otherwise, raises an exception. 870 871 .. code-block:: python 872 :test: 873 :caption: Example 874 875 import mlflow 876 877 with mlflow.start_run() as run: 878 mlflow.log_param("p", 0) 879 run_id = run.info.run_id 880 print(f"run_id: {run_id}; lifecycle_stage: {mlflow.get_run(run_id).info.lifecycle_stage}") 881 882 .. code-block:: text 883 :caption: Output 884 885 run_id: 7472befefc754e388e8e922824a0cca5; lifecycle_stage: active 886 """ 887 return MlflowClient().get_run(run_id) 888 889 890 def get_parent_run(run_id: str) -> Run | None: 891 """Gets the parent run for the given run id if one exists. 892 893 Args: 894 run_id: Unique identifier for the child run. 895 896 Returns: 897 A single :py:class:`mlflow.entities.Run` object, if the parent run exists. Otherwise, 898 returns None. 899 900 .. code-block:: python 901 :test: 902 :caption: Example 903 904 import mlflow 905 906 # Create nested runs 907 with mlflow.start_run(): 908 with mlflow.start_run(nested=True) as child_run: 909 child_run_id = child_run.info.run_id 910 911 parent_run = mlflow.get_parent_run(child_run_id) 912 913 print(f"child_run_id: {child_run_id}") 914 print(f"parent_run_id: {parent_run.info.run_id}") 915 916 .. code-block:: text 917 :caption: Output 918 919 child_run_id: 7d175204675e40328e46d9a6a5a7ee6a 920 parent_run_id: 8979459433a24a52ab3be87a229a9cdf 921 """ 922 return MlflowClient().get_parent_run(run_id) 923 924 925 def log_param(key: str, value: Any, synchronous: bool | None = None) -> Any: 926 """ 927 Log a parameter (e.g. model hyperparameter) under the current run. If no run is active, 928 this method will create a new active run. 929 930 Args: 931 key: Parameter name. This string may only contain alphanumerics, underscores (_), dashes 932 (-), periods (.), spaces ( ), and slashes (/). All backend stores support keys up to 933 length 250, but some may support larger keys. 934 value: Parameter value, but will be string-ified if not. All built-in backend stores support 935 values up to length 6000, but some may support larger values. 936 synchronous: *Experimental* If True, blocks until the parameter is logged successfully. If 937 False, logs the parameter asynchronously and returns a future representing the logging 938 operation. If None, read from environment variable `MLFLOW_ENABLE_ASYNC_LOGGING`, 939 which defaults to False if not set. 940 941 Returns: 942 When `synchronous=True`, returns parameter value. When `synchronous=False`, returns an 943 :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that represents 944 future for logging operation. 945 946 .. code-block:: python 947 :test: 948 :caption: Example 949 950 import mlflow 951 952 with mlflow.start_run(): 953 value = mlflow.log_param("learning_rate", 0.01) 954 assert value == 0.01 955 value = mlflow.log_param("learning_rate", 0.02, synchronous=False) 956 """ 957 run_id = _get_or_start_run().info.run_id 958 synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() 959 return MlflowClient().log_param(run_id, key, value, synchronous=synchronous) 960 961 962 def flush_async_logging() -> None: 963 """Flush all pending async logging.""" 964 _get_store().flush_async_logging() 965 966 967 def _shut_down_async_logging() -> None: 968 """Shutdown the async logging and flush all pending data.""" 969 _get_store().shut_down_async_logging() 970 971 972 def flush_artifact_async_logging() -> None: 973 """Flush all pending artifact async logging.""" 974 run_id = _get_or_start_run().info.run_id 975 if _artifact_repo := _get_artifact_repo(run_id): 976 _artifact_repo.flush_async_logging() 977 978 979 def flush_trace_async_logging(terminate=False) -> None: 980 """ 981 Flush all pending trace async logging. 982 983 Args: 984 terminate: If True, shut down the logging threads after flushing. 985 """ 986 # Flush ALL batch span processors and their exporters' async queues. 987 # When set_destination() is called multiple times, each call creates a new 988 # tracer provider, processor, and exporter. The registry tracks all of them 989 # so we drain both layers: span queue → exporter → async DB write queue. 990 from mlflow.tracing.processor.base_mlflow import flush_all_batch_processors 991 992 try: 993 flush_all_batch_processors(terminate=terminate) 994 except Exception as e: 995 _logger.debug(f"Failed to flush batch processors: {e}", exc_info=True) 996 997 # When batch processor is disabled (no registry entries), the current exporter 998 # may still have an _async_queue that needs draining (SimpleSpanProcessor path). 999 try: 1000 if trace_exporter := _get_trace_exporter(): 1001 if hasattr(trace_exporter, "_async_queue"): 1002 trace_exporter._async_queue.flush(terminate=terminate) 1003 except Exception as e: 1004 _logger.debug(f"Failed to flush trace exporter async queue: {e}", exc_info=True) 1005 1006 1007 def set_experiment_tag(key: str, value: Any) -> None: 1008 """ 1009 Set a tag on the current experiment. Value is converted to a string. 1010 1011 Args: 1012 key: Tag name. This string may only contain alphanumerics, underscores (_), dashes (-), 1013 periods (.), spaces ( ), and slashes (/). All backend stores will support keys up to 1014 length 250, but some may support larger keys. 1015 value: Tag value, but will be string-ified if not. All backend stores will support values 1016 up to length 5000, but some may support larger values. 1017 1018 .. code-block:: python 1019 :test: 1020 :caption: Example 1021 1022 import mlflow 1023 1024 with mlflow.start_run(): 1025 mlflow.set_experiment_tag("release.version", "2.2.0") 1026 """ 1027 experiment_id = _get_experiment_id() 1028 MlflowClient().set_experiment_tag(experiment_id, key, value) 1029 1030 1031 def delete_experiment_tag(key: str) -> None: 1032 """ 1033 Delete a tag from the current experiment. 1034 1035 Args: 1036 key: Name of the tag to be deleted. 1037 1038 .. code-block:: python 1039 :test: 1040 :caption: Example 1041 1042 import mlflow 1043 1044 exp = mlflow.set_experiment("test-delete-tag") 1045 mlflow.set_experiment_tag("release.version", "1.0") 1046 mlflow.delete_experiment_tag("release.version") 1047 exp = mlflow.get_experiment(exp.experiment_id) 1048 assert "release.version" not in exp.tags 1049 """ 1050 experiment_id = _get_experiment_id() 1051 MlflowClient().delete_experiment_tag(experiment_id, key) 1052 1053 1054 def set_tag(key: str, value: Any, synchronous: bool | None = None) -> RunOperations | None: 1055 """ 1056 Set a tag under the current run. If no run is active, this method will create a new active 1057 run. 1058 1059 Args: 1060 key: Tag name. This string may only contain alphanumerics, underscores (_), dashes (-), 1061 periods (.), spaces ( ), and slashes (/). All backend stores will support keys up to 1062 length 250, but some may support larger keys. 1063 value: Tag value, but will be string-ified if not. All backend stores will support values 1064 up to length 5000, but some may support larger values. 1065 synchronous: *Experimental* If True, blocks until the tag is logged successfully. If False, 1066 logs the tag asynchronously and returns a future representing the logging operation. 1067 If None, read from environment variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which 1068 defaults to False if not set. 1069 1070 Returns: 1071 When `synchronous=True`, returns None. When `synchronous=False`, returns an 1072 :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that 1073 represents future for logging operation. 1074 1075 .. code-block:: python 1076 :test: 1077 :caption: Example 1078 1079 import mlflow 1080 1081 # Set a tag. 1082 with mlflow.start_run(): 1083 mlflow.set_tag("release.version", "2.2.0") 1084 1085 # Set a tag in async fashion. 1086 with mlflow.start_run(): 1087 mlflow.set_tag("release.version", "2.2.1", synchronous=False) 1088 """ 1089 run_id = _get_or_start_run().info.run_id 1090 synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() 1091 return MlflowClient().set_tag(run_id, key, value, synchronous=synchronous) 1092 1093 1094 def delete_tag(key: str) -> None: 1095 """ 1096 Delete a tag from a run. This is irreversible. If no run is active, this method 1097 will create a new active run. 1098 1099 Args: 1100 key: Name of the tag 1101 1102 .. code-block:: python 1103 :test: 1104 :caption: Example 1105 1106 import mlflow 1107 1108 tags = {"engineering": "ML Platform", "engineering_remote": "ML Platform"} 1109 1110 with mlflow.start_run() as run: 1111 mlflow.set_tags(tags) 1112 1113 with mlflow.start_run(run_id=run.info.run_id): 1114 mlflow.delete_tag("engineering_remote") 1115 """ 1116 run_id = _get_or_start_run().info.run_id 1117 MlflowClient().delete_tag(run_id, key) 1118 1119 1120 def log_metric( 1121 key: str, 1122 value: float, 1123 step: int | None = None, 1124 synchronous: bool | None = None, 1125 timestamp: int | None = None, 1126 run_id: str | None = None, 1127 model_id: str | None = None, 1128 dataset: Union["Dataset", DatasetEntity] | None = None, 1129 ) -> RunOperations | None: 1130 """ 1131 Log a metric under the current run. If no run is active, this method will create 1132 a new active run. 1133 1134 Args: 1135 key: Metric name. This string may only contain alphanumerics, underscores (_), 1136 dashes (-), periods (.), spaces ( ), and slashes (/). 1137 All backend stores will support keys up to length 250, but some may 1138 support larger keys. 1139 value: Metric value. Note that some special values such as +/- Infinity may be 1140 replaced by other values depending on the store. For example, the 1141 SQLAlchemy store replaces +/- Infinity with max / min float values. 1142 All backend stores will support values up to length 5000, but some 1143 may support larger values. 1144 step: Metric step. Defaults to zero if unspecified. 1145 synchronous: *Experimental* If True, blocks until the metric is logged 1146 successfully. If False, logs the metric asynchronously and 1147 returns a future representing the logging operation. If None, read from environment 1148 variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which defaults to False if not set. 1149 timestamp: Time when this metric was calculated. Defaults to the current system time. 1150 run_id: If specified, log the metric to the specified run. If not specified, log the metric 1151 to the currently active run. 1152 model_id: The ID of the model associated with the metric. If not specified, use the current 1153 active model ID set by :py:func:`mlflow.set_active_model`. If no active model exists, 1154 the models IDs associated with the specified or active run will be used. 1155 dataset: The dataset associated with the metric. 1156 1157 Returns: 1158 When `synchronous=True`, returns None. 1159 When `synchronous=False`, returns `RunOperations` that represents future for 1160 logging operation. 1161 1162 .. code-block:: python 1163 :test: 1164 :caption: Example 1165 1166 import mlflow 1167 1168 # Log a metric 1169 with mlflow.start_run(): 1170 mlflow.log_metric("mse", 2500.00) 1171 1172 # Log a metric in async fashion. 1173 with mlflow.start_run(): 1174 mlflow.log_metric("mse", 2500.00, synchronous=False) 1175 """ 1176 run = _get_or_start_run() if run_id is None else MlflowClient().get_run(run_id) 1177 run_id = run.info.run_id 1178 synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() 1179 model_id = model_id or get_active_model_id() 1180 _log_inputs_for_metrics_if_necessary( 1181 run, 1182 [ 1183 Metric( 1184 key=key, 1185 value=value, 1186 timestamp=timestamp or get_current_time_millis(), 1187 step=step or 0, 1188 model_id=model_id, 1189 dataset_name=dataset.name if dataset is not None else None, 1190 dataset_digest=dataset.digest if dataset is not None else None, 1191 ), 1192 ], 1193 datasets=[dataset] if dataset is not None else None, 1194 ) 1195 timestamp = timestamp or get_current_time_millis() 1196 step = step or 0 1197 model_ids = ( 1198 [model_id] 1199 if model_id is not None 1200 else (_get_model_ids_for_new_metric_if_exist(run, step) or [None]) 1201 ) 1202 for model_id in model_ids: 1203 return MlflowClient().log_metric( 1204 run_id, 1205 key, 1206 value, 1207 timestamp, 1208 step, 1209 synchronous=synchronous, 1210 model_id=model_id, 1211 dataset_name=dataset.name if dataset is not None else None, 1212 dataset_digest=dataset.digest if dataset is not None else None, 1213 ) 1214 1215 1216 def _log_inputs_for_metrics_if_necessary( 1217 run: Run, metrics: list[Metric], datasets: list["Dataset"] | None = None 1218 ) -> None: 1219 client = MlflowClient() 1220 input_model_ids = ( 1221 {i.model_id for i in run.inputs.model_inputs} 1222 if run.inputs and run.inputs.model_inputs 1223 else set() 1224 ) 1225 output_model_ids = ( 1226 {o.model_id for o in run.outputs.model_outputs} 1227 if run.outputs and run.outputs.model_outputs 1228 else set() 1229 ) 1230 run_datasets = ( 1231 [(inp.dataset.name, inp.dataset.digest) for inp in run.inputs.dataset_inputs] 1232 if run.inputs 1233 else [] 1234 ) 1235 datasets = datasets or [] 1236 models_to_log = [] 1237 datasets_to_log = [] 1238 for metric in metrics: 1239 if ( 1240 metric.model_id is not None 1241 and metric.model_id not in input_model_ids | output_model_ids 1242 ): 1243 models_to_log.append(LoggedModelInput(model_id=metric.model_id)) 1244 if datasets and (metric.dataset_name, metric.dataset_digest) not in run_datasets: 1245 matching_dataset = next( 1246 ( 1247 dataset 1248 for dataset in datasets 1249 if dataset.name == metric.dataset_name 1250 and dataset.digest == metric.dataset_digest 1251 ), 1252 None, 1253 ) 1254 if matching_dataset is not None: 1255 if isinstance(matching_dataset, DatasetEntity): 1256 dataset_entity = matching_dataset 1257 else: 1258 dataset_entity = matching_dataset._to_mlflow_entity() 1259 datasets_to_log.append(DatasetInput(dataset_entity, tags=[])) 1260 if models_to_log or datasets_to_log: 1261 client.log_inputs(run.info.run_id, models=models_to_log, datasets=datasets_to_log) 1262 # update in-memory run inputs to avoid duplicate logging 1263 if run.inputs is None: 1264 run._inputs = RunInputs(dataset_inputs=datasets_to_log, model_inputs=models_to_log) 1265 else: 1266 run._inputs._model_inputs.extend(models_to_log) 1267 run._inputs._dataset_inputs.extend(datasets_to_log) 1268 1269 1270 def _get_model_ids_for_new_metric_if_exist(run: Run, metric_step: str) -> list[str]: 1271 outputs = run.outputs.model_outputs if run.outputs else [] 1272 model_outputs_at_step = [mo for mo in outputs if mo.step == metric_step] 1273 return [mo.model_id for mo in model_outputs_at_step] 1274 1275 1276 def log_metrics( 1277 metrics: dict[str, float], 1278 step: int | None = None, 1279 synchronous: bool | None = None, 1280 run_id: str | None = None, 1281 timestamp: int | None = None, 1282 model_id: str | None = None, 1283 dataset: Union["Dataset", DatasetEntity] | None = None, 1284 ) -> RunOperations | None: 1285 """ 1286 Log multiple metrics for the current run. If no run is active, this method will create a new 1287 active run. 1288 1289 Args: 1290 metrics: Dictionary of metric_name: String -> value: Float. Note that some special 1291 values such as +/- Infinity may be replaced by other values depending on 1292 the store. For example, sql based store may replace +/- Infinity with 1293 max / min float values. 1294 step: A single integer step at which to log the specified 1295 Metrics. If unspecified, each metric is logged at step zero. 1296 synchronous: *Experimental* If True, blocks until the metrics are logged 1297 successfully. If False, logs the metrics asynchronously and 1298 returns a future representing the logging operation. If None, read from environment 1299 variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which defaults to False if not set. 1300 run_id: Run ID. If specified, log metrics to the specified run. If not specified, log 1301 metrics to the currently active run. 1302 timestamp: Time when these metrics were calculated. Defaults to the current system time. 1303 model_id: The ID of the model associated with the metric. If not specified, use the current 1304 active model ID set by :py:func:`mlflow.set_active_model`. If no active model 1305 exists, the models IDs associated with the specified or active run will be used. 1306 dataset: The dataset associated with the metrics. 1307 1308 Returns: 1309 When `synchronous=True`, returns None. When `synchronous=False`, returns an 1310 :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that 1311 represents future for logging operation. 1312 1313 .. code-block:: python 1314 :test: 1315 :caption: Example 1316 1317 import mlflow 1318 1319 metrics = {"mse": 2500.00, "rmse": 50.00} 1320 1321 # Log a batch of metrics 1322 with mlflow.start_run(): 1323 mlflow.log_metrics(metrics) 1324 1325 # Log a batch of metrics in async fashion. 1326 with mlflow.start_run(): 1327 mlflow.log_metrics(metrics, synchronous=False) 1328 """ 1329 run = _get_or_start_run() if run_id is None else MlflowClient().get_run(run_id) 1330 run_id = run.info.run_id 1331 timestamp = timestamp or get_current_time_millis() 1332 step = step or 0 1333 dataset_name = dataset.name if dataset is not None else None 1334 dataset_digest = dataset.digest if dataset is not None else None 1335 model_id = model_id or get_active_model_id() 1336 model_ids = ( 1337 [model_id] 1338 if model_id is not None 1339 else (_get_model_ids_for_new_metric_if_exist(run, step) or [None]) 1340 ) 1341 metrics_arr = [ 1342 Metric( 1343 key, 1344 value, 1345 timestamp, 1346 step or 0, 1347 model_id=model_id, 1348 dataset_name=dataset_name, 1349 dataset_digest=dataset_digest, 1350 run_id=run_id, 1351 ) 1352 for key, value in metrics.items() 1353 for model_id in model_ids 1354 ] 1355 _log_inputs_for_metrics_if_necessary( 1356 run, metrics_arr, [dataset] if dataset is not None else None 1357 ) 1358 synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() 1359 return MlflowClient().log_batch( 1360 run_id=run_id, 1361 metrics=metrics_arr, 1362 params=[], 1363 tags=[], 1364 synchronous=synchronous, 1365 ) 1366 1367 1368 def log_params( 1369 params: dict[str, Any], synchronous: bool | None = None, run_id: str | None = None 1370 ) -> RunOperations | None: 1371 """ 1372 Log a batch of params for the current run. If no run is active, this method will create a 1373 new active run. 1374 1375 Args: 1376 params: Dictionary of param_name: String -> value: (String, but will be string-ified if 1377 not) 1378 synchronous: *Experimental* If True, blocks until the parameters are logged 1379 successfully. If False, logs the parameters asynchronously and 1380 returns a future representing the logging operation. If None, read from environment 1381 variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which defaults to False if not set. 1382 run_id: Run ID. If specified, log params to the specified run. If not specified, log 1383 params to the currently active run. 1384 1385 Returns: 1386 When `synchronous=True`, returns None. When `synchronous=False`, returns an 1387 :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that 1388 represents future for logging operation. 1389 1390 .. code-block:: python 1391 :test: 1392 :caption: Example 1393 1394 import mlflow 1395 1396 params = {"learning_rate": 0.01, "n_estimators": 10} 1397 1398 # Log a batch of parameters 1399 with mlflow.start_run(): 1400 mlflow.log_params(params) 1401 1402 # Log a batch of parameters in async fashion. 1403 with mlflow.start_run(): 1404 mlflow.log_params(params, synchronous=False) 1405 """ 1406 run_id = run_id or _get_or_start_run().info.run_id 1407 params_arr = [Param(key, str(value)) for key, value in params.items()] 1408 synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() 1409 return MlflowClient().log_batch( 1410 run_id=run_id, metrics=[], params=params_arr, tags=[], synchronous=synchronous 1411 ) 1412 1413 1414 def _create_dataset_input( 1415 dataset: Optional["Dataset"], 1416 context: str | None = None, 1417 tags: dict[str, str] | None = None, 1418 ) -> DatasetInput | None: 1419 if (context or tags) and dataset is None: 1420 raise MlflowException.invalid_parameter_value( 1421 "`dataset` must be specified if `context` or `tags` is specified." 1422 ) 1423 tags_to_log = [] 1424 if tags: 1425 tags_to_log = [InputTag(key=key, value=value) for key, value in tags.items()] 1426 if context: 1427 tags_to_log.append(InputTag(key=MLFLOW_DATASET_CONTEXT, value=context)) 1428 1429 return DatasetInput(dataset=dataset._to_mlflow_entity(), tags=tags_to_log) if dataset else None 1430 1431 1432 def log_input( 1433 dataset: Optional["Dataset"] = None, 1434 context: str | None = None, 1435 tags: dict[str, str] | None = None, 1436 model: LoggedModelInput | None = None, 1437 ) -> None: 1438 """ 1439 Log a dataset used in the current run. 1440 1441 Args: 1442 dataset: :py:class:`mlflow.data.dataset.Dataset` object to be logged. 1443 context: Context in which the dataset is used. For example: "training", "testing". 1444 This will be set as an input tag with key `mlflow.data.context`. 1445 tags: Tags to be associated with the dataset. Dictionary of tag_key -> tag_value. 1446 model: A :py:class:`mlflow.entities.LoggedModelInput` instance to log as input to 1447 the run. 1448 1449 .. code-block:: python 1450 :test: 1451 :caption: Example 1452 1453 import numpy as np 1454 import mlflow 1455 1456 array = np.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 1457 dataset = mlflow.data.from_numpy(array, source="data.csv") 1458 1459 # Log an input dataset used for training 1460 with mlflow.start_run(): 1461 mlflow.log_input(dataset, context="training") 1462 """ 1463 run_id = _get_or_start_run().info.run_id 1464 datasets = [_create_dataset_input(dataset, context, tags)] if dataset else None 1465 models = [model] if model else None 1466 1467 MlflowClient().log_inputs(run_id=run_id, datasets=datasets, models=models) 1468 1469 1470 def log_inputs( 1471 datasets: list[Optional["Dataset"]] | None = None, 1472 contexts: list[str | None] | None = None, 1473 tags_list: list[dict[str, str] | None] | None = None, 1474 models: list[LoggedModelInput | None] | None = None, 1475 ) -> None: 1476 """ 1477 Log a batch of datasets used in the current run. 1478 1479 The lists of `datasets`, `contexts`, `tags_list` must have the same length. 1480 The entries in these lists can be ``None``, which represents empty value to the 1481 corresponding input. 1482 1483 Args: 1484 datasets: List of :py:class:`mlflow.data.dataset.Dataset` object to be logged. 1485 contexts: List of context in which the dataset is used. For example: "training", "testing". 1486 This will be set as an input tag with key `mlflow.data.context`. 1487 tags_list: List of tags to be associated with the dataset. Dictionary of 1488 tag_key -> tag_value. 1489 models: List of :py:class:`mlflow.entities.LoggedModelInput` instance to log as input 1490 to the run. Currently only Databricks managed MLflow supports this argument. 1491 1492 .. code-block:: python 1493 :test: 1494 :caption: Example 1495 1496 import numpy as np 1497 import mlflow 1498 1499 array = np.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 1500 dataset = mlflow.data.from_numpy(array, source="data.csv") 1501 1502 array2 = np.asarray([[-1, 2, 3], [-4, 5, 6]]) 1503 dataset2 = mlflow.data.from_numpy(array2, source="data2.csv") 1504 1505 # Log 2 input datasets used for training and test, 1506 # the training dataset has no tag. 1507 # the test dataset has tags `{"my_tag": "tag_value"}`. 1508 with mlflow.start_run(): 1509 mlflow.log_inputs( 1510 [dataset, dataset2], 1511 contexts=["training", "test"], 1512 tags_list=[None, {"my_tag": "tag_value"}], 1513 models=None, 1514 ) 1515 """ 1516 from mlflow.utils.databricks_utils import is_databricks_uri 1517 1518 run_id = _get_or_start_run().info.run_id 1519 1520 datasets = datasets or [] 1521 contexts = contexts or [] 1522 tags_list = tags_list or [] 1523 if not (len(datasets) == len(contexts) == len(tags_list)): 1524 raise MlflowException( 1525 "`mlflow.log_inputs` requires `datasets`, `contexts`, `tags_list` to be " 1526 "non-empty list and have the same length." 1527 ) 1528 1529 if models and not is_databricks_uri(mlflow.get_tracking_uri()): 1530 raise MlflowException("'models' argument is only supported by Databricks managed MLflow.") 1531 1532 dataset_inputs = [ 1533 _create_dataset_input(dataset, context, tags) 1534 for dataset, context, tags in zip(datasets, contexts, tags_list) 1535 ] 1536 1537 MlflowClient().log_inputs(run_id=run_id, datasets=dataset_inputs, models=models) 1538 1539 1540 def set_experiment_tags(tags: dict[str, Any]) -> None: 1541 """ 1542 Set tags for the current active experiment. 1543 1544 Args: 1545 tags: Dictionary containing tag names and corresponding values. 1546 1547 .. code-block:: python 1548 :test: 1549 :caption: Example 1550 1551 import mlflow 1552 1553 tags = { 1554 "engineering": "ML Platform", 1555 "release.candidate": "RC1", 1556 "release.version": "2.2.0", 1557 } 1558 1559 # Set a batch of tags 1560 with mlflow.start_run(): 1561 mlflow.set_experiment_tags(tags) 1562 """ 1563 for key, value in tags.items(): 1564 set_experiment_tag(key, value) 1565 1566 1567 def set_tags(tags: dict[str, Any], synchronous: bool | None = None) -> RunOperations | None: 1568 """ 1569 Log a batch of tags for the current run. If no run is active, this method will create a 1570 new active run. 1571 1572 Args: 1573 tags: Dictionary of tag_name: String -> value: (String, but will be string-ified if 1574 not) 1575 synchronous: *Experimental* If True, blocks until tags are logged successfully. If False, 1576 logs tags asynchronously and returns a future representing the logging operation. 1577 If None, read from environment variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which 1578 defaults to False if not set. 1579 1580 Returns: 1581 When `synchronous=True`, returns None. When `synchronous=False`, returns an 1582 :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that 1583 represents future for logging operation. 1584 1585 .. code-block:: python 1586 :test: 1587 :caption: Example 1588 1589 import mlflow 1590 1591 tags = { 1592 "engineering": "ML Platform", 1593 "release.candidate": "RC1", 1594 "release.version": "2.2.0", 1595 } 1596 1597 # Set a batch of tags 1598 with mlflow.start_run(): 1599 mlflow.set_tags(tags) 1600 1601 # Set a batch of tags in async fashion. 1602 with mlflow.start_run(): 1603 mlflow.set_tags(tags, synchronous=False) 1604 """ 1605 run_id = _get_or_start_run().info.run_id 1606 tags_arr = [RunTag(key, str(value)) for key, value in tags.items()] 1607 synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get() 1608 return MlflowClient().log_batch( 1609 run_id=run_id, metrics=[], params=[], tags=tags_arr, synchronous=synchronous 1610 ) 1611 1612 1613 def log_artifact( 1614 local_path: str, artifact_path: str | None = None, run_id: str | None = None 1615 ) -> None: 1616 """ 1617 Log a local file or directory as an artifact of the currently active run. If no run is 1618 active, this method will create a new active run. 1619 1620 Args: 1621 local_path: Path to the file to write. 1622 artifact_path: If provided, the directory in ``artifact_uri`` to write to. 1623 run_id: If specified, log the artifact to the specified run. If not specified, log the 1624 artifact to the currently active run. 1625 1626 .. code-block:: python 1627 :test: 1628 :caption: Example 1629 1630 import tempfile 1631 from pathlib import Path 1632 1633 import mlflow 1634 1635 # Create a features.txt artifact file 1636 features = "rooms, zipcode, median_price, school_rating, transport" 1637 with tempfile.TemporaryDirectory() as tmp_dir: 1638 path = Path(tmp_dir, "features.txt") 1639 path.write_text(features) 1640 # With artifact_path=None write features.txt under 1641 # root artifact_uri/artifacts directory 1642 with mlflow.start_run(): 1643 mlflow.log_artifact(path) 1644 """ 1645 run_id = run_id or _get_or_start_run().info.run_id 1646 MlflowClient().log_artifact(run_id, local_path, artifact_path) 1647 1648 1649 def log_artifacts( 1650 local_dir: str, artifact_path: str | None = None, run_id: str | None = None 1651 ) -> None: 1652 """ 1653 Log all the contents of a local directory as artifacts of the run. If no run is active, 1654 this method will create a new active run. 1655 1656 Args: 1657 local_dir: Path to the directory of files to write. 1658 artifact_path: If provided, the directory in ``artifact_uri`` to write to. 1659 run_id: If specified, log the artifacts to the specified run. If not specified, log the 1660 artifacts to the currently active run. 1661 1662 .. code-block:: python 1663 :test: 1664 :caption: Example 1665 1666 import json 1667 import tempfile 1668 from pathlib import Path 1669 1670 import mlflow 1671 1672 # Create some files to preserve as artifacts 1673 features = "rooms, zipcode, median_price, school_rating, transport" 1674 data = {"state": "TX", "Available": 25, "Type": "Detached"} 1675 with tempfile.TemporaryDirectory() as tmp_dir: 1676 tmp_dir = Path(tmp_dir) 1677 with (tmp_dir / "data.json").open("w") as f: 1678 json.dump(data, f, indent=2) 1679 with (tmp_dir / "features.json").open("w") as f: 1680 f.write(features) 1681 # Write all files in `tmp_dir` to root artifact_uri/states 1682 with mlflow.start_run(): 1683 mlflow.log_artifacts(tmp_dir, artifact_path="states") 1684 """ 1685 run_id = run_id or _get_or_start_run().info.run_id 1686 MlflowClient().log_artifacts(run_id, local_dir, artifact_path) 1687 1688 1689 def log_text(text: str, artifact_file: str, run_id: str | None = None) -> None: 1690 """ 1691 Log text as an artifact. 1692 1693 Args: 1694 text: String containing text to log. 1695 artifact_file: The run-relative artifact file path in posixpath format to which 1696 the text is saved (e.g. "dir/file.txt"). 1697 run_id: If specified, log the artifact to the specified run. If not specified, log the 1698 artifact to the currently active run. 1699 1700 .. code-block:: python 1701 :test: 1702 :caption: Example 1703 1704 import mlflow 1705 1706 with mlflow.start_run(): 1707 # Log text to a file under the run's root artifact directory 1708 mlflow.log_text("text1", "file1.txt") 1709 1710 # Log text in a subdirectory of the run's root artifact directory 1711 mlflow.log_text("text2", "dir/file2.txt") 1712 1713 # Log HTML text 1714 mlflow.log_text("<h1>header</h1>", "index.html") 1715 1716 """ 1717 run_id = run_id or _get_or_start_run().info.run_id 1718 MlflowClient().log_text(run_id, text, artifact_file) 1719 1720 1721 def log_dict(dictionary: dict[str, Any], artifact_file: str, run_id: str | None = None) -> None: 1722 """ 1723 Log a JSON/YAML-serializable object (e.g. `dict`) as an artifact. The serialization 1724 format (JSON or YAML) is automatically inferred from the extension of `artifact_file`. 1725 If the file extension doesn't exist or match any of [".json", ".yml", ".yaml"], 1726 JSON format is used. 1727 1728 Args: 1729 dictionary: Dictionary to log. 1730 artifact_file: The run-relative artifact file path in posixpath format to which 1731 the dictionary is saved (e.g. "dir/data.json"). 1732 run_id: If specified, log the dictionary to the specified run. If not specified, log the 1733 dictionary to the currently active run. 1734 1735 .. code-block:: python 1736 :test: 1737 :caption: Example 1738 1739 import mlflow 1740 1741 dictionary = {"k": "v"} 1742 1743 with mlflow.start_run(): 1744 # Log a dictionary as a JSON file under the run's root artifact directory 1745 mlflow.log_dict(dictionary, "data.json") 1746 1747 # Log a dictionary as a YAML file in a subdirectory of the run's root artifact directory 1748 mlflow.log_dict(dictionary, "dir/data.yml") 1749 1750 # If the file extension doesn't exist or match any of [".json", ".yaml", ".yml"], 1751 # JSON format is used. 1752 mlflow.log_dict(dictionary, "data") 1753 mlflow.log_dict(dictionary, "data.txt") 1754 1755 """ 1756 run_id = run_id or _get_or_start_run().info.run_id 1757 MlflowClient().log_dict(run_id, dictionary, artifact_file) 1758 1759 1760 @experimental(version="3.9.0") 1761 def log_stream( 1762 stream: io.BufferedIOBase | io.RawIOBase, artifact_file: str, run_id: str | None = None 1763 ) -> None: 1764 """ 1765 Log a binary file-like object (e.g., ``io.BytesIO``) as an artifact. 1766 1767 Args: 1768 stream: A binary file-like object supporting ``.read()`` method (e.g., ``io.BytesIO``). 1769 artifact_file: The run-relative artifact file path in posixpath format to which 1770 the stream content is saved (e.g. "dir/file.bin"). 1771 run_id: If specified, log the artifact to the specified run. If not specified, log the 1772 artifact to the currently active run. 1773 1774 .. code-block:: python 1775 :test: 1776 :caption: Example 1777 1778 import io 1779 1780 import mlflow 1781 1782 with mlflow.start_run(): 1783 # Log a BytesIO stream 1784 bytes_stream = io.BytesIO(b"binary content") 1785 mlflow.log_stream(bytes_stream, "binary_file.bin") 1786 1787 """ 1788 run_id = run_id or _get_or_start_run().info.run_id 1789 MlflowClient().log_stream(run_id, stream, artifact_file) 1790 1791 1792 def log_figure( 1793 figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"], 1794 artifact_file: str, 1795 *, 1796 save_kwargs: dict[str, Any] | None = None, 1797 ) -> None: 1798 """ 1799 Log a figure as an artifact. The following figure objects are supported: 1800 1801 - `matplotlib.figure.Figure`_ 1802 - `plotly.graph_objects.Figure`_ 1803 1804 .. _matplotlib.figure.Figure: 1805 https://matplotlib.org/api/_as_gen/matplotlib.figure.Figure.html 1806 1807 .. _plotly.graph_objects.Figure: 1808 https://plotly.com/python-api-reference/generated/plotly.graph_objects.Figure.html 1809 1810 Args: 1811 figure: Figure to log. 1812 artifact_file: The run-relative artifact file path in posixpath format to which 1813 the figure is saved (e.g. "dir/file.png"). 1814 save_kwargs: Additional keyword arguments passed to the method that saves the figure. 1815 1816 .. code-block:: python 1817 :test: 1818 :caption: Matplotlib Example 1819 1820 import mlflow 1821 import matplotlib.pyplot as plt 1822 1823 fig, ax = plt.subplots() 1824 ax.plot([0, 1], [2, 3]) 1825 1826 with mlflow.start_run(): 1827 mlflow.log_figure(fig, "figure.png") 1828 1829 .. code-block:: python 1830 :test: 1831 :caption: Plotly Example 1832 1833 import mlflow 1834 from plotly import graph_objects as go 1835 1836 fig = go.Figure(go.Scatter(x=[0, 1], y=[2, 3])) 1837 1838 with mlflow.start_run(): 1839 mlflow.log_figure(fig, "figure.html") 1840 """ 1841 run_id = _get_or_start_run().info.run_id 1842 MlflowClient().log_figure(run_id, figure, artifact_file, save_kwargs=save_kwargs) 1843 1844 1845 def log_image( 1846 image: Union["numpy.ndarray", "PIL.Image.Image", "mlflow.Image"], 1847 artifact_file: str | None = None, 1848 key: str | None = None, 1849 step: int | None = None, 1850 timestamp: int | None = None, 1851 synchronous: bool | None = False, 1852 ) -> None: 1853 """ 1854 Logs an image in MLflow, supporting two use cases: 1855 1856 1. Time-stepped image logging: 1857 Ideal for tracking changes or progressions through iterative processes (e.g., 1858 during model training phases). 1859 1860 - Usage: :code:`log_image(image, key=key, step=step, timestamp=timestamp)` 1861 1862 2. Artifact file image logging: 1863 Best suited for static image logging where the image is saved directly as a file 1864 artifact. 1865 1866 - Usage: :code:`log_image(image, artifact_file)` 1867 1868 The following image formats are supported: 1869 - `numpy.ndarray`_ 1870 - `PIL.Image.Image`_ 1871 1872 .. _numpy.ndarray: 1873 https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html 1874 1875 .. _PIL.Image.Image: 1876 https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image 1877 1878 - :class:`mlflow.Image`: An MLflow wrapper around PIL image for convenient image logging. 1879 1880 Numpy array support 1881 - data types: 1882 1883 - bool (useful for logging image masks) 1884 - integer [0, 255] 1885 - unsigned integer [0, 255] 1886 - float [0.0, 1.0] 1887 1888 .. warning:: 1889 1890 - Out-of-range integer values will raise ValueError. 1891 - Out-of-range float values will auto-scale with min/max and warn. 1892 1893 - shape (H: height, W: width): 1894 1895 - H x W (Grayscale) 1896 - H x W x 1 (Grayscale) 1897 - H x W x 3 (an RGB channel order is assumed) 1898 - H x W x 4 (an RGBA channel order is assumed) 1899 1900 Args: 1901 image: The image object to be logged. 1902 artifact_file: Specifies the path, in POSIX format, where the image 1903 will be stored as an artifact relative to the run's root directory (for 1904 example, "dir/image.png"). This parameter is kept for backward compatibility 1905 and should not be used together with `key`, `step`, or `timestamp`. 1906 key: Image name for time-stepped image logging. This string may only contain 1907 alphanumerics, underscores (_), dashes (-), periods (.), spaces ( ), and 1908 slashes (/). 1909 step: Integer training step (iteration) at which the image was saved. 1910 Defaults to 0. 1911 timestamp: Time when this image was saved. Defaults to the current system time. 1912 synchronous: *Experimental* If True, blocks until the image is logged successfully. 1913 1914 .. code-block:: python 1915 :caption: Time-stepped image logging numpy example 1916 1917 import mlflow 1918 import numpy as np 1919 1920 image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 1921 1922 with mlflow.start_run(): 1923 mlflow.log_image(image, key="dogs", step=3) 1924 1925 .. code-block:: python 1926 :caption: Time-stepped image logging pillow example 1927 1928 import mlflow 1929 from PIL import Image 1930 1931 image = Image.new("RGB", (100, 100)) 1932 1933 with mlflow.start_run(): 1934 mlflow.log_image(image, key="dogs", step=3) 1935 1936 .. code-block:: python 1937 :caption: Time-stepped image logging with mlflow.Image example 1938 1939 import mlflow 1940 from PIL import Image 1941 1942 # If you have a preexisting saved image 1943 Image.new("RGB", (100, 100)).save("image.png") 1944 1945 image = mlflow.Image("image.png") 1946 with mlflow.start_run() as run: 1947 mlflow.log_image(run.info.run_id, image, key="dogs", step=3) 1948 1949 .. code-block:: python 1950 :caption: Legacy artifact file image logging numpy example 1951 1952 import mlflow 1953 import numpy as np 1954 1955 image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 1956 1957 with mlflow.start_run(): 1958 mlflow.log_image(image, "image.png") 1959 1960 .. code-block:: python 1961 :caption: Legacy artifact file image logging pillow example 1962 1963 import mlflow 1964 from PIL import Image 1965 1966 image = Image.new("RGB", (100, 100)) 1967 1968 with mlflow.start_run(): 1969 mlflow.log_image(image, "image.png") 1970 """ 1971 run_id = _get_or_start_run().info.run_id 1972 MlflowClient().log_image(run_id, image, artifact_file, key, step, timestamp, synchronous) 1973 1974 1975 def log_table( 1976 data: Union[dict[str, Any], "pandas.DataFrame"], 1977 artifact_file: str, 1978 run_id: str | None = None, 1979 ) -> None: 1980 """ 1981 Log a table to MLflow Tracking as a JSON artifact. If the artifact_file already exists 1982 in the run, the data would be appended to the existing artifact_file. 1983 1984 Args: 1985 data: Dictionary or pandas.DataFrame to log. 1986 artifact_file: The run-relative artifact file path in posixpath format to which 1987 the table is saved (e.g. "dir/file.json"). 1988 run_id: If specified, log the table to the specified run. If not specified, log the 1989 table to the currently active run. 1990 1991 .. code-block:: python 1992 :test: 1993 :caption: Dictionary Example 1994 1995 import mlflow 1996 1997 table_dict = { 1998 "inputs": ["What is MLflow?", "What is Databricks?"], 1999 "outputs": ["MLflow is ...", "Databricks is ..."], 2000 "toxicity": [0.0, 0.0], 2001 } 2002 with mlflow.start_run(): 2003 # Log the dictionary as a table 2004 mlflow.log_table(data=table_dict, artifact_file="qabot_eval_results.json") 2005 2006 .. code-block:: python 2007 :test: 2008 :caption: Pandas DF Example 2009 2010 import mlflow 2011 import pandas as pd 2012 2013 table_dict = { 2014 "inputs": ["What is MLflow?", "What is Databricks?"], 2015 "outputs": ["MLflow is ...", "Databricks is ..."], 2016 "toxicity": [0.0, 0.0], 2017 } 2018 df = pd.DataFrame.from_dict(table_dict) 2019 with mlflow.start_run(): 2020 # Log the df as a table 2021 mlflow.log_table(data=df, artifact_file="qabot_eval_results.json") 2022 """ 2023 run_id = run_id or _get_or_start_run().info.run_id 2024 MlflowClient().log_table(run_id, data, artifact_file) 2025 2026 2027 def load_table( 2028 artifact_file: str, 2029 run_ids: list[str] | None = None, 2030 extra_columns: list[str] | None = None, 2031 ) -> "pandas.DataFrame": 2032 """ 2033 Load a table from MLflow Tracking as a pandas.DataFrame. The table is loaded from the 2034 specified artifact_file in the specified run_ids. The extra_columns are columns that 2035 are not in the table but are augmented with run information and added to the DataFrame. 2036 2037 Args: 2038 artifact_file: The run-relative artifact file path in posixpath format to which 2039 table to load (e.g. "dir/file.json"). 2040 run_ids: Optional list of run_ids to load the table from. If no run_ids are specified, 2041 the table is loaded from all runs in the current experiment. 2042 extra_columns: Optional list of extra columns to add to the returned DataFrame 2043 For example, if extra_columns=["run_id"], then the returned DataFrame 2044 will have a column named run_id. 2045 2046 Returns: 2047 pandas.DataFrame containing the loaded table if the artifact exists 2048 or else throw a MlflowException. 2049 2050 .. code-block:: python 2051 :test: 2052 :caption: Example with passing run_ids 2053 2054 import mlflow 2055 2056 table_dict = { 2057 "inputs": ["What is MLflow?", "What is Databricks?"], 2058 "outputs": ["MLflow is ...", "Databricks is ..."], 2059 "toxicity": [0.0, 0.0], 2060 } 2061 2062 with mlflow.start_run() as run: 2063 # Log the dictionary as a table 2064 mlflow.log_table(data=table_dict, artifact_file="qabot_eval_results.json") 2065 run_id = run.info.run_id 2066 2067 loaded_table = mlflow.load_table( 2068 artifact_file="qabot_eval_results.json", 2069 run_ids=[run_id], 2070 # Append a column containing the associated run ID for each row 2071 extra_columns=["run_id"], 2072 ) 2073 2074 .. code-block:: python 2075 :test: 2076 :caption: Example with passing no run_ids 2077 2078 # Loads the table with the specified name for all runs in the given 2079 # experiment and joins them together 2080 import mlflow 2081 2082 table_dict = { 2083 "inputs": ["What is MLflow?", "What is Databricks?"], 2084 "outputs": ["MLflow is ...", "Databricks is ..."], 2085 "toxicity": [0.0, 0.0], 2086 } 2087 2088 with mlflow.start_run(): 2089 # Log the dictionary as a table 2090 mlflow.log_table(data=table_dict, artifact_file="qabot_eval_results.json") 2091 2092 loaded_table = mlflow.load_table( 2093 "qabot_eval_results.json", 2094 # Append the run ID and the parent run ID to the table 2095 extra_columns=["run_id"], 2096 ) 2097 """ 2098 experiment_id = _get_experiment_id() 2099 return MlflowClient().load_table(experiment_id, artifact_file, run_ids, extra_columns) 2100 2101 2102 def _record_logged_model(mlflow_model, run_id=None): 2103 run_id = run_id or _get_or_start_run().info.run_id 2104 MlflowClient()._record_logged_model(run_id, mlflow_model) 2105 2106 2107 def get_experiment(experiment_id: str) -> Experiment: 2108 """Retrieve an experiment by experiment_id from the backend store 2109 2110 Args: 2111 experiment_id: The string-ified experiment ID returned from ``create_experiment``. 2112 2113 Returns: 2114 :py:class:`mlflow.entities.Experiment` 2115 2116 .. code-block:: python 2117 :test: 2118 :caption: Example 2119 2120 import mlflow 2121 2122 experiment = mlflow.get_experiment("0") 2123 print(f"Name: {experiment.name}") 2124 print(f"Artifact Location: {experiment.artifact_location}") 2125 print(f"Tags: {experiment.tags}") 2126 print(f"Lifecycle_stage: {experiment.lifecycle_stage}") 2127 print(f"Creation timestamp: {experiment.creation_time}") 2128 2129 .. code-block:: text 2130 :caption: Output 2131 2132 Name: Default 2133 Artifact Location: file:///.../mlruns/0 2134 Tags: {} 2135 Lifecycle_stage: active 2136 Creation timestamp: 1662004217511 2137 """ 2138 return MlflowClient().get_experiment(experiment_id) 2139 2140 2141 def get_experiment_by_name(name: str) -> Experiment | None: 2142 """ 2143 Retrieve an experiment by experiment name from the backend store 2144 2145 Args: 2146 name: The case sensitive experiment name. 2147 2148 Returns: 2149 An instance of :py:class:`mlflow.entities.Experiment` 2150 if an experiment with the specified name exists, otherwise None. 2151 2152 .. code-block:: python 2153 :test: 2154 :caption: Example 2155 2156 import mlflow 2157 2158 # Case sensitive name 2159 experiment = mlflow.get_experiment_by_name("Default") 2160 print(f"Experiment_id: {experiment.experiment_id}") 2161 print(f"Artifact Location: {experiment.artifact_location}") 2162 print(f"Tags: {experiment.tags}") 2163 print(f"Lifecycle_stage: {experiment.lifecycle_stage}") 2164 print(f"Creation timestamp: {experiment.creation_time}") 2165 2166 .. code-block:: text 2167 :caption: Output 2168 2169 Experiment_id: 0 2170 Artifact Location: file:///.../mlruns/0 2171 Tags: {} 2172 Lifecycle_stage: active 2173 Creation timestamp: 1662004217511 2174 """ 2175 return MlflowClient().get_experiment_by_name(name) 2176 2177 2178 def search_experiments( 2179 view_type: int = ViewType.ACTIVE_ONLY, 2180 max_results: int | None = None, 2181 filter_string: str | None = None, 2182 order_by: list[str] | None = None, 2183 ) -> list[Experiment]: 2184 """ 2185 Search for experiments that match the specified search query. 2186 2187 Args: 2188 view_type: One of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``, or ``ALL`` 2189 defined in :py:class:`mlflow.entities.ViewType`. 2190 max_results: If passed, specifies the maximum number of experiments desired. If not 2191 passed, all experiments will be returned. 2192 filter_string: Filter query string (e.g., ``"name = 'my_experiment'"``), defaults to 2193 searching for all experiments. The following identifiers, comparators, and logical 2194 operators are supported. 2195 2196 Identifiers 2197 - ``name``: Experiment name 2198 - ``creation_time``: Experiment creation time 2199 - ``last_update_time``: Experiment last update time 2200 - ``tags.<tag_key>``: Experiment tag. If ``tag_key`` contains 2201 spaces, it must be wrapped with backticks (e.g., ``"tags.`extra key`"``). 2202 2203 Comparators for string attributes and tags 2204 - ``=``: Equal to 2205 - ``!=``: Not equal to 2206 - ``LIKE``: Case-sensitive pattern match 2207 - ``ILIKE``: Case-insensitive pattern match 2208 2209 Comparators for numeric attributes 2210 - ``=``: Equal to 2211 - ``!=``: Not equal to 2212 - ``<``: Less than 2213 - ``<=``: Less than or equal to 2214 - ``>``: Greater than 2215 - ``>=``: Greater than or equal to 2216 2217 Logical operators 2218 - ``AND``: Combines two sub-queries and returns True if both of them are True. 2219 2220 order_by: List of columns to order by. The ``order_by`` column can contain an optional 2221 ``DESC`` or ``ASC`` value (e.g., ``"name DESC"``). The default ordering is ``ASC``, 2222 so ``"name"`` is equivalent to ``"name ASC"``. If unspecified, defaults to 2223 ``["last_update_time DESC"]``, which lists experiments updated most recently first. 2224 The following fields are supported: 2225 2226 - ``experiment_id``: Experiment ID 2227 - ``name``: Experiment name 2228 - ``creation_time``: Experiment creation time 2229 - ``last_update_time``: Experiment last update time 2230 2231 Returns: 2232 A list of :py:class:`Experiment <mlflow.entities.Experiment>` objects. 2233 2234 .. code-block:: python 2235 :test: 2236 :caption: Example 2237 2238 import mlflow 2239 2240 2241 def assert_experiment_names_equal(experiments, expected_names): 2242 actual_names = [e.name for e in experiments if e.name != "Default"] 2243 assert actual_names == expected_names, (actual_names, expected_names) 2244 2245 2246 mlflow.set_tracking_uri("sqlite:///:memory:") 2247 # Create experiments 2248 for name, tags in [ 2249 ("a", None), 2250 ("b", None), 2251 ("ab", {"k": "v"}), 2252 ("bb", {"k": "V"}), 2253 ]: 2254 mlflow.create_experiment(name, tags=tags) 2255 2256 # Search for experiments with name "a" 2257 experiments = mlflow.search_experiments(filter_string="name = 'a'") 2258 assert_experiment_names_equal(experiments, ["a"]) 2259 # Search for experiments with name starting with "a" 2260 experiments = mlflow.search_experiments(filter_string="name LIKE 'a%'") 2261 assert_experiment_names_equal(experiments, ["ab", "a"]) 2262 # Search for experiments with tag key "k" and value ending with "v" or "V" 2263 experiments = mlflow.search_experiments(filter_string="tags.k ILIKE '%v'") 2264 assert_experiment_names_equal(experiments, ["bb", "ab"]) 2265 # Search for experiments with name ending with "b" and tag {"k": "v"} 2266 experiments = mlflow.search_experiments(filter_string="name LIKE '%b' AND tags.k = 'v'") 2267 assert_experiment_names_equal(experiments, ["ab"]) 2268 # Sort experiments by name in ascending order 2269 experiments = mlflow.search_experiments(order_by=["name"]) 2270 assert_experiment_names_equal(experiments, ["a", "ab", "b", "bb"]) 2271 # Sort experiments by ID in descending order 2272 experiments = mlflow.search_experiments(order_by=["experiment_id DESC"]) 2273 assert_experiment_names_equal(experiments, ["bb", "ab", "b", "a"]) 2274 """ 2275 2276 def pagination_wrapper_func(number_to_get, next_page_token): 2277 return MlflowClient().search_experiments( 2278 view_type=view_type, 2279 max_results=number_to_get, 2280 filter_string=filter_string, 2281 order_by=order_by, 2282 page_token=next_page_token, 2283 ) 2284 2285 return get_results_from_paginated_fn( 2286 pagination_wrapper_func, 2287 SEARCH_MAX_RESULTS_DEFAULT, 2288 max_results, 2289 ) 2290 2291 2292 def create_experiment( 2293 name: str, 2294 artifact_location: str | None = None, 2295 tags: dict[str, Any] | None = None, 2296 trace_location: UnityCatalog | None = None, 2297 ) -> str: 2298 """ 2299 Create an experiment. 2300 2301 Args: 2302 name: The experiment name, must be a non-empty unique string. 2303 artifact_location: The location to store run artifacts. If not provided, the server picks 2304 an appropriate default. 2305 tags: An optional dictionary of string keys and values to set as tags on the experiment. 2306 trace_location: Optional UC trace location to link to the experiment. Must be an instance 2307 of ``mlflow.entities.trace_location.UnityCatalog(...)``. If ``table_prefix`` is not 2308 set, it defaults to the experiment ID. Note: call ``mlflow.set_experiment`` afterward 2309 to activate the experiment and sync the trace provider. 2310 2311 Returns: 2312 String ID of the created experiment. 2313 2314 .. code-block:: python 2315 :test: 2316 :caption: Example 2317 2318 import mlflow 2319 from pathlib import Path 2320 2321 # Create an experiment name, which must be unique and case sensitive 2322 experiment_id = mlflow.create_experiment( 2323 "Social NLP Experiments", 2324 artifact_location=Path.cwd().joinpath("mlruns").as_uri(), 2325 tags={"version": "v1", "priority": "P1"}, 2326 ) 2327 experiment = mlflow.get_experiment(experiment_id) 2328 print(f"Name: {experiment.name}") 2329 print(f"Experiment_id: {experiment.experiment_id}") 2330 print(f"Artifact Location: {experiment.artifact_location}") 2331 print(f"Tags: {experiment.tags}") 2332 print(f"Lifecycle_stage: {experiment.lifecycle_stage}") 2333 print(f"Creation timestamp: {experiment.creation_time}") 2334 2335 .. code-block:: text 2336 :caption: Output 2337 2338 Name: Social NLP Experiments 2339 Experiment_id: 1 2340 Artifact Location: file:///.../mlruns 2341 Tags: {'version': 'v1', 'priority': 'P1'} 2342 Lifecycle_stage: active 2343 Creation timestamp: 1662004217511 2344 """ 2345 client = MlflowClient() 2346 experiment_id = client.create_experiment(name, artifact_location, tags) 2347 2348 if trace_location is not None: 2349 experiment = client.get_experiment(experiment_id) 2350 2351 if trace_location.table_prefix is None: 2352 trace_location = UnityCatalog( 2353 catalog_name=trace_location.catalog_name, 2354 schema_name=trace_location.schema_name, 2355 table_prefix=experiment_id, 2356 ) 2357 2358 try: 2359 _resolve_experiment_to_trace_location( 2360 experiment=experiment, 2361 trace_location=trace_location, 2362 ) 2363 except MlflowException as e: 2364 raise MlflowException.invalid_parameter_value( 2365 f"Experiment '{name}' (ID: {experiment_id}) was created " 2366 f"but linking to trace location '{trace_location.full_table_prefix}' failed: " 2367 f"{e.message} Please delete the experiment and retry." 2368 ) from e 2369 2370 return experiment_id 2371 2372 2373 def delete_experiment(experiment_id: str) -> None: 2374 """ 2375 Delete an experiment from the backend store. 2376 2377 Args: 2378 experiment_id: The string-ified experiment ID returned from ``create_experiment``. 2379 2380 .. code-block:: python 2381 :test: 2382 :caption: Example 2383 2384 import mlflow 2385 2386 experiment_id = mlflow.create_experiment("New Experiment") 2387 mlflow.delete_experiment(experiment_id) 2388 2389 # Examine the deleted experiment details. 2390 experiment = mlflow.get_experiment(experiment_id) 2391 print(f"Name: {experiment.name}") 2392 print(f"Artifact Location: {experiment.artifact_location}") 2393 print(f"Lifecycle_stage: {experiment.lifecycle_stage}") 2394 print(f"Last Updated timestamp: {experiment.last_update_time}") 2395 2396 .. code-block:: text 2397 :caption: Output 2398 2399 Name: New Experiment 2400 Artifact Location: file:///.../mlruns/2 2401 Lifecycle_stage: deleted 2402 Last Updated timestamp: 1662004217511 2403 2404 """ 2405 MlflowClient().delete_experiment(experiment_id) 2406 2407 2408 def initialize_logged_model( 2409 name: str | None = None, 2410 source_run_id: str | None = None, 2411 tags: dict[str, str] | None = None, 2412 params: dict[str, str] | None = None, 2413 model_type: str | None = None, 2414 experiment_id: str | None = None, 2415 ) -> LoggedModel: 2416 """ 2417 Initialize a LoggedModel. Creates a LoggedModel with status ``PENDING`` and no artifacts. You 2418 must add artifacts to the model and finalize it to the ``READY`` state, for example by calling 2419 a flavor-specific ``log_model()`` method such as :py:func:`mlflow.pyfunc.log_model()`. 2420 2421 Args: 2422 name: The name of the model. If not specified, a random name will be generated. 2423 source_run_id: The ID of the run that the model is associated with. If unspecified and a 2424 run is active, the active run ID will be used. 2425 tags: A dictionary of string keys and values to set as tags on the model. 2426 params: A dictionary of string keys and values to set as parameters on the model. 2427 model_type: The type of the model. 2428 experiment_id: The experiment ID of the experiment to which the model belongs. 2429 2430 Returns: 2431 A new :py:class:`mlflow.entities.LoggedModel` object with status ``PENDING``. 2432 """ 2433 return _initialize_logged_model( 2434 name=name, 2435 source_run_id=source_run_id, 2436 tags=tags, 2437 params=params, 2438 model_type=model_type, 2439 experiment_id=experiment_id, 2440 flavor="initialize", 2441 ) 2442 2443 2444 def _initialize_logged_model( 2445 name: str | None = None, 2446 source_run_id: str | None = None, 2447 tags: dict[str, str] | None = None, 2448 params: dict[str, str] | None = None, 2449 model_type: str | None = None, 2450 experiment_id: str | None = None, 2451 # this is only for internal logging purpose 2452 flavor: str | None = None, 2453 ) -> LoggedModel: 2454 model = _create_logged_model( 2455 name=name, 2456 source_run_id=source_run_id, 2457 tags=tags, 2458 params=params, 2459 model_type=model_type, 2460 experiment_id=experiment_id, 2461 flavor=flavor, 2462 ) 2463 _last_logged_model_id.set(model.model_id) 2464 return model 2465 2466 2467 @contextlib.contextmanager 2468 def _use_logged_model(model: LoggedModel) -> Generator[LoggedModel, None, None]: 2469 """ 2470 Context manager to wrap a LoggedModel and update the model 2471 status after the context is exited. 2472 If any exception occurs, the model status is set to FAILED. 2473 Otherwise, it is set to READY. 2474 """ 2475 try: 2476 yield model 2477 except Exception: 2478 finalize_logged_model(model.model_id, LoggedModelStatus.FAILED) 2479 raise 2480 else: 2481 finalize_logged_model(model.model_id, LoggedModelStatus.READY) 2482 2483 2484 def create_external_model( 2485 name: str | None = None, 2486 source_run_id: str | None = None, 2487 tags: dict[str, str] | None = None, 2488 params: dict[str, str] | None = None, 2489 model_type: str | None = None, 2490 experiment_id: str | None = None, 2491 ) -> LoggedModel: 2492 """ 2493 Create a new LoggedModel whose artifacts are stored outside of MLflow. This is useful for 2494 tracking parameters and performance data (metrics, traces etc.) for a model, application, or 2495 generative AI agent that is not packaged using the MLflow Model format. 2496 2497 Args: 2498 name: The name of the model. If not specified, a random name will be generated. 2499 source_run_id: The ID of the run that the model is associated with. If unspecified and a 2500 run is active, the active run ID will be used. 2501 tags: A dictionary of string keys and values to set as tags on the model. 2502 params: A dictionary of string keys and values to set as parameters on the model. 2503 model_type: The type of the model. This is a user-defined string that can be used to 2504 search and compare related models. For example, setting ``model_type="agent"`` 2505 enables you to easily search for this model and compare it to other models of 2506 type ``"agent"`` in the future. 2507 experiment_id: The experiment ID of the experiment to which the model belongs. 2508 2509 Returns: 2510 A new :py:class:`mlflow.entities.LoggedModel` object with status ``READY``. 2511 """ 2512 from mlflow.models.model import MLMODEL_FILE_NAME, Model 2513 from mlflow.models.utils import get_external_mlflow_model_spec 2514 2515 tags = dict(tags) if tags else {} 2516 tags[MLFLOW_MODEL_IS_EXTERNAL] = "true" 2517 2518 client = MlflowClient() 2519 model = _create_logged_model( 2520 name=name, 2521 source_run_id=source_run_id, 2522 tags=tags, 2523 params=params, 2524 model_type=model_type, 2525 experiment_id=experiment_id, 2526 flavor="external", 2527 ) 2528 2529 # If a model is external, its artifacts (code, weights, etc.) are not stored in MLflow. 2530 # Accordingly, we finalize the model immediately after creation, since there aren't 2531 # any model artifacts for the client to upload to MLflow. Additionally, we create a 2532 # dummy MLModel file to ensure that the model can be registered to the Model Registry 2533 mlflow_model: Model = get_external_mlflow_model_spec(model) 2534 with TempDir() as tmp: 2535 mlflow_model.save(tmp.path(MLMODEL_FILE_NAME)) 2536 MlflowClient().log_model_artifacts( 2537 model_id=model.model_id, 2538 local_dir=tmp.path(), 2539 ) 2540 2541 model = client.finalize_logged_model(model_id=model.model_id, status=LoggedModelStatus.READY) 2542 _last_logged_model_id.set(model.model_id) 2543 2544 return model 2545 2546 2547 def _create_logged_model( 2548 name: str | None = None, 2549 source_run_id: str | None = None, 2550 tags: dict[str, str] | None = None, 2551 params: dict[str, str] | None = None, 2552 model_type: str | None = None, 2553 experiment_id: str | None = None, 2554 flavor: str | None = None, 2555 serialization_format: str | None = None, 2556 uses_uv: bool = False, 2557 ) -> LoggedModel: 2558 """ 2559 Create a new LoggedModel in the ``PENDING`` state. 2560 2561 Args: 2562 name: The name of the model. If not specified, a random name will be generated. 2563 source_run_id: The ID of the run that the model is associated with. If unspecified and a 2564 run is active, the active run ID will be used. 2565 tags: A dictionary of string keys and values to set as tags on the model. 2566 params: A dictionary of string keys and values to set as parameters on the model. 2567 model_type: The type of the model. This is a user-defined string that can be used to 2568 search and compare related models. For example, setting ``model_type="agent"`` 2569 enables you to easily search for this model and compare it to other models of 2570 type ``"agent"`` in the future. 2571 experiment_id: The experiment ID of the experiment to which the model belongs. 2572 flavor: The flavor of the model, recorded for telemetry and analytics only; it does not 2573 affect the stored LoggedModel. 2574 serialization_format: The serialization format of the model, recorded for telemetry and 2575 analytics only; it does not affect the stored LoggedModel. 2576 uses_uv: Whether the model uses uv dependency management, recorded for telemetry and 2577 analytics only; it does not affect the stored LoggedModel. 2578 2579 Returns: 2580 A new LoggedModel in the ``PENDING`` state. 2581 """ 2582 if source_run_id is None and (run := active_run()): 2583 source_run_id = run.info.run_id 2584 2585 if experiment_id is None and (run := active_run()): 2586 experiment_id = run.info.experiment_id 2587 elif experiment_id is None: 2588 experiment_id = _get_experiment_id() or ( 2589 get_run(source_run_id).info.experiment_id if source_run_id else None 2590 ) 2591 resolved_tags = context_registry.resolve_tags(tags) 2592 return MlflowClient()._create_logged_model( 2593 experiment_id=experiment_id, 2594 name=name, 2595 source_run_id=source_run_id, 2596 tags=resolved_tags, 2597 params=params, 2598 model_type=model_type, 2599 flavor=flavor, 2600 serialization_format=serialization_format, 2601 uses_uv=uses_uv, 2602 ) 2603 2604 2605 def log_model_params(params: dict[str, str], model_id: str | None = None) -> None: 2606 """ 2607 Log params to the specified logged model. 2608 2609 Args: 2610 params: Params to log on the model. 2611 model_id: ID of the model. If not specified, use the current active model ID. 2612 2613 Returns: 2614 None 2615 2616 Example: 2617 2618 .. code-block:: python 2619 :test: 2620 2621 import mlflow 2622 2623 2624 class DummyModel(mlflow.pyfunc.PythonModel): 2625 def predict(self, context, model_input: list[str]) -> list[str]: 2626 return model_input 2627 2628 2629 model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel()) 2630 mlflow.log_model_params(params={"param": "value"}, model_id=model_info.model_id) 2631 """ 2632 model_id = model_id or get_active_model_id() 2633 MlflowClient().log_model_params(model_id, params) 2634 2635 2636 def import_checkpoints( 2637 checkpoint_path: str, 2638 source_run_id: str | None = None, 2639 model_prefix: str | None = None, 2640 overwrite_checkpoints: bool = False, 2641 ) -> list[LoggedModel]: 2642 """ 2643 Create external models for all top-level files and directories under the specified 2644 checkpoint path. 2645 2646 This API only supports Databricks runtime currently. 2647 2648 Args: 2649 checkpoint_path: Path that contains the checkpoints. 2650 Only Databricks Unity Catalog Volume path is supported for now. 2651 It must follows the 2652 "/Volumes/<catalog_identifier>/<schema_identifier>/<volume_identifier>/<path_to_checkpoints_directory>" 2653 format specified https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-volumes#volume-naming-and-reference. 2654 Note: Each path must be isolated from other models and runs. 2655 source_run_id: ID of the MLflow source run that these checkpoints were trained with. 2656 If not provided, uses the current active run if available. 2657 model_prefix: String prefix to prepend to the name of each external model created from 2658 each checkpoint. If not provided, no prefix is applied. 2659 overwrite_checkpoints: If True and existing models are found with the same name in the 2660 associated experiment, they will be deleted and recreated to point to the latest 2661 checkpoint. Defaults to False. 2662 2663 Returns: 2664 List of imported models. If 'overwrite_checkpoints' is True, the list only contains 2665 new created models, otherwise the list contains new created models for the new model 2666 names and existing models for the existing model names. 2667 2668 Example: 2669 2670 .. code-block:: python 2671 2672 import mlflow 2673 2674 # Optionally start a run so `source_run_id` can be inferred 2675 with mlflow.start_run() as run: 2676 # ... training code that writes checkpoints to a UC Volume ... 2677 logged_models = mlflow.import_checkpoints( 2678 checkpoint_path=( 2679 "/Volumes/mycatalog/myschema/myvolume/mytrainingmodel/trainingrun1/checkpoints" 2680 ), 2681 # You can omit `source_run_id` if there is an active run. 2682 # source_run_id=run.info.run_id, 2683 model_prefix="my_model_", 2684 overwrite_checkpoints=True, 2685 ) 2686 """ 2687 from databricks.sdk import WorkspaceClient 2688 2689 # Validate checkpoint_path before accessing workspace files 2690 if not isinstance(checkpoint_path, str) or not checkpoint_path.strip().startswith("/Volumes/"): 2691 raise MlflowException( 2692 "Parameter 'checkpoint_path' must be a non-empty string pointing to a Unity Catalog " 2693 "Volume path that contains checkpoints, e.g. '/Volumes/...'", 2694 error_code=INVALID_PARAMETER_VALUE, 2695 ) 2696 2697 # Resolve source_run_id from the active run if not provided 2698 if source_run_id is None: 2699 if run := active_run(): 2700 source_run_id = run.info.run_id 2701 else: 2702 raise MlflowException.invalid_parameter_value( 2703 "Please set 'source_run_id' or start an active run before calling " 2704 "'import_checkpoints'." 2705 ) 2706 2707 # Resolve experiment ID to operate against 2708 exp_id = MlflowClient().get_run(source_run_id).info.experiment_id 2709 2710 ws = WorkspaceClient() 2711 top_level_paths = [ 2712 entry.path.rstrip("/") for entry in ws.files.list_directory_contents(checkpoint_path) 2713 ] 2714 2715 imported_models: list[LoggedModel] = [] 2716 client = MlflowClient() 2717 2718 if not top_level_paths: 2719 _logger.warning( 2720 f"No checkpoints were found at path '{checkpoint_path}'. " 2721 "Please verify that 'checkpoint_path' is correct and accessible." 2722 ) 2723 return [] 2724 2725 for sub_checkpoint_path in top_level_paths: 2726 base_name = os.path.basename(sub_checkpoint_path) 2727 2728 model_name = model_prefix + base_name if model_prefix else base_name 2729 2730 try: 2731 _validate_logged_model_name(model_name) 2732 except MlflowException as e: 2733 _logger.warning( 2734 f"The model name is invalid (root error: {e!s}), skip importing the " 2735 f"model with name '{model_name}' from checkpoint folder '{sub_checkpoint_path}'." 2736 ) 2737 continue 2738 2739 existing_models = [ 2740 model 2741 for model in search_logged_models( 2742 experiment_ids=[exp_id], 2743 filter_string=f"name = '{model_name}'", 2744 output_format="list", 2745 ) 2746 if model.source_run_id == source_run_id 2747 ] 2748 2749 if not existing_models or overwrite_checkpoints: 2750 # Create a new model pointing to this checkpoint path. 2751 created_model = create_external_model( 2752 name=model_name, 2753 source_run_id=source_run_id, 2754 tags={"original_artifact_path": sub_checkpoint_path}, 2755 experiment_id=exp_id, 2756 ) 2757 imported_models.append(created_model) 2758 else: 2759 imported_models.extend(existing_models) 2760 2761 if existing_models and overwrite_checkpoints: 2762 for model in existing_models: 2763 client.delete_logged_model(model.model_id) 2764 2765 return imported_models 2766 2767 2768 def finalize_logged_model( 2769 model_id: str, status: Literal["READY", "FAILED"] | LoggedModelStatus 2770 ) -> LoggedModel: 2771 """ 2772 Finalize a model by updating its status. 2773 2774 Args: 2775 model_id: ID of the model to finalize. 2776 status: Final status to set on the model. 2777 2778 Returns: 2779 The updated model. 2780 2781 Example: 2782 2783 .. code-block:: python 2784 :test: 2785 2786 import mlflow 2787 from mlflow.entities import LoggedModelStatus 2788 2789 model = mlflow.initialize_logged_model(name="model") 2790 logged_model = mlflow.finalize_logged_model( 2791 model_id=model.model_id, 2792 status=LoggedModelStatus.READY, 2793 ) 2794 assert logged_model.status == LoggedModelStatus.READY 2795 2796 """ 2797 return MlflowClient().finalize_logged_model(model_id, status) 2798 2799 2800 def get_logged_model(model_id: str) -> LoggedModel: 2801 """ 2802 Get a logged model by ID. 2803 2804 Args: 2805 model_id: The ID of the logged model. 2806 2807 Returns: 2808 The logged model. 2809 2810 Example: 2811 2812 .. code-block:: python 2813 :test: 2814 2815 import mlflow 2816 2817 2818 class DummyModel(mlflow.pyfunc.PythonModel): 2819 def predict(self, context, model_input: list[str]) -> list[str]: 2820 return model_input 2821 2822 2823 model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel()) 2824 logged_model = mlflow.get_logged_model(model_id=model_info.model_id) 2825 assert logged_model.model_id == model_info.model_id 2826 2827 """ 2828 return MlflowClient().get_logged_model(model_id) 2829 2830 2831 def last_logged_model() -> LoggedModel | None: 2832 """ 2833 Fetches the most recent logged model in the current session. 2834 If no model has been logged, None is returned. 2835 2836 Returns: 2837 The logged model. 2838 2839 2840 .. code-block:: python 2841 :test: 2842 :caption: Example 2843 2844 import mlflow 2845 2846 2847 class DummyModel(mlflow.pyfunc.PythonModel): 2848 def predict(self, context, model_input: list[str]) -> list[str]: 2849 return model_input 2850 2851 2852 model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel()) 2853 last_model = mlflow.last_logged_model() 2854 assert last_model.model_id == model_info.model_id 2855 """ 2856 if id := _last_logged_model_id.get(): 2857 return get_logged_model(id) 2858 2859 2860 @overload 2861 def search_logged_models( 2862 experiment_ids: list[str] | None = None, 2863 filter_string: str | None = None, 2864 datasets: list[dict[str, str]] | None = None, 2865 max_results: int | None = None, 2866 order_by: list[dict[str, Any]] | None = None, 2867 output_format: Literal["pandas"] = "pandas", 2868 ) -> "pandas.DataFrame": ... 2869 2870 2871 @overload 2872 def search_logged_models( 2873 experiment_ids: list[str] | None = None, 2874 filter_string: str | None = None, 2875 datasets: list[dict[str, str]] | None = None, 2876 max_results: int | None = None, 2877 order_by: list[dict[str, Any]] | None = None, 2878 output_format: Literal["list"] = "list", 2879 ) -> list[LoggedModel]: ... 2880 2881 2882 def search_logged_models( 2883 experiment_ids: list[str] | None = None, 2884 filter_string: str | None = None, 2885 datasets: list[dict[str, str]] | None = None, 2886 max_results: int | None = None, 2887 order_by: list[dict[str, Any]] | None = None, 2888 output_format: Literal["pandas", "list"] = "pandas", 2889 ) -> Union[list[LoggedModel], "pandas.DataFrame"]: 2890 """ 2891 Search for logged models that match the specified search criteria. 2892 2893 Args: 2894 experiment_ids: List of experiment IDs to search for logged models. If not specified, 2895 the active experiment will be used. 2896 filter_string: A SQL-like filter string to parse. The filter string syntax supports: 2897 2898 - Entity specification: 2899 - attributes: `attribute_name` (default if no prefix is specified) 2900 - metrics: `metrics.metric_name` 2901 - parameters: `params.param_name` 2902 - tags: `tags.tag_name` 2903 - Comparison operators: 2904 - For numeric entities (metrics and numeric attributes): <, <=, >, >=, =, != 2905 - For string entities (params, tags, string attributes): =, !=, IN, NOT IN 2906 - Multiple conditions can be joined with 'AND' 2907 - String values must be enclosed in single quotes 2908 2909 Example filter strings: 2910 - `creation_time > 100` 2911 - `metrics.rmse > 0.5 AND params.model_type = 'rf'` 2912 - `tags.release IN ('v1.0', 'v1.1')` 2913 - `params.optimizer != 'adam' AND metrics.accuracy >= 0.9` 2914 2915 datasets: List of dictionaries to specify datasets on which to apply metrics filters 2916 For example, a filter string with `metrics.accuracy > 0.9` and dataset with name 2917 "test_dataset" means we will return all logged models with accuracy > 0.9 on the 2918 test_dataset. Metric values from ANY dataset matching the criteria are considered. 2919 If no datasets are specified, then metrics across all datasets are considered in 2920 the filter. The following fields are supported: 2921 2922 dataset_name (str): 2923 Required. Name of the dataset. 2924 dataset_digest (str): 2925 Optional. Digest of the dataset. 2926 max_results: The maximum number of logged models to return. 2927 order_by: List of dictionaries to specify the ordering of the search results. The following 2928 fields are supported: 2929 2930 field_name (str): 2931 Required. Name of the field to order by, e.g. "metrics.accuracy". 2932 ascending (bool): 2933 Optional. Whether the order is ascending or not. 2934 dataset_name (str): 2935 Optional. If ``field_name`` refers to a metric, this field 2936 specifies the name of the dataset associated with the metric. Only metrics 2937 associated with the specified dataset name will be considered for ordering. 2938 This field may only be set if ``field_name`` refers to a metric. 2939 dataset_digest (str): 2940 Optional. If ``field_name`` refers to a metric, this field 2941 specifies the digest of the dataset associated with the metric. Only metrics 2942 associated with the specified dataset name and digest will be considered for 2943 ordering. This field may only be set if ``dataset_name`` is also set. 2944 2945 output_format: The output format of the search results. Supported values are 'pandas' 2946 and 'list'. 2947 2948 Returns: 2949 The search results in the specified output format. 2950 2951 Example: 2952 2953 .. code-block:: python 2954 :test: 2955 2956 import mlflow 2957 2958 2959 class DummyModel(mlflow.pyfunc.PythonModel): 2960 def predict(self, context, model_input: list[str]) -> list[str]: 2961 return model_input 2962 2963 2964 model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel()) 2965 another_model_info = mlflow.pyfunc.log_model( 2966 name="another_model", python_model=DummyModel() 2967 ) 2968 models = mlflow.search_logged_models(output_format="list") 2969 assert [m.name for m in models] == ["another_model", "model"] 2970 models = mlflow.search_logged_models( 2971 filter_string="name = 'another_model'", output_format="list" 2972 ) 2973 assert [m.name for m in models] == ["another_model"] 2974 models = mlflow.search_logged_models( 2975 order_by=[{"field_name": "creation_time", "ascending": True}], output_format="list" 2976 ) 2977 assert [m.name for m in models] == ["model", "another_model"] 2978 """ 2979 experiment_ids = experiment_ids or [_get_experiment_id()] 2980 client = MlflowClient() 2981 models = [] 2982 page_token = None 2983 while True: 2984 logged_models_page = client.search_logged_models( 2985 experiment_ids=experiment_ids, 2986 filter_string=filter_string, 2987 datasets=datasets, 2988 max_results=max_results, 2989 order_by=order_by, 2990 page_token=page_token, 2991 ) 2992 models.extend(logged_models_page.to_list()) 2993 if max_results is not None and len(models) >= max_results: 2994 break 2995 if not logged_models_page.token: 2996 break 2997 page_token = logged_models_page.token 2998 2999 # Only return at most max_results logged models if specified 3000 if max_results is not None: 3001 models = models[:max_results] 3002 3003 if output_format == "list": 3004 return models 3005 elif output_format == "pandas": 3006 import pandas as pd 3007 3008 model_dicts = [] 3009 for model in models: 3010 model_dict = model.to_dictionary() 3011 # Convert the status back from int to the enum string 3012 model_dict["status"] = LoggedModelStatus.from_int(model_dict["status"]) 3013 model_dicts.append(model_dict) 3014 3015 return pd.DataFrame(model_dicts) 3016 3017 else: 3018 raise MlflowException( 3019 f"Unsupported output format: {output_format!r}. Supported string values are " 3020 "'pandas' or 'list'", 3021 INVALID_PARAMETER_VALUE, 3022 ) 3023 3024 3025 def log_outputs(models: list[LoggedModelOutput] | None = None): 3026 """ 3027 Log outputs, such as models, to the active run. If there is no active run, a new run will be 3028 created. 3029 3030 Args: 3031 models: List of :py:class:`mlflow.entities.LoggedModelOutput` instances to log 3032 as outputs to the run. 3033 3034 Returns: 3035 None. 3036 """ 3037 run_id = _get_or_start_run().info.run_id 3038 MlflowClient().log_outputs(run_id, models=models) 3039 3040 3041 def delete_run(run_id: str) -> None: 3042 """ 3043 Deletes a run with the given ID. 3044 3045 Args: 3046 run_id: Unique identifier for the run to delete. 3047 3048 .. code-block:: python 3049 :test: 3050 :caption: Example 3051 3052 import mlflow 3053 3054 with mlflow.start_run() as run: 3055 mlflow.log_param("p", 0) 3056 3057 run_id = run.info.run_id 3058 mlflow.delete_run(run_id) 3059 3060 lifecycle_stage = mlflow.get_run(run_id).info.lifecycle_stage 3061 print(f"run_id: {run_id}; lifecycle_stage: {lifecycle_stage}") 3062 3063 .. code-block:: text 3064 :caption: Output 3065 3066 run_id: 45f4af3e6fd349e58579b27fcb0b8277; lifecycle_stage: deleted 3067 3068 """ 3069 MlflowClient().delete_run(run_id) 3070 3071 3072 def set_logged_model_tags(model_id: str, tags: dict[str, Any]) -> None: 3073 """ 3074 Set tags on the specified logged model. 3075 3076 Args: 3077 model_id: ID of the model. 3078 tags: Tags to set on the model. 3079 3080 Returns: 3081 None 3082 3083 Example: 3084 3085 .. code-block:: python 3086 :test: 3087 3088 import mlflow 3089 3090 3091 class DummyModel(mlflow.pyfunc.PythonModel): 3092 def predict(self, context, model_input: list[str]) -> list[str]: 3093 return model_input 3094 3095 3096 model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel()) 3097 mlflow.set_logged_model_tags(model_info.model_id, {"key": "value"}) 3098 model = mlflow.get_logged_model(model_info.model_id) 3099 assert model.tags["key"] == "value" 3100 """ 3101 MlflowClient().set_logged_model_tags(model_id, tags) 3102 3103 3104 def delete_logged_model_tag(model_id: str, key: str) -> None: 3105 """ 3106 Delete a tag from the specified logged model. 3107 3108 Args: 3109 model_id: ID of the model. 3110 key: Tag key to delete. 3111 3112 Example: 3113 3114 .. code-block:: python 3115 :test: 3116 3117 import mlflow 3118 3119 3120 class DummyModel(mlflow.pyfunc.PythonModel): 3121 def predict(self, context, model_input: list[str]) -> list[str]: 3122 return model_input 3123 3124 3125 model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel()) 3126 mlflow.set_logged_model_tags(model_info.model_id, {"key": "value"}) 3127 model = mlflow.get_logged_model(model_info.model_id) 3128 assert model.tags["key"] == "value" 3129 mlflow.delete_logged_model_tag(model_info.model_id, "key") 3130 model = mlflow.get_logged_model(model_info.model_id) 3131 assert "key" not in model.tags 3132 """ 3133 MlflowClient().delete_logged_model_tag(model_id, key) 3134 3135 3136 def get_artifact_uri(artifact_path: str | None = None) -> str: 3137 """ 3138 Get the absolute URI of the specified artifact in the currently active run. 3139 3140 If `path` is not specified, the artifact root URI of the currently active 3141 run will be returned; calls to ``log_artifact`` and ``log_artifacts`` write 3142 artifact(s) to subdirectories of the artifact root URI. 3143 3144 If no run is active, this method will create a new active run. 3145 3146 Args: 3147 artifact_path: The run-relative artifact path for which to obtain an absolute URI. 3148 For example, "path/to/artifact". If unspecified, the artifact root URI 3149 for the currently active run will be returned. 3150 3151 Returns: 3152 An *absolute* URI referring to the specified artifact or the currently active run's 3153 artifact root. For example, if an artifact path is provided and the currently active 3154 run uses an S3-backed store, this may be a uri of the form 3155 ``s3://<bucket_name>/path/to/artifact/root/path/to/artifact``. If an artifact path 3156 is not provided and the currently active run uses an S3-backed store, this may be a 3157 URI of the form ``s3://<bucket_name>/path/to/artifact/root``. 3158 3159 .. code-block:: python 3160 :test: 3161 :caption: Example 3162 3163 import tempfile 3164 3165 import mlflow 3166 3167 features = "rooms, zipcode, median_price, school_rating, transport" 3168 with tempfile.NamedTemporaryFile("w") as tmp_file: 3169 tmp_file.write(features) 3170 tmp_file.flush() 3171 3172 # Log the artifact in a directory "features" under the root artifact_uri/features 3173 with mlflow.start_run(): 3174 mlflow.log_artifact(tmp_file.name, artifact_path="features") 3175 3176 # Fetch the artifact uri root directory 3177 artifact_uri = mlflow.get_artifact_uri() 3178 print(f"Artifact uri: {artifact_uri}") 3179 3180 # Fetch a specific artifact uri 3181 artifact_uri = mlflow.get_artifact_uri(artifact_path="features/features.txt") 3182 print(f"Artifact uri: {artifact_uri}") 3183 3184 .. code-block:: text 3185 :caption: Output 3186 3187 Artifact uri: file:///.../0/a46a80f1c9644bd8f4e5dd5553fffce/artifacts 3188 Artifact uri: file:///.../0/a46a80f1c9644bd8f4e5dd5553fffce/artifacts/features/features.txt 3189 """ 3190 if not mlflow.active_run(): 3191 _logger.warning( 3192 "No active run found. A new active run will be created. If this is not intended, " 3193 "please create a run using `mlflow.start_run()` first." 3194 ) 3195 3196 return artifact_utils.get_artifact_uri( 3197 run_id=_get_or_start_run().info.run_id, artifact_path=artifact_path 3198 ) 3199 3200 3201 def search_runs( 3202 experiment_ids: list[str] | None = None, 3203 filter_string: str = "", 3204 run_view_type: int = ViewType.ACTIVE_ONLY, 3205 max_results: int = SEARCH_MAX_RESULTS_PANDAS, 3206 order_by: list[str] | None = None, 3207 output_format: str = "pandas", 3208 search_all_experiments: bool = False, 3209 experiment_names: list[str] | None = None, 3210 ) -> Union[list[Run], "pandas.DataFrame"]: 3211 """ 3212 Search for Runs that fit the specified criteria. 3213 3214 Args: 3215 experiment_ids: List of experiment IDs. Search can work with experiment IDs or 3216 experiment names, but not both in the same call. Values other than 3217 ``None`` or ``[]`` will result in error if ``experiment_names`` is 3218 also not ``None`` or ``[]``. ``None`` will default to the active 3219 experiment if ``experiment_names`` is ``None`` or ``[]``. 3220 filter_string: Filter query string, defaults to searching all runs. 3221 run_view_type: one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``, or ``ALL`` runs 3222 defined in :py:class:`mlflow.entities.ViewType`. 3223 max_results: The maximum number of runs to put in the dataframe. Default is 100,000 3224 to avoid causing out-of-memory issues on the user's machine. 3225 order_by: List of columns to order by (e.g., "metrics.rmse"). The ``order_by`` column 3226 can contain an optional ``DESC`` or ``ASC`` value. The default is ``ASC``. 3227 The default ordering is to sort by ``start_time DESC``, then ``run_id``. 3228 output_format: The output format to be returned. If ``pandas``, a ``pandas.DataFrame`` 3229 is returned and, if ``list``, a list of :py:class:`mlflow.entities.Run` 3230 is returned. 3231 search_all_experiments: Boolean specifying whether all experiments should be searched. 3232 Only honored if ``experiment_ids`` is ``[]`` or ``None``. 3233 experiment_names: List of experiment names. Search can work with experiment IDs or 3234 experiment names, but not both in the same call. Values other 3235 than ``None`` or ``[]`` will result in error if ``experiment_ids`` 3236 is also not ``None`` or ``[]``. ``None`` will default to the active 3237 experiment if ``experiment_ids`` is ``None`` or ``[]``. 3238 3239 Returns: 3240 If output_format is ``list``: a list of :py:class:`mlflow.entities.Run`. If 3241 output_format is ``pandas``: ``pandas.DataFrame`` of runs, where each metric, 3242 parameter, and tag is expanded into its own column named metrics.*, params.*, or 3243 tags.* respectively. For runs that don't have a particular metric, parameter, or tag, 3244 the value for the corresponding column is (NumPy) ``Nan``, ``None``, or ``None`` 3245 respectively. 3246 3247 .. code-block:: python 3248 :test: 3249 :caption: Example 3250 3251 import mlflow 3252 3253 # Create an experiment and log two runs under it 3254 experiment_name = "Social NLP Experiments" 3255 experiment_id = mlflow.create_experiment(experiment_name) 3256 with mlflow.start_run(experiment_id=experiment_id): 3257 mlflow.log_metric("m", 1.55) 3258 mlflow.set_tag("s.release", "1.1.0-RC") 3259 with mlflow.start_run(experiment_id=experiment_id): 3260 mlflow.log_metric("m", 2.50) 3261 mlflow.set_tag("s.release", "1.2.0-GA") 3262 # Search for all the runs in the experiment with the given experiment ID 3263 df = mlflow.search_runs([experiment_id], order_by=["metrics.m DESC"]) 3264 print(df[["metrics.m", "tags.s.release", "run_id"]]) 3265 print("--") 3266 # Search the experiment_id using a filter_string with tag 3267 # that has a case insensitive pattern 3268 filter_string = "tags.s.release ILIKE '%rc%'" 3269 df = mlflow.search_runs([experiment_id], filter_string=filter_string) 3270 print(df[["metrics.m", "tags.s.release", "run_id"]]) 3271 print("--") 3272 # Search for all the runs in the experiment with the given experiment name 3273 df = mlflow.search_runs(experiment_names=[experiment_name], order_by=["metrics.m DESC"]) 3274 print(df[["metrics.m", "tags.s.release", "run_id"]]) 3275 3276 .. code-block:: text 3277 :caption: Output 3278 3279 metrics.m tags.s.release run_id 3280 0 2.50 1.2.0-GA 147eed886ab44633902cc8e19b2267e2 3281 1 1.55 1.1.0-RC 5cc7feaf532f496f885ad7750809c4d4 3282 -- 3283 metrics.m tags.s.release run_id 3284 0 1.55 1.1.0-RC 5cc7feaf532f496f885ad7750809c4d4 3285 -- 3286 metrics.m tags.s.release run_id 3287 0 2.50 1.2.0-GA 147eed886ab44633902cc8e19b2267e2 3288 1 1.55 1.1.0-RC 5cc7feaf532f496f885ad7750809c4d4 3289 """ 3290 no_ids = experiment_ids is None or len(experiment_ids) == 0 3291 no_names = experiment_names is None or len(experiment_names) == 0 3292 no_ids_or_names = no_ids and no_names 3293 if not no_ids and not no_names: 3294 raise MlflowException( 3295 message="Only experiment_ids or experiment_names can be used, but not both", 3296 error_code=INVALID_PARAMETER_VALUE, 3297 ) 3298 3299 if search_all_experiments and no_ids_or_names: 3300 experiment_ids = [ 3301 exp.experiment_id for exp in search_experiments(view_type=ViewType.ACTIVE_ONLY) 3302 ] 3303 elif no_ids_or_names: 3304 experiment_ids = [_get_experiment_id()] 3305 elif not no_names: 3306 experiments = [] 3307 for n in experiment_names: 3308 if n is not None: 3309 if experiment_by_name := get_experiment_by_name(n): 3310 experiments.append(experiment_by_name) 3311 else: 3312 _logger.warning("Cannot retrieve experiment by name %s", n) 3313 experiment_ids = [e.experiment_id for e in experiments if e is not None] 3314 3315 if len(experiment_ids) == 0: 3316 runs = [] 3317 else: 3318 # Using an internal function as the linter doesn't like assigning a lambda, and inlining the 3319 # full thing is a mess 3320 def pagination_wrapper_func(number_to_get, next_page_token): 3321 return MlflowClient().search_runs( 3322 experiment_ids, 3323 filter_string, 3324 run_view_type, 3325 number_to_get, 3326 order_by, 3327 next_page_token, 3328 ) 3329 3330 runs = get_results_from_paginated_fn( 3331 pagination_wrapper_func, 3332 NUM_RUNS_PER_PAGE_PANDAS, 3333 max_results, 3334 ) 3335 3336 if output_format == "list": 3337 return runs # List[mlflow.entities.run.Run] 3338 elif output_format == "pandas": 3339 import numpy as np 3340 import pandas as pd 3341 3342 info = { 3343 "run_id": [], 3344 "experiment_id": [], 3345 "status": [], 3346 "artifact_uri": [], 3347 "start_time": [], 3348 "end_time": [], 3349 } 3350 params = {} 3351 metrics = {} 3352 tags = {} 3353 PARAM_NULL = None 3354 METRIC_NULL = np.nan 3355 TAG_NULL = None 3356 for i, run in enumerate(runs): 3357 info["run_id"].append(run.info.run_id) 3358 info["experiment_id"].append(run.info.experiment_id) 3359 info["status"].append(run.info.status) 3360 info["artifact_uri"].append(run.info.artifact_uri) 3361 info["start_time"].append(pd.to_datetime(run.info.start_time, unit="ms", utc=True)) 3362 info["end_time"].append(pd.to_datetime(run.info.end_time, unit="ms", utc=True)) 3363 3364 # Params 3365 param_keys = set(params.keys()) 3366 for key in param_keys: 3367 if key in run.data.params: 3368 params[key].append(run.data.params[key]) 3369 else: 3370 params[key].append(PARAM_NULL) 3371 new_params = set(run.data.params.keys()) - param_keys 3372 for p in new_params: 3373 params[p] = [PARAM_NULL] * i # Fill in null values for all previous runs 3374 params[p].append(run.data.params[p]) 3375 3376 # Metrics 3377 metric_keys = set(metrics.keys()) 3378 for key in metric_keys: 3379 if key in run.data.metrics: 3380 metrics[key].append(run.data.metrics[key]) 3381 else: 3382 metrics[key].append(METRIC_NULL) 3383 new_metrics = set(run.data.metrics.keys()) - metric_keys 3384 for m in new_metrics: 3385 metrics[m] = [METRIC_NULL] * i 3386 metrics[m].append(run.data.metrics[m]) 3387 3388 # Tags 3389 tag_keys = set(tags.keys()) 3390 for key in tag_keys: 3391 if key in run.data.tags: 3392 tags[key].append(run.data.tags[key]) 3393 else: 3394 tags[key].append(TAG_NULL) 3395 new_tags = set(run.data.tags.keys()) - tag_keys 3396 for t in new_tags: 3397 tags[t] = [TAG_NULL] * i 3398 tags[t].append(run.data.tags[t]) 3399 3400 data = {} 3401 data.update(info) 3402 for key, value in metrics.items(): 3403 data["metrics." + key] = value 3404 for key, value in params.items(): 3405 data["params." + key] = value 3406 for key, value in tags.items(): 3407 data["tags." + key] = value 3408 return pd.DataFrame(data) 3409 else: 3410 raise ValueError( 3411 f"Unsupported output format: {output_format}. Supported string values are 'pandas' " 3412 "or 'list'" 3413 ) 3414 3415 3416 def _get_or_start_run(): 3417 active_run_stack = _active_run_stack.get() 3418 if len(active_run_stack) > 0: 3419 return active_run_stack[-1] 3420 return start_run() 3421 3422 3423 def _get_experiment_id_from_env(): 3424 experiment_name = MLFLOW_EXPERIMENT_NAME.get() 3425 experiment_id = MLFLOW_EXPERIMENT_ID.get() 3426 if experiment_name is not None: 3427 if exp := MlflowClient().get_experiment_by_name(experiment_name): 3428 if experiment_id and experiment_id != exp.experiment_id: 3429 raise MlflowException( 3430 message=f"The provided {MLFLOW_EXPERIMENT_ID} environment variable " 3431 f"value `{experiment_id}` does not match the experiment id " 3432 f"`{exp.experiment_id}` for experiment name `{experiment_name}`", 3433 error_code=INVALID_PARAMETER_VALUE, 3434 ) 3435 else: 3436 return exp.experiment_id 3437 else: 3438 return MlflowClient().create_experiment(name=experiment_name) 3439 if experiment_id is not None: 3440 try: 3441 exp = MlflowClient().get_experiment(experiment_id) 3442 return exp.experiment_id 3443 except MlflowException as exc: 3444 raise MlflowException( 3445 message=f"The provided {MLFLOW_EXPERIMENT_ID} environment variable " 3446 f"value `{experiment_id}` does not exist in the tracking server. Provide a valid " 3447 f"experiment_id.", 3448 error_code=INVALID_PARAMETER_VALUE, 3449 ) from exc 3450 3451 3452 def _get_experiment_id() -> str | None: 3453 if _active_experiment_id: 3454 return _active_experiment_id 3455 else: 3456 return _get_experiment_id_from_env() or default_experiment_registry.get_experiment_id() 3457 3458 3459 @autologging_integration("mlflow") 3460 def autolog( 3461 log_input_examples: bool = False, 3462 log_model_signatures: bool = True, 3463 log_models: bool = True, 3464 log_datasets: bool = True, 3465 log_traces: bool = True, 3466 disable: bool = False, 3467 exclusive: bool = False, 3468 disable_for_unsupported_versions: bool = False, 3469 silent: bool = False, 3470 extra_tags: dict[str, str] | None = None, 3471 exclude_flavors: list[str] | None = None, 3472 ) -> None: 3473 """ 3474 Enables (or disables) and configures autologging for all supported integrations. 3475 3476 The parameters are passed to any autologging integrations that support them. 3477 3478 See the `tracking docs <../../tracking/autolog.html>`_ for a list of supported autologging 3479 integrations. 3480 3481 Note that framework-specific configurations set at any point will take precedence over 3482 any configurations set by this function. For example: 3483 3484 .. code-block:: python 3485 :test: 3486 3487 import mlflow 3488 3489 mlflow.autolog(log_models=False, exclusive=True) 3490 import sklearn 3491 3492 would enable autologging for `sklearn` with `log_models=False` and `exclusive=True`, 3493 but 3494 3495 .. code-block:: python 3496 :test: 3497 3498 import mlflow 3499 3500 mlflow.autolog(log_models=False, exclusive=True) 3501 3502 import sklearn 3503 3504 mlflow.sklearn.autolog(log_models=True) 3505 3506 would enable autologging for `sklearn` with `log_models=True` and `exclusive=False`, 3507 the latter resulting from the default value for `exclusive` in `mlflow.sklearn.autolog`; 3508 other framework autolog functions (e.g. `mlflow.tensorflow.autolog`) would use the 3509 configurations set by `mlflow.autolog` (in this instance, `log_models=False`, `exclusive=True`), 3510 until they are explicitly called by the user. 3511 3512 Args: 3513 log_input_examples: If ``True``, input examples from training datasets are collected and 3514 logged along with model artifacts during training. If ``False``, 3515 input examples are not logged. 3516 Note: Input examples are MLflow model attributes 3517 and are only collected if ``log_models`` is also ``True``. 3518 log_model_signatures: If ``True``, 3519 :py:class:`ModelSignatures <mlflow.models.ModelSignature>` 3520 describing model inputs and outputs are collected and logged along 3521 with model artifacts during training. If ``False``, signatures are 3522 not logged. Note: Model signatures are MLflow model attributes 3523 and are only collected if ``log_models`` is also ``True``. 3524 log_models: If ``True``, trained models are logged as MLflow model artifacts. 3525 If ``False``, trained models are not logged. 3526 Input examples and model signatures, which are attributes of MLflow models, 3527 are also omitted when ``log_models`` is ``False``. 3528 log_datasets: If ``True``, dataset information is logged to MLflow Tracking. 3529 If ``False``, dataset information is not logged. 3530 log_traces: If ``True``, traces are collected for integrations. 3531 If ``False``, no trace is collected. 3532 disable: If ``True``, disables all supported autologging integrations. If ``False``, 3533 enables all supported autologging integrations. 3534 exclusive: If ``True``, autologged content is not logged to user-created fluent runs. 3535 If ``False``, autologged content is logged to the active fluent run, 3536 which may be user-created. 3537 disable_for_unsupported_versions: If ``True``, disable autologging for versions of 3538 all integration libraries that have not been tested against this version 3539 of the MLflow client or are incompatible. 3540 silent: If ``True``, suppress all event logs and warnings from MLflow during autologging 3541 setup and training execution. If ``False``, show all events and warnings during 3542 autologging setup and training execution. 3543 extra_tags: A dictionary of extra tags to set on each managed run created by autologging. 3544 exclude_flavors: A list of flavor names that are excluded from the auto-logging. 3545 e.g. tensorflow, pyspark.ml 3546 3547 .. code-block:: python 3548 :test: 3549 :caption: Example 3550 3551 import numpy as np 3552 import mlflow.sklearn 3553 from mlflow import MlflowClient 3554 from sklearn.linear_model import LinearRegression 3555 3556 3557 def print_auto_logged_info(r): 3558 tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")} 3559 artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")] 3560 print(f"run_id: {r.info.run_id}") 3561 print(f"artifacts: {artifacts}") 3562 print(f"params: {r.data.params}") 3563 print(f"metrics: {r.data.metrics}") 3564 print(f"tags: {tags}") 3565 3566 3567 # prepare training data 3568 X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) 3569 y = np.dot(X, np.array([1, 2])) + 3 3570 3571 # Auto log all the parameters, metrics, and artifacts 3572 mlflow.autolog() 3573 model = LinearRegression() 3574 with mlflow.start_run() as run: 3575 model.fit(X, y) 3576 3577 # fetch the auto logged parameters and metrics for ended run 3578 print_auto_logged_info(mlflow.get_run(run_id=run.info.run_id)) 3579 3580 .. code-block:: text 3581 :caption: Output 3582 3583 run_id: fd10a17d028c47399a55ab8741721ef7 3584 artifacts: ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl'] 3585 params: {'copy_X': 'True', 3586 'normalize': 'False', 3587 'fit_intercept': 'True', 3588 'n_jobs': 'None'} 3589 metrics: {'training_score': 1.0, 3590 'training_root_mean_squared_error': 4.440892098500626e-16, 3591 'training_r2_score': 1.0, 3592 'training_mean_absolute_error': 2.220446049250313e-16, 3593 'training_mean_squared_error': 1.9721522630525295e-31} 3594 tags: {'estimator_class': 'sklearn.linear_model._base.LinearRegression', 3595 'estimator_name': 'LinearRegression'} 3596 """ 3597 locals_copy = locals().items() 3598 3599 # Mapping of library name to specific autolog function name. We use string like 3600 # "tensorflow.autolog" to avoid loading all flavor modules, so we only set autologging for 3601 # compatible modules. 3602 LIBRARY_TO_AUTOLOG_MODULE = { 3603 "tensorflow": "mlflow.tensorflow", 3604 "keras": "mlflow.keras", 3605 "xgboost": "mlflow.xgboost", 3606 "lightgbm": "mlflow.lightgbm", 3607 "statsmodels": "mlflow.statsmodels", 3608 "sklearn": "mlflow.sklearn", 3609 "pyspark": "mlflow.spark", 3610 "pyspark.ml": "mlflow.pyspark.ml", 3611 # TODO: Broaden this beyond pytorch_lightning as we add autologging support for more 3612 # Pytorch frameworks under mlflow.pytorch.autolog 3613 "pytorch_lightning": "mlflow.pytorch", 3614 "lightning": "mlflow.pytorch", 3615 "setfit": "mlflow.transformers", 3616 "transformers": "mlflow.transformers", 3617 # do not enable langchain autologging by default 3618 } 3619 3620 GENAI_LIBRARY_TO_AUTOLOG_MODULE = { 3621 "autogen": "mlflow.ag2", 3622 "agno": "mlflow.agno", 3623 "anthropic": "mlflow.anthropic", 3624 "autogen_agentchat": "mlflow.autogen", 3625 "openai": "mlflow.openai", 3626 "google.genai": "mlflow.gemini", 3627 "google.generativeai": "mlflow.gemini", 3628 "litellm": "mlflow.litellm", 3629 "llama_index.core": "mlflow.llama_index", 3630 "langchain": "mlflow.langchain", 3631 "dspy": "mlflow.dspy", 3632 "crewai": "mlflow.crewai", 3633 "smolagents": "mlflow.smolagents", 3634 "groq": "mlflow.groq", 3635 "strands": "mlflow.strands", 3636 "haystack": "mlflow.haystack", 3637 "boto3": "mlflow.bedrock", 3638 "mistralai": "mlflow.mistral", 3639 "pydantic_ai": "mlflow.pydantic_ai", 3640 } 3641 3642 # Currently, GenAI libraries are not enabled by `mlflow.autolog` in Databricks, 3643 # particularly when disable=False. This is because the function is automatically invoked 3644 # by system and we don't want to take the risk of enabling GenAI libraries all at once. 3645 # TODO: Remove this logic once a feature flag is implemented in Databricks Runtime init logic. 3646 if is_in_databricks_runtime() and (not disable): 3647 target_library_and_module = LIBRARY_TO_AUTOLOG_MODULE 3648 else: 3649 target_library_and_module = LIBRARY_TO_AUTOLOG_MODULE | GENAI_LIBRARY_TO_AUTOLOG_MODULE 3650 3651 if exclude_flavors: 3652 excluded_modules = [f"mlflow.{flavor}" for flavor in exclude_flavors] 3653 target_library_and_module = { 3654 k: v for k, v in target_library_and_module.items() if v not in excluded_modules 3655 } 3656 3657 def get_autologging_params(autolog_fn): 3658 try: 3659 needed_params = list(inspect.signature(autolog_fn).parameters.keys()) 3660 return {k: v for k, v in locals_copy if k in needed_params} 3661 except Exception: 3662 return {} 3663 3664 # Note: we need to protect `setup_autologging` with `autologging_conf_lock`, 3665 # because `setup_autologging` might be registered as post importing hook 3666 # and be executed asynchronously, so that it is out of current active 3667 # `autologging_conf_lock` scope. 3668 @autologging_conf_lock 3669 def setup_autologging(module): 3670 try: 3671 autologging_params = None 3672 autolog_module = importlib.import_module(target_library_and_module[module.__name__]) 3673 autolog_fn = autolog_module.autolog 3674 # Only call integration's autolog function with `mlflow.autolog` configs 3675 # if the integration's autolog function has not already been called by the user. 3676 # Logic is as follows: 3677 # - if a previous_config exists, that means either `mlflow.autolog` or 3678 # `mlflow.integration.autolog` was called. 3679 # - if the config contains `AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED`, the 3680 # configuration was set by `mlflow.autolog`, and so we can safely call `autolog_fn` 3681 # with `autologging_params`. 3682 # - if the config doesn't contain this key, the configuration was set by an 3683 # `mlflow.integration.autolog` call, so we should not call `autolog_fn` with 3684 # new configs. 3685 prev_config = AUTOLOGGING_INTEGRATIONS.get(autolog_fn.integration_name) 3686 if prev_config and not prev_config.get( 3687 AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED, False 3688 ): 3689 return 3690 3691 autologging_params = get_autologging_params(autolog_fn) 3692 autolog_fn(**autologging_params) 3693 AUTOLOGGING_INTEGRATIONS[autolog_fn.integration_name][ 3694 AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED 3695 ] = True 3696 if not autologging_is_disabled( 3697 autolog_fn.integration_name 3698 ) and not autologging_params.get("silent", False): 3699 _logger.info("Autologging successfully enabled for %s.", module.__name__) 3700 except Exception as e: 3701 if is_testing(): 3702 # Raise unexpected exceptions in test mode in order to detect 3703 # errors within dependent autologging integrations 3704 raise 3705 elif autologging_params is None or not autologging_params.get("silent", False): 3706 _logger.warning( 3707 "Exception raised while enabling autologging for %s: %s", 3708 module.__name__, 3709 str(e), 3710 ) 3711 3712 # for each autolog library (except pyspark), register a post-import hook. 3713 # this way, we do not send any errors to the user until we know they are using the library. 3714 # the post-import hook also retroactively activates for previously-imported libraries. 3715 for library in sorted(set(target_library_and_module) - {"pyspark", "pyspark.ml"}): 3716 register_post_import_hook(setup_autologging, library, overwrite=True) 3717 3718 if is_in_databricks_runtime(): 3719 # for pyspark, we activate autologging immediately, without waiting for a module import. 3720 # this is because on Databricks a SparkSession already exists and the user can directly 3721 # interact with it, and this activity should be logged. 3722 import pyspark as pyspark_module 3723 import pyspark.ml as pyspark_ml_module 3724 3725 setup_autologging(pyspark_module) 3726 setup_autologging(pyspark_ml_module) 3727 else: 3728 if "pyspark" in target_library_and_module: 3729 register_post_import_hook(setup_autologging, "pyspark", overwrite=True) 3730 if "pyspark.ml" in target_library_and_module: 3731 register_post_import_hook(setup_autologging, "pyspark.ml", overwrite=True) 3732 3733 _record_event(AutologgingEvent, {"flavor": "all", "log_traces": log_traces, "disable": disable}) 3734 3735 3736 _active_model_id_env_lock = threading.Lock() 3737 3738 3739 class ActiveModelContext: 3740 """ 3741 The context of the active model. 3742 3743 Args: 3744 model_id: The ID of the active model. 3745 set_by_user: Whether the active model was set by the user or not. 3746 """ 3747 3748 def __init__(self, model_id: str | None = None, set_by_user: bool = False): 3749 # use active model ID from environment variables as the default value for model_id 3750 # so that for subprocesses the default _ACTIVE_MODEL_CONTEXT.model_id 3751 # is still valid, and we don't need to read from env var. 3752 self._set_by_user = set_by_user 3753 if is_in_databricks_model_serving_environment(): 3754 # In Databricks, we set the active model ID to the environment variable 3755 # so that it can be used in the main process, since databricks serving 3756 # loads model from threads. 3757 with _active_model_id_env_lock: 3758 self._model_id = model_id or _get_active_model_id_from_env() 3759 if self._model_id: 3760 _MLFLOW_ACTIVE_MODEL_ID.set(self._model_id) 3761 else: 3762 self._model_id = model_id or _get_active_model_id_from_env() 3763 3764 def __repr__(self): 3765 return f"ActiveModelContext(model_id={self.model_id}, set_by_user={self.set_by_user})" 3766 3767 @property 3768 def model_id(self) -> str | None: 3769 return self._model_id 3770 3771 @property 3772 def set_by_user(self) -> bool: 3773 return self._set_by_user 3774 3775 3776 def _get_active_model_id_from_env() -> str | None: 3777 """ 3778 Get the active model ID from environment variables, with proper precedence handling. 3779 3780 This utility function reads the active model ID from environment variables with the following 3781 precedence order: 3782 1. MLFLOW_ACTIVE_MODEL_ID (public variable) - takes precedence if set 3783 2. _MLFLOW_ACTIVE_MODEL_ID (legacy internal variable) - used as fallback 3784 3785 Historical Context: 3786 The _MLFLOW_ACTIVE_MODEL_ID environment variable was originally created for internal MLflow 3787 use only. With the introduction of MLFLOW_ACTIVE_MODEL_ID as the public API, we prioritize 3788 the public variable to encourage migration to the public interface while maintaining 3789 backward compatibility by falling back to the legacy variable when only it is set. 3790 3791 Returns: 3792 The active model ID if found in environment variables, otherwise None. 3793 """ 3794 # Check public variable first to prioritize the public API 3795 public_model_id = MLFLOW_ACTIVE_MODEL_ID.get() 3796 if public_model_id is not None: 3797 return public_model_id 3798 3799 # Fallback to legacy internal variable for backward compatibility 3800 return _MLFLOW_ACTIVE_MODEL_ID.get() 3801 3802 3803 _ACTIVE_MODEL_CONTEXT = ThreadLocalVariable(default_factory=lambda: ActiveModelContext()) 3804 3805 3806 class ActiveModel(LoggedModel): 3807 """ 3808 Wrapper around :py:class:`mlflow.entities.LoggedModel` to enable using Python ``with`` syntax. 3809 """ 3810 3811 def __init__(self, logged_model: LoggedModel, set_by_user: bool): 3812 super().__init__(**logged_model.to_dictionary()) 3813 self.last_active_model_context = _ACTIVE_MODEL_CONTEXT.get() 3814 _set_active_model_id(self.model_id, set_by_user) 3815 3816 def __enter__(self): 3817 return self 3818 3819 def __exit__(self, exc_type, exc_val, exc_tb): 3820 if is_in_databricks_model_serving_environment(): 3821 # create a new instance of ActiveModelContext to make sure the 3822 # environment variable is updated in databricks serving environment 3823 _ACTIVE_MODEL_CONTEXT.set( 3824 ActiveModelContext( 3825 model_id=self.last_active_model_context.model_id, 3826 set_by_user=self.last_active_model_context.set_by_user, 3827 ) 3828 ) 3829 else: 3830 _ACTIVE_MODEL_CONTEXT.set(self.last_active_model_context) 3831 3832 3833 # NB: This function is only intended to be used publicly by users to set the 3834 # active model ID. MLflow internally should NEVER call this function directly, 3835 # since we need to differentiate between user and system set active model IDs. 3836 # For MLflow internal usage, use `_set_active_model` instead. 3837 3838 3839 def set_active_model(*, name: str | None = None, model_id: str | None = None) -> ActiveModel: 3840 """ 3841 Set the active model with the specified name or model ID, and it will be used for linking 3842 traces that are generated during the lifecycle of the model. The return value can be used as 3843 a context manager within a ``with`` block; otherwise, you must call ``set_active_model()`` 3844 to update active model. 3845 3846 Args: 3847 name: The name of the :py:class:`mlflow.entities.LoggedModel` to set as active. 3848 If a LoggedModel with the name does not exist, it will be created under the current 3849 experiment. If multiple LoggedModels with the name exist, the latest one will be 3850 set as active. 3851 model_id: The ID of the :py:class:`mlflow.entities.LoggedModel` to set as active. 3852 If no LoggedModel with the ID exists, an exception will be raised. 3853 3854 Returns: 3855 :py:class:`mlflow.ActiveModel` object that acts as a context manager wrapping the 3856 LoggedModel's state. 3857 3858 .. code-block:: python 3859 :test: 3860 :caption: Example 3861 3862 import mlflow 3863 3864 # Set the active model by name 3865 mlflow.set_active_model(name="my_model") 3866 3867 # Set the active model by model ID 3868 model = mlflow.create_external_model(name="test_model") 3869 mlflow.set_active_model(model_id=model.model_id) 3870 3871 # Use the active model in a context manager 3872 with mlflow.set_active_model(name="new_model"): 3873 print(mlflow.get_active_model_id()) 3874 3875 # Traces are automatically linked to the active model 3876 mlflow.set_active_model(name="my_model") 3877 3878 3879 @mlflow.trace 3880 def predict(model_input): 3881 return model_input 3882 3883 3884 predict("abc") 3885 traces = mlflow.search_traces( 3886 model_id=mlflow.get_active_model_id(), return_type="list", flush=True 3887 ) 3888 assert len(traces) == 1 3889 """ 3890 return _set_active_model(name=name, model_id=model_id, set_by_user=True) 3891 3892 3893 def _set_active_model( 3894 *, name: str | None = None, model_id: str | None = None, set_by_user: bool = False 3895 ) -> ActiveModel: 3896 if name is None and model_id is None: 3897 raise MlflowException.invalid_parameter_value( 3898 message="Either name or model_id must be provided", 3899 ) 3900 3901 if model_id is not None: 3902 logged_model = mlflow.get_logged_model(model_id) 3903 if name is not None and logged_model.name != name: 3904 raise MlflowException.invalid_parameter_value( 3905 f"LoggedModel with model_id {model_id!r} has name {logged_model.name!r}, which does" 3906 f" not match the provided name {name!r}." 3907 ) 3908 elif name is not None: 3909 logged_models = mlflow.search_logged_models( 3910 filter_string=f"name='{name}'", max_results=2, output_format="list" 3911 ) 3912 if len(logged_models) > 1: 3913 _logger.warning( 3914 f"Multiple LoggedModels found with name {name!r}, setting the latest one as active " 3915 "model." 3916 ) 3917 if len(logged_models) == 0: 3918 _logger.info(f"LoggedModel with name {name!r} does not exist, creating one...") 3919 logged_model = mlflow.create_external_model(name=name) 3920 else: 3921 logged_model = logged_models[0] 3922 return ActiveModel(logged_model=logged_model, set_by_user=set_by_user) 3923 3924 3925 def _set_active_model_id(model_id: str, set_by_user: bool = False) -> None: 3926 """ 3927 Set the active model ID in the active model context and update the 3928 corresponding environment variable. This should only be used when 3929 we know the LoggedModel with the model_id exists. 3930 This function should be used inside MLflow to set the active model 3931 while not blocking other code execution. 3932 """ 3933 try: 3934 _ACTIVE_MODEL_CONTEXT.set(ActiveModelContext(model_id, set_by_user)) 3935 except Exception as e: 3936 _logger.warning(f"Failed to set active model ID to {model_id}, error: {e}") 3937 else: 3938 _logger.info(f"Active model is set to the logged model with ID: {model_id}") 3939 if not set_by_user: 3940 _logger.info( 3941 "Use `mlflow.set_active_model` to set the active model " 3942 "to a different one if needed." 3943 ) 3944 3945 3946 def _get_active_model_context() -> ActiveModelContext: 3947 """ 3948 Get the active model context. This is used internally by MLflow to manage the active model 3949 context. 3950 """ 3951 return _ACTIVE_MODEL_CONTEXT.get() 3952 3953 3954 def get_active_model_id() -> str | None: 3955 """ 3956 Get the active model ID. If no active model is set with ``set_active_model()``, the 3957 default active model is set using model ID from the environment variable 3958 ``MLFLOW_ACTIVE_MODEL_ID`` or the legacy environment variable ``_MLFLOW_ACTIVE_MODEL_ID``. 3959 If neither is set, return None. Note that this function only get the active model ID from the 3960 current thread. 3961 3962 Returns: 3963 The active model ID if set, otherwise None. 3964 """ 3965 return _get_active_model_context().model_id 3966 3967 3968 def _get_active_model_id_global() -> str | None: 3969 """ 3970 Get the active model ID from the global context by checking all threads. 3971 This is useful when we need to get the active_model_id set by a different thread. 3972 """ 3973 # if the active model ID is set in the current thread, always use it 3974 if model_id_in_current_thread := get_active_model_id(): 3975 _logger.debug(f"Active model ID found in the current thread: {model_id_in_current_thread}") 3976 return model_id_in_current_thread 3977 model_ids = [ 3978 ctx.model_id 3979 for ctx in _ACTIVE_MODEL_CONTEXT.get_all_thread_values().values() 3980 if ctx.model_id is not None 3981 ] 3982 if model_ids: 3983 if len(set(model_ids)) > 1: 3984 _logger.debug( 3985 "Failed to get one active model id from all threads, multiple active model IDs " 3986 f"found: {set(model_ids)}." 3987 ) 3988 return 3989 return model_ids[0] 3990 _logger.debug("No active model ID found in any thread.") 3991 3992 3993 def clear_active_model() -> None: 3994 """ 3995 Clear the active model. This will clear the active model previously set by 3996 :py:func:`mlflow.set_active_model` or via the ``MLFLOW_ACTIVE_MODEL_ID`` environment variable 3997 or the ``_MLFLOW_ACTIVE_MODEL_ID`` legacy environment variable. 3998 3999 from current thread. To temporarily switch 4000 the active model, use ``with mlflow.set_active_model(...)`` instead. 4001 4002 .. code-block:: python 4003 :test: 4004 :caption: Example 4005 4006 import mlflow 4007 4008 # Set the active model by name 4009 mlflow.set_active_model(name="my_model") 4010 4011 # Clear the active model 4012 mlflow.clear_active_model() 4013 # Check that the active model is None 4014 assert mlflow.get_active_model_id() is None 4015 4016 # If you want to temporarily set the active model, 4017 # use `set_active_model` as a context manager instead 4018 with mlflow.set_active_model(name="my_model") as active_model: 4019 assert mlflow.get_active_model_id() == active_model.model_id 4020 assert mlflow.get_active_model_id() is None 4021 """ 4022 # reset the environment variables as well to avoid them being used when creating 4023 # ActiveModelContext 4024 MLFLOW_ACTIVE_MODEL_ID.unset() 4025 _MLFLOW_ACTIVE_MODEL_ID.unset() 4026 4027 # Reset the active model context to avoid the active model ID set by other threads 4028 # to be used when creating a new ActiveModelContext 4029 _ACTIVE_MODEL_CONTEXT.reset() 4030 # set_by_user is False because this API clears the state of active model 4031 # and MLflow might still set the active model in cases like `load_model` 4032 _ACTIVE_MODEL_CONTEXT.set(ActiveModelContext(set_by_user=False)) 4033 _logger.info("Active model is cleared")