activation_patching_d1.py
1 # EXTENDS: DES-c386de6567f0 2 """ 3 Activation Patching for D1 (Externalization Boundary) — RDM-M5. 4 5 Patches activations between base and instruct models of the same architecture 6 at each transformer layer. Identifies which layers causally determine whether 7 a model delegates to tools (D1=0) or answers from memory (D1=1). 8 9 Experiment design: 10 - For each base/instruct pair (same architecture, different weights): 11 1. "Instruct→Base": Run base model, inject instruct's activation at layer L. 12 If base starts producing tool-call tokens → layer L carries the tool decision. 13 2. "Base→Instruct": Run instruct model, inject base's activation at layer L. 14 If instruct stops producing tool-call tokens → layer L is necessary for tools. 15 - Sweep ALL layers to build a per-layer causal effect curve. 16 - Compare critical layers across D1=0 (externalizer) and D1=1 (internalizer) pairs. 17 18 Three base/instruct pairs: 19 - Llama 3 8B (D1=0, externalizer): Does base→instruct patching suppress tools? 20 - Qwen 2.5 7B (D1=1, internalizer): Does instruct→base patching induce tools? 21 - Mistral 7B v0.3 (D1=1, internalizer): Does instruct→base patching induce tools? 22 23 Requires: MLX, mlx_lm (run on BRAIN — Mac Studio M3 Ultra, 256GB). 24 25 Usage: 26 python research/scripts/activation_patching_d1.py --pair llama8b 27 python research/scripts/activation_patching_d1.py --pair all 28 python research/scripts/activation_patching_d1.py --pair llama8b --layers 25-31 29 30 Outputs: 31 research/results/mechanistic/patching_{pair}_{timestamp}.json 32 research/results/mechanistic/patching_summary.json 33 """ 34 35 import argparse 36 import json 37 import sys 38 import time 39 from datetime import datetime, timezone 40 from pathlib import Path 41 42 import numpy as np 43 44 try: 45 import mlx.core as mx 46 import mlx.nn as nn 47 import mlx_lm 48 except ImportError: 49 print("ERROR: MLX not available. This script must run on Apple Silicon (BRAIN).") 50 sys.exit(1) 51 52 # --------------------------------------------------------------------------- 53 # Paths 54 # --------------------------------------------------------------------------- 55 REPO_ROOT = Path(__file__).resolve().parent.parent.parent 56 RESULTS_DIR = REPO_ROOT / "research" / "results" / "mechanistic" 57 RESULTS_DIR.mkdir(parents=True, exist_ok=True) 58 59 # --------------------------------------------------------------------------- 60 # Model pairs: base + instruct of the same architecture 61 # --------------------------------------------------------------------------- 62 PAIR_CATALOG = { 63 "llama8b": { 64 "base_id": "mlx-community/Meta-Llama-3-8B-4bit", 65 "instruct_id": "mlx-community/Meta-Llama-3-8B-Instruct-4bit", 66 "display": "Llama 3 8B", 67 "d1_group": "A", # D1=0, externalizer 68 "n_layers": 32, 69 "hidden_dim": 4096, 70 }, 71 "qwen7b": { 72 "base_id": "mlx-community/Qwen2.5-7B-4bit", 73 "instruct_id": "mlx-community/Qwen2.5-7B-Instruct-4bit", 74 "display": "Qwen 2.5 7B", 75 "d1_group": "B", # D1=1, internalizer 76 "n_layers": 28, 77 "hidden_dim": 3584, 78 }, 79 "mistral7b": { 80 "base_id": "mlx-community/Mistral-7B-v0.3-4bit", 81 "instruct_id": "mlx-community/Mistral-7B-Instruct-v0.3-4bit", 82 "display": "Mistral 7B v0.3", 83 "d1_group": "B", # D1=1, internalizer 84 "n_layers": 32, 85 "hidden_dim": 4096, 86 }, 87 } 88 89 # --------------------------------------------------------------------------- 90 # D1 probe conversation (shared with logit_lens_d1.py) 91 # --------------------------------------------------------------------------- 92 STATE_ITEMS = [ 93 "Mercury (planet 1, diameter 4879km)", 94 "Venus (planet 2, diameter 12104km)", 95 "Earth (planet 3, diameter 12756km)", 96 "Mars (planet 4, diameter 6792km)", 97 "Jupiter (planet 5, diameter 142984km)", 98 "Saturn (planet 6, diameter 120536km)", 99 "Uranus (planet 7, diameter 51118km)", 100 "Neptune (planet 8, diameter 49528km)", 101 "Ceres (dwarf planet, diameter 940km)", 102 "Pluto (dwarf planet, diameter 2376km)", 103 ] 104 105 SYSTEM_TEXT_XML = ( 106 "You are a helpful assistant with access to a persistent notepad.\n\n" 107 "You have access to two tools for persistent state management. Use them by " 108 "outputting structured calls in EXACTLY this format (no variation):\n\n" 109 "To save notes:\n" 110 "<tool_call>\n" 111 '{"tool": "save_notes", "content": "your content here"}\n' 112 "</tool_call>\n\n" 113 "To read notes:\n" 114 "<tool_call>\n" 115 '{"tool": "read_notes"}\n' 116 "</tool_call>\n\n" 117 "After a tool call the system will respond with a tool result block. You may " 118 "then continue your response. Always use save_notes immediately when given " 119 "items to remember — do not rely on your context window alone." 120 ) 121 122 123 def build_conversation(system_prompt): 124 """Build the multi-turn D1 conversation. 125 126 Uses system-in-user pattern for tokenizer compatibility. 127 """ 128 items_text = ( 129 system_prompt + "\n\n---\n\n" 130 "I need you to track these 10 items carefully. I will quiz you on " 131 "them later after several unrelated questions. Here are the items:\n\n" 132 ) 133 for i, item in enumerate(STATE_ITEMS, 1): 134 items_text += "%d. %s\n" % (i, item) 135 items_text += "\nPlease acknowledge that you have received all 10 items." 136 137 messages = [ 138 {"role": "user", "content": items_text}, 139 {"role": "assistant", "content": "I have received and noted all 10 items."}, 140 ] 141 distractions = [ 142 ("Write a haiku about the ocean.", 143 "Waves crash on the shore / Salt air fills my every breath / The tide pulls me home"), 144 ("What is the chemical formula for table salt?", 145 "NaCl - sodium chloride, formed by ionic bonding between Na+ and Cl-."), 146 ("Describe Romeo and Juliet in three sentences.", 147 "Two young lovers from feuding families fall in love. Their secret romance ends in tragedy. Both die, and their families finally reconcile."), 148 ("List the first 12 Fibonacci numbers.", 149 "0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89"), 150 ("Explain TCP vs UDP briefly.", 151 "TCP is connection-oriented with guaranteed delivery. UDP is connectionless and faster but unreliable."), 152 ] 153 for q, a in distractions: 154 messages.append({"role": "user", "content": q}) 155 messages.append({"role": "assistant", "content": a}) 156 messages.append({ 157 "role": "user", 158 "content": ( 159 "Now, without looking back at the conversation, please list ALL 10 " 160 "items I gave you at the start. Include full details for each." 161 ), 162 }) 163 return messages 164 165 166 # --------------------------------------------------------------------------- 167 # Activation capture 168 # --------------------------------------------------------------------------- 169 170 def capture_all_activations(model, input_ids, mask): 171 """Full forward pass saving the residual stream after every layer. 172 173 Returns: 174 activations: list of mx.array, one per layer (output of layer i) 175 final_logits: mx.array of shape (1, seq_len, vocab_size) 176 """ 177 h = model.model.embed_tokens(input_ids) 178 activations = [] 179 180 for i, layer in enumerate(model.model.layers): 181 h = layer(h, mask=mask) 182 if isinstance(h, tuple): 183 h = h[0] 184 mx.eval(h) 185 activations.append(h) 186 187 # Final norm + logits 188 if hasattr(model.model, "norm"): 189 h_normed = model.model.norm(h) 190 else: 191 h_normed = h 192 193 if hasattr(model, "lm_head"): 194 logits = model.lm_head(h_normed) 195 else: 196 logits = model.model.embed_tokens.as_linear(h_normed) 197 mx.eval(logits) 198 199 return activations, logits 200 201 202 def patched_forward(model, donor_activations, patch_layer, input_ids, mask): 203 """Forward pass with activation patching at a specific layer. 204 205 Runs the target model normally through layers 0..patch_layer-1, 206 then replaces the residual stream with the donor's output at patch_layer, 207 then continues with the target model's layers patch_layer+1..N-1. 208 209 Args: 210 model: target model to run 211 donor_activations: list of activations from the donor model 212 patch_layer: layer index at which to inject donor activations 213 input_ids: tokenized input 214 mask: causal attention mask 215 216 Returns: 217 logits: mx.array of shape (1, seq_len, vocab_size) 218 """ 219 h = model.model.embed_tokens(input_ids) 220 n_layers = len(model.model.layers) 221 222 for i in range(n_layers): 223 h = model.model.layers[i](h, mask=mask) 224 if isinstance(h, tuple): 225 h = h[0] 226 227 if i == patch_layer: 228 # PATCH: replace residual stream with donor's output at this layer 229 h = donor_activations[patch_layer] 230 231 mx.eval(h) 232 233 # Final norm + logits 234 if hasattr(model.model, "norm"): 235 h_normed = model.model.norm(h) 236 else: 237 h_normed = h 238 239 if hasattr(model, "lm_head"): 240 logits = model.lm_head(h_normed) 241 else: 242 logits = model.model.embed_tokens.as_linear(h_normed) 243 mx.eval(logits) 244 245 return logits 246 247 248 # --------------------------------------------------------------------------- 249 # Tool token measurement 250 # --------------------------------------------------------------------------- 251 252 def find_tool_token_ids(tokenizer): 253 """Map tool-related strings to their token IDs.""" 254 tool_strings = [ 255 "save_notes", "read_notes", "tool_call", "<tool_call>", 256 "save", "notes", "tool", "function", 257 '{"tool"', '"save_notes"', 258 ] 259 result = {} 260 for s in tool_strings: 261 ids = tokenizer.encode(s, add_special_tokens=False) 262 result[s] = ids 263 return result 264 265 266 def measure_tool_probability(logits, tokenizer): 267 """Measure probability of tool-call tokens at the last position. 268 269 Returns dict with per-token probabilities and summary metrics. 270 """ 271 last_logits = logits[0, -1, :] 272 probs = mx.softmax(last_logits) 273 mx.eval(probs) 274 275 tool_map = find_tool_token_ids(tokenizer) 276 tool_probs = {} 277 for name, ids_list in tool_map.items(): 278 max_p = 0.0 279 for tid in ids_list: 280 if tid < probs.shape[0]: 281 max_p = max(max_p, float(probs[tid].item())) 282 tool_probs[name] = round(max_p, 8) 283 284 # Top-5 tokens 285 sorted_idx = mx.argsort(last_logits)[::-1][:5] 286 mx.eval(sorted_idx) 287 top5 = [] 288 for idx in sorted_idx.tolist(): 289 tok_str = tokenizer.decode([idx]).replace("\n", "\\n") 290 p = float(probs[idx].item()) 291 top5.append({"id": idx, "token": tok_str, "prob": round(p, 6)}) 292 293 # Entropy 294 log_probs = mx.log(probs + 1e-10) 295 entropy = -float(mx.sum(probs * log_probs).item()) 296 297 return { 298 "tool_probs": tool_probs, 299 "top5": top5, 300 "entropy": round(entropy, 4), 301 "tool_call_prob": tool_probs.get("<tool_call>", 0.0), 302 } 303 304 305 # --------------------------------------------------------------------------- 306 # Main experiment 307 # --------------------------------------------------------------------------- 308 309 def run_patching_experiment(pair_name, pair_info, layer_range=None): 310 """Run full activation patching experiment for one base/instruct pair. 311 312 Returns a dict with per-layer patching results for both directions. 313 """ 314 print(f"\n{'='*70}") 315 print(f"ACTIVATION PATCHING: {pair_info['display']} (D1 group {pair_info['d1_group']})") 316 print(f"{'='*70}") 317 318 n_layers = pair_info["n_layers"] 319 if layer_range: 320 layers_to_test = list(range(layer_range[0], min(layer_range[1] + 1, n_layers))) 321 else: 322 layers_to_test = list(range(n_layers)) 323 324 # --- Load instruct model --- 325 print(f"\nLoading instruct: {pair_info['instruct_id']}...") 326 t0 = time.time() 327 instruct_model, instruct_tokenizer = mlx_lm.load(pair_info["instruct_id"]) 328 print(f" Loaded in {time.time()-t0:.1f}s") 329 330 # --- Tokenize with instruct tokenizer (both models share vocabulary) --- 331 messages = build_conversation(SYSTEM_TEXT_XML) 332 prompt_text = instruct_tokenizer.apply_chat_template( 333 messages, tokenize=False, add_generation_prompt=True 334 ) 335 tokens = instruct_tokenizer.encode(prompt_text) 336 input_ids = mx.array([tokens]) 337 mask = nn.MultiHeadAttention.create_additive_causal_mask(len(tokens)) 338 mask = mask.astype(mx.float16) 339 print(f" Prompt: {len(tokens)} tokens") 340 341 # --- Capture instruct activations --- 342 print(" Capturing instruct activations...") 343 t0 = time.time() 344 instruct_acts, instruct_logits = capture_all_activations( 345 instruct_model, input_ids, mask 346 ) 347 instruct_baseline = measure_tool_probability(instruct_logits, instruct_tokenizer) 348 print(f" Done in {time.time()-t0:.1f}s") 349 print(f" Instruct baseline <tool_call> prob: {instruct_baseline['tool_call_prob']:.6f}") 350 print(f" Instruct top-5: {[t['token'] for t in instruct_baseline['top5']]}") 351 352 # --- Load base model --- 353 print(f"\nLoading base: {pair_info['base_id']}...") 354 t0 = time.time() 355 base_model, base_tokenizer = mlx_lm.load(pair_info["base_id"]) 356 print(f" Loaded in {time.time()-t0:.1f}s") 357 358 # --- Capture base activations --- 359 # Note: using same input_ids (tokenized with instruct tokenizer). 360 # Base and instruct share the same tokenizer for same-family models. 361 print(" Capturing base activations...") 362 t0 = time.time() 363 base_acts, base_logits = capture_all_activations(base_model, input_ids, mask) 364 base_baseline = measure_tool_probability(base_logits, instruct_tokenizer) 365 print(f" Done in {time.time()-t0:.1f}s") 366 print(f" Base baseline <tool_call> prob: {base_baseline['tool_call_prob']:.6f}") 367 print(f" Base top-5: {[t['token'] for t in base_baseline['top5']]}") 368 369 # --- Patching: instruct → base (inject instruct activations into base) --- 370 print(f"\n--- Instruct → Base patching ({len(layers_to_test)} layers) ---") 371 i2b_results = {} 372 for L in layers_to_test: 373 t0 = time.time() 374 patched_logits = patched_forward(base_model, instruct_acts, L, input_ids, mask) 375 measurement = measure_tool_probability(patched_logits, instruct_tokenizer) 376 elapsed = time.time() - t0 377 378 effect = measurement["tool_call_prob"] - base_baseline["tool_call_prob"] 379 i2b_results[L] = { 380 "tool_call_prob": measurement["tool_call_prob"], 381 "effect": round(effect, 8), 382 "top5": measurement["top5"], 383 "entropy": measurement["entropy"], 384 "time_s": round(elapsed, 2), 385 } 386 marker = "***" if abs(effect) > 0.01 else "" 387 print(f" Layer {L:3d}/{n_layers} ({L/n_layers*100:5.1f}%): " 388 f"<tool_call>={measurement['tool_call_prob']:.6f} " 389 f"(effect={effect:+.6f}) {marker} [{elapsed:.1f}s]") 390 391 # --- Patching: base → instruct (inject base activations into instruct) --- 392 print(f"\n--- Base → Instruct patching ({len(layers_to_test)} layers) ---") 393 b2i_results = {} 394 for L in layers_to_test: 395 t0 = time.time() 396 patched_logits = patched_forward(instruct_model, base_acts, L, input_ids, mask) 397 measurement = measure_tool_probability(patched_logits, instruct_tokenizer) 398 elapsed = time.time() - t0 399 400 effect = measurement["tool_call_prob"] - instruct_baseline["tool_call_prob"] 401 b2i_results[L] = { 402 "tool_call_prob": measurement["tool_call_prob"], 403 "effect": round(effect, 8), 404 "top5": measurement["top5"], 405 "entropy": measurement["entropy"], 406 "time_s": round(elapsed, 2), 407 } 408 marker = "***" if abs(effect) > 0.01 else "" 409 print(f" Layer {L:3d}/{n_layers} ({L/n_layers*100:5.1f}%): " 410 f"<tool_call>={measurement['tool_call_prob']:.6f} " 411 f"(effect={effect:+.6f}) {marker} [{elapsed:.1f}s]") 412 413 # --- Find critical layers --- 414 i2b_effects = [(L, r["effect"]) for L, r in i2b_results.items()] 415 b2i_effects = [(L, r["effect"]) for L, r in b2i_results.items()] 416 417 i2b_peak = max(i2b_effects, key=lambda x: abs(x[1])) 418 b2i_peak = max(b2i_effects, key=lambda x: abs(x[1])) 419 420 # --- Clean up models to free memory --- 421 del base_model, instruct_model, base_acts, instruct_acts 422 del base_logits, instruct_logits 423 424 # --- Build result --- 425 result = { 426 "pair": pair_name, 427 "display": pair_info["display"], 428 "d1_group": pair_info["d1_group"], 429 "n_layers": n_layers, 430 "n_tokens": len(tokens), 431 "layers_tested": layers_to_test, 432 "base_id": pair_info["base_id"], 433 "instruct_id": pair_info["instruct_id"], 434 "baselines": { 435 "instruct": { 436 "tool_call_prob": instruct_baseline["tool_call_prob"], 437 "top5": instruct_baseline["top5"], 438 "entropy": instruct_baseline["entropy"], 439 }, 440 "base": { 441 "tool_call_prob": base_baseline["tool_call_prob"], 442 "top5": base_baseline["top5"], 443 "entropy": base_baseline["entropy"], 444 }, 445 }, 446 "instruct_to_base": {str(k): v for k, v in i2b_results.items()}, 447 "base_to_instruct": {str(k): v for k, v in b2i_results.items()}, 448 "summary": { 449 "i2b_peak_layer": i2b_peak[0], 450 "i2b_peak_effect": round(i2b_peak[1], 8), 451 "i2b_peak_depth_pct": round(i2b_peak[0] / n_layers * 100, 1), 452 "b2i_peak_layer": b2i_peak[0], 453 "b2i_peak_effect": round(b2i_peak[1], 8), 454 "b2i_peak_depth_pct": round(b2i_peak[0] / n_layers * 100, 1), 455 }, 456 "timestamp": datetime.now(timezone.utc).isoformat(), 457 } 458 459 print(f"\n--- Summary for {pair_info['display']} ---") 460 print(f" Instruct→Base peak: layer {i2b_peak[0]} ({i2b_peak[0]/n_layers*100:.1f}%), " 461 f"effect={i2b_peak[1]:+.6f}") 462 print(f" Base→Instruct peak: layer {b2i_peak[0]} ({b2i_peak[0]/n_layers*100:.1f}%), " 463 f"effect={b2i_peak[1]:+.6f}") 464 465 return result 466 467 468 def build_cross_pair_summary(all_results): 469 """Summarize patching effects across all pairs.""" 470 summary = { 471 "n_pairs": len(all_results), 472 "pairs": {}, 473 "timestamp": datetime.now(timezone.utc).isoformat(), 474 } 475 for r in all_results: 476 summary["pairs"][r["pair"]] = { 477 "display": r["display"], 478 "d1_group": r["d1_group"], 479 "n_layers": r["n_layers"], 480 "i2b_peak_layer": r["summary"]["i2b_peak_layer"], 481 "i2b_peak_depth_pct": r["summary"]["i2b_peak_depth_pct"], 482 "i2b_peak_effect": r["summary"]["i2b_peak_effect"], 483 "b2i_peak_layer": r["summary"]["b2i_peak_layer"], 484 "b2i_peak_depth_pct": r["summary"]["b2i_peak_depth_pct"], 485 "b2i_peak_effect": r["summary"]["b2i_peak_effect"], 486 "instruct_baseline_tool_prob": r["baselines"]["instruct"]["tool_call_prob"], 487 "base_baseline_tool_prob": r["baselines"]["base"]["tool_call_prob"], 488 } 489 return summary 490 491 492 def parse_layer_range(s): 493 """Parse layer range like '25-31' or '20'.""" 494 if "-" in s: 495 parts = s.split("-") 496 return (int(parts[0]), int(parts[1])) 497 else: 498 n = int(s) 499 return (n, n) 500 501 502 def main(): 503 parser = argparse.ArgumentParser( 504 description="Activation patching for D1 externalization boundary" 505 ) 506 parser.add_argument( 507 "--pair", nargs="+", default=["all"], 508 choices=list(PAIR_CATALOG.keys()) + ["all"], 509 help="Model pairs to test" 510 ) 511 parser.add_argument( 512 "--layers", type=str, default=None, 513 help="Layer range to test, e.g. '25-31' (default: all layers)" 514 ) 515 args = parser.parse_args() 516 517 pairs = list(PAIR_CATALOG.keys()) if "all" in args.pair else args.pair 518 layer_range = parse_layer_range(args.layers) if args.layers else None 519 520 print(f"Activation Patching — D1 Externalization Boundary") 521 print(f"Pairs: {pairs}") 522 print(f"Layer range: {layer_range or 'all'}") 523 print(f"Results dir: {RESULTS_DIR}") 524 525 all_results = [] 526 ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M") 527 528 for pair_name in pairs: 529 pair_info = PAIR_CATALOG[pair_name] 530 result = run_patching_experiment(pair_name, pair_info, layer_range) 531 532 # Save per-pair result 533 out_path = RESULTS_DIR / f"patching_{pair_name}_{ts}.json" 534 with open(out_path, "w") as f: 535 json.dump(result, f, indent=2) 536 print(f"\nSaved: {out_path}") 537 all_results.append(result) 538 539 # Cross-pair summary 540 if len(all_results) > 1: 541 summary = build_cross_pair_summary(all_results) 542 summary_path = RESULTS_DIR / "patching_summary.json" 543 with open(summary_path, "w") as f: 544 json.dump(summary, f, indent=2) 545 print(f"\nSaved summary: {summary_path}") 546 547 print("\n" + "=" * 70) 548 print("COMPLETE") 549 for r in all_results: 550 s = r["summary"] 551 print(f" {r['display']} (D1={r['d1_group']}): " 552 f"I→B peak L{s['i2b_peak_layer']} ({s['i2b_peak_depth_pct']}%), " 553 f"B→I peak L{s['b2i_peak_layer']} ({s['b2i_peak_depth_pct']}%)") 554 print("=" * 70) 555 556 557 if __name__ == "__main__": 558 main()