langgraph-checkpoint.ts
1 import type { RunnableConfig } from '@langchain/core/runnables' 2 import { 3 BaseCheckpointSaver, 4 WRITES_IDX_MAP, 5 copyCheckpoint, 6 maxChannelVersion, 7 TASKS, 8 } from '@langchain/langgraph-checkpoint' 9 import type { 10 Checkpoint, 11 CheckpointListOptions, 12 CheckpointMetadata, 13 CheckpointPendingWrite, 14 CheckpointTuple, 15 PendingWrite, 16 } from '@langchain/langgraph-checkpoint' 17 import Database from 'better-sqlite3' 18 import path from 'path' 19 import { DATA_DIR } from './data-dir' 20 21 const DB_PATH = path.join(DATA_DIR, 'swarmclaw.db') 22 23 function getDb(dbPath = DB_PATH): Database.Database { 24 const db = new Database(dbPath) 25 db.pragma('journal_mode = WAL') 26 db.pragma('busy_timeout = 5000') 27 return db 28 } 29 30 function hasColumn(db: Database.Database, table: string, column: string): boolean { 31 const rows = db.prepare(`PRAGMA table_info(${table})`).all() as Array<{ name: string }> 32 return rows.some((row) => row.name === column) 33 } 34 35 function ensureSchema(db: Database.Database): void { 36 db.exec(` 37 CREATE TABLE IF NOT EXISTS langgraph_checkpoints ( 38 thread_id TEXT NOT NULL, 39 checkpoint_ns TEXT NOT NULL DEFAULT '', 40 checkpoint_id TEXT NOT NULL, 41 parent_checkpoint_id TEXT, 42 type TEXT NOT NULL DEFAULT 'json', 43 checkpoint BLOB NOT NULL, 44 metadata BLOB NOT NULL DEFAULT '{}', 45 created_at INTEGER NOT NULL DEFAULT (unixepoch()), 46 PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) 47 ) 48 `) 49 db.exec(` 50 CREATE TABLE IF NOT EXISTS langgraph_writes ( 51 thread_id TEXT NOT NULL, 52 checkpoint_ns TEXT NOT NULL DEFAULT '', 53 checkpoint_id TEXT NOT NULL, 54 task_id TEXT NOT NULL, 55 idx INTEGER NOT NULL, 56 channel TEXT NOT NULL, 57 type TEXT NOT NULL DEFAULT 'json', 58 value BLOB, 59 PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) 60 ) 61 `) 62 if (!hasColumn(db, 'langgraph_checkpoints', 'metadata_type')) { 63 try { 64 db.exec(`ALTER TABLE langgraph_checkpoints ADD COLUMN metadata_type TEXT NOT NULL DEFAULT 'json'`) 65 } catch (err: unknown) { 66 // Tolerate "duplicate column" from concurrent build workers 67 if (!(err instanceof Error) || !err.message.includes('duplicate column')) throw err 68 } 69 } 70 } 71 72 const initDb = getDb() 73 ensureSchema(initDb) 74 initDb.close() 75 76 function getThreadId(config: RunnableConfig): string { 77 return (config.configurable?.thread_id as string) || '' 78 } 79 80 function getOptionalCheckpointNs(config: RunnableConfig): string | undefined { 81 const raw = config.configurable?.checkpoint_ns 82 return typeof raw === 'string' ? raw : undefined 83 } 84 85 function getCheckpointNs(config: RunnableConfig, fallback = ''): string { 86 return getOptionalCheckpointNs(config) ?? fallback 87 } 88 89 function getConfiguredCheckpointId(config: RunnableConfig): string | undefined { 90 const raw = config.configurable?.checkpoint_id 91 return typeof raw === 'string' && raw.trim() ? raw : undefined 92 } 93 94 function normalizeSerializedValue(value: unknown): Uint8Array | string | undefined { 95 if (value == null) return undefined 96 if (typeof value === 'string') return value 97 if (value instanceof Uint8Array) return value 98 if (Buffer.isBuffer(value)) return new Uint8Array(value) 99 return Buffer.from(String(value)) 100 } 101 102 function readLegacyJson<T>(value: Uint8Array | string | undefined): T { 103 if (value == null) return undefined as T 104 const text = typeof value === 'string' ? value : Buffer.from(value).toString() 105 return JSON.parse(text) as T 106 } 107 108 type CheckpointRow = { 109 thread_id: string 110 checkpoint_ns: string 111 checkpoint_id: string 112 parent_checkpoint_id: string | null 113 type: string 114 checkpoint: Buffer | string 115 metadata_type?: string | null 116 metadata: Buffer | string 117 created_at: number 118 } 119 120 type PendingWriteRow = { 121 task_id: string 122 idx: number 123 channel: string 124 type: string 125 value: Buffer | string | null 126 } 127 128 export class SqliteCheckpointSaver extends BaseCheckpointSaver { 129 private db: Database.Database 130 131 constructor(dbPath = DB_PATH) { 132 super() 133 this.db = getDb(dbPath) 134 ensureSchema(this.db) 135 } 136 137 private async deserializeValue<T>(type: string | undefined, value: unknown): Promise<T> { 138 const normalized = normalizeSerializedValue(value) 139 if (normalized == null) return undefined as T 140 const serializationType = typeof type === 'string' && type.trim() ? type : 'json' 141 try { 142 return await this.serde.loadsTyped(serializationType, normalized) as T 143 } catch (err) { 144 if (serializationType !== 'json') throw err 145 return readLegacyJson<T>(normalized) 146 } 147 } 148 149 private async loadPendingWrites( 150 threadId: string, 151 checkpointNs: string, 152 checkpointId: string, 153 ): Promise<CheckpointPendingWrite[]> { 154 const rows = this.db.prepare( 155 `SELECT task_id, idx, channel, type, value 156 FROM langgraph_writes 157 WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? 158 ORDER BY task_id ASC, idx ASC` 159 ).all(threadId, checkpointNs, checkpointId) as PendingWriteRow[] 160 161 return Promise.all(rows.map(async (row) => [ 162 row.task_id, 163 row.channel, 164 await this.deserializeValue(row.type, row.value), 165 ])) 166 } 167 168 private async migratePendingSends( 169 checkpoint: Checkpoint, 170 threadId: string, 171 checkpointNs: string, 172 parentCheckpointId: string, 173 ): Promise<void> { 174 if (checkpoint.v >= 4) return 175 176 const rows = this.db.prepare( 177 `SELECT type, value 178 FROM langgraph_writes 179 WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? AND channel = ? 180 ORDER BY idx ASC` 181 ).all(threadId, checkpointNs, parentCheckpointId, TASKS) as Array<{ type: string; value: Buffer | string | null }> 182 183 if (!rows.length) return 184 185 const pendingSends = await Promise.all(rows.map((row) => this.deserializeValue(row.type, row.value))) 186 checkpoint.channel_values ??= {} 187 checkpoint.channel_values[TASKS] = pendingSends 188 checkpoint.channel_versions ??= {} 189 checkpoint.channel_versions[TASKS] = Object.keys(checkpoint.channel_versions).length > 0 190 ? maxChannelVersion(...Object.values(checkpoint.channel_versions)) 191 : this.getNextVersion(undefined) 192 } 193 194 async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> { 195 const threadId = getThreadId(config) 196 if (!threadId) return undefined 197 198 const checkpointNs = getCheckpointNs(config) 199 const checkpointId = getConfiguredCheckpointId(config) 200 201 let row: CheckpointRow | undefined 202 if (checkpointId) { 203 row = this.db.prepare( 204 `SELECT * 205 FROM langgraph_checkpoints 206 WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` 207 ).get(threadId, checkpointNs, checkpointId) as CheckpointRow | undefined 208 } else { 209 row = this.db.prepare( 210 `SELECT * 211 FROM langgraph_checkpoints 212 WHERE thread_id = ? AND checkpoint_ns = ? 213 ORDER BY checkpoint_id DESC 214 LIMIT 1` 215 ).get(threadId, checkpointNs) as CheckpointRow | undefined 216 } 217 218 if (!row) return undefined 219 220 const checkpoint = await this.deserializeValue<Checkpoint>(row.type, row.checkpoint) 221 if (row.parent_checkpoint_id) { 222 await this.migratePendingSends(checkpoint, threadId, checkpointNs, row.parent_checkpoint_id) 223 } 224 225 const resultConfig: RunnableConfig = { 226 configurable: { 227 thread_id: threadId, 228 checkpoint_ns: checkpointNs, 229 checkpoint_id: row.checkpoint_id, 230 }, 231 } 232 233 const parentConfig = row.parent_checkpoint_id 234 ? { 235 configurable: { 236 thread_id: threadId, 237 checkpoint_ns: checkpointNs, 238 checkpoint_id: row.parent_checkpoint_id, 239 }, 240 } 241 : undefined 242 243 return { 244 config: resultConfig, 245 checkpoint, 246 metadata: await this.deserializeValue<CheckpointMetadata>(row.metadata_type ?? 'json', row.metadata), 247 parentConfig, 248 pendingWrites: await this.loadPendingWrites(threadId, checkpointNs, row.checkpoint_id), 249 } 250 } 251 252 async *list( 253 config: RunnableConfig, 254 options?: CheckpointListOptions, 255 ): AsyncGenerator<CheckpointTuple> { 256 const threadId = getThreadId(config) 257 if (!threadId) return 258 259 const checkpointNs = getOptionalCheckpointNs(config) 260 const checkpointId = getConfiguredCheckpointId(config) 261 const limit = options?.limit 262 263 let query = ` 264 SELECT * 265 FROM langgraph_checkpoints 266 WHERE thread_id = ? 267 ` 268 const params: Array<string | number> = [threadId] 269 270 if (checkpointNs !== undefined) { 271 query += ` AND checkpoint_ns = ?` 272 params.push(checkpointNs) 273 } 274 275 if (checkpointId) { 276 query += ` AND checkpoint_id = ?` 277 params.push(checkpointId) 278 } 279 280 if (options?.before?.configurable?.checkpoint_id) { 281 query += ` AND checkpoint_id < ?` 282 params.push(options.before.configurable.checkpoint_id) 283 } 284 285 query += ` ORDER BY checkpoint_id DESC` 286 287 const rows = this.db.prepare(query).all(...params) as CheckpointRow[] 288 let yielded = 0 289 290 for (const row of rows) { 291 const metadata = await this.deserializeValue<CheckpointMetadata>(row.metadata_type ?? 'json', row.metadata) 292 if (options?.filter && !Object.entries(options.filter).every(([key, value]) => (metadata as Record<string, unknown>)[key] === value)) { 293 continue 294 } 295 296 const checkpoint = await this.deserializeValue<Checkpoint>(row.type, row.checkpoint) 297 if (row.parent_checkpoint_id) { 298 await this.migratePendingSends(checkpoint, threadId, row.checkpoint_ns, row.parent_checkpoint_id) 299 } 300 301 yield { 302 config: { 303 configurable: { 304 thread_id: threadId, 305 checkpoint_ns: row.checkpoint_ns, 306 checkpoint_id: row.checkpoint_id, 307 }, 308 }, 309 checkpoint, 310 metadata, 311 parentConfig: row.parent_checkpoint_id 312 ? { 313 configurable: { 314 thread_id: threadId, 315 checkpoint_ns: row.checkpoint_ns, 316 checkpoint_id: row.parent_checkpoint_id, 317 }, 318 } 319 : undefined, 320 pendingWrites: await this.loadPendingWrites(threadId, row.checkpoint_ns, row.checkpoint_id), 321 } 322 323 yielded += 1 324 if (limit !== undefined && yielded >= limit) break 325 } 326 } 327 328 async put( 329 config: RunnableConfig, 330 checkpoint: Checkpoint, 331 metadata: CheckpointMetadata, 332 newVersions: Record<string, number | string>, 333 ): Promise<RunnableConfig> { 334 const threadId = getThreadId(config) 335 const checkpointNs = getCheckpointNs(config) 336 const parentCheckpointId = getConfiguredCheckpointId(config) 337 338 if (!threadId) { 339 throw new Error('Failed to put checkpoint. Missing required configurable.thread_id.') 340 } 341 void newVersions 342 343 const preparedCheckpoint = copyCheckpoint(checkpoint) 344 const [ 345 [checkpointType, serializedCheckpoint], 346 [metadataType, serializedMetadata], 347 ] = await Promise.all([ 348 this.serde.dumpsTyped(preparedCheckpoint), 349 this.serde.dumpsTyped(metadata), 350 ]) 351 352 const createdAt = Number.isFinite(Date.parse(preparedCheckpoint.ts)) 353 ? Date.parse(preparedCheckpoint.ts) 354 : Date.now() 355 356 this.db.prepare(` 357 INSERT OR REPLACE INTO langgraph_checkpoints 358 (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata_type, metadata, created_at) 359 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) 360 `).run( 361 threadId, 362 checkpointNs, 363 preparedCheckpoint.id, 364 parentCheckpointId || null, 365 checkpointType, 366 Buffer.from(serializedCheckpoint), 367 metadataType, 368 Buffer.from(serializedMetadata), 369 createdAt, 370 ) 371 372 return { 373 configurable: { 374 thread_id: threadId, 375 checkpoint_ns: checkpointNs, 376 checkpoint_id: preparedCheckpoint.id, 377 }, 378 } 379 } 380 381 async putWrites( 382 config: RunnableConfig, 383 writes: PendingWrite[], 384 taskId: string, 385 ): Promise<void> { 386 const threadId = getThreadId(config) 387 const checkpointNs = getCheckpointNs(config) 388 const checkpointId = getConfiguredCheckpointId(config) 389 390 if (!threadId) { 391 throw new Error('Failed to put writes. Missing required configurable.thread_id.') 392 } 393 if (!checkpointId) { 394 throw new Error('Failed to put writes. Missing required configurable.checkpoint_id.') 395 } 396 397 const serializedWrites = await Promise.all(writes.map(async ([channel, value], idx) => { 398 const [type, serializedValue] = await this.serde.dumpsTyped(value) 399 const writeIdx = WRITES_IDX_MAP[channel as string] ?? idx 400 return { 401 channel: channel as string, 402 idx: writeIdx, 403 type, 404 value: Buffer.from(serializedValue), 405 } 406 })) 407 408 const getExisting = this.db.prepare( 409 `SELECT 1 410 FROM langgraph_writes 411 WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? AND task_id = ? AND idx = ?` 412 ) 413 const upsert = this.db.prepare(` 414 INSERT OR REPLACE INTO langgraph_writes 415 (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) 416 VALUES (?, ?, ?, ?, ?, ?, ?, ?) 417 `) 418 419 const tx = this.db.transaction((items: typeof serializedWrites) => { 420 for (const item of items) { 421 if (item.idx >= 0) { 422 const existing = getExisting.get(threadId, checkpointNs, checkpointId, taskId, item.idx) 423 if (existing) continue 424 } 425 upsert.run( 426 threadId, 427 checkpointNs, 428 checkpointId, 429 taskId, 430 item.idx, 431 item.channel, 432 item.type, 433 item.value, 434 ) 435 } 436 }) 437 438 tx(serializedWrites) 439 } 440 async deleteThread(threadId: string): Promise<void> { 441 this.db.prepare(`DELETE FROM langgraph_checkpoints WHERE thread_id = ?`).run(threadId) 442 this.db.prepare(`DELETE FROM langgraph_writes WHERE thread_id = ?`).run(threadId) 443 } 444 445 async deleteCheckpoint(threadId: string, checkpointId: string): Promise<void> { 446 this.db.prepare(`DELETE FROM langgraph_checkpoints WHERE thread_id = ? AND checkpoint_id = ?`).run(threadId, checkpointId) 447 this.db.prepare(`DELETE FROM langgraph_writes WHERE thread_id = ? AND checkpoint_id = ?`).run(threadId, checkpointId) 448 } 449 450 async deleteCheckpointsAfter(threadId: string, timestamp: number): Promise<void> { 451 this.db.prepare(`DELETE FROM langgraph_checkpoints WHERE thread_id = ? AND created_at > ?`).run(threadId, timestamp) 452 this.db.prepare(` 453 DELETE FROM langgraph_writes 454 WHERE thread_id = ? 455 AND checkpoint_id NOT IN ( 456 SELECT checkpoint_id FROM langgraph_checkpoints WHERE thread_id = ? 457 ) 458 `).run(threadId, threadId) 459 } 460 } 461 462 let _saver: SqliteCheckpointSaver | undefined 463 464 export function getCheckpointSaver(): SqliteCheckpointSaver { 465 if (!_saver) _saver = new SqliteCheckpointSaver() 466 return _saver 467 }