test_text_match.py
1 import pandas as pd 2 import pytest 3 4 from evidently.core.datasets import Dataset 5 from evidently.descriptors import TextMatch 6 7 8 @pytest.fixture 9 def sample_data(): 10 return pd.DataFrame( 11 { 12 "description": [ 13 "This is urgent and important message", 14 "This is just a test message", 15 "URGENT: Please respond immediately", 16 "Normal message without keywords", 17 "Contains both urgent and important words", 18 "Spam test message for filtering", 19 "Empty message", 20 "Message with numbers 123 and symbols @#$", 21 ], 22 "keywords": [ 23 ["urgent", "important"], 24 ["test"], 25 ["urgent"], 26 [], 27 ["urgent", "important"], 28 ["spam", "test"], 29 [], 30 ["numbers"], 31 ], 32 "single_keyword": [ 33 "urgent", 34 "test", 35 "urgent", 36 "", 37 "urgent", 38 "spam", 39 "", 40 "numbers", 41 ], 42 } 43 ) 44 45 46 @pytest.fixture 47 def sample_dataset(sample_data): 48 return Dataset.from_pandas(sample_data) 49 50 51 def test_contains_any_mode(sample_dataset): 52 descriptor = TextMatch( 53 column_name="description", 54 match_items=["urgent", "important"], 55 match_type="contains", 56 match_mode="any", 57 case_sensitive=False, 58 ) 59 sample_dataset.add_descriptor(descriptor) 60 result = sample_dataset.column(descriptor.alias) 61 62 expected = [True, False, True, False, True, False, False, False] 63 assert result.data.tolist() == expected 64 65 66 def test_contains_all_mode(sample_dataset): 67 descriptor = TextMatch( 68 column_name="description", 69 match_items=["urgent", "important"], 70 match_type="contains", 71 match_mode="all", 72 case_sensitive=False, 73 ) 74 sample_dataset.add_descriptor(descriptor) 75 result = sample_dataset.column(descriptor.alias) 76 77 expected = [True, False, False, False, True, False, False, False] 78 assert result.data.tolist() == expected 79 80 81 def test_not_contains_any_mode(sample_dataset): 82 descriptor = TextMatch( 83 column_name="description", 84 match_items=["spam", "test"], 85 match_type="not_contains", 86 match_mode="any", 87 case_sensitive=False, 88 ) 89 sample_dataset.add_descriptor(descriptor) 90 result = sample_dataset.column(descriptor.alias) 91 92 expected = [True, True, True, True, True, False, True, True] 93 assert result.data.tolist() == expected 94 95 96 def test_not_contains_all_mode(sample_dataset): 97 descriptor = TextMatch( 98 column_name="description", 99 match_items=["spam", "test"], 100 match_type="not_contains", 101 match_mode="all", 102 case_sensitive=False, 103 ) 104 sample_dataset.add_descriptor(descriptor) 105 result = sample_dataset.column(descriptor.alias) 106 107 expected = [True, False, True, True, True, False, True, True] 108 assert result.data.tolist() == expected 109 110 111 def test_exact_match(sample_dataset): 112 descriptor = TextMatch( 113 column_name="description", match_items=["urgent", "test"], match_type="exact", case_sensitive=False 114 ) 115 sample_dataset.add_descriptor(descriptor) 116 result = sample_dataset.column(descriptor.alias) 117 118 expected = [False, False, False, False, False, False, False, False] 119 assert result.data.tolist() == expected 120 121 122 def test_regex_match(sample_dataset): 123 descriptor = TextMatch( 124 column_name="description", match_items=[r"\b(urgent|important)\b"], match_type="regex", case_sensitive=False 125 ) 126 sample_dataset.add_descriptor(descriptor) 127 result = sample_dataset.column(descriptor.alias) 128 129 expected = [True, False, True, False, True, False, False, False] 130 assert result.data.tolist() == expected 131 132 133 def test_case_sensitive_true(sample_dataset): 134 descriptor = TextMatch( 135 column_name="description", match_items=["URGENT"], match_type="contains", case_sensitive=True 136 ) 137 sample_dataset.add_descriptor(descriptor) 138 result = sample_dataset.column(descriptor.alias) 139 140 expected = [False, False, True, False, False, False, False, False] 141 assert result.data.tolist() == expected 142 143 144 def test_case_sensitive_false(sample_dataset): 145 descriptor = TextMatch( 146 column_name="description", match_items=["URGENT"], match_type="contains", case_sensitive=False 147 ) 148 sample_dataset.add_descriptor(descriptor) 149 result = sample_dataset.column(descriptor.alias) 150 151 expected = [True, False, True, False, True, False, False, False] 152 assert result.data.tolist() == expected 153 154 155 def test_column_to_column_contains(sample_dataset): 156 descriptor = TextMatch(column_name="description", match_items="keywords", match_type="contains", match_mode="any") 157 sample_dataset.add_descriptor(descriptor) 158 result = sample_dataset.column(descriptor.alias) 159 160 expected = [True, True, False, False, True, True, False, True] 161 assert result.data.tolist() == expected 162 163 164 def test_column_to_column_not_contains(sample_dataset): 165 descriptor = TextMatch( 166 column_name="description", match_items="keywords", match_type="not_contains", match_mode="any" 167 ) 168 sample_dataset.add_descriptor(descriptor) 169 result = sample_dataset.column(descriptor.alias) 170 171 expected = [False, False, True, False, False, True, False, False] 172 assert result.data.tolist() == expected 173 174 175 def test_word_boundaries(sample_dataset): 176 descriptor = TextMatch( 177 column_name="description", 178 match_items=["urgent"], 179 match_type="contains", 180 word_boundaries=True, 181 case_sensitive=False, 182 ) 183 sample_dataset.add_descriptor(descriptor) 184 result = sample_dataset.column(descriptor.alias) 185 186 expected = [True, False, True, False, True, False, False, False] 187 assert result.data.tolist() == expected 188 189 190 def test_lemmatization(sample_dataset): 191 descriptor = TextMatch( 192 column_name="description", 193 match_items=["filtering"], 194 match_type="contains", 195 lemmatize=True, 196 word_boundaries=True, 197 case_sensitive=False, 198 ) 199 sample_dataset.add_descriptor(descriptor) 200 result = sample_dataset.column(descriptor.alias) 201 202 expected = [False, False, False, False, False, True, False, False] 203 assert result.data.tolist() == expected 204 205 206 def test_empty_items_list(sample_dataset): 207 descriptor = TextMatch(column_name="description", match_items=[], match_type="contains", match_mode="any") 208 sample_dataset.add_descriptor(descriptor) 209 result = sample_dataset.column(descriptor.alias) 210 211 expected = [False] * len(sample_dataset.as_dataframe()) 212 assert result.data.tolist() == expected 213 214 215 def test_single_item_list(sample_dataset): 216 descriptor = TextMatch( 217 column_name="description", match_items=["urgent"], match_type="contains", case_sensitive=False 218 ) 219 sample_dataset.add_descriptor(descriptor) 220 result = sample_dataset.column(descriptor.alias) 221 222 expected = [True, False, True, False, True, False, False, False] 223 assert result.data.tolist() == expected 224 225 226 def test_none_values(sample_dataset): 227 df = sample_dataset.as_dataframe() 228 df.loc[len(df)] = {"description": None, "keywords": [], "single_keyword": ""} 229 dataset_with_none = Dataset.from_pandas(df) 230 231 descriptor = TextMatch(column_name="description", match_items=["urgent"], match_type="contains") 232 dataset_with_none.add_descriptor(descriptor) 233 result = dataset_with_none.column(descriptor.alias) 234 235 assert not result.data.iloc[-1] 236 237 238 def test_missing_column(sample_dataset): 239 descriptor = TextMatch(column_name="nonexistent_column", match_items=["urgent"], match_type="contains") 240 241 with pytest.raises(ValueError, match="Column 'nonexistent_column' is not found in dataset.*"): 242 sample_dataset.add_descriptor(descriptor) 243 244 245 def test_missing_match_column(sample_dataset): 246 descriptor = TextMatch(column_name="description", match_items="nonexistent_column", match_type="contains") 247 248 with pytest.raises(ValueError, match="Column 'nonexistent_column' is not found in dataset.*"): 249 sample_dataset.add_descriptor(descriptor) 250 251 252 def test_regex_multiple_patterns_error(sample_dataset): 253 descriptor = TextMatch( 254 column_name="description", match_items=[r"\b(urgent|important)\b", r"\btest\b"], match_type="regex" 255 ) 256 257 with pytest.raises(ValueError, match="Regex matching requires exactly one pattern.*"): 258 sample_dataset.add_descriptor(descriptor) 259 260 261 def test_invalid_match_type(): 262 with pytest.raises(ValueError, match=".*match_type\n unexpected value.*"): 263 TextMatch( 264 column_name="description", 265 match_items=["urgent"], 266 match_type="invalid", # type: ignore 267 ) 268 269 270 @pytest.mark.parametrize("match_type", ["contains", "not_contains", "exact", "regex"]) 271 @pytest.mark.parametrize("match_mode", ["any", "all"]) 272 @pytest.mark.parametrize("case_sensitive", [True, False]) 273 def test_parameter_combinations(sample_dataset, match_type, match_mode, case_sensitive): 274 if match_type == "exact" and match_mode == "all": 275 pytest.skip("Exact match doesn't use match_mode") 276 if match_type == "regex": 277 pytest.skip("Regex match doesn't use match_mode") 278 279 descriptor = TextMatch( 280 column_name="description", 281 match_items=["urgent", "important"], 282 match_type=match_type, 283 match_mode=match_mode, 284 case_sensitive=case_sensitive, 285 ) 286 sample_dataset.add_descriptor(descriptor) 287 result = sample_dataset.column(descriptor.alias) 288 289 assert len(result.data) == len(sample_dataset.as_dataframe()) 290 assert all(isinstance(x, bool) for x in result.data) 291 292 293 @pytest.mark.parametrize("lemmatize", [True, False]) 294 @pytest.mark.parametrize("word_boundaries", [True, False]) 295 def test_processing_combinations(sample_dataset, lemmatize, word_boundaries): 296 descriptor = TextMatch( 297 column_name="description", 298 match_items=["urgent"], 299 match_type="contains", 300 lemmatize=lemmatize, 301 word_boundaries=word_boundaries, 302 case_sensitive=False, 303 ) 304 sample_dataset.add_descriptor(descriptor) 305 result = sample_dataset.column(descriptor.alias) 306 307 assert len(result.data) == len(sample_dataset.as_dataframe()) 308 assert all(isinstance(x, bool) for x in result.data) 309 310 311 def test_alias_generation_list_items(): 312 descriptor = TextMatch( 313 column_name="description", match_items=["urgent", "important"], match_type="contains", match_mode="any" 314 ) 315 assert descriptor.alias is not None 316 assert "description" in descriptor.alias 317 assert "contains" in descriptor.alias 318 319 320 def test_alias_generation_column_items(): 321 descriptor = TextMatch(column_name="description", match_items="keywords", match_type="contains", match_mode="all") 322 assert descriptor.alias is not None 323 assert "description" in descriptor.alias 324 assert "keywords" in descriptor.alias 325 326 327 def test_custom_alias(): 328 custom_alias = "custom_text_match" 329 descriptor = TextMatch(column_name="description", match_items=["urgent"], match_type="contains", alias=custom_alias) 330 assert descriptor.alias == custom_alias 331 332 333 def test_list_items_input_columns(): 334 descriptor = TextMatch(column_name="description", match_items=["urgent", "important"], match_type="contains") 335 input_columns = descriptor.list_input_columns() 336 assert input_columns == ["description"] 337 338 339 def test_column_items_input_columns(): 340 descriptor = TextMatch(column_name="description", match_items="keywords", match_type="contains") 341 input_columns = descriptor.list_input_columns() 342 assert set(input_columns) == {"description", "keywords"}