/ mlflow / types / utils.py
utils.py
  1  import logging
  2  import warnings
  3  from collections import defaultdict
  4  from copy import deepcopy
  5  from typing import Any, Dict, List
  6  
  7  import numpy as np
  8  import pandas as pd
  9  import pydantic
 10  
 11  from mlflow.exceptions import MlflowException
 12  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
 13  from mlflow.types import DataType
 14  from mlflow.types.schema import (
 15      HAS_PYSPARK,
 16      AnyType,
 17      Array,
 18      ColSpec,
 19      Map,
 20      Object,
 21      ParamSchema,
 22      ParamSpec,
 23      Property,
 24      Schema,
 25      SparkMLVector,
 26      TensorSpec,
 27  )
 28  
 29  MULTIPLE_TYPES_ERROR_MSG = (
 30      "Expected all values in the list to be of the same type. To specify a model signature "
 31      "with a list containing elements of multiple types, define the signature manually "
 32      "using the Array(AnyType()) type from mlflow.models.schema."
 33  )
 34  _logger = logging.getLogger(__name__)
 35  
 36  
 37  class TensorsNotSupportedException(MlflowException):
 38      def __init__(self, msg):
 39          super().__init__(f"Multidimensional arrays (aka tensors) are not supported. {msg}")
 40  
 41  
 42  def _get_tensor_shape(data, variable_dimension: int | None = 0) -> tuple[int, ...]:
 43      """Infer the shape of the inputted data.
 44  
 45      This method creates the shape of the tensor to store in the TensorSpec. The variable dimension
 46      is assumed to be the first dimension by default. This assumption can be overridden by inputting
 47      a different variable dimension or `None` to represent that the input tensor does not contain a
 48      variable dimension.
 49  
 50      Args:
 51          data: Dataset to infer from.
 52          variable_dimension: An optional integer representing a variable dimension.
 53  
 54      Returns:
 55          tuple: Shape of the inputted data (including a variable dimension)
 56      """
 57      from scipy.sparse import csc_matrix, csr_matrix
 58  
 59      if not isinstance(data, (np.ndarray, csr_matrix, csc_matrix)):
 60          raise TypeError(f"Expected numpy.ndarray or csc/csr matrix, got '{type(data)}'.")
 61      variable_input_data_shape = data.shape
 62      if variable_dimension is not None:
 63          try:
 64              variable_input_data_shape = list(variable_input_data_shape)
 65              variable_input_data_shape[variable_dimension] = -1
 66          except IndexError:
 67              raise MlflowException(
 68                  f"The specified variable_dimension {variable_dimension} is out of bounds with "
 69                  f"respect to the number of dimensions {data.ndim} in the input dataset"
 70              )
 71      return tuple(variable_input_data_shape)
 72  
 73  
 74  def clean_tensor_type(dtype: np.dtype):
 75      """
 76      This method strips away the size information stored in flexible datatypes such as np.str_ and
 77      np.bytes_. Other numpy dtypes are returned unchanged.
 78  
 79      Args:
 80          dtype: Numpy dtype of a tensor
 81  
 82      Returns:
 83          dtype: Cleaned numpy dtype
 84      """
 85      if not isinstance(dtype, np.dtype):
 86          raise TypeError(
 87              f"Expected `type` to be instance of `{np.dtype}`, received `{dtype.__class__}`"
 88          )
 89  
 90      # Special casing for np.str_ and np.bytes_
 91      if dtype.char == "U":
 92          return np.dtype("str")
 93      elif dtype.char == "S":
 94          return np.dtype("bytes")
 95      return dtype
 96  
 97  
 98  def _infer_colspec_type(data: Any) -> DataType | Array | Object | AnyType:
 99      """
100      Infer an MLflow Colspec type from the dataset.
101  
102      Args:
103          data: data to infer from.
104  
105      Returns:
106          Object
107      """
108      dtype = _infer_datatype(data)
109  
110      if dtype is None:
111          raise MlflowException(
112              f"Numpy array must include at least one non-empty item. Invalid input `{data}`."
113          )
114  
115      return dtype
116  
117  
118  class InvalidDataForSignatureInferenceError(MlflowException):
119      def __init__(self, message):
120          super().__init__(message=message, error_code=INVALID_PARAMETER_VALUE)
121  
122  
123  def _infer_datatype(data: Any) -> DataType | Array | Object | AnyType | None:
124      """
125      Infer the datatype of input data.
126      Data type and inferred schema type mapping:
127          - dict -> Object
128          - list -> Array
129          - numpy.ndarray -> Array
130          - scalar -> DataType
131          - None, empty dictionary/list -> AnyType
132  
133      .. Note::
134          Empty numpy arrays are inferred as None to keep the backward compatibility, as numpy
135          arrays are used by some traditional ML flavors.
136          e.g. numpy.array([]) -> None, numpy.array([[], []]) -> None
137          While empty lists are inferred as AnyType instead of None after the support of AnyType.
138          e.g. [] -> AnyType, [[], []] -> Array(Any)
139      """
140      if isinstance(data, pydantic.BaseModel):
141          raise InvalidDataForSignatureInferenceError(
142              message="MLflow does not support inferring model signature from input example "
143              "with Pydantic objects. To use Pydantic objects, define your PythonModel's "
144              "`predict` method with a Pydantic type hint, and model signature will be automatically "
145              "inferred when logging the model. e.g. "
146              "`def predict(self, model_input: list[PydanticType])`. Check "
147              "https://mlflow.org/docs/latest/model/python_model.html#type-hint-usage-in-pythonmodel "
148              "for more details."
149          )
150  
151      if _is_none_or_nan(data) or (isinstance(data, (list, dict)) and not data):
152          return AnyType()
153  
154      if isinstance(data, dict):
155          properties = []
156          for k, v in data.items():
157              dtype = _infer_datatype(v)
158              if dtype is None:
159                  raise MlflowException("Dictionary value must not be an empty numpy array.")
160              properties.append(
161                  Property(name=k, dtype=dtype, required=not isinstance(dtype, AnyType))
162              )
163          return Object(properties=properties)
164  
165      if isinstance(data, (list, np.ndarray)):
166          return _infer_array_datatype(data)
167  
168      return _infer_scalar_datatype(data)
169  
170  
171  def _infer_array_datatype(data: list[Any] | np.ndarray) -> Array | None:
172      """Infer schema from an array. This tries to infer type if there is at least one
173      non-null item in the list, assuming the list has a homogeneous type. However,
174      if the list is empty or all items are null, returns None as a sign of undetermined.
175  
176      E.g.
177          ["a", "b"] => Array(string)
178          ["a", None] => Array(string)
179          [["a", "b"], []] => Array(Array(string))
180          [["a", "b"], None] => Array(Array(string))
181          [] => None
182          [None] => Array(Any)
183  
184      Args:
185          data: data to infer from.
186  
187      Returns:
188          Array(dtype) or None if undetermined
189      """
190      result = None
191      for item in data:
192          dtype = _infer_datatype(item)
193  
194          # Skip item with undetermined type
195          if dtype is None:
196              continue
197  
198          if result is None:
199              result = Array(dtype)
200          elif isinstance(result.dtype, (Array, Object, Map, AnyType)):
201              try:
202                  result = Array(result.dtype._merge(dtype))
203              except MlflowException as e:
204                  raise MlflowException.invalid_parameter_value(MULTIPLE_TYPES_ERROR_MSG) from e
205          elif isinstance(result.dtype, DataType):
206              if not isinstance(dtype, AnyType) and dtype != result.dtype:
207                  raise MlflowException.invalid_parameter_value(MULTIPLE_TYPES_ERROR_MSG)
208          else:
209              raise MlflowException.invalid_parameter_value(
210                  f"{dtype} is not a valid type for an item of a list or numpy array."
211              )
212      return result
213  
214  
215  # datetime is not included here
216  SCALAR_TO_DATATYPE_MAPPING = {
217      bool: DataType.boolean,
218      np.bool_: DataType.boolean,
219      int: DataType.long,
220      np.int64: DataType.long,
221      np.int32: DataType.integer,
222      float: DataType.double,
223      np.float64: DataType.double,
224      np.float32: DataType.float,
225      str: DataType.string,
226      np.str_: DataType.string,
227      object: DataType.string,
228      bytes: DataType.binary,
229      np.bytes_: DataType.binary,
230      bytearray: DataType.binary,
231  }
232  
233  
234  def _infer_scalar_datatype(data) -> DataType:
235      if data_type := SCALAR_TO_DATATYPE_MAPPING.get(type(data)):
236          return data_type
237      if DataType.check_type(DataType.datetime, data):
238          return DataType.datetime
239      if HAS_PYSPARK:
240          for data_type in DataType.all_types():
241              if isinstance(data, type(data_type.to_spark())):
242                  return data_type
243      raise MlflowException.invalid_parameter_value(
244          f"Data {data} is not one of the supported DataType"
245      )
246  
247  
248  def _infer_schema(data: Any) -> Schema:
249      """
250      Infer an MLflow schema from a dataset.
251  
252      Data inputted as a numpy array or a dictionary is represented by :py:class:`TensorSpec`.
253      All other inputted data types are specified by :py:class:`ColSpec`.
254  
255      A `TensorSpec` captures the data shape (default variable axis is 0), the data type (numpy.dtype)
256      and an optional name for each individual tensor of the dataset.
257      A `ColSpec` captures the data type (defined in :py:class:`DataType`) and an optional name for
258      each individual column of the dataset.
259  
260      This method will raise an exception if the user data contains incompatible types or is not
261      passed in one of the supported formats (containers).
262  
263      The input should be one of these:
264        - pandas.DataFrame
265        - pandas.Series
266        - numpy.ndarray
267        - dictionary of (name -> numpy.ndarray)
268        - pyspark.sql.DataFrame
269        - scipy.sparse.csr_matrix/csc_matrix
270        - DataType
271        - List[DataType]
272        - Dict[str, Union[DataType, List, Dict]]
273        - List[Dict[str, Union[DataType, List, Dict]]]
274  
275      The last two formats are used to represent complex data structures. For example,
276  
277          Input Data:
278              [
279                  {
280                      'text': 'some sentence',
281                      'ids': ['id1'],
282                      'dict': {'key': 'value'}
283                  },
284                  {
285                      'text': 'some sentence',
286                      'ids': ['id1', 'id2'],
287                      'dict': {'key': 'value', 'key2': 'value2'}
288                  },
289              ]
290  
291          The corresponding pandas DataFrame representation should look like this:
292  
293                      output         ids                                dict
294              0  some sentence  [id1, id2]                    {'key': 'value'}
295              1  some sentence  [id1, id2]  {'key': 'value', 'key2': 'value2'}
296  
297          The inferred schema should look like this:
298  
299              Schema([
300                  ColSpec(type=DataType.string, name='output'),
301                  ColSpec(type=Array(dtype=DataType.string), name='ids'),
302                  ColSpec(
303                      type=Object([
304                          Property(name='key', dtype=DataType.string),
305                          Property(name='key2', dtype=DataType.string, required=False)
306                      ]),
307                      name='dict')]
308                  ),
309              ])
310  
311      The element types should be mappable to one of :py:class:`mlflow.models.signature.DataType` for
312      dataframes and to one of numpy types for tensors.
313  
314      Args:
315          data: Dataset to infer from.
316  
317      Returns:
318          Schema
319      """
320      from scipy.sparse import csc_matrix, csr_matrix
321  
322      # To keep backward compatibility with < 2.9.0, an empty list is inferred as string.
323      #   ref: https://github.com/mlflow/mlflow/pull/10125#discussion_r1372751487
324      if isinstance(data, list) and data == []:
325          return Schema([ColSpec(DataType.string)])
326  
327      if isinstance(data, list) and all(isinstance(value, dict) for value in data):
328          col_data_mapping = defaultdict(list)
329          for item in data:
330              for k, v in item.items():
331                  col_data_mapping[k].append(v)
332          requiredness = {}
333          for col in col_data_mapping:
334              # if col exists in item but its value is None, then it is not required
335              requiredness[col] = all(item.get(col) is not None for item in data)
336  
337          schema = Schema([
338              ColSpec(_infer_colspec_type(values).dtype, name=name, required=requiredness[name])
339              for name, values in col_data_mapping.items()
340          ])
341  
342      elif isinstance(data, dict):
343          # dictionary of (name -> numpy.ndarray)
344          if all(isinstance(values, np.ndarray) for values in data.values()):
345              schema = Schema([
346                  TensorSpec(
347                      type=clean_tensor_type(ndarray.dtype),
348                      shape=_get_tensor_shape(ndarray),
349                      name=name,
350                  )
351                  for name, ndarray in data.items()
352              ])
353          # Dict[str, Union[DataType, List, Dict]]
354          else:
355              if any(not isinstance(key, str) for key in data):
356                  raise MlflowException("The dictionary keys are not all strings.")
357              schema = Schema([
358                  ColSpec(
359                      _infer_colspec_type(value),
360                      name=name,
361                      required=_infer_required(value),
362                  )
363                  for name, value in data.items()
364              ])
365      # pandas.Series
366      elif isinstance(data, pd.Series):
367          name = getattr(data, "name", None)
368          schema = Schema([
369              ColSpec(
370                  type=_infer_pandas_column(data),
371                  name=name,
372                  required=_infer_required(data),
373              )
374          ])
375      # pandas.DataFrame
376      elif isinstance(data, pd.DataFrame):
377          schema = Schema([
378              ColSpec(
379                  type=_infer_pandas_column(data[col]),
380                  name=col,
381                  required=_infer_required(data[col]),
382              )
383              for col in data.columns
384          ])
385      # numpy.ndarray
386      elif isinstance(data, np.ndarray):
387          schema = Schema([
388              TensorSpec(type=clean_tensor_type(data.dtype), shape=_get_tensor_shape(data))
389          ])
390      # scipy.sparse.csr_matrix/csc_matrix
391      elif isinstance(data, (csc_matrix, csr_matrix)):
392          schema = Schema([
393              TensorSpec(type=clean_tensor_type(data.data.dtype), shape=_get_tensor_shape(data))
394          ])
395      # pyspark.sql.DataFrame
396      elif _is_spark_df(data):
397          schema = Schema([
398              ColSpec(
399                  type=_infer_spark_type(field.dataType, data, field.name),
400                  name=field.name,
401                  # Avoid setting required field for spark dataframe
402                  # as the default value for spark df nullable is True
403                  # which counterparts to default required=True in ColSpec
404              )
405              for field in data.schema.fields
406          ])
407      elif isinstance(data, list):
408          # Assume list as a single column
409          # List[DataType]
410          # e.g. ['some sentence', 'some sentence'] -> Schema([ColSpec(type=DataType.string)])
411          # The corresponding pandas DataFrame representation should be pd.DataFrame(data)
412          # We set required=True as unnamed optional inputs is not allowed
413          schema = Schema([ColSpec(_infer_colspec_type(data).dtype)])
414      else:
415          # DataType
416          # e.g. "some sentence" -> Schema([ColSpec(type=DataType.string)])
417          try:
418              # We set required=True as unnamed optional inputs is not allowed
419              schema = Schema([ColSpec(_infer_colspec_type(data))])
420          except MlflowException as e:
421              raise MlflowException.invalid_parameter_value(
422                  "Failed to infer schema. Expected one of the following types:\n"
423                  "- pandas.DataFrame\n"
424                  "- pandas.Series\n"
425                  "- numpy.ndarray\n"
426                  "- dictionary of (name -> numpy.ndarray)\n"
427                  "- pyspark.sql.DataFrame\n"
428                  "- scipy.sparse.csr_matrix\n"
429                  "- scipy.sparse.csc_matrix\n"
430                  "- DataType\n"
431                  "- List[DataType]\n"
432                  "- Dict[str, Union[DataType, List, Dict]]\n"
433                  "- List[Dict[str, Union[DataType, List, Dict]]]\n"
434                  f"but got '{data}'.\n"
435                  f"Error: {e}",
436              )
437      if not schema.is_tensor_spec() and any(
438          t in (DataType.integer, DataType.long) for t in schema.input_types()
439      ):
440          warnings.warn(
441              "Hint: Inferred schema contains integer column(s). Integer columns in "
442              "Python cannot represent missing values. If your input data contains "
443              "missing values at inference time, it will be encoded as floats and will "
444              "cause a schema enforcement error. The best way to avoid this problem is "
445              "to infer the model schema based on a realistic data sample (training "
446              "dataset) that includes missing values. Alternatively, you can declare "
447              "integer columns as doubles (float64) whenever these columns may have "
448              "missing values. See `Handling Integers With Missing Values "
449              "<https://www.mlflow.org/docs/latest/models.html#"
450              "handling-integers-with-missing-values>`_ for more details."
451          )
452      return schema
453  
454  
455  def _infer_numpy_dtype(dtype) -> DataType:
456      supported_types = np.dtype
457  
458      # noinspection PyBroadException
459      try:
460          from pandas.core.dtypes.base import ExtensionDtype
461  
462          supported_types = (np.dtype, ExtensionDtype)
463      except ImportError:
464          # This version of pandas does not support extension types
465          pass
466      if not isinstance(dtype, supported_types):
467          raise TypeError(f"Expected numpy.dtype or pandas.ExtensionDtype, got '{type(dtype)}'.")
468  
469      if dtype.kind == "b":
470          return DataType.boolean
471      elif dtype.kind in {"i", "u"}:
472          if dtype.itemsize < 4 or (dtype.kind == "i" and dtype.itemsize == 4):
473              return DataType.integer
474          elif dtype.itemsize < 8 or (dtype.kind == "i" and dtype.itemsize == 8):
475              return DataType.long
476      elif dtype.kind == "f":
477          if dtype.itemsize <= 4:
478              return DataType.float
479          elif dtype.itemsize <= 8:
480              return DataType.double
481  
482      elif dtype.kind == "U":
483          return DataType.string
484      elif dtype.kind == "S":
485          return DataType.binary
486      elif dtype.kind == "O":
487          raise Exception(
488              "Can not infer object without looking at the values, call _map_numpy_array instead."
489          )
490      elif dtype.kind == "M":
491          return DataType.datetime
492      raise MlflowException(f"Unsupported numpy data type '{dtype}', kind '{dtype.kind}'")
493  
494  
495  def _is_none_or_nan(x):
496      if isinstance(x, float):
497          return np.isnan(x)
498      # NB: We can't use pd.isna() because the input can be a series.
499      return x is None or x is pd.NA or x is pd.NaT
500  
501  
502  def _infer_required(col) -> bool:
503      if isinstance(col, (list, pd.Series)):
504          return not any(_is_none_or_nan(x) for x in col)
505      return not _is_none_or_nan(col)
506  
507  
508  def _infer_pandas_column(col: pd.Series) -> DataType:
509      if not isinstance(col, pd.Series):
510          raise TypeError(f"Expected pandas.Series, got '{type(col)}'.")
511      if len(col.values.shape) > 1:
512          raise MlflowException(f"Expected 1d array, got array with shape {col.shape}")
513  
514      if col.dtype.kind == "O":
515          col = col.infer_objects()
516      if col.dtype.kind == "O":
517          try:
518              # We convert pandas Series into list and infer the schema.
519              # The real schema for internal field should be the Array's dtype
520              arr_type = _infer_colspec_type(col.to_list())
521              return arr_type.dtype
522          except Exception as e:
523              # For backwards compatibility, we fall back to string
524              # if the provided array is of string type
525              if pd.api.types.is_string_dtype(col):
526                  return DataType.string
527              raise MlflowException(f"Failed to infer schema for pandas.Series {col}. Error: {e}")
528      else:
529          # NB: The following works for numpy types as well as pandas extension types.
530          return _infer_numpy_dtype(col.dtype)
531  
532  
533  def _infer_spark_type(x, data=None, col_name=None) -> DataType:
534      import pyspark.sql.types
535      from pyspark.ml.linalg import VectorUDT
536      from pyspark.sql.functions import col, collect_list
537  
538      if isinstance(x, pyspark.sql.types.NumericType):
539          if isinstance(x, pyspark.sql.types.IntegralType):
540              if isinstance(x, pyspark.sql.types.LongType):
541                  return DataType.long
542              else:
543                  return DataType.integer
544          elif isinstance(x, pyspark.sql.types.FloatType):
545              return DataType.float
546          elif isinstance(x, pyspark.sql.types.DoubleType):
547              return DataType.double
548      elif isinstance(x, pyspark.sql.types.BooleanType):
549          return DataType.boolean
550      elif isinstance(x, pyspark.sql.types.StringType):
551          return DataType.string
552      elif isinstance(x, pyspark.sql.types.BinaryType):
553          return DataType.binary
554      # NB: Spark differentiates date and timestamps, so we coerce both to TimestampType.
555      elif isinstance(x, (pyspark.sql.types.DateType, pyspark.sql.types.TimestampType)):
556          return DataType.datetime
557      elif isinstance(x, pyspark.sql.types.ArrayType):
558          return Array(_infer_spark_type(x.elementType))
559      elif isinstance(x, pyspark.sql.types.StructType):
560          return Object(
561              properties=[
562                  Property(
563                      name=f.name,
564                      dtype=_infer_spark_type(f.dataType),
565                      required=not f.nullable,
566                  )
567                  for f in x.fields
568              ]
569          )
570      elif isinstance(x, pyspark.sql.types.MapType):
571          if data is None or col_name is None:
572              raise MlflowException("Cannot infer schema for MapType without data and column name.")
573          # Map MapType to StructType
574          # Note that MapType assumes all values are of same type,
575          # if they're not then spark picks the first item's type
576          # and tries to convert rest to that type.
577          # e.g.
578          # >>> spark.createDataFrame([{"col": {"a": 1, "b": "b"}}]).show()
579          # +-------------------+
580          # |                col|
581          # +-------------------+
582          # |{a -> 1, b -> null}|
583          # +-------------------+
584          if isinstance(x.valueType, pyspark.sql.types.MapType):
585              raise MlflowException(
586                  "Please construct spark DataFrame with schema using StructType "
587                  "for dictionary/map fields, MLflow schema inference only supports "
588                  "scalar, array and struct types."
589              )
590  
591          merged_keys = (
592              data
593              .selectExpr(f"map_keys({col_name}) as keys")
594              .agg(collect_list(col("keys")).alias("merged_keys"))
595              .head()
596              .merged_keys
597          )
598          keys = {key for sublist in merged_keys for key in sublist}
599          return Object(
600              properties=[
601                  Property(
602                      name=k,
603                      dtype=_infer_spark_type(x.valueType),
604                  )
605                  for k in keys
606              ]
607          )
608      elif isinstance(x, VectorUDT):
609          return SparkMLVector()
610  
611      else:
612          raise MlflowException.invalid_parameter_value(
613              f"Unsupported Spark Type '{type(x)}' for MLflow schema."
614          )
615  
616  
617  def _is_spark_df(x) -> bool:
618      try:
619          import pyspark.sql.dataframe
620  
621          if isinstance(x, pyspark.sql.dataframe.DataFrame):
622              return True
623      except ImportError:
624          return False
625      # For spark 4.0
626      try:
627          import pyspark.sql.connect.dataframe
628  
629          return isinstance(x, pyspark.sql.connect.dataframe.DataFrame)
630      except ImportError:
631          return False
632  
633  
634  def _validate_input_dictionary_contains_only_strings_and_lists_of_strings(data) -> None:
635      # isinstance(True, int) is True
636      invalid_keys = [
637          key for key in data.keys() if not isinstance(key, (str, int)) or isinstance(key, bool)
638      ]
639      if invalid_keys:
640          raise MlflowException(
641              f"The dictionary keys are not all strings or indexes. Invalid keys: {invalid_keys}"
642          )
643      if any(isinstance(value, np.ndarray) for value in data.values()) and not all(
644          isinstance(value, np.ndarray) for value in data.values()
645      ):
646          raise MlflowException("The dictionary values are not all numpy.ndarray.")
647  
648      invalid_values = [
649          key
650          for key, value in data.items()
651          if (isinstance(value, list) and not all(isinstance(item, (str, bytes)) for item in value))
652          or (not isinstance(value, (np.ndarray, list, str, bytes)))
653      ]
654      if invalid_values:
655          raise MlflowException.invalid_parameter_value(
656              "Invalid values in dictionary. If passing a dictionary containing strings, all "
657              "values must be either strings or lists of strings. If passing a dictionary containing "
658              "numeric values, the data must be enclosed in a numpy.ndarray. The following keys "
659              f"in the input dictionary are invalid: {invalid_values}",
660          )
661  
662  
663  def _is_list_str(type_hint: Any) -> bool:
664      return type_hint in [
665          List[str],  # noqa: UP006
666          list[str],
667      ]
668  
669  
670  def _is_list_dict_str(type_hint: Any) -> bool:
671      return type_hint in [
672          List[Dict[str, str]],  # noqa: UP006
673          list[Dict[str, str]],  # noqa: UP006
674          List[dict[str, str]],  # noqa: UP006
675          list[dict[str, str]],
676      ]
677  
678  
679  def _get_array_depth(l: Any) -> int:
680      if isinstance(l, np.ndarray):
681          return l.ndim
682      if isinstance(l, list):
683          return max(_get_array_depth(item) for item in l) + 1 if l else 1
684      return 0
685  
686  
687  def _infer_type_and_shape(value):
688      if isinstance(value, (list, np.ndarray)):
689          ndim = _get_array_depth(value)
690          if ndim != 1:
691              raise MlflowException.invalid_parameter_value(
692                  f"Expected parameters to be 1D array or scalar, got {ndim}D array",
693              )
694          if all(DataType.check_type(DataType.datetime, v) for v in value):
695              return DataType.datetime, (-1,)
696          value_type = _infer_numpy_dtype(np.array(value).dtype)
697          return value_type, (-1,)
698      elif DataType.check_type(DataType.datetime, value):
699          return DataType.datetime, None
700      elif np.isscalar(value):
701          try:
702              value_type = _infer_numpy_dtype(np.array(value).dtype)
703              return value_type, None
704          except (Exception, MlflowException) as e:
705              raise MlflowException.invalid_parameter_value(
706                  f"Failed to infer schema for parameter {value}: {e!r}"
707              )
708      elif isinstance(value, dict):
709          # reuse _infer_schema to infer schema for dict, wrapping it in a dictionary is
710          # necessary to make sure value is inferred as Object
711          schema = _infer_schema({"value": value})
712          object_type = schema.inputs[0].type
713          return object_type, None
714      raise MlflowException.invalid_parameter_value(
715          f"Expected parameters to be 1D array or scalar, got {type(value).__name__}",
716      )
717  
718  
719  def _infer_param_schema(parameters: dict[str, Any]):
720      if not isinstance(parameters, dict):
721          raise MlflowException.invalid_parameter_value(
722              f"Expected parameters to be dict, got {type(parameters).__name__}",
723          )
724  
725      param_specs = []
726      invalid_params = []
727      for name, value in parameters.items():
728          try:
729              value_type, shape = _infer_type_and_shape(value)
730              param_specs.append(
731                  ParamSpec(name=name, dtype=value_type, default=deepcopy(value), shape=shape)
732              )
733          except Exception as e:
734              invalid_params.append((name, value, e))
735  
736      if invalid_params:
737          raise MlflowException.invalid_parameter_value(
738              f"Failed to infer schema for parameters: {invalid_params}",
739          )
740  
741      return ParamSchema(param_specs)