/ internal / daemon / checkpoint_test.go
checkpoint_test.go
  1  package daemon
  2  
  3  import (
  4  	"testing"
  5  	"time"
  6  
  7  	"github.com/Kocoro-lab/ShanClaw/internal/agent"
  8  	"github.com/Kocoro-lab/ShanClaw/internal/client"
  9  	"github.com/Kocoro-lab/ShanClaw/internal/session"
 10  )
 11  
 12  // usageStub fulfills agent.UsageProvider for applyTurnUsage tests.
 13  type usageStub struct{ usage agent.AccumulatedUsage }
 14  
 15  func (u *usageStub) Usage() agent.AccumulatedUsage { return u.usage }
 16  
 17  // checkpointTestLoop exposes a way to inject run messages without a live
 18  // agent loop, for unit-testing applyRunMessagesToSession's idempotency.
 19  type checkpointTestLoop struct {
 20  	*agent.AgentLoop
 21  	msgs []client.Message
 22  }
 23  
 24  // We directly construct a real AgentLoop, then use its public
 25  // RunMessages(). Since that getter reads from internal state only set
 26  // inside Run(), we fall back to constructing a test harness below.
 27  
 28  // Here we just exercise applyRunMessagesToSession directly with a hand-
 29  // built session and fake loop-messages. The function is the idempotency
 30  // linchpin, so it deserves direct coverage.
 31  func TestApplyTurnMessages_Idempotent(t *testing.T) {
 32  	// Baseline: session with system + one pre-loop user message already.
 33  	sess := &session.Session{
 34  		ID: "sess-1",
 35  		Messages: []client.Message{
 36  			{Role: "system", Content: client.NewTextContent("system")},
 37  			{Role: "user", Content: client.NewTextContent("hello")},
 38  		},
 39  		MessageMeta: []session.MessageMeta{
 40  			{Source: "web"},
 41  			{Source: "web", Timestamp: session.TimePtr(time.Now())},
 42  		},
 43  	}
 44  	base := captureTurnBaseline(sess, "web", true)
 45  
 46  	loop := agent.NewAgentLoop(nil, agent.NewToolRegistry(), "m", "", 1, 1, 1, nil, nil, nil)
 47  
 48  	// Round 1.
 49  	agent.SetRunMessagesForTest(loop, []client.Message{
 50  		{Role: "user", Content: client.NewTextContent("hello")},
 51  		{Role: "assistant", Content: client.NewTextContent("call tool")},
 52  		{Role: "user", Content: client.NewTextContent("tool result")},
 53  	})
 54  	applyTurnMessages(sess, loop, base)
 55  	if got := len(sess.Messages); got != base.msgCount+2 {
 56  		t.Fatalf("round 1: want %d msgs, got %d", base.msgCount+2, got)
 57  	}
 58  
 59  	// Round 2.
 60  	agent.SetRunMessagesForTest(loop, []client.Message{
 61  		{Role: "user", Content: client.NewTextContent("hello")},
 62  		{Role: "assistant", Content: client.NewTextContent("call tool 1")},
 63  		{Role: "user", Content: client.NewTextContent("result 1")},
 64  		{Role: "assistant", Content: client.NewTextContent("call tool 2")},
 65  		{Role: "user", Content: client.NewTextContent("result 2")},
 66  	})
 67  	applyTurnMessages(sess, loop, base)
 68  	if got := len(sess.Messages); got != base.msgCount+4 {
 69  		t.Fatalf("round 2: want %d msgs, got %d", base.msgCount+4, got)
 70  	}
 71  
 72  	// Round 3: compaction shrink.
 73  	agent.SetRunMessagesForTest(loop, []client.Message{
 74  		{Role: "user", Content: client.NewTextContent("hello")},
 75  		{Role: "assistant", Content: client.NewTextContent("compacted summary")},
 76  	})
 77  	applyTurnMessages(sess, loop, base)
 78  	if got := len(sess.Messages); got != base.msgCount+1 {
 79  		t.Fatalf("round 3 (compaction): want %d msgs, got %d", base.msgCount+1, got)
 80  	}
 81  	if len(sess.Messages) != len(sess.MessageMeta) {
 82  		t.Fatalf("meta drift: %d vs %d", len(sess.Messages), len(sess.MessageMeta))
 83  	}
 84  	if sess.Messages[0].Role != "system" || sess.Messages[1].Role != "user" {
 85  		t.Fatalf("baseline corrupted")
 86  	}
 87  }
 88  
 89  // Regression for finding #1: a turn that produces a mid-turn checkpoint
 90  // followed by a final save must end with ONE canonical transcript, not
 91  // a duplicated one. Both paths share applyTurnMessages + the same baseline
 92  // so iteration count is irrelevant.
 93  func TestApplyTurnMessages_CheckpointThenFinalSave_NoDuplicate(t *testing.T) {
 94  	sess := &session.Session{
 95  		Messages: []client.Message{
 96  			{Role: "system", Content: client.NewTextContent("sys")},
 97  			{Role: "user", Content: client.NewTextContent("hi")},
 98  		},
 99  		MessageMeta: []session.MessageMeta{{Source: "web"}, {Source: "web"}},
100  	}
101  	base := captureTurnBaseline(sess, "web", true)
102  	loop := agent.NewAgentLoop(nil, agent.NewToolRegistry(), "m", "", 1, 1, 1, nil, nil, nil)
103  
104  	// Simulate: tool batch completes → checkpoint fires.
105  	agent.SetRunMessagesForTest(loop, []client.Message{
106  		{Role: "user", Content: client.NewTextContent("hi")},
107  		{Role: "assistant", Content: client.NewTextContent("[tool_use]")},
108  		{Role: "user", Content: client.NewTextContent("[tool_result]")},
109  	})
110  	applyTurnMessages(sess, loop, base) // mid-turn checkpoint
111  
112  	// Turn completes: final text appended to RunMessages.
113  	agent.SetRunMessagesForTest(loop, []client.Message{
114  		{Role: "user", Content: client.NewTextContent("hi")},
115  		{Role: "assistant", Content: client.NewTextContent("[tool_use]")},
116  		{Role: "user", Content: client.NewTextContent("[tool_result]")},
117  		{Role: "assistant", Content: client.NewTextContent("final answer")},
118  	})
119  	applyTurnMessages(sess, loop, base) // final save
120  
121  	// Expected: baseline(2) + 3 post-user messages = 5. No duplicates.
122  	if got := len(sess.Messages); got != 5 {
123  		t.Fatalf("expected 5 messages (2 baseline + 3 turn), got %d", got)
124  	}
125  	// Check the sequence has exactly one [tool_use] and one [tool_result].
126  	var countToolUse, countToolResult, countFinal int
127  	for _, m := range sess.Messages {
128  		switch m.Content.Text() {
129  		case "[tool_use]":
130  			countToolUse++
131  		case "[tool_result]":
132  			countToolResult++
133  		case "final answer":
134  			countFinal++
135  		}
136  	}
137  	if countToolUse != 1 || countToolResult != 1 || countFinal != 1 {
138  		t.Fatalf("duplicated transcript: tool_use=%d tool_result=%d final=%d",
139  			countToolUse, countToolResult, countFinal)
140  	}
141  }
142  
143  // Regression for hard-error-after-checkpoint: a non-soft failure after
144  // one or more successful mid-turn checkpoints must NOT duplicate the
145  // transcript (checkpoint already persisted it) and must NOT double-count
146  // usage (additive AddUsage on top of already-folded usage was the bug).
147  // This test mirrors the runner's hard-error path inline.
148  func TestApplyTurnState_HardErrorAfterCheckpoint_NoDuplicate(t *testing.T) {
149  	sess := &session.Session{
150  		Messages: []client.Message{
151  			{Role: "user", Content: client.NewTextContent("do thing")},
152  		},
153  		MessageMeta: []session.MessageMeta{{Source: "web"}},
154  		Usage:       &session.UsageSummary{InputTokens: 100, LLMCalls: 1},
155  	}
156  	base := captureTurnBaseline(sess, "web", true)
157  	loop := agent.NewAgentLoop(nil, agent.NewToolRegistry(), "m", "", 1, 1, 1, nil, nil, nil)
158  	up := &usageStub{usage: agent.AccumulatedUsage{
159  		LLM: agent.TurnUsage{InputTokens: 50, LLMCalls: 1},
160  	}}
161  
162  	// Step 1: mid-turn checkpoint after a successful tool batch.
163  	agent.SetRunMessagesForTest(loop, []client.Message{
164  		{Role: "user", Content: client.NewTextContent("do thing")},
165  		{Role: "assistant", Content: client.NewTextContent("[tool_use]")},
166  		{Role: "user", Content: client.NewTextContent("[tool_result]")},
167  	})
168  	applyTurnMessages(sess, loop, base)
169  	applyTurnUsage(sess, up, base)
170  	// Sanity: 1 baseline + 2 turn msgs = 3. Usage: 100+50=150.
171  	if len(sess.Messages) != 3 {
172  		t.Fatalf("after checkpoint: want 3 msgs, got %d", len(sess.Messages))
173  	}
174  	if sess.Usage.InputTokens != 150 {
175  		t.Fatalf("after checkpoint: want 150 input tokens, got %d", sess.Usage.InputTokens)
176  	}
177  
178  	// Step 2: hard error fires. The runner's hard-error path rebuilds
179  	// from baseline + current RunMessages, appends a friendly error stub,
180  	// then applies usage. The accumulator has grown slightly (e.g., one
181  	// more failed LLM call).
182  	up.usage.LLM.InputTokens = 70 // +20 since checkpoint
183  	up.usage.LLM.LLMCalls = 2
184  	applyTurnMessages(sess, loop, base)
185  	sess.Messages = append(sess.Messages,
186  		client.Message{Role: "assistant", Content: client.NewTextContent("Sorry, something failed.")},
187  	)
188  	sess.MessageMeta = append(sess.MessageMeta,
189  		session.MessageMeta{Source: "web", Timestamp: session.TimePtr(time.Now())},
190  	)
191  	applyTurnUsage(sess, up, base)
192  
193  	// Expected: 1 baseline + 2 turn + 1 error stub = 4 total. No duplicates.
194  	if len(sess.Messages) != 4 {
195  		t.Fatalf("after hard error: want 4 msgs (1 baseline + 2 turn + 1 error), got %d", len(sess.Messages))
196  	}
197  	// Usage: 100 baseline + 70 current = 170. NOT 100+50+70=220 (double-count).
198  	if sess.Usage.InputTokens != 170 {
199  		t.Fatalf("after hard error: want 170 input tokens (baseline+current), got %d (double-counted)", sess.Usage.InputTokens)
200  	}
201  	if sess.Usage.LLMCalls != 3 {
202  		t.Fatalf("after hard error: want 3 LLMCalls (1 baseline + 2 current), got %d", sess.Usage.LLMCalls)
203  	}
204  	// Duplicate scan: exactly one tool_use and one tool_result.
205  	var toolUse, toolResult, errStub int
206  	for _, m := range sess.Messages {
207  		switch m.Content.Text() {
208  		case "[tool_use]":
209  			toolUse++
210  		case "[tool_result]":
211  			toolResult++
212  		case "Sorry, something failed.":
213  			errStub++
214  		}
215  	}
216  	if toolUse != 1 || toolResult != 1 || errStub != 1 {
217  		t.Fatalf("duplicate in hard-error path: tool_use=%d tool_result=%d err=%d",
218  			toolUse, toolResult, errStub)
219  	}
220  }
221  
222  // Regression for finding #3: usage survives mid-turn checkpoint + final
223  // save without being double-counted. Baseline + current accumulator is
224  // the authoritative value at every save.
225  func TestApplyTurnUsage_IdempotentAcrossCheckpointAndFinalSave(t *testing.T) {
226  	sess := &session.Session{Usage: &session.UsageSummary{
227  		InputTokens: 100, OutputTokens: 50, TotalTokens: 150, LLMCalls: 1,
228  	}}
229  	base := captureTurnBaseline(sess, "web", false)
230  	up := &usageStub{usage: agent.AccumulatedUsage{
231  		LLM: agent.TurnUsage{InputTokens: 20, OutputTokens: 10, TotalTokens: 30, LLMCalls: 1},
232  	}}
233  
234  	// First call: mid-turn checkpoint after first LLM call.
235  	applyTurnUsage(sess, up, base)
236  	if sess.Usage.InputTokens != 120 || sess.Usage.OutputTokens != 60 || sess.Usage.LLMCalls != 2 {
237  		t.Fatalf("after checkpoint: %+v", sess.Usage)
238  	}
239  
240  	// Second call: accumulator grew (second LLM call). Final save uses
241  	// the SAME baseline — must not double-count the first call.
242  	up.usage = agent.AccumulatedUsage{
243  		LLM: agent.TurnUsage{InputTokens: 40, OutputTokens: 20, TotalTokens: 60, LLMCalls: 2},
244  	}
245  	applyTurnUsage(sess, up, base)
246  	// Expected: baseline(100/50/1) + current(40/20/2) = 140/70/3
247  	if sess.Usage.InputTokens != 140 || sess.Usage.OutputTokens != 70 || sess.Usage.LLMCalls != 3 {
248  		t.Fatalf("after final save (double-count regression): %+v", sess.Usage)
249  	}
250  }
251  
252  func TestSessionInProgress_FlagCycles(t *testing.T) {
253  	sess := &session.Session{}
254  	if sess.InProgress {
255  		t.Fatal("fresh session should not be InProgress")
256  	}
257  	sess.InProgress = true
258  	sess.InProgress = false
259  	if sess.InProgress {
260  		t.Fatal("toggle off didn't clear")
261  	}
262  }