/ tests / models / test_utils.py
test_utils.py
  1  import os
  2  import random
  3  from typing import Any, NamedTuple
  4  from unittest import mock
  5  
  6  import numpy as np
  7  import pandas as pd
  8  import pytest
  9  import sklearn.neighbors as knn
 10  from sklearn import datasets
 11  
 12  import mlflow
 13  from mlflow import MlflowClient
 14  from mlflow.entities.model_registry import ModelVersion
 15  from mlflow.environment_variables import MLFLOW_DISABLE_SCHEMA_DETAILS
 16  from mlflow.exceptions import MlflowException
 17  from mlflow.models import add_libraries_to_model
 18  from mlflow.models.utils import (
 19      _config_context,
 20      _convert_llm_input_data,
 21      _enforce_array,
 22      _enforce_datatype,
 23      _enforce_mlflow_datatype,
 24      _enforce_object,
 25      _enforce_property,
 26      _flatten_nested_params,
 27      _validate_and_get_model_code_path,
 28      _validate_model_code_from_notebook,
 29      get_model_version_from_model_uri,
 30  )
 31  from mlflow.pyfunc import _enforce_schema, _validate_prediction_input
 32  from mlflow.types import DataType, Schema
 33  from mlflow.types.schema import Array, ColSpec, Object, Property
 34  
 35  
 36  class ModelWithData(NamedTuple):
 37      model: Any
 38      inference_data: Any
 39  
 40  
 41  @pytest.fixture(scope="module")
 42  def sklearn_knn_model():
 43      iris = datasets.load_iris()
 44      X = iris.data[:, :2]  # we only take the first two features.
 45      y = iris.target
 46      knn_model = knn.KNeighborsClassifier()
 47      knn_model.fit(X, y)
 48      return ModelWithData(model=knn_model, inference_data=X)
 49  
 50  
 51  def random_int(lo=1, hi=1000000000):
 52      return random.randint(int(lo), int(hi))
 53  
 54  
 55  def test_adding_libraries_to_model_default(sklearn_knn_model):
 56      model_name = f"wheels-test-{random_int()}"
 57      artifact_path = "model"
 58      model_uri = f"models:/{model_name}/1"
 59      wheeled_model_uri = f"models:/{model_name}/2"
 60  
 61      # Log a model
 62      with mlflow.start_run():
 63          run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
 64          mlflow.sklearn.log_model(
 65              sklearn_knn_model.model,
 66              name=artifact_path,
 67              registered_model_name=model_name,
 68          )
 69  
 70      wheeled_model_info = add_libraries_to_model(model_uri)
 71      assert wheeled_model_info.run_id == run_id
 72  
 73      # Verify new model version created
 74      wheeled_model_version = get_model_version_from_model_uri(wheeled_model_uri)
 75      assert wheeled_model_version.run_id == run_id
 76      assert wheeled_model_version.name == model_name
 77  
 78  
 79  def test_adding_libraries_to_model_new_run(sklearn_knn_model):
 80      model_name = f"wheels-test-{random_int()}"
 81      artifact_path = "model"
 82      model_uri = f"models:/{model_name}/1"
 83      wheeled_model_uri = f"models:/{model_name}/2"
 84  
 85      # Log a model
 86      with mlflow.start_run():
 87          original_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
 88          mlflow.sklearn.log_model(
 89              sklearn_knn_model.model,
 90              name=artifact_path,
 91              registered_model_name=model_name,
 92          )
 93  
 94      with mlflow.start_run():
 95          wheeled_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
 96          wheeled_model_info = add_libraries_to_model(model_uri)
 97      assert original_run_id != wheeled_run_id
 98      assert wheeled_model_info.run_id == wheeled_run_id
 99  
