/ tests / store / test_workspace_model_coverage.py
test_workspace_model_coverage.py
 1  """Verify that every SQLAlchemy model with a ``workspace`` column is handled
 2  by at least one workspace store's ``_get_query`` method.
 3  
 4  If a new model is added with a ``workspace`` column but the developer forgets
 5  to add it to ``_get_query``, that model's queries will bypass workspace
 6  isolation.  This test catches that gap.
 7  """
 8  
 9  import ast
10  from pathlib import Path
11  
12  # Explicitly import all dbmodel modules so that every ORM model is registered
13  # with Base.registry before we inspect mappers.
14  import mlflow.store.model_registry.dbmodels.models  # noqa: F401
15  import mlflow.store.tracking.dbmodels.models  # noqa: F401
16  import mlflow.store.workspace.dbmodels.models  # noqa: F401
17  from mlflow.store.db.base_sql_model import Base
18  
19  # Locate the repository root relative to this test file so that workspace
20  # store paths resolve correctly regardless of the pytest working directory.
21  _REPO_ROOT = Path(__file__).resolve().parent.parent.parent
22  
23  # Every store that provides a ``_get_query`` with workspace filtering.
24  WORKSPACE_STORE_PATHS = [
25      "mlflow/store/tracking/sqlalchemy_workspace_store.py",
26      "mlflow/store/model_registry/sqlalchemy_store.py",
27      "mlflow/store/jobs/sqlalchemy_workspace_store.py",
28  ]
29  
30  
31  def _models_handled_by_get_query(ws_path: str) -> set[str]:
32      """Parse a workspace store file and return the ``Sql*`` model names
33      referenced in comparisons inside its ``_get_query`` method.
34      """
35      tree = ast.parse((_REPO_ROOT / ws_path).read_text())
36  
37      # Collect module-level variables holding Sql* names (e.g. tuples of models)
38      var_models: dict[str, set[str]] = {}
39      for node in ast.iter_child_nodes(tree):
40          if isinstance(node, ast.Assign):
41              for target in node.targets:
42                  if isinstance(target, ast.Name):
43                      names = {
44                          n.id
45                          for n in ast.walk(node.value)
46                          if isinstance(n, ast.Name) and n.id.startswith("Sql")
47                      }
48                      if names:
49                          var_models[target.id] = names
50  
51      # Extract model names from _get_query comparisons
52      models: set[str] = set()
53      for node in ast.walk(tree):
54          if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
55              continue
56          if node.name != "_get_query":
57              continue
58          for inner in ast.walk(node):
59              if not isinstance(inner, ast.Compare):
60                  continue
61              for comp in inner.comparators:
62                  models |= {
63                      n.id
64                      for n in ast.walk(comp)
65                      if isinstance(n, ast.Name) and n.id.startswith("Sql")
66                  }
67                  if isinstance(comp, ast.Name) and comp.id in var_models:
68                      models |= var_models[comp.id]
69      return models
70  
71  
72  def test_all_workspace_models_handled_in_get_query():
73      """Every model with a workspace column must appear in at least one
74      workspace store's ``_get_query``.
75      """
76      handled: set[str] = set()
77      for ws_path in WORKSPACE_STORE_PATHS:
78          handled |= _models_handled_by_get_query(ws_path)
79  
80      models_with_column: set[str] = set()
81      for mapper in Base.registry.mappers:
82          if "workspace" in {col.key for col in mapper.columns}:
83              models_with_column.add(mapper.class_.__name__)
84  
85      missing = models_with_column - handled
86      assert not missing, (
87          f"These models have a `workspace` column but are not handled by any "
88          f"workspace store's _get_query: {sorted(missing)}. "
89          f"Add handling in the appropriate workspace store's _get_query method."
90      )