/ src / lib / server / langgraph-checkpoint.ts
langgraph-checkpoint.ts
  1  import type { RunnableConfig } from '@langchain/core/runnables'
  2  import {
  3    BaseCheckpointSaver,
  4    WRITES_IDX_MAP,
  5    copyCheckpoint,
  6    maxChannelVersion,
  7    TASKS,
  8  } from '@langchain/langgraph-checkpoint'
  9  import type {
 10    Checkpoint,
 11    CheckpointListOptions,
 12    CheckpointMetadata,
 13    CheckpointPendingWrite,
 14    CheckpointTuple,
 15    PendingWrite,
 16  } from '@langchain/langgraph-checkpoint'
 17  import Database from 'better-sqlite3'
 18  import path from 'path'
 19  import { DATA_DIR } from './data-dir'
 20  
 21  const DB_PATH = path.join(DATA_DIR, 'swarmclaw.db')
 22  
 23  function getDb(dbPath = DB_PATH): Database.Database {
 24    const db = new Database(dbPath)
 25    db.pragma('journal_mode = WAL')
 26    db.pragma('busy_timeout = 5000')
 27    return db
 28  }
 29  
 30  function hasColumn(db: Database.Database, table: string, column: string): boolean {
 31    const rows = db.prepare(`PRAGMA table_info(${table})`).all() as Array<{ name: string }>
 32    return rows.some((row) => row.name === column)
 33  }
 34  
 35  function ensureSchema(db: Database.Database): void {
 36    db.exec(`
 37      CREATE TABLE IF NOT EXISTS langgraph_checkpoints (
 38        thread_id TEXT NOT NULL,
 39        checkpoint_ns TEXT NOT NULL DEFAULT '',
 40        checkpoint_id TEXT NOT NULL,
 41        parent_checkpoint_id TEXT,
 42        type TEXT NOT NULL DEFAULT 'json',
 43        checkpoint BLOB NOT NULL,
 44        metadata BLOB NOT NULL DEFAULT '{}',
 45        created_at INTEGER NOT NULL DEFAULT (unixepoch()),
 46        PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
 47      )
 48    `)
 49    db.exec(`
 50      CREATE TABLE IF NOT EXISTS langgraph_writes (
 51        thread_id TEXT NOT NULL,
 52        checkpoint_ns TEXT NOT NULL DEFAULT '',
 53        checkpoint_id TEXT NOT NULL,
 54        task_id TEXT NOT NULL,
 55        idx INTEGER NOT NULL,
 56        channel TEXT NOT NULL,
 57        type TEXT NOT NULL DEFAULT 'json',
 58        value BLOB,
 59        PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
 60      )
 61    `)
 62    if (!hasColumn(db, 'langgraph_checkpoints', 'metadata_type')) {
 63      try {
 64        db.exec(`ALTER TABLE langgraph_checkpoints ADD COLUMN metadata_type TEXT NOT NULL DEFAULT 'json'`)
 65      } catch (err: unknown) {
 66        // Tolerate "duplicate column" from concurrent build workers
 67        if (!(err instanceof Error) || !err.message.includes('duplicate column')) throw err
 68      }
 69    }
 70  }
 71  
 72  const initDb = getDb()
 73  ensureSchema(initDb)
 74  initDb.close()
 75  
 76  function getThreadId(config: RunnableConfig): string {
 77    return (config.configurable?.thread_id as string) || ''
 78  }
 79  
 80  function getOptionalCheckpointNs(config: RunnableConfig): string | undefined {
 81    const raw = config.configurable?.checkpoint_ns
 82    return typeof raw === 'string' ? raw : undefined
 83  }
 84  
 85  function getCheckpointNs(config: RunnableConfig, fallback = ''): string {
 86    return getOptionalCheckpointNs(config) ?? fallback
 87  }
 88  
 89  function getConfiguredCheckpointId(config: RunnableConfig): string | undefined {
 90    const raw = config.configurable?.checkpoint_id
 91    return typeof raw === 'string' && raw.trim() ? raw : undefined
 92  }
 93  
 94  function normalizeSerializedValue(value: unknown): Uint8Array | string | undefined {
 95    if (value == null) return undefined
 96    if (typeof value === 'string') return value
 97    if (value instanceof Uint8Array) return value
 98    if (Buffer.isBuffer(value)) return new Uint8Array(value)
 99    return Buffer.from(String(value))
100  }
101  
102  function readLegacyJson<T>(value: Uint8Array | string | undefined): T {
103    if (value == null) return undefined as T
104    const text = typeof value === 'string' ? value : Buffer.from(value).toString()
105    return JSON.parse(text) as T
106  }
107  
108  type CheckpointRow = {
109    thread_id: string
110    checkpoint_ns: string
111    checkpoint_id: string
112    parent_checkpoint_id: string | null
113    type: string
114    checkpoint: Buffer | string
115    metadata_type?: string | null
116    metadata: Buffer | string
117    created_at: number
118  }
119  
120  type PendingWriteRow = {
121    task_id: string
122    idx: number
123    channel: string
124    type: string
125    value: Buffer | string | null
126  }
127  
128  export class SqliteCheckpointSaver extends BaseCheckpointSaver {
129    private db: Database.Database
130  
131    constructor(dbPath = DB_PATH) {
132      super()
133      this.db = getDb(dbPath)
134      ensureSchema(this.db)
135    }
136  
137    private async deserializeValue<T>(type: string | undefined, value: unknown): Promise<T> {
138      const normalized = normalizeSerializedValue(value)
139      if (normalized == null) return undefined as T
140      const serializationType = typeof type === 'string' && type.trim() ? type : 'json'
141      try {
142        return await this.serde.loadsTyped(serializationType, normalized) as T
143      } catch (err) {
144        if (serializationType !== 'json') throw err
145        return readLegacyJson<T>(normalized)
146      }
147    }
148  
149    private async loadPendingWrites(
150      threadId: string,
151      checkpointNs: string,
152      checkpointId: string,
153    ): Promise<CheckpointPendingWrite[]> {
154      const rows = this.db.prepare(
155        `SELECT task_id, idx, channel, type, value
156         FROM langgraph_writes
157         WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?
158         ORDER BY task_id ASC, idx ASC`
159      ).all(threadId, checkpointNs, checkpointId) as PendingWriteRow[]
160  
161      return Promise.all(rows.map(async (row) => [
162        row.task_id,
163        row.channel,
164        await this.deserializeValue(row.type, row.value),
165      ]))
166    }
167  
168    private async migratePendingSends(
169      checkpoint: Checkpoint,
170      threadId: string,
171      checkpointNs: string,
172      parentCheckpointId: string,
173    ): Promise<void> {
174      if (checkpoint.v >= 4) return
175  
176      const rows = this.db.prepare(
177        `SELECT type, value
178         FROM langgraph_writes
179         WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? AND channel = ?
180         ORDER BY idx ASC`
181      ).all(threadId, checkpointNs, parentCheckpointId, TASKS) as Array<{ type: string; value: Buffer | string | null }>
182  
183      if (!rows.length) return
184  
185      const pendingSends = await Promise.all(rows.map((row) => this.deserializeValue(row.type, row.value)))
186      checkpoint.channel_values ??= {}
187      checkpoint.channel_values[TASKS] = pendingSends
188      checkpoint.channel_versions ??= {}
189      checkpoint.channel_versions[TASKS] = Object.keys(checkpoint.channel_versions).length > 0
190        ? maxChannelVersion(...Object.values(checkpoint.channel_versions))
191        : this.getNextVersion(undefined)
192    }
193  
194    async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
195      const threadId = getThreadId(config)
196      if (!threadId) return undefined
197  
198      const checkpointNs = getCheckpointNs(config)
199      const checkpointId = getConfiguredCheckpointId(config)
200  
201      let row: CheckpointRow | undefined
202      if (checkpointId) {
203        row = this.db.prepare(
204          `SELECT *
205           FROM langgraph_checkpoints
206           WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`
207        ).get(threadId, checkpointNs, checkpointId) as CheckpointRow | undefined
208      } else {
209        row = this.db.prepare(
210          `SELECT *
211           FROM langgraph_checkpoints
212           WHERE thread_id = ? AND checkpoint_ns = ?
213           ORDER BY checkpoint_id DESC
214           LIMIT 1`
215        ).get(threadId, checkpointNs) as CheckpointRow | undefined
216      }
217  
218      if (!row) return undefined
219  
220      const checkpoint = await this.deserializeValue<Checkpoint>(row.type, row.checkpoint)
221      if (row.parent_checkpoint_id) {
222        await this.migratePendingSends(checkpoint, threadId, checkpointNs, row.parent_checkpoint_id)
223      }
224  
225      const resultConfig: RunnableConfig = {
226        configurable: {
227          thread_id: threadId,
228          checkpoint_ns: checkpointNs,
229          checkpoint_id: row.checkpoint_id,
230        },
231      }
232  
233      const parentConfig = row.parent_checkpoint_id
234        ? {
235            configurable: {
236              thread_id: threadId,
237              checkpoint_ns: checkpointNs,
238              checkpoint_id: row.parent_checkpoint_id,
239            },
240          }
241        : undefined
242  
243      return {
244        config: resultConfig,
245        checkpoint,
246        metadata: await this.deserializeValue<CheckpointMetadata>(row.metadata_type ?? 'json', row.metadata),
247        parentConfig,
248        pendingWrites: await this.loadPendingWrites(threadId, checkpointNs, row.checkpoint_id),
249      }
250    }
251  
252    async *list(
253      config: RunnableConfig,
254      options?: CheckpointListOptions,
255    ): AsyncGenerator<CheckpointTuple> {
256      const threadId = getThreadId(config)
257      if (!threadId) return
258  
259      const checkpointNs = getOptionalCheckpointNs(config)
260      const checkpointId = getConfiguredCheckpointId(config)
261      const limit = options?.limit
262  
263      let query = `
264        SELECT *
265        FROM langgraph_checkpoints
266        WHERE thread_id = ?
267      `
268      const params: Array<string | number> = [threadId]
269  
270      if (checkpointNs !== undefined) {
271        query += ` AND checkpoint_ns = ?`
272        params.push(checkpointNs)
273      }
274  
275      if (checkpointId) {
276        query += ` AND checkpoint_id = ?`
277        params.push(checkpointId)
278      }
279  
280      if (options?.before?.configurable?.checkpoint_id) {
281        query += ` AND checkpoint_id < ?`
282        params.push(options.before.configurable.checkpoint_id)
283      }
284  
285      query += ` ORDER BY checkpoint_id DESC`
286  
287      const rows = this.db.prepare(query).all(...params) as CheckpointRow[]
288      let yielded = 0
289  
290      for (const row of rows) {
291        const metadata = await this.deserializeValue<CheckpointMetadata>(row.metadata_type ?? 'json', row.metadata)
292        if (options?.filter && !Object.entries(options.filter).every(([key, value]) => (metadata as Record<string, unknown>)[key] === value)) {
293          continue
294        }
295  
296        const checkpoint = await this.deserializeValue<Checkpoint>(row.type, row.checkpoint)
297        if (row.parent_checkpoint_id) {
298          await this.migratePendingSends(checkpoint, threadId, row.checkpoint_ns, row.parent_checkpoint_id)
299        }
300  
301        yield {
302          config: {
303            configurable: {
304              thread_id: threadId,
305              checkpoint_ns: row.checkpoint_ns,
306              checkpoint_id: row.checkpoint_id,
307            },
308          },
309          checkpoint,
310          metadata,
311          parentConfig: row.parent_checkpoint_id
312            ? {
313                configurable: {
314                  thread_id: threadId,
315                  checkpoint_ns: row.checkpoint_ns,
316                  checkpoint_id: row.parent_checkpoint_id,
317                },
318              }
319            : undefined,
320          pendingWrites: await this.loadPendingWrites(threadId, row.checkpoint_ns, row.checkpoint_id),
321        }
322  
323        yielded += 1
324        if (limit !== undefined && yielded >= limit) break
325      }
326    }
327  
328    async put(
329      config: RunnableConfig,
330      checkpoint: Checkpoint,
331      metadata: CheckpointMetadata,
332      newVersions: Record<string, number | string>,
333    ): Promise<RunnableConfig> {
334      const threadId = getThreadId(config)
335      const checkpointNs = getCheckpointNs(config)
336      const parentCheckpointId = getConfiguredCheckpointId(config)
337  
338      if (!threadId) {
339        throw new Error('Failed to put checkpoint. Missing required configurable.thread_id.')
340      }
341      void newVersions
342  
343      const preparedCheckpoint = copyCheckpoint(checkpoint)
344      const [
345        [checkpointType, serializedCheckpoint],
346        [metadataType, serializedMetadata],
347      ] = await Promise.all([
348        this.serde.dumpsTyped(preparedCheckpoint),
349        this.serde.dumpsTyped(metadata),
350      ])
351  
352      const createdAt = Number.isFinite(Date.parse(preparedCheckpoint.ts))
353        ? Date.parse(preparedCheckpoint.ts)
354        : Date.now()
355  
356      this.db.prepare(`
357        INSERT OR REPLACE INTO langgraph_checkpoints
358          (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata_type, metadata, created_at)
359        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
360      `).run(
361        threadId,
362        checkpointNs,
363        preparedCheckpoint.id,
364        parentCheckpointId || null,
365        checkpointType,
366        Buffer.from(serializedCheckpoint),
367        metadataType,
368        Buffer.from(serializedMetadata),
369        createdAt,
370      )
371  
372      return {
373        configurable: {
374          thread_id: threadId,
375          checkpoint_ns: checkpointNs,
376          checkpoint_id: preparedCheckpoint.id,
377        },
378      }
379    }
380  
381    async putWrites(
382      config: RunnableConfig,
383      writes: PendingWrite[],
384      taskId: string,
385    ): Promise<void> {
386      const threadId = getThreadId(config)
387      const checkpointNs = getCheckpointNs(config)
388      const checkpointId = getConfiguredCheckpointId(config)
389  
390      if (!threadId) {
391        throw new Error('Failed to put writes. Missing required configurable.thread_id.')
392      }
393      if (!checkpointId) {
394        throw new Error('Failed to put writes. Missing required configurable.checkpoint_id.')
395      }
396  
397      const serializedWrites = await Promise.all(writes.map(async ([channel, value], idx) => {
398        const [type, serializedValue] = await this.serde.dumpsTyped(value)
399        const writeIdx = WRITES_IDX_MAP[channel as string] ?? idx
400        return {
401          channel: channel as string,
402          idx: writeIdx,
403          type,
404          value: Buffer.from(serializedValue),
405        }
406      }))
407  
408      const getExisting = this.db.prepare(
409        `SELECT 1
410         FROM langgraph_writes
411         WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? AND task_id = ? AND idx = ?`
412      )
413      const upsert = this.db.prepare(`
414        INSERT OR REPLACE INTO langgraph_writes
415          (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
416        VALUES (?, ?, ?, ?, ?, ?, ?, ?)
417      `)
418  
419      const tx = this.db.transaction((items: typeof serializedWrites) => {
420        for (const item of items) {
421          if (item.idx >= 0) {
422            const existing = getExisting.get(threadId, checkpointNs, checkpointId, taskId, item.idx)
423            if (existing) continue
424          }
425          upsert.run(
426            threadId,
427            checkpointNs,
428            checkpointId,
429            taskId,
430            item.idx,
431            item.channel,
432            item.type,
433            item.value,
434          )
435        }
436      })
437  
438      tx(serializedWrites)
439    }
440    async deleteThread(threadId: string): Promise<void> {
441      this.db.prepare(`DELETE FROM langgraph_checkpoints WHERE thread_id = ?`).run(threadId)
442      this.db.prepare(`DELETE FROM langgraph_writes WHERE thread_id = ?`).run(threadId)
443    }
444  
445    async deleteCheckpoint(threadId: string, checkpointId: string): Promise<void> {
446      this.db.prepare(`DELETE FROM langgraph_checkpoints WHERE thread_id = ? AND checkpoint_id = ?`).run(threadId, checkpointId)
447      this.db.prepare(`DELETE FROM langgraph_writes WHERE thread_id = ? AND checkpoint_id = ?`).run(threadId, checkpointId)
448    }
449  
450    async deleteCheckpointsAfter(threadId: string, timestamp: number): Promise<void> {
451      this.db.prepare(`DELETE FROM langgraph_checkpoints WHERE thread_id = ? AND created_at > ?`).run(threadId, timestamp)
452      this.db.prepare(`
453        DELETE FROM langgraph_writes
454        WHERE thread_id = ?
455          AND checkpoint_id NOT IN (
456            SELECT checkpoint_id FROM langgraph_checkpoints WHERE thread_id = ?
457          )
458      `).run(threadId, threadId)
459    }
460  }
461  
462  let _saver: SqliteCheckpointSaver | undefined
463  
464  export function getCheckpointSaver(): SqliteCheckpointSaver {
465    if (!_saver) _saver = new SqliteCheckpointSaver()
466    return _saver
467  }