/ tests / gateway / conftest.py
conftest.py
  1  """Shared fixtures for gateway tests.
  2  
  3  The ``_ensure_telegram_mock`` helper guarantees that a minimal mock of
  4  the ``telegram`` package is registered in :data:`sys.modules` **before**
  5  any test file triggers ``from gateway.platforms.telegram import ...``.
  6  
  7  Without this, ``pytest-xdist`` workers that happen to collect
  8  ``test_telegram_caption_merge.py`` (bare top-level import, no per-file
  9  mock) first will cache ``ChatType = None`` from the production
 10  ImportError fallback, causing 30+ downstream test failures wherever
 11  ``ChatType.GROUP`` / ``ChatType.SUPERGROUP`` is accessed.
 12  
 13  Individual test files may still call their own ``_ensure_telegram_mock``
 14  — it short-circuits when the mock is already present.
 15  
 16  Plugin-adapter anti-pattern guard
 17  ---------------------------------
 18  Tests for platform plugins (``plugins/platforms/<name>/adapter.py``)
 19  must load the adapter via
 20  :func:`tests.gateway._plugin_adapter_loader.load_plugin_adapter`, not by
 21  adding the plugin directory to ``sys.path`` and doing a bare
 22  ``from adapter import ...``. The guard at the bottom of this file
 23  scans test module ASTs at collection time and fails collection with a
 24  pointer to the helper if the anti-pattern is detected.
 25  
 26  Rationale: every plugin ships its own ``adapter.py``, and two tests each
 27  inserting their plugin dir on ``sys.path[0]`` race for
 28  ``sys.modules["adapter"]`` in the same xdist worker. Whichever collects
 29  first wins; the other fails with ``ImportError``, and the polluted
 30  ``sys.path`` cascades into unrelated tests. See PR #17764 for the
 31  incident.
 32  """
 33  
 34  import ast
 35  import sys
 36  from pathlib import Path
 37  from unittest.mock import MagicMock
 38  
 39  import pytest
 40  
 41  
 42  def _ensure_telegram_mock() -> None:
 43      """Install a comprehensive telegram mock in sys.modules.
 44  
 45      Idempotent — skips when the real library is already imported.
 46      Uses ``sys.modules[name] = mod`` (overwrite) instead of
 47      ``setdefault`` so it wins even if a partial/broken import
 48      already cached a module with ``ChatType = None``.
 49      """
 50      if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
 51          return  # Real library is installed — nothing to mock
 52  
 53      mod = MagicMock()
 54      mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
 55      mod.constants.ParseMode.MARKDOWN = "Markdown"
 56      mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
 57      mod.constants.ParseMode.HTML = "HTML"
 58      mod.constants.ChatType.PRIVATE = "private"
 59      mod.constants.ChatType.GROUP = "group"
 60      mod.constants.ChatType.SUPERGROUP = "supergroup"
 61      mod.constants.ChatType.CHANNEL = "channel"
 62  
 63      # Real exception classes so ``except (NetworkError, ...)`` clauses
 64      # in production code don't blow up with TypeError.
 65      mod.error.NetworkError = type("NetworkError", (OSError,), {})
 66      mod.error.TimedOut = type("TimedOut", (OSError,), {})
 67      mod.error.BadRequest = type("BadRequest", (Exception,), {})
 68      mod.error.Forbidden = type("Forbidden", (Exception,), {})
 69      mod.error.InvalidToken = type("InvalidToken", (Exception,), {})
 70      mod.error.RetryAfter = type("RetryAfter", (Exception,), {"retry_after": 1})
 71      mod.error.Conflict = type("Conflict", (Exception,), {})
 72  
 73      # Update.ALL_TYPES used in start_polling()
 74      mod.Update.ALL_TYPES = []
 75  
 76      for name in (
 77          "telegram",
 78          "telegram.ext",
 79          "telegram.constants",
 80          "telegram.request",
 81      ):
 82          sys.modules[name] = mod
 83      sys.modules["telegram.error"] = mod.error
 84  
 85  
 86  def _ensure_discord_mock() -> None:
 87      """Install a comprehensive discord mock in sys.modules.
 88  
 89      Idempotent — skips when the real library is already imported.
 90      Uses ``sys.modules[name] = mod`` (overwrite) instead of
 91      ``setdefault`` so it wins even if a partial/broken import already
 92      cached the module.
 93  
 94      This mock is comprehensive — it includes **all** attributes needed by
 95      every gateway discord test file.  Individual test files should call
 96      this function (it short-circuits when already present) rather than
 97      maintaining their own mock setup.
 98      """
 99      if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
