test_chat_model.py
1 import json 2 import pathlib 3 import pickle 4 import uuid 5 from dataclasses import asdict 6 7 import pytest 8 9 import mlflow 10 from mlflow.exceptions import MlflowException 11 from mlflow.models.model import Model 12 from mlflow.models.signature import ModelSignature 13 from mlflow.models.utils import load_serving_example 14 from mlflow.pyfunc.loaders.chat_model import _ChatModelPyfuncWrapper 15 from mlflow.tracing.constant import TraceTagKey 16 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 17 from mlflow.types.llm import ( 18 CHAT_MODEL_INPUT_SCHEMA, 19 CHAT_MODEL_OUTPUT_SCHEMA, 20 ChatChoice, 21 ChatChoiceDelta, 22 ChatChunkChoice, 23 ChatCompletionChunk, 24 ChatCompletionResponse, 25 ChatMessage, 26 ChatParams, 27 FunctionToolCallArguments, 28 FunctionToolDefinition, 29 ToolParamsSchema, 30 ) 31 from mlflow.types.schema import ColSpec, DataType, Schema 32 33 from tests.helper_functions import ( 34 expect_status_code, 35 pyfunc_serve_and_score_model, 36 ) 37 from tests.tracing.helper import get_traces 38 39 # `None`s (`max_tokens` and `stop`) are excluded 40 DEFAULT_PARAMS = { 41 "temperature": 1.0, 42 "n": 1, 43 "stream": False, 44 } 45 46 47 def get_mock_streaming_response(message, is_last_chunk=False): 48 if is_last_chunk: 49 return { 50 "id": "123", 51 "model": "MyChatModel", 52 "choices": [ 53 { 54 "index": 0, 55 "delta": { 56 "role": None, 57 "content": None, 58 }, 59 "finish_reason": "stop", 60 } 61 ], 62 "usage": { 63 "prompt_tokens": 10, 64 "completion_tokens": 10, 65 "total_tokens": 20, 66 }, 67 } 68 else: 69 return { 70 "id": "123", 71 "model": "MyChatModel", 72 "choices": [ 73 { 74 "index": 0, 75 "delta": { 76 "role": "assistant", 77 "content": message, 78 }, 79 "finish_reason": "stop", 80 } 81 ], 82 "usage": { 83 "prompt_tokens": 10, 84 "completion_tokens": 10, 85 "total_tokens": 20, 86 }, 87 } 88 89 90 def get_mock_response(messages, params): 91 return { 92 "id": "123", 93 "model": "MyChatModel", 94 "choices": [ 95 { 96 "index": 0, 97 "message": { 98 "role": "assistant", 99 "content": json.dumps([m.to_dict() for m in messages]), 100 }, 101 "finish_reason": "stop", 102 }, 103 { 104 "index": 1, 105 "message": { 106 "role": "user", 107 "content": json.dumps(params.to_dict()), 108 }, 109 "finish_reason": "stop", 110 }, 111 ], 112 "usage": { 113 "prompt_tokens": 10, 114 "completion_tokens": 10, 115 "total_tokens": 20, 116 }, 117 } 118 119 120 class SimpleChatModel(mlflow.pyfunc.ChatModel): 121 def predict( 122 self, context, messages: list[ChatMessage], params: ChatParams 123 ) -> ChatCompletionResponse: 124 mock_response = get_mock_response(messages, params) 125 return ChatCompletionResponse.from_dict(mock_response) 126 127 def predict_stream(self, context, messages: list[ChatMessage], params: ChatParams): 128 num_chunks = 10 129 for i in range(num_chunks): 130 mock_response = get_mock_streaming_response( 131 f"message {i}", is_last_chunk=(i == num_chunks - 1) 132 ) 133 yield ChatCompletionChunk.from_dict(mock_response) 134 135 136 class ChatModelWithContext(mlflow.pyfunc.ChatModel): 137 def load_context(self, context): 138 predict_path = pathlib.Path(context.artifacts["predict_fn"]) 139 self.predict_fn = pickle.loads(predict_path.read_bytes()) 140 141 def predict( 142 self, context, messages: list[ChatMessage], params: ChatParams 143 ) -> ChatCompletionResponse: 144 message = ChatMessage(role="assistant", content=self.predict_fn()) 145 return ChatCompletionResponse.from_dict(get_mock_response([message], params)) 146 147 148 class ChatModelWithTrace(mlflow.pyfunc.ChatModel): 149 @mlflow.trace 150 def predict( 151 self, context, messages: list[ChatMessage], params: ChatParams 152 ) -> ChatCompletionResponse: 153 mock_response = get_mock_response(messages, params) 154 return ChatCompletionResponse.from_dict(mock_response) 155 156 157 class ChatModelWithMetadata(mlflow.pyfunc.ChatModel): 158 def predict( 159 self, context, messages: list[ChatMessage], params: ChatParams 160 ) -> ChatCompletionResponse: 161 mock_response = get_mock_response(messages, params) 162 return ChatCompletionResponse( 163 **mock_response, 164 custom_outputs=params.custom_inputs, 165 ) 166 167 168 class ChatModelWithToolCalling(mlflow.pyfunc.ChatModel): 169 def predict( 170 self, context, messages: list[ChatMessage], params: ChatParams 171 ) -> ChatCompletionResponse: 172 tools = params.tools 173 174 # call the first tool with some value for all the required params 175 tool_name = tools[0].function.name 176 tool_params = tools[0].function.parameters 177 arguments = {} 178 for param in tool_params.required: 179 param_type = tool_params.properties[param].type 180 if param_type == "string": 181 arguments[param] = "some_value" 182 elif param_type == "number": 183 arguments[param] = 123 184 elif param_type == "boolean": 185 arguments[param] = True 186 else: 187 # keep the test example simple 188 raise ValueError(f"Unsupported param type: {param_type}") 189 190 tool_call = FunctionToolCallArguments( 191 name=tool_name, 192 arguments=json.dumps(arguments), 193 ).to_tool_call(id=uuid.uuid4().hex) 194 195 tool_message = ChatMessage( 196 role="assistant", 197 tool_calls=[tool_call], 198 ) 199 200 return ChatCompletionResponse(choices=[ChatChoice(index=0, message=tool_message)]) 201 202 203 def test_chat_model_save_load(tmp_path): 204 model = SimpleChatModel() 205 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 206 207 loaded_model = mlflow.pyfunc.load_model(tmp_path) 208 assert isinstance(loaded_model._model_impl, _ChatModelPyfuncWrapper) 209 input_schema = loaded_model.metadata.get_input_schema() 210 output_schema = loaded_model.metadata.get_output_schema() 211 assert input_schema == CHAT_MODEL_INPUT_SCHEMA 212 assert output_schema == CHAT_MODEL_OUTPUT_SCHEMA 213 214 215 def test_chat_model_with_trace(tmp_path): 216 model = ChatModelWithTrace() 217 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 218 219 # predict() call during saving chat model should not generate a trace 220 assert len(get_traces()) == 0 221 222 loaded_model = mlflow.pyfunc.load_model(tmp_path) 223 messages = [ 224 {"role": "system", "content": "You are a helpful assistant"}, 225 {"role": "user", "content": "Hello!"}, 226 ] 227 loaded_model.predict({"messages": messages}) 228 229 traces = get_traces() 230 assert len(traces) == 1 231 assert traces[0].info.tags[TraceTagKey.TRACE_NAME] == "predict" 232 request = json.loads(traces[0].data.request) 233 assert request["messages"] == [asdict(ChatMessage.from_dict(msg)) for msg in messages] 234 235 236 def test_chat_model_save_throws_with_signature(tmp_path): 237 model = SimpleChatModel() 238 239 with pytest.raises(MlflowException, match="Please remove the `signature` parameter"): 240 mlflow.pyfunc.save_model( 241 python_model=model, 242 path=tmp_path, 243 signature=ModelSignature( 244 Schema([ColSpec(name="test", type=DataType.string)]), 245 Schema([ColSpec(name="test", type=DataType.string)]), 246 ), 247 ) 248 249 250 def mock_predict(): 251 return "hello" 252 253 254 def test_chat_model_with_context_saves_successfully(tmp_path): 255 model_path = tmp_path / "model" 256 predict_path = tmp_path / "predict.pkl" 257 predict_path.write_bytes(pickle.dumps(mock_predict)) 258 259 model = ChatModelWithContext() 260 mlflow.pyfunc.save_model( 261 python_model=model, 262 path=model_path, 263 artifacts={"predict_fn": str(predict_path)}, 264 ) 265 266 loaded_model = mlflow.pyfunc.load_model(model_path) 267 messages = [{"role": "user", "content": "test"}] 268 269 response = loaded_model.predict({"messages": messages}) 270 expected_response = json.dumps([{"role": "assistant", "content": "hello"}]) 271 assert response["choices"][0]["message"]["content"] == expected_response 272 273 274 @pytest.mark.parametrize( 275 "ret", 276 [ 277 "not a ChatCompletionResponse", 278 {"dict": "with", "bad": "keys"}, 279 { 280 "id": "1", 281 "created": 1, 282 "model": "m", 283 "choices": [{"bad": "choice"}], 284 "usage": { 285 "prompt_tokens": 10, 286 "completion_tokens": 10, 287 "total_tokens": 20, 288 }, 289 }, 290 ], 291 ) 292 def test_save_throws_on_invalid_output(tmp_path, ret): 293 class BadChatModel(mlflow.pyfunc.ChatModel): 294 def predict(self, context, messages, params) -> ChatCompletionResponse: 295 return ret 296 297 model = BadChatModel() 298 with pytest.raises( 299 MlflowException, 300 match=( 301 "Failed to save ChatModel. Please ensure that the model's " 302 r"predict\(\) method returns a ChatCompletionResponse object" 303 ), 304 ): 305 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 306 307 308 # test that we can predict with the model 309 def test_chat_model_predict(tmp_path): 310 model = SimpleChatModel() 311 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 312 313 loaded_model = mlflow.pyfunc.load_model(tmp_path) 314 messages = [ 315 {"role": "system", "content": "You are a helpful assistant"}, 316 {"role": "user", "content": "Hello!"}, 317 ] 318 319 response = loaded_model.predict({"messages": messages}) 320 assert response["choices"][0]["message"]["content"] == json.dumps(messages) 321 assert json.loads(response["choices"][1]["message"]["content"]) == DEFAULT_PARAMS 322 323 # override all params 324 params_override = { 325 "temperature": 0.5, 326 "max_tokens": 10, 327 "stop": ["\n"], 328 "n": 2, 329 "stream": True, 330 "top_p": 0.1, 331 "top_k": 20, 332 "frequency_penalty": 0.5, 333 "presence_penalty": -0.5, 334 } 335 response = loaded_model.predict({"messages": messages, **params_override}) 336 assert response["choices"][0]["message"]["content"] == json.dumps(messages) 337 assert json.loads(response["choices"][1]["message"]["content"]) == params_override 338 339 # override a subset of params 340 params_subset = { 341 "max_tokens": 100, 342 } 343 response = loaded_model.predict({"messages": messages, **params_subset}) 344 assert response["choices"][0]["message"]["content"] == json.dumps(messages) 345 assert json.loads(response["choices"][1]["message"]["content"]) == { 346 **DEFAULT_PARAMS, 347 **params_subset, 348 } 349 350 351 def test_chat_model_works_in_serving(): 352 model = SimpleChatModel() 353 messages = [ 354 {"role": "system", "content": "You are a helpful assistant"}, 355 {"role": "user", "content": "Hello!"}, 356 ] 357 params_subset = { 358 "max_tokens": 100, 359 } 360 with mlflow.start_run(): 361 model_info = mlflow.pyfunc.log_model( 362 name="model", 363 python_model=model, 364 input_example=(messages, params_subset), 365 ) 366 367 inference_payload = load_serving_example(model_info.model_uri) 368 response = pyfunc_serve_and_score_model( 369 model_uri=model_info.model_uri, 370 data=inference_payload, 371 content_type="application/json", 372 extra_args=["--env-manager", "local"], 373 ) 374 375 expect_status_code(response, 200) 376 choices = json.loads(response.content)["choices"] 377 assert choices[0]["message"]["content"] == json.dumps(messages) 378 assert json.loads(choices[1]["message"]["content"]) == { 379 **DEFAULT_PARAMS, 380 **params_subset, 381 } 382 383 384 def test_chat_model_works_with_infer_signature_input_example(tmp_path): 385 model = SimpleChatModel() 386 params_subset = { 387 "max_tokens": 100, 388 } 389 input_example = { 390 "messages": [ 391 { 392 "role": "user", 393 "content": "What is Retrieval-augmented Generation?", 394 } 395 ], 396 **params_subset, 397 } 398 with mlflow.start_run(): 399 model_info = mlflow.pyfunc.log_model( 400 name="model", python_model=model, input_example=input_example 401 ) 402 assert model_info.signature.inputs == CHAT_MODEL_INPUT_SCHEMA 403 assert model_info.signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA 404 mlflow_model = Model.load(model_info.model_uri) 405 local_path = _download_artifact_from_uri(model_info.model_uri) 406 assert mlflow_model.load_input_example(local_path) == { 407 "messages": input_example["messages"], 408 **params_subset, 409 } 410 411 inference_payload = load_serving_example(model_info.model_uri) 412 response = pyfunc_serve_and_score_model( 413 model_uri=model_info.model_uri, 414 data=inference_payload, 415 content_type="application/json", 416 extra_args=["--env-manager", "local"], 417 ) 418 419 expect_status_code(response, 200) 420 choices = json.loads(response.content)["choices"] 421 assert choices[0]["message"]["content"] == json.dumps(input_example["messages"]) 422 assert json.loads(choices[1]["message"]["content"]) == { 423 **DEFAULT_PARAMS, 424 **params_subset, 425 } 426 427 428 def test_chat_model_logs_default_metadata_task(tmp_path): 429 model = SimpleChatModel() 430 params_subset = { 431 "max_tokens": 100, 432 } 433 input_example = { 434 "messages": [ 435 { 436 "role": "user", 437 "content": "What is Retrieval-augmented Generation?", 438 } 439 ], 440 **params_subset, 441 } 442 with mlflow.start_run(): 443 model_info = mlflow.pyfunc.log_model( 444 name="model", python_model=model, input_example=input_example 445 ) 446 assert model_info.signature.inputs == CHAT_MODEL_INPUT_SCHEMA 447 assert model_info.signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA 448 assert model_info.metadata["task"] == "agent/v1/chat" 449 450 with mlflow.start_run(): 451 model_info_with_override = mlflow.pyfunc.log_model( 452 name="model", python_model=model, input_example=input_example, metadata={"task": None} 453 ) 454 assert model_info_with_override.metadata["task"] is None 455 456 457 def test_chat_model_works_with_chat_message_input_example(tmp_path): 458 model = SimpleChatModel() 459 input_example = [ 460 ChatMessage(role="user", content="What is Retrieval-augmented Generation?", name="chat") 461 ] 462 with mlflow.start_run(): 463 model_info = mlflow.pyfunc.log_model( 464 name="model", python_model=model, input_example=input_example 465 ) 466 assert model_info.signature.inputs == CHAT_MODEL_INPUT_SCHEMA 467 assert model_info.signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA 468 mlflow_model = Model.load(model_info.model_uri) 469 local_path = _download_artifact_from_uri(model_info.model_uri) 470 assert mlflow_model.load_input_example(local_path) == { 471 "messages": [message.to_dict() for message in input_example], 472 } 473 474 inference_payload = load_serving_example(model_info.model_uri) 475 response = pyfunc_serve_and_score_model( 476 model_uri=model_info.model_uri, 477 data=inference_payload, 478 content_type="application/json", 479 extra_args=["--env-manager", "local"], 480 ) 481 482 expect_status_code(response, 200) 483 choices = json.loads(response.content)["choices"] 484 assert choices[0]["message"]["content"] == json.dumps(json.loads(inference_payload)["messages"]) 485 486 487 def test_chat_model_works_with_infer_signature_multi_input_example(tmp_path): 488 model = SimpleChatModel() 489 params_subset = { 490 "max_tokens": 100, 491 } 492 input_example = { 493 "messages": [ 494 { 495 "role": "assistant", 496 "content": "You are in helpful assistant!", 497 }, 498 { 499 "role": "user", 500 "content": "What is Retrieval-augmented Generation?", 501 }, 502 ], 503 **params_subset, 504 } 505 with mlflow.start_run(): 506 model_info = mlflow.pyfunc.log_model( 507 name="model", python_model=model, input_example=input_example 508 ) 509 assert model_info.signature.inputs == CHAT_MODEL_INPUT_SCHEMA 510 assert model_info.signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA 511 mlflow_model = Model.load(model_info.model_uri) 512 local_path = _download_artifact_from_uri(model_info.model_uri) 513 assert mlflow_model.load_input_example(local_path) == { 514 "messages": input_example["messages"], 515 **params_subset, 516 } 517 518 inference_payload = load_serving_example(model_info.model_uri) 519 response = pyfunc_serve_and_score_model( 520 model_uri=model_info.model_uri, 521 data=inference_payload, 522 content_type="application/json", 523 extra_args=["--env-manager", "local"], 524 ) 525 526 expect_status_code(response, 200) 527 choices = json.loads(response.content)["choices"] 528 assert choices[0]["message"]["content"] == json.dumps(input_example["messages"]) 529 assert json.loads(choices[1]["message"]["content"]) == { 530 **DEFAULT_PARAMS, 531 **params_subset, 532 } 533 534 535 def test_chat_model_predict_stream(tmp_path): 536 model = SimpleChatModel() 537 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 538 539 loaded_model = mlflow.pyfunc.load_model(tmp_path) 540 messages = [ 541 {"role": "system", "content": "You are a helpful assistant"}, 542 {"role": "user", "content": "Hello!"}, 543 ] 544 545 responses = list(loaded_model.predict_stream({"messages": messages})) 546 for i, resp in enumerate(responses[:-1]): 547 assert resp["choices"][0]["delta"]["content"] == f"message {i}" 548 549 assert responses[-1]["choices"][0]["delta"] == {} 550 551 552 def test_chat_model_can_receive_and_return_metadata(): 553 messages = [{"role": "user", "content": "Hello!"}] 554 params = { 555 "custom_inputs": {"image_url": "example", "detail": "high", "other_dict": {"key": "value"}}, 556 } 557 input_example = { 558 "messages": messages, 559 **params, 560 } 561 562 model = ChatModelWithMetadata() 563 with mlflow.start_run(): 564 model_info = mlflow.pyfunc.log_model( 565 name="model", 566 python_model=model, 567 input_example=input_example, 568 ) 569 570 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 571 572 # test that it works for normal pyfunc predict 573 response = loaded_model.predict({"messages": messages, **params}) 574 assert response["custom_outputs"] == params["custom_inputs"] 575 576 # test that it works in serving 577 inference_payload = load_serving_example(model_info.model_uri) 578 response = pyfunc_serve_and_score_model( 579 model_uri=model_info.model_uri, 580 data=inference_payload, 581 content_type="application/json", 582 extra_args=["--env-manager", "local"], 583 ) 584 585 serving_response = json.loads(response.content) 586 assert serving_response["custom_outputs"] == params["custom_inputs"] 587 588 589 def test_chat_model_can_use_tool_calls(): 590 messages = [{"role": "user", "content": "What's the weather?"}] 591 592 weather_tool = ( 593 FunctionToolDefinition( 594 name="get_weather", 595 description="Get the weather for your current location", 596 parameters=ToolParamsSchema( 597 { 598 "city": { 599 "type": "string", 600 "description": "The city to get the weather for", 601 }, 602 "unit": {"type": "string", "enum": ["F", "C"]}, 603 }, 604 required=["city", "unit"], 605 ), 606 ) 607 .to_tool_definition() 608 .to_dict() 609 ) 610 611 example = { 612 "messages": messages, 613 "tools": [weather_tool], 614 } 615 616 model = ChatModelWithToolCalling() 617 with mlflow.start_run(): 618 model_info = mlflow.pyfunc.log_model( 619 name="model", 620 python_model=model, 621 input_example=example, 622 ) 623 624 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 625 response = loaded_model.predict(example) 626 627 model_tool_calls = response["choices"][0]["message"]["tool_calls"] 628 assert json.loads(model_tool_calls[0]["function"]["arguments"]) == { 629 "city": "some_value", 630 "unit": "some_value", 631 } 632 633 634 def test_chat_model_without_context_in_predict(): 635 response = ChatCompletionResponse( 636 choices=[ChatChoice(message=ChatMessage(role="assistant", content="hi"))] 637 ) 638 chunk_response = ChatCompletionChunk( 639 choices=[ChatChunkChoice(delta=ChatChoiceDelta(role="assistant", content="hi"))] 640 ) 641 642 class Model(mlflow.pyfunc.ChatModel): 643 def predict(self, messages: list[ChatMessage], params: ChatParams): 644 return response 645 646 def predict_stream(self, messages: list[ChatMessage], params: ChatParams): 647 yield chunk_response 648 649 model = Model() 650 messages = [ChatMessage(role="user", content="hello?", name="chat")] 651 assert model.predict(messages, ChatParams()) == response 652 assert next(iter(model.predict_stream(messages, ChatParams()))) == chunk_response 653 654 with mlflow.start_run(): 655 model_info = mlflow.pyfunc.log_model( 656 name="model", python_model=model, input_example=messages 657 ) 658 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 659 input_data = {"messages": [{"role": "user", "content": "hello"}]} 660 assert pyfunc_model.predict(input_data) == response.to_dict() 661 assert next(iter(pyfunc_model.predict_stream(input_data))) == chunk_response.to_dict()