/ tests / stress / test_concurrency_reclaim_race.py
test_concurrency_reclaim_race.py
  1  """Target the reclaim race specifically.
  2  
  3  Workers claim tasks with a 1s TTL but sleep 2s before completing. The
  4  reclaimer runs every 200ms. Scenario: worker claims, reclaimer expires
  5  the claim mid-work, worker tries to complete AFTER its run has been
  6  reclaimed.
  7  
  8  Expected behavior (per design): the worker's complete_task should
  9  either succeed on the reclaimed-and-re-claimed-by-another-worker case
 10  (no, it should refuse — the claim was invalidated), OR succeed by
 11  grace (we "forgive" a late complete from the original worker if no
 12  one else picked it up).
 13  
 14  Actually looking at complete_task: it doesn't check claim_lock. It just
 15  transitions from 'running' -> 'done'. So if the reclaimer moved it back
 16  to 'ready', the late worker's complete_task will fail (CAS on
 17  status='running' fails). This is the CORRECT behavior.
 18  
 19  Invariant being tested: race between worker.complete and
 20  dispatcher.reclaim must not produce a double-run-close or other
 21  inconsistency.
 22  """
 23  
 24  import json
 25  import multiprocessing as mp
 26  import os
 27  import random
 28  import sqlite3
 29  import sys
 30  import tempfile
 31  import time
 32  from pathlib import Path
 33  
 34  NUM_WORKERS = 5
 35  NUM_TASKS = 50
 36  TTL = 1
 37  WORK_DURATION_S = 2.0  # longer than TTL => reclaimer wins
 38  WT = str(Path(__file__).resolve().parents[2])
 39  
 40  
 41  def worker_loop(worker_id: int, hermes_home: str, result_file: str) -> None:
 42      os.environ["HERMES_HOME"] = hermes_home
 43      os.environ["HOME"] = hermes_home
 44      sys.path.insert(0, WT)
 45      from hermes_cli import kanban_db as kb
 46  
 47      events = []
 48      start = time.monotonic()
 49      idle = 0
 50  
 51      while time.monotonic() - start < 40:
 52          conn = kb.connect()
 53          try:
 54              row = conn.execute(
 55                  "SELECT id FROM tasks WHERE status='ready' AND claim_lock IS NULL LIMIT 1"
 56              ).fetchone()
 57              if row is None:
 58                  idle += 1
 59                  if idle > 30:
 60                      break
 61                  time.sleep(0.05)
 62                  continue
 63              idle = 0
 64              tid = row["id"]
 65              try:
 66                  claimed = kb.claim_task(conn, tid, claimer=f"worker-{worker_id}",
 67                                          ttl_seconds=TTL)
 68              except sqlite3.OperationalError as e:
 69                  events.append({"kind": "sqlite_err", "op": "claim", "err": str(e)[:100]})
 70                  continue
 71              if claimed is None:
 72                  events.append({"kind": "lost_claim", "task": tid})
 73                  continue
 74              run = kb.latest_run(conn, tid)
 75              events.append({"kind": "claimed", "task": tid, "worker": worker_id,
 76                             "run_id": run.id})
 77  
 78              # Sleep longer than TTL so reclaimer has a chance to intervene
 79              time.sleep(WORK_DURATION_S + random.uniform(-0.3, 0.3))
 80  
 81              try:
 82                  ok = kb.complete_task(
 83                      conn, tid,
 84                      result=f"by worker-{worker_id}",
 85                      summary=f"worker-{worker_id} finished",
 86                  )
 87                  events.append({"kind": "complete_ok" if ok else "complete_refused",
 88                                 "task": tid, "worker": worker_id, "run_id": run.id})
 89              except sqlite3.OperationalError as e:
 90                  events.append({"kind": "sqlite_err", "op": "complete", "err": str(e)[:100]})
 91          finally:
 92              conn.close()
 93  
 94      with open(result_file, "w") as f:
 95          json.dump(events, f)
 96  
 97  
 98  def reclaimer_loop(hermes_home: str, result_file: str) -> None:
 99      os.environ["HERMES_HOME"] = hermes_home
