/ mlflow / tracking / fluent.py
fluent.py
   1  """
   2  Internal module implementing the fluent API, allowing management of an active
   3  MLflow run. This module is exposed to users at the top-level :py:mod:`mlflow` module.
   4  """
   5  
   6  import atexit
   7  import contextlib
   8  import importlib
   9  import inspect
  10  import io
  11  import logging
  12  import os
  13  import threading
  14  from copy import deepcopy
  15  from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Union, overload
  16  
  17  import mlflow
  18  from mlflow.entities import Dataset as DatasetEntity
  19  from mlflow.entities import (
  20      DatasetInput,
  21      Experiment,
  22      InputTag,
  23      LoggedModel,
  24      LoggedModelInput,
  25      LoggedModelOutput,
  26      LoggedModelStatus,
  27      Metric,
  28      Param,
  29      Run,
  30      RunInputs,
  31      RunStatus,
  32      RunTag,
  33      ViewType,
  34  )
  35  from mlflow.entities.lifecycle_stage import LifecycleStage
  36  from mlflow.entities.trace_location import UnityCatalog
  37  from mlflow.environment_variables import (
  38      _MLFLOW_ACTIVE_MODEL_ID,
  39      _MLFLOW_ENABLE_SGC_RUN_RESUMPTION_FOR_DATABRICKS_JOBS,
  40      MLFLOW_ACTIVE_MODEL_ID,
  41      MLFLOW_ENABLE_ASYNC_LOGGING,
  42      MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING,
  43      MLFLOW_EXPERIMENT_ID,
  44      MLFLOW_EXPERIMENT_NAME,
  45      MLFLOW_RUN_ID,
  46      MLFLOW_TRACING_SQL_WAREHOUSE_ID,
  47  )
  48  from mlflow.exceptions import MlflowException
  49  from mlflow.protos.databricks_pb2 import (
  50      INVALID_PARAMETER_VALUE,
  51      RESOURCE_DOES_NOT_EXIST,
  52  )
  53  from mlflow.store.tracking import SEARCH_MAX_RESULTS_DEFAULT
  54  from mlflow.telemetry.events import AutologgingEvent
  55  from mlflow.telemetry.track import _record_event
  56  from mlflow.tracing.provider import (
  57      _get_trace_exporter,
  58  )
  59  from mlflow.tracking._tracking_service.client import TrackingServiceClient
  60  from mlflow.tracking._tracking_service.utils import _resolve_tracking_uri
  61  from mlflow.utils import get_results_from_paginated_fn
  62  from mlflow.utils.annotations import experimental
  63  from mlflow.utils.async_logging.run_operations import RunOperations
  64  from mlflow.utils.autologging_utils import (
  65      AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED,
  66      AUTOLOGGING_INTEGRATIONS,
  67      autologging_conf_lock,
  68      autologging_integration,
  69      autologging_is_disabled,
  70      is_testing,
  71  )
  72  from mlflow.utils.databricks_utils import (
  73      get_sgc_job_run_id,
  74      is_in_databricks_model_serving_environment,
  75      is_in_databricks_runtime,
  76  )
  77  from mlflow.utils.file_utils import TempDir
  78  from mlflow.utils.import_hooks import register_post_import_hook
  79  from mlflow.utils.mlflow_tags import (
  80      MLFLOW_DATABRICKS_SGC_RESUME_RUN_JOB_RUN_ID_PREFIX,
  81      MLFLOW_DATASET_CONTEXT,
  82      MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH,
  83      MLFLOW_EXPERIMENT_PRIMARY_METRIC_GREATER_IS_BETTER,
  84      MLFLOW_EXPERIMENT_PRIMARY_METRIC_NAME,
  85      MLFLOW_MODEL_IS_EXTERNAL,
  86      MLFLOW_PARENT_RUN_ID,
  87      MLFLOW_RUN_NAME,
  88      MLFLOW_RUN_NOTE,
  89  )
  90  from mlflow.utils.thread_utils import ThreadLocalVariable
  91  from mlflow.utils.time import get_current_time_millis
  92  from mlflow.utils.uri import is_databricks_uri
  93  from mlflow.utils.validation import (
  94      _validate_experiment_id_type,
  95      _validate_logged_model_name,
  96      _validate_run_id,
  97  )
  98  from mlflow.version import IS_TRACING_SDK_ONLY
  99  
 100  if not IS_TRACING_SDK_ONLY:
 101      from mlflow.data.dataset import Dataset
 102      from mlflow.tracking import _get_artifact_repo, _get_store, artifact_utils
 103      from mlflow.tracking.client import MlflowClient
 104      from mlflow.tracking.context import registry as context_registry
 105      from mlflow.tracking.default_experiment import registry as default_experiment_registry
 106  
 107  
 108  if TYPE_CHECKING:
 109      import matplotlib.figure
 110      import numpy
 111      import pandas
 112      import PIL
 113      import plotly
 114  
 115  
 116  _active_experiment_id = None
 117  
 118  SEARCH_MAX_RESULTS_PANDAS = 100000
 119  NUM_RUNS_PER_PAGE_PANDAS = 10000
 120  
 121  _logger = logging.getLogger(__name__)
 122  
 123  
 124  run_id_to_system_metrics_monitor = {}
 125  
 126  
 127  _active_run_stack = ThreadLocalVariable(default_factory=lambda: [])
 128  
 129  _last_active_run_id = ThreadLocalVariable(default_factory=lambda: None)
 130  _last_logged_model_id = ThreadLocalVariable(default_factory=lambda: None)
 131  
 132  
 133  def _reset_last_logged_model_id() -> None:
 134      """
 135      Should be called only for testing purposes.
 136      """
 137      _last_logged_model_id.set(None)
 138  
 139  
 140  _experiment_lock = threading.Lock()
 141  
 142  
 143  def set_experiment(
 144      experiment_name: str | None = None,
 145      experiment_id: str | None = None,
 146      trace_location: UnityCatalog | None = None,
 147  ) -> Experiment:
 148      """
 149      Set the given experiment as the active experiment. The experiment must either be specified by
 150      name via `experiment_name` or by ID via `experiment_id`. The experiment name and ID cannot
 151      both be specified.
 152  
 153      .. note::
 154          If the experiment being set by name does not exist, a new experiment will be
 155          created with the given name. After the experiment has been created, it will be set
 156          as the active experiment. On certain platforms, such as Databricks, the experiment name
 157          must be an absolute path, e.g. ``"/Users/<username>/my-experiment"``.
 158  
 159      Args:
 160          experiment_name: Case sensitive name of the experiment to be activated.
 161          experiment_id: ID of the experiment to be activated. If an experiment with this ID
 162              does not exist, an exception is thrown.
 163          trace_location: Optional UC trace location used to configure the experiment-derived
 164              tracing destination. Must be an instance of
 165              ``mlflow.entities.trace_location.UnityCatalog(...)``.
 166  
 167      Returns:
 168          An instance of :py:class:`mlflow.entities.Experiment` representing the new active
 169          experiment.
 170  
 171      .. code-block:: python
 172          :test:
 173          :caption: Example
 174  
 175          import mlflow
 176  
 177          # Set an experiment name, which must be unique and case-sensitive.
 178          experiment = mlflow.set_experiment("Social NLP Experiments")
 179          # Get Experiment Details
 180          print(f"Experiment_id: {experiment.experiment_id}")
 181          print(f"Artifact Location: {experiment.artifact_location}")
 182          print(f"Tags: {experiment.tags}")
 183          print(f"Lifecycle_stage: {experiment.lifecycle_stage}")
 184  
 185      .. code-block:: text
 186          :caption: Output
 187  
 188          Experiment_id: 1
 189          Artifact Location: file:///.../mlruns/1
 190          Tags: {}
 191          Lifecycle_stage: active
 192      """
 193      if (experiment_name is not None and experiment_id is not None) or (
 194          experiment_name is None and experiment_id is None
 195      ):
 196          raise MlflowException(
 197              message="Must specify exactly one of: `experiment_id` or `experiment_name`.",
 198              error_code=INVALID_PARAMETER_VALUE,
 199          )
 200  
 201      client = TrackingServiceClient(_resolve_tracking_uri())
 202  
 203      is_newly_created = False
 204  
 205      with _experiment_lock:
 206          if experiment_id is None:
 207              experiment = client.get_experiment_by_name(experiment_name)
 208              if not experiment:
 209                  _logger.info(
 210                      "Experiment with name '%s' does not exist. Creating a new experiment.",
 211                      experiment_name,
 212                  )
 213                  try:
 214                      experiment_id = client.create_experiment(experiment_name)
 215                  except MlflowException as e:
 216                      if e.error_code == "RESOURCE_ALREADY_EXISTS":
 217                          # NB: If two simultaneous processes attempt to set the same experiment
 218                          # simultaneously, a race condition may be encountered here wherein
 219                          # experiment creation fails
 220                          return client.get_experiment_by_name(experiment_name)
 221                      raise
 222  
 223                  experiment = client.get_experiment(experiment_id)
 224                  is_newly_created = True
 225          else:
 226              experiment = client.get_experiment(experiment_id)
 227              if experiment is None:
 228                  raise MlflowException(
 229                      message=f"Experiment with ID '{experiment_id}' does not exist.",
 230                      error_code=RESOURCE_DOES_NOT_EXIST,
 231                  )
 232  
 233          if experiment.lifecycle_stage != LifecycleStage.ACTIVE:
 234              raise MlflowException(
 235                  message=(
 236                      f"Cannot set a deleted experiment {experiment.name!r} as the active"
 237                      " experiment. "
 238                      "You can restore the experiment, or permanently delete the "
 239                      "experiment to create a new one."
 240                  ),
 241                  error_code=INVALID_PARAMETER_VALUE,
 242              )
 243  
 244      if trace_location is not None and trace_location.table_prefix is None:
 245          trace_location = UnityCatalog(
 246              catalog_name=trace_location.catalog_name,
 247              schema_name=trace_location.schema_name,
 248              table_prefix=experiment.experiment_id,
 249          )
 250  
 251      try:
 252          resolved_location = _resolve_experiment_to_trace_location(
 253              experiment=experiment,
 254              trace_location=trace_location,
 255          )
 256      except MlflowException as e:
 257          if is_newly_created and trace_location is not None:
 258              raise MlflowException.invalid_parameter_value(
 259                  f"Experiment '{experiment.name}' (ID: {experiment.experiment_id}) was created "
 260                  f"but linking to trace location '{trace_location.full_table_prefix}' failed: "
 261                  f"{e.message} Please fix the issue and call set_experiment again to retry."
 262              ) from e
 263          raise
 264  
 265      global _active_experiment_id
 266      _active_experiment_id = experiment.experiment_id
 267  
 268      # Set 'MLFLOW_EXPERIMENT_ID' environment variable
 269      # so that subprocess can inherit it.
 270      MLFLOW_EXPERIMENT_ID.set(_active_experiment_id)
 271      if resolved_location is not None:
 272          experiment.trace_location = resolved_location
 273  
 274      _sync_trace_destination_and_provider(resolved_location)
 275  
 276      return experiment
 277  
 278  
 279  def _sync_trace_destination_and_provider(
 280      resolved_location: UnityCatalog | None,
 281  ) -> None:
 282      from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION, provider
 283  
 284      # If the tracer provider has already been initialized, reset it so the
 285      # next trace re-derives the correct processor chain from the new experiment.
 286      if provider.once._done:
 287          provider.reset()
 288  
 289      _MLFLOW_TRACE_USER_DESTINATION.set(resolved_location)
 290  
 291  
 292  def _resolve_experiment_to_trace_location(
 293      experiment: Experiment,
 294      trace_location: UnityCatalog | None,
 295  ) -> UnityCatalog | None:
 296      """Resolve the trace destination for an experiment without mutating state.
 297  
 298      All validation and network calls happen here. The caller is responsible
 299      for committing the result (setting experiment-derived destination, etc.).
 300  
 301      Returns:
 302          The resolved UnityCatalog location if one was configured, or None.
 303      """
 304      if trace_location is None:
 305          return None
 306      if not isinstance(trace_location, UnityCatalog):
 307          raise MlflowException.invalid_parameter_value(
 308              "`trace_location` must be an instance of `mlflow.entities.trace_location.UnityCatalog`."
 309          )
 310  
 311      if not is_databricks_uri(_resolve_tracking_uri()):
 312          raise MlflowException.invalid_parameter_value(
 313              "`trace_location` is only supported with a Databricks tracking URI."
 314          )
 315  
 316      # Check if experiment is already linked via the destination path tag (no backend call).
 317      if destination_path := experiment.tags.get(MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH):
 318          if destination_path == trace_location.full_table_prefix:
 319              return experiment.trace_location
 320          raise MlflowException.invalid_parameter_value(
 321              f"Experiment '{experiment.name}' is already linked to a different "
 322              f"trace location '{destination_path}'."
 323          )
 324  
 325      # No existing link — register and link via backend.
 326      from mlflow.tracing.client import TracingClient
 327  
 328      tracing_client = TracingClient()
 329      resolved = tracing_client._create_or_get_trace_location(
 330          trace_location,
 331          MLFLOW_TRACING_SQL_WAREHOUSE_ID.get(),
 332      )
 333      tracing_client._link_trace_location(
 334          experiment_id=experiment.experiment_id,
 335          location=resolved,
 336      )
 337      return resolved
 338  
 339  
 340  def _set_experiment_primary_metric(
 341      experiment_id: str, primary_metric: str, greater_is_better: bool
 342  ):
 343      client = MlflowClient()
 344      client.set_experiment_tag(experiment_id, MLFLOW_EXPERIMENT_PRIMARY_METRIC_NAME, primary_metric)
 345      client.set_experiment_tag(
 346          experiment_id, MLFLOW_EXPERIMENT_PRIMARY_METRIC_GREATER_IS_BETTER, str(greater_is_better)
 347      )
 348  
 349  
 350  class ActiveRun(Run):
 351      """Wrapper around :py:class:`mlflow.entities.Run` to enable using Python ``with`` syntax."""
 352  
 353      def __init__(self, run):
 354          Run.__init__(self, run.info, run.data)
 355  
 356      def __enter__(self):
 357          return self
 358  
 359      def __exit__(self, exc_type, exc_val, exc_tb):
 360          active_run_stack = _active_run_stack.get()
 361  
 362          # Check if the run is still active. We check based on ID instead of
 363          # using referential equality, because some tools (e.g. AutoML) may
 364          # stop a run and start it again with the same ID to restore session state
 365          if any(r.info.run_id == self.info.run_id for r in active_run_stack):
 366              status = RunStatus.FINISHED if exc_type is None else RunStatus.FAILED
 367              end_run(RunStatus.to_string(status))
 368  
 369          return exc_type is None
 370  
 371  
 372  def _get_sgc_job_run_id_tag_key() -> str | None:
 373      """
 374      Get the SGC job run ID tag key for run resumption if enabled and available.
 375  
 376      Returns:
 377          str or None: The experiment tag key for SGC resumption, or None if not applicable.
 378      """
 379      if not _MLFLOW_ENABLE_SGC_RUN_RESUMPTION_FOR_DATABRICKS_JOBS.get():
 380          return None
 381  
 382      if sgc_job_run_id := get_sgc_job_run_id():
 383          return f"{MLFLOW_DATABRICKS_SGC_RESUME_RUN_JOB_RUN_ID_PREFIX}.{sgc_job_run_id}"
 384  
 385      return None
 386  
 387  
 388  def _get_sgc_mlflow_run_id_for_resumption(
 389      client, experiment_id: str | None, sgc_job_run_id_tag_key: str | None
 390  ) -> str | None:
 391      """
 392      Retrieves the MLflow run ID associated with a specific SGC job run ID tag key
 393      for potential run resumption.
 394  
 395      This function searches the experiment (specified by `experiment_id`, or the
 396      default if None) for an experiment tag named `sgc_job_run_id_tag_key`. If the
 397      tag exists, its value (the run ID to resume) is returned; otherwise, returns None.
 398  
 399      Args:
 400          client: MlflowClient instance used to query experiment information.
 401          experiment_id: The experiment ID to search, or None to use the default.
 402          sgc_job_run_id_tag_key: The experiment tag key that maps the SGC job run ID
 403              to an MLflow run ID.
 404  
 405      Returns:
 406          str or None: The MLflow run ID to resume, if found; otherwise None.
 407      """
 408      search_exp_id = experiment_id or _get_experiment_id()
 409  
 410      try:
 411          exp = client.get_experiment(search_exp_id)
 412          # Check if experiment has the tag for resumption
 413          if prev_mlflow_run_id := exp.tags.get(sgc_job_run_id_tag_key):
 414              _logger.info(
 415                  f"Resuming MLflow run: {prev_mlflow_run_id} "
 416                  f"using SGC tag key: {sgc_job_run_id_tag_key}"
 417              )
 418              return prev_mlflow_run_id
 419      except Exception as e:
 420          _logger.debug(f"Failed to retrieve SGC run ID: {e}", exc_info=True)
 421  
 422      return None
 423  
 424  
 425  def start_run(
 426      run_id: str | None = None,
 427      experiment_id: str | None = None,
 428      run_name: str | None = None,
 429      nested: bool = False,
 430      parent_run_id: str | None = None,
 431      tags: dict[str, Any] | None = None,
 432      description: str | None = None,
 433      log_system_metrics: bool | None = None,
 434  ) -> ActiveRun:
 435      """
 436      Start a new MLflow run, setting it as the active run under which metrics and parameters
 437      will be logged. The return value can be used as a context manager within a ``with`` block;
 438      otherwise, you must call ``end_run()`` to terminate the current run.
 439  
 440      If you pass a ``run_id`` or the ``MLFLOW_RUN_ID`` environment variable is set,
 441      ``start_run`` attempts to resume a run with the specified run ID and
 442      other parameters are ignored. ``run_id`` takes precedence over ``MLFLOW_RUN_ID``.
 443  
 444      If resuming an existing run, the run status is set to ``RunStatus.RUNNING``.
 445  
 446      MLflow sets a variety of default tags on the run, as defined in
 447      `MLflow system tags <../../tracking/tracking-api.html#system_tags>`_.
 448  
 449      Args:
 450          run_id: If specified, get the run with the specified UUID and log parameters
 451              and metrics under that run. The run's end time is unset and its status
 452              is set to running, but the run's other attributes (``source_version``,
 453              ``source_type``, etc.) are not changed.
 454          experiment_id: ID of the experiment under which to create the current run (applicable
 455              only when ``run_id`` is not specified). If ``experiment_id`` argument
 456              is unspecified, will look for valid experiment in the following order:
 457              activated using ``set_experiment``, ``MLFLOW_EXPERIMENT_NAME``
 458              environment variable, ``MLFLOW_EXPERIMENT_ID`` environment variable,
 459              or the default experiment as defined by the tracking server.
 460          run_name: Name of new run, should be a non-empty string. Used only when ``run_id`` is
 461              unspecified. If a new run is created and ``run_name`` is not specified,
 462              a random name will be generated for the run.
 463          nested: Controls whether run is nested in parent run. ``True`` creates a nested run.
 464          parent_run_id: If specified, the current run will be nested under the the run with
 465              the specified UUID. The parent run must be in the ACTIVE state.
 466          tags: An optional dictionary of string keys and values to set as tags on the run.
 467              If a run is being resumed, these tags are set on the resumed run. If a new run is
 468              being created, these tags are set on the new run.
 469          description: An optional string that populates the description box of the run.
 470              If a run is being resumed, the description is set on the resumed run.
 471              If a new run is being created, the description is set on the new run.
 472          log_system_metrics: bool, defaults to None. If True, system metrics will be logged
 473              to MLflow, e.g., cpu/gpu utilization. If None, we will check environment variable
 474              `MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING` to determine whether to log system metrics.
 475              System metrics logging is an experimental feature in MLflow 2.8 and subject to change.
 476  
 477      Returns:
 478          :py:class:`mlflow.ActiveRun` object that acts as a context manager wrapping the
 479          run's state.
 480  
 481      .. code-block:: python
 482          :test:
 483          :caption: Example
 484  
 485          import mlflow
 486  
 487          # Create nested runs
 488          experiment_id = mlflow.create_experiment("experiment1")
 489          with mlflow.start_run(
 490              run_name="PARENT_RUN",
 491              experiment_id=experiment_id,
 492              tags={"version": "v1", "priority": "P1"},
 493              description="parent",
 494          ) as parent_run:
 495              mlflow.log_param("parent", "yes")
 496              with mlflow.start_run(
 497                  run_name="CHILD_RUN",
 498                  experiment_id=experiment_id,
 499                  description="child",
 500                  nested=True,
 501              ) as child_run:
 502                  mlflow.log_param("child", "yes")
 503          print("parent run:")
 504          print(f"run_id: {parent_run.info.run_id}")
 505          print("description: {}".format(parent_run.data.tags.get("mlflow.note.content")))
 506          print("version tag value: {}".format(parent_run.data.tags.get("version")))
 507          print("priority tag value: {}".format(parent_run.data.tags.get("priority")))
 508          print("--")
 509  
 510          # Search all child runs with a parent id
 511          query = f"tags.mlflow.parentRunId = '{parent_run.info.run_id}'"
 512          results = mlflow.search_runs(experiment_ids=[experiment_id], filter_string=query)
 513          print("child runs:")
 514          print(results[["run_id", "params.child", "tags.mlflow.runName"]])
 515  
 516          # Create a nested run under the existing parent run
 517          with mlflow.start_run(
 518              run_name="NEW_CHILD_RUN",
 519              experiment_id=experiment_id,
 520              description="new child",
 521              parent_run_id=parent_run.info.run_id,
 522          ) as child_run:
 523              mlflow.log_param("new-child", "yes")
 524  
 525      .. code-block:: text
 526          :caption: Output
 527  
 528          parent run:
 529          run_id: 8979459433a24a52ab3be87a229a9cdf
 530          description: starting a parent for experiment 7
 531          version tag value: v1
 532          priority tag value: P1
 533          --
 534          child runs:
 535                                       run_id params.child tags.mlflow.runName
 536          0  7d175204675e40328e46d9a6a5a7ee6a          yes           CHILD_RUN
 537      """
 538      active_run_stack = _active_run_stack.get()
 539      _validate_experiment_id_type(experiment_id)
 540      # back compat for int experiment_id
 541      experiment_id = str(experiment_id) if isinstance(experiment_id, int) else experiment_id
 542      if len(active_run_stack) > 0 and not nested:
 543          raise Exception(
 544              (
 545                  "Run with UUID {} is already active. To start a new run, first end the "
 546                  + "current run with mlflow.end_run(). To start a nested "
 547                  + "run, call start_run with nested=True"
 548              ).format(active_run_stack[0].info.run_id)
 549          )
 550      client = MlflowClient()
 551      sgc_job_run_id_tag_key: str | None = None
 552      if run_id:
 553          existing_run_id = run_id
 554      elif run_id := MLFLOW_RUN_ID.get():
 555          existing_run_id = run_id
 556          del os.environ[MLFLOW_RUN_ID.name]
 557      # Get SGC job run ID tag key for run resumption if applicable
 558      elif sgc_job_run_id_tag_key := _get_sgc_job_run_id_tag_key():
 559          existing_run_id = _get_sgc_mlflow_run_id_for_resumption(
 560              client, experiment_id, sgc_job_run_id_tag_key
 561          )
 562      else:
 563          existing_run_id = None
 564      if existing_run_id:
 565          _validate_run_id(existing_run_id)
 566          active_run_obj = client.get_run(existing_run_id)
 567          # Check to see if experiment_id from environment matches experiment_id from set_experiment()
 568          if (
 569              _active_experiment_id is not None
 570              and _active_experiment_id != active_run_obj.info.experiment_id
 571          ):
 572              raise MlflowException(
 573                  f"Cannot start run with ID {existing_run_id} because active experiment ID "
 574                  "does not match environment run ID. Make sure --experiment-name "
 575                  "or --experiment-id matches experiment set with "
 576                  "set_experiment(), or just use command-line arguments"
 577              )
 578          # Check if the current run has been deleted.
 579          if active_run_obj.info.lifecycle_stage == LifecycleStage.DELETED:
 580              raise MlflowException(
 581                  f"Cannot start run with ID {existing_run_id} because it is in the deleted state."
 582              )
 583          # Use previous `end_time` because a value is required for `update_run_info`.
 584          end_time = active_run_obj.info.end_time
 585          _get_store().update_run_info(
 586              existing_run_id, run_status=RunStatus.RUNNING, end_time=end_time, run_name=run_name
 587          )
 588          tags = tags or {}
 589          if description:
 590              if MLFLOW_RUN_NOTE in tags:
 591                  raise MlflowException(
 592                      f"Description is already set via the tag {MLFLOW_RUN_NOTE} in tags."
 593                      f"Remove the key {MLFLOW_RUN_NOTE} from the tags or omit the description.",
 594                      error_code=INVALID_PARAMETER_VALUE,
 595                  )
 596              tags[MLFLOW_RUN_NOTE] = description
 597  
 598          if tags:
 599              client.log_batch(
 600                  run_id=existing_run_id,
 601                  tags=[RunTag(key, str(value)) for key, value in tags.items()],
 602              )
 603          active_run_obj = client.get_run(existing_run_id)
 604      else:
 605          if parent_run_id:
 606              _validate_run_id(parent_run_id)
 607              # Make sure parent_run_id matches the current run id, if there is an active run
 608              if len(active_run_stack) > 0 and parent_run_id != active_run_stack[-1].info.run_id:
 609                  current_run_id = active_run_stack[-1].info.run_id
 610                  raise MlflowException(
 611                      f"Current run with UUID {current_run_id} does not match the specified "
 612                      f"parent_run_id {parent_run_id}. To start a new nested run under "
 613                      f"the parent run with UUID {current_run_id}, first end the current run "
 614                      "with mlflow.end_run()."
 615                  )
 616              parent_run_obj = client.get_run(parent_run_id)
 617              # Check if the specified parent_run has been deleted.
 618              if parent_run_obj.info.lifecycle_stage == LifecycleStage.DELETED:
 619                  raise MlflowException(
 620                      f"Cannot start run under parent run with ID {parent_run_id} "
 621                      f"because it is in the deleted state."
 622                  )
 623          else:
 624              parent_run_id = active_run_stack[-1].info.run_id if len(active_run_stack) > 0 else None
 625  
 626          exp_id_for_run = experiment_id if experiment_id is not None else _get_experiment_id()
 627  
 628          user_specified_tags = deepcopy(tags) or {}
 629          if description:
 630              if MLFLOW_RUN_NOTE in user_specified_tags:
 631                  raise MlflowException(
 632                      f"Description is already set via the tag {MLFLOW_RUN_NOTE} in tags."
 633                      f"Remove the key {MLFLOW_RUN_NOTE} from the tags or omit the description.",
 634                      error_code=INVALID_PARAMETER_VALUE,
 635                  )
 636              user_specified_tags[MLFLOW_RUN_NOTE] = description
 637          if parent_run_id is not None:
 638              user_specified_tags[MLFLOW_PARENT_RUN_ID] = parent_run_id
 639          if run_name:
 640              user_specified_tags[MLFLOW_RUN_NAME] = run_name
 641  
 642          resolved_tags = context_registry.resolve_tags(user_specified_tags)
 643  
 644          active_run_obj = client.create_run(
 645              experiment_id=exp_id_for_run,
 646              tags=resolved_tags,
 647              run_name=run_name,
 648          )
 649  
 650          # If SGC run resumption is enabled, set the experiment tag mapping
 651          # SGC job_run_id to this run_id for future run resumption
 652          if sgc_job_run_id_tag_key:
 653              try:
 654                  client.set_experiment_tag(
 655                      exp_id_for_run, sgc_job_run_id_tag_key, active_run_obj.info.run_id
 656                  )
 657                  _logger.info(
 658                      f"Set experiment tag {sgc_job_run_id_tag_key} = {active_run_obj.info.run_id} "
 659                      f"for SGC run resumption"
 660                  )
 661              except Exception as e:
 662                  _logger.debug(
 663                      f"Failed to set experiment tag for SGC resumption: {e}", exc_info=True
 664                  )
 665  
 666      if log_system_metrics is None:
 667          # If `log_system_metrics` is not specified, we will check environment variable.
 668          log_system_metrics = MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING.get()
 669  
 670      if log_system_metrics:
 671          if importlib.util.find_spec("psutil") is None:
 672              raise MlflowException(
 673                  "Failed to start system metrics monitoring as package `psutil` is not installed. "
 674                  "Please run `pip install psutil` to resolve the issue, otherwise you can disable "
 675                  "system metrics logging by passing `log_system_metrics=False` to "
 676                  "`mlflow.start_run()` or calling `mlflow.disable_system_metrics_logging`."
 677              )
 678          try:
 679              from mlflow.system_metrics.system_metrics_monitor import SystemMetricsMonitor
 680  
 681              system_monitor = SystemMetricsMonitor(
 682                  active_run_obj.info.run_id,
 683                  resume_logging=existing_run_id is not None,
 684              )
 685              run_id_to_system_metrics_monitor[active_run_obj.info.run_id] = system_monitor
 686              system_monitor.start()
 687          except Exception as e:
 688              _logger.error(f"Failed to start system metrics monitoring: {e}.")
 689  
 690      active_run_stack.append(ActiveRun(active_run_obj))
 691      return active_run_stack[-1]
 692  
 693  
 694  def end_run(status: str = RunStatus.to_string(RunStatus.FINISHED)) -> None:
 695      """
 696      End an active MLflow run (if there is one).
 697  
 698      .. code-block:: python
 699          :test:
 700          :caption: Example
 701  
 702          import mlflow
 703  
 704          # Start run and get status
 705          mlflow.start_run()
 706          run = mlflow.active_run()
 707          print(f"run_id: {run.info.run_id}; status: {run.info.status}")
 708  
 709          # End run and get status
 710          mlflow.end_run()
 711          run = mlflow.get_run(run.info.run_id)
 712          print(f"run_id: {run.info.run_id}; status: {run.info.status}")
 713          print("--")
 714  
 715          # Check for any active runs
 716          print(f"Active run: {mlflow.active_run()}")
 717  
 718      .. code-block:: text
 719          :caption: Output
 720  
 721          run_id: b47ee4563368419880b44ad8535f6371; status: RUNNING
 722          run_id: b47ee4563368419880b44ad8535f6371; status: FINISHED
 723          --
 724          Active run: None
 725      """
 726      active_run_stack = _active_run_stack.get()
 727      if len(active_run_stack) > 0:
 728          # Clear out the global existing run environment variable as well.
 729          MLFLOW_RUN_ID.unset()
 730          run = active_run_stack.pop()
 731          last_active_run_id = run.info.run_id
 732          _last_active_run_id.set(last_active_run_id)
 733          MlflowClient().set_terminated(last_active_run_id, status)
 734          if last_active_run_id in run_id_to_system_metrics_monitor:
 735              system_metrics_monitor = run_id_to_system_metrics_monitor.pop(last_active_run_id)
 736              system_metrics_monitor.finish()
 737  
 738  
 739  def _safe_end_run():
 740      with contextlib.suppress(Exception):
 741          end_run()
 742  
 743  
 744  atexit.register(_safe_end_run)
 745  
 746  
 747  def active_run() -> ActiveRun | None:
 748      """
 749      Get the currently active ``Run``, or None if no such run exists.
 750  
 751      .. attention::
 752          This API is **thread-local** and returns only the active run in the current thread.
 753          If your application is multi-threaded and a run is started in a different thread,
 754          this API will not retrieve that run.
 755  
 756      **Note**: You cannot access currently-active run attributes
 757      (parameters, metrics, etc.) through the run returned by ``mlflow.active_run``. In order
 758      to access such attributes, use the :py:class:`mlflow.client.MlflowClient` as follows:
 759  
 760      .. code-block:: python
 761          :test:
 762          :caption: Example
 763  
 764          import mlflow
 765  
 766          mlflow.start_run()
 767          run = mlflow.active_run()
 768          print(f"Active run_id: {run.info.run_id}")
 769          mlflow.end_run()
 770  
 771      .. code-block:: text
 772          :caption: Output
 773  
 774          Active run_id: 6f252757005748708cd3aad75d1ff462
 775      """
 776      active_run_stack = _active_run_stack.get()
 777      return active_run_stack[-1] if len(active_run_stack) > 0 else None
 778  
 779  
 780  def last_active_run() -> Run | None:
 781      """Gets the most recent active run.
 782  
 783      Examples:
 784  
 785      .. code-block:: python
 786          :test:
 787          :caption: To retrieve the most recent autologged run:
 788  
 789          import mlflow
 790  
 791          from sklearn.model_selection import train_test_split
 792          from sklearn.datasets import load_diabetes
 793          from sklearn.ensemble import RandomForestRegressor
 794  
 795          mlflow.autolog()
 796  
 797          db = load_diabetes()
 798          X_train, X_test, y_train, y_test = train_test_split(db.data, db.target)
 799  
 800          # Create and train models.
 801          rf = RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3)
 802          rf.fit(X_train, y_train)
 803  
 804          # Use the model to make predictions on the test dataset.
 805          predictions = rf.predict(X_test)
 806          autolog_run = mlflow.last_active_run()
 807  
 808      .. code-block:: python
 809          :test:
 810          :caption: To get the most recently active run that ended:
 811  
 812          import mlflow
 813  
 814          mlflow.start_run()
 815          mlflow.end_run()
 816          run = mlflow.last_active_run()
 817  
 818      .. code-block:: python
 819          :test:
 820          :caption: To retrieve the currently active run:
 821  
 822          import mlflow
 823  
 824          mlflow.start_run()
 825          run = mlflow.last_active_run()
 826          mlflow.end_run()
 827  
 828      Returns:
 829          The active run (this is equivalent to ``mlflow.active_run()``) if one exists.
 830          Otherwise, the last run started from the current Python process that reached
 831          a terminal status (i.e. FINISHED, FAILED, or KILLED).
 832      """
 833      _active_run = active_run()
 834      if _active_run is not None:
 835          return _active_run
 836  
 837      last_active_run_id = _last_active_run_id.get()
 838      if last_active_run_id is None:
 839          return None
 840      return get_run(last_active_run_id)
 841  
 842  
 843  def _get_latest_active_run():
 844      """
 845      Get active run from global context by checking all threads. The `mlflow.active_run` API
 846      only returns active run from current thread. This API is useful for the case where one
 847      needs to get a run started from a separate thread.
 848      """
 849      all_active_runs = [
 850          run for run_stack in _active_run_stack.get_all_thread_values().values() for run in run_stack
 851      ]
 852      if all_active_runs:
 853          return max(all_active_runs, key=lambda run: run.info.start_time)
 854      return None
 855  
 856  
 857  def get_run(run_id: str) -> Run:
 858      """
 859      Fetch the run from backend store. The resulting Run contains a collection of run metadata --
 860      RunInfo as well as a collection of run parameters, tags, and metrics -- RunData. It also
 861      contains a collection of run inputs (experimental), including information about datasets used by
 862      the run -- RunInputs. In the case where multiple metrics with the same key are logged for the
 863      run, the RunData contains the most recently logged value at the largest step for each metric.
 864  
 865      Args:
 866          run_id: Unique identifier for the run.
 867  
 868      Returns:
 869          A single Run object, if the run exists. Otherwise, raises an exception.
 870  
 871      .. code-block:: python
 872          :test:
 873          :caption: Example
 874  
 875          import mlflow
 876  
 877          with mlflow.start_run() as run:
 878              mlflow.log_param("p", 0)
 879          run_id = run.info.run_id
 880          print(f"run_id: {run_id}; lifecycle_stage: {mlflow.get_run(run_id).info.lifecycle_stage}")
 881  
 882      .. code-block:: text
 883          :caption: Output
 884  
 885          run_id: 7472befefc754e388e8e922824a0cca5; lifecycle_stage: active
 886      """
 887      return MlflowClient().get_run(run_id)
 888  
 889  
 890  def get_parent_run(run_id: str) -> Run | None:
 891      """Gets the parent run for the given run id if one exists.
 892  
 893      Args:
 894          run_id: Unique identifier for the child run.
 895  
 896      Returns:
 897          A single :py:class:`mlflow.entities.Run` object, if the parent run exists. Otherwise,
 898          returns None.
 899  
 900      .. code-block:: python
 901          :test:
 902          :caption: Example
 903  
 904          import mlflow
 905  
 906          # Create nested runs
 907          with mlflow.start_run():
 908              with mlflow.start_run(nested=True) as child_run:
 909                  child_run_id = child_run.info.run_id
 910  
 911          parent_run = mlflow.get_parent_run(child_run_id)
 912  
 913          print(f"child_run_id: {child_run_id}")
 914          print(f"parent_run_id: {parent_run.info.run_id}")
 915  
 916      .. code-block:: text
 917          :caption: Output
 918  
 919          child_run_id: 7d175204675e40328e46d9a6a5a7ee6a
 920          parent_run_id: 8979459433a24a52ab3be87a229a9cdf
 921      """
 922      return MlflowClient().get_parent_run(run_id)
 923  
 924  
 925  def log_param(key: str, value: Any, synchronous: bool | None = None) -> Any:
 926      """
 927      Log a parameter (e.g. model hyperparameter) under the current run. If no run is active,
 928      this method will create a new active run.
 929  
 930      Args:
 931          key: Parameter name. This string may only contain alphanumerics, underscores (_), dashes
 932              (-), periods (.), spaces ( ), and slashes (/). All backend stores support keys up to
 933              length 250, but some may support larger keys.
 934          value: Parameter value, but will be string-ified if not. All built-in backend stores support
 935              values up to length 6000, but some may support larger values.
 936          synchronous: *Experimental* If True, blocks until the parameter is logged successfully. If
 937              False, logs the parameter asynchronously and returns a future representing the logging
 938              operation. If None, read from environment variable `MLFLOW_ENABLE_ASYNC_LOGGING`,
 939              which defaults to False if not set.
 940  
 941      Returns:
 942          When `synchronous=True`, returns parameter value. When `synchronous=False`, returns an
 943          :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that represents
 944          future for logging operation.
 945  
 946      .. code-block:: python
 947          :test:
 948          :caption: Example
 949  
 950          import mlflow
 951  
 952          with mlflow.start_run():
 953              value = mlflow.log_param("learning_rate", 0.01)
 954              assert value == 0.01
 955              value = mlflow.log_param("learning_rate", 0.02, synchronous=False)
 956      """
 957      run_id = _get_or_start_run().info.run_id
 958      synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
 959      return MlflowClient().log_param(run_id, key, value, synchronous=synchronous)
 960  
 961  
 962  def flush_async_logging() -> None:
 963      """Flush all pending async logging."""
 964      _get_store().flush_async_logging()
 965  
 966  
 967  def _shut_down_async_logging() -> None:
 968      """Shutdown the async logging and flush all pending data."""
 969      _get_store().shut_down_async_logging()
 970  
 971  
 972  def flush_artifact_async_logging() -> None:
 973      """Flush all pending artifact async logging."""
 974      run_id = _get_or_start_run().info.run_id
 975      if _artifact_repo := _get_artifact_repo(run_id):
 976          _artifact_repo.flush_async_logging()
 977  
 978  
 979  def flush_trace_async_logging(terminate=False) -> None:
 980      """
 981      Flush all pending trace async logging.
 982  
 983      Args:
 984          terminate: If True, shut down the logging threads after flushing.
 985      """
 986      # Flush ALL batch span processors and their exporters' async queues.
 987      # When set_destination() is called multiple times, each call creates a new
 988      # tracer provider, processor, and exporter. The registry tracks all of them
 989      # so we drain both layers: span queue → exporter → async DB write queue.
 990      from mlflow.tracing.processor.base_mlflow import flush_all_batch_processors
 991  
 992      try:
 993          flush_all_batch_processors(terminate=terminate)
 994      except Exception as e:
 995          _logger.debug(f"Failed to flush batch processors: {e}", exc_info=True)
 996  
 997      # When batch processor is disabled (no registry entries), the current exporter
 998      # may still have an _async_queue that needs draining (SimpleSpanProcessor path).
 999      try:
1000          if trace_exporter := _get_trace_exporter():
1001              if hasattr(trace_exporter, "_async_queue"):
1002                  trace_exporter._async_queue.flush(terminate=terminate)
1003      except Exception as e:
1004          _logger.debug(f"Failed to flush trace exporter async queue: {e}", exc_info=True)
1005  
1006  
1007  def set_experiment_tag(key: str, value: Any) -> None:
1008      """
1009      Set a tag on the current experiment. Value is converted to a string.
1010  
1011      Args:
1012          key: Tag name. This string may only contain alphanumerics, underscores (_), dashes (-),
1013              periods (.), spaces ( ), and slashes (/). All backend stores will support keys up to
1014              length 250, but some may support larger keys.
1015          value: Tag value, but will be string-ified if not. All backend stores will support values
1016              up to length 5000, but some may support larger values.
1017  
1018      .. code-block:: python
1019          :test:
1020          :caption: Example
1021  
1022          import mlflow
1023  
1024          with mlflow.start_run():
1025              mlflow.set_experiment_tag("release.version", "2.2.0")
1026      """
1027      experiment_id = _get_experiment_id()
1028      MlflowClient().set_experiment_tag(experiment_id, key, value)
1029  
1030  
1031  def delete_experiment_tag(key: str) -> None:
1032      """
1033      Delete a tag from the current experiment.
1034  
1035      Args:
1036          key: Name of the tag to be deleted.
1037  
1038      .. code-block:: python
1039          :test:
1040          :caption: Example
1041  
1042          import mlflow
1043  
1044          exp = mlflow.set_experiment("test-delete-tag")
1045          mlflow.set_experiment_tag("release.version", "1.0")
1046          mlflow.delete_experiment_tag("release.version")
1047          exp = mlflow.get_experiment(exp.experiment_id)
1048          assert "release.version" not in exp.tags
1049      """
1050      experiment_id = _get_experiment_id()
1051      MlflowClient().delete_experiment_tag(experiment_id, key)
1052  
1053  
1054  def set_tag(key: str, value: Any, synchronous: bool | None = None) -> RunOperations | None:
1055      """
1056      Set a tag under the current run. If no run is active, this method will create a new active
1057      run.
1058  
1059      Args:
1060          key: Tag name. This string may only contain alphanumerics, underscores (_), dashes (-),
1061              periods (.), spaces ( ), and slashes (/). All backend stores will support keys up to
1062              length 250, but some may support larger keys.
1063          value: Tag value, but will be string-ified if not. All backend stores will support values
1064              up to length 5000, but some may support larger values.
1065          synchronous: *Experimental* If True, blocks until the tag is logged successfully. If False,
1066              logs the tag asynchronously and returns a future representing the logging operation.
1067              If None, read from environment variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which
1068              defaults to False if not set.
1069  
1070      Returns:
1071          When `synchronous=True`, returns None. When `synchronous=False`, returns an
1072          :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that
1073          represents future for logging operation.
1074  
1075      .. code-block:: python
1076          :test:
1077          :caption: Example
1078  
1079          import mlflow
1080  
1081          # Set a tag.
1082          with mlflow.start_run():
1083              mlflow.set_tag("release.version", "2.2.0")
1084  
1085          # Set a tag in async fashion.
1086          with mlflow.start_run():
1087              mlflow.set_tag("release.version", "2.2.1", synchronous=False)
1088      """
1089      run_id = _get_or_start_run().info.run_id
1090      synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
1091      return MlflowClient().set_tag(run_id, key, value, synchronous=synchronous)
1092  
1093  
1094  def delete_tag(key: str) -> None:
1095      """
1096      Delete a tag from a run. This is irreversible. If no run is active, this method
1097      will create a new active run.
1098  
1099      Args:
1100          key: Name of the tag
1101  
1102      .. code-block:: python
1103          :test:
1104          :caption: Example
1105  
1106          import mlflow
1107  
1108          tags = {"engineering": "ML Platform", "engineering_remote": "ML Platform"}
1109  
1110          with mlflow.start_run() as run:
1111              mlflow.set_tags(tags)
1112  
1113          with mlflow.start_run(run_id=run.info.run_id):
1114              mlflow.delete_tag("engineering_remote")
1115      """
1116      run_id = _get_or_start_run().info.run_id
1117      MlflowClient().delete_tag(run_id, key)
1118  
1119  
1120  def log_metric(
1121      key: str,
1122      value: float,
1123      step: int | None = None,
1124      synchronous: bool | None = None,
1125      timestamp: int | None = None,
1126      run_id: str | None = None,
1127      model_id: str | None = None,
1128      dataset: Union["Dataset", DatasetEntity] | None = None,
1129  ) -> RunOperations | None:
1130      """
1131      Log a metric under the current run. If no run is active, this method will create
1132      a new active run.
1133  
1134      Args:
1135          key: Metric name. This string may only contain alphanumerics, underscores (_),
1136              dashes (-), periods (.), spaces ( ), and slashes (/).
1137              All backend stores will support keys up to length 250, but some may
1138              support larger keys.
1139          value: Metric value. Note that some special values such as +/- Infinity may be
1140              replaced by other values depending on the store. For example, the
1141              SQLAlchemy store replaces +/- Infinity with max / min float values.
1142              All backend stores will support values up to length 5000, but some
1143              may support larger values.
1144          step: Metric step. Defaults to zero if unspecified.
1145          synchronous: *Experimental* If True, blocks until the metric is logged
1146              successfully. If False, logs the metric asynchronously and
1147              returns a future representing the logging operation. If None, read from environment
1148              variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which defaults to False if not set.
1149          timestamp: Time when this metric was calculated. Defaults to the current system time.
1150          run_id: If specified, log the metric to the specified run. If not specified, log the metric
1151              to the currently active run.
1152          model_id: The ID of the model associated with the metric. If not specified, use the current
1153              active model ID set by :py:func:`mlflow.set_active_model`. If no active model exists,
1154              the models IDs associated with the specified or active run will be used.
1155          dataset: The dataset associated with the metric.
1156  
1157      Returns:
1158          When `synchronous=True`, returns None.
1159          When `synchronous=False`, returns `RunOperations` that represents future for
1160          logging operation.
1161  
1162      .. code-block:: python
1163          :test:
1164          :caption: Example
1165  
1166          import mlflow
1167  
1168          # Log a metric
1169          with mlflow.start_run():
1170              mlflow.log_metric("mse", 2500.00)
1171  
1172          # Log a metric in async fashion.
1173          with mlflow.start_run():
1174              mlflow.log_metric("mse", 2500.00, synchronous=False)
1175      """
1176      run = _get_or_start_run() if run_id is None else MlflowClient().get_run(run_id)
1177      run_id = run.info.run_id
1178      synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
1179      model_id = model_id or get_active_model_id()
1180      _log_inputs_for_metrics_if_necessary(
1181          run,
1182          [
1183              Metric(
1184                  key=key,
1185                  value=value,
1186                  timestamp=timestamp or get_current_time_millis(),
1187                  step=step or 0,
1188                  model_id=model_id,
1189                  dataset_name=dataset.name if dataset is not None else None,
1190                  dataset_digest=dataset.digest if dataset is not None else None,
1191              ),
1192          ],
1193          datasets=[dataset] if dataset is not None else None,
1194      )
1195      timestamp = timestamp or get_current_time_millis()
1196      step = step or 0
1197      model_ids = (
1198          [model_id]
1199          if model_id is not None
1200          else (_get_model_ids_for_new_metric_if_exist(run, step) or [None])
1201      )
1202      for model_id in model_ids:
1203          return MlflowClient().log_metric(
1204              run_id,
1205              key,
1206              value,
1207              timestamp,
1208              step,
1209              synchronous=synchronous,
1210              model_id=model_id,
1211              dataset_name=dataset.name if dataset is not None else None,
1212              dataset_digest=dataset.digest if dataset is not None else None,
1213          )
1214  
1215  
1216  def _log_inputs_for_metrics_if_necessary(
1217      run: Run, metrics: list[Metric], datasets: list["Dataset"] | None = None
1218  ) -> None:
1219      client = MlflowClient()
1220      input_model_ids = (
1221          {i.model_id for i in run.inputs.model_inputs}
1222          if run.inputs and run.inputs.model_inputs
1223          else set()
1224      )
1225      output_model_ids = (
1226          {o.model_id for o in run.outputs.model_outputs}
1227          if run.outputs and run.outputs.model_outputs
1228          else set()
1229      )
1230      run_datasets = (
1231          [(inp.dataset.name, inp.dataset.digest) for inp in run.inputs.dataset_inputs]
1232          if run.inputs
1233          else []
1234      )
1235      datasets = datasets or []
1236      models_to_log = []
1237      datasets_to_log = []
1238      for metric in metrics:
1239          if (
1240              metric.model_id is not None
1241              and metric.model_id not in input_model_ids | output_model_ids
1242          ):
1243              models_to_log.append(LoggedModelInput(model_id=metric.model_id))
1244          if datasets and (metric.dataset_name, metric.dataset_digest) not in run_datasets:
1245              matching_dataset = next(
1246                  (
1247                      dataset
1248                      for dataset in datasets
1249                      if dataset.name == metric.dataset_name
1250                      and dataset.digest == metric.dataset_digest
1251                  ),
1252                  None,
1253              )
1254              if matching_dataset is not None:
1255                  if isinstance(matching_dataset, DatasetEntity):
1256                      dataset_entity = matching_dataset
1257                  else:
1258                      dataset_entity = matching_dataset._to_mlflow_entity()
1259                  datasets_to_log.append(DatasetInput(dataset_entity, tags=[]))
1260      if models_to_log or datasets_to_log:
1261          client.log_inputs(run.info.run_id, models=models_to_log, datasets=datasets_to_log)
1262          # update in-memory run inputs to avoid duplicate logging
1263          if run.inputs is None:
1264              run._inputs = RunInputs(dataset_inputs=datasets_to_log, model_inputs=models_to_log)
1265          else:
1266              run._inputs._model_inputs.extend(models_to_log)
1267              run._inputs._dataset_inputs.extend(datasets_to_log)
1268  
1269  
1270  def _get_model_ids_for_new_metric_if_exist(run: Run, metric_step: str) -> list[str]:
1271      outputs = run.outputs.model_outputs if run.outputs else []
1272      model_outputs_at_step = [mo for mo in outputs if mo.step == metric_step]
1273      return [mo.model_id for mo in model_outputs_at_step]
1274  
1275  
1276  def log_metrics(
1277      metrics: dict[str, float],
1278      step: int | None = None,
1279      synchronous: bool | None = None,
1280      run_id: str | None = None,
1281      timestamp: int | None = None,
1282      model_id: str | None = None,
1283      dataset: Union["Dataset", DatasetEntity] | None = None,
1284  ) -> RunOperations | None:
1285      """
1286      Log multiple metrics for the current run. If no run is active, this method will create a new
1287      active run.
1288  
1289      Args:
1290          metrics: Dictionary of metric_name: String -> value: Float. Note that some special
1291              values such as +/- Infinity may be replaced by other values depending on
1292              the store. For example, sql based store may replace +/- Infinity with
1293              max / min float values.
1294          step: A single integer step at which to log the specified
1295              Metrics. If unspecified, each metric is logged at step zero.
1296          synchronous: *Experimental* If True, blocks until the metrics are logged
1297              successfully. If False, logs the metrics asynchronously and
1298              returns a future representing the logging operation. If None, read from environment
1299              variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which defaults to False if not set.
1300          run_id: Run ID. If specified, log metrics to the specified run. If not specified, log
1301              metrics to the currently active run.
1302          timestamp: Time when these metrics were calculated. Defaults to the current system time.
1303          model_id: The ID of the model associated with the metric. If not specified, use the current
1304              active model ID set by :py:func:`mlflow.set_active_model`. If no active model
1305              exists, the models IDs associated with the specified or active run will be used.
1306          dataset: The dataset associated with the metrics.
1307  
1308      Returns:
1309          When `synchronous=True`, returns None. When `synchronous=False`, returns an
1310          :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that
1311          represents future for logging operation.
1312  
1313      .. code-block:: python
1314          :test:
1315          :caption: Example
1316  
1317          import mlflow
1318  
1319          metrics = {"mse": 2500.00, "rmse": 50.00}
1320  
1321          # Log a batch of metrics
1322          with mlflow.start_run():
1323              mlflow.log_metrics(metrics)
1324  
1325          # Log a batch of metrics in async fashion.
1326          with mlflow.start_run():
1327              mlflow.log_metrics(metrics, synchronous=False)
1328      """
1329      run = _get_or_start_run() if run_id is None else MlflowClient().get_run(run_id)
1330      run_id = run.info.run_id
1331      timestamp = timestamp or get_current_time_millis()
1332      step = step or 0
1333      dataset_name = dataset.name if dataset is not None else None
1334      dataset_digest = dataset.digest if dataset is not None else None
1335      model_id = model_id or get_active_model_id()
1336      model_ids = (
1337          [model_id]
1338          if model_id is not None
1339          else (_get_model_ids_for_new_metric_if_exist(run, step) or [None])
1340      )
1341      metrics_arr = [
1342          Metric(
1343              key,
1344              value,
1345              timestamp,
1346              step or 0,
1347              model_id=model_id,
1348              dataset_name=dataset_name,
1349              dataset_digest=dataset_digest,
1350              run_id=run_id,
1351          )
1352          for key, value in metrics.items()
1353          for model_id in model_ids
1354      ]
1355      _log_inputs_for_metrics_if_necessary(
1356          run, metrics_arr, [dataset] if dataset is not None else None
1357      )
1358      synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
1359      return MlflowClient().log_batch(
1360          run_id=run_id,
1361          metrics=metrics_arr,
1362          params=[],
1363          tags=[],
1364          synchronous=synchronous,
1365      )
1366  
1367  
1368  def log_params(
1369      params: dict[str, Any], synchronous: bool | None = None, run_id: str | None = None
1370  ) -> RunOperations | None:
1371      """
1372      Log a batch of params for the current run. If no run is active, this method will create a
1373      new active run.
1374  
1375      Args:
1376          params: Dictionary of param_name: String -> value: (String, but will be string-ified if
1377              not)
1378          synchronous: *Experimental* If True, blocks until the parameters are logged
1379              successfully. If False, logs the parameters asynchronously and
1380              returns a future representing the logging operation. If None, read from environment
1381              variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which defaults to False if not set.
1382          run_id: Run ID. If specified, log params to the specified run. If not specified, log
1383              params to the currently active run.
1384  
1385      Returns:
1386          When `synchronous=True`, returns None. When `synchronous=False`, returns an
1387          :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that
1388          represents future for logging operation.
1389  
1390      .. code-block:: python
1391          :test:
1392          :caption: Example
1393  
1394          import mlflow
1395  
1396          params = {"learning_rate": 0.01, "n_estimators": 10}
1397  
1398          # Log a batch of parameters
1399          with mlflow.start_run():
1400              mlflow.log_params(params)
1401  
1402          # Log a batch of parameters in async fashion.
1403          with mlflow.start_run():
1404              mlflow.log_params(params, synchronous=False)
1405      """
1406      run_id = run_id or _get_or_start_run().info.run_id
1407      params_arr = [Param(key, str(value)) for key, value in params.items()]
1408      synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
1409      return MlflowClient().log_batch(
1410          run_id=run_id, metrics=[], params=params_arr, tags=[], synchronous=synchronous
1411      )
1412  
1413  
1414  def _create_dataset_input(
1415      dataset: Optional["Dataset"],
1416      context: str | None = None,
1417      tags: dict[str, str] | None = None,
1418  ) -> DatasetInput | None:
1419      if (context or tags) and dataset is None:
1420          raise MlflowException.invalid_parameter_value(
1421              "`dataset` must be specified if `context` or `tags` is specified."
1422          )
1423      tags_to_log = []
1424      if tags:
1425          tags_to_log = [InputTag(key=key, value=value) for key, value in tags.items()]
1426      if context:
1427          tags_to_log.append(InputTag(key=MLFLOW_DATASET_CONTEXT, value=context))
1428  
1429      return DatasetInput(dataset=dataset._to_mlflow_entity(), tags=tags_to_log) if dataset else None
1430  
1431  
1432  def log_input(
1433      dataset: Optional["Dataset"] = None,
1434      context: str | None = None,
1435      tags: dict[str, str] | None = None,
1436      model: LoggedModelInput | None = None,
1437  ) -> None:
1438      """
1439      Log a dataset used in the current run.
1440  
1441      Args:
1442          dataset: :py:class:`mlflow.data.dataset.Dataset` object to be logged.
1443          context: Context in which the dataset is used. For example: "training", "testing".
1444              This will be set as an input tag with key `mlflow.data.context`.
1445          tags: Tags to be associated with the dataset. Dictionary of tag_key -> tag_value.
1446          model: A :py:class:`mlflow.entities.LoggedModelInput` instance to log as input to
1447              the run.
1448  
1449      .. code-block:: python
1450          :test:
1451          :caption: Example
1452  
1453          import numpy as np
1454          import mlflow
1455  
1456          array = np.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1457          dataset = mlflow.data.from_numpy(array, source="data.csv")
1458  
1459          # Log an input dataset used for training
1460          with mlflow.start_run():
1461              mlflow.log_input(dataset, context="training")
1462      """
1463      run_id = _get_or_start_run().info.run_id
1464      datasets = [_create_dataset_input(dataset, context, tags)] if dataset else None
1465      models = [model] if model else None
1466  
1467      MlflowClient().log_inputs(run_id=run_id, datasets=datasets, models=models)
1468  
1469  
1470  def log_inputs(
1471      datasets: list[Optional["Dataset"]] | None = None,
1472      contexts: list[str | None] | None = None,
1473      tags_list: list[dict[str, str] | None] | None = None,
1474      models: list[LoggedModelInput | None] | None = None,
1475  ) -> None:
1476      """
1477      Log a batch of datasets used in the current run.
1478  
1479      The lists of `datasets`, `contexts`, `tags_list` must have the same length.
1480      The entries in these lists can be ``None``, which represents empty value to the
1481      corresponding input.
1482  
1483      Args:
1484          datasets: List of :py:class:`mlflow.data.dataset.Dataset` object to be logged.
1485          contexts: List of context in which the dataset is used. For example: "training", "testing".
1486              This will be set as an input tag with key `mlflow.data.context`.
1487          tags_list: List of tags to be associated with the dataset. Dictionary of
1488              tag_key -> tag_value.
1489          models: List of :py:class:`mlflow.entities.LoggedModelInput` instance to log as input
1490              to the run. Currently only Databricks managed MLflow supports this argument.
1491  
1492      .. code-block:: python
1493          :test:
1494          :caption: Example
1495  
1496          import numpy as np
1497          import mlflow
1498  
1499          array = np.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1500          dataset = mlflow.data.from_numpy(array, source="data.csv")
1501  
1502          array2 = np.asarray([[-1, 2, 3], [-4, 5, 6]])
1503          dataset2 = mlflow.data.from_numpy(array2, source="data2.csv")
1504  
1505          # Log 2 input datasets used for training and test,
1506          # the training dataset has no tag.
1507          # the test dataset has tags `{"my_tag": "tag_value"}`.
1508          with mlflow.start_run():
1509              mlflow.log_inputs(
1510                  [dataset, dataset2],
1511                  contexts=["training", "test"],
1512                  tags_list=[None, {"my_tag": "tag_value"}],
1513                  models=None,
1514              )
1515      """
1516      from mlflow.utils.databricks_utils import is_databricks_uri
1517  
1518      run_id = _get_or_start_run().info.run_id
1519  
1520      datasets = datasets or []
1521      contexts = contexts or []
1522      tags_list = tags_list or []
1523      if not (len(datasets) == len(contexts) == len(tags_list)):
1524          raise MlflowException(
1525              "`mlflow.log_inputs` requires `datasets`, `contexts`, `tags_list` to be "
1526              "non-empty list and have the same length."
1527          )
1528  
1529      if models and not is_databricks_uri(mlflow.get_tracking_uri()):
1530          raise MlflowException("'models' argument is only supported by Databricks managed MLflow.")
1531  
1532      dataset_inputs = [
1533          _create_dataset_input(dataset, context, tags)
1534          for dataset, context, tags in zip(datasets, contexts, tags_list)
1535      ]
1536  
1537      MlflowClient().log_inputs(run_id=run_id, datasets=dataset_inputs, models=models)
1538  
1539  
1540  def set_experiment_tags(tags: dict[str, Any]) -> None:
1541      """
1542      Set tags for the current active experiment.
1543  
1544      Args:
1545          tags: Dictionary containing tag names and corresponding values.
1546  
1547      .. code-block:: python
1548          :test:
1549          :caption: Example
1550  
1551          import mlflow
1552  
1553          tags = {
1554              "engineering": "ML Platform",
1555              "release.candidate": "RC1",
1556              "release.version": "2.2.0",
1557          }
1558  
1559          # Set a batch of tags
1560          with mlflow.start_run():
1561              mlflow.set_experiment_tags(tags)
1562      """
1563      for key, value in tags.items():
1564          set_experiment_tag(key, value)
1565  
1566  
1567  def set_tags(tags: dict[str, Any], synchronous: bool | None = None) -> RunOperations | None:
1568      """
1569      Log a batch of tags for the current run. If no run is active, this method will create a
1570      new active run.
1571  
1572      Args:
1573          tags: Dictionary of tag_name: String -> value: (String, but will be string-ified if
1574              not)
1575          synchronous: *Experimental* If True, blocks until tags are logged successfully. If False,
1576              logs tags asynchronously and returns a future representing the logging operation.
1577              If None, read from environment variable `MLFLOW_ENABLE_ASYNC_LOGGING`, which
1578              defaults to False if not set.
1579  
1580      Returns:
1581          When `synchronous=True`, returns None. When `synchronous=False`, returns an
1582          :py:class:`mlflow.utils.async_logging.run_operations.RunOperations` instance that
1583          represents future for logging operation.
1584  
1585      .. code-block:: python
1586          :test:
1587          :caption: Example
1588  
1589          import mlflow
1590  
1591          tags = {
1592              "engineering": "ML Platform",
1593              "release.candidate": "RC1",
1594              "release.version": "2.2.0",
1595          }
1596  
1597          # Set a batch of tags
1598          with mlflow.start_run():
1599              mlflow.set_tags(tags)
1600  
1601          # Set a batch of tags in async fashion.
1602          with mlflow.start_run():
1603              mlflow.set_tags(tags, synchronous=False)
1604      """
1605      run_id = _get_or_start_run().info.run_id
1606      tags_arr = [RunTag(key, str(value)) for key, value in tags.items()]
1607      synchronous = synchronous if synchronous is not None else not MLFLOW_ENABLE_ASYNC_LOGGING.get()
1608      return MlflowClient().log_batch(
1609          run_id=run_id, metrics=[], params=[], tags=tags_arr, synchronous=synchronous
1610      )
1611  
1612  
1613  def log_artifact(
1614      local_path: str, artifact_path: str | None = None, run_id: str | None = None
1615  ) -> None:
1616      """
1617      Log a local file or directory as an artifact of the currently active run. If no run is
1618      active, this method will create a new active run.
1619  
1620      Args:
1621          local_path: Path to the file to write.
1622          artifact_path: If provided, the directory in ``artifact_uri`` to write to.
1623          run_id: If specified, log the artifact to the specified run. If not specified, log the
1624              artifact to the currently active run.
1625  
1626      .. code-block:: python
1627          :test:
1628          :caption: Example
1629  
1630          import tempfile
1631          from pathlib import Path
1632  
1633          import mlflow
1634  
1635          # Create a features.txt artifact file
1636          features = "rooms, zipcode, median_price, school_rating, transport"
1637          with tempfile.TemporaryDirectory() as tmp_dir:
1638              path = Path(tmp_dir, "features.txt")
1639              path.write_text(features)
1640              # With artifact_path=None write features.txt under
1641              # root artifact_uri/artifacts directory
1642              with mlflow.start_run():
1643                  mlflow.log_artifact(path)
1644      """
1645      run_id = run_id or _get_or_start_run().info.run_id
1646      MlflowClient().log_artifact(run_id, local_path, artifact_path)
1647  
1648  
1649  def log_artifacts(
1650      local_dir: str, artifact_path: str | None = None, run_id: str | None = None
1651  ) -> None:
1652      """
1653      Log all the contents of a local directory as artifacts of the run. If no run is active,
1654      this method will create a new active run.
1655  
1656      Args:
1657          local_dir: Path to the directory of files to write.
1658          artifact_path: If provided, the directory in ``artifact_uri`` to write to.
1659          run_id: If specified, log the artifacts to the specified run. If not specified, log the
1660              artifacts to the currently active run.
1661  
1662      .. code-block:: python
1663          :test:
1664          :caption: Example
1665  
1666          import json
1667          import tempfile
1668          from pathlib import Path
1669  
1670          import mlflow
1671  
1672          # Create some files to preserve as artifacts
1673          features = "rooms, zipcode, median_price, school_rating, transport"
1674          data = {"state": "TX", "Available": 25, "Type": "Detached"}
1675          with tempfile.TemporaryDirectory() as tmp_dir:
1676              tmp_dir = Path(tmp_dir)
1677              with (tmp_dir / "data.json").open("w") as f:
1678                  json.dump(data, f, indent=2)
1679              with (tmp_dir / "features.json").open("w") as f:
1680                  f.write(features)
1681              # Write all files in `tmp_dir` to root artifact_uri/states
1682              with mlflow.start_run():
1683                  mlflow.log_artifacts(tmp_dir, artifact_path="states")
1684      """
1685      run_id = run_id or _get_or_start_run().info.run_id
1686      MlflowClient().log_artifacts(run_id, local_dir, artifact_path)
1687  
1688  
1689  def log_text(text: str, artifact_file: str, run_id: str | None = None) -> None:
1690      """
1691      Log text as an artifact.
1692  
1693      Args:
1694          text: String containing text to log.
1695          artifact_file: The run-relative artifact file path in posixpath format to which
1696              the text is saved (e.g. "dir/file.txt").
1697          run_id: If specified, log the artifact to the specified run. If not specified, log the
1698              artifact to the currently active run.
1699  
1700      .. code-block:: python
1701          :test:
1702          :caption: Example
1703  
1704          import mlflow
1705  
1706          with mlflow.start_run():
1707              # Log text to a file under the run's root artifact directory
1708              mlflow.log_text("text1", "file1.txt")
1709  
1710              # Log text in a subdirectory of the run's root artifact directory
1711              mlflow.log_text("text2", "dir/file2.txt")
1712  
1713              # Log HTML text
1714              mlflow.log_text("<h1>header</h1>", "index.html")
1715  
1716      """
1717      run_id = run_id or _get_or_start_run().info.run_id
1718      MlflowClient().log_text(run_id, text, artifact_file)
1719  
1720  
1721  def log_dict(dictionary: dict[str, Any], artifact_file: str, run_id: str | None = None) -> None:
1722      """
1723      Log a JSON/YAML-serializable object (e.g. `dict`) as an artifact. The serialization
1724      format (JSON or YAML) is automatically inferred from the extension of `artifact_file`.
1725      If the file extension doesn't exist or match any of [".json", ".yml", ".yaml"],
1726      JSON format is used.
1727  
1728      Args:
1729          dictionary: Dictionary to log.
1730          artifact_file: The run-relative artifact file path in posixpath format to which
1731              the dictionary is saved (e.g. "dir/data.json").
1732          run_id: If specified, log the dictionary to the specified run. If not specified, log the
1733              dictionary to the currently active run.
1734  
1735      .. code-block:: python
1736          :test:
1737          :caption: Example
1738  
1739          import mlflow
1740  
1741          dictionary = {"k": "v"}
1742  
1743          with mlflow.start_run():
1744              # Log a dictionary as a JSON file under the run's root artifact directory
1745              mlflow.log_dict(dictionary, "data.json")
1746  
1747              # Log a dictionary as a YAML file in a subdirectory of the run's root artifact directory
1748              mlflow.log_dict(dictionary, "dir/data.yml")
1749  
1750              # If the file extension doesn't exist or match any of [".json", ".yaml", ".yml"],
1751              # JSON format is used.
1752              mlflow.log_dict(dictionary, "data")
1753              mlflow.log_dict(dictionary, "data.txt")
1754  
1755      """
1756      run_id = run_id or _get_or_start_run().info.run_id
1757      MlflowClient().log_dict(run_id, dictionary, artifact_file)
1758  
1759  
1760  @experimental(version="3.9.0")
1761  def log_stream(
1762      stream: io.BufferedIOBase | io.RawIOBase, artifact_file: str, run_id: str | None = None
1763  ) -> None:
1764      """
1765      Log a binary file-like object (e.g., ``io.BytesIO``) as an artifact.
1766  
1767      Args:
1768          stream: A binary file-like object supporting ``.read()`` method (e.g., ``io.BytesIO``).
1769          artifact_file: The run-relative artifact file path in posixpath format to which
1770              the stream content is saved (e.g. "dir/file.bin").
1771          run_id: If specified, log the artifact to the specified run. If not specified, log the
1772              artifact to the currently active run.
1773  
1774      .. code-block:: python
1775          :test:
1776          :caption: Example
1777  
1778          import io
1779  
1780          import mlflow
1781  
1782          with mlflow.start_run():
1783              # Log a BytesIO stream
1784              bytes_stream = io.BytesIO(b"binary content")
1785              mlflow.log_stream(bytes_stream, "binary_file.bin")
1786  
1787      """
1788      run_id = run_id or _get_or_start_run().info.run_id
1789      MlflowClient().log_stream(run_id, stream, artifact_file)
1790  
1791  
1792  def log_figure(
1793      figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"],
1794      artifact_file: str,
1795      *,
1796      save_kwargs: dict[str, Any] | None = None,
1797  ) -> None:
1798      """
1799      Log a figure as an artifact. The following figure objects are supported:
1800  
1801      - `matplotlib.figure.Figure`_
1802      - `plotly.graph_objects.Figure`_
1803  
1804      .. _matplotlib.figure.Figure:
1805          https://matplotlib.org/api/_as_gen/matplotlib.figure.Figure.html
1806  
1807      .. _plotly.graph_objects.Figure:
1808          https://plotly.com/python-api-reference/generated/plotly.graph_objects.Figure.html
1809  
1810      Args:
1811          figure: Figure to log.
1812          artifact_file: The run-relative artifact file path in posixpath format to which
1813              the figure is saved (e.g. "dir/file.png").
1814          save_kwargs: Additional keyword arguments passed to the method that saves the figure.
1815  
1816      .. code-block:: python
1817          :test:
1818          :caption: Matplotlib Example
1819  
1820          import mlflow
1821          import matplotlib.pyplot as plt
1822  
1823          fig, ax = plt.subplots()
1824          ax.plot([0, 1], [2, 3])
1825  
1826          with mlflow.start_run():
1827              mlflow.log_figure(fig, "figure.png")
1828  
1829      .. code-block:: python
1830          :test:
1831          :caption: Plotly Example
1832  
1833          import mlflow
1834          from plotly import graph_objects as go
1835  
1836          fig = go.Figure(go.Scatter(x=[0, 1], y=[2, 3]))
1837  
1838          with mlflow.start_run():
1839              mlflow.log_figure(fig, "figure.html")
1840      """
1841      run_id = _get_or_start_run().info.run_id
1842      MlflowClient().log_figure(run_id, figure, artifact_file, save_kwargs=save_kwargs)
1843  
1844  
1845  def log_image(
1846      image: Union["numpy.ndarray", "PIL.Image.Image", "mlflow.Image"],
1847      artifact_file: str | None = None,
1848      key: str | None = None,
1849      step: int | None = None,
1850      timestamp: int | None = None,
1851      synchronous: bool | None = False,
1852  ) -> None:
1853      """
1854      Logs an image in MLflow, supporting two use cases:
1855  
1856      1. Time-stepped image logging:
1857          Ideal for tracking changes or progressions through iterative processes (e.g.,
1858          during model training phases).
1859  
1860          - Usage: :code:`log_image(image, key=key, step=step, timestamp=timestamp)`
1861  
1862      2. Artifact file image logging:
1863          Best suited for static image logging where the image is saved directly as a file
1864          artifact.
1865  
1866          - Usage: :code:`log_image(image, artifact_file)`
1867  
1868      The following image formats are supported:
1869          - `numpy.ndarray`_
1870          - `PIL.Image.Image`_
1871  
1872          .. _numpy.ndarray:
1873              https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html
1874  
1875          .. _PIL.Image.Image:
1876              https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image
1877  
1878          - :class:`mlflow.Image`: An MLflow wrapper around PIL image for convenient image logging.
1879  
1880      Numpy array support
1881          - data types:
1882  
1883              - bool (useful for logging image masks)
1884              - integer [0, 255]
1885              - unsigned integer [0, 255]
1886              - float [0.0, 1.0]
1887  
1888              .. warning::
1889  
1890                  - Out-of-range integer values will raise ValueError.
1891                  - Out-of-range float values will auto-scale with min/max and warn.
1892  
1893          - shape (H: height, W: width):
1894  
1895              - H x W (Grayscale)
1896              - H x W x 1 (Grayscale)
1897              - H x W x 3 (an RGB channel order is assumed)
1898              - H x W x 4 (an RGBA channel order is assumed)
1899  
1900      Args:
1901          image: The image object to be logged.
1902          artifact_file: Specifies the path, in POSIX format, where the image
1903              will be stored as an artifact relative to the run's root directory (for
1904              example, "dir/image.png"). This parameter is kept for backward compatibility
1905              and should not be used together with `key`, `step`, or `timestamp`.
1906          key: Image name for time-stepped image logging. This string may only contain
1907              alphanumerics, underscores (_), dashes (-), periods (.), spaces ( ), and
1908              slashes (/).
1909          step: Integer training step (iteration) at which the image was saved.
1910              Defaults to 0.
1911          timestamp: Time when this image was saved. Defaults to the current system time.
1912          synchronous: *Experimental* If True, blocks until the image is logged successfully.
1913  
1914      .. code-block:: python
1915          :caption: Time-stepped image logging numpy example
1916  
1917          import mlflow
1918          import numpy as np
1919  
1920          image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
1921  
1922          with mlflow.start_run():
1923              mlflow.log_image(image, key="dogs", step=3)
1924  
1925      .. code-block:: python
1926          :caption: Time-stepped image logging pillow example
1927  
1928          import mlflow
1929          from PIL import Image
1930  
1931          image = Image.new("RGB", (100, 100))
1932  
1933          with mlflow.start_run():
1934              mlflow.log_image(image, key="dogs", step=3)
1935  
1936      .. code-block:: python
1937          :caption: Time-stepped image logging with mlflow.Image example
1938  
1939          import mlflow
1940          from PIL import Image
1941  
1942          # If you have a preexisting saved image
1943          Image.new("RGB", (100, 100)).save("image.png")
1944  
1945          image = mlflow.Image("image.png")
1946          with mlflow.start_run() as run:
1947              mlflow.log_image(run.info.run_id, image, key="dogs", step=3)
1948  
1949      .. code-block:: python
1950          :caption: Legacy artifact file image logging numpy example
1951  
1952          import mlflow
1953          import numpy as np
1954  
1955          image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
1956  
1957          with mlflow.start_run():
1958              mlflow.log_image(image, "image.png")
1959  
1960      .. code-block:: python
1961          :caption: Legacy artifact file image logging pillow example
1962  
1963          import mlflow
1964          from PIL import Image
1965  
1966          image = Image.new("RGB", (100, 100))
1967  
1968          with mlflow.start_run():
1969              mlflow.log_image(image, "image.png")
1970      """
1971      run_id = _get_or_start_run().info.run_id
1972      MlflowClient().log_image(run_id, image, artifact_file, key, step, timestamp, synchronous)
1973  
1974  
1975  def log_table(
1976      data: Union[dict[str, Any], "pandas.DataFrame"],
1977      artifact_file: str,
1978      run_id: str | None = None,
1979  ) -> None:
1980      """
1981      Log a table to MLflow Tracking as a JSON artifact. If the artifact_file already exists
1982      in the run, the data would be appended to the existing artifact_file.
1983  
1984      Args:
1985          data: Dictionary or pandas.DataFrame to log.
1986          artifact_file: The run-relative artifact file path in posixpath format to which
1987              the table is saved (e.g. "dir/file.json").
1988          run_id: If specified, log the table to the specified run. If not specified, log the
1989              table to the currently active run.
1990  
1991      .. code-block:: python
1992          :test:
1993          :caption: Dictionary Example
1994  
1995          import mlflow
1996  
1997          table_dict = {
1998              "inputs": ["What is MLflow?", "What is Databricks?"],
1999              "outputs": ["MLflow is ...", "Databricks is ..."],
2000              "toxicity": [0.0, 0.0],
2001          }
2002          with mlflow.start_run():
2003              # Log the dictionary as a table
2004              mlflow.log_table(data=table_dict, artifact_file="qabot_eval_results.json")
2005  
2006      .. code-block:: python
2007          :test:
2008          :caption: Pandas DF Example
2009  
2010          import mlflow
2011          import pandas as pd
2012  
2013          table_dict = {
2014              "inputs": ["What is MLflow?", "What is Databricks?"],
2015              "outputs": ["MLflow is ...", "Databricks is ..."],
2016              "toxicity": [0.0, 0.0],
2017          }
2018          df = pd.DataFrame.from_dict(table_dict)
2019          with mlflow.start_run():
2020              # Log the df as a table
2021              mlflow.log_table(data=df, artifact_file="qabot_eval_results.json")
2022      """
2023      run_id = run_id or _get_or_start_run().info.run_id
2024      MlflowClient().log_table(run_id, data, artifact_file)
2025  
2026  
2027  def load_table(
2028      artifact_file: str,
2029      run_ids: list[str] | None = None,
2030      extra_columns: list[str] | None = None,
2031  ) -> "pandas.DataFrame":
2032      """
2033      Load a table from MLflow Tracking as a pandas.DataFrame. The table is loaded from the
2034      specified artifact_file in the specified run_ids. The extra_columns are columns that
2035      are not in the table but are augmented with run information and added to the DataFrame.
2036  
2037      Args:
2038          artifact_file: The run-relative artifact file path in posixpath format to which
2039              table to load (e.g. "dir/file.json").
2040          run_ids: Optional list of run_ids to load the table from. If no run_ids are specified,
2041              the table is loaded from all runs in the current experiment.
2042          extra_columns: Optional list of extra columns to add to the returned DataFrame
2043              For example, if extra_columns=["run_id"], then the returned DataFrame
2044              will have a column named run_id.
2045  
2046      Returns:
2047          pandas.DataFrame containing the loaded table if the artifact exists
2048          or else throw a MlflowException.
2049  
2050      .. code-block:: python
2051          :test:
2052          :caption: Example with passing run_ids
2053  
2054          import mlflow
2055  
2056          table_dict = {
2057              "inputs": ["What is MLflow?", "What is Databricks?"],
2058              "outputs": ["MLflow is ...", "Databricks is ..."],
2059              "toxicity": [0.0, 0.0],
2060          }
2061  
2062          with mlflow.start_run() as run:
2063              # Log the dictionary as a table
2064              mlflow.log_table(data=table_dict, artifact_file="qabot_eval_results.json")
2065              run_id = run.info.run_id
2066  
2067          loaded_table = mlflow.load_table(
2068              artifact_file="qabot_eval_results.json",
2069              run_ids=[run_id],
2070              # Append a column containing the associated run ID for each row
2071              extra_columns=["run_id"],
2072          )
2073  
2074      .. code-block:: python
2075          :test:
2076          :caption: Example with passing no run_ids
2077  
2078          # Loads the table with the specified name for all runs in the given
2079          # experiment and joins them together
2080          import mlflow
2081  
2082          table_dict = {
2083              "inputs": ["What is MLflow?", "What is Databricks?"],
2084              "outputs": ["MLflow is ...", "Databricks is ..."],
2085              "toxicity": [0.0, 0.0],
2086          }
2087  
2088          with mlflow.start_run():
2089              # Log the dictionary as a table
2090              mlflow.log_table(data=table_dict, artifact_file="qabot_eval_results.json")
2091  
2092          loaded_table = mlflow.load_table(
2093              "qabot_eval_results.json",
2094              # Append the run ID and the parent run ID to the table
2095              extra_columns=["run_id"],
2096          )
2097      """
2098      experiment_id = _get_experiment_id()
2099      return MlflowClient().load_table(experiment_id, artifact_file, run_ids, extra_columns)
2100  
2101  
2102  def _record_logged_model(mlflow_model, run_id=None):
2103      run_id = run_id or _get_or_start_run().info.run_id
2104      MlflowClient()._record_logged_model(run_id, mlflow_model)
2105  
2106  
2107  def get_experiment(experiment_id: str) -> Experiment:
2108      """Retrieve an experiment by experiment_id from the backend store
2109  
2110      Args:
2111          experiment_id: The string-ified experiment ID returned from ``create_experiment``.
2112  
2113      Returns:
2114          :py:class:`mlflow.entities.Experiment`
2115  
2116      .. code-block:: python
2117          :test:
2118          :caption: Example
2119  
2120          import mlflow
2121  
2122          experiment = mlflow.get_experiment("0")
2123          print(f"Name: {experiment.name}")
2124          print(f"Artifact Location: {experiment.artifact_location}")
2125          print(f"Tags: {experiment.tags}")
2126          print(f"Lifecycle_stage: {experiment.lifecycle_stage}")
2127          print(f"Creation timestamp: {experiment.creation_time}")
2128  
2129      .. code-block:: text
2130          :caption: Output
2131  
2132          Name: Default
2133          Artifact Location: file:///.../mlruns/0
2134          Tags: {}
2135          Lifecycle_stage: active
2136          Creation timestamp: 1662004217511
2137      """
2138      return MlflowClient().get_experiment(experiment_id)
2139  
2140  
2141  def get_experiment_by_name(name: str) -> Experiment | None:
2142      """
2143      Retrieve an experiment by experiment name from the backend store
2144  
2145      Args:
2146          name: The case sensitive experiment name.
2147  
2148      Returns:
2149          An instance of :py:class:`mlflow.entities.Experiment`
2150          if an experiment with the specified name exists, otherwise None.
2151  
2152      .. code-block:: python
2153          :test:
2154          :caption: Example
2155  
2156          import mlflow
2157  
2158          # Case sensitive name
2159          experiment = mlflow.get_experiment_by_name("Default")
2160          print(f"Experiment_id: {experiment.experiment_id}")
2161          print(f"Artifact Location: {experiment.artifact_location}")
2162          print(f"Tags: {experiment.tags}")
2163          print(f"Lifecycle_stage: {experiment.lifecycle_stage}")
2164          print(f"Creation timestamp: {experiment.creation_time}")
2165  
2166      .. code-block:: text
2167          :caption: Output
2168  
2169          Experiment_id: 0
2170          Artifact Location: file:///.../mlruns/0
2171          Tags: {}
2172          Lifecycle_stage: active
2173          Creation timestamp: 1662004217511
2174      """
2175      return MlflowClient().get_experiment_by_name(name)
2176  
2177  
2178  def search_experiments(
2179      view_type: int = ViewType.ACTIVE_ONLY,
2180      max_results: int | None = None,
2181      filter_string: str | None = None,
2182      order_by: list[str] | None = None,
2183  ) -> list[Experiment]:
2184      """
2185      Search for experiments that match the specified search query.
2186  
2187      Args:
2188          view_type: One of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``, or ``ALL``
2189              defined in :py:class:`mlflow.entities.ViewType`.
2190          max_results: If passed, specifies the maximum number of experiments desired. If not
2191              passed, all experiments will be returned.
2192          filter_string: Filter query string (e.g., ``"name = 'my_experiment'"``), defaults to
2193              searching for all experiments. The following identifiers, comparators, and logical
2194              operators are supported.
2195  
2196              Identifiers
2197                - ``name``: Experiment name
2198                - ``creation_time``: Experiment creation time
2199                - ``last_update_time``: Experiment last update time
2200                - ``tags.<tag_key>``: Experiment tag. If ``tag_key`` contains
2201                  spaces, it must be wrapped with backticks (e.g., ``"tags.`extra key`"``).
2202  
2203              Comparators for string attributes and tags
2204                  - ``=``: Equal to
2205                  - ``!=``: Not equal to
2206                  - ``LIKE``: Case-sensitive pattern match
2207                  - ``ILIKE``: Case-insensitive pattern match
2208  
2209              Comparators for numeric attributes
2210                  - ``=``: Equal to
2211                  - ``!=``: Not equal to
2212                  - ``<``: Less than
2213                  - ``<=``: Less than or equal to
2214                  - ``>``: Greater than
2215                  - ``>=``: Greater than or equal to
2216  
2217              Logical operators
2218                - ``AND``: Combines two sub-queries and returns True if both of them are True.
2219  
2220          order_by: List of columns to order by. The ``order_by`` column can contain an optional
2221              ``DESC`` or ``ASC`` value (e.g., ``"name DESC"``). The default ordering is ``ASC``,
2222              so ``"name"`` is equivalent to ``"name ASC"``. If unspecified, defaults to
2223              ``["last_update_time DESC"]``, which lists experiments updated most recently first.
2224              The following fields are supported:
2225  
2226                  - ``experiment_id``: Experiment ID
2227                  - ``name``: Experiment name
2228                  - ``creation_time``: Experiment creation time
2229                  - ``last_update_time``: Experiment last update time
2230  
2231      Returns:
2232          A list of :py:class:`Experiment <mlflow.entities.Experiment>` objects.
2233  
2234      .. code-block:: python
2235          :test:
2236          :caption: Example
2237  
2238          import mlflow
2239  
2240  
2241          def assert_experiment_names_equal(experiments, expected_names):
2242              actual_names = [e.name for e in experiments if e.name != "Default"]
2243              assert actual_names == expected_names, (actual_names, expected_names)
2244  
2245  
2246          mlflow.set_tracking_uri("sqlite:///:memory:")
2247          # Create experiments
2248          for name, tags in [
2249              ("a", None),
2250              ("b", None),
2251              ("ab", {"k": "v"}),
2252              ("bb", {"k": "V"}),
2253          ]:
2254              mlflow.create_experiment(name, tags=tags)
2255  
2256          # Search for experiments with name "a"
2257          experiments = mlflow.search_experiments(filter_string="name = 'a'")
2258          assert_experiment_names_equal(experiments, ["a"])
2259          # Search for experiments with name starting with "a"
2260          experiments = mlflow.search_experiments(filter_string="name LIKE 'a%'")
2261          assert_experiment_names_equal(experiments, ["ab", "a"])
2262          # Search for experiments with tag key "k" and value ending with "v" or "V"
2263          experiments = mlflow.search_experiments(filter_string="tags.k ILIKE '%v'")
2264          assert_experiment_names_equal(experiments, ["bb", "ab"])
2265          # Search for experiments with name ending with "b" and tag {"k": "v"}
2266          experiments = mlflow.search_experiments(filter_string="name LIKE '%b' AND tags.k = 'v'")
2267          assert_experiment_names_equal(experiments, ["ab"])
2268          # Sort experiments by name in ascending order
2269          experiments = mlflow.search_experiments(order_by=["name"])
2270          assert_experiment_names_equal(experiments, ["a", "ab", "b", "bb"])
2271          # Sort experiments by ID in descending order
2272          experiments = mlflow.search_experiments(order_by=["experiment_id DESC"])
2273          assert_experiment_names_equal(experiments, ["bb", "ab", "b", "a"])
2274      """
2275  
2276      def pagination_wrapper_func(number_to_get, next_page_token):
2277          return MlflowClient().search_experiments(
2278              view_type=view_type,
2279              max_results=number_to_get,
2280              filter_string=filter_string,
2281              order_by=order_by,
2282              page_token=next_page_token,
2283          )
2284  
2285      return get_results_from_paginated_fn(
2286          pagination_wrapper_func,
2287          SEARCH_MAX_RESULTS_DEFAULT,
2288          max_results,
2289      )
2290  
2291  
2292  def create_experiment(
2293      name: str,
2294      artifact_location: str | None = None,
2295      tags: dict[str, Any] | None = None,
2296      trace_location: UnityCatalog | None = None,
2297  ) -> str:
2298      """
2299      Create an experiment.
2300  
2301      Args:
2302          name: The experiment name, must be a non-empty unique string.
2303          artifact_location: The location to store run artifacts. If not provided, the server picks
2304              an appropriate default.
2305          tags: An optional dictionary of string keys and values to set as tags on the experiment.
2306          trace_location: Optional UC trace location to link to the experiment. Must be an instance
2307              of ``mlflow.entities.trace_location.UnityCatalog(...)``. If ``table_prefix`` is not
2308              set, it defaults to the experiment ID. Note: call ``mlflow.set_experiment`` afterward
2309              to activate the experiment and sync the trace provider.
2310  
2311      Returns:
2312          String ID of the created experiment.
2313  
2314       .. code-block:: python
2315          :test:
2316          :caption: Example
2317  
2318          import mlflow
2319          from pathlib import Path
2320  
2321          # Create an experiment name, which must be unique and case sensitive
2322          experiment_id = mlflow.create_experiment(
2323              "Social NLP Experiments",
2324              artifact_location=Path.cwd().joinpath("mlruns").as_uri(),
2325              tags={"version": "v1", "priority": "P1"},
2326          )
2327          experiment = mlflow.get_experiment(experiment_id)
2328          print(f"Name: {experiment.name}")
2329          print(f"Experiment_id: {experiment.experiment_id}")
2330          print(f"Artifact Location: {experiment.artifact_location}")
2331          print(f"Tags: {experiment.tags}")
2332          print(f"Lifecycle_stage: {experiment.lifecycle_stage}")
2333          print(f"Creation timestamp: {experiment.creation_time}")
2334  
2335      .. code-block:: text
2336          :caption: Output
2337  
2338          Name: Social NLP Experiments
2339          Experiment_id: 1
2340          Artifact Location: file:///.../mlruns
2341          Tags: {'version': 'v1', 'priority': 'P1'}
2342          Lifecycle_stage: active
2343          Creation timestamp: 1662004217511
2344      """
2345      client = MlflowClient()
2346      experiment_id = client.create_experiment(name, artifact_location, tags)
2347  
2348      if trace_location is not None:
2349          experiment = client.get_experiment(experiment_id)
2350  
2351          if trace_location.table_prefix is None:
2352              trace_location = UnityCatalog(
2353                  catalog_name=trace_location.catalog_name,
2354                  schema_name=trace_location.schema_name,
2355                  table_prefix=experiment_id,
2356              )
2357  
2358          try:
2359              _resolve_experiment_to_trace_location(
2360                  experiment=experiment,
2361                  trace_location=trace_location,
2362              )
2363          except MlflowException as e:
2364              raise MlflowException.invalid_parameter_value(
2365                  f"Experiment '{name}' (ID: {experiment_id}) was created "
2366                  f"but linking to trace location '{trace_location.full_table_prefix}' failed: "
2367                  f"{e.message} Please delete the experiment and retry."
2368              ) from e
2369  
2370      return experiment_id
2371  
2372  
2373  def delete_experiment(experiment_id: str) -> None:
2374      """
2375      Delete an experiment from the backend store.
2376  
2377      Args:
2378          experiment_id: The string-ified experiment ID returned from ``create_experiment``.
2379  
2380      .. code-block:: python
2381          :test:
2382          :caption: Example
2383  
2384          import mlflow
2385  
2386          experiment_id = mlflow.create_experiment("New Experiment")
2387          mlflow.delete_experiment(experiment_id)
2388  
2389          # Examine the deleted experiment details.
2390          experiment = mlflow.get_experiment(experiment_id)
2391          print(f"Name: {experiment.name}")
2392          print(f"Artifact Location: {experiment.artifact_location}")
2393          print(f"Lifecycle_stage: {experiment.lifecycle_stage}")
2394          print(f"Last Updated timestamp: {experiment.last_update_time}")
2395  
2396      .. code-block:: text
2397          :caption: Output
2398  
2399          Name: New Experiment
2400          Artifact Location: file:///.../mlruns/2
2401          Lifecycle_stage: deleted
2402          Last Updated timestamp: 1662004217511
2403  
2404      """
2405      MlflowClient().delete_experiment(experiment_id)
2406  
2407  
2408  def initialize_logged_model(
2409      name: str | None = None,
2410      source_run_id: str | None = None,
2411      tags: dict[str, str] | None = None,
2412      params: dict[str, str] | None = None,
2413      model_type: str | None = None,
2414      experiment_id: str | None = None,
2415  ) -> LoggedModel:
2416      """
2417      Initialize a LoggedModel. Creates a LoggedModel with status ``PENDING`` and no artifacts. You
2418      must add artifacts to the model and finalize it to the ``READY`` state, for example by calling
2419      a flavor-specific ``log_model()`` method such as :py:func:`mlflow.pyfunc.log_model()`.
2420  
2421      Args:
2422          name: The name of the model. If not specified, a random name will be generated.
2423          source_run_id: The ID of the run that the model is associated with. If unspecified and a
2424                         run is active, the active run ID will be used.
2425          tags: A dictionary of string keys and values to set as tags on the model.
2426          params: A dictionary of string keys and values to set as parameters on the model.
2427          model_type: The type of the model.
2428          experiment_id: The experiment ID of the experiment to which the model belongs.
2429  
2430      Returns:
2431          A new :py:class:`mlflow.entities.LoggedModel` object with status ``PENDING``.
2432      """
2433      return _initialize_logged_model(
2434          name=name,
2435          source_run_id=source_run_id,
2436          tags=tags,
2437          params=params,
2438          model_type=model_type,
2439          experiment_id=experiment_id,
2440          flavor="initialize",
2441      )
2442  
2443  
2444  def _initialize_logged_model(
2445      name: str | None = None,
2446      source_run_id: str | None = None,
2447      tags: dict[str, str] | None = None,
2448      params: dict[str, str] | None = None,
2449      model_type: str | None = None,
2450      experiment_id: str | None = None,
2451      # this is only for internal logging purpose
2452      flavor: str | None = None,
2453  ) -> LoggedModel:
2454      model = _create_logged_model(
2455          name=name,
2456          source_run_id=source_run_id,
2457          tags=tags,
2458          params=params,
2459          model_type=model_type,
2460          experiment_id=experiment_id,
2461          flavor=flavor,
2462      )
2463      _last_logged_model_id.set(model.model_id)
2464      return model
2465  
2466  
2467  @contextlib.contextmanager
2468  def _use_logged_model(model: LoggedModel) -> Generator[LoggedModel, None, None]:
2469      """
2470      Context manager to wrap a LoggedModel and update the model
2471      status after the context is exited.
2472      If any exception occurs, the model status is set to FAILED.
2473      Otherwise, it is set to READY.
2474      """
2475      try:
2476          yield model
2477      except Exception:
2478          finalize_logged_model(model.model_id, LoggedModelStatus.FAILED)
2479          raise
2480      else:
2481          finalize_logged_model(model.model_id, LoggedModelStatus.READY)
2482  
2483  
2484  def create_external_model(
2485      name: str | None = None,
2486      source_run_id: str | None = None,
2487      tags: dict[str, str] | None = None,
2488      params: dict[str, str] | None = None,
2489      model_type: str | None = None,
2490      experiment_id: str | None = None,
2491  ) -> LoggedModel:
2492      """
2493      Create a new LoggedModel whose artifacts are stored outside of MLflow. This is useful for
2494      tracking parameters and performance data (metrics, traces etc.) for a model, application, or
2495      generative AI agent that is not packaged using the MLflow Model format.
2496  
2497      Args:
2498          name: The name of the model. If not specified, a random name will be generated.
2499          source_run_id: The ID of the run that the model is associated with. If unspecified and a
2500                         run is active, the active run ID will be used.
2501          tags: A dictionary of string keys and values to set as tags on the model.
2502          params: A dictionary of string keys and values to set as parameters on the model.
2503          model_type: The type of the model. This is a user-defined string that can be used to
2504                      search and compare related models. For example, setting ``model_type="agent"``
2505                      enables you to easily search for this model and compare it to other models of
2506                      type ``"agent"`` in the future.
2507          experiment_id: The experiment ID of the experiment to which the model belongs.
2508  
2509      Returns:
2510          A new :py:class:`mlflow.entities.LoggedModel` object with status ``READY``.
2511      """
2512      from mlflow.models.model import MLMODEL_FILE_NAME, Model
2513      from mlflow.models.utils import get_external_mlflow_model_spec
2514  
2515      tags = dict(tags) if tags else {}
2516      tags[MLFLOW_MODEL_IS_EXTERNAL] = "true"
2517  
2518      client = MlflowClient()
2519      model = _create_logged_model(
2520          name=name,
2521          source_run_id=source_run_id,
2522          tags=tags,
2523          params=params,
2524          model_type=model_type,
2525          experiment_id=experiment_id,
2526          flavor="external",
2527      )
2528  
2529      # If a model is external, its artifacts (code, weights, etc.) are not stored in MLflow.
2530      # Accordingly, we finalize the model immediately after creation, since there aren't
2531      # any model artifacts for the client to upload to MLflow. Additionally, we create a
2532      # dummy MLModel file to ensure that the model can be registered to the Model Registry
2533      mlflow_model: Model = get_external_mlflow_model_spec(model)
2534      with TempDir() as tmp:
2535          mlflow_model.save(tmp.path(MLMODEL_FILE_NAME))
2536          MlflowClient().log_model_artifacts(
2537              model_id=model.model_id,
2538              local_dir=tmp.path(),
2539          )
2540  
2541      model = client.finalize_logged_model(model_id=model.model_id, status=LoggedModelStatus.READY)
2542      _last_logged_model_id.set(model.model_id)
2543  
2544      return model
2545  
2546  
2547  def _create_logged_model(
2548      name: str | None = None,
2549      source_run_id: str | None = None,
2550      tags: dict[str, str] | None = None,
2551      params: dict[str, str] | None = None,
2552      model_type: str | None = None,
2553      experiment_id: str | None = None,
2554      flavor: str | None = None,
2555      serialization_format: str | None = None,
2556      uses_uv: bool = False,
2557  ) -> LoggedModel:
2558      """
2559      Create a new LoggedModel in the ``PENDING`` state.
2560  
2561      Args:
2562          name: The name of the model. If not specified, a random name will be generated.
2563          source_run_id: The ID of the run that the model is associated with. If unspecified and a
2564                         run is active, the active run ID will be used.
2565          tags: A dictionary of string keys and values to set as tags on the model.
2566          params: A dictionary of string keys and values to set as parameters on the model.
2567          model_type: The type of the model. This is a user-defined string that can be used to
2568                      search and compare related models. For example, setting ``model_type="agent"``
2569                      enables you to easily search for this model and compare it to other models of
2570                      type ``"agent"`` in the future.
2571          experiment_id: The experiment ID of the experiment to which the model belongs.
2572          flavor: The flavor of the model, recorded for telemetry and analytics only; it does not
2573                  affect the stored LoggedModel.
2574          serialization_format: The serialization format of the model, recorded for telemetry and
2575                                analytics only; it does not affect the stored LoggedModel.
2576          uses_uv: Whether the model uses uv dependency management, recorded for telemetry and
2577                   analytics only; it does not affect the stored LoggedModel.
2578  
2579      Returns:
2580          A new LoggedModel in the ``PENDING`` state.
2581      """
2582      if source_run_id is None and (run := active_run()):
2583          source_run_id = run.info.run_id
2584  
2585      if experiment_id is None and (run := active_run()):
2586          experiment_id = run.info.experiment_id
2587      elif experiment_id is None:
2588          experiment_id = _get_experiment_id() or (
2589              get_run(source_run_id).info.experiment_id if source_run_id else None
2590          )
2591      resolved_tags = context_registry.resolve_tags(tags)
2592      return MlflowClient()._create_logged_model(
2593          experiment_id=experiment_id,
2594          name=name,
2595          source_run_id=source_run_id,
2596          tags=resolved_tags,
2597          params=params,
2598          model_type=model_type,
2599          flavor=flavor,
2600          serialization_format=serialization_format,
2601          uses_uv=uses_uv,
2602      )
2603  
2604  
2605  def log_model_params(params: dict[str, str], model_id: str | None = None) -> None:
2606      """
2607      Log params to the specified logged model.
2608  
2609      Args:
2610          params: Params to log on the model.
2611          model_id: ID of the model. If not specified, use the current active model ID.
2612  
2613      Returns:
2614          None
2615  
2616      Example:
2617  
2618      .. code-block:: python
2619          :test:
2620  
2621          import mlflow
2622  
2623  
2624          class DummyModel(mlflow.pyfunc.PythonModel):
2625              def predict(self, context, model_input: list[str]) -> list[str]:
2626                  return model_input
2627  
2628  
2629          model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel())
2630          mlflow.log_model_params(params={"param": "value"}, model_id=model_info.model_id)
2631      """
2632      model_id = model_id or get_active_model_id()
2633      MlflowClient().log_model_params(model_id, params)
2634  
2635  
2636  def import_checkpoints(
2637      checkpoint_path: str,
2638      source_run_id: str | None = None,
2639      model_prefix: str | None = None,
2640      overwrite_checkpoints: bool = False,
2641  ) -> list[LoggedModel]:
2642      """
2643      Create external models for all top-level files and directories under the specified
2644      checkpoint path.
2645  
2646      This API only supports Databricks runtime currently.
2647  
2648      Args:
2649          checkpoint_path: Path that contains the checkpoints.
2650              Only Databricks Unity Catalog Volume path is supported for now.
2651              It must follows the
2652              "/Volumes/<catalog_identifier>/<schema_identifier>/<volume_identifier>/<path_to_checkpoints_directory>"
2653              format specified https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-volumes#volume-naming-and-reference.
2654              Note: Each path must be isolated from other models and runs.
2655          source_run_id: ID of the MLflow source run that these checkpoints were trained with.
2656              If not provided, uses the current active run if available.
2657          model_prefix: String prefix to prepend to the name of each external model created from
2658              each checkpoint. If not provided, no prefix is applied.
2659          overwrite_checkpoints: If True and existing models are found with the same name in the
2660              associated experiment, they will be deleted and recreated to point to the latest
2661              checkpoint. Defaults to False.
2662  
2663      Returns:
2664          List of imported models. If 'overwrite_checkpoints' is True, the list only contains
2665              new created models, otherwise the list contains new created models for the new model
2666              names and existing models for the existing model names.
2667  
2668      Example:
2669  
2670      .. code-block:: python
2671  
2672          import mlflow
2673  
2674          # Optionally start a run so `source_run_id` can be inferred
2675          with mlflow.start_run() as run:
2676              # ... training code that writes checkpoints to a UC Volume ...
2677              logged_models = mlflow.import_checkpoints(
2678                  checkpoint_path=(
2679                      "/Volumes/mycatalog/myschema/myvolume/mytrainingmodel/trainingrun1/checkpoints"
2680                  ),
2681                  # You can omit `source_run_id` if there is an active run.
2682                  # source_run_id=run.info.run_id,
2683                  model_prefix="my_model_",
2684                  overwrite_checkpoints=True,
2685              )
2686      """
2687      from databricks.sdk import WorkspaceClient
2688  
2689      # Validate checkpoint_path before accessing workspace files
2690      if not isinstance(checkpoint_path, str) or not checkpoint_path.strip().startswith("/Volumes/"):
2691          raise MlflowException(
2692              "Parameter 'checkpoint_path' must be a non-empty string pointing to a Unity Catalog "
2693              "Volume path that contains checkpoints, e.g. '/Volumes/...'",
2694              error_code=INVALID_PARAMETER_VALUE,
2695          )
2696  
2697      # Resolve source_run_id from the active run if not provided
2698      if source_run_id is None:
2699          if run := active_run():
2700              source_run_id = run.info.run_id
2701          else:
2702              raise MlflowException.invalid_parameter_value(
2703                  "Please set 'source_run_id' or start an active run before calling "
2704                  "'import_checkpoints'."
2705              )
2706  
2707      # Resolve experiment ID to operate against
2708      exp_id = MlflowClient().get_run(source_run_id).info.experiment_id
2709  
2710      ws = WorkspaceClient()
2711      top_level_paths = [
2712          entry.path.rstrip("/") for entry in ws.files.list_directory_contents(checkpoint_path)
2713      ]
2714  
2715      imported_models: list[LoggedModel] = []
2716      client = MlflowClient()
2717  
2718      if not top_level_paths:
2719          _logger.warning(
2720              f"No checkpoints were found at path '{checkpoint_path}'. "
2721              "Please verify that 'checkpoint_path' is correct and accessible."
2722          )
2723          return []
2724  
2725      for sub_checkpoint_path in top_level_paths:
2726          base_name = os.path.basename(sub_checkpoint_path)
2727  
2728          model_name = model_prefix + base_name if model_prefix else base_name
2729  
2730          try:
2731              _validate_logged_model_name(model_name)
2732          except MlflowException as e:
2733              _logger.warning(
2734                  f"The model name is invalid (root error: {e!s}), skip importing the "
2735                  f"model with name '{model_name}' from checkpoint folder '{sub_checkpoint_path}'."
2736              )
2737              continue
2738  
2739          existing_models = [
2740              model
2741              for model in search_logged_models(
2742                  experiment_ids=[exp_id],
2743                  filter_string=f"name = '{model_name}'",
2744                  output_format="list",
2745              )
2746              if model.source_run_id == source_run_id
2747          ]
2748  
2749          if not existing_models or overwrite_checkpoints:
2750              # Create a new model pointing to this checkpoint path.
2751              created_model = create_external_model(
2752                  name=model_name,
2753                  source_run_id=source_run_id,
2754                  tags={"original_artifact_path": sub_checkpoint_path},
2755                  experiment_id=exp_id,
2756              )
2757              imported_models.append(created_model)
2758          else:
2759              imported_models.extend(existing_models)
2760  
2761          if existing_models and overwrite_checkpoints:
2762              for model in existing_models:
2763                  client.delete_logged_model(model.model_id)
2764  
2765      return imported_models
2766  
2767  
2768  def finalize_logged_model(
2769      model_id: str, status: Literal["READY", "FAILED"] | LoggedModelStatus
2770  ) -> LoggedModel:
2771      """
2772      Finalize a model by updating its status.
2773  
2774      Args:
2775          model_id: ID of the model to finalize.
2776          status: Final status to set on the model.
2777  
2778      Returns:
2779          The updated model.
2780  
2781      Example:
2782  
2783      .. code-block:: python
2784          :test:
2785  
2786          import mlflow
2787          from mlflow.entities import LoggedModelStatus
2788  
2789          model = mlflow.initialize_logged_model(name="model")
2790          logged_model = mlflow.finalize_logged_model(
2791              model_id=model.model_id,
2792              status=LoggedModelStatus.READY,
2793          )
2794          assert logged_model.status == LoggedModelStatus.READY
2795  
2796      """
2797      return MlflowClient().finalize_logged_model(model_id, status)
2798  
2799  
2800  def get_logged_model(model_id: str) -> LoggedModel:
2801      """
2802      Get a logged model by ID.
2803  
2804      Args:
2805          model_id: The ID of the logged model.
2806  
2807      Returns:
2808          The logged model.
2809  
2810      Example:
2811  
2812      .. code-block:: python
2813          :test:
2814  
2815          import mlflow
2816  
2817  
2818          class DummyModel(mlflow.pyfunc.PythonModel):
2819              def predict(self, context, model_input: list[str]) -> list[str]:
2820                  return model_input
2821  
2822  
2823          model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel())
2824          logged_model = mlflow.get_logged_model(model_id=model_info.model_id)
2825          assert logged_model.model_id == model_info.model_id
2826  
2827      """
2828      return MlflowClient().get_logged_model(model_id)
2829  
2830  
2831  def last_logged_model() -> LoggedModel | None:
2832      """
2833      Fetches the most recent logged model in the current session.
2834      If no model has been logged, None is returned.
2835  
2836      Returns:
2837          The logged model.
2838  
2839  
2840      .. code-block:: python
2841          :test:
2842          :caption: Example
2843  
2844          import mlflow
2845  
2846  
2847          class DummyModel(mlflow.pyfunc.PythonModel):
2848              def predict(self, context, model_input: list[str]) -> list[str]:
2849                  return model_input
2850  
2851  
2852          model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel())
2853          last_model = mlflow.last_logged_model()
2854          assert last_model.model_id == model_info.model_id
2855      """
2856      if id := _last_logged_model_id.get():
2857          return get_logged_model(id)
2858  
2859  
2860  @overload
2861  def search_logged_models(
2862      experiment_ids: list[str] | None = None,
2863      filter_string: str | None = None,
2864      datasets: list[dict[str, str]] | None = None,
2865      max_results: int | None = None,
2866      order_by: list[dict[str, Any]] | None = None,
2867      output_format: Literal["pandas"] = "pandas",
2868  ) -> "pandas.DataFrame": ...
2869  
2870  
2871  @overload
2872  def search_logged_models(
2873      experiment_ids: list[str] | None = None,
2874      filter_string: str | None = None,
2875      datasets: list[dict[str, str]] | None = None,
2876      max_results: int | None = None,
2877      order_by: list[dict[str, Any]] | None = None,
2878      output_format: Literal["list"] = "list",
2879  ) -> list[LoggedModel]: ...
2880  
2881  
2882  def search_logged_models(
2883      experiment_ids: list[str] | None = None,
2884      filter_string: str | None = None,
2885      datasets: list[dict[str, str]] | None = None,
2886      max_results: int | None = None,
2887      order_by: list[dict[str, Any]] | None = None,
2888      output_format: Literal["pandas", "list"] = "pandas",
2889  ) -> Union[list[LoggedModel], "pandas.DataFrame"]:
2890      """
2891      Search for logged models that match the specified search criteria.
2892  
2893      Args:
2894          experiment_ids: List of experiment IDs to search for logged models. If not specified,
2895              the active experiment will be used.
2896          filter_string: A SQL-like filter string to parse. The filter string syntax supports:
2897  
2898              - Entity specification:
2899                  - attributes: `attribute_name` (default if no prefix is specified)
2900                  - metrics: `metrics.metric_name`
2901                  - parameters: `params.param_name`
2902                  - tags: `tags.tag_name`
2903              - Comparison operators:
2904                  - For numeric entities (metrics and numeric attributes): <, <=, >, >=, =, !=
2905                  - For string entities (params, tags, string attributes): =, !=, IN, NOT IN
2906              - Multiple conditions can be joined with 'AND'
2907              - String values must be enclosed in single quotes
2908  
2909              Example filter strings:
2910                  - `creation_time > 100`
2911                  - `metrics.rmse > 0.5 AND params.model_type = 'rf'`
2912                  - `tags.release IN ('v1.0', 'v1.1')`
2913                  - `params.optimizer != 'adam' AND metrics.accuracy >= 0.9`
2914  
2915          datasets: List of dictionaries to specify datasets on which to apply metrics filters
2916              For example, a filter string with `metrics.accuracy > 0.9` and dataset with name
2917              "test_dataset" means we will return all logged models with accuracy > 0.9 on the
2918              test_dataset. Metric values from ANY dataset matching the criteria are considered.
2919              If no datasets are specified, then metrics across all datasets are considered in
2920              the filter. The following fields are supported:
2921  
2922              dataset_name (str):
2923                  Required. Name of the dataset.
2924              dataset_digest (str):
2925                  Optional. Digest of the dataset.
2926          max_results: The maximum number of logged models to return.
2927          order_by: List of dictionaries to specify the ordering of the search results. The following
2928              fields are supported:
2929  
2930              field_name (str):
2931                  Required. Name of the field to order by, e.g. "metrics.accuracy".
2932              ascending (bool):
2933                  Optional. Whether the order is ascending or not.
2934              dataset_name (str):
2935                  Optional. If ``field_name`` refers to a metric, this field
2936                  specifies the name of the dataset associated with the metric. Only metrics
2937                  associated with the specified dataset name will be considered for ordering.
2938                  This field may only be set if ``field_name`` refers to a metric.
2939              dataset_digest (str):
2940                  Optional. If ``field_name`` refers to a metric, this field
2941                  specifies the digest of the dataset associated with the metric. Only metrics
2942                  associated with the specified dataset name and digest will be considered for
2943                  ordering. This field may only be set if ``dataset_name`` is also set.
2944  
2945          output_format: The output format of the search results. Supported values are 'pandas'
2946              and 'list'.
2947  
2948      Returns:
2949          The search results in the specified output format.
2950  
2951      Example:
2952  
2953      .. code-block:: python
2954          :test:
2955  
2956          import mlflow
2957  
2958  
2959          class DummyModel(mlflow.pyfunc.PythonModel):
2960              def predict(self, context, model_input: list[str]) -> list[str]:
2961                  return model_input
2962  
2963  
2964          model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel())
2965          another_model_info = mlflow.pyfunc.log_model(
2966              name="another_model", python_model=DummyModel()
2967          )
2968          models = mlflow.search_logged_models(output_format="list")
2969          assert [m.name for m in models] == ["another_model", "model"]
2970          models = mlflow.search_logged_models(
2971              filter_string="name = 'another_model'", output_format="list"
2972          )
2973          assert [m.name for m in models] == ["another_model"]
2974          models = mlflow.search_logged_models(
2975              order_by=[{"field_name": "creation_time", "ascending": True}], output_format="list"
2976          )
2977          assert [m.name for m in models] == ["model", "another_model"]
2978      """
2979      experiment_ids = experiment_ids or [_get_experiment_id()]
2980      client = MlflowClient()
2981      models = []
2982      page_token = None
2983      while True:
2984          logged_models_page = client.search_logged_models(
2985              experiment_ids=experiment_ids,
2986              filter_string=filter_string,
2987              datasets=datasets,
2988              max_results=max_results,
2989              order_by=order_by,
2990              page_token=page_token,
2991          )
2992          models.extend(logged_models_page.to_list())
2993          if max_results is not None and len(models) >= max_results:
2994              break
2995          if not logged_models_page.token:
2996              break
2997          page_token = logged_models_page.token
2998  
2999      # Only return at most max_results logged models if specified
3000      if max_results is not None:
3001          models = models[:max_results]
3002  
3003      if output_format == "list":
3004          return models
3005      elif output_format == "pandas":
3006          import pandas as pd
3007  
3008          model_dicts = []
3009          for model in models:
3010              model_dict = model.to_dictionary()
3011              # Convert the status back from int to the enum string
3012              model_dict["status"] = LoggedModelStatus.from_int(model_dict["status"])
3013              model_dicts.append(model_dict)
3014  
3015          return pd.DataFrame(model_dicts)
3016  
3017      else:
3018          raise MlflowException(
3019              f"Unsupported output format: {output_format!r}. Supported string values are "
3020              "'pandas' or 'list'",
3021              INVALID_PARAMETER_VALUE,
3022          )
3023  
3024  
3025  def log_outputs(models: list[LoggedModelOutput] | None = None):
3026      """
3027      Log outputs, such as models, to the active run. If there is no active run, a new run will be
3028      created.
3029  
3030      Args:
3031          models: List of :py:class:`mlflow.entities.LoggedModelOutput` instances to log
3032              as outputs to the run.
3033  
3034      Returns:
3035          None.
3036      """
3037      run_id = _get_or_start_run().info.run_id
3038      MlflowClient().log_outputs(run_id, models=models)
3039  
3040  
3041  def delete_run(run_id: str) -> None:
3042      """
3043      Deletes a run with the given ID.
3044  
3045      Args:
3046          run_id: Unique identifier for the run to delete.
3047  
3048      .. code-block:: python
3049          :test:
3050          :caption: Example
3051  
3052          import mlflow
3053  
3054          with mlflow.start_run() as run:
3055              mlflow.log_param("p", 0)
3056  
3057          run_id = run.info.run_id
3058          mlflow.delete_run(run_id)
3059  
3060          lifecycle_stage = mlflow.get_run(run_id).info.lifecycle_stage
3061          print(f"run_id: {run_id}; lifecycle_stage: {lifecycle_stage}")
3062  
3063      .. code-block:: text
3064          :caption: Output
3065  
3066          run_id: 45f4af3e6fd349e58579b27fcb0b8277; lifecycle_stage: deleted
3067  
3068      """
3069      MlflowClient().delete_run(run_id)
3070  
3071  
3072  def set_logged_model_tags(model_id: str, tags: dict[str, Any]) -> None:
3073      """
3074      Set tags on the specified logged model.
3075  
3076      Args:
3077          model_id: ID of the model.
3078          tags: Tags to set on the model.
3079  
3080      Returns:
3081          None
3082  
3083      Example:
3084  
3085      .. code-block:: python
3086          :test:
3087  
3088          import mlflow
3089  
3090  
3091          class DummyModel(mlflow.pyfunc.PythonModel):
3092              def predict(self, context, model_input: list[str]) -> list[str]:
3093                  return model_input
3094  
3095  
3096          model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel())
3097          mlflow.set_logged_model_tags(model_info.model_id, {"key": "value"})
3098          model = mlflow.get_logged_model(model_info.model_id)
3099          assert model.tags["key"] == "value"
3100      """
3101      MlflowClient().set_logged_model_tags(model_id, tags)
3102  
3103  
3104  def delete_logged_model_tag(model_id: str, key: str) -> None:
3105      """
3106      Delete a tag from the specified logged model.
3107  
3108      Args:
3109          model_id: ID of the model.
3110          key: Tag key to delete.
3111  
3112      Example:
3113  
3114      .. code-block:: python
3115          :test:
3116  
3117          import mlflow
3118  
3119  
3120          class DummyModel(mlflow.pyfunc.PythonModel):
3121              def predict(self, context, model_input: list[str]) -> list[str]:
3122                  return model_input
3123  
3124  
3125          model_info = mlflow.pyfunc.log_model(name="model", python_model=DummyModel())
3126          mlflow.set_logged_model_tags(model_info.model_id, {"key": "value"})
3127          model = mlflow.get_logged_model(model_info.model_id)
3128          assert model.tags["key"] == "value"
3129          mlflow.delete_logged_model_tag(model_info.model_id, "key")
3130          model = mlflow.get_logged_model(model_info.model_id)
3131          assert "key" not in model.tags
3132      """
3133      MlflowClient().delete_logged_model_tag(model_id, key)
3134  
3135  
3136  def get_artifact_uri(artifact_path: str | None = None) -> str:
3137      """
3138      Get the absolute URI of the specified artifact in the currently active run.
3139  
3140      If `path` is not specified, the artifact root URI of the currently active
3141      run will be returned; calls to ``log_artifact`` and ``log_artifacts`` write
3142      artifact(s) to subdirectories of the artifact root URI.
3143  
3144      If no run is active, this method will create a new active run.
3145  
3146      Args:
3147          artifact_path: The run-relative artifact path for which to obtain an absolute URI.
3148              For example, "path/to/artifact". If unspecified, the artifact root URI
3149              for the currently active run will be returned.
3150  
3151      Returns:
3152          An *absolute* URI referring to the specified artifact or the currently active run's
3153          artifact root. For example, if an artifact path is provided and the currently active
3154          run uses an S3-backed store, this may be a uri of the form
3155          ``s3://<bucket_name>/path/to/artifact/root/path/to/artifact``. If an artifact path
3156          is not provided and the currently active run uses an S3-backed store, this may be a
3157          URI of the form ``s3://<bucket_name>/path/to/artifact/root``.
3158  
3159      .. code-block:: python
3160          :test:
3161          :caption: Example
3162  
3163          import tempfile
3164  
3165          import mlflow
3166  
3167          features = "rooms, zipcode, median_price, school_rating, transport"
3168          with tempfile.NamedTemporaryFile("w") as tmp_file:
3169              tmp_file.write(features)
3170              tmp_file.flush()
3171  
3172              # Log the artifact in a directory "features" under the root artifact_uri/features
3173              with mlflow.start_run():
3174                  mlflow.log_artifact(tmp_file.name, artifact_path="features")
3175  
3176                  # Fetch the artifact uri root directory
3177                  artifact_uri = mlflow.get_artifact_uri()
3178                  print(f"Artifact uri: {artifact_uri}")
3179  
3180                  # Fetch a specific artifact uri
3181                  artifact_uri = mlflow.get_artifact_uri(artifact_path="features/features.txt")
3182                  print(f"Artifact uri: {artifact_uri}")
3183  
3184      .. code-block:: text
3185          :caption: Output
3186  
3187          Artifact uri: file:///.../0/a46a80f1c9644bd8f4e5dd5553fffce/artifacts
3188          Artifact uri: file:///.../0/a46a80f1c9644bd8f4e5dd5553fffce/artifacts/features/features.txt
3189      """
3190      if not mlflow.active_run():
3191          _logger.warning(
3192              "No active run found. A new active run will be created. If this is not intended, "
3193              "please create a run using `mlflow.start_run()` first."
3194          )
3195  
3196      return artifact_utils.get_artifact_uri(
3197          run_id=_get_or_start_run().info.run_id, artifact_path=artifact_path
3198      )
3199  
3200  
3201  def search_runs(
3202      experiment_ids: list[str] | None = None,
3203      filter_string: str = "",
3204      run_view_type: int = ViewType.ACTIVE_ONLY,
3205      max_results: int = SEARCH_MAX_RESULTS_PANDAS,
3206      order_by: list[str] | None = None,
3207      output_format: str = "pandas",
3208      search_all_experiments: bool = False,
3209      experiment_names: list[str] | None = None,
3210  ) -> Union[list[Run], "pandas.DataFrame"]:
3211      """
3212      Search for Runs that fit the specified criteria.
3213  
3214      Args:
3215          experiment_ids: List of experiment IDs. Search can work with experiment IDs or
3216              experiment names, but not both in the same call. Values other than
3217              ``None`` or ``[]`` will result in error if ``experiment_names`` is
3218              also not ``None`` or ``[]``. ``None`` will default to the active
3219              experiment if ``experiment_names`` is ``None`` or ``[]``.
3220          filter_string: Filter query string, defaults to searching all runs.
3221          run_view_type: one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``, or ``ALL`` runs
3222              defined in :py:class:`mlflow.entities.ViewType`.
3223          max_results: The maximum number of runs to put in the dataframe. Default is 100,000
3224              to avoid causing out-of-memory issues on the user's machine.
3225          order_by: List of columns to order by (e.g., "metrics.rmse"). The ``order_by`` column
3226              can contain an optional ``DESC`` or ``ASC`` value. The default is ``ASC``.
3227              The default ordering is to sort by ``start_time DESC``, then ``run_id``.
3228          output_format: The output format to be returned. If ``pandas``, a ``pandas.DataFrame``
3229              is returned and, if ``list``, a list of :py:class:`mlflow.entities.Run`
3230              is returned.
3231          search_all_experiments: Boolean specifying whether all experiments should be searched.
3232              Only honored if ``experiment_ids`` is ``[]`` or ``None``.
3233          experiment_names: List of experiment names. Search can work with experiment IDs or
3234              experiment names, but not both in the same call. Values other
3235              than ``None`` or ``[]`` will result in error if ``experiment_ids``
3236              is also not ``None`` or ``[]``. ``None`` will default to the active
3237              experiment if ``experiment_ids`` is ``None`` or ``[]``.
3238  
3239      Returns:
3240          If output_format is ``list``: a list of :py:class:`mlflow.entities.Run`. If
3241          output_format is ``pandas``: ``pandas.DataFrame`` of runs, where each metric,
3242          parameter, and tag is expanded into its own column named metrics.*, params.*, or
3243          tags.* respectively. For runs that don't have a particular metric, parameter, or tag,
3244          the value for the corresponding column is (NumPy) ``Nan``, ``None``, or ``None``
3245          respectively.
3246  
3247       .. code-block:: python
3248          :test:
3249          :caption: Example
3250  
3251          import mlflow
3252  
3253          # Create an experiment and log two runs under it
3254          experiment_name = "Social NLP Experiments"
3255          experiment_id = mlflow.create_experiment(experiment_name)
3256          with mlflow.start_run(experiment_id=experiment_id):
3257              mlflow.log_metric("m", 1.55)
3258              mlflow.set_tag("s.release", "1.1.0-RC")
3259          with mlflow.start_run(experiment_id=experiment_id):
3260              mlflow.log_metric("m", 2.50)
3261              mlflow.set_tag("s.release", "1.2.0-GA")
3262          # Search for all the runs in the experiment with the given experiment ID
3263          df = mlflow.search_runs([experiment_id], order_by=["metrics.m DESC"])
3264          print(df[["metrics.m", "tags.s.release", "run_id"]])
3265          print("--")
3266          # Search the experiment_id using a filter_string with tag
3267          # that has a case insensitive pattern
3268          filter_string = "tags.s.release ILIKE '%rc%'"
3269          df = mlflow.search_runs([experiment_id], filter_string=filter_string)
3270          print(df[["metrics.m", "tags.s.release", "run_id"]])
3271          print("--")
3272          # Search for all the runs in the experiment with the given experiment name
3273          df = mlflow.search_runs(experiment_names=[experiment_name], order_by=["metrics.m DESC"])
3274          print(df[["metrics.m", "tags.s.release", "run_id"]])
3275  
3276      .. code-block:: text
3277          :caption: Output
3278  
3279             metrics.m tags.s.release                            run_id
3280          0       2.50       1.2.0-GA  147eed886ab44633902cc8e19b2267e2
3281          1       1.55       1.1.0-RC  5cc7feaf532f496f885ad7750809c4d4
3282          --
3283             metrics.m tags.s.release                            run_id
3284          0       1.55       1.1.0-RC  5cc7feaf532f496f885ad7750809c4d4
3285          --
3286             metrics.m tags.s.release                            run_id
3287          0       2.50       1.2.0-GA  147eed886ab44633902cc8e19b2267e2
3288          1       1.55       1.1.0-RC  5cc7feaf532f496f885ad7750809c4d4
3289      """
3290      no_ids = experiment_ids is None or len(experiment_ids) == 0
3291      no_names = experiment_names is None or len(experiment_names) == 0
3292      no_ids_or_names = no_ids and no_names
3293      if not no_ids and not no_names:
3294          raise MlflowException(
3295              message="Only experiment_ids or experiment_names can be used, but not both",
3296              error_code=INVALID_PARAMETER_VALUE,
3297          )
3298  
3299      if search_all_experiments and no_ids_or_names:
3300          experiment_ids = [
3301              exp.experiment_id for exp in search_experiments(view_type=ViewType.ACTIVE_ONLY)
3302          ]
3303      elif no_ids_or_names:
3304          experiment_ids = [_get_experiment_id()]
3305      elif not no_names:
3306          experiments = []
3307          for n in experiment_names:
3308              if n is not None:
3309                  if experiment_by_name := get_experiment_by_name(n):
3310                      experiments.append(experiment_by_name)
3311                  else:
3312                      _logger.warning("Cannot retrieve experiment by name %s", n)
3313          experiment_ids = [e.experiment_id for e in experiments if e is not None]
3314  
3315      if len(experiment_ids) == 0:
3316          runs = []
3317      else:
3318          # Using an internal function as the linter doesn't like assigning a lambda, and inlining the
3319          # full thing is a mess
3320          def pagination_wrapper_func(number_to_get, next_page_token):
3321              return MlflowClient().search_runs(
3322                  experiment_ids,
3323                  filter_string,
3324                  run_view_type,
3325                  number_to_get,
3326                  order_by,
3327                  next_page_token,
3328              )
3329  
3330          runs = get_results_from_paginated_fn(
3331              pagination_wrapper_func,
3332              NUM_RUNS_PER_PAGE_PANDAS,
3333              max_results,
3334          )
3335  
3336      if output_format == "list":
3337          return runs  # List[mlflow.entities.run.Run]
3338      elif output_format == "pandas":
3339          import numpy as np
3340          import pandas as pd
3341  
3342          info = {
3343              "run_id": [],
3344              "experiment_id": [],
3345              "status": [],
3346              "artifact_uri": [],
3347              "start_time": [],
3348              "end_time": [],
3349          }
3350          params = {}
3351          metrics = {}
3352          tags = {}
3353          PARAM_NULL = None
3354          METRIC_NULL = np.nan
3355          TAG_NULL = None
3356          for i, run in enumerate(runs):
3357              info["run_id"].append(run.info.run_id)
3358              info["experiment_id"].append(run.info.experiment_id)
3359              info["status"].append(run.info.status)
3360              info["artifact_uri"].append(run.info.artifact_uri)
3361              info["start_time"].append(pd.to_datetime(run.info.start_time, unit="ms", utc=True))
3362              info["end_time"].append(pd.to_datetime(run.info.end_time, unit="ms", utc=True))
3363  
3364              # Params
3365              param_keys = set(params.keys())
3366              for key in param_keys:
3367                  if key in run.data.params:
3368                      params[key].append(run.data.params[key])
3369                  else:
3370                      params[key].append(PARAM_NULL)
3371              new_params = set(run.data.params.keys()) - param_keys
3372              for p in new_params:
3373                  params[p] = [PARAM_NULL] * i  # Fill in null values for all previous runs
3374                  params[p].append(run.data.params[p])
3375  
3376              # Metrics
3377              metric_keys = set(metrics.keys())
3378              for key in metric_keys:
3379                  if key in run.data.metrics:
3380                      metrics[key].append(run.data.metrics[key])
3381                  else:
3382                      metrics[key].append(METRIC_NULL)
3383              new_metrics = set(run.data.metrics.keys()) - metric_keys
3384              for m in new_metrics:
3385                  metrics[m] = [METRIC_NULL] * i
3386                  metrics[m].append(run.data.metrics[m])
3387  
3388              # Tags
3389              tag_keys = set(tags.keys())
3390              for key in tag_keys:
3391                  if key in run.data.tags:
3392                      tags[key].append(run.data.tags[key])
3393                  else:
3394                      tags[key].append(TAG_NULL)
3395              new_tags = set(run.data.tags.keys()) - tag_keys
3396              for t in new_tags:
3397                  tags[t] = [TAG_NULL] * i
3398                  tags[t].append(run.data.tags[t])
3399  
3400          data = {}
3401          data.update(info)
3402          for key, value in metrics.items():
3403              data["metrics." + key] = value
3404          for key, value in params.items():
3405              data["params." + key] = value
3406          for key, value in tags.items():
3407              data["tags." + key] = value
3408          return pd.DataFrame(data)
3409      else:
3410          raise ValueError(
3411              f"Unsupported output format: {output_format}. Supported string values are 'pandas' "
3412              "or 'list'"
3413          )
3414  
3415  
3416  def _get_or_start_run():
3417      active_run_stack = _active_run_stack.get()
3418      if len(active_run_stack) > 0:
3419          return active_run_stack[-1]
3420      return start_run()
3421  
3422  
3423  def _get_experiment_id_from_env():
3424      experiment_name = MLFLOW_EXPERIMENT_NAME.get()
3425      experiment_id = MLFLOW_EXPERIMENT_ID.get()
3426      if experiment_name is not None:
3427          if exp := MlflowClient().get_experiment_by_name(experiment_name):
3428              if experiment_id and experiment_id != exp.experiment_id:
3429                  raise MlflowException(
3430                      message=f"The provided {MLFLOW_EXPERIMENT_ID} environment variable "
3431                      f"value `{experiment_id}` does not match the experiment id "
3432                      f"`{exp.experiment_id}` for experiment name `{experiment_name}`",
3433                      error_code=INVALID_PARAMETER_VALUE,
3434                  )
3435              else:
3436                  return exp.experiment_id
3437          else:
3438              return MlflowClient().create_experiment(name=experiment_name)
3439      if experiment_id is not None:
3440          try:
3441              exp = MlflowClient().get_experiment(experiment_id)
3442              return exp.experiment_id
3443          except MlflowException as exc:
3444              raise MlflowException(
3445                  message=f"The provided {MLFLOW_EXPERIMENT_ID} environment variable "
3446                  f"value `{experiment_id}` does not exist in the tracking server. Provide a valid "
3447                  f"experiment_id.",
3448                  error_code=INVALID_PARAMETER_VALUE,
3449              ) from exc
3450  
3451  
3452  def _get_experiment_id() -> str | None:
3453      if _active_experiment_id:
3454          return _active_experiment_id
3455      else:
3456          return _get_experiment_id_from_env() or default_experiment_registry.get_experiment_id()
3457  
3458  
3459  @autologging_integration("mlflow")
3460  def autolog(
3461      log_input_examples: bool = False,
3462      log_model_signatures: bool = True,
3463      log_models: bool = True,
3464      log_datasets: bool = True,
3465      log_traces: bool = True,
3466      disable: bool = False,
3467      exclusive: bool = False,
3468      disable_for_unsupported_versions: bool = False,
3469      silent: bool = False,
3470      extra_tags: dict[str, str] | None = None,
3471      exclude_flavors: list[str] | None = None,
3472  ) -> None:
3473      """
3474      Enables (or disables) and configures autologging for all supported integrations.
3475  
3476      The parameters are passed to any autologging integrations that support them.
3477  
3478      See the `tracking docs <../../tracking/autolog.html>`_ for a list of supported autologging
3479      integrations.
3480  
3481      Note that framework-specific configurations set at any point will take precedence over
3482      any configurations set by this function. For example:
3483  
3484      .. code-block:: python
3485          :test:
3486  
3487          import mlflow
3488  
3489          mlflow.autolog(log_models=False, exclusive=True)
3490          import sklearn
3491  
3492      would enable autologging for `sklearn` with `log_models=False` and `exclusive=True`,
3493      but
3494  
3495      .. code-block:: python
3496          :test:
3497  
3498          import mlflow
3499  
3500          mlflow.autolog(log_models=False, exclusive=True)
3501  
3502          import sklearn
3503  
3504          mlflow.sklearn.autolog(log_models=True)
3505  
3506      would enable autologging for `sklearn` with `log_models=True` and `exclusive=False`,
3507      the latter resulting from the default value for `exclusive` in `mlflow.sklearn.autolog`;
3508      other framework autolog functions (e.g. `mlflow.tensorflow.autolog`) would use the
3509      configurations set by `mlflow.autolog` (in this instance, `log_models=False`, `exclusive=True`),
3510      until they are explicitly called by the user.
3511  
3512      Args:
3513          log_input_examples: If ``True``, input examples from training datasets are collected and
3514              logged along with model artifacts during training. If ``False``,
3515              input examples are not logged.
3516              Note: Input examples are MLflow model attributes
3517              and are only collected if ``log_models`` is also ``True``.
3518          log_model_signatures: If ``True``,
3519              :py:class:`ModelSignatures <mlflow.models.ModelSignature>`
3520              describing model inputs and outputs are collected and logged along
3521              with model artifacts during training. If ``False``, signatures are
3522              not logged. Note: Model signatures are MLflow model attributes
3523              and are only collected if ``log_models`` is also ``True``.
3524          log_models: If ``True``, trained models are logged as MLflow model artifacts.
3525              If ``False``, trained models are not logged.
3526              Input examples and model signatures, which are attributes of MLflow models,
3527              are also omitted when ``log_models`` is ``False``.
3528          log_datasets: If ``True``, dataset information is logged to MLflow Tracking.
3529              If ``False``, dataset information is not logged.
3530          log_traces: If ``True``, traces are collected for integrations.
3531              If ``False``, no trace is collected.
3532          disable: If ``True``, disables all supported autologging integrations. If ``False``,
3533              enables all supported autologging integrations.
3534          exclusive: If ``True``, autologged content is not logged to user-created fluent runs.
3535              If ``False``, autologged content is logged to the active fluent run,
3536              which may be user-created.
3537          disable_for_unsupported_versions: If ``True``, disable autologging for versions of
3538              all integration libraries that have not been tested against this version
3539              of the MLflow client or are incompatible.
3540          silent: If ``True``, suppress all event logs and warnings from MLflow during autologging
3541              setup and training execution. If ``False``, show all events and warnings during
3542              autologging setup and training execution.
3543          extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
3544          exclude_flavors: A list of flavor names that are excluded from the auto-logging.
3545              e.g. tensorflow, pyspark.ml
3546  
3547      .. code-block:: python
3548          :test:
3549          :caption: Example
3550  
3551          import numpy as np
3552          import mlflow.sklearn
3553          from mlflow import MlflowClient
3554          from sklearn.linear_model import LinearRegression
3555  
3556  
3557          def print_auto_logged_info(r):
3558              tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
3559              artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
3560              print(f"run_id: {r.info.run_id}")
3561              print(f"artifacts: {artifacts}")
3562              print(f"params: {r.data.params}")
3563              print(f"metrics: {r.data.metrics}")
3564              print(f"tags: {tags}")
3565  
3566  
3567          # prepare training data
3568          X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
3569          y = np.dot(X, np.array([1, 2])) + 3
3570  
3571          # Auto log all the parameters, metrics, and artifacts
3572          mlflow.autolog()
3573          model = LinearRegression()
3574          with mlflow.start_run() as run:
3575              model.fit(X, y)
3576  
3577          # fetch the auto logged parameters and metrics for ended run
3578          print_auto_logged_info(mlflow.get_run(run_id=run.info.run_id))
3579  
3580      .. code-block:: text
3581          :caption: Output
3582  
3583          run_id: fd10a17d028c47399a55ab8741721ef7
3584          artifacts: ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl']
3585          params: {'copy_X': 'True',
3586                   'normalize': 'False',
3587                   'fit_intercept': 'True',
3588                   'n_jobs': 'None'}
3589          metrics: {'training_score': 1.0,
3590                    'training_root_mean_squared_error': 4.440892098500626e-16,
3591                    'training_r2_score': 1.0,
3592                    'training_mean_absolute_error': 2.220446049250313e-16,
3593                    'training_mean_squared_error': 1.9721522630525295e-31}
3594          tags: {'estimator_class': 'sklearn.linear_model._base.LinearRegression',
3595                 'estimator_name': 'LinearRegression'}
3596      """
3597      locals_copy = locals().items()
3598  
3599      # Mapping of library name to specific autolog function name. We use string like
3600      # "tensorflow.autolog" to avoid loading all flavor modules, so we only set autologging for
3601      # compatible modules.
3602      LIBRARY_TO_AUTOLOG_MODULE = {
3603          "tensorflow": "mlflow.tensorflow",
3604          "keras": "mlflow.keras",
3605          "xgboost": "mlflow.xgboost",
3606          "lightgbm": "mlflow.lightgbm",
3607          "statsmodels": "mlflow.statsmodels",
3608          "sklearn": "mlflow.sklearn",
3609          "pyspark": "mlflow.spark",
3610          "pyspark.ml": "mlflow.pyspark.ml",
3611          # TODO: Broaden this beyond pytorch_lightning as we add autologging support for more
3612          # Pytorch frameworks under mlflow.pytorch.autolog
3613          "pytorch_lightning": "mlflow.pytorch",
3614          "lightning": "mlflow.pytorch",
3615          "setfit": "mlflow.transformers",
3616          "transformers": "mlflow.transformers",
3617          # do not enable langchain autologging by default
3618      }
3619  
3620      GENAI_LIBRARY_TO_AUTOLOG_MODULE = {
3621          "autogen": "mlflow.ag2",
3622          "agno": "mlflow.agno",
3623          "anthropic": "mlflow.anthropic",
3624          "autogen_agentchat": "mlflow.autogen",
3625          "openai": "mlflow.openai",
3626          "google.genai": "mlflow.gemini",
3627          "google.generativeai": "mlflow.gemini",
3628          "litellm": "mlflow.litellm",
3629          "llama_index.core": "mlflow.llama_index",
3630          "langchain": "mlflow.langchain",
3631          "dspy": "mlflow.dspy",
3632          "crewai": "mlflow.crewai",
3633          "smolagents": "mlflow.smolagents",
3634          "groq": "mlflow.groq",
3635          "strands": "mlflow.strands",
3636          "haystack": "mlflow.haystack",
3637          "boto3": "mlflow.bedrock",
3638          "mistralai": "mlflow.mistral",
3639          "pydantic_ai": "mlflow.pydantic_ai",
3640      }
3641  
3642      # Currently, GenAI libraries are not enabled by `mlflow.autolog` in Databricks,
3643      # particularly when disable=False. This is because the function is automatically invoked
3644      # by system and we don't want to take the risk of enabling GenAI libraries all at once.
3645      # TODO: Remove this logic once a feature flag is implemented in Databricks Runtime init logic.
3646      if is_in_databricks_runtime() and (not disable):
3647          target_library_and_module = LIBRARY_TO_AUTOLOG_MODULE
3648      else:
3649          target_library_and_module = LIBRARY_TO_AUTOLOG_MODULE | GENAI_LIBRARY_TO_AUTOLOG_MODULE
3650  
3651      if exclude_flavors:
3652          excluded_modules = [f"mlflow.{flavor}" for flavor in exclude_flavors]
3653          target_library_and_module = {
3654              k: v for k, v in target_library_and_module.items() if v not in excluded_modules
3655          }
3656  
3657      def get_autologging_params(autolog_fn):
3658          try:
3659              needed_params = list(inspect.signature(autolog_fn).parameters.keys())
3660              return {k: v for k, v in locals_copy if k in needed_params}
3661          except Exception:
3662              return {}
3663  
3664      # Note: we need to protect `setup_autologging` with `autologging_conf_lock`,
3665      # because `setup_autologging` might be registered as post importing hook
3666      # and be executed asynchronously, so that it is out of current active
3667      # `autologging_conf_lock` scope.
3668      @autologging_conf_lock
3669      def setup_autologging(module):
3670          try:
3671              autologging_params = None
3672              autolog_module = importlib.import_module(target_library_and_module[module.__name__])
3673              autolog_fn = autolog_module.autolog
3674              # Only call integration's autolog function with `mlflow.autolog` configs
3675              # if the integration's autolog function has not already been called by the user.
3676              # Logic is as follows:
3677              # - if a previous_config exists, that means either `mlflow.autolog` or
3678              #   `mlflow.integration.autolog` was called.
3679              # - if the config contains `AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED`, the
3680              #   configuration was set by `mlflow.autolog`, and so we can safely call `autolog_fn`
3681              #   with `autologging_params`.
3682              # - if the config doesn't contain this key, the configuration was set by an
3683              #   `mlflow.integration.autolog` call, so we should not call `autolog_fn` with
3684              #   new configs.
3685              prev_config = AUTOLOGGING_INTEGRATIONS.get(autolog_fn.integration_name)
3686              if prev_config and not prev_config.get(
3687                  AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED, False
3688              ):
3689                  return
3690  
3691              autologging_params = get_autologging_params(autolog_fn)
3692              autolog_fn(**autologging_params)
3693              AUTOLOGGING_INTEGRATIONS[autolog_fn.integration_name][
3694                  AUTOLOGGING_CONF_KEY_IS_GLOBALLY_CONFIGURED
3695              ] = True
3696              if not autologging_is_disabled(
3697                  autolog_fn.integration_name
3698              ) and not autologging_params.get("silent", False):
3699                  _logger.info("Autologging successfully enabled for %s.", module.__name__)
3700          except Exception as e:
3701              if is_testing():
3702                  # Raise unexpected exceptions in test mode in order to detect
3703                  # errors within dependent autologging integrations
3704                  raise
3705              elif autologging_params is None or not autologging_params.get("silent", False):
3706                  _logger.warning(
3707                      "Exception raised while enabling autologging for %s: %s",
3708                      module.__name__,
3709                      str(e),
3710                  )
3711  
3712      # for each autolog library (except pyspark), register a post-import hook.
3713      # this way, we do not send any errors to the user until we know they are using the library.
3714      # the post-import hook also retroactively activates for previously-imported libraries.
3715      for library in sorted(set(target_library_and_module) - {"pyspark", "pyspark.ml"}):
3716          register_post_import_hook(setup_autologging, library, overwrite=True)
3717  
3718      if is_in_databricks_runtime():
3719          # for pyspark, we activate autologging immediately, without waiting for a module import.
3720          # this is because on Databricks a SparkSession already exists and the user can directly
3721          #   interact with it, and this activity should be logged.
3722          import pyspark as pyspark_module
3723          import pyspark.ml as pyspark_ml_module
3724  
3725          setup_autologging(pyspark_module)
3726          setup_autologging(pyspark_ml_module)
3727      else:
3728          if "pyspark" in target_library_and_module:
3729              register_post_import_hook(setup_autologging, "pyspark", overwrite=True)
3730          if "pyspark.ml" in target_library_and_module:
3731              register_post_import_hook(setup_autologging, "pyspark.ml", overwrite=True)
3732  
3733      _record_event(AutologgingEvent, {"flavor": "all", "log_traces": log_traces, "disable": disable})
3734  
3735  
3736  _active_model_id_env_lock = threading.Lock()
3737  
3738  
3739  class ActiveModelContext:
3740      """
3741      The context of the active model.
3742  
3743      Args:
3744          model_id: The ID of the active model.
3745          set_by_user: Whether the active model was set by the user or not.
3746      """
3747  
3748      def __init__(self, model_id: str | None = None, set_by_user: bool = False):
3749          # use active model ID from environment variables as the default value for model_id
3750          # so that for subprocesses the default _ACTIVE_MODEL_CONTEXT.model_id
3751          # is still valid, and we don't need to read from env var.
3752          self._set_by_user = set_by_user
3753          if is_in_databricks_model_serving_environment():
3754              # In Databricks, we set the active model ID to the environment variable
3755              # so that it can be used in the main process, since databricks serving
3756              # loads model from threads.
3757              with _active_model_id_env_lock:
3758                  self._model_id = model_id or _get_active_model_id_from_env()
3759                  if self._model_id:
3760                      _MLFLOW_ACTIVE_MODEL_ID.set(self._model_id)
3761          else:
3762              self._model_id = model_id or _get_active_model_id_from_env()
3763  
3764      def __repr__(self):
3765          return f"ActiveModelContext(model_id={self.model_id}, set_by_user={self.set_by_user})"
3766  
3767      @property
3768      def model_id(self) -> str | None:
3769          return self._model_id
3770  
3771      @property
3772      def set_by_user(self) -> bool:
3773          return self._set_by_user
3774  
3775  
3776  def _get_active_model_id_from_env() -> str | None:
3777      """
3778      Get the active model ID from environment variables, with proper precedence handling.
3779  
3780      This utility function reads the active model ID from environment variables with the following
3781      precedence order:
3782      1. MLFLOW_ACTIVE_MODEL_ID (public variable) - takes precedence if set
3783      2. _MLFLOW_ACTIVE_MODEL_ID (legacy internal variable) - used as fallback
3784  
3785      Historical Context:
3786      The _MLFLOW_ACTIVE_MODEL_ID environment variable was originally created for internal MLflow
3787      use only. With the introduction of MLFLOW_ACTIVE_MODEL_ID as the public API, we prioritize
3788      the public variable to encourage migration to the public interface while maintaining
3789      backward compatibility by falling back to the legacy variable when only it is set.
3790  
3791      Returns:
3792          The active model ID if found in environment variables, otherwise None.
3793      """
3794      # Check public variable first to prioritize the public API
3795      public_model_id = MLFLOW_ACTIVE_MODEL_ID.get()
3796      if public_model_id is not None:
3797          return public_model_id
3798  
3799      # Fallback to legacy internal variable for backward compatibility
3800      return _MLFLOW_ACTIVE_MODEL_ID.get()
3801  
3802  
3803  _ACTIVE_MODEL_CONTEXT = ThreadLocalVariable(default_factory=lambda: ActiveModelContext())
3804  
3805  
3806  class ActiveModel(LoggedModel):
3807      """
3808      Wrapper around :py:class:`mlflow.entities.LoggedModel` to enable using Python ``with`` syntax.
3809      """
3810  
3811      def __init__(self, logged_model: LoggedModel, set_by_user: bool):
3812          super().__init__(**logged_model.to_dictionary())
3813          self.last_active_model_context = _ACTIVE_MODEL_CONTEXT.get()
3814          _set_active_model_id(self.model_id, set_by_user)
3815  
3816      def __enter__(self):
3817          return self
3818  
3819      def __exit__(self, exc_type, exc_val, exc_tb):
3820          if is_in_databricks_model_serving_environment():
3821              # create a new instance of ActiveModelContext to make sure the
3822              # environment variable is updated in databricks serving environment
3823              _ACTIVE_MODEL_CONTEXT.set(
3824                  ActiveModelContext(
3825                      model_id=self.last_active_model_context.model_id,
3826                      set_by_user=self.last_active_model_context.set_by_user,
3827                  )
3828              )
3829          else:
3830              _ACTIVE_MODEL_CONTEXT.set(self.last_active_model_context)
3831  
3832  
3833  # NB: This function is only intended to be used publicly by users to set the
3834  # active model ID. MLflow internally should NEVER call this function directly,
3835  # since we need to differentiate between user and system set active model IDs.
3836  # For MLflow internal usage, use `_set_active_model` instead.
3837  
3838  
3839  def set_active_model(*, name: str | None = None, model_id: str | None = None) -> ActiveModel:
3840      """
3841      Set the active model with the specified name or model ID, and it will be used for linking
3842      traces that are generated during the lifecycle of the model. The return value can be used as
3843      a context manager within a ``with`` block; otherwise, you must call ``set_active_model()``
3844      to update active model.
3845  
3846      Args:
3847          name: The name of the :py:class:`mlflow.entities.LoggedModel` to set as active.
3848              If a LoggedModel with the name does not exist, it will be created under the current
3849              experiment. If multiple LoggedModels with the name exist, the latest one will be
3850              set as active.
3851          model_id: The ID of the :py:class:`mlflow.entities.LoggedModel` to set as active.
3852              If no LoggedModel with the ID exists, an exception will be raised.
3853  
3854      Returns:
3855          :py:class:`mlflow.ActiveModel` object that acts as a context manager wrapping the
3856          LoggedModel's state.
3857  
3858      .. code-block:: python
3859          :test:
3860          :caption: Example
3861  
3862          import mlflow
3863  
3864          # Set the active model by name
3865          mlflow.set_active_model(name="my_model")
3866  
3867          # Set the active model by model ID
3868          model = mlflow.create_external_model(name="test_model")
3869          mlflow.set_active_model(model_id=model.model_id)
3870  
3871          # Use the active model in a context manager
3872          with mlflow.set_active_model(name="new_model"):
3873              print(mlflow.get_active_model_id())
3874  
3875          # Traces are automatically linked to the active model
3876          mlflow.set_active_model(name="my_model")
3877  
3878  
3879          @mlflow.trace
3880          def predict(model_input):
3881              return model_input
3882  
3883  
3884          predict("abc")
3885          traces = mlflow.search_traces(
3886              model_id=mlflow.get_active_model_id(), return_type="list", flush=True
3887          )
3888          assert len(traces) == 1
3889      """
3890      return _set_active_model(name=name, model_id=model_id, set_by_user=True)
3891  
3892  
3893  def _set_active_model(
3894      *, name: str | None = None, model_id: str | None = None, set_by_user: bool = False
3895  ) -> ActiveModel:
3896      if name is None and model_id is None:
3897          raise MlflowException.invalid_parameter_value(
3898              message="Either name or model_id must be provided",
3899          )
3900  
3901      if model_id is not None:
3902          logged_model = mlflow.get_logged_model(model_id)
3903          if name is not None and logged_model.name != name:
3904              raise MlflowException.invalid_parameter_value(
3905                  f"LoggedModel with model_id {model_id!r} has name {logged_model.name!r}, which does"
3906                  f" not match the provided name {name!r}."
3907              )
3908      elif name is not None:
3909          logged_models = mlflow.search_logged_models(
3910              filter_string=f"name='{name}'", max_results=2, output_format="list"
3911          )
3912          if len(logged_models) > 1:
3913              _logger.warning(
3914                  f"Multiple LoggedModels found with name {name!r}, setting the latest one as active "
3915                  "model."
3916              )
3917          if len(logged_models) == 0:
3918              _logger.info(f"LoggedModel with name {name!r} does not exist, creating one...")
3919              logged_model = mlflow.create_external_model(name=name)
3920          else:
3921              logged_model = logged_models[0]
3922      return ActiveModel(logged_model=logged_model, set_by_user=set_by_user)
3923  
3924  
3925  def _set_active_model_id(model_id: str, set_by_user: bool = False) -> None:
3926      """
3927      Set the active model ID in the active model context and update the
3928      corresponding environment variable. This should only be used when
3929      we know the LoggedModel with the model_id exists.
3930      This function should be used inside MLflow to set the active model
3931      while not blocking other code execution.
3932      """
3933      try:
3934          _ACTIVE_MODEL_CONTEXT.set(ActiveModelContext(model_id, set_by_user))
3935      except Exception as e:
3936          _logger.warning(f"Failed to set active model ID to {model_id}, error: {e}")
3937      else:
3938          _logger.info(f"Active model is set to the logged model with ID: {model_id}")
3939          if not set_by_user:
3940              _logger.info(
3941                  "Use `mlflow.set_active_model` to set the active model "
3942                  "to a different one if needed."
3943              )
3944  
3945  
3946  def _get_active_model_context() -> ActiveModelContext:
3947      """
3948      Get the active model context. This is used internally by MLflow to manage the active model
3949      context.
3950      """
3951      return _ACTIVE_MODEL_CONTEXT.get()
3952  
3953  
3954  def get_active_model_id() -> str | None:
3955      """
3956      Get the active model ID. If no active model is set with ``set_active_model()``, the
3957      default active model is set using model ID from the environment variable
3958      ``MLFLOW_ACTIVE_MODEL_ID`` or the legacy environment variable ``_MLFLOW_ACTIVE_MODEL_ID``.
3959      If neither is set, return None. Note that this function only get the active model ID from the
3960      current thread.
3961  
3962      Returns:
3963          The active model ID if set, otherwise None.
3964      """
3965      return _get_active_model_context().model_id
3966  
3967  
3968  def _get_active_model_id_global() -> str | None:
3969      """
3970      Get the active model ID from the global context by checking all threads.
3971      This is useful when we need to get the active_model_id set by a different thread.
3972      """
3973      # if the active model ID is set in the current thread, always use it
3974      if model_id_in_current_thread := get_active_model_id():
3975          _logger.debug(f"Active model ID found in the current thread: {model_id_in_current_thread}")
3976          return model_id_in_current_thread
3977      model_ids = [
3978          ctx.model_id
3979          for ctx in _ACTIVE_MODEL_CONTEXT.get_all_thread_values().values()
3980          if ctx.model_id is not None
3981      ]
3982      if model_ids:
3983          if len(set(model_ids)) > 1:
3984              _logger.debug(
3985                  "Failed to get one active model id from all threads, multiple active model IDs "
3986                  f"found: {set(model_ids)}."
3987              )
3988              return
3989          return model_ids[0]
3990      _logger.debug("No active model ID found in any thread.")
3991  
3992  
3993  def clear_active_model() -> None:
3994      """
3995      Clear the active model. This will clear the active model previously set by
3996      :py:func:`mlflow.set_active_model` or via the ``MLFLOW_ACTIVE_MODEL_ID`` environment variable
3997      or the ``_MLFLOW_ACTIVE_MODEL_ID`` legacy environment variable.
3998  
3999      from current thread. To temporarily switch
4000      the active model, use ``with mlflow.set_active_model(...)`` instead.
4001  
4002      .. code-block:: python
4003          :test:
4004          :caption: Example
4005  
4006          import mlflow
4007  
4008          # Set the active model by name
4009          mlflow.set_active_model(name="my_model")
4010  
4011          # Clear the active model
4012          mlflow.clear_active_model()
4013          # Check that the active model is None
4014          assert mlflow.get_active_model_id() is None
4015  
4016          # If you want to temporarily set the active model,
4017          # use  `set_active_model` as a context manager instead
4018          with mlflow.set_active_model(name="my_model") as active_model:
4019              assert mlflow.get_active_model_id() == active_model.model_id
4020          assert mlflow.get_active_model_id() is None
4021      """
4022      # reset the environment variables as well to avoid them being used when creating
4023      # ActiveModelContext
4024      MLFLOW_ACTIVE_MODEL_ID.unset()
4025      _MLFLOW_ACTIVE_MODEL_ID.unset()
4026  
4027      # Reset the active model context to avoid the active model ID set by other threads
4028      # to be used when creating a new ActiveModelContext
4029      _ACTIVE_MODEL_CONTEXT.reset()
4030      # set_by_user is False because this API clears the state of active model
4031      # and MLflow might still set the active model in cases like `load_model`
4032      _ACTIVE_MODEL_CONTEXT.set(ActiveModelContext(set_by_user=False))
4033      _logger.info("Active model is cleared")