/ scripts / regen_skill.py
regen_skill.py
  1  #!/usr/bin/env python3
  2  # -*- coding: utf-8 -*-
  3  """Regenerate KB-derived sections in pyod skill files.
  4  
  5  Reads pyod.utils.knowledge.algorithms and rewrites the content between
  6  <!-- BEGIN KB-DERIVED: <section-name> --> and <!-- END KB-DERIVED: <section-name> -->
  7  markers in every *.md file under pyod/skills/.
  8  
  9  Hand-written content (everything outside the markers) is left untouched.
 10  
 11  Usage:
 12      python scripts/regen_skill.py              # regenerate in place
 13      python scripts/regen_skill.py --check      # dry-run; exit 1 if any file would change
 14      python scripts/regen_skill.py --verbose    # print every regenerated section
 15  """
 16  from __future__ import annotations
 17  
 18  import argparse
 19  import re
 20  import sys
 21  from collections import Counter
 22  from pathlib import Path
 23  
 24  REPO_ROOT = Path(__file__).resolve().parents[1]
 25  SKILLS_DIR = REPO_ROOT / "pyod" / "skills"
 26  
 27  # Regex to match a KB-DERIVED block. Captures the section name and the
 28  # old body (everything between the markers). DOTALL so . matches newline.
 29  _BLOCK_RE = re.compile(
 30      r"<!-- BEGIN KB-DERIVED: ([a-z0-9_\-]+) -->\n(.*?)<!-- END KB-DERIVED: \1 -->",
 31      re.DOTALL,
 32  )
 33  
 34  # Map raw KB requires tokens (Python package names) to user-facing pyproject
 35  # extras. The KB stores the runtime dependency name (e.g. "torch_geometric"),
 36  # but the install hint in the rendered skill must use the extra exposed in
 37  # pyproject.toml (e.g. "graph"). Unknown tokens raise KeyError so the
 38  # maintainer is forced to keep this mapping in sync with pyproject.toml.
 39  _REQUIRES_TO_EXTRA = {
 40      "torch_geometric": "graph",
 41      "torch": "torch",
 42      "xgboost": "xgboost",
 43      "suod": "suod",
 44      "combo": "combo",
 45  }
 46  
 47  
 48  def _load_kb():
 49      """Load pyod.utils.knowledge.algorithms once."""
 50      from pyod.utils.ad_engine import ADEngine
 51      return ADEngine().kb.algorithms
 52  
 53  
 54  def _format_complexity(complexity):
 55      """Format the KB complexity field as a one-line readable string.
 56  
 57      The KB stores complexity as a dict ``{"time": ..., "space": ...}``.
 58      Older entries may store a plain string. Empty/missing returns "?".
 59      """
 60      if not complexity:
 61          return "?"
 62      if isinstance(complexity, str):
 63          return complexity
 64      if isinstance(complexity, dict):
 65          time_s = complexity.get("time")
 66          space_s = complexity.get("space")
 67          parts = []
 68          if time_s:
 69              parts.append(f"time {time_s}")
 70          if space_s:
 71              parts.append(f"space {space_s}")
 72          return ", ".join(parts) if parts else "?"
 73      return str(complexity)
 74  
 75  
 76  def _format_paper(paper):
 77      """Format the KB paper field as a short human-facing reference.
 78  
 79      The KB stores paper as a dict ``{"id": ..., "short": ...}`` where
 80      ``short`` is e.g. "Liu et al., ICDM 2008". Older entries may be a
 81      plain string. Returns empty string for missing/empty paper.
 82      """
 83      if not paper:
 84          return ""
 85      if isinstance(paper, str):
 86          return paper
 87      if isinstance(paper, dict):
 88          return paper.get("short", "") or paper.get("id", "")
 89      return str(paper)
 90  
 91  
 92  def _format_requires(requires):
 93      """Format the KB requires list as comma-joined ``pyod[extra]`` hints.
 94  
 95      Each token must be in ``_REQUIRES_TO_EXTRA``; unknown tokens raise
 96      KeyError so the maintainer keeps the mapping current. Returns empty
 97      string for an empty/missing list.
 98      """
 99      if not requires:
