test_responses_agent.py
1 import functools 2 import pathlib 3 import pickle 4 from typing import Generator 5 from uuid import uuid4 6 7 import pytest 8 9 import mlflow 10 from mlflow.entities.span import SpanType 11 from mlflow.exceptions import MlflowException 12 from mlflow.models.signature import ModelSignature 13 from mlflow.pyfunc.loaders.responses_agent import _ResponsesAgentPyfuncWrapper 14 from mlflow.pyfunc.model import _DEFAULT_RESPONSES_AGENT_METADATA_TASK, ResponsesAgent 15 from mlflow.types.responses import ( 16 _HAS_LANGCHAIN_BASE_MESSAGE, 17 RESPONSES_AGENT_INPUT_EXAMPLE, 18 RESPONSES_AGENT_INPUT_SCHEMA, 19 RESPONSES_AGENT_OUTPUT_SCHEMA, 20 ResponsesAgentRequest, 21 ResponsesAgentResponse, 22 ResponsesAgentStreamEvent, 23 output_to_responses_items_stream, 24 ) 25 26 from tests.tracing.helper import get_traces, purge_traces 27 28 if _HAS_LANGCHAIN_BASE_MESSAGE: 29 pass 30 from mlflow.types.schema import ColSpec, DataType, Schema 31 32 33 def get_mock_response(request: ResponsesAgentRequest): 34 return { 35 "output": [ 36 { 37 "type": "message", 38 "id": str(uuid4()), 39 "status": "completed", 40 "role": "assistant", 41 "content": [ 42 { 43 "type": "output_text", 44 "text": request.input[0].content, 45 } 46 ], 47 } 48 ], 49 } 50 51 52 def get_stream_mock_response(): 53 yield from [ 54 { 55 "type": "response.output_item.added", 56 "output_index": 0, 57 "item": { 58 "type": "message", 59 "id": "1", 60 "status": "in_progress", 61 "role": "assistant", 62 "content": [], 63 }, 64 }, 65 { 66 "type": "response.content_part.added", 67 "item_id": "1", 68 "output_index": 0, 69 "content_index": 0, 70 "part": {"type": "output_text", "text": "", "annotations": []}, 71 }, 72 { 73 "type": "response.output_text.delta", 74 "item_id": "1", 75 "output_index": 0, 76 "content_index": 0, 77 "delta": "Deb", 78 }, 79 { 80 "type": "response.output_text.delta", 81 "item_id": "1", 82 "output_index": 0, 83 "content_index": 0, 84 "delta": "rid", 85 }, 86 { 87 "type": "response.output_text.done", 88 "item_id": "1", 89 "output_index": 0, 90 "content_index": 0, 91 "text": "Debrid", 92 }, 93 { 94 "type": "response.content_part.done", 95 "item_id": "1", 96 "output_index": 0, 97 "content_index": 0, 98 "part": { 99 "type": "output_text", 100 "text": "Debrid", 101 "annotations": [], 102 }, 103 }, 104 ] 105 106 107 class SimpleResponsesAgent(ResponsesAgent): 108 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 109 mock_response = get_mock_response(request) 110 return ResponsesAgentResponse(**mock_response) 111 112 def predict_stream( 113 self, request: ResponsesAgentRequest 114 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 115 yield from [ResponsesAgentStreamEvent(**r) for r in get_stream_mock_response()] 116 117 118 class ResponsesAgentWithContext(ResponsesAgent): 119 def load_context(self, context): 120 predict_path = pathlib.Path(context.artifacts["predict_fn"]) 121 self.predict_fn = pickle.loads(predict_path.read_bytes()) 122 123 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 124 return ResponsesAgentResponse( 125 output=[ 126 { 127 "type": "message", 128 "id": "test-id", 129 "status": "completed", 130 "role": "assistant", 131 "content": [ 132 { 133 "type": "output_text", 134 "text": self.predict_fn(), 135 } 136 ], 137 } 138 ] 139 ) 140 141 def predict_stream( 142 self, request: ResponsesAgentRequest 143 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 144 yield ResponsesAgentStreamEvent( 145 type="response.output_item.added", 146 output_index=0, 147 item=self.create_text_output_item(self.predict_fn(), "test-id"), 148 ) 149 150 151 def mock_responses_predict(): 152 return "hello from context" 153 154 155 def test_responses_agent_with_context(tmp_path): 156 predict_path = tmp_path / "predict.pkl" 157 predict_path.write_bytes(pickle.dumps(mock_responses_predict)) 158 159 model = ResponsesAgentWithContext() 160 161 with mlflow.start_run(): 162 model_info = mlflow.pyfunc.log_model( 163 name="model", 164 python_model=model, 165 artifacts={"predict_fn": str(predict_path)}, 166 ) 167 168 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 169 170 # Test predict 171 response = loaded_model.predict(RESPONSES_AGENT_INPUT_EXAMPLE) 172 assert response["output"][0]["content"][0]["text"] == "hello from context" 173 174 # Test predict_stream 175 responses = list(loaded_model.predict_stream(RESPONSES_AGENT_INPUT_EXAMPLE)) 176 assert len(responses) == 1 177 assert responses[0]["item"]["content"][0]["text"] == "hello from context" 178 179 180 def test_responses_agent_save_load_signatures(tmp_path): 181 model = SimpleResponsesAgent() 182 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 183 184 loaded_model = mlflow.pyfunc.load_model(tmp_path) 185 assert isinstance(loaded_model._model_impl, _ResponsesAgentPyfuncWrapper) 186 input_schema = loaded_model.metadata.get_input_schema() 187 output_schema = loaded_model.metadata.get_output_schema() 188 assert input_schema == RESPONSES_AGENT_INPUT_SCHEMA 189 assert output_schema == RESPONSES_AGENT_OUTPUT_SCHEMA 190 191 192 def test_responses_agent_log_default_task(): 193 model = SimpleResponsesAgent() 194 with mlflow.start_run(): 195 model_info = mlflow.pyfunc.log_model(name="model", python_model=model) 196 assert model_info.metadata["task"] == _DEFAULT_RESPONSES_AGENT_METADATA_TASK 197 198 with mlflow.start_run(): 199 model_info_with_override = mlflow.pyfunc.log_model( 200 name="model", python_model=model, metadata={"task": None} 201 ) 202 assert model_info_with_override.metadata["task"] is None 203 204 205 def test_responses_agent_predict(tmp_path): 206 model_path = tmp_path / "model" 207 model = SimpleResponsesAgent() 208 response = model.predict(RESPONSES_AGENT_INPUT_EXAMPLE) 209 assert response.output[0].content[0]["type"] == "output_text" 210 response = model.predict_stream(RESPONSES_AGENT_INPUT_EXAMPLE) 211 assert next(response).type == "response.output_item.added" 212 mlflow.pyfunc.save_model(python_model=model, path=model_path) 213 loaded_model = mlflow.pyfunc.load_model(model_path) 214 response = loaded_model.predict(RESPONSES_AGENT_INPUT_EXAMPLE) 215 assert response["output"][0]["type"] == "message" 216 assert response["output"][0]["content"][0]["type"] == "output_text" 217 assert response["output"][0]["content"][0]["text"] == "Hello!" 218 219 220 def test_responses_agent_predict_stream(tmp_path): 221 model_path = tmp_path / "model" 222 model = SimpleResponsesAgent() 223 mlflow.pyfunc.save_model(python_model=model, path=model_path) 224 loaded_model = mlflow.pyfunc.load_model(model_path) 225 responses = list(loaded_model.predict_stream(RESPONSES_AGENT_INPUT_EXAMPLE)) 226 # most of this test is that the predict_stream parsing works in _ResponsesAgentPyfuncWrapper 227 for r in responses: 228 assert "type" in r 229 230 231 def test_responses_agent_with_pydantic_input(): 232 model = SimpleResponsesAgent() 233 response = model.predict(ResponsesAgentRequest(**RESPONSES_AGENT_INPUT_EXAMPLE)) 234 assert response.output[0].content[0]["text"] == "Hello!" 235 236 237 class CustomInputsResponsesAgent(ResponsesAgent): 238 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 239 mock_response = get_mock_response(request) 240 return ResponsesAgentResponse(**mock_response, custom_outputs=request.custom_inputs) 241 242 def predict_stream(self, request: ResponsesAgentRequest): 243 for r in get_stream_mock_response(): 244 r["custom_outputs"] = request.custom_inputs 245 yield r 246 247 248 def test_responses_agent_custom_inputs(tmp_path): 249 model = CustomInputsResponsesAgent() 250 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 251 loaded_model = mlflow.pyfunc.load_model(tmp_path) 252 payload = {**RESPONSES_AGENT_INPUT_EXAMPLE, "custom_inputs": {"asdf": "asdf"}} 253 response = loaded_model.predict(payload) 254 assert response["custom_outputs"] == {"asdf": "asdf"} 255 responses = list( 256 loaded_model.predict_stream({ 257 **RESPONSES_AGENT_INPUT_EXAMPLE, 258 "custom_inputs": {"asdf": "asdf"}, 259 }) 260 ) 261 for r in responses: 262 assert r["custom_outputs"] == {"asdf": "asdf"} 263 264 265 def test_responses_agent_predict_with_params(tmp_path): 266 # needed because `load_model_and_predict` in `utils/_capture_modules.py` expects a params field 267 model = SimpleResponsesAgent() 268 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 269 loaded_model = mlflow.pyfunc.load_model(tmp_path) 270 response = loaded_model.predict(RESPONSES_AGENT_INPUT_EXAMPLE, params=None) 271 assert response["output"][0]["type"] == "message" 272 273 274 def test_responses_agent_save_throws_with_signature(tmp_path): 275 model = SimpleResponsesAgent() 276 277 with pytest.raises(MlflowException, match="Please remove the `signature` parameter"): 278 mlflow.pyfunc.save_model( 279 python_model=model, 280 path=tmp_path, 281 signature=ModelSignature( 282 inputs=Schema([ColSpec(name="test", type=DataType.string)]), 283 ), 284 ) 285 286 287 def test_responses_agent_throws_with_invalid_output(tmp_path): 288 class BadResponsesAgent(ResponsesAgent): 289 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 290 return {"output": [{"type": "message", "content": [{"type": "output_text"}]}]} 291 292 model = BadResponsesAgent() 293 with pytest.raises( 294 MlflowException, match="Failed to save ResponsesAgent. Ensure your model's predict" 295 ): 296 mlflow.pyfunc.save_model(python_model=model, path=tmp_path) 297 298 299 @pytest.mark.parametrize( 300 ("input", "outputs"), 301 [ 302 # 1. Normal text input output 303 ( 304 RESPONSES_AGENT_INPUT_EXAMPLE, 305 { 306 "output": [ 307 { 308 "type": "message", 309 "id": "test", 310 "status": "completed", 311 "role": "assistant", 312 "content": [{"type": "output_text", "text": "Dummy output"}], 313 } 314 ], 315 }, 316 ), 317 # 2. Image input 318 ( 319 { 320 "input": [ 321 { 322 "role": "user", 323 "content": [ 324 {"type": "input_text", "text": "what is in this image?"}, 325 {"type": "input_image", "image_url": "test.jpg"}, 326 ], 327 } 328 ], 329 }, 330 { 331 "output": [ 332 { 333 "type": "message", 334 "id": "test", 335 "status": "completed", 336 "role": "assistant", 337 "content": [{"type": "output_text", "text": "Dummy output"}], 338 } 339 ], 340 }, 341 ), 342 # 3. Tool calling 343 ( 344 { 345 "input": [ 346 { 347 "role": "user", 348 "content": "What is the weather like in Boston today?", 349 } 350 ], 351 "tools": [ 352 { 353 "type": "function", 354 "name": "get_current_weather", 355 "parameters": { 356 "type": "object", 357 "properties": {"location": {"type": "string"}}, 358 "required": ["location", "unit"], 359 }, 360 } 361 ], 362 }, 363 { 364 "output": [ 365 { 366 "arguments": '{"location":"Boston, MA","unit":"celsius"}', 367 "call_id": "function_call_1", 368 "name": "get_current_weather", 369 "type": "function_call", 370 "id": "fc_6805c835567481918c27724bbe931dc40b1b7951a48825bb", 371 "status": "completed", 372 } 373 ] 374 }, 375 ), 376 ], 377 ) 378 def test_responses_agent_trace(input, outputs): 379 class TracedResponsesAgent(ResponsesAgent): 380 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 381 return ResponsesAgentResponse(**outputs) 382 383 def predict_stream( 384 self, request: ResponsesAgentRequest 385 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 386 for item in outputs["output"]: 387 yield ResponsesAgentStreamEvent( 388 type="response.output_item.done", 389 item=item, 390 ) 391 392 model = TracedResponsesAgent() 393 model.predict(ResponsesAgentRequest(**input)) 394 395 traces = get_traces() 396 assert len(traces) == 1 397 spans = traces[0].data.spans 398 assert len(spans) == 1 399 assert spans[0].name == "predict" 400 assert spans[0].span_type == SpanType.AGENT 401 402 list(model.predict_stream(ResponsesAgentRequest(**input))) 403 404 traces = get_traces() 405 assert len(traces) == 2 406 spans = traces[0].data.spans 407 assert len(spans) == 1 408 assert spans[0].name == "predict_stream" 409 assert spans[0].span_type == SpanType.AGENT 410 411 assert "output" in spans[0].outputs 412 assert spans[0].outputs["output"] == outputs["output"] 413 414 415 def test_responses_agent_custom_trace_configurations(): 416 # Agent with custom span names and attributes 417 class CustomTracedAgent(ResponsesAgent): 418 @mlflow.trace( 419 name="custom_predict", span_type=SpanType.AGENT, attributes={"custom": "value"} 420 ) 421 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 422 return ResponsesAgentResponse(**get_mock_response(request)) 423 424 @mlflow.trace( 425 name="custom_predict_stream", 426 span_type=SpanType.AGENT, 427 attributes={"stream": "true"}, 428 output_reducer=ResponsesAgent.responses_agent_output_reducer, 429 ) 430 def predict_stream( 431 self, request: ResponsesAgentRequest 432 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 433 yield from [ResponsesAgentStreamEvent(**r) for r in get_stream_mock_response()] 434 435 purge_traces() 436 437 agent = CustomTracedAgent() 438 agent.predict(ResponsesAgentRequest(**RESPONSES_AGENT_INPUT_EXAMPLE)) 439 440 traces_predict = get_traces() 441 assert len(traces_predict) == 1 442 spans_predict = traces_predict[0].data.spans 443 assert len(spans_predict) == 1 444 assert spans_predict[0].name == "custom_predict" 445 assert spans_predict[0].span_type == SpanType.AGENT 446 assert spans_predict[0].attributes.get("custom") == "value" 447 448 purge_traces() 449 list(agent.predict_stream(ResponsesAgentRequest(**RESPONSES_AGENT_INPUT_EXAMPLE))) 450 451 traces_stream = get_traces() 452 assert len(traces_stream) == 1 453 spans_stream = traces_stream[0].data.spans 454 assert len(spans_stream) == 1 455 assert spans_stream[0].name == "custom_predict_stream" 456 assert spans_stream[0].span_type == SpanType.AGENT 457 assert spans_stream[0].attributes.get("stream") == "true" 458 459 460 def test_responses_agent_non_mlflow_decorators(): 461 # Create a custom decorator to test with 462 def custom_decorator(func): 463 @functools.wraps(func) 464 def wrapper(*args, **kwargs): 465 return func(*args, **kwargs) 466 467 return wrapper 468 469 class MixedDecoratedAgent(ResponsesAgent): 470 @custom_decorator 471 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 472 return ResponsesAgentResponse(**get_mock_response(request)) 473 474 # Just a regular method (no decorator) to test that it gets auto-traced 475 def predict_stream( 476 self, request: ResponsesAgentRequest 477 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 478 yield from [ResponsesAgentStreamEvent(**r) for r in get_stream_mock_response()] 479 480 # Both methods should get auto-traced since they don't have __mlflow_traced__ 481 agent = MixedDecoratedAgent() 482 agent.predict(ResponsesAgentRequest(**RESPONSES_AGENT_INPUT_EXAMPLE)) 483 484 traces_mixed_predict = get_traces() 485 assert len(traces_mixed_predict) == 1 486 spans_mixed_predict = traces_mixed_predict[0].data.spans 487 assert len(spans_mixed_predict) == 1 488 assert spans_mixed_predict[0].name == "predict" 489 assert spans_mixed_predict[0].span_type == SpanType.AGENT 490 491 purge_traces() 492 list(agent.predict_stream(ResponsesAgentRequest(**RESPONSES_AGENT_INPUT_EXAMPLE))) 493 494 traces_mixed_stream = get_traces() 495 assert len(traces_mixed_stream) == 1 496 spans_mixed_stream = traces_mixed_stream[0].data.spans 497 assert len(spans_mixed_stream) == 1 498 assert spans_mixed_stream[0].name == "predict_stream" 499 assert spans_mixed_stream[0].span_type == SpanType.AGENT 500 501 502 @pytest.mark.parametrize( 503 ("chunks", "expected_output"), 504 [ 505 ( 506 [ 507 { 508 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 509 "choices": [{"delta": {"content": "", "role": "assistant"}, "index": 0}], 510 "object": "chat.completion.chunk", 511 }, 512 { 513 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 514 "choices": [ 515 { 516 "delta": { 517 "content": [ 518 { 519 "type": "reasoning", 520 "summary": [{"type": "summary_text", "text": "We"}], 521 } 522 ] 523 }, 524 "index": 0, 525 } 526 ], 527 "object": "chat.completion.chunk", 528 }, 529 { 530 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 531 "choices": [ 532 { 533 "delta": { 534 "content": [ 535 { 536 "type": "reasoning", 537 "summary": [{"type": "summary_text", "text": " need"}], 538 } 539 ] 540 }, 541 "index": 0, 542 } 543 ], 544 "object": "chat.completion.chunk", 545 }, 546 { 547 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 548 "choices": [{"delta": {"content": ""}, "index": 0}], 549 "object": "chat.completion.chunk", 550 }, 551 { 552 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 553 "choices": [{"delta": {"content": "Hello"}, "index": 0}], 554 "object": "chat.completion.chunk", 555 }, 556 { 557 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 558 "choices": [{"delta": {"content": "!"}, "index": 0}], 559 "object": "chat.completion.chunk", 560 }, 561 ], 562 [ 563 ResponsesAgentStreamEvent( 564 type="response.output_text.delta", 565 custom_outputs=None, 566 item_id="chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 567 delta="", 568 ), 569 ResponsesAgentStreamEvent( 570 type="response.output_item.done", 571 custom_outputs=None, 572 item={ 573 "type": "reasoning", 574 "summary": [{"type": "summary_text", "text": "We need"}], 575 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 576 }, 577 ), 578 ResponsesAgentStreamEvent( 579 type="response.output_text.delta", 580 custom_outputs=None, 581 item_id="chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 582 delta="", 583 ), 584 ResponsesAgentStreamEvent( 585 type="response.output_text.delta", 586 custom_outputs=None, 587 item_id="chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 588 delta="Hello", 589 ), 590 ResponsesAgentStreamEvent( 591 type="response.output_text.delta", 592 custom_outputs=None, 593 item_id="chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 594 delta="!", 595 ), 596 ResponsesAgentStreamEvent( 597 type="response.output_item.done", 598 custom_outputs=None, 599 item={ 600 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 601 "content": [{"text": "Hello!", "type": "output_text", "annotations": []}], 602 "role": "assistant", 603 "type": "message", 604 }, 605 ), 606 ], 607 ), 608 ( 609 [ 610 { 611 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 612 "choices": [ 613 { 614 "delta": {"content": "", "role": "assistant"}, 615 "finish_reason": None, 616 "index": 0, 617 "logprobs": None, 618 } 619 ], 620 "object": "chat.completion.chunk", 621 }, 622 { 623 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 624 "choices": [ 625 { 626 "delta": { 627 "content": [ 628 { 629 "type": "reasoning", 630 "summary": [ 631 { 632 "type": "summary_text", 633 "text": "We need to respond. The user just says " 634 '"hi". We can reply friendly.', 635 } 636 ], 637 }, 638 {"type": "text", "text": "Hello! How can I help you today?"}, 639 ] 640 }, 641 "finish_reason": None, 642 "index": 0, 643 "logprobs": None, 644 } 645 ], 646 "object": "chat.completion.chunk", 647 }, 648 { 649 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 650 "choices": [ 651 { 652 "delta": {"content": ""}, 653 "finish_reason": "stop", 654 "index": 0, 655 "logprobs": None, 656 } 657 ], 658 "object": "chat.completion.chunk", 659 }, 660 ], 661 [ 662 ResponsesAgentStreamEvent( 663 type="response.output_text.delta", 664 custom_outputs=None, 665 item_id="chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 666 delta="", 667 ), 668 ResponsesAgentStreamEvent( 669 type="response.output_text.delta", 670 custom_outputs=None, 671 item_id="chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 672 delta="Hello! How can I help you today?", 673 ), 674 ResponsesAgentStreamEvent( 675 type="response.output_item.done", 676 custom_outputs=None, 677 item={ 678 "type": "reasoning", 679 "summary": [ 680 { 681 "type": "summary_text", 682 "text": 'We need to respond. The user just says "hi". ' 683 "We can reply friendly.", 684 } 685 ], 686 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 687 }, 688 ), 689 ResponsesAgentStreamEvent( 690 type="response.output_text.delta", 691 custom_outputs=None, 692 item_id="chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 693 delta="", 694 ), 695 ResponsesAgentStreamEvent( 696 type="response.output_item.done", 697 custom_outputs=None, 698 item={ 699 "id": "chatcmpl_fd04a20f-f348-45e1-af37-68cf3bb08bdb", 700 "content": [ 701 { 702 "text": "Hello! How can I help you today?", 703 "type": "output_text", 704 "annotations": [], 705 } 706 ], 707 "role": "assistant", 708 "type": "message", 709 }, 710 ), 711 ], 712 ), 713 ( 714 [ 715 { 716 "id": "msg_bdrk_016AC1ojH743YLHDfgnf4B7Y", 717 "choices": [ 718 { 719 "delta": {"content": "Hello", "role": "assistant"}, 720 "finish_reason": None, 721 "index": 0, 722 } 723 ], 724 "object": "chat.completion.chunk", 725 }, 726 { 727 "id": "msg_bdrk_016AC1ojH743YLHDfgnf4B7Y", 728 "choices": [ 729 { 730 "delta": {"content": " there! I'", "role": "assistant"}, 731 "finish_reason": None, 732 "index": 0, 733 } 734 ], 735 "object": "chat.completion.chunk", 736 }, 737 ], 738 [ 739 ResponsesAgentStreamEvent( 740 type="response.output_text.delta", 741 custom_outputs=None, 742 item_id="msg_bdrk_016AC1ojH743YLHDfgnf4B7Y", 743 delta="Hello", 744 ), 745 ResponsesAgentStreamEvent( 746 type="response.output_text.delta", 747 custom_outputs=None, 748 item_id="msg_bdrk_016AC1ojH743YLHDfgnf4B7Y", 749 delta=" there! I'", 750 ), 751 ResponsesAgentStreamEvent( 752 type="response.output_item.done", 753 custom_outputs=None, 754 item={ 755 "id": "msg_bdrk_016AC1ojH743YLHDfgnf4B7Y", 756 "content": [ 757 {"text": "Hello there! I'", "type": "output_text", "annotations": []} 758 ], 759 "role": "assistant", 760 "type": "message", 761 }, 762 ), 763 ], 764 ), 765 ( 766 [ 767 { 768 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 769 "choices": [ 770 { 771 "delta": {"content": "I", "role": "assistant"}, 772 "finish_reason": None, 773 "index": 0, 774 } 775 ], 776 "object": "chat.completion.chunk", 777 }, 778 { 779 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 780 "choices": [ 781 { 782 "delta": {"content": " can help you calculate 4*", "role": "assistant"}, 783 "finish_reason": None, 784 "index": 0, 785 } 786 ], 787 "object": "chat.completion.chunk", 788 }, 789 { 790 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 791 "choices": [ 792 { 793 "delta": { 794 "content": None, 795 "role": "assistant", 796 "tool_calls": [ 797 { 798 "index": 0, 799 "id": "toolu_bdrk_01XKD5j3Ru1dk3jnm69xkXUL", 800 "function": { 801 "arguments": "", 802 "name": "system__ai__python_exec", 803 }, 804 "type": "function", 805 } 806 ], 807 }, 808 "finish_reason": None, 809 "index": 0, 810 } 811 ], 812 "object": "chat.completion.chunk", 813 }, 814 { 815 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 816 "choices": [ 817 { 818 "delta": { 819 "content": None, 820 "role": "assistant", 821 "tool_calls": [{"index": 0, "function": {"arguments": ""}}], 822 }, 823 "finish_reason": None, 824 "index": 0, 825 } 826 ], 827 "object": "chat.completion.chunk", 828 }, 829 { 830 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 831 "choices": [ 832 { 833 "delta": { 834 "content": None, 835 "role": "assistant", 836 "tool_calls": [ 837 {"index": 0, "function": {"arguments": '{"code": "#'}} 838 ], 839 }, 840 "finish_reason": None, 841 "index": 0, 842 } 843 ], 844 "created": 1757977465, 845 "model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", 846 "object": "chat.completion.chunk", 847 }, 848 { 849 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 850 "choices": [ 851 { 852 "delta": { 853 "content": None, 854 "role": "assistant", 855 "tool_calls": [{"index": 0, "function": {"arguments": " Calc"}}], 856 }, 857 "finish_reason": None, 858 "index": 0, 859 } 860 ], 861 "object": "chat.completion.chunk", 862 }, 863 { 864 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 865 "choices": [ 866 { 867 "delta": { 868 "content": None, 869 "role": "assistant", 870 "tool_calls": [ 871 {"index": 0, "function": {"arguments": "ulate 4*3"}} 872 ], 873 }, 874 "finish_reason": None, 875 "index": 0, 876 } 877 ], 878 "object": "chat.completion.chunk", 879 }, 880 { 881 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 882 "choices": [ 883 { 884 "delta": {"content": "", "role": "assistant"}, 885 "finish_reason": "tool_calls", 886 "index": 0, 887 } 888 ], 889 "object": "chat.completion.chunk", 890 }, 891 ], 892 [ 893 ResponsesAgentStreamEvent( 894 type="response.output_text.delta", 895 custom_outputs=None, 896 item_id="msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 897 delta="I", 898 ), 899 ResponsesAgentStreamEvent( 900 type="response.output_text.delta", 901 custom_outputs=None, 902 item_id="msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 903 delta=" can help you calculate 4*", 904 ), 905 ResponsesAgentStreamEvent( 906 type="response.output_text.delta", 907 custom_outputs=None, 908 item_id="msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 909 delta="", 910 ), 911 ResponsesAgentStreamEvent( 912 type="response.output_item.done", 913 custom_outputs=None, 914 item={ 915 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 916 "content": [ 917 { 918 "text": "I can help you calculate 4*", 919 "type": "output_text", 920 "annotations": [], 921 } 922 ], 923 "role": "assistant", 924 "type": "message", 925 }, 926 ), 927 ResponsesAgentStreamEvent( 928 type="response.output_item.done", 929 custom_outputs=None, 930 item={ 931 "type": "function_call", 932 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 933 "call_id": "toolu_bdrk_01XKD5j3Ru1dk3jnm69xkXUL", 934 "name": "system__ai__python_exec", 935 "arguments": '{"code": "# Calculate 4*3', 936 }, 937 ), 938 ], 939 ), 940 ( 941 [ 942 { 943 "id": "chatcmpl_56a443d8-bf71-4f71-aff5-082191c4db1e", 944 "choices": [ 945 { 946 "delta": { 947 "content": None, 948 "role": "assistant", 949 "tool_calls": [ 950 { 951 "index": 0, 952 "id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 953 "function": { 954 "arguments": "", 955 "name": "system__ai__python_exec", 956 }, 957 "type": "function", 958 } 959 ], 960 }, 961 "finish_reason": None, 962 "index": 0, 963 "logprobs": None, 964 } 965 ], 966 "object": "chat.completion.chunk", 967 }, 968 { 969 "id": "chatcmpl_56a443d8-bf71-4f71-aff5-082191c4db1e", 970 "choices": [ 971 { 972 "delta": { 973 "content": None, 974 "tool_calls": [ 975 { 976 "index": 0, 977 "function": { 978 "arguments": '{\n "code": "result = 4 * 3\\n' 979 'print(result)"\n}' 980 }, 981 } 982 ], 983 }, 984 "finish_reason": None, 985 "index": 0, 986 "logprobs": None, 987 } 988 ], 989 "object": "chat.completion.chunk", 990 }, 991 { 992 "id": "chatcmpl_56a443d8-bf71-4f71-aff5-082191c4db1e", 993 "choices": [ 994 { 995 "delta": { 996 "content": None, 997 "tool_calls": [{"index": 0, "function": {"arguments": ""}}], 998 }, 999 "finish_reason": "tool_calls", 1000 "index": 0, 1001 "logprobs": None, 1002 } 1003 ], 1004 "object": "chat.completion.chunk", 1005 }, 1006 ], 1007 [ 1008 ResponsesAgentStreamEvent( 1009 type="response.output_item.done", 1010 custom_outputs=None, 1011 item={ 1012 "type": "function_call", 1013 "id": "chatcmpl_56a443d8-bf71-4f71-aff5-082191c4db1e", 1014 "call_id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1015 "name": "system__ai__python_exec", 1016 "arguments": '{\n "code": "result = 4 * 3\\nprint(result)"\n}', 1017 }, 1018 ) 1019 ], 1020 ), 1021 # Parallel tool calls: verifies arguments are assembled per tool call index 1022 # Before fix, all arguments were concatenated into first tool call, causing JSON errors 1023 ( 1024 [ 1025 # Text content 1026 { 1027 "id": "msg1", 1028 "choices": [{"delta": {"content": "Calling tools."}, "index": 0}], 1029 "object": "chat.completion.chunk", 1030 }, 1031 # Tool 0: search - init + args 1032 { 1033 "id": "msg1", 1034 "choices": [ 1035 { 1036 "delta": { 1037 "tool_calls": [ 1038 { 1039 "index": 0, 1040 "id": "call_0", 1041 "function": {"name": "search", "arguments": ""}, 1042 } 1043 ] 1044 }, 1045 "index": 0, 1046 } 1047 ], 1048 "object": "chat.completion.chunk", 1049 }, 1050 { 1051 "id": "msg1", 1052 "choices": [ 1053 { 1054 "delta": { 1055 "tool_calls": [ 1056 { 1057 "index": 0, 1058 "function": {"arguments": '{"query": "ML best practices"}'}, 1059 } 1060 ] 1061 }, 1062 "index": 0, 1063 } 1064 ], 1065 "object": "chat.completion.chunk", 1066 }, 1067 # Tool 1: weather - init + args 1068 { 1069 "id": "msg1", 1070 "choices": [ 1071 { 1072 "delta": { 1073 "tool_calls": [ 1074 { 1075 "index": 1, 1076 "id": "call_1", 1077 "function": {"name": "weather", "arguments": ""}, 1078 } 1079 ] 1080 }, 1081 "index": 0, 1082 } 1083 ], 1084 "object": "chat.completion.chunk", 1085 }, 1086 { 1087 "id": "msg1", 1088 "choices": [ 1089 { 1090 "delta": { 1091 "tool_calls": [ 1092 { 1093 "index": 1, 1094 "function": {"arguments": '{"location": "Seattle"}'}, 1095 } 1096 ] 1097 }, 1098 "index": 0, 1099 } 1100 ], 1101 "object": "chat.completion.chunk", 1102 }, 1103 # Tool 2: calculate - init + args 1104 { 1105 "id": "msg1", 1106 "choices": [ 1107 { 1108 "delta": { 1109 "tool_calls": [ 1110 { 1111 "index": 2, 1112 "id": "call_2", 1113 "function": {"name": "calc", "arguments": ""}, 1114 } 1115 ] 1116 }, 1117 "index": 0, 1118 } 1119 ], 1120 "object": "chat.completion.chunk", 1121 }, 1122 { 1123 "id": "msg1", 1124 "choices": [ 1125 { 1126 "delta": { 1127 "tool_calls": [ 1128 {"index": 2, "function": {"arguments": '{"expr": "42*17"}'}} 1129 ] 1130 }, 1131 "index": 0, 1132 } 1133 ], 1134 "object": "chat.completion.chunk", 1135 }, 1136 # Final chunk 1137 { 1138 "id": "msg1", 1139 "choices": [ 1140 {"delta": {"content": ""}, "finish_reason": "tool_calls", "index": 0} 1141 ], 1142 "object": "chat.completion.chunk", 1143 }, 1144 ], 1145 [ 1146 ResponsesAgentStreamEvent( 1147 type="response.output_text.delta", item_id="msg1", delta="Calling tools." 1148 ), 1149 ResponsesAgentStreamEvent( 1150 type="response.output_text.delta", item_id="msg1", delta="" 1151 ), 1152 ResponsesAgentStreamEvent( 1153 type="response.output_item.done", 1154 item={ 1155 "id": "msg1", 1156 "content": [ 1157 {"text": "Calling tools.", "type": "output_text", "annotations": []} 1158 ], 1159 "role": "assistant", 1160 "type": "message", 1161 }, 1162 ), 1163 ResponsesAgentStreamEvent( 1164 type="response.output_item.done", 1165 item={ 1166 "type": "function_call", 1167 "id": "msg1", 1168 "call_id": "call_0", 1169 "name": "search", 1170 "arguments": '{"query": "ML best practices"}', 1171 }, 1172 ), 1173 ResponsesAgentStreamEvent( 1174 type="response.output_item.done", 1175 item={ 1176 "type": "function_call", 1177 "id": "msg1", 1178 "call_id": "call_1", 1179 "name": "weather", 1180 "arguments": '{"location": "Seattle"}', 1181 }, 1182 ), 1183 ResponsesAgentStreamEvent( 1184 type="response.output_item.done", 1185 item={ 1186 "type": "function_call", 1187 "id": "msg1", 1188 "call_id": "call_2", 1189 "name": "calc", 1190 "arguments": '{"expr": "42*17"}', 1191 }, 1192 ), 1193 ], 1194 ), 1195 ], 1196 ) 1197 def test_responses_agent_output_to_responses_items_stream(chunks, expected_output): 1198 """ 1199 In order of the parameters: 1200 1. gpt oss with no tools streaming 1201 - other models don't differentiate between w/ and w/o tools streaming 1202 2. gpt oss with tools streaming 1203 3. claude no tool call streaming 1204 4. claude tool call streaming 1205 """ 1206 aggregator = [] 1207 converted_output = list(ResponsesAgent.output_to_responses_items_stream(chunks, aggregator)) 1208 assert converted_output == expected_output 1209 expected_aggregator = [ 1210 event.item for event in expected_output if event.type == "response.output_item.done" 1211 ] 1212 assert aggregator == expected_aggregator 1213 1214 1215 def test_create_text_delta(): 1216 result = ResponsesAgent.create_text_delta("Hello", "test-id") 1217 expected = { 1218 "type": "response.output_text.delta", 1219 "item_id": "test-id", 1220 "delta": "Hello", 1221 } 1222 assert result == expected 1223 1224 1225 def test_create_annotation_added(): 1226 annotation = {"type": "citation", "text": "Reference"} 1227 result = ResponsesAgent.create_annotation_added("test-id", annotation, 1) 1228 expected = { 1229 "type": "response.output_text.annotation.added", 1230 "item_id": "test-id", 1231 "annotation_index": 1, 1232 "annotation": annotation, 1233 } 1234 assert result == expected 1235 1236 # Test with default annotation_index 1237 result_default = ResponsesAgent.create_annotation_added("test-id", annotation) 1238 expected_default = { 1239 "type": "response.output_text.annotation.added", 1240 "item_id": "test-id", 1241 "annotation_index": 0, 1242 "annotation": annotation, 1243 } 1244 assert result_default == expected_default 1245 1246 1247 def test_create_text_output_item(): 1248 # Test without annotations 1249 result = ResponsesAgent.create_text_output_item("Hello world", "test-id") 1250 expected = { 1251 "id": "test-id", 1252 "content": [ 1253 { 1254 "text": "Hello world", 1255 "type": "output_text", 1256 "annotations": [], 1257 } 1258 ], 1259 "role": "assistant", 1260 "type": "message", 1261 } 1262 assert result == expected 1263 1264 # Test with annotations 1265 annotations = [{"type": "citation", "text": "Reference"}] 1266 result_with_annotations = ResponsesAgent.create_text_output_item( 1267 "Hello world", "test-id", annotations 1268 ) 1269 expected_with_annotations = { 1270 "id": "test-id", 1271 "content": [ 1272 { 1273 "text": "Hello world", 1274 "type": "output_text", 1275 "annotations": annotations, 1276 } 1277 ], 1278 "role": "assistant", 1279 "type": "message", 1280 } 1281 assert result_with_annotations == expected_with_annotations 1282 1283 1284 def test_create_reasoning_item(): 1285 result = ResponsesAgent.create_reasoning_item("test-id", "This is my reasoning") 1286 expected = { 1287 "type": "reasoning", 1288 "summary": [ 1289 { 1290 "type": "summary_text", 1291 "text": "This is my reasoning", 1292 } 1293 ], 1294 "id": "test-id", 1295 } 1296 assert result == expected 1297 1298 1299 def test_create_function_call_item(): 1300 result = ResponsesAgent.create_function_call_item( 1301 "test-id", "call-123", "get_weather", '{"location": "Boston"}' 1302 ) 1303 expected = { 1304 "type": "function_call", 1305 "id": "test-id", 1306 "call_id": "call-123", 1307 "name": "get_weather", 1308 "arguments": '{"location": "Boston"}', 1309 } 1310 assert result == expected 1311 1312 1313 def test_create_function_call_output_item(): 1314 result = ResponsesAgent.create_function_call_output_item("call-123", "Sunny, 75°F") 1315 expected = { 1316 "type": "function_call_output", 1317 "call_id": "call-123", 1318 "output": "Sunny, 75°F", 1319 } 1320 assert result == expected 1321 1322 1323 @pytest.mark.parametrize( 1324 ("responses_input", "cc_msgs"), 1325 [ 1326 ( 1327 [ 1328 {"type": "user", "content": "what is 4*3 in python"}, 1329 {"type": "reasoning", "summary": "I can help you calculate 4*3"}, 1330 { 1331 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 1332 "content": [{"text": "I can help you calculate 4*", "type": "output_text"}], 1333 "role": "assistant", 1334 "type": "message", 1335 }, 1336 { 1337 "type": "mcp_approval_request", 1338 "id": "mcp_approval_request_123", 1339 "arguments": "{}", 1340 "name": "system__ai__python_exec", 1341 "server_label": "python_exec", 1342 }, 1343 { 1344 "type": "mcp_approval_response", 1345 "id": "mcp_approval_response_123", 1346 "approval_request_id": "mcp_approval_request_123", 1347 "approve": True, 1348 "reason": "The request was approved", 1349 }, 1350 { 1351 "type": "function_call", 1352 "id": "chatcmpl_56a443d8-bf71-4f71-aff5-082191c4db1e", 1353 "call_id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1354 "name": "system__ai__python_exec", 1355 "arguments": '{\n "code": "result = 4 * 3\\nprint(result)"\n}', 1356 }, 1357 { 1358 "type": "function_call_output", 1359 "call_id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1360 "output": "12\n", 1361 }, 1362 ], 1363 [ 1364 {"content": "what is 4*3 in python"}, 1365 {"role": "assistant", "content": '"I can help you calculate 4*3"'}, 1366 {"role": "assistant", "content": "I can help you calculate 4*"}, 1367 { 1368 "role": "assistant", 1369 "content": "mcp approval request", 1370 "tool_calls": [ 1371 { 1372 "id": "mcp_approval_request_123", 1373 "type": "function", 1374 "function": { 1375 "arguments": "{}", 1376 "name": "system__ai__python_exec", 1377 }, 1378 } 1379 ], 1380 }, 1381 { 1382 "role": "tool", 1383 "content": "True", 1384 "tool_call_id": "mcp_approval_request_123", 1385 }, 1386 { 1387 "role": "assistant", 1388 "content": "tool call", 1389 "tool_calls": [ 1390 { 1391 "id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1392 "type": "function", 1393 "function": { 1394 "arguments": '{\n "code": "result = 4 * 3\\nprint(result)"\n}', 1395 "name": "system__ai__python_exec", 1396 }, 1397 } 1398 ], 1399 }, 1400 { 1401 "role": "tool", 1402 "content": "12\n", 1403 "tool_call_id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1404 }, 1405 ], 1406 ) 1407 ], 1408 ) 1409 def test_prep_msgs_for_cc_llm(responses_input, cc_msgs): 1410 result = ResponsesAgent.prep_msgs_for_cc_llm(responses_input) 1411 assert result == cc_msgs 1412 1413 1414 @pytest.mark.parametrize( 1415 ("responses_input", "cc_msgs"), 1416 [ 1417 ( 1418 [ 1419 {"type": "user", "content": "what is 4*3 in python"}, 1420 {"type": "reasoning", "summary": "I can help you calculate 4*3"}, 1421 { 1422 "id": "msg_bdrk_015YdA8hjVSHWxpAdecgHqj3", 1423 "content": [{"text": "I can help you calculate 4*", "type": "output_text"}], 1424 "role": "assistant", 1425 "type": "message", 1426 }, 1427 { 1428 "type": "function_call", 1429 "id": "chatcmpl_56a443d8-bf71-4f71-aff5-082191c4db1e", 1430 "call_id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1431 "name": "system__ai__python_exec", 1432 "arguments": "", 1433 }, 1434 { 1435 "type": "function_call_output", 1436 "call_id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1437 "output": "12\n", 1438 }, 1439 ], 1440 [ 1441 {"content": "what is 4*3 in python"}, 1442 {"role": "assistant", "content": '"I can help you calculate 4*3"'}, 1443 {"role": "assistant", "content": "I can help you calculate 4*"}, 1444 { 1445 "role": "assistant", 1446 "content": "tool call", 1447 "tool_calls": [ 1448 { 1449 "id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1450 "type": "function", 1451 "function": { 1452 "arguments": "{}", 1453 "name": "system__ai__python_exec", 1454 }, 1455 } 1456 ], 1457 }, 1458 { 1459 "role": "tool", 1460 "content": "12\n", 1461 "tool_call_id": "call_39565342-e7d7-4ed5-a3e3-ea115a7f9fc6", 1462 }, 1463 ], 1464 ) 1465 ], 1466 ) 1467 def test_prep_msgs_for_cc_llm_empty_arguments(responses_input, cc_msgs): 1468 result = ResponsesAgent.prep_msgs_for_cc_llm(responses_input) 1469 assert result == cc_msgs 1470 1471 1472 def test_cc_stream_to_responses_stream_handles_multiple_invalid_chunks(): 1473 chunks_with_mixed_validity = [ 1474 {"choices": None, "id": "msg-1"}, 1475 {"choices": [], "id": "msg-2"}, 1476 {"choices": [{"delta": {"content": "valid"}}], "id": "msg-3"}, 1477 {"choices": None, "id": "msg-4"}, 1478 {"choices": [{"delta": {"content": " content"}}], "id": "msg-5"}, 1479 ] 1480 1481 events = list(output_to_responses_items_stream(iter(chunks_with_mixed_validity))) 1482 1483 # Should only process chunks with valid choices 1484 # Expected: 2 delta events + 1 done event (content gets aggregated) 1485 assert len(events) == 3 1486 assert events[0].type == "response.output_text.delta" 1487 assert events[0].delta == "valid" 1488 assert events[1].type == "response.output_text.delta" 1489 assert events[1].delta == " content" 1490 assert events[2].type == "response.output_item.done"