/ test / dataclasses / test_streaming_chunk.py
test_streaming_chunk.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import warnings
  6  
  7  import pytest
  8  
  9  from haystack import Pipeline, component
 10  from haystack.dataclasses import (
 11      ComponentInfo,
 12      ReasoningContent,
 13      StreamingChunk,
 14      ToolCall,
 15      ToolCallDelta,
 16      ToolCallResult,
 17  )
 18  
 19  
 20  @component
 21  class ExampleComponent:
 22      def __init__(self):
 23          self.name = "test_component"
 24  
 25      def run(self) -> str:
 26          return "Test content"
 27  
 28  
 29  def test_create_chunk_with_content_and_metadata():
 30      chunk = StreamingChunk(content="Test content", meta={"key": "value"})
 31  
 32      assert chunk.content == "Test content"
 33      assert chunk.meta == {"key": "value"}
 34  
 35  
 36  def test_create_chunk_with_only_content():
 37      chunk = StreamingChunk(content="Test content")
 38  
 39      assert chunk.content == "Test content"
 40      assert chunk.meta == {}
 41  
 42  
 43  def test_access_content():
 44      chunk = StreamingChunk(content="Test content", meta={"key": "value"})
 45      assert chunk.content == "Test content"
 46  
 47  
 48  def test_create_chunk_with_empty_content():
 49      chunk = StreamingChunk(content="")
 50      assert chunk.content == ""
 51      assert chunk.meta == {}
 52  
 53  
 54  def test_create_chunk_with_all_fields():
 55      component_info = ComponentInfo(type="test.component", name="test_component")
 56      chunk = StreamingChunk(content="Test content", meta={"key": "value"}, component_info=component_info)
 57  
 58      assert chunk.content == "Test content"
 59      assert chunk.meta == {"key": "value"}
 60      assert chunk.component_info == component_info
 61  
 62  
 63  def test_create_chunk_with_content_and_tool_call():
 64      with pytest.raises(ValueError):
 65          # Can't have content + tool_call at the same time
 66          StreamingChunk(
 67              content="Test content",
 68              meta={"key": "value"},
 69              tool_calls=[ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0)],
 70          )
 71  
 72  
 73  def test_create_chunk_with_content_and_tool_call_result():
 74      with pytest.raises(ValueError):
 75          # Can't have content + tool_call_result at the same time
 76          StreamingChunk(
 77              content="Test content",
 78              meta={"key": "value"},
 79              tool_call_result=ToolCallResult(
 80                  result="output",
 81                  origin=ToolCall(id="123", tool_name="test_tool", arguments={"arg1": "value1"}),
 82                  error=False,
 83              ),
 84          )
 85  
 86  
 87  def test_create_chunk_with_content_and_reasoning():
 88      with pytest.raises(ValueError, match="Only one of `content`, `tool_call`, `tool_call_result`"):
 89          StreamingChunk(
 90              content="Test content", meta={"key": "value"}, reasoning=ReasoningContent(reasoning_text="thinking")
 91          )
 92  
 93  
 94  def test_reasoning_and_no_index():
 95      with pytest.raises(
 96          ValueError, match="If `tool_call`, `tool_call_result` or `reasoning` is set, `index` must also be set."
 97      ):
 98          StreamingChunk(content="", meta={"key": "value"}, reasoning=ReasoningContent(reasoning_text="thinking"))
 99  
