/ mlflow / types / responses.py
responses.py
  1  import json
  2  from collections.abc import Sequence
  3  from itertools import tee
  4  from typing import Any, Generator, Iterator
  5  from uuid import uuid4
  6  
  7  from pydantic import BaseModel, ConfigDict, model_validator
  8  
  9  from mlflow.types.agent import ChatContext
 10  from mlflow.types.responses_helpers import (
 11      BaseRequestPayload,
 12      Message,
 13      OutputItem,
 14      Response,
 15      ResponseCompletedEvent,
 16      ResponseErrorEvent,
 17      ResponseOutputItemDoneEvent,
 18      ResponseTextAnnotationDeltaEvent,
 19      ResponseTextDeltaEvent,
 20  )
 21  
 22  __all__ = [
 23      "ResponsesAgentRequest",
 24      "ResponsesAgentResponse",
 25      "ResponsesAgentStreamEvent",
 26  ]
 27  
 28  from mlflow.types.schema import Schema
 29  from mlflow.types.type_hints import _infer_schema_from_type_hint
 30  from mlflow.utils.autologging_utils.logging_and_warnings import (
 31      MlflowEventsAndWarningsBehaviorGlobally,
 32  )
 33  
 34  
 35  class ResponsesAgentRequest(BaseRequestPayload):
 36      """Request object for ResponsesAgent.
 37  
 38      Args:
 39          input: List of simple `role` and `content` messages or output items. See examples at
 40              https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#testing-out-your-agent
 41              and
 42              https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
 43          custom_inputs (Dict[str, Any]): An optional param to provide arbitrary additional context
 44              to the model. The dictionary values must be JSON-serializable.
 45              **Optional** defaults to ``None``
 46          context (:py:class:`mlflow.types.agent.ChatContext`): The context to be used in the chat
 47              endpoint. Includes conversation_id and user_id. **Optional** defaults to ``None``
 48      """
 49  
 50      input: list[Message | OutputItem]
 51      custom_inputs: dict[str, Any] | None = None
 52      context: ChatContext | None = None
 53  
 54  
 55  class ResponsesAgentResponse(Response):
 56      """Response object for ResponsesAgent.
 57  
 58      Args:
 59          output: List of output items. See examples at
 60              https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
 61          reasoning: Reasoning parameters
 62          usage: Usage information
 63          custom_outputs (Dict[str, Any]): An optional param to provide arbitrary additional context
 64              from the model. The dictionary values must be JSON-serializable. **Optional**, defaults
 65              to ``None``
 66      """
 67  
 68      custom_outputs: dict[str, Any] | None = None
 69  
 70  
 71  class ResponsesAgentStreamEvent(BaseModel):
 72      """Stream event for ResponsesAgent.
 73      See examples at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#streaming-agent-output
 74  
 75      Args:
 76          type (str): Type of the stream event
 77          custom_outputs (Dict[str, Any]): An optional param to provide arbitrary additional context
 78              from the model. The dictionary values must be JSON-serializable. **Optional**, defaults
 79              to ``None``
 80      """
 81  
 82      model_config = ConfigDict(extra="allow")
 83      type: str
 84      custom_outputs: dict[str, Any] | None = None
 85  
 86      @model_validator(mode="after")
 87      def check_type(self) -> "ResponsesAgentStreamEvent":
 88          type = self.type
 89          if type == "response.output_item.done":
 90              ResponseOutputItemDoneEvent(**self.model_dump())
 91          elif type == "response.output_text.delta":
 92              ResponseTextDeltaEvent(**self.model_dump())
 93          elif type == "response.output_text.annotation.added":
 94              ResponseTextAnnotationDeltaEvent(**self.model_dump())
 95          elif type == "error":
 96              ResponseErrorEvent(**self.model_dump())
 97          elif type == "response.completed":
 98              ResponseCompletedEvent(**self.model_dump())
 99          """
100          unvalidated types: {
101              "response.created",
102              "response.in_progress",
103              "response.completed",
104              "response.failed",
105              "response.incomplete",
106              "response.content_part.added",
107              "response.content_part.done",
108              "response.output_text.done",
109              "response.output_item.added",
110              "response.refusal.delta",
111              "response.refusal.done",
112              "response.function_call_arguments.delta",
113              "response.function_call_arguments.done",
114              "response.file_search_call.in_progress",
115              "response.file_search_call.searching",
116              "response.file_search_call.completed",
117              "response.web_search_call.in_progress",
118              "response.web_search_call.searching",
119              "response.web_search_call.completed",
120              "response.error",
121          }
122          """
123          return self
124  
125  
126  with MlflowEventsAndWarningsBehaviorGlobally(
127      reroute_warnings=False,
128      disable_event_logs=True,
129      disable_warnings=True,
130  ):
131      properties = _infer_schema_from_type_hint(ResponsesAgentRequest).to_dict()[0]["properties"]
132      formatted_properties = [{**prop, "name": name} for name, prop in properties.items()]
133      RESPONSES_AGENT_INPUT_SCHEMA = Schema.from_json(json.dumps(formatted_properties))
134      RESPONSES_AGENT_OUTPUT_SCHEMA = _infer_schema_from_type_hint(ResponsesAgentResponse)
135  RESPONSES_AGENT_INPUT_EXAMPLE = {"input": [{"role": "user", "content": "Hello!"}]}
136  
137  try:
138      from langchain_core.messages import BaseMessage
139  
140      _HAS_LANGCHAIN_BASE_MESSAGE = True
141  except ImportError:
142      _HAS_LANGCHAIN_BASE_MESSAGE = False
143  
144  
145  def responses_agent_output_reducer(
146      chunks: list[ResponsesAgentStreamEvent | dict[str, Any]],
147  ):
148      """Output reducer for ResponsesAgent streaming."""
149      output_items = []
150      for chunk in chunks:
151          # Handle both dict and pydantic object formats
152          if isinstance(chunk, dict):
153              chunk_type = chunk.get("type")
154              if chunk_type == "response.output_item.done":
155                  output_items.append(chunk.get("item"))
156          else:
157              # Pydantic object (ResponsesAgentStreamEvent)
158              if hasattr(chunk, "type") and chunk.type == "response.output_item.done":
159                  output_items.append(chunk.item)
160  
161      return ResponsesAgentResponse(output=output_items).model_dump(exclude_none=True)
162  
163  
164  def create_text_delta(delta: str, item_id: str) -> dict[str, Any]:
165      """Helper method to create a dictionary conforming to the text delta schema for
166      streaming.
167  
168      Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#streaming-agent-output.
169      """
170      return {
171          "type": "response.output_text.delta",
172          "item_id": item_id,
173          "delta": delta,
174      }
175  
176  
177  def create_annotation_added(
178      item_id: str, annotation: dict[str, Any], annotation_index: int | None = 0
179  ) -> dict[str, Any]:
180      """Helper method to create annotation added event."""
181      return {
182          "type": "response.output_text.annotation.added",
183          "item_id": item_id,
184          "annotation_index": annotation_index,
185          "annotation": annotation,
186      }
187  
188  
189  def create_text_output_item(
190      text: str, id: str, annotations: list[dict[str, Any]] | None = None
191  ) -> dict[str, Any]:
192      """Helper method to create a dictionary conforming to the text output item schema.
193  
194      Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
195  
196      Args:
197          text (str): The text to be outputted.
198          id (str): The id of the output item.
199          annotations (Optional[list[dict]]): The annotations of the output item.
200      """
201      content_item = {
202          "text": text,
203          "type": "output_text",
204          "annotations": annotations or [],
205      }
206      return {
207          "id": id,
208          "content": [content_item],
209          "role": "assistant",
210          "type": "message",
211      }
212  
213  
214  def create_reasoning_item(id: str, reasoning_text: str) -> dict[str, Any]:
215      """Helper method to create a dictionary conforming to the reasoning item schema.
216  
217      Read more at https://www.mlflow.org/docs/latest/llms/responses-agent-intro/#creating-agent-output.
218      """
219      return {
220          "type": "reasoning",
221          "summary": [
222              {
223                  "type": "summary_text",
224                  "text": reasoning_text,
225              }
226          ],
227          "id": id,
228      }
229  
230  
231  def create_function_call_item(id: str, call_id: str, name: str, arguments: str) -> dict[str, Any]:
232      """Helper method to create a dictionary conforming to the function call item schema.
233  
234      Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
235  
236      Args:
237          id (str): The id of the output item.
238          call_id (str): The id of the function call.
239          name (str): The name of the function to be called.
240          arguments (str): The arguments to be passed to the function.
241      """
242      return {
243          "type": "function_call",
244          "id": id,
245          "call_id": call_id,
246          "name": name,
247          "arguments": arguments,
248      }
249  
250  
251  def create_function_call_output_item(call_id: str, output: str) -> dict[str, Any]:
252      """Helper method to create a dictionary conforming to the function call output item
253      schema.
254  
255      Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
256  
257      Args:
258          call_id (str): The id of the function call.
259          output (str): The output of the function call.
260      """
261      return {
262          "type": "function_call_output",
263          "call_id": call_id,
264          "output": output,
265      }
266  
267  
268  def create_mcp_approval_request_item(
269      id: str, arguments: str, name: str, server_label: str
270  ) -> dict[str, Any]:
271      """Helper method to create a dictionary conforming to the MCP approval request item schema.
272  
273      Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
274  
275      Args:
276          id (str): The unique id of the approval request.
277          arguments (str): A JSON string of arguments for the tool.
278          name (str): The name of the tool to run.
279          server_label (str): The label of the MCP server making the request.
280      """
281      return {
282          "type": "mcp_approval_request",
283          "id": id,
284          "arguments": arguments,
285          "name": name,
286          "server_label": server_label,
287      }
288  
289  
290  def create_mcp_approval_response_item(
291      id: str,
292      approval_request_id: str,
293      approve: bool,
294      reason: str | None = None,
295  ) -> dict[str, Any]:
296      """Helper method to create a dictionary conforming to the MCP approval response item schema.
297  
298      Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
299  
300      Args:
301          id (str): The unique id of the approval response.
302          approval_request_id (str): The id of the approval request being answered.
303          approve (bool): Whether the request was approved.
304          reason (Optional[str]): The reason for the approval.
305      """
306      return {
307          "type": "mcp_approval_response",
308          "id": id,
309          "approval_request_id": approval_request_id,
310          "approve": approve,
311          "reason": reason,
312      }
313  
314  
315  def responses_to_cc(message: dict[str, Any]) -> list[dict[str, Any]]:
316      """Convert from a Responses API output item to a list of ChatCompletion messages."""
317      msg_type = message.get("type")
318      if msg_type == "function_call":
319          return [
320              {
321                  "role": "assistant",
322                  "content": "tool call",  # empty content is not supported by claude models
323                  "tool_calls": [
324                      {
325                          "id": message["call_id"],
326                          "type": "function",
327                          "function": {
328                              "arguments": message.get("arguments") or "{}",
329                              "name": message["name"],
330                          },
331                      }
332                  ],
333              }
334          ]
335      elif msg_type == "message" and isinstance(message.get("content"), list):
336          return [
337              {"role": message["role"], "content": content["text"]} for content in message["content"]
338          ]
339      elif msg_type == "reasoning":
340          return [{"role": "assistant", "content": json.dumps(message["summary"])}]
341      elif msg_type == "function_call_output":
342          output = message["output"]
343          # Convert non-string output to string for ChatCompletion compatibility
344          if not isinstance(output, str):
345              try:
346                  output = json.dumps(output)
347              except (TypeError, ValueError):
348                  output = str(output)
349          return [
350              {
351                  "role": "tool",
352                  "content": output,
353                  "tool_call_id": message["call_id"],
354              }
355          ]
356      elif msg_type == "mcp_approval_request":
357          return [
358              {
359                  "role": "assistant",
360                  "content": "mcp approval request",
361                  "tool_calls": [
362                      {
363                          "id": message["id"],
364                          "type": "function",
365                          "function": {
366                              "arguments": message.get("arguments") or "{}",
367                              "name": message["name"],
368                          },
369                      }
370                  ],
371              }
372          ]
373      elif msg_type == "mcp_approval_response":
374          return [
375              {
376                  "role": "tool",
377                  "content": str(message["approve"]),
378                  "tool_call_id": message["approval_request_id"],
379              }
380          ]
381      compatible_keys = ["role", "content", "name", "tool_calls", "tool_call_id"]
382      filtered = {k: v for k, v in message.items() if k in compatible_keys}
383      return [filtered] if filtered else []
384  
385  
386  def to_chat_completions_input(
387      responses_input: Sequence[dict[str, Any] | Message | OutputItem],
388  ) -> list[dict[str, Any]]:
389      """Convert from Responses input items to ChatCompletion dictionaries."""
390      cc_msgs = []
391      for msg in responses_input:
392          if isinstance(msg, BaseModel):
393              cc_msgs.extend(responses_to_cc(msg.model_dump()))
394          else:
395              cc_msgs.extend(responses_to_cc(msg))
396      return cc_msgs
397  
398  
399  def output_to_responses_items_stream(
400      chunks: Iterator[dict[str, Any]],
401      aggregator: list[dict[str, Any]] | None = None,
402  ) -> Generator[ResponsesAgentStreamEvent, None, None]:
403      """
404      For streaming, convert from various message format dicts to Responses output items,
405      returning a generator of ResponsesAgentStreamEvent objects.
406  
407      If `aggregator` is provided, it will be extended with the aggregated output item dicts.
408  
409      Handles an iterator of ChatCompletion chunks or LangChain BaseMessage objects.
410      """
411      peeking_iter, chunks = tee(chunks)
412      first_chunk = next(peeking_iter)
413      if _HAS_LANGCHAIN_BASE_MESSAGE and isinstance(first_chunk, BaseMessage):
414          yield from _langchain_message_stream_to_responses_stream(chunks, aggregator)
415      else:
416          yield from _cc_stream_to_responses_stream(chunks, aggregator)
417  
418  
419  if _HAS_LANGCHAIN_BASE_MESSAGE:
420  
421      def _stringify_content(content: Any) -> str:
422          """Ensure content is a string, JSON-serializing if necessary."""
423          if isinstance(content, str):
424              return content
425          try:
426              return json.dumps(content)
427          except (TypeError, ValueError):
428              return str(content)
429  
430      def _langchain_message_stream_to_responses_stream(
431          chunks: Iterator[BaseMessage],
432          aggregator: list[dict[str, Any]] | None = None,
433      ) -> Generator[ResponsesAgentStreamEvent, None, None]:
434          """Convert from a stream of LangChain BaseMessage objects to a stream of
435          ResponsesAgentStreamEvent objects. Skips user or human messages.
436          """
437          for chunk in chunks:
438              message = chunk.model_dump()
439              role = message["type"]
440              if role == "ai":
441                  if message.get("content"):
442                      text_output_item = create_text_output_item(
443                          text=message["content"],
444                          id=message.get("id") or str(uuid4()),
445                      )
446                      if aggregator is not None:
447                          aggregator.append(text_output_item)
448                      yield ResponsesAgentStreamEvent(
449                          type="response.output_item.done", item=text_output_item
450                      )
451                  if tool_calls := message.get("tool_calls"):
452                      for tool_call in tool_calls:
453                          function_call_item = create_function_call_item(
454                              id=tool_call.get("id") or message.get("id") or str(uuid4()),
455                              call_id=tool_call["id"],
456                              name=tool_call["name"],
457                              arguments=json.dumps(tool_call["args"]),
458                          )
459                          if aggregator is not None:
460                              aggregator.append(function_call_item)
461                          yield ResponsesAgentStreamEvent(
462                              type="response.output_item.done", item=function_call_item
463                          )
464  
465              elif role == "tool":
466                  function_call_output_item = create_function_call_output_item(
467                      call_id=message["tool_call_id"],
468                      output=_stringify_content(message["content"]),
469                  )
470                  if aggregator is not None:
471                      aggregator.append(function_call_output_item)
472                  yield ResponsesAgentStreamEvent(
473                      type="response.output_item.done", item=function_call_output_item
474                  )
475              elif role == "user" or "human":
476                  continue
477  
478  
479  def _cc_stream_to_responses_stream(
480      chunks: Iterator[dict[str, Any]],
481      aggregator: list[dict[str, Any]] | None = None,
482  ) -> Generator[ResponsesAgentStreamEvent, None, None]:
483      """
484      Convert from stream of ChatCompletion chunks to a stream of
485      ResponsesAgentStreamEvent objects.
486      """
487      llm_content = ""
488      reasoning_content = ""
489      tool_calls: dict[int, dict[str, Any]] = {}  # index -> tool_call dict
490      msg_id = None
491      for chunk in chunks:
492          if chunk.get("choices") is None or len(chunk["choices"]) == 0:
493              continue
494          delta = chunk["choices"][0]["delta"]
495          msg_id = chunk.get("id", None)
496          content = delta.get("content", None)
497          if tc := delta.get("tool_calls"):
498              for tool_call_delta in tc:
499                  idx = tool_call_delta.get("index", 0)
500                  if idx not in tool_calls:
501                      # First chunk for this tool call contains id and name
502                      tool_calls[idx] = {
503                          "id": tool_call_delta.get("id"),
504                          "function": {
505                              "name": tool_call_delta.get("function", {}).get("name", ""),
506                              "arguments": tool_call_delta.get("function", {}).get("arguments", ""),
507                          },
508                      }
509                  else:
510                      # Subsequent chunks only contain argument fragments
511                      tool_calls[idx]["function"]["arguments"] += tool_call_delta.get(
512                          "function", {}
513                      ).get("arguments", "")
514          elif content is not None:
515              # logic for content item format
516              # https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#contentitem
517              if isinstance(content, list):
518                  for item in content:
519                      if isinstance(item, dict):
520                          if item.get("type") == "reasoning":
521                              reasoning_content += item.get("summary", [])[0].get("text", "")
522                          if item.get("type") == "text" and item.get("text"):
523                              llm_content += item["text"]
524                              yield ResponsesAgentStreamEvent(
525                                  **create_text_delta(item["text"], item_id=msg_id)
526                              )
527              elif reasoning_content != "":
528                  # reasoning content is done streaming
529                  reasoning_item = create_reasoning_item(msg_id, reasoning_content)
530                  if aggregator is not None:
531                      aggregator.append(reasoning_item)
532                  yield ResponsesAgentStreamEvent(
533                      type="response.output_item.done",
534                      item=reasoning_item,
535                  )
536                  reasoning_content = ""
537  
538              if isinstance(content, str):
539                  llm_content += content
540                  yield ResponsesAgentStreamEvent(**create_text_delta(content, item_id=msg_id))
541  
542      # yield an `output_item.done` `output_text` event that aggregates the stream
543      # this enables tracing and payload logging
544      if llm_content:
545          text_output_item = create_text_output_item(llm_content, msg_id)
546          if aggregator is not None:
547              aggregator.append(text_output_item)
548          yield ResponsesAgentStreamEvent(
549              type="response.output_item.done",
550              item=text_output_item,
551          )
552  
553      for idx in sorted(tool_calls.keys()):
554          tool_call = tool_calls[idx]
555          function_call_output_item = create_function_call_item(
556              msg_id,
557              tool_call["id"],
558              tool_call["function"]["name"],
559              tool_call["function"]["arguments"],
560          )
561          if aggregator is not None:
562              aggregator.append(function_call_output_item)
563          yield ResponsesAgentStreamEvent(
564              type="response.output_item.done",
565              item=function_call_output_item,
566          )