/ examples / mlflow-3 / register_model.py
register_model.py
 1  import json
 2  
 3  from sklearn.linear_model import LinearRegression
 4  
 5  import mlflow
 6  
 7  client = mlflow.MlflowClient()
 8  
 9  with mlflow.start_run():
10      model = LinearRegression().fit([[1], [2]], [3, 4])
11      model_info = mlflow.sklearn.log_model(
12          model,
13          name="model",
14          params={
15              "alpha": 0.5,
16              "l1_ratio": 0.5,
17          },
18      )
19      model_info_2 = mlflow.sklearn.log_model(
20          model,
21          name="model",
22          step=2,
23          params={
24              "alpha": 0.5,
25              "l1_ratio": 0.5,
26          },
27      )
28  
29  mlflow.register_model(model_info.model_uri, name="model")
30  m = mlflow.get_logged_model(model_info.model_id)
31  assert len(json.loads(m.tags["mlflow.modelVersions"])) == 1
32  print(m.tags)
33  assert m.model_id == model_info.model_id
34  
35  mlflow.register_model(model_info.model_uri, name="hello")
36  m = mlflow.get_logged_model(model_info.model_id)
37  assert len(json.loads(m.tags["mlflow.modelVersions"])) == 2
38  print(m.tags)
39  
40  client = mlflow.MlflowClient()
41  
42  client.create_registered_model("model_client")
43  client.create_model_version("model_client", model_info.model_uri, model_id=model_info.model_id)
44  m = client.get_model_version("model_client", 1)
45  print(m)
46  assert m.model_id == model_info.model_id
47  assert m.params == {
48      "alpha": "0.5",
49      "l1_ratio": "0.5",
50  }
51  
52  # Support backwards compatibility for runs:/... in addition to models:/...
53  model_uri = f"runs:/{model_info.run_id}/model"
54  mlflow.register_model(model_uri, name="model_from_runs_path")
55  mv = client.get_model_version("model_from_runs_path", 1)
56  assert mv.model_id == model_info_2.model_id  # model at largest step is registered
57  
58  # Register model in log_model() directly
59  with mlflow.start_run():
60      model_1 = LinearRegression().fit([[1], [2]], [3, 4])
61      model_info_1 = mlflow.sklearn.log_model(
62          model_1, name="model_1", registered_model_name="model_1"
63      )
64  
65  m = mlflow.get_logged_model(model_info_1.model_id)
66  assert len(json.loads(m.tags["mlflow.modelVersions"])) == 1
67  print(m.tags)
68  
69  mv = client.get_model_version("model_1", 1)
70  assert mv.model_id == model_info_1.model_id