test_streaming_chunk.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import warnings 6 7 import pytest 8 9 from haystack import Pipeline, component 10 from haystack.dataclasses import ( 11 ComponentInfo, 12 ReasoningContent, 13 StreamingChunk, 14 ToolCall, 15 ToolCallDelta, 16 ToolCallResult, 17 ) 18 19 20 @component 21 class ExampleComponent: 22 def __init__(self): 23 self.name = "test_component" 24 25 def run(self) -> str: 26 return "Test content" 27 28 29 def test_create_chunk_with_content_and_metadata(): 30 chunk = StreamingChunk(content="Test content", meta={"key": "value"}) 31 32 assert chunk.content == "Test content" 33 assert chunk.meta == {"key": "value"} 34 35 36 def test_create_chunk_with_only_content(): 37 chunk = StreamingChunk(content="Test content") 38 39 assert chunk.content == "Test content" 40 assert chunk.meta == {} 41 42 43 def test_access_content(): 44 chunk = StreamingChunk(content="Test content", meta={"key": "value"}) 45 assert chunk.content == "Test content" 46 47 48 def test_create_chunk_with_empty_content(): 49 chunk = StreamingChunk(content="") 50 assert chunk.content == "" 51 assert chunk.meta == {} 52 53 54 def test_create_chunk_with_all_fields(): 55 component_info = ComponentInfo(type="test.component", name="test_component") 56 chunk = StreamingChunk(content="Test content", meta={"key": "value"}, component_info=component_info) 57 58 assert chunk.content == "Test content" 59 assert chunk.meta == {"key": "value"} 60 assert chunk.component_info == component_info 61 62 63 def test_create_chunk_with_content_and_tool_call(): 64 with pytest.raises(ValueError): 65 # Can't have content + tool_call at the same time 66 StreamingChunk( 67 content="Test content", 68 meta={"key": "value"}, 69 tool_calls=[ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0)], 70 ) 71 72 73 def test_create_chunk_with_content_and_tool_call_result(): 74 with pytest.raises(ValueError): 75 # Can't have content + tool_call_result at the same time 76 StreamingChunk( 77 content="Test content", 78 meta={"key": "value"}, 79 tool_call_result=ToolCallResult( 80 result="output", 81 origin=ToolCall(id="123", tool_name="test_tool", arguments={"arg1": "value1"}), 82 error=False, 83 ), 84 ) 85 86 87 def test_create_chunk_with_content_and_reasoning(): 88 with pytest.raises(ValueError, match="Only one of `content`, `tool_call`, `tool_call_result`"): 89 StreamingChunk( 90 content="Test content", meta={"key": "value"}, reasoning=ReasoningContent(reasoning_text="thinking") 91 ) 92 93 94 def test_reasoning_and_no_index(): 95 with pytest.raises( 96 ValueError, match="If `tool_call`, `tool_call_result` or `reasoning` is set, `index` must also be set." 97 ): 98 StreamingChunk(content="", meta={"key": "value"}, reasoning=ReasoningContent(reasoning_text="thinking")) 99 100 101 def test_component_info_from_component(): 102 component_info = ComponentInfo.from_component(ExampleComponent()) 103 assert component_info.type == "test_streaming_chunk.ExampleComponent" 104 105 106 def test_component_info_from_component_with_name_from_pipeline(): 107 pipeline = Pipeline() 108 comp = ExampleComponent() 109 pipeline.add_component("pipeline_component", comp) 110 component_info = ComponentInfo.from_component(comp) 111 assert component_info.type == "test_streaming_chunk.ExampleComponent" 112 assert component_info.name == "pipeline_component" 113 114 115 def test_tool_call_delta(): 116 tool_call = ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0) 117 assert tool_call.id == "123" 118 assert tool_call.tool_name == "test_tool" 119 assert tool_call.arguments == '{"arg1": "value1"}' 120 assert tool_call.index == 0 121 122 123 def test_create_chunk_with_finish_reason(): 124 """Test creating a chunk with the new finish_reason field.""" 125 chunk = StreamingChunk(content="Test content", finish_reason="stop") 126 127 assert chunk.content == "Test content" 128 assert chunk.finish_reason == "stop" 129 assert chunk.meta == {} 130 131 132 def test_create_chunk_with_finish_reason_and_meta(): 133 """Test creating a chunk with both finish_reason field and meta.""" 134 chunk = StreamingChunk( 135 content="Test content", finish_reason="stop", meta={"model": "gpt-4", "usage": {"tokens": 10}} 136 ) 137 138 assert chunk.content == "Test content" 139 assert chunk.finish_reason == "stop" 140 assert chunk.meta["model"] == "gpt-4" 141 assert chunk.meta["usage"]["tokens"] == 10 142 143 144 def test_finish_reason_standard_values(): 145 """Test all standard finish_reason values including the new Haystack-specific ones.""" 146 standard_values = ["stop", "length", "tool_calls", "content_filter", "tool_call_results"] 147 148 for value in standard_values: 149 chunk = StreamingChunk(content="Test content", finish_reason=value) 150 assert chunk.finish_reason == value 151 152 153 def test_finish_reason_tool_call_results(): 154 """Test specifically the new tool_call_results finish reason.""" 155 chunk = StreamingChunk(content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}) 156 157 assert chunk.finish_reason == "tool_call_results" 158 assert chunk.meta["finish_reason"] == "tool_call_results" 159 assert chunk.content == "" 160 161 162 def test_to_dict_tool_call_result(): 163 """Test the to_dict method for StreamingChunk with tool_call_result.""" 164 component_info = ComponentInfo.from_component(ExampleComponent()) 165 tool_call_result = ToolCallResult( 166 result="output", origin=ToolCall(id="123", tool_name="test_tool", arguments={"arg1": "value1"}), error=False 167 ) 168 169 chunk = StreamingChunk( 170 content="", 171 meta={"key": "value"}, 172 index=0, 173 component_info=component_info, 174 tool_call_result=tool_call_result, 175 finish_reason="tool_call_results", 176 ) 177 178 d = chunk.to_dict() 179 180 assert d["content"] == "" 181 assert d["meta"] == {"key": "value"} 182 assert d["index"] == 0 183 assert d["component_info"]["type"] == "test_streaming_chunk.ExampleComponent" 184 assert d["tool_call_result"]["result"] == "output" 185 assert d["tool_call_result"]["error"] is False 186 assert d["tool_call_result"]["origin"]["id"] == "123" 187 assert d["tool_call_result"]["origin"]["arguments"]["arg1"] == "value1" 188 assert d["finish_reason"] == "tool_call_results" 189 assert d["reasoning"] is None 190 191 192 def test_to_dict_tool_calls(): 193 """Test the to_dict method for StreamingChunk with tool_calls.""" 194 component_info = ComponentInfo.from_component(ExampleComponent()) 195 tool_calls = [ 196 ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}', index=0), 197 ToolCallDelta(id="456", tool_name="another_tool", arguments='{"arg2": "value2"}', index=1), 198 ] 199 200 chunk = StreamingChunk( 201 content="", 202 meta={"key": "value"}, 203 index=0, 204 component_info=component_info, 205 tool_calls=tool_calls, 206 finish_reason="tool_calls", 207 ) 208 209 d = chunk.to_dict() 210 211 assert d["content"] == "" 212 assert d["meta"] == {"key": "value"} 213 assert d["index"] == 0 214 assert d["component_info"]["type"] == "test_streaming_chunk.ExampleComponent" 215 assert len(d["tool_calls"]) == 2 216 assert d["tool_calls"][0]["id"] == "123" 217 assert d["tool_calls"][0]["index"] == 0 218 assert d["tool_calls"][1]["id"] == "456" 219 assert d["tool_calls"][1]["index"] == 1 220 assert d["finish_reason"] == "tool_calls" 221 assert d["reasoning"] is None 222 223 224 def test_to_dict_reasoning(): 225 """Test the to_dict method for StreamingChunk with reasoning.""" 226 component_info = ComponentInfo.from_component(ExampleComponent()) 227 reasoning = ReasoningContent(reasoning_text="thinking", extra={"step": 1}) 228 229 chunk = StreamingChunk( 230 content="", 231 meta={"key": "value"}, 232 index=0, 233 component_info=component_info, 234 reasoning=reasoning, 235 finish_reason="stop", 236 ) 237 238 d = chunk.to_dict() 239 240 assert d["content"] == "" 241 assert d["meta"] == {"key": "value"} 242 assert d["index"] == 0 243 assert d["component_info"]["type"] == "test_streaming_chunk.ExampleComponent" 244 assert d["reasoning"]["reasoning_text"] == "thinking" 245 assert d["reasoning"]["extra"]["step"] == 1 246 assert d["finish_reason"] == "stop" 247 assert d["tool_calls"] is None 248 assert d["tool_call_result"] is None 249 250 251 def test_from_dict_tool_call_result(): 252 """Test the from_dict method for StreamingChunk with tool_call_result.""" 253 component_info = {"type": "test_streaming_chunk.ExampleComponent", "name": "test_component"} 254 tool_call_result = { 255 "result": "output", 256 "origin": {"id": "123", "tool_name": "test_tool", "arguments": {"arg1": "value1"}}, 257 "error": False, 258 } 259 260 data = { 261 "content": "", 262 "meta": {"key": "value"}, 263 "index": 0, 264 "component_info": component_info, 265 "tool_call_result": tool_call_result, 266 "finish_reason": "tool_call_results", 267 } 268 269 chunk = StreamingChunk.from_dict(data) 270 271 assert chunk.content == "" 272 assert chunk.meta == {"key": "value"} 273 assert chunk.index == 0 274 assert chunk.component_info.type == "test_streaming_chunk.ExampleComponent" 275 assert chunk.component_info.name == "test_component" 276 assert chunk.tool_call_result.result == "output" 277 assert chunk.tool_call_result.error is False 278 assert chunk.tool_call_result.origin.id == "123" 279 assert chunk.reasoning is None 280 281 282 def test_from_dict_tool_calls(): 283 """Test the from_dict method for StreamingChunk with tool_calls.""" 284 component_info = {"type": "test_streaming_chunk.ExampleComponent", "name": "test_component"} 285 tool_calls = [{"id": "123", "tool_name": "test_tool", "arguments": '{"arg1": "value1"}', "index": 0}] 286 287 data = { 288 "content": "", 289 "meta": {"key": "value"}, 290 "index": 0, 291 "component_info": component_info, 292 "tool_calls": tool_calls, 293 "finish_reason": "tool_calls", 294 } 295 296 chunk = StreamingChunk.from_dict(data) 297 298 assert chunk.content == "" 299 assert chunk.meta == {"key": "value"} 300 assert chunk.index == 0 301 assert chunk.component_info.type == "test_streaming_chunk.ExampleComponent" 302 assert chunk.component_info.name == "test_component" 303 assert chunk.tool_calls[0].tool_name == "test_tool" 304 assert chunk.tool_calls[0].index == 0 305 assert chunk.finish_reason == "tool_calls" 306 assert chunk.reasoning is None 307 308 309 def test_from_dict_reasoning(): 310 """Test the from_dict method for StreamingChunk with reasoning.""" 311 component_info = {"type": "test_streaming_chunk.ExampleComponent", "name": "test_component"} 312 reasoning = {"reasoning_text": "thinking", "extra": {"step": 1}} 313 314 data = { 315 "content": "", 316 "meta": {"key": "value"}, 317 "index": 0, 318 "component_info": component_info, 319 "reasoning": reasoning, 320 "finish_reason": "stop", 321 } 322 323 chunk = StreamingChunk.from_dict(data) 324 325 assert chunk.content == "" 326 assert chunk.meta == {"key": "value"} 327 assert chunk.index == 0 328 assert chunk.component_info.type == "test_streaming_chunk.ExampleComponent" 329 assert chunk.component_info.name == "test_component" 330 assert chunk.reasoning.reasoning_text == "thinking" 331 assert chunk.reasoning.extra["step"] == 1 332 assert chunk.finish_reason == "stop" 333 assert chunk.tool_calls is None 334 assert chunk.tool_call_result is None 335 336 337 def test_tool_call_delta_no_warning_on_init(): 338 with warnings.catch_warnings(): 339 warnings.simplefilter("error", Warning) 340 ToolCallDelta(index=0, tool_name="t") 341 342 343 def test_tool_call_delta_warn_on_inplace_mutation(): 344 tcd = ToolCallDelta(index=0, tool_name="t") 345 with pytest.warns(Warning, match="dataclasses.replace"): 346 tcd.tool_name = "other" 347 348 349 def test_component_info_no_warning_on_init(): 350 with warnings.catch_warnings(): 351 warnings.simplefilter("error", Warning) 352 ComponentInfo(type="test.component", name="my_component") 353 354 355 def test_component_info_warn_on_inplace_mutation(): 356 ci = ComponentInfo(type="test.component", name="my_component") 357 with pytest.warns(Warning, match="dataclasses.replace"): 358 ci.name = "other" 359 360 361 def test_streaming_chunk_no_warning_on_init(): 362 with warnings.catch_warnings(): 363 warnings.simplefilter("error", Warning) 364 StreamingChunk(content="test") 365 366 367 def test_streaming_chunk_warn_on_inplace_mutation(): 368 chunk = StreamingChunk(content="test") 369 with pytest.warns(Warning, match="dataclasses.replace"): 370 chunk.content = "other"