check_mlflow_lazily_imports_ml_packages.py
1 """ 2 Tests that `import mlflow` and `mlflow.autolog()` do not import ML packages. 3 """ 4 5 import importlib 6 import logging 7 import sys 8 9 import mlflow 10 11 logger = logging.getLogger() 12 13 14 def main(): 15 ml_packages = { 16 "catboost", 17 "h2o", 18 "lightgbm", 19 "onnx", 20 "pytorch_lightning", 21 "pyspark.ml", 22 "shap", 23 "sklearn", 24 "spacy", 25 "statsmodels", 26 "tensorflow", 27 "torch", 28 "xgboost", 29 "pmdarima", 30 "transformers", 31 "sentence_transformers", 32 } 33 imported = ml_packages.intersection(set(sys.modules)) 34 assert imported == set(), f"mlflow imports {imported} when it's imported but it should not" 35 36 mlflow.autolog() 37 imported = ml_packages.intersection(set(sys.modules)) 38 assert imported == set(), f"`mlflow.autolog` imports {imported} but it should not" 39 40 # Ensure that the ML packages are importable 41 failed_to_import = [] 42 for package in sorted(ml_packages): 43 try: 44 importlib.import_module(package) 45 except ImportError: 46 logger.exception(f"Failed to import {package}") 47 failed_to_import.append(package) 48 49 message = ( 50 f"Failed to import {failed_to_import}. Please install packages that provide these modules." 51 ) 52 assert failed_to_import == [], message 53 54 55 if __name__ == "__main__": 56 main()