/ mlflow / models / utils.py
utils.py
   1  import base64
   2  import datetime as dt
   3  import decimal
   4  import importlib
   5  import json
   6  import logging
   7  import os
   8  import re
   9  import shutil
  10  import sys
  11  import tempfile
  12  import uuid
  13  from contextlib import contextmanager
  14  from copy import deepcopy
  15  from pathlib import Path
  16  from typing import Any, Dict, List, Union
  17  
  18  import numpy as np
  19  import pandas as pd
  20  import pydantic
  21  
  22  import mlflow
  23  from mlflow.entities import LoggedModel
  24  from mlflow.environment_variables import MLFLOW_DISABLE_SCHEMA_DETAILS
  25  from mlflow.exceptions import INVALID_PARAMETER_VALUE, MlflowException
  26  from mlflow.models import Model
  27  from mlflow.models.model_config import _set_model_config
  28  from mlflow.store.artifact.utils.models import get_model_name_and_version
  29  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  30  from mlflow.types import DataType, ParamSchema, ParamSpec, Schema, TensorSpec
  31  from mlflow.types.schema import AnyType, Array, Map, Object, Property
  32  from mlflow.types.utils import (
  33      TensorsNotSupportedException,
  34      _infer_param_schema,
  35      _is_none_or_nan,
  36      clean_tensor_type,
  37  )
  38  from mlflow.utils.databricks_utils import is_in_databricks_runtime
  39  from mlflow.utils.file_utils import create_tmp_dir, get_local_path_or_none
  40  from mlflow.utils.mlflow_tags import MLFLOW_MODEL_IS_EXTERNAL
  41  from mlflow.utils.proto_json_utils import (
  42      NumpyEncoder,
  43      dataframe_from_parsed_json,
  44      parse_inputs_data,
  45      parse_tf_serving_input,
  46  )
  47  from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri
  48  
  49  try:
  50      from scipy.sparse import csc_matrix, csr_matrix
  51  
  52      HAS_SCIPY = True
  53  except ImportError:
  54      HAS_SCIPY = False
  55  
  56  try:
  57      from pyspark.sql import DataFrame as SparkDataFrame
  58      from pyspark.sql import Row
  59      from pyspark.sql.types import (
  60          ArrayType,
  61          BinaryType,
  62          DateType,
  63          FloatType,
  64          IntegerType,
  65          ShortType,
  66          StructType,
  67          TimestampType,
  68      )
  69  
  70      HAS_PYSPARK = True
  71  except ImportError:
  72      SparkDataFrame = None
  73      HAS_PYSPARK = False
  74  
  75  
  76  INPUT_EXAMPLE_PATH = "artifact_path"
  77  EXAMPLE_DATA_KEY = "inputs"
  78  EXAMPLE_PARAMS_KEY = "params"
  79  EXAMPLE_FILENAME = "input_example.json"
  80  SERVING_INPUT_PATH = "serving_input_path"
  81  SERVING_INPUT_FILENAME = "serving_input_example.json"
  82  
  83  # TODO: import from scoring_server after refactoring
  84  DF_SPLIT = "dataframe_split"
  85  INPUTS = "inputs"
  86  SERVING_PARAMS_KEY = "params"
  87  
  88  ModelInputExample = Union[
  89      pd.DataFrame, np.ndarray, dict, list, "csr_matrix", "csc_matrix", str, bytes, tuple
  90  ]
  91  
  92  PyFuncLLMSingleInput = dict[str, Any] | bool | bytes | float | int | str
  93  
  94  PyFuncLLMOutputChunk = dict[str, Any] | str
  95  
  96  PyFuncInput = Union[
  97      pd.DataFrame,
  98      pd.Series,
  99      np.ndarray,
 100      "csc_matrix",
 101      "csr_matrix",
 102      List[Any],  # noqa: UP006
 103      Dict[str, Any],  # noqa: UP006
 104      dt.datetime,
 105      bool,
 106      bytes,
 107      float,
 108      int,
 109      str,
 110  ]
 111  PyFuncOutput = pd.DataFrame | pd.Series | np.ndarray | list | str | dict[str, Any]
 112  
 113  if HAS_PYSPARK:
 114      PyFuncInput = PyFuncInput | SparkDataFrame
 115      PyFuncOutput = PyFuncOutput | SparkDataFrame
 116  
 117  _logger = logging.getLogger(__name__)
 118  
 119  _FEATURE_STORE_FLAVOR = "databricks.feature_store.mlflow_model"
 120  
 121  
 122  def _is_scalar(x):
 123      return np.isscalar(x) or x is None
 124  
 125  
 126  def _validate_params(params):
 127      try:
 128          _infer_param_schema(params)
 129      except MlflowException:
 130          _logger.warning(f"Invalid params found in input example: {params}")
 131          raise
 132  
 133  
 134  def _is_ndarray(x):
 135      return isinstance(x, np.ndarray) or (
 136          isinstance(x, dict) and all(isinstance(ary, np.ndarray) for ary in x.values())
 137      )
 138  
 139  
 140  def _is_sparse_matrix(x):
 141      if not HAS_SCIPY:
 142          # we can safely assume that if no scipy is installed,
 143          # the user won't log scipy sparse matrices
 144          return False
 145      return isinstance(x, (csc_matrix, csr_matrix))
 146  
 147  
 148  def _handle_ndarray_nans(x: np.ndarray):
 149      if np.issubdtype(x.dtype, np.number):
 150          return np.where(np.isnan(x), None, x)
 151      else:
 152          return x
 153  
 154  
 155  def _handle_ndarray_input(input_array: np.ndarray | dict[str, Any]):
 156      if isinstance(input_array, dict):
 157          result = {}
 158          for name in input_array.keys():
 159              result[name] = _handle_ndarray_nans(input_array[name]).tolist()
 160          return result
 161      else:
 162          return _handle_ndarray_nans(input_array).tolist()
 163  
 164  
 165  def _handle_sparse_matrix(x: Union["csr_matrix", "csc_matrix"]):
 166      return {
 167          "data": _handle_ndarray_nans(x.data).tolist(),
 168          "indices": x.indices.tolist(),
 169          "indptr": x.indptr.tolist(),
 170          "shape": list(x.shape),
 171      }
 172  
 173  
 174  def _handle_dataframe_nans(df: pd.DataFrame):
 175      return df.where(df.notnull(), None)
 176  
 177  
 178  def _coerce_to_pandas_df(input_ex):
 179      if isinstance(input_ex, dict):
 180          # We need to be compatible with infer_schema's behavior, where
 181          # it infers each value's type directly.
 182          if all(
 183              isinstance(x, str) or (isinstance(x, list) and all(_is_scalar(y) for y in x))
 184              for x in input_ex.values()
 185          ):
 186              # e.g.
 187              # data = {"a": "a", "b": ["a", "b", "c"]}
 188              # >>> pd.DataFrame([data])
 189              #    a          b
 190              # 0  a  [a, b, c]
 191              _logger.info(
 192                  "We convert input dictionaries to pandas DataFrames such that "
 193                  "each key represents a column, collectively constituting a "
 194                  "single row of data. If you would like to save data as "
 195                  "multiple rows, please convert your data to a pandas "
 196                  "DataFrame before passing to input_example."
 197              )
 198          input_ex = pd.DataFrame([input_ex])
 199      elif np.isscalar(input_ex):
 200          input_ex = pd.DataFrame([input_ex])
 201      elif not isinstance(input_ex, pd.DataFrame):
 202          input_ex = None
 203      return input_ex
 204  
 205  
 206  def _convert_dataframe_to_split_dict(df):
 207      result = _handle_dataframe_nans(df).to_dict(orient="split")
 208      # Do not include row index
 209      del result["index"]
 210      if all(df.columns == range(len(df.columns))):
 211          # No need to write default column index out
 212          del result["columns"]
 213      return result
 214  
 215  
 216  def _contains_nd_array(data):
 217      import numpy as np
 218  
 219      if isinstance(data, np.ndarray):
 220          return True
 221      if isinstance(data, list):
 222          return any(_contains_nd_array(x) for x in data)
 223      if isinstance(data, dict):
 224          return any(_contains_nd_array(x) for x in data.values())
 225      return False
 226  
 227  
 228  class _Example:
 229      """
 230      Represents an input example for MLflow model.
 231  
 232      Contains jsonable data that can be saved with the model and meta data about the exported format
 233      that can be saved with :py:class:`Model <mlflow.models.Model>`.
 234  
 235      The _Example is created from example data provided by user. The example(s) can be provided as
 236      pandas.DataFrame, numpy.ndarray, python dictionary or python list. The assumption is that the
 237      example contains jsonable elements (see storage format section below). The input example will
 238      be saved as a json serializable object if it is a pandas DataFrame or numpy array.
 239      If the example is a tuple, the first element is considered as the example data and the second
 240      element is considered as the example params.
 241  
 242      NOTE: serving input example is not supported for sparse matrices yet.
 243  
 244      Metadata:
 245  
 246      The _Example metadata contains the following information:
 247          - artifact_path: Relative path to the serialized example within the model directory.
 248          - serving_input_path: Relative path to the serialized example used for model serving
 249              within the model directory.
 250          - type: Type of example data provided by the user. Supported types are:
 251              - ndarray
 252              - dataframe
 253              - json_object
 254              - sparse_matrix_csc
 255              - sparse_matrix_csr
 256              If the `type` is `dataframe`, `pandas_orient` is also stored in the metadata. This
 257              attribute specifies how is the dataframe encoded in json. For example, "split" value
 258              signals that the data is stored as object with columns and data attributes.
 259  
 260      Storage Format:
 261  
 262      The examples are stored as json for portability and readability. Therefore, the contents of the
 263      example(s) must be jsonable. MLflow will make the following conversions automatically on behalf
 264      of the user:
 265  
 266          - binary values: :py:class:`bytes` or :py:class:`bytearray` are converted to base64
 267            encoded strings.
 268          - numpy types: Numpy types are converted to the corresponding python types or their closest
 269            equivalent.
 270          - csc/csr matrix: similar to 2 dims numpy array, csc/csr matrix are converted to
 271            corresponding python types or their closest equivalent.
 272      """
 273  
 274      def __init__(self, input_example: ModelInputExample):
 275          try:
 276              import pyspark.sql
 277  
 278              if isinstance(input_example, pyspark.sql.DataFrame):
 279                  raise MlflowException(
 280                      "Examples can not be provided as Spark Dataframe. "
 281                      "Please make sure your example is of a small size and "
 282                      "turn it into a pandas DataFrame."
 283                  )
 284          except ImportError:
 285              pass
 286  
 287          self.info = {
 288              INPUT_EXAMPLE_PATH: EXAMPLE_FILENAME,
 289          }
 290  
 291          self._inference_data, self._inference_params = _split_input_data_and_params(
 292              deepcopy(input_example)
 293          )
 294          if self._inference_params:
 295              self.info[EXAMPLE_PARAMS_KEY] = "true"
 296          model_input = deepcopy(self._inference_data)
 297  
 298          if isinstance(model_input, pydantic.BaseModel):
 299              model_input = model_input.model_dump()
 300  
 301          is_unified_llm_input = False
 302          if isinstance(model_input, dict):
 303              """
 304              Supported types are:
 305              - Dict[str, Union[DataType, List, Dict]] --> type: json_object
 306              - Dict[str, numpy.ndarray] --> type: ndarray
 307              """
 308              if any(isinstance(values, np.ndarray) for values in model_input.values()):
 309                  if not all(isinstance(values, np.ndarray) for values in model_input.values()):
 310                      raise MlflowException.invalid_parameter_value(
 311                          "Mixed types in dictionary are not supported as input examples. "
 312                          "Found numpy arrays and other types."
 313                      )
 314                  self.info["type"] = "ndarray"
 315                  model_input = _handle_ndarray_input(model_input)
 316                  self.serving_input = {INPUTS: model_input}
 317              else:
 318                  from mlflow.pyfunc.utils.serving_data_parser import is_unified_llm_input
 319  
 320                  self.info["type"] = "json_object"
 321                  is_unified_llm_input = is_unified_llm_input(model_input)
 322                  if is_unified_llm_input:
 323                      self.serving_input = model_input
 324                  else:
 325                      self.serving_input = {INPUTS: model_input}
 326          elif isinstance(model_input, np.ndarray):
 327              """type: ndarray"""
 328              model_input = _handle_ndarray_input(model_input)
 329              self.info["type"] = "ndarray"
 330              self.serving_input = {INPUTS: model_input}
 331          elif isinstance(model_input, list):
 332              """
 333              Supported types are:
 334              - List[DataType]
 335              - List[Dict[str, Union[DataType, List, Dict]]]
 336              --> type: json_object
 337              """
 338              if _contains_nd_array(model_input):
 339                  raise TensorsNotSupportedException(
 340                      "Numpy arrays in list are not supported as input examples."
 341                  )
 342              self.info["type"] = "json_object"
 343              self.serving_input = {INPUTS: model_input}
 344          elif _is_sparse_matrix(model_input):
 345              """
 346              Supported types are:
 347              - scipy.sparse.csr_matrix
 348              - scipy.sparse.csc_matrix
 349              Note: This type of input is not supported by the scoring server yet
 350              """
 351              if isinstance(model_input, csc_matrix):
 352                  example_type = "sparse_matrix_csc"
 353              else:
 354                  example_type = "sparse_matrix_csr"
 355              self.info["type"] = example_type
 356              self.serving_input = {INPUTS: model_input.toarray()}
 357              model_input = _handle_sparse_matrix(model_input)
 358          elif isinstance(model_input, pd.DataFrame):
 359              model_input = _convert_dataframe_to_split_dict(model_input)
 360              self.serving_input = {DF_SPLIT: model_input}
 361              orient = "split" if "columns" in model_input else "values"
 362              self.info.update({
 363                  "type": "dataframe",
 364                  "pandas_orient": orient,
 365              })
 366          elif np.isscalar(model_input) or isinstance(model_input, dt.datetime):
 367              self.info["type"] = "json_object"
 368              self.serving_input = {INPUTS: model_input}
 369          else:
 370              raise MlflowException.invalid_parameter_value(
 371                  "Expected one of the following types:\n"
 372                  "- pandas.DataFrame\n"
 373                  "- numpy.ndarray\n"
 374                  "- dictionary of (name -> numpy.ndarray)\n"
 375                  "- scipy.sparse.csr_matrix\n"
 376                  "- scipy.sparse.csc_matrix\n"
 377                  "- dict\n"
 378                  "- list\n"
 379                  "- scalars\n"
 380                  "- datetime.datetime\n"
 381                  "- pydantic model instance\n"
 382                  f"but got '{type(model_input)}'",
 383              )
 384  
 385          if self._inference_params is not None:
 386              """
 387              Save input data and params with their respective keys, so we can load them separately.
 388              """
 389              model_input = {
 390                  EXAMPLE_DATA_KEY: model_input,
 391                  EXAMPLE_PARAMS_KEY: self._inference_params,
 392              }
 393              if self.serving_input:
 394                  if is_unified_llm_input:
 395                      self.serving_input = {
 396                          **(self.serving_input or {}),
 397                          **self._inference_params,
 398                      }
 399                  else:
 400                      self.serving_input = {
 401                          **(self.serving_input or {}),
 402                          SERVING_PARAMS_KEY: self._inference_params,
 403                      }
 404  
 405          self.json_input_example = json.dumps(model_input, cls=NumpyEncoder)
 406          if self.serving_input:
 407              self.json_serving_input = json.dumps(self.serving_input, cls=NumpyEncoder, indent=2)
 408              self.info[SERVING_INPUT_PATH] = SERVING_INPUT_FILENAME
 409          else:
 410              self.json_serving_input = None
 411  
 412      def save(self, parent_dir_path: str):
 413          """
 414          Save the example as json at ``parent_dir_path``/`self.info['artifact_path']`.
 415          Save serving input as json at ``parent_dir_path``/`self.info['serving_input_path']`.
 416          """
 417          with open(os.path.join(parent_dir_path, self.info[INPUT_EXAMPLE_PATH]), "w") as f:
 418              f.write(self.json_input_example)
 419          if self.json_serving_input:
 420              with open(os.path.join(parent_dir_path, self.info[SERVING_INPUT_PATH]), "w") as f:
 421                  f.write(self.json_serving_input)
 422  
 423      @property
 424      def inference_data(self):
 425          """
 426          Returns the input example in a form that PyFunc wrapped models can score.
 427          """
 428          return self._inference_data
 429  
 430      @property
 431      def inference_params(self):
 432          """
 433          Returns the params dictionary that PyFunc wrapped models can use for scoring.
 434          """
 435          return self._inference_params
 436  
 437  
 438  def _contains_params(input_example):
 439      # For tuple input, we assume the first item is input_example data
 440      # and the second item is params dictionary.
 441      return (
 442          isinstance(input_example, tuple)
 443          and len(input_example) == 2
 444          and isinstance(input_example[1], dict)
 445      )
 446  
 447  
 448  def _split_input_data_and_params(input_example):
 449      if _contains_params(input_example):
 450          input_data, inference_params = input_example
 451          _validate_params(inference_params)
 452          return input_data, inference_params
 453      return input_example, None
 454  
 455  
 456  def convert_input_example_to_serving_input(input_example) -> str | None:
 457      """
 458      Helper function to convert a model's input example to a serving input example that
 459      can be used for model inference in the scoring server.
 460  
 461      Args:
 462          input_example: model input example. Supported types are pandas.DataFrame, numpy.ndarray,
 463              dictionary of (name -> numpy.ndarray), list, scalars and dicts with json serializable
 464              values.
 465  
 466      Returns:
 467          serving input example as a json string
 468      """
 469      if input_example is None:
 470          return None
 471  
 472      example = _Example(input_example)
 473      return example.json_serving_input
 474  
 475  
 476  def _save_example(
 477      mlflow_model: Model, input_example: ModelInputExample | None, path: str
 478  ) -> _Example | None:
 479      """
 480      Saves example to a file on the given path and updates passed Model with example metadata.
 481  
 482      The metadata is a dictionary with the following fields:
 483        - 'artifact_path': example path relative to the model directory.
 484        - 'type': Type of example. Currently the supported values are 'dataframe' and 'ndarray'
 485        -  One of the following metadata based on the `type`:
 486              - 'pandas_orient': Used to store dataframes. Determines the json encoding for dataframe
 487                                 examples in terms of pandas orient convention. Defaults to 'split'.
 488              - 'format: Used to store tensors. Determines the standard used to store a tensor input
 489                         example. MLflow uses a JSON-formatted string representation of TF serving
 490                         input.
 491  
 492      Args:
 493          mlflow_model: Model metadata that will get updated with the example metadata.
 494          path: Where to store the example file. Should be model the model directory.
 495  
 496      Returns:
 497          _Example object that contains saved input example.
 498      """
 499      if input_example is None:
 500          return None
 501  
 502      example = _Example(input_example)
 503      example.save(path)
 504      mlflow_model.saved_input_example_info = example.info
 505      return example
 506  
 507  
 508  def _get_mlflow_model_input_example_dict(
 509      mlflow_model: Model, uri_or_path: str
 510  ) -> dict[str, Any] | None:
 511      """
 512      Args:
 513          mlflow_model: Model metadata.
 514          uri_or_path: Model or run URI, or path to the `model` directory.
 515              e.g. models://<model_name>/<model_version>, runs:/<run_id>/<artifact_path>
 516              or /path/to/model
 517  
 518      Returns:
 519          Input example or None if the model has no example.
 520      """
 521      if mlflow_model.saved_input_example_info is None:
 522          return None
 523      example_type = mlflow_model.saved_input_example_info["type"]
 524      if example_type not in [
 525          "dataframe",
 526          "ndarray",
 527          "sparse_matrix_csc",
 528          "sparse_matrix_csr",
 529          "json_object",
 530      ]:
 531          raise MlflowException(f"This version of mlflow can not load example of type {example_type}")
 532      return json.loads(
 533          _read_file_content(uri_or_path, mlflow_model.saved_input_example_info[INPUT_EXAMPLE_PATH])
 534      )
 535  
 536  
 537  def _load_serving_input_example(mlflow_model: Model, path: str) -> str | None:
 538      """
 539      Load serving input example from a model directory. Returns None if there is no serving input
 540      example.
 541  
 542      Args:
 543          mlflow_model: Model metadata.
 544          path: Path to the model directory.
 545  
 546      Returns:
 547          Serving input example or None if the model has no serving input example.
 548      """
 549      if mlflow_model.saved_input_example_info is None:
 550          return None
 551      serving_input_path = mlflow_model.saved_input_example_info.get(SERVING_INPUT_PATH)
 552      if serving_input_path is None:
 553          return None
 554      with open(os.path.join(path, serving_input_path)) as handle:
 555          return handle.read()
 556  
 557  
 558  def load_serving_example(model_uri_or_path: str):
 559      """
 560      Load serving input example from a model directory or URI.
 561  
 562      Args:
 563          model_uri_or_path: Model URI or path to the `model` directory.
 564              e.g. models://<model_name>/<model_version> or /path/to/model
 565      """
 566      return _read_file_content(model_uri_or_path, SERVING_INPUT_FILENAME)
 567  
 568  
 569  def _read_file_content(uri_or_path: str, file_name: str):
 570      """
 571      Read file content from a model directory or URI.
 572  
 573      Args:
 574          uri_or_path: Model or run URI, or path to the `model` directory.
 575              e.g. models://<model_name>/<model_version>, runs:/<run_id>/<artifact_path>
 576              or /path/to/model
 577          file_name: Name of the file to read.
 578      """
 579      from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository
 580  
 581      if ModelsArtifactRepository._is_logged_model_uri(uri_or_path):
 582          uri_or_path = ModelsArtifactRepository.get_underlying_uri(uri_or_path)
 583  
 584      file_path = str(uri_or_path).rstrip("/") + "/" + file_name
 585      if os.path.exists(file_path):
 586          with open(file_path) as handle:
 587              return handle.read()
 588      else:
 589          with tempfile.TemporaryDirectory() as tmpdir:
 590              local_file_path = _download_artifact_from_uri(file_path, output_path=tmpdir)
 591              with open(local_file_path) as handle:
 592                  return handle.read()
 593  
 594  
 595  def _read_example(mlflow_model: Model, uri_or_path: str):
 596      """
 597      Read example from a model directory. Returns None if there is no example metadata (i.e. the
 598      model was saved without example). Raises FileNotFoundError if there is model metadata but the
 599      example file is missing.
 600  
 601      Args:
 602          mlflow_model: Model metadata.
 603          uri_or_path: Model or run URI, or path to the `model` directory.
 604                  e.g. models://<model_name>/<model_version>, runs:/<run_id>/<artifact_path>
 605                  or /path/to/model
 606  
 607      Returns:
 608          Input example data or None if the model has no example.
 609      """
 610      input_example = _get_mlflow_model_input_example_dict(mlflow_model, uri_or_path)
 611      if input_example is None:
 612          return None
 613  
 614      example_type = mlflow_model.saved_input_example_info["type"]
 615      input_schema = mlflow_model.signature.inputs if mlflow_model.signature is not None else None
 616      if mlflow_model.saved_input_example_info.get(EXAMPLE_PARAMS_KEY, None):
 617          input_example = input_example[EXAMPLE_DATA_KEY]
 618      if example_type == "json_object":
 619          return input_example
 620      if example_type == "ndarray":
 621          return parse_inputs_data(input_example, schema=input_schema)
 622      if example_type in ["sparse_matrix_csc", "sparse_matrix_csr"]:
 623          return _read_sparse_matrix_from_json(input_example, example_type)
 624      if example_type == "dataframe":
 625          return dataframe_from_parsed_json(input_example, pandas_orient="split", schema=input_schema)
 626      raise MlflowException(
 627          "Malformed input example metadata. The 'type' field must be one of "
 628          "'dataframe', 'ndarray', 'sparse_matrix_csc', 'sparse_matrix_csr' or 'json_object'."
 629      )
 630  
 631  
 632  def _read_example_params(mlflow_model: Model, path: str):
 633      """
 634      Read params of input_example from a model directory. Returns None if there is no params
 635      in the input_example or the model was saved without example.
 636      """
 637      if (
 638          mlflow_model.saved_input_example_info is None
 639          or mlflow_model.saved_input_example_info.get(EXAMPLE_PARAMS_KEY, None) is None
 640      ):
 641          return None
 642      input_example_dict = _get_mlflow_model_input_example_dict(mlflow_model, path)
 643      return input_example_dict[EXAMPLE_PARAMS_KEY]
 644  
 645  
 646  def _read_tensor_input_from_json(path_or_data, schema=None):
 647      if isinstance(path_or_data, str) and os.path.exists(path_or_data):
 648          with open(path_or_data) as handle:
 649              inp_dict = json.load(handle)
 650      else:
 651          inp_dict = path_or_data
 652      return parse_tf_serving_input(inp_dict, schema)
 653  
 654  
 655  def _read_sparse_matrix_from_json(path_or_data, example_type):
 656      if isinstance(path_or_data, str) and os.path.exists(path_or_data):
 657          with open(path_or_data) as handle:
 658              matrix_data = json.load(handle)
 659      else:
 660          matrix_data = path_or_data
 661      data = matrix_data["data"]
 662      indices = matrix_data["indices"]
 663      indptr = matrix_data["indptr"]
 664      shape = tuple(matrix_data["shape"])
 665  
 666      if example_type == "sparse_matrix_csc":
 667          return csc_matrix((data, indices, indptr), shape=shape)
 668      else:
 669          return csr_matrix((data, indices, indptr), shape=shape)
 670  
 671  
 672  def plot_lines(data_series, xlabel, ylabel, legend_loc=None, line_kwargs=None, title=None):
 673      import matplotlib.pyplot as plt
 674  
 675      fig, ax = plt.subplots()
 676  
 677      if line_kwargs is None:
 678          line_kwargs = {}
 679  
 680      for label, data_x, data_y in data_series:
 681          ax.plot(data_x, data_y, label=label, **line_kwargs)
 682  
 683      if legend_loc:
 684          ax.legend(loc=legend_loc)
 685  
 686      ax.set(xlabel=xlabel, ylabel=ylabel, title=title)
 687  
 688      return fig, ax
 689  
 690  
 691  def _enforce_tensor_spec(
 692      values: Union[np.ndarray, "csc_matrix", "csr_matrix"],
 693      tensor_spec: TensorSpec,
 694  ):
 695      """
 696      Enforce the input tensor shape and type matches the provided tensor spec.
 697      """
 698      expected_shape = tensor_spec.shape
 699      expected_type = tensor_spec.type
 700      actual_shape = values.shape
 701      actual_type = values.dtype if isinstance(values, np.ndarray) else values.data.dtype
 702  
 703      # This logic is for handling "ragged" arrays. The first check is for a standard numpy shape
 704      # representation of a ragged array. The second is for handling a more manual specification
 705      # of shape while support an input which is a ragged array.
 706      if len(expected_shape) == 1 and expected_shape[0] == -1 and expected_type == np.dtype("O"):
 707          # Sample spec: Tensor('object', (-1,))
 708          # Will pass on any provided input
 709          return values
 710      if (
 711          len(expected_shape) > 1
 712          and -1 in expected_shape[1:]
 713          and len(actual_shape) == 1
 714          and actual_type == np.dtype("O")
 715      ):
 716          # Sample spec: Tensor('float64', (-1, -1, -1, 3))
 717          # Will pass on inputs which are ragged arrays: shape==(x,), dtype=='object'
 718          return values
 719  
 720      if len(expected_shape) != len(actual_shape):
 721          raise MlflowException(
 722              f"Shape of input {actual_shape} does not match expected shape {expected_shape}."
 723          )
 724      for expected, actual in zip(expected_shape, actual_shape):
 725          if expected == -1:
 726              continue
 727          if expected != actual:
 728              raise MlflowException(
 729                  f"Shape of input {actual_shape} does not match expected shape {expected_shape}."
 730              )
 731      if clean_tensor_type(actual_type) != expected_type:
 732          raise MlflowException(
 733              f"dtype of input {actual_type} does not match expected dtype {expected_type}"
 734          )
 735      return values
 736  
 737  
 738  def _enforce_mlflow_datatype(name, values: pd.Series, t: DataType):
 739      """
 740      Enforce the input column type matches the declared in model input schema.
 741  
 742      The following type conversions are allowed:
 743  
 744      1. object -> string
 745      2. int -> long (upcast)
 746      3. float -> double (upcast)
 747      4. int -> double (safe conversion)
 748      5. np.datetime64[x] -> datetime (any precision)
 749      6. object -> datetime
 750  
 751      NB: pandas does not have native decimal data type, when user train and infer
 752      model from pyspark dataframe that contains decimal type, the schema will be
 753      treated as float64.
 754      7. decimal -> double
 755  
 756      Any other type mismatch will raise error.
 757      """
 758  
 759      if values.dtype == object and t not in (DataType.binary, DataType.string):
 760          values = values.infer_objects()
 761  
 762      if t == DataType.string and (
 763          values.dtype == object or isinstance(values.dtype, pd.StringDtype)
 764      ):
 765          # NB: the object can contain any type and we currently cannot cast to pandas Strings
 766          # due to how None is cast
 767          return values
 768  
 769      # NB: Comparison of pandas and numpy data type fails when numpy data type is on the left hand
 770      # side of the comparison operator. It works, however, if pandas type is on the left hand side.
 771      # That is because pandas is aware of numpy.
 772      if t.to_pandas() == values.dtype or t.to_numpy() == values.dtype:
 773          # The types are already compatible => conversion is not necessary.
 774          return values
 775  
 776      if t == DataType.binary and values.dtype.kind == t.binary.to_numpy().kind:
 777          # NB: bytes in numpy have variable itemsize depending on the length of the longest
 778          # element in the array (column). Since MLflow binary type is length agnostic, we ignore
 779          # itemsize when matching binary columns.
 780          return values
 781  
 782      if t == DataType.datetime and values.dtype.kind == t.to_numpy().kind:
 783          # NB: datetime values have variable precision denoted by brackets, e.g. datetime64[ns]
 784          # denotes nanosecond precision. Since MLflow datetime type is precision agnostic, we
 785          # ignore precision when matching datetime columns.
 786          try:
 787              return values.astype(np.dtype("datetime64[ns]"))
 788          except TypeError as e:
 789              raise MlflowException(
 790                  "Please ensure that the input data of datetime column only contains timezone-naive "
 791                  f"datetime objects. Error: {e}"
 792              )
 793  
 794      if t == DataType.datetime and (values.dtype == object or values.dtype == t.to_python()):
 795          # NB: Pyspark date columns get converted to object when converted to a pandas
 796          # DataFrame. To respect the original typing, we convert the column to datetime.
 797          try:
 798              return values.astype(np.dtype("datetime64[ns]"), errors="raise")
 799          except ValueError as e:
 800              raise MlflowException(
 801                  f"Failed to convert column {name} from type {values.dtype} to {t}."
 802              ) from e
 803  
 804      if t == DataType.boolean and values.dtype == object:
 805          # Should not convert type otherwise it converts None to boolean False
 806          return values
 807  
 808      if t == DataType.double and values.dtype == decimal.Decimal:
 809          # NB: Pyspark Decimal column get converted to decimal.Decimal when converted to pandas
 810          # DataFrame. In order to support decimal data training from spark data frame, we add this
 811          # conversion even we might lose the precision.
 812          try:
 813              return pd.to_numeric(values, errors="raise")
 814          except ValueError:
 815              raise MlflowException(
 816                  f"Failed to convert column {name} from type {values.dtype} to {t}."
 817              )
 818  
 819      numpy_type = t.to_numpy()
 820      if values.dtype.kind == numpy_type.kind:
 821          is_upcast = values.dtype.itemsize <= numpy_type.itemsize
 822      elif values.dtype.kind == "u" and numpy_type.kind == "i":
 823          is_upcast = values.dtype.itemsize < numpy_type.itemsize
 824      elif values.dtype.kind in ("i", "u") and numpy_type == np.float64:
 825          # allow (u)int => double conversion
 826          is_upcast = values.dtype.itemsize <= 6
 827      else:
 828          is_upcast = False
 829  
 830      if is_upcast:
 831          return values.astype(numpy_type, errors="raise")
 832      else:
 833          # support converting long -> float/double for 0 and 1 values
 834          def all_zero_or_ones(xs):
 835              return all(pd.isnull(x) or x in [0, 1] for x in xs)
 836  
 837          if (
 838              values.dtype == np.int64
 839              and numpy_type in (np.float32, np.float64)
 840              and all_zero_or_ones(values)
 841          ):
 842              return values.astype(numpy_type, errors="raise")
 843  
 844          # NB: conversion between incompatible types (e.g. floats -> ints or
 845          # double -> float) are not allowed. While supported by pandas and numpy,
 846          # these conversions alter the values significantly.
 847          def all_ints(xs):
 848              return all(pd.isnull(x) or int(x) == x for x in xs)
 849  
 850          hint = ""
 851          if (
 852              values.dtype == np.float64
 853              and numpy_type.kind in ("i", "u")
 854              and values.hasnans
 855              and all_ints(values)
 856          ):
 857              hint = (
 858                  " Hint: the type mismatch is likely caused by missing values. "
 859                  "Integer columns in python can not represent missing values and are therefore "
 860                  "encoded as floats. The best way to avoid this problem is to infer the model "
 861                  "schema based on a realistic data sample (training dataset) that includes missing "
 862                  "values. Alternatively, you can declare integer columns as doubles (float64) "
 863                  "whenever these columns may have missing values. See `Handling Integers With "
 864                  "Missing Values <https://www.mlflow.org/docs/latest/models.html#"
 865                  "handling-integers-with-missing-values>`_ for more details."
 866              )
 867  
 868          raise MlflowException(
 869              f"Incompatible input types for column {name}. "
 870              f"Can not safely convert {values.dtype} to {numpy_type}.{hint}"
 871          )
 872  
 873  
 874  # dtype -> possible value types mapping
 875  _ALLOWED_CONVERSIONS_FOR_PARAMS = {
 876      DataType.long: (DataType.integer,),
 877      DataType.float: (DataType.integer, DataType.long),
 878      DataType.double: (DataType.integer, DataType.long, DataType.float),
 879  }
 880  
 881  
 882  def _enforce_param_datatype(value: Any, dtype: DataType):
 883      """
 884      Enforce the value matches the data type. This is used to enforce params datatype.
 885      The returned data is of python built-in type or a datetime object.
 886  
 887      The following type conversions are allowed:
 888  
 889      1. int -> long, float, double
 890      2. long -> float, double
 891      3. float -> double
 892      4. any -> datetime (try conversion)
 893  
 894      Any other type mismatch will raise error.
 895  
 896      Args:
 897          value: parameter value
 898          dtype: expected data type
 899      """
 900      if value is None:
 901          return
 902  
 903      if dtype == DataType.datetime:
 904          try:
 905              datetime_value = np.datetime64(value).item()
 906              if isinstance(datetime_value, int):
 907                  raise MlflowException.invalid_parameter_value(
 908                      f"Failed to convert value to `{dtype}`. "
 909                      f"It must be convertible to datetime.date/datetime, got `{value}`"
 910                  )
 911              return datetime_value
 912          except ValueError as e:
 913              raise MlflowException.invalid_parameter_value(
 914                  f"Failed to convert value `{value}` from type `{type(value)}` to `{dtype}`"
 915              ) from e
 916  
 917      # Note that np.isscalar(datetime.date(...)) is False
 918      if not np.isscalar(value):
 919          raise MlflowException.invalid_parameter_value(
 920              f"Value must be a scalar for type `{dtype}`, got `{value}`"
 921          )
 922  
 923      # Always convert to python native type for params
 924      if DataType.check_type(dtype, value):
 925          return dtype.to_python()(value)
 926  
 927      if dtype in _ALLOWED_CONVERSIONS_FOR_PARAMS and any(
 928          DataType.check_type(t, value) for t in _ALLOWED_CONVERSIONS_FOR_PARAMS[dtype]
 929      ):
 930          try:
 931              return dtype.to_python()(value)
 932          except ValueError as e:
 933              raise MlflowException.invalid_parameter_value(
 934                  f"Failed to convert value `{value}` from type `{type(value)}` to `{dtype}`"
 935              ) from e
 936  
 937      raise MlflowException.invalid_parameter_value(
 938          f"Can not safely convert `{type(value)}` to `{dtype}` for value `{value}`"
 939      )
 940  
 941  
 942  def _enforce_unnamed_col_schema(pf_input: pd.DataFrame, input_schema: Schema):
 943      """Enforce the input columns conform to the model's column-based signature."""
 944      input_names = pf_input.columns[: len(input_schema.inputs)]
 945      input_types = input_schema.input_types()
 946      new_pf_input = {}
 947      for i, x in enumerate(input_names):
 948          if isinstance(input_types[i], DataType):
 949              new_pf_input[x] = _enforce_mlflow_datatype(x, pf_input[x], input_types[i])
 950          # If the input_type is objects/arrays/maps, we assume pf_input must be a pandas DataFrame.
 951          # Otherwise, the schema is not valid.
 952          else:
 953              new_pf_input[x] = pd.Series(
 954                  [_enforce_type(obj, input_types[i]) for obj in pf_input[x]], name=x
 955              )
 956      return pd.DataFrame(new_pf_input)
 957  
 958  
 959  def _enforce_named_col_schema(pf_input: pd.DataFrame, input_schema: Schema):
 960      """Enforce the input columns conform to the model's column-based signature."""
 961      input_names = input_schema.input_names()
 962      input_dict = input_schema.input_dict()
 963      new_pf_input = {}
 964      for name in input_names:
 965          input_type = input_dict[name].type
 966          required = input_dict[name].required
 967          if name not in pf_input:
 968              if required:
 969                  raise MlflowException(
 970                      f"The input column '{name}' is required by the model "
 971                      "signature but missing from the input data."
 972                  )
 973              else:
 974                  continue
 975          if isinstance(input_type, DataType):
 976              new_pf_input[name] = _enforce_mlflow_datatype(name, pf_input[name], input_type)
 977          # If the input_type is objects/arrays/maps, we assume pf_input must be a pandas DataFrame.
 978          # Otherwise, the schema is not valid.
 979          else:
 980              new_pf_input[name] = pd.Series(
 981                  [_enforce_type(obj, input_type, required) for obj in pf_input[name]], name=name
 982              )
 983      return pd.DataFrame(new_pf_input)
 984  
 985  
 986  def _reshape_and_cast_pandas_column_values(name, pd_series, tensor_spec):
 987      if tensor_spec.shape[0] != -1 or -1 in tensor_spec.shape[1:]:
 988          raise MlflowException(
 989              "For pandas dataframe input, the first dimension of shape must be a variable "
 990              "dimension and other dimensions must be fixed, but in model signature the shape "
 991              f"of {'input ' + name if name else 'the unnamed input'} is {tensor_spec.shape}."
 992          )
 993  
 994      if np.isscalar(pd_series[0]):
 995          for shape in [(-1,), (-1, 1)]:
 996              if tensor_spec.shape == shape:
 997                  return _enforce_tensor_spec(
 998                      np.array(pd_series, dtype=tensor_spec.type).reshape(shape), tensor_spec
 999                  )
