/ src / server / chat / memory-retrieval.ts
memory-retrieval.ts
  1  import type { MemoryItem } from '@/lib/shared/chat'
  2  import { getRuntimeConfig } from '@/server/config/runtime'
  3  import { getChatProvider } from '@/server/providers'
  4  import {
  5    listMemories,
  6    listMemoryEmbeddings,
  7    upsertMemoryEmbeddings,
  8  } from '@/server/storage/chat-store'
  9  
 10  interface ScoredMemory {
 11    memory: MemoryItem
 12    score: number
 13    source: 'semantic' | 'lexical'
 14  }
 15  
 16  export async function refreshMemoryEmbeddings(sessionId: string): Promise<void> {
 17    const memories = listMemories(sessionId)
 18    if (memories.length === 0) return
 19  
 20    const provider = getChatProvider()
 21    const inputTexts = memories.map(memoryToEmbeddingText)
 22  
 23    const result = await provider.embedTexts({ texts: inputTexts, sessionId })
 24    if (result.embeddings.length !== memories.length) {
 25      throw new Error('Embedding provider returned an unexpected number of vectors.')
 26    }
 27  
 28    upsertMemoryEmbeddings(
 29      sessionId,
 30      memories.map((memory, index) => ({
 31        memoryId: memory.id,
 32        provider: result.provider,
 33        model: result.model,
 34        vector: result.embeddings[index] ?? [],
 35      })),
 36    )
 37  }
 38  
 39  export async function selectRelevantMemories(
 40    sessionId: string,
 41    queryText: string,
 42  ): Promise<MemoryItem[]> {
 43    const config = getRuntimeConfig()
 44    const memories = listMemories(sessionId)
 45    if (memories.length === 0) return []
 46  
 47    const semantic = await trySemanticSelection(sessionId, queryText, memories)
 48    if (semantic.length > 0) {
 49      return semantic
 50        .filter((item) => item.score >= config.memorySimilarityThreshold)
 51        .slice(0, config.memorySemanticTopK)
 52        .map((item) => item.memory)
 53    }
 54  
 55    return lexicalFallback(memories, queryText, config.memorySemanticTopK)
 56  }
 57  
 58  async function trySemanticSelection(
 59    sessionId: string,
 60    queryText: string,
 61    memories: MemoryItem[],
 62  ): Promise<ScoredMemory[]> {
 63    const provider = getChatProvider()
 64    const query = queryText.trim()
 65    if (!query) return []
 66  
 67    try {
 68      const queryEmbeddingResult = await provider.embedTexts({ texts: [query], sessionId })
 69      const queryVector = queryEmbeddingResult.embeddings.at(0)
 70      if (!queryVector || queryVector.length === 0) return []
 71  
 72      const storedEmbeddings = listMemoryEmbeddings(sessionId)
 73      if (storedEmbeddings.length === 0) return []
 74  
 75      const memoriesById = new Map(memories.map((memory) => [memory.id, memory]))
 76      const scored: ScoredMemory[] = []
 77  
 78      for (const embedding of storedEmbeddings) {
 79        const memory = memoriesById.get(embedding.memoryId)
 80        if (!memory) continue
 81        if (embedding.vector.length !== queryVector.length) continue
 82  
 83        scored.push({
 84          memory,
 85          score: cosineSimilarity(queryVector, embedding.vector),
 86          source: 'semantic',
 87        })
 88      }
 89  
 90      return scored.sort((a, b) => b.score - a.score)
 91    } catch {
 92      return []
 93    }
 94  }
 95  
 96  function lexicalFallback(
 97    memories: MemoryItem[],
 98    queryText: string,
 99    limit: number,
100  ): MemoryItem[] {
101    const tokens = tokenize(queryText)
102    if (tokens.length === 0) {
103      return memories.slice(0, limit)
104    }
105  
106    const scored = memories
107      .map((memory) => {
108        const haystack = `${memory.key} ${memory.value}`.toLowerCase()
109        const score = tokens.reduce((sum, token) => sum + (haystack.includes(token) ? 1 : 0), 0)
110        return { memory, score }
111      })
112      .filter((item) => item.score > 0)
113      .sort((a, b) => b.score - a.score)
114      .slice(0, limit)
115      .map((item) => item.memory)
116  
117    return scored.length > 0 ? scored : memories.slice(0, limit)
118  }
119  
120  function memoryToEmbeddingText(memory: MemoryItem): string {
121    return [memory.key, memory.category, memory.value].join(' | ')
122  }
123  
124  function tokenize(text: string): string[] {
125    return text
126      .toLowerCase()
127      .split(/[^a-z0-9]+/)
128      .filter((token) => token.length >= 2)
129  }
130  
131  function cosineSimilarity(a: number[], b: number[]): number {
132    let dot = 0
133    let normA = 0
134    let normB = 0
135  
136    for (let i = 0; i < a.length; i++) {
137      const av = a[i] ?? 0
138      const bv = b[i] ?? 0
139      dot += av * bv
140      normA += av * av
141      normB += bv * bv
142    }
143  
144    if (normA === 0 || normB === 0) return 0
145    return dot / (Math.sqrt(normA) * Math.sqrt(normB))
146  }