/ tests / pyfunc / test_rag_model.py
test_rag_model.py
 1  import json
 2  from dataclasses import asdict
 3  
 4  import mlflow
 5  from mlflow.models.model import Model
 6  from mlflow.models.rag_signatures import (
 7      ChainCompletionChoice,
 8      ChatCompletionRequest,
 9      ChatCompletionResponse,
10      Message,
11  )
12  from mlflow.models.signature import ModelSignature
13  
14  from tests.helper_functions import expect_status_code, pyfunc_serve_and_score_model
15  
16  
17  class TestRagModel(mlflow.pyfunc.PythonModel):
18      def predict(self, context, model_input: ChatCompletionRequest):
19          message = model_input.messages[0].content
20          # return the message back
21          return asdict(
22              ChatCompletionResponse(
23                  choices=[ChainCompletionChoice(message=Message(role="assistant", content=message))]
24                  # NB: intentionally validating the default population of the object field
25              )
26          )
27  
28  
29  def test_rag_model_works_with_type_hint(tmp_path):
30      model = TestRagModel()
31      signature = ModelSignature(inputs=ChatCompletionRequest(), outputs=ChatCompletionResponse())
32      input_example = {"messages": [{"role": "user", "content": "What is mlflow?"}]}
33      mlflow.pyfunc.save_model(
34          python_model=model, path=tmp_path, signature=signature, input_example=input_example
35      )
36  
37      # test that the model can be loaded and invoked
38      loaded_model = mlflow.pyfunc.load_model(tmp_path)
39  
40      response = loaded_model.predict(input_example)
41      assert response["choices"][0]["message"]["content"] == "What is mlflow?"
42      assert response["object"] == "chat.completion"
43  
44      # confirm the input example is set
45      mlflow_model = Model.load(tmp_path)
46      assert mlflow_model.load_input_example(tmp_path) == input_example
47  
48      # test that the model can be served
49      response = pyfunc_serve_and_score_model(
50          model_uri=tmp_path,
51          data=json.dumps(input_example),
52          content_type="application/json",
53          extra_args=["--env-manager", "local"],
54      )
55  
56      expect_status_code(response, 200)
57      json_response = json.loads(response.content)
58      assert json_response["choices"][0]["message"]["content"] == "What is mlflow?"
59      assert json_response["object"] == "chat.completion"