100      # Verify new model version created
101      wheeled_model_version = get_model_version_from_model_uri(wheeled_model_uri)
102      assert wheeled_model_version.run_id == wheeled_run_id
103      assert wheeled_model_version.name == model_name
104  
105  
106  def test_adding_libraries_to_model_run_id_passed(sklearn_knn_model):
107      model_name = f"wheels-test-{random_int()}"
108      artifact_path = "model"
109      model_uri = f"models:/{model_name}/1"
110      wheeled_model_uri = f"models:/{model_name}/2"
111  
112      # Log a model
113      with mlflow.start_run():
114          original_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
115          mlflow.sklearn.log_model(
116              sklearn_knn_model.model,
117              name=artifact_path,
118              registered_model_name=model_name,
119          )
120  
121      with mlflow.start_run():
122          wheeled_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
123  
124      wheeled_model_info = add_libraries_to_model(model_uri, run_id=wheeled_run_id)
125      assert original_run_id != wheeled_run_id
126      assert wheeled_model_info.run_id == wheeled_run_id
127  
128      # Verify new model version created
129      wheeled_model_version = get_model_version_from_model_uri(wheeled_model_uri)
130      assert wheeled_model_version.run_id == wheeled_run_id
131      assert wheeled_model_version.name == model_name
132  
133  
134  def test_adding_libraries_to_model_new_model_name(sklearn_knn_model):
135      model_name = f"wheels-test-{random_int()}"
136      wheeled_model_name = f"wheels-test-{random_int()}"
137      artifact_path = "model"
138      model_uri = f"models:/{model_name}/1"
139      wheeled_model_uri = f"models:/{wheeled_model_name}/1"
140  
141      # Log a model
142      with mlflow.start_run():
143          mlflow.sklearn.log_model(
144              sklearn_knn_model.model,
145              name=artifact_path,
146              registered_model_name=model_name,
147          )
148  
149      with mlflow.start_run():
150          new_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
151          wheeled_model_info = add_libraries_to_model(
152              model_uri, registered_model_name=wheeled_model_name
153          )
154      assert wheeled_model_info.run_id == new_run_id
155  
156      # Verify new model version created
157      wheeled_model_version = get_model_version_from_model_uri(wheeled_model_uri)
158      assert wheeled_model_version.run_id == new_run_id
159      assert wheeled_model_version.name == wheeled_model_name
160      assert wheeled_model_name != model_name
161  
162  
163  def test_adding_libraries_to_model_when_version_source_None(sklearn_knn_model):
164      model_name = f"wheels-test-{random_int()}"
165      artifact_path = "model"
166      model_uri = f"models:/{model_name}/1"
167  
168      # Log a model
169      with mlflow.start_run():
170          original_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
171          mlflow.sklearn.log_model(
172              sklearn_knn_model.model,
173              name=artifact_path,
174              registered_model_name=model_name,
175          )
176  
177      model_version_without_source = ModelVersion(name=model_name, version=1, creation_timestamp=124)
178      assert model_version_without_source.run_id is None
179      with mock.patch.object(
180          MlflowClient, "get_model_version", return_value=model_version_without_source
181      ) as mlflow_client_mock:
182          wheeled_model_info = add_libraries_to_model(model_uri)
183          assert wheeled_model_info.run_id is not None
184          assert wheeled_model_info.run_id != original_run_id
185          mlflow_client_mock.assert_called_once_with(model_name, "1")
186  
187  
188  @pytest.mark.parametrize(
189      ("data", "data_type"),
190      [
191          ("string", DataType.string),
192          (np.int32(1), DataType.integer),
193          (np.int32(1), DataType.long),
194          (np.int32(1), DataType.double),
195          (True, DataType.boolean),
196          (1.0, DataType.double),
197          (np.float32(0.1), DataType.float),
198          (np.float32(0.1), DataType.double),
199          (np.int64(100), DataType.long),
200          (np.datetime64("2023-10-13 00:00:00"), DataType.datetime),
201      ],
202  )
203  def test_enforce_datatype(data, data_type):
204      assert _enforce_datatype(data, data_type) == data
205  
206  
207  def test_enforce_datatype_with_errors():
208      with pytest.raises(MlflowException, match=r"Expected dtype to be DataType, got str"):
209          _enforce_datatype("string", "string")
210  
211      with pytest.raises(
212          MlflowException, match=r"Failed to enforce schema of data `123` with dtype `string`"
213      ):
214          _enforce_datatype(123, DataType.string)
215  
216  
217  @pytest.mark.parametrize(
218      "dtype",
219      [
220          pd.StringDtype(),
221          "string",
222          object,
223          None,  # infers object in pandas <3.0, StringDtype in pandas 3.0
224      ],
225  )
226  def test_enforce_mlflow_datatype_with_string_dtype(dtype):
227      # Test that string dtypes are handled correctly (pandas 3.0 compatibility)
228      series = pd.Series(["a", "b", "c"], dtype=dtype)
229      result = _enforce_mlflow_datatype("col", series, DataType.string)
230      assert result is series
231  
232  
233  def test_enforce_object():
234      data = {
235          "a": "some_sentence",
236          "b": b"some_bytes",
237          "c": ["sentence1", "sentence2"],
238          "d": {"str": "value", "arr": [0.1, 0.2]},
239      }
240      obj = Object([
241          Property("a", DataType.string),
242          Property("b", DataType.binary, required=False),
243          Property("c", Array(DataType.string)),
244          Property(
245              "d",
246              Object([
247                  Property("str", DataType.string),
248                  Property("arr", Array(DataType.double), required=False),
249              ]),
250          ),
251      ])
252      assert _enforce_object(data, obj) == data
253  
254      data = {"a": "some_sentence", "c": ["sentence1", "sentence2"], "d": {"str": "some_value"}}
255      assert _enforce_object(data, obj) == data
256  
257  
258  def test_enforce_object_with_errors():
259      with pytest.raises(MlflowException, match=r"Expected data to be dictionary, got list"):
260          _enforce_object(["some_sentence"], Object([Property("a", DataType.string)]))
261  
262      with pytest.raises(MlflowException, match=r"Expected obj to be Object, got Property"):
263          _enforce_object({"a": "some_sentence"}, Property("a", DataType.string))
264  
265      obj = Object([Property("a", DataType.string), Property("b", DataType.string, required=False)])
266      with pytest.raises(MlflowException, match=r"Missing required properties: {'a'}"):
267          _enforce_object({}, obj)
268  
269      with pytest.raises(
270          MlflowException, match=r"Invalid properties not defined in the schema found: {'c'}"
271      ):
272          _enforce_object({"a": "some_sentence", "c": "some_sentence"}, obj)
273  
274      with pytest.raises(
275          MlflowException,
276          match=r"Failed to enforce schema for key `a`. Expected type string, received type int",
277      ):
278          _enforce_object({"a": 1}, obj)
279  
280  
281  def test_enforce_property():
282      data = "some_sentence"
283      prop = Property("a", DataType.string)
284      assert _enforce_property(data, prop) == data
285  
286      data = ["some_sentence1", "some_sentence2"]
287      prop = Property("a", Array(DataType.string))
288      assert _enforce_property(data, prop) == data
289  
290      prop = Property("a", Array(DataType.binary))
291      assert _enforce_property(data, prop) == [b"some_sentence1", b"some_sentence2"]
292  
293      data = np.array([np.int32(1), np.int32(2)])
294      prop = Property("a", Array(DataType.integer))
295      assert (_enforce_property(data, prop) == data).all()
296  
297      data = {
298          "a": "some_sentence",
299          "b": b"some_bytes",
300          "c": ["sentence1", "sentence2"],
301          "d": {"str": "value", "arr": [0.1, 0.2]},
302      }
303      prop = Property(
304          "any_name",
305          Object([
306              Property("a", DataType.string),
307              Property("b", DataType.binary, required=False),
308              Property("c", Array(DataType.string), required=False),
309              Property(
310                  "d",
311                  Object([
312                      Property("str", DataType.string),
313                      Property("arr", Array(DataType.double), required=False),
314                  ]),
315              ),
316          ]),
317      )
318      assert _enforce_property(data, prop) == data
319      data = {"a": "some_sentence", "d": {"str": "some_value"}}
320      assert _enforce_property(data, prop) == data
321  
322  
323  def test_enforce_property_with_errors():
324      with pytest.raises(
325          MlflowException, match=r"Failed to enforce schema of data `123` with dtype `string`"
326      ):
327          _enforce_property(123, Property("a", DataType.string))
328  
329      with pytest.raises(MlflowException, match=r"Missing required properties: {'a'}"):
330          _enforce_property(
331              {"b": ["some_sentence1", "some_sentence2"]},
332              Property(
333                  "any_name",
334                  Object([Property("a", DataType.string), Property("b", Array(DataType.string))]),
335              ),
336          )
337  
338      with pytest.raises(
339          MlflowException,
340          match=r"Failed to enforce schema for key `a`. Expected type string, received type list",
341      ):
342          _enforce_property(
343              {"a": ["some_sentence1", "some_sentence2"]},
344              Property("any_name", Object([Property("a", DataType.string)])),
345          )
346  
347  
348  @pytest.mark.parametrize(
349      ("data", "schema"),
350      [
351          # 1. Flat list
352          (["some_sentence1", "some_sentence2"], Array(DataType.string)),
353          # 2. Nested list
354          (
355              [
356                  [["a", "b"], ["c", "d"]],
357                  [["e", "f", "g"], ["h"]],
358                  [[]],
359              ],
360              Array(Array(Array(DataType.string))),
361          ),
362          # 3. Array of Object
363          (
364              [
365                  {"a": "some_sentence1", "b": "some_sentence2"},
366                  {"a": "some_sentence3", "c": ["some_sentence4", "some_sentence5"]},
367              ],
368              Array(
369                  Object([
370                      Property("a", DataType.string),
371                      Property("b", DataType.string, required=False),
372                      Property("c", Array(DataType.string), required=False),
373                  ])
374              ),
375          ),
376          # 4. Empty list
377          ([], Array(DataType.string)),
378      ],
379  )
380  def test_enforce_array_on_list(data, schema):
381      assert _enforce_array(data, schema) == data
382  
383  
384  @pytest.mark.parametrize(
385      ("data", "schema"),
386      [
387          # 1. 1D array
388          (np.array(["some_sentence1", "some_sentence2"]), Array(DataType.string)),
389          # 2. 2D array
390          (
391              np.array([
392                  ["a", "b"],
393                  ["c", "d"],
394              ]),
395              Array(Array(DataType.string)),
396          ),
397          # 3. Empty array
398          (np.array([[], []]), Array(Array(DataType.string))),
399      ],
400  )
401  def test_enforce_array_on_numpy_array(data, schema):
402      assert (_enforce_array(data, schema) == data).all()
403  
404  
405  def test_enforce_array_with_errors():
406      with pytest.raises(MlflowException, match=r"Expected data to be list or numpy array, got str"):
407          _enforce_array("abc", Array(DataType.string))
408  
409      with pytest.raises(MlflowException, match=r"Incompatible input types"):
410          _enforce_array([123, 456, 789], Array(DataType.string))
411  
412      # Nested array with mixed type elements
413      with pytest.raises(MlflowException, match=r"Incompatible input types"):
414          _enforce_array([["a", "b"], [1, 2]], Array(Array(DataType.string)))
415  
416      # Nested array with different nest level
417      with pytest.raises(MlflowException, match=r"Expected data to be list or numpy array, got str"):
418          _enforce_array([["a", "b"], "c"], Array(Array(DataType.string)))
419  
420      # Missing priperties in Object
421      with pytest.raises(MlflowException, match=r"Missing required properties: {'b'}"):
422          _enforce_array(
423              [
424                  {"a": "some_sentence1", "b": "some_sentence2"},
425                  {"a": "some_sentence3", "c": ["some_sentence4", "some_sentence5"]},
426              ],
427              Array(Object([Property("a", DataType.string), Property("b", DataType.string)])),
428          )
429  
430      # Extra properties
431      with pytest.raises(
432          MlflowException, match=r"Invalid properties not defined in the schema found: {'c'}"
433      ):
434          _enforce_array(
435              [
436                  {"a": "some_sentence1", "b": "some_sentence2"},
437                  {"a": "some_sentence3", "c": ["some_sentence4", "some_sentence5"]},
438              ],
439              Array(
440                  Object([
441                      Property("a", DataType.string),
442                      Property("b", DataType.string, required=False),
443                  ])
444              ),
445          )
446  
447  
448  def test_model_code_validation():
449      # Invalid code with dbutils
450      invalid_code = "dbutils.library.restartPython()\nsome_python_variable = 5"
451  
452      with mock.patch("mlflow.models.utils._logger.warning") as mock_warning:
453          _validate_model_code_from_notebook(invalid_code)
454          mock_warning.assert_called_once_with(
455              "The model file uses 'dbutils' commands which are not supported. To ensure your "
456              "code functions correctly, make sure that it does not rely on these dbutils "
457              "commands for correctness."
458          )
459  
460      # Code with commented magic commands displays warning
461      warning_code = "# dbutils.library.restartPython()\n# MAGIC %run ../wheel_installer"
462  
463      with mock.patch("mlflow.models.utils._logger.warning") as mock_warning:
464          _validate_model_code_from_notebook(warning_code)
465          mock_warning.assert_called_once_with(
466              "The model file uses magic commands which have been commented out. To ensure your code "
467              "functions correctly, make sure that it does not rely on these magic commands for "
468              "correctness."
469          )
470  
471      # Code with commented pip magic commands does not warn
472      warning_code = "# MAGIC %pip install mlflow"
473      with mock.patch("mlflow.models.utils._logger.warning") as mock_warning:
474          _validate_model_code_from_notebook(warning_code)
475          mock_warning.assert_not_called()
476  
477      # Test valid code
478      valid_code = "some_valid_python_code = 'valid'"
479  
480      validated_code = _validate_model_code_from_notebook(valid_code).decode("utf-8")
481      assert validated_code == valid_code
482  
483      # Test uncommented magic commands
484      code_with_magic_command = (
485          "valid_python_code = 'valid'\n%pip install sqlparse\nvalid_python_code = 'valid'\n# Comment"
486      )
487      expected_validated_code = (
488          "valid_python_code = 'valid'\n# MAGIC %pip install sqlparse\nvalid_python_code = "
489          "'valid'\n# Comment"
490      )
491  
492      validated_code_with_magic_command = _validate_model_code_from_notebook(
493          code_with_magic_command
494      ).decode("utf-8")
495      assert validated_code_with_magic_command == expected_validated_code
496  
497  
498  def test_config_context():
499      with _config_context("tests/langchain/config.yml"):
500          assert mlflow.models.model_config.__mlflow_model_config__ == "tests/langchain/config.yml"
501  
502      assert mlflow.models.model_config.__mlflow_model_config__ is None
503  
504  
505  def test_flatten_nested_params():
506      nested_params = {
507          "a": 1,
508          "b": {"c": 2, "d": {"e": 3}},
509          "f": {"g": {"h": 4}},
510      }
511      expected_flattened_params = {
512          "a": 1,
513          "b.c": 2,
514          "b.d.e": 3,
515          "f.g.h": 4,
516      }
517      assert _flatten_nested_params(nested_params, sep=".") == expected_flattened_params
518      assert _flatten_nested_params(nested_params, sep="/") == {
519          "a": 1,
520          "b/c": 2,
521          "b/d/e": 3,
522          "f/g/h": 4,
523      }
524      assert _flatten_nested_params({}) == {}
525  
526      params = {"a": 1, "b": 2, "c": 3}
527      assert _flatten_nested_params(params) == params
528  
529      params = {
530          "a": 1,
531          "b": {"c": 2, "d": {"e": 3, "f": [1, 2, 3]}, "g": "hello"},
532          "h": {"i": None},
533      }
534      expected_flattened_params = {
535          "a": 1,
536          "b/c": 2,
537          "b/d/e": 3,
538          "b/d/f": [1, 2, 3],
539          "b/g": "hello",
540          "h/i": None,
541      }
542      assert _flatten_nested_params(params) == expected_flattened_params
543  
544      nested_params = {1: {2: {3: 4}}, "a": {"b": {"c": 5}}}
545      expected_flattened_params_mixed = {
546          "1/2/3": 4,
547          "a/b/c": 5,
548      }
549      assert _flatten_nested_params(nested_params) == expected_flattened_params_mixed
550  
551      rag_params = {
552          "workspace_url": "https://e2-dogfood.staging.cloud.databricks.com",
553          "vector_search_endpoint_name": "dbdemos_vs_endpoint",
554          "vector_search_index": "monitoring.rag.databricks_docs_index",
555          "embedding_model_endpoint_name": "databricks-bge-large-en",
556          "embedding_model_query_instructions": "Represent this sentence for searching",
557          "llm_model": "databricks-dbrx-instruct",
558          "llm_prompt_template": "You are a trustful assistant for Databricks users.",
559          "retriever_config": {"k": 5, "use_mmr": "false"},
560          "llm_parameters": {"temperature": 0.01, "max_tokens": 200},
561          "llm_prompt_template_variables": ["chat_history", "context", "question"],
562          "secret_scope": "dbdemos",
563          "secret_key": "rag_sunish",
564      }
565  
566      expected_rag_flattened_params = {
567          "workspace_url": "https://e2-dogfood.staging.cloud.databricks.com",
568          "vector_search_endpoint_name": "dbdemos_vs_endpoint",
569          "vector_search_index": "monitoring.rag.databricks_docs_index",
570          "embedding_model_endpoint_name": "databricks-bge-large-en",
571          "embedding_model_query_instructions": "Represent this sentence for searching",
572          "llm_model": "databricks-dbrx-instruct",
573          "llm_prompt_template": "You are a trustful assistant for Databricks users.",
574          "retriever_config/k": 5,
575          "retriever_config/use_mmr": "false",
576          "llm_parameters/temperature": 0.01,
577          "llm_parameters/max_tokens": 200,
578          "llm_prompt_template_variables": ["chat_history", "context", "question"],
579          "secret_scope": "dbdemos",
580          "secret_key": "rag_sunish",
581      }
582  
583      assert _flatten_nested_params(rag_params) == expected_rag_flattened_params
584  
585  
586  @pytest.mark.parametrize(
587      ("data", "target", "target_type"),
588      [
589          (pd.DataFrame([{"a": [1, 2, 3]}]), [{"a": [1, 2, 3]}], list),
590          (pd.DataFrame([{"a": np.array([1, 2, 3])}]), [{"a": [1, 2, 3]}], list),
591          (pd.DataFrame([{0: np.array(["abc"])[0]}]), ["abc"], list),
592          (np.array([1, 2, 3]), [1, 2, 3], list),
593          (np.array([123])[0], 123, int),
594          (np.array(["abc"])[0], "abc", str),
595      ],
596  )
597  def test_convert_llm_input_data(data, target, target_type):
598      result = _convert_llm_input_data(data)
599      assert result == target
600      assert type(result) == target_type
601  
602  
603  @pytest.mark.parametrize(
604      ("model_path", "error_message"),
605      [
606          (
607              "model.py",
608              f"The provided model path '{os.getcwd()}/model.py' does not exist. "
609              "Ensure the file path is valid and try again.",
610          ),
611          (
612              "model",
613              f"The provided model path '{os.getcwd()}/model' does not exist. "
614              "Ensure the file path is valid and try again. "
615              f"Perhaps you meant '{os.getcwd()}/model.py'?",
616          ),
617      ],
618  )
619  def test_validate_and_get_model_code_path_not_found(model_path, error_message, tmp_path):
620      with pytest.raises(MlflowException, match=error_message):
621          _validate_and_get_model_code_path(model_path, tmp_path)
622  
623  
624  def test_validate_and_get_model_code_path_success(tmp_path):
625      # if the model file exists, return the path as is
626      model_path = os.path.abspath(__file__)
627      actual = _validate_and_get_model_code_path(model_path, tmp_path)
628  
629      assert actual == model_path
630  
631  
632  def test_suppress_schema_error(monkeypatch):
633      schema = Schema([
634          ColSpec("double", "id"),
635          ColSpec("string", "name"),
636      ])
637      monkeypatch.setenv(MLFLOW_DISABLE_SCHEMA_DETAILS.name, "true")
638      data = pd.DataFrame({"id": [1, 2]}, dtype="float64")
639  
640      with pytest.raises(
641          MlflowException,
642          match=r"Failed to enforce model input schema. Please check your input data.",
643      ):
644          _validate_prediction_input(data, None, schema, None)
645  
646  
647  def test_enforce_schema_with_missing_and_extra_columns(monkeypatch):
648      schema = Schema([
649          ColSpec("long", "id"),
650          ColSpec("string", "name"),
651      ])
652      monkeypatch.setenv(MLFLOW_DISABLE_SCHEMA_DETAILS.name, "true")
653      input_data = pd.DataFrame({"id": [1, 2], "extra_col": ["mlflow", "oss"]})
654      with pytest.raises(
655          MlflowException, match=r"Input schema validation failed.*extra inputs provided"
656      ):
657          _enforce_schema(input_data, schema)