/ test / components / generators / chat / test_llm.py
test_llm.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from typing import Any
  6  
  7  import pytest
  8  
  9  from haystack import Document, Pipeline, component
 10  from haystack.components.agents.agent import Agent
 11  from haystack.components.generators.chat import LLM
 12  from haystack.components.generators.chat.openai import OpenAIChatGenerator
 13  from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
 14  from haystack.core.component.types import OutputSocket
 15  from haystack.dataclasses import ChatMessage
 16  from haystack.dataclasses.chat_message import ChatRole
 17  from haystack.document_stores.in_memory import InMemoryDocumentStore
 18  from haystack.tools import Tool
 19  from haystack.tools.toolset import Toolset
 20  
 21  
 22  @component
 23  class MockChatGeneratorWithTools:
 24      """A mock chat generator that accepts a tools parameter."""
 25  
 26      def to_dict(self) -> dict[str, Any]:
 27          return {"type": "test_llm.MockChatGeneratorWithTools", "data": {}}
 28  
 29      @classmethod
 30      def from_dict(cls, data: dict[str, Any]) -> "MockChatGeneratorWithTools":
 31          return cls()
 32  
 33      @component.output_types(replies=list[ChatMessage])
 34      def run(self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs) -> dict[str, Any]:
 35          return {"replies": [ChatMessage.from_assistant("Reply with tools support")]}
 36  
 37      @component.output_types(replies=list[ChatMessage])
 38      async def run_async(
 39          self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs
 40      ) -> dict[str, Any]:
 41          return {"replies": [ChatMessage.from_assistant("Async reply with tools support")]}
 42  
 43  
 44  @component
 45  class MockChatGenerator:
 46      """A mock chat generator that does NOT accept a tools parameter."""
 47  
 48      def to_dict(self) -> dict[str, Any]:
 49          return {"type": "test_llm.MockChatGenerator", "data": {}}
 50  
 51      @classmethod
 52      def from_dict(cls, data: dict[str, Any]) -> "MockChatGenerator":
 53          return cls()
 54  
 55      @component.output_types(replies=list[ChatMessage])
 56      def run(self, messages: list[ChatMessage], **kwargs) -> dict[str, Any]:
 57          return {"replies": [ChatMessage.from_assistant("Sync reply")]}
 58  
 59      @component.output_types(replies=list[ChatMessage])
 60      async def run_async(self, messages: list[ChatMessage], **kwargs) -> dict[str, Any]:
 61          return {"replies": [ChatMessage.from_assistant("Async reply")]}
 62  
 63  
 64  class TestLLM:
 65      class TestInit:
 66          USER_PROMPT = '{% message role="user" %}{{ query }}{% endmessage %}'
 67  
 68          def test_is_subclass_of_agent(self):
 69              assert issubclass(LLM, Agent)
 70  
 71          def test_defaults(self):
 72              llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
 73              assert llm.chat_generator is not None
 74              assert llm.tools == []
 75              assert llm.system_prompt is None
 76              assert llm.user_prompt == self.USER_PROMPT
 77              assert llm.required_variables == "*"
 78              assert llm.streaming_callback is None
 79              assert llm._tool_invoker is None
 80  
 81          def test_output_sockets(self):
 82              llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
 83              assert llm.__haystack_output__._sockets_dict == {
 84                  "messages": OutputSocket(name="messages", type=list[ChatMessage], receivers=[]),
 85                  "last_message": OutputSocket(name="last_message", type=ChatMessage, receivers=[]),
 86              }
 87  
 88          def test_detects_no_tools_support(self):
 89              llm = LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT)
 90              assert llm._chat_generator_supports_tools is False
 91  
 92          def test_detects_tools_support(self):
 93              llm = LLM(chat_generator=MockChatGeneratorWithTools(), user_prompt=self.USER_PROMPT)
 94              assert llm._chat_generator_supports_tools is True
 95  
 96          def test_raises_if_user_prompt_has_no_variables(self):
 97              with pytest.raises(ValueError, match="at least one template variable"):
 98                  LLM(
 99                      chat_generator=MockChatGenerator(),
100                      user_prompt='{% message role="user" %}Hello world{% endmessage %}',
101                  )
102  
103          def test_raises_if_required_variables_empty(self):
104              with pytest.raises(ValueError, match="required_variables must not be empty"):
105                  LLM(chat_generator=MockChatGenerator(), user_prompt=self.USER_PROMPT, required_variables=[])
106  
107      class TestSerialization:
108          def test_to_dict_excludes_agent_only_params(self, monkeypatch):
109              monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
110              user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}'
111              llm = LLM(chat_generator=OpenAIChatGenerator(), system_prompt="You are helpful.", user_prompt=user_prompt)
112  
113              serialized = llm.to_dict()
114  
115              assert serialized["type"] == "haystack.components.generators.chat.llm.LLM"
116              assert "chat_generator" in serialized["init_parameters"]
117              assert serialized["init_parameters"]["system_prompt"] == "You are helpful."
118  
119              agent_only_params = [
120                  "tools",
121                  "exit_conditions",
122                  "max_agent_steps",
123                  "raise_on_tool_invocation_failure",
124                  "tool_invoker_kwargs",
125                  "confirmation_strategies",
126                  "state_schema",
127              ]
128              for param in agent_only_params:
129                  assert param not in serialized["init_parameters"], (
130                      f"Agent-only param '{param}' should not be serialized"
131                  )
132  
133          def test_to_dict_includes_llm_params(self, monkeypatch):
134              monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
135              llm = LLM(
136                  chat_generator=OpenAIChatGenerator(),
137                  system_prompt="Be concise.",
138                  user_prompt='{% message role="user" %}{{ query }}{% endmessage %}',
139                  required_variables=["query"],
140              )
141  
142              serialized = llm.to_dict()
143  
144              assert serialized["init_parameters"]["system_prompt"] == "Be concise."
145              assert "{{ query }}" in serialized["init_parameters"]["user_prompt"]
146              assert serialized["init_parameters"]["required_variables"] == ["query"]
147              assert serialized["init_parameters"]["streaming_callback"] is None
148  
149          def test_from_dict(self, monkeypatch):
150              monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
151              data = {
152                  "type": "haystack.components.generators.chat.llm.LLM",
153                  "init_parameters": {
154                      "chat_generator": {
155                          "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
156                          "init_parameters": {
157                              "model": "gpt-4o-mini",
158                              "streaming_callback": None,
159                              "api_base_url": None,
160                              "organization": None,
161                              "generation_kwargs": {},
162                              "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
163                              "timeout": None,
164                              "max_retries": None,
165                              "tools": None,
166                              "tools_strict": False,
167                              "http_client_kwargs": None,
168                          },
169                      },
170                      "system_prompt": "You are helpful.",
171                      "user_prompt": '{% message role="user" %}{{ query }}{% endmessage %}',
172                      "required_variables": "*",
173                      "streaming_callback": None,
174                  },
175              }
176  
177              llm = LLM.from_dict(data)
178  
179              assert isinstance(llm, LLM)
180              assert isinstance(llm.chat_generator, OpenAIChatGenerator)
181              assert llm.system_prompt == "You are helpful."
182              assert llm.tools == []
183  
184          def test_roundtrip(self, monkeypatch):
185              monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
186              user_prompt = '{% message role="user" %}{{ query }}{% endmessage %}'
187              original = LLM(
188                  chat_generator=OpenAIChatGenerator(), system_prompt="You are a poet.", user_prompt=user_prompt
189              )
190  
191              restored = LLM.from_dict(original.to_dict())
192  
193              assert isinstance(restored, LLM)
194              assert isinstance(restored.chat_generator, OpenAIChatGenerator)
195              assert restored.system_prompt == original.system_prompt
196              assert restored.tools == []
197  
198      class TestPipelineIntegration:
199          @pytest.fixture()
200          def document_store_with_docs(self):
201              store = InMemoryDocumentStore()
202              store.write_documents(
203                  [
204                      Document(content="The Eiffel Tower is located in Paris."),
205                      Document(content="The Brandenburg Gate is in Berlin."),
206                      Document(content="The Colosseum is in Rome."),
207                  ]
208              )
209              return store
210  
211          def test_rag_pipeline(self, document_store_with_docs):
212              user_prompt = (
213                  '{% message role="user" %}'
214                  "Use the following documents to answer the question.\n"
215                  "Documents:\n{% for doc in documents %}{{ doc.content }}\n{% endfor %}"
216                  "Question: {{ query }}"
217                  "{% endmessage %}"
218              )
219              llm = LLM(
220                  chat_generator=MockChatGenerator(),
221                  system_prompt="You are a knowledgeable assistant.",
222                  user_prompt=user_prompt,
223                  required_variables=["query", "documents"],
224              )
225  
226              pipe = Pipeline()
227              pipe.add_component("retriever", InMemoryBM25Retriever(document_store=document_store_with_docs))
228              pipe.add_component("llm", llm)
229              pipe.connect("retriever.documents", "llm.documents")
230  
231              query = "Where is the Colosseum?"
232              result = pipe.run(data={"retriever": {"query": query}, "llm": {"query": query}})
233  
234              assert "llm" in result
235              llm_output = result["llm"]
236              assert "messages" in llm_output
237              assert "last_message" in llm_output
238  
239              messages = llm_output["messages"]
240  
241              assert messages[0].is_from(ChatRole.SYSTEM)
242              assert messages[0].text == "You are a knowledgeable assistant."
243  
244              user_messages = [m for m in messages if m.is_from(ChatRole.USER)]
245              assert len(user_messages) == 1
246              rendered = user_messages[0].text
247              assert "Question: Where is the Colosseum?" in rendered
248              assert "Documents:" in rendered
249              assert "Colosseum" in rendered
250  
251              assert llm_output["last_message"].is_from(ChatRole.ASSISTANT)
252              assert llm_output["last_message"].text == "Sync reply"