/ tests / models / test_model.py
test_model.py
  1  import json
  2  import os
  3  import pathlib
  4  import time
  5  import uuid
  6  from datetime import date
  7  from unittest import mock
  8  
  9  import numpy as np
 10  import pandas as pd
 11  import pydantic
 12  import pytest
 13  import sklearn.datasets
 14  import sklearn.neighbors
 15  from packaging.version import Version
 16  from scipy.sparse import csc_matrix
 17  
 18  import mlflow
 19  from mlflow.exceptions import MlflowException
 20  from mlflow.models import Model, ModelSignature, infer_signature, set_model, validate_schema
 21  from mlflow.models.model import METADATA_FILES, SET_MODEL_ERROR
 22  from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex
 23  from mlflow.models.utils import _read_example, _save_example
 24  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 25  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 26  from mlflow.types.schema import ColSpec, DataType, ParamSchema, ParamSpec, Schema, TensorSpec
 27  from mlflow.utils.databricks_utils import DatabricksRuntimeVersion
 28  from mlflow.utils.file_utils import TempDir
 29  from mlflow.utils.model_utils import _validate_and_prepare_target_save_path
 30  from mlflow.utils.proto_json_utils import dataframe_from_raw_json
 31  
 32  
 33  @pytest.fixture(scope="module")
 34  def iris_data():
 35      iris = sklearn.datasets.load_iris()
 36      x = iris.data[:, :2]
 37      y = iris.target
 38      return x, y
 39  
 40  
 41  @pytest.fixture(scope="module")
 42  def sklearn_knn_model(iris_data):
 43      x, y = iris_data
 44      knn_model = sklearn.neighbors.KNeighborsClassifier()
 45      knn_model.fit(x, y)
 46      return knn_model
 47  
 48  
 49  def test_model_save_load():
 50      m = Model(
 51          artifact_path="model",
 52          run_id="123",
 53          flavors={"flavor1": {"a": 1, "b": 2}, "flavor2": {"x": 1, "y": 2}},
 54          signature=ModelSignature(
 55              inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]),
 56              outputs=Schema([ColSpec(name=None, type="double")]),
 57          ),
 58          saved_input_example_info={"x": 1, "y": 2},
 59      )
 60      assert m.get_input_schema() == m.signature.inputs
 61      assert m.get_output_schema() == m.signature.outputs
 62      x = Model(artifact_path="some/other/path", run_id="1234")
 63      assert x.get_input_schema() is None
 64      assert x.get_output_schema() is None
 65  
 66      n = Model(
 67          artifact_path="model",
 68          run_id="123",
 69          flavors={"flavor1": {"a": 1, "b": 2}, "flavor2": {"x": 1, "y": 2}},
 70          signature=ModelSignature(
 71              inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]),
 72              outputs=Schema([ColSpec(name=None, type="double")]),
 73          ),
 74          saved_input_example_info={"x": 1, "y": 2},
 75      )
 76      n.utc_time_created = m.utc_time_created
 77      n.model_uuid = m.model_uuid
 78      assert m == n
 79      n.signature = None
 80      assert m != n
 81      with TempDir() as tmp:
 82          m.save(tmp.path("MLmodel"))
 83          o = Model.load(tmp.path("MLmodel"))
 84      assert m == o
 85      assert m.to_json() == o.to_json()
 86      assert m.to_yaml() == o.to_yaml()
 87  
 88  
 89  def test_model_load_remote(tmp_path, mock_s3_bucket):
 90      model = Model(
 91          artifact_path="model",
 92          run_id="123",
 93          flavors={"flavor1": {"a": 1, "b": 2}, "flavor2": {"x": 1, "y": 2}},
 94          signature=ModelSignature(
 95              inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]),
 96              outputs=Schema([ColSpec(name=None, type="double")]),
 97          ),
 98          saved_input_example_info={"x": 1, "y": 2},
 99      )
