/ analysis / classify_trials.py
classify_trials.py
  1  # EXTENDS: DES-c386de6567f0
  2  #!/usr/bin/env python3
  3  """
  4  Trial Classification Pipeline — Paper Integrity Layer 1
  5  ========================================================
  6  
  7  Reads ALL raw format_sensitivity_*.json files and classifies each trial:
  8  
  9    LIVE         — behavioral measurement (save_count>0 OR items_recalled>0)
 10    INTERNALIZER — genuine D1=1.0 behavior (no saves, no recall, but model
 11                   engaged: total_tokens >= 500 AND errors is null)
 12    EXCLUDED     — infrastructure failure (total_tokens < 500 OR errors non-null)
 13  
 14  INTERNALIZER trials are valid D1=1.0 measurements and are included in analysis.
 15  EXCLUDED trials are NOT behavioral data and are removed from all computations.
 16  
 17  Prompt sensitivity variants (minimal, emphatic, t0.5, t1.0) are tagged
 18  separately and excluded from the main analysis.
 19  
 20  Outputs:
 21    research/data/classified_trials.json  — every trial with classification
 22    research/data/trial_summary.json      — aggregate statistics
 23  
 24  Usage:
 25    .venv/bin/python3 research/scripts/classify_trials.py
 26  """
 27  
 28  from __future__ import annotations
 29  
 30  import json
 31  import os
 32  import sys
 33  from collections import defaultdict
 34  from datetime import datetime
 35  from pathlib import Path
 36  
 37  REPO_ROOT = Path(__file__).resolve().parent.parent.parent
 38  RESULTS_DIR = REPO_ROOT / "benchmarks" / "capability_hubris" / "results"
 39  OUTPUT_DIR = REPO_ROOT / "research" / "data"
 40  OUTPUT_TRIALS = OUTPUT_DIR / "classified_trials.json"
 41  OUTPUT_SUMMARY = OUTPUT_DIR / "trial_summary.json"
 42  
 43  PROMPT_SENSITIVITY_TAGS = ["_minimal_", "_emphatic_", "_t0.5_", "_t1.0_"]
 44  FORMATS = ["native_api", "text_xml", "pythonic_text"]
 45  
 46  # Threshold: trials with fewer tokens than this are infrastructure failures.
 47  # A model that genuinely engages produces 500+ tokens across the conversation.
 48  # API failures typically produce 0-100 tokens.
 49  MIN_ENGAGEMENT_TOKENS = 500
 50  
 51  
 52  def classify_trial(trial: dict) -> str:
 53      """Classify a single trial into LIVE, INTERNALIZER, or EXCLUDED."""
 54      save_count = trial.get("save_count", 0) or 0
 55      items_recalled = trial.get("items_recalled", 0) or 0
 56      total_tokens = trial.get("total_tokens", 0) or 0
 57      errors = trial.get("errors")
 58  
 59      # If the model saved or recalled anything, it's a live behavioral measurement
 60      if save_count > 0 or items_recalled > 0:
 61          return "LIVE"
 62  
 63      # No saves and no recall — could be infra failure or genuine internalization
 64      # Infrastructure failure: low tokens or explicit errors
 65      if total_tokens < MIN_ENGAGEMENT_TOKENS or errors is not None:
 66          return "EXCLUDED"
 67  
 68      # Genuine internalizer: model engaged (500+ tokens) but chose not to save
 69      # and failed to recall. This IS a valid D1=1.0 measurement.
 70      return "INTERNALIZER"
 71  
 72  
 73  def is_prompt_sensitivity_file(filename: str) -> bool:
 74      """Check if a file is a prompt sensitivity variant."""
 75      return any(tag in filename for tag in PROMPT_SENSITIVITY_TAGS)
 76  
 77  
 78  def main():
 79      # Find all result files
 80      pattern = RESULTS_DIR / "format_sensitivity_*.json"
 81      all_files = sorted(RESULTS_DIR.glob("format_sensitivity_*.json"))
 82  
 83      if not all_files:
 84          print(f"ERROR: No result files found in {RESULTS_DIR}")
 85          sys.exit(1)
 86  
 87      # Process all files
 88      classified = []
 89      per_model = defaultdict(lambda: defaultdict(lambda: {
 90          "live": 0, "internalizer": 0, "excluded": 0,
 91          "d1_values": [],  # all valid D1 values (LIVE + INTERNALIZER)
 92      }))
 93      per_file_stats = []
 94      prompt_sensitivity_trials = []
 95  
 96      for filepath in all_files:
 97          filename = filepath.name
 98          is_ps = is_prompt_sensitivity_file(filename)
 99  
