/ tests / check_mlflow_lazily_imports_ml_packages.py
check_mlflow_lazily_imports_ml_packages.py
 1  """
 2  Tests that `import mlflow` and `mlflow.autolog()` do not import ML packages.
 3  """
 4  
 5  import importlib
 6  import logging
 7  import sys
 8  
 9  import mlflow
10  
11  logger = logging.getLogger()
12  
13  
14  def main():
15      ml_packages = {
16          "catboost",
17          "h2o",
18          "lightgbm",
19          "onnx",
20          "pytorch_lightning",
21          "pyspark.ml",
22          "shap",
23          "sklearn",
24          "spacy",
25          "statsmodels",
26          "tensorflow",
27          "torch",
28          "xgboost",
29          "pmdarima",
30          "transformers",
31          "sentence_transformers",
32      }
33      imported = ml_packages.intersection(set(sys.modules))
34      assert imported == set(), f"mlflow imports {imported} when it's imported but it should not"
35  
36      mlflow.autolog()
37      imported = ml_packages.intersection(set(sys.modules))
38      assert imported == set(), f"`mlflow.autolog` imports {imported} but it should not"
39  
40      # Ensure that the ML packages are importable
41      failed_to_import = []
42      for package in sorted(ml_packages):
43          try:
44              importlib.import_module(package)
45          except ImportError:
46              logger.exception(f"Failed to import {package}")
47              failed_to_import.append(package)
48  
49      message = (
50          f"Failed to import {failed_to_import}. Please install packages that provide these modules."
51      )
52      assert failed_to_import == [], message
53  
54  
55  if __name__ == "__main__":
56      main()