100  
101  def test_component_info_from_component():
102      component_info = ComponentInfo.from_component(ExampleComponent())
103      assert component_info.type == "test_streaming_chunk.ExampleComponent"
104  
105  
106  def test_component_info_from_component_with_name_from_pipeline():
107      pipeline = Pipeline()
108      comp = ExampleComponent()
109      pipeline.add_component("pipeline_component", comp)
110      component_info = ComponentInfo.from_component(comp)
111      assert component_info.type == "test_streaming_chunk.ExampleComponent"
112      assert component_info.name == "pipeline_component"
113  
114  
115  def test_tool_call_delta():
116      tool_call = ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0)
117      assert tool_call.id == "123"
118      assert tool_call.tool_name == "test_tool"
119      assert tool_call.arguments == '{"arg1": "value1"}'
120      assert tool_call.index == 0
121  
122  
123  def test_create_chunk_with_finish_reason():
124      """Test creating a chunk with the new finish_reason field."""
125      chunk = StreamingChunk(content="Test content", finish_reason="stop")
126  
127      assert chunk.content == "Test content"
128      assert chunk.finish_reason == "stop"
129      assert chunk.meta == {}
130  
131  
132  def test_create_chunk_with_finish_reason_and_meta():
133      """Test creating a chunk with both finish_reason field and meta."""
134      chunk = StreamingChunk(
135          content="Test content", finish_reason="stop", meta={"model": "gpt-4", "usage": {"tokens": 10}}
136      )
137  
138      assert chunk.content == "Test content"
139      assert chunk.finish_reason == "stop"
140      assert chunk.meta["model"] == "gpt-4"
141      assert chunk.meta["usage"]["tokens"] == 10
142  
143  
144  def test_finish_reason_standard_values():
145      """Test all standard finish_reason values including the new Haystack-specific ones."""
146      standard_values = ["stop", "length", "tool_calls", "content_filter", "tool_call_results"]
147  
148      for value in standard_values:
149          chunk = StreamingChunk(content="Test content", finish_reason=value)
150          assert chunk.finish_reason == value
151  
152  
153  def test_finish_reason_tool_call_results():
154      """Test specifically the new tool_call_results finish reason."""
155      chunk = StreamingChunk(content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"})
156  
157      assert chunk.finish_reason == "tool_call_results"
158      assert chunk.meta["finish_reason"] == "tool_call_results"
159      assert chunk.content == ""
160  
161  
162  def test_to_dict_tool_call_result():
163      """Test the to_dict method for StreamingChunk with tool_call_result."""
164      component_info = ComponentInfo.from_component(ExampleComponent())
165      tool_call_result = ToolCallResult(
166          result="output", origin=ToolCall(id="123", tool_name="test_tool", arguments={"arg1": "value1"}), error=False
167      )
168  
169      chunk = StreamingChunk(
170          content="",
171          meta={"key": "value"},
172          index=0,
173          component_info=component_info,
174          tool_call_result=tool_call_result,
175          finish_reason="tool_call_results",
176      )
177  
178      d = chunk.to_dict()
179  
180      assert d["content"] == ""
181      assert d["meta"] == {"key": "value"}
182      assert d["index"] == 0
183      assert d["component_info"]["type"] == "test_streaming_chunk.ExampleComponent"
184      assert d["tool_call_result"]["result"] == "output"
185      assert d["tool_call_result"]["error"] is False
186      assert d["tool_call_result"]["origin"]["id"] == "123"
187      assert d["tool_call_result"]["origin"]["arguments"]["arg1"] == "value1"
188      assert d["finish_reason"] == "tool_call_results"
189      assert d["reasoning"] is None
190  
191  
192  def test_to_dict_tool_calls():
193      """Test the to_dict method for StreamingChunk with tool_calls."""
194      component_info = ComponentInfo.from_component(ExampleComponent())
195      tool_calls = [
196          ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0),
197          ToolCallDelta(id="456", tool_name="another_tool", arguments='{"arg2": "value2"}', index=1),
198      ]
199  
200      chunk = StreamingChunk(
201          content="",
202          meta={"key": "value"},
203          index=0,
204          component_info=component_info,
205          tool_calls=tool_calls,
206          finish_reason="tool_calls",
207      )
208  
209      d = chunk.to_dict()
210  
211      assert d["content"] == ""
212      assert d["meta"] == {"key": "value"}
213      assert d["index"] == 0
214      assert d["component_info"]["type"] == "test_streaming_chunk.ExampleComponent"
215      assert len(d["tool_calls"]) == 2
216      assert d["tool_calls"][0]["id"] == "123"
217      assert d["tool_calls"][0]["index"] == 0
218      assert d["tool_calls"][1]["id"] == "456"
219      assert d["tool_calls"][1]["index"] == 1
220      assert d["finish_reason"] == "tool_calls"
221      assert d["reasoning"] is None
222  
223  
224  def test_to_dict_reasoning():
225      """Test the to_dict method for StreamingChunk with reasoning."""
226      component_info = ComponentInfo.from_component(ExampleComponent())
227      reasoning = ReasoningContent(reasoning_text="thinking", extra={"step": 1})
228  
229      chunk = StreamingChunk(
230          content="",
231          meta={"key": "value"},
232          index=0,
233          component_info=component_info,
234          reasoning=reasoning,
235          finish_reason="stop",
236      )
237  
238      d = chunk.to_dict()
239  
240      assert d["content"] == ""
241      assert d["meta"] == {"key": "value"}
242      assert d["index"] == 0
243      assert d["component_info"]["type"] == "test_streaming_chunk.ExampleComponent"
244      assert d["reasoning"]["reasoning_text"] == "thinking"
245      assert d["reasoning"]["extra"]["step"] == 1
246      assert d["finish_reason"] == "stop"
247      assert d["tool_calls"] is None
248      assert d["tool_call_result"] is None
249  
250  
251  def test_from_dict_tool_call_result():
252      """Test the from_dict method for StreamingChunk with tool_call_result."""
253      component_info = {"type": "test_streaming_chunk.ExampleComponent", "name": "test_component"}
254      tool_call_result = {
255          "result": "output",
256          "origin": {"id": "123", "tool_name": "test_tool", "arguments": {"arg1": "value1"}},
257          "error": False,
258      }
259  
260      data = {
261          "content": "",
262          "meta": {"key": "value"},
263          "index": 0,
264          "component_info": component_info,
265          "tool_call_result": tool_call_result,
266          "finish_reason": "tool_call_results",
267      }
268  
269      chunk = StreamingChunk.from_dict(data)
270  
271      assert chunk.content == ""
272      assert chunk.meta == {"key": "value"}
273      assert chunk.index == 0
274      assert chunk.component_info.type == "test_streaming_chunk.ExampleComponent"
275      assert chunk.component_info.name == "test_component"
276      assert chunk.tool_call_result.result == "output"
277      assert chunk.tool_call_result.error is False
278      assert chunk.tool_call_result.origin.id == "123"
279      assert chunk.reasoning is None
280  
281  
282  def test_from_dict_tool_calls():
283      """Test the from_dict method for StreamingChunk with tool_calls."""
284      component_info = {"type": "test_streaming_chunk.ExampleComponent", "name": "test_component"}
285      tool_calls = [{"id": "123", "tool_name": "test_tool", "arguments": '{"arg1": "value1"}', "index": 0}]
286  
287      data = {
288          "content": "",
289          "meta": {"key": "value"},
290          "index": 0,
291          "component_info": component_info,
292          "tool_calls": tool_calls,
293          "finish_reason": "tool_calls",
294      }
295  
296      chunk = StreamingChunk.from_dict(data)
297  
298      assert chunk.content == ""
299      assert chunk.meta == {"key": "value"}
300      assert chunk.index == 0
301      assert chunk.component_info.type == "test_streaming_chunk.ExampleComponent"
302      assert chunk.component_info.name == "test_component"
303      assert chunk.tool_calls[0].tool_name == "test_tool"
304      assert chunk.tool_calls[0].index == 0
305      assert chunk.finish_reason == "tool_calls"
306      assert chunk.reasoning is None
307  
308  
309  def test_from_dict_reasoning():
310      """Test the from_dict method for StreamingChunk with reasoning."""
311      component_info = {"type": "test_streaming_chunk.ExampleComponent", "name": "test_component"}
312      reasoning = {"reasoning_text": "thinking", "extra": {"step": 1}}
313  
314      data = {
315          "content": "",
316          "meta": {"key": "value"},
317          "index": 0,
318          "component_info": component_info,
319          "reasoning": reasoning,
320          "finish_reason": "stop",
321      }
322  
323      chunk = StreamingChunk.from_dict(data)
324  
325      assert chunk.content == ""
326      assert chunk.meta == {"key": "value"}
327      assert chunk.index == 0
328      assert chunk.component_info.type == "test_streaming_chunk.ExampleComponent"
329      assert chunk.component_info.name == "test_component"
330      assert chunk.reasoning.reasoning_text == "thinking"
331      assert chunk.reasoning.extra["step"] == 1
332      assert chunk.finish_reason == "stop"
333      assert chunk.tool_calls is None
334      assert chunk.tool_call_result is None
335  
336  
337  def test_tool_call_delta_no_warning_on_init():
338      with warnings.catch_warnings():
339          warnings.simplefilter("error", Warning)
340          ToolCallDelta(index=0, tool_name="t")
341  
342  
343  def test_tool_call_delta_warn_on_inplace_mutation():
344      tcd = ToolCallDelta(index=0, tool_name="t")
345      with pytest.warns(Warning, match="dataclasses.replace"):
346          tcd.tool_name = "other"
347  
348  
349  def test_component_info_no_warning_on_init():
350      with warnings.catch_warnings():
351          warnings.simplefilter("error", Warning)
352          ComponentInfo(type="test.component", name="my_component")
353  
354  
355  def test_component_info_warn_on_inplace_mutation():
356      ci = ComponentInfo(type="test.component", name="my_component")
357      with pytest.warns(Warning, match="dataclasses.replace"):
358          ci.name = "other"
359  
360  
361  def test_streaming_chunk_no_warning_on_init():
362      with warnings.catch_warnings():
363          warnings.simplefilter("error", Warning)
364          StreamingChunk(content="test")
365  
366  
367  def test_streaming_chunk_warn_on_inplace_mutation():
368      chunk = StreamingChunk(content="test")
369      with pytest.warns(Warning, match="dataclasses.replace"):
370          chunk.content = "other"