/ tests / pyfunc / test_model_export_with_loader_module_and_data_path.py
test_model_export_with_loader_module_and_data_path.py
  1  import os
  2  import pickle
  3  import types
  4  from unittest import mock
  5  
  6  import cloudpickle
  7  import numpy as np
  8  import pytest
  9  import sklearn.datasets
 10  import sklearn.neighbors
 11  import yaml
 12  
 13  import mlflow
 14  import mlflow.pyfunc
 15  from mlflow.exceptions import MlflowException
 16  from mlflow.models import Model, infer_signature
 17  from mlflow.models.utils import _read_example
 18  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 19  from mlflow.utils.environment import _mlflow_conda_env
 20  from mlflow.utils.file_utils import TempDir
 21  from mlflow.utils.model_utils import _get_flavor_configuration
 22  
 23  from tests.helper_functions import _assert_pip_requirements
 24  
 25  
 26  def _load_pyfunc(path):
 27      with open(path, "rb") as f:
 28          return pickle.load(f, encoding="latin1")
 29  
 30  
 31  @pytest.fixture
 32  def pyfunc_custom_env_file(tmp_path):
 33      conda_env = os.path.join(tmp_path, "conda_env.yml")
 34      _mlflow_conda_env(
 35          conda_env,
 36          additional_pip_deps=[
 37              "scikit-learn",
 38              "pytest",
 39              "cloudpickle",
 40              "-e " + os.path.dirname(mlflow.__path__[0]),
 41          ],
 42      )
 43      return conda_env
 44  
 45  
 46  @pytest.fixture
 47  def pyfunc_custom_env_dict():
 48      return _mlflow_conda_env(
 49          additional_pip_deps=[
 50              "scikit-learn",
 51              "pytest",
 52              "cloudpickle",
 53              "-e " + os.path.dirname(mlflow.__path__[0]),
 54          ],
 55      )
 56  
 57  
 58  @pytest.fixture(scope="module")
 59  def iris_data():
 60      iris = sklearn.datasets.load_iris()
 61      x = iris.data[:, :2]
 62      y = iris.target
 63      return x, y
 64  
 65  
 66  @pytest.fixture(scope="module")
 67  def sklearn_knn_model(iris_data):
 68      x, y = iris_data
 69      knn_model = sklearn.neighbors.KNeighborsClassifier()
 70      knn_model.fit(x, y)
 71      return knn_model
 72  
 73  
 74  @pytest.fixture
 75  def model_path(tmp_path):
 76      return os.path.join(tmp_path, "model")
 77  
 78  
 79  def test_model_save_load(sklearn_knn_model, iris_data, tmp_path, model_path):
 80      sk_model_path = os.path.join(tmp_path, "knn.pkl")
 81      with open(sk_model_path, "wb") as f:
 82          pickle.dump(sklearn_knn_model, f)
 83  
 84      model_config = Model(run_id="test", artifact_path="testtest")
 85      mlflow.pyfunc.save_model(
 86          path=model_path,
 87          data_path=sk_model_path,
 88          loader_module=__name__,
 89          code_paths=[__file__],
 90          mlflow_model=model_config,
 91      )
 92  
 93      reloaded_model_config = Model.load(os.path.join(model_path, "MLmodel"))
 94      assert model_config.__dict__ == reloaded_model_config.__dict__
 95      assert mlflow.pyfunc.FLAVOR_NAME in reloaded_model_config.flavors
 96      assert mlflow.pyfunc.PY_VERSION in reloaded_model_config.flavors[mlflow.pyfunc.FLAVOR_NAME]
 97      reloaded_model = mlflow.pyfunc.load_model(model_path)
 98      np.testing.assert_array_equal(
 99          sklearn_knn_model.predict(iris_data[0]), reloaded_model.predict(iris_data[0])
100      )
101  
102  
103  def test_signature_and_examples_are_saved_correctly(sklearn_knn_model, iris_data):
104      data = iris_data
105      signature_ = infer_signature(*data)
106      example_ = data[0][:3]
107      for signature in (None, signature_):
108          for example in (None, example_):
109              with TempDir() as tmp:
110                  with open(tmp.path("skmodel"), "wb") as f:
111                      pickle.dump(sklearn_knn_model, f)
112                  path = tmp.path("model")
113                  mlflow.pyfunc.save_model(
114                      path=path,
115                      data_path=tmp.path("skmodel"),
116                      loader_module=__name__,
117                      code_paths=[__file__],
118                      signature=signature,
119                      input_example=example,
120                  )
121                  mlflow_model = Model.load(path)
122                  assert signature == mlflow_model.signature
123                  if example is None:
124                      assert mlflow_model.saved_input_example_info is None
125                  else:
126                      np.testing.assert_array_equal(_read_example(mlflow_model, path), example)
127  
128  
129  def test_model_log_load(sklearn_knn_model, iris_data, tmp_path):
130      sk_model_path = os.path.join(tmp_path, "knn.pkl")
131      with open(sk_model_path, "wb") as f:
132          pickle.dump(sklearn_knn_model, f)
133  
134      pyfunc_artifact_path = "pyfunc_model"
135      with mlflow.start_run():
136          mlflow.pyfunc.log_model(
137              name=pyfunc_artifact_path,
138              data_path=sk_model_path,
139              loader_module=__name__,
140              code_paths=[__file__],
141          )
142          pyfunc_model_path = _download_artifact_from_uri(
143              f"runs:/{mlflow.active_run().info.run_id}/{pyfunc_artifact_path}"
144          )
145  
146      model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
147      assert mlflow.pyfunc.FLAVOR_NAME in model_config.flavors
148      assert mlflow.pyfunc.PY_VERSION in model_config.flavors[mlflow.pyfunc.FLAVOR_NAME]
149      reloaded_model = mlflow.pyfunc.load_model(pyfunc_model_path)
150      assert model_config.to_yaml() == reloaded_model.metadata.to_yaml()
151      np.testing.assert_array_equal(
152          sklearn_knn_model.predict(iris_data[0]), reloaded_model.predict(iris_data[0])
153      )
154  
155  
156  @pytest.mark.skip(
157      reason="In MLflow 3.0, `log_model` does not start a run. Consider removing this test."
158  )
159  def test_model_log_load_no_active_run(sklearn_knn_model, iris_data, tmp_path):
160      sk_model_path = os.path.join(tmp_path, "knn.pkl")
161      with open(sk_model_path, "wb") as f:
162          pickle.dump(sklearn_knn_model, f)
163  
164      pyfunc_artifact_path = "pyfunc_model"
165      assert mlflow.active_run() is None
166      mlflow.pyfunc.log_model(
167          name=pyfunc_artifact_path,
168          data_path=sk_model_path,
169          loader_module=__name__,
170          code_paths=[__file__],
171      )
172      pyfunc_model_path = _download_artifact_from_uri(
173          f"runs:/{mlflow.active_run().info.run_id}/{pyfunc_artifact_path}"
174      )
175  
176      model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
177      assert mlflow.pyfunc.FLAVOR_NAME in model_config.flavors
178      assert mlflow.pyfunc.PY_VERSION in model_config.flavors[mlflow.pyfunc.FLAVOR_NAME]
179      reloaded_model = mlflow.pyfunc.load_model(pyfunc_model_path)
180      np.testing.assert_array_equal(
181          sklearn_knn_model.predict(iris_data[0]), reloaded_model.predict(iris_data[0])
182      )
183      mlflow.end_run()
184  
185  
186  def test_save_model_with_unsupported_argument_combinations_throws_exception(model_path):
187      with pytest.raises(
188          MlflowException, match="Either `loader_module` or `python_model` must be specified"
189      ):
190          mlflow.pyfunc.save_model(path=model_path, data_path="/path/to/data")
191  
192  
193  def test_log_model_with_unsupported_argument_combinations_throws_exception():
194      with (
195          mlflow.start_run(),
196          pytest.raises(
197              MlflowException, match="Either `loader_module` or `python_model` must be specified"
198          ),
199      ):
200          mlflow.pyfunc.log_model(name="pyfunc_model", data_path="/path/to/data")
201  
202  
203  def test_log_model_persists_specified_conda_env_file_in_mlflow_model_directory(
204      sklearn_knn_model, tmp_path, pyfunc_custom_env_file
205  ):
206      sk_model_path = os.path.join(tmp_path, "knn.pkl")
207      with open(sk_model_path, "wb") as f:
208          pickle.dump(sklearn_knn_model, f)
209  
210      pyfunc_artifact_path = "pyfunc_model"
211      with mlflow.start_run():
212          mlflow.pyfunc.log_model(
213              name=pyfunc_artifact_path,
214              data_path=sk_model_path,
215              loader_module=__name__,
216              code_paths=[__file__],
217              conda_env=pyfunc_custom_env_file,
218          )
219          run_id = mlflow.active_run().info.run_id
220  
221      pyfunc_model_path = _download_artifact_from_uri(f"runs:/{run_id}/{pyfunc_artifact_path}")
222  
223      pyfunc_conf = _get_flavor_configuration(
224          model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
225      )
226      saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV]["conda"])
227      assert os.path.exists(saved_conda_env_path)
228      assert saved_conda_env_path != pyfunc_custom_env_file
229  
230      with open(pyfunc_custom_env_file) as f:
231          pyfunc_custom_env_parsed = yaml.safe_load(f)
232      with open(saved_conda_env_path) as f:
233          saved_conda_env_parsed = yaml.safe_load(f)
234      assert saved_conda_env_parsed == pyfunc_custom_env_parsed
235  
236  
237  def test_log_model_persists_specified_conda_env_dict_in_mlflow_model_directory(
238      sklearn_knn_model, tmp_path, pyfunc_custom_env_dict
239  ):
240      sk_model_path = os.path.join(tmp_path, "knn.pkl")
241      with open(sk_model_path, "wb") as f:
242          pickle.dump(sklearn_knn_model, f)
243  
244      pyfunc_artifact_path = "pyfunc_model"
245      with mlflow.start_run():
246          mlflow.pyfunc.log_model(
247              name=pyfunc_artifact_path,
248              data_path=sk_model_path,
249              loader_module=__name__,
250              code_paths=[__file__],
251              conda_env=pyfunc_custom_env_dict,
252          )
253          run_id = mlflow.active_run().info.run_id
254  
255      pyfunc_model_path = _download_artifact_from_uri(f"runs:/{run_id}/{pyfunc_artifact_path}")
256  
257      pyfunc_conf = _get_flavor_configuration(
258          model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
259      )
260      saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV]["conda"])
261      assert os.path.exists(saved_conda_env_path)
262  
263      with open(saved_conda_env_path) as f:
264          saved_conda_env_parsed = yaml.safe_load(f)
265      assert saved_conda_env_parsed == pyfunc_custom_env_dict
266  
267  
268  def test_log_model_persists_requirements_in_mlflow_model_directory(
269      sklearn_knn_model, tmp_path, pyfunc_custom_env_dict
270  ):
271      sk_model_path = os.path.join(tmp_path, "knn.pkl")
272      with open(sk_model_path, "wb") as f:
273          pickle.dump(sklearn_knn_model, f)
274  
275      pyfunc_artifact_path = "pyfunc_model"
276      with mlflow.start_run():
277          mlflow.pyfunc.log_model(
278              name=pyfunc_artifact_path,
279              data_path=sk_model_path,
280              loader_module=__name__,
281              code_paths=[__file__],
282              conda_env=pyfunc_custom_env_dict,
283          )
284          run_id = mlflow.active_run().info.run_id
285  
286      pyfunc_model_path = _download_artifact_from_uri(f"runs:/{run_id}/{pyfunc_artifact_path}")
287  
288      saved_pip_req_path = os.path.join(pyfunc_model_path, "requirements.txt")
289      assert os.path.exists(saved_pip_req_path)
290  
291      with open(saved_pip_req_path) as f:
292          requirements = f.read().split("\n")
293  
294      assert pyfunc_custom_env_dict["dependencies"][-1]["pip"] == requirements
295  
296  
297  def test_log_model_without_specified_conda_env_uses_default_env_with_expected_dependencies(
298      sklearn_knn_model, tmp_path
299  ):
300      sk_model_path = os.path.join(tmp_path, "knn.pkl")
301      with open(sk_model_path, "wb") as f:
302          pickle.dump(sklearn_knn_model, f)
303  
304      pyfunc_artifact_path = "pyfunc_model"
305      with mlflow.start_run():
306          model_info = mlflow.pyfunc.log_model(
307              name=pyfunc_artifact_path,
308              data_path=sk_model_path,
309              loader_module=__name__,
310              code_paths=[__file__],
311          )
312      _assert_pip_requirements(model_info.model_uri, mlflow.pyfunc.get_default_pip_requirements())
313  
314  
315  def test_streamable_model_save_load(tmp_path, model_path):
316      class StreamableModel:
317          def __init__(self):
318              pass
319  
320          def predict(self, model_input, params=None):
321              pass
322  
323          def predict_stream(self, model_input, params=None):
324              yield "test1"
325              yield "test2"
326  
327      custom_model = StreamableModel()
328  
329      custom_model_path = os.path.join(tmp_path, "model.pkl")
330      with open(custom_model_path, "wb") as f:
331          cloudpickle.dump(custom_model, f)
332  
333      model_config = Model(run_id="test", artifact_path="testtest")
334      mlflow.pyfunc.save_model(
335          path=model_path,
336          data_path=custom_model_path,
337          loader_module=__name__,
338          code_paths=[__file__],
339          mlflow_model=model_config,
340          streamable=True,
341      )
342      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=model_path)
343  
344      stream_result = loaded_pyfunc_model.predict_stream("single-input")
345      assert isinstance(stream_result, types.GeneratorType)
346  
347      assert list(stream_result) == ["test1", "test2"]
348  
349  
350  def test_log_loader_module_model_does_not_emit_pickle_warning(sklearn_knn_model, tmp_path):
351      sk_model_path = tmp_path / "knn.pkl"
352      with open(sk_model_path, "wb") as f:
353          pickle.dump(sklearn_knn_model, f)
354  
355      with mlflow.start_run(), mock.patch("mlflow.pyfunc._logger.warning") as mock_log_warning:
356          mlflow.pyfunc.log_model(
357              name="pyfunc_model",
358              data_path=sk_model_path,
359              loader_module=__name__,
360              code_paths=[__file__],
361          )
362  
363      warning_messages = [args[0] for args, _ in mock_log_warning.call_args_list if args]
364      assert not any(
365          "Passing a Python object as `python_model` causes it to be serialized using CloudPickle"
366          in msg
367          for msg in warning_messages
368      )