memory-retrieval.ts
1 import type { MemoryItem } from '@/lib/shared/chat' 2 import { getRuntimeConfig } from '@/server/config/runtime' 3 import { getChatProvider } from '@/server/providers' 4 import { 5 listMemories, 6 listMemoryEmbeddings, 7 upsertMemoryEmbeddings, 8 } from '@/server/storage/chat-store' 9 10 interface ScoredMemory { 11 memory: MemoryItem 12 score: number 13 source: 'semantic' | 'lexical' 14 } 15 16 export async function refreshMemoryEmbeddings(sessionId: string): Promise<void> { 17 const memories = listMemories(sessionId) 18 if (memories.length === 0) return 19 20 const provider = getChatProvider() 21 const inputTexts = memories.map(memoryToEmbeddingText) 22 23 const result = await provider.embedTexts({ texts: inputTexts, sessionId }) 24 if (result.embeddings.length !== memories.length) { 25 throw new Error('Embedding provider returned an unexpected number of vectors.') 26 } 27 28 upsertMemoryEmbeddings( 29 sessionId, 30 memories.map((memory, index) => ({ 31 memoryId: memory.id, 32 provider: result.provider, 33 model: result.model, 34 vector: result.embeddings[index] ?? [], 35 })), 36 ) 37 } 38 39 export async function selectRelevantMemories( 40 sessionId: string, 41 queryText: string, 42 ): Promise<MemoryItem[]> { 43 const config = getRuntimeConfig() 44 const memories = listMemories(sessionId) 45 if (memories.length === 0) return [] 46 47 const semantic = await trySemanticSelection(sessionId, queryText, memories) 48 if (semantic.length > 0) { 49 return semantic 50 .filter((item) => item.score >= config.memorySimilarityThreshold) 51 .slice(0, config.memorySemanticTopK) 52 .map((item) => item.memory) 53 } 54 55 return lexicalFallback(memories, queryText, config.memorySemanticTopK) 56 } 57 58 async function trySemanticSelection( 59 sessionId: string, 60 queryText: string, 61 memories: MemoryItem[], 62 ): Promise<ScoredMemory[]> { 63 const provider = getChatProvider() 64 const query = queryText.trim() 65 if (!query) return [] 66 67 try { 68 const queryEmbeddingResult = await provider.embedTexts({ texts: [query], sessionId }) 69 const queryVector = queryEmbeddingResult.embeddings.at(0) 70 if (!queryVector || queryVector.length === 0) return [] 71 72 const storedEmbeddings = listMemoryEmbeddings(sessionId) 73 if (storedEmbeddings.length === 0) return [] 74 75 const memoriesById = new Map(memories.map((memory) => [memory.id, memory])) 76 const scored: ScoredMemory[] = [] 77 78 for (const embedding of storedEmbeddings) { 79 const memory = memoriesById.get(embedding.memoryId) 80 if (!memory) continue 81 if (embedding.vector.length !== queryVector.length) continue 82 83 scored.push({ 84 memory, 85 score: cosineSimilarity(queryVector, embedding.vector), 86 source: 'semantic', 87 }) 88 } 89 90 return scored.sort((a, b) => b.score - a.score) 91 } catch { 92 return [] 93 } 94 } 95 96 function lexicalFallback( 97 memories: MemoryItem[], 98 queryText: string, 99 limit: number, 100 ): MemoryItem[] { 101 const tokens = tokenize(queryText) 102 if (tokens.length === 0) { 103 return memories.slice(0, limit) 104 } 105 106 const scored = memories 107 .map((memory) => { 108 const haystack = `${memory.key} ${memory.value}`.toLowerCase() 109 const score = tokens.reduce((sum, token) => sum + (haystack.includes(token) ? 1 : 0), 0) 110 return { memory, score } 111 }) 112 .filter((item) => item.score > 0) 113 .sort((a, b) => b.score - a.score) 114 .slice(0, limit) 115 .map((item) => item.memory) 116 117 return scored.length > 0 ? scored : memories.slice(0, limit) 118 } 119 120 function memoryToEmbeddingText(memory: MemoryItem): string { 121 return [memory.key, memory.category, memory.value].join(' | ') 122 } 123 124 function tokenize(text: string): string[] { 125 return text 126 .toLowerCase() 127 .split(/[^a-z0-9]+/) 128 .filter((token) => token.length >= 2) 129 } 130 131 function cosineSimilarity(a: number[], b: number[]): number { 132 let dot = 0 133 let normA = 0 134 let normB = 0 135 136 for (let i = 0; i < a.length; i++) { 137 const av = a[i] ?? 0 138 const bv = b[i] ?? 0 139 dot += av * bv 140 normA += av * av 141 normB += bv * bv 142 } 143 144 if (normA === 0 || normB === 0) return 0 145 return dot / (Math.sqrt(normA) * Math.sqrt(normB)) 146 }