/ haystack / dataclasses / streaming_chunk.py
streaming_chunk.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from collections.abc import Awaitable, Callable
  6  from dataclasses import asdict, dataclass, field
  7  from typing import Any, Literal, overload
  8  
  9  from haystack.core.component import Component
 10  from haystack.dataclasses.chat_message import ReasoningContent, ToolCallResult
 11  from haystack.utils.asynchronous import is_callable_async_compatible
 12  from haystack.utils.dataclasses import _warn_on_inplace_mutation
 13  
 14  # Type alias for standard finish_reason values following OpenAI's convention
 15  # plus Haystack-specific value ("tool_call_results")
 16  FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "tool_call_results"]
 17  
 18  
 19  @_warn_on_inplace_mutation
 20  @dataclass
 21  class ToolCallDelta:
 22      """
 23      Represents a Tool call prepared by the model, usually contained in an assistant message.
 24  
 25      :param index: The index of the Tool call in the list of Tool calls.
 26      :param tool_name: The name of the Tool to call.
 27      :param arguments: Either the full arguments in JSON format or a delta of the arguments.
 28      :param id: The ID of the Tool call.
 29      :param extra: Dictionary of extra information about the Tool call. Use to store provider-specific
 30          information. To avoid serialization issues, values should be JSON serializable.
 31      """
 32  
 33      index: int
 34      tool_name: str | None = field(default=None)
 35      arguments: str | None = field(default=None)
 36      id: str | None = field(default=None)
 37      extra: dict[str, Any] | None = field(default=None)
 38  
 39      def to_dict(self) -> dict[str, Any]:
 40          """
 41          Returns a dictionary representation of the ToolCallDelta.
 42  
 43          :returns: A dictionary with keys 'index', 'tool_name', 'arguments', 'id', and 'extra'.
 44          """
 45          return asdict(self)
 46  
 47      @classmethod
 48      def from_dict(cls, data: dict[str, Any]) -> "ToolCallDelta":
 49          """
 50          Creates a ToolCallDelta from a serialized representation.
 51  
 52          :param data: Dictionary containing ToolCallDelta's attributes.
 53          :returns: A ToolCallDelta instance.
 54          """
 55          return ToolCallDelta(**data)
 56  
 57  
 58  @_warn_on_inplace_mutation
 59  @dataclass
 60  class ComponentInfo:
 61      """
 62      The `ComponentInfo` class encapsulates information about a component.
 63  
 64      :param type: The type of the component.
 65      :param name: The name of the component assigned when adding it to a pipeline.
 66  
 67      """
 68  
 69      type: str
 70      name: str | None = field(default=None)
 71  
 72      @classmethod
 73      def from_component(cls, component: Component) -> "ComponentInfo":
 74          """
 75          Create a `ComponentInfo` object from a `Component` instance.
 76  
 77          :param component:
 78              The `Component` instance.
 79          :returns:
 80              The `ComponentInfo` object with the type and name of the given component.
 81          """
 82          component_type = f"{component.__class__.__module__}.{component.__class__.__name__}"
 83          component_name = getattr(component, "__component_name__", None)
 84          return cls(type=component_type, name=component_name)
 85  
 86      def to_dict(self) -> dict[str, Any]:
 87          """
 88          Returns a dictionary representation of ComponentInfo.
 89  
 90          :returns: A dictionary with keys 'type' and 'name'.
 91          """
 92          return asdict(self)
 93  
 94      @classmethod
 95      def from_dict(cls, data: dict[str, Any]) -> "ComponentInfo":
 96          """
 97          Creates a ComponentInfo from a serialized representation.
 98  
 99          :param data: Dictionary containing ComponentInfo's attributes.
100          :returns: A ComponentInfo instance.
101          """
102          return ComponentInfo(**data)
103  
104  
105  @_warn_on_inplace_mutation
106  @dataclass
107  class StreamingChunk:
108      """
109      The `StreamingChunk` class encapsulates a segment of streamed content along with associated metadata.
110  
111      This structure facilitates the handling and processing of streamed data in a systematic manner.
112  
113      :param content: The content of the message chunk as a string.
114      :param meta: A dictionary containing metadata related to the message chunk.
115      :param component_info: A `ComponentInfo` object containing information about the component that generated the chunk,
116          such as the component name and type.
117      :param index: An optional integer index representing which content block this chunk belongs to.
118      :param tool_calls: An optional list of ToolCallDelta object representing a tool call associated with the message
119          chunk.
120      :param tool_call_result: An optional ToolCallResult object representing the result of a tool call.
121      :param start: A boolean indicating whether this chunk marks the start of a content block.
122      :param finish_reason: An optional value indicating the reason the generation finished.
123          Standard values follow OpenAI's convention: "stop", "length", "tool_calls", "content_filter",
124          plus Haystack-specific value "tool_call_results".
125      :param reasoning: An optional ReasoningContent object representing the reasoning content associated
126          with the message chunk.
127      """
128  
129      content: str
130      meta: dict[str, Any] = field(default_factory=dict, hash=False)
131      component_info: ComponentInfo | None = field(default=None)
132      index: int | None = field(default=None)
133      tool_calls: list[ToolCallDelta] | None = field(default=None)
134      tool_call_result: ToolCallResult | None = field(default=None)
135      start: bool = field(default=False)
136      finish_reason: FinishReason | None = field(default=None)
137      reasoning: ReasoningContent | None = field(default=None)
138  
139      def __post_init__(self) -> None:
140          fields_set = sum(bool(x) for x in (self.content, self.tool_calls, self.tool_call_result, self.reasoning))
141          if fields_set > 1:
142              raise ValueError(
143                  "Only one of `content`, `tool_call`, `tool_call_result` or `reasoning` may be set in a StreamingChunk. "
144                  f"Got content: '{self.content}', tool_call: '{self.tool_calls}', "
145                  f"tool_call_result: '{self.tool_call_result}', reasoning: '{self.reasoning}'."
146              )
147  
148          # NOTE: We don't enforce this for self.content otherwise it would be a breaking change
149          if (self.tool_calls or self.tool_call_result or self.reasoning) and self.index is None:
150              raise ValueError("If `tool_call`, `tool_call_result` or `reasoning` is set, `index` must also be set.")
151  
152      def to_dict(self) -> dict[str, Any]:
153          """
154          Returns a dictionary representation of the StreamingChunk.
155  
156          :returns: Serialized dictionary representation of the calling object.
157          """
158          return {
159              "content": self.content,
160              "meta": self.meta,
161              "component_info": self.component_info.to_dict() if self.component_info else None,
162              "index": self.index,
163              "tool_calls": [tc.to_dict() for tc in self.tool_calls] if self.tool_calls else None,
164              "tool_call_result": self.tool_call_result.to_dict() if self.tool_call_result else None,
165              "start": self.start,
166              "finish_reason": self.finish_reason,
167              "reasoning": self.reasoning.to_dict() if self.reasoning else None,
168          }
169  
170      @classmethod
171      def from_dict(cls, data: dict[str, Any]) -> "StreamingChunk":
172          """
173          Creates a deserialized StreamingChunk instance from a serialized representation.
174  
175          :param data: Dictionary containing the StreamingChunk's attributes.
176          :returns: A StreamingChunk instance.
177          """
178          if "content" not in data:
179              raise ValueError("Missing required field `content` in StreamingChunk deserialization.")
180  
181          return StreamingChunk(
182              content=data["content"],
183              meta=data.get("meta", {}),
184              component_info=ComponentInfo.from_dict(data["component_info"]) if data.get("component_info") else None,
185              index=data.get("index"),
186              tool_calls=[ToolCallDelta.from_dict(tc) for tc in data["tool_calls"]] if data.get("tool_calls") else None,
187              tool_call_result=ToolCallResult.from_dict(data["tool_call_result"])
188              if data.get("tool_call_result")
189              else None,
190              start=data.get("start", False),
191              finish_reason=data.get("finish_reason"),
192              reasoning=ReasoningContent.from_dict(data["reasoning"]) if data.get("reasoning") else None,
193          )
194  
195  
196  SyncStreamingCallbackT = Callable[[StreamingChunk], None]
197  AsyncStreamingCallbackT = Callable[[StreamingChunk], Awaitable[None]]
198  
199  StreamingCallbackT = SyncStreamingCallbackT | AsyncStreamingCallbackT
200  
201  
202  @overload
203  def select_streaming_callback(
204      init_callback: StreamingCallbackT | None,
205      runtime_callback: StreamingCallbackT | None,
206      requires_async: Literal[False],
207  ) -> SyncStreamingCallbackT | None: ...
208  @overload
209  def select_streaming_callback(
210      init_callback: StreamingCallbackT | None, runtime_callback: StreamingCallbackT | None, requires_async: Literal[True]
211  ) -> AsyncStreamingCallbackT | None: ...
212  
213  
214  def select_streaming_callback(
215      init_callback: StreamingCallbackT | None, runtime_callback: StreamingCallbackT | None, requires_async: bool
216  ) -> StreamingCallbackT | None:
217      """
218      Picks the correct streaming callback given an optional initial and runtime callback.
219  
220      The runtime callback takes precedence over the initial callback.
221  
222      :param init_callback:
223          The initial callback.
224      :param runtime_callback:
225          The runtime callback.
226      :param requires_async:
227          Whether the selected callback must be async compatible.
228      :returns:
229          The selected callback.
230      """
231      if init_callback is not None:
232          if requires_async and not is_callable_async_compatible(init_callback):
233              raise ValueError("The init callback must be async compatible.")
234          if not requires_async and is_callable_async_compatible(init_callback):
235              raise ValueError("The init callback cannot be a coroutine.")
236  
237      if runtime_callback is not None:
238          if requires_async and not is_callable_async_compatible(runtime_callback):
239              raise ValueError("The runtime callback must be async compatible.")
240          if not requires_async and is_callable_async_compatible(runtime_callback):
241              raise ValueError("The runtime callback cannot be a coroutine.")
242  
243      return runtime_callback or init_callback