/ restai / knowledge_graph.py
knowledge_graph.py
  1  """Knowledge graph helpers: entity extraction, persistence, dedup, merge.
  2  
  3  Stores extracted entities at the source level (not chunk level) in three tables:
  4  - kg_entities: unique entities per project (deduped by normalized name + type)
  5  - kg_entity_mentions: which sources each entity appears in
  6  - kg_entity_relationships: co-occurrence edges between entities
  7  """
  8  import logging
  9  import re
 10  from collections import Counter, defaultdict
 11  from datetime import datetime, timezone
 12  from difflib import SequenceMatcher
 13  from typing import Optional
 14  
 15  from restai.models.databasemodels import (
 16      KGEntityDatabase,
 17      KGEntityMentionDatabase,
 18      KGEntityRelationshipDatabase,
 19  )
 20  
 21  logger = logging.getLogger(__name__)
 22  
 23  ALLOWED_TYPES = {"PER", "PERSON", "ORG", "LOC", "MISC", "DATE", "GPE"}
 24  TYPE_NORMALIZATION = {
 25      "PER": "PERSON",
 26      "GPE": "LOC",
 27  }
 28  
 29  
 30  def normalize_entity_name(name: str) -> str:
 31      """Lowercase + collapse whitespace + strip punctuation around the edges."""
 32      s = (name or "").strip().lower()
 33      s = re.sub(r"\s+", " ", s)
 34      s = re.sub(r"^[^\w]+|[^\w]+$", "", s)
 35      return s
 36  
 37  
 38  def _normalize_type(t: str) -> str:
 39      return TYPE_NORMALIZATION.get(t, t)
 40  
 41  
 42  def find_entities_in_text(text: str, brain, model_name: Optional[str] = None) -> list[tuple[str, str]]:
 43      """Run NER and return [(canonical_name, type), ...] deduped within this call."""
 44      if not text:
 45          return []
 46      raw = brain.extract_entities_from_text(text, model_name=model_name)
 47      seen: dict[tuple[str, str], str] = {}
 48      for ent in raw:
 49          word = (ent.get("word") or "").strip()
 50          etype = _normalize_type((ent.get("entity_group") or ent.get("entity") or "").upper())
 51          if not word or etype not in ALLOWED_TYPES:
 52              continue
 53          normalized = normalize_entity_name(word)
 54          if not normalized:
 55              continue
 56          key = (normalized, etype)
 57          if key not in seen:
 58              seen[key] = word
 59      return [(canonical, etype) for (norm, etype), canonical in seen.items()]
 60  
 61  
 62  def extract_and_persist(project_id: int, source: str, text: str, brain, db) -> int:
 63      """Extract entities from text and persist them. Returns the count of distinct entities found.
 64  
 65      `db` is the DBWrapper instance (uses db.db for the SQLAlchemy session).
 66      """
 67      if not text:
 68          return 0
 69      raw = brain.extract_entities_from_text(text, model_name=None)
 70      if not raw:
 71          return 0
 72  
 73      # Aggregate counts within this source
 74      per_source_counts: Counter = Counter()
 75      canonical_names: dict[tuple[str, str], str] = {}
 76      for ent in raw:
 77          word = (ent.get("word") or "").strip()
 78          etype = _normalize_type((ent.get("entity_group") or ent.get("entity") or "").upper())
 79          if not word or etype not in ALLOWED_TYPES:
 80              continue
 81          normalized = normalize_entity_name(word)
 82          if not normalized:
 83              continue
 84          key = (normalized, etype)
 85          per_source_counts[key] += 1
 86          if key not in canonical_names:
 87              canonical_names[key] = word
 88  
 89      if not per_source_counts:
 90          return 0
 91  
 92      now = datetime.now(timezone.utc)
 93      session = db.db
 94      entity_ids: dict[tuple[str, str], int] = {}
 95  
 96      # Upsert entities
 97      for (normalized, etype), count in per_source_counts.items():
 98          existing = (
 99              session.query(KGEntityDatabase)
100              .filter(
101                  KGEntityDatabase.project_id == project_id,
102                  KGEntityDatabase.normalized == normalized,
103                  KGEntityDatabase.entity_type == etype,
104              )
105              .first()
106          )
107          if existing:
108              existing.mention_count = (existing.mention_count or 0) + count
109              existing.updated_at = now
110              entity_ids[(normalized, etype)] = existing.id
111          else:
112              ent = KGEntityDatabase(
113                  project_id=project_id,
114                  name=canonical_names[(normalized, etype)],
115                  normalized=normalized,
116                  entity_type=etype,
117                  mention_count=count,
118                  created_at=now,
119                  updated_at=now,
120              )
121              session.add(ent)
122              session.flush()
123              entity_ids[(normalized, etype)] = ent.id
124  
125      # Upsert mentions (entity × source)
126      for (normalized, etype), count in per_source_counts.items():
127          eid = entity_ids[(normalized, etype)]
128          existing_mention = (
129              session.query(KGEntityMentionDatabase)
130              .filter(
131                  KGEntityMentionDatabase.entity_id == eid,
132                  KGEntityMentionDatabase.source == source,
133              )
134              .first()
135          )
136          if existing_mention:
137              existing_mention.mention_count = (existing_mention.mention_count or 0) + count
138          else:
139              session.add(KGEntityMentionDatabase(
140                  entity_id=eid,
141                  project_id=project_id,
142                  source=source,
143                  mention_count=count,
144                  created_at=now,
145              ))
146  
147      # Co-occurrence edges: for every pair of distinct entities in this source,
148      # increment the edge weight (or create the edge).
149      ids_in_source = sorted(entity_ids.values())
150      for i, a in enumerate(ids_in_source):
151          for b in ids_in_source[i + 1:]:
152              existing_edge = (
153                  session.query(KGEntityRelationshipDatabase)
154                  .filter(
155                      KGEntityRelationshipDatabase.project_id == project_id,
156                      KGEntityRelationshipDatabase.from_entity_id == a,
157                      KGEntityRelationshipDatabase.to_entity_id == b,
158                  )
159                  .first()
160              )
161              if existing_edge:
162                  existing_edge.weight = (existing_edge.weight or 0) + 1
163              else:
164                  session.add(KGEntityRelationshipDatabase(
165                      project_id=project_id,
166                      from_entity_id=a,
167                      to_entity_id=b,
168                      weight=1,
169                      created_at=now,
170                  ))
171  
172      session.commit()
173      return len(per_source_counts)
174  
175  
176  def extract_and_persist_safe(project_id: int, source: str, text: str, brain, db_factory) -> None:
177      """Background-task safe wrapper. Creates a fresh DB session and handles errors."""
178      try:
179          db = db_factory()
180          try:
181              extract_and_persist(project_id, source, text, brain, db)
182          finally:
183              try:
184                  db.db.close()
185              except Exception:
186                  pass
187      except Exception as e:
188          logger.exception("Background entity extraction failed for project %s source %s: %s", project_id, source, e)
189  
190  
191  def merge_entities(db, primary_id: int, secondary_id: int) -> bool:
192      """Merge secondary entity into primary. Moves all mentions and relationships, then deletes secondary.
193  
194      Returns True if successful, False if either entity doesn't exist or they're the same.
195      """
196      if primary_id == secondary_id:
197          return False
198      session = db.db
199      primary = session.query(KGEntityDatabase).filter(KGEntityDatabase.id == primary_id).first()
200      secondary = session.query(KGEntityDatabase).filter(KGEntityDatabase.id == secondary_id).first()
201      if not primary or not secondary or primary.project_id != secondary.project_id:
202          return False
203  
204      now = datetime.now(timezone.utc)
205  
206      # Move mentions: if a mention of the same source exists for primary, sum counts; otherwise repoint
207      secondary_mentions = (
208          session.query(KGEntityMentionDatabase)
209          .filter(KGEntityMentionDatabase.entity_id == secondary_id)
210          .all()
211      )
212      for sm in secondary_mentions:
213          existing = (
214              session.query(KGEntityMentionDatabase)
215              .filter(
216                  KGEntityMentionDatabase.entity_id == primary_id,
217                  KGEntityMentionDatabase.source == sm.source,
218              )
219              .first()
220          )
221          if existing:
222              existing.mention_count = (existing.mention_count or 0) + (sm.mention_count or 0)
223              session.delete(sm)
224          else:
225              sm.entity_id = primary_id
226  
227      # Move relationships: any edge involving secondary → repoint to primary, dedup
228      secondary_edges = (
229          session.query(KGEntityRelationshipDatabase)
230          .filter(
231              (KGEntityRelationshipDatabase.from_entity_id == secondary_id)
232              | (KGEntityRelationshipDatabase.to_entity_id == secondary_id)
233          )
234          .all()
235      )
236      for edge in secondary_edges:
237          new_from = primary_id if edge.from_entity_id == secondary_id else edge.from_entity_id
238          new_to = primary_id if edge.to_entity_id == secondary_id else edge.to_entity_id
239          if new_from == new_to:
240              session.delete(edge)
241              continue
242          # Normalize order so we can dedup
243          a, b = sorted([new_from, new_to])
244          existing = (
245              session.query(KGEntityRelationshipDatabase)
246              .filter(
247                  KGEntityRelationshipDatabase.project_id == primary.project_id,
248                  KGEntityRelationshipDatabase.from_entity_id == a,
249                  KGEntityRelationshipDatabase.to_entity_id == b,
250              )
251              .first()
252          )
253          if existing and existing.id != edge.id:
254              existing.weight = (existing.weight or 0) + (edge.weight or 0)
255              session.delete(edge)
256          else:
257              edge.from_entity_id = a
258              edge.to_entity_id = b
259  
260      primary.mention_count = (primary.mention_count or 0) + (secondary.mention_count or 0)
261      primary.updated_at = now
262      session.delete(secondary)
263      session.commit()
264      return True
265  
266  
267  def compute_potential_duplicates(db, project_id: int, threshold: float = 0.85, limit: int = 100) -> list[dict]:
268      """Find entity pairs with similar names within the same type."""
269      session = db.db
270      entities = (
271          session.query(KGEntityDatabase)
272          .filter(KGEntityDatabase.project_id == project_id)
273          .all()
274      )
275      by_type: dict[str, list[KGEntityDatabase]] = defaultdict(list)
276      for e in entities:
277          by_type[e.entity_type].append(e)
278  
279      candidates: list[dict] = []
280      for etype, ents in by_type.items():
281          for i in range(len(ents)):
282              for j in range(i + 1, len(ents)):
283                  a, b = ents[i], ents[j]
284                  if a.normalized == b.normalized:
285                      continue  # Already merged
286                  ratio = SequenceMatcher(None, a.normalized, b.normalized).ratio()
287                  if ratio >= threshold:
288                      candidates.append({
289                          "entity_a_id": a.id,
290                          "entity_a_name": a.name,
291                          "entity_b_id": b.id,
292                          "entity_b_name": b.name,
293                          "similarity": round(ratio, 3),
294                      })
295      candidates.sort(key=lambda c: c["similarity"], reverse=True)
296      return candidates[:limit]