100      os.environ["HOME"] = hermes_home
101      sys.path.insert(0, WT)
102      from hermes_cli import kanban_db as kb
103  
104      events = []
105      start = time.monotonic()
106      while time.monotonic() - start < 42:
107          conn = kb.connect()
108          try:
109              try:
110                  n = kb.release_stale_claims(conn)
111                  if n:
112                      events.append({"kind": "reclaimed", "count": n,
113                                     "t": time.monotonic() - start})
114              except sqlite3.OperationalError as e:
115                  events.append({"kind": "sqlite_err", "err": str(e)[:100]})
116          finally:
117              conn.close()
118          time.sleep(0.2)
119      with open(result_file, "w") as f:
120          json.dump(events, f)
121  
122  
123  def main():
124      home = tempfile.mkdtemp(prefix="hermes_reclaim_race_")
125      os.environ["HERMES_HOME"] = home
126      os.environ["HOME"] = home
127      sys.path.insert(0, WT)
128      from hermes_cli import kanban_db as kb
129  
130      kb.init_db()
131      conn = kb.connect()
132      for i in range(NUM_TASKS):
133          kb.create_task(conn, title=f"t{i}", assignee="shared",
134                         tenant="reclaim-race")
135      conn.close()
136      print(f"Seeded {NUM_TASKS} tasks. TTL={TTL}s, work_duration={WORK_DURATION_S}s")
137      print(f"(worker work > TTL guarantees reclaims)")
138  
139      ctx = mp.get_context("spawn")
140      worker_results = [f"/tmp/rc_worker_{i}.json" for i in range(NUM_WORKERS)]
141      reclaim_result = "/tmp/rc_reclaim.json"
142      procs = []
143      for i in range(NUM_WORKERS):
144          p = ctx.Process(target=worker_loop, args=(i, home, worker_results[i]))
145          p.start()
146          procs.append(p)
147      r = ctx.Process(target=reclaimer_loop, args=(home, reclaim_result))
148      r.start()
149      procs.append(r)
150  
151      for p in procs:
152          p.join(timeout=60)
153          if p.is_alive():
154              p.terminate()
155              p.join()
156  
157      # Aggregate.
158      all_events = []
159      for f in worker_results:
160          if os.path.isfile(f):
161              with open(f) as fh:
162                  all_events.extend(json.load(fh))
163      reclaim_events = []
164      if os.path.isfile(reclaim_result):
165          with open(reclaim_result) as fh:
166              reclaim_events = json.load(fh)
167  
168      op_counts = {}
169      for e in all_events:
170          op_counts[e["kind"]] = op_counts.get(e["kind"], 0) + 1
171      total_reclaims = sum(e.get("count", 0) for e in reclaim_events)
172      print(f"\nReclaimer fired {len(reclaim_events)} times, total tasks reclaimed: {total_reclaims}")
173      print("Worker events:")
174      for k in sorted(op_counts):
175          print(f"  {k:<25} {op_counts[k]}")
176  
177      # Invariant checks
178      failures = []
179      conn = kb.connect()
180      try:
181          # Any task stuck with current_run_id pointing at a closed run?
182          bad = conn.execute("""
183              SELECT t.id, t.status, t.current_run_id, r.ended_at, r.outcome
184              FROM tasks t
185              JOIN task_runs r ON r.id = t.current_run_id
186              WHERE r.ended_at IS NOT NULL
187          """).fetchall()
188          for row in bad:
189              failures.append(
190                  f"INVARIANT VIOLATION: task {row['id']} status={row['status']} "
191                  f"current_run_id={row['current_run_id']} but run ended "
192                  f"outcome={row['outcome']}"
193              )
194          # Every run with NULL ended_at should still have the task pointing at it
195          orphans = conn.execute("""
196              SELECT r.id, r.task_id
197              FROM task_runs r
198              LEFT JOIN tasks t ON t.current_run_id = r.id
199              WHERE r.ended_at IS NULL AND t.id IS NULL
200          """).fetchall()
201          for row in orphans:
202              failures.append(f"ORPHAN OPEN RUN: run {row['id']} on task {row['task_id']}")
203          # Event counts
204          claim_evts = conn.execute(
205              "SELECT COUNT(*) FROM task_events WHERE kind='claimed'").fetchone()[0]
206          reclaim_evts = conn.execute(
207              "SELECT COUNT(*) FROM task_events WHERE kind='reclaimed'").fetchone()[0]
208          comp_evts = conn.execute(
209              "SELECT COUNT(*) FROM task_events WHERE kind='completed'").fetchone()[0]
210          print(f"\nDB event counts: claimed={claim_evts} reclaimed={reclaim_evts} completed={comp_evts}")
211          # Every reclaimed run must have ended_at set
212          unended_reclaims = conn.execute(
213              "SELECT COUNT(*) FROM task_runs WHERE outcome='reclaimed' AND ended_at IS NULL"
214          ).fetchone()[0]
215          if unended_reclaims:
216              failures.append(f"UNENDED RECLAIMED RUNS: {unended_reclaims}")
217          # Count of completed runs
218          comp_runs = conn.execute(
219              "SELECT COUNT(*) FROM task_runs WHERE outcome='completed'"
220          ).fetchone()[0]
221          reclaim_runs = conn.execute(
222              "SELECT COUNT(*) FROM task_runs WHERE outcome='reclaimed'"
223          ).fetchone()[0]
224          print(f"DB run outcomes: completed={comp_runs} reclaimed={reclaim_runs}")
225      finally:
226          conn.close()
227  
228      if reclaim_runs == 0:
229          failures.append("NO RECLAIMS HAPPENED — test didn't stress what it was supposed to")
230  
231      if failures:
232          print(f"\nFAILURES ({len(failures)}):")
233          for f in failures[:20]:
234              print(f"  {f}")
235          sys.exit(1)
236      else:
237          print("\nāœ” RECLAIM RACE INVARIANTS HELD")
238  
239  
240  if __name__ == "__main__":
241      main()