/ dev / clint / tests / rules / test_log_model_artifact_path.py
test_log_model_artifact_path.py
 1  from pathlib import Path
 2  
 3  from clint.config import Config
 4  from clint.index import SymbolIndex
 5  from clint.linter import Position, Range, lint_file
 6  from clint.rules.log_model_artifact_path import LogModelArtifactPath
 7  
 8  
 9  def test_log_model_artifact_path(index: SymbolIndex) -> None:
10      code = """
11  import mlflow
12  
13  # Bad - using deprecated artifact_path positionally
14  mlflow.sklearn.log_model(model, "model")
15  
16  # Bad - using deprecated artifact_path as keyword
17  mlflow.tensorflow.log_model(model, artifact_path="tf_model")
18  
19  # Good - using the new 'name' parameter
20  mlflow.sklearn.log_model(model, name="my_model")
21  
22  # Good - spark flavor is exempted from this rule
23  mlflow.spark.log_model(spark_model, "spark_model")
24  
25  # Bad - another flavor with artifact_path
26  mlflow.pytorch.log_model(model, artifact_path="pytorch_model")
27  """
28      config = Config(select={LogModelArtifactPath.name})
29      violations = lint_file(Path("test.py"), code, config, index)
30      assert len(violations) == 3
31      assert all(isinstance(v.rule, LogModelArtifactPath) for v in violations)
32      assert violations[0].range == Range(Position(4, 0))
33      assert violations[1].range == Range(Position(7, 0))
34      assert violations[2].range == Range(Position(16, 0))