/ internal / agent / usage_test.go
usage_test.go
 1  package agent
 2  
 3  import (
 4  	"testing"
 5  
 6  	"github.com/Kocoro-lab/ShanClaw/internal/client"
 7  )
 8  
 9  func TestLLMUsageDelta_NormalizesSplitCacheCreation(t *testing.T) {
10  	delta := LLMUsageDelta(client.Usage{
11  		InputTokens:           120,
12  		OutputTokens:          30,
13  		CacheReadTokens:       40,
14  		CacheCreation5mTokens: 100,
15  		CacheCreation1hTokens: 200,
16  	}, "claude-test")
17  
18  	if delta.TotalTokens != 150 {
19  		t.Fatalf("expected total tokens 150, got %d", delta.TotalTokens)
20  	}
21  	if delta.CacheCreationTokens != 300 {
22  		t.Fatalf("expected legacy cache creation total 300, got %d", delta.CacheCreationTokens)
23  	}
24  	if delta.CacheCreation5mTokens != 100 || delta.CacheCreation1hTokens != 200 {
25  		t.Fatalf("expected split cache creation 100/200, got %d/%d", delta.CacheCreation5mTokens, delta.CacheCreation1hTokens)
26  	}
27  	if delta.Model != "claude-test" {
28  		t.Fatalf("expected model claude-test, got %q", delta.Model)
29  	}
30  	if delta.LLMCalls != 1 {
31  		t.Fatalf("expected 1 LLM call, got %d", delta.LLMCalls)
32  	}
33  }
34  
35  func TestUsageAccumulator_AccumulatesSplitCacheCreation(t *testing.T) {
36  	var acc UsageAccumulator
37  	acc.Add(LLMUsageDelta(client.Usage{
38  		InputTokens:           90,
39  		OutputTokens:          10,
40  		CacheCreation5mTokens: 25,
41  		CacheCreation1hTokens: 75,
42  	}, "claude-test"))
43  
44  	snap := acc.Snapshot()
45  	if snap.LLM.CacheCreationTokens != 100 {
46  		t.Fatalf("expected legacy cache creation total 100, got %d", snap.LLM.CacheCreationTokens)
47  	}
48  	if snap.LLM.CacheCreation5mTokens != 25 || snap.LLM.CacheCreation1hTokens != 75 {
49  		t.Fatalf("expected split cache creation 25/75, got %d/%d", snap.LLM.CacheCreation5mTokens, snap.LLM.CacheCreation1hTokens)
50  	}
51  	if snap.LLM.TotalTokens != 100 {
52  		t.Fatalf("expected total tokens 100, got %d", snap.LLM.TotalTokens)
53  	}
54  }
55  
56  func TestTotalPromptTokens(t *testing.T) {
57  	tests := []struct {
58  		name string
59  		u    client.Usage
60  		want int
61  	}{
62  		{
63  			name: "only non-cached input",
64  			u:    client.Usage{InputTokens: 1200},
65  			want: 1200,
66  		},
67  		{
68  			name: "warm cache: small input, large cache read",
69  			u:    client.Usage{InputTokens: 500, CacheReadTokens: 90000},
70  			want: 90500,
71  		},
72  		{
73  			name: "cache miss filled a new cache: small input, large cache creation",
74  			u:    client.Usage{InputTokens: 300, CacheCreationTokens: 45000},
75  			want: 45300,
76  		},
77  		{
78  			name: "mixed: all three populated",
79  			u:    client.Usage{InputTokens: 800, CacheReadTokens: 60000, CacheCreationTokens: 5000},
80  			want: 65800,
81  		},
82  		{
83  			name: "zero usage",
84  			u:    client.Usage{},
85  			want: 0,
86  		},
87  	}
88  	for _, tt := range tests {
89  		t.Run(tt.name, func(t *testing.T) {
90  			got := totalPromptTokens(tt.u)
91  			if got != tt.want {
92  				t.Errorf("totalPromptTokens(%+v) = %d, want %d", tt.u, got, tt.want)
93  			}
94  		})
95  	}
96  }