/ internal / agent / deferred.go
deferred.go
  1  package agent
  2  
  3  import (
  4  	"context"
  5  	"encoding/json"
  6  	"fmt"
  7  	"sort"
  8  	"strings"
  9  
 10  	"github.com/Kocoro-lab/ShanClaw/internal/client"
 11  )
 12  
 13  // toolSearchTool is a meta-tool that loads full schemas for deferred tools on demand.
 14  // Defined in the agent package to avoid an import cycle with internal/tools.
 15  type toolSearchTool struct {
 16  	registry *ToolRegistry
 17  	deferred map[string]bool
 18  }
 19  
 20  // newToolSearchTool creates a tool_search scoped to the given deferred tool names.
 21  func newToolSearchTool(reg *ToolRegistry, deferred map[string]bool) *toolSearchTool {
 22  	return &toolSearchTool{registry: reg, deferred: deferred}
 23  }
 24  
 25  func (t *toolSearchTool) Info() ToolInfo {
 26  	return ToolInfo{
 27  		Name:        "tool_search",
 28  		Description: "Load deferred tool schemas so you can call them in this same request. After calling tool_search, immediately continue the task using the loaded tools — do not stop or ask the user to proceed. Use \"select:name1,name2\" for exact lookup or a keyword to search by name/description.",
 29  		Parameters: map[string]any{
 30  			"type": "object",
 31  			"properties": map[string]any{
 32  				"query": map[string]any{
 33  					"type":        "string",
 34  					"description": "Either \"select:name1,name2\" for exact match or a keyword to search deferred tools.",
 35  				},
 36  			},
 37  		},
 38  		Required: []string{"query"},
 39  	}
 40  }
 41  
 42  func (t *toolSearchTool) RequiresApproval() bool     { return false }
 43  func (t *toolSearchTool) IsReadOnlyCall(string) bool { return true }
 44  
 45  func (t *toolSearchTool) Run(_ context.Context, argsJSON string) (ToolResult, error) {
 46  	var args struct {
 47  		Query string `json:"query"`
 48  	}
 49  	if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
 50  		return ValidationError("invalid arguments: " + err.Error()), nil
 51  	}
 52  	if args.Query == "" {
 53  		return ValidationError("query is required"), nil
 54  	}
 55  
 56  	var matched []string
 57  
 58  	if strings.HasPrefix(args.Query, "select:") {
 59  		names := strings.Split(strings.TrimPrefix(args.Query, "select:"), ",")
 60  		for _, name := range names {
 61  			name = strings.TrimSpace(name)
 62  			if name != "" && t.deferred[name] {
 63  				matched = append(matched, name)
 64  			}
 65  		}
 66  	} else {
 67  		query := strings.ToLower(args.Query)
 68  		for name := range t.deferred {
 69  			tool, ok := t.registry.Get(name)
 70  			if !ok {
 71  				continue
 72  			}
 73  			info := tool.Info()
 74  			if strings.Contains(strings.ToLower(info.Name), query) ||
 75  				strings.Contains(strings.ToLower(info.Description), query) {
 76  				matched = append(matched, name)
 77  			}
 78  		}
 79  		sort.Strings(matched)
 80  	}
 81  	matched = expandDeferredFamilyCore(t.registry, t.deferred, matched)
 82  
 83  	// Build structured tool_reference blocks for the new protocol path.
 84  	// Zero matches → zero blocks (loop.go falls back to the Content string).
 85  	var blocks []client.ContentBlock
 86  	for _, name := range matched {
 87  		blocks = append(blocks, client.ContentBlock{
 88  			Type:     "tool_reference",
 89  			ToolName: name,
 90  		})
 91  	}
 92  
 93  	// Legacy Content string: preserved as the fallback path for non-supporting
 94  	// backends (Ollama, pre-3.1 shannon-cloud gateway). Contains the LOADED:
 95  	// header + full schema JSON so the model can still discover tools when
 96  	// the tool_reference protocol is unavailable.
 97  	var sb strings.Builder
 98  	sb.WriteString("LOADED:")
 99  	sb.WriteString(strings.Join(matched, ","))
