/ tests / test_check_types.py
test_check_types.py
  1  """Verify check types."""
  2  # pyright: reportUnusedCallResult=false
  3  
  4  from bisect import bisect_right
  5  from dataclasses import dataclass
  6  from itertools import accumulate, chain, combinations, repeat
  7  from re import escape
  8  from string import ascii_letters, ascii_lowercase
  9  
 10  from hypothesis import assume, given
 11  from hypothesis import strategies as st
 12  from rstr import xeger
 13  
 14  from proselint.registry.checks import BATCH_COUNT, Check, Padding, engine, types
 15  from tests.common import engine_from
 16  
 17  PADDING_STRATEGY = st.sampled_from(Padding)
 18  
 19  
 20  @dataclass(frozen=True)
 21  class _TermStrategies:
 22      item_text: st.SearchStrategy[str]
 23      term_pair: st.SearchStrategy[tuple[str, str]]
 24      term_pairs: st.SearchStrategy[list[tuple[str, str]]]
 25  
 26      @staticmethod
 27      def _not_substring(pair: tuple[str, str]) -> bool:
 28          a, b = (s.lower() for s in pair)
 29          return a not in b and b not in a
 30  
 31      @classmethod
 32      def from_alphabet(cls, alphabet: str) -> "_TermStrategies":
 33          item: st.SearchStrategy[str] = st.text(min_size=1, alphabet=alphabet)
 34          pair: st.SearchStrategy[tuple[str, str]] = st.tuples(item, item).filter(
 35              cls._not_substring
 36          )
 37          pairs: st.SearchStrategy[list[tuple[str, str]]] = st.lists(
 38              pair, min_size=1, max_size=BATCH_COUNT
 39          ).filter(
 40              lambda xs: all(map(cls._not_substring, combinations(chain(*xs), 2)))
 41          )
 42  
 43          return cls(item, pair, pairs)
 44  
 45  
 46  def n_counts(data: st.DataObject, n: int) -> list[int]:
 47      """Generate a list of repetition counts for a collection of items."""
 48      return data.draw(st.lists(st.integers(1, 100), min_size=n, max_size=n))
 49  
 50  
 51  LOWER_STRATEGIES = _TermStrategies.from_alphabet(ascii_lowercase)
 52  ITEM_TEXT_STRATEGY = LOWER_STRATEGIES.item_text
 53  TERM_PAIR_STRATEGY = LOWER_STRATEGIES.term_pair
 54  TERM_PAIRS_STRATEGY = LOWER_STRATEGIES.term_pairs
 55  
 56  
 57  CASED_STRATEGIES = _TermStrategies.from_alphabet(ascii_letters)
 58  
 59  
 60  # Consistency
 61  @given(TERM_PAIR_STRATEGY, st.text(), st.text())
 62  def test_consistency_smoke(
 63      term_pair: tuple[str, str], path: str, noise: str
 64  ) -> None:
 65      """Return no matches when no elements are present."""
 66      assume(all(term not in noise.lower() for term in term_pair))
 67      check_type = types.Consistency(term_pairs=(term_pair,))
 68      check = Check(check_type=check_type, path=path, message="{} || {}")
 69      assert list(check.check(noise)) == []
 70  
 71  
 72  @given(TERM_PAIR_STRATEGY, st.text(), st.integers(0, 1))
 73  def test_consistency_one_in_text(
 74      term_pair: tuple[str, str], path: str, choice: int
 75  ) -> None:
 76      """Return no matches when only one element is present."""
 77      check_type = types.Consistency(term_pairs=(term_pair,))
 78      check = Check(check_type=check_type, path=path, message="{} || {}")
 79      assert list(check.check(term_pair[choice])) == []
 80  
 81  
 82  @given(
 83      TERM_PAIR_STRATEGY,
 84      st.text(),
 85      st.tuples(st.integers(1, 100), st.integers(1, 100)),
 86  )
 87  def test_consistency_both_in_text(
 88      term_pair: tuple[str, str], path: str, count: tuple[int, int]
 89  ) -> None:
 90      """Return correct matches when both elements are present."""
 91      check_type = types.Consistency(term_pairs=(term_pair,))
 92      check = Check(check_type=check_type, path=path, message="{} || {}")
 93  
 94      text = f"{term_pair[0]} " * count[0] + f"{term_pair[1]} " * count[1]
 95      results = list(check.check(text))
 96      assert len(results) == min(count)
 97      assert all(result.check_path == path for result in results)
 98      idx_minority = count[0] > count[1]
 99      assert all(
100          result.message
101          == f"{term_pair[not idx_minority]} || {term_pair[idx_minority]}"
102          for result in results
103      )
104  
105  
106  # PreferredForms
107  PREFERRED_ITEMS_STRATEGY: st.SearchStrategy[dict[str, str]] = st.builds(  # pyright: ignore[reportUnknownVariableType]
108      dict,
109      LOWER_STRATEGIES.term_pairs,
110  )
111  
112  PREFERRED_ITEMS_CASED_STRATEGY: st.SearchStrategy[dict[str, str]] = st.builds(  # pyright: ignore[reportUnknownVariableType]
113      dict,
114      CASED_STRATEGIES.term_pairs,
115  )
116  
117  
118  @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.text())
119  def test_preferred_smoke(
120      items: dict[str, str], padding: Padding, path: str, noise: str
121  ) -> None:
122      """Return no matches when no elements are present."""
123      assume(all(item not in noise.lower() for item in items))
124      check_type = types.PreferredForms(items=items, padding=padding)
125      check = Check(
126          check_type=check_type,
127          path=path,
128          message="{} || {}",
129          engine=engine_from(padding),
130      )
131      assert list(check.check(noise)) == []
132  
133  
134  @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.data())
135  def test_preferred_values_smoke(
136      items: dict[str, str], padding: Padding, path: str, data: st.DataObject
137  ) -> None:
138      """Return no matches when only replacements are present."""
139      count = n_counts(data, len(items))
140      check_type = types.PreferredForms(items=items, padding=padding)
141      check = Check(
142          check_type=check_type,
143          path=path,
144          message="{} || {}",
145          engine=engine_from(padding),
146      )
147      content = " ".join(chain.from_iterable(map(repeat, items.values(), count)))
148      assert list(check.check(content)) == []
149  
150  
151  @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.data())
152  def test_preferred_in_text(
153      items: dict[str, str], padding: Padding, path: str, data: st.DataObject
154  ) -> None:
155      """Return correct matches when elements are present."""
156      count = n_counts(data, len(items))
157      check_type = types.PreferredForms(items=items, padding=padding)
158      check = Check(
159          check_type=check_type,
160          path=path,
161          message="{} || {}",
162          engine=engine_from(padding),
163      )
164      selected_matches = list(
165          chain.from_iterable(
166              repeat(xeger(padding.format(escape(a))), b)
167              for a, b in zip(items.keys(), count, strict=True)
168          )
169      )
170  
171      # NOTE: rstr.xeger does not work for nonword boundaries, so this is
172      # necessary. See leapfrogonline/rstr#45.
173      if padding == Padding.NONWORDS_IN_TEXT:
174          selected_matches = [f"_{x}_" for x in selected_matches]
175  
176      content = "  ".join(selected_matches)
177      results = list(check.check(content))
178      assert len(results) == sum(count)
179  
180      cum_count = list(accumulate(count))
181      values = list(items.values())
182      for idx in range(len(results)):
183          replacement = values[bisect_right(cum_count, idx)]
184          entry = selected_matches[idx].strip()
185          if padding == Padding.NONWORDS_IN_TEXT:
186              entry = entry[1:-1]
187  
188          assert results[idx].check_path == path
189          assert results[idx].message == f"{replacement} || {entry}"
190          assert results[idx].replacements == replacement
191  
192  
193  # PreferredFormsSimple
194  @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.text())
195  def test_preferred_s_smoke(
196      items: dict[str, str], padding: Padding, path: str, noise: str
197  ) -> None:
198      """Return no matches when no elements are present."""
199      assume(all(item not in noise.lower() for item in items))
200      check_type = types.PreferredFormsSimple(items=items, padding=padding)
201      check = Check(
202          check_type=check_type,
203          path=path,
204          message="{} || {}",
205          engine=engine_from(padding),
206      )
207      assert list(check.check(noise)) == []
208  
209  
210  @given(PREFERRED_ITEMS_CASED_STRATEGY, PADDING_STRATEGY, st.text())
211  def test_preferred_s_case_sensitive(
212      items: dict[str, str], padding: Padding, path: str
213  ) -> None:
214      """Return correct matches in case-sensitive mode with case variations."""
215      check = Check(
216          check_type=types.PreferredFormsSimple(items=items, padding=padding),
217          path=path,
218          message="{} || {}",
219          engine=engine_from(
220              padding, engine.RegexOptions(case_insensitive=False)
221          ),
222      )
223  
224      text = " ".join(items.keys())
225      check_type = types.PreferredFormsSimple(items=items, padding=padding)
226      results = list(check_type.check(text, check))
227  
228      for result, (original, replacement) in zip(
229          results, items.items(), strict=False
230      ):
231          assert result.check_path == path
232          assert result.message == f"{replacement} || {original}"
233          assert result.replacements == replacement
234  
235  
236  @given(PREFERRED_ITEMS_CASED_STRATEGY, PADDING_STRATEGY, st.text())
237  def test_preferred_s_case_insensitive(
238      items: dict[str, str], padding: Padding, path: str
239  ) -> None:
240      """Return correct matches in case-insensitive mode with case variations."""
241      normalised = {k.lower(): v for k, v in items.items()}
242  
243      check_type = types.PreferredFormsSimple(items=normalised, padding=padding)
244      check = Check(
245          check_type=check_type,
246          path=path,
247          message="{} || {}",
248          engine=engine_from(padding, engine.RegexOptions(case_insensitive=True)),
249      )
250  
251      text_variations = [
252          k.swapcase() if i % 2 == 0 else k.upper()
253          for i, k in enumerate(items.keys())
254      ]
255  
256      results = list(check_type.check(" ".join(text_variations), check))
257  
258      for result, text in zip(results, text_variations, strict=False):
259          replacements = normalised[text.lower()]
260  
261          assert result.check_path == path
262          assert result.message == f"{replacements} || {text}"
263          assert result.replacements == replacements
264  
265  
266  @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.data())
267  def test_preferred_s_values_smoke(
268      items: dict[str, str], padding: Padding, path: str, data: st.DataObject
269  ) -> None:
270      """Return no matches when only replacements are present."""
271      count = n_counts(data, len(items))
272      check_type = types.PreferredFormsSimple(items=items, padding=padding)
273      check = Check(
274          check_type=check_type,
275          path=path,
276          message="{} || {}",
277          engine=engine_from(padding),
278      )
279      content = " ".join(chain.from_iterable(map(repeat, items.values(), count)))
280      assert list(check.check(content)) == []
281  
282  
283  @given(PREFERRED_ITEMS_STRATEGY, PADDING_STRATEGY, st.text(), st.data())
284  def test_preferred_s_in_text(
285      items: dict[str, str], padding: Padding, path: str, data: st.DataObject
286  ) -> None:
287      """Return correct matches when elements are present."""
288      count = n_counts(data, len(items))
289      check_type = types.PreferredFormsSimple(items=items, padding=padding)
290      check = Check(
291          check_type=check_type,
292          path=path,
293          message="{} || {}",
294          engine=engine_from(padding),
295      )
296      selected_matches = list(
297          chain.from_iterable(
298              repeat(xeger(padding.format(escape(a))), b)
299              for a, b in zip(items.keys(), count, strict=True)
300          )
301      )
302  
303      # NOTE: rstr.xeger does not work for nonword boundaries, so this is
304      # necessary. See leapfrogonline/rstr#45.
305      if padding == Padding.NONWORDS_IN_TEXT:
306          selected_matches = [f"_{x}_" for x in selected_matches]
307  
308      content = "  ".join(selected_matches)
309      results = list(check.check(content))
310      assert len(results) == sum(count)
311  
312      cum_count = list(accumulate(count))
313      values = list(items.values())
314      for idx in range(len(results)):
315          replacement = values[bisect_right(cum_count, idx)]
316          entry = selected_matches[idx].strip()
317          if padding == Padding.NONWORDS_IN_TEXT:
318              entry = entry[1:-1]
319  
320          assert results[idx].check_path == path
321          assert results[idx].message == f"{replacement} || {entry}"
322          assert results[idx].replacements == replacement
323  
324  
325  # Existence
326  EXISTENCE_ITEMS_STRATEGY: st.SearchStrategy[tuple[str, ...]] = st.builds(  # pyright: ignore[reportAssignmentType]
327      tuple,
328      st.lists(ITEM_TEXT_STRATEGY, min_size=1, max_size=BATCH_COUNT).filter(
329          lambda x: all(
330              item[0] not in item[1] and item[1] not in item[0]
331              for item in combinations(x, 2)
332          )
333      ),
334  )
335  
336  
337  @st.composite
338  def items_split(draw: st.DrawFn) -> tuple[tuple[str, ...], tuple[str, ...]]:
339      """Split existence items into items and exceptions."""
340      items = draw(EXISTENCE_ITEMS_STRATEGY)
341      split_idx = draw(st.integers(1, len(items)))
342      return (items[:split_idx], items[split_idx:])
343  
344  
345  @given(items_split(), PADDING_STRATEGY, st.text(), st.text())
346  def test_existence_smoke(
347      items_exceptions: tuple[tuple[str, ...], tuple[str, ...]],
348      padding: Padding,
349      path: str,
350      noise: str,
351  ) -> None:
352      """Return no matches when no elements are present."""
353      items, exceptions = items_exceptions
354      assume(all(item not in noise.lower() for item in items))
355      check_type = types.Existence(
356          items=items, padding=padding, exceptions=exceptions
357      )
358      check = Check(
359          check_type=check_type,
360          path=path,
361          message="{}",
362          engine=engine_from(padding),
363      )
364      assert list(check.check(noise)) == []
365  
366  
367  @given(items_split(), PADDING_STRATEGY, st.text(), st.data())
368  def test_existence_exceptions_smoke(
369      items_exceptions: tuple[tuple[str, ...], tuple[str, ...]],
370      padding: Padding,
371      path: str,
372      data: st.DataObject,
373  ) -> None:
374      """Return no matches when only exceptions are present."""
375      items, exceptions = items_exceptions
376      count = n_counts(data, len(exceptions))
377      check_type = types.Existence(
378          items=items, padding=padding, exceptions=exceptions
379      )
380      check = Check(
381          check_type=check_type,
382          path=path,
383          message="{}",
384          engine=engine_from(padding),
385      )
386      content = " ".join(chain.from_iterable(map(repeat, exceptions, count)))
387      assert list(check.check(content)) == []
388  
389  
390  @given(items_split(), PADDING_STRATEGY, st.text(), st.data())
391  def test_existence_in_text(
392      items_exceptions: tuple[tuple[str, ...], tuple[str, ...]],
393      padding: Padding,
394      path: str,
395      data: st.DataObject,
396  ) -> None:
397      """Return correct matches when elements are present."""
398      items, exceptions = items_exceptions
399      count = n_counts(data, len(items) + len(exceptions))
400      check_type = types.Existence(
401          items=items, padding=padding, exceptions=exceptions
402      )
403      check = Check(
404          check_type=check_type,
405          path=path,
406          message="{}",
407          engine=engine_from(padding),
408      )
409      selected_matches = list(
410          chain.from_iterable(
411              repeat(xeger(padding.format(escape(a))), b)
412              for a, b in zip(chain(items, exceptions), count, strict=True)
413          )
414      )
415  
416      # NOTE: rstr.xeger does not work for nonword boundaries, so this is
417      # necessary. See leapfrogonline/rstr#45.
418      if padding == Padding.NONWORDS_IN_TEXT:
419          selected_matches = [f"_{x}_" for x in selected_matches]
420  
421      content = "  ".join(selected_matches)
422      results = list(check.check(content))
423      cum_count = list(accumulate(count))
424      assert len(results) == cum_count[len(count) - len(exceptions) - 1]
425  
426      for result, raw_entry in zip(results, selected_matches, strict=False):
427          entry = raw_entry.strip()
428          if padding == Padding.NONWORDS_IN_TEXT:
429              entry = entry[1:-1]
430  
431          assert result.check_path == path
432          assert result.message == entry
433          assert result.replacements is None
434  
435  
436  # ExistenceSimple
437  @given(EXISTENCE_ITEMS_STRATEGY, st.text(), st.text())
438  def test_existence_s_smoke(
439      items_exceptions: tuple[str, ...],
440      path: str,
441      noise: str,
442  ) -> None:
443      """Return no matches when no elements are present."""
444      pattern, exceptions = items_exceptions[0], items_exceptions[1:]
445      assume(pattern not in noise.lower())
446      check_type = types.ExistenceSimple(pattern=pattern, exceptions=exceptions)
447      check = Check(check_type=check_type, path=path, message="{}")
448      assert list(check.check(noise)) == []
449  
450  
451  @given(EXISTENCE_ITEMS_STRATEGY, st.text(), st.data())
452  def test_existence_s_exceptions_smoke(
453      items_exceptions: tuple[str, ...],
454      path: str,
455      data: st.DataObject,
456  ) -> None:
457      """Return no matches when only exceptions are present."""
458      pattern, exceptions = items_exceptions[0], items_exceptions[1:]
459      count = n_counts(data, len(exceptions))
460      check_type = types.ExistenceSimple(pattern=pattern, exceptions=exceptions)
461      check = Check(check_type=check_type, path=path, message="{}")
462      content = " ".join(chain.from_iterable(map(repeat, exceptions, count)))
463      assert list(check.check(content)) == []
464  
465  
466  @given(EXISTENCE_ITEMS_STRATEGY, st.text(), st.data())
467  def test_existence_s_in_text(
468      items_exceptions: tuple[str, ...],
469      path: str,
470      data: st.DataObject,
471  ) -> None:
472      """Return correct matches when elements are present."""
473      pattern, exceptions = items_exceptions[0], items_exceptions[1:]
474      count = n_counts(data, len(items_exceptions))
475      check_type = types.ExistenceSimple(pattern=pattern, exceptions=exceptions)
476      check = Check(check_type=check_type, path=path, message="{}")
477      selected_matches = list(
478          chain.from_iterable(
479              repeat(xeger(escape(a)), b)
480              for a, b in zip(items_exceptions, count, strict=True)
481          )
482      )
483  
484      content = "  ".join(selected_matches)
485      results = list(check.check(content))
486      cum_count = list(accumulate(count))
487      assert len(results) == cum_count[len(count) - len(exceptions) - 1]
488      for result, entry in zip(results, selected_matches, strict=False):
489          assert result.check_path == path
490          assert result.message == entry.strip()
491          assert result.replacements is None
492  
493  
494  # ReverseExistence
495  TOKEN_STRATEGY = st.from_regex(
496      types._DEFAULT_TOKENIZER,  # pyright: ignore[reportPrivateUsage]
497      alphabet=f"{ascii_letters}'-_",
498  )
499  TOKENS_STRATEGY = st.lists(TOKEN_STRATEGY, min_size=1, max_size=BATCH_COUNT)
500  NON_TOKEN_STRATEGY = st.one_of(st.just(r"\w\w?"), st.just(r"\w*\d+\w*"))
501  NON_TOKENS_STRATEGY = st.lists(
502      NON_TOKEN_STRATEGY, min_size=1, max_size=BATCH_COUNT
503  )
504  REV_ALLOWED_STRATEGY: st.SearchStrategy[set[str]] = st.builds(  # pyright: ignore[reportUnknownVariableType]
505      set, EXISTENCE_ITEMS_STRATEGY
506  )
507  
508  # NOTE: traditional smoke is not applicable here, since any tokenizable noise
509  # beyond the allowed set would trigger a result.
510  
511  
512  @given(REV_ALLOWED_STRATEGY, st.text(), st.data())
513  def test_rev_existence_allowed_smoke(
514      allowed: set[str], path: str, data: st.DataObject
515  ) -> None:
516      """Return no matches when only allowed items are present."""
517      count = n_counts(data, len(allowed))
518      check_type = types.ReverseExistence(allowed=allowed)
519      check = Check(check_type=check_type, path=path, message="{}")
520      content = " ".join(chain.from_iterable(map(repeat, allowed, count)))
521      assert check.check(content)
522  
523  
524  @given(REV_ALLOWED_STRATEGY, st.text(), NON_TOKENS_STRATEGY)
525  def test_rev_existence_non_token_smoke(
526      allowed: set[str], path: str, non_tokens: list[str]
527  ) -> None:
528      """Return no matches when only non-tokenizable items are present."""
529      check_type = types.ReverseExistence(allowed=allowed)
530      check = Check(check_type=check_type, path=path, message="{}")
531      assert list(check.check(" ".join(map(xeger, non_tokens)))) == []
532  
533  
534  @given(REV_ALLOWED_STRATEGY, st.text(), TOKENS_STRATEGY)
535  def test_rev_existence_forbidden(
536      allowed: set[str], path: str, tokens: list[str]
537  ) -> None:
538      """Return correct matches when items not in the allowed set are present."""
539      assume(all(token not in allowed for token in tokens))
540      check_type = types.ReverseExistence(allowed=allowed)
541      check = Check(check_type=check_type, path=path, message="{}")
542      results = list(check.check(" ".join(tokens)))
543  
544      assert len(results) == len(tokens)
545  
546      for result, token in zip(results, tokens, strict=True):
547          assert result.check_path == path
548          assert result.message == token.strip("'-")
549          assert result.replacements is None