test_workspace_model_coverage.py
1 """Verify that every SQLAlchemy model with a ``workspace`` column is handled 2 by at least one workspace store's ``_get_query`` method. 3 4 If a new model is added with a ``workspace`` column but the developer forgets 5 to add it to ``_get_query``, that model's queries will bypass workspace 6 isolation. This test catches that gap. 7 """ 8 9 import ast 10 from pathlib import Path 11 12 # Explicitly import all dbmodel modules so that every ORM model is registered 13 # with Base.registry before we inspect mappers. 14 import mlflow.store.model_registry.dbmodels.models # noqa: F401 15 import mlflow.store.tracking.dbmodels.models # noqa: F401 16 import mlflow.store.workspace.dbmodels.models # noqa: F401 17 from mlflow.store.db.base_sql_model import Base 18 19 # Locate the repository root relative to this test file so that workspace 20 # store paths resolve correctly regardless of the pytest working directory. 21 _REPO_ROOT = Path(__file__).resolve().parent.parent.parent 22 23 # Every store that provides a ``_get_query`` with workspace filtering. 24 WORKSPACE_STORE_PATHS = [ 25 "mlflow/store/tracking/sqlalchemy_workspace_store.py", 26 "mlflow/store/model_registry/sqlalchemy_store.py", 27 "mlflow/store/jobs/sqlalchemy_workspace_store.py", 28 ] 29 30 31 def _models_handled_by_get_query(ws_path: str) -> set[str]: 32 """Parse a workspace store file and return the ``Sql*`` model names 33 referenced in comparisons inside its ``_get_query`` method. 34 """ 35 tree = ast.parse((_REPO_ROOT / ws_path).read_text()) 36 37 # Collect module-level variables holding Sql* names (e.g. tuples of models) 38 var_models: dict[str, set[str]] = {} 39 for node in ast.iter_child_nodes(tree): 40 if isinstance(node, ast.Assign): 41 for target in node.targets: 42 if isinstance(target, ast.Name): 43 names = { 44 n.id 45 for n in ast.walk(node.value) 46 if isinstance(n, ast.Name) and n.id.startswith("Sql") 47 } 48 if names: 49 var_models[target.id] = names 50 51 # Extract model names from _get_query comparisons 52 models: set[str] = set() 53 for node in ast.walk(tree): 54 if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 55 continue 56 if node.name != "_get_query": 57 continue 58 for inner in ast.walk(node): 59 if not isinstance(inner, ast.Compare): 60 continue 61 for comp in inner.comparators: 62 models |= { 63 n.id 64 for n in ast.walk(comp) 65 if isinstance(n, ast.Name) and n.id.startswith("Sql") 66 } 67 if isinstance(comp, ast.Name) and comp.id in var_models: 68 models |= var_models[comp.id] 69 return models 70 71 72 def test_all_workspace_models_handled_in_get_query(): 73 """Every model with a workspace column must appear in at least one 74 workspace store's ``_get_query``. 75 """ 76 handled: set[str] = set() 77 for ws_path in WORKSPACE_STORE_PATHS: 78 handled |= _models_handled_by_get_query(ws_path) 79 80 models_with_column: set[str] = set() 81 for mapper in Base.registry.mappers: 82 if "workspace" in {col.key for col in mapper.columns}: 83 models_with_column.add(mapper.class_.__name__) 84 85 missing = models_with_column - handled 86 assert not missing, ( 87 f"These models have a `workspace` column but are not handled by any " 88 f"workspace store's _get_query: {sorted(missing)}. " 89 f"Add handling in the appropriate workspace store's _get_query method." 90 )