memory_test.go
1 package tools 2 3 import ( 4 "context" 5 "encoding/json" 6 "strings" 7 "testing" 8 9 "github.com/Kocoro-lab/ShanClaw/internal/memory" 10 ) 11 12 type fakeFallback struct { 13 snippet string 14 hits []any 15 gotQuery string 16 gotLimit int 17 snippetQuery string 18 } 19 20 func (f *fakeFallback) SessionKeyword(_ context.Context, q string, limit int) ([]any, error) { 21 f.gotQuery = q 22 f.gotLimit = limit 23 return f.hits, nil 24 } 25 func (f *fakeFallback) MemoryFileSnippet(_ context.Context, q string) (string, error) { 26 f.snippetQuery = q 27 return f.snippet, nil 28 } 29 30 func TestMemoryTool_FallbackWhenNoService(t *testing.T) { 31 tool := &MemoryTool{Fallback: &fakeFallback{snippet: "memory.md note"}} 32 res, err := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`) 33 if err != nil { 34 t.Fatalf("err=%v", err) 35 } 36 if res.IsError { 37 t.Fatalf("res.IsError=true: %s", res.Content) 38 } 39 var body map[string]any 40 if err := json.Unmarshal([]byte(res.Content), &body); err != nil { 41 t.Fatalf("decode tool result: %v\n%s", err, res.Content) 42 } 43 if body["source"] != "fallback" { 44 t.Fatalf("source=%v want fallback", body["source"]) 45 } 46 if body["evidence_quality"] != "text_search" { 47 t.Fatalf("evidence_quality=%v", body["evidence_quality"]) 48 } 49 } 50 51 func TestMemoryTool_RejectsEmptyAnchorMentions(t *testing.T) { 52 tool := &MemoryTool{Fallback: &fakeFallback{}} 53 res, _ := tool.Run(context.Background(), `{"anchor_mentions":[]}`) 54 if !res.IsError || !strings.Contains(res.Content, "anchor_mentions") { 55 t.Fatalf("res=%+v", res) 56 } 57 } 58 59 func TestMemoryTool_RejectsMissingAnchorMentions(t *testing.T) { 60 tool := &MemoryTool{Fallback: &fakeFallback{}} 61 res, _ := tool.Run(context.Background(), `{}`) 62 if !res.IsError { 63 t.Fatalf("expected error: %s", res.Content) 64 } 65 } 66 67 func TestMemoryTool_RejectsMalformedJSON(t *testing.T) { 68 tool := &MemoryTool{Fallback: &fakeFallback{}} 69 res, _ := tool.Run(context.Background(), `{not json`) 70 if !res.IsError { 71 t.Fatalf("expected error: %s", res.Content) 72 } 73 } 74 75 type stubQuerier struct { 76 status memory.ServiceStatus 77 env *memory.ResponseEnvelope 78 class memory.ErrorClass 79 err error 80 callN int 81 seq []memory.ErrorClass // optional: deliver i-th class on i-th call 82 } 83 84 func (s *stubQuerier) Status() memory.ServiceStatus { return s.status } 85 func (s *stubQuerier) Query(_ context.Context, _ memory.QueryIntent) (*memory.ResponseEnvelope, memory.ErrorClass, error) { 86 s.callN++ 87 if len(s.seq) > 0 { 88 c := s.seq[0] 89 s.seq = s.seq[1:] 90 return s.env, c, nil 91 } 92 return s.env, s.class, s.err 93 } 94 95 func TestMemoryTool_ClassOK(t *testing.T) { 96 env := &memory.ResponseEnvelope{ 97 Reason: "ok", 98 BundleVersion: "0.4.0", 99 Candidates: []memory.QueryCandidate{ 100 {Value: "v", Score: 0.9, Evidence: "observed", SupportingEventIDs: []string{"e1"}}, 101 }, 102 } 103 tool := &MemoryTool{ 104 Service: &stubQuerier{status: memory.StatusReady, env: env, class: memory.ClassOK}, 105 Fallback: &fakeFallback{}, 106 } 107 res, err := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`) 108 if err != nil || res.IsError { 109 t.Fatalf("res=%+v err=%v", res, err) 110 } 111 var body map[string]any 112 json.Unmarshal([]byte(res.Content), &body) 113 if body["source"] != "memory_sidecar" { 114 t.Fatalf("source=%v", body["source"]) 115 } 116 if body["evidence_quality"] != "structured" { 117 t.Fatalf("evidence_quality=%v", body["evidence_quality"]) 118 } 119 if body["bundle_version"] != "0.4.0" { 120 t.Fatalf("bundle_version=%v", body["bundle_version"]) 121 } 122 cands, _ := body["candidates"].([]any) 123 if len(cands) != 1 { 124 t.Fatalf("candidates=%+v", cands) 125 } 126 } 127 128 func TestMemoryTool_ClassDegraded(t *testing.T) { 129 env := &memory.ResponseEnvelope{ 130 Reason: "degraded", 131 BundleVersion: "0.4.0", 132 Candidates: []memory.QueryCandidate{{Value: "v", Evidence: "observed"}}, 133 } 134 tool := &MemoryTool{ 135 Service: &stubQuerier{status: memory.StatusReady, env: env, class: memory.ClassOK}, 136 Fallback: &fakeFallback{}, 137 } 138 res, _ := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`) 139 var body map[string]any 140 json.Unmarshal([]byte(res.Content), &body) 141 if body["evidence_quality"] != "structured_degraded" { 142 t.Fatalf("evidence_quality=%v", body["evidence_quality"]) 143 } 144 warnings, _ := body["warnings"].([]any) 145 if len(warnings) == 0 { 146 t.Fatal("expected degraded warning") 147 } 148 w0, _ := warnings[0].(map[string]any) 149 if msg, _ := w0["message"].(string); !strings.Contains(msg, "degraded") { 150 t.Fatalf("first warning should be the degraded notice; got %+v", w0) 151 } 152 } 153 154 func TestMemoryTool_ClassRetryable_OneRetryThenFallback(t *testing.T) { 155 sq := &stubQuerier{status: memory.StatusReady, seq: []memory.ErrorClass{memory.ClassRetryable, memory.ClassRetryable}} 156 tool := &MemoryTool{Service: sq, Fallback: &fakeFallback{}} 157 res, _ := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`) 158 var body map[string]any 159 json.Unmarshal([]byte(res.Content), &body) 160 if body["source"] != "fallback_after_retry" { 161 t.Fatalf("source=%v", body["source"]) 162 } 163 if body["fallback_reason"] != "retryable_failed" { 164 t.Fatalf("fallback_reason=%v", body["fallback_reason"]) 165 } 166 if sq.callN != 2 { 167 t.Fatalf("expected 2 calls, got %d", sq.callN) 168 } 169 } 170 171 func TestMemoryTool_ClassPermanent_SurfacesDiagnostics(t *testing.T) { 172 env := &memory.ResponseEnvelope{ 173 Reason: "error", 174 Error: &memory.ErrorObject{ 175 Code: "validation_error", 176 Message: "bad mode", 177 Details: map[string]any{"sub_code": "schema_validation"}, 178 }, 179 } 180 tool := &MemoryTool{ 181 Service: &stubQuerier{status: memory.StatusReady, env: env, class: memory.ClassPermanent}, 182 Fallback: &fakeFallback{}, 183 } 184 res, _ := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`) 185 if !res.IsError { 186 t.Fatal("permanent should be IsError") 187 } 188 var body map[string]any 189 json.Unmarshal([]byte(res.Content), &body) 190 warnings, _ := body["warnings"].([]any) 191 if len(warnings) == 0 { 192 t.Fatal("expected warnings with sub_code") 193 } 194 w0, _ := warnings[0].(map[string]any) 195 if w0["sub_code"] != "schema_validation" { 196 t.Fatalf("w0=%+v", w0) 197 } 198 } 199 200 func TestMemoryTool_ClassUnavailable_FallsBack(t *testing.T) { 201 tool := &MemoryTool{ 202 Service: &stubQuerier{status: memory.StatusReady, class: memory.ClassUnavailable}, 203 Fallback: &fakeFallback{}, 204 } 205 res, _ := tool.Run(context.Background(), `{"anchor_mentions":["x"]}`) 206 var body map[string]any 207 json.Unmarshal([]byte(res.Content), &body) 208 if body["source"] != "fallback" { 209 t.Fatalf("source=%v", body["source"]) 210 } 211 if body["fallback_reason"] != "service_unavailable" { 212 t.Fatalf("fallback_reason=%v", body["fallback_reason"]) 213 } 214 } 215 216 func TestMemoryTool_Info(t *testing.T) { 217 tool := &MemoryTool{Fallback: &fakeFallback{}} 218 info := tool.Info() 219 if info.Name != "memory_recall" { 220 t.Fatalf("name=%q want memory_recall", info.Name) 221 } 222 if !tool.IsReadOnlyCall("") { 223 t.Fatal("memory_recall must be read-only") 224 } 225 if tool.RequiresApproval() { 226 t.Fatal("memory_recall must not require approval") 227 } 228 // Required field declared. 229 found := false 230 for _, r := range info.Required { 231 if r == "anchor_mentions" { 232 found = true 233 } 234 } 235 if !found { 236 t.Fatal("anchor_mentions must be in Required") 237 } 238 } 239 240 // TestMemoryTool_FallbackInvokesProvider locks the contract that fallback() 241 // actually delegates to FallbackQuery.SessionKeyword + MemoryFileSnippet 242 // (regression: earlier the fallback returned an empty envelope and the 243 // provider plumbing was dead code). 244 func TestMemoryTool_FallbackInvokesProvider(t *testing.T) { 245 fb := &fakeFallback{ 246 snippet: "MEMORY.md hit line", 247 hits: []any{map[string]any{"id": "sess1", "snippet": "from session_search"}}, 248 } 249 tool := &MemoryTool{Fallback: fb} 250 res, _ := tool.Run(context.Background(), `{"anchor_mentions":["foo","bar"],"result_limit":7}`) 251 if res.IsError { 252 t.Fatalf("unexpected error: %s", res.Content) 253 } 254 if fb.gotQuery != "foo bar" { 255 t.Fatalf("session keyword query = %q want %q", fb.gotQuery, "foo bar") 256 } 257 if fb.gotLimit != 7 { 258 t.Fatalf("session keyword limit = %d want 7", fb.gotLimit) 259 } 260 if fb.snippetQuery != "foo bar" { 261 t.Fatalf("memory snippet query = %q want %q", fb.snippetQuery, "foo bar") 262 } 263 var body map[string]any 264 if err := json.Unmarshal([]byte(res.Content), &body); err != nil { 265 t.Fatal(err) 266 } 267 cands, _ := body["candidates"].([]any) 268 if len(cands) != 2 { 269 t.Fatalf("expected 2 fallback candidates (1 session hit + 1 MEMORY.md snippet), got %d: %+v", len(cands), cands) 270 } 271 }