/ tests / pyfunc / test_chat_agent.py
test_chat_agent.py
  1  import json
  2  from typing import Any
  3  from uuid import uuid4
  4  
  5  import pydantic
  6  import pytest
  7  
  8  import mlflow
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.models.model import Model
 11  from mlflow.models.signature import ModelSignature
 12  from mlflow.models.utils import load_serving_example
 13  from mlflow.pyfunc.loaders.chat_agent import _ChatAgentPyfuncWrapper
 14  from mlflow.pyfunc.model import ChatAgent
 15  from mlflow.tracing.constant import TraceTagKey
 16  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 17  from mlflow.types.agent import (
 18      CHAT_AGENT_INPUT_EXAMPLE,
 19      CHAT_AGENT_INPUT_SCHEMA,
 20      CHAT_AGENT_OUTPUT_SCHEMA,
 21      ChatAgentChunk,
 22      ChatAgentMessage,
 23      ChatAgentRequest,
 24      ChatAgentResponse,
 25      ChatContext,
 26  )
 27  from mlflow.types.schema import ColSpec, DataType, Schema
 28  
 29  from tests.helper_functions import (
 30      expect_status_code,
 31      pyfunc_serve_and_score_model,
 32  )
 33  from tests.tracing.helper import get_traces
 34  
 35  
 36  def get_mock_response(messages: list[ChatAgentMessage], message=None):
 37      return {
 38          "messages": [
 39              {
 40                  "role": "assistant",
 41                  "content": message or msg.content,
 42                  "name": "llm",
 43                  "id": str(uuid4()),
 44              }
 45              for msg in messages
 46          ],
 47      }
 48  
 49  
 50  class SimpleChatAgent(ChatAgent):
 51      @mlflow.trace
 52      def predict(
 53          self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any]
 54      ) -> ChatAgentResponse:
 55          mock_response = get_mock_response(messages)
 56          return ChatAgentResponse(**mock_response)
 57  
 58      def predict_stream(
 59          self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any]
 60      ):
 61          for i in range(5):
 62              mock_response = get_mock_response(messages, f"message {i}")
 63              mock_response["delta"] = mock_response["messages"][0]
 64              mock_response["delta"]["id"] = str(i)
 65              yield ChatAgentChunk(**mock_response)
 66  
 67  
 68  class SimpleBadChatAgent(ChatAgent):
 69      @mlflow.trace
 70      def predict(
 71          self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any]
 72      ) -> ChatAgentResponse:
 73          mock_response = get_mock_response(messages)
 74          return ChatAgentResponse(messages=mock_response)
 75  
 76      def predict_stream(
 77          self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any]
 78      ):
 79          for i in range(5):
 80              mock_response = get_mock_response(messages, f"message {i}")
 81              mock_response["delta"] = mock_response["messages"][0]
 82              yield ChatAgentChunk(delta=mock_response)
 83  
 84  
 85  class SimpleDictChatAgent(ChatAgent):
 86      @mlflow.trace
 87      def predict(
 88          self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any]
 89      ) -> ChatAgentResponse:
 90          mock_response = get_mock_response(messages)
 91          return ChatAgentResponse(**mock_response).model_dump()
 92  
 93  
 94  class ChatAgentWithCustomInputs(ChatAgent):
 95      def predict(
 96          self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any]
 97      ) -> ChatAgentResponse:
 98          mock_response = get_mock_response(messages)
 99          return ChatAgentResponse(
100              **mock_response,
101              custom_outputs=custom_inputs,
102          )
103  
104  
105  def test_chat_agent_save_load(tmp_path):
106      model = SimpleChatAgent()
107      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
108  
109      loaded_model = mlflow.pyfunc.load_model(tmp_path)
110      assert isinstance(loaded_model._model_impl, _ChatAgentPyfuncWrapper)
111      input_schema = loaded_model.metadata.get_input_schema()
112      output_schema = loaded_model.metadata.get_output_schema()
113      assert input_schema == CHAT_AGENT_INPUT_SCHEMA
114      assert output_schema == CHAT_AGENT_OUTPUT_SCHEMA
115  
116  
117  def test_chat_agent_save_load_dict_output(tmp_path):
118      model = SimpleDictChatAgent()
119      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
120  
121      loaded_model = mlflow.pyfunc.load_model(tmp_path)
122      assert isinstance(loaded_model._model_impl, _ChatAgentPyfuncWrapper)
123      input_schema = loaded_model.metadata.get_input_schema()
124      output_schema = loaded_model.metadata.get_output_schema()
125      assert input_schema == CHAT_AGENT_INPUT_SCHEMA
126      assert output_schema == CHAT_AGENT_OUTPUT_SCHEMA
127  
128  
129  def test_chat_agent_trace(tmp_path):
130      model = SimpleChatAgent()
131      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
132  
133      # predict() call during saving chat model should not generate a trace
134      assert len(get_traces()) == 0
135  
136      loaded_model = mlflow.pyfunc.load_model(tmp_path)
137      messages = [{"role": "user", "content": "Hello!"}]
138      loaded_model.predict({"messages": messages})
139  
140      traces = get_traces()
141      assert len(traces) == 1
142      assert traces[0].info.tags[TraceTagKey.TRACE_NAME] == "predict"
143      request = json.loads(traces[0].data.request)
144      assert [{k: v for k, v in msg.items() if k != "id"} for msg in request["messages"]] == [
145          {k: v for k, v in ChatAgentMessage(**msg).model_dump().items() if k != "id"}
146          for msg in messages
147      ]
148  
149  
150  def test_chat_agent_save_throws_with_signature(tmp_path):
151      model = SimpleChatAgent()
152  
153      with pytest.raises(MlflowException, match="Please remove the `signature` parameter"):
154          mlflow.pyfunc.save_model(
155              python_model=model,
156              path=tmp_path,
157              signature=ModelSignature(
158                  inputs=Schema([ColSpec(name="test", type=DataType.string)]),
159              ),
160          )
161  
162  
163  @pytest.mark.parametrize(
164      "ret",
165      [
166          "not a ChatAgentResponse",
167          {"dict": "with", "bad": "keys"},
168          {
169              "id": "1",
170              "created": 1,
171              "model": "m",
172              "choices": [{"bad": "choice"}],
173              "usage": {
174                  "prompt_tokens": 10,
175                  "completion_tokens": 10,
176                  "total_tokens": 20,
177              },
178          },
179      ],
180  )
181  def test_save_throws_on_invalid_output(tmp_path, ret):
182      class BadChatAgent(ChatAgent):
183          def predict(
184              self,
185              messages: list[ChatAgentMessage],
186              context: ChatContext,
187              custom_inputs: dict[str, Any],
188          ) -> ChatAgentResponse:
189              return ret
190  
191      model = BadChatAgent()
192      with pytest.raises(
193          MlflowException,
194          match=("Failed to save ChatAgent. Ensure your model's predict"),
195      ):
196          mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
197  
198  
199  def test_chat_agent_predict(tmp_path):
200      model = ChatAgentWithCustomInputs()
201      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
202  
203      loaded_model = mlflow.pyfunc.load_model(tmp_path)
204  
205      # test that a single dictionary will work
206      messages = [
207          {"role": "system", "content": "You are a helpful assistant"},
208          {"role": "user", "content": "Hello!"},
209      ]
210  
211      response = loaded_model.predict({"messages": messages})
212      assert response["messages"][0]["content"] == "You are a helpful assistant"
213  
214  
215  def test_chat_agent_works_with_infer_signature_input_example():
216      model = SimpleChatAgent()
217      input_example = {
218          "messages": [
219              {
220                  "role": "system",
221                  "content": "You are in helpful assistant!",
222              },
223              {
224                  "role": "user",
225                  "content": "What is Retrieval-augmented Generation?",
226              },
227          ],
228          "context": {
229              "conversation_id": "123",
230              "user_id": "456",
231          },
232          "stream": False,  # this is set by default
233      }
234      with mlflow.start_run():
235          model_info = mlflow.pyfunc.log_model(
236              name="model", python_model=model, input_example=input_example
237          )
238      assert model_info.signature.inputs == CHAT_AGENT_INPUT_SCHEMA
239      assert model_info.signature.outputs == CHAT_AGENT_OUTPUT_SCHEMA
240      mlflow_model = Model.load(model_info.model_uri)
241      local_path = _download_artifact_from_uri(model_info.model_uri)
242      loaded_input_example = mlflow_model.load_input_example(local_path)
243      # drop the generated UUID
244      loaded_input_example["messages"] = [
245          {k: v for k, v in msg.items() if k != "id"} for msg in loaded_input_example["messages"]
246      ]
247      assert loaded_input_example == input_example
248  
249      inference_payload = load_serving_example(model_info.model_uri)
250      response = pyfunc_serve_and_score_model(
251          model_uri=model_info.model_uri,
252          data=inference_payload,
253          content_type="application/json",
254          extra_args=["--env-manager", "local"],
255      )
256  
257      expect_status_code(response, 200)
258      model_response = json.loads(response.content)
259      assert model_response["messages"][0]["content"] == "You are in helpful assistant!"
260  
261  
262  def test_chat_agent_logs_default_metadata_task():
263      model = SimpleChatAgent()
264      with mlflow.start_run():
265          model_info = mlflow.pyfunc.log_model(name="model", python_model=model)
266      assert model_info.signature.inputs == CHAT_AGENT_INPUT_SCHEMA
267      assert model_info.signature.outputs == CHAT_AGENT_OUTPUT_SCHEMA
268      assert model_info.metadata["task"] == "agent/v2/chat"
269  
270      with mlflow.start_run():
271          model_info_with_override = mlflow.pyfunc.log_model(
272              name="model", python_model=model, metadata={"task": None}
273          )
274      assert model_info_with_override.metadata["task"] is None
275  
276  
277  def test_chat_agent_works_with_chat_agent_request_input_example():
278      model = SimpleChatAgent()
279      input_example_no_params = {"messages": [{"role": "user", "content": "What is rag?"}]}
280      with mlflow.start_run():
281          model_info = mlflow.pyfunc.log_model(
282              name="model", python_model=model, input_example=input_example_no_params
283          )
284      mlflow_model = Model.load(model_info.model_uri)
285      local_path = _download_artifact_from_uri(model_info.model_uri)
286      assert mlflow_model.load_input_example(local_path) == input_example_no_params
287  
288      input_example_with_params = {
289          "messages": [{"role": "user", "content": "What is rag?"}],
290          "context": {"conversation_id": "121", "user_id": "123"},
291      }
292      with mlflow.start_run():
293          model_info = mlflow.pyfunc.log_model(
294              name="model", python_model=model, input_example=input_example_with_params
295          )
296      mlflow_model = Model.load(model_info.model_uri)
297      local_path = _download_artifact_from_uri(model_info.model_uri)
298      assert mlflow_model.load_input_example(local_path) == input_example_with_params
299  
300      inference_payload = load_serving_example(model_info.model_uri)
301      response = pyfunc_serve_and_score_model(
302          model_uri=model_info.model_uri,
303          data=inference_payload,
304          content_type="application/json",
305          extra_args=["--env-manager", "local"],
306      )
307  
308      expect_status_code(response, 200)
309      model_response = json.loads(response.content)
310      assert model_response["messages"][0]["content"] == "What is rag?"
311  
312  
313  def test_chat_agent_predict_stream(tmp_path):
314      model = SimpleChatAgent()
315      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
316  
317      loaded_model = mlflow.pyfunc.load_model(tmp_path)
318      messages = [
319          {"role": "user", "content": "Hello!"},
320      ]
321  
322      responses = list(loaded_model.predict_stream({"messages": messages}))
323      for i, resp in enumerate(responses[:-1]):
324          assert resp["delta"]["content"] == f"message {i}"
325  
326  
327  def test_chat_agent_can_receive_and_return_custom():
328      messages = [{"role": "user", "content": "Hello!"}]
329      input_example = {
330          "messages": messages,
331          "custom_inputs": {"image_url": "example", "detail": "high", "other_dict": {"key": "value"}},
332      }
333  
334      model = ChatAgentWithCustomInputs()
335      with mlflow.start_run():
336          model_info = mlflow.pyfunc.log_model(
337              name="model",
338              python_model=model,
339              input_example=input_example,
340          )
341  
342      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
343  
344      # test that it works for normal pyfunc predict
345      response = loaded_model.predict(input_example)
346      assert response["custom_outputs"] == input_example["custom_inputs"]
347  
348      # test that it works in serving
349      inference_payload = load_serving_example(model_info.model_uri)
350      response = pyfunc_serve_and_score_model(
351          model_uri=model_info.model_uri,
352          data=inference_payload,
353          content_type="application/json",
354          extra_args=["--env-manager", "local"],
355      )
356  
357      serving_response = json.loads(response.content)
358      assert serving_response["custom_outputs"] == input_example["custom_inputs"]
359  
360  
361  def test_chat_agent_predict_wrapper():
362      model = ChatAgentWithCustomInputs()
363      dict_input_example = {
364          "messages": [{"role": "user", "content": "What is rag?"}],
365          "context": {"conversation_id": "121", "user_id": "123"},
366          "custom_inputs": {"image_url": "example", "detail": "high", "other_dict": {"key": "value"}},
367      }
368      chat_agent_request = ChatAgentRequest(**dict_input_example)
369      pydantic_input_example = (
370          chat_agent_request.messages,
371          chat_agent_request.context,
372          chat_agent_request.custom_inputs,
373      )
374      dict_input_response = model.predict(dict_input_example)
375      pydantic_input_response = model.predict(*pydantic_input_example)
376      assert dict_input_response.messages[0].id is not None
377      del dict_input_response.messages[0].id
378      assert pydantic_input_response.messages[0].id is not None
379      del pydantic_input_response.messages[0].id
380      assert dict_input_response == pydantic_input_response
381      no_context_dict_input_example = {**dict_input_example, "context": None}
382      no_context_pydantic_input_example = (
383          chat_agent_request.messages,
384          None,
385          chat_agent_request.custom_inputs,
386      )
387      dict_input_response = model.predict(no_context_dict_input_example)
388      pydantic_input_response = model.predict(*no_context_pydantic_input_example)
389      assert dict_input_response.messages[0].id is not None
390      del dict_input_response.messages[0].id
391      assert pydantic_input_response.messages[0].id is not None
392      del pydantic_input_response.messages[0].id
393      assert dict_input_response == pydantic_input_response
394  
395      model = SimpleChatAgent()
396      dict_input_response = model.predict(dict_input_example)
397      pydantic_input_response = model.predict(*pydantic_input_example)
398      assert dict_input_response.messages[0].id is not None
399      del dict_input_response.messages[0].id
400      assert pydantic_input_response.messages[0].id is not None
401      del pydantic_input_response.messages[0].id
402      assert dict_input_response == pydantic_input_response
403      assert list(model.predict_stream(dict_input_example)) == list(
404          model.predict_stream(*pydantic_input_example)
405      )
406  
407      with pytest.raises(MlflowException, match="Invalid dictionary input for a ChatAgent"):
408          model.predict({"malformed dict": "bad"})
409      with pytest.raises(MlflowException, match="Invalid dictionary input for a ChatAgent"):
410          model.predict_stream({"malformed dict": "bad"})
411  
412      model = SimpleBadChatAgent()
413      with pytest.raises(pydantic.ValidationError, match="validation error for ChatAgentResponse"):
414          model.predict(dict_input_example)
415      with pytest.raises(pydantic.ValidationError, match="validation error for ChatAgentChunk"):
416          list(model.predict_stream(dict_input_example))
417  
418  
419  def test_chat_agent_predict_with_params(tmp_path):
420      # test to codify having params in the signature
421      # needed because `load_model_and_predict` in `utils/_capture_modules.py` expects a params field
422      model = SimpleChatAgent()
423      mlflow.pyfunc.save_model(python_model=model, path=tmp_path)
424  
425      loaded_model = mlflow.pyfunc.load_model(tmp_path)
426      assert isinstance(loaded_model._model_impl, _ChatAgentPyfuncWrapper)
427      response = loaded_model.predict(CHAT_AGENT_INPUT_EXAMPLE, params=None)
428      assert response["messages"][0]["content"] == "Hello!"
429  
430      responses = list(loaded_model.predict_stream(CHAT_AGENT_INPUT_EXAMPLE, params=None))
431      for i, resp in enumerate(responses[:-1]):
432          assert resp["delta"]["content"] == f"message {i}"
433  
434  
435  def test_chat_agent_load_context_called_during_save(tmp_path):
436      class ChatAgentWithArtifacts(ChatAgent):
437          def __init__(self):
438              self.prefix = None
439  
440          def load_context(self, context):
441              self.prefix = "loaded_prefix"
442  
443          def predict(
444              self,
445              messages: list[ChatAgentMessage],
446              context: ChatContext,
447              custom_inputs: dict[str, Any],
448          ) -> ChatAgentResponse:
449              if self.prefix is None:
450                  raise ValueError("load_context was not called - prefix is None")
451              return ChatAgentResponse(
452                  messages=[
453                      {
454                          "role": "assistant",
455                          "content": f"{self.prefix}: {messages[0].content}",
456                          "id": str(uuid4()),
457                      }
458                  ]
459              )
460  
461      model = ChatAgentWithArtifacts()
462      save_path = tmp_path / "model"
463      mlflow.pyfunc.save_model(
464          python_model=model,
465          path=save_path,
466      )
467  
468      loaded_model = mlflow.pyfunc.load_model(save_path)
469      response = loaded_model.predict({"messages": [{"role": "user", "content": "Hello!"}]})
470      assert response["messages"][0]["content"] == "loaded_prefix: Hello!"