/ analysis / generate_paper_data.py
generate_paper_data.py
  1  # EXTENDS: DES-c386de6567f0
  2  #!/usr/bin/env python3
  3  """
  4  Canonical Paper Data Generator — Paper Integrity Layer 2
  5  =========================================================
  6  
  7  Reads classified_trials.json and computes every metric used in the paper.
  8  Outputs:
  9    research/data/paper_claims.json     — canonical source for all paper numbers
 10    research/latex/generated_macros.tex — LaTeX macros for all numbers
 11  
 12  Every number in the paper should trace back to this script.
 13  
 14  Usage:
 15    .venv/bin/python3 research/scripts/generate_paper_data.py
 16  """
 17  
 18  from __future__ import annotations
 19  
 20  import json
 21  import math
 22  import sys
 23  from collections import defaultdict
 24  from datetime import datetime
 25  from pathlib import Path
 26  from statistics import mean, stdev
 27  
 28  REPO_ROOT = Path(__file__).resolve().parent.parent.parent
 29  CLASSIFIED = REPO_ROOT / "research" / "data" / "classified_trials.json"
 30  OUTPUT_JSON = REPO_ROOT / "research" / "data" / "paper_claims.json"
 31  OUTPUT_TEX = REPO_ROOT / "research" / "latex" / "generated_macros.tex"
 32  
 33  FORMATS = ["native_api", "text_xml", "pythonic_text"]
 34  
 35  # ---- Model metadata (lab assignments) ----------------------------------------
 36  
 37  MODEL_LABS = {
 38      "claude_37_sonnet": "Anthropic", "claude_haiku": "Anthropic",
 39      "claude_sonnet": "Anthropic",
 40      "command_r_plus": "Cohere",
 41      "deepseek_r1": "DeepSeek", "deepseek_r1_distill_llama": "DeepSeek",
 42      "deepseek_r1_distill_qwen": "DeepSeek", "deepseek_v3": "DeepSeek",
 43      "ernie_45": "Baidu",
 44      "gemini_25_pro": "Google", "gemini_flash": "Google",
 45      "gemma3_27b": "Google", "gemma_27b": "Google",
 46      "glm_45": "Zhipu", "glm_47": "Zhipu", "glm_5": "Zhipu",
 47      "gpt41": "OpenAI", "gpt41_mini": "OpenAI",
 48      "gpt4o": "OpenAI", "gpt4o_mini": "OpenAI",
 49      "grok3": "xAI",
 50      "hunyuan": "Tencent", "hunyuan_t1": "Tencent",
 51      "jamba_large": "AI21",
 52      "kimi_k2": "Moonshot",
 53      "llama33_70b": "Meta", "llama4_maverick": "Meta", "llama4_scout": "Meta",
 54      "longcat_flash": "Meituan",
 55      "mimo_flash": "Xiaomi",
 56      "minimax_m25": "MiniMax",
 57      "mistral_large": "Mistral", "mixtral_8x22b": "Mistral",
 58      "nova_pro": "Amazon",
 59      "o3_mini": "OpenAI",
 60      "phi4": "Microsoft",
 61      "qwen35_397b": "Alibaba", "qwen3_235b": "Alibaba",
 62      "qwen_72b": "Alibaba", "qwq_32b": "Alibaba",
 63      "seed_16": "ByteDance", "seed_20_lite": "ByteDance",
 64      "seed_20_pro": "ByteDance",
 65      "step_flash": "StepFun",
 66  }
 67  
 68  
 69  def wilson_ci(p: float, n: int, z: float = 1.96) -> tuple[float, float]:
 70      """Wilson score confidence interval for a proportion."""
 71      if n == 0:
 72          return (0.0, 1.0)
 73      denom = 1 + z * z / n
 74      centre = (p + z * z / (2 * n)) / denom
 75      margin = z * math.sqrt((p * (1 - p) + z * z / (4 * n)) / n) / denom
 76      return (max(0.0, centre - margin), min(1.0, centre + margin))
 77  
 78  
 79  def fisher_exact_2x2(a: int, b: int, c: int, d: int) -> float:
 80      """Two-sided Fisher's exact test for a 2x2 table. Returns p-value."""
 81      try:
 82          from scipy.stats import fisher_exact
 83          _, p = fisher_exact([[a, b], [c, d]], alternative="two-sided")
 84          return float(p)
 85      except ImportError:
 86          return float("nan")
 87  
 88  
 89  def cliffs_delta(x: list[float], y: list[float]) -> float | None:
 90      """Cliff's delta non-parametric effect size."""
 91      if not x or not y:
 92          return None
 93      nx, ny = len(x), len(y)
 94      count = sum((1 if xi > yi else -1 if xi < yi else 0)
 95                  for xi in x for yi in y)
 96      return count / (nx * ny)
 97  
 98  
 99  def latex_safe(name: str) -> str:
