test_python_api.py
1 import datetime 2 import json 3 import os 4 import sys 5 from unittest import mock 6 7 import numpy as np 8 import pandas as pd 9 import pytest 10 import scipy.sparse 11 12 import mlflow 13 from mlflow.exceptions import MlflowException 14 from mlflow.models.python_api import ( 15 _CONTENT_TYPE_CSV, 16 _CONTENT_TYPE_JSON, 17 _serialize_input_data, 18 ) 19 from mlflow.tracing.constant import TraceMetadataKey 20 from mlflow.utils.env_manager import CONDA, LOCAL, UV, VIRTUALENV 21 22 from tests.tracing.helper import get_traces 23 24 25 @pytest.mark.parametrize( 26 ("input_data", "expected_data", "content_type"), 27 [ 28 ( 29 "x,y\n1,3\n2,4", 30 pd.DataFrame({"x": [1, 2], "y": [3, 4]}), 31 _CONTENT_TYPE_CSV, 32 ), 33 ( 34 {"a": [1]}, 35 {"a": np.array([1])}, 36 _CONTENT_TYPE_JSON, 37 ), 38 ( 39 1, 40 np.array(1), 41 _CONTENT_TYPE_JSON, 42 ), 43 ( 44 np.array([1, 2, 3]), 45 np.array([1, 2, 3]), 46 _CONTENT_TYPE_JSON, 47 ), 48 ( 49 scipy.sparse.csc_matrix([[1, 2], [3, 4]]), 50 np.array([[1, 2], [3, 4]]), 51 _CONTENT_TYPE_JSON, 52 ), 53 ( 54 # uLLM input, no change 55 {"input": "some_data"}, 56 {"input": "some_data"}, 57 _CONTENT_TYPE_JSON, 58 ), 59 ], 60 ) 61 @pytest.mark.parametrize( 62 "env_manager", 63 [VIRTUALENV, UV], 64 ) 65 def test_predict(input_data, expected_data, content_type, env_manager): 66 class TestModel(mlflow.pyfunc.PythonModel): 67 def predict(self, context, model_input): 68 if isinstance(model_input, pd.DataFrame): 69 assert model_input.equals(expected_data) 70 elif isinstance(model_input, np.ndarray): 71 assert np.array_equal(model_input, expected_data) 72 else: 73 assert model_input == expected_data 74 return {} 75 76 with mlflow.start_run(): 77 model_info = mlflow.pyfunc.log_model( 78 name="model", 79 python_model=TestModel(), 80 extra_pip_requirements=["pytest"], 81 ) 82 83 mlflow.models.predict( 84 model_uri=model_info.model_uri, 85 input_data=input_data, 86 content_type=content_type, 87 env_manager=env_manager, 88 ) 89 90 91 @pytest.mark.parametrize( 92 "env_manager", 93 [VIRTUALENV, CONDA, UV], 94 ) 95 def test_predict_with_pip_requirements_override(env_manager): 96 if env_manager == CONDA: 97 if sys.platform == "win32": 98 pytest.skip("Skipping conda tests on Windows") 99 100 class TestModel(mlflow.pyfunc.PythonModel): 101 def predict(self, context, model_input): 102 # XGBoost should be installed by pip_requirements_override 103 import xgboost 104 105 assert xgboost.__version__ == "1.7.3" 106 107 # Scikit-learn version should be overridden to 1.3.0 by pip_requirements_override 108 import sklearn 109 110 assert sklearn.__version__ == "1.3.0" 111 112 with mlflow.start_run(): 113 model_info = mlflow.pyfunc.log_model( 114 name="model", 115 python_model=TestModel(), 116 extra_pip_requirements=["scikit-learn==1.3.2", "pytest"], 117 ) 118 119 requirements_override = ["xgboost==1.7.3", "scikit-learn==1.3.0"] 120 if env_manager == CONDA: 121 # Install charset-normalizer with conda-forge to work around pip-vs-conda issue during 122 # CI tests. At the beginning of the CI test, it installs MLflow dependencies via pip, 123 # which includes charset-normalizer. Then when it runs this test case, the conda env 124 # is created but charset-normalizer is installed via the default channel, which is one 125 # major version behind the version installed via pip (as of 2024 Jan). As a result, 126 # Python env confuses pip and conda versions and cause errors like "ImportError: cannot 127 # import name 'COMMON_SAFE_ASCII_CHARACTERS' from 'charset_normalizer.constant'". 128 # To work around this, we install the latest cversion from the conda-forge. 129 # TODO: Implement better isolation approach for pip and conda environments during testing. 130 requirements_override.append("conda-forge::charset-normalizer") 131 132 mlflow.models.predict( 133 model_uri=model_info.model_uri, 134 input_data={"inputs": [1, 2, 3]}, 135 content_type=_CONTENT_TYPE_JSON, 136 pip_requirements_override=requirements_override, 137 env_manager=env_manager, 138 ) 139 140 141 @pytest.mark.parametrize("env_manager", [VIRTUALENV, CONDA, UV]) 142 def test_predict_with_model_alias(env_manager): 143 class TestModel(mlflow.pyfunc.PythonModel): 144 def predict(self, context, model_input): 145 assert os.environ["TEST"] == "test" 146 return model_input 147 148 with mlflow.start_run(): 149 mlflow.pyfunc.log_model( 150 name="model", 151 python_model=TestModel(), 152 registered_model_name="model_name", 153 ) 154 client = mlflow.MlflowClient() 155 client.set_registered_model_alias("model_name", "test_alias", 1) 156 157 mlflow.models.predict( 158 model_uri="models:/model_name@test_alias", 159 input_data="abc", 160 env_manager=env_manager, 161 extra_envs={"TEST": "test"}, 162 ) 163 164 165 @pytest.mark.parametrize("env_manager", [VIRTUALENV, CONDA, UV]) 166 def test_predict_with_extra_envs(env_manager): 167 class TestModel(mlflow.pyfunc.PythonModel): 168 def predict(self, context, model_input): 169 assert os.environ["TEST"] == "test" 170 return model_input 171 172 with mlflow.start_run(): 173 model_info = mlflow.pyfunc.log_model( 174 name="model", 175 python_model=TestModel(), 176 ) 177 178 mlflow.models.predict( 179 model_uri=model_info.model_uri, 180 input_data="abc", 181 content_type=_CONTENT_TYPE_JSON, 182 env_manager=env_manager, 183 extra_envs={"TEST": "test"}, 184 ) 185 186 187 def test_predict_with_extra_envs_errors(): 188 class TestModel(mlflow.pyfunc.PythonModel): 189 def predict(self, context, model_input): 190 assert os.environ["TEST"] == "test" 191 return model_input 192 193 with mlflow.start_run(): 194 model_info = mlflow.pyfunc.log_model( 195 name="model", 196 python_model=TestModel(), 197 ) 198 199 with pytest.raises( 200 MlflowException, 201 match=r"Extra environment variables are only " 202 r"supported when env_manager is set to 'virtualenv', 'conda' or 'uv'", 203 ): 204 mlflow.models.predict( 205 model_uri=model_info.model_uri, 206 input_data="abc", 207 content_type=_CONTENT_TYPE_JSON, 208 env_manager=LOCAL, 209 extra_envs={"TEST": "test"}, 210 ) 211 212 with pytest.raises( 213 MlflowException, match=r"An exception occurred while running model prediction" 214 ): 215 mlflow.models.predict( 216 model_uri=model_info.model_uri, 217 input_data="abc", 218 content_type=_CONTENT_TYPE_JSON, 219 ) 220 221 222 @pytest.fixture 223 def mock_backend(): 224 mock_backend = mock.MagicMock() 225 with mock.patch("mlflow.models.python_api.get_flavor_backend", return_value=mock_backend): 226 yield mock_backend 227 228 229 def test_predict_with_both_input_data_and_path_raise(mock_backend): 230 with pytest.raises(MlflowException, match=r"Both input_data and input_path are provided"): 231 mlflow.models.predict( 232 model_uri="runs:/test/Model", 233 input_data={"inputs": [1, 2, 3]}, 234 input_path="input.csv", 235 content_type=_CONTENT_TYPE_CSV, 236 ) 237 238 239 def test_predict_invalid_content_type(mock_backend): 240 with pytest.raises(MlflowException, match=r"Content type must be one of"): 241 mlflow.models.predict( 242 model_uri="runs:/test/Model", 243 input_data={"inputs": [1, 2, 3]}, 244 content_type="any", 245 ) 246 247 248 def test_predict_with_input_none(mock_backend): 249 mlflow.models.predict( 250 model_uri="runs:/test/Model", 251 content_type=_CONTENT_TYPE_CSV, 252 ) 253 254 mock_backend.predict.assert_called_once_with( 255 model_uri="runs:/test/Model", 256 input_path=None, 257 output_path=None, 258 content_type=_CONTENT_TYPE_CSV, 259 pip_requirements_override=None, 260 extra_envs=None, 261 ) 262 263 264 @pytest.mark.parametrize( 265 ("input_data", "content_type", "expected"), 266 [ 267 # String (convert to serving input) 268 ("[1, 2, 3]", _CONTENT_TYPE_JSON, '{"inputs": "[1, 2, 3]"}'), 269 # uLLM String (no change) 270 ({"input": "data"}, _CONTENT_TYPE_JSON, '{"input": "data"}'), 271 ("x,y,z\n1,2,3\n4,5,6", _CONTENT_TYPE_CSV, "x,y,z\n1,2,3\n4,5,6"), 272 # Bool 273 (True, _CONTENT_TYPE_JSON, '{"inputs": true}'), 274 # Int 275 (1, _CONTENT_TYPE_JSON, '{"inputs": 1}'), 276 # Float 277 (1.0, _CONTENT_TYPE_JSON, '{"inputs": 1.0}'), 278 # Datetime 279 ( 280 datetime.datetime(2021, 1, 1, 0, 0, 0), 281 _CONTENT_TYPE_JSON, 282 '{"inputs": "2021-01-01T00:00:00"}', 283 ), 284 # List 285 ([1, 2, 3], _CONTENT_TYPE_CSV, "0\n1\n2\n3\n"), # a header '0' is added by pandas 286 ([[1, 2, 3], [4, 5, 6]], _CONTENT_TYPE_CSV, "0,1,2\n1,2,3\n4,5,6\n"), 287 # Dict (pandas) 288 ( 289 { 290 "x": [ 291 1, 292 2, 293 ], 294 "y": [3, 4], 295 }, 296 _CONTENT_TYPE_CSV, 297 "x,y\n1,3\n2,4\n", 298 ), 299 # Dict (json) 300 ({"a": [1, 2, 3]}, _CONTENT_TYPE_JSON, '{"inputs": {"a": [1, 2, 3]}}'), 301 # Pandas DataFrame (csv) 302 (pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}), _CONTENT_TYPE_CSV, "x,y\n1,4\n2,5\n3,6\n"), 303 # Pandas DataFrame (json) 304 ( 305 pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}), 306 _CONTENT_TYPE_JSON, 307 '{"dataframe_split": {"columns": ["x", "y"], "data": [[1, 4], [2, 5], [3, 6]]}}', 308 ), 309 # Numpy Array 310 (np.array([1, 2, 3]), _CONTENT_TYPE_JSON, '{"inputs": [1, 2, 3]}'), 311 # CSC Matrix 312 ( 313 scipy.sparse.csc_matrix([[1, 2], [3, 4]]), 314 _CONTENT_TYPE_JSON, 315 '{"inputs": [[1, 2], [3, 4]]}', 316 ), 317 # CSR Matrix 318 ( 319 scipy.sparse.csr_matrix([[1, 2], [3, 4]]), 320 _CONTENT_TYPE_JSON, 321 '{"inputs": [[1, 2], [3, 4]]}', 322 ), 323 ], 324 ) 325 def test_serialize_input_data(input_data, content_type, expected): 326 if content_type == _CONTENT_TYPE_JSON: 327 assert json.loads(_serialize_input_data(input_data, content_type)) == json.loads(expected) 328 else: 329 assert _serialize_input_data(input_data, content_type) == expected 330 331 332 @pytest.mark.parametrize( 333 ("input_data", "content_type"), 334 [ 335 # Invalid input datatype for the content type 336 (1, _CONTENT_TYPE_CSV), 337 ({1, 2, 3}, _CONTENT_TYPE_CSV), 338 # Invalid string 339 ("x,y\n1,2\n3,4,5\n", _CONTENT_TYPE_CSV), 340 # Invalid list 341 ([[1, 2], [3, 4], 5], _CONTENT_TYPE_CSV), 342 # Invalid dict (unserealizable) 343 ({"x": 1, "y": {1, 2, 3}}, _CONTENT_TYPE_JSON), 344 ], 345 ) 346 def test_serialize_input_data_invalid_format(input_data, content_type): 347 with pytest.raises(MlflowException): # noqa: PT011 348 _serialize_input_data(input_data, content_type) 349 350 351 def test_predict_use_current_experiment(): 352 class TestModel(mlflow.pyfunc.PythonModel): 353 @mlflow.trace 354 def predict(self, context, model_input: list[str]): 355 return model_input 356 357 exp_id = mlflow.set_experiment("test_experiment").experiment_id 358 client = mlflow.MlflowClient() 359 with mlflow.start_run(): 360 model_info = mlflow.pyfunc.log_model( 361 name="model", 362 python_model=TestModel(), 363 ) 364 365 assert len(client.search_traces(locations=[exp_id])) == 0 366 mlflow.models.predict( 367 model_uri=model_info.model_uri, 368 input_data=["a", "b", "c"], 369 env_manager=VIRTUALENV, 370 ) 371 traces = client.search_traces(locations=[exp_id]) 372 assert len(traces) == 1 373 assert json.loads(traces[0].data.request)["model_input"] == ["a", "b", "c"] 374 375 376 def test_predict_traces_link_to_active_model(): 377 model = mlflow.set_active_model(name="test_model") 378 379 class TestModel(mlflow.pyfunc.PythonModel): 380 @mlflow.trace 381 def predict(self, context, model_input: list[str]): 382 return model_input 383 384 with mlflow.start_run(): 385 model_info = mlflow.pyfunc.log_model( 386 name="model", 387 python_model=TestModel(), 388 ) 389 390 traces = get_traces() 391 assert len(traces) == 0 392 393 mlflow.models.predict( 394 model_uri=model_info.model_uri, 395 input_data=["a", "b", "c"], 396 env_manager=VIRTUALENV, 397 ) 398 traces = get_traces() 399 assert len(traces) == 1 400 assert traces[0].info.request_metadata[TraceMetadataKey.MODEL_ID] == model.model_id