trigram_test.go
1 package session 2 3 import ( 4 "database/sql" 5 "os" 6 "strings" 7 "testing" 8 "time" 9 10 "github.com/Kocoro-lab/ShanClaw/internal/client" 11 ) 12 13 func openRawDBImpl(path string) (*sql.DB, error) { 14 db, err := sql.Open("sqlite", path) 15 if err != nil { 16 return nil, err 17 } 18 db.SetMaxOpenConns(1) 19 return db, nil 20 } 21 22 func writeFile(path, content string) error { 23 return os.WriteFile(path, []byte(content), 0644) 24 } 25 26 // seed populates an index with a controlled CJK / JP / EN / mixed corpus. 27 func seed(t *testing.T) *Index { 28 t.Helper() 29 dir := t.TempDir() 30 idx, err := OpenIndex(dir) 31 if err != nil { 32 t.Fatalf("OpenIndex: %v", err) 33 } 34 t.Cleanup(func() { idx.Close() }) 35 36 now := time.Now().Truncate(time.Second) 37 sessions := []*Session{ 38 {ID: "zh-1", Title: "登录", CreatedAt: now, UpdatedAt: now, 39 Messages: []client.Message{ 40 {Role: "user", Content: client.NewTextContent("帮我实现登录接口,使用 OAuth2 协议完成授权流程")}, 41 {Role: "assistant", Content: client.NewTextContent("好的,JWT 签发 token")}, 42 }}, 43 {ID: "zh-2", Title: "修复", CreatedAt: now, UpdatedAt: now, 44 Messages: []client.Message{ 45 {Role: "user", Content: client.NewTextContent("生产环境部署失败,需要修复 nginx 配置")}, 46 }}, 47 {ID: "zh-3", Title: "机器学习", CreatedAt: now, UpdatedAt: now, 48 Messages: []client.Message{ 49 {Role: "user", Content: client.NewTextContent("机器学习的原理和应用,深度学习神经网络")}, 50 }}, 51 {ID: "ja-1", Title: "機械学習", CreatedAt: now, UpdatedAt: now, 52 Messages: []client.Message{ 53 {Role: "user", Content: client.NewTextContent("機械学習のログイン機能を実装してください")}, 54 {Role: "assistant", Content: client.NewTextContent("実装が完了しました")}, 55 }}, 56 {ID: "en-1", Title: "server", CreatedAt: now, UpdatedAt: now, 57 Messages: []client.Message{ 58 {Role: "user", Content: client.NewTextContent("the server is running on port 8080 with multiple connections")}, 59 {Role: "assistant", Content: client.NewTextContent("several programs deployed, all connections stable")}, 60 }}, 61 {ID: "mix-1", Title: "mixed", CreatedAt: now, UpdatedAt: now, 62 Messages: []client.Message{ 63 {Role: "user", Content: client.NewTextContent("debug 登录接口 failed on port 8080")}, 64 }}, 65 } 66 for _, s := range sessions { 67 if err := idx.UpsertSession(s); err != nil { 68 t.Fatal(err) 69 } 70 } 71 return idx 72 } 73 74 func TestTrigram_2CharCJK_LikeFallback(t *testing.T) { 75 idx := seed(t) 76 cases := map[string][]string{ 77 "登录": {"zh-1", "mix-1"}, 78 "接口": {"zh-1", "mix-1"}, 79 "修复": {"zh-2"}, 80 "部署": {"zh-2"}, 81 "実装": {"ja-1"}, 82 } 83 for q, wantSessions := range cases { 84 t.Run(q, func(t *testing.T) { 85 res, err := idx.Search(q, 10) 86 if err != nil { 87 t.Fatalf("Search(%q): %v", q, err) 88 } 89 got := make(map[string]bool) 90 for _, r := range res { 91 got[r.SessionID] = true 92 } 93 for _, want := range wantSessions { 94 if !got[want] { 95 t.Errorf("expected hit in %s for %q, got %v", want, q, keys(got)) 96 } 97 } 98 if len(res) > 0 && !strings.Contains(res[0].Snippet, ">>>") { 99 t.Errorf("expected highlight in snippet for %q, got %q", q, res[0].Snippet) 100 } 101 }) 102 } 103 } 104 105 func TestTrigram_3CharCJK(t *testing.T) { 106 idx := seed(t) 107 cases := map[string]string{ 108 "机器学习": "zh-3", 109 "登录接口": "zh-1", 110 "機械学習": "ja-1", 111 } 112 for q, wantID := range cases { 113 t.Run(q, func(t *testing.T) { 114 res, err := idx.Search(q, 5) 115 if err != nil { 116 t.Fatalf("Search: %v", err) 117 } 118 if len(res) == 0 { 119 t.Fatalf("expected hits for %q", q) 120 } 121 found := false 122 for _, r := range res { 123 if r.SessionID == wantID { 124 found = true 125 } 126 } 127 if !found { 128 t.Errorf("expected %s in hits for %q, got results %+v", wantID, q, res) 129 } 130 }) 131 } 132 } 133 134 func TestTrigram_JapaneseMixed(t *testing.T) { 135 idx := seed(t) 136 // kana + kanji compound. 137 res, err := idx.Search("ログイン機能", 5) 138 if err != nil { 139 t.Fatalf("Search: %v", err) 140 } 141 if len(res) == 0 { 142 t.Fatal("expected hit for ログイン機能") 143 } 144 if res[0].SessionID != "ja-1" { 145 t.Errorf("expected ja-1, got %s", res[0].SessionID) 146 } 147 } 148 149 func TestTrigram_QuotedPhrase(t *testing.T) { 150 idx := seed(t) 151 // Quoted CJK phrase should match adjacent occurrence. 152 res, err := idx.Search(`"登录接口"`, 5) 153 if err != nil { 154 t.Fatalf("Search: %v", err) 155 } 156 if len(res) == 0 { 157 t.Fatal("expected hits for quoted CJK phrase") 158 } 159 // Quoted EN phrase. 160 res, err = idx.Search(`"port 8080"`, 5) 161 if err != nil { 162 t.Fatalf("Search: %v", err) 163 } 164 if len(res) == 0 { 165 t.Fatal("expected hits for quoted EN phrase") 166 } 167 } 168 169 func TestTrigram_EnglishSubstring(t *testing.T) { 170 idx := seed(t) 171 // Trigram provides substring match (not porter stemming). 'run' matches 172 // 'running' because 'run' is a trigram of 'running'. 173 cases := []string{"run", "program", "deploy", "connection"} 174 for _, q := range cases { 175 t.Run(q, func(t *testing.T) { 176 res, err := idx.Search(q, 5) 177 if err != nil { 178 t.Fatalf("Search: %v", err) 179 } 180 if len(res) == 0 { 181 t.Errorf("expected hits for %q", q) 182 } 183 // Native snippet() should highlight. 184 if len(res) > 0 && !strings.Contains(res[0].Snippet, ">>>") { 185 t.Errorf("expected highlight for %q, got %q", q, res[0].Snippet) 186 } 187 }) 188 } 189 } 190 191 func TestTrigram_MixedLatinCJK(t *testing.T) { 192 idx := seed(t) 193 res, err := idx.Search("OAuth2", 5) 194 if err != nil { 195 t.Fatalf("Search: %v", err) 196 } 197 if len(res) == 0 { 198 t.Fatal("expected hits for OAuth2") 199 } 200 if res[0].SessionID != "zh-1" { 201 t.Errorf("expected zh-1, got %s", res[0].SessionID) 202 } 203 } 204 205 func TestTrigram_VersionGateRebuild(t *testing.T) { 206 dir := t.TempDir() 207 idx1, err := OpenIndex(dir) 208 if err != nil { 209 t.Fatal(err) 210 } 211 now := time.Now().Truncate(time.Second) 212 if err := idx1.UpsertSession(&Session{ 213 ID: "v-1", Title: "v", CreatedAt: now, UpdatedAt: now, 214 Messages: []client.Message{{Role: "user", Content: client.NewTextContent("登录接口 test")}}, 215 }); err != nil { 216 t.Fatal(err) 217 } 218 idx1.Close() 219 220 // Roll version back to simulate upgrade-from-porter. 221 idx2raw, err := OpenIndex(dir) 222 if err != nil { 223 t.Fatal(err) 224 } 225 if _, err := idx2raw.db.Exec(`PRAGMA user_version = 1`); err != nil { 226 t.Fatal(err) 227 } 228 idx2raw.Close() 229 230 idx3, err := OpenIndex(dir) 231 if err != nil { 232 t.Fatal(err) 233 } 234 defer idx3.Close() 235 if !idx3.NeedsRebuild() { 236 t.Error("expected NeedsRebuild true after version rollback") 237 } 238 // messages table was dropped and re-created empty — FTS should return nothing. 239 var n int 240 if err := idx3.db.QueryRow(`SELECT COUNT(*) FROM messages`).Scan(&n); err != nil { 241 t.Fatal(err) 242 } 243 if n != 0 { 244 t.Errorf("expected empty messages table after version-gate drop, got %d rows", n) 245 } 246 } 247 248 // TestTrigram_VersionGateThroughNewStore guards against a regression where 249 // OpenIndex drops stale FTS tables on version change but the NewStore 250 // auto-rebuild trigger misses the signal, leaving search permanently empty. 251 func TestTrigram_VersionGateThroughNewStore(t *testing.T) { 252 dir := t.TempDir() 253 254 s1 := NewStore(dir) 255 now := time.Now().Truncate(time.Second) 256 if err := s1.Save(&Session{ 257 ID: "s", Title: "t", CreatedAt: now, UpdatedAt: now, 258 Messages: []client.Message{{Role: "user", Content: client.NewTextContent("登录接口 failed")}}, 259 }); err != nil { 260 t.Fatal(err) 261 } 262 s1.Close() 263 264 // Simulate an older tokenizer version on disk. 265 idx, err := OpenIndex(dir) 266 if err != nil { 267 t.Fatal(err) 268 } 269 if _, err := idx.db.Exec(`PRAGMA user_version = 1`); err != nil { 270 t.Fatal(err) 271 } 272 idx.Close() 273 274 // Reopen via NewStore — the real application flow. Must rebuild. 275 s2 := NewStore(dir) 276 defer s2.Close() 277 res, err := s2.Search("登录接口", 5) 278 if err != nil { 279 t.Fatal(err) 280 } 281 if len(res) == 0 { 282 t.Error("expected rebuild to restore searchability after version change via NewStore") 283 } 284 } 285 286 // TestTrigram_MixedShortAndLongTerms guards against silently dropping short 287 // CJK terms in mixed queries. FTS5 trigram ignores terms <3 chars, so without 288 // per-term fallback analysis, `登录 failed` would match rows that only contain 289 // "failed" with no 登录 at all. 290 func TestTrigram_MixedShortAndLongTerms(t *testing.T) { 291 dir := t.TempDir() 292 idx, err := OpenIndex(dir) 293 if err != nil { 294 t.Fatal(err) 295 } 296 defer idx.Close() 297 298 now := time.Now().Truncate(time.Second) 299 if err := idx.UpsertSession(&Session{ 300 ID: "target", Title: "t", CreatedAt: now, UpdatedAt: now, 301 Messages: []client.Message{{Role: "user", Content: client.NewTextContent("登录接口 failed on port 8080")}}, 302 }); err != nil { 303 t.Fatal(err) 304 } 305 if err := idx.UpsertSession(&Session{ 306 ID: "noise", Title: "n", CreatedAt: now, UpdatedAt: now, 307 Messages: []client.Message{{Role: "user", Content: client.NewTextContent("the build failed with no CJK")}}, 308 }); err != nil { 309 t.Fatal(err) 310 } 311 312 res, err := idx.Search("登录 failed", 10) 313 if err != nil { 314 t.Fatal(err) 315 } 316 for _, r := range res { 317 if r.SessionID == "noise" { 318 t.Errorf("mixed query `登录 failed` incorrectly matched 'noise' session with no 登录") 319 } 320 } 321 if len(res) == 0 { 322 t.Error("expected at least the 'target' session to match") 323 } 324 } 325 326 // TestTrigram_QuotedShortCJK ensures `"登录"` works the same as `登录`. 327 func TestTrigram_QuotedShortCJK(t *testing.T) { 328 dir := t.TempDir() 329 idx, err := OpenIndex(dir) 330 if err != nil { 331 t.Fatal(err) 332 } 333 defer idx.Close() 334 335 now := time.Now().Truncate(time.Second) 336 if err := idx.UpsertSession(&Session{ 337 ID: "s", Title: "t", CreatedAt: now, UpdatedAt: now, 338 Messages: []client.Message{{Role: "user", Content: client.NewTextContent("帮我实现登录接口")}}, 339 }); err != nil { 340 t.Fatal(err) 341 } 342 unquoted, _ := idx.Search("登录", 5) 343 quoted, _ := idx.Search(`"登录"`, 5) 344 if len(quoted) != len(unquoted) { 345 t.Errorf("quoted short CJK should match like unquoted: quoted=%d unquoted=%d", len(quoted), len(unquoted)) 346 } 347 } 348 349 // TestTrigram_UpgradeFromMainStored0 covers the real upgrade path: existing 350 // users on current main have a porter-tokenized FTS and user_version=0 (main 351 // never stamped it). The migration must drop the stale FTS and trigger a 352 // rebuild so trigram semantics (substring match) take effect. 353 func TestTrigram_UpgradeFromMainStored0(t *testing.T) { 354 dir := t.TempDir() 355 356 // Construct a pre-trigram DB directly: porter+unicode61 FTS, no user_version. 357 raw, err := openRawDBImpl(dir + "/sessions.db") 358 if err != nil { 359 t.Fatal(err) 360 } 361 if _, err := raw.Exec(` 362 PRAGMA journal_mode=WAL; 363 CREATE TABLE sessions (id TEXT PRIMARY KEY, title TEXT NOT NULL DEFAULT '', 364 cwd TEXT NOT NULL DEFAULT '', created_at DATETIME NOT NULL, 365 updated_at DATETIME NOT NULL, msg_count INTEGER NOT NULL DEFAULT 0); 366 CREATE TABLE messages (rowid INTEGER PRIMARY KEY AUTOINCREMENT, 367 session_id TEXT NOT NULL, msg_index INTEGER NOT NULL, role TEXT NOT NULL, 368 content TEXT NOT NULL, UNIQUE(session_id, msg_index)); 369 CREATE VIRTUAL TABLE messages_fts USING fts5(content, content=messages, 370 content_rowid=rowid, tokenize='porter unicode61'); 371 `); err != nil { 372 t.Fatal(err) 373 } 374 now := time.Now().Truncate(time.Second).Format(time.RFC3339Nano) 375 raw.Exec(`INSERT INTO sessions (id,title,created_at,updated_at,msg_count) VALUES ('s','t',?,?,1)`, now, now) 376 raw.Exec(`INSERT INTO messages (session_id,msg_index,role,content) VALUES ('s',0,'user','the nginx configuration file')`) 377 raw.Exec(`INSERT INTO messages_fts(rowid,content) VALUES (1, 'the nginx configuration file')`) 378 raw.Close() 379 380 // Matching JSON so the rebuild path can reseed. 381 json := `{"id":"s","title":"t","created_at":"` + now + `","updated_at":"` + now + 382 `","schema_version":1,"messages":[{"role":"user","content":"the nginx configuration file"}]}` 383 if err := writeFile(dir+"/s.json", json); err != nil { 384 t.Fatal(err) 385 } 386 387 // Reopen via NewStore — real upgrade flow. 388 s := NewStore(dir) 389 defer s.Close() 390 391 // `figur` is a mid-word substring that only a trigram index matches. 392 res, err := s.Search("figur", 5) 393 if err != nil { 394 t.Fatal(err) 395 } 396 if len(res) == 0 { 397 t.Error("upgrade from main (user_version=0, porter FTS) did not migrate to trigram — substring query failed") 398 } 399 } 400 401 // TestTrigram_OperatorQueryWithShortCJK: boolean operators combined with 402 // short CJK terms cannot be faithfully expressed via LIKE intersection, and 403 // silently degrading to trigram MATCH drops the short term. We reject the 404 // query instead of returning wrong results. 405 func TestTrigram_OperatorQueryWithShortCJK(t *testing.T) { 406 dir := t.TempDir() 407 idx, err := OpenIndex(dir) 408 if err != nil { 409 t.Fatal(err) 410 } 411 defer idx.Close() 412 413 now := time.Now().Truncate(time.Second) 414 if err := idx.UpsertSession(&Session{ 415 ID: "s", Title: "t", CreatedAt: now, UpdatedAt: now, 416 Messages: []client.Message{{Role: "user", Content: client.NewTextContent("登录接口 failed on port 8080")}}, 417 }); err != nil { 418 t.Fatal(err) 419 } 420 _, err = idx.Search("登录 AND failed", 5) 421 if err == nil { 422 t.Error("expected error for operator query with short CJK term, got nil") 423 } 424 } 425 426 // TestLikeSnippet_TurkishCaseExpansion guards against the latent bug where 427 // strings.ToLower("İ") becomes "i\u0307" (2 → 3 bytes), which would make the 428 // byte-offset-based snippet machinery slice mid-rune on surrounding text. 429 func TestLikeSnippet_TurkishCaseExpansion(t *testing.T) { 430 // "İ" appears before the match; naive byte-offset code could mis-slice. 431 content := "Proje İstanbul 登录 sonuç" 432 snip := likeSnippet(content, []string{"登录"}) 433 if !strings.Contains(snip, ">>>登录<<<") { 434 t.Errorf("expected >>>登录<<< highlight, got %q", snip) 435 } 436 } 437 438 // TestLikeSnippet_EarliestMatch centres the snippet on whichever term matches 439 // earliest so multi-term queries surface the relevant context instead of 440 // always anchoring on the first term. 441 func TestLikeSnippet_EarliestMatch(t *testing.T) { 442 content := "failed at startup before 登录 was ever tried" 443 // Query terms: 登录 (later) and failed (earlier). Snippet should centre 444 // on 'failed' since it appears first in content. 445 snip := likeSnippet(content, []string{"登录", "failed"}) 446 if !strings.Contains(snip, ">>>failed<<<") { 447 t.Errorf("expected snippet centred on earliest term 'failed', got %q", snip) 448 } 449 } 450 451 func keys(m map[string]bool) []string { 452 out := make([]string, 0, len(m)) 453 for k := range m { 454 out = append(out, k) 455 } 456 return out 457 }