/ tests / test_mismatch.py
test_mismatch.py
 1  import warnings
 2  from importlib.metadata import PackageNotFoundError
 3  from unittest import mock
 4  
 5  import pytest
 6  
 7  from mlflow.mismatch import _check_version_mismatch
 8  
 9  
10  @pytest.mark.parametrize(
11      ("mlflow_version", "skinny_version"),
12      [
13          ("1.0.0", "1.0.0"),
14          ("1.0.0.dev0", "1.0.0"),
15          ("1.0.0", "1.0.0.dev0"),
16          ("1.0.0.dev0", "1.0.0.dev0"),
17          ("1.0.0", None),
18          (None, "1.0.0"),
19          (None, None),
20      ],
21  )
22  @pytest.mark.parametrize(
23      "tracing_version",
24      [None, "1.0.0", "1.0.0.dev0"],
25  )
26  def test_check_version_mismatch_no_warn(
27      mlflow_version: str | None, skinny_version: str | None, tracing_version: str | None
28  ):
29      def mock_version(package_name: str) -> str:
30          if package_name == "mlflow":
31              if mlflow_version is None:
32                  raise PackageNotFoundError
33              return mlflow_version
34          elif package_name == "mlflow-skinny":
35              if skinny_version is None:
36                  raise PackageNotFoundError
37              return skinny_version
38          elif package_name == "mlflow-tracing":
39              if tracing_version is None:
40                  raise PackageNotFoundError
41              return tracing_version
42          raise ValueError(f"Unexpected package: {package_name}")
43  
44      with mock.patch("importlib.metadata.version", side_effect=mock_version) as mv:
45          with warnings.catch_warnings():
46              warnings.simplefilter("error")
47              _check_version_mismatch()
48  
49          mv.assert_called()
50  
51  
52  @pytest.mark.parametrize(
53      ("mlflow_version", "skinny_version", "tracing_version", "expected"),
54      [
55          ("1.0.0", "1.0.1", "1.0.0", r"mlflow-skinny \(1.0.1\)"),
56          ("1.0.0", "1.0.0", "1.0.1", r"mlflow-tracing \(1.0.1\)"),
57          ("1.0.1", "1.0.0", "1.0.0", r"mlflow-skinny \(1.0.0\), mlflow-tracing \(1.0.0\)"),
58      ],
59  )
60  def test_check_version_mismatch_warn(
61      mlflow_version: str,
62      skinny_version: str,
63      tracing_version: str,
64      expected: str,
65  ):
66      def mock_version(package_name: str) -> str:
67          if package_name == "mlflow":
68              return mlflow_version
69          elif package_name == "mlflow-skinny":
70              return skinny_version
71          elif package_name == "mlflow-tracing":
72              if tracing_version is None:
73                  raise PackageNotFoundError
74              return tracing_version
75          raise ValueError(f"Unexpected package: {package_name}")
76  
77      with mock.patch("importlib.metadata.version", side_effect=mock_version) as mv:
78          with pytest.warns(
79              UserWarning,
80              match=rf"Versions of mlflow \([.\w]+\) and child packages {expected} are different",
81          ):
82              _check_version_mismatch()
83  
84          mv.assert_called()