streaming_chunk.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from collections.abc import Awaitable, Callable 6 from dataclasses import asdict, dataclass, field 7 from typing import Any, Literal, overload 8 9 from haystack.core.component import Component 10 from haystack.dataclasses.chat_message import ReasoningContent, ToolCallResult 11 from haystack.utils.asynchronous import is_callable_async_compatible 12 from haystack.utils.dataclasses import _warn_on_inplace_mutation 13 14 # Type alias for standard finish_reason values following OpenAI's convention 15 # plus Haystack-specific value ("tool_call_results") 16 FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "tool_call_results"] 17 18 19 @_warn_on_inplace_mutation 20 @dataclass 21 class ToolCallDelta: 22 """ 23 Represents a Tool call prepared by the model, usually contained in an assistant message. 24 25 :param index: The index of the Tool call in the list of Tool calls. 26 :param tool_name: The name of the Tool to call. 27 :param arguments: Either the full arguments in JSON format or a delta of the arguments. 28 :param id: The ID of the Tool call. 29 :param extra: Dictionary of extra information about the Tool call. Use to store provider-specific 30 information. To avoid serialization issues, values should be JSON serializable. 31 """ 32 33 index: int 34 tool_name: str | None = field(default=None) 35 arguments: str | None = field(default=None) 36 id: str | None = field(default=None) 37 extra: dict[str, Any] | None = field(default=None) 38 39 def to_dict(self) -> dict[str, Any]: 40 """ 41 Returns a dictionary representation of the ToolCallDelta. 42 43 :returns: A dictionary with keys 'index', 'tool_name', 'arguments', 'id', and 'extra'. 44 """ 45 return asdict(self) 46 47 @classmethod 48 def from_dict(cls, data: dict[str, Any]) -> "ToolCallDelta": 49 """ 50 Creates a ToolCallDelta from a serialized representation. 51 52 :param data: Dictionary containing ToolCallDelta's attributes. 53 :returns: A ToolCallDelta instance. 54 """ 55 return ToolCallDelta(**data) 56 57 58 @_warn_on_inplace_mutation 59 @dataclass 60 class ComponentInfo: 61 """ 62 The `ComponentInfo` class encapsulates information about a component. 63 64 :param type: The type of the component. 65 :param name: The name of the component assigned when adding it to a pipeline. 66 67 """ 68 69 type: str 70 name: str | None = field(default=None) 71 72 @classmethod 73 def from_component(cls, component: Component) -> "ComponentInfo": 74 """ 75 Create a `ComponentInfo` object from a `Component` instance. 76 77 :param component: 78 The `Component` instance. 79 :returns: 80 The `ComponentInfo` object with the type and name of the given component. 81 """ 82 component_type = f"{component.__class__.__module__}.{component.__class__.__name__}" 83 component_name = getattr(component, "__component_name__", None) 84 return cls(type=component_type, name=component_name) 85 86 def to_dict(self) -> dict[str, Any]: 87 """ 88 Returns a dictionary representation of ComponentInfo. 89 90 :returns: A dictionary with keys 'type' and 'name'. 91 """ 92 return asdict(self) 93 94 @classmethod 95 def from_dict(cls, data: dict[str, Any]) -> "ComponentInfo": 96 """ 97 Creates a ComponentInfo from a serialized representation. 98 99 :param data: Dictionary containing ComponentInfo's attributes. 100 :returns: A ComponentInfo instance. 101 """ 102 return ComponentInfo(**data) 103 104 105 @_warn_on_inplace_mutation 106 @dataclass 107 class StreamingChunk: 108 """ 109 The `StreamingChunk` class encapsulates a segment of streamed content along with associated metadata. 110 111 This structure facilitates the handling and processing of streamed data in a systematic manner. 112 113 :param content: The content of the message chunk as a string. 114 :param meta: A dictionary containing metadata related to the message chunk. 115 :param component_info: A `ComponentInfo` object containing information about the component that generated the chunk, 116 such as the component name and type. 117 :param index: An optional integer index representing which content block this chunk belongs to. 118 :param tool_calls: An optional list of ToolCallDelta object representing a tool call associated with the message 119 chunk. 120 :param tool_call_result: An optional ToolCallResult object representing the result of a tool call. 121 :param start: A boolean indicating whether this chunk marks the start of a content block. 122 :param finish_reason: An optional value indicating the reason the generation finished. 123 Standard values follow OpenAI's convention: "stop", "length", "tool_calls", "content_filter", 124 plus Haystack-specific value "tool_call_results". 125 :param reasoning: An optional ReasoningContent object representing the reasoning content associated 126 with the message chunk. 127 """ 128 129 content: str 130 meta: dict[str, Any] = field(default_factory=dict, hash=False) 131 component_info: ComponentInfo | None = field(default=None) 132 index: int | None = field(default=None) 133 tool_calls: list[ToolCallDelta] | None = field(default=None) 134 tool_call_result: ToolCallResult | None = field(default=None) 135 start: bool = field(default=False) 136 finish_reason: FinishReason | None = field(default=None) 137 reasoning: ReasoningContent | None = field(default=None) 138 139 def __post_init__(self) -> None: 140 fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result, self.reasoning)) 141 if fields_set > 1: 142 raise ValueError( 143 "Only one of `content`, `tool_call`, `tool_call_result` or `reasoning` may be set in a StreamingChunk. " 144 f"Got content: '{self.content}', tool_call: '{self.tool_calls}', " 145 f"tool_call_result: '{self.tool_call_result}', reasoning: '{self.reasoning}'." 146 ) 147 148 # NOTE: We don't enforce this for self.content otherwise it would be a breaking change 149 if (self.tool_calls or self.tool_call_result or self.reasoning) and self.index is None: 150 raise ValueError("If `tool_call`, `tool_call_result` or `reasoning` is set, `index` must also be set.") 151 152 def to_dict(self) -> dict[str, Any]: 153 """ 154 Returns a dictionary representation of the StreamingChunk. 155 156 :returns: Serialized dictionary representation of the calling object. 157 """ 158 return { 159 "content": self.content, 160 "meta": self.meta, 161 "component_info": self.component_info.to_dict() if self.component_info else None, 162 "index": self.index, 163 "tool_calls": [tc.to_dict() for tc in self.tool_calls] if self.tool_calls else None, 164 "tool_call_result": self.tool_call_result.to_dict() if self.tool_call_result else None, 165 "start": self.start, 166 "finish_reason": self.finish_reason, 167 "reasoning": self.reasoning.to_dict() if self.reasoning else None, 168 } 169 170 @classmethod 171 def from_dict(cls, data: dict[str, Any]) -> "StreamingChunk": 172 """ 173 Creates a deserialized StreamingChunk instance from a serialized representation. 174 175 :param data: Dictionary containing the StreamingChunk's attributes. 176 :returns: A StreamingChunk instance. 177 """ 178 if "content" not in data: 179 raise ValueError("Missing required field `content` in StreamingChunk deserialization.") 180 181 return StreamingChunk( 182 content=data["content"], 183 meta=data.get("meta", {}), 184 component_info=ComponentInfo.from_dict(data["component_info"]) if data.get("component_info") else None, 185 index=data.get("index"), 186 tool_calls=[ToolCallDelta.from_dict(tc) for tc in data["tool_calls"]] if data.get("tool_calls") else None, 187 tool_call_result=ToolCallResult.from_dict(data["tool_call_result"]) 188 if data.get("tool_call_result") 189 else None, 190 start=data.get("start", False), 191 finish_reason=data.get("finish_reason"), 192 reasoning=ReasoningContent.from_dict(data["reasoning"]) if data.get("reasoning") else None, 193 ) 194 195 196 SyncStreamingCallbackT = Callable[[StreamingChunk], None] 197 AsyncStreamingCallbackT = Callable[[StreamingChunk], Awaitable[None]] 198 199 StreamingCallbackT = SyncStreamingCallbackT | AsyncStreamingCallbackT 200 201 202 @overload 203 def select_streaming_callback( 204 init_callback: StreamingCallbackT | None, 205 runtime_callback: StreamingCallbackT | None, 206 requires_async: Literal[False], 207 ) -> SyncStreamingCallbackT | None: ... 208 @overload 209 def select_streaming_callback( 210 init_callback: StreamingCallbackT | None, runtime_callback: StreamingCallbackT | None, requires_async: Literal[True] 211 ) -> AsyncStreamingCallbackT | None: ... 212 213 214 def select_streaming_callback( 215 init_callback: StreamingCallbackT | None, runtime_callback: StreamingCallbackT | None, requires_async: bool 216 ) -> StreamingCallbackT | None: 217 """ 218 Picks the correct streaming callback given an optional initial and runtime callback. 219 220 The runtime callback takes precedence over the initial callback. 221 222 :param init_callback: 223 The initial callback. 224 :param runtime_callback: 225 The runtime callback. 226 :param requires_async: 227 Whether the selected callback must be async compatible. 228 :returns: 229 The selected callback. 230 """ 231 if init_callback is not None: 232 if requires_async and not is_callable_async_compatible(init_callback): 233 raise ValueError("The init callback must be async compatible.") 234 if not requires_async and is_callable_async_compatible(init_callback): 235 raise ValueError("The init callback cannot be a coroutine.") 236 237 if runtime_callback is not None: 238 if requires_async and not is_callable_async_compatible(runtime_callback): 239 raise ValueError("The runtime callback must be async compatible.") 240 if not requires_async and is_callable_async_compatible(runtime_callback): 241 raise ValueError("The runtime callback cannot be a coroutine.") 242 243 return runtime_callback or init_callback