/ internal / tools / memory_test.go
memory_test.go
  1  package tools
  2  
  3  import (
  4  	"context"
  5  	"encoding/json"
  6  	"strings"
  7  	"testing"
  8  
  9  	"github.com/Kocoro-lab/ShanClaw/internal/memory"
 10  )
 11  
 12  type fakeFallback struct {
 13  	snippet      string
 14  	hits         []any
 15  	gotQuery     string
 16  	gotLimit     int
 17  	snippetQuery string
 18  }
 19  
 20  func (f *fakeFallback) SessionKeyword(_ context.Context, q string, limit int) ([]any, error) {
 21  	f.gotQuery = q
 22  	f.gotLimit = limit
 23  	return f.hits, nil
 24  }
 25  func (f *fakeFallback) MemoryFileSnippet(_ context.Context, q string) (string, error) {
 26  	f.snippetQuery = q
 27  	return f.snippet, nil
 28  }
 29  
 30  func TestMemoryTool_FallbackWhenNoService(t *testing.T) {
 31  	tool := &MemoryTool{Fallback: &fakeFallback{snippet: "memory.md note"}}
 32  	res, err := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`)
 33  	if err != nil {
 34  		t.Fatalf("err=%v", err)
 35  	}
 36  	if res.IsError {
 37  		t.Fatalf("res.IsError=true: %s", res.Content)
 38  	}
 39  	var body map[string]any
 40  	if err := json.Unmarshal([]byte(res.Content), &body); err != nil {
 41  		t.Fatalf("decode tool result: %v\n%s", err, res.Content)
 42  	}
 43  	if body["source"] != "fallback" {
 44  		t.Fatalf("source=%v want fallback", body["source"])
 45  	}
 46  	if body["evidence_quality"] != "text_search" {
 47  		t.Fatalf("evidence_quality=%v", body["evidence_quality"])
 48  	}
 49  }
 50  
 51  func TestMemoryTool_RejectsEmptyAnchorMentions(t *testing.T) {
 52  	tool := &MemoryTool{Fallback: &fakeFallback{}}
 53  	res, _ := tool.Run(context.Background(), `{"anchor_mentions":[]}`)
 54  	if !res.IsError || !strings.Contains(res.Content, "anchor_mentions") {
 55  		t.Fatalf("res=%+v", res)
 56  	}
 57  }
 58  
 59  func TestMemoryTool_RejectsMissingAnchorMentions(t *testing.T) {
 60  	tool := &MemoryTool{Fallback: &fakeFallback{}}
 61  	res, _ := tool.Run(context.Background(), `{}`)
 62  	if !res.IsError {
 63  		t.Fatalf("expected error: %s", res.Content)
 64  	}
 65  }
 66  
 67  func TestMemoryTool_RejectsMalformedJSON(t *testing.T) {
 68  	tool := &MemoryTool{Fallback: &fakeFallback{}}
 69  	res, _ := tool.Run(context.Background(), `{not json`)
 70  	if !res.IsError {
 71  		t.Fatalf("expected error: %s", res.Content)
 72  	}
 73  }
 74  
 75  type stubQuerier struct {
 76  	status memory.ServiceStatus
 77  	env    *memory.ResponseEnvelope
 78  	class  memory.ErrorClass
 79  	err    error
 80  	callN  int
 81  	seq    []memory.ErrorClass // optional: deliver i-th class on i-th call
 82  }
 83  
 84  func (s *stubQuerier) Status() memory.ServiceStatus { return s.status }
 85  func (s *stubQuerier) Query(_ context.Context, _ memory.QueryIntent) (*memory.ResponseEnvelope, memory.ErrorClass, error) {
 86  	s.callN++
 87  	if len(s.seq) > 0 {
 88  		c := s.seq[0]
 89  		s.seq = s.seq[1:]
 90  		return s.env, c, nil
 91  	}
 92  	return s.env, s.class, s.err
 93  }
 94  
 95  func TestMemoryTool_ClassOK(t *testing.T) {
 96  	env := &memory.ResponseEnvelope{
 97  		Reason:        "ok",
 98  		BundleVersion: "0.4.0",
 99  		Candidates: []memory.QueryCandidate{
100  			{Value: "v", Score: 0.9, Evidence: "observed", SupportingEventIDs: []string{"e1"}},
101  		},
102  	}
103  	tool := &MemoryTool{
104  		Service:  &stubQuerier{status: memory.StatusReady, env: env, class: memory.ClassOK},
105  		Fallback: &fakeFallback{},
106  	}
107  	res, err := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`)
108  	if err != nil || res.IsError {
109  		t.Fatalf("res=%+v err=%v", res, err)
110  	}
111  	var body map[string]any
112  	json.Unmarshal([]byte(res.Content), &body)
113  	if body["source"] != "memory_sidecar" {
114  		t.Fatalf("source=%v", body["source"])
115  	}
116  	if body["evidence_quality"] != "structured" {
117  		t.Fatalf("evidence_quality=%v", body["evidence_quality"])
118  	}
119  	if body["bundle_version"] != "0.4.0" {
120  		t.Fatalf("bundle_version=%v", body["bundle_version"])
121  	}
122  	cands, _ := body["candidates"].([]any)
123  	if len(cands) != 1 {
124  		t.Fatalf("candidates=%+v", cands)
125  	}
126  }
127  
128  func TestMemoryTool_ClassDegraded(t *testing.T) {
129  	env := &memory.ResponseEnvelope{
130  		Reason:        "degraded",
131  		BundleVersion: "0.4.0",
132  		Candidates:    []memory.QueryCandidate{{Value: "v", Evidence: "observed"}},
133  	}
134  	tool := &MemoryTool{
135  		Service:  &stubQuerier{status: memory.StatusReady, env: env, class: memory.ClassOK},
136  		Fallback: &fakeFallback{},
137  	}
138  	res, _ := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`)
139  	var body map[string]any
140  	json.Unmarshal([]byte(res.Content), &body)
141  	if body["evidence_quality"] != "structured_degraded" {
142  		t.Fatalf("evidence_quality=%v", body["evidence_quality"])
143  	}
144  	warnings, _ := body["warnings"].([]any)
145  	if len(warnings) == 0 {
146  		t.Fatal("expected degraded warning")
147  	}
148  	w0, _ := warnings[0].(map[string]any)
149  	if msg, _ := w0["message"].(string); !strings.Contains(msg, "degraded") {
150  		t.Fatalf("first warning should be the degraded notice; got %+v", w0)
151  	}
152  }
153  
154  func TestMemoryTool_ClassRetryable_OneRetryThenFallback(t *testing.T) {
155  	sq := &stubQuerier{status: memory.StatusReady, seq: []memory.ErrorClass{memory.ClassRetryable, memory.ClassRetryable}}
156  	tool := &MemoryTool{Service: sq, Fallback: &fakeFallback{}}
157  	res, _ := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`)
158  	var body map[string]any
159  	json.Unmarshal([]byte(res.Content), &body)
160  	if body["source"] != "fallback_after_retry" {
161  		t.Fatalf("source=%v", body["source"])
162  	}
163  	if body["fallback_reason"] != "retryable_failed" {
164  		t.Fatalf("fallback_reason=%v", body["fallback_reason"])
165  	}
166  	if sq.callN != 2 {
167  		t.Fatalf("expected 2 calls, got %d", sq.callN)
168  	}
169  }
170  
171  func TestMemoryTool_ClassPermanent_SurfacesDiagnostics(t *testing.T) {
172  	env := &memory.ResponseEnvelope{
173  		Reason: "error",
174  		Error: &memory.ErrorObject{
175  			Code:    "validation_error",
176  			Message: "bad mode",
177  			Details: map[string]any{"sub_code": "schema_validation"},
178  		},
179  	}
180  	tool := &MemoryTool{
181  		Service:  &stubQuerier{status: memory.StatusReady, env: env, class: memory.ClassPermanent},
182  		Fallback: &fakeFallback{},
183  	}
184  	res, _ := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`)
185  	if !res.IsError {
186  		t.Fatal("permanent should be IsError")
187  	}
188  	var body map[string]any
189  	json.Unmarshal([]byte(res.Content), &body)
190  	warnings, _ := body["warnings"].([]any)
191  	if len(warnings) == 0 {
192  		t.Fatal("expected warnings with sub_code")
193  	}
194  	w0, _ := warnings[0].(map[string]any)
195  	if w0["sub_code"] != "schema_validation" {
196  		t.Fatalf("w0=%+v", w0)
197  	}
198  }
199  
200  func TestMemoryTool_ClassUnavailable_FallsBack(t *testing.T) {
201  	tool := &MemoryTool{
202  		Service:  &stubQuerier{status: memory.StatusReady, class: memory.ClassUnavailable},
203  		Fallback: &fakeFallback{},
204  	}
205  	res, _ := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`)
206  	var body map[string]any
207  	json.Unmarshal([]byte(res.Content), &body)
208  	if body["source"] != "fallback" {
209  		t.Fatalf("source=%v", body["source"])
210  	}
211  	if body["fallback_reason"] != "service_unavailable" {
212  		t.Fatalf("fallback_reason=%v", body["fallback_reason"])
213  	}
214  }
215  
216  func TestMemoryTool_Info(t *testing.T) {
217  	tool := &MemoryTool{Fallback: &fakeFallback{}}
218  	info := tool.Info()
219  	if info.Name != "memory_recall" {
220  		t.Fatalf("name=%q want memory_recall", info.Name)
221  	}
222  	if !tool.IsReadOnlyCall("") {
223  		t.Fatal("memory_recall must be read-only")
224  	}
225  	if tool.RequiresApproval() {
226  		t.Fatal("memory_recall must not require approval")
227  	}
228  	// Required field declared.
229  	found := false
230  	for _, r := range info.Required {
231  		if r == "anchor_mentions" {
232  			found = true
233  		}
234  	}
235  	if !found {
236  		t.Fatal("anchor_mentions must be in Required")
237  	}
238  }
239  
240  // TestMemoryTool_FallbackInvokesProvider locks the contract that fallback()
241  // actually delegates to FallbackQuery.SessionKeyword + MemoryFileSnippet
242  // (regression: earlier the fallback returned an empty envelope and the
243  // provider plumbing was dead code).
244  func TestMemoryTool_FallbackInvokesProvider(t *testing.T) {
245  	fb := &fakeFallback{
246  		snippet: "MEMORY.md hit line",
247  		hits:    []any{map[string]any{"id": "sess1", "snippet": "from session_search"}},
248  	}
249  	tool := &MemoryTool{Fallback: fb}
250  	res, _ := tool.Run(context.Background(), `{"anchor_mentions":["foo","bar"],"result_limit":7}`)
251  	if res.IsError {
252  		t.Fatalf("unexpected error: %s", res.Content)
253  	}
254  	if fb.gotQuery != "foo bar" {
255  		t.Fatalf("session keyword query = %q want %q", fb.gotQuery, "foo bar")
256  	}
257  	if fb.gotLimit != 7 {
258  		t.Fatalf("session keyword limit = %d want 7", fb.gotLimit)
259  	}
260  	if fb.snippetQuery != "foo bar" {
261  		t.Fatalf("memory snippet query = %q want %q", fb.snippetQuery, "foo bar")
262  	}
263  	var body map[string]any
264  	if err := json.Unmarshal([]byte(res.Content), &body); err != nil {
265  		t.Fatal(err)
266  	}
267  	cands, _ := body["candidates"].([]any)
268  	if len(cands) != 2 {
269  		t.Fatalf("expected 2 fallback candidates (1 session hit + 1 MEMORY.md snippet), got %d: %+v", len(cands), cands)
270  	}
271  }