/ haystack / components / evaluators / context_relevance.py
context_relevance.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from statistics import mean
  6  from typing import Any
  7  
  8  from haystack import component, default_from_dict, default_to_dict
  9  from haystack.components.evaluators.llm_evaluator import LLMEvaluator
 10  from haystack.components.generators.chat.types import ChatGenerator
 11  from haystack.core.serialization import component_to_dict
 12  from haystack.utils import deserialize_chatgenerator_inplace
 13  
 14  # Private global variable for default examples to include in the prompt if the user does not provide any examples
 15  _DEFAULT_EXAMPLES = [
 16      {
 17          "inputs": {
 18              "questions": "What is the capital of Germany?",
 19              "contexts": ["Berlin is the capital of Germany. Berlin and was founded in 1244."],
 20          },
 21          "outputs": {"relevant_statements": ["Berlin is the capital of Germany."]},
 22      },
 23      {
 24          "inputs": {
 25              "questions": "What is the capital of France?",
 26              "contexts": [
 27                  "Berlin is the capital of Germany and was founded in 1244.",
 28                  "Europe is a continent with 44 countries.",
 29                  "Madrid is the capital of Spain.",
 30              ],
 31          },
 32          "outputs": {"relevant_statements": []},
 33      },
 34      {
 35          "inputs": {"questions": "What is the capital of Italy?", "contexts": ["Rome is the capital of Italy."]},
 36          "outputs": {"relevant_statements": ["Rome is the capital of Italy."]},
 37      },
 38  ]
 39  
 40  
 41  @component
 42  class ContextRelevanceEvaluator(LLMEvaluator):
 43      """
 44      Evaluator that checks if a provided context is relevant to the question.
 45  
 46      An LLM breaks up a context into multiple statements and checks whether each statement
 47      is relevant for answering a question.
 48      The score for each context is either binary score of 1 or 0, where 1 indicates that the context is relevant
 49      to the question and 0 indicates that the context is not relevant.
 50      The evaluator also provides the relevant statements from the context and an average score over all the provided
 51      input questions contexts pairs.
 52  
 53      Usage example:
 54      ```python
 55      from haystack.components.evaluators import ContextRelevanceEvaluator
 56  
 57      questions = ["Who created the Python language?", "Why does Java needs a JVM?", "Is C++ better than Python?"]
 58      contexts = [
 59          [(
 60              "Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming "
 61              "language. Its design philosophy emphasizes code readability, and its language constructs aim to help "
 62              "programmers write clear, logical code for both small and large-scale software projects."
 63          )],
 64          [(
 65              "Java is a high-level, class-based, object-oriented programming language that is designed to have as few "
 66              "implementation dependencies as possible. The JVM has two primary functions: to allow Java programs to run"
 67              "on any device or operating system (known as the 'write once, run anywhere' principle), and to manage and"
 68              "optimize program memory."
 69          )],
 70          [(
 71              "C++ is a general-purpose programming language created by Bjarne Stroustrup as an extension of the C "
 72              "programming language."
 73          )],
 74      ]
 75  
 76      evaluator = ContextRelevanceEvaluator()
 77      result = evaluator.run(questions=questions, contexts=contexts)
 78      print(result["score"])
 79      # 0.67
 80      print(result["individual_scores"])
 81      # [1,1,0]
 82      print(result["results"])
 83      # [{
 84      #   'relevant_statements': ['Python, created by Guido van Rossum in the late 1980s.'],
 85      #    'score': 1.0
 86      #  },
 87      #  {
 88      #   'relevant_statements': ['The JVM has two primary functions: to allow Java programs to run on any device or
 89      #                           operating system (known as the "write once, run anywhere" principle), and to manage and
 90      #                           optimize program memory'],
 91      #   'score': 1.0
 92      #  },
 93      #  {
 94      #   'relevant_statements': [],
 95      #   'score': 0.0
 96      #  }]
 97      ```
 98      """
 99  