100  
101  	if len(matched) == 0 {
102  		sb.WriteString("\nNo matching deferred tools found.")
103  	} else {
104  		sb.WriteString("\nSchemas loaded. Call these tools now to continue the user's task — do not stop or describe what was loaded.")
105  		schemas := t.registry.FullSchemas(matched)
106  		for i, s := range schemas {
107  			schemaJSON, _ := json.MarshalIndent(s, "", "  ")
108  			sb.WriteString(fmt.Sprintf("\n\n## %s\n%s", matched[i], string(schemaJSON)))
109  		}
110  	}
111  
112  	return ToolResult{
113  		Content:       sb.String(),
114  		ContentBlocks: blocks,
115  	}, nil
116  }
117  
118  func expandDeferredFamilyCore(reg *ToolRegistry, deferred map[string]bool, matched []string) []string {
119  	if len(matched) == 0 {
120  		return nil
121  	}
122  
123  	selected := make(map[string]bool, len(matched))
124  	for _, name := range matched {
125  		if name != "" && deferred[name] {
126  			selected[name] = true
127  		}
128  		family := toolFamily(name)
129  		spec, ok := FamilyRegistry[family]
130  		if !ok {
131  			continue
132  		}
133  		for _, coreName := range spec.Core {
134  			if deferred[coreName] {
135  				selected[coreName] = true
136  			}
137  		}
138  	}
139  
140  	expanded := make([]string, 0, len(selected))
141  	for _, name := range reg.SortedNames() {
142  		if selected[name] {
143  			expanded = append(expanded, name)
144  		}
145  	}
146  	return expanded
147  }
148  
149  // parseLoadedHeader extracts tool names from the LOADED: header line
150  // in a tool_search result. Returns nil if no valid header found.
151  func parseLoadedHeader(content string) []string {
152  	if !strings.HasPrefix(content, "LOADED:") {
153  		return nil
154  	}
155  	line := content
156  	if idx := strings.Index(content, "\n"); idx >= 0 {
157  		line = content[:idx]
158  	}
159  	nameStr := strings.TrimPrefix(line, "LOADED:")
160  	nameStr = strings.TrimSpace(nameStr)
161  	if nameStr == "" {
162  		return nil
163  	}
164  	return strings.Split(nameStr, ",")
165  }
166  
167  // rebuildSchemas produces a deterministic tool schema list by iterating
168  // the registry's canonical source-aware order (SortedNames: local alpha →
169  // MCP alpha → gateway alpha) and including tools that are either in base
170  // or loaded. This preserves cache stability.
171  func rebuildSchemas(reg *ToolRegistry, baseSchemas []client.Tool, loaded map[string]client.Tool) []client.Tool {
172  	baseNames := make(map[string]bool, len(baseSchemas))
173  	for _, s := range baseSchemas {
174  		baseNames[schemaToolName(s)] = true
175  	}
176  
177  	result := make([]client.Tool, 0, len(baseSchemas)+len(loaded))
178  	for _, name := range reg.SortedNames() {
179  		if baseNames[name] {
180  			if t, ok := reg.Get(name); ok {
181  				result = append(result, buildToolSchema(t))
182  			}
183  		} else if s, ok := loaded[name]; ok {
184  			result = append(result, s)
185  		}
186  	}
187  	return result
188  }
189  
190  // liveToolNames returns tool names in the same order as the live schema list.
191  func liveToolNames(schemas []client.Tool) []string {
192  	names := make([]string, 0, len(schemas))
193  	for _, schema := range schemas {
194  		name := schemaToolName(schema)
195  		if name != "" {
196  			names = append(names, name)
197  		}
198  	}
199  	return names
200  }
201  
202  // schemaToolName extracts the tool name from a client.Tool.
203  func schemaToolName(t client.Tool) string {
204  	if t.Function.Name != "" {
205  		return t.Function.Name
206  	}
207  	return t.Name
208  }
209  
210  // buildLocalOnlySchemas returns sorted schemas for local tools only.
211  func buildLocalOnlySchemas(reg *ToolRegistry) []client.Tool {
212  	local, _, _ := reg.partitionBySource()
213  	sort.Strings(local)
214  	schemas := make([]client.Tool, 0, len(local))
215  	for _, name := range local {
216  		if t, ok := reg.Get(name); ok {
217  			schemas = append(schemas, buildToolSchema(t))
218  		}
219  	}
220  	return schemas
221  }
222  
223  // deferredToolNames returns the set of non-local tool names (MCP + gateway).
224  func deferredToolNames(reg *ToolRegistry) map[string]bool {
225  	_, mcp, gw := reg.partitionBySource()
226  	names := make(map[string]bool, len(mcp)+len(gw))
227  	for _, n := range mcp {
228  		names[n] = true
229  	}
230  	for _, n := range gw {
231  		names[n] = true
232  	}
233  	return names
234  }
235  
236  // preseedDeferredSchemas filters the session working set down to schemas that
237  // are still deferred in the current effective registry.
238  func preseedDeferredSchemas(ws *WorkingSet, deferred map[string]bool) map[string]client.Tool {
239  	loaded := make(map[string]client.Tool)
240  	if ws == nil || len(deferred) == 0 {
241  		return loaded
242  	}
243  	for name, schema := range ws.Schemas() {
244  		if deferred[name] {
245  			loaded[name] = schema
246  		}
247  	}
248  	return loaded
249  }
250  
251  // remainingDeferredNames removes already-warmed schemas from the deferred set.
252  func remainingDeferredNames(deferred map[string]bool, loaded map[string]client.Tool) map[string]bool {
253  	remaining := make(map[string]bool, len(deferred))
254  	for name := range deferred {
255  		if _, ok := loaded[name]; ok {
256  			continue
257  		}
258  		remaining[name] = true
259  	}
260  	return remaining
261  }
262  
263  // modelSupportsToolRef reports whether the configured model supports the
264  // defer_loading + tool_reference protocol. Sonnet 4.0+ / Opus 4.0+ only,
265  // per Anthropic tool-search docs (Haiku excluded, pre-4 excluded).
266  // Non-Anthropic providers always fall back to the legacy rebuildSchemas path.
267  func modelSupportsToolRef(modelID string) bool {
268  	m := strings.ToLower(modelID)
269  	if !strings.Contains(m, "claude") {
270  		return false
271  	}
272  	if strings.Contains(m, "haiku") {
273  		return false
274  	}
275  	// claude-sonnet-4*, claude-opus-4*, claude-sonnet-5*, etc.
276  	return strings.Contains(m, "sonnet-4") ||
277  		strings.Contains(m, "opus-4") ||
278  		strings.Contains(m, "sonnet-5") ||
279  		strings.Contains(m, "opus-5")
280  }
281  
282  // hasAnyNonDeferred returns true if at least one tool in the slice is NOT deferred.
283  // Anthropic rejects requests where every tool has defer_loading: true (400 error).
284  // tool_search itself is always non-deferred (registered outside the defer set),
285  // so this invariant holds whenever deferred mode is active.
286  func hasAnyNonDeferred(tools []client.Tool) bool {
287  	for _, t := range tools {
288  		if !t.DeferLoading {
289  			return true
290  		}
291  	}
292  	return false
293  }
294  
295  // buildFullSchemasWithDefer emits the complete tools array (local + MCP + gateway)
296  // with defer_loading: true on the cold set. Anthropic strips deferred entries from
297  // the cache-key hash before caching, so tools[] stays byte-stable across sessions
298  // while retaining full input_schema for tool_search's BM25/regex matching.
299  //
300  // Caller is responsible for ensuring at least one tool (typically tool_search
301  // itself) is non-deferred — verify with hasAnyNonDeferred.
302  func buildFullSchemasWithDefer(reg *ToolRegistry, cold map[string]bool) []client.Tool {
303  	out := make([]client.Tool, 0)
304  	for _, name := range reg.SortedNames() {
305  		tool, ok := reg.Get(name)
306  		if !ok {
307  			continue
308  		}
309  		s := buildToolSchema(tool)
310  		if cold[name] {
311  			s.DeferLoading = true
312  		}
313  		out = append(out, s)
314  	}
315  	return out
316  }
317  
318  // deferredToolSummariesForNames returns sorted summaries for the named deferred tools.
319  func deferredToolSummariesForNames(reg *ToolRegistry, names map[string]bool) []ToolSummary {
320  	if len(names) == 0 {
321  		return nil
322  	}
323  
324  	all := make([]string, 0, len(names))
325  	for name := range names {
326  		all = append(all, name)
327  	}
328  	sort.Strings(all)
329  
330  	summaries := make([]ToolSummary, 0, len(all))
331  	for _, name := range all {
332  		if t, ok := reg.Get(name); ok {
333  			info := t.Info()
334  			summaries = append(summaries, ToolSummary{Name: info.Name, Description: info.Description})
335  		}
336  	}
337  	return summaries
338  }