test_rag_model.py
1 import json 2 from dataclasses import asdict 3 4 import mlflow 5 from mlflow.models.model import Model 6 from mlflow.models.rag_signatures import ( 7 ChainCompletionChoice, 8 ChatCompletionRequest, 9 ChatCompletionResponse, 10 Message, 11 ) 12 from mlflow.models.signature import ModelSignature 13 14 from tests.helper_functions import expect_status_code, pyfunc_serve_and_score_model 15 16 17 class TestRagModel(mlflow.pyfunc.PythonModel): 18 def predict(self, context, model_input: ChatCompletionRequest): 19 message = model_input.messages[0].content 20 # return the message back 21 return asdict( 22 ChatCompletionResponse( 23 choices=[ChainCompletionChoice(message=Message(role="assistant", content=message))] 24 # NB: intentionally validating the default population of the object field 25 ) 26 ) 27 28 29 def test_rag_model_works_with_type_hint(tmp_path): 30 model = TestRagModel() 31 signature = ModelSignature(inputs=ChatCompletionRequest(), outputs=ChatCompletionResponse()) 32 input_example = {"messages": [{"role": "user", "content": "What is mlflow?"}]} 33 mlflow.pyfunc.save_model( 34 python_model=model, path=tmp_path, signature=signature, input_example=input_example 35 ) 36 37 # test that the model can be loaded and invoked 38 loaded_model = mlflow.pyfunc.load_model(tmp_path) 39 40 response = loaded_model.predict(input_example) 41 assert response["choices"][0]["message"]["content"] == "What is mlflow?" 42 assert response["object"] == "chat.completion" 43 44 # confirm the input example is set 45 mlflow_model = Model.load(tmp_path) 46 assert mlflow_model.load_input_example(tmp_path) == input_example 47 48 # test that the model can be served 49 response = pyfunc_serve_and_score_model( 50 model_uri=tmp_path, 51 data=json.dumps(input_example), 52 content_type="application/json", 53 extra_args=["--env-manager", "local"], 54 ) 55 56 expect_status_code(response, 200) 57 json_response = json.loads(response.content) 58 assert json_response["choices"][0]["message"]["content"] == "What is mlflow?" 59 assert json_response["object"] == "chat.completion"