test_hugging_face_api.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 from datetime import datetime 7 from typing import Any 8 from unittest.mock import AsyncMock, MagicMock, Mock, patch 9 10 import pytest 11 from huggingface_hub import ( 12 ChatCompletionInputStreamOptions, 13 ChatCompletionOutput, 14 ChatCompletionOutputComplete, 15 ChatCompletionOutputFunctionDefinition, 16 ChatCompletionOutputMessage, 17 ChatCompletionOutputToolCall, 18 ChatCompletionOutputUsage, 19 ChatCompletionStreamOutput, 20 ChatCompletionStreamOutputChoice, 21 ChatCompletionStreamOutputDelta, 22 ChatCompletionStreamOutputUsage, 23 ) 24 from huggingface_hub.errors import RepositoryNotFoundError 25 26 from haystack import Pipeline 27 from haystack.components.generators.chat.hugging_face_api import ( 28 HuggingFaceAPIChatGenerator, 29 _convert_chat_completion_stream_output_to_streaming_chunk, 30 _convert_hfapi_tool_calls, 31 _convert_tools_to_hfapi_tools, 32 _resolve_schema_refs, 33 ) 34 from haystack.dataclasses import ChatMessage, ImageContent, ReasoningContent, StreamingChunk, ToolCall 35 from haystack.tools import Tool 36 from haystack.tools.toolset import Toolset 37 from haystack.utils.auth import Secret 38 from haystack.utils.hf import HFGenerationAPIType 39 40 41 @pytest.fixture 42 def chat_messages(): 43 return [ 44 ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), 45 ChatMessage.from_user("Tell me about Berlin"), 46 ] 47 48 49 def get_weather(city: str) -> dict[str, Any]: 50 weather_info = { 51 "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, 52 "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, 53 "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, 54 } 55 return weather_info.get(city, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) 56 57 58 @pytest.fixture 59 def tools(): 60 weather_tool = Tool( 61 name="weather", 62 description="useful to determine the weather in a given location", 63 parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, 64 function=get_weather, 65 ) 66 return [weather_tool] 67 68 69 @pytest.fixture 70 def mock_check_valid_model(): 71 with patch( 72 "haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None) 73 ) as mock: 74 yield mock 75 76 77 @pytest.fixture 78 def mock_chat_completion(): 79 # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example 80 81 with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: 82 completion = ChatCompletionOutput( 83 choices=[ 84 ChatCompletionOutputComplete( 85 finish_reason="eos_token", 86 index=0, 87 message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), 88 ) 89 ], 90 id="some_id", 91 model="some_model", 92 system_fingerprint="some_fingerprint", 93 usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), 94 created=1710498360, 95 ) 96 97 mock_chat_completion.return_value = completion 98 yield mock_chat_completion 99 100 101 @pytest.fixture 102 def mock_chat_completion_async(): 103 with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion: 104 completion = ChatCompletionOutput( 105 choices=[ 106 ChatCompletionOutputComplete( 107 finish_reason="eos_token", 108 index=0, 109 message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), 110 ) 111 ], 112 id="some_id", 113 model="some_model", 114 system_fingerprint="some_fingerprint", 115 usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), 116 created=1710498360, 117 ) 118 119 # Use AsyncMock to properly mock the async method 120 mock_chat_completion.return_value = completion 121 mock_chat_completion.__call__ = AsyncMock(return_value=completion) 122 123 yield mock_chat_completion 124 125 126 # used to test serialization of streaming_callback 127 def streaming_callback_handler(x): 128 return x 129 130 131 class TestHuggingFaceAPIChatGenerator: 132 def test_init_invalid_api_type(self): 133 with pytest.raises(ValueError): 134 HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) 135 136 def test_init_serverless(self, mock_check_valid_model): 137 model = "HuggingFaceH4/zephyr-7b-alpha" 138 generation_kwargs = {"temperature": 0.6} 139 stop_words = ["stop"] 140 streaming_callback = None 141 142 generator = HuggingFaceAPIChatGenerator( 143 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 144 api_params={"model": model}, 145 token=None, 146 generation_kwargs=generation_kwargs, 147 stop_words=stop_words, 148 streaming_callback=streaming_callback, 149 ) 150 151 assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API 152 assert generator.api_params == {"model": model} 153 assert generator.generation_kwargs == {**generation_kwargs, "stop": ["stop"], "max_tokens": 512} 154 assert generator.streaming_callback == streaming_callback 155 assert generator.tools is None 156 157 # check that client and async_client are initialized 158 assert generator._client.model == model 159 assert generator._async_client.model == model 160 161 def test_init_serverless_with_tools(self, mock_check_valid_model, tools): 162 model = "HuggingFaceH4/zephyr-7b-alpha" 163 generation_kwargs = {"temperature": 0.6} 164 stop_words = ["stop"] 165 streaming_callback = None 166 167 generator = HuggingFaceAPIChatGenerator( 168 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 169 api_params={"model": model}, 170 token=None, 171 generation_kwargs=generation_kwargs, 172 stop_words=stop_words, 173 streaming_callback=streaming_callback, 174 tools=tools, 175 ) 176 177 assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API 178 assert generator.api_params == {"model": model} 179 assert generator.generation_kwargs == {**generation_kwargs, "stop": ["stop"], "max_tokens": 512} 180 assert generator.streaming_callback == streaming_callback 181 assert generator.tools == tools 182 183 assert generator._client.model == model 184 assert generator._async_client.model == model 185 186 def test_init_serverless_invalid_model(self, mock_check_valid_model): 187 mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id", response=MagicMock()) 188 with pytest.raises(RepositoryNotFoundError): 189 HuggingFaceAPIChatGenerator( 190 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} 191 ) 192 193 def test_init_serverless_no_model(self): 194 with pytest.raises(ValueError): 195 HuggingFaceAPIChatGenerator( 196 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} 197 ) 198 199 def test_init_tgi(self): 200 url = "https://some_model.com" 201 generation_kwargs = {"temperature": 0.6} 202 stop_words = ["stop"] 203 streaming_callback = None 204 205 generator = HuggingFaceAPIChatGenerator( 206 api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, 207 api_params={"url": url}, 208 token=None, 209 generation_kwargs=generation_kwargs, 210 stop_words=stop_words, 211 streaming_callback=streaming_callback, 212 ) 213 214 assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE 215 assert generator.api_params == {"url": url} 216 assert generator.generation_kwargs == {**generation_kwargs, "stop": ["stop"], "max_tokens": 512} 217 assert generator.streaming_callback == streaming_callback 218 assert generator.tools is None 219 220 assert generator._client.model == url 221 assert generator._async_client.model == url 222 223 def test_init_tgi_invalid_url(self): 224 with pytest.raises(ValueError): 225 HuggingFaceAPIChatGenerator( 226 api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"} 227 ) 228 229 def test_init_tgi_no_url(self): 230 with pytest.raises(ValueError): 231 HuggingFaceAPIChatGenerator( 232 api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} 233 ) 234 235 def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): 236 duplicate_tools = [tools[0], tools[0]] 237 with pytest.raises(ValueError): 238 HuggingFaceAPIChatGenerator( 239 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 240 api_params={"model": "irrelevant"}, 241 tools=duplicate_tools, 242 ) 243 244 def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): 245 with pytest.raises(ValueError): 246 HuggingFaceAPIChatGenerator( 247 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 248 api_params={"model": "irrelevant"}, 249 tools=tools, 250 streaming_callback=streaming_callback_handler, 251 ) 252 253 def test_to_dict(self, mock_check_valid_model): 254 tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) 255 256 generator = HuggingFaceAPIChatGenerator( 257 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 258 api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, 259 generation_kwargs={"temperature": 0.6}, 260 stop_words=["stop", "words"], 261 tools=[tool], 262 ) 263 264 result = generator.to_dict() 265 init_params = result["init_parameters"] 266 267 assert init_params["api_type"] == "serverless_inference_api" 268 assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} 269 assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} 270 assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} 271 assert init_params["streaming_callback"] is None 272 assert init_params["tools"] == [ 273 { 274 "type": "haystack.tools.tool.Tool", 275 "data": { 276 "description": "description", 277 "function": "builtins.print", 278 "inputs_from_state": None, 279 "name": "name", 280 "outputs_to_state": None, 281 "outputs_to_string": None, 282 "parameters": {"x": {"type": "string"}}, 283 }, 284 } 285 ] 286 287 def test_from_dict(self, mock_check_valid_model): 288 tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) 289 290 generator = HuggingFaceAPIChatGenerator( 291 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 292 api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, 293 token=Secret.from_env_var("ENV_VAR", strict=False), 294 generation_kwargs={"temperature": 0.6}, 295 stop_words=["stop", "words"], 296 tools=[tool], 297 ) 298 result = generator.to_dict() 299 300 # now deserialize, call from_dict 301 generator_2 = HuggingFaceAPIChatGenerator.from_dict(result) 302 assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API 303 assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} 304 assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) 305 assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} 306 assert generator_2.streaming_callback is None 307 assert generator_2.tools == [tool] 308 309 def test_serde_in_pipeline(self, mock_check_valid_model): 310 tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) 311 312 generator = HuggingFaceAPIChatGenerator( 313 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 314 api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, 315 token=Secret.from_env_var("ENV_VAR", strict=False), 316 generation_kwargs={"temperature": 0.6}, 317 stop_words=["stop", "words"], 318 tools=[tool], 319 ) 320 321 pipeline = Pipeline() 322 pipeline.add_component("generator", generator) 323 324 pipeline_dict = pipeline.to_dict() 325 assert pipeline_dict == { 326 "metadata": {}, 327 "max_runs_per_component": 100, 328 "connection_type_validation": True, 329 "components": { 330 "generator": { 331 "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", 332 "init_parameters": { 333 "api_type": "serverless_inference_api", 334 "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, 335 "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, 336 "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, 337 "streaming_callback": None, 338 "tools": [ 339 { 340 "type": "haystack.tools.tool.Tool", 341 "data": { 342 "inputs_from_state": None, 343 "name": "name", 344 "outputs_to_state": None, 345 "outputs_to_string": None, 346 "description": "description", 347 "parameters": {"x": {"type": "string"}}, 348 "function": "builtins.print", 349 }, 350 } 351 ], 352 }, 353 } 354 }, 355 "connections": [], 356 } 357 358 pipeline_yaml = pipeline.dumps() 359 360 new_pipeline = Pipeline.loads(pipeline_yaml) 361 assert new_pipeline == pipeline 362 363 def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): 364 generator = HuggingFaceAPIChatGenerator( 365 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 366 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 367 generation_kwargs={"temperature": 0.6}, 368 stop_words=["stop", "words"], 369 streaming_callback=None, 370 ) 371 372 response = generator.run(messages=chat_messages) 373 374 # check kwargs passed to chat_completion 375 _, kwargs = mock_chat_completion.call_args 376 hf_messages = [ 377 {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, 378 {"role": "user", "content": "Tell me about Berlin"}, 379 ] 380 assert kwargs == { 381 "temperature": 0.6, 382 "stop": ["stop", "words"], 383 "max_tokens": 512, 384 "tools": None, 385 "messages": hf_messages, 386 } 387 388 assert isinstance(response, dict) 389 assert "replies" in response 390 assert isinstance(response["replies"], list) 391 assert len(response["replies"]) == 1 392 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 393 394 def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): 395 streaming_call_count = 0 396 397 # Define the streaming callback function 398 def streaming_callback_fn(chunk: StreamingChunk): 399 nonlocal streaming_call_count 400 streaming_call_count += 1 401 assert isinstance(chunk, StreamingChunk) 402 403 generator = HuggingFaceAPIChatGenerator( 404 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 405 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 406 streaming_callback=streaming_callback_fn, 407 ) 408 409 # Create a fake streamed response 410 # self needed here, don't remove 411 def mock_iter(self): 412 yield ChatCompletionStreamOutput( 413 choices=[ 414 ChatCompletionStreamOutputChoice( 415 delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), 416 index=0, 417 finish_reason=None, 418 ) 419 ], 420 id="some_id", 421 model="some_model", 422 system_fingerprint="some_fingerprint", 423 created=1710498504, 424 ) 425 426 yield ChatCompletionStreamOutput( 427 choices=[ 428 ChatCompletionStreamOutputChoice( 429 delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" 430 ) 431 ], 432 id="some_id", 433 model="some_model", 434 system_fingerprint="some_fingerprint", 435 created=1710498504, 436 ) 437 438 mock_response = Mock(__iter__=mock_iter) 439 mock_chat_completion.return_value = mock_response 440 441 # Generate text response with streaming callback 442 response = generator.run(chat_messages) 443 444 # check kwargs passed to text_generation 445 _, kwargs = mock_chat_completion.call_args 446 assert kwargs == { 447 "stop": [], 448 "stream": True, 449 "max_tokens": 512, 450 "stream_options": ChatCompletionInputStreamOptions(include_usage=True), 451 } 452 453 # Assert that the streaming callback was called twice 454 assert streaming_call_count == 2 455 456 # Assert that the response contains the generated replies 457 assert "replies" in response 458 assert isinstance(response["replies"], list) 459 assert len(response["replies"]) > 0 460 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 461 462 def test_run_with_streaming_callback_in_run_method( 463 self, mock_check_valid_model, mock_chat_completion, chat_messages 464 ): 465 streaming_call_count = 0 466 467 # Define the streaming callback function 468 def streaming_callback_fn(chunk: StreamingChunk): 469 nonlocal streaming_call_count 470 streaming_call_count += 1 471 assert isinstance(chunk, StreamingChunk) 472 473 generator = HuggingFaceAPIChatGenerator( 474 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 475 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 476 ) 477 478 # Create a fake streamed response 479 # self needed here, don't remove 480 def mock_iter(self): 481 yield ChatCompletionStreamOutput( 482 choices=[ 483 ChatCompletionStreamOutputChoice( 484 delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), 485 index=0, 486 finish_reason=None, 487 ) 488 ], 489 id="some_id", 490 model="some_model", 491 system_fingerprint="some_fingerprint", 492 created=1710498504, 493 ) 494 495 yield ChatCompletionStreamOutput( 496 choices=[ 497 ChatCompletionStreamOutputChoice( 498 delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" 499 ) 500 ], 501 id="some_id", 502 model="some_model", 503 system_fingerprint="some_fingerprint", 504 created=1710498504, 505 ) 506 507 mock_response = Mock(__iter__=mock_iter) 508 mock_chat_completion.return_value = mock_response 509 510 # Generate text response with streaming callback 511 response = generator.run(chat_messages, streaming_callback=streaming_callback_fn) 512 513 # check kwargs passed to text_generation 514 _, kwargs = mock_chat_completion.call_args 515 assert kwargs == { 516 "stop": [], 517 "stream": True, 518 "max_tokens": 512, 519 "stream_options": ChatCompletionInputStreamOptions(include_usage=True), 520 } 521 522 # Assert that the streaming callback was called twice 523 assert streaming_call_count == 2 524 525 # Assert that the response contains the generated replies 526 assert "replies" in response 527 assert isinstance(response["replies"], list) 528 assert len(response["replies"]) > 0 529 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 530 531 def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): 532 component = HuggingFaceAPIChatGenerator( 533 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 534 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 535 streaming_callback=streaming_callback_handler, 536 ) 537 538 with pytest.raises(ValueError): 539 message = ChatMessage.from_user("irrelevant") 540 component.run([message], tools=tools) 541 542 def test_run_with_tools(self, mock_check_valid_model, tools): 543 generator = HuggingFaceAPIChatGenerator( 544 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 545 api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, 546 tools=tools, 547 ) 548 549 with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: 550 completion = ChatCompletionOutput( 551 choices=[ 552 ChatCompletionOutputComplete( 553 finish_reason="stop", 554 index=0, 555 message=ChatCompletionOutputMessage( 556 role="assistant", 557 content=None, 558 tool_calls=[ 559 ChatCompletionOutputToolCall( 560 function=ChatCompletionOutputFunctionDefinition( 561 arguments={"city": "Paris"}, name="weather", description=None 562 ), 563 id="0", 564 type="function", 565 ) 566 ], 567 ), 568 logprobs=None, 569 ) 570 ], 571 created=1729074760, 572 id="", 573 model="meta-llama/Llama-3.1-70B-Instruct", 574 system_fingerprint="2.3.2-dev0-sha-28bb7ae", 575 usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), 576 ) 577 mock_chat_completion.return_value = completion 578 579 messages = [ChatMessage.from_user("What is the weather in Paris?")] 580 response = generator.run(messages=messages) 581 582 assert isinstance(response, dict) 583 assert "replies" in response 584 assert isinstance(response["replies"], list) 585 assert len(response["replies"]) == 1 586 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 587 assert response["replies"][0].tool_calls[0].tool_name == "weather" 588 assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} 589 assert response["replies"][0].tool_calls[0].id == "0" 590 assert response["replies"][0].meta == { 591 "finish_reason": "tool_calls", 592 "index": 0, 593 "model": "meta-llama/Llama-3.1-70B-Instruct", 594 "usage": {"completion_tokens": 30, "prompt_tokens": 426}, 595 } 596 597 def test_convert_hfapi_tool_calls_empty(self): 598 hfapi_tool_calls = None 599 tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) 600 assert len(tool_calls) == 0 601 602 hfapi_tool_calls = [] 603 tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) 604 assert len(tool_calls) == 0 605 606 def test_convert_hfapi_tool_calls_dict_arguments(self): 607 hfapi_tool_calls = [ 608 ChatCompletionOutputToolCall( 609 function=ChatCompletionOutputFunctionDefinition( 610 arguments={"city": "Paris"}, name="weather", description=None 611 ), 612 id="0", 613 type="function", 614 ) 615 ] 616 tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) 617 assert len(tool_calls) == 1 618 assert tool_calls[0].tool_name == "weather" 619 assert tool_calls[0].arguments == {"city": "Paris"} 620 assert tool_calls[0].id == "0" 621 622 def test_convert_hfapi_tool_calls_str_arguments(self): 623 hfapi_tool_calls = [ 624 ChatCompletionOutputToolCall( 625 function=ChatCompletionOutputFunctionDefinition( 626 arguments='{"city": "Paris"}', name="weather", description=None 627 ), 628 id="0", 629 type="function", 630 ) 631 ] 632 tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) 633 assert len(tool_calls) == 1 634 assert tool_calls[0].tool_name == "weather" 635 assert tool_calls[0].arguments == {"city": "Paris"} 636 assert tool_calls[0].id == "0" 637 638 def test_convert_hfapi_tool_calls_invalid_str_arguments(self): 639 hfapi_tool_calls = [ 640 ChatCompletionOutputToolCall( 641 function=ChatCompletionOutputFunctionDefinition( 642 arguments="not a valid JSON string", name="weather", description=None 643 ), 644 id="0", 645 type="function", 646 ) 647 ] 648 tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) 649 assert len(tool_calls) == 0 650 651 def test_convert_hfapi_tool_calls_invalid_type_arguments(self): 652 hfapi_tool_calls = [ 653 ChatCompletionOutputToolCall( 654 function=ChatCompletionOutputFunctionDefinition( 655 arguments=["this", "is", "a", "list"], name="weather", description=None 656 ), 657 id="0", 658 type="function", 659 ) 660 ] 661 tool_calls = _convert_hfapi_tool_calls(hfapi_tool_calls) 662 assert len(tool_calls) == 0 663 664 @pytest.mark.parametrize( 665 "hf_stream_output, expected_stream_chunk, dummy_previous_chunks", 666 [ 667 ( 668 ChatCompletionStreamOutput( 669 choices=[ 670 ChatCompletionStreamOutputChoice( 671 delta=ChatCompletionStreamOutputDelta(role="assistant", content=" Paris"), index=0 672 ) 673 ], 674 created=1748339326, 675 id="", 676 model="microsoft/Phi-3.5-mini-instruct", 677 system_fingerprint="3.2.1-sha-4d28897", 678 ), 679 StreamingChunk( 680 content=" Paris", 681 meta={ 682 "received_at": "2025-05-27T12:14:28.228852", 683 "model": "microsoft/Phi-3.5-mini-instruct", 684 "finish_reason": None, 685 }, 686 index=0, 687 start=True, 688 ), 689 [], 690 ), 691 ( 692 ChatCompletionStreamOutput( 693 choices=[ 694 ChatCompletionStreamOutputChoice( 695 delta=ChatCompletionStreamOutputDelta(role="assistant", content=""), 696 index=0, 697 finish_reason="stop", 698 ) 699 ], 700 created=1748339326, 701 id="", 702 model="microsoft/Phi-3.5-mini-instruct", 703 system_fingerprint="3.2.1-sha-4d28897", 704 ), 705 StreamingChunk( 706 content="", 707 meta={ 708 "received_at": "2025-05-27T12:14:28.228852", 709 "model": "microsoft/Phi-3.5-mini-instruct", 710 "finish_reason": "stop", 711 }, 712 finish_reason="stop", 713 ), 714 [0], 715 ), 716 ( 717 ChatCompletionStreamOutput( 718 choices=[], 719 created=1748339326, 720 id="", 721 model="microsoft/Phi-3.5-mini-instruct", 722 system_fingerprint="3.2.1-sha-4d28897", 723 usage=ChatCompletionStreamOutputUsage(completion_tokens=2, prompt_tokens=21, total_tokens=23), 724 ), 725 StreamingChunk( 726 content="", 727 meta={ 728 "received_at": "2025-05-27T12:14:28.228852", 729 "model": "microsoft/Phi-3.5-mini-instruct", 730 "usage": {"completion_tokens": 2, "prompt_tokens": 21}, 731 }, 732 ), 733 [0, 1], 734 ), 735 ], 736 ) 737 def test_convert_chat_completion_stream_output_to_streaming_chunk( 738 self, hf_stream_output, expected_stream_chunk, dummy_previous_chunks 739 ): 740 converted_stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk( 741 chunk=hf_stream_output, previous_chunks=dummy_previous_chunks 742 ) 743 # Remove timestamp from comparison since it's always the current time 744 converted_stream_chunk.meta.pop("received_at", None) 745 expected_stream_chunk.meta.pop("received_at", None) 746 assert converted_stream_chunk == expected_stream_chunk 747 748 @pytest.mark.integration 749 @pytest.mark.slow 750 @pytest.mark.skipif( 751 not os.environ.get("HF_API_TOKEN", None), 752 reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", 753 ) 754 @pytest.mark.flaky(reruns=2, reruns_delay=10) 755 def test_live_run_serverless(self): 756 generator = HuggingFaceAPIChatGenerator( 757 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 758 api_params={"model": "Qwen/Qwen2.5-7B-Instruct", "provider": "together"}, 759 generation_kwargs={"max_tokens": 20}, 760 ) 761 762 # No need for instruction tokens here since we use the chat_completion endpoint which handles the chat 763 # templating for us. 764 messages = [ 765 ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.") 766 ] 767 response = generator.run(messages=messages) 768 769 assert "replies" in response 770 assert isinstance(response["replies"], list) 771 assert len(response["replies"]) > 0 772 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 773 assert response["replies"][0].text is not None 774 meta = response["replies"][0].meta 775 assert "usage" in meta 776 assert "prompt_tokens" in meta["usage"] 777 assert meta["usage"]["prompt_tokens"] > 0 778 assert "completion_tokens" in meta["usage"] 779 assert meta["usage"]["completion_tokens"] > 0 780 assert meta["model"] == "Qwen/Qwen2.5-7B-Instruct" 781 assert meta["finish_reason"] is not None 782 783 @pytest.mark.integration 784 @pytest.mark.slow 785 @pytest.mark.skipif( 786 not os.environ.get("HF_API_TOKEN", None), 787 reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", 788 ) 789 @pytest.mark.flaky(reruns=2, reruns_delay=10) 790 def test_live_run_serverless_streaming(self): 791 generator = HuggingFaceAPIChatGenerator( 792 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 793 api_params={"model": "Qwen/Qwen2.5-7B-Instruct", "provider": "together"}, 794 generation_kwargs={"max_tokens": 20}, 795 streaming_callback=streaming_callback_handler, 796 ) 797 798 # No need for instruction tokens here since we use the chat_completion endpoint which handles the chat 799 # templating for us. 800 messages = [ 801 ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.") 802 ] 803 response = generator.run(messages=messages) 804 805 print(response) 806 807 assert "replies" in response 808 assert isinstance(response["replies"], list) 809 assert len(response["replies"]) > 0 810 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 811 assert response["replies"][0].text is not None 812 813 response_meta = response["replies"][0].meta 814 assert "completion_start_time" in response_meta 815 assert datetime.fromisoformat(response_meta["completion_start_time"]) <= datetime.now() 816 assert "usage" in response_meta 817 assert "prompt_tokens" in response_meta["usage"] 818 assert response_meta["usage"]["prompt_tokens"] >= 0 819 assert "completion_tokens" in response_meta["usage"] 820 assert response_meta["usage"]["completion_tokens"] >= 0 821 # internally, Together calls this "Qwen/Qwen2.5-7B-Instruct-Turbo" 822 assert "Qwen/Qwen2.5-7B-Instruct" in response_meta["model"] 823 assert response_meta["finish_reason"] is not None 824 825 @pytest.mark.integration 826 @pytest.mark.slow 827 @pytest.mark.skipif( 828 not os.environ.get("HF_API_TOKEN", None), 829 reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", 830 ) 831 def test_live_run_with_tools(self, tools): 832 """ 833 We test the round trip: generate tool call, pass tool message, generate response. 834 835 The model used here (Qwen/Qwen3-VL-30B-A3B-Instruct) is not gated and kept in a warm state. 836 """ 837 838 chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] 839 generator = HuggingFaceAPIChatGenerator( 840 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 841 api_params={"model": "Qwen/Qwen3-VL-30B-A3B-Instruct", "provider": "fireworks-ai"}, 842 generation_kwargs={"temperature": 0.5}, 843 ) 844 845 results = generator.run(chat_messages, tools=tools) 846 assert len(results["replies"]) == 1 847 message = results["replies"][0] 848 849 assert message.tool_calls 850 tool_call = message.tool_call 851 assert isinstance(tool_call, ToolCall) 852 assert tool_call.tool_name == "weather" 853 assert "city" in tool_call.arguments 854 assert "Paris" in tool_call.arguments["city"] 855 assert message.meta["finish_reason"] == "tool_calls" 856 857 new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] 858 859 # the model tends to make tool calls if provided with tools, so we don't pass them here 860 results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) 861 862 assert len(results["replies"]) == 1 863 final_message = results["replies"][0] 864 assert not final_message.tool_calls 865 assert len(final_message.text) > 0 866 assert "paris" in final_message.text.lower() and "22" in final_message.text 867 868 @pytest.mark.integration 869 @pytest.mark.slow 870 @pytest.mark.skipif( 871 not os.environ.get("HF_API_TOKEN", None), 872 reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", 873 ) 874 def test_live_run_multimodal(self, test_files_path): 875 image_path = test_files_path / "images" / "apple.jpg" 876 # Resize the image to keep this test fast 877 image_content = ImageContent.from_file_path(file_path=image_path, size=(100, 100)) 878 messages = [ChatMessage.from_user(content_parts=["What does this image show? Max 5 words", image_content])] 879 880 generator = HuggingFaceAPIChatGenerator( 881 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 882 api_params={"model": "Qwen/Qwen3-VL-30B-A3B-Instruct", "provider": "fireworks-ai"}, 883 generation_kwargs={"max_tokens": 20}, 884 ) 885 886 response = generator.run(messages=messages) 887 888 assert "replies" in response 889 assert isinstance(response["replies"], list) 890 assert len(response["replies"]) > 0 891 message = response["replies"][0] 892 assert message.text 893 assert len(message.text) > 0 894 assert any(word in message.text.lower() for word in ["apple", "fruit", "red"]) 895 896 @pytest.mark.asyncio 897 async def test_run_async(self, mock_check_valid_model, mock_chat_completion_async, chat_messages): 898 generator = HuggingFaceAPIChatGenerator( 899 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 900 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 901 generation_kwargs={"temperature": 0.6}, 902 stop_words=["stop", "words"], 903 streaming_callback=None, 904 ) 905 906 response = await generator.run_async(messages=chat_messages) 907 908 # check kwargs passed to chat_completion 909 _, kwargs = mock_chat_completion_async.call_args 910 hf_messages = [ 911 {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, 912 {"role": "user", "content": "Tell me about Berlin"}, 913 ] 914 assert kwargs == { 915 "temperature": 0.6, 916 "stop": ["stop", "words"], 917 "max_tokens": 512, 918 "tools": None, 919 "messages": hf_messages, 920 } 921 922 assert isinstance(response, dict) 923 assert "replies" in response 924 assert isinstance(response["replies"], list) 925 assert len(response["replies"]) == 1 926 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 927 928 @pytest.mark.asyncio 929 async def test_run_async_with_streaming(self, mock_check_valid_model, mock_chat_completion_async, chat_messages): 930 streaming_call_count = 0 931 932 async def streaming_callback_fn(chunk: StreamingChunk): 933 nonlocal streaming_call_count 934 streaming_call_count += 1 935 assert isinstance(chunk, StreamingChunk) 936 937 # Create a fake streamed response 938 async def mock_aiter(self): 939 yield ChatCompletionStreamOutput( 940 choices=[ 941 ChatCompletionStreamOutputChoice( 942 delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), 943 index=0, 944 finish_reason=None, 945 ) 946 ], 947 id="some_id", 948 model="some_model", 949 system_fingerprint="some_fingerprint", 950 created=1710498504, 951 ) 952 953 yield ChatCompletionStreamOutput( 954 choices=[ 955 ChatCompletionStreamOutputChoice( 956 delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" 957 ) 958 ], 959 id="some_id", 960 model="some_model", 961 system_fingerprint="some_fingerprint", 962 created=1710498504, 963 ) 964 965 mock_response = Mock(__aiter__=mock_aiter) 966 mock_chat_completion_async.return_value = mock_response 967 968 generator = HuggingFaceAPIChatGenerator( 969 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 970 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 971 streaming_callback=streaming_callback_fn, 972 ) 973 974 response = await generator.run_async(messages=chat_messages) 975 976 # check kwargs passed to chat_completion 977 _, kwargs = mock_chat_completion_async.call_args 978 assert kwargs == { 979 "stop": [], 980 "stream": True, 981 "max_tokens": 512, 982 "stream_options": ChatCompletionInputStreamOptions(include_usage=True), 983 } 984 985 # Assert that the streaming callback was called twice 986 assert streaming_call_count == 2 987 988 # Assert that the response contains the generated replies 989 assert "replies" in response 990 assert isinstance(response["replies"], list) 991 assert len(response["replies"]) > 0 992 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 993 994 @pytest.mark.asyncio 995 async def test_run_async_with_tools(self, tools, mock_check_valid_model): 996 generator = HuggingFaceAPIChatGenerator( 997 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 998 api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, 999 tools=tools, 1000 ) 1001 1002 with patch("huggingface_hub.AsyncInferenceClient.chat_completion", autospec=True) as mock_chat_completion_async: 1003 completion = ChatCompletionOutput( 1004 choices=[ 1005 ChatCompletionOutputComplete( 1006 finish_reason="stop", 1007 index=0, 1008 message=ChatCompletionOutputMessage( 1009 role="assistant", 1010 content=None, 1011 tool_calls=[ 1012 ChatCompletionOutputToolCall( 1013 function=ChatCompletionOutputFunctionDefinition( 1014 arguments={"city": "Paris"}, name="weather", description=None 1015 ), 1016 id="0", 1017 type="function", 1018 ) 1019 ], 1020 ), 1021 logprobs=None, 1022 ) 1023 ], 1024 created=1729074760, 1025 id="", 1026 model="meta-llama/Llama-3.1-70B-Instruct", 1027 system_fingerprint="2.3.2-dev0-sha-28bb7ae", 1028 usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), 1029 ) 1030 mock_chat_completion_async.return_value = completion 1031 1032 messages = [ChatMessage.from_user("What is the weather in Paris?")] 1033 response = await generator.run_async(messages=messages) 1034 1035 assert isinstance(response, dict) 1036 assert "replies" in response 1037 assert isinstance(response["replies"], list) 1038 assert len(response["replies"]) == 1 1039 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 1040 assert response["replies"][0].tool_calls[0].tool_name == "weather" 1041 assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} 1042 assert response["replies"][0].tool_calls[0].id == "0" 1043 assert response["replies"][0].meta == { 1044 "finish_reason": "tool_calls", 1045 "index": 0, 1046 "model": "meta-llama/Llama-3.1-70B-Instruct", 1047 "usage": {"completion_tokens": 30, "prompt_tokens": 426}, 1048 } 1049 1050 @pytest.mark.integration 1051 @pytest.mark.slow 1052 @pytest.mark.skipif( 1053 not os.environ.get("HF_API_TOKEN", None), 1054 reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", 1055 ) 1056 @pytest.mark.flaky(reruns=2, reruns_delay=10) 1057 @pytest.mark.asyncio 1058 async def test_live_run_async_serverless(self): 1059 generator = HuggingFaceAPIChatGenerator( 1060 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1061 api_params={"model": "Qwen/Qwen2.5-7B-Instruct", "provider": "together"}, 1062 generation_kwargs={"max_tokens": 20}, 1063 ) 1064 1065 messages = [ 1066 ChatMessage.from_user("What is the capital of France? Be concise only provide the capital, nothing else.") 1067 ] 1068 try: 1069 response = await generator.run_async(messages=messages) 1070 1071 assert "replies" in response 1072 assert isinstance(response["replies"], list) 1073 assert len(response["replies"]) > 0 1074 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] 1075 assert response["replies"][0].text is not None 1076 1077 meta = response["replies"][0].meta 1078 assert "usage" in meta 1079 assert "prompt_tokens" in meta["usage"] 1080 assert meta["usage"]["prompt_tokens"] > 0 1081 assert "completion_tokens" in meta["usage"] 1082 assert meta["usage"]["completion_tokens"] > 0 1083 assert meta["model"] == "Qwen/Qwen2.5-7B-Instruct" 1084 assert meta["finish_reason"] is not None 1085 finally: 1086 await generator._async_client.close() 1087 1088 @pytest.mark.integration 1089 @pytest.mark.slow 1090 @pytest.mark.skipif( 1091 not os.environ.get("HF_API_TOKEN", None), 1092 reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", 1093 ) 1094 @pytest.mark.flaky(reruns=2, reruns_delay=10) 1095 def test_live_run_multi_turn_with_reasoning_model(self): 1096 """ 1097 Test multi-turn conversation with a reasoning model. 1098 1099 This test verifies that: 1100 1. Reasoning content is captured from the model's response 1101 2. When the assistant message (with reasoning) is sent back in a multi-turn conversation, 1102 the API call succeeds (reasoning is dropped during conversion since HF API doesn't support it) 1103 """ 1104 # Note: Using a model that supports reasoning AND a provider that actually follows the spec defined in 1105 # huggingface-hub. Reasoning content especially seems to be non-standard across providers and is either left 1106 # in the main response or put in a new field that is not part of the official API. 1107 # One combo that does respect the spec is together + openai/gpt-oss-20b. 1108 # together + openai/gpt-oss-20b actually uses the expected reasoning field in the response 1109 generator = HuggingFaceAPIChatGenerator( 1110 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1111 # We use together + openai/gpt-oss-20b since it actually returns reasoning content in the expected field 1112 api_params={"model": "openai/gpt-oss-20b", "provider": "together"}, 1113 generation_kwargs={"max_tokens": 300}, 1114 ) 1115 1116 # First turn: ask a question 1117 messages = [ChatMessage.from_user("What is 2 + 2? Answer briefly.")] 1118 response = generator.run(messages=messages) 1119 1120 assert "replies" in response 1121 assert len(response["replies"]) > 0 1122 first_reply = response["replies"][0] 1123 assert first_reply.text is not None 1124 assert first_reply.reasoning is not None 1125 1126 # Second turn: send a follow-up including the assistant's previous response 1127 # This tests that convert_message_to_hf_format properly handles messages 1128 # that may contain ReasoningContent (it should skip it) 1129 follow_up_messages = [ 1130 ChatMessage.from_user("What is 2 + 2? Answer briefly."), 1131 first_reply, # Include the assistant's response with reasoning 1132 ChatMessage.from_user("Now what is 3 + 3? Answer briefly."), 1133 ] 1134 follow_up_response = generator.run(messages=follow_up_messages) 1135 1136 # Verify the second turn succeeds 1137 assert "replies" in follow_up_response 1138 assert len(follow_up_response["replies"]) > 0 1139 assert follow_up_response["replies"][0].text is not None 1140 assert follow_up_response["replies"][0].reasoning is not None 1141 1142 def test_hugging_face_api_generator_with_toolset_initialization(self, mock_check_valid_model, tools): 1143 """Test that the HuggingFaceAPIChatGenerator can be initialized with a Toolset.""" 1144 toolset = Toolset(tools) 1145 generator = HuggingFaceAPIChatGenerator( 1146 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset 1147 ) 1148 assert generator.tools == toolset 1149 1150 def test_from_dict_with_toolset(self, mock_check_valid_model, tools): 1151 """Test that the HuggingFaceAPIChatGenerator can be deserialized from a dictionary with a Toolset.""" 1152 toolset = Toolset(tools) 1153 component = HuggingFaceAPIChatGenerator( 1154 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset 1155 ) 1156 data = component.to_dict() 1157 1158 deserialized_component = HuggingFaceAPIChatGenerator.from_dict(data) 1159 1160 assert isinstance(deserialized_component.tools, Toolset) 1161 assert len(deserialized_component.tools) == len(tools) 1162 assert all(isinstance(tool, Tool) for tool in deserialized_component.tools) 1163 1164 def test_to_dict_with_toolset(self, mock_check_valid_model, tools): 1165 """Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset.""" 1166 toolset = Toolset(tools[:1]) 1167 generator = HuggingFaceAPIChatGenerator( 1168 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset 1169 ) 1170 data = generator.to_dict() 1171 1172 expected_tools_data = { 1173 "type": "haystack.tools.toolset.Toolset", 1174 "data": { 1175 "tools": [ 1176 { 1177 "type": "haystack.tools.tool.Tool", 1178 "data": { 1179 "name": "weather", 1180 "description": "useful to determine the weather in a given location", 1181 "parameters": { 1182 "type": "object", 1183 "properties": {"city": {"type": "string"}}, 1184 "required": ["city"], 1185 }, 1186 "function": "generators.chat.test_hugging_face_api.get_weather", 1187 "outputs_to_string": None, 1188 "inputs_from_state": None, 1189 "outputs_to_state": None, 1190 }, 1191 } 1192 ] 1193 }, 1194 } 1195 assert data["init_parameters"]["tools"] == expected_tools_data 1196 1197 def test_convert_tools_to_hfapi_tools(self): 1198 assert _convert_tools_to_hfapi_tools(None) is None 1199 assert _convert_tools_to_hfapi_tools([]) is None 1200 1201 tool = Tool( 1202 name="weather", 1203 description="useful to determine the weather in a given location", 1204 parameters={"city": {"type": "string"}}, 1205 function=get_weather, 1206 ) 1207 hf_tools = _convert_tools_to_hfapi_tools([tool]) 1208 assert len(hf_tools) == 1 1209 assert hf_tools[0].type == "function" 1210 assert hf_tools[0].function.name == "weather" 1211 assert hf_tools[0].function.description == "useful to determine the weather in a given location" 1212 assert hf_tools[0].function.parameters == {"city": {"type": "string"}} 1213 1214 def test_convert_tools_to_hfapi_tools_legacy(self): 1215 # this satisfies the check hasattr(ChatCompletionInputFunctionDefinition, "arguments") 1216 mock_class = MagicMock() 1217 1218 with patch( 1219 "haystack.components.generators.chat.hugging_face_api.ChatCompletionInputFunctionDefinition", mock_class 1220 ): 1221 tool = Tool( 1222 name="weather", 1223 description="useful to determine the weather in a given location", 1224 parameters={"city": {"type": "string"}}, 1225 function=get_weather, 1226 ) 1227 _convert_tools_to_hfapi_tools([tool]) 1228 1229 mock_class.assert_called_once_with( 1230 name="weather", 1231 arguments={"city": {"type": "string"}}, 1232 description="useful to determine the weather in a given location", 1233 ) 1234 1235 def test_warm_up_with_tools(self, mock_check_valid_model): 1236 """Test that warm_up() calls warm_up on tools and is idempotent.""" 1237 1238 # Create a mock tool that tracks if warm_up() was called 1239 class MockTool(Tool): 1240 warm_up_call_count = 0 # Class variable to track calls 1241 1242 def __init__(self): 1243 super().__init__( 1244 name="mock_tool", 1245 description="A mock tool for testing", 1246 parameters={"x": {"type": "string"}}, 1247 function=lambda x: x, 1248 ) 1249 1250 def warm_up(self): 1251 MockTool.warm_up_call_count += 1 1252 1253 # Reset the class variable before test 1254 MockTool.warm_up_call_count = 0 1255 mock_tool = MockTool() 1256 1257 # Create HuggingFaceAPIChatGenerator with the mock tool 1258 component = HuggingFaceAPIChatGenerator( 1259 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1260 api_params={"model": "HuggingFaceH4/zephyr-7b-alpha"}, 1261 tools=[mock_tool], 1262 ) 1263 1264 # Verify initial state - warm_up not called yet 1265 assert MockTool.warm_up_call_count == 0 1266 assert not component._is_warmed_up 1267 1268 # Call warm_up() on the generator 1269 component.warm_up() 1270 1271 # Assert that the tool's warm_up() was called 1272 assert MockTool.warm_up_call_count == 1 1273 assert component._is_warmed_up 1274 1275 # Call warm_up() again and verify it's idempotent (only warms up once) 1276 component.warm_up() 1277 1278 # The tool's warm_up should still only have been called once 1279 assert MockTool.warm_up_call_count == 1 1280 assert component._is_warmed_up 1281 1282 def test_warm_up_with_no_tools(self, mock_check_valid_model): 1283 """Test that warm_up() works when no tools are provided.""" 1284 component = HuggingFaceAPIChatGenerator( 1285 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-alpha"} 1286 ) 1287 1288 # Verify initial state 1289 assert not component._is_warmed_up 1290 assert component.tools is None 1291 1292 # Call warm_up() - should not raise an error 1293 component.warm_up() 1294 1295 # Verify the component is warmed up 1296 assert component._is_warmed_up 1297 1298 # Call warm_up() again - should be idempotent 1299 component.warm_up() 1300 assert component._is_warmed_up 1301 1302 def test_warm_up_with_multiple_tools(self, mock_check_valid_model): 1303 """Test that warm_up() works with multiple tools.""" 1304 # Track warm_up calls 1305 warm_up_calls = [] 1306 1307 class MockTool(Tool): 1308 def __init__(self, tool_name): 1309 super().__init__( 1310 name=tool_name, 1311 description=f"Mock tool {tool_name}", 1312 parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, 1313 function=lambda x: f"{tool_name} result: {x}", 1314 ) 1315 1316 def warm_up(self): 1317 warm_up_calls.append(self.name) 1318 1319 mock_tool1 = MockTool("tool1") 1320 mock_tool2 = MockTool("tool2") 1321 1322 # Use a LIST of tools, not a Toolset 1323 component = HuggingFaceAPIChatGenerator( 1324 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1325 api_params={"model": "HuggingFaceH4/zephyr-7b-alpha"}, 1326 tools=[mock_tool1, mock_tool2], 1327 ) 1328 1329 # Call warm_up() 1330 component.warm_up() 1331 1332 # Assert that both tools' warm_up() were called 1333 assert "tool1" in warm_up_calls 1334 assert "tool2" in warm_up_calls 1335 assert component._is_warmed_up 1336 1337 # Track count 1338 call_count = len(warm_up_calls) 1339 1340 # Verify idempotency 1341 component.warm_up() 1342 assert len(warm_up_calls) == call_count 1343 1344 def test_run_with_reasoning_non_streaming(self, mock_check_valid_model, chat_messages): 1345 """Test that reasoning content is correctly extracted from non-streaming responses.""" 1346 with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: 1347 reasoning_text = "Let me think about this. France is a country in Europe. Its capital city is Paris." 1348 completion = ChatCompletionOutput( 1349 choices=[ 1350 ChatCompletionOutputComplete( 1351 finish_reason="eos_token", 1352 index=0, 1353 message=ChatCompletionOutputMessage( 1354 content="The capital of France is Paris.", role="assistant", reasoning=reasoning_text 1355 ), 1356 ) 1357 ], 1358 id="some_id", 1359 model="some_model", 1360 system_fingerprint="some_fingerprint", 1361 usage=ChatCompletionOutputUsage(completion_tokens=20, prompt_tokens=17, total_tokens=37), 1362 created=1710498360, 1363 ) 1364 mock_chat_completion.return_value = completion 1365 1366 generator = HuggingFaceAPIChatGenerator( 1367 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1368 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 1369 ) 1370 1371 response = generator.run(chat_messages) 1372 1373 assert "replies" in response 1374 assert len(response["replies"]) == 1 1375 reply = response["replies"][0] 1376 assert reply.text == "The capital of France is Paris." 1377 assert reply.reasoning is not None 1378 assert isinstance(reply.reasoning, ReasoningContent) 1379 assert reply.reasoning.reasoning_text == reasoning_text 1380 1381 def test_run_without_reasoning_non_streaming(self, mock_check_valid_model, mock_chat_completion, chat_messages): 1382 """Test that responses without reasoning work correctly (backward compatibility).""" 1383 generator = HuggingFaceAPIChatGenerator( 1384 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1385 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 1386 ) 1387 1388 response = generator.run(chat_messages) 1389 1390 assert "replies" in response 1391 assert len(response["replies"]) == 1 1392 reply = response["replies"][0] 1393 assert reply.text == "The capital of France is Paris." 1394 assert reply.reasoning is None 1395 1396 def test_run_with_reasoning_streaming(self, mock_check_valid_model, chat_messages): 1397 """Test that reasoning content is correctly extracted from streaming responses.""" 1398 streaming_chunks_received = [] 1399 1400 def streaming_callback_fn(chunk: StreamingChunk): 1401 streaming_chunks_received.append(chunk) 1402 1403 with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: 1404 # Create a fake streamed response with reasoning 1405 def mock_iter(self): 1406 # First chunk with reasoning 1407 yield ChatCompletionStreamOutput( 1408 choices=[ 1409 ChatCompletionStreamOutputChoice( 1410 delta=ChatCompletionStreamOutputDelta( 1411 role="assistant", content=None, reasoning="Let me think..." 1412 ), 1413 index=0, 1414 finish_reason=None, 1415 ) 1416 ], 1417 id="some_id", 1418 model="some_model", 1419 system_fingerprint="some_fingerprint", 1420 created=1710498504, 1421 ) 1422 # Second chunk with more reasoning 1423 yield ChatCompletionStreamOutput( 1424 choices=[ 1425 ChatCompletionStreamOutputChoice( 1426 delta=ChatCompletionStreamOutputDelta( 1427 role=None, content=None, reasoning=" The capital of France is Paris." 1428 ), 1429 index=0, 1430 finish_reason=None, 1431 ) 1432 ], 1433 id="some_id", 1434 model="some_model", 1435 system_fingerprint="some_fingerprint", 1436 created=1710498504, 1437 ) 1438 # Third chunk with actual content 1439 yield ChatCompletionStreamOutput( 1440 choices=[ 1441 ChatCompletionStreamOutputChoice( 1442 delta=ChatCompletionStreamOutputDelta(role=None, content="Paris", reasoning=None), 1443 index=0, 1444 finish_reason=None, 1445 ) 1446 ], 1447 id="some_id", 1448 model="some_model", 1449 system_fingerprint="some_fingerprint", 1450 created=1710498504, 1451 ) 1452 # Final chunk with finish reason 1453 yield ChatCompletionStreamOutput( 1454 choices=[ 1455 ChatCompletionStreamOutputChoice( 1456 delta=ChatCompletionStreamOutputDelta(role=None, content=None, reasoning=None), 1457 index=0, 1458 finish_reason="stop", 1459 ) 1460 ], 1461 id="some_id", 1462 model="some_model", 1463 system_fingerprint="some_fingerprint", 1464 created=1710498504, 1465 ) 1466 1467 mock_response = Mock(__iter__=mock_iter) 1468 mock_chat_completion.return_value = mock_response 1469 1470 generator = HuggingFaceAPIChatGenerator( 1471 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1472 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 1473 streaming_callback=streaming_callback_fn, 1474 ) 1475 1476 response = generator.run(chat_messages) 1477 1478 # Check streaming chunks received with reasoning 1479 assert len(streaming_chunks_received) == 4 1480 assert streaming_chunks_received[0].reasoning is not None 1481 assert streaming_chunks_received[0].reasoning.reasoning_text == "Let me think..." 1482 assert streaming_chunks_received[1].reasoning is not None 1483 assert streaming_chunks_received[1].reasoning.reasoning_text == " The capital of France is Paris." 1484 1485 # Check final message 1486 assert "replies" in response 1487 assert len(response["replies"]) == 1 1488 reply = response["replies"][0] 1489 assert reply.text == "Paris" 1490 assert reply.reasoning is not None 1491 assert isinstance(reply.reasoning, ReasoningContent) 1492 assert reply.reasoning.reasoning_text == "Let me think... The capital of France is Paris." 1493 1494 @pytest.mark.asyncio 1495 async def test_run_async_with_reasoning_non_streaming(self, mock_check_valid_model, chat_messages): 1496 """Test that reasoning content is correctly extracted from async non-streaming responses.""" 1497 with patch( 1498 "huggingface_hub.AsyncInferenceClient.chat_completion", new_callable=AsyncMock 1499 ) as mock_chat_completion: 1500 completion = ChatCompletionOutput( 1501 choices=[ 1502 ChatCompletionOutputComplete( 1503 finish_reason="eos_token", 1504 index=0, 1505 message=ChatCompletionOutputMessage( 1506 content="The capital of France is Paris.", 1507 role="assistant", 1508 reasoning="Let me reason about this question step by step.", 1509 ), 1510 ) 1511 ], 1512 id="some_id", 1513 model="some_model", 1514 system_fingerprint="some_fingerprint", 1515 usage=ChatCompletionOutputUsage(completion_tokens=20, prompt_tokens=17, total_tokens=37), 1516 created=1710498360, 1517 ) 1518 mock_chat_completion.return_value = completion 1519 1520 generator = HuggingFaceAPIChatGenerator( 1521 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1522 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 1523 ) 1524 1525 response = await generator.run_async(chat_messages) 1526 1527 assert "replies" in response 1528 assert len(response["replies"]) == 1 1529 reply = response["replies"][0] 1530 assert reply.text == "The capital of France is Paris." 1531 assert reply.reasoning is not None 1532 assert isinstance(reply.reasoning, ReasoningContent) 1533 assert reply.reasoning.reasoning_text == "Let me reason about this question step by step." 1534 1535 @pytest.mark.asyncio 1536 async def test_run_async_with_reasoning_streaming(self, mock_check_valid_model, chat_messages): 1537 """Test that reasoning content is correctly extracted from async streaming responses.""" 1538 streaming_chunks_received = [] 1539 1540 async def streaming_callback_fn(chunk: StreamingChunk): 1541 streaming_chunks_received.append(chunk) 1542 1543 with patch( 1544 "huggingface_hub.AsyncInferenceClient.chat_completion", new_callable=AsyncMock 1545 ) as mock_chat_completion: 1546 # Create async iterable for streaming 1547 async def mock_aiter(): 1548 # First chunk with reasoning 1549 yield ChatCompletionStreamOutput( 1550 choices=[ 1551 ChatCompletionStreamOutputChoice( 1552 delta=ChatCompletionStreamOutputDelta( 1553 role="assistant", content=None, reasoning="Thinking..." 1554 ), 1555 index=0, 1556 finish_reason=None, 1557 ) 1558 ], 1559 id="some_id", 1560 model="some_model", 1561 system_fingerprint="some_fingerprint", 1562 created=1710498504, 1563 ) 1564 # Second chunk with content 1565 yield ChatCompletionStreamOutput( 1566 choices=[ 1567 ChatCompletionStreamOutputChoice( 1568 delta=ChatCompletionStreamOutputDelta(role=None, content="Paris", reasoning=None), 1569 index=0, 1570 finish_reason=None, 1571 ) 1572 ], 1573 id="some_id", 1574 model="some_model", 1575 system_fingerprint="some_fingerprint", 1576 created=1710498504, 1577 ) 1578 # Final chunk 1579 yield ChatCompletionStreamOutput( 1580 choices=[ 1581 ChatCompletionStreamOutputChoice( 1582 delta=ChatCompletionStreamOutputDelta(role=None, content=None, reasoning=None), 1583 index=0, 1584 finish_reason="stop", 1585 ) 1586 ], 1587 id="some_id", 1588 model="some_model", 1589 system_fingerprint="some_fingerprint", 1590 created=1710498504, 1591 ) 1592 1593 mock_chat_completion.return_value = mock_aiter() 1594 1595 generator = HuggingFaceAPIChatGenerator( 1596 api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, 1597 api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, 1598 streaming_callback=streaming_callback_fn, 1599 ) 1600 1601 response = await generator.run_async(chat_messages) 1602 1603 # Check streaming chunks 1604 assert len(streaming_chunks_received) == 3 1605 assert streaming_chunks_received[0].reasoning is not None 1606 assert streaming_chunks_received[0].reasoning.reasoning_text == "Thinking..." 1607 1608 # Check final message 1609 assert "replies" in response 1610 assert len(response["replies"]) == 1 1611 reply = response["replies"][0] 1612 assert reply.text == "Paris" 1613 assert reply.reasoning is not None 1614 assert isinstance(reply.reasoning, ReasoningContent) 1615 assert reply.reasoning.reasoning_text == "Thinking..." 1616 1617 def test_convert_chat_completion_stream_output_to_streaming_chunk_with_reasoning(self): 1618 """Test that reasoning is correctly extracted from streaming chunks.""" 1619 # In streaming mode, reasoning and content come in separate chunks 1620 chunk = ChatCompletionStreamOutput( 1621 choices=[ 1622 ChatCompletionStreamOutputChoice( 1623 delta=ChatCompletionStreamOutputDelta( 1624 role="assistant", content=None, reasoning="Let me think about this." 1625 ), 1626 index=0, 1627 finish_reason=None, 1628 ) 1629 ], 1630 id="some_id", 1631 model="some_model", 1632 system_fingerprint="some_fingerprint", 1633 created=1710498504, 1634 ) 1635 1636 streaming_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(chunk=chunk, previous_chunks=[]) 1637 1638 assert streaming_chunk.content == "" 1639 assert streaming_chunk.reasoning is not None 1640 assert isinstance(streaming_chunk.reasoning, ReasoningContent) 1641 assert streaming_chunk.reasoning.reasoning_text == "Let me think about this." 1642 1643 def test_convert_chat_completion_stream_output_to_streaming_chunk_without_reasoning(self): 1644 """Test that chunks without reasoning still work correctly.""" 1645 chunk = ChatCompletionStreamOutput( 1646 choices=[ 1647 ChatCompletionStreamOutputChoice( 1648 delta=ChatCompletionStreamOutputDelta(role="assistant", content="Hello"), 1649 index=0, 1650 finish_reason=None, 1651 ) 1652 ], 1653 id="some_id", 1654 model="some_model", 1655 system_fingerprint="some_fingerprint", 1656 created=1710498504, 1657 ) 1658 1659 streaming_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(chunk=chunk, previous_chunks=[]) 1660 1661 assert streaming_chunk.content == "Hello" 1662 assert streaming_chunk.reasoning is None 1663 1664 def test_resolve_schema_refs_no_defs(self): 1665 """Schema without $defs is returned as-is.""" 1666 schema = {"type": "object", "properties": {"name": {"type": "string"}}} 1667 assert _resolve_schema_refs(schema) == schema 1668 1669 def test_resolve_schema_refs_expands_defs(self): 1670 """Schema with $defs and $ref is expanded correctly.""" 1671 schema = { 1672 "$defs": { 1673 "User": { 1674 "type": "object", 1675 "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, 1676 "required": ["name"], 1677 } 1678 }, 1679 "type": "object", 1680 "properties": {"user": {"$ref": "#/$defs/User"}}, 1681 "required": ["user"], 1682 } 1683 resolved = _resolve_schema_refs(schema) 1684 assert "$defs" not in resolved 1685 assert "$ref" not in resolved["properties"]["user"] 1686 assert resolved["properties"]["user"]["type"] == "object" 1687 assert resolved["properties"]["user"]["properties"]["name"] == {"type": "string"} 1688 1689 def test_resolve_schema_refs_nested_refs(self): 1690 """Schema with nested $ref references is expanded correctly.""" 1691 schema = { 1692 "$defs": { 1693 "Address": {"type": "object", "properties": {"street": {"type": "string"}}}, 1694 "User": { 1695 "type": "object", 1696 "properties": {"name": {"type": "string"}, "address": {"$ref": "#/$defs/Address"}}, 1697 }, 1698 }, 1699 "type": "object", 1700 "properties": {"user": {"$ref": "#/$defs/User"}}, 1701 } 1702 resolved = _resolve_schema_refs(schema) 1703 assert "$defs" not in resolved 1704 user = resolved["properties"]["user"] 1705 assert user["properties"]["address"]["type"] == "object" 1706 assert user["properties"]["address"]["properties"]["street"] == {"type": "string"} 1707 1708 def test_convert_tools_to_hfapi_tools_resolves_defs(self): 1709 """Tool schemas with $defs are resolved before passing to HF API.""" 1710 tool = Tool( 1711 name="get_user", 1712 description="Get user info", 1713 parameters={ 1714 "$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, 1715 "type": "object", 1716 "properties": {"user": {"$ref": "#/$defs/User"}}, 1717 }, 1718 function=lambda user: user, 1719 ) 1720 hf_tools = _convert_tools_to_hfapi_tools([tool]) 1721 assert hf_tools is not None 1722 assert len(hf_tools) == 1 1723 params = hf_tools[0].function.parameters or hf_tools[0].function.arguments 1724 assert "$defs" not in params 1725 assert params["properties"]["user"]["type"] == "object"