100          return ""
101      extras = []
102      for req in requires:
103          if req not in _REQUIRES_TO_EXTRA:
104              raise KeyError(
105                  f"Unknown KB requires token {req!r}. Add it to "
106                  f"_REQUIRES_TO_EXTRA in {__file__} (map to the matching "
107                  f"pyproject.toml extra)."
108              )
109          extras.append(f"pyod[{_REQUIRES_TO_EXTRA[req]}]")
110      return ", ".join(extras)
111  
112  
113  def _select_algos(kb, modalities):
114      """Return deduplicated (name, algo) tuples whose data_types include any modality.
115  
116      Used by both single-modality renderers and the combined text/image
117      renderer. Iteration is in sorted-name order, and each detector
118      appears at most once even if it lives in multiple modalities (e.g.
119      ``EmbeddingOD`` is in both ``text`` and ``image``).
120      """
121      seen = set()
122      items = []
123      for name, algo in sorted(kb.items()):
124          data_types = algo.get("data_types", [])
125          if any(m in data_types for m in modalities) and name not in seen:
126              seen.add(name)
127              items.append((name, algo))
128      return items
129  
130  
131  def _render_bullets(items):
132      """Render a list of (name, algo) tuples as a markdown bullet list.
133  
134      Each bullet contains: name, full_name, complexity, best_for,
135      avoid_when, requires (as ``pyod[extra]``), paper. Fields that are
136      empty in the KB are silently omitted.
137      """
138      if not items:
139          return "_No detectors registered._\n"
140      lines = []
141      for name, algo in items:
142          full = algo.get("full_name", name)
143          complexity = _format_complexity(algo.get("complexity"))
144          best_for = algo.get("best_for", "")
145          avoid_when = algo.get("avoid_when", "")
146          requires = _format_requires(algo.get("requires", []))
147          paper = _format_paper(algo.get("paper"))
148          line = f"- **{name}** ({full}) — complexity: {complexity}"
149          if best_for:
150              line += f"; best for: {best_for}"
151          if avoid_when:
152              line += f"; avoid when: {avoid_when}"
153          if requires:
154              line += f"; requires: {requires}"
155          if paper:
156              line += f"; paper: {paper}"
157          lines.append(line)
158      return "\n".join(lines) + "\n"
159  
160  
161  def _render_detector_list(kb, modality):
162      """Render a markdown bullet list of detectors for a single modality."""
163      return _render_bullets(_select_algos(kb, [modality]))
164  
165  
166  def _render_combined_detector_list(kb, modalities):
167      """Render a deduplicated bullet list across multiple modalities."""
168      return _render_bullets(_select_algos(kb, modalities))
169  
170  
171  def _render_total_count(kb):
172      """Render a one-line summary of total detector counts by modality."""
173      counts = Counter()
174      for algo in kb.values():
175          for dt in algo.get("data_types", []):
176              counts[dt] += 1
177      total = len(kb)
178      preferred = ["tabular", "time_series", "graph", "text", "image", "multimodal"]
179      parts = []
180      for key in preferred:
181          if counts.get(key):
182              label = key.replace("_", "-")
183              parts.append(f"{counts[key]} {label}")
184      for key, val in counts.items():
185          if key not in preferred and val:
186              parts.append(f"{val} {key}")
187      breakdown = ", ".join(parts) if parts else "none"
188      return f"PyOD ships **{total}** detectors total ({breakdown}).\n"
189  
190  
191  def _render_benchmark_list(kb):
192      """Render a deduplicated list of benchmark refs cited in the KB."""
193      refs = set()
194      for algo in kb.values():
195          for ref in algo.get("benchmark_refs", []) or []:
196              refs.add(ref)
197      if not refs:
198          return "_No benchmark refs registered._\n"
199      lines = ["Benchmarks referenced by PyOD detectors:"]
200      for ref in sorted(refs):
201          lines.append(f"- {ref}")
202      return "\n".join(lines) + "\n"
203  
204  
205  # Section name → renderer function
206  _SECTION_RENDERERS = {
207      "tabular-detector-list": lambda kb: _render_detector_list(kb, "tabular"),
208      "time-series-detector-list": lambda kb: _render_detector_list(kb, "time_series"),
209      "graph-detector-list": lambda kb: _render_detector_list(kb, "graph"),
210      "text-image-detector-list": lambda kb: _render_combined_detector_list(
211          kb, ["text", "image"]
212      ),
213      "total-detector-count": _render_total_count,
214      "benchmark-list": _render_benchmark_list,
215  }
216  
217  
218  def render_section(section_name):
219      """Render a named KB-derived section. Raises KeyError on unknown names."""
220      renderer = _SECTION_RENDERERS[section_name]
221      kb = _load_kb()
222      return renderer(kb)
223  
224  
225  def regen_file(path):
226      """Regenerate every KB-DERIVED block in a single file in place.
227  
228      Returns True if the file was modified, False if it was already up to date.
229      """
230      text = path.read_text(encoding="utf-8")
231      kb = _load_kb()
232  
233      def _replace(match):
234          section_name = match.group(1)
235          if section_name not in _SECTION_RENDERERS:
236              raise KeyError(
237                  f"Unknown KB-DERIVED section name {section_name!r} in {path}. "
238                  f"Add it to _SECTION_RENDERERS in {__file__}."
239              )
240          renderer = _SECTION_RENDERERS[section_name]
241          new_body = renderer(kb)
242          if not new_body.endswith("\n"):
243              new_body += "\n"
244          return (
245              f"<!-- BEGIN KB-DERIVED: {section_name} -->\n"
246              f"{new_body}"
247              f"<!-- END KB-DERIVED: {section_name} -->"
248          )
249  
250      new_text = _BLOCK_RE.sub(_replace, text)
251      if new_text == text:
252          return False
253      path.write_text(new_text, encoding="utf-8")
254      return True
255  
256  
257  def find_skill_files():
258      """Yield every *.md file under pyod/skills/ recursively."""
259      for path in sorted(SKILLS_DIR.rglob("*.md")):
260          yield path
261  
262  
263  def regen_all(verbose=False):
264      """Regenerate every skill file. Returns the number of files modified."""
265      modified = 0
266      for path in find_skill_files():
267          if regen_file(path):
268              modified += 1
269              if verbose:
270                  rel = path.relative_to(REPO_ROOT)
271                  print(f"regenerated: {rel}")
272      return modified
273  
274  
275  def check_files(paths):
276      """Dry-run: return 0 if no file would change, 1 otherwise.
277  
278      Used by --check mode and by tests. Does NOT mutate any file: it reads
279      the file, computes what regen_file would write, compares, and reports.
280      """
281      kb = _load_kb()
282  
283      def _replace(match):
284          section_name = match.group(1)
285          if section_name not in _SECTION_RENDERERS:
286              raise KeyError(f"Unknown KB-DERIVED section name {section_name!r}")
287          renderer = _SECTION_RENDERERS[section_name]
288          new_body = renderer(kb)
289          if not new_body.endswith("\n"):
290              new_body += "\n"
291          return (
292              f"<!-- BEGIN KB-DERIVED: {section_name} -->\n"
293              f"{new_body}"
294              f"<!-- END KB-DERIVED: {section_name} -->"
295          )
296  
297      diffs = 0
298      for path in paths:
299          text = path.read_text(encoding="utf-8")
300          new_text = _BLOCK_RE.sub(_replace, text)
301          if new_text != text:
302              print(f"would regenerate: {path}", file=sys.stderr)
303              diffs += 1
304      return 1 if diffs else 0
305  
306  
307  def main(argv=None):
308      parser = argparse.ArgumentParser(
309          description="Regenerate KB-derived sections in pyod skill files."
310      )
311      parser.add_argument(
312          "--check",
313          action="store_true",
314          help="Dry-run: exit 1 if any file would change, 0 otherwise.",
315      )
316      parser.add_argument(
317          "--verbose",
318          "-v",
319          action="store_true",
320          help="Print every regenerated file path.",
321      )
322      args = parser.parse_args(argv)
323  
324      if args.check:
325          rc = check_files(list(find_skill_files()))
326          if rc == 0:
327              print("All skill files are up to date.")
328          return rc
329  
330      n = regen_all(verbose=args.verbose)
331      print(f"Regenerated {n} file(s).")
332      return 0
333  
334  
335  if __name__ == "__main__":
336      sys.exit(main())