fluent.py
1 import json 2 import logging 3 import threading 4 import uuid 5 import warnings 6 from typing import Any 7 8 from pydantic import BaseModel 9 10 import mlflow 11 from mlflow.entities.logged_model import LoggedModel 12 from mlflow.entities.model_registry import ModelVersion, Prompt, PromptVersion, RegisteredModel 13 from mlflow.entities.model_registry.prompt_version import PromptModelConfig 14 from mlflow.entities.run import Run 15 from mlflow.environment_variables import MLFLOW_PRINT_MODEL_URLS_ON_CREATION 16 from mlflow.exceptions import MlflowException 17 from mlflow.models.model import MLMODEL_FILE_NAME 18 from mlflow.prompt.registry_utils import require_prompt_registry 19 from mlflow.protos.databricks_pb2 import ( 20 ALREADY_EXISTS, 21 NOT_FOUND, 22 RESOURCE_ALREADY_EXISTS, 23 RESOURCE_DOES_NOT_EXIST, 24 ErrorCode, 25 ) 26 from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository 27 from mlflow.store.artifact.utils.models import _parse_model_id_if_present 28 from mlflow.store.model_registry import ( 29 SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT, 30 SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, 31 ) 32 from mlflow.telemetry.events import LoadPromptEvent 33 from mlflow.telemetry.track import record_usage_event 34 from mlflow.tracing.constant import SpanAttributeKey 35 from mlflow.tracing.fluent import get_active_trace_id, get_current_active_span 36 from mlflow.tracing.trace_manager import InMemoryTraceManager 37 from mlflow.tracing.utils.prompt import update_linked_prompts_tag 38 from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS 39 from mlflow.tracking.client import MlflowClient 40 from mlflow.tracking.fluent import _get_latest_active_run, get_active_model_id 41 from mlflow.utils import get_results_from_paginated_fn, mlflow_tags 42 from mlflow.utils.databricks_utils import ( 43 _construct_databricks_uc_registered_model_url, 44 get_workspace_id, 45 get_workspace_url, 46 stage_model_for_databricks_model_serving, 47 ) 48 from mlflow.utils.env_pack import ( 49 EnvPackConfig, 50 EnvPackType, 51 _validate_env_pack, 52 pack_env_for_databricks_model_serving, 53 ) 54 from mlflow.utils.logging_utils import eprint 55 from mlflow.utils.uri import is_databricks_unity_catalog_uri 56 57 _logger = logging.getLogger(__name__) 58 59 60 PROMPT_API_MIGRATION_MSG = ( 61 "The `mlflow.{func_name}` API is moved to the `mlflow.genai` namespace. Please use " 62 "`mlflow.genai.{func_name}` instead. The original API will be removed in the " 63 "future release." 64 ) 65 66 67 def register_model( 68 model_uri, 69 name, 70 await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, 71 *, 72 tags: dict[str, Any] | None = None, 73 env_pack: EnvPackType | EnvPackConfig | None = None, 74 ) -> ModelVersion: 75 """Create a new model version in model registry for the model files specified by ``model_uri``. 76 77 Note that this method assumes the model registry backend URI is the same as that of the 78 tracking backend. 79 80 Args: 81 model_uri: URI referring to the MLmodel directory. Supported URI schemes include: 82 83 - ``runs:/`` URIs (e.g., ``runs:/<run_id>/<artifact_path>``) to register a model 84 from a specific run. The run ID is recorded with the model version. 85 - ``models:/`` URIs, which support two forms: 86 87 - ``models:/<model_name>/<version>`` to promote an existing registered 88 model version. The source run lineage is preserved when the 89 referenced model version has an associated source run. 90 - ``models:/<model_id>`` to create a new registered model version from a logged 91 model (for example, one returned by ``log_model``). The source 92 run lineage is preserved. 93 94 - Local filesystem paths for registering locally-persisted MLflow models that were 95 previously saved using ``save_model``. 96 97 name: Name of the registered model under which to create a new model version. If a 98 registered model with the given name does not exist, it will be created 99 automatically. 100 await_registration_for: Number of seconds to wait for the model version to finish 101 being created and is in ``READY`` status. By default, the function 102 waits for five minutes. Specify 0 or None to skip waiting. 103 tags: A dictionary of key-value pairs that are converted into 104 :py:class:`mlflow.entities.model_registry.ModelVersionTag` objects. 105 env_pack: Either a string or an EnvPackConfig. If specified, 106 the model dependencies are optionally first installed into the current Python 107 environment, and then the complete environment will be packaged and included 108 in the registered model artifacts. If the string shortcut "databricks_model_serving" is 109 used, then model dependencies will be installed in the current environment. This is 110 useful when deploying the model to a serving environment like Databricks Model Serving. 111 112 .. Note:: Experimental: This parameter may change or be removed in a future 113 release without warning. 114 115 Returns: 116 Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by 117 backend. 118 119 .. code-block:: python 120 :test: 121 :caption: Example 122 123 import mlflow.sklearn 124 from mlflow.models import infer_signature 125 from sklearn.datasets import make_regression 126 from sklearn.ensemble import RandomForestRegressor 127 128 mlflow.set_tracking_uri("sqlite:////tmp/mlruns.db") 129 params = {"n_estimators": 3, "random_state": 42} 130 X, y = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False) 131 # Log MLflow entities 132 with mlflow.start_run() as run: 133 rfr = RandomForestRegressor(**params).fit(X, y) 134 signature = infer_signature(X, rfr.predict(X)) 135 mlflow.log_params(params) 136 mlflow.sklearn.log_model(rfr, name="sklearn-model", signature=signature) 137 model_uri = f"runs:/{run.info.run_id}/sklearn-model" 138 mv = mlflow.register_model(model_uri, "RandomForestRegressionModel") 139 print(f"Name: {mv.name}") 140 print(f"Version: {mv.version}") 141 142 .. code-block:: text 143 :caption: Output 144 145 Name: RandomForestRegressionModel 146 Version: 1 147 """ 148 return _register_model( 149 model_uri=model_uri, 150 name=name, 151 await_registration_for=await_registration_for, 152 tags=tags, 153 env_pack=env_pack, 154 ) 155 156 157 def _register_model( 158 model_uri, 159 name, 160 await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, 161 *, 162 tags: dict[str, Any] | None = None, 163 local_model_path=None, 164 env_pack: EnvPackType | EnvPackConfig | None = None, 165 ) -> ModelVersion: 166 client = MlflowClient() 167 try: 168 create_model_response = client.create_registered_model(name) 169 eprint(f"Successfully registered model '{create_model_response.name}'.") 170 except MlflowException as e: 171 if e.error_code in ( 172 ErrorCode.Name(RESOURCE_ALREADY_EXISTS), 173 ErrorCode.Name(ALREADY_EXISTS), 174 ): 175 eprint( 176 f"Registered model {name!r} already exists. Creating a new version of this model..." 177 ) 178 else: 179 raise e 180 181 run_id = None 182 model_id = None 183 source = model_uri 184 if RunsArtifactRepository.is_runs_uri(model_uri): 185 # If the uri is of the form runs:/... 186 (run_id, artifact_path) = RunsArtifactRepository.parse_runs_uri(model_uri) 187 runs_artifact_repo = RunsArtifactRepository(model_uri) 188 # List artifacts in `<run_artifact_root>/<artifact_path>` to see if the run has artifacts. 189 # If so use the run's artifact location as source. 190 artifacts = runs_artifact_repo._list_run_artifacts() 191 if MLMODEL_FILE_NAME in (art.path for art in artifacts): 192 source = RunsArtifactRepository.get_underlying_uri(model_uri) 193 # Otherwise check if there's a logged model with 194 # name artifact_path and source_run_id run_id 195 else: 196 run = client.get_run(run_id) 197 logged_models = _get_logged_models_from_run(run, artifact_path) 198 if not logged_models: 199 raise MlflowException( 200 f"Unable to find a logged_model with artifact_path {artifact_path} " 201 f"under run {run_id}", 202 error_code=ErrorCode.Name(NOT_FOUND), 203 ) 204 if len(logged_models) > 1: 205 if run.outputs is None: 206 raise MlflowException.invalid_parameter_value( 207 f"Multiple logged models found for run {run_id}. Cannot determine " 208 "which model to register. Please use `models:/<model_id>` instead." 209 ) 210 # If there are multiple such logged models, get the one logged at the largest step 211 model_id_to_step = {m_o.model_id: m_o.step for m_o in run.outputs.model_outputs} 212 model_id = max(logged_models, key=lambda lm: model_id_to_step[lm.model_id]).model_id 213 else: 214 model_id = logged_models[0].model_id 215 source = f"models:/{model_id}" 216 _logger.warning( 217 f"Run with id {run_id} has no artifacts at artifact path {artifact_path!r}, " 218 f"registering model based on {source} instead" 219 ) 220 221 # Otherwise if the uri is of the form models:/..., try to get the model_id from the uri directly 222 model_id = _parse_model_id_if_present(model_uri) if not model_id else model_id 223 224 # Passing in the string value is a shortcut for passing in the EnvPackConfig 225 # Validate early; `_validate_env_pack` will raise on invalid inputs. 226 validated_env_pack = _validate_env_pack(env_pack) 227 228 # Helper to avoid parameter drift below. 229 def _create_model_version(local_model_path: str | None) -> ModelVersion: 230 return client._create_model_version( 231 name=name, 232 source=source, 233 run_id=run_id, 234 tags=tags, 235 await_creation_for=await_registration_for, 236 local_model_path=local_model_path, 237 model_id=model_id, 238 ) 239 240 # If env_pack is supported and indicates Databricks Model Serving, 241 # pack env locally and directly register the resulting artifacts. 242 # This avoids storing artifacts prior to the final registered model version. 243 if validated_env_pack: 244 eprint( 245 "Packing environment for Databricks Model Serving with install_dependencies " 246 f"{validated_env_pack.install_dependencies}..." 247 ) 248 with pack_env_for_databricks_model_serving( 249 model_uri, 250 enforce_pip_requirements=validated_env_pack.install_dependencies, 251 local_model_path=local_model_path, 252 ) as artifacts_path_with_env: 253 create_version_response = _create_model_version(artifacts_path_with_env) 254 else: 255 create_version_response = _create_model_version(local_model_path) 256 created_message = ( 257 f"Created version '{create_version_response.version}' of model " 258 f"'{create_version_response.name}'" 259 ) 260 # Print a link to the UC model version page if the model is in UC. 261 registry_uri = mlflow.get_registry_uri() 262 if ( 263 MLFLOW_PRINT_MODEL_URLS_ON_CREATION.get() 264 and is_databricks_unity_catalog_uri(registry_uri) 265 and (url := get_workspace_url()) 266 ): 267 uc_model_url = _construct_databricks_uc_registered_model_url( 268 url, 269 create_version_response.name, 270 create_version_response.version, 271 get_workspace_id(), 272 ) 273 created_message = "🔗 " + created_message + f": {uc_model_url}" 274 else: 275 created_message += "." 276 eprint(created_message) 277 278 if model_id: 279 new_value = [ 280 { 281 "name": create_version_response.name, 282 "version": create_version_response.version, 283 } 284 ] 285 try: 286 model = client.get_logged_model(model_id) 287 if existing_value := model.tags.get(mlflow_tags.MLFLOW_MODEL_VERSIONS): 288 new_value = json.loads(existing_value) + new_value 289 290 client.set_logged_model_tags( 291 model_id, 292 {mlflow_tags.MLFLOW_MODEL_VERSIONS: json.dumps(new_value)}, 293 ) 294 except MlflowException as e: 295 if e.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST): 296 _logger.warning( 297 "Unable to update logged model tags for model ID '%s': the logged model " 298 "does not exist in the current workspace. No model version link will be " 299 "recorded on the logged model.", 300 model_id, 301 ) 302 else: 303 raise 304 305 if validated_env_pack: 306 eprint( 307 f"Staging model {create_version_response.name} " 308 f"version {create_version_response.version} " 309 "for Databricks Model Serving..." 310 ) 311 try: 312 stage_model_for_databricks_model_serving( 313 model_name=create_version_response.name, 314 model_version=create_version_response.version, 315 ) 316 except Exception as e: 317 eprint( 318 f"Failed to stage model for Databricks Model Serving: {e!s}. " 319 "The model was registered successfully and is available for serving, but may take " 320 "longer to deploy." 321 ) 322 323 return create_version_response 324 325 326 def _get_logged_models_from_run(source_run: Run, model_name: str) -> list[LoggedModel]: 327 """Get all logged models from the source rnu that have the specified model name. 328 329 Args: 330 source_run: Source run from which to retrieve logged models. 331 model_name: Name of the model to retrieve. 332 """ 333 client = MlflowClient() 334 logged_models = [] 335 page_token = None 336 337 while True: 338 logged_models_page = client.search_logged_models( 339 experiment_ids=[source_run.info.experiment_id], 340 # TODO: Filter by 'source_run_id' once Databricks backend supports it 341 filter_string=f"name = '{model_name}'", 342 page_token=page_token, 343 ) 344 logged_models.extend( 345 m for m in logged_models_page if m.source_run_id == source_run.info.run_id 346 ) 347 if not logged_models_page.token: 348 break 349 page_token = logged_models_page.token 350 351 return logged_models 352 353 354 def search_registered_models( 355 max_results: int | None = None, 356 filter_string: str | None = None, 357 order_by: list[str] | None = None, 358 ) -> list[RegisteredModel]: 359 """Search for registered models that satisfy the filter criteria. 360 361 Args: 362 max_results: If passed, specifies the maximum number of models desired. If not 363 passed, all models will be returned. 364 filter_string: Filter query string (e.g., "name = 'a_model_name' and tag.key = 'value1'"), 365 defaults to searching for all registered models. The following identifiers, comparators, 366 and logical operators are supported. 367 368 Identifiers 369 - "name": registered model name. 370 - "tags.<tag_key>": registered model tag. If "tag_key" contains spaces, it must be 371 wrapped with backticks (e.g., "tags.`extra key`"). 372 373 Comparators 374 - "=": Equal to. 375 - "!=": Not equal to. 376 - "LIKE": Case-sensitive pattern match. 377 - "ILIKE": Case-insensitive pattern match. 378 379 Logical operators 380 - "AND": Combines two sub-queries and returns True if both of them are True. 381 382 order_by: List of column names with ASC|DESC annotation, to be used for ordering 383 matching search results. 384 385 Returns: 386 A list of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects 387 that satisfy the search expressions. 388 389 .. code-block:: python 390 :test: 391 :caption: Example 392 393 import mlflow 394 from sklearn.linear_model import LogisticRegression 395 396 with mlflow.start_run(): 397 mlflow.sklearn.log_model( 398 LogisticRegression(), 399 name="Cordoba", 400 registered_model_name="CordobaWeatherForecastModel", 401 ) 402 mlflow.sklearn.log_model( 403 LogisticRegression(), 404 name="Boston", 405 registered_model_name="BostonWeatherForecastModel", 406 ) 407 408 # Get search results filtered by the registered model name 409 filter_string = "name = 'CordobaWeatherForecastModel'" 410 results = mlflow.search_registered_models(filter_string=filter_string) 411 print("-" * 80) 412 for res in results: 413 for mv in res.latest_versions: 414 print(f"name={mv.name}; run_id={mv.run_id}; version={mv.version}") 415 416 # Get search results filtered by the registered model name that matches 417 # prefix pattern 418 filter_string = "name LIKE 'Boston%'" 419 results = mlflow.search_registered_models(filter_string=filter_string) 420 print("-" * 80) 421 for res in results: 422 for mv in res.latest_versions: 423 print(f"name={mv.name}; run_id={mv.run_id}; version={mv.version}") 424 425 # Get all registered models and order them by ascending order of the names 426 results = mlflow.search_registered_models(order_by=["name ASC"]) 427 print("-" * 80) 428 for res in results: 429 for mv in res.latest_versions: 430 print(f"name={mv.name}; run_id={mv.run_id}; version={mv.version}") 431 432 .. code-block:: text 433 :caption: Output 434 435 -------------------------------------------------------------------------------- 436 name=CordobaWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1 437 -------------------------------------------------------------------------------- 438 name=BostonWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1 439 -------------------------------------------------------------------------------- 440 name=BostonWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1 441 name=CordobaWeatherForecastModel; run_id=248c66a666744b4887bdeb2f9cf7f1c6; version=1 442 """ 443 444 def pagination_wrapper_func(number_to_get, next_page_token): 445 return MlflowClient().search_registered_models( 446 max_results=number_to_get, 447 filter_string=filter_string, 448 order_by=order_by, 449 page_token=next_page_token, 450 ) 451 452 return get_results_from_paginated_fn( 453 pagination_wrapper_func, 454 SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, 455 max_results, 456 ) 457 458 459 def search_model_versions( 460 max_results: int | None = None, 461 filter_string: str | None = None, 462 order_by: list[str] | None = None, 463 ) -> list[ModelVersion]: 464 """Search for model versions that satisfy the filter criteria. 465 466 .. warning: 467 468 The model version search results may not have aliases populated for performance reasons. 469 470 Args: 471 max_results: If passed, specifies the maximum number of models desired. If not 472 passed, all models will be returned. 473 filter_string: Filter query string 474 (e.g., ``"name = 'a_model_name' and tag.key = 'value1'"``), 475 defaults to searching for all model versions. The following identifiers, comparators, 476 and logical operators are supported. 477 478 Identifiers 479 - ``name``: model name. 480 - ``source_path``: model version source path. 481 - ``run_id``: The id of the mlflow run that generates the model version. 482 - ``tags.<tag_key>``: model version tag. If ``tag_key`` contains spaces, it must be 483 wrapped with backticks (e.g., ``"tags.`extra key`"``). 484 485 Comparators 486 - ``=``: Equal to. 487 - ``!=``: Not equal to. 488 - ``LIKE``: Case-sensitive pattern match. 489 - ``ILIKE``: Case-insensitive pattern match. 490 - ``IN``: In a value list. Only ``run_id`` identifier supports ``IN`` comparator. 491 492 Logical operators 493 - ``AND``: Combines two sub-queries and returns True if both of them are True. 494 495 order_by: List of column names with ASC|DESC annotation, to be used for ordering 496 matching search results. 497 498 Returns: 499 A list of :py:class:`mlflow.entities.model_registry.ModelVersion` objects 500 that satisfy the search expressions. 501 502 .. code-block:: python 503 :test: 504 :caption: Example 505 506 import mlflow 507 from sklearn.linear_model import LogisticRegression 508 509 for _ in range(2): 510 with mlflow.start_run(): 511 mlflow.sklearn.log_model( 512 LogisticRegression(), 513 name="Cordoba", 514 registered_model_name="CordobaWeatherForecastModel", 515 ) 516 517 # Get all versions of the model filtered by name 518 filter_string = "name = 'CordobaWeatherForecastModel'" 519 results = mlflow.search_model_versions(filter_string=filter_string) 520 print("-" * 80) 521 for res in results: 522 print(f"name={res.name}; run_id={res.run_id}; version={res.version}") 523 524 # Get the version of the model filtered by run_id 525 filter_string = "run_id = 'ae9a606a12834c04a8ef1006d0cff779'" 526 results = mlflow.search_model_versions(filter_string=filter_string) 527 print("-" * 80) 528 for res in results: 529 print(f"name={res.name}; run_id={res.run_id}; version={res.version}") 530 531 .. code-block:: text 532 :caption: Output 533 534 -------------------------------------------------------------------------------- 535 name=CordobaWeatherForecastModel; run_id=ae9a606a12834c04a8ef1006d0cff779; version=2 536 name=CordobaWeatherForecastModel; run_id=d8f028b5fedf4faf8e458f7693dfa7ce; version=1 537 -------------------------------------------------------------------------------- 538 name=CordobaWeatherForecastModel; run_id=ae9a606a12834c04a8ef1006d0cff779; version=2 539 """ 540 541 def pagination_wrapper_func(number_to_get, next_page_token): 542 return MlflowClient().search_model_versions( 543 max_results=number_to_get, 544 filter_string=filter_string, 545 order_by=order_by, 546 page_token=next_page_token, 547 ) 548 549 return get_results_from_paginated_fn( 550 paginated_fn=pagination_wrapper_func, 551 max_results_per_page=SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT, 552 max_results=max_results, 553 ) 554 555 556 def set_model_version_tag( 557 name: str, 558 version: str | None = None, 559 key: str | None = None, 560 value: Any = None, 561 ) -> None: 562 """ 563 Set a tag for the model version. 564 565 Args: 566 name: Registered model name. 567 version: Registered model version. 568 key: Tag key to log. key is required. 569 value: Tag value to log. value is required. 570 """ 571 return MlflowClient().set_model_version_tag( 572 name=name, 573 version=version, 574 key=key, 575 value=value, 576 ) 577 578 579 @require_prompt_registry 580 def register_prompt( 581 name: str, 582 template: str | list[dict[str, Any]], 583 commit_message: str | None = None, 584 tags: dict[str, str] | None = None, 585 response_format: type[BaseModel] | dict[str, Any] | None = None, 586 model_config: "PromptModelConfig | dict[str, Any] | None" = None, 587 ) -> PromptVersion: 588 """ 589 Register a new :py:class:`Prompt <mlflow.entities.Prompt>` in the MLflow Prompt Registry. 590 591 A :py:class:`Prompt <mlflow.entities.Prompt>` is a pair of name and 592 template content at minimum. With MLflow Prompt Registry, you can create, manage, and 593 version control prompts with the MLflow's robust model tracking framework. 594 595 If there is no registered prompt with the given name, a new prompt will be created. 596 Otherwise, a new version of the existing prompt will be created. 597 598 599 Args: 600 name: The name of the prompt. 601 template: The template content of the prompt. Can be either: 602 - A string containing text with variables enclosed in double curly braces, 603 e.g. {{variable}}, which will be replaced with actual values by the `format` method. 604 - A list of dictionaries representing chat messages, where each message has 605 'role' and 'content' keys (e.g., [{"role": "user", "content": "Hello {{name}}"}]) 606 607 .. note:: 608 609 If you want to use the prompt with a framework that uses single curly braces 610 e.g. LangChain, you can use the `to_single_brace_format` method to convert the 611 loaded prompt to a format that uses single curly braces. 612 613 .. code-block:: python 614 615 prompt = client.load_prompt("my_prompt") 616 langchain_format = prompt.to_single_brace_format() 617 618 commit_message: A message describing the changes made to the prompt, similar to a 619 Git commit message. Optional. 620 tags: A dictionary of tags associated with the **prompt version**. 621 This is useful for storing version-specific information, such as the author of 622 the changes. Optional. 623 response_format: Optional Pydantic class or dictionary defining the expected response 624 structure. This can be used to specify the schema for structured outputs from LLM calls. 625 model_config: Optional PromptModelConfig instance or dictionary containing model-specific 626 configuration. Using PromptModelConfig provides validation and type safety. 627 628 Returns: 629 A :py:class:`Prompt <mlflow.entities.Prompt>` object that was created. 630 631 Example: 632 633 .. code-block:: python 634 635 import mlflow 636 from pydantic import BaseModel 637 638 # Register a text prompt 639 mlflow.register_prompt( 640 name="greeting_prompt", 641 template="Respond to the user's message as a {{style}} AI.", 642 response_format={"type": "string", "description": "A friendly response"}, 643 ) 644 645 # Register a chat prompt with multiple messages 646 mlflow.register_prompt( 647 name="assistant_prompt", 648 template=[ 649 {"role": "system", "content": "You are a helpful {{style}} assistant."}, 650 {"role": "user", "content": "{{question}}"}, 651 ], 652 response_format={"type": "object", "properties": {"answer": {"type": "string"}}}, 653 ) 654 655 # Load the prompt from the registry 656 prompt = mlflow.load_prompt("greeting_prompt") 657 658 # Use the prompt in your application 659 import openai 660 661 openai_client = openai.OpenAI() 662 openai_client.chat.completion.create( 663 model="gpt-4o-mini", 664 messages=[ 665 {"role": "system", "content": prompt.format(style="friendly")}, 666 {"role": "user", "content": "Hello, how are you?"}, 667 ], 668 ) 669 670 # Update the prompt with a new version 671 prompt = mlflow.register_prompt( 672 name="greeting_prompt", 673 template="Respond to the user's message as a {{style}} AI. {{greeting}}", 674 commit_message="Add a greeting to the prompt.", 675 tags={"author": "Bob"}, 676 ) 677 """ 678 warnings.warn( 679 PROMPT_API_MIGRATION_MSG.format(func_name="register_prompt"), 680 category=FutureWarning, 681 stacklevel=3, 682 ) 683 684 return MlflowClient().register_prompt( 685 name=name, 686 template=template, 687 commit_message=commit_message, 688 tags=tags, 689 response_format=response_format, 690 model_config=model_config, 691 ) 692 693 694 @require_prompt_registry 695 def search_prompts( 696 filter_string: str | None = None, 697 max_results: int | None = None, 698 ) -> list[Prompt]: 699 """ 700 Search for prompts in the MLflow Prompt Registry. 701 702 This call returns prompt metadata for prompts that have been marked 703 as prompts (i.e. tagged with `mlflow.prompt.is_prompt=true`). We can 704 further restrict results via a standard registry filter expression. 705 706 Args: 707 filter_string (Optional[str]): 708 An additional registry-search expression to apply (e.g. 709 `"name LIKE 'my_prompt%'"`). For Unity Catalog registries, must include 710 catalog and schema: "catalog = 'catalog_name' AND schema = 'schema_name'". 711 max_results (Optional[int]): 712 The maximum number of prompts to return. 713 714 Returns: 715 A list of :py:class:`Prompt <mlflow.entities.Prompt>` objects representing prompt metadata: 716 717 - name: The prompt name 718 - description: The prompt description 719 - tags: Prompt-level tags 720 - creation_timestamp: When the prompt was created 721 722 To get the actual prompt template content, 723 use :py:func:`mlflow.genai.load_prompt()` API with a specific version: 724 725 .. code-block:: python 726 import mlflow 727 728 # Search for prompts 729 prompts = mlflow.genai.search_prompts(filter_string="name LIKE 'greeting%'") 730 731 # Get prompts by experiment 732 prompts = mlflow.genai.search_prompts(filter_string='experiment_id = "1"') 733 734 # Get specific version content 735 for prompt in prompts: 736 prompt_version = mlflow.genai.load_prompt(prompt.name, version="1") 737 print(f"Template: {prompt_version.template}") 738 """ 739 warnings.warn( 740 PROMPT_API_MIGRATION_MSG.format(func_name="search_prompts"), 741 category=FutureWarning, 742 stacklevel=3, 743 ) 744 745 def pagination_wrapper_func(number_to_get, next_page_token): 746 return MlflowClient().search_prompts( 747 filter_string=filter_string, max_results=number_to_get, page_token=next_page_token 748 ) 749 750 return get_results_from_paginated_fn( 751 pagination_wrapper_func, 752 SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, 753 max_results, 754 ) 755 756 757 @require_prompt_registry 758 @record_usage_event(LoadPromptEvent) 759 def load_prompt( 760 name_or_uri: str, 761 version: str | int | None = None, 762 allow_missing: bool = False, 763 link_to_model: bool = True, 764 model_id: str | None = None, 765 cache_ttl_seconds: float | None = None, 766 ) -> PromptVersion: 767 """ 768 Load a :py:class:`Prompt <mlflow.entities.Prompt>` from the MLflow Prompt Registry. 769 770 The prompt can be specified by name and version, or by URI. 771 772 Args: 773 name_or_uri: The name of the prompt, or the URI in the format "prompts:/name/version". 774 version: The version of the prompt (required when using name, not allowed when using URI). 775 allow_missing: If True, return None instead of raising Exception if the specified prompt 776 is not found. 777 link_to_model: If True, the prompt will be linked to the model with the ID specified 778 by `model_id`, or the active model ID if `model_id` is None and 779 there is an active model. 780 model_id: The ID of the model to which to link the prompt, if `link_to_model` is True. 781 cache_ttl_seconds: Time-to-live in seconds for the cached prompt. If not specified, 782 uses the value from `MLFLOW_ALIAS_PROMPT_CACHE_TTL_SECONDS` environment variable for 783 alias-based prompts (default 60), and the value from 784 `MLFLOW_VERSION_PROMPT_CACHE_TTL_SECONDS` environment variable for version-based prompts 785 (default None, no TTL). Set to 0 to bypass the cache and always fetch from the server. 786 787 Example: 788 789 .. code-block:: python 790 791 import mlflow 792 793 # Load a specific version of the prompt 794 prompt = mlflow.load_prompt("my_prompt", version=1) 795 796 # Load a specific version of the prompt by URI 797 prompt = mlflow.load_prompt("prompts:/my_prompt/1") 798 799 # Load a prompt version with an alias "production" 800 prompt = mlflow.load_prompt("prompts:/my_prompt@production") 801 802 # Load with custom cache TTL (5 minutes) 803 prompt = mlflow.load_prompt("my_prompt", version=1, cache_ttl_seconds=300) 804 805 # Bypass cache entirely 806 prompt = mlflow.load_prompt("my_prompt", version=1, cache_ttl_seconds=0) 807 808 """ 809 warnings.warn( 810 PROMPT_API_MIGRATION_MSG.format(func_name="load_prompt"), 811 category=FutureWarning, 812 stacklevel=3, 813 ) 814 815 client = MlflowClient() 816 817 # Load prompt with caching (handled by client) 818 prompt = client.load_prompt( 819 name_or_uri=name_or_uri, 820 version=version, 821 allow_missing=allow_missing, 822 cache_ttl_seconds=cache_ttl_seconds, 823 ) 824 if prompt is None: 825 return 826 827 # If there is an active MLflow run, associate the prompt with the run. 828 # Note that we do this synchronously because it's unlikely that run linking occurs 829 # in a latency sensitive environment, since runs aren't typically used in real-time / 830 # production scenarios 831 # NB: We shouldn't use `active_run()` here because it only returns the active run 832 # from the current thread. It doesn't work in multi-threaded scenarios such as 833 # MLflow GenAI evaluation. 834 if run := _get_latest_active_run(): 835 client.link_prompt_version_to_run(run.info.run_id, prompt) 836 837 if link_to_model: 838 model_id = model_id or get_active_model_id() 839 if model_id is not None: 840 # Run linking in background thread to avoid blocking prompt loading. Prompt linking 841 # is not critical for the user's workflow (if the prompt fails to link, the user's 842 # workflow is minorly affected), so we handle it asynchronously and gracefully 843 # handle any failures without impacting the core prompt loading functionality. 844 845 def _link_prompt_async(): 846 try: 847 client.link_prompt_version_to_model( 848 name=prompt.name, 849 version=prompt.version, 850 model_id=model_id, 851 ) 852 except Exception: 853 # NB: We should still load the prompt even if linking fails, since the prompt 854 # is critical to the caller's application logic 855 _logger.warning( 856 f"Failed to link prompt '{prompt.name}' version '{prompt.version}'" 857 f" to model '{model_id}'.", 858 exc_info=True, 859 ) 860 861 # Start linking in background - don't wait for completion 862 link_thread = threading.Thread( 863 target=_link_prompt_async, name=f"link_prompt_thread-{uuid.uuid4().hex[:8]}" 864 ) 865 link_thread.start() 866 867 if trace_id := get_active_trace_id(): 868 InMemoryTraceManager.get_instance().register_prompt( 869 trace_id=trace_id, 870 prompt=prompt, 871 ) 872 873 # Set prompt version information as span attributes if there's an active span 874 if span := get_current_active_span(): 875 current_value = span.attributes.get(SpanAttributeKey.LINKED_PROMPTS) 876 updated_value = update_linked_prompts_tag(current_value, [prompt]) 877 span.set_attribute(SpanAttributeKey.LINKED_PROMPTS, updated_value) 878 879 return prompt 880 881 882 @require_prompt_registry 883 def set_prompt_alias(name: str, alias: str, version: int) -> None: 884 """ 885 Set an alias for a :py:class:`Prompt <mlflow.entities.Prompt>` in the MLflow Prompt Registry. 886 887 Args: 888 name: The name of the prompt. 889 alias: The alias to set for the prompt. 890 version: The version of the prompt. 891 892 Example: 893 894 .. code-block:: python 895 896 import mlflow 897 898 # Set an alias for the prompt 899 mlflow.set_prompt_alias(name="my_prompt", version=1, alias="production") 900 901 # Load the prompt by alias (use "@" to specify the alias) 902 prompt = mlflow.load_prompt("prompts:/my_prompt@production") 903 904 # Switch the alias to a new version of the prompt 905 mlflow.set_prompt_alias(name="my_prompt", version=2, alias="production") 906 907 # Delete the alias 908 mlflow.delete_prompt_alias(name="my_prompt", alias="production") 909 """ 910 warnings.warn( 911 PROMPT_API_MIGRATION_MSG.format(func_name="set_prompt_alias"), 912 category=FutureWarning, 913 stacklevel=3, 914 ) 915 916 MlflowClient().set_prompt_alias(name=name, version=version, alias=alias) 917 918 919 @require_prompt_registry 920 def delete_prompt_alias(name: str, alias: str) -> None: 921 """ 922 Delete an alias for a :py:class:`Prompt <mlflow.entities.Prompt>` in the MLflow Prompt Registry. 923 924 Args: 925 name: The name of the prompt. 926 alias: The alias to delete for the prompt. 927 """ 928 warnings.warn( 929 PROMPT_API_MIGRATION_MSG.format(func_name="delete_prompt_alias"), 930 category=FutureWarning, 931 stacklevel=3, 932 ) 933 934 MlflowClient().delete_prompt_alias(name=name, alias=alias)