/ internal / session / trigram_test.go
trigram_test.go
  1  package session
  2  
  3  import (
  4  	"database/sql"
  5  	"os"
  6  	"strings"
  7  	"testing"
  8  	"time"
  9  
 10  	"github.com/Kocoro-lab/ShanClaw/internal/client"
 11  )
 12  
 13  func openRawDBImpl(path string) (*sql.DB, error) {
 14  	db, err := sql.Open("sqlite", path)
 15  	if err != nil {
 16  		return nil, err
 17  	}
 18  	db.SetMaxOpenConns(1)
 19  	return db, nil
 20  }
 21  
 22  func writeFile(path, content string) error {
 23  	return os.WriteFile(path, []byte(content), 0644)
 24  }
 25  
 26  // seed populates an index with a controlled CJK / JP / EN / mixed corpus.
 27  func seed(t *testing.T) *Index {
 28  	t.Helper()
 29  	dir := t.TempDir()
 30  	idx, err := OpenIndex(dir)
 31  	if err != nil {
 32  		t.Fatalf("OpenIndex: %v", err)
 33  	}
 34  	t.Cleanup(func() { idx.Close() })
 35  
 36  	now := time.Now().Truncate(time.Second)
 37  	sessions := []*Session{
 38  		{ID: "zh-1", Title: "登录", CreatedAt: now, UpdatedAt: now,
 39  			Messages: []client.Message{
 40  				{Role: "user", Content: client.NewTextContent("帮我实现登录接口,使用 OAuth2 协议完成授权流程")},
 41  				{Role: "assistant", Content: client.NewTextContent("好的,JWT 签发 token")},
 42  			}},
 43  		{ID: "zh-2", Title: "修复", CreatedAt: now, UpdatedAt: now,
 44  			Messages: []client.Message{
 45  				{Role: "user", Content: client.NewTextContent("生产环境部署失败,需要修复 nginx 配置")},
 46  			}},
 47  		{ID: "zh-3", Title: "机器学习", CreatedAt: now, UpdatedAt: now,
 48  			Messages: []client.Message{
 49  				{Role: "user", Content: client.NewTextContent("机器学习的原理和应用,深度学习神经网络")},
 50  			}},
 51  		{ID: "ja-1", Title: "機械学習", CreatedAt: now, UpdatedAt: now,
 52  			Messages: []client.Message{
 53  				{Role: "user", Content: client.NewTextContent("機械学習のログイン機能を実装してください")},
 54  				{Role: "assistant", Content: client.NewTextContent("実装が完了しました")},
 55  			}},
 56  		{ID: "en-1", Title: "server", CreatedAt: now, UpdatedAt: now,
 57  			Messages: []client.Message{
 58  				{Role: "user", Content: client.NewTextContent("the server is running on port 8080 with multiple connections")},
 59  				{Role: "assistant", Content: client.NewTextContent("several programs deployed, all connections stable")},
 60  			}},
 61  		{ID: "mix-1", Title: "mixed", CreatedAt: now, UpdatedAt: now,
 62  			Messages: []client.Message{
 63  				{Role: "user", Content: client.NewTextContent("debug 登录接口 failed on port 8080")},
 64  			}},
 65  	}
 66  	for _, s := range sessions {
 67  		if err := idx.UpsertSession(s); err != nil {
 68  			t.Fatal(err)
 69  		}
 70  	}
 71  	return idx
 72  }
 73  
 74  func TestTrigram_2CharCJK_LikeFallback(t *testing.T) {
 75  	idx := seed(t)
 76  	cases := map[string][]string{
 77  		"登录": {"zh-1", "mix-1"},
 78  		"接口": {"zh-1", "mix-1"},
 79  		"修复": {"zh-2"},
 80  		"部署": {"zh-2"},
 81  		"実装": {"ja-1"},
 82  	}
 83  	for q, wantSessions := range cases {
 84  		t.Run(q, func(t *testing.T) {
 85  			res, err := idx.Search(q, 10)
 86  			if err != nil {
 87  				t.Fatalf("Search(%q): %v", q, err)
 88  			}
 89  			got := make(map[string]bool)
 90  			for _, r := range res {
 91  				got[r.SessionID] = true
 92  			}
 93  			for _, want := range wantSessions {
 94  				if !got[want] {
 95  					t.Errorf("expected hit in %s for %q, got %v", want, q, keys(got))
 96  				}
 97  			}
 98  			if len(res) > 0 && !strings.Contains(res[0].Snippet, ">>>") {
 99  				t.Errorf("expected highlight in snippet for %q, got %q", q, res[0].Snippet)
100  			}
101  		})
102  	}
103  }
104  
105  func TestTrigram_3CharCJK(t *testing.T) {
106  	idx := seed(t)
107  	cases := map[string]string{
108  		"机器学习": "zh-3",
109  		"登录接口": "zh-1",
110  		"機械学習": "ja-1",
111  	}
112  	for q, wantID := range cases {
113  		t.Run(q, func(t *testing.T) {
114  			res, err := idx.Search(q, 5)
115  			if err != nil {
116  				t.Fatalf("Search: %v", err)
117  			}
118  			if len(res) == 0 {
119  				t.Fatalf("expected hits for %q", q)
120  			}
121  			found := false
122  			for _, r := range res {
123  				if r.SessionID == wantID {
124  					found = true
125  				}
126  			}
127  			if !found {
128  				t.Errorf("expected %s in hits for %q, got results %+v", wantID, q, res)
129  			}
130  		})
131  	}
132  }
133  
134  func TestTrigram_JapaneseMixed(t *testing.T) {
135  	idx := seed(t)
136  	// kana + kanji compound.
137  	res, err := idx.Search("ログイン機能", 5)
138  	if err != nil {
139  		t.Fatalf("Search: %v", err)
140  	}
141  	if len(res) == 0 {
142  		t.Fatal("expected hit for ログイン機能")
143  	}
144  	if res[0].SessionID != "ja-1" {
145  		t.Errorf("expected ja-1, got %s", res[0].SessionID)
146  	}
147  }
148  
149  func TestTrigram_QuotedPhrase(t *testing.T) {
150  	idx := seed(t)
151  	// Quoted CJK phrase should match adjacent occurrence.
152  	res, err := idx.Search(`"登录接口"`, 5)
153  	if err != nil {
154  		t.Fatalf("Search: %v", err)
155  	}
156  	if len(res) == 0 {
157  		t.Fatal("expected hits for quoted CJK phrase")
158  	}
159  	// Quoted EN phrase.
160  	res, err = idx.Search(`"port 8080"`, 5)
161  	if err != nil {
162  		t.Fatalf("Search: %v", err)
163  	}
164  	if len(res) == 0 {
165  		t.Fatal("expected hits for quoted EN phrase")
166  	}
167  }
168  
169  func TestTrigram_EnglishSubstring(t *testing.T) {
170  	idx := seed(t)
171  	// Trigram provides substring match (not porter stemming). 'run' matches
172  	// 'running' because 'run' is a trigram of 'running'.
173  	cases := []string{"run", "program", "deploy", "connection"}
174  	for _, q := range cases {
175  		t.Run(q, func(t *testing.T) {
176  			res, err := idx.Search(q, 5)
177  			if err != nil {
178  				t.Fatalf("Search: %v", err)
179  			}
180  			if len(res) == 0 {
181  				t.Errorf("expected hits for %q", q)
182  			}
183  			// Native snippet() should highlight.
184  			if len(res) > 0 && !strings.Contains(res[0].Snippet, ">>>") {
185  				t.Errorf("expected highlight for %q, got %q", q, res[0].Snippet)
186  			}
187  		})
188  	}
189  }
190  
191  func TestTrigram_MixedLatinCJK(t *testing.T) {
192  	idx := seed(t)
193  	res, err := idx.Search("OAuth2", 5)
194  	if err != nil {
195  		t.Fatalf("Search: %v", err)
196  	}
197  	if len(res) == 0 {
198  		t.Fatal("expected hits for OAuth2")
199  	}
200  	if res[0].SessionID != "zh-1" {
201  		t.Errorf("expected zh-1, got %s", res[0].SessionID)
202  	}
203  }
204  
205  func TestTrigram_VersionGateRebuild(t *testing.T) {
206  	dir := t.TempDir()
207  	idx1, err := OpenIndex(dir)
208  	if err != nil {
209  		t.Fatal(err)
210  	}
211  	now := time.Now().Truncate(time.Second)
212  	if err := idx1.UpsertSession(&Session{
213  		ID: "v-1", Title: "v", CreatedAt: now, UpdatedAt: now,
214  		Messages: []client.Message{{Role: "user", Content: client.NewTextContent("登录接口 test")}},
215  	}); err != nil {
216  		t.Fatal(err)
217  	}
218  	idx1.Close()
219  
220  	// Roll version back to simulate upgrade-from-porter.
221  	idx2raw, err := OpenIndex(dir)
222  	if err != nil {
223  		t.Fatal(err)
224  	}
225  	if _, err := idx2raw.db.Exec(`PRAGMA user_version = 1`); err != nil {
226  		t.Fatal(err)
227  	}
228  	idx2raw.Close()
229  
230  	idx3, err := OpenIndex(dir)
231  	if err != nil {
232  		t.Fatal(err)
233  	}
234  	defer idx3.Close()
235  	if !idx3.NeedsRebuild() {
236  		t.Error("expected NeedsRebuild true after version rollback")
237  	}
238  	// messages table was dropped and re-created empty — FTS should return nothing.
239  	var n int
240  	if err := idx3.db.QueryRow(`SELECT COUNT(*) FROM messages`).Scan(&n); err != nil {
241  		t.Fatal(err)
242  	}
243  	if n != 0 {
244  		t.Errorf("expected empty messages table after version-gate drop, got %d rows", n)
245  	}
246  }
247  
248  // TestTrigram_VersionGateThroughNewStore guards against a regression where
249  // OpenIndex drops stale FTS tables on version change but the NewStore
250  // auto-rebuild trigger misses the signal, leaving search permanently empty.
251  func TestTrigram_VersionGateThroughNewStore(t *testing.T) {
252  	dir := t.TempDir()
253  
254  	s1 := NewStore(dir)
255  	now := time.Now().Truncate(time.Second)
256  	if err := s1.Save(&Session{
257  		ID: "s", Title: "t", CreatedAt: now, UpdatedAt: now,
258  		Messages: []client.Message{{Role: "user", Content: client.NewTextContent("登录接口 failed")}},
259  	}); err != nil {
260  		t.Fatal(err)
261  	}
262  	s1.Close()
263  
264  	// Simulate an older tokenizer version on disk.
265  	idx, err := OpenIndex(dir)
266  	if err != nil {
267  		t.Fatal(err)
268  	}
269  	if _, err := idx.db.Exec(`PRAGMA user_version = 1`); err != nil {
270  		t.Fatal(err)
271  	}
272  	idx.Close()
273  
274  	// Reopen via NewStore — the real application flow. Must rebuild.
275  	s2 := NewStore(dir)
276  	defer s2.Close()
277  	res, err := s2.Search("登录接口", 5)
278  	if err != nil {
279  		t.Fatal(err)
280  	}
281  	if len(res) == 0 {
282  		t.Error("expected rebuild to restore searchability after version change via NewStore")
283  	}
284  }
285  
286  // TestTrigram_MixedShortAndLongTerms guards against silently dropping short
287  // CJK terms in mixed queries. FTS5 trigram ignores terms <3 chars, so without
288  // per-term fallback analysis, `登录 failed` would match rows that only contain
289  // "failed" with no 登录 at all.
290  func TestTrigram_MixedShortAndLongTerms(t *testing.T) {
291  	dir := t.TempDir()
292  	idx, err := OpenIndex(dir)
293  	if err != nil {
294  		t.Fatal(err)
295  	}
296  	defer idx.Close()
297  
298  	now := time.Now().Truncate(time.Second)
299  	if err := idx.UpsertSession(&Session{
300  		ID: "target", Title: "t", CreatedAt: now, UpdatedAt: now,
301  		Messages: []client.Message{{Role: "user", Content: client.NewTextContent("登录接口 failed on port 8080")}},
302  	}); err != nil {
303  		t.Fatal(err)
304  	}
305  	if err := idx.UpsertSession(&Session{
306  		ID: "noise", Title: "n", CreatedAt: now, UpdatedAt: now,
307  		Messages: []client.Message{{Role: "user", Content: client.NewTextContent("the build failed with no CJK")}},
308  	}); err != nil {
309  		t.Fatal(err)
310  	}
311  
312  	res, err := idx.Search("登录 failed", 10)
313  	if err != nil {
314  		t.Fatal(err)
315  	}
316  	for _, r := range res {
317  		if r.SessionID == "noise" {
318  			t.Errorf("mixed query `登录 failed` incorrectly matched 'noise' session with no 登录")
319  		}
320  	}
321  	if len(res) == 0 {
322  		t.Error("expected at least the 'target' session to match")
323  	}
324  }
325  
326  // TestTrigram_QuotedShortCJK ensures `"登录"` works the same as `登录`.
327  func TestTrigram_QuotedShortCJK(t *testing.T) {
328  	dir := t.TempDir()
329  	idx, err := OpenIndex(dir)
330  	if err != nil {
331  		t.Fatal(err)
332  	}
333  	defer idx.Close()
334  
335  	now := time.Now().Truncate(time.Second)
336  	if err := idx.UpsertSession(&Session{
337  		ID: "s", Title: "t", CreatedAt: now, UpdatedAt: now,
338  		Messages: []client.Message{{Role: "user", Content: client.NewTextContent("帮我实现登录接口")}},
339  	}); err != nil {
340  		t.Fatal(err)
341  	}
342  	unquoted, _ := idx.Search("登录", 5)
343  	quoted, _ := idx.Search(`"登录"`, 5)
344  	if len(quoted) != len(unquoted) {
345  		t.Errorf("quoted short CJK should match like unquoted: quoted=%d unquoted=%d", len(quoted), len(unquoted))
346  	}
347  }
348  
349  // TestTrigram_UpgradeFromMainStored0 covers the real upgrade path: existing
350  // users on current main have a porter-tokenized FTS and user_version=0 (main
351  // never stamped it). The migration must drop the stale FTS and trigger a
352  // rebuild so trigram semantics (substring match) take effect.
353  func TestTrigram_UpgradeFromMainStored0(t *testing.T) {
354  	dir := t.TempDir()
355  
356  	// Construct a pre-trigram DB directly: porter+unicode61 FTS, no user_version.
357  	raw, err := openRawDBImpl(dir + "/sessions.db")
358  	if err != nil {
359  		t.Fatal(err)
360  	}
361  	if _, err := raw.Exec(`
362  PRAGMA journal_mode=WAL;
363  CREATE TABLE sessions (id TEXT PRIMARY KEY, title TEXT NOT NULL DEFAULT '',
364      cwd TEXT NOT NULL DEFAULT '', created_at DATETIME NOT NULL,
365      updated_at DATETIME NOT NULL, msg_count INTEGER NOT NULL DEFAULT 0);
366  CREATE TABLE messages (rowid INTEGER PRIMARY KEY AUTOINCREMENT,
367      session_id TEXT NOT NULL, msg_index INTEGER NOT NULL, role TEXT NOT NULL,
368      content TEXT NOT NULL, UNIQUE(session_id, msg_index));
369  CREATE VIRTUAL TABLE messages_fts USING fts5(content, content=messages,
370      content_rowid=rowid, tokenize='porter unicode61');
371  `); err != nil {
372  		t.Fatal(err)
373  	}
374  	now := time.Now().Truncate(time.Second).Format(time.RFC3339Nano)
375  	raw.Exec(`INSERT INTO sessions (id,title,created_at,updated_at,msg_count) VALUES ('s','t',?,?,1)`, now, now)
376  	raw.Exec(`INSERT INTO messages (session_id,msg_index,role,content) VALUES ('s',0,'user','the nginx configuration file')`)
377  	raw.Exec(`INSERT INTO messages_fts(rowid,content) VALUES (1, 'the nginx configuration file')`)
378  	raw.Close()
379  
380  	// Matching JSON so the rebuild path can reseed.
381  	json := `{"id":"s","title":"t","created_at":"` + now + `","updated_at":"` + now +
382  		`","schema_version":1,"messages":[{"role":"user","content":"the nginx configuration file"}]}`
383  	if err := writeFile(dir+"/s.json", json); err != nil {
384  		t.Fatal(err)
385  	}
386  
387  	// Reopen via NewStore — real upgrade flow.
388  	s := NewStore(dir)
389  	defer s.Close()
390  
391  	// `figur` is a mid-word substring that only a trigram index matches.
392  	res, err := s.Search("figur", 5)
393  	if err != nil {
394  		t.Fatal(err)
395  	}
396  	if len(res) == 0 {
397  		t.Error("upgrade from main (user_version=0, porter FTS) did not migrate to trigram — substring query failed")
398  	}
399  }
400  
401  // TestTrigram_OperatorQueryWithShortCJK: boolean operators combined with
402  // short CJK terms cannot be faithfully expressed via LIKE intersection, and
403  // silently degrading to trigram MATCH drops the short term. We reject the
404  // query instead of returning wrong results.
405  func TestTrigram_OperatorQueryWithShortCJK(t *testing.T) {
406  	dir := t.TempDir()
407  	idx, err := OpenIndex(dir)
408  	if err != nil {
409  		t.Fatal(err)
410  	}
411  	defer idx.Close()
412  
413  	now := time.Now().Truncate(time.Second)
414  	if err := idx.UpsertSession(&Session{
415  		ID: "s", Title: "t", CreatedAt: now, UpdatedAt: now,
416  		Messages: []client.Message{{Role: "user", Content: client.NewTextContent("登录接口 failed on port 8080")}},
417  	}); err != nil {
418  		t.Fatal(err)
419  	}
420  	_, err = idx.Search("登录 AND failed", 5)
421  	if err == nil {
422  		t.Error("expected error for operator query with short CJK term, got nil")
423  	}
424  }
425  
426  // TestLikeSnippet_TurkishCaseExpansion guards against the latent bug where
427  // strings.ToLower("İ") becomes "i\u0307" (2 → 3 bytes), which would make the
428  // byte-offset-based snippet machinery slice mid-rune on surrounding text.
429  func TestLikeSnippet_TurkishCaseExpansion(t *testing.T) {
430  	// "İ" appears before the match; naive byte-offset code could mis-slice.
431  	content := "Proje İstanbul 登录 sonuç"
432  	snip := likeSnippet(content, []string{"登录"})
433  	if !strings.Contains(snip, ">>>登录<<<") {
434  		t.Errorf("expected >>>登录<<< highlight, got %q", snip)
435  	}
436  }
437  
438  // TestLikeSnippet_EarliestMatch centres the snippet on whichever term matches
439  // earliest so multi-term queries surface the relevant context instead of
440  // always anchoring on the first term.
441  func TestLikeSnippet_EarliestMatch(t *testing.T) {
442  	content := "failed at startup before 登录 was ever tried"
443  	// Query terms: 登录 (later) and failed (earlier). Snippet should centre
444  	// on 'failed' since it appears first in content.
445  	snip := likeSnippet(content, []string{"登录", "failed"})
446  	if !strings.Contains(snip, ">>>failed<<<") {
447  		t.Errorf("expected snippet centred on earliest term 'failed', got %q", snip)
448  	}
449  }
450  
451  func keys(m map[string]bool) []string {
452  	out := make([]string, 0, len(m))
453  	for k := range m {
454  		out = append(out, k)
455  	}
456  	return out
457  }