test_log_model_artifact_path.py
1 from pathlib import Path 2 3 from clint.config import Config 4 from clint.index import SymbolIndex 5 from clint.linter import Position, Range, lint_file 6 from clint.rules.log_model_artifact_path import LogModelArtifactPath 7 8 9 def test_log_model_artifact_path(index: SymbolIndex) -> None: 10 code = """ 11 import mlflow 12 13 # Bad - using deprecated artifact_path positionally 14 mlflow.sklearn.log_model(model, "model") 15 16 # Bad - using deprecated artifact_path as keyword 17 mlflow.tensorflow.log_model(model, artifact_path="tf_model") 18 19 # Good - using the new 'name' parameter 20 mlflow.sklearn.log_model(model, name="my_model") 21 22 # Good - spark flavor is exempted from this rule 23 mlflow.spark.log_model(spark_model, "spark_model") 24 25 # Bad - another flavor with artifact_path 26 mlflow.pytorch.log_model(model, artifact_path="pytorch_model") 27 """ 28 config = Config(select={LogModelArtifactPath.name}) 29 violations = lint_file(Path("test.py"), code, config, index) 30 assert len(violations) == 3 31 assert all(isinstance(v.rule, LogModelArtifactPath) for v in violations) 32 assert violations[0].range == Range(Position(4, 0)) 33 assert violations[1].range == Range(Position(7, 0)) 34 assert violations[2].range == Range(Position(16, 0))