/ core / attention / coherence_risk.py
coherence_risk.py
  1  """
  2  Coherence Risk Detection
  3  
  4  Detects when concurrent sessions risk architectural divergence.
  5  
  6  The problem: You're in Session A designing an attention system.
  7  Session B is also making architecture decisions. If they diverge,
  8  you'll have to reconcile later - expensive.
  9  
 10  The solution: Detect high-coherence-risk situations and increase
 11  membrane permeability automatically. The sessions should "feel"
 12  each other more strongly when they're making overlapping decisions.
 13  
 14  Risk factors:
 15  1. Semantic overlap - same concepts being discussed
 16  2. Decision weight - architecture/principle decisions vs implementation
 17  3. Recency - both sessions active recently
 18  4. Conflict signals - different conclusions on same topic
 19  
 20  When risk is high, membrane opens. Sessions become more aware.
 21  """
 22  
 23  from dataclasses import dataclass, field
 24  from datetime import datetime, timedelta
 25  from typing import Optional, List, Dict, Set, Tuple
 26  from enum import Enum
 27  
 28  from .membrane import PermeabilityLevel
 29  
 30  
 31  class DecisionType(Enum):
 32      """Types of decisions with different coherence weights."""
 33      PRINCIPLE = "principle"      # Foundational choices (highest weight)
 34      ARCHITECTURE = "architecture"  # Structural decisions
 35      INTERFACE = "interface"      # API/contract decisions
 36      IMPLEMENTATION = "implementation"  # How to do something
 37      NAMING = "naming"            # What to call things
 38      STYLE = "style"              # Formatting, conventions
 39  
 40  
 41  # Weight multipliers for decision types
 42  DECISION_WEIGHTS = {
 43      DecisionType.PRINCIPLE: 1.0,
 44      DecisionType.ARCHITECTURE: 0.9,
 45      DecisionType.INTERFACE: 0.7,
 46      DecisionType.IMPLEMENTATION: 0.3,
 47      DecisionType.NAMING: 0.5,
 48      DecisionType.STYLE: 0.1,
 49  }
 50  
 51  
 52  @dataclass
 53  class SessionDecision:
 54      """A decision made in a session."""
 55      session_id: str
 56      timestamp: datetime
 57      decision_type: DecisionType
 58      topic: str
 59      content: str
 60      confidence: float = 0.5  # How confident the session seems
 61  
 62      # For conflict detection
 63      position: Optional[str] = None  # "for", "against", "undecided"
 64      alternatives_considered: List[str] = field(default_factory=list)
 65  
 66  
 67  @dataclass
 68  class CoherenceRisk:
 69      """Assessment of coherence risk between sessions."""
 70      risk_score: float  # 0-1, higher = more risk
 71      semantic_overlap: float
 72      decision_weight: float
 73      recency_factor: float
 74      conflict_detected: bool
 75  
 76      overlapping_topics: List[str] = field(default_factory=list)
 77      conflicting_decisions: List[Tuple[str, str]] = field(default_factory=list)
 78  
 79      recommended_permeability: PermeabilityLevel = PermeabilityLevel.ATTRACTORS
 80  
 81  
 82  class CoherenceRiskDetector:
 83      """
 84      Detects coherence risk across concurrent sessions.
 85  
 86      Usage:
 87          detector = CoherenceRiskDetector()
 88  
 89          # Record decisions from sessions
 90          detector.record_decision(SessionDecision(
 91              session_id="abc123",
 92              timestamp=datetime.now(),
 93              decision_type=DecisionType.ARCHITECTURE,
 94              topic="attention_system",
 95              content="Using continuous stream instead of batch compaction"
 96          ))
 97  
 98          # Check risk between sessions
 99          risk = detector.assess_risk("abc123", "def456")
100          print(f"Risk: {risk.risk_score:.0%}, recommend: {risk.recommended_permeability}")
101  
102          # Get recommended permeability for a session
103          permeability = detector.get_recommended_permeability("abc123")
104      """
105  
106      # Thresholds for permeability recommendations
107      ALERT_THRESHOLD = 0.7
108      SUMMARY_THRESHOLD = 0.4
109      ATTRACTOR_THRESHOLD = 0.2
110  
111      # Decision keywords for type detection
112      DECISION_KEYWORDS = {
113          DecisionType.PRINCIPLE: [
114              'principle', 'philosophy', 'approach', 'belief', 'always', 'never',
115              'should', 'must', 'fundamental', 'core', 'foundational'
116          ],
117          DecisionType.ARCHITECTURE: [
118              'architecture', 'design', 'structure', 'system', 'component',
119              'module', 'layer', 'pattern', 'flow', 'pipeline'
120          ],
121          DecisionType.INTERFACE: [
122              'interface', 'api', 'contract', 'protocol', 'schema',
123              'endpoint', 'method', 'function', 'signature'
124          ],
125          DecisionType.IMPLEMENTATION: [
126              'implement', 'code', 'write', 'build', 'create',
127              'fix', 'bug', 'feature', 'todo'
128          ],
129          DecisionType.NAMING: [
130              'name', 'call', 'rename', 'term', 'vocabulary',
131              'convention', 'label'
132          ],
133      }
134  
135      def __init__(
136          self,
137          recency_window_minutes: int = 30,
138          overlap_threshold: float = 0.3
139      ):
140          self.recency_window = timedelta(minutes=recency_window_minutes)
141          self.overlap_threshold = overlap_threshold
142  
143          # Decision storage
144          self._decisions: Dict[str, List[SessionDecision]] = {}  # session_id -> decisions
145          self._topic_decisions: Dict[str, List[SessionDecision]] = {}  # topic -> decisions
146  
147      def record_decision(self, decision: SessionDecision) -> None:
148          """Record a decision from a session."""
149          # Store by session
150          if decision.session_id not in self._decisions:
151              self._decisions[decision.session_id] = []
152          self._decisions[decision.session_id].append(decision)
153  
154          # Store by topic
155          if decision.topic not in self._topic_decisions:
156              self._topic_decisions[decision.topic] = []
157          self._topic_decisions[decision.topic].append(decision)
158  
159      def detect_decision_type(self, content: str) -> DecisionType:
160          """Detect decision type from content."""
161          content_lower = content.lower()
162  
163          scores = {}
164          for dtype, keywords in self.DECISION_KEYWORDS.items():
165              score = sum(1 for kw in keywords if kw in content_lower)
166              scores[dtype] = score
167  
168          if not any(scores.values()):
169              return DecisionType.IMPLEMENTATION
170  
171          return max(scores, key=scores.get)
172  
173      def assess_risk(
174          self,
175          session_a: str,
176          session_b: str
177      ) -> CoherenceRisk:
178          """
179          Assess coherence risk between two sessions.
180  
181          Returns a CoherenceRisk with score and recommendation.
182          """
183          now = datetime.now()
184  
185          # Get recent decisions from each session
186          decisions_a = self._get_recent_decisions(session_a, now)
187          decisions_b = self._get_recent_decisions(session_b, now)
188  
189          if not decisions_a or not decisions_b:
190              return CoherenceRisk(
191                  risk_score=0.0,
192                  semantic_overlap=0.0,
193                  decision_weight=0.0,
194                  recency_factor=0.0,
195                  conflict_detected=False,
196                  recommended_permeability=PermeabilityLevel.CLOSED
197              )
198  
199          # Compute semantic overlap (shared topics)
200          topics_a = set(d.topic for d in decisions_a)
201          topics_b = set(d.topic for d in decisions_b)
202          overlapping = topics_a & topics_b
203  
204          if not topics_a or not topics_b:
205              semantic_overlap = 0.0
206          else:
207              semantic_overlap = len(overlapping) / min(len(topics_a), len(topics_b))
208  
209          # Compute decision weight (max weight of overlapping decisions)
210          decision_weight = 0.0
211          for topic in overlapping:
212              for d in self._topic_decisions.get(topic, []):
213                  weight = DECISION_WEIGHTS.get(d.decision_type, 0.3)
214                  decision_weight = max(decision_weight, weight)
215  
216          # Compute recency factor
217          most_recent_a = max(d.timestamp for d in decisions_a)
218          most_recent_b = max(d.timestamp for d in decisions_b)
219          time_gap = abs((most_recent_a - most_recent_b).total_seconds())
220          recency_factor = max(0, 1 - time_gap / self.recency_window.total_seconds())
221  
222          # Detect conflicts (same topic, different positions)
223          conflicts = []
224          for topic in overlapping:
225              topic_decisions = self._topic_decisions.get(topic, [])
226              positions = {}
227              for d in topic_decisions:
228                  if d.position:
229                      if d.session_id not in positions:
230                          positions[d.session_id] = set()
231                      positions[d.session_id].add(d.position)
232  
233              # Check if sessions have different positions
234              if len(positions) >= 2:
235                  all_positions = list(positions.values())
236                  if all_positions[0] != all_positions[1]:
237                      conflicts.append((topic, "position_mismatch"))
238  
239          conflict_detected = len(conflicts) > 0
240  
241          # Compute overall risk score
242          risk_score = (
243              semantic_overlap * 0.4 +
244              decision_weight * 0.3 +
245              recency_factor * 0.2 +
246              (0.3 if conflict_detected else 0.0)
247          )
248          risk_score = min(1.0, risk_score)
249  
250          # Determine recommended permeability
251          if risk_score >= self.ALERT_THRESHOLD or conflict_detected:
252              recommended = PermeabilityLevel.ALERTS
253          elif risk_score >= self.SUMMARY_THRESHOLD:
254              recommended = PermeabilityLevel.SUMMARIES
255          elif risk_score >= self.ATTRACTOR_THRESHOLD:
256              recommended = PermeabilityLevel.ATTRACTORS
257          else:
258              recommended = PermeabilityLevel.CLOSED
259  
260          return CoherenceRisk(
261              risk_score=risk_score,
262              semantic_overlap=semantic_overlap,
263              decision_weight=decision_weight,
264              recency_factor=recency_factor,
265              conflict_detected=conflict_detected,
266              overlapping_topics=list(overlapping),
267              conflicting_decisions=conflicts,
268              recommended_permeability=recommended
269          )
270  
271      def _get_recent_decisions(
272          self,
273          session_id: str,
274          now: datetime
275      ) -> List[SessionDecision]:
276          """Get decisions from a session within the recency window."""
277          if session_id not in self._decisions:
278              return []
279  
280          cutoff = now - self.recency_window
281          return [
282              d for d in self._decisions[session_id]
283              if d.timestamp >= cutoff
284          ]
285  
286      def get_recommended_permeability(
287          self,
288          session_id: str
289      ) -> PermeabilityLevel:
290          """
291          Get recommended permeability for a session based on risk with all others.
292  
293          Returns the highest permeability recommended across all session pairs.
294          """
295          max_permeability = PermeabilityLevel.CLOSED
296  
297          for other_session in self._decisions.keys():
298              if other_session == session_id:
299                  continue
300  
301              risk = self.assess_risk(session_id, other_session)
302  
303              # Take the highest (most open) permeability
304              if risk.recommended_permeability.value > max_permeability.value:
305                  max_permeability = risk.recommended_permeability
306  
307          return max_permeability
308  
309      def get_high_risk_topics(
310          self,
311          threshold: float = 0.5
312      ) -> List[Tuple[str, float, List[str]]]:
313          """
314          Get topics with high coherence risk.
315  
316          Returns list of (topic, risk_score, session_ids).
317          """
318          high_risk = []
319  
320          for topic, decisions in self._topic_decisions.items():
321              session_ids = list(set(d.session_id for d in decisions))
322              if len(session_ids) < 2:
323                  continue
324  
325              # Compute topic-specific risk
326              max_weight = max(
327                  DECISION_WEIGHTS.get(d.decision_type, 0.3)
328                  for d in decisions
329              )
330  
331              # Check for position conflicts
332              positions = {}
333              for d in decisions:
334                  if d.position:
335                      if d.session_id not in positions:
336                          positions[d.session_id] = d.position
337  
338              conflict_factor = 0.3 if len(set(positions.values())) > 1 else 0.0
339  
340              risk = max_weight * 0.7 + conflict_factor
341  
342              if risk >= threshold:
343                  high_risk.append((topic, risk, session_ids))
344  
345          return sorted(high_risk, key=lambda x: x[1], reverse=True)
346  
347  
348  class AdaptiveMembrane:
349      """
350      Membrane that adapts permeability based on coherence risk.
351  
352      Wraps the base membrane and adjusts permeability dynamically.
353      """
354  
355      def __init__(
356          self,
357          risk_detector: CoherenceRiskDetector,
358          base_permeability: PermeabilityLevel = PermeabilityLevel.ATTRACTORS
359      ):
360          self.risk_detector = risk_detector
361          self.base_permeability = base_permeability
362  
363      def get_permeability_for_session(
364          self,
365          session_id: str
366      ) -> PermeabilityLevel:
367          """
368          Get current permeability for a session.
369  
370          Returns the higher of base permeability and risk-based recommendation.
371          """
372          risk_permeability = self.risk_detector.get_recommended_permeability(session_id)
373  
374          # Permeability levels in order
375          levels = [
376              PermeabilityLevel.CLOSED,
377              PermeabilityLevel.ATTRACTORS,
378              PermeabilityLevel.SUMMARIES,
379              PermeabilityLevel.ALERTS,
380              PermeabilityLevel.OPEN
381          ]
382  
383          base_idx = levels.index(self.base_permeability)
384          risk_idx = levels.index(risk_permeability)
385  
386          # Return the more open of the two
387          return levels[max(base_idx, risk_idx)]
388  
389      def should_alert(self, session_id: str, topic: str) -> bool:
390          """
391          Check if we should send an alert about a topic to a session.
392  
393          Returns True if the topic is high-risk and involves this session.
394          """
395          high_risk = self.risk_detector.get_high_risk_topics(threshold=0.5)
396  
397          for risk_topic, risk_score, sessions in high_risk:
398              if topic == risk_topic and session_id in sessions:
399                  return True
400  
401          return False
402  
403  
404  def create_coherence_system() -> Tuple[CoherenceRiskDetector, AdaptiveMembrane]:
405      """
406      Create the coherence risk detection system.
407  
408      Returns (CoherenceRiskDetector, AdaptiveMembrane).
409      """
410      detector = CoherenceRiskDetector()
411      membrane = AdaptiveMembrane(detector)
412      return detector, membrane
413  
414  
415  if __name__ == "__main__":
416      print("=== Coherence Risk Detection ===\n")
417  
418      detector, membrane = create_coherence_system()
419  
420      # Simulate decisions from two sessions
421      print("Recording decisions from Session A...")
422      detector.record_decision(SessionDecision(
423          session_id="session_a",
424          timestamp=datetime.now(),
425          decision_type=DecisionType.ARCHITECTURE,
426          topic="context_management",
427          content="Using stream model instead of sawtooth compaction",
428          position="for"
429      ))
430  
431      detector.record_decision(SessionDecision(
432          session_id="session_a",
433          timestamp=datetime.now(),
434          decision_type=DecisionType.PRINCIPLE,
435          topic="attention_system",
436          content="Attention is all you need - the core insight",
437          position="for"
438      ))
439  
440      print("Recording decisions from Session B...")
441      detector.record_decision(SessionDecision(
442          session_id="session_b",
443          timestamp=datetime.now(),
444          decision_type=DecisionType.ARCHITECTURE,
445          topic="context_management",
446          content="Considering batch compaction for simplicity",
447          position="against"  # Different position!
448      ))
449  
450      detector.record_decision(SessionDecision(
451          session_id="session_b",
452          timestamp=datetime.now(),
453          decision_type=DecisionType.ARCHITECTURE,
454          topic="aha_detection",
455          content="Two types: discovery and architectural"
456      ))
457  
458      # Assess risk
459      print("\n--- Risk Assessment ---")
460      risk = detector.assess_risk("session_a", "session_b")
461      print(f"Risk score: {risk.risk_score:.0%}")
462      print(f"Semantic overlap: {risk.semantic_overlap:.0%}")
463      print(f"Decision weight: {risk.decision_weight:.0%}")
464      print(f"Conflict detected: {risk.conflict_detected}")
465      print(f"Overlapping topics: {risk.overlapping_topics}")
466      print(f"Recommended permeability: {risk.recommended_permeability.value}")
467  
468      # Get high-risk topics
469      print("\n--- High Risk Topics ---")
470      for topic, score, sessions in detector.get_high_risk_topics():
471          print(f"  [{score:.0%}] {topic} ({', '.join(s[:8] for s in sessions)})")
472  
473      # Adaptive membrane
474      print("\n--- Adaptive Membrane ---")
475      perm = membrane.get_permeability_for_session("session_a")
476      print(f"Session A permeability: {perm.value}")
477  
478      print("\n=== Coherence protected ===")