check_migration.py
1 """ 2 Usage 3 ----- 4 export MLFLOW_TRACKING_URI=sqlite:///mlruns.db 5 6 # pre migration 7 python tests/db/check_migration.py pre-migration 8 9 # post migration 10 python tests/db/check_migration.py post-migration 11 """ 12 13 import os 14 import uuid 15 from pathlib import Path 16 17 import click 18 import pandas as pd 19 import sqlalchemy as sa 20 21 import mlflow 22 from mlflow.store.model_registry.dbmodels.models import ( 23 SqlModelVersion, 24 SqlModelVersionTag, 25 SqlRegisteredModel, 26 SqlRegisteredModelTag, 27 ) 28 from mlflow.store.tracking.dbmodels.models import ( 29 SqlExperiment, 30 SqlExperimentTag, 31 SqlLatestMetric, 32 SqlMetric, 33 SqlParam, 34 SqlRun, 35 SqlTag, 36 ) 37 38 TABLES = [ 39 SqlExperiment.__tablename__, 40 SqlRun.__tablename__, 41 SqlMetric.__tablename__, 42 SqlParam.__tablename__, 43 SqlTag.__tablename__, 44 SqlExperimentTag.__tablename__, 45 SqlLatestMetric.__tablename__, 46 SqlRegisteredModel.__tablename__, 47 SqlModelVersion.__tablename__, 48 SqlRegisteredModelTag.__tablename__, 49 SqlModelVersionTag.__tablename__, 50 ] 51 SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" 52 WORKSPACE_TABLES = { 53 "experiments", 54 "registered_models", 55 "model_versions", 56 "registered_model_tags", 57 "model_version_tags", 58 "registered_model_aliases", 59 "evaluation_datasets", 60 "webhooks", 61 "jobs", 62 } 63 64 65 class Model(mlflow.pyfunc.PythonModel): 66 def predict(self, context, model_input, params=None): 67 return [0] 68 69 70 def log_everything(): 71 exp_id = mlflow.create_experiment(uuid.uuid4().hex, tags={"tag": "experiment"}) 72 mlflow.set_experiment(experiment_id=exp_id) 73 with mlflow.start_run() as run: 74 mlflow.log_params({"param": "value"}) 75 mlflow.log_metrics({"metric": 0.1}) 76 mlflow.set_tags({"tag": "run"}) 77 model_info = mlflow.pyfunc.log_model( # clint: disable=log-model-artifact-path 78 "model", python_model=Model() 79 ) 80 81 client = mlflow.MlflowClient() 82 registered_model_name = uuid.uuid4().hex 83 client.create_registered_model( 84 registered_model_name, tags={"tag": "registered_model"}, description="description" 85 ) 86 model_version = client.create_model_version( 87 registered_model_name, 88 model_info.model_uri, 89 run_id=run.info.run_id, 90 tags={"tag": "model_version"}, 91 run_link="run_link", 92 description="description", 93 ) 94 client.set_registered_model_alias( 95 name=registered_model_name, 96 alias="prod", 97 version=model_version.version, 98 ) 99 # Create an additional experiment/model to ensure workspace backfills cover multiple resources. 100 mlflow.create_experiment(uuid.uuid4().hex) 101 client.create_registered_model(uuid.uuid4().hex) 102 client.create_webhook( 103 name=f"migration-webhook-{uuid.uuid4().hex}", 104 url="https://example.com/hook", 105 events=["model_version.created"], 106 description="workspace-migration-check", 107 ) 108 engine = sa.create_engine(os.environ["MLFLOW_TRACKING_URI"]) 109 metadata = sa.MetaData() 110 evaluation_datasets_table = sa.Table( 111 "evaluation_datasets", 112 metadata, 113 autoload_with=engine, 114 ) 115 jobs_table = sa.Table( 116 "jobs", 117 metadata, 118 autoload_with=engine, 119 ) 120 with engine.begin() as conn: 121 conn.execute( 122 sa.insert(evaluation_datasets_table).values( 123 dataset_id=uuid.uuid4().hex, 124 name="workspace-migration-dataset", 125 schema="{}", 126 profile="{}", 127 digest=uuid.uuid4().hex, 128 created_time=0, 129 last_update_time=0, 130 created_by="user", 131 last_updated_by="user", 132 ) 133 ) 134 conn.execute( 135 sa.insert(jobs_table).values( 136 id=uuid.uuid4().hex, 137 creation_time=0, 138 job_name="tests.db.check_migration.log_everything", 139 params="{}", 140 timeout=None, 141 status=0, 142 result=None, 143 retry_count=0, 144 last_update_time=0, 145 ) 146 ) 147 148 149 def connect_to_mlflow_db(): 150 return sa.create_engine(os.environ["MLFLOW_TRACKING_URI"]).connect() 151 152 153 @click.group() 154 def cli(): 155 pass 156 157 158 @cli.command() 159 @click.option("--verbose", is_flag=True, default=False) 160 def pre_migration(verbose): 161 for _ in range(5): 162 log_everything() 163 SNAPSHOTS_DIR.mkdir(exist_ok=True) 164 with connect_to_mlflow_db() as conn: 165 for table in TABLES: 166 df = pd.read_sql(sa.text(f"SELECT * FROM {table}"), conn) 167 df.to_pickle(SNAPSHOTS_DIR / f"{table}.pkl") 168 if verbose: 169 click.secho(f"\n{table}\n", fg="blue") 170 click.secho(df.head(5).to_markdown(index=False)) 171 172 173 @cli.command() 174 def post_migration(): 175 with connect_to_mlflow_db() as conn: 176 for table in TABLES: 177 df_actual = pd.read_sql(sa.text(f"SELECT * FROM {table}"), conn) 178 df_expected = pd.read_pickle(SNAPSHOTS_DIR / f"{table}.pkl") 179 pd.testing.assert_frame_equal(df_actual[df_expected.columns], df_expected) 180 for table in WORKSPACE_TABLES: 181 df = pd.read_sql(sa.text(f"SELECT DISTINCT workspace FROM {table}"), conn) 182 assert not df["workspace"].isna().any(), f"{table} contains NULL workspace values" 183 assert set(df["workspace"]) == {"default"}, f"{table} contains non-default workspaces" 184 185 186 if __name__ == "__main__": 187 cli()