/ tests / models / test_python_api.py
test_python_api.py
  1  import datetime
  2  import json
  3  import os
  4  import sys
  5  from unittest import mock
  6  
  7  import numpy as np
  8  import pandas as pd
  9  import pytest
 10  import scipy.sparse
 11  
 12  import mlflow
 13  from mlflow.exceptions import MlflowException
 14  from mlflow.models.python_api import (
 15      _CONTENT_TYPE_CSV,
 16      _CONTENT_TYPE_JSON,
 17      _serialize_input_data,
 18  )
 19  from mlflow.tracing.constant import TraceMetadataKey
 20  from mlflow.utils.env_manager import CONDA, LOCAL, UV, VIRTUALENV
 21  
 22  from tests.tracing.helper import get_traces
 23  
 24  
 25  @pytest.mark.parametrize(
 26      ("input_data", "expected_data", "content_type"),
 27      [
 28          (
 29              "x,y\n1,3\n2,4",
 30              pd.DataFrame({"x": [1, 2], "y": [3, 4]}),
 31              _CONTENT_TYPE_CSV,
 32          ),
 33          (
 34              {"a": [1]},
 35              {"a": np.array([1])},
 36              _CONTENT_TYPE_JSON,
 37          ),
 38          (
 39              1,
 40              np.array(1),
 41              _CONTENT_TYPE_JSON,
 42          ),
 43          (
 44              np.array([1, 2, 3]),
 45              np.array([1, 2, 3]),
 46              _CONTENT_TYPE_JSON,
 47          ),
 48          (
 49              scipy.sparse.csc_matrix([[1, 2], [3, 4]]),
 50              np.array([[1, 2], [3, 4]]),
 51              _CONTENT_TYPE_JSON,
 52          ),
 53          (
 54              # uLLM input, no change
 55              {"input": "some_data"},
 56              {"input": "some_data"},
 57              _CONTENT_TYPE_JSON,
 58          ),
 59      ],
 60  )
 61  @pytest.mark.parametrize(
 62      "env_manager",
 63      [VIRTUALENV, UV],
 64  )
 65  def test_predict(input_data, expected_data, content_type, env_manager):
 66      class TestModel(mlflow.pyfunc.PythonModel):
 67          def predict(self, context, model_input):
 68              if isinstance(model_input, pd.DataFrame):
 69                  assert model_input.equals(expected_data)
 70              elif isinstance(model_input, np.ndarray):
 71                  assert np.array_equal(model_input, expected_data)
 72              else:
 73                  assert model_input == expected_data
 74              return {}
 75  
 76      with mlflow.start_run():
 77          model_info = mlflow.pyfunc.log_model(
 78              name="model",
 79              python_model=TestModel(),
 80              extra_pip_requirements=["pytest"],
 81          )
 82  
 83      mlflow.models.predict(
 84          model_uri=model_info.model_uri,
 85          input_data=input_data,
 86          content_type=content_type,
 87          env_manager=env_manager,
 88      )
 89  
 90  
 91  @pytest.mark.parametrize(
 92      "env_manager",
 93      [VIRTUALENV, CONDA, UV],
 94  )
 95  def test_predict_with_pip_requirements_override(env_manager):
 96      if env_manager == CONDA:
 97          if sys.platform == "win32":
 98              pytest.skip("Skipping conda tests on Windows")
 99  
