test_pyfunc_class_methods.py
1 import mlflow 2 from mlflow.pyfunc import PythonModel, load_model, log_model 3 4 5 def test_unwrap_python_model_from_pyfunc_class(): 6 class MyModel(PythonModel): 7 def __init__(self, param_1: str, param_2: int): 8 self.param_1 = param_1 9 self.param_2 = param_2 10 11 def predict(self, context, model_input, params=None): 12 return model_input + self.param_2 13 14 def upper_param_1(self): 15 return self.param_1.upper() 16 17 with mlflow.start_run(): 18 model = MyModel("this is test message", 2) 19 model_uri = log_model("mlruns", python_model=model).model_uri 20 loaded_model = load_model(model_uri).unwrap_python_model() 21 assert isinstance(loaded_model, MyModel) 22 assert loaded_model.param_1 == "this is test message" 23 assert loaded_model.param_2 == 2 24 assert loaded_model.predict(None, 1) == 3 25 assert loaded_model.upper_param_1() == "THIS IS TEST MESSAGE"