/ tests / db / check_migration.py
check_migration.py
  1  """
  2  Usage
  3  -----
  4  export MLFLOW_TRACKING_URI=sqlite:///mlruns.db
  5  
  6  # pre migration
  7  python tests/db/check_migration.py pre-migration
  8  
  9  # post migration
 10  python tests/db/check_migration.py post-migration
 11  """
 12  
 13  import os
 14  import uuid
 15  from pathlib import Path
 16  
 17  import click
 18  import pandas as pd
 19  import sqlalchemy as sa
 20  
 21  import mlflow
 22  from mlflow.store.model_registry.dbmodels.models import (
 23      SqlModelVersion,
 24      SqlModelVersionTag,
 25      SqlRegisteredModel,
 26      SqlRegisteredModelTag,
 27  )
 28  from mlflow.store.tracking.dbmodels.models import (
 29      SqlExperiment,
 30      SqlExperimentTag,
 31      SqlLatestMetric,
 32      SqlMetric,
 33      SqlParam,
 34      SqlRun,
 35      SqlTag,
 36  )
 37  
 38  TABLES = [
 39      SqlExperiment.__tablename__,
 40      SqlRun.__tablename__,
 41      SqlMetric.__tablename__,
 42      SqlParam.__tablename__,
 43      SqlTag.__tablename__,
 44      SqlExperimentTag.__tablename__,
 45      SqlLatestMetric.__tablename__,
 46      SqlRegisteredModel.__tablename__,
 47      SqlModelVersion.__tablename__,
 48      SqlRegisteredModelTag.__tablename__,
 49      SqlModelVersionTag.__tablename__,
 50  ]
 51  SNAPSHOTS_DIR = Path(__file__).parent / "snapshots"
 52  WORKSPACE_TABLES = {
 53      "experiments",
 54      "registered_models",
 55      "model_versions",
 56      "registered_model_tags",
 57      "model_version_tags",
 58      "registered_model_aliases",
 59      "evaluation_datasets",
 60      "webhooks",
 61      "jobs",
 62  }
 63  
 64  
 65  class Model(mlflow.pyfunc.PythonModel):
 66      def predict(self, context, model_input, params=None):
 67          return [0]
 68  
 69  
 70  def log_everything():
 71      exp_id = mlflow.create_experiment(uuid.uuid4().hex, tags={"tag": "experiment"})
 72      mlflow.set_experiment(experiment_id=exp_id)
 73      with mlflow.start_run() as run:
 74          mlflow.log_params({"param": "value"})
 75          mlflow.log_metrics({"metric": 0.1})
 76          mlflow.set_tags({"tag": "run"})
 77          model_info = mlflow.pyfunc.log_model(  # clint: disable=log-model-artifact-path
 78              "model", python_model=Model()
 79          )
 80  
 81      client = mlflow.MlflowClient()
 82      registered_model_name = uuid.uuid4().hex
 83      client.create_registered_model(
 84          registered_model_name, tags={"tag": "registered_model"}, description="description"
 85      )
 86      model_version = client.create_model_version(
 87          registered_model_name,
 88          model_info.model_uri,
 89          run_id=run.info.run_id,
 90          tags={"tag": "model_version"},
 91          run_link="run_link",
 92          description="description",
 93      )
 94      client.set_registered_model_alias(
 95          name=registered_model_name,
 96          alias="prod",
 97          version=model_version.version,
 98      )
 99      # Create an additional experiment/model to ensure workspace backfills cover multiple resources.
100      mlflow.create_experiment(uuid.uuid4().hex)
101      client.create_registered_model(uuid.uuid4().hex)
102      client.create_webhook(
103          name=f"migration-webhook-{uuid.uuid4().hex}",
104          url="https://example.com/hook",
105          events=["model_version.created"],
106          description="workspace-migration-check",
107      )
108      engine = sa.create_engine(os.environ["MLFLOW_TRACKING_URI"])
109      metadata = sa.MetaData()
110      evaluation_datasets_table = sa.Table(
111          "evaluation_datasets",
112          metadata,
113          autoload_with=engine,
114      )
115      jobs_table = sa.Table(
116          "jobs",
117          metadata,
118          autoload_with=engine,
119      )
120      with engine.begin() as conn:
121          conn.execute(
122              sa.insert(evaluation_datasets_table).values(
123                  dataset_id=uuid.uuid4().hex,
124                  name="workspace-migration-dataset",
125                  schema="{}",
126                  profile="{}",
127                  digest=uuid.uuid4().hex,
128                  created_time=0,
129                  last_update_time=0,
130                  created_by="user",
131                  last_updated_by="user",
132              )
133          )
134          conn.execute(
135              sa.insert(jobs_table).values(
136                  id=uuid.uuid4().hex,
137                  creation_time=0,
138                  job_name="tests.db.check_migration.log_everything",
139                  params="{}",
140                  timeout=None,
141                  status=0,
142                  result=None,
143                  retry_count=0,
144                  last_update_time=0,
145              )
146          )
147  
148  
149  def connect_to_mlflow_db():
150      return sa.create_engine(os.environ["MLFLOW_TRACKING_URI"]).connect()
151  
152  
153  @click.group()
154  def cli():
155      pass
156  
157  
158  @cli.command()
159  @click.option("--verbose", is_flag=True, default=False)
160  def pre_migration(verbose):
161      for _ in range(5):
162          log_everything()
163      SNAPSHOTS_DIR.mkdir(exist_ok=True)
164      with connect_to_mlflow_db() as conn:
165          for table in TABLES:
166              df = pd.read_sql(sa.text(f"SELECT * FROM {table}"), conn)
167              df.to_pickle(SNAPSHOTS_DIR / f"{table}.pkl")
168              if verbose:
169                  click.secho(f"\n{table}\n", fg="blue")
170                  click.secho(df.head(5).to_markdown(index=False))
171  
172  
173  @cli.command()
174  def post_migration():
175      with connect_to_mlflow_db() as conn:
176          for table in TABLES:
177              df_actual = pd.read_sql(sa.text(f"SELECT * FROM {table}"), conn)
178              df_expected = pd.read_pickle(SNAPSHOTS_DIR / f"{table}.pkl")
179              pd.testing.assert_frame_equal(df_actual[df_expected.columns], df_expected)
180          for table in WORKSPACE_TABLES:
181              df = pd.read_sql(sa.text(f"SELECT DISTINCT workspace FROM {table}"), conn)
182              assert not df["workspace"].isna().any(), f"{table} contains NULL workspace values"
183              assert set(df["workspace"]) == {"default"}, f"{table} contains non-default workspaces"
184  
185  
186  if __name__ == "__main__":
187      cli()