test_llm.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from typing import Any 6 7 import pytest 8 9 from haystack import Document, Pipeline, component 10 from haystack.components.agents.agent import Agent 11 from haystack.components.generators.chat import LLM 12 from haystack.components.generators.chat.openai import OpenAIChatGenerator 13 from haystack.components.retrievers.in_memory import InMemoryBM25Retriever 14 from haystack.core.component.types import OutputSocket 15 from haystack.dataclasses import ChatMessage 16 from haystack.dataclasses.chat_message import ChatRole 17 from haystack.document_stores.in_memory import InMemoryDocumentStore 18 from haystack.tools import Tool 19 from haystack.tools.toolset import Toolset 20 21 22 @component 23 class MockChatGeneratorWithTools: 24 """A mock chat generator that accepts a tools parameter.""" 25 26 def to_dict(self) -> dict[str, Any]: 27 return {"type": "test_llm.MockChatGeneratorWithTools", "data": {}} 28 29 @classmethod 30 def from_dict(cls, data: dict[str, Any]) -> "MockChatGeneratorWithTools": 31 return cls() 32 33 @component.output_types(replies=list[ChatMessage]) 34 def run(self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs) -> dict[str, Any]: 35 return {"replies": [ChatMessage.from_assistant("Reply with tools support")]} 36 37 @component.output_types(replies=list[ChatMessage]) 38 async def run_async( 39 self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs 40 ) -> dict[str, Any]: 41 return {"replies": [ChatMessage.from_assistant("Async reply with tools support")]} 42 43 44 @component 45 class MockChatGenerator: 46 """A mock chat generator that does NOT accept a tools parameter.""" 47 48 def to_dict(self) -> dict[str, Any]: 49 return {"type": "test_llm.MockChatGenerator", "data": {}} 50 51 @classmethod 52 def from_dict(cls, data: dict[str, Any]) -> "MockChatGenerator": 53 return cls() 54 55 @component.output_types(replies=list[ChatMessage]) 56 def run(self, messages: list[ChatMessage], **kwargs) -> dict[str, Any]: 57 return {"replies": [ChatMessage.from_assistant("Sync reply")]} 58 59 @component.output_types(replies=list[ChatMessage]) 60 async def run_async(self, messages: list[ChatMessage], **kwargs) -> dict[str, Any]: 61 return {"replies": [ChatMessage.from_assistant("Async reply")]} 62 63 64 class TestLLM: 65 class TestInit: 66 USER_PROMPT = '{% message role="user" %}{{ query }}{% endmessage %}' 67 68 def test_is_subclass_of_agent(self): 69 assert issubclass(LLM, Agent) 70 71 def test_defaults(self): 72 llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT) 73 assert llm.chat_generator is not None 74 assert llm.tools == [] 75 assert llm.system_prompt is None 76 assert llm.user_prompt == self.USER_PROMPT 77 assert llm.required_variables == "*" 78 assert llm.streaming_callback is None 79 assert llm._tool_invoker is None 80 81 def test_output_sockets(self): 82 llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT) 83 assert llm.__haystack_output__._sockets_dict == { 84 "messages": OutputSocket(name="messages", type=list[ChatMessage], receivers=[]), 85 "last_message": OutputSocket(name="last_message", type=ChatMessage, receivers=[]), 86 } 87 88 def test_detects_no_tools_support(self): 89 llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT) 90 assert llm._chat_generator_supports_tools is False 91 92 def test_detects_tools_support(self): 93 llm = LLM(chat_generator=MockChatGeneratorWithTools(), user_prompt=self.USER_PROMPT) 94 assert llm._chat_generator_supports_tools is True 95 96 def test_raises_if_user_prompt_has_no_variables(self): 97 with pytest.raises(ValueError, match="at least one template variable"): 98 LLM( 99 chat_generator=MockChatGenerator(), 100 user_prompt='{% message role="user" %}Hello world{% endmessage %}', 101 ) 102 103 def test_raises_if_required_variables_empty(self): 104 with pytest.raises(ValueError, match="required_variables must not be empty"): 105 LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT, required_variables=[]) 106 107 class TestSerialization: 108 def test_to_dict_excludes_agent_only_params(self, monkeypatch): 109 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 110 user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}' 111 llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.", user_prompt=user_prompt) 112 113 serialized = llm.to_dict() 114 115 assert serialized["type"] == "haystack.components.generators.chat.llm.LLM" 116 assert "chat_generator" in serialized["init_parameters"] 117 assert serialized["init_parameters"]["system_prompt"] == "You are helpful." 118 119 agent_only_params = [ 120 "tools", 121 "exit_conditions", 122 "max_agent_steps", 123 "raise_on_tool_invocation_failure", 124 "tool_invoker_kwargs", 125 "confirmation_strategies", 126 "state_schema", 127 ] 128 for param in agent_only_params: 129 assert param not in serialized["init_parameters"], ( 130 f"Agent-only param '{param}' should not be serialized" 131 ) 132 133 def test_to_dict_includes_llm_params(self, monkeypatch): 134 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 135 llm = LLM( 136 chat_generator=OpenAIChatGenerator(), 137 system_prompt="Be concise.", 138 user_prompt='{% message role="user" %}{{ query }}{% endmessage %}', 139 required_variables=["query"], 140 ) 141 142 serialized = llm.to_dict() 143 144 assert serialized["init_parameters"]["system_prompt"] == "Be concise." 145 assert "{{ query }}" in serialized["init_parameters"]["user_prompt"] 146 assert serialized["init_parameters"]["required_variables"] == ["query"] 147 assert serialized["init_parameters"]["streaming_callback"] is None 148 149 def test_from_dict(self, monkeypatch): 150 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 151 data = { 152 "type": "haystack.components.generators.chat.llm.LLM", 153 "init_parameters": { 154 "chat_generator": { 155 "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", 156 "init_parameters": { 157 "model": "gpt-4o-mini", 158 "streaming_callback": None, 159 "api_base_url": None, 160 "organization": None, 161 "generation_kwargs": {}, 162 "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, 163 "timeout": None, 164 "max_retries": None, 165 "tools": None, 166 "tools_strict": False, 167 "http_client_kwargs": None, 168 }, 169 }, 170 "system_prompt": "You are helpful.", 171 "user_prompt": '{% message role="user" %}{{ query }}{% endmessage %}', 172 "required_variables": "*", 173 "streaming_callback": None, 174 }, 175 } 176 177 llm = LLM.from_dict(data) 178 179 assert isinstance(llm, LLM) 180 assert isinstance(llm.chat_generator, OpenAIChatGenerator) 181 assert llm.system_prompt == "You are helpful." 182 assert llm.tools == [] 183 184 def test_roundtrip(self, monkeypatch): 185 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 186 user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}' 187 original = LLM( 188 chat_generator=OpenAIChatGenerator(), system_prompt="You are a poet.", user_prompt=user_prompt 189 ) 190 191 restored = LLM.from_dict(original.to_dict()) 192 193 assert isinstance(restored, LLM) 194 assert isinstance(restored.chat_generator, OpenAIChatGenerator) 195 assert restored.system_prompt == original.system_prompt 196 assert restored.tools == [] 197 198 class TestPipelineIntegration: 199 @pytest.fixture() 200 def document_store_with_docs(self): 201 store = InMemoryDocumentStore() 202 store.write_documents( 203 [ 204 Document(content="The Eiffel Tower is located in Paris."), 205 Document(content="The Brandenburg Gate is in Berlin."), 206 Document(content="The Colosseum is in Rome."), 207 ] 208 ) 209 return store 210 211 def test_rag_pipeline(self, document_store_with_docs): 212 user_prompt = ( 213 '{% message role="user" %}' 214 "Use the following documents to answer the question.\n" 215 "Documents:\n{% for doc in documents %}{{ doc.content }}\n{% endfor %}" 216 "Question: {{ query }}" 217 "{% endmessage %}" 218 ) 219 llm = LLM( 220 chat_generator=MockChatGenerator(), 221 system_prompt="You are a knowledgeable assistant.", 222 user_prompt=user_prompt, 223 required_variables=["query", "documents"], 224 ) 225 226 pipe = Pipeline() 227 pipe.add_component("retriever", InMemoryBM25Retriever(document_store=document_store_with_docs)) 228 pipe.add_component("llm", llm) 229 pipe.connect("retriever.documents", "llm.documents") 230 231 query = "Where is the Colosseum?" 232 result = pipe.run(data={"retriever": {"query": query}, "llm": {"query": query}}) 233 234 assert "llm" in result 235 llm_output = result["llm"] 236 assert "messages" in llm_output 237 assert "last_message" in llm_output 238 239 messages = llm_output["messages"] 240 241 assert messages[0].is_from(ChatRole.SYSTEM) 242 assert messages[0].text == "You are a knowledgeable assistant." 243 244 user_messages = [m for m in messages if m.is_from(ChatRole.USER)] 245 assert len(user_messages) == 1 246 rendered = user_messages[0].text 247 assert "Question: Where is the Colosseum?" in rendered 248 assert "Documents:" in rendered 249 assert "Colosseum" in rendered 250 251 assert llm_output["last_message"].is_from(ChatRole.ASSISTANT) 252 assert llm_output["last_message"].text == "Sync reply"