utils.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import json 6 from typing import Any 7 8 from haystack import logging 9 from haystack.dataclasses import ChatMessage, ReasoningContent, StreamingChunk, ToolCall 10 11 logger = logging.getLogger(__name__) 12 13 14 def print_streaming_chunk(chunk: StreamingChunk) -> None: 15 """ 16 Callback function to handle and display streaming output chunks. 17 18 This function processes a `StreamingChunk` object by: 19 - Printing tool call metadata (if any), including function names and arguments, as they arrive. 20 - Printing tool call results when available. 21 - Printing the main content (e.g., text tokens) of the chunk as it is received. 22 23 The function outputs data directly to stdout and flushes output buffers to ensure immediate display during 24 streaming. 25 26 :param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and 27 tool results. 28 """ 29 if chunk.start and chunk.index and chunk.index > 0: 30 # If this is the start of a new content block but not the first content block, print two new lines 31 print("\n\n", flush=True, end="") 32 33 ## Tool Call streaming 34 if chunk.tool_calls: 35 # Typically, if there are multiple tool calls in the chunk this means that the tool calls are fully formed and 36 # not just a delta. 37 for tool_call in chunk.tool_calls: 38 # If chunk.start is True indicates beginning of a tool call 39 # Also presence of tool_call.tool_name indicates the start of a tool call too 40 if chunk.start: 41 # If there is more than one tool call in the chunk, we print two new lines to separate them 42 # We know there is more than one tool call if the index of the tool call is greater than the index of 43 # the chunk. 44 if chunk.index and tool_call.index > chunk.index: 45 print("\n\n", flush=True, end="") 46 47 print(f"[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="") 48 49 # print the tool arguments 50 if tool_call.arguments: 51 print(tool_call.arguments, flush=True, end="") 52 53 ## Tool Call Result streaming 54 # Print tool call results if available (from ToolInvoker) 55 if chunk.tool_call_result: 56 # Tool Call Result is fully formed so delta accumulation is not needed 57 print(f"[TOOL RESULT]\n{chunk.tool_call_result.result}", flush=True, end="") 58 59 ## Normal content streaming 60 # Print the main content of the chunk (from ChatGenerator) 61 if chunk.content: 62 if chunk.start: 63 print("[ASSISTANT]\n", flush=True, end="") 64 print(chunk.content, flush=True, end="") 65 66 ## Reasoning content streaming 67 # Print the reasoning content of the chunk (from ChatGenerator) 68 if chunk.reasoning: 69 if chunk.start: 70 print("[REASONING]\n", flush=True, end="") 71 print(chunk.reasoning.reasoning_text, flush=True, end="") 72 73 # End of LLM assistant message so we add two new lines 74 # This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results 75 if chunk.finish_reason is not None: 76 print("\n\n", flush=True, end="") 77 78 79 def _convert_streaming_chunks_to_chat_message(chunks: list[StreamingChunk]) -> ChatMessage: 80 """ 81 Connects the streaming chunks into a single ChatMessage. 82 83 :param chunks: The list of all `StreamingChunk` objects. 84 85 :returns: The ChatMessage. 86 """ 87 text = "".join([chunk.content for chunk in chunks]) 88 logprobs = [] 89 for chunk in chunks: 90 if chunk.meta.get("logprobs"): 91 logprobs.append(chunk.meta.get("logprobs")) 92 tool_calls = [] 93 94 # Accumulate reasoning content from chunks 95 reasoning_parts = [chunk.reasoning.reasoning_text for chunk in chunks if chunk.reasoning] 96 reasoning = ReasoningContent(reasoning_text="".join(reasoning_parts)) if reasoning_parts else None 97 98 # Process tool calls if present in any chunk 99 tool_call_data: dict[int, dict[str, str]] = {} # Track tool calls by index 100 for chunk in chunks: 101 if chunk.tool_calls: 102 for tool_call in chunk.tool_calls: 103 # We use the index of the tool_call to track the tool call across chunks since the ID is not always 104 # provided 105 if tool_call.index not in tool_call_data: 106 tool_call_data[tool_call.index] = {"id": "", "name": "", "arguments": ""} 107 108 # Save the ID if present 109 if tool_call.id is not None: 110 tool_call_data[tool_call.index]["id"] = tool_call.id 111 112 if tool_call.tool_name is not None: 113 tool_call_data[tool_call.index]["name"] += tool_call.tool_name 114 if tool_call.arguments is not None: 115 tool_call_data[tool_call.index]["arguments"] += tool_call.arguments 116 117 # Convert accumulated tool call data into ToolCall objects 118 sorted_keys = sorted(tool_call_data.keys()) 119 for key in sorted_keys: 120 tool_call_dict = tool_call_data[key] 121 try: 122 arguments = json.loads(tool_call_dict.get("arguments", "{}")) if tool_call_dict.get("arguments") else {} 123 tool_calls.append(ToolCall(id=tool_call_dict["id"], tool_name=tool_call_dict["name"], arguments=arguments)) 124 except json.JSONDecodeError: 125 logger.warning( 126 "The LLM provider returned a malformed JSON string for tool call arguments. This tool call " 127 "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " 128 "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", 129 _id=tool_call_dict["id"], 130 _name=tool_call_dict["name"], 131 _arguments=tool_call_dict["arguments"], 132 ) 133 134 # finish_reason can appear in different places so we look for the last one 135 finish_reasons = [chunk.finish_reason for chunk in chunks if chunk.finish_reason] 136 finish_reason = finish_reasons[-1] if finish_reasons else None 137 138 # usage info can appear in different chunks depending on the API provider 139 # (e.g., OpenAI returns it in the last chunk with empty choices, but Qwen3 may return it differently) 140 # so we look for the last non-None usage value across all chunks 141 usage = None 142 for chunk in reversed(chunks): 143 chunk_usage = chunk.meta.get("usage") 144 if chunk_usage is not None: 145 usage = chunk_usage 146 break 147 148 meta = { 149 "model": chunks[-1].meta.get("model"), 150 "index": 0, 151 "finish_reason": finish_reason, 152 "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received 153 "usage": usage, 154 } 155 156 if logprobs: 157 meta["logprobs"] = logprobs 158 159 return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, reasoning=reasoning, meta=meta) 160 161 162 def _serialize_object(obj: Any) -> Any: 163 """Convert an object to a serializable dict recursively""" 164 if hasattr(obj, "model_dump"): 165 return obj.model_dump() 166 if hasattr(obj, "__dict__"): 167 return {k: _serialize_object(v) for k, v in obj.__dict__.items() if not k.startswith("_")} 168 if isinstance(obj, dict): 169 return {k: _serialize_object(v) for k, v in obj.items()} 170 if isinstance(obj, list): 171 return [_serialize_object(item) for item in obj] 172 return obj