100          return  # Real library is installed — nothing to mock
101  
102      from types import SimpleNamespace
103  
104      discord_mod = MagicMock()
105      discord_mod.Intents.default.return_value = MagicMock()
106      discord_mod.Client = MagicMock
107      discord_mod.File = MagicMock
108      discord_mod.DMChannel = type("DMChannel", (), {})
109      discord_mod.Thread = type("Thread", (), {})
110      discord_mod.ForumChannel = type("ForumChannel", (), {})
111      discord_mod.Interaction = object
112      discord_mod.Message = type("Message", (), {})
113  
114      # Embed: accept the kwargs production code / tests use
115      # (title, description, color). MagicMock auto-attributes work too,
116      # but some tests construct and inspect .title/.description directly.
117      class _FakeEmbed:
118          def __init__(self, *, title=None, description=None, color=None, **_):
119              self.title = title
120              self.description = description
121              self.color = color
122      discord_mod.Embed = _FakeEmbed
123  
124      # ui.View / ui.Select / ui.Button: real classes (not MagicMock) so
125      # tests that subclass ModelPickerView / iterate .children / clear
126      # items work.
127      class _FakeView:
128          def __init__(self, timeout=None):
129              self.timeout = timeout
130              self.children = []
131          def add_item(self, item):
132              self.children.append(item)
133          def clear_items(self):
134              self.children.clear()
135  
136      class _FakeSelect:
137          def __init__(self, *, placeholder=None, options=None, custom_id=None, **_):
138              self.placeholder = placeholder
139              self.options = options or []
140              self.custom_id = custom_id
141              self.callback = None
142              self.disabled = False
143  
144      class _FakeButton:
145          def __init__(self, *, label=None, style=None, custom_id=None, emoji=None,
146                       url=None, disabled=False, row=None, sku_id=None, **_):
147              self.label = label
148              self.style = style
149              self.custom_id = custom_id
150              self.emoji = emoji
151              self.url = url
152              self.disabled = disabled
153              self.row = row
154              self.sku_id = sku_id
155              self.callback = None
156  
157      class _FakeSelectOption:
158          def __init__(self, *, label=None, value=None, description=None, **_):
159              self.label = label
160              self.value = value
161              self.description = description
162      discord_mod.SelectOption = _FakeSelectOption
163  
164      discord_mod.ui = SimpleNamespace(
165          View=_FakeView,
166          Select=_FakeSelect,
167          Button=_FakeButton,
168          button=lambda *a, **k: (lambda fn: fn),
169      )
170      discord_mod.ButtonStyle = SimpleNamespace(
171          success=1, primary=2, secondary=2, danger=3,
172          green=1, grey=2, blurple=2, red=3,
173      )
174      discord_mod.Color = SimpleNamespace(
175          orange=lambda: 1, green=lambda: 2, blue=lambda: 3,
176          red=lambda: 4, purple=lambda: 5, greyple=lambda: 6,
177      )
178  
179      # app_commands — needed by _register_slash_commands auto-registration
180      class _FakeGroup:
181          def __init__(self, *, name, description, parent=None):
182              self.name = name
183              self.description = description
184              self.parent = parent
185              self._children: dict = {}
186              if parent is not None:
187                  parent.add_command(self)
188  
189          def add_command(self, cmd):
190              self._children[cmd.name] = cmd
191  
192      class _FakeCommand:
193          def __init__(self, *, name, description, callback, parent=None):
194              self.name = name
195              self.description = description
196              self.callback = callback
197              self.parent = parent
198  
199      discord_mod.app_commands = SimpleNamespace(
200          describe=lambda **kwargs: (lambda fn: fn),
201          choices=lambda **kwargs: (lambda fn: fn),
202          Choice=lambda **kwargs: SimpleNamespace(**kwargs),
203          Group=_FakeGroup,
204          Command=_FakeCommand,
205      )
206  
207      ext_mod = MagicMock()
208      commands_mod = MagicMock()
209      commands_mod.Bot = MagicMock
210      ext_mod.commands = commands_mod
211  
212      for name in ("discord", "discord.ext", "discord.ext.commands"):
213          sys.modules[name] = discord_mod
214      sys.modules["discord.ext"] = ext_mod
215      sys.modules["discord.ext.commands"] = commands_mod
216  
217  
218  # Run at collection time — before any test file's module-level imports.
219  _ensure_telegram_mock()
220  _ensure_discord_mock()
221  
222  
223  # ---------------------------------------------------------------------------
224  # Plugin-adapter anti-pattern guard
225  # ---------------------------------------------------------------------------
226  
227  _GATEWAY_DIR = Path(__file__).resolve().parent
228  _GUARD_HINT = (
229      "Plugin adapter tests must use "
230      "``from tests.gateway._plugin_adapter_loader import load_plugin_adapter`` "
231      "and call ``load_plugin_adapter('<plugin_name>')`` instead of inserting "
232      "``plugins/platforms/<name>/`` on sys.path and doing a bare ``import "
233      "adapter`` / ``from adapter import ...``. See the 'Plugin-adapter "
234      "anti-pattern guard' docstring in tests/gateway/conftest.py."
235  )
236  
237  
238  def _scan_for_plugin_adapter_antipattern(source: str) -> list[str]:
239      """Return a list of offending-line descriptions, or [] if clean.
240  
241      Flags two things:
242      1. ``sys.path.insert(..., <something mentioning 'plugins/platforms'>)``
243      2. ``import adapter`` or ``from adapter import ...`` at module level.
244      """
245      try:
246          tree = ast.parse(source)
247      except SyntaxError:
248          return []  # Let pytest surface the real syntax error.
249  
250      offenses: list[str] = []
251  
252      for node in ast.walk(tree):
253          # sys.path.insert(0, ".../plugins/platforms/...")
254          if isinstance(node, ast.Call):
255              func = node.func
256              target_name: str | None = None
257              if isinstance(func, ast.Attribute):
258                  # sys.path.insert / sys.path.append
259                  if (
260                      isinstance(func.value, ast.Attribute)
261                      and isinstance(func.value.value, ast.Name)
262                      and func.value.value.id == "sys"
263                      and func.value.attr == "path"
264                      and func.attr in ("insert", "append", "extend")
265                  ):
266                      target_name = f"sys.path.{func.attr}"
267  
268              if target_name is not None:
269                  call_src = ast.unparse(node)
270                  # Match both the string-literal form
271                  # ``.../plugins/platforms/...`` and the Path-operator form
272                  # ``Path(...) / 'plugins' / 'platforms' / ...`` that
273                  # plugin tests typically use.
274                  _src_no_ws = "".join(call_src.split())
275                  if (
276                      "plugins/platforms" in call_src
277                      or "plugins\\platforms" in call_src
278                      or "'plugins'/'platforms'" in _src_no_ws
279                      or '"plugins"/"platforms"' in _src_no_ws
280                  ):
281                      offenses.append(
282                          f"line {node.lineno}: {target_name}(...) points into "
283                          f"plugins/platforms/"
284                      )
285  
286      # Bare `import adapter` / `from adapter import ...` anywhere (module level
287      # OR inside functions — both are symptoms of the same pattern).
288      for node in ast.walk(tree):
289          if isinstance(node, ast.Import):
290              for alias in node.names:
291                  if alias.name == "adapter":
292                      offenses.append(
293                          f"line {node.lineno}: ``import adapter`` "
294                          f"(bare — resolves to whichever plugin's adapter.py "
295                          f"is first on sys.path)"
296                      )
297          elif isinstance(node, ast.ImportFrom):
298              if node.module == "adapter" and node.level == 0:
299                  offenses.append(
300                      f"line {node.lineno}: ``from adapter import ...`` "
301                      f"(bare — resolves to whichever plugin's adapter.py "
302                      f"is first on sys.path)"
303                  )
304  
305      return offenses
306  
307  
308  def pytest_configure(config):
309      """Reject plugin-adapter tests that use the sys.path anti-pattern.
310  
311      Runs once per pytest session on the controller, BEFORE any xdist
312      worker is spawned. If any file under ``tests/gateway/`` matches the
313      anti-pattern, we fail the whole session with a clear message —
314      before a polluted ``sys.path`` can cascade across workers.
315      """
316      # Only run on the xdist controller (or in non-xdist runs). Skip on
317      # worker subprocesses so we don't scan the filesystem N times.
318      if hasattr(config, "workerinput"):
319          return
320  
321      violations: list[str] = []
322      for path in _GATEWAY_DIR.rglob("test_*.py"):
323          if path.name in {"_plugin_adapter_loader.py", "conftest.py"}:
324              continue
325          try:
326              source = path.read_text(encoding="utf-8")
327          except OSError:
328              continue
329          if "adapter" not in source and "plugins/platforms" not in source:
330              continue
331          offenses = _scan_for_plugin_adapter_antipattern(source)
332          if offenses:
333              violations.append(
334                  f"  {path.relative_to(_GATEWAY_DIR.parent.parent)}:\n    "
335                  + "\n    ".join(offenses)
336              )
337  
338      if violations:
339          raise pytest.UsageError(
340              "Plugin-adapter-import anti-pattern detected in gateway tests:\n"
341              + "\n".join(violations)
342              + "\n\n"
343              + _GUARD_HINT
344          )
345