utils.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import json
  6  from typing import Any
  7  
  8  from haystack import logging
  9  from haystack.dataclasses import ChatMessage, ReasoningContent, StreamingChunk, ToolCall
 10  
 11  logger = logging.getLogger(__name__)
 12  
 13  
 14  def print_streaming_chunk(chunk: StreamingChunk) -> None:
 15      """
 16      Callback function to handle and display streaming output chunks.
 17  
 18      This function processes a `StreamingChunk` object by:
 19      - Printing tool call metadata (if any), including function names and arguments, as they arrive.
 20      - Printing tool call results when available.
 21      - Printing the main content (e.g., text tokens) of the chunk as it is received.
 22  
 23      The function outputs data directly to stdout and flushes output buffers to ensure immediate display during
 24      streaming.
 25  
 26      :param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and
 27          tool results.
 28      """
 29      if chunk.start and chunk.index and chunk.index > 0:
 30          # If this is the start of a new content block but not the first content block, print two new lines
 31          print("\n\n", flush=True, end="")
 32  
 33      ## Tool Call streaming
 34      if chunk.tool_calls:
 35          # Typically, if there are multiple tool calls in the chunk this means that the tool calls are fully formed and
 36          # not just a delta.
 37          for tool_call in chunk.tool_calls:
 38              # If chunk.start is True indicates beginning of a tool call
 39              # Also presence of tool_call.tool_name indicates the start of a tool call too
 40              if chunk.start:
 41                  # If there is more than one tool call in the chunk, we print two new lines to separate them
 42                  # We know there is more than one tool call if the index of the tool call is greater than the index of
 43                  # the chunk.
 44                  if chunk.index and tool_call.index > chunk.index:
 45                      print("\n\n", flush=True, end="")
 46  
 47                  print(f"[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
 48  
 49              # print the tool arguments
 50              if tool_call.arguments:
 51                  print(tool_call.arguments, flush=True, end="")
 52  
 53      ## Tool Call Result streaming
 54      # Print tool call results if available (from ToolInvoker)
 55      if chunk.tool_call_result:
 56          # Tool Call Result is fully formed so delta accumulation is not needed
 57          print(f"[TOOL RESULT]\n{chunk.tool_call_result.result}", flush=True, end="")
 58  
 59      ## Normal content streaming
 60      # Print the main content of the chunk (from ChatGenerator)
 61      if chunk.content:
 62          if chunk.start:
 63              print("[ASSISTANT]\n", flush=True, end="")
 64          print(chunk.content, flush=True, end="")
 65  
 66      ## Reasoning content streaming
 67      # Print the reasoning content of the chunk (from ChatGenerator)
 68      if chunk.reasoning:
 69          if chunk.start:
 70              print("[REASONING]\n", flush=True, end="")
 71          print(chunk.reasoning.reasoning_text, flush=True, end="")
 72  
 73      # End of LLM assistant message so we add two new lines
 74      # This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results
 75      if chunk.finish_reason is not None:
 76          print("\n\n", flush=True, end="")
 77  
 78  
 79  def _convert_streaming_chunks_to_chat_message(chunks: list[StreamingChunk]) -> ChatMessage:
 80      """
 81      Connects the streaming chunks into a single ChatMessage.
 82  
 83      :param chunks: The list of all `StreamingChunk` objects.
 84  
 85      :returns: The ChatMessage.
 86      """
 87      text = "".join([chunk.content for chunk in chunks])
 88      logprobs = []
 89      for chunk in chunks:
 90          if chunk.meta.get("logprobs"):
 91              logprobs.append(chunk.meta.get("logprobs"))
 92      tool_calls = []
 93  
 94      # Accumulate reasoning content from chunks
 95      reasoning_parts = [chunk.reasoning.reasoning_text for chunk in chunks if chunk.reasoning]
 96      reasoning = ReasoningContent(reasoning_text="".join(reasoning_parts)) if reasoning_parts else None
 97  
 98      # Process tool calls if present in any chunk
 99      tool_call_data: dict[int, dict[str, str]] = {}  # Track tool calls by index
100      for chunk in chunks:
101          if chunk.tool_calls:
102              for tool_call in chunk.tool_calls:
103                  # We use the index of the tool_call to track the tool call across chunks since the ID is not always
104                  # provided
105                  if tool_call.index not in tool_call_data:
106                      tool_call_data[tool_call.index] = {"id": "", "name": "", "arguments": ""}
107  
108                  # Save the ID if present
109                  if tool_call.id is not None:
110                      tool_call_data[tool_call.index]["id"] = tool_call.id
111  
112                  if tool_call.tool_name is not None:
113                      tool_call_data[tool_call.index]["name"] += tool_call.tool_name
114                  if tool_call.arguments is not None:
115                      tool_call_data[tool_call.index]["arguments"] += tool_call.arguments
116  
117      # Convert accumulated tool call data into ToolCall objects
118      sorted_keys = sorted(tool_call_data.keys())
119      for key in sorted_keys:
120          tool_call_dict = tool_call_data[key]
121          try:
122              arguments = json.loads(tool_call_dict.get("arguments", "{}")) if tool_call_dict.get("arguments") else {}
123              tool_calls.append(ToolCall(id=tool_call_dict["id"], tool_name=tool_call_dict["name"], arguments=arguments))
124          except json.JSONDecodeError:
125              logger.warning(
126                  "The LLM provider returned a malformed JSON string for tool call arguments. This tool call "
127                  "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
128                  "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
129                  _id=tool_call_dict["id"],
130                  _name=tool_call_dict["name"],
131                  _arguments=tool_call_dict["arguments"],
132              )
133  
134      # finish_reason can appear in different places so we look for the last one
135      finish_reasons = [chunk.finish_reason for chunk in chunks if chunk.finish_reason]
136      finish_reason = finish_reasons[-1] if finish_reasons else None
137  
138      # usage info can appear in different chunks depending on the API provider
139      # (e.g., OpenAI returns it in the last chunk with empty choices, but Qwen3 may return it differently)
140      # so we look for the last non-None usage value across all chunks
141      usage = None
142      for chunk in reversed(chunks):
143          chunk_usage = chunk.meta.get("usage")
144          if chunk_usage is not None:
145              usage = chunk_usage
146              break
147  
148      meta = {
149          "model": chunks[-1].meta.get("model"),
150          "index": 0,
151          "finish_reason": finish_reason,
152          "completion_start_time": chunks[0].meta.get("received_at"),  # first chunk received
153          "usage": usage,
154      }
155  
156      if logprobs:
157          meta["logprobs"] = logprobs
158  
159      return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, reasoning=reasoning, meta=meta)
160  
161  
162  def _serialize_object(obj: Any) -> Any:
163      """Convert an object to a serializable dict recursively"""
164      if hasattr(obj, "model_dump"):
165          return obj.model_dump()
166      if hasattr(obj, "__dict__"):
167          return {k: _serialize_object(v) for k, v in obj.__dict__.items() if not k.startswith("_")}
168      if isinstance(obj, dict):
169          return {k: _serialize_object(v) for k, v in obj.items()}
170      if isinstance(obj, list):
171          return [_serialize_object(item) for item in obj]
172      return obj