/ analysis / activation_patching_d1.py
activation_patching_d1.py
  1  # EXTENDS: DES-c386de6567f0
  2  """
  3  Activation Patching for D1 (Externalization Boundary) — RDM-M5.
  4  
  5  Patches activations between base and instruct models of the same architecture
  6  at each transformer layer. Identifies which layers causally determine whether
  7  a model delegates to tools (D1=0) or answers from memory (D1=1).
  8  
  9  Experiment design:
 10    - For each base/instruct pair (same architecture, different weights):
 11      1. "Instruct→Base": Run base model, inject instruct's activation at layer L.
 12         If base starts producing tool-call tokens → layer L carries the tool decision.
 13      2. "Base→Instruct": Run instruct model, inject base's activation at layer L.
 14         If instruct stops producing tool-call tokens → layer L is necessary for tools.
 15    - Sweep ALL layers to build a per-layer causal effect curve.
 16    - Compare critical layers across D1=0 (externalizer) and D1=1 (internalizer) pairs.
 17  
 18  Three base/instruct pairs:
 19    - Llama 3 8B (D1=0, externalizer): Does base→instruct patching suppress tools?
 20    - Qwen 2.5 7B (D1=1, internalizer): Does instruct→base patching induce tools?
 21    - Mistral 7B v0.3 (D1=1, internalizer): Does instruct→base patching induce tools?
 22  
 23  Requires: MLX, mlx_lm (run on BRAIN — Mac Studio M3 Ultra, 256GB).
 24  
 25  Usage:
 26      python research/scripts/activation_patching_d1.py --pair llama8b
 27      python research/scripts/activation_patching_d1.py --pair all
 28      python research/scripts/activation_patching_d1.py --pair llama8b --layers 25-31
 29  
 30  Outputs:
 31      research/results/mechanistic/patching_{pair}_{timestamp}.json
 32      research/results/mechanistic/patching_summary.json
 33  """
 34  
 35  import argparse
 36  import json
 37  import sys
 38  import time
 39  from datetime import datetime, timezone
 40  from pathlib import Path
 41  
 42  import numpy as np
 43  
 44  try:
 45      import mlx.core as mx
 46      import mlx.nn as nn
 47      import mlx_lm
 48  except ImportError:
 49      print("ERROR: MLX not available. This script must run on Apple Silicon (BRAIN).")
 50      sys.exit(1)
 51  
 52  # ---------------------------------------------------------------------------
 53  # Paths
 54  # ---------------------------------------------------------------------------
 55  REPO_ROOT = Path(__file__).resolve().parent.parent.parent
 56  RESULTS_DIR = REPO_ROOT / "research" / "results" / "mechanistic"
 57  RESULTS_DIR.mkdir(parents=True, exist_ok=True)
 58  
 59  # ---------------------------------------------------------------------------
 60  # Model pairs: base + instruct of the same architecture
 61  # ---------------------------------------------------------------------------
 62  PAIR_CATALOG = {
 63      "llama8b": {
 64          "base_id": "mlx-community/Meta-Llama-3-8B-4bit",
 65          "instruct_id": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
 66          "display": "Llama 3 8B",
 67          "d1_group": "A",  # D1=0, externalizer
 68          "n_layers": 32,
 69          "hidden_dim": 4096,
 70      },
 71      "qwen7b": {
 72          "base_id": "mlx-community/Qwen2.5-7B-4bit",
 73          "instruct_id": "mlx-community/Qwen2.5-7B-Instruct-4bit",
 74          "display": "Qwen 2.5 7B",
 75          "d1_group": "B",  # D1=1, internalizer
 76          "n_layers": 28,
 77          "hidden_dim": 3584,
 78      },
 79      "mistral7b": {
 80          "base_id": "mlx-community/Mistral-7B-v0.3-4bit",
 81          "instruct_id": "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
 82          "display": "Mistral 7B v0.3",
 83          "d1_group": "B",  # D1=1, internalizer
 84          "n_layers": 32,
 85          "hidden_dim": 4096,
 86      },
 87  }
 88  
 89  # ---------------------------------------------------------------------------
 90  # D1 probe conversation (shared with logit_lens_d1.py)
 91  # ---------------------------------------------------------------------------
 92  STATE_ITEMS = [
 93      "Mercury (planet 1, diameter 4879km)",
 94      "Venus (planet 2, diameter 12104km)",
 95      "Earth (planet 3, diameter 12756km)",
 96      "Mars (planet 4, diameter 6792km)",
 97      "Jupiter (planet 5, diameter 142984km)",
 98      "Saturn (planet 6, diameter 120536km)",
 99      "Uranus (planet 7, diameter 51118km)",
100      "Neptune (planet 8, diameter 49528km)",
101      "Ceres (dwarf planet, diameter 940km)",
102      "Pluto (dwarf planet, diameter 2376km)",
103  ]
104  
105  SYSTEM_TEXT_XML = (
106      "You are a helpful assistant with access to a persistent notepad.\n\n"
107      "You have access to two tools for persistent state management. Use them by "
108      "outputting structured calls in EXACTLY this format (no variation):\n\n"
109      "To save notes:\n"
110      "<tool_call>\n"
111      '{"tool": "save_notes", "content": "your content here"}\n'
112      "</tool_call>\n\n"
113      "To read notes:\n"
114      "<tool_call>\n"
115      '{"tool": "read_notes"}\n'
116      "</tool_call>\n\n"
117      "After a tool call the system will respond with a tool result block. You may "
118      "then continue your response. Always use save_notes immediately when given "
119      "items to remember — do not rely on your context window alone."
120  )
121  
122  
123  def build_conversation(system_prompt):
124      """Build the multi-turn D1 conversation.
125  
126      Uses system-in-user pattern for tokenizer compatibility.
127      """
128      items_text = (
129          system_prompt + "\n\n---\n\n"
130          "I need you to track these 10 items carefully. I will quiz you on "
131          "them later after several unrelated questions. Here are the items:\n\n"
132      )
133      for i, item in enumerate(STATE_ITEMS, 1):
134          items_text += "%d. %s\n" % (i, item)
135      items_text += "\nPlease acknowledge that you have received all 10 items."
136  
137      messages = [
138          {"role": "user", "content": items_text},
139          {"role": "assistant", "content": "I have received and noted all 10 items."},
140      ]
141      distractions = [
142          ("Write a haiku about the ocean.",
143           "Waves crash on the shore / Salt air fills my every breath / The tide pulls me home"),
144          ("What is the chemical formula for table salt?",
145           "NaCl - sodium chloride, formed by ionic bonding between Na+ and Cl-."),
146          ("Describe Romeo and Juliet in three sentences.",
147           "Two young lovers from feuding families fall in love. Their secret romance ends in tragedy. Both die, and their families finally reconcile."),
148          ("List the first 12 Fibonacci numbers.",
149           "0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89"),
150          ("Explain TCP vs UDP briefly.",
151           "TCP is connection-oriented with guaranteed delivery. UDP is connectionless and faster but unreliable."),
152      ]
153      for q, a in distractions:
154          messages.append({"role": "user", "content": q})
155          messages.append({"role": "assistant", "content": a})
156      messages.append({
157          "role": "user",
158          "content": (
159              "Now, without looking back at the conversation, please list ALL 10 "
160              "items I gave you at the start. Include full details for each."
161          ),
162      })
163      return messages
164  
165  
166  # ---------------------------------------------------------------------------
167  # Activation capture
168  # ---------------------------------------------------------------------------
169  
170  def capture_all_activations(model, input_ids, mask):
171      """Full forward pass saving the residual stream after every layer.
172  
173      Returns:
174          activations: list of mx.array, one per layer (output of layer i)
175          final_logits: mx.array of shape (1, seq_len, vocab_size)
176      """
177      h = model.model.embed_tokens(input_ids)
178      activations = []
179  
180      for i, layer in enumerate(model.model.layers):
181          h = layer(h, mask=mask)
182          if isinstance(h, tuple):
183              h = h[0]
184          mx.eval(h)
185          activations.append(h)
186  
187      # Final norm + logits
188      if hasattr(model.model, "norm"):
189          h_normed = model.model.norm(h)
190      else:
191          h_normed = h
192  
193      if hasattr(model, "lm_head"):
194          logits = model.lm_head(h_normed)
195      else:
196          logits = model.model.embed_tokens.as_linear(h_normed)
197      mx.eval(logits)
198  
199      return activations, logits
200  
201  
202  def patched_forward(model, donor_activations, patch_layer, input_ids, mask):
203      """Forward pass with activation patching at a specific layer.
204  
205      Runs the target model normally through layers 0..patch_layer-1,
206      then replaces the residual stream with the donor's output at patch_layer,
207      then continues with the target model's layers patch_layer+1..N-1.
208  
209      Args:
210          model: target model to run
211          donor_activations: list of activations from the donor model
212          patch_layer: layer index at which to inject donor activations
213          input_ids: tokenized input
214          mask: causal attention mask
215  
216      Returns:
217          logits: mx.array of shape (1, seq_len, vocab_size)
218      """
219      h = model.model.embed_tokens(input_ids)
220      n_layers = len(model.model.layers)
221  
222      for i in range(n_layers):
223          h = model.model.layers[i](h, mask=mask)
224          if isinstance(h, tuple):
225              h = h[0]
226  
227          if i == patch_layer:
228              # PATCH: replace residual stream with donor's output at this layer
229              h = donor_activations[patch_layer]
230  
231          mx.eval(h)
232  
233      # Final norm + logits
234      if hasattr(model.model, "norm"):
235          h_normed = model.model.norm(h)
236      else:
237          h_normed = h
238  
239      if hasattr(model, "lm_head"):
240          logits = model.lm_head(h_normed)
241      else:
242          logits = model.model.embed_tokens.as_linear(h_normed)
243      mx.eval(logits)
244  
245      return logits
246  
247  
248  # ---------------------------------------------------------------------------
249  # Tool token measurement
250  # ---------------------------------------------------------------------------
251  
252  def find_tool_token_ids(tokenizer):
253      """Map tool-related strings to their token IDs."""
254      tool_strings = [
255          "save_notes", "read_notes", "tool_call", "<tool_call>",
256          "save", "notes", "tool", "function",
257          '{"tool"', '"save_notes"',
258      ]
259      result = {}
260      for s in tool_strings:
261          ids = tokenizer.encode(s, add_special_tokens=False)
262          result[s] = ids
263      return result
264  
265  
266  def measure_tool_probability(logits, tokenizer):
267      """Measure probability of tool-call tokens at the last position.
268  
269      Returns dict with per-token probabilities and summary metrics.
270      """
271      last_logits = logits[0, -1, :]
272      probs = mx.softmax(last_logits)
273      mx.eval(probs)
274  
275      tool_map = find_tool_token_ids(tokenizer)
276      tool_probs = {}
277      for name, ids_list in tool_map.items():
278          max_p = 0.0
279          for tid in ids_list:
280              if tid < probs.shape[0]:
281                  max_p = max(max_p, float(probs[tid].item()))
282          tool_probs[name] = round(max_p, 8)
283  
284      # Top-5 tokens
285      sorted_idx = mx.argsort(last_logits)[::-1][:5]
286      mx.eval(sorted_idx)
287      top5 = []
288      for idx in sorted_idx.tolist():
289          tok_str = tokenizer.decode([idx]).replace("\n", "\\n")
290          p = float(probs[idx].item())
291          top5.append({"id": idx, "token": tok_str, "prob": round(p, 6)})
292  
293      # Entropy
294      log_probs = mx.log(probs + 1e-10)
295      entropy = -float(mx.sum(probs * log_probs).item())
296  
297      return {
298          "tool_probs": tool_probs,
299          "top5": top5,
300          "entropy": round(entropy, 4),
301          "tool_call_prob": tool_probs.get("<tool_call>", 0.0),
302      }
303  
304  
305  # ---------------------------------------------------------------------------
306  # Main experiment
307  # ---------------------------------------------------------------------------
308  
309  def run_patching_experiment(pair_name, pair_info, layer_range=None):
310      """Run full activation patching experiment for one base/instruct pair.
311  
312      Returns a dict with per-layer patching results for both directions.
313      """
314      print(f"\n{'='*70}")
315      print(f"ACTIVATION PATCHING: {pair_info['display']} (D1 group {pair_info['d1_group']})")
316      print(f"{'='*70}")
317  
318      n_layers = pair_info["n_layers"]
319      if layer_range:
320          layers_to_test = list(range(layer_range[0], min(layer_range[1] + 1, n_layers)))
321      else:
322          layers_to_test = list(range(n_layers))
323  
324      # --- Load instruct model ---
325      print(f"\nLoading instruct: {pair_info['instruct_id']}...")
326      t0 = time.time()
327      instruct_model, instruct_tokenizer = mlx_lm.load(pair_info["instruct_id"])
328      print(f"  Loaded in {time.time()-t0:.1f}s")
329  
330      # --- Tokenize with instruct tokenizer (both models share vocabulary) ---
331      messages = build_conversation(SYSTEM_TEXT_XML)
332      prompt_text = instruct_tokenizer.apply_chat_template(
333          messages, tokenize=False, add_generation_prompt=True
334      )
335      tokens = instruct_tokenizer.encode(prompt_text)
336      input_ids = mx.array([tokens])
337      mask = nn.MultiHeadAttention.create_additive_causal_mask(len(tokens))
338      mask = mask.astype(mx.float16)
339      print(f"  Prompt: {len(tokens)} tokens")
340  
341      # --- Capture instruct activations ---
342      print("  Capturing instruct activations...")
343      t0 = time.time()
344      instruct_acts, instruct_logits = capture_all_activations(
345          instruct_model, input_ids, mask
346      )
347      instruct_baseline = measure_tool_probability(instruct_logits, instruct_tokenizer)
348      print(f"  Done in {time.time()-t0:.1f}s")
349      print(f"  Instruct baseline <tool_call> prob: {instruct_baseline['tool_call_prob']:.6f}")
350      print(f"  Instruct top-5: {[t['token'] for t in instruct_baseline['top5']]}")
351  
352      # --- Load base model ---
353      print(f"\nLoading base: {pair_info['base_id']}...")
354      t0 = time.time()
355      base_model, base_tokenizer = mlx_lm.load(pair_info["base_id"])
356      print(f"  Loaded in {time.time()-t0:.1f}s")
357  
358      # --- Capture base activations ---
359      # Note: using same input_ids (tokenized with instruct tokenizer).
360      # Base and instruct share the same tokenizer for same-family models.
361      print("  Capturing base activations...")
362      t0 = time.time()
363      base_acts, base_logits = capture_all_activations(base_model, input_ids, mask)
364      base_baseline = measure_tool_probability(base_logits, instruct_tokenizer)
365      print(f"  Done in {time.time()-t0:.1f}s")
366      print(f"  Base baseline <tool_call> prob: {base_baseline['tool_call_prob']:.6f}")
367      print(f"  Base top-5: {[t['token'] for t in base_baseline['top5']]}")
368  
369      # --- Patching: instruct → base (inject instruct activations into base) ---
370      print(f"\n--- Instruct → Base patching ({len(layers_to_test)} layers) ---")
371      i2b_results = {}
372      for L in layers_to_test:
373          t0 = time.time()
374          patched_logits = patched_forward(base_model, instruct_acts, L, input_ids, mask)
375          measurement = measure_tool_probability(patched_logits, instruct_tokenizer)
376          elapsed = time.time() - t0
377  
378          effect = measurement["tool_call_prob"] - base_baseline["tool_call_prob"]
379          i2b_results[L] = {
380              "tool_call_prob": measurement["tool_call_prob"],
381              "effect": round(effect, 8),
382              "top5": measurement["top5"],
383              "entropy": measurement["entropy"],
384              "time_s": round(elapsed, 2),
385          }
386          marker = "***" if abs(effect) > 0.01 else ""
387          print(f"  Layer {L:3d}/{n_layers} ({L/n_layers*100:5.1f}%): "
388                f"<tool_call>={measurement['tool_call_prob']:.6f} "
389                f"(effect={effect:+.6f}) {marker} [{elapsed:.1f}s]")
390  
391      # --- Patching: base → instruct (inject base activations into instruct) ---
392      print(f"\n--- Base → Instruct patching ({len(layers_to_test)} layers) ---")
393      b2i_results = {}
394      for L in layers_to_test:
395          t0 = time.time()
396          patched_logits = patched_forward(instruct_model, base_acts, L, input_ids, mask)
397          measurement = measure_tool_probability(patched_logits, instruct_tokenizer)
398          elapsed = time.time() - t0
399  
400          effect = measurement["tool_call_prob"] - instruct_baseline["tool_call_prob"]
401          b2i_results[L] = {
402              "tool_call_prob": measurement["tool_call_prob"],
403              "effect": round(effect, 8),
404              "top5": measurement["top5"],
405              "entropy": measurement["entropy"],
406              "time_s": round(elapsed, 2),
407          }
408          marker = "***" if abs(effect) > 0.01 else ""
409          print(f"  Layer {L:3d}/{n_layers} ({L/n_layers*100:5.1f}%): "
410                f"<tool_call>={measurement['tool_call_prob']:.6f} "
411                f"(effect={effect:+.6f}) {marker} [{elapsed:.1f}s]")
412  
413      # --- Find critical layers ---
414      i2b_effects = [(L, r["effect"]) for L, r in i2b_results.items()]
415      b2i_effects = [(L, r["effect"]) for L, r in b2i_results.items()]
416  
417      i2b_peak = max(i2b_effects, key=lambda x: abs(x[1]))
418      b2i_peak = max(b2i_effects, key=lambda x: abs(x[1]))
419  
420      # --- Clean up models to free memory ---
421      del base_model, instruct_model, base_acts, instruct_acts
422      del base_logits, instruct_logits
423  
424      # --- Build result ---
425      result = {
426          "pair": pair_name,
427          "display": pair_info["display"],
428          "d1_group": pair_info["d1_group"],
429          "n_layers": n_layers,
430          "n_tokens": len(tokens),
431          "layers_tested": layers_to_test,
432          "base_id": pair_info["base_id"],
433          "instruct_id": pair_info["instruct_id"],
434          "baselines": {
435              "instruct": {
436                  "tool_call_prob": instruct_baseline["tool_call_prob"],
437                  "top5": instruct_baseline["top5"],
438                  "entropy": instruct_baseline["entropy"],
439              },
440              "base": {
441                  "tool_call_prob": base_baseline["tool_call_prob"],
442                  "top5": base_baseline["top5"],
443                  "entropy": base_baseline["entropy"],
444              },
445          },
446          "instruct_to_base": {str(k): v for k, v in i2b_results.items()},
447          "base_to_instruct": {str(k): v for k, v in b2i_results.items()},
448          "summary": {
449              "i2b_peak_layer": i2b_peak[0],
450              "i2b_peak_effect": round(i2b_peak[1], 8),
451              "i2b_peak_depth_pct": round(i2b_peak[0] / n_layers * 100, 1),
452              "b2i_peak_layer": b2i_peak[0],
453              "b2i_peak_effect": round(b2i_peak[1], 8),
454              "b2i_peak_depth_pct": round(b2i_peak[0] / n_layers * 100, 1),
455          },
456          "timestamp": datetime.now(timezone.utc).isoformat(),
457      }
458  
459      print(f"\n--- Summary for {pair_info['display']} ---")
460      print(f"  Instruct→Base peak: layer {i2b_peak[0]} ({i2b_peak[0]/n_layers*100:.1f}%), "
461            f"effect={i2b_peak[1]:+.6f}")
462      print(f"  Base→Instruct peak: layer {b2i_peak[0]} ({b2i_peak[0]/n_layers*100:.1f}%), "
463            f"effect={b2i_peak[1]:+.6f}")
464  
465      return result
466  
467  
468  def build_cross_pair_summary(all_results):
469      """Summarize patching effects across all pairs."""
470      summary = {
471          "n_pairs": len(all_results),
472          "pairs": {},
473          "timestamp": datetime.now(timezone.utc).isoformat(),
474      }
475      for r in all_results:
476          summary["pairs"][r["pair"]] = {
477              "display": r["display"],
478              "d1_group": r["d1_group"],
479              "n_layers": r["n_layers"],
480              "i2b_peak_layer": r["summary"]["i2b_peak_layer"],
481              "i2b_peak_depth_pct": r["summary"]["i2b_peak_depth_pct"],
482              "i2b_peak_effect": r["summary"]["i2b_peak_effect"],
483              "b2i_peak_layer": r["summary"]["b2i_peak_layer"],
484              "b2i_peak_depth_pct": r["summary"]["b2i_peak_depth_pct"],
485              "b2i_peak_effect": r["summary"]["b2i_peak_effect"],
486              "instruct_baseline_tool_prob": r["baselines"]["instruct"]["tool_call_prob"],
487              "base_baseline_tool_prob": r["baselines"]["base"]["tool_call_prob"],
488          }
489      return summary
490  
491  
492  def parse_layer_range(s):
493      """Parse layer range like '25-31' or '20'."""
494      if "-" in s:
495          parts = s.split("-")
496          return (int(parts[0]), int(parts[1]))
497      else:
498          n = int(s)
499          return (n, n)
500  
501  
502  def main():
503      parser = argparse.ArgumentParser(
504          description="Activation patching for D1 externalization boundary"
505      )
506      parser.add_argument(
507          "--pair", nargs="+", default=["all"],
508          choices=list(PAIR_CATALOG.keys()) + ["all"],
509          help="Model pairs to test"
510      )
511      parser.add_argument(
512          "--layers", type=str, default=None,
513          help="Layer range to test, e.g. '25-31' (default: all layers)"
514      )
515      args = parser.parse_args()
516  
517      pairs = list(PAIR_CATALOG.keys()) if "all" in args.pair else args.pair
518      layer_range = parse_layer_range(args.layers) if args.layers else None
519  
520      print(f"Activation Patching — D1 Externalization Boundary")
521      print(f"Pairs: {pairs}")
522      print(f"Layer range: {layer_range or 'all'}")
523      print(f"Results dir: {RESULTS_DIR}")
524  
525      all_results = []
526      ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M")
527  
528      for pair_name in pairs:
529          pair_info = PAIR_CATALOG[pair_name]
530          result = run_patching_experiment(pair_name, pair_info, layer_range)
531  
532          # Save per-pair result
533          out_path = RESULTS_DIR / f"patching_{pair_name}_{ts}.json"
534          with open(out_path, "w") as f:
535              json.dump(result, f, indent=2)
536          print(f"\nSaved: {out_path}")
537          all_results.append(result)
538  
539      # Cross-pair summary
540      if len(all_results) > 1:
541          summary = build_cross_pair_summary(all_results)
542          summary_path = RESULTS_DIR / "patching_summary.json"
543          with open(summary_path, "w") as f:
544              json.dump(summary, f, indent=2)
545          print(f"\nSaved summary: {summary_path}")
546  
547      print("\n" + "=" * 70)
548      print("COMPLETE")
549      for r in all_results:
550          s = r["summary"]
551          print(f"  {r['display']} (D1={r['d1_group']}): "
552                f"I→B peak L{s['i2b_peak_layer']} ({s['i2b_peak_depth_pct']}%), "
553                f"B→I peak L{s['b2i_peak_layer']} ({s['b2i_peak_depth_pct']}%)")
554      print("=" * 70)
555  
556  
557  if __name__ == "__main__":
558      main()