100      model_path = tmp_path / "MLmodel"
101      model.save(model_path)
102  
103      artifact_root = f"s3://{mock_s3_bucket}"
104      artifact_repo = S3ArtifactRepository(artifact_root)
105      artifact_repo.log_artifact(str(model_path))
106  
107      model_reloaded_1 = Model.load(f"{artifact_root}/MLmodel")
108      assert model_reloaded_1 == model
109  
110      model_reloaded_2 = Model.load(artifact_root)
111      assert model_reloaded_2 == model
112  
113  
114  class TestFlavor:
115      @classmethod
116      def save_model(cls, path, mlflow_model, signature=None, input_example=None):
117          mlflow_model.flavors["flavor1"] = {"a": 1, "b": 2}
118          mlflow_model.flavors["flavor2"] = {"x": 1, "y": 2}
119          _validate_and_prepare_target_save_path(path)
120          if signature is not None:
121              mlflow_model.signature = signature
122          if input_example is not None:
123              _save_example(mlflow_model, input_example, path)
124          mlflow_model.save(os.path.join(path, "MLmodel"))
125  
126  
127  def _log_model_with_signature_and_example(
128      tmp_path, sig, input_example, metadata=None, resources=None
129  ):
130      experiment_id = mlflow.create_experiment("test")
131  
132      with mlflow.start_run(experiment_id=experiment_id) as run:
133          model = Model.log(
134              "model",
135              TestFlavor,
136              signature=sig,
137              input_example=input_example,
138              metadata=metadata,
139              resources=resources,
140          )
141  
142      # TODO: remove this after replacing all `with TempDir(chdr=True) as tmp`
143      # with tmp_path fixture
144      output_path = tmp_path if isinstance(tmp_path, pathlib.PosixPath) else tmp_path.path("")
145      local_path = _download_artifact_from_uri(model.model_uri, output_path=output_path)
146      return local_path, run
147  
148  
149  def test_model_log():
150      with TempDir(chdr=True) as tmp:
151          sig = ModelSignature(
152              inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]),
153              outputs=Schema([ColSpec(name=None, type="double")]),
154          )
155          input_example = {"x": 1, "y": 2}
156          local_path, r = _log_model_with_signature_and_example(tmp, sig, input_example)
157  
158          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
159          assert loaded_model.run_id == r.info.run_id
160          assert loaded_model.flavors == {
161              "flavor1": {"a": 1, "b": 2},
162              "flavor2": {"x": 1, "y": 2},
163          }
164          assert loaded_model.signature == sig
165          x = _read_example(
166              Model(saved_input_example_info=loaded_model.saved_input_example_info), local_path
167          )
168          assert x == input_example
169          assert not hasattr(loaded_model, "databricks_runtime")
170  
171          loaded_example = loaded_model.load_input_example(local_path)
172          assert loaded_example == input_example
173  
174          assert Version(loaded_model.mlflow_version) == Version(mlflow.version.VERSION)
175  
176  
177  def test_model_log_without_run(tmp_path):
178      model_info = Model.log("model", TestFlavor)
179      assert model_info.run_id is None
180  
181  
182  def test_model_log_with_active_run(tmp_path):
183      with mlflow.start_run() as run:
184          model_info = Model.log("model", TestFlavor)
185      assert model_info.run_id == run.info.run_id
186  
187  
188  def test_model_log_inactive_run_id(tmp_path):
189      experiment_id = mlflow.create_experiment("test", artifact_location=str(tmp_path))
190      run = mlflow.MlflowClient().create_run(experiment_id=experiment_id)
191      model_info = Model.log("model", TestFlavor, run_id=run.info.run_id)
192      assert model_info.run_id == run.info.run_id
193  
194  
195  def test_model_log_calls_maybe_render_agent_eval_recipe(tmp_path):
196      sig = ModelSignature(
197          inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]),
198          outputs=Schema([ColSpec(name=None, type="double")]),
199      )
200      input_example = {"x": 1, "y": 2}
201      with mock.patch("mlflow.models.display_utils.maybe_render_agent_eval_recipe") as render_mock:
202          _log_model_with_signature_and_example(tmp_path, sig, input_example)
203          render_mock.assert_called_once()
204  
205  
206  def test_model_info():
207      with TempDir(chdr=True) as tmp:
208          sig = ModelSignature(
209              inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]),
210              outputs=Schema([ColSpec(name=None, type="double")]),
211          )
212          input_example = {"x": 1, "y": 2}
213  
214          experiment_id = mlflow.create_experiment("test")
215          with mlflow.start_run(experiment_id=experiment_id) as run:
216              model_info = Model.log("model", TestFlavor, signature=sig, input_example=input_example)
217          model_uri = f"models:/{model_info.model_id}"
218  
219          model_info_fetched = mlflow.models.get_model_info(model_uri)
220          local_path = _download_artifact_from_uri(model_uri, output_path=tmp.path(""))
221  
222          assert model_info.run_id == run.info.run_id
223          assert model_info_fetched.run_id == run.info.run_id
224          assert model_info.model_uri == model_uri
225          assert model_info_fetched.model_uri == model_uri
226  
227          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
228          assert model_info.utc_time_created == loaded_model.utc_time_created
229          assert model_info_fetched.utc_time_created == loaded_model.utc_time_created
230          assert model_info.model_uuid == loaded_model.model_uuid
231          assert model_info_fetched.model_uuid == loaded_model.model_uuid
232  
233          assert model_info.flavors == {
234              "flavor1": {"a": 1, "b": 2},
235              "flavor2": {"x": 1, "y": 2},
236          }
237  
238          x = _read_example(
239              Model(saved_input_example_info=model_info.saved_input_example_info), local_path
240          )
241          assert x == input_example
242  
243          model_signature = model_info_fetched.signature
244          assert model_signature.to_dict() == sig.to_dict()
245  
246          assert model_info.mlflow_version == loaded_model.mlflow_version
247          assert model_info_fetched.mlflow_version == loaded_model.mlflow_version
248  
249  
250  def test_model_info_with_model_version(tmp_path):
251      experiment_id = mlflow.create_experiment("test", artifact_location=str(tmp_path))
252      with mlflow.start_run(experiment_id=experiment_id):
253          model_info = Model.log("model", TestFlavor, registered_model_name="model_abc")
254          assert model_info.registered_model_version == 1
255          model_info = Model.log("model", TestFlavor, registered_model_name="model_abc")
256          assert model_info.registered_model_version == 2
257          model_info = Model.log("model", TestFlavor)
258          assert model_info.registered_model_version is None
259  
260  
261  def test_model_metadata():
262      with TempDir(chdr=True) as tmp:
263          metadata = {"metadata_key": "metadata_value"}
264          local_path, _ = _log_model_with_signature_and_example(tmp, None, None, metadata)
265          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
266          assert loaded_model.metadata["metadata_key"] == "metadata_value"
267  
268  
269  def test_load_model_without_mlflow_version():
270      with TempDir(chdr=True) as tmp:
271          model = Model(artifact_path="model", run_id="1234", mlflow_version=None)
272          path = tmp.path("MLmodel")
273          with open(path, "w") as out:
274              model.to_yaml(out)
275          loaded_model = Model.load(path)
276  
277          assert loaded_model.mlflow_version is None
278  
279  
280  def test_model_log_with_databricks_runtime():
281      dbr_version = "8.3.x"
282      with mlflow.start_run():
283          with mock.patch(
284              "mlflow.models.model.get_databricks_runtime_version", return_value=dbr_version
285          ) as mock_get_dbr_version:
286              model = Model.log("path", TestFlavor, signature=None, input_example=None)
287              mock_get_dbr_version.assert_called()
288  
289      loaded_model = Model.load(model.model_uri)
290      assert loaded_model.databricks_runtime == dbr_version
291  
292  
293  def test_model_log_with_databricks_runtime_gpu():
294      dbr_version = "client.8.1-gpu"
295      with mlflow.start_run():
296          with mock.patch(
297              "mlflow.models.model.get_databricks_runtime_version", return_value=dbr_version
298          ) as mock_get_dbr_version:
299              model = Model.log("path", TestFlavor, signature=None, input_example=None)
300              mock_get_dbr_version.assert_called()
301  
302      # Verify the GPU suffix is preserved in the MLmodel file
303      loaded_model = Model.load(model.model_uri)
304      assert loaded_model.databricks_runtime == dbr_version
305  
306      # Verify that the version can be parsed correctly and is_gpu_image is True
307      parsed_version = DatabricksRuntimeVersion.parse(loaded_model.databricks_runtime)
308      assert parsed_version.is_client_image is True
309      assert parsed_version.major == 8
310      assert parsed_version.minor == 1
311      assert parsed_version.is_gpu_image is True
312  
313  
314  def test_model_log_with_input_example_succeeds():
315      with TempDir(chdr=True) as tmp:
316          sig = ModelSignature(
317              inputs=Schema([
318                  ColSpec("integer", "a"),
319                  ColSpec("string", "b"),
320                  ColSpec("boolean", "c"),
321                  ColSpec("string", "d"),
322                  ColSpec("datetime", "e"),
323              ]),
324              outputs=Schema([ColSpec(name=None, type="double")]),
325          )
326          input_example = pd.DataFrame(
327              {
328                  "a": np.int32(1),
329                  "b": "test string",
330                  "c": True,
331                  "d": date.today(),
332                  "e": np.datetime64("2020-01-01T00:00:00"),
333              },
334              index=[0],
335          )
336  
337          local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example)
338          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
339          path = os.path.join(local_path, loaded_model.saved_input_example_info["artifact_path"])
340          x = dataframe_from_raw_json(path, schema=sig.inputs)
341  
342          # date column will get deserialized into string
343          input_example["d"] = input_example["d"].apply(lambda x: x.isoformat())
344          # datetime Datatype numpy type is [ns]
345          input_example["e"] = input_example["e"].astype(np.dtype("datetime64[ns]"))
346          pd.testing.assert_frame_equal(x, input_example)
347  
348          loaded_example = loaded_model.load_input_example(local_path)
349          assert isinstance(loaded_example, pd.DataFrame)
350          pd.testing.assert_frame_equal(loaded_example, input_example)
351  
352  
353  def test_model_input_example_with_params_log_load_succeeds(tmp_path):
354      pdf = pd.DataFrame(
355          {
356              "a": np.int32(1),
357              "b": "test string",
358              "c": True,
359              "d": date.today(),
360              "e": np.datetime64("2020-01-01T00:00:00"),
361          },
362          index=[0],
363      )
364      input_example = (pdf, {"a": 1, "b": "string"})
365  
366      sig = ModelSignature(
367          inputs=Schema([
368              ColSpec("integer", "a"),
369              ColSpec("string", "b"),
370              ColSpec("boolean", "c"),
371              ColSpec("string", "d"),
372              ColSpec("datetime", "e"),
373          ]),
374          outputs=Schema([ColSpec(name=None, type="double")]),
375          params=ParamSchema([
376              ParamSpec("a", DataType.long, 1),
377              ParamSpec("b", DataType.string, "string"),
378          ]),
379      )
380  
381      local_path, _ = _log_model_with_signature_and_example(tmp_path, sig, input_example)
382      loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
383  
384      # date column will get deserialized into string
385      pdf["d"] = pdf["d"].apply(lambda x: x.isoformat())
386      loaded_example = loaded_model.load_input_example(local_path)
387      assert isinstance(loaded_example, pd.DataFrame)
388      # datetime Datatype numpy type is [ns]
389      pdf["e"] = pdf["e"].astype(np.dtype("datetime64[ns]"))
390      pd.testing.assert_frame_equal(loaded_example, pdf)
391  
392      params = loaded_model.load_input_example_params(local_path)
393      assert params == input_example[1]
394  
395  
396  def test_model_load_input_example_numpy():
397      with TempDir(chdr=True) as tmp:
398          input_example = np.array([[3, 4, 5]], dtype=np.int32)
399          sig = ModelSignature(
400              inputs=Schema([TensorSpec(type=input_example.dtype, shape=input_example.shape)]),
401              outputs=Schema([ColSpec(name=None, type="double")]),
402          )
403  
404          local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example)
405          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
406          loaded_example = loaded_model.load_input_example(local_path)
407  
408          assert isinstance(loaded_example, np.ndarray)
409          np.testing.assert_array_equal(input_example, loaded_example)
410  
411  
412  def test_model_load_input_example_scipy():
413      with TempDir(chdr=True) as tmp:
414          input_example = csc_matrix(np.arange(0, 12, 0.5).reshape(3, 8))
415          sig = ModelSignature(
416              inputs=Schema([TensorSpec(type=input_example.data.dtype, shape=input_example.shape)]),
417              outputs=Schema([ColSpec(name=None, type="double")]),
418          )
419  
420          local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example)
421          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
422          loaded_example = loaded_model.load_input_example(local_path)
423  
424          assert isinstance(loaded_example, csc_matrix)
425          np.testing.assert_array_equal(input_example.data, loaded_example.data)
426  
427  
428  def test_model_load_input_example_failures():
429      with TempDir(chdr=True) as tmp:
430          input_example = np.array([[3, 4, 5]], dtype=np.int32)
431          sig = ModelSignature(
432              inputs=Schema([TensorSpec(type=input_example.dtype, shape=input_example.shape)]),
433              outputs=Schema([ColSpec(name=None, type="double")]),
434          )
435  
436          local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example)
437          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
438          loaded_example = loaded_model.load_input_example(local_path)
439          assert loaded_example is not None
440  
441          with pytest.raises(MlflowException, match="No such artifact"):
442              loaded_model.load_input_example(os.path.join(local_path, "folder_which_does_not_exist"))
443  
444          path = os.path.join(local_path, loaded_model.saved_input_example_info["artifact_path"])
445          os.remove(path)
446          with pytest.raises(MlflowException, match="No such artifact"):
447              loaded_model.load_input_example(local_path)
448  
449  
450  def test_model_load_input_example_no_signature():
451      with TempDir(chdr=True) as tmp:
452          input_example = np.array([[3, 4, 5]], dtype=np.int32)
453          sig = ModelSignature(
454              inputs=Schema([TensorSpec(type=input_example.dtype, shape=input_example.shape)]),
455              outputs=Schema([ColSpec(name=None, type="double")]),
456          )
457  
458          local_path, _ = _log_model_with_signature_and_example(tmp, sig, input_example=None)
459          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
460          loaded_example = loaded_model.load_input_example(local_path)
461          assert loaded_example is None
462  
463  
464  def _is_valid_uuid(val):
465      try:
466          uuid.UUID(str(val))
467          return True
468      except ValueError:
469          return False
470  
471  
472  def test_model_uuid():
473      m = Model()
474      assert m.model_uuid is not None
475      assert _is_valid_uuid(m.model_uuid)
476  
477      m2 = Model()
478      assert m.model_uuid != m2.model_uuid
479  
480      m_dict = m.to_dict()
481      assert m_dict["model_uuid"] == m.model_uuid
482      m3 = Model.from_dict(m_dict)
483      assert m3.model_uuid == m.model_uuid
484  
485      m_dict.pop("model_uuid")
486      m4 = Model.from_dict(m_dict)
487      assert m4.model_uuid is None
488  
489  
490  def test_validate_schema(sklearn_knn_model, iris_data, tmp_path):
491      sk_model_path = os.path.join(tmp_path, "sk_model")
492      X, y = iris_data
493      signature = infer_signature(X, y)
494      mlflow.sklearn.save_model(
495          sklearn_knn_model,
496          sk_model_path,
497          signature=signature,
498      )
499  
500      validate_schema(X, signature.inputs)
501      prediction = sklearn_knn_model.predict(X)
502      reloaded_model = mlflow.sklearn.load_model(sk_model_path)
503      np.testing.assert_array_equal(prediction, reloaded_model.predict(X))
504      validate_schema(prediction, signature.outputs)
505  
506  
507  def test_save_load_input_example_without_conversion(tmp_path):
508      class MyModel(mlflow.pyfunc.PythonModel):
509          def predict(self, context, model_input, params=None):
510              return model_input
511  
512      input_example = {
513          "messages": [
514              {"role": "user", "content": "Hello!"},
515          ]
516      }
517      with mlflow.start_run() as run:
518          mlflow.pyfunc.log_model(
519              name="test_model",
520              python_model=MyModel(),
521              input_example=input_example,
522          )
523          local_path = _download_artifact_from_uri(
524              f"runs:/{run.info.run_id}/test_model", output_path=tmp_path
525          )
526      loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
527      assert loaded_model.saved_input_example_info["type"] == "json_object"
528      loaded_example = loaded_model.load_input_example(local_path)
529      assert loaded_example == input_example
530  
531  
532  def test_save_load_input_example_with_pydantic_model(tmp_path):
533      class Message(pydantic.BaseModel):
534          role: str
535          content: str
536  
537      class MyModel(mlflow.pyfunc.PythonModel):
538          def predict(self, context, model_input: list[Message], params=None):
539              return model_input
540  
541      with mlflow.start_run():
542          model_info = mlflow.pyfunc.log_model(
543              name="test_model",
544              python_model=MyModel(),
545              input_example=[Message(role="user", content="Hello!")],
546          )
547      local_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path)
548      loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
549      assert loaded_model.saved_input_example_info["type"] == "json_object"
550      loaded_example = loaded_model.load_input_example(local_path)
551      assert loaded_example == [{"role": "user", "content": "Hello!"}]
552  
553  
554  def test_model_saved_by_save_model_can_be_loaded(tmp_path, sklearn_knn_model):
555      mlflow.sklearn.save_model(sklearn_knn_model, tmp_path)
556      info = Model.load(tmp_path).get_model_info()
557      assert info.run_id is None
558      assert info.artifact_path is None
559  
560  
561  def test_copy_metadata(mock_is_in_databricks, sklearn_knn_model):
562      with mlflow.start_run():
563          model_info = mlflow.sklearn.log_model(sklearn_knn_model, name="model")
564  
565      artifact_path = mlflow.artifacts.download_artifacts(model_info.model_uri)
566      metadata_path = os.path.join(artifact_path, "metadata")
567      # Metadata should be copied only in Databricks
568      if mock_is_in_databricks.return_value:
569          assert set(os.listdir(metadata_path)) == set(METADATA_FILES)
570      else:
571          assert not os.path.exists(metadata_path)
572      mock_is_in_databricks.assert_called_once()
573  
574  
575  class LegacyTestFlavor:
576      @classmethod
577      def save_model(cls, path, mlflow_model):
578          mlflow_model.flavors["flavor1"] = {"a": 1, "b": 2}
579          mlflow_model.flavors["flavor2"] = {"x": 1, "y": 2}
580          _validate_and_prepare_target_save_path(path)
581          mlflow_model.save(os.path.join(path, "MLmodel"))
582  
583  
584  def test_legacy_flavor(mock_is_in_databricks):
585      with mlflow.start_run():
586          model_info = Model.log("model", LegacyTestFlavor)
587  
588      artifact_path = _download_artifact_from_uri(model_info.model_uri)
589      metadata_path = os.path.join(artifact_path, "metadata")
590      # Metadata should be copied only in Databricks
591      if mock_is_in_databricks.return_value:
592          assert set(os.listdir(metadata_path)) == {"MLmodel"}
593      else:
594          assert not os.path.exists(metadata_path)
595      mock_is_in_databricks.assert_called_once()
596  
597  
598  def test_pyfunc_set_model():
599      class MyModel(mlflow.pyfunc.PythonModel):
600          def predict(self, context, model_input):
601              return model_input
602  
603      set_model(MyModel())
604      assert isinstance(mlflow.models.model.__mlflow_model__, mlflow.pyfunc.PythonModel)
605  
606  
607  def test_langchain_set_model():
608      from langchain_core.runnables import RunnableLambda
609  
610      def create_runnable():
611          def my_runnable(input):
612              return f"Input was: {input}"
613  
614          runnable = RunnableLambda(my_runnable)
615          set_model(runnable)
616  
617      create_runnable()
618      assert isinstance(mlflow.models.model.__mlflow_model__, RunnableLambda)
619  
620  
621  def test_error_set_model(sklearn_knn_model):
622      with pytest.raises(mlflow.MlflowException, match=SET_MODEL_ERROR):
623          set_model(sklearn_knn_model)
624  
625  
626  def test_model_resources():
627      expected_resources = {
628          "api_version": "1",
629          "databricks": {
630              "serving_endpoint": [
631                  {"name": "databricks-mixtral-8x7b-instruct"},
632                  {"name": "databricks-bge-large-en"},
633                  {"name": "azure-eastus-model-serving-2_vs_endpoint"},
634              ],
635              "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
636          },
637      }
638      with TempDir(chdr=True) as tmp:
639          resources = [
640              DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
641              DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
642              DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"),
643              DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"),
644          ]
645          local_path, _ = _log_model_with_signature_and_example(tmp, None, None, resources=resources)
646          loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
647          assert loaded_model.resources == expected_resources
648  
649  
650  def test_save_load_model_with_run_uri():
651      class MyModel(mlflow.pyfunc.PythonModel):
652          def predict(self, context, model_input: list[str], params=None):
653              return model_input
654  
655      with mlflow.start_run() as run:
656          mlflow.pyfunc.log_model(
657              name="test_model",
658              python_model=MyModel(),
659              input_example=["a", "b", "c"],
660          )
661      mlflow_model = Model.load(f"runs:/{run.info.run_id}/test_model/MLmodel")
662      assert mlflow_model.load_input_example() == ["a", "b", "c"]
663  
664      model = Model.load(f"runs:/{run.info.run_id}/test_model")
665      assert model == mlflow_model
666  
667      model = Model.load(f"runs:/{run.info.run_id}/test_model/")
668      assert model == mlflow_model
669  
670  
671  def test_save_model_with_prompts():
672      prompt_1 = mlflow.register_prompt("prompt-1", "Hello, {{title}} {{name}}!")
673      time.sleep(0.001)  # To avoid timestamp precision issue in Windows
674      prompt_2 = mlflow.register_prompt("prompt-2", "Hello, {{title}} {{name}}!")
675  
676      class MyModel(mlflow.pyfunc.PythonModel):
677          def predict(self, model_input: list[str]):
678              return model_input
679  
680      with mlflow.start_run():
681          model_info = mlflow.pyfunc.log_model(
682              name="test_model",
683              python_model=MyModel(),
684              # The 'prompts' parameter should accept both prompt object and URI
685              prompts=[prompt_1, prompt_2.uri],
686          )
687  
688      assert model_info.prompts == [prompt_1.uri, prompt_2.uri]
689  
690      # Prompts should be recorded in the yaml file
691      model = Model.load(model_info.model_uri)
692      assert model.prompts == [prompt_1.uri, prompt_2.uri]
693  
694      # Check that prompts were linked to the run via the linkedPrompts tag
695      from mlflow.tracing.constant import TraceTagKey
696  
697      run = mlflow.MlflowClient().get_run(model_info.run_id)
698      linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS)
699      assert linked_prompts_tag is not None
700  
701      linked_prompts = json.loads(linked_prompts_tag)
702      assert len(linked_prompts) == 2
703      assert {p["name"] for p in linked_prompts} == {prompt_1.name, prompt_2.name}
704  
705  
706  def test_logged_model_status():
707      def predict_fn(model_input: list[str]):
708          return model_input
709  
710      model_info = mlflow.pyfunc.log_model(
711          name="test_model",
712          python_model=predict_fn,
713          input_example=["a", "b", "c"],
714      )
715      logged_model = mlflow.get_logged_model(model_info.model_id)
716      assert logged_model.status == "READY"
717  
718      with pytest.raises(Exception, match=r"mock exception"):
719          with mock.patch(
720              "mlflow.pyfunc.model._save_model_with_class_artifacts_params",
721              side_effect=Exception("mock exception"),
722          ):
723              mlflow.pyfunc.log_model(
724                  name="test_model",
725                  python_model=predict_fn,
726                  input_example=["a", "b", "c"],
727              )
728      logged_model = mlflow.last_logged_model()
729      assert logged_model.status == "FAILED"
730  
731  
732  def test_model_log_links_prompts_to_logged_model():
733      client = mlflow.MlflowClient()
734  
735      # Create actual prompts in the registry
736      client.create_prompt(name="test_prompt_1")
737      prompt_1 = client.create_prompt_version(name="test_prompt_1", template="Hello {{name}}")
738      client.create_prompt(name="test_prompt_2")
739      prompt_2 = client.create_prompt_version(name="test_prompt_2", template="Goodbye {{name}}")
740  
741      with mlflow.start_run() as run:
742          model_info = Model.log("model", TestFlavor, prompts=[prompt_1, prompt_2])
743  
744      # Verify prompts were linked to the run
745      run_data = client.get_run(run.info.run_id)
746      linked_prompts_tag = run_data.data.tags.get("mlflow.linkedPrompts")
747      assert linked_prompts_tag is not None
748      linked_prompts = json.loads(linked_prompts_tag)
749      assert len(linked_prompts) == 2
750      assert {p["name"] for p in linked_prompts} == {"test_prompt_1", "test_prompt_2"}
751  
752      # Verify prompts were linked to the LoggedModel
753      logged_model = client.get_logged_model(model_info.model_id)
754      model_linked_prompts_tag = logged_model.tags.get("mlflow.linkedPrompts")
755      assert model_linked_prompts_tag is not None
756      model_linked_prompts = json.loads(model_linked_prompts_tag)
757      assert len(model_linked_prompts) == 2
758      assert {p["name"] for p in model_linked_prompts} == {"test_prompt_1", "test_prompt_2"}
759  
760  
761  def test_get_model_info_with_logged_model():
762      def model(model_input: list[str]) -> list[str]:
763          return model_input
764  
765      model_info_log_model = mlflow.pyfunc.log_model(
766          name="test_model", python_model=model, input_example=["a", "b", "c"]
767      )
768      model_info_get_model_info = mlflow.models.get_model_info(model_info_log_model.model_uri)
769      assert model_info_log_model.model_id == model_info_get_model_info.model_id
770      assert model_info_log_model.name == model_info_get_model_info.name