100      """Convert a metric name to a valid LaTeX macro name."""
101      return name.replace("_", "").replace(".", "").replace("-", "")
102  
103  
104  def main():
105      if not CLASSIFIED.exists():
106          print(f"ERROR: {CLASSIFIED} not found. Run classify_trials.py first.")
107          sys.exit(1)
108  
109      with open(CLASSIFIED) as f:
110          data = json.load(f)
111  
112      trials = data["trials"]
113      policy = data["classification_policy"]
114  
115      # ---- Aggregate per model per format (valid trials only) ------------------
116      per_model_format = defaultdict(lambda: defaultdict(list))
117      all_valid = []
118  
119      for t in trials:
120          if t["classification"] in ("LIVE", "INTERNALIZER"):
121              key = t["model_key"]
122              fmt = t["format"]
123              d1 = t["d1_score"]
124              if d1 is not None:
125                  per_model_format[key][fmt].append(d1)
126                  all_valid.append(t)
127  
128      # ---- Compute per-model metrics -------------------------------------------
129      models = {}
130      for model_key in sorted(per_model_format.keys()):
131          fmt_data = per_model_format[model_key]
132          lab = MODEL_LABS.get(model_key, "Unknown")
133  
134          model_entry = {
135              "model_key": model_key,
136              "lab": lab,
137              "formats": {},
138          }
139  
140          all_d1 = []
141          for fmt in FORMATS:
142              vals = fmt_data.get(fmt, [])
143              if vals:
144                  m = mean(vals)
145                  ci = wilson_ci(m, len(vals))
146                  model_entry["formats"][fmt] = {
147                      "d1_mean": round(m, 4),
148                      "d1_values": [round(v, 4) for v in vals],
149                      "n": len(vals),
150                      "ci_lower": round(ci[0], 4),
151                      "ci_upper": round(ci[1], 4),
152                  }
153                  all_d1.extend(vals)
154              else:
155                  model_entry["formats"][fmt] = {
156                      "d1_mean": None, "n": 0,
157                      "ci_lower": None, "ci_upper": None,
158                  }
159  
160          # Weighted mean across formats
161          total_n = sum(
162              model_entry["formats"][f]["n"] for f in FORMATS
163          )
164          if total_n > 0:
165              weighted = sum(
166                  model_entry["formats"][f]["d1_mean"] * model_entry["formats"][f]["n"]
167                  for f in FORMATS
168                  if model_entry["formats"][f]["d1_mean"] is not None
169              ) / total_n
170              model_entry["d1_weighted_mean"] = round(weighted, 4)
171          else:
172              model_entry["d1_weighted_mean"] = None
173  
174          model_entry["total_valid_trials"] = total_n
175          models[model_key] = model_entry
176  
177      # ---- Aggregate statistics ------------------------------------------------
178      total_valid_trials = len(all_valid)
179      total_excluded = sum(1 for t in trials if t["classification"] == "EXCLUDED")
180      total_internalizer = sum(1 for t in trials if t["classification"] == "INTERNALIZER")
181      total_raw = len(trials)
182      unique_models = len(models)
183      unique_labs = len(set(MODEL_LABS.get(k, "Unknown") for k in models))
184  
185      # Exclusion rate
186      exclusion_rate = round(total_excluded / max(1, total_raw), 3)
187  
188      # ---- Format sensitivity (Fisher exact test per model) --------------------
189      fisher_results = {}
190      fisher_significant_count = 0
191      for model_key, entry in models.items():
192          nat = entry["formats"].get("native_api", {})
193          xml = entry["formats"].get("text_xml", {})
194          pyth = entry["formats"].get("pythonic_text", {})
195  
196          # Need at least 2 formats with data
197          available = []
198          for fmt_name, fmt_data in [("native_api", nat), ("text_xml", xml), ("pythonic_text", pyth)]:
199              n = fmt_data.get("n", 0)
200              if n >= 5:  # minimum for meaningful test
201                  available.append((fmt_name, fmt_data))
202  
203          if len(available) >= 2:
204              # Use first two available formats for Fisher test
205              f1_name, f1 = available[0]
206              f2_name, f2 = available[1]
207              # 2x2 table: [externalize_f1, internalize_f1], [externalize_f2, internalize_f2]
208              d1_f1 = f1["d1_mean"] or 0
209              d1_f2 = f2["d1_mean"] or 0
210              n1, n2 = f1["n"], f2["n"]
211              ext_f1 = round(n1 * (1 - d1_f1))
212              int_f1 = n1 - ext_f1
213              ext_f2 = round(n2 * (1 - d1_f2))
214              int_f2 = n2 - ext_f2
215              p = fisher_exact_2x2(ext_f1, int_f1, ext_f2, int_f2)
216  
217              fisher_results[model_key] = {
218                  "formats_compared": [f1_name, f2_name],
219                  "p_value": round(p, 6) if not math.isnan(p) else None,
220                  "significant_005": bool(p < 0.05) if not math.isnan(p) else None,
221              }
222              if not math.isnan(p) and p < 0.05:
223                  fisher_significant_count += 1
224  
225      # BH-FDR correction
226      p_values = [(k, v["p_value"]) for k, v in fisher_results.items()
227                  if v["p_value"] is not None]
228      p_values.sort(key=lambda x: x[1])
229      m_tests = len(p_values)
230      for rank, (model_key, p) in enumerate(p_values, 1):
231          q = p * m_tests / rank
232          fisher_results[model_key]["bh_q"] = round(min(q, 1.0), 6)
233          fisher_results[model_key]["significant_fdr005"] = bool(q < 0.05)
234  
235      fisher_sig_fdr = sum(1 for v in fisher_results.values()
236                           if v.get("significant_fdr005"))
237  
238      # ---- Build paper claims --------------------------------------------------
239      claims = {
240          "generated_at": datetime.now().isoformat(),
241          "generated_by": "research/scripts/generate_paper_data.py",
242          "classification_policy": policy,
243  
244          # Aggregate numbers
245          "aggregate": {
246              "total_models": unique_models,
247              "total_labs": unique_labs,
248              "total_valid_trials": total_valid_trials,
249              "total_live_trials": total_valid_trials - total_internalizer,
250              "total_internalizer_trials": total_internalizer,
251              "total_excluded_trials": total_excluded,
252              "total_raw_trials": total_raw,
253              "exclusion_rate": exclusion_rate,
254              "exclusion_rate_pct": round(exclusion_rate * 100, 1),
255              "fisher_significant_uncorrected": fisher_significant_count,
256              "fisher_significant_fdr": fisher_sig_fdr,
257              "fisher_testable_models": len(fisher_results),
258          },
259  
260          # Per-model data
261          "models": models,
262  
263          # Format sensitivity
264          "fisher_tests": fisher_results,
265      }
266  
267      # ---- Generate LaTeX macros -----------------------------------------------
268      macros = [
269          "% Auto-generated by research/scripts/generate_paper_data.py",
270          f"% Generated: {datetime.now().isoformat()}",
271          "% DO NOT EDIT MANUALLY. Regenerate with:",
272          "%   .venv/bin/python3 research/scripts/generate_paper_data.py",
273          "",
274          "% ---- Aggregate statistics ----",
275          f"\\newcommand{{\\nModels}}{{{unique_models}}}",
276          f"\\newcommand{{\\nLabs}}{{{unique_labs}}}",
277          f"\\newcommand{{\\nValidTrials}}{{{total_valid_trials:,}}}",
278          f"\\newcommand{{\\nLiveTrials}}{{{total_valid_trials - total_internalizer:,}}}",
279          f"\\newcommand{{\\nInternalizerTrials}}{{{total_internalizer}}}",
280          f"\\newcommand{{\\nExcludedTrials}}{{{total_excluded:,}}}",
281          f"\\newcommand{{\\nRawTrials}}{{{total_raw:,}}}",
282          f"\\newcommand{{\\exclusionRatePct}}{{{round(exclusion_rate * 100, 0):.0f}}}",
283          f"\\newcommand{{\\nFisherSigFDR}}{{{fisher_sig_fdr}}}",
284          f"\\newcommand{{\\nFisherTestable}}{{{len(fisher_results)}}}",
285          "",
286          "% ---- Per-model D1 (weighted mean across formats) ----",
287      ]
288  
289      for model_key, entry in sorted(models.items()):
290          safe_name = latex_safe(model_key)
291          wm = entry["d1_weighted_mean"]
292          n = entry["total_valid_trials"]
293          if wm is not None:
294              macros.append(
295                  f"\\newcommand{{\\d{safe_name}Mean}}{{{wm:.3f}}}"
296              )
297              macros.append(
298                  f"\\newcommand{{\\d{safe_name}N}}{{{n}}}"
299              )
300          # Per-format
301          for fmt in FORMATS:
302              fmt_safe = latex_safe(fmt)
303              fd = entry["formats"].get(fmt, {})
304              fm = fd.get("d1_mean")
305              fn = fd.get("n", 0)
306              if fm is not None:
307                  macros.append(
308                      f"\\newcommand{{\\d{safe_name}{fmt_safe}Mean}}{{{fm:.3f}}}"
309                  )
310                  macros.append(
311                      f"\\newcommand{{\\d{safe_name}{fmt_safe}N}}{{{fn}}}"
312                  )
313  
314      # Write outputs
315      OUTPUT_JSON.parent.mkdir(parents=True, exist_ok=True)
316      OUTPUT_TEX.parent.mkdir(parents=True, exist_ok=True)
317  
318      with open(OUTPUT_JSON, "w") as f:
319          json.dump(claims, f, indent=2)
320  
321      with open(OUTPUT_TEX, "w") as f:
322          f.write("\n".join(macros) + "\n")
323  
324      # Print summary
325      print(f"Paper claims: {OUTPUT_JSON}")
326      print(f"LaTeX macros: {OUTPUT_TEX}")
327      print(f"\n=== Canonical Numbers ===")
328      print(f"Models: {unique_models} from {unique_labs} labs")
329      print(f"Valid trials: {total_valid_trials:,} "
330            f"({total_valid_trials - total_internalizer:,} LIVE + "
331            f"{total_internalizer} INTERNALIZER)")
332      print(f"Excluded: {total_excluded:,} ({exclusion_rate*100:.1f}%)")
333      print(f"Raw: {total_raw:,}")
334      print(f"Format sensitivity: {fisher_sig_fdr}/{len(fisher_results)} "
335            f"significant (FDR q<0.05)")
336      print(f"\nLaTeX macros generated: {len([l for l in macros if l.startswith(chr(92))])}")
337  
338  
339  if __name__ == "__main__":
340      main()