test_chat_agent.py
1 import json 2 from typing import Any 3 from uuid import uuid4 4 5 import pydantic 6 import pytest 7 8 import mlflow 9 from mlflow.exceptions import MlflowException 10 from mlflow.models.model import Model 11 from mlflow.models.signature import ModelSignature 12 from mlflow.models.utils import load_serving_example 13 from mlflow.pyfunc.loaders.chat_agent import _ChatAgentPyfuncWrapper 14 from mlflow.pyfunc.model import ChatAgent 15 from mlflow.tracing.constant import TraceTagKey 16 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 17 from mlflow.types.agent import ( 18 CHAT_AGENT_INPUT_EXAMPLE, 19 CHAT_AGENT_INPUT_SCHEMA, 20 CHAT_AGENT_OUTPUT_SCHEMA, 21 ChatAgentChunk, 22 ChatAgentMessage, 23 ChatAgentRequest, 24 ChatAgentResponse, 25 ChatContext, 26 ) 27 from mlflow.types.schema import ColSpec, DataType, Schema 28 29 from tests.helper_functions import ( 30 expect_status_code, 31 pyfunc_serve_and_score_model, 32 ) 33 from tests.tracing.helper import get_traces 34 35 36 def get_mock_response(messages: list[ChatAgentMessage], message=None): 37 return { 38 "messages": [ 39 { 40 "role": "assistant", 41 "content": message or msg.content, 42 "name": "llm", 43 "id": str(uuid4()), 44 } 45 for msg in messages 46 ], 47 } 48 49 50 class SimpleChatAgent(ChatAgent): 51 @mlflow.trace 52 def predict( 53 self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any] 54 ) -> ChatAgentResponse: 55 mock_response = get_mock_response(messages) 56 return ChatAgentResponse(**mock_response) 57 58 def predict_stream( 59 self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any] 60 ): 61 for i in range(5): 62 mock_response = get_mock_response(messages, f"message {i}") 63 mock_response["delta"] = mock_response["messages"][0] 64 mock_response["delta"]["id"] = str(i) 65 yield ChatAgentChunk(**mock_response) 66 67 68 class SimpleBadChatAgent(ChatAgent): 69 @mlflow.trace 70 def predict( 71 self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any] 72 ) -> ChatAgentResponse: 73 mock_response = get_mock_response(messages) 74 return ChatAgentResponse(messages=mock_response) 75 76 def predict_stream( 77 self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any] 78 ): 79 for i in range(5): 80 mock_response = get_mock_response(messages, f"message {i}") 81 mock_response["delta"] = mock_response["messages"][0] 82 yield ChatAgentChunk(delta=mock_response) 83 84 85 class SimpleDictChatAgent(ChatAgent): 86 @mlflow.trace 87 def predict( 88 self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any] 89 ) -> ChatAgentResponse: 90 mock_response = get_mock_response(messages) 91 return ChatAgentResponse(**mock_response).model_dump() 92 93 94 class ChatAgentWithCustomInputs(ChatAgent): 95 def predict( 96 self, messages: list[ChatAgentMessage], context: ChatContext, custom_inputs: dict[str, Any] 97 ) -> ChatAgentResponse: 98 mock_response = get_mock_response(messages) 99 return ChatAgentResponse( 100 **mock_response, 101 custom_outputs=custom_inputs, 102 ) 103 104 105 def test_chat_agent_save_load(tmp_path): 106 model = SimpleChatAgent() 107 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 108 109 loaded_model = mlflow.pyfunc.load_model(tmp_path) 110 assert isinstance(loaded_model._model_impl, _ChatAgentPyfuncWrapper) 111 input_schema = loaded_model.metadata.get_input_schema() 112 output_schema = loaded_model.metadata.get_output_schema() 113 assert input_schema == CHAT_AGENT_INPUT_SCHEMA 114 assert output_schema == CHAT_AGENT_OUTPUT_SCHEMA 115 116 117 def test_chat_agent_save_load_dict_output(tmp_path): 118 model = SimpleDictChatAgent() 119 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 120 121 loaded_model = mlflow.pyfunc.load_model(tmp_path) 122 assert isinstance(loaded_model._model_impl, _ChatAgentPyfuncWrapper) 123 input_schema = loaded_model.metadata.get_input_schema() 124 output_schema = loaded_model.metadata.get_output_schema() 125 assert input_schema == CHAT_AGENT_INPUT_SCHEMA 126 assert output_schema == CHAT_AGENT_OUTPUT_SCHEMA 127 128 129 def test_chat_agent_trace(tmp_path): 130 model = SimpleChatAgent() 131 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 132 133 # predict() call during saving chat model should not generate a trace 134 assert len(get_traces()) == 0 135 136 loaded_model = mlflow.pyfunc.load_model(tmp_path) 137 messages = [{"role": "user", "content": "Hello!"}] 138 loaded_model.predict({"messages": messages}) 139 140 traces = get_traces() 141 assert len(traces) == 1 142 assert traces[0].info.tags[TraceTagKey.TRACE_NAME] == "predict" 143 request = json.loads(traces[0].data.request) 144 assert [{k: v for k, v in msg.items() if k != "id"} for msg in request["messages"]] == [ 145 {k: v for k, v in ChatAgentMessage(**msg).model_dump().items() if k != "id"} 146 for msg in messages 147 ] 148 149 150 def test_chat_agent_save_throws_with_signature(tmp_path): 151 model = SimpleChatAgent() 152 153 with pytest.raises(MlflowException, match="Please remove the `signature` parameter"): 154 mlflow.pyfunc.save_model( 155 python_model=model, 156 path=tmp_path, 157 signature=ModelSignature( 158 inputs=Schema([ColSpec(name="test", type=DataType.string)]), 159 ), 160 ) 161 162 163 @pytest.mark.parametrize( 164 "ret", 165 [ 166 "not a ChatAgentResponse", 167 {"dict": "with", "bad": "keys"}, 168 { 169 "id": "1", 170 "created": 1, 171 "model": "m", 172 "choices": [{"bad": "choice"}], 173 "usage": { 174 "prompt_tokens": 10, 175 "completion_tokens": 10, 176 "total_tokens": 20, 177 }, 178 }, 179 ], 180 ) 181 def test_save_throws_on_invalid_output(tmp_path, ret): 182 class BadChatAgent(ChatAgent): 183 def predict( 184 self, 185 messages: list[ChatAgentMessage], 186 context: ChatContext, 187 custom_inputs: dict[str, Any], 188 ) -> ChatAgentResponse: 189 return ret 190 191 model = BadChatAgent() 192 with pytest.raises( 193 MlflowException, 194 match=("Failed to save ChatAgent. Ensure your model's predict"), 195 ): 196 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 197 198 199 def test_chat_agent_predict(tmp_path): 200 model = ChatAgentWithCustomInputs() 201 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 202 203 loaded_model = mlflow.pyfunc.load_model(tmp_path) 204 205 # test that a single dictionary will work 206 messages = [ 207 {"role": "system", "content": "You are a helpful assistant"}, 208 {"role": "user", "content": "Hello!"}, 209 ] 210 211 response = loaded_model.predict({"messages": messages}) 212 assert response["messages"][0]["content"] == "You are a helpful assistant" 213 214 215 def test_chat_agent_works_with_infer_signature_input_example(): 216 model = SimpleChatAgent() 217 input_example = { 218 "messages": [ 219 { 220 "role": "system", 221 "content": "You are in helpful assistant!", 222 }, 223 { 224 "role": "user", 225 "content": "What is Retrieval-augmented Generation?", 226 }, 227 ], 228 "context": { 229 "conversation_id": "123", 230 "user_id": "456", 231 }, 232 "stream": False, # this is set by default 233 } 234 with mlflow.start_run(): 235 model_info = mlflow.pyfunc.log_model( 236 name="model", python_model=model, input_example=input_example 237 ) 238 assert model_info.signature.inputs == CHAT_AGENT_INPUT_SCHEMA 239 assert model_info.signature.outputs == CHAT_AGENT_OUTPUT_SCHEMA 240 mlflow_model = Model.load(model_info.model_uri) 241 local_path = _download_artifact_from_uri(model_info.model_uri) 242 loaded_input_example = mlflow_model.load_input_example(local_path) 243 # drop the generated UUID 244 loaded_input_example["messages"] = [ 245 {k: v for k, v in msg.items() if k != "id"} for msg in loaded_input_example["messages"] 246 ] 247 assert loaded_input_example == input_example 248 249 inference_payload = load_serving_example(model_info.model_uri) 250 response = pyfunc_serve_and_score_model( 251 model_uri=model_info.model_uri, 252 data=inference_payload, 253 content_type="application/json", 254 extra_args=["--env-manager", "local"], 255 ) 256 257 expect_status_code(response, 200) 258 model_response = json.loads(response.content) 259 assert model_response["messages"][0]["content"] == "You are in helpful assistant!" 260 261 262 def test_chat_agent_logs_default_metadata_task(): 263 model = SimpleChatAgent() 264 with mlflow.start_run(): 265 model_info = mlflow.pyfunc.log_model(name="model", python_model=model) 266 assert model_info.signature.inputs == CHAT_AGENT_INPUT_SCHEMA 267 assert model_info.signature.outputs == CHAT_AGENT_OUTPUT_SCHEMA 268 assert model_info.metadata["task"] == "agent/v2/chat" 269 270 with mlflow.start_run(): 271 model_info_with_override = mlflow.pyfunc.log_model( 272 name="model", python_model=model, metadata={"task": None} 273 ) 274 assert model_info_with_override.metadata["task"] is None 275 276 277 def test_chat_agent_works_with_chat_agent_request_input_example(): 278 model = SimpleChatAgent() 279 input_example_no_params = {"messages": [{"role": "user", "content": "What is rag?"}]} 280 with mlflow.start_run(): 281 model_info = mlflow.pyfunc.log_model( 282 name="model", python_model=model, input_example=input_example_no_params 283 ) 284 mlflow_model = Model.load(model_info.model_uri) 285 local_path = _download_artifact_from_uri(model_info.model_uri) 286 assert mlflow_model.load_input_example(local_path) == input_example_no_params 287 288 input_example_with_params = { 289 "messages": [{"role": "user", "content": "What is rag?"}], 290 "context": {"conversation_id": "121", "user_id": "123"}, 291 } 292 with mlflow.start_run(): 293 model_info = mlflow.pyfunc.log_model( 294 name="model", python_model=model, input_example=input_example_with_params 295 ) 296 mlflow_model = Model.load(model_info.model_uri) 297 local_path = _download_artifact_from_uri(model_info.model_uri) 298 assert mlflow_model.load_input_example(local_path) == input_example_with_params 299 300 inference_payload = load_serving_example(model_info.model_uri) 301 response = pyfunc_serve_and_score_model( 302 model_uri=model_info.model_uri, 303 data=inference_payload, 304 content_type="application/json", 305 extra_args=["--env-manager", "local"], 306 ) 307 308 expect_status_code(response, 200) 309 model_response = json.loads(response.content) 310 assert model_response["messages"][0]["content"] == "What is rag?" 311 312 313 def test_chat_agent_predict_stream(tmp_path): 314 model = SimpleChatAgent() 315 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 316 317 loaded_model = mlflow.pyfunc.load_model(tmp_path) 318 messages = [ 319 {"role": "user", "content": "Hello!"}, 320 ] 321 322 responses = list(loaded_model.predict_stream({"messages": messages})) 323 for i, resp in enumerate(responses[:-1]): 324 assert resp["delta"]["content"] == f"message {i}" 325 326 327 def test_chat_agent_can_receive_and_return_custom(): 328 messages = [{"role": "user", "content": "Hello!"}] 329 input_example = { 330 "messages": messages, 331 "custom_inputs": {"image_url": "example", "detail": "high", "other_dict": {"key": "value"}}, 332 } 333 334 model = ChatAgentWithCustomInputs() 335 with mlflow.start_run(): 336 model_info = mlflow.pyfunc.log_model( 337 name="model", 338 python_model=model, 339 input_example=input_example, 340 ) 341 342 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 343 344 # test that it works for normal pyfunc predict 345 response = loaded_model.predict(input_example) 346 assert response["custom_outputs"] == input_example["custom_inputs"] 347 348 # test that it works in serving 349 inference_payload = load_serving_example(model_info.model_uri) 350 response = pyfunc_serve_and_score_model( 351 model_uri=model_info.model_uri, 352 data=inference_payload, 353 content_type="application/json", 354 extra_args=["--env-manager", "local"], 355 ) 356 357 serving_response = json.loads(response.content) 358 assert serving_response["custom_outputs"] == input_example["custom_inputs"] 359 360 361 def test_chat_agent_predict_wrapper(): 362 model = ChatAgentWithCustomInputs() 363 dict_input_example = { 364 "messages": [{"role": "user", "content": "What is rag?"}], 365 "context": {"conversation_id": "121", "user_id": "123"}, 366 "custom_inputs": {"image_url": "example", "detail": "high", "other_dict": {"key": "value"}}, 367 } 368 chat_agent_request = ChatAgentRequest(**dict_input_example) 369 pydantic_input_example = ( 370 chat_agent_request.messages, 371 chat_agent_request.context, 372 chat_agent_request.custom_inputs, 373 ) 374 dict_input_response = model.predict(dict_input_example) 375 pydantic_input_response = model.predict(*pydantic_input_example) 376 assert dict_input_response.messages[0].id is not None 377 del dict_input_response.messages[0].id 378 assert pydantic_input_response.messages[0].id is not None 379 del pydantic_input_response.messages[0].id 380 assert dict_input_response == pydantic_input_response 381 no_context_dict_input_example = {**dict_input_example, "context": None} 382 no_context_pydantic_input_example = ( 383 chat_agent_request.messages, 384 None, 385 chat_agent_request.custom_inputs, 386 ) 387 dict_input_response = model.predict(no_context_dict_input_example) 388 pydantic_input_response = model.predict(*no_context_pydantic_input_example) 389 assert dict_input_response.messages[0].id is not None 390 del dict_input_response.messages[0].id 391 assert pydantic_input_response.messages[0].id is not None 392 del pydantic_input_response.messages[0].id 393 assert dict_input_response == pydantic_input_response 394 395 model = SimpleChatAgent() 396 dict_input_response = model.predict(dict_input_example) 397 pydantic_input_response = model.predict(*pydantic_input_example) 398 assert dict_input_response.messages[0].id is not None 399 del dict_input_response.messages[0].id 400 assert pydantic_input_response.messages[0].id is not None 401 del pydantic_input_response.messages[0].id 402 assert dict_input_response == pydantic_input_response 403 assert list(model.predict_stream(dict_input_example)) == list( 404 model.predict_stream(*pydantic_input_example) 405 ) 406 407 with pytest.raises(MlflowException, match="Invalid dictionary input for a ChatAgent"): 408 model.predict({"malformed dict": "bad"}) 409 with pytest.raises(MlflowException, match="Invalid dictionary input for a ChatAgent"): 410 model.predict_stream({"malformed dict": "bad"}) 411 412 model = SimpleBadChatAgent() 413 with pytest.raises(pydantic.ValidationError, match="validation error for ChatAgentResponse"): 414 model.predict(dict_input_example) 415 with pytest.raises(pydantic.ValidationError, match="validation error for ChatAgentChunk"): 416 list(model.predict_stream(dict_input_example)) 417 418 419 def test_chat_agent_predict_with_params(tmp_path): 420 # test to codify having params in the signature 421 # needed because `load_model_and_predict` in `utils/_capture_modules.py` expects a params field 422 model = SimpleChatAgent() 423 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 424 425 loaded_model = mlflow.pyfunc.load_model(tmp_path) 426 assert isinstance(loaded_model._model_impl, _ChatAgentPyfuncWrapper) 427 response = loaded_model.predict(CHAT_AGENT_INPUT_EXAMPLE, params=None) 428 assert response["messages"][0]["content"] == "Hello!" 429 430 responses = list(loaded_model.predict_stream(CHAT_AGENT_INPUT_EXAMPLE, params=None)) 431 for i, resp in enumerate(responses[:-1]): 432 assert resp["delta"]["content"] == f"message {i}" 433 434 435 def test_chat_agent_load_context_called_during_save(tmp_path): 436 class ChatAgentWithArtifacts(ChatAgent): 437 def __init__(self): 438 self.prefix = None 439 440 def load_context(self, context): 441 self.prefix = "loaded_prefix" 442 443 def predict( 444 self, 445 messages: list[ChatAgentMessage], 446 context: ChatContext, 447 custom_inputs: dict[str, Any], 448 ) -> ChatAgentResponse: 449 if self.prefix is None: 450 raise ValueError("load_context was not called - prefix is None") 451 return ChatAgentResponse( 452 messages=[ 453 { 454 "role": "assistant", 455 "content": f"{self.prefix}: {messages[0].content}", 456 "id": str(uuid4()), 457 } 458 ] 459 ) 460 461 model = ChatAgentWithArtifacts() 462 save_path = tmp_path / "model" 463 mlflow.pyfunc.save_model( 464 python_model=model, 465 path=save_path, 466 ) 467 468 loaded_model = mlflow.pyfunc.load_model(save_path) 469 response = loaded_model.predict({"messages": [{"role": "user", "content": "Hello!"}]}) 470 assert response["messages"][0]["content"] == "loaded_prefix: Hello!"