1000          raise MlflowException(
1001              f"The input pandas dataframe column '{name}' contains scalar "
1002              "values, which requires the shape to be (-1,) or (-1, 1), but got tensor spec "
1003              f"shape of {tensor_spec.shape}.",
1004              error_code=INVALID_PARAMETER_VALUE,
1005          )
1006      elif isinstance(pd_series[0], list) and np.isscalar(pd_series[0][0]):
1007          # If the pandas column contains list type values,
1008          # in this case, the shape and type information is lost,
1009          # so do not enforce the shape and type, instead,
1010          # reshape the array value list to the required shape, and cast value type to
1011          # required type.
1012          reshape_err_msg = (
1013              f"The value in the Input DataFrame column '{name}' could not be converted to the "
1014              f"expected shape of: '{tensor_spec.shape}'. Ensure that each of the input list "
1015              "elements are of uniform length and that the data can be coerced to the tensor "
1016              f"type '{tensor_spec.type}'"
1017          )
1018          try:
1019              flattened_numpy_arr = np.vstack(pd_series.tolist())
1020              reshaped_numpy_arr = flattened_numpy_arr.reshape(tensor_spec.shape).astype(
1021                  tensor_spec.type
1022              )
1023          except ValueError:
1024              raise MlflowException(reshape_err_msg, error_code=INVALID_PARAMETER_VALUE)
1025          if len(reshaped_numpy_arr) != len(pd_series):
1026              raise MlflowException(reshape_err_msg, error_code=INVALID_PARAMETER_VALUE)
1027          return reshaped_numpy_arr
1028      elif isinstance(pd_series[0], np.ndarray):
1029          reshape_err_msg = (
1030              f"The value in the Input DataFrame column '{name}' could not be converted to the "
1031              f"expected shape of: '{tensor_spec.shape}'. Ensure that each of the input numpy "
1032              "array elements are of uniform length and can be reshaped to above expected shape."
1033          )
1034          try:
1035              # Because numpy array includes precise type information, so we don't convert type
1036              # here, so that in following schema validation we can have strict type check on
1037              # numpy array column.
1038              reshaped_numpy_arr = np.vstack(pd_series.tolist()).reshape(tensor_spec.shape)
1039          except ValueError:
1040              raise MlflowException(reshape_err_msg, error_code=INVALID_PARAMETER_VALUE)
1041          if len(reshaped_numpy_arr) != len(pd_series):
1042              raise MlflowException(reshape_err_msg, error_code=INVALID_PARAMETER_VALUE)
1043          return reshaped_numpy_arr
1044      else:
1045          raise MlflowException(
1046              "Because the model signature requires tensor spec input, the input "
1047              "pandas dataframe values should be either scalar value, python list "
1048              "containing scalar values or numpy array containing scalar values, "
1049              "other types are not supported.",
1050              error_code=INVALID_PARAMETER_VALUE,
1051          )
1052  
1053  
1054  def _enforce_tensor_schema(pf_input: PyFuncInput, input_schema: Schema):
1055      """Enforce the input tensor(s) conforms to the model's tensor-based signature."""
1056  
1057      def _is_sparse_matrix(x):
1058          if not HAS_SCIPY:
1059              # we can safely assume that it's not a sparse matrix if scipy is not installed
1060              return False
1061          return isinstance(x, (csr_matrix, csc_matrix))
1062  
1063      if input_schema.has_input_names():
1064          if isinstance(pf_input, dict):
1065              new_pf_input = {}
1066              for col_name, tensor_spec in zip(input_schema.input_names(), input_schema.inputs):
1067                  if not isinstance(pf_input[col_name], np.ndarray):
1068                      raise MlflowException(
1069                          "This model contains a tensor-based model signature with input names,"
1070                          " which suggests a dictionary input mapping input name to a numpy"
1071                          f" array, but a dict with value type {type(pf_input[col_name])} was found.",
1072                          error_code=INVALID_PARAMETER_VALUE,
1073                      )
1074                  new_pf_input[col_name] = _enforce_tensor_spec(pf_input[col_name], tensor_spec)
1075          elif isinstance(pf_input, pd.DataFrame):
1076              new_pf_input = {}
1077              for col_name, tensor_spec in zip(input_schema.input_names(), input_schema.inputs):
1078                  pd_series = pf_input[col_name]
1079                  new_pf_input[col_name] = _reshape_and_cast_pandas_column_values(
1080                      col_name, pd_series, tensor_spec
1081                  )
1082          else:
1083              raise MlflowException(
1084                  "This model contains a tensor-based model signature with input names, which"
1085                  " suggests a dictionary input mapping input name to tensor, or a pandas"
1086                  " DataFrame input containing columns mapping input name to flattened list value"
1087                  f" from tensor, but an input of type {type(pf_input)} was found.",
1088                  error_code=INVALID_PARAMETER_VALUE,
1089              )
1090      else:
1091          tensor_spec = input_schema.inputs[0]
1092          if isinstance(pf_input, pd.DataFrame):
1093              num_input_columns = len(pf_input.columns)
1094              if pf_input.empty:
1095                  raise MlflowException("Input DataFrame is empty.")
1096              elif num_input_columns == 1:
1097                  new_pf_input = _reshape_and_cast_pandas_column_values(
1098                      None, pf_input[pf_input.columns[0]], tensor_spec
1099                  )
1100              else:
1101                  if tensor_spec.shape != (-1, num_input_columns):
1102                      raise MlflowException(
1103                          "This model contains a model signature with an unnamed input. Since the "
1104                          "input data is a pandas DataFrame containing multiple columns, "
1105                          "the input shape must be of the structure "
1106                          "(-1, number_of_dataframe_columns). "
1107                          f"Instead, the input DataFrame passed had {num_input_columns} columns and "
1108                          f"an input shape of {tensor_spec.shape} with all values within the "
1109                          "DataFrame of scalar type. Please adjust the passed in DataFrame to "
1110                          "match the expected structure",
1111                          error_code=INVALID_PARAMETER_VALUE,
1112                      )
1113                  new_pf_input = _enforce_tensor_spec(pf_input.to_numpy(), tensor_spec)
1114          elif isinstance(pf_input, np.ndarray) or _is_sparse_matrix(pf_input):
1115              new_pf_input = _enforce_tensor_spec(pf_input, tensor_spec)
1116          else:
1117              raise MlflowException(
1118                  "This model contains a tensor-based model signature with no input names,"
1119                  " which suggests a numpy array input or a pandas dataframe input with"
1120                  f" proper column values, but an input of type {type(pf_input)} was found.",
1121                  error_code=INVALID_PARAMETER_VALUE,
1122              )
1123      return new_pf_input
1124  
1125  
1126  def _enforce_schema(pf_input: PyFuncInput, input_schema: Schema, flavor: str | None = None):
1127      """
1128      Enforces the provided input matches the model's input schema,
1129  
1130      For signatures with input names, we check there are no missing inputs and reorder the inputs to
1131      match the ordering declared in schema if necessary. Any extra columns are ignored.
1132  
1133      For column-based signatures, we make sure the types of the input match the type specified in
1134      the schema or if it can be safely converted to match the input schema.
1135  
1136      For Pyspark DataFrame inputs, MLflow casts a sample of the PySpark DataFrame into a Pandas
1137      DataFrame. MLflow will only enforce the schema on a subset of the data rows.
1138  
1139      For tensor-based signatures, we make sure the shape and type of the input matches the shape
1140      and type specified in model's input schema.
1141      """
1142  
1143      def _is_scalar(x):
1144          return np.isscalar(x) or x is None
1145  
1146      original_pf_input = pf_input
1147      if isinstance(pf_input, pd.Series):
1148          pf_input = pd.DataFrame(pf_input)
1149      if not input_schema.is_tensor_spec():
1150          # convert single DataType to pandas DataFrame
1151          if np.isscalar(pf_input):
1152              pf_input = pd.DataFrame([pf_input])
1153          elif isinstance(pf_input, dict):
1154              # keys are column names
1155              if any(
1156                  isinstance(col_spec.type, (Array, Object)) for col_spec in input_schema.inputs
1157              ) or all(
1158                  _is_scalar(value)
1159                  or (isinstance(value, list) and all(isinstance(item, str) for item in value))
1160                  for value in pf_input.values()
1161              ):
1162                  pf_input = pd.DataFrame([pf_input])
1163              else:
1164                  try:
1165                      # This check is specifically to handle the serving structural cast for
1166                      # certain inputs for the transformers implementation. Due to the fact that
1167                      # specific Pipeline types in transformers support passing input data
1168                      # of the form Dict[str, str] in which the value is a scalar string, model
1169                      # serving will cast this entry as a numpy array with shape () and size 1.
1170                      # This is seen as a scalar input when attempting to create a Pandas
1171                      # DataFrame from such a numpy structure and requires the array to be
1172                      # encapsulated in a list in order to prevent a ValueError exception for
1173                      # requiring an index if passing in all scalar values thrown by Pandas.
1174                      if all(
1175                          isinstance(value, np.ndarray)
1176                          and value.dtype.type == np.str_
1177                          and value.size == 1
1178                          and value.shape == ()
1179                          for value in pf_input.values()
1180                      ):
1181                          pf_input = pd.DataFrame([pf_input])
1182                      elif any(
1183                          isinstance(value, np.ndarray) and value.ndim > 1
1184                          for value in pf_input.values()
1185                      ):
1186                          # Pandas DataFrames can't be constructed with embedded multi-dimensional
1187                          # numpy arrays. Accordingly, we convert any multi-dimensional numpy
1188                          # arrays to lists before constructing a DataFrame. This is safe because
1189                          # ColSpec model signatures do not support array columns, so subsequent
1190                          # validation logic will result in a clear "incompatible input types"
1191                          # exception. This is preferable to a pandas DataFrame construction error
1192                          pf_input = pd.DataFrame({
1193                              key: (
1194                                  value.tolist()
1195                                  if (isinstance(value, np.ndarray) and value.ndim > 1)
1196                                  else value
1197                              )
1198                              for key, value in pf_input.items()
1199                          })
1200                      else:
1201                          pf_input = pd.DataFrame(pf_input)
1202                  except Exception as e:
1203                      raise MlflowException(
1204                          "This model contains a column-based signature, which suggests a DataFrame"
1205                          " input. There was an error casting the input data to a DataFrame:"
1206                          f" {e}"
1207                      )
1208          elif isinstance(pf_input, (list, np.ndarray, pd.Series)):
1209              pf_input = pd.DataFrame(pf_input)
1210          elif HAS_PYSPARK and isinstance(pf_input, SparkDataFrame):
1211              pf_input = pf_input.limit(10).toPandas()
1212              for field in original_pf_input.schema.fields:
1213                  if isinstance(field.dataType, (StructType, ArrayType)):
1214                      pf_input[field.name] = pf_input[field.name].apply(
1215                          lambda row: convert_complex_types_pyspark_to_pandas(row, field.dataType)
1216                      )
1217          if not isinstance(pf_input, pd.DataFrame):
1218              raise MlflowException(
1219                  f"Expected input to be DataFrame. Found: {type(pf_input).__name__}"
1220              )
1221  
1222      if input_schema.has_input_names():
1223          # make sure there are no missing columns
1224          input_names = input_schema.required_input_names()
1225          optional_names = input_schema.optional_input_names()
1226          expected_required_cols = set(input_names)
1227          actual_cols = set()
1228          optional_cols = set(optional_names)
1229          if len(expected_required_cols) == 1 and isinstance(pf_input, np.ndarray):
1230              # for schemas with a single column, match input with column
1231              pf_input = {input_names[0]: pf_input}
1232              actual_cols = expected_required_cols
1233          elif isinstance(pf_input, pd.DataFrame):
1234              actual_cols = set(pf_input.columns)
1235          elif isinstance(pf_input, dict):
1236              actual_cols = set(pf_input.keys())
1237          missing_cols = expected_required_cols - actual_cols
1238          extra_cols = actual_cols - expected_required_cols - optional_cols
1239          # Preserve order from the original columns, since missing/extra columns are likely to
1240          # be in same order.
1241          missing_cols = [c for c in input_names if c in missing_cols]
1242          extra_cols = [c for c in actual_cols if c in extra_cols]
1243          if missing_cols:
1244              # If the user has set MLFLOW_DISABLE_SCHEMA_DETAILS to true, we raise a generic error
1245              if MLFLOW_DISABLE_SCHEMA_DETAILS.get():
1246                  message = "Input schema validation failed. Mismatched or missing input(s)."
1247                  if extra_cols:
1248                      message += " Note that there were extra inputs provided."
1249              else:
1250                  message = f"Model is missing inputs {missing_cols}."
1251                  if extra_cols:
1252                      message += f" Note that there were extra inputs: {extra_cols}."
1253              raise MlflowException(message)
1254  
1255          if extra_cols:
1256              _logger.warning(
1257                  "Found extra inputs in the model input that are not defined in the model "
1258                  f"signature: `{extra_cols}`. These inputs will be ignored."
1259              )
1260      elif not input_schema.is_tensor_spec():
1261          # The model signature does not specify column names => we can only verify column count.
1262          num_actual_columns = len(pf_input.columns)
1263          if num_actual_columns < len(input_schema.inputs):
1264              raise MlflowException(
1265                  "Model inference is missing inputs. The model signature declares "
1266                  "{} inputs  but the provided value only has "
1267                  "{} inputs. Note: the inputs were not named in the signature so we can "
1268                  "only verify their count.".format(len(input_schema.inputs), num_actual_columns)
1269              )
1270      if input_schema.is_tensor_spec():
1271          return _enforce_tensor_schema(pf_input, input_schema)
1272      elif HAS_PYSPARK and isinstance(original_pf_input, SparkDataFrame):
1273          return _enforce_pyspark_dataframe_schema(
1274              original_pf_input, pf_input, input_schema, flavor=flavor
1275          )
1276      else:
1277          # pf_input must be a pandas Dataframe at this point
1278          return (
1279              _enforce_named_col_schema(pf_input, input_schema)
1280              if input_schema.has_input_names()
1281              else _enforce_unnamed_col_schema(pf_input, input_schema)
1282          )
1283  
1284  
1285  def _enforce_pyspark_dataframe_schema(
1286      original_pf_input: SparkDataFrame,
1287      pf_input_as_pandas,
1288      input_schema: Schema,
1289      flavor: str | None = None,
1290  ):
1291      """
1292      Enforce that the input PySpark DataFrame conforms to the model's input schema.
1293  
1294      This function creates a new DataFrame that only includes the columns from the original
1295      DataFrame that are declared in the model's input schema. Any extra columns in the original
1296      DataFrame are dropped.Note that this function does not modify the original DataFrame.
1297  
1298      Args:
1299          original_pf_input: Original input PySpark DataFrame.
1300          pf_input_as_pandas: Input DataFrame converted to pandas.
1301          input_schema: Expected schema of the input DataFrame.
1302          flavor: Optional model flavor. If specified, it is used to handle specific behaviors
1303              for different model flavors. Currently, only the '_FEATURE_STORE_FLAVOR' is
1304              handled specially.
1305  
1306      Returns:
1307          New PySpark DataFrame that conforms to the model's input schema.
1308      """
1309      if not HAS_PYSPARK:
1310          raise MlflowException("PySpark is not installed. Cannot handle a PySpark DataFrame.")
1311      new_pf_input = original_pf_input.alias("pf_input_copy")
1312      if input_schema.has_input_names():
1313          _enforce_named_col_schema(pf_input_as_pandas, input_schema)
1314          input_names = input_schema.input_names()
1315  
1316      else:
1317          _enforce_unnamed_col_schema(pf_input_as_pandas, input_schema)
1318          input_names = pf_input_as_pandas.columns[: len(input_schema.inputs)]
1319      columns_to_drop = []
1320      columns_not_dropped_for_feature_store_model = []
1321      for col, dtype in new_pf_input.dtypes:
1322          if col not in input_names:
1323              # to support backwards compatibility with feature store models
1324              if any(x in dtype for x in ["array", "map", "struct"]):
1325                  if flavor == _FEATURE_STORE_FLAVOR:
1326                      columns_not_dropped_for_feature_store_model.append(col)
1327                      continue
1328              columns_to_drop.append(col)
1329      if columns_not_dropped_for_feature_store_model:
1330          _logger.warning(
1331              "The following columns are not in the model signature but "
1332              "are not dropped for feature store model: %s",
1333              ", ".join(columns_not_dropped_for_feature_store_model),
1334          )
1335      return new_pf_input.drop(*columns_to_drop)
1336  
1337  
1338  def _enforce_datatype(data: Any, dtype: DataType, required=True):
1339      if not required and _is_none_or_nan(data):
1340          return None
1341  
1342      if not isinstance(dtype, DataType):
1343          raise MlflowException(f"Expected dtype to be DataType, got {type(dtype).__name__}")
1344      if not np.isscalar(data):
1345          raise MlflowException(f"Expected data to be scalar, got {type(data).__name__}")
1346      # Reuse logic in _enforce_mlflow_datatype for type conversion
1347      pd_series = pd.Series(data)
1348      try:
1349          pd_series = _enforce_mlflow_datatype("", pd_series, dtype)
1350      except MlflowException:
1351          # error_code is INVALID_PARAMETER_VALUE but this is a schema enforcement failure
1352          raise MlflowException(
1353              f"Failed to enforce schema of data `{data}` with dtype `{dtype.name}`",
1354              error_code=INVALID_PARAMETER_VALUE,
1355              error_class="SCHEMA_ENFORCEMENT_FAILED",
1356          )
1357      return pd_series[0]
1358  
1359  
1360  def _enforce_array(data: Any, arr: Array, required: bool = True):
1361      """
1362      Enforce data against an Array type.
1363      If the field is required, then the data must be provided.
1364      If Array's internal dtype is AnyType, then None and empty lists are also accepted.
1365      """
1366      if not required or isinstance(arr.dtype, AnyType):
1367          if data is None or (isinstance(data, (list, np.ndarray)) and len(data) == 0):
1368              return data
1369  
1370      if not isinstance(data, (list, np.ndarray)):
1371          raise MlflowException(f"Expected data to be list or numpy array, got {type(data).__name__}")
1372  
1373      if isinstance(arr.dtype, DataType):
1374          # TODO: this is still significantly slower than direct np.asarray dtype conversion
1375          # pd.Series conversion can be removed once we support direct validation on the numpy array
1376          data_enforced = (
1377              _enforce_mlflow_datatype("", pd.Series(data), arr.dtype).to_numpy(
1378                  dtype=arr.dtype.to_numpy()
1379              )
1380              if len(data) > 0
1381              else data
1382          )
1383      else:
1384          data_enforced = [_enforce_type(x, arr.dtype, required=required) for x in data]
1385  
1386      if isinstance(data, list) and isinstance(data_enforced, np.ndarray):
1387          data_enforced = data_enforced.tolist()
1388      elif isinstance(data, np.ndarray) and isinstance(data_enforced, list):
1389          data_enforced = np.array(data_enforced)
1390  
1391      return data_enforced
1392  
1393  
1394  def _enforce_property(data: Any, property: Property):
1395      return _enforce_type(data, property.dtype, required=property.required)
1396  
1397  
1398  def _enforce_object(data: dict[str, Any], obj: Object, required: bool = True):
1399      if HAS_PYSPARK and isinstance(data, Row):
1400          data = None if len(data) == 0 else data.asDict(True)
1401      if not required and (data is None or data == {}):
1402          return data
1403      if not isinstance(data, dict):
1404          raise MlflowException(
1405              f"Failed to enforce schema of '{data}' with type '{obj}'. "
1406              f"Expected data to be dictionary, got {type(data).__name__}"
1407          )
1408      if not isinstance(obj, Object):
1409          raise MlflowException(
1410              f"Failed to enforce schema of '{data}' with type '{obj}'. "
1411              f"Expected obj to be Object, got {type(obj).__name__}"
1412          )
1413      properties = {prop.name: prop for prop in obj.properties}
1414      required_props = {k for k, prop in properties.items() if prop.required}
1415      if missing_props := required_props - set(data.keys()):
1416          raise MlflowException(f"Missing required properties: {missing_props}")
1417      if invalid_props := data.keys() - properties.keys():
1418          raise MlflowException(
1419              f"Invalid properties not defined in the schema found: {invalid_props}"
1420          )
1421      for k, v in data.items():
1422          try:
1423              data[k] = _enforce_property(v, properties[k])
1424          except MlflowException as e:
1425              raise MlflowException(
1426                  f"Failed to enforce schema for key `{k}`. "
1427                  f"Expected type {properties[k].to_dict()[k]['type']}, "
1428                  f"received type {type(v).__name__}"
1429              ) from e
1430      return data
1431  
1432  
1433  def _enforce_map(data: Any, map_type: Map, required: bool = True):
1434      if (not required or isinstance(map_type.value_type, AnyType)) and (data is None or data == {}):
1435          return data
1436  
1437      if not isinstance(data, dict):
1438          raise MlflowException(f"Expected data to be a dict, got {type(data).__name__}")
1439  
1440      if not all(isinstance(k, str) for k in data):
1441          raise MlflowException("Expected all keys in the map type data are string type.")
1442  
1443      return {k: _enforce_type(v, map_type.value_type, required=required) for k, v in data.items()}
1444  
1445  
1446  def _enforce_type(data: Any, data_type: DataType | Array | Object | Map, required=True):
1447      if isinstance(data_type, DataType):
1448          return _enforce_datatype(data, data_type, required=required)
1449      if isinstance(data_type, Array):
1450          return _enforce_array(data, data_type, required=required)
1451      if isinstance(data_type, Object):
1452          return _enforce_object(data, data_type, required=required)
1453      if isinstance(data_type, Map):
1454          return _enforce_map(data, data_type, required=required)
1455      if isinstance(data_type, AnyType):
1456          return data
1457      raise MlflowException(f"Invalid data type: {data_type!r}")
1458  
1459  
1460  def validate_schema(data: PyFuncInput, expected_schema: Schema) -> None:
1461      """
1462      Validate that the input data has the expected schema.
1463  
1464      Args:
1465          data: Input data to be validated. Supported types are:
1466  
1467              - pandas.DataFrame
1468              - pandas.Series
1469              - numpy.ndarray
1470              - scipy.sparse.csc_matrix
1471              - scipy.sparse.csr_matrix
1472              - List[Any]
1473              - Dict[str, Any]
1474              - str
1475  
1476          expected_schema: Expected Schema of the input data.
1477  
1478      Raises:
1479          mlflow.exceptions.MlflowException: when the input data does not match the schema.
1480  
1481      .. code-block:: python
1482          :caption: Example usage of validate_schema
1483  
1484          import mlflow.models
1485  
1486          # Suppose you've already got a model_uri
1487          model_info = mlflow.models.get_model_info(model_uri)
1488          # Get model signature directly
1489          model_signature = model_info.signature
1490          # validate schema
1491          mlflow.models.validate_schema(input_data, model_signature.inputs)
1492      """
1493  
1494      _enforce_schema(data, expected_schema)
1495  
1496  
1497  def add_libraries_to_model(model_uri, run_id=None, registered_model_name=None):
1498      """
1499      Given a registered model_uri (e.g. models:/<model_name>/<model_version>), this utility
1500      re-logs the model along with all the required model libraries back to the Model Registry.
1501      The required model libraries are stored along with the model as model artifacts. In
1502      addition, supporting files to the model (e.g. conda.yaml, requirements.txt) are modified
1503      to use the added libraries.
1504  
1505      By default, this utility creates a new model version under the same registered model specified
1506      by ``model_uri``. This behavior can be overridden by specifying the ``registered_model_name``
1507      argument.
1508  
1509      Args:
1510          model_uri: A registered model uri in the Model Registry of the form
1511              models:/<model_name>/<model_version/stage/latest>
1512          run_id: The ID of the run to which the model with libraries is logged. If None, the model
1513              with libraries is logged to the source run corresponding to model version
1514              specified by ``model_uri``; if the model version does not have a source run, a
1515              new run created.
1516          registered_model_name: The new model version (model with its libraries) is
1517              registered under the inputted registered_model_name. If None, a
1518              new version is logged to the existing model in the Model Registry.
1519  
1520      .. note::
1521          This utility only operates on a model that has been registered to the Model Registry.
1522  
1523      .. note::
1524          The libraries are only compatible with the platform on which they are added. Cross platform
1525          libraries are not supported.
1526  
1527      .. code-block:: python
1528          :caption: Example
1529  
1530          # Create and log a model to the Model Registry
1531          import pandas as pd
1532          from sklearn import datasets
1533          from sklearn.ensemble import RandomForestClassifier
1534          import mlflow
1535          import mlflow.sklearn
1536          from mlflow.models import infer_signature
1537  
1538          with mlflow.start_run():
1539              iris = datasets.load_iris()
1540              iris_train = pd.DataFrame(iris.data, columns=iris.feature_names)
1541              clf = RandomForestClassifier(max_depth=7, random_state=0)
1542              clf.fit(iris_train, iris.target)
1543              signature = infer_signature(iris_train, clf.predict(iris_train))
1544              mlflow.sklearn.log_model(
1545                  clf,
1546                  name="iris_rf",
1547                  signature=signature,
1548                  registered_model_name="model-with-libs",
1549              )
1550  
1551          # model uri for the above model
1552          model_uri = "models:/model-with-libs/1"
1553  
1554          # Import utility
1555          from mlflow.models.utils import add_libraries_to_model
1556  
1557          # Log libraries to the original run of the model
1558          add_libraries_to_model(model_uri)
1559  
1560          # Log libraries to some run_id
1561          existing_run_id = "21df94e6bdef4631a9d9cb56f211767f"
1562          add_libraries_to_model(model_uri, run_id=existing_run_id)
1563  
1564          # Log libraries to a new run
1565          with mlflow.start_run():
1566              add_libraries_to_model(model_uri)
1567  
1568          # Log libraries to a new registered model named 'new-model'
1569          with mlflow.start_run():
1570              add_libraries_to_model(model_uri, registered_model_name="new-model")
1571      """
1572  
1573      import mlflow
1574      from mlflow.models.wheeled_model import WheeledModel
1575  
1576      if mlflow.active_run() is None:
1577          if run_id is None:
1578              run_id = get_model_version_from_model_uri(model_uri).run_id
1579          with mlflow.start_run(run_id):
1580              return WheeledModel.log_model(model_uri, registered_model_name)
1581      else:
1582          return WheeledModel.log_model(model_uri, registered_model_name)
1583  
1584  
1585  def get_model_version_from_model_uri(model_uri):
1586      """
1587      Helper function to fetch a model version from a model uri of the form
1588      models:/<model_name>/<model_version/stage/latest>.
1589      """
1590      import mlflow
1591      from mlflow import MlflowClient
1592  
1593      databricks_profile_uri = (
1594          get_databricks_profile_uri_from_artifact_uri(model_uri) or mlflow.get_registry_uri()
1595      )
1596      client = MlflowClient(registry_uri=databricks_profile_uri)
1597      (name, version) = get_model_name_and_version(client, model_uri)
1598      return client.get_model_version(name, version)
1599  
1600  
1601  def _enforce_params_schema(params: dict[str, Any] | None, schema: ParamSchema | None):
1602      if schema is None:
1603          if params in [None, {}]:
1604              return params
1605          params_info = (
1606              f"Ignoring provided params: {list(params.keys())}"
1607              if isinstance(params, dict)
1608              else "Ignoring invalid params (not a dictionary)."
1609          )
1610          _logger.warning(
1611              "`params` can only be specified at inference time if the model signature "
1612              f"defines a params schema. This model does not define a params schema. {params_info}",
1613          )
1614          return {}
1615      params = {} if params is None else params
1616      if not isinstance(params, dict):
1617          raise MlflowException.invalid_parameter_value(
1618              f"Parameters must be a dictionary. Got type '{type(params).__name__}'.",
1619          )
1620      if not isinstance(schema, ParamSchema):
1621          raise MlflowException.invalid_parameter_value(
1622              "Parameters schema must be an instance of ParamSchema. "
1623              f"Got type '{type(schema).__name__}'.",
1624          )
1625      if any(not isinstance(k, str) for k in params.keys()):
1626          _logger.warning(
1627              "Keys in parameters should be of type `str`, but received non-string keys."
1628              "Converting all keys to string..."
1629          )
1630          params = {str(k): v for k, v in params.items()}
1631  
1632      allowed_keys = {param.name for param in schema.params}
1633      if ignored_keys := set(params) - allowed_keys:
1634          _logger.warning(
1635              f"Unrecognized params {list(ignored_keys)} are ignored for inference. "
1636              f"Supported params are: {allowed_keys}. "
1637              "To enable them, please add corresponding schema in ModelSignature."
1638          )
1639  
1640      params = {k: params[k] for k in params if k in allowed_keys}
1641  
1642      invalid_params = set()
1643      for param_spec in schema.params:
1644          if param_spec.name in params:
1645              try:
1646                  params[param_spec.name] = ParamSpec.validate_param_spec(
1647                      params[param_spec.name], param_spec
1648                  )
1649              except MlflowException as e:
1650                  invalid_params.add((param_spec.name, e.message))
1651          else:
1652              params[param_spec.name] = param_spec.default
1653  
1654      if invalid_params:
1655          raise MlflowException.invalid_parameter_value(
1656              f"Invalid parameters found: {invalid_params!r}",
1657          )
1658  
1659      return params
1660  
1661  
1662  def convert_complex_types_pyspark_to_pandas(value, dataType):
1663      # This function is needed because the default `asDict` function in PySpark
1664      # converts the data to Python types, which is not compatible with the schema enforcement.
1665      type_mapping = {
1666          IntegerType: lambda v: np.int32(v),
1667          ShortType: lambda v: np.int16(v),
1668          FloatType: lambda v: np.float32(v),
1669          DateType: lambda v: v.strftime("%Y-%m-%d"),
1670          TimestampType: lambda v: v.strftime("%Y-%m-%d %H:%M:%S.%f"),
1671          BinaryType: lambda v: np.bytes_(v),
1672      }
1673      if value is None:
1674          return None
1675      if isinstance(dataType, StructType):
1676          return {
1677              field.name: convert_complex_types_pyspark_to_pandas(value[field.name], field.dataType)
1678              for field in dataType.fields
1679          }
1680      elif isinstance(dataType, ArrayType):
1681          return [
1682              convert_complex_types_pyspark_to_pandas(elem, dataType.elementType) for elem in value
1683          ]
1684      if converter := type_mapping.get(type(dataType)):
1685          return converter(value)
1686      return value
1687  
1688  
1689  def _is_in_comment(line, start):
1690      """
1691      Check if the code at the index "start" of the line is in a comment.
1692  
1693      Limitations: This function does not handle multi-line comments, and the # symbol could be in a
1694      string, or otherwise not indicate a comment.
1695      """
1696      return "#" in line[:start]
1697  
1698  
1699  def _is_in_string_only(line, search_string):
1700      """
1701      Check is the search_string
1702  
1703      Limitations: This function does not handle multi-line strings.
1704      """
1705      # Regex for matching double quotes and everything inside
1706      double_quotes_regex = r"\"(\\.|[^\"])*\""
1707  
1708      # Regex for matching single quotes and everything inside
1709      single_quotes_regex = r"\'(\\.|[^\'])*\'"
1710  
1711      # Regex for matching search_string exactly
1712      search_string_regex = rf"({re.escape(search_string)})"
1713  
1714      # Concatenate the patterns using the OR operator '|'
1715      # This will matches left to right - on quotes first, search_string last
1716      pattern = double_quotes_regex + r"|" + single_quotes_regex + r"|" + search_string_regex
1717  
1718      # Iterate through all matches in the line
1719      for match in re.finditer(pattern, line):
1720          # If the regex matched on the search_string, we know that it did not match in quotes since
1721          # that is the order. So we know that the search_string exists outside of quotes
1722          # (at least once).
1723          if match.group() == search_string:
1724              return False
1725      return True
1726  
1727  
1728  def _validate_model_code_from_notebook(code):
1729      """
1730      Validate there isn't any code that would work in a notebook but not as exported Python file.
1731      For now, this checks for dbutils and magic commands.
1732      """
1733  
1734      output_code_list = []
1735      for line in code.splitlines():
1736          for match in re.finditer(r"\bdbutils\b", line):
1737              start = match.start()
1738              if not _is_in_comment(line, start) and not _is_in_string_only(line, "dbutils"):
1739                  _logger.warning(
1740                      "The model file uses 'dbutils' commands which are not supported. To ensure "
1741                      "your code functions correctly, make sure that it does not rely on these "
1742                      "dbutils commands for correctness."
1743                  )
1744          # Prefix any line containing MAGIC commands with a comment. When there is better support
1745          # for the Databricks workspace export API, we can get rid of this.
1746          if line.startswith("%"):
1747              output_code_list.append("# MAGIC " + line)
1748          else:
1749              output_code_list.append(line)
1750      output_code = "\n".join(output_code_list)
1751  
1752      magic_regex = r"^# MAGIC %((?!pip)\S+).*"
1753      if re.search(magic_regex, output_code, re.MULTILINE):
1754          _logger.warning(
1755              "The model file uses magic commands which have been commented out. To ensure your code "
1756              "functions correctly, make sure that it does not rely on these magic commands for "
1757              "correctness."
1758          )
1759  
1760      return output_code.encode("utf-8")
1761  
1762  
1763  def _convert_llm_ndarray_to_list(data):
1764      """
1765      Convert numpy array in the input data to list, because numpy array is not json serializable.
1766      """
1767      if isinstance(data, np.ndarray):
1768          return data.tolist()
1769      if isinstance(data, list):
1770          return [_convert_llm_ndarray_to_list(d) for d in data]
1771      if isinstance(data, dict):
1772          return {k: _convert_llm_ndarray_to_list(v) for k, v in data.items()}
1773      # scalar values are also converted to numpy types, but they are
1774      # not acceptable by the model
1775      if np.isscalar(data) and isinstance(data, np.generic):
1776          return data.item()
1777      return data
1778  
1779  
1780  def _convert_llm_input_data(data: Any) -> list[Any] | dict[str, Any]:
1781      """
1782      Convert input data to a format that can be passed to the model with GenAI flavors such as
1783      LangChain and LLamaIndex.
1784  
1785      Args
1786          data: Input data to be converted. We assume it is a single request payload, but it can be
1787              in any format such as a single scalar value, a dictionary, list (with one element),
1788              Pandas DataFrame, etc.
1789      """
1790      # This handles pyfunc / spark_udf inputs with model signature. Schema enforcement convert
1791      # the input data to pandas DataFrame, so we convert it back.
1792      if isinstance(data, pd.DataFrame):
1793          # if the data only contains a single key as 0, we assume the input
1794          # is either a string or list of strings
1795          if list(data.columns) == [0]:
1796              data = data.to_dict("list")[0]
1797          else:
1798              data = data.to_dict(orient="records")
1799  
1800      return _convert_llm_ndarray_to_list(data)
1801  
1802  
1803  def _databricks_path_exists(path: Path) -> bool:
1804      """
1805      Check if a path exists in Databricks workspace.
1806      """
1807      if not is_in_databricks_runtime():
1808          return False
1809  
1810      from databricks.sdk import WorkspaceClient
1811      from databricks.sdk.errors import ResourceDoesNotExist
1812  
1813      client = WorkspaceClient()
1814      try:
1815          client.workspace.get_status(str(path))
1816          return True
1817      except ResourceDoesNotExist:
1818          return False
1819  
1820  
1821  def _validate_and_get_model_code_path(model_code_path: str, temp_dir: str) -> str:
1822      """
1823      Validate model code path exists. When failing to open the model file on Databricks,
1824      creates a temp file in temp_dir and validate its contents if it's a notebook.
1825  
1826      Returns either `model_code_path` or a temp file path with the contents of the notebook.
1827      """
1828  
1829      # If the path is not a absolute path then convert it
1830      model_code_path = Path(model_code_path).resolve()
1831  
1832      if not (model_code_path.exists() or _databricks_path_exists(model_code_path)):
1833          additional_message = (
1834              f" Perhaps you meant '{model_code_path}.py'?" if not model_code_path.suffix else ""
1835          )
1836  
1837          raise MlflowException.invalid_parameter_value(
1838              f"The provided model path '{model_code_path}' does not exist. "
1839              f"Ensure the file path is valid and try again.{additional_message}"
1840          )
1841  
1842      try:
1843          # If `model_code_path` points to a notebook on Databricks, this line throws either
1844          # a `FileNotFoundError` or an `OSError`. In this case, try to export the notebook as
1845          # a Python file.
1846          with open(model_code_path):
1847              pass
1848  
1849          return str(model_code_path)
1850      except Exception:
1851          pass
1852  
1853      try:
1854          from databricks.sdk import WorkspaceClient
1855          from databricks.sdk.service.workspace import ExportFormat
1856  
1857          w = WorkspaceClient()
1858          response = w.workspace.export(path=model_code_path, format=ExportFormat.SOURCE)
1859          decoded_content = base64.b64decode(response.content)
1860      except Exception:
1861          raise MlflowException.invalid_parameter_value(
1862              f"The provided model path '{model_code_path}' is not a valid Python file path or a "
1863              "Databricks Notebook file path containing the code for defining the chain "
1864              "instance. Ensure the file path is valid and try again."
1865          )
1866  
1867      _validate_model_code_from_notebook(decoded_content.decode("utf-8"))
1868      path = os.path.join(temp_dir, "model.py")
1869      with open(path, "wb") as f:
1870          f.write(decoded_content)
1871      return path
1872  
1873  
1874  @contextmanager
1875  def _config_context(config: str | dict[str, Any] | None = None):
1876      # Check if config_path is None and set it to "" so when loading the model
1877      # the config_path is set to "" so the ModelConfig can correctly check if the
1878      # config is set or not
1879      if config is None:
1880          config = ""
1881  
1882      _set_model_config(config)
1883      try:
1884          yield
1885      finally:
1886          _set_model_config(None)
1887  
1888  
1889  class MockDbutils:
1890      def __init__(self, real_dbutils=None):
1891          self.real_dbutils = real_dbutils
1892  
1893      def __getattr__(self, name):
1894          try:
1895              if self.real_dbutils:
1896                  return getattr(self.real_dbutils, name)
1897          except AttributeError:
1898              pass
1899          return MockDbutils()
1900  
1901      def __call__(self, *args, **kwargs):
1902          pass
1903  
1904  
1905  @contextmanager
1906  def _mock_dbutils(globals_dict):
1907      module_name = "dbutils"
1908      original_module = sys.modules.get(module_name)
1909      sys.modules[module_name] = MockDbutils(original_module)
1910  
1911      # Inject module directly into the global namespace in case it is referenced without an import
1912      original_global = globals_dict.get(module_name)
1913      globals_dict[module_name] = MockDbutils(original_module)
1914  
1915      try:
1916          yield
1917      finally:
1918          if original_module is not None:
1919              sys.modules[module_name] = original_module
1920          else:
1921              del sys.modules[module_name]
1922  
1923          if original_global is not None:
1924              globals_dict[module_name] = original_global
1925          else:
1926              del globals_dict[module_name]
1927  
1928  
1929  # Python's module caching mechanism prevents the re-importation of previously loaded modules by
1930  # default. Once a module is imported, it's added to `sys.modules`, and subsequent import attempts
1931  # retrieve the cached module rather than re-importing it.
1932  # Here, we want to import the `code path` module multiple times during a single runtime session.
1933  # This function addresses this by dynamically importing the `code path` module under a unique,
1934  # dynamically generated module name. This bypasses the caching mechanism, as each import is
1935  # considered a separate module by the Python interpreter.
1936  def _load_model_code_path(code_path: str, model_config: str | dict[str, Any] | None):
1937      with _config_context(model_config):
1938          try:
1939              new_module_name = f"code_model_{uuid.uuid4().hex}"
1940              spec = importlib.util.spec_from_file_location(new_module_name, code_path)
1941              module = importlib.util.module_from_spec(spec)
1942              sys.modules[new_module_name] = module
1943              # Since dbutils will only work in databricks environment, we need to mock it
1944              with _mock_dbutils(module.__dict__):
1945                  spec.loader.exec_module(module)
1946          except ImportError as e:
1947              raise MlflowException(
1948                  f"Failed to import code model from {code_path}. Error: {e!s}"
1949              ) from e
1950          except Exception as e:
1951              raise MlflowException(
1952                  f"Failed to run user code from {code_path}. "
1953                  f"Error: {e!s}. "
1954                  "Review the stack trace for more information."
1955              ) from e
1956  
1957      if mlflow.models.model.__mlflow_model__ is None:
1958          raise MlflowException(
1959              "If the model is logged as code, ensure the model is set using "
1960              "mlflow.models.set_model() within the code file code file."
1961          )
1962      return mlflow.models.model.__mlflow_model__
1963  
1964  
1965  def _flatten_nested_params(
1966      d: dict[str, Any], parent_key: str = "", sep: str = "/"
1967  ) -> dict[str, str]:
1968      items: dict[str, Any] = {}
1969      for k, v in d.items():
1970          new_key = f"{parent_key}{sep}{k}" if parent_key else k
1971          if isinstance(v, dict):
1972              items.update(_flatten_nested_params(v, new_key, sep=sep))
1973          else:
1974              items[new_key] = v
1975      return items
1976  
1977  
1978  # NB: this function should always be kept in sync with the serving
1979  # process in scoring_server invocations.
1980  def validate_serving_input(model_uri: str, serving_input: str | dict[str, Any]):
1981      """
1982      Helper function to validate the model can be served and provided input is valid
1983      prior to serving the model.
1984  
1985      Args:
1986          model_uri: URI of the model to be served.
1987          serving_input: Input data to be validated. Should be a dictionary or a JSON string.
1988  
1989      Returns:
1990          The prediction result from the model.
1991      """
1992      from mlflow.pyfunc.scoring_server import _parse_json_data
1993      from mlflow.pyfunc.utils.environment import _simulate_serving_environment
1994  
1995      # sklearn model might not have python_function flavor if it
1996      # doesn't define a predict function. In such case the model
1997      # can not be served anyways
1998  
1999      output_dir = None if get_local_path_or_none(model_uri) else create_tmp_dir()
2000  
2001      try:
2002          pyfunc_model = mlflow.pyfunc.load_model(model_uri, dst_path=output_dir)
2003          parsed_input = _parse_json_data(
2004              serving_input,
2005              pyfunc_model.metadata,
2006              pyfunc_model.metadata.get_input_schema(),
2007          )
2008          with _simulate_serving_environment():
2009              return pyfunc_model.predict(parsed_input.data, params=parsed_input.params)
2010      finally:
2011          if output_dir and os.path.exists(output_dir):
2012              shutil.rmtree(output_dir)
2013  
2014  
2015  def get_external_mlflow_model_spec(logged_model: LoggedModel) -> Model:
2016      """
2017      Create the MLflow Model specification for a given logged model whose artifacts
2018      (code, weights, etc.) are stored externally outside of MLflow.
2019  
2020      Args:
2021          logged_model: The external logged model for which to create an MLflow Model specification.
2022  
2023      Returns:
2024          Model: MLflow Model specification for the given logged model with external artifacts.
2025      """
2026      from mlflow.models.signature import infer_signature
2027  
2028      return Model(
2029          artifact_path=logged_model.artifact_location,
2030          model_uuid=logged_model.model_id,
2031          model_id=logged_model.model_id,
2032          run_id=logged_model.source_run_id,
2033          # Include a dummy signature so that the model can be registered to the Databricks Unity
2034          # Catalog Model Registry.
2035          # TODO: Remove this once the Databricks Unity Catalog Model Registry supports registration
2036          # of models without signatures
2037          signature=infer_signature(model_input=True, model_output=True),
2038          metadata={
2039              # Add metadata to the logged model indicating that its artifacts are stored externally.
2040              # This helps downstream consumers of the model, such as the Model Registry, easily
2041              # and consistently identify that the model's artifacts are external
2042              MLFLOW_MODEL_IS_EXTERNAL: True,
2043          },
2044      )