100      def __init__(
101          self,
102          examples: list[dict[str, Any]] | None = None,
103          progress_bar: bool = True,
104          raise_on_failure: bool = True,
105          chat_generator: ChatGenerator | None = None,
106      ) -> None:
107          """
108          Creates an instance of ContextRelevanceEvaluator.
109  
110          If no LLM is specified using the `chat_generator` parameter, the component will use OpenAI in JSON mode.
111  
112          :param examples:
113              Optional few-shot examples conforming to the expected input and output format of ContextRelevanceEvaluator.
114              Default examples will be used if none are provided.
115              Each example must be a dictionary with keys "inputs" and "outputs".
116              "inputs" must be a dictionary with keys "questions" and "contexts".
117              "outputs" must be a dictionary with "relevant_statements".
118              Expected format:
119              ```python
120              [{
121                  "inputs": {
122                      "questions": "What is the capital of Italy?", "contexts": ["Rome is the capital of Italy."],
123                  },
124                  "outputs": {
125                      "relevant_statements": ["Rome is the capital of Italy."],
126                  },
127              }]
128              ```
129          :param progress_bar:
130              Whether to show a progress bar during the evaluation.
131          :param raise_on_failure:
132              Whether to raise an exception if the API call fails.
133          :param chat_generator:
134              a ChatGenerator instance which represents the LLM.
135              In order for the component to work, the LLM should be configured to return a JSON object. For example,
136              when using the OpenAIChatGenerator, you should pass `{"response_format": {"type": "json_object"}}` in the
137              `generation_kwargs`.
138          """
139  
140          self.instructions = (
141              "Please extract only sentences from the provided context which are absolutely relevant and "
142              "required to answer the following question. If no relevant sentences are found, or if you "
143              "believe the question cannot be answered from the given context, return an empty list, example: []"
144          )
145          self.inputs = [("questions", list[str]), ("contexts", list[list[str]])]
146          self.outputs = ["relevant_statements"]
147          self.examples = examples or _DEFAULT_EXAMPLES
148  
149          super(ContextRelevanceEvaluator, self).__init__(  # noqa: UP008
150              instructions=self.instructions,
151              inputs=self.inputs,
152              outputs=self.outputs,
153              examples=self.examples,
154              chat_generator=chat_generator,
155              raise_on_failure=raise_on_failure,
156              progress_bar=progress_bar,
157          )
158  
159      @component.output_types(score=float, results=list[dict[str, Any]])
160      def run(self, **inputs: Any) -> dict[str, Any]:
161          """
162          Run the LLM evaluator.
163  
164          :param questions:
165              A list of questions.
166          :param contexts:
167              A list of lists of contexts. Each list of contexts corresponds to one question.
168          :returns:
169              A dictionary with the following outputs:
170                  - `score`: Mean context relevance score over all the provided input questions.
171                  - `results`: A list of dictionaries with `relevant_statements` and `score` for each input context.
172          """
173          result = super(ContextRelevanceEvaluator, self).run(**inputs)  # noqa: UP008
174  
175          for idx, res in enumerate(result["results"]):
176              if res is None:
177                  result["results"][idx] = {"relevant_statements": [], "score": float("nan")}
178                  continue
179              if len(res["relevant_statements"]) > 0:
180                  res["score"] = 1
181              else:
182                  res["score"] = 0
183  
184          # calculate average context relevance score over all queries
185          result["score"] = mean([res["score"] for res in result["results"]])
186          result["individual_scores"] = [res["score"] for res in result["results"]]  # useful for the EvaluationRunResult
187  
188          return result
189  
190      def to_dict(self) -> dict[str, Any]:
191          """
192          Serialize this component to a dictionary.
193  
194          :returns:
195              A dictionary with serialized data.
196          """
197          return default_to_dict(
198              self,
199              chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"),
200              examples=self.examples,
201              progress_bar=self.progress_bar,
202              raise_on_failure=self.raise_on_failure,
203          )
204  
205      @classmethod
206      def from_dict(cls, data: dict[str, Any]) -> "ContextRelevanceEvaluator":
207          """
208          Deserialize this component from a dictionary.
209  
210          :param data:
211              The dictionary representation of this component.
212          :returns:
213              The deserialized component instance.
214          """
215          if data["init_parameters"].get("chat_generator"):
216              deserialize_chatgenerator_inplace(data["init_parameters"], key="chat_generator")
217          return default_from_dict(cls, data)