/ tests / pyfunc / test_chat_model.py
test_chat_model.py
  1  import json
  2  import pathlib
  3  import pickle
  4  import uuid
  5  from dataclasses import asdict
  6  
  7  import pytest
  8  
  9  import mlflow
 10  from mlflow.exceptions import MlflowException
 11  from mlflow.models.model import Model
 12  from mlflow.models.signature import ModelSignature
 13  from mlflow.models.utils import load_serving_example
 14  from mlflow.pyfunc.loaders.chat_model import _ChatModelPyfuncWrapper
 15  from mlflow.tracing.constant import TraceTagKey
 16  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 17  from mlflow.types.llm import (
 18      CHAT_MODEL_INPUT_SCHEMA,
 19      CHAT_MODEL_OUTPUT_SCHEMA,
 20      ChatChoice,
 21      ChatChoiceDelta,
 22      ChatChunkChoice,
 23      ChatCompletionChunk,
 24      ChatCompletionResponse,
 25      ChatMessage,
 26      ChatParams,
 27      FunctionToolCallArguments,
 28      FunctionToolDefinition,
 29      ToolParamsSchema,
 30  )
 31  from mlflow.types.schema import ColSpec, DataType, Schema
 32  
 33  from tests.helper_functions import (
 34      expect_status_code,
 35      pyfunc_serve_and_score_model,
 36  )
 37  from tests.tracing.helper import get_traces
 38  
 39  # `None`s (`max_tokens` and `stop`) are excluded
 40  DEFAULT_PARAMS = {
 41      "temperature": 1.0,
 42      "n": 1,
 43      "stream": False,
 44  }
 45  
 46  
 47  def get_mock_streaming_response(message, is_last_chunk=False):
 48      if is_last_chunk:
 49          return {
 50              "id": "123",
 51              "model": "MyChatModel",
 52              "choices": [
 53                  {
 54                      "index": 0,
 55                      "delta": {
 56                          "role": None,
 57                          "content": None,
 58                      },
 59                      "finish_reason": "stop",
 60                  }
 61              ],
 62              "usage": {
 63                  "prompt_tokens": 10,
 64                  "completion_tokens": 10,
 65                  "total_tokens": 20,
 66              },
 67          }
 68      else:
 69          return {
 70              "id": "123",
 71              "model": "MyChatModel",
 72              "choices": [
 73                  {
 74                      "index": 0,
 75                      "delta": {
 76                          "role": "assistant",
 77                          "content": message,
 78                      },
 79                      "finish_reason": "stop",
 80                  }
 81              ],
 82              "usage": {
 83                  "prompt_tokens": 10,
 84                  "completion_tokens": 10,
 85                  "total_tokens": 20,
 86              },
 87          }
 88  
 89  
 90  def get_mock_response(messages, params):
 91      return {
 92          "id": "123",
 93          "model": "MyChatModel",
 94          "choices": [
 95              {
 96                  "index": 0,
 97                  "message": {
 98                      "role": "assistant",
 99                      "content": json.dumps([m.to_dict() for m in messages]),
100                  },
101                  "finish_reason": "stop",
102              },
103              {
104                  "index": 1,
105                  "message": {
106                      "role": "user",
107                      "content": json.dumps(params.to_dict()),
108                  },
109                  "finish_reason": "stop",
110              },
111          ],
112          "usage": {
113              "prompt_tokens": 10,
114              "completion_tokens": 10,
115              "total_tokens": 20,
116          },
117      }
118  
119  
120  class SimpleChatModel(mlflow.pyfunc.ChatModel):
121      def predict(
122          self, context, messages: list[ChatMessage], params: ChatParams
123      ) -> ChatCompletionResponse:
124          mock_response = get_mock_response(messages, params)
125          return ChatCompletionResponse.from_dict(mock_response)
126  
127      def predict_stream(self, context, messages: list[ChatMessage], params: ChatParams):
128          num_chunks = 10
129          for i in range(num_chunks):
130              mock_response = get_mock_streaming_response(
131                  f"message {i}", is_last_chunk=(i == num_chunks - 1)
132              )
133              yield ChatCompletionChunk.from_dict(mock_response)
134  
135  
136  class ChatModelWithContext(mlflow.pyfunc.ChatModel):
137      def load_context(self, context):
138          predict_path = pathlib.Path(context.artifacts["predict_fn"])
139          self.predict_fn = pickle.loads(predict_path.read_bytes())
140  
141      def predict(
142          self, context, messages: list[ChatMessage], params: ChatParams
143      ) -> ChatCompletionResponse:
144          message = ChatMessage(role="assistant", content=self.predict_fn())
145          return ChatCompletionResponse.from_dict(get_mock_response([message], params))
146  
147  
148  class ChatModelWithTrace(mlflow.pyfunc.ChatModel):
149      @mlflow.trace
150      def predict(
151          self, context, messages: list[ChatMessage], params: ChatParams
152      ) -> ChatCompletionResponse:
153          mock_response = get_mock_response(messages, params)
154          return ChatCompletionResponse.from_dict(mock_response)
155  
156  
157  class ChatModelWithMetadata(mlflow.pyfunc.ChatModel):
158      def predict(
159          self, context, messages: list[ChatMessage], params: ChatParams
160      ) -> ChatCompletionResponse:
161          mock_response = get_mock_response(messages, params)
162          return ChatCompletionResponse(
163              **mock_response,
164              custom_outputs=params.custom_inputs,
165          )
166  
167  
168  class ChatModelWithToolCalling(mlflow.pyfunc.ChatModel):
169      def predict(
170          self, context, messages: list[ChatMessage], params: ChatParams
171      ) -> ChatCompletionResponse:
172          tools = params.tools
173  
174          # call the first tool with some value for all the required params
175          tool_name = tools[0].function.name
176          tool_params = tools[0].function.parameters
177          arguments = {}
178          for param in tool_params.required:
179              param_type = tool_params.properties[param].type
180              if param_type == "string":
181                  arguments[param] = "some_value"
182              elif param_type == "number":
183                  arguments[param] = 123
184              elif param_type == "boolean":
185                  arguments[param] = True
186              else:
187                  # keep the test example simple
188                  raise ValueError(f"Unsupported param type: {param_type}")
189  
190          tool_call = FunctionToolCallArguments(
191              name=tool_name,
192              arguments=json.dumps(arguments),
193          ).to_tool_call(id=uuid.uuid4().hex)
194  
195          tool_message = ChatMessage(
196              role="assistant",
197              tool_calls=[tool_call],
198          )
199  
200          return ChatCompletionResponse(choices=[ChatChoice(index=0, message=tool_message)])
201  
202  
203  def test_chat_model_save_load(tmp_path):
204      model = SimpleChatModel()
205      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
206  
207      loaded_model = mlflow.pyfunc.load_model(tmp_path)
208      assert isinstance(loaded_model._model_impl, _ChatModelPyfuncWrapper)
209      input_schema = loaded_model.metadata.get_input_schema()
210      output_schema = loaded_model.metadata.get_output_schema()
211      assert input_schema == CHAT_MODEL_INPUT_SCHEMA
212      assert output_schema == CHAT_MODEL_OUTPUT_SCHEMA
213  
214  
215  def test_chat_model_with_trace(tmp_path):
216      model = ChatModelWithTrace()
217      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
218  
219      # predict() call during saving chat model should not generate a trace
220      assert len(get_traces()) == 0
221  
222      loaded_model = mlflow.pyfunc.load_model(tmp_path)
223      messages = [
224          {"role": "system", "content": "You are a helpful assistant"},
225          {"role": "user", "content": "Hello!"},
226      ]
227      loaded_model.predict({"messages": messages})
228  
229      traces = get_traces()
230      assert len(traces) == 1
231      assert traces[0].info.tags[TraceTagKey.TRACE_NAME] == "predict"
232      request = json.loads(traces[0].data.request)
233      assert request["messages"] == [asdict(ChatMessage.from_dict(msg)) for msg in messages]
234  
235  
236  def test_chat_model_save_throws_with_signature(tmp_path):
237      model = SimpleChatModel()
238  
239      with pytest.raises(MlflowException, match="Please remove the `signature` parameter"):
240          mlflow.pyfunc.save_model(
241              python_model=model,
242              path=tmp_path,
243              signature=ModelSignature(
244                  Schema([ColSpec(name="test", type=DataType.string)]),
245                  Schema([ColSpec(name="test", type=DataType.string)]),
246              ),
247          )
248  
249  
250  def mock_predict():
251      return "hello"
252  
253  
254  def test_chat_model_with_context_saves_successfully(tmp_path):
255      model_path = tmp_path / "model"
256      predict_path = tmp_path / "predict.pkl"
257      predict_path.write_bytes(pickle.dumps(mock_predict))
258  
259      model = ChatModelWithContext()
260      mlflow.pyfunc.save_model(
261          python_model=model,
262          path=model_path,
263          artifacts={"predict_fn": str(predict_path)},
264      )
265  
266      loaded_model = mlflow.pyfunc.load_model(model_path)
267      messages = [{"role": "user", "content": "test"}]
268  
269      response = loaded_model.predict({"messages": messages})
270      expected_response = json.dumps([{"role": "assistant", "content": "hello"}])
271      assert response["choices"][0]["message"]["content"] == expected_response
272  
273  
274  @pytest.mark.parametrize(
275      "ret",
276      [
277          "not a ChatCompletionResponse",
278          {"dict": "with", "bad": "keys"},
279          {
280              "id": "1",
281              "created": 1,
282              "model": "m",
283              "choices": [{"bad": "choice"}],
284              "usage": {
285                  "prompt_tokens": 10,
286                  "completion_tokens": 10,
287                  "total_tokens": 20,
288              },
289          },
290      ],
291  )
292  def test_save_throws_on_invalid_output(tmp_path, ret):
293      class BadChatModel(mlflow.pyfunc.ChatModel):
294          def predict(self, context, messages, params) -> ChatCompletionResponse:
295              return ret
296  
297      model = BadChatModel()
298      with pytest.raises(
299          MlflowException,
300          match=(
301              "Failed to save ChatModel. Please ensure that the model's "
302              r"predict\(\) method returns a ChatCompletionResponse object"
303          ),
304      ):
305          mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
306  
307  
308  # test that we can predict with the model
309  def test_chat_model_predict(tmp_path):
310      model = SimpleChatModel()
311      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
312  
313      loaded_model = mlflow.pyfunc.load_model(tmp_path)
314      messages = [
315          {"role": "system", "content": "You are a helpful assistant"},
316          {"role": "user", "content": "Hello!"},
317      ]
318  
319      response = loaded_model.predict({"messages": messages})
320      assert response["choices"][0]["message"]["content"] == json.dumps(messages)
321      assert json.loads(response["choices"][1]["message"]["content"]) == DEFAULT_PARAMS
322  
323      # override all params
324      params_override = {
325          "temperature": 0.5,
326          "max_tokens": 10,
327          "stop": ["\n"],
328          "n": 2,
329          "stream": True,
330          "top_p": 0.1,
331          "top_k": 20,
332          "frequency_penalty": 0.5,
333          "presence_penalty": -0.5,
334      }
335      response = loaded_model.predict({"messages": messages, **params_override})
336      assert response["choices"][0]["message"]["content"] == json.dumps(messages)
337      assert json.loads(response["choices"][1]["message"]["content"]) == params_override
338  
339      # override a subset of params
340      params_subset = {
341          "max_tokens": 100,
342      }
343      response = loaded_model.predict({"messages": messages, **params_subset})
344      assert response["choices"][0]["message"]["content"] == json.dumps(messages)
345      assert json.loads(response["choices"][1]["message"]["content"]) == {
346          **DEFAULT_PARAMS,
347          **params_subset,
348      }
349  
350  
351  def test_chat_model_works_in_serving():
352      model = SimpleChatModel()
353      messages = [
354          {"role": "system", "content": "You are a helpful assistant"},
355          {"role": "user", "content": "Hello!"},
356      ]
357      params_subset = {
358          "max_tokens": 100,
359      }
360      with mlflow.start_run():
361          model_info = mlflow.pyfunc.log_model(
362              name="model",
363              python_model=model,
364              input_example=(messages, params_subset),
365          )
366  
367      inference_payload = load_serving_example(model_info.model_uri)
368      response = pyfunc_serve_and_score_model(
369          model_uri=model_info.model_uri,
370          data=inference_payload,
371          content_type="application/json",
372          extra_args=["--env-manager", "local"],
373      )
374  
375      expect_status_code(response, 200)
376      choices = json.loads(response.content)["choices"]
377      assert choices[0]["message"]["content"] == json.dumps(messages)
378      assert json.loads(choices[1]["message"]["content"]) == {
379          **DEFAULT_PARAMS,
380          **params_subset,
381      }
382  
383  
384  def test_chat_model_works_with_infer_signature_input_example(tmp_path):
385      model = SimpleChatModel()
386      params_subset = {
387          "max_tokens": 100,
388      }
389      input_example = {
390          "messages": [
391              {
392                  "role": "user",
393                  "content": "What is Retrieval-augmented Generation?",
394              }
395          ],
396          **params_subset,
397      }
398      with mlflow.start_run():
399          model_info = mlflow.pyfunc.log_model(
400              name="model", python_model=model, input_example=input_example
401          )
402      assert model_info.signature.inputs == CHAT_MODEL_INPUT_SCHEMA
403      assert model_info.signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA
404      mlflow_model = Model.load(model_info.model_uri)
405      local_path = _download_artifact_from_uri(model_info.model_uri)
406      assert mlflow_model.load_input_example(local_path) == {
407          "messages": input_example["messages"],
408          **params_subset,
409      }
410  
411      inference_payload = load_serving_example(model_info.model_uri)
412      response = pyfunc_serve_and_score_model(
413          model_uri=model_info.model_uri,
414          data=inference_payload,
415          content_type="application/json",
416          extra_args=["--env-manager", "local"],
417      )
418  
419      expect_status_code(response, 200)
420      choices = json.loads(response.content)["choices"]
421      assert choices[0]["message"]["content"] == json.dumps(input_example["messages"])
422      assert json.loads(choices[1]["message"]["content"]) == {
423          **DEFAULT_PARAMS,
424          **params_subset,
425      }
426  
427  
428  def test_chat_model_logs_default_metadata_task(tmp_path):
429      model = SimpleChatModel()
430      params_subset = {
431          "max_tokens": 100,
432      }
433      input_example = {
434          "messages": [
435              {
436                  "role": "user",
437                  "content": "What is Retrieval-augmented Generation?",
438              }
439          ],
440          **params_subset,
441      }
442      with mlflow.start_run():
443          model_info = mlflow.pyfunc.log_model(
444              name="model", python_model=model, input_example=input_example
445          )
446      assert model_info.signature.inputs == CHAT_MODEL_INPUT_SCHEMA
447      assert model_info.signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA
448      assert model_info.metadata["task"] == "agent/v1/chat"
449  
450      with mlflow.start_run():
451          model_info_with_override = mlflow.pyfunc.log_model(
452              name="model", python_model=model, input_example=input_example, metadata={"task": None}
453          )
454      assert model_info_with_override.metadata["task"] is None
455  
456  
457  def test_chat_model_works_with_chat_message_input_example(tmp_path):
458      model = SimpleChatModel()
459      input_example = [
460          ChatMessage(role="user", content="What is Retrieval-augmented Generation?", name="chat")
461      ]
462      with mlflow.start_run():
463          model_info = mlflow.pyfunc.log_model(
464              name="model", python_model=model, input_example=input_example
465          )
466      assert model_info.signature.inputs == CHAT_MODEL_INPUT_SCHEMA
467      assert model_info.signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA
468      mlflow_model = Model.load(model_info.model_uri)
469      local_path = _download_artifact_from_uri(model_info.model_uri)
470      assert mlflow_model.load_input_example(local_path) == {
471          "messages": [message.to_dict() for message in input_example],
472      }
473  
474      inference_payload = load_serving_example(model_info.model_uri)
475      response = pyfunc_serve_and_score_model(
476          model_uri=model_info.model_uri,
477          data=inference_payload,
478          content_type="application/json",
479          extra_args=["--env-manager", "local"],
480      )
481  
482      expect_status_code(response, 200)
483      choices = json.loads(response.content)["choices"]
484      assert choices[0]["message"]["content"] == json.dumps(json.loads(inference_payload)["messages"])
485  
486  
487  def test_chat_model_works_with_infer_signature_multi_input_example(tmp_path):
488      model = SimpleChatModel()
489      params_subset = {
490          "max_tokens": 100,
491      }
492      input_example = {
493          "messages": [
494              {
495                  "role": "assistant",
496                  "content": "You are in helpful assistant!",
497              },
498              {
499                  "role": "user",
500                  "content": "What is Retrieval-augmented Generation?",
501              },
502          ],
503          **params_subset,
504      }
505      with mlflow.start_run():
506          model_info = mlflow.pyfunc.log_model(
507              name="model", python_model=model, input_example=input_example
508          )
509      assert model_info.signature.inputs == CHAT_MODEL_INPUT_SCHEMA
510      assert model_info.signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA
511      mlflow_model = Model.load(model_info.model_uri)
512      local_path = _download_artifact_from_uri(model_info.model_uri)
513      assert mlflow_model.load_input_example(local_path) == {
514          "messages": input_example["messages"],
515          **params_subset,
516      }
517  
518      inference_payload = load_serving_example(model_info.model_uri)
519      response = pyfunc_serve_and_score_model(
520          model_uri=model_info.model_uri,
521          data=inference_payload,
522          content_type="application/json",
523          extra_args=["--env-manager", "local"],
524      )
525  
526      expect_status_code(response, 200)
527      choices = json.loads(response.content)["choices"]
528      assert choices[0]["message"]["content"] == json.dumps(input_example["messages"])
529      assert json.loads(choices[1]["message"]["content"]) == {
530          **DEFAULT_PARAMS,
531          **params_subset,
532      }
533  
534  
535  def test_chat_model_predict_stream(tmp_path):
536      model = SimpleChatModel()
537      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
538  
539      loaded_model = mlflow.pyfunc.load_model(tmp_path)
540      messages = [
541          {"role": "system", "content": "You are a helpful assistant"},
542          {"role": "user", "content": "Hello!"},
543      ]
544  
545      responses = list(loaded_model.predict_stream({"messages": messages}))
546      for i, resp in enumerate(responses[:-1]):
547          assert resp["choices"][0]["delta"]["content"] == f"message {i}"
548  
549      assert responses[-1]["choices"][0]["delta"] == {}
550  
551  
552  def test_chat_model_can_receive_and_return_metadata():
553      messages = [{"role": "user", "content": "Hello!"}]
554      params = {
555          "custom_inputs": {"image_url": "example", "detail": "high", "other_dict": {"key": "value"}},
556      }
557      input_example = {
558          "messages": messages,
559          **params,
560      }
561  
562      model = ChatModelWithMetadata()
563      with mlflow.start_run():
564          model_info = mlflow.pyfunc.log_model(
565              name="model",
566              python_model=model,
567              input_example=input_example,
568          )
569  
570      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
571  
572      # test that it works for normal pyfunc predict
573      response = loaded_model.predict({"messages": messages, **params})
574      assert response["custom_outputs"] == params["custom_inputs"]
575  
576      # test that it works in serving
577      inference_payload = load_serving_example(model_info.model_uri)
578      response = pyfunc_serve_and_score_model(
579          model_uri=model_info.model_uri,
580          data=inference_payload,
581          content_type="application/json",
582          extra_args=["--env-manager", "local"],
583      )
584  
585      serving_response = json.loads(response.content)
586      assert serving_response["custom_outputs"] == params["custom_inputs"]
587  
588  
589  def test_chat_model_can_use_tool_calls():
590      messages = [{"role": "user", "content": "What's the weather?"}]
591  
592      weather_tool = (
593          FunctionToolDefinition(
594              name="get_weather",
595              description="Get the weather for your current location",
596              parameters=ToolParamsSchema(
597                  {
598                      "city": {
599                          "type": "string",
600                          "description": "The city to get the weather for",
601                      },
602                      "unit": {"type": "string", "enum": ["F", "C"]},
603                  },
604                  required=["city", "unit"],
605              ),
606          )
607          .to_tool_definition()
608          .to_dict()
609      )
610  
611      example = {
612          "messages": messages,
613          "tools": [weather_tool],
614      }
615  
616      model = ChatModelWithToolCalling()
617      with mlflow.start_run():
618          model_info = mlflow.pyfunc.log_model(
619              name="model",
620              python_model=model,
621              input_example=example,
622          )
623  
624      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
625      response = loaded_model.predict(example)
626  
627      model_tool_calls = response["choices"][0]["message"]["tool_calls"]
628      assert json.loads(model_tool_calls[0]["function"]["arguments"]) == {
629          "city": "some_value",
630          "unit": "some_value",
631      }
632  
633  
634  def test_chat_model_without_context_in_predict():
635      response = ChatCompletionResponse(
636          choices=[ChatChoice(message=ChatMessage(role="assistant", content="hi"))]
637      )
638      chunk_response = ChatCompletionChunk(
639          choices=[ChatChunkChoice(delta=ChatChoiceDelta(role="assistant", content="hi"))]
640      )
641  
642      class Model(mlflow.pyfunc.ChatModel):
643          def predict(self, messages: list[ChatMessage], params: ChatParams):
644              return response
645  
646          def predict_stream(self, messages: list[ChatMessage], params: ChatParams):
647              yield chunk_response
648  
649      model = Model()
650      messages = [ChatMessage(role="user", content="hello?", name="chat")]
651      assert model.predict(messages, ChatParams()) == response
652      assert next(iter(model.predict_stream(messages, ChatParams()))) == chunk_response
653  
654      with mlflow.start_run():
655          model_info = mlflow.pyfunc.log_model(
656              name="model", python_model=model, input_example=messages
657          )
658      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
659      input_data = {"messages": [{"role": "user", "content": "hello"}]}
660      assert pyfunc_model.predict(input_data) == response.to_dict()
661      assert next(iter(pyfunc_model.predict_stream(input_data))) == chunk_response.to_dict()