checkpoint_test.go
1 package daemon 2 3 import ( 4 "testing" 5 "time" 6 7 "github.com/Kocoro-lab/ShanClaw/internal/agent" 8 "github.com/Kocoro-lab/ShanClaw/internal/client" 9 "github.com/Kocoro-lab/ShanClaw/internal/session" 10 ) 11 12 // usageStub fulfills agent.UsageProvider for applyTurnUsage tests. 13 type usageStub struct{ usage agent.AccumulatedUsage } 14 15 func (u *usageStub) Usage() agent.AccumulatedUsage { return u.usage } 16 17 // checkpointTestLoop exposes a way to inject run messages without a live 18 // agent loop, for unit-testing applyRunMessagesToSession's idempotency. 19 type checkpointTestLoop struct { 20 *agent.AgentLoop 21 msgs []client.Message 22 } 23 24 // We directly construct a real AgentLoop, then use its public 25 // RunMessages(). Since that getter reads from internal state only set 26 // inside Run(), we fall back to constructing a test harness below. 27 28 // Here we just exercise applyRunMessagesToSession directly with a hand- 29 // built session and fake loop-messages. The function is the idempotency 30 // linchpin, so it deserves direct coverage. 31 func TestApplyTurnMessages_Idempotent(t *testing.T) { 32 // Baseline: session with system + one pre-loop user message already. 33 sess := &session.Session{ 34 ID: "sess-1", 35 Messages: []client.Message{ 36 {Role: "system", Content: client.NewTextContent("system")}, 37 {Role: "user", Content: client.NewTextContent("hello")}, 38 }, 39 MessageMeta: []session.MessageMeta{ 40 {Source: "web"}, 41 {Source: "web", Timestamp: session.TimePtr(time.Now())}, 42 }, 43 } 44 base := captureTurnBaseline(sess, "web", true) 45 46 loop := agent.NewAgentLoop(nil, agent.NewToolRegistry(), "m", "", 1, 1, 1, nil, nil, nil) 47 48 // Round 1. 49 agent.SetRunMessagesForTest(loop, []client.Message{ 50 {Role: "user", Content: client.NewTextContent("hello")}, 51 {Role: "assistant", Content: client.NewTextContent("call tool")}, 52 {Role: "user", Content: client.NewTextContent("tool result")}, 53 }) 54 applyTurnMessages(sess, loop, base) 55 if got := len(sess.Messages); got != base.msgCount+2 { 56 t.Fatalf("round 1: want %d msgs, got %d", base.msgCount+2, got) 57 } 58 59 // Round 2. 60 agent.SetRunMessagesForTest(loop, []client.Message{ 61 {Role: "user", Content: client.NewTextContent("hello")}, 62 {Role: "assistant", Content: client.NewTextContent("call tool 1")}, 63 {Role: "user", Content: client.NewTextContent("result 1")}, 64 {Role: "assistant", Content: client.NewTextContent("call tool 2")}, 65 {Role: "user", Content: client.NewTextContent("result 2")}, 66 }) 67 applyTurnMessages(sess, loop, base) 68 if got := len(sess.Messages); got != base.msgCount+4 { 69 t.Fatalf("round 2: want %d msgs, got %d", base.msgCount+4, got) 70 } 71 72 // Round 3: compaction shrink. 73 agent.SetRunMessagesForTest(loop, []client.Message{ 74 {Role: "user", Content: client.NewTextContent("hello")}, 75 {Role: "assistant", Content: client.NewTextContent("compacted summary")}, 76 }) 77 applyTurnMessages(sess, loop, base) 78 if got := len(sess.Messages); got != base.msgCount+1 { 79 t.Fatalf("round 3 (compaction): want %d msgs, got %d", base.msgCount+1, got) 80 } 81 if len(sess.Messages) != len(sess.MessageMeta) { 82 t.Fatalf("meta drift: %d vs %d", len(sess.Messages), len(sess.MessageMeta)) 83 } 84 if sess.Messages[0].Role != "system" || sess.Messages[1].Role != "user" { 85 t.Fatalf("baseline corrupted") 86 } 87 } 88 89 // Regression for finding #1: a turn that produces a mid-turn checkpoint 90 // followed by a final save must end with ONE canonical transcript, not 91 // a duplicated one. Both paths share applyTurnMessages + the same baseline 92 // so iteration count is irrelevant. 93 func TestApplyTurnMessages_CheckpointThenFinalSave_NoDuplicate(t *testing.T) { 94 sess := &session.Session{ 95 Messages: []client.Message{ 96 {Role: "system", Content: client.NewTextContent("sys")}, 97 {Role: "user", Content: client.NewTextContent("hi")}, 98 }, 99 MessageMeta: []session.MessageMeta{{Source: "web"}, {Source: "web"}}, 100 } 101 base := captureTurnBaseline(sess, "web", true) 102 loop := agent.NewAgentLoop(nil, agent.NewToolRegistry(), "m", "", 1, 1, 1, nil, nil, nil) 103 104 // Simulate: tool batch completes → checkpoint fires. 105 agent.SetRunMessagesForTest(loop, []client.Message{ 106 {Role: "user", Content: client.NewTextContent("hi")}, 107 {Role: "assistant", Content: client.NewTextContent("[tool_use]")}, 108 {Role: "user", Content: client.NewTextContent("[tool_result]")}, 109 }) 110 applyTurnMessages(sess, loop, base) // mid-turn checkpoint 111 112 // Turn completes: final text appended to RunMessages. 113 agent.SetRunMessagesForTest(loop, []client.Message{ 114 {Role: "user", Content: client.NewTextContent("hi")}, 115 {Role: "assistant", Content: client.NewTextContent("[tool_use]")}, 116 {Role: "user", Content: client.NewTextContent("[tool_result]")}, 117 {Role: "assistant", Content: client.NewTextContent("final answer")}, 118 }) 119 applyTurnMessages(sess, loop, base) // final save 120 121 // Expected: baseline(2) + 3 post-user messages = 5. No duplicates. 122 if got := len(sess.Messages); got != 5 { 123 t.Fatalf("expected 5 messages (2 baseline + 3 turn), got %d", got) 124 } 125 // Check the sequence has exactly one [tool_use] and one [tool_result]. 126 var countToolUse, countToolResult, countFinal int 127 for _, m := range sess.Messages { 128 switch m.Content.Text() { 129 case "[tool_use]": 130 countToolUse++ 131 case "[tool_result]": 132 countToolResult++ 133 case "final answer": 134 countFinal++ 135 } 136 } 137 if countToolUse != 1 || countToolResult != 1 || countFinal != 1 { 138 t.Fatalf("duplicated transcript: tool_use=%d tool_result=%d final=%d", 139 countToolUse, countToolResult, countFinal) 140 } 141 } 142 143 // Regression for hard-error-after-checkpoint: a non-soft failure after 144 // one or more successful mid-turn checkpoints must NOT duplicate the 145 // transcript (checkpoint already persisted it) and must NOT double-count 146 // usage (additive AddUsage on top of already-folded usage was the bug). 147 // This test mirrors the runner's hard-error path inline. 148 func TestApplyTurnState_HardErrorAfterCheckpoint_NoDuplicate(t *testing.T) { 149 sess := &session.Session{ 150 Messages: []client.Message{ 151 {Role: "user", Content: client.NewTextContent("do thing")}, 152 }, 153 MessageMeta: []session.MessageMeta{{Source: "web"}}, 154 Usage: &session.UsageSummary{InputTokens: 100, LLMCalls: 1}, 155 } 156 base := captureTurnBaseline(sess, "web", true) 157 loop := agent.NewAgentLoop(nil, agent.NewToolRegistry(), "m", "", 1, 1, 1, nil, nil, nil) 158 up := &usageStub{usage: agent.AccumulatedUsage{ 159 LLM: agent.TurnUsage{InputTokens: 50, LLMCalls: 1}, 160 }} 161 162 // Step 1: mid-turn checkpoint after a successful tool batch. 163 agent.SetRunMessagesForTest(loop, []client.Message{ 164 {Role: "user", Content: client.NewTextContent("do thing")}, 165 {Role: "assistant", Content: client.NewTextContent("[tool_use]")}, 166 {Role: "user", Content: client.NewTextContent("[tool_result]")}, 167 }) 168 applyTurnMessages(sess, loop, base) 169 applyTurnUsage(sess, up, base) 170 // Sanity: 1 baseline + 2 turn msgs = 3. Usage: 100+50=150. 171 if len(sess.Messages) != 3 { 172 t.Fatalf("after checkpoint: want 3 msgs, got %d", len(sess.Messages)) 173 } 174 if sess.Usage.InputTokens != 150 { 175 t.Fatalf("after checkpoint: want 150 input tokens, got %d", sess.Usage.InputTokens) 176 } 177 178 // Step 2: hard error fires. The runner's hard-error path rebuilds 179 // from baseline + current RunMessages, appends a friendly error stub, 180 // then applies usage. The accumulator has grown slightly (e.g., one 181 // more failed LLM call). 182 up.usage.LLM.InputTokens = 70 // +20 since checkpoint 183 up.usage.LLM.LLMCalls = 2 184 applyTurnMessages(sess, loop, base) 185 sess.Messages = append(sess.Messages, 186 client.Message{Role: "assistant", Content: client.NewTextContent("Sorry, something failed.")}, 187 ) 188 sess.MessageMeta = append(sess.MessageMeta, 189 session.MessageMeta{Source: "web", Timestamp: session.TimePtr(time.Now())}, 190 ) 191 applyTurnUsage(sess, up, base) 192 193 // Expected: 1 baseline + 2 turn + 1 error stub = 4 total. No duplicates. 194 if len(sess.Messages) != 4 { 195 t.Fatalf("after hard error: want 4 msgs (1 baseline + 2 turn + 1 error), got %d", len(sess.Messages)) 196 } 197 // Usage: 100 baseline + 70 current = 170. NOT 100+50+70=220 (double-count). 198 if sess.Usage.InputTokens != 170 { 199 t.Fatalf("after hard error: want 170 input tokens (baseline+current), got %d (double-counted)", sess.Usage.InputTokens) 200 } 201 if sess.Usage.LLMCalls != 3 { 202 t.Fatalf("after hard error: want 3 LLMCalls (1 baseline + 2 current), got %d", sess.Usage.LLMCalls) 203 } 204 // Duplicate scan: exactly one tool_use and one tool_result. 205 var toolUse, toolResult, errStub int 206 for _, m := range sess.Messages { 207 switch m.Content.Text() { 208 case "[tool_use]": 209 toolUse++ 210 case "[tool_result]": 211 toolResult++ 212 case "Sorry, something failed.": 213 errStub++ 214 } 215 } 216 if toolUse != 1 || toolResult != 1 || errStub != 1 { 217 t.Fatalf("duplicate in hard-error path: tool_use=%d tool_result=%d err=%d", 218 toolUse, toolResult, errStub) 219 } 220 } 221 222 // Regression for finding #3: usage survives mid-turn checkpoint + final 223 // save without being double-counted. Baseline + current accumulator is 224 // the authoritative value at every save. 225 func TestApplyTurnUsage_IdempotentAcrossCheckpointAndFinalSave(t *testing.T) { 226 sess := &session.Session{Usage: &session.UsageSummary{ 227 InputTokens: 100, OutputTokens: 50, TotalTokens: 150, LLMCalls: 1, 228 }} 229 base := captureTurnBaseline(sess, "web", false) 230 up := &usageStub{usage: agent.AccumulatedUsage{ 231 LLM: agent.TurnUsage{InputTokens: 20, OutputTokens: 10, TotalTokens: 30, LLMCalls: 1}, 232 }} 233 234 // First call: mid-turn checkpoint after first LLM call. 235 applyTurnUsage(sess, up, base) 236 if sess.Usage.InputTokens != 120 || sess.Usage.OutputTokens != 60 || sess.Usage.LLMCalls != 2 { 237 t.Fatalf("after checkpoint: %+v", sess.Usage) 238 } 239 240 // Second call: accumulator grew (second LLM call). Final save uses 241 // the SAME baseline — must not double-count the first call. 242 up.usage = agent.AccumulatedUsage{ 243 LLM: agent.TurnUsage{InputTokens: 40, OutputTokens: 20, TotalTokens: 60, LLMCalls: 2}, 244 } 245 applyTurnUsage(sess, up, base) 246 // Expected: baseline(100/50/1) + current(40/20/2) = 140/70/3 247 if sess.Usage.InputTokens != 140 || sess.Usage.OutputTokens != 70 || sess.Usage.LLMCalls != 3 { 248 t.Fatalf("after final save (double-count regression): %+v", sess.Usage) 249 } 250 } 251 252 func TestSessionInProgress_FlagCycles(t *testing.T) { 253 sess := &session.Session{} 254 if sess.InProgress { 255 t.Fatal("fresh session should not be InProgress") 256 } 257 sess.InProgress = true 258 sess.InProgress = false 259 if sess.InProgress { 260 t.Fatal("toggle off didn't clear") 261 } 262 }