100      class TestModel(mlflow.pyfunc.PythonModel):
101          def predict(self, context, model_input):
102              # XGBoost should be installed by pip_requirements_override
103              import xgboost
104  
105              assert xgboost.__version__ == "1.7.3"
106  
107              # Scikit-learn version should be overridden to 1.3.0 by pip_requirements_override
108              import sklearn
109  
110              assert sklearn.__version__ == "1.3.0"
111  
112      with mlflow.start_run():
113          model_info = mlflow.pyfunc.log_model(
114              name="model",
115              python_model=TestModel(),
116              extra_pip_requirements=["scikit-learn==1.3.2", "pytest"],
117          )
118  
119      requirements_override = ["xgboost==1.7.3", "scikit-learn==1.3.0"]
120      if env_manager == CONDA:
121          # Install charset-normalizer with conda-forge to work around pip-vs-conda issue during
122          # CI tests. At the beginning of the CI test, it installs MLflow dependencies via pip,
123          # which includes charset-normalizer. Then when it runs this test case, the conda env
124          # is created but charset-normalizer is installed via the default channel, which is one
125          # major version behind the version installed via pip (as of 2024 Jan). As a result,
126          # Python env confuses pip and conda versions and cause errors like "ImportError: cannot
127          # import name 'COMMON_SAFE_ASCII_CHARACTERS' from 'charset_normalizer.constant'".
128          # To work around this, we install the latest cversion from the conda-forge.
129          # TODO: Implement better isolation approach for pip and conda environments during testing.
130          requirements_override.append("conda-forge::charset-normalizer")
131  
132      mlflow.models.predict(
133          model_uri=model_info.model_uri,
134          input_data={"inputs": [1, 2, 3]},
135          content_type=_CONTENT_TYPE_JSON,
136          pip_requirements_override=requirements_override,
137          env_manager=env_manager,
138      )
139  
140  
141  @pytest.mark.parametrize("env_manager", [VIRTUALENV, CONDA, UV])
142  def test_predict_with_model_alias(env_manager):
143      class TestModel(mlflow.pyfunc.PythonModel):
144          def predict(self, context, model_input):
145              assert os.environ["TEST"] == "test"
146              return model_input
147  
148      with mlflow.start_run():
149          mlflow.pyfunc.log_model(
150              name="model",
151              python_model=TestModel(),
152              registered_model_name="model_name",
153          )
154      client = mlflow.MlflowClient()
155      client.set_registered_model_alias("model_name", "test_alias", 1)
156  
157      mlflow.models.predict(
158          model_uri="models:/model_name@test_alias",
159          input_data="abc",
160          env_manager=env_manager,
161          extra_envs={"TEST": "test"},
162      )
163  
164  
165  @pytest.mark.parametrize("env_manager", [VIRTUALENV, CONDA, UV])
166  def test_predict_with_extra_envs(env_manager):
167      class TestModel(mlflow.pyfunc.PythonModel):
168          def predict(self, context, model_input):
169              assert os.environ["TEST"] == "test"
170              return model_input
171  
172      with mlflow.start_run():
173          model_info = mlflow.pyfunc.log_model(
174              name="model",
175              python_model=TestModel(),
176          )
177  
178      mlflow.models.predict(
179          model_uri=model_info.model_uri,
180          input_data="abc",
181          content_type=_CONTENT_TYPE_JSON,
182          env_manager=env_manager,
183          extra_envs={"TEST": "test"},
184      )
185  
186  
187  def test_predict_with_extra_envs_errors():
188      class TestModel(mlflow.pyfunc.PythonModel):
189          def predict(self, context, model_input):
190              assert os.environ["TEST"] == "test"
191              return model_input
192  
193      with mlflow.start_run():
194          model_info = mlflow.pyfunc.log_model(
195              name="model",
196              python_model=TestModel(),
197          )
198  
199      with pytest.raises(
200          MlflowException,
201          match=r"Extra environment variables are only "
202          r"supported when env_manager is set to 'virtualenv', 'conda' or 'uv'",
203      ):
204          mlflow.models.predict(
205              model_uri=model_info.model_uri,
206              input_data="abc",
207              content_type=_CONTENT_TYPE_JSON,
208              env_manager=LOCAL,
209              extra_envs={"TEST": "test"},
210          )
211  
212      with pytest.raises(
213          MlflowException, match=r"An exception occurred while running model prediction"
214      ):
215          mlflow.models.predict(
216              model_uri=model_info.model_uri,
217              input_data="abc",
218              content_type=_CONTENT_TYPE_JSON,
219          )
220  
221  
222  @pytest.fixture
223  def mock_backend():
224      mock_backend = mock.MagicMock()
225      with mock.patch("mlflow.models.python_api.get_flavor_backend", return_value=mock_backend):
226          yield mock_backend
227  
228  
229  def test_predict_with_both_input_data_and_path_raise(mock_backend):
230      with pytest.raises(MlflowException, match=r"Both input_data and input_path are provided"):
231          mlflow.models.predict(
232              model_uri="runs:/test/Model",
233              input_data={"inputs": [1, 2, 3]},
234              input_path="input.csv",
235              content_type=_CONTENT_TYPE_CSV,
236          )
237  
238  
239  def test_predict_invalid_content_type(mock_backend):
240      with pytest.raises(MlflowException, match=r"Content type must be one of"):
241          mlflow.models.predict(
242              model_uri="runs:/test/Model",
243              input_data={"inputs": [1, 2, 3]},
244              content_type="any",
245          )
246  
247  
248  def test_predict_with_input_none(mock_backend):
249      mlflow.models.predict(
250          model_uri="runs:/test/Model",
251          content_type=_CONTENT_TYPE_CSV,
252      )
253  
254      mock_backend.predict.assert_called_once_with(
255          model_uri="runs:/test/Model",
256          input_path=None,
257          output_path=None,
258          content_type=_CONTENT_TYPE_CSV,
259          pip_requirements_override=None,
260          extra_envs=None,
261      )
262  
263  
264  @pytest.mark.parametrize(
265      ("input_data", "content_type", "expected"),
266      [
267          # String (convert to serving input)
268          ("[1, 2, 3]", _CONTENT_TYPE_JSON, '{"inputs": "[1, 2, 3]"}'),
269          # uLLM String (no change)
270          ({"input": "data"}, _CONTENT_TYPE_JSON, '{"input": "data"}'),
271          ("x,y,z\n1,2,3\n4,5,6", _CONTENT_TYPE_CSV, "x,y,z\n1,2,3\n4,5,6"),
272          # Bool
273          (True, _CONTENT_TYPE_JSON, '{"inputs": true}'),
274          # Int
275          (1, _CONTENT_TYPE_JSON, '{"inputs": 1}'),
276          # Float
277          (1.0, _CONTENT_TYPE_JSON, '{"inputs": 1.0}'),
278          # Datetime
279          (
280              datetime.datetime(2021, 1, 1, 0, 0, 0),
281              _CONTENT_TYPE_JSON,
282              '{"inputs": "2021-01-01T00:00:00"}',
283          ),
284          # List
285          ([1, 2, 3], _CONTENT_TYPE_CSV, "0\n1\n2\n3\n"),  # a header '0' is added by pandas
286          ([[1, 2, 3], [4, 5, 6]], _CONTENT_TYPE_CSV, "0,1,2\n1,2,3\n4,5,6\n"),
287          # Dict (pandas)
288          (
289              {
290                  "x": [
291                      1,
292                      2,
293                  ],
294                  "y": [3, 4],
295              },
296              _CONTENT_TYPE_CSV,
297              "x,y\n1,3\n2,4\n",
298          ),
299          # Dict (json)
300          ({"a": [1, 2, 3]}, _CONTENT_TYPE_JSON, '{"inputs": {"a": [1, 2, 3]}}'),
301          # Pandas DataFrame (csv)
302          (pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}), _CONTENT_TYPE_CSV, "x,y\n1,4\n2,5\n3,6\n"),
303          # Pandas DataFrame (json)
304          (
305              pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}),
306              _CONTENT_TYPE_JSON,
307              '{"dataframe_split": {"columns": ["x", "y"], "data": [[1, 4], [2, 5], [3, 6]]}}',
308          ),
309          # Numpy Array
310          (np.array([1, 2, 3]), _CONTENT_TYPE_JSON, '{"inputs": [1, 2, 3]}'),
311          # CSC Matrix
312          (
313              scipy.sparse.csc_matrix([[1, 2], [3, 4]]),
314              _CONTENT_TYPE_JSON,
315              '{"inputs": [[1, 2], [3, 4]]}',
316          ),
317          # CSR Matrix
318          (
319              scipy.sparse.csr_matrix([[1, 2], [3, 4]]),
320              _CONTENT_TYPE_JSON,
321              '{"inputs": [[1, 2], [3, 4]]}',
322          ),
323      ],
324  )
325  def test_serialize_input_data(input_data, content_type, expected):
326      if content_type == _CONTENT_TYPE_JSON:
327          assert json.loads(_serialize_input_data(input_data, content_type)) == json.loads(expected)
328      else:
329          assert _serialize_input_data(input_data, content_type) == expected
330  
331  
332  @pytest.mark.parametrize(
333      ("input_data", "content_type"),
334      [
335          # Invalid input datatype for the content type
336          (1, _CONTENT_TYPE_CSV),
337          ({1, 2, 3}, _CONTENT_TYPE_CSV),
338          # Invalid string
339          ("x,y\n1,2\n3,4,5\n", _CONTENT_TYPE_CSV),
340          # Invalid list
341          ([[1, 2], [3, 4], 5], _CONTENT_TYPE_CSV),
342          # Invalid dict (unserealizable)
343          ({"x": 1, "y": {1, 2, 3}}, _CONTENT_TYPE_JSON),
344      ],
345  )
346  def test_serialize_input_data_invalid_format(input_data, content_type):
347      with pytest.raises(MlflowException):  # noqa: PT011
348          _serialize_input_data(input_data, content_type)
349  
350  
351  def test_predict_use_current_experiment():
352      class TestModel(mlflow.pyfunc.PythonModel):
353          @mlflow.trace
354          def predict(self, context, model_input: list[str]):
355              return model_input
356  
357      exp_id = mlflow.set_experiment("test_experiment").experiment_id
358      client = mlflow.MlflowClient()
359      with mlflow.start_run():
360          model_info = mlflow.pyfunc.log_model(
361              name="model",
362              python_model=TestModel(),
363          )
364  
365      assert len(client.search_traces(locations=[exp_id])) == 0
366      mlflow.models.predict(
367          model_uri=model_info.model_uri,
368          input_data=["a", "b", "c"],
369          env_manager=VIRTUALENV,
370      )
371      traces = client.search_traces(locations=[exp_id])
372      assert len(traces) == 1
373      assert json.loads(traces[0].data.request)["model_input"] == ["a", "b", "c"]
374  
375  
376  def test_predict_traces_link_to_active_model():
377      model = mlflow.set_active_model(name="test_model")
378  
379      class TestModel(mlflow.pyfunc.PythonModel):
380          @mlflow.trace
381          def predict(self, context, model_input: list[str]):
382              return model_input
383  
384      with mlflow.start_run():
385          model_info = mlflow.pyfunc.log_model(
386              name="model",
387              python_model=TestModel(),
388          )
389  
390      traces = get_traces()
391      assert len(traces) == 0
392  
393      mlflow.models.predict(
394          model_uri=model_info.model_uri,
395          input_data=["a", "b", "c"],
396          env_manager=VIRTUALENV,
397      )
398      traces = get_traces()
399      assert len(traces) == 1
400      assert traces[0].info.request_metadata[TraceMetadataKey.MODEL_ID] == model.model_id