100          try:
101              with open(filepath) as f:
102                  data = json.load(f)
103          except (json.JSONDecodeError, IOError) as e:
104              print(f"WARNING: Cannot read {filename}: {e}")
105              continue
106  
107          model_key = data.get("model_key", "unknown")
108          model_display = data.get("model_display", model_key)
109          file_timestamp = data.get("timestamp", "")
110          config = data.get("config", {})
111  
112          file_live = file_int = file_excl = 0
113  
114          for trial in data.get("trials", []):
115              classification = classify_trial(trial)
116              fmt = trial.get("format", "unknown")
117  
118              entry = {
119                  "model_key": model_key,
120                  "model_display": model_display,
121                  "format": fmt,
122                  "trial_num": trial.get("trial", 0),
123                  "classification": classification,
124                  "d1_score": trial.get("d1_score"),
125                  "save_count": trial.get("save_count", 0),
126                  "items_recalled": trial.get("items_recalled", 0),
127                  "total_tokens": trial.get("total_tokens", 0),
128                  "errors": trial.get("errors"),
129                  "duration_s": trial.get("duration_s", 0),
130                  "source_file": filename,
131                  "file_timestamp": file_timestamp,
132                  "is_prompt_sensitivity": is_ps,
133              }
134  
135              if is_ps:
136                  prompt_sensitivity_trials.append(entry)
137              else:
138                  classified.append(entry)
139  
140                  # Aggregate per model+format
141                  bucket = per_model[model_key][fmt]
142                  bucket[classification.lower()] += 1
143                  if classification in ("LIVE", "INTERNALIZER"):
144                      d1 = trial.get("d1_score")
145                      if d1 is not None:
146                          bucket["d1_values"].append(d1)
147  
148              if classification == "LIVE":
149                  file_live += 1
150              elif classification == "INTERNALIZER":
151                  file_int += 1
152              else:
153                  file_excl += 1
154  
155          per_file_stats.append({
156              "filename": filename,
157              "model_key": model_key,
158              "timestamp": file_timestamp,
159              "is_prompt_sensitivity": is_ps,
160              "n_trials": len(data.get("trials", [])),
161              "live": file_live,
162              "internalizer": file_int,
163              "excluded": file_excl,
164              "formats": config.get("formats", []),
165          })
166  
167      # Compute summary statistics
168      total_std_trials = len(classified)
169      total_live = sum(1 for t in classified if t["classification"] == "LIVE")
170      total_int = sum(1 for t in classified if t["classification"] == "INTERNALIZER")
171      total_excl = sum(1 for t in classified if t["classification"] == "EXCLUDED")
172      total_valid = total_live + total_int
173  
174      # Per-model summary
175      model_summaries = {}
176      unique_models = set()
177      unique_labs = set()
178      for model_key, fmt_data in sorted(per_model.items()):
179          unique_models.add(model_key)
180          model_total = {"live": 0, "internalizer": 0, "excluded": 0, "valid_d1": []}
181          per_format = {}
182          for fmt, counts in fmt_data.items():
183              per_format[fmt] = {
184                  "live": counts["live"],
185                  "internalizer": counts["internalizer"],
186                  "excluded": counts["excluded"],
187                  "valid_trials": counts["live"] + counts["internalizer"],
188                  "d1_values": counts["d1_values"],
189                  "d1_mean": (sum(counts["d1_values"]) / len(counts["d1_values"])
190                              if counts["d1_values"] else None),
191              }
192              model_total["live"] += counts["live"]
193              model_total["internalizer"] += counts["internalizer"]
194              model_total["excluded"] += counts["excluded"]
195              model_total["valid_d1"].extend(counts["d1_values"])
196  
197          model_summaries[model_key] = {
198              "formats": per_format,
199              "total_valid": model_total["live"] + model_total["internalizer"],
200              "total_excluded": model_total["excluded"],
201              "total_raw": (model_total["live"] + model_total["internalizer"]
202                            + model_total["excluded"]),
203              "exclusion_rate": round(
204                  model_total["excluded"]
205                  / max(1, model_total["live"] + model_total["internalizer"]
206                         + model_total["excluded"]),
207                  3
208              ),
209          }
210  
211      summary = {
212          "generated_at": datetime.now().isoformat(),
213          "classification_policy": {
214              "LIVE": "save_count > 0 OR items_recalled > 0",
215              "INTERNALIZER": (
216                  "save_count == 0 AND items_recalled == 0 AND "
217                  f"total_tokens >= {MIN_ENGAGEMENT_TOKENS} AND errors is null"
218              ),
219              "EXCLUDED": (
220                  f"total_tokens < {MIN_ENGAGEMENT_TOKENS} OR errors is not null"
221              ),
222              "min_engagement_tokens": MIN_ENGAGEMENT_TOKENS,
223              "note": (
224                  "INTERNALIZER trials are valid D1=1.0 measurements. "
225                  "The model genuinely engaged but chose not to externalize. "
226                  "EXCLUDED trials are infrastructure failures (API errors, "
227                  "credit exhaustion, routing failures) and are NOT behavioral data."
228              ),
229          },
230          "file_counts": {
231              "total_files": len(all_files),
232              "standard_files": len(all_files) - len([
233                  f for f in all_files if is_prompt_sensitivity_file(f.name)
234              ]),
235              "prompt_sensitivity_files": len([
236                  f for f in all_files if is_prompt_sensitivity_file(f.name)
237              ]),
238          },
239          "trial_counts": {
240              "total_standard_trials": total_std_trials,
241              "live": total_live,
242              "internalizer": total_int,
243              "excluded": total_excl,
244              "total_valid": total_valid,
245              "exclusion_rate": round(total_excl / max(1, total_std_trials), 3),
246              "prompt_sensitivity_trials": len(prompt_sensitivity_trials),
247          },
248          "model_count": len(unique_models),
249          "per_model": model_summaries,
250      }
251  
252      # Write outputs
253      OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
254  
255      with open(OUTPUT_TRIALS, "w") as f:
256          json.dump({
257              "classification_policy": summary["classification_policy"],
258              "trials": classified,
259          }, f, indent=2)
260  
261      with open(OUTPUT_SUMMARY, "w") as f:
262          json.dump(summary, f, indent=2)
263  
264      # Print results
265      print(f"Files processed: {len(all_files)} ({summary['file_counts']['standard_files']} standard, "
266            f"{summary['file_counts']['prompt_sensitivity_files']} prompt sensitivity)")
267      print(f"\nStandard trials: {total_std_trials}")
268      print(f"  LIVE:         {total_live:>5} ({total_live/total_std_trials*100:.1f}%)")
269      print(f"  INTERNALIZER: {total_int:>5} ({total_int/total_std_trials*100:.1f}%)")
270      print(f"  EXCLUDED:     {total_excl:>5} ({total_excl/total_std_trials*100:.1f}%)")
271      print(f"  Valid total:  {total_valid:>5} ({total_valid/total_std_trials*100:.1f}%)")
272      print(f"\nExclusion rate: {total_excl/total_std_trials*100:.1f}% "
273            f"(was 65% with old classification)")
274      print(f"Valid behavioral trials: {total_valid} "
275            f"(was ~2,058 with old classification)")
276      print(f"\nModels: {len(unique_models)}")
277      print(f"\nOutputs:")
278      print(f"  {OUTPUT_TRIALS}")
279      print(f"  {OUTPUT_SUMMARY}")
280  
281  
282  if __name__ == "__main__":
283      main()