register_model.py
1 import json 2 3 from sklearn.linear_model import LinearRegression 4 5 import mlflow 6 7 client = mlflow.MlflowClient() 8 9 with mlflow.start_run(): 10 model = LinearRegression().fit([[1], [2]], [3, 4]) 11 model_info = mlflow.sklearn.log_model( 12 model, 13 name="model", 14 params={ 15 "alpha": 0.5, 16 "l1_ratio": 0.5, 17 }, 18 ) 19 model_info_2 = mlflow.sklearn.log_model( 20 model, 21 name="model", 22 step=2, 23 params={ 24 "alpha": 0.5, 25 "l1_ratio": 0.5, 26 }, 27 ) 28 29 mlflow.register_model(model_info.model_uri, name="model") 30 m = mlflow.get_logged_model(model_info.model_id) 31 assert len(json.loads(m.tags["mlflow.modelVersions"])) == 1 32 print(m.tags) 33 assert m.model_id == model_info.model_id 34 35 mlflow.register_model(model_info.model_uri, name="hello") 36 m = mlflow.get_logged_model(model_info.model_id) 37 assert len(json.loads(m.tags["mlflow.modelVersions"])) == 2 38 print(m.tags) 39 40 client = mlflow.MlflowClient() 41 42 client.create_registered_model("model_client") 43 client.create_model_version("model_client", model_info.model_uri, model_id=model_info.model_id) 44 m = client.get_model_version("model_client", 1) 45 print(m) 46 assert m.model_id == model_info.model_id 47 assert m.params == { 48 "alpha": "0.5", 49 "l1_ratio": "0.5", 50 } 51 52 # Support backwards compatibility for runs:/... in addition to models:/... 53 model_uri = f"runs:/{model_info.run_id}/model" 54 mlflow.register_model(model_uri, name="model_from_runs_path") 55 mv = client.get_model_version("model_from_runs_path", 1) 56 assert mv.model_id == model_info_2.model_id # model at largest step is registered 57 58 # Register model in log_model() directly 59 with mlflow.start_run(): 60 model_1 = LinearRegression().fit([[1], [2]], [3, 4]) 61 model_info_1 = mlflow.sklearn.log_model( 62 model_1, name="model_1", registered_model_name="model_1" 63 ) 64 65 m = mlflow.get_logged_model(model_info_1.model_id) 66 assert len(json.loads(m.tags["mlflow.modelVersions"])) == 1 67 print(m.tags) 68 69 mv = client.get_model_version("model_1", 1) 70 assert mv.model_id == model_info_1.model_id