knowledge_graph.py
1 """Knowledge graph helpers: entity extraction, persistence, dedup, merge. 2 3 Stores extracted entities at the source level (not chunk level) in three tables: 4 - kg_entities: unique entities per project (deduped by normalized name + type) 5 - kg_entity_mentions: which sources each entity appears in 6 - kg_entity_relationships: co-occurrence edges between entities 7 """ 8 import logging 9 import re 10 from collections import Counter, defaultdict 11 from datetime import datetime, timezone 12 from difflib import SequenceMatcher 13 from typing import Optional 14 15 from restai.models.databasemodels import ( 16 KGEntityDatabase, 17 KGEntityMentionDatabase, 18 KGEntityRelationshipDatabase, 19 ) 20 21 logger = logging.getLogger(__name__) 22 23 ALLOWED_TYPES = {"PER", "PERSON", "ORG", "LOC", "MISC", "DATE", "GPE"} 24 TYPE_NORMALIZATION = { 25 "PER": "PERSON", 26 "GPE": "LOC", 27 } 28 29 30 def normalize_entity_name(name: str) -> str: 31 """Lowercase + collapse whitespace + strip punctuation around the edges.""" 32 s = (name or "").strip().lower() 33 s = re.sub(r"\s+", " ", s) 34 s = re.sub(r"^[^\w]+|[^\w]+$", "", s) 35 return s 36 37 38 def _normalize_type(t: str) -> str: 39 return TYPE_NORMALIZATION.get(t, t) 40 41 42 def find_entities_in_text(text: str, brain, model_name: Optional[str] = None) -> list[tuple[str, str]]: 43 """Run NER and return [(canonical_name, type), ...] deduped within this call.""" 44 if not text: 45 return [] 46 raw = brain.extract_entities_from_text(text, model_name=model_name) 47 seen: dict[tuple[str, str], str] = {} 48 for ent in raw: 49 word = (ent.get("word") or "").strip() 50 etype = _normalize_type((ent.get("entity_group") or ent.get("entity") or "").upper()) 51 if not word or etype not in ALLOWED_TYPES: 52 continue 53 normalized = normalize_entity_name(word) 54 if not normalized: 55 continue 56 key = (normalized, etype) 57 if key not in seen: 58 seen[key] = word 59 return [(canonical, etype) for (norm, etype), canonical in seen.items()] 60 61 62 def extract_and_persist(project_id: int, source: str, text: str, brain, db) -> int: 63 """Extract entities from text and persist them. Returns the count of distinct entities found. 64 65 `db` is the DBWrapper instance (uses db.db for the SQLAlchemy session). 66 """ 67 if not text: 68 return 0 69 raw = brain.extract_entities_from_text(text, model_name=None) 70 if not raw: 71 return 0 72 73 # Aggregate counts within this source 74 per_source_counts: Counter = Counter() 75 canonical_names: dict[tuple[str, str], str] = {} 76 for ent in raw: 77 word = (ent.get("word") or "").strip() 78 etype = _normalize_type((ent.get("entity_group") or ent.get("entity") or "").upper()) 79 if not word or etype not in ALLOWED_TYPES: 80 continue 81 normalized = normalize_entity_name(word) 82 if not normalized: 83 continue 84 key = (normalized, etype) 85 per_source_counts[key] += 1 86 if key not in canonical_names: 87 canonical_names[key] = word 88 89 if not per_source_counts: 90 return 0 91 92 now = datetime.now(timezone.utc) 93 session = db.db 94 entity_ids: dict[tuple[str, str], int] = {} 95 96 # Upsert entities 97 for (normalized, etype), count in per_source_counts.items(): 98 existing = ( 99 session.query(KGEntityDatabase) 100 .filter( 101 KGEntityDatabase.project_id == project_id, 102 KGEntityDatabase.normalized == normalized, 103 KGEntityDatabase.entity_type == etype, 104 ) 105 .first() 106 ) 107 if existing: 108 existing.mention_count = (existing.mention_count or 0) + count 109 existing.updated_at = now 110 entity_ids[(normalized, etype)] = existing.id 111 else: 112 ent = KGEntityDatabase( 113 project_id=project_id, 114 name=canonical_names[(normalized, etype)], 115 normalized=normalized, 116 entity_type=etype, 117 mention_count=count, 118 created_at=now, 119 updated_at=now, 120 ) 121 session.add(ent) 122 session.flush() 123 entity_ids[(normalized, etype)] = ent.id 124 125 # Upsert mentions (entity × source) 126 for (normalized, etype), count in per_source_counts.items(): 127 eid = entity_ids[(normalized, etype)] 128 existing_mention = ( 129 session.query(KGEntityMentionDatabase) 130 .filter( 131 KGEntityMentionDatabase.entity_id == eid, 132 KGEntityMentionDatabase.source == source, 133 ) 134 .first() 135 ) 136 if existing_mention: 137 existing_mention.mention_count = (existing_mention.mention_count or 0) + count 138 else: 139 session.add(KGEntityMentionDatabase( 140 entity_id=eid, 141 project_id=project_id, 142 source=source, 143 mention_count=count, 144 created_at=now, 145 )) 146 147 # Co-occurrence edges: for every pair of distinct entities in this source, 148 # increment the edge weight (or create the edge). 149 ids_in_source = sorted(entity_ids.values()) 150 for i, a in enumerate(ids_in_source): 151 for b in ids_in_source[i + 1:]: 152 existing_edge = ( 153 session.query(KGEntityRelationshipDatabase) 154 .filter( 155 KGEntityRelationshipDatabase.project_id == project_id, 156 KGEntityRelationshipDatabase.from_entity_id == a, 157 KGEntityRelationshipDatabase.to_entity_id == b, 158 ) 159 .first() 160 ) 161 if existing_edge: 162 existing_edge.weight = (existing_edge.weight or 0) + 1 163 else: 164 session.add(KGEntityRelationshipDatabase( 165 project_id=project_id, 166 from_entity_id=a, 167 to_entity_id=b, 168 weight=1, 169 created_at=now, 170 )) 171 172 session.commit() 173 return len(per_source_counts) 174 175 176 def extract_and_persist_safe(project_id: int, source: str, text: str, brain, db_factory) -> None: 177 """Background-task safe wrapper. Creates a fresh DB session and handles errors.""" 178 try: 179 db = db_factory() 180 try: 181 extract_and_persist(project_id, source, text, brain, db) 182 finally: 183 try: 184 db.db.close() 185 except Exception: 186 pass 187 except Exception as e: 188 logger.exception("Background entity extraction failed for project %s source %s: %s", project_id, source, e) 189 190 191 def merge_entities(db, primary_id: int, secondary_id: int) -> bool: 192 """Merge secondary entity into primary. Moves all mentions and relationships, then deletes secondary. 193 194 Returns True if successful, False if either entity doesn't exist or they're the same. 195 """ 196 if primary_id == secondary_id: 197 return False 198 session = db.db 199 primary = session.query(KGEntityDatabase).filter(KGEntityDatabase.id == primary_id).first() 200 secondary = session.query(KGEntityDatabase).filter(KGEntityDatabase.id == secondary_id).first() 201 if not primary or not secondary or primary.project_id != secondary.project_id: 202 return False 203 204 now = datetime.now(timezone.utc) 205 206 # Move mentions: if a mention of the same source exists for primary, sum counts; otherwise repoint 207 secondary_mentions = ( 208 session.query(KGEntityMentionDatabase) 209 .filter(KGEntityMentionDatabase.entity_id == secondary_id) 210 .all() 211 ) 212 for sm in secondary_mentions: 213 existing = ( 214 session.query(KGEntityMentionDatabase) 215 .filter( 216 KGEntityMentionDatabase.entity_id == primary_id, 217 KGEntityMentionDatabase.source == sm.source, 218 ) 219 .first() 220 ) 221 if existing: 222 existing.mention_count = (existing.mention_count or 0) + (sm.mention_count or 0) 223 session.delete(sm) 224 else: 225 sm.entity_id = primary_id 226 227 # Move relationships: any edge involving secondary → repoint to primary, dedup 228 secondary_edges = ( 229 session.query(KGEntityRelationshipDatabase) 230 .filter( 231 (KGEntityRelationshipDatabase.from_entity_id == secondary_id) 232 | (KGEntityRelationshipDatabase.to_entity_id == secondary_id) 233 ) 234 .all() 235 ) 236 for edge in secondary_edges: 237 new_from = primary_id if edge.from_entity_id == secondary_id else edge.from_entity_id 238 new_to = primary_id if edge.to_entity_id == secondary_id else edge.to_entity_id 239 if new_from == new_to: 240 session.delete(edge) 241 continue 242 # Normalize order so we can dedup 243 a, b = sorted([new_from, new_to]) 244 existing = ( 245 session.query(KGEntityRelationshipDatabase) 246 .filter( 247 KGEntityRelationshipDatabase.project_id == primary.project_id, 248 KGEntityRelationshipDatabase.from_entity_id == a, 249 KGEntityRelationshipDatabase.to_entity_id == b, 250 ) 251 .first() 252 ) 253 if existing and existing.id != edge.id: 254 existing.weight = (existing.weight or 0) + (edge.weight or 0) 255 session.delete(edge) 256 else: 257 edge.from_entity_id = a 258 edge.to_entity_id = b 259 260 primary.mention_count = (primary.mention_count or 0) + (secondary.mention_count or 0) 261 primary.updated_at = now 262 session.delete(secondary) 263 session.commit() 264 return True 265 266 267 def compute_potential_duplicates(db, project_id: int, threshold: float = 0.85, limit: int = 100) -> list[dict]: 268 """Find entity pairs with similar names within the same type.""" 269 session = db.db 270 entities = ( 271 session.query(KGEntityDatabase) 272 .filter(KGEntityDatabase.project_id == project_id) 273 .all() 274 ) 275 by_type: dict[str, list[KGEntityDatabase]] = defaultdict(list) 276 for e in entities: 277 by_type[e.entity_type].append(e) 278 279 candidates: list[dict] = [] 280 for etype, ents in by_type.items(): 281 for i in range(len(ents)): 282 for j in range(i + 1, len(ents)): 283 a, b = ents[i], ents[j] 284 if a.normalized == b.normalized: 285 continue # Already merged 286 ratio = SequenceMatcher(None, a.normalized, b.normalized).ratio() 287 if ratio >= threshold: 288 candidates.append({ 289 "entity_a_id": a.id, 290 "entity_a_name": a.name, 291 "entity_b_id": b.id, 292 "entity_b_name": b.name, 293 "similarity": round(ratio, 3), 294 }) 295 candidates.sort(key=lambda c: c["similarity"], reverse=True) 296 return candidates[:limit]