/ tests / features / test_llm_judge.py
test_llm_judge.py
  1  import json
  2  import re
  3  from typing import Dict
  4  from typing import List
  5  from typing import Optional
  6  from typing import Tuple
  7  from typing import Union
  8  
  9  import pandas as pd
 10  import pytest
 11  
 12  from evidently.legacy.descriptors import NegativityLLMEval
 13  from evidently.legacy.features.llm_judge import BinaryClassificationPromptTemplate
 14  from evidently.legacy.features.llm_judge import LLMJudge
 15  from evidently.legacy.features.llm_judge import LLMMessage
 16  from evidently.legacy.features.llm_judge import LLMWrapper
 17  from evidently.legacy.metric_preset import TextEvals
 18  from evidently.legacy.options.base import Options
 19  from evidently.legacy.report import Report
 20  from evidently.legacy.utils.data_preprocessing import DataDefinition
 21  from evidently.legacy.utils.llm.errors import LLMResponseParseError
 22  from evidently.legacy.utils.llm.wrapper import LLMResult
 23  from evidently.legacy.utils.llm.wrapper import llm_provider
 24  
 25  
 26  def _LLMPromptTemplate(
 27      include_reasoning: bool,
 28      target_category: str,
 29      non_target_category: str,
 30      score_range: Optional[Tuple[float, float]] = None,
 31  ):
 32      return BinaryClassificationPromptTemplate(
 33          criteria="",
 34          instructions_template="",
 35          include_reasoning=include_reasoning,
 36          target_category=target_category,
 37          non_target_category=non_target_category,
 38          include_score=score_range is not None,
 39          score_range=score_range or (0, 1),
 40      )
 41  
 42  
 43  @pytest.mark.parametrize(
 44      "template,results",
 45      [
 46          (
 47              _LLMPromptTemplate(include_reasoning=False, target_category="FIRST", non_target_category="SECOND"),
 48              {
 49                  json.dumps({"result": "FIRST"}): {"result": "FIRST"},
 50                  json.dumps({"result": "SECOND"}): {"result": "SECOND"},
 51              },
 52          ),
 53          (
 54              _LLMPromptTemplate(include_reasoning=True, target_category="FIRST", non_target_category="SECOND"),
 55              {
 56                  json.dumps({"result": "FIRST", "reasoning": "Reason"}): {"result": "FIRST", "reasoning": "Reason"},
 57                  json.dumps({"result": "SECOND", "reasoning": "Reason"}): {"result": "SECOND", "reasoning": "Reason"},
 58              },
 59          ),
 60          (
 61              _LLMPromptTemplate(
 62                  include_reasoning=False, target_category="FIRST", non_target_category="SECOND", score_range=(0, 1)
 63              ),
 64              {
 65                  json.dumps({"result": 0}): {"result": 0},
 66                  json.dumps({"result": 1}): {"result": 1},
 67              },
 68          ),
 69          (
 70              _LLMPromptTemplate(
 71                  include_reasoning=True, target_category="FIRST", non_target_category="SECOND", score_range=(0, 1)
 72              ),
 73              {
 74                  json.dumps({"result": 0, "reasoning": "Reason"}): {"result": 0, "reasoning": "Reason"},
 75                  json.dumps({"result": 1, "reasoning": "Reason"}): {"result": 1, "reasoning": "Reason"},
 76              },
 77          ),
 78      ],
 79  )
 80  def test_parse_response(
 81      template: BinaryClassificationPromptTemplate,
 82      results: Dict[str, Union[LLMResponseParseError, Dict[str, Union[str, float]]]],
 83  ):
 84      for response, expected_result in results.items():
 85          if isinstance(expected_result, LLMResponseParseError):
 86              with pytest.raises(expected_result.__class__):
 87                  template.get_parser()(response)
 88          else:
 89              assert template.get_parser()(response) == expected_result
 90  
 91  
 92  @llm_provider("mock", None)
 93  class MockLLMWrapper(LLMWrapper):
 94      def __init__(self, model: str, options: Options):
 95          self.model = model
 96  
 97      async def complete(self, messages: List[LLMMessage], seed: Optional[int] = None) -> LLMResult[str]:
 98          text = messages[-1].content
 99          cat = re.findall("___text_starts_here___\n(.*)\n___text_ends_here___", text)[0][0]
100          return LLMResult(json.dumps({"category": cat}), 0, 0)
101  
102  
103  @pytest.mark.asyncio
104  def test_llm_judge():
105      llm_judge = LLMJudge(
106          input_column="text",
107          provider="mock",
108          model="",
109          template=BinaryClassificationPromptTemplate(target_category="A", non_target_category="B"),
110      )
111  
112      data = pd.DataFrame({"text": ["A", "B"]})
113  
114      dd = DataDefinition(columns={}, reference_present=False)
115      fts = llm_judge.generate_features(data, dd, Options())
116      pd.testing.assert_frame_equal(fts, pd.DataFrame({"category": ["A", "B"]}))
117  
118  
119  @pytest.mark.asyncio
120  def test_multicol_llm_judge():
121      llm_judge = LLMJudge(
122          input_columns={"text": "input", "text2": "input2"},
123          provider="mock",
124          model="",
125          template=BinaryClassificationPromptTemplate(target_category="A", non_target_category="B"),
126      )
127  
128      data = pd.DataFrame({"text": ["A", "B"], "text2": ["C", "D"]})
129  
130      dd = DataDefinition(columns={}, reference_present=False)
131      fts = llm_judge.generate_features(data, dd, Options())
132      pd.testing.assert_frame_equal(fts, pd.DataFrame({"category": ["A", "B"]}))
133  
134  
135  def test_run_snapshot_with_llm_judge():
136      data = pd.DataFrame({"text": ["A", "B"], "text2": ["C", "D"]})
137      neg_eval = NegativityLLMEval(
138          input_columns={"text": "input", "text2": "input2"},
139          provider="mock",
140          model="",
141          template=BinaryClassificationPromptTemplate(target_category="A", non_target_category="B"),
142      )
143      report = Report(metrics=[TextEvals("text", descriptors=[neg_eval])])
144  
145      report.run(current_data=data, reference_data=None)
146      report._inner_suite.raise_for_error()
147      assert report.as_dict() == {
148          "metrics": [
149              {
150                  "metric": "ColumnSummaryMetric",
151                  "result": {
152                      "column_name": "Negativity",
153                      "column_type": "cat",
154                      "current_characteristics": {
155                          "count": 2,
156                          "missing": 0,
157                          "missing_percentage": 0.0,
158                          "most_common": "A",
159                          "most_common_percentage": 50.0,
160                          "new_in_current_values_count": None,
161                          "number_of_rows": 2,
162                          "unique": 2,
163                          "unique_percentage": 100.0,
164                          "unused_in_current_values_count": None,
165                      },
166                      "reference_characteristics": None,
167                  },
168              }
169          ]
170      }