fluent.py
  1  import json
  2  import logging
  3  import threading
  4  import uuid
  5  import warnings
  6  from typing import Any
  7  
  8  from pydantic import BaseModel
  9  
 10  import mlflow
 11  from mlflow.entities.logged_model import LoggedModel
 12  from mlflow.entities.model_registry import ModelVersion, Prompt, PromptVersion, RegisteredModel
 13  from mlflow.entities.model_registry.prompt_version import PromptModelConfig
 14  from mlflow.entities.run import Run
 15  from mlflow.environment_variables import MLFLOW_PRINT_MODEL_URLS_ON_CREATION
 16  from mlflow.exceptions import MlflowException
 17  from mlflow.models.model import MLMODEL_FILE_NAME
 18  from mlflow.prompt.registry_utils import require_prompt_registry
 19  from mlflow.protos.databricks_pb2 import (
 20      ALREADY_EXISTS,
 21      NOT_FOUND,
 22      RESOURCE_ALREADY_EXISTS,
 23      RESOURCE_DOES_NOT_EXIST,
 24      ErrorCode,
 25  )
 26  from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
 27  from mlflow.store.artifact.utils.models import _parse_model_id_if_present
 28  from mlflow.store.model_registry import (
 29      SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
 30      SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
 31  )
 32  from mlflow.telemetry.events import LoadPromptEvent
 33  from mlflow.telemetry.track import record_usage_event
 34  from mlflow.tracing.constant import SpanAttributeKey
 35  from mlflow.tracing.fluent import get_active_trace_id, get_current_active_span
 36  from mlflow.tracing.trace_manager import InMemoryTraceManager
 37  from mlflow.tracing.utils.prompt import update_linked_prompts_tag
 38  from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
 39  from mlflow.tracking.client import MlflowClient
 40  from mlflow.tracking.fluent import _get_latest_active_run, get_active_model_id
 41  from mlflow.utils import get_results_from_paginated_fn, mlflow_tags
 42  from mlflow.utils.databricks_utils import (
 43      _construct_databricks_uc_registered_model_url,
 44      get_workspace_id,
 45      get_workspace_url,
 46      stage_model_for_databricks_model_serving,
 47  )
 48  from mlflow.utils.env_pack import (
 49      EnvPackConfig,
 50      EnvPackType,
 51      _validate_env_pack,
 52      pack_env_for_databricks_model_serving,
 53  )
 54  from mlflow.utils.logging_utils import eprint
 55  from mlflow.utils.uri import is_databricks_unity_catalog_uri
 56  
 57  _logger = logging.getLogger(__name__)
 58  
 59  
 60  PROMPT_API_MIGRATION_MSG = (
 61      "The `mlflow.{func_name}` API is moved to the `mlflow.genai` namespace. Please use "
 62      "`mlflow.genai.{func_name}` instead. The original API will be removed in the "
 63      "future release."
 64  )
 65  
 66  
 67  def register_model(
 68      model_uri,
 69      name,
 70      await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
 71      *,
 72      tags: dict[str, Any] | None = None,
 73      env_pack: EnvPackType | EnvPackConfig | None = None,
 74  ) -> ModelVersion:
 75      """Create a new model version in model registry for the model files specified by ``model_uri``.
 76  
 77      Note that this method assumes the model registry backend URI is the same as that of the
 78      tracking backend.
 79  
 80      Args:
 81          model_uri: URI referring to the MLmodel directory. Supported URI schemes include:
 82  
 83              - ``runs:/`` URIs (e.g., ``runs:/<run_id>/<artifact_path>``) to register a model
 84                from a specific run. The run ID is recorded with the model version.
 85              - ``models:/`` URIs, which support two forms:
 86  
 87                - ``models:/<model_name>/<version>`` to promote an existing registered
 88                  model version. The source run lineage is preserved when the
 89                  referenced model version has an associated source run.
 90                - ``models:/<model_id>`` to create a new registered model version from a logged
 91                  model (for example, one returned by ``log_model``). The source
 92                  run lineage is preserved.
 93  
 94              - Local filesystem paths for registering locally-persisted MLflow models that were
 95                previously saved using ``save_model``.
 96  
 97          name: Name of the registered model under which to create a new model version. If a
 98              registered model with the given name does not exist, it will be created
 99              automatically.
100          await_registration_for: Number of seconds to wait for the model version to finish
101              being created and is in ``READY`` status. By default, the function
102              waits for five minutes. Specify 0 or None to skip waiting.
103          tags: A dictionary of key-value pairs that are converted into
104              :py:class:`mlflow.entities.model_registry.ModelVersionTag` objects.
105          env_pack: Either a string or an EnvPackConfig. If specified,
106              the model dependencies are optionally first installed into the current Python
107              environment, and then the complete environment will be packaged and included
108              in the registered model artifacts. If the string shortcut "databricks_model_serving" is
109              used, then model dependencies will be installed in the current environment. This is
110              useful when deploying the model to a serving environment like Databricks Model Serving.
111  
112              .. Note:: Experimental: This parameter may change or be removed in a future
113                                      release without warning.
114  
115      Returns:
116          Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by
117          backend.
118  
119      .. code-block:: python
120          :test:
121          :caption: Example
122  
123          import mlflow.sklearn
124          from mlflow.models import infer_signature
125          from sklearn.datasets import make_regression
126          from sklearn.ensemble import RandomForestRegressor
127  
128          mlflow.set_tracking_uri("sqlite:////tmp/mlruns.db")
129          params = {"n_estimators": 3, "random_state": 42}
130          X, y = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
131          # Log MLflow entities
132          with mlflow.start_run() as run:
133              rfr = RandomForestRegressor(**params).fit(X, y)
134              signature = infer_signature(X, rfr.predict(X))
135              mlflow.log_params(params)
136              mlflow.sklearn.log_model(rfr, name="sklearn-model", signature=signature)
137          model_uri = f"runs:/{run.info.run_id}/sklearn-model"
138          mv = mlflow.register_model(model_uri, "RandomForestRegressionModel")
139          print(f"Name: {mv.name}")
140          print(f"Version: {mv.version}")
141  
142      .. code-block:: text
143          :caption: Output
144  
145          Name: RandomForestRegressionModel
146          Version: 1
147      """
148      return _register_model(
149          model_uri=model_uri,
150          name=name,
151          await_registration_for=await_registration_for,
152          tags=tags,
153          env_pack=env_pack,
154      )
155  
156  
157  def _register_model(
158      model_uri,
159      name,
160      await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
161      *,
162      tags: dict[str, Any] | None = None,
163      local_model_path=None,
164      env_pack: EnvPackType | EnvPackConfig | None = None,
165  ) -> ModelVersion:
166      client = MlflowClient()
167      try:
168          create_model_response = client.create_registered_model(name)
169          eprint(f"Successfully registered model '{create_model_response.name}'.")
170      except MlflowException as e:
171          if e.error_code in (
172              ErrorCode.Name(RESOURCE_ALREADY_EXISTS),
173              ErrorCode.Name(ALREADY_EXISTS),
174          ):
175              eprint(
176                  f"Registered model {name!r} already exists. Creating a new version of this model..."
177              )
178          else:
179              raise e
180  
181      run_id = None
182      model_id = None
183      source = model_uri
184      if RunsArtifactRepository.is_runs_uri(model_uri):
185          # If the uri is of the form runs:/...
186          (run_id, artifact_path) = RunsArtifactRepository.parse_runs_uri(model_uri)
187          runs_artifact_repo = RunsArtifactRepository(model_uri)
188          # List artifacts in `<run_artifact_root>/<artifact_path>` to see if the run has artifacts.
189          # If so use the run's artifact location as source.
190          artifacts = runs_artifact_repo._list_run_artifacts()
191          if MLMODEL_FILE_NAME in (art.path for art in artifacts):
192              source = RunsArtifactRepository.get_underlying_uri(model_uri)
193          # Otherwise check if there's a logged model with
194          # name artifact_path and source_run_id run_id
195          else:
196              run = client.get_run(run_id)
197              logged_models = _get_logged_models_from_run(run, artifact_path)
198              if not logged_models:
199                  raise MlflowException(
200                      f"Unable to find a logged_model with artifact_path {artifact_path} "
201                      f"under run {run_id}",
202                      error_code=ErrorCode.Name(NOT_FOUND),
203                  )
204              if len(logged_models) > 1:
205                  if run.outputs is None:
206                      raise MlflowException.invalid_parameter_value(
207                          f"Multiple logged models found for run {run_id}. Cannot determine "
208                          "which model to register. Please use `models:/<model_id>` instead."
209                      )
210                  # If there are multiple such logged models, get the one logged at the largest step
211                  model_id_to_step = {m_o.model_id: m_o.step for m_o in run.outputs.model_outputs}
212                  model_id = max(logged_models, key=lambda lm: model_id_to_step[lm.model_id]).model_id
213              else:
214                  model_id = logged_models[0].model_id
215              source = f"models:/{model_id}"
216              _logger.warning(
217                  f"Run with id {run_id} has no artifacts at artifact path {artifact_path!r}, "
218                  f"registering model based on {source} instead"
219              )
220  
221      # Otherwise if the uri is of the form models:/..., try to get the model_id from the uri directly
222      model_id = _parse_model_id_if_present(model_uri) if not model_id else model_id
223  
224      # Passing in the string value is a shortcut for passing in the EnvPackConfig
225      # Validate early; `_validate_env_pack` will raise on invalid inputs.
226      validated_env_pack = _validate_env_pack(env_pack)
227  
228      # Helper to avoid parameter drift below.
229      def _create_model_version(local_model_path: str | None) -> ModelVersion:
230          return client._create_model_version(
231              name=name,
232              source=source,
233              run_id=run_id,
234              tags=tags,
235              await_creation_for=await_registration_for,
236              local_model_path=local_model_path,
237              model_id=model_id,
238          )
239  
240      # If env_pack is supported and indicates Databricks Model Serving,
241      # pack env locally and directly register the resulting artifacts.
242      # This avoids storing artifacts prior to the final registered model version.
243      if validated_env_pack:
244          eprint(
245              "Packing environment for Databricks Model Serving with install_dependencies "
246              f"{validated_env_pack.install_dependencies}..."
247          )
248          with pack_env_for_databricks_model_serving(
249              model_uri,
250              enforce_pip_requirements=validated_env_pack.install_dependencies,
251              local_model_path=local_model_path,
252          ) as artifacts_path_with_env:
253              create_version_response = _create_model_version(artifacts_path_with_env)
254      else:
255          create_version_response = _create_model_version(local_model_path)
256      created_message = (
257          f"Created version '{create_version_response.version}' of model "
258          f"'{create_version_response.name}'"
259      )
260      # Print a link to the UC model version page if the model is in UC.
261      registry_uri = mlflow.get_registry_uri()
262      if (
263          MLFLOW_PRINT_MODEL_URLS_ON_CREATION.get()
264          and is_databricks_unity_catalog_uri(registry_uri)
265          and (url := get_workspace_url())
266      ):
267          uc_model_url = _construct_databricks_uc_registered_model_url(
268              url,
269              create_version_response.name,
270              create_version_response.version,
271              get_workspace_id(),
272          )
273          created_message = "🔗 " + created_message + f": {uc_model_url}"
274      else:
275          created_message += "."
276      eprint(created_message)
277  
278      if model_id:
279          new_value = [
280              {
281                  "name": create_version_response.name,
282                  "version": create_version_response.version,
283              }
284          ]
285          try:
286              model = client.get_logged_model(model_id)
287              if existing_value := model.tags.get(mlflow_tags.MLFLOW_MODEL_VERSIONS):
288                  new_value = json.loads(existing_value) + new_value
289  
290              client.set_logged_model_tags(
291                  model_id,
292                  {mlflow_tags.MLFLOW_MODEL_VERSIONS: json.dumps(new_value)},
293              )
294          except MlflowException as e:
295              if e.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST):
296                  _logger.warning(
297                      "Unable to update logged model tags for model ID '%s': the logged model "
298                      "does not exist in the current workspace. No model version link will be "
299                      "recorded on the logged model.",
300                      model_id,
301                  )
302              else:
303                  raise
304  
305      if validated_env_pack:
306          eprint(
307              f"Staging model {create_version_response.name} "
308              f"version {create_version_response.version} "
309              "for Databricks Model Serving..."
310          )
311          try:
312              stage_model_for_databricks_model_serving(
313                  model_name=create_version_response.name,
314                  model_version=create_version_response.version,
315              )
316          except Exception as e:
317              eprint(
318                  f"Failed to stage model for Databricks Model Serving: {e!s}. "
319                  "The model was registered successfully and is available for serving, but may take "
320                  "longer to deploy."
321              )
322  
323      return create_version_response
324  
325  
326  def _get_logged_models_from_run(source_run: Run, model_name: str) -> list[LoggedModel]:
327      """Get all logged models from the source rnu that have the specified model name.
328  
329      Args:
330          source_run: Source run from which to retrieve logged models.
331          model_name: Name of the model to retrieve.
332      """
333      client = MlflowClient()
334      logged_models = []
335      page_token = None
336  
337      while True:
338          logged_models_page = client.search_logged_models(
339              experiment_ids=[source_run.info.experiment_id],
340              # TODO: Filter by 'source_run_id' once Databricks backend supports it
341              filter_string=f"name = '{model_name}'",
342              page_token=page_token,
343          )
344          logged_models.extend(
345              m for m in logged_models_page if m.source_run_id == source_run.info.run_id
346          )
347          if not logged_models_page.token:
348              break
349          page_token = logged_models_page.token
350  
351      return logged_models
352  
353  
354  def search_registered_models(
355      max_results: int | None = None,
356      filter_string: str | None = None,
357      order_by: list[str] | None = None,
358  ) -> list[RegisteredModel]:
359      """Search for registered models that satisfy the filter criteria.
360  
361      Args:
362          max_results: If passed, specifies the maximum number of models desired. If not
363              passed, all models will be returned.
364          filter_string: Filter query string (e.g., "name = 'a_model_name' and tag.key = 'value1'"),
365              defaults to searching for all registered models. The following identifiers, comparators,
366              and logical operators are supported.
367  
368              Identifiers
369                - "name": registered model name.
370                - "tags.<tag_key>": registered model tag. If "tag_key" contains spaces, it must be
371                  wrapped with backticks (e.g., "tags.`extra key`").
372  
373              Comparators
374                - "=": Equal to.
375                - "!=": Not equal to.
376                - "LIKE": Case-sensitive pattern match.
377                - "ILIKE": Case-insensitive pattern match.
378  
379              Logical operators
380                - "AND": Combines two sub-queries and returns True if both of them are True.
381  
382          order_by: List of column names with ASC|DESC annotation, to be used for ordering
383              matching search results.
384  
385      Returns:
386          A list of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
387          that satisfy the search expressions.
388  
389      .. code-block:: python
390          :test:
391          :caption: Example
392  
393          import mlflow
394          from sklearn.linear_model import LogisticRegression
395  
396          with mlflow.start_run():
397              mlflow.sklearn.log_model(
398                  LogisticRegression(),
399                  name="Cordoba",
400                  registered_model_name="CordobaWeatherForecastModel",
401              )
402              mlflow.sklearn.log_model(
403                  LogisticRegression(),
404                  name="Boston",
405                  registered_model_name="BostonWeatherForecastModel",
406              )
407  
408          # Get search results filtered by the registered model name
409          filter_string = "name = 'CordobaWeatherForecastModel'"
410          results = mlflow.search_registered_models(filter_string=filter_string)
411          print("-" * 80)
412          for res in results:
413              for mv in res.latest_versions:
414                  print(f"name={mv.name}; run_id={mv.run_id}; version={mv.version}")
415  
416          # Get search results filtered by the registered model name that matches
417          # prefix pattern
418          filter_string = "name LIKE 'Boston%'"
419          results = mlflow.search_registered_models(filter_string=filter_string)
420          print("-" * 80)
421          for res in results:
422              for mv in res.latest_versions:
423                  print(f"name={mv.name}; run_id={mv.run_id}; version={mv.version}")
424  
425          # Get all registered models and order them by ascending order of the names
426          results = mlflow.search_registered_models(order_by=["name ASC"])
427          print("-" * 80)
428          for res in results:
429              for mv in res.latest_versions:
430                  print(f"name={mv.name}; run_id={mv.run_id}; version={mv.version}")
431  
432      .. code-block:: text
433          :caption: Output
434  
435          --------------------------------------------------------------------------------
436          name=CordobaWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1
437          --------------------------------------------------------------------------------
438          name=BostonWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1
439          --------------------------------------------------------------------------------
440          name=BostonWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1
441          name=CordobaWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1
442      """
443  
444      def pagination_wrapper_func(number_to_get, next_page_token):
445          return MlflowClient().search_registered_models(
446              max_results=number_to_get,
447              filter_string=filter_string,
448              order_by=order_by,
449              page_token=next_page_token,
450          )
451  
452      return get_results_from_paginated_fn(
453          pagination_wrapper_func,
454          SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
455          max_results,
456      )
457  
458  
459  def search_model_versions(
460      max_results: int | None = None,
461      filter_string: str | None = None,
462      order_by: list[str] | None = None,
463  ) -> list[ModelVersion]:
464      """Search for model versions that satisfy the filter criteria.
465  
466      .. warning:
467  
468          The model version search results may not have aliases populated for performance reasons.
469  
470      Args:
471          max_results: If passed, specifies the maximum number of models desired. If not
472              passed, all models will be returned.
473          filter_string: Filter query string
474              (e.g., ``"name = 'a_model_name' and tag.key = 'value1'"``),
475              defaults to searching for all model versions. The following identifiers, comparators,
476              and logical operators are supported.
477  
478              Identifiers
479                - ``name``: model name.
480                - ``source_path``: model version source path.
481                - ``run_id``: The id of the mlflow run that generates the model version.
482                - ``tags.<tag_key>``: model version tag. If ``tag_key`` contains spaces, it must be
483                  wrapped with backticks (e.g., ``"tags.`extra key`"``).
484  
485              Comparators
486                - ``=``: Equal to.
487                - ``!=``: Not equal to.
488                - ``LIKE``: Case-sensitive pattern match.
489                - ``ILIKE``: Case-insensitive pattern match.
490                - ``IN``: In a value list. Only ``run_id`` identifier supports ``IN`` comparator.
491  
492              Logical operators
493                - ``AND``: Combines two sub-queries and returns True if both of them are True.
494  
495          order_by: List of column names with ASC|DESC annotation, to be used for ordering
496              matching search results.
497  
498      Returns:
499          A list of :py:class:`mlflow.entities.model_registry.ModelVersion` objects
500              that satisfy the search expressions.
501  
502      .. code-block:: python
503          :test:
504          :caption: Example
505  
506          import mlflow
507          from sklearn.linear_model import LogisticRegression
508  
509          for _ in range(2):
510              with mlflow.start_run():
511                  mlflow.sklearn.log_model(
512                      LogisticRegression(),
513                      name="Cordoba",
514                      registered_model_name="CordobaWeatherForecastModel",
515                  )
516  
517          # Get all versions of the model filtered by name
518          filter_string = "name = 'CordobaWeatherForecastModel'"
519          results = mlflow.search_model_versions(filter_string=filter_string)
520          print("-" * 80)
521          for res in results:
522              print(f"name={res.name}; run_id={res.run_id}; version={res.version}")
523  
524          # Get the version of the model filtered by run_id
525          filter_string = "run_id = 'ae9a606a12834c04a8ef1006d0cff779'"
526          results = mlflow.search_model_versions(filter_string=filter_string)
527          print("-" * 80)
528          for res in results:
529              print(f"name={res.name}; run_id={res.run_id}; version={res.version}")
530  
531      .. code-block:: text
532          :caption: Output
533  
534          --------------------------------------------------------------------------------
535          name=CordobaWeatherForecastModel; run_id=ae9a606a12834c04a8ef1006d0cff779; version=2
536          name=CordobaWeatherForecastModel; run_id=d8f028b5fedf4faf8e458f7693dfa7ce; version=1
537          --------------------------------------------------------------------------------
538          name=CordobaWeatherForecastModel; run_id=ae9a606a12834c04a8ef1006d0cff779; version=2
539      """
540  
541      def pagination_wrapper_func(number_to_get, next_page_token):
542          return MlflowClient().search_model_versions(
543              max_results=number_to_get,
544              filter_string=filter_string,
545              order_by=order_by,
546              page_token=next_page_token,
547          )
548  
549      return get_results_from_paginated_fn(
550          paginated_fn=pagination_wrapper_func,
551          max_results_per_page=SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
552          max_results=max_results,
553      )
554  
555  
556  def set_model_version_tag(
557      name: str,
558      version: str | None = None,
559      key: str | None = None,
560      value: Any = None,
561  ) -> None:
562      """
563      Set a tag for the model version.
564  
565      Args:
566          name: Registered model name.
567          version: Registered model version.
568          key: Tag key to log. key is required.
569          value: Tag value to log. value is required.
570      """
571      return MlflowClient().set_model_version_tag(
572          name=name,
573          version=version,
574          key=key,
575          value=value,
576      )
577  
578  
579  @require_prompt_registry
580  def register_prompt(
581      name: str,
582      template: str | list[dict[str, Any]],
583      commit_message: str | None = None,
584      tags: dict[str, str] | None = None,
585      response_format: type[BaseModel] | dict[str, Any] | None = None,
586      model_config: "PromptModelConfig | dict[str, Any] | None" = None,
587  ) -> PromptVersion:
588      """
589      Register a new :py:class:`Prompt <mlflow.entities.Prompt>` in the MLflow Prompt Registry.
590  
591      A :py:class:`Prompt <mlflow.entities.Prompt>` is a pair of name and
592      template content at minimum. With MLflow Prompt Registry, you can create, manage, and
593      version control prompts with the MLflow's robust model tracking framework.
594  
595      If there is no registered prompt with the given name, a new prompt will be created.
596      Otherwise, a new version of the existing prompt will be created.
597  
598  
599      Args:
600          name: The name of the prompt.
601          template: The template content of the prompt. Can be either:
602              - A string containing text with variables enclosed in double curly braces,
603                e.g. {{variable}}, which will be replaced with actual values by the `format` method.
604              - A list of dictionaries representing chat messages, where each message has
605                'role' and 'content' keys (e.g., [{"role": "user", "content": "Hello {{name}}"}])
606  
607              .. note::
608  
609                  If you want to use the prompt with a framework that uses single curly braces
610                  e.g. LangChain, you can use the `to_single_brace_format` method to convert the
611                  loaded prompt to a format that uses single curly braces.
612  
613                  .. code-block:: python
614  
615                      prompt = client.load_prompt("my_prompt")
616                      langchain_format = prompt.to_single_brace_format()
617  
618          commit_message: A message describing the changes made to the prompt, similar to a
619              Git commit message. Optional.
620          tags: A dictionary of tags associated with the **prompt version**.
621              This is useful for storing version-specific information, such as the author of
622              the changes. Optional.
623          response_format: Optional Pydantic class or dictionary defining the expected response
624              structure. This can be used to specify the schema for structured outputs from LLM calls.
625          model_config: Optional PromptModelConfig instance or dictionary containing model-specific
626              configuration. Using PromptModelConfig provides validation and type safety.
627  
628      Returns:
629          A :py:class:`Prompt <mlflow.entities.Prompt>` object that was created.
630  
631      Example:
632  
633      .. code-block:: python
634  
635          import mlflow
636          from pydantic import BaseModel
637  
638          # Register a text prompt
639          mlflow.register_prompt(
640              name="greeting_prompt",
641              template="Respond to the user's message as a {{style}} AI.",
642              response_format={"type": "string", "description": "A friendly response"},
643          )
644  
645          # Register a chat prompt with multiple messages
646          mlflow.register_prompt(
647              name="assistant_prompt",
648              template=[
649                  {"role": "system", "content": "You are a helpful {{style}} assistant."},
650                  {"role": "user", "content": "{{question}}"},
651              ],
652              response_format={"type": "object", "properties": {"answer": {"type": "string"}}},
653          )
654  
655          # Load the prompt from the registry
656          prompt = mlflow.load_prompt("greeting_prompt")
657  
658          # Use the prompt in your application
659          import openai
660  
661          openai_client = openai.OpenAI()
662          openai_client.chat.completion.create(
663              model="gpt-4o-mini",
664              messages=[
665                  {"role": "system", "content": prompt.format(style="friendly")},
666                  {"role": "user", "content": "Hello, how are you?"},
667              ],
668          )
669  
670          # Update the prompt with a new version
671          prompt = mlflow.register_prompt(
672              name="greeting_prompt",
673              template="Respond to the user's message as a {{style}} AI. {{greeting}}",
674              commit_message="Add a greeting to the prompt.",
675              tags={"author": "Bob"},
676          )
677      """
678      warnings.warn(
679          PROMPT_API_MIGRATION_MSG.format(func_name="register_prompt"),
680          category=FutureWarning,
681          stacklevel=3,
682      )
683  
684      return MlflowClient().register_prompt(
685          name=name,
686          template=template,
687          commit_message=commit_message,
688          tags=tags,
689          response_format=response_format,
690          model_config=model_config,
691      )
692  
693  
694  @require_prompt_registry
695  def search_prompts(
696      filter_string: str | None = None,
697      max_results: int | None = None,
698  ) -> list[Prompt]:
699      """
700      Search for prompts in the MLflow Prompt Registry.
701  
702      This call returns prompt metadata for prompts that have been marked
703      as prompts (i.e. tagged with `mlflow.prompt.is_prompt=true`). We can
704      further restrict results via a standard registry filter expression.
705  
706      Args:
707          filter_string (Optional[str]):
708              An additional registry-search expression to apply (e.g.
709              `"name LIKE 'my_prompt%'"`).  For Unity Catalog registries, must include
710              catalog and schema: "catalog = 'catalog_name' AND schema = 'schema_name'".
711          max_results (Optional[int]):
712              The maximum number of prompts to return.
713  
714      Returns:
715          A list of :py:class:`Prompt <mlflow.entities.Prompt>` objects representing prompt metadata:
716  
717          - name: The prompt name
718          - description: The prompt description
719          - tags: Prompt-level tags
720          - creation_timestamp: When the prompt was created
721  
722          To get the actual prompt template content,
723          use :py:func:`mlflow.genai.load_prompt()` API with a specific version:
724  
725          .. code-block:: python
726              import mlflow
727  
728              # Search for prompts
729              prompts = mlflow.genai.search_prompts(filter_string="name LIKE 'greeting%'")
730  
731              # Get prompts by experiment
732              prompts = mlflow.genai.search_prompts(filter_string='experiment_id = "1"')
733  
734              # Get specific version content
735              for prompt in prompts:
736                  prompt_version = mlflow.genai.load_prompt(prompt.name, version="1")
737                  print(f"Template: {prompt_version.template}")
738      """
739      warnings.warn(
740          PROMPT_API_MIGRATION_MSG.format(func_name="search_prompts"),
741          category=FutureWarning,
742          stacklevel=3,
743      )
744  
745      def pagination_wrapper_func(number_to_get, next_page_token):
746          return MlflowClient().search_prompts(
747              filter_string=filter_string, max_results=number_to_get, page_token=next_page_token
748          )
749  
750      return get_results_from_paginated_fn(
751          pagination_wrapper_func,
752          SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
753          max_results,
754      )
755  
756  
757  @require_prompt_registry
758  @record_usage_event(LoadPromptEvent)
759  def load_prompt(
760      name_or_uri: str,
761      version: str | int | None = None,
762      allow_missing: bool = False,
763      link_to_model: bool = True,
764      model_id: str | None = None,
765      cache_ttl_seconds: float | None = None,
766  ) -> PromptVersion:
767      """
768      Load a :py:class:`Prompt <mlflow.entities.Prompt>` from the MLflow Prompt Registry.
769  
770      The prompt can be specified by name and version, or by URI.
771  
772      Args:
773          name_or_uri: The name of the prompt, or the URI in the format "prompts:/name/version".
774          version: The version of the prompt (required when using name, not allowed when using URI).
775          allow_missing: If True, return None instead of raising Exception if the specified prompt
776              is not found.
777          link_to_model: If True, the prompt will be linked to the model with the ID specified
778                         by `model_id`, or the active model ID if `model_id` is None and
779                         there is an active model.
780          model_id: The ID of the model to which to link the prompt, if `link_to_model` is True.
781          cache_ttl_seconds: Time-to-live in seconds for the cached prompt. If not specified,
782              uses the value from `MLFLOW_ALIAS_PROMPT_CACHE_TTL_SECONDS` environment variable for
783              alias-based prompts (default 60), and the value from
784              `MLFLOW_VERSION_PROMPT_CACHE_TTL_SECONDS` environment variable for version-based prompts
785              (default None, no TTL). Set to 0 to bypass the cache and always fetch from the server.
786  
787      Example:
788  
789      .. code-block:: python
790  
791          import mlflow
792  
793          # Load a specific version of the prompt
794          prompt = mlflow.load_prompt("my_prompt", version=1)
795  
796          # Load a specific version of the prompt by URI
797          prompt = mlflow.load_prompt("prompts:/my_prompt/1")
798  
799          # Load a prompt version with an alias "production"
800          prompt = mlflow.load_prompt("prompts:/my_prompt@production")
801  
802          # Load with custom cache TTL (5 minutes)
803          prompt = mlflow.load_prompt("my_prompt", version=1, cache_ttl_seconds=300)
804  
805          # Bypass cache entirely
806          prompt = mlflow.load_prompt("my_prompt", version=1, cache_ttl_seconds=0)
807  
808      """
809      warnings.warn(
810          PROMPT_API_MIGRATION_MSG.format(func_name="load_prompt"),
811          category=FutureWarning,
812          stacklevel=3,
813      )
814  
815      client = MlflowClient()
816  
817      # Load prompt with caching (handled by client)
818      prompt = client.load_prompt(
819          name_or_uri=name_or_uri,
820          version=version,
821          allow_missing=allow_missing,
822          cache_ttl_seconds=cache_ttl_seconds,
823      )
824      if prompt is None:
825          return
826  
827      # If there is an active MLflow run, associate the prompt with the run.
828      # Note that we do this synchronously because it's unlikely that run linking occurs
829      # in a latency sensitive environment, since runs aren't typically used in real-time /
830      # production scenarios
831      # NB: We shouldn't use `active_run()` here because it only returns the active run
832      # from the current thread. It doesn't work in multi-threaded scenarios such as
833      # MLflow GenAI evaluation.
834      if run := _get_latest_active_run():
835          client.link_prompt_version_to_run(run.info.run_id, prompt)
836  
837      if link_to_model:
838          model_id = model_id or get_active_model_id()
839          if model_id is not None:
840              # Run linking in background thread to avoid blocking prompt loading. Prompt linking
841              # is not critical for the user's workflow (if the prompt fails to link, the user's
842              # workflow is minorly affected), so we handle it asynchronously and gracefully
843              # handle any failures without impacting the core prompt loading functionality.
844  
845              def _link_prompt_async():
846                  try:
847                      client.link_prompt_version_to_model(
848                          name=prompt.name,
849                          version=prompt.version,
850                          model_id=model_id,
851                      )
852                  except Exception:
853                      # NB: We should still load the prompt even if linking fails, since the prompt
854                      # is critical to the caller's application logic
855                      _logger.warning(
856                          f"Failed to link prompt '{prompt.name}' version '{prompt.version}'"
857                          f" to model '{model_id}'.",
858                          exc_info=True,
859                      )
860  
861              # Start linking in background - don't wait for completion
862              link_thread = threading.Thread(
863                  target=_link_prompt_async, name=f"link_prompt_thread-{uuid.uuid4().hex[:8]}"
864              )
865              link_thread.start()
866  
867      if trace_id := get_active_trace_id():
868          InMemoryTraceManager.get_instance().register_prompt(
869              trace_id=trace_id,
870              prompt=prompt,
871          )
872  
873      # Set prompt version information as span attributes if there's an active span
874      if span := get_current_active_span():
875          current_value = span.attributes.get(SpanAttributeKey.LINKED_PROMPTS)
876          updated_value = update_linked_prompts_tag(current_value, [prompt])
877          span.set_attribute(SpanAttributeKey.LINKED_PROMPTS, updated_value)
878  
879      return prompt
880  
881  
882  @require_prompt_registry
883  def set_prompt_alias(name: str, alias: str, version: int) -> None:
884      """
885      Set an alias for a :py:class:`Prompt <mlflow.entities.Prompt>` in the MLflow Prompt Registry.
886  
887      Args:
888          name: The name of the prompt.
889          alias: The alias to set for the prompt.
890          version: The version of the prompt.
891  
892      Example:
893  
894      .. code-block:: python
895  
896          import mlflow
897  
898          # Set an alias for the prompt
899          mlflow.set_prompt_alias(name="my_prompt", version=1, alias="production")
900  
901          # Load the prompt by alias (use "@" to specify the alias)
902          prompt = mlflow.load_prompt("prompts:/my_prompt@production")
903  
904          # Switch the alias to a new version of the prompt
905          mlflow.set_prompt_alias(name="my_prompt", version=2, alias="production")
906  
907          # Delete the alias
908          mlflow.delete_prompt_alias(name="my_prompt", alias="production")
909      """
910      warnings.warn(
911          PROMPT_API_MIGRATION_MSG.format(func_name="set_prompt_alias"),
912          category=FutureWarning,
913          stacklevel=3,
914      )
915  
916      MlflowClient().set_prompt_alias(name=name, version=version, alias=alias)
917  
918  
919  @require_prompt_registry
920  def delete_prompt_alias(name: str, alias: str) -> None:
921      """
922      Delete an alias for a :py:class:`Prompt <mlflow.entities.Prompt>` in the MLflow Prompt Registry.
923  
924      Args:
925          name: The name of the prompt.
926          alias: The alias to delete for the prompt.
927      """
928      warnings.warn(
929          PROMPT_API_MIGRATION_MSG.format(func_name="delete_prompt_alias"),
930          category=FutureWarning,
931          stacklevel=3,
932      )
933  
934      MlflowClient().delete_prompt_alias(name=name, alias=alias)