/ internal / agent / usage.go
usage.go
  1  package agent
  2  
  3  import (
  4  	"context"
  5  	"sync"
  6  
  7  	"github.com/Kocoro-lab/ShanClaw/internal/client"
  8  )
  9  
 10  // usageEmitKey is a context value used by the agent loop to expose a per-run
 11  // usage sink to tools. Tools that report their own cost (gateway tools calling
 12  // xAI/Grok, SerpAPI, etc.) call EmitUsage with the converted TurnUsage; the
 13  // handler attached to the loop receives it like any other OnUsage event.
 14  type usageEmitKey struct{}
 15  
 16  // WithUsageEmit returns a context carrying a usage emitter. Callers (typically
 17  // the agent loop at tool dispatch time) wrap their ctx with this before
 18  // handing it to a tool's Run method.
 19  func WithUsageEmit(ctx context.Context, emit func(TurnUsage)) context.Context {
 20  	if emit == nil {
 21  		return ctx
 22  	}
 23  	return context.WithValue(ctx, usageEmitKey{}, emit)
 24  }
 25  
 26  // EmitUsage forwards a usage report through the ctx-attached emitter.
 27  // Safe to call from any tool — no-op if the ctx has no emitter.
 28  func EmitUsage(ctx context.Context, u TurnUsage) {
 29  	if emit, ok := ctx.Value(usageEmitKey{}).(func(TurnUsage)); ok && emit != nil {
 30  		emit(u)
 31  	}
 32  }
 33  
 34  // UsageProvider is the optional interface a handler can satisfy to expose
 35  // its accumulated usage. Callers (daemon runner, CLI, TUI) type-assert and
 36  // read Usage() at end-of-run for persistence/display. Returns the combined
 37  // LLM + tool breakdown so callers can report each independently.
 38  type UsageProvider interface {
 39  	Usage() AccumulatedUsage
 40  }
 41  
 42  // AccumulatedUsage is the combined snapshot returned by UsageAccumulator.
 43  // LLM and tool billing are tracked separately so callers can report the
 44  // token breakdown without mixing model tokens with gateway tool synthetic
 45  // counts (e.g. SERP tools' 7500-token-per-query billing abstraction).
 46  //
 47  // The invariant input_tokens+output_tokens == total_tokens holds on LLM
 48  // only; ToolCostUSD/ToolTokens are additive on top for "total spend"
 49  // summaries but should never be folded into the LLM token fields.
 50  type AccumulatedUsage struct {
 51  	LLM         TurnUsage // model-only: input/output/cache tokens, LLMCalls, Model
 52  	ToolCalls   int       // count of gateway-tool emissions (tools that billed)
 53  	ToolTokens  int       // sum of gateway-tool reported tokens (may be synthetic)
 54  	ToolCostUSD float64   // sum of gateway-tool cost_usd
 55  }
 56  
 57  // TotalCostUSD returns the combined LLM + tool cost.
 58  func (a AccumulatedUsage) TotalCostUSD() float64 {
 59  	return a.LLM.CostUSD + a.ToolCostUSD
 60  }
 61  
 62  // UsageAccumulator is a thread-safe collector that handlers embed to
 63  // aggregate TurnUsage events across a run/session. It separates LLM
 64  // events (agent loop + cloud_delegate nested calls, signalled by
 65  // LLMCalls > 0) from gateway tool billing events (server.go emissions,
 66  // signalled by LLMCalls == 0) so the caller can report each independently.
 67  //
 68  // Typical flow:
 69  //  1. Handler embeds an UsageAccumulator (value or pointer).
 70  //  2. Handler's OnUsage(u TurnUsage) calls accumulator.Add(u).
 71  //  3. Caller queries Snapshot() at end-of-run and persists it.
 72  //
 73  // The zero value is ready to use.
 74  type UsageAccumulator struct {
 75  	mu          sync.Mutex
 76  	llm         TurnUsage
 77  	toolCalls   int
 78  	toolTokens  int
 79  	toolCostUSD float64
 80  }
 81  
 82  // LLMUsageDelta converts a provider usage payload into the normalized LLM-side
 83  // TurnUsage delta used by handlers, session persistence, and cloud_delegate.
 84  func LLMUsageDelta(u client.Usage, model string) TurnUsage {
 85  	u = u.Normalized()
 86  	return TurnUsage{
 87  		InputTokens:           u.InputTokens,
 88  		OutputTokens:          u.OutputTokens,
 89  		TotalTokens:           u.TotalTokens,
 90  		CostUSD:               u.CostUSD,
 91  		LLMCalls:              1,
 92  		Model:                 model,
 93  		CacheReadTokens:       u.CacheReadTokens,
 94  		CacheCreationTokens:   u.CacheCreationTokens,
 95  		CacheCreation5mTokens: u.CacheCreation5mTokens,
 96  		CacheCreation1hTokens: u.CacheCreation1hTokens,
 97  	}
 98  }
 99  
100  // Add merges an incoming TurnUsage delta into the running total, routing
101  // tool-only emissions (LLMCalls == 0) to the separate tool counters so
102  // LLM token fields stay consistent (input+output == total).
103  func (a *UsageAccumulator) Add(u TurnUsage) {
104  	a.mu.Lock()
105  	defer a.mu.Unlock()
106  	if u.LLMCalls == 0 {
107  		// Tool-only emission: server.go reports gateway-tool billing here.
108  		// Keep it out of LLM token fields so total_tokens remains explainable
109  		// as input_tokens + output_tokens on the LLM side.
110  		a.toolCalls++
111  		a.toolTokens += u.TotalTokens
112  		a.toolCostUSD += u.CostUSD
113  		return
114  	}
115  	a.llm.InputTokens += u.InputTokens
116  	a.llm.OutputTokens += u.OutputTokens
117  	a.llm.TotalTokens += u.TotalTokens
118  	a.llm.CostUSD += u.CostUSD
119  	a.llm.CacheReadTokens += u.CacheReadTokens
120  	a.llm.CacheCreationTokens += u.CacheCreationTokens
121  	a.llm.CacheCreation5mTokens += u.CacheCreation5mTokens
122  	a.llm.CacheCreation1hTokens += u.CacheCreation1hTokens
123  	a.llm.LLMCalls += u.LLMCalls
124  	if u.Model != "" {
125  		a.llm.Model = u.Model
126  	}
127  }
128  
129  // Snapshot returns the current cumulative totals split into LLM and tool
130  // buckets. Safe to call from any goroutine.
131  func (a *UsageAccumulator) Snapshot() AccumulatedUsage {
132  	a.mu.Lock()
133  	defer a.mu.Unlock()
134  	return AccumulatedUsage{
135  		LLM:         a.llm,
136  		ToolCalls:   a.toolCalls,
137  		ToolTokens:  a.toolTokens,
138  		ToolCostUSD: a.toolCostUSD,
139  	}
140  }
141  
142  // totalPromptTokens returns the effective prompt size for compaction decisions.
143  // Anthropic's input_tokens excludes cached tokens (they're reported separately
144  // as cache_read_input_tokens / cache_creation_input_tokens but still count
145  // against the model's context window). Using raw input_tokens as the
146  // compaction gate means warm-cache sessions never trip the threshold.
147  func totalPromptTokens(u client.Usage) int {
148  	return u.InputTokens + u.CacheReadTokens + u.CacheCreationTokens
149  }
150  
151  // Reset clears accumulated totals. Use between independent runs in a
152  // long-lived handler (e.g. daemon per-message handler reuse).
153  func (a *UsageAccumulator) Reset() {
154  	a.mu.Lock()
155  	defer a.mu.Unlock()
156  	a.llm = TurnUsage{}
157  	a.toolCalls = 0
158  	a.toolTokens = 0
159  	a.toolCostUSD = 0
160  }