/ tests / pyfunc / test_pyfunc_class_methods.py
test_pyfunc_class_methods.py
 1  import mlflow
 2  from mlflow.pyfunc import PythonModel, load_model, log_model
 3  
 4  
 5  def test_unwrap_python_model_from_pyfunc_class():
 6      class MyModel(PythonModel):
 7          def __init__(self, param_1: str, param_2: int):
 8              self.param_1 = param_1
 9              self.param_2 = param_2
10  
11          def predict(self, context, model_input, params=None):
12              return model_input + self.param_2
13  
14          def upper_param_1(self):
15              return self.param_1.upper()
16  
17      with mlflow.start_run():
18          model = MyModel("this is test message", 2)
19          model_uri = log_model("mlruns", python_model=model).model_uri
20          loaded_model = load_model(model_uri).unwrap_python_model()
21          assert isinstance(loaded_model, MyModel)
22          assert loaded_model.param_1 == "this is test message"
23          assert loaded_model.param_2 == 2
24          assert loaded_model.predict(None, 1) == 3
25          assert loaded_model.upper_param_1() == "THIS IS TEST MESSAGE"