/ tests / future / descriptors / test_text_match.py
test_text_match.py
  1  import pandas as pd
  2  import pytest
  3  
  4  from evidently.core.datasets import Dataset
  5  from evidently.descriptors import TextMatch
  6  
  7  
  8  @pytest.fixture
  9  def sample_data():
 10      return pd.DataFrame(
 11          {
 12              "description": [
 13                  "This is urgent and important message",
 14                  "This is just a test message",
 15                  "URGENT: Please respond immediately",
 16                  "Normal message without keywords",
 17                  "Contains both urgent and important words",
 18                  "Spam test message for filtering",
 19                  "Empty message",
 20                  "Message with numbers 123 and symbols @#$",
 21              ],
 22              "keywords": [
 23                  ["urgent", "important"],
 24                  ["test"],
 25                  ["urgent"],
 26                  [],
 27                  ["urgent", "important"],
 28                  ["spam", "test"],
 29                  [],
 30                  ["numbers"],
 31              ],
 32              "single_keyword": [
 33                  "urgent",
 34                  "test",
 35                  "urgent",
 36                  "",
 37                  "urgent",
 38                  "spam",
 39                  "",
 40                  "numbers",
 41              ],
 42          }
 43      )
 44  
 45  
 46  @pytest.fixture
 47  def sample_dataset(sample_data):
 48      return Dataset.from_pandas(sample_data)
 49  
 50  
 51  def test_contains_any_mode(sample_dataset):
 52      descriptor = TextMatch(
 53          column_name="description",
 54          match_items=["urgent", "important"],
 55          match_type="contains",
 56          match_mode="any",
 57          case_sensitive=False,
 58      )
 59      sample_dataset.add_descriptor(descriptor)
 60      result = sample_dataset.column(descriptor.alias)
 61  
 62      expected = [True, False, True, False, True, False, False, False]
 63      assert result.data.tolist() == expected
 64  
 65  
 66  def test_contains_all_mode(sample_dataset):
 67      descriptor = TextMatch(
 68          column_name="description",
 69          match_items=["urgent", "important"],
 70          match_type="contains",
 71          match_mode="all",
 72          case_sensitive=False,
 73      )
 74      sample_dataset.add_descriptor(descriptor)
 75      result = sample_dataset.column(descriptor.alias)
 76  
 77      expected = [True, False, False, False, True, False, False, False]
 78      assert result.data.tolist() == expected
 79  
 80  
 81  def test_not_contains_any_mode(sample_dataset):
 82      descriptor = TextMatch(
 83          column_name="description",
 84          match_items=["spam", "test"],
 85          match_type="not_contains",
 86          match_mode="any",
 87          case_sensitive=False,
 88      )
 89      sample_dataset.add_descriptor(descriptor)
 90      result = sample_dataset.column(descriptor.alias)
 91  
 92      expected = [True, True, True, True, True, False, True, True]
 93      assert result.data.tolist() == expected
 94  
 95  
 96  def test_not_contains_all_mode(sample_dataset):
 97      descriptor = TextMatch(
 98          column_name="description",
 99          match_items=["spam", "test"],
100          match_type="not_contains",
101          match_mode="all",
102          case_sensitive=False,
103      )
104      sample_dataset.add_descriptor(descriptor)
105      result = sample_dataset.column(descriptor.alias)
106  
107      expected = [True, False, True, True, True, False, True, True]
108      assert result.data.tolist() == expected
109  
110  
111  def test_exact_match(sample_dataset):
112      descriptor = TextMatch(
113          column_name="description", match_items=["urgent", "test"], match_type="exact", case_sensitive=False
114      )
115      sample_dataset.add_descriptor(descriptor)
116      result = sample_dataset.column(descriptor.alias)
117  
118      expected = [False, False, False, False, False, False, False, False]
119      assert result.data.tolist() == expected
120  
121  
122  def test_regex_match(sample_dataset):
123      descriptor = TextMatch(
124          column_name="description", match_items=[r"\b(urgent|important)\b"], match_type="regex", case_sensitive=False
125      )
126      sample_dataset.add_descriptor(descriptor)
127      result = sample_dataset.column(descriptor.alias)
128  
129      expected = [True, False, True, False, True, False, False, False]
130      assert result.data.tolist() == expected
131  
132  
133  def test_case_sensitive_true(sample_dataset):
134      descriptor = TextMatch(
135          column_name="description", match_items=["URGENT"], match_type="contains", case_sensitive=True
136      )
137      sample_dataset.add_descriptor(descriptor)
138      result = sample_dataset.column(descriptor.alias)
139  
140      expected = [False, False, True, False, False, False, False, False]
141      assert result.data.tolist() == expected
142  
143  
144  def test_case_sensitive_false(sample_dataset):
145      descriptor = TextMatch(
146          column_name="description", match_items=["URGENT"], match_type="contains", case_sensitive=False
147      )
148      sample_dataset.add_descriptor(descriptor)
149      result = sample_dataset.column(descriptor.alias)
150  
151      expected = [True, False, True, False, True, False, False, False]
152      assert result.data.tolist() == expected
153  
154  
155  def test_column_to_column_contains(sample_dataset):
156      descriptor = TextMatch(column_name="description", match_items="keywords", match_type="contains", match_mode="any")
157      sample_dataset.add_descriptor(descriptor)
158      result = sample_dataset.column(descriptor.alias)
159  
160      expected = [True, True, False, False, True, True, False, True]
161      assert result.data.tolist() == expected
162  
163  
164  def test_column_to_column_not_contains(sample_dataset):
165      descriptor = TextMatch(
166          column_name="description", match_items="keywords", match_type="not_contains", match_mode="any"
167      )
168      sample_dataset.add_descriptor(descriptor)
169      result = sample_dataset.column(descriptor.alias)
170  
171      expected = [False, False, True, False, False, True, False, False]
172      assert result.data.tolist() == expected
173  
174  
175  def test_word_boundaries(sample_dataset):
176      descriptor = TextMatch(
177          column_name="description",
178          match_items=["urgent"],
179          match_type="contains",
180          word_boundaries=True,
181          case_sensitive=False,
182      )
183      sample_dataset.add_descriptor(descriptor)
184      result = sample_dataset.column(descriptor.alias)
185  
186      expected = [True, False, True, False, True, False, False, False]
187      assert result.data.tolist() == expected
188  
189  
190  def test_lemmatization(sample_dataset):
191      descriptor = TextMatch(
192          column_name="description",
193          match_items=["filtering"],
194          match_type="contains",
195          lemmatize=True,
196          word_boundaries=True,
197          case_sensitive=False,
198      )
199      sample_dataset.add_descriptor(descriptor)
200      result = sample_dataset.column(descriptor.alias)
201  
202      expected = [False, False, False, False, False, True, False, False]
203      assert result.data.tolist() == expected
204  
205  
206  def test_empty_items_list(sample_dataset):
207      descriptor = TextMatch(column_name="description", match_items=[], match_type="contains", match_mode="any")
208      sample_dataset.add_descriptor(descriptor)
209      result = sample_dataset.column(descriptor.alias)
210  
211      expected = [False] * len(sample_dataset.as_dataframe())
212      assert result.data.tolist() == expected
213  
214  
215  def test_single_item_list(sample_dataset):
216      descriptor = TextMatch(
217          column_name="description", match_items=["urgent"], match_type="contains", case_sensitive=False
218      )
219      sample_dataset.add_descriptor(descriptor)
220      result = sample_dataset.column(descriptor.alias)
221  
222      expected = [True, False, True, False, True, False, False, False]
223      assert result.data.tolist() == expected
224  
225  
226  def test_none_values(sample_dataset):
227      df = sample_dataset.as_dataframe()
228      df.loc[len(df)] = {"description": None, "keywords": [], "single_keyword": ""}
229      dataset_with_none = Dataset.from_pandas(df)
230  
231      descriptor = TextMatch(column_name="description", match_items=["urgent"], match_type="contains")
232      dataset_with_none.add_descriptor(descriptor)
233      result = dataset_with_none.column(descriptor.alias)
234  
235      assert not result.data.iloc[-1]
236  
237  
238  def test_missing_column(sample_dataset):
239      descriptor = TextMatch(column_name="nonexistent_column", match_items=["urgent"], match_type="contains")
240  
241      with pytest.raises(ValueError, match="Column 'nonexistent_column' is not found in dataset.*"):
242          sample_dataset.add_descriptor(descriptor)
243  
244  
245  def test_missing_match_column(sample_dataset):
246      descriptor = TextMatch(column_name="description", match_items="nonexistent_column", match_type="contains")
247  
248      with pytest.raises(ValueError, match="Column 'nonexistent_column' is not found in dataset.*"):
249          sample_dataset.add_descriptor(descriptor)
250  
251  
252  def test_regex_multiple_patterns_error(sample_dataset):
253      descriptor = TextMatch(
254          column_name="description", match_items=[r"\b(urgent|important)\b", r"\btest\b"], match_type="regex"
255      )
256  
257      with pytest.raises(ValueError, match="Regex matching requires exactly one pattern.*"):
258          sample_dataset.add_descriptor(descriptor)
259  
260  
261  def test_invalid_match_type():
262      with pytest.raises(ValueError, match=".*match_type\n  unexpected value.*"):
263          TextMatch(
264              column_name="description",
265              match_items=["urgent"],
266              match_type="invalid",  # type: ignore
267          )
268  
269  
270  @pytest.mark.parametrize("match_type", ["contains", "not_contains", "exact", "regex"])
271  @pytest.mark.parametrize("match_mode", ["any", "all"])
272  @pytest.mark.parametrize("case_sensitive", [True, False])
273  def test_parameter_combinations(sample_dataset, match_type, match_mode, case_sensitive):
274      if match_type == "exact" and match_mode == "all":
275          pytest.skip("Exact match doesn't use match_mode")
276      if match_type == "regex":
277          pytest.skip("Regex match doesn't use match_mode")
278  
279      descriptor = TextMatch(
280          column_name="description",
281          match_items=["urgent", "important"],
282          match_type=match_type,
283          match_mode=match_mode,
284          case_sensitive=case_sensitive,
285      )
286      sample_dataset.add_descriptor(descriptor)
287      result = sample_dataset.column(descriptor.alias)
288  
289      assert len(result.data) == len(sample_dataset.as_dataframe())
290      assert all(isinstance(x, bool) for x in result.data)
291  
292  
293  @pytest.mark.parametrize("lemmatize", [True, False])
294  @pytest.mark.parametrize("word_boundaries", [True, False])
295  def test_processing_combinations(sample_dataset, lemmatize, word_boundaries):
296      descriptor = TextMatch(
297          column_name="description",
298          match_items=["urgent"],
299          match_type="contains",
300          lemmatize=lemmatize,
301          word_boundaries=word_boundaries,
302          case_sensitive=False,
303      )
304      sample_dataset.add_descriptor(descriptor)
305      result = sample_dataset.column(descriptor.alias)
306  
307      assert len(result.data) == len(sample_dataset.as_dataframe())
308      assert all(isinstance(x, bool) for x in result.data)
309  
310  
311  def test_alias_generation_list_items():
312      descriptor = TextMatch(
313          column_name="description", match_items=["urgent", "important"], match_type="contains", match_mode="any"
314      )
315      assert descriptor.alias is not None
316      assert "description" in descriptor.alias
317      assert "contains" in descriptor.alias
318  
319  
320  def test_alias_generation_column_items():
321      descriptor = TextMatch(column_name="description", match_items="keywords", match_type="contains", match_mode="all")
322      assert descriptor.alias is not None
323      assert "description" in descriptor.alias
324      assert "keywords" in descriptor.alias
325  
326  
327  def test_custom_alias():
328      custom_alias = "custom_text_match"
329      descriptor = TextMatch(column_name="description", match_items=["urgent"], match_type="contains", alias=custom_alias)
330      assert descriptor.alias == custom_alias
331  
332  
333  def test_list_items_input_columns():
334      descriptor = TextMatch(column_name="description", match_items=["urgent", "important"], match_type="contains")
335      input_columns = descriptor.list_input_columns()
336      assert input_columns == ["description"]
337  
338  
339  def test_column_items_input_columns():
340      descriptor = TextMatch(column_name="description", match_items="keywords", match_type="contains")
341      input_columns = descriptor.list_input_columns()
342      assert set(input_columns) == {"description", "keywords"}