test_check_types.py
1 """Verify check types.""" 2 # pyright: reportUnusedCallResult=false 3 4 from bisect import bisect_right 5 from dataclasses import dataclass 6 from itertools import accumulate, chain, combinations, repeat 7 from re import escape 8 from string import ascii_letters, ascii_lowercase 9 10 from hypothesis import assume, given 11 from hypothesis import strategies as st 12 from rstr import xeger 13 14 from proselint.registry.checks import BATCH_COUNT, Check, Padding, engine, types 15 from tests.common import engine_from 16 17 PADDING_STRATEGY = st.sampled_from(Padding) 18 19 20 @dataclass(frozen=True) 21 class _TermStrategies: 22 item_text: st.SearchStrategy[str] 23 term_pair: st.SearchStrategy[tuple[str, str]] 24 term_pairs: st.SearchStrategy[list[tuple[str, str]]] 25 26 @staticmethod 27 def _not_substring(pair: tuple[str, str]) -> bool: 28 a, b = (s.lower() for s in pair) 29 return a not in b and b not in a 30 31 @classmethod 32 def from_alphabet(cls, alphabet: str) -> "_TermStrategies": 33 item: st.SearchStrategy[str] = st.text(min_size=1, alphabet=alphabet) 34 pair: st.SearchStrategy[tuple[str, str]] = st.tuples(item, item).filter( 35 cls._not_substring 36 ) 37 pairs: st.SearchStrategy[list[tuple[str, str]]] = st.lists( 38 pair, min_size=1, max_size=BATCH_COUNT 39 ).filter( 40 lambda xs: all(map(cls._not_substring, combinations(chain(*xs), 2))) 41 ) 42 43 return cls(item, pair, pairs) 44 45 46 def n_counts(data: st.DataObject, n: int) -> list[int]: 47 """Generate a list of repetition counts for a collection of items.""" 48 return data.draw(st.lists(st.integers(1, 100), min_size=n, max_size=n)) 49 50 51 LOWER_STRATEGIES = _TermStrategies.from_alphabet(ascii_lowercase) 52 ITEM_TEXT_STRATEGY = LOWER_STRATEGIES.item_text 53 TERM_PAIR_STRATEGY = LOWER_STRATEGIES.term_pair 54 TERM_PAIRS_STRATEGY = LOWER_STRATEGIES.term_pairs 55 56 57 CASED_STRATEGIES = _TermStrategies.from_alphabet(ascii_letters) 58 59 60 # Consistency 61 @given(TERM_PAIR_STRATEGY, st.text(), st.text()) 62 def test_consistency_smoke( 63 term_pair: tuple[str, str], path: str, noise: str 64 ) -> None: 65 """Return no matches when no elements are present.""" 66 assume(all(term not in noise.lower() for term in term_pair)) 67 check_type = types.Consistency(term_pairs=(term_pair,)) 68 check = Check(check_type=check_type, path=path, message="{} || {}") 69 assert list(check.check(noise)) == [] 70 71 72 @given(TERM_PAIR_STRATEGY, st.text(), st.integers(0, 1)) 73 def test_consistency_one_in_text( 74 term_pair: tuple[str, str], path: str, choice: int 75 ) -> None: 76 """Return no matches when only one element is present.""" 77 check_type = types.Consistency(term_pairs=(term_pair,)) 78 check = Check(check_type=check_type, path=path, message="{} || {}") 79 assert list(check.check(term_pair[choice])) == [] 80 81 82 @given( 83 TERM_PAIR_STRATEGY, 84 st.text(), 85 st.tuples(st.integers(1, 100), st.integers(1, 100)), 86 ) 87 def test_consistency_both_in_text( 88 term_pair: tuple[str, str], path: str, count: tuple[int, int] 89 ) -> None: 90 """Return correct matches when both elements are present.""" 91 check_type = types.Consistency(term_pairs=(term_pair,)) 92 check = Check(check_type=check_type, path=path, message="{} || {}") 93 94 text = f"{term_pair[0]} " * count[0] + f"{term_pair[1]} " * count[1] 95 results = list(check.check(text)) 96 assert len(results) == min(count) 97 assert all(result.check_path == path for result in results) 98 idx_minority = count[0] > count[1] 99 assert all( 100 result.message 101 == f"{term_pair[not idx_minority]} || {term_pair[idx_minority]}" 102 for result in results 103 ) 104 105 106 # PreferredForms 107 PREFERRED_ITEMS_STRATEGY: st.SearchStrategy[dict[str, str]] = st.builds( # pyright: ignore[reportUnknownVariableType] 108 dict, 109 LOWER_STRATEGIES.term_pairs, 110 ) 111 112 PREFERRED_ITEMS_CASED_STRATEGY: st.SearchStrategy[dict[str, str]] = st.builds( # pyright: ignore[reportUnknownVariableType] 113 dict, 114 CASED_STRATEGIES.term_pairs, 115 ) 116 117 118 @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.text()) 119 def test_preferred_smoke( 120 items: dict[str, str], padding: Padding, path: str, noise: str 121 ) -> None: 122 """Return no matches when no elements are present.""" 123 assume(all(item not in noise.lower() for item in items)) 124 check_type = types.PreferredForms(items=items, padding=padding) 125 check = Check( 126 check_type=check_type, 127 path=path, 128 message="{} || {}", 129 engine=engine_from(padding), 130 ) 131 assert list(check.check(noise)) == [] 132 133 134 @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.data()) 135 def test_preferred_values_smoke( 136 items: dict[str, str], padding: Padding, path: str, data: st.DataObject 137 ) -> None: 138 """Return no matches when only replacements are present.""" 139 count = n_counts(data, len(items)) 140 check_type = types.PreferredForms(items=items, padding=padding) 141 check = Check( 142 check_type=check_type, 143 path=path, 144 message="{} || {}", 145 engine=engine_from(padding), 146 ) 147 content = " ".join(chain.from_iterable(map(repeat, items.values(), count))) 148 assert list(check.check(content)) == [] 149 150 151 @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.data()) 152 def test_preferred_in_text( 153 items: dict[str, str], padding: Padding, path: str, data: st.DataObject 154 ) -> None: 155 """Return correct matches when elements are present.""" 156 count = n_counts(data, len(items)) 157 check_type = types.PreferredForms(items=items, padding=padding) 158 check = Check( 159 check_type=check_type, 160 path=path, 161 message="{} || {}", 162 engine=engine_from(padding), 163 ) 164 selected_matches = list( 165 chain.from_iterable( 166 repeat(xeger(padding.format(escape(a))), b) 167 for a, b in zip(items.keys(), count, strict=True) 168 ) 169 ) 170 171 # NOTE: rstr.xeger does not work for nonword boundaries, so this is 172 # necessary. See leapfrogonline/rstr#45. 173 if padding == Padding.NONWORDS_IN_TEXT: 174 selected_matches = [f"_{x}_" for x in selected_matches] 175 176 content = " ".join(selected_matches) 177 results = list(check.check(content)) 178 assert len(results) == sum(count) 179 180 cum_count = list(accumulate(count)) 181 values = list(items.values()) 182 for idx in range(len(results)): 183 replacement = values[bisect_right(cum_count, idx)] 184 entry = selected_matches[idx].strip() 185 if padding == Padding.NONWORDS_IN_TEXT: 186 entry = entry[1:-1] 187 188 assert results[idx].check_path == path 189 assert results[idx].message == f"{replacement} || {entry}" 190 assert results[idx].replacements == replacement 191 192 193 # PreferredFormsSimple 194 @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.text()) 195 def test_preferred_s_smoke( 196 items: dict[str, str], padding: Padding, path: str, noise: str 197 ) -> None: 198 """Return no matches when no elements are present.""" 199 assume(all(item not in noise.lower() for item in items)) 200 check_type = types.PreferredFormsSimple(items=items, padding=padding) 201 check = Check( 202 check_type=check_type, 203 path=path, 204 message="{} || {}", 205 engine=engine_from(padding), 206 ) 207 assert list(check.check(noise)) == [] 208 209 210 @given(PREFERRED_ITEMS_CASED_STRATEGY, PADDING_STRATEGY, st.text()) 211 def test_preferred_s_case_sensitive( 212 items: dict[str, str], padding: Padding, path: str 213 ) -> None: 214 """Return correct matches in case-sensitive mode with case variations.""" 215 check = Check( 216 check_type=types.PreferredFormsSimple(items=items, padding=padding), 217 path=path, 218 message="{} || {}", 219 engine=engine_from( 220 padding, engine.RegexOptions(case_insensitive=False) 221 ), 222 ) 223 224 text = " ".join(items.keys()) 225 check_type = types.PreferredFormsSimple(items=items, padding=padding) 226 results = list(check_type.check(text, check)) 227 228 for result, (original, replacement) in zip( 229 results, items.items(), strict=False 230 ): 231 assert result.check_path == path 232 assert result.message == f"{replacement} || {original}" 233 assert result.replacements == replacement 234 235 236 @given(PREFERRED_ITEMS_CASED_STRATEGY, PADDING_STRATEGY, st.text()) 237 def test_preferred_s_case_insensitive( 238 items: dict[str, str], padding: Padding, path: str 239 ) -> None: 240 """Return correct matches in case-insensitive mode with case variations.""" 241 normalised = {k.lower(): v for k, v in items.items()} 242 243 check_type = types.PreferredFormsSimple(items=normalised, padding=padding) 244 check = Check( 245 check_type=check_type, 246 path=path, 247 message="{} || {}", 248 engine=engine_from(padding, engine.RegexOptions(case_insensitive=True)), 249 ) 250 251 text_variations = [ 252 k.swapcase() if i % 2 == 0 else k.upper() 253 for i, k in enumerate(items.keys()) 254 ] 255 256 results = list(check_type.check(" ".join(text_variations), check)) 257 258 for result, text in zip(results, text_variations, strict=False): 259 replacements = normalised[text.lower()] 260 261 assert result.check_path == path 262 assert result.message == f"{replacements} || {text}" 263 assert result.replacements == replacements 264 265 266 @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.data()) 267 def test_preferred_s_values_smoke( 268 items: dict[str, str], padding: Padding, path: str, data: st.DataObject 269 ) -> None: 270 """Return no matches when only replacements are present.""" 271 count = n_counts(data, len(items)) 272 check_type = types.PreferredFormsSimple(items=items, padding=padding) 273 check = Check( 274 check_type=check_type, 275 path=path, 276 message="{} || {}", 277 engine=engine_from(padding), 278 ) 279 content = " ".join(chain.from_iterable(map(repeat, items.values(), count))) 280 assert list(check.check(content)) == [] 281 282 283 @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.data()) 284 def test_preferred_s_in_text( 285 items: dict[str, str], padding: Padding, path: str, data: st.DataObject 286 ) -> None: 287 """Return correct matches when elements are present.""" 288 count = n_counts(data, len(items)) 289 check_type = types.PreferredFormsSimple(items=items, padding=padding) 290 check = Check( 291 check_type=check_type, 292 path=path, 293 message="{} || {}", 294 engine=engine_from(padding), 295 ) 296 selected_matches = list( 297 chain.from_iterable( 298 repeat(xeger(padding.format(escape(a))), b) 299 for a, b in zip(items.keys(), count, strict=True) 300 ) 301 ) 302 303 # NOTE: rstr.xeger does not work for nonword boundaries, so this is 304 # necessary. See leapfrogonline/rstr#45. 305 if padding == Padding.NONWORDS_IN_TEXT: 306 selected_matches = [f"_{x}_" for x in selected_matches] 307 308 content = " ".join(selected_matches) 309 results = list(check.check(content)) 310 assert len(results) == sum(count) 311 312 cum_count = list(accumulate(count)) 313 values = list(items.values()) 314 for idx in range(len(results)): 315 replacement = values[bisect_right(cum_count, idx)] 316 entry = selected_matches[idx].strip() 317 if padding == Padding.NONWORDS_IN_TEXT: 318 entry = entry[1:-1] 319 320 assert results[idx].check_path == path 321 assert results[idx].message == f"{replacement} || {entry}" 322 assert results[idx].replacements == replacement 323 324 325 # Existence 326 EXISTENCE_ITEMS_STRATEGY: st.SearchStrategy[tuple[str, ...]] = st.builds( # pyright: ignore[reportAssignmentType] 327 tuple, 328 st.lists(ITEM_TEXT_STRATEGY, min_size=1, max_size=BATCH_COUNT).filter( 329 lambda x: all( 330 item[0] not in item[1] and item[1] not in item[0] 331 for item in combinations(x, 2) 332 ) 333 ), 334 ) 335 336 337 @st.composite 338 def items_split(draw: st.DrawFn) -> tuple[tuple[str, ...], tuple[str, ...]]: 339 """Split existence items into items and exceptions.""" 340 items = draw(EXISTENCE_ITEMS_STRATEGY) 341 split_idx = draw(st.integers(1, len(items))) 342 return (items[:split_idx], items[split_idx:]) 343 344 345 @given(items_split(), PADDING_STRATEGY, st.text(), st.text()) 346 def test_existence_smoke( 347 items_exceptions: tuple[tuple[str, ...], tuple[str, ...]], 348 padding: Padding, 349 path: str, 350 noise: str, 351 ) -> None: 352 """Return no matches when no elements are present.""" 353 items, exceptions = items_exceptions 354 assume(all(item not in noise.lower() for item in items)) 355 check_type = types.Existence( 356 items=items, padding=padding, exceptions=exceptions 357 ) 358 check = Check( 359 check_type=check_type, 360 path=path, 361 message="{}", 362 engine=engine_from(padding), 363 ) 364 assert list(check.check(noise)) == [] 365 366 367 @given(items_split(), PADDING_STRATEGY, st.text(), st.data()) 368 def test_existence_exceptions_smoke( 369 items_exceptions: tuple[tuple[str, ...], tuple[str, ...]], 370 padding: Padding, 371 path: str, 372 data: st.DataObject, 373 ) -> None: 374 """Return no matches when only exceptions are present.""" 375 items, exceptions = items_exceptions 376 count = n_counts(data, len(exceptions)) 377 check_type = types.Existence( 378 items=items, padding=padding, exceptions=exceptions 379 ) 380 check = Check( 381 check_type=check_type, 382 path=path, 383 message="{}", 384 engine=engine_from(padding), 385 ) 386 content = " ".join(chain.from_iterable(map(repeat, exceptions, count))) 387 assert list(check.check(content)) == [] 388 389 390 @given(items_split(), PADDING_STRATEGY, st.text(), st.data()) 391 def test_existence_in_text( 392 items_exceptions: tuple[tuple[str, ...], tuple[str, ...]], 393 padding: Padding, 394 path: str, 395 data: st.DataObject, 396 ) -> None: 397 """Return correct matches when elements are present.""" 398 items, exceptions = items_exceptions 399 count = n_counts(data, len(items) + len(exceptions)) 400 check_type = types.Existence( 401 items=items, padding=padding, exceptions=exceptions 402 ) 403 check = Check( 404 check_type=check_type, 405 path=path, 406 message="{}", 407 engine=engine_from(padding), 408 ) 409 selected_matches = list( 410 chain.from_iterable( 411 repeat(xeger(padding.format(escape(a))), b) 412 for a, b in zip(chain(items, exceptions), count, strict=True) 413 ) 414 ) 415 416 # NOTE: rstr.xeger does not work for nonword boundaries, so this is 417 # necessary. See leapfrogonline/rstr#45. 418 if padding == Padding.NONWORDS_IN_TEXT: 419 selected_matches = [f"_{x}_" for x in selected_matches] 420 421 content = " ".join(selected_matches) 422 results = list(check.check(content)) 423 cum_count = list(accumulate(count)) 424 assert len(results) == cum_count[len(count) - len(exceptions) - 1] 425 426 for result, raw_entry in zip(results, selected_matches, strict=False): 427 entry = raw_entry.strip() 428 if padding == Padding.NONWORDS_IN_TEXT: 429 entry = entry[1:-1] 430 431 assert result.check_path == path 432 assert result.message == entry 433 assert result.replacements is None 434 435 436 # ExistenceSimple 437 @given(EXISTENCE_ITEMS_STRATEGY, st.text(), st.text()) 438 def test_existence_s_smoke( 439 items_exceptions: tuple[str, ...], 440 path: str, 441 noise: str, 442 ) -> None: 443 """Return no matches when no elements are present.""" 444 pattern, exceptions = items_exceptions[0], items_exceptions[1:] 445 assume(pattern not in noise.lower()) 446 check_type = types.ExistenceSimple(pattern=pattern, exceptions=exceptions) 447 check = Check(check_type=check_type, path=path, message="{}") 448 assert list(check.check(noise)) == [] 449 450 451 @given(EXISTENCE_ITEMS_STRATEGY, st.text(), st.data()) 452 def test_existence_s_exceptions_smoke( 453 items_exceptions: tuple[str, ...], 454 path: str, 455 data: st.DataObject, 456 ) -> None: 457 """Return no matches when only exceptions are present.""" 458 pattern, exceptions = items_exceptions[0], items_exceptions[1:] 459 count = n_counts(data, len(exceptions)) 460 check_type = types.ExistenceSimple(pattern=pattern, exceptions=exceptions) 461 check = Check(check_type=check_type, path=path, message="{}") 462 content = " ".join(chain.from_iterable(map(repeat, exceptions, count))) 463 assert list(check.check(content)) == [] 464 465 466 @given(EXISTENCE_ITEMS_STRATEGY, st.text(), st.data()) 467 def test_existence_s_in_text( 468 items_exceptions: tuple[str, ...], 469 path: str, 470 data: st.DataObject, 471 ) -> None: 472 """Return correct matches when elements are present.""" 473 pattern, exceptions = items_exceptions[0], items_exceptions[1:] 474 count = n_counts(data, len(items_exceptions)) 475 check_type = types.ExistenceSimple(pattern=pattern, exceptions=exceptions) 476 check = Check(check_type=check_type, path=path, message="{}") 477 selected_matches = list( 478 chain.from_iterable( 479 repeat(xeger(escape(a)), b) 480 for a, b in zip(items_exceptions, count, strict=True) 481 ) 482 ) 483 484 content = " ".join(selected_matches) 485 results = list(check.check(content)) 486 cum_count = list(accumulate(count)) 487 assert len(results) == cum_count[len(count) - len(exceptions) - 1] 488 for result, entry in zip(results, selected_matches, strict=False): 489 assert result.check_path == path 490 assert result.message == entry.strip() 491 assert result.replacements is None 492 493 494 # ReverseExistence 495 TOKEN_STRATEGY = st.from_regex( 496 types._DEFAULT_TOKENIZER, # pyright: ignore[reportPrivateUsage] 497 alphabet=f"{ascii_letters}'-_", 498 ) 499 TOKENS_STRATEGY = st.lists(TOKEN_STRATEGY, min_size=1, max_size=BATCH_COUNT) 500 NON_TOKEN_STRATEGY = st.one_of(st.just(r"\w\w?"), st.just(r"\w*\d+\w*")) 501 NON_TOKENS_STRATEGY = st.lists( 502 NON_TOKEN_STRATEGY, min_size=1, max_size=BATCH_COUNT 503 ) 504 REV_ALLOWED_STRATEGY: st.SearchStrategy[set[str]] = st.builds( # pyright: ignore[reportUnknownVariableType] 505 set, EXISTENCE_ITEMS_STRATEGY 506 ) 507 508 # NOTE: traditional smoke is not applicable here, since any tokenizable noise 509 # beyond the allowed set would trigger a result. 510 511 512 @given(REV_ALLOWED_STRATEGY, st.text(), st.data()) 513 def test_rev_existence_allowed_smoke( 514 allowed: set[str], path: str, data: st.DataObject 515 ) -> None: 516 """Return no matches when only allowed items are present.""" 517 count = n_counts(data, len(allowed)) 518 check_type = types.ReverseExistence(allowed=allowed) 519 check = Check(check_type=check_type, path=path, message="{}") 520 content = " ".join(chain.from_iterable(map(repeat, allowed, count))) 521 assert check.check(content) 522 523 524 @given(REV_ALLOWED_STRATEGY, st.text(), NON_TOKENS_STRATEGY) 525 def test_rev_existence_non_token_smoke( 526 allowed: set[str], path: str, non_tokens: list[str] 527 ) -> None: 528 """Return no matches when only non-tokenizable items are present.""" 529 check_type = types.ReverseExistence(allowed=allowed) 530 check = Check(check_type=check_type, path=path, message="{}") 531 assert list(check.check(" ".join(map(xeger, non_tokens)))) == [] 532 533 534 @given(REV_ALLOWED_STRATEGY, st.text(), TOKENS_STRATEGY) 535 def test_rev_existence_forbidden( 536 allowed: set[str], path: str, tokens: list[str] 537 ) -> None: 538 """Return correct matches when items not in the allowed set are present.""" 539 assume(all(token not in allowed for token in tokens)) 540 check_type = types.ReverseExistence(allowed=allowed) 541 check = Check(check_type=check_type, path=path, message="{}") 542 results = list(check.check(" ".join(tokens))) 543 544 assert len(results) == len(tokens) 545 546 for result, token in zip(results, tokens, strict=True): 547 assert result.check_path == path 548 assert result.message == token.strip("'-") 549 assert result.replacements is None