responses.py
1 import json 2 from collections.abc import Sequence 3 from itertools import tee 4 from typing import Any, Generator, Iterator 5 from uuid import uuid4 6 7 from pydantic import BaseModel, ConfigDict, model_validator 8 9 from mlflow.types.agent import ChatContext 10 from mlflow.types.responses_helpers import ( 11 BaseRequestPayload, 12 Message, 13 OutputItem, 14 Response, 15 ResponseCompletedEvent, 16 ResponseErrorEvent, 17 ResponseOutputItemDoneEvent, 18 ResponseTextAnnotationDeltaEvent, 19 ResponseTextDeltaEvent, 20 ) 21 22 __all__ = [ 23 "ResponsesAgentRequest", 24 "ResponsesAgentResponse", 25 "ResponsesAgentStreamEvent", 26 ] 27 28 from mlflow.types.schema import Schema 29 from mlflow.types.type_hints import _infer_schema_from_type_hint 30 from mlflow.utils.autologging_utils.logging_and_warnings import ( 31 MlflowEventsAndWarningsBehaviorGlobally, 32 ) 33 34 35 class ResponsesAgentRequest(BaseRequestPayload): 36 """Request object for ResponsesAgent. 37 38 Args: 39 input: List of simple `role` and `content` messages or output items. See examples at 40 https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#testing-out-your-agent 41 and 42 https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 43 custom_inputs (Dict[str, Any]): An optional param to provide arbitrary additional context 44 to the model. The dictionary values must be JSON-serializable. 45 **Optional** defaults to ``None`` 46 context (:py:class:`mlflow.types.agent.ChatContext`): The context to be used in the chat 47 endpoint. Includes conversation_id and user_id. **Optional** defaults to ``None`` 48 """ 49 50 input: list[Message | OutputItem] 51 custom_inputs: dict[str, Any] | None = None 52 context: ChatContext | None = None 53 54 55 class ResponsesAgentResponse(Response): 56 """Response object for ResponsesAgent. 57 58 Args: 59 output: List of output items. See examples at 60 https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 61 reasoning: Reasoning parameters 62 usage: Usage information 63 custom_outputs (Dict[str, Any]): An optional param to provide arbitrary additional context 64 from the model. The dictionary values must be JSON-serializable. **Optional**, defaults 65 to ``None`` 66 """ 67 68 custom_outputs: dict[str, Any] | None = None 69 70 71 class ResponsesAgentStreamEvent(BaseModel): 72 """Stream event for ResponsesAgent. 73 See examples at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#streaming-agent-output 74 75 Args: 76 type (str): Type of the stream event 77 custom_outputs (Dict[str, Any]): An optional param to provide arbitrary additional context 78 from the model. The dictionary values must be JSON-serializable. **Optional**, defaults 79 to ``None`` 80 """ 81 82 model_config = ConfigDict(extra="allow") 83 type: str 84 custom_outputs: dict[str, Any] | None = None 85 86 @model_validator(mode="after") 87 def check_type(self) -> "ResponsesAgentStreamEvent": 88 type = self.type 89 if type == "response.output_item.done": 90 ResponseOutputItemDoneEvent(**self.model_dump()) 91 elif type == "response.output_text.delta": 92 ResponseTextDeltaEvent(**self.model_dump()) 93 elif type == "response.output_text.annotation.added": 94 ResponseTextAnnotationDeltaEvent(**self.model_dump()) 95 elif type == "error": 96 ResponseErrorEvent(**self.model_dump()) 97 elif type == "response.completed": 98 ResponseCompletedEvent(**self.model_dump()) 99 """ 100 unvalidated types: { 101 "response.created", 102 "response.in_progress", 103 "response.completed", 104 "response.failed", 105 "response.incomplete", 106 "response.content_part.added", 107 "response.content_part.done", 108 "response.output_text.done", 109 "response.output_item.added", 110 "response.refusal.delta", 111 "response.refusal.done", 112 "response.function_call_arguments.delta", 113 "response.function_call_arguments.done", 114 "response.file_search_call.in_progress", 115 "response.file_search_call.searching", 116 "response.file_search_call.completed", 117 "response.web_search_call.in_progress", 118 "response.web_search_call.searching", 119 "response.web_search_call.completed", 120 "response.error", 121 } 122 """ 123 return self 124 125 126 with MlflowEventsAndWarningsBehaviorGlobally( 127 reroute_warnings=False, 128 disable_event_logs=True, 129 disable_warnings=True, 130 ): 131 properties = _infer_schema_from_type_hint(ResponsesAgentRequest).to_dict()[0]["properties"] 132 formatted_properties = [{**prop, "name": name} for name, prop in properties.items()] 133 RESPONSES_AGENT_INPUT_SCHEMA = Schema.from_json(json.dumps(formatted_properties)) 134 RESPONSES_AGENT_OUTPUT_SCHEMA = _infer_schema_from_type_hint(ResponsesAgentResponse) 135 RESPONSES_AGENT_INPUT_EXAMPLE = {"input": [{"role": "user", "content": "Hello!"}]} 136 137 try: 138 from langchain_core.messages import BaseMessage 139 140 _HAS_LANGCHAIN_BASE_MESSAGE = True 141 except ImportError: 142 _HAS_LANGCHAIN_BASE_MESSAGE = False 143 144 145 def responses_agent_output_reducer( 146 chunks: list[ResponsesAgentStreamEvent | dict[str, Any]], 147 ): 148 """Output reducer for ResponsesAgent streaming.""" 149 output_items = [] 150 for chunk in chunks: 151 # Handle both dict and pydantic object formats 152 if isinstance(chunk, dict): 153 chunk_type = chunk.get("type") 154 if chunk_type == "response.output_item.done": 155 output_items.append(chunk.get("item")) 156 else: 157 # Pydantic object (ResponsesAgentStreamEvent) 158 if hasattr(chunk, "type") and chunk.type == "response.output_item.done": 159 output_items.append(chunk.item) 160 161 return ResponsesAgentResponse(output=output_items).model_dump(exclude_none=True) 162 163 164 def create_text_delta(delta: str, item_id: str) -> dict[str, Any]: 165 """Helper method to create a dictionary conforming to the text delta schema for 166 streaming. 167 168 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#streaming-agent-output. 169 """ 170 return { 171 "type": "response.output_text.delta", 172 "item_id": item_id, 173 "delta": delta, 174 } 175 176 177 def create_annotation_added( 178 item_id: str, annotation: dict[str, Any], annotation_index: int | None = 0 179 ) -> dict[str, Any]: 180 """Helper method to create annotation added event.""" 181 return { 182 "type": "response.output_text.annotation.added", 183 "item_id": item_id, 184 "annotation_index": annotation_index, 185 "annotation": annotation, 186 } 187 188 189 def create_text_output_item( 190 text: str, id: str, annotations: list[dict[str, Any]] | None = None 191 ) -> dict[str, Any]: 192 """Helper method to create a dictionary conforming to the text output item schema. 193 194 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 195 196 Args: 197 text (str): The text to be outputted. 198 id (str): The id of the output item. 199 annotations (Optional[list[dict]]): The annotations of the output item. 200 """ 201 content_item = { 202 "text": text, 203 "type": "output_text", 204 "annotations": annotations or [], 205 } 206 return { 207 "id": id, 208 "content": [content_item], 209 "role": "assistant", 210 "type": "message", 211 } 212 213 214 def create_reasoning_item(id: str, reasoning_text: str) -> dict[str, Any]: 215 """Helper method to create a dictionary conforming to the reasoning item schema. 216 217 Read more at https://www.mlflow.org/docs/latest/llms/responses-agent-intro/#creating-agent-output. 218 """ 219 return { 220 "type": "reasoning", 221 "summary": [ 222 { 223 "type": "summary_text", 224 "text": reasoning_text, 225 } 226 ], 227 "id": id, 228 } 229 230 231 def create_function_call_item(id: str, call_id: str, name: str, arguments: str) -> dict[str, Any]: 232 """Helper method to create a dictionary conforming to the function call item schema. 233 234 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 235 236 Args: 237 id (str): The id of the output item. 238 call_id (str): The id of the function call. 239 name (str): The name of the function to be called. 240 arguments (str): The arguments to be passed to the function. 241 """ 242 return { 243 "type": "function_call", 244 "id": id, 245 "call_id": call_id, 246 "name": name, 247 "arguments": arguments, 248 } 249 250 251 def create_function_call_output_item(call_id: str, output: str) -> dict[str, Any]: 252 """Helper method to create a dictionary conforming to the function call output item 253 schema. 254 255 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 256 257 Args: 258 call_id (str): The id of the function call. 259 output (str): The output of the function call. 260 """ 261 return { 262 "type": "function_call_output", 263 "call_id": call_id, 264 "output": output, 265 } 266 267 268 def create_mcp_approval_request_item( 269 id: str, arguments: str, name: str, server_label: str 270 ) -> dict[str, Any]: 271 """Helper method to create a dictionary conforming to the MCP approval request item schema. 272 273 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 274 275 Args: 276 id (str): The unique id of the approval request. 277 arguments (str): A JSON string of arguments for the tool. 278 name (str): The name of the tool to run. 279 server_label (str): The label of the MCP server making the request. 280 """ 281 return { 282 "type": "mcp_approval_request", 283 "id": id, 284 "arguments": arguments, 285 "name": name, 286 "server_label": server_label, 287 } 288 289 290 def create_mcp_approval_response_item( 291 id: str, 292 approval_request_id: str, 293 approve: bool, 294 reason: str | None = None, 295 ) -> dict[str, Any]: 296 """Helper method to create a dictionary conforming to the MCP approval response item schema. 297 298 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 299 300 Args: 301 id (str): The unique id of the approval response. 302 approval_request_id (str): The id of the approval request being answered. 303 approve (bool): Whether the request was approved. 304 reason (Optional[str]): The reason for the approval. 305 """ 306 return { 307 "type": "mcp_approval_response", 308 "id": id, 309 "approval_request_id": approval_request_id, 310 "approve": approve, 311 "reason": reason, 312 } 313 314 315 def responses_to_cc(message: dict[str, Any]) -> list[dict[str, Any]]: 316 """Convert from a Responses API output item to a list of ChatCompletion messages.""" 317 msg_type = message.get("type") 318 if msg_type == "function_call": 319 return [ 320 { 321 "role": "assistant", 322 "content": "tool call", # empty content is not supported by claude models 323 "tool_calls": [ 324 { 325 "id": message["call_id"], 326 "type": "function", 327 "function": { 328 "arguments": message.get("arguments") or "{}", 329 "name": message["name"], 330 }, 331 } 332 ], 333 } 334 ] 335 elif msg_type == "message" and isinstance(message.get("content"), list): 336 return [ 337 {"role": message["role"], "content": content["text"]} for content in message["content"] 338 ] 339 elif msg_type == "reasoning": 340 return [{"role": "assistant", "content": json.dumps(message["summary"])}] 341 elif msg_type == "function_call_output": 342 output = message["output"] 343 # Convert non-string output to string for ChatCompletion compatibility 344 if not isinstance(output, str): 345 try: 346 output = json.dumps(output) 347 except (TypeError, ValueError): 348 output = str(output) 349 return [ 350 { 351 "role": "tool", 352 "content": output, 353 "tool_call_id": message["call_id"], 354 } 355 ] 356 elif msg_type == "mcp_approval_request": 357 return [ 358 { 359 "role": "assistant", 360 "content": "mcp approval request", 361 "tool_calls": [ 362 { 363 "id": message["id"], 364 "type": "function", 365 "function": { 366 "arguments": message.get("arguments") or "{}", 367 "name": message["name"], 368 }, 369 } 370 ], 371 } 372 ] 373 elif msg_type == "mcp_approval_response": 374 return [ 375 { 376 "role": "tool", 377 "content": str(message["approve"]), 378 "tool_call_id": message["approval_request_id"], 379 } 380 ] 381 compatible_keys = ["role", "content", "name", "tool_calls", "tool_call_id"] 382 filtered = {k: v for k, v in message.items() if k in compatible_keys} 383 return [filtered] if filtered else [] 384 385 386 def to_chat_completions_input( 387 responses_input: Sequence[dict[str, Any] | Message | OutputItem], 388 ) -> list[dict[str, Any]]: 389 """Convert from Responses input items to ChatCompletion dictionaries.""" 390 cc_msgs = [] 391 for msg in responses_input: 392 if isinstance(msg, BaseModel): 393 cc_msgs.extend(responses_to_cc(msg.model_dump())) 394 else: 395 cc_msgs.extend(responses_to_cc(msg)) 396 return cc_msgs 397 398 399 def output_to_responses_items_stream( 400 chunks: Iterator[dict[str, Any]], 401 aggregator: list[dict[str, Any]] | None = None, 402 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 403 """ 404 For streaming, convert from various message format dicts to Responses output items, 405 returning a generator of ResponsesAgentStreamEvent objects. 406 407 If `aggregator` is provided, it will be extended with the aggregated output item dicts. 408 409 Handles an iterator of ChatCompletion chunks or LangChain BaseMessage objects. 410 """ 411 peeking_iter, chunks = tee(chunks) 412 first_chunk = next(peeking_iter) 413 if _HAS_LANGCHAIN_BASE_MESSAGE and isinstance(first_chunk, BaseMessage): 414 yield from _langchain_message_stream_to_responses_stream(chunks, aggregator) 415 else: 416 yield from _cc_stream_to_responses_stream(chunks, aggregator) 417 418 419 if _HAS_LANGCHAIN_BASE_MESSAGE: 420 421 def _stringify_content(content: Any) -> str: 422 """Ensure content is a string, JSON-serializing if necessary.""" 423 if isinstance(content, str): 424 return content 425 try: 426 return json.dumps(content) 427 except (TypeError, ValueError): 428 return str(content) 429 430 def _langchain_message_stream_to_responses_stream( 431 chunks: Iterator[BaseMessage], 432 aggregator: list[dict[str, Any]] | None = None, 433 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 434 """Convert from a stream of LangChain BaseMessage objects to a stream of 435 ResponsesAgentStreamEvent objects. Skips user or human messages. 436 """ 437 for chunk in chunks: 438 message = chunk.model_dump() 439 role = message["type"] 440 if role == "ai": 441 if message.get("content"): 442 text_output_item = create_text_output_item( 443 text=message["content"], 444 id=message.get("id") or str(uuid4()), 445 ) 446 if aggregator is not None: 447 aggregator.append(text_output_item) 448 yield ResponsesAgentStreamEvent( 449 type="response.output_item.done", item=text_output_item 450 ) 451 if tool_calls := message.get("tool_calls"): 452 for tool_call in tool_calls: 453 function_call_item = create_function_call_item( 454 id=tool_call.get("id") or message.get("id") or str(uuid4()), 455 call_id=tool_call["id"], 456 name=tool_call["name"], 457 arguments=json.dumps(tool_call["args"]), 458 ) 459 if aggregator is not None: 460 aggregator.append(function_call_item) 461 yield ResponsesAgentStreamEvent( 462 type="response.output_item.done", item=function_call_item 463 ) 464 465 elif role == "tool": 466 function_call_output_item = create_function_call_output_item( 467 call_id=message["tool_call_id"], 468 output=_stringify_content(message["content"]), 469 ) 470 if aggregator is not None: 471 aggregator.append(function_call_output_item) 472 yield ResponsesAgentStreamEvent( 473 type="response.output_item.done", item=function_call_output_item 474 ) 475 elif role == "user" or "human": 476 continue 477 478 479 def _cc_stream_to_responses_stream( 480 chunks: Iterator[dict[str, Any]], 481 aggregator: list[dict[str, Any]] | None = None, 482 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 483 """ 484 Convert from stream of ChatCompletion chunks to a stream of 485 ResponsesAgentStreamEvent objects. 486 """ 487 llm_content = "" 488 reasoning_content = "" 489 tool_calls: dict[int, dict[str, Any]] = {} # index -> tool_call dict 490 msg_id = None 491 for chunk in chunks: 492 if chunk.get("choices") is None or len(chunk["choices"]) == 0: 493 continue 494 delta = chunk["choices"][0]["delta"] 495 msg_id = chunk.get("id", None) 496 content = delta.get("content", None) 497 if tc := delta.get("tool_calls"): 498 for tool_call_delta in tc: 499 idx = tool_call_delta.get("index", 0) 500 if idx not in tool_calls: 501 # First chunk for this tool call contains id and name 502 tool_calls[idx] = { 503 "id": tool_call_delta.get("id"), 504 "function": { 505 "name": tool_call_delta.get("function", {}).get("name", ""), 506 "arguments": tool_call_delta.get("function", {}).get("arguments", ""), 507 }, 508 } 509 else: 510 # Subsequent chunks only contain argument fragments 511 tool_calls[idx]["function"]["arguments"] += tool_call_delta.get( 512 "function", {} 513 ).get("arguments", "") 514 elif content is not None: 515 # logic for content item format 516 # https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#contentitem 517 if isinstance(content, list): 518 for item in content: 519 if isinstance(item, dict): 520 if item.get("type") == "reasoning": 521 reasoning_content += item.get("summary", [])[0].get("text", "") 522 if item.get("type") == "text" and item.get("text"): 523 llm_content += item["text"] 524 yield ResponsesAgentStreamEvent( 525 **create_text_delta(item["text"], item_id=msg_id) 526 ) 527 elif reasoning_content != "": 528 # reasoning content is done streaming 529 reasoning_item = create_reasoning_item(msg_id, reasoning_content) 530 if aggregator is not None: 531 aggregator.append(reasoning_item) 532 yield ResponsesAgentStreamEvent( 533 type="response.output_item.done", 534 item=reasoning_item, 535 ) 536 reasoning_content = "" 537 538 if isinstance(content, str): 539 llm_content += content 540 yield ResponsesAgentStreamEvent(**create_text_delta(content, item_id=msg_id)) 541 542 # yield an `output_item.done` `output_text` event that aggregates the stream 543 # this enables tracing and payload logging 544 if llm_content: 545 text_output_item = create_text_output_item(llm_content, msg_id) 546 if aggregator is not None: 547 aggregator.append(text_output_item) 548 yield ResponsesAgentStreamEvent( 549 type="response.output_item.done", 550 item=text_output_item, 551 ) 552 553 for idx in sorted(tool_calls.keys()): 554 tool_call = tool_calls[idx] 555 function_call_output_item = create_function_call_item( 556 msg_id, 557 tool_call["id"], 558 tool_call["function"]["name"], 559 tool_call["function"]["arguments"], 560 ) 561 if aggregator is not None: 562 aggregator.append(function_call_output_item) 563 yield ResponsesAgentStreamEvent( 564 type="response.output_item.done", 565 item=function_call_output_item, 566 )