register_test.go
1 package tools 2 3 import ( 4 "encoding/json" 5 "net/http" 6 "net/http/httptest" 7 "testing" 8 9 "github.com/Kocoro-lab/ShanClaw/internal/agent" 10 "github.com/Kocoro-lab/ShanClaw/internal/client" 11 "github.com/Kocoro-lab/ShanClaw/internal/mcp" 12 mcpproto "github.com/mark3labs/mcp-go/mcp" 13 ) 14 15 func TestRegisterAll_WithServerTools(t *testing.T) { 16 serverTools := []client.ServerToolSchema{ 17 {Name: "web_search", Description: "Search the web"}, 18 {Name: "getStockBars", Description: "Get stock price bars"}, 19 } 20 21 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 22 w.Header().Set("Content-Type", "application/json") 23 json.NewEncoder(w).Encode(serverTools) 24 })) 25 defer server.Close() 26 27 gw := client.NewGatewayClient(server.URL, "") 28 reg, _, _, cleanup, err := RegisterAll(gw, nil) 29 defer cleanup() 30 if err != nil { 31 t.Fatalf("unexpected error: %v", err) 32 } 33 34 // Check local tools are registered 35 for _, name := range []string{"use_skill", "file_read", "file_write", "file_edit", "glob", "grep", "bash", "think", "directory_list", "http", "system_info", "clipboard", "notify", "process", "applescript", "accessibility", "ghostty", "browser", "screenshot", "computer", "wait_for", "schedule_create", "schedule_list", "schedule_update", "schedule_remove"} { 36 if _, ok := reg.Get(name); !ok { 37 t.Errorf("local tool %q not registered", name) 38 } 39 } 40 41 // Check server tools are registered 42 for _, name := range []string{"web_search", "getStockBars"} { 43 if _, ok := reg.Get(name); !ok { 44 t.Errorf("server tool %q not registered", name) 45 } 46 } 47 48 // Total: 26 local + 2 server = 28 49 schemas := reg.Schemas() 50 if len(schemas) != 28 { 51 t.Errorf("expected 28 tools, got %d", len(schemas)) 52 } 53 } 54 55 func TestRegisterAll_ServerUnavailable(t *testing.T) { 56 // Point to a closed server 57 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 58 server.Close() 59 60 gw := client.NewGatewayClient(server.URL, "") 61 reg, _, _, cleanup, err := RegisterAll(gw, nil) 62 defer cleanup() 63 if err == nil { 64 t.Error("expected warning error when server is unavailable") 65 } 66 67 // Local tools should still be registered 68 for _, name := range []string{"file_read", "bash", "glob"} { 69 if _, ok := reg.Get(name); !ok { 70 t.Errorf("local tool %q should still be registered", name) 71 } 72 } 73 74 schemas := reg.Schemas() 75 if len(schemas) != 26 { 76 t.Errorf("expected 26 local tools, got %d", len(schemas)) 77 } 78 } 79 80 func TestRegisterAll_LocalPriority(t *testing.T) { 81 // Server returns a tool named "bash" — should be skipped 82 serverTools := []client.ServerToolSchema{ 83 {Name: "bash", Description: "Server bash (should be skipped)"}, 84 {Name: "web_search", Description: "Search the web"}, 85 } 86 87 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 88 w.Header().Set("Content-Type", "application/json") 89 json.NewEncoder(w).Encode(serverTools) 90 })) 91 defer server.Close() 92 93 gw := client.NewGatewayClient(server.URL, "") 94 reg, _, _, cleanup, err := RegisterAll(gw, nil) 95 defer cleanup() 96 if err != nil { 97 t.Fatalf("unexpected error: %v", err) 98 } 99 100 // "bash" should be the local BashTool, not the server one 101 tool, ok := reg.Get("bash") 102 if !ok { 103 t.Fatal("bash tool not found") 104 } 105 if _, isServer := tool.(*ServerTool); isServer { 106 t.Error("bash should be local tool, not server tool") 107 } 108 109 // web_search should be server tool 110 tool, ok = reg.Get("web_search") 111 if !ok { 112 t.Fatal("web_search tool not found") 113 } 114 if _, isServer := tool.(*ServerTool); !isServer { 115 t.Error("web_search should be a server tool") 116 } 117 118 // 26 local + 1 server (bash skipped) = 27 119 schemas := reg.Schemas() 120 if len(schemas) != 27 { 121 t.Errorf("expected 27 tools, got %d", len(schemas)) 122 } 123 } 124 125 func TestRegisterServerTools_AllowlistFiltering(t *testing.T) { 126 serverTools := []client.ServerToolSchema{ 127 {Name: "web_search", Description: "Search the web"}, 128 {Name: "python_executor", Description: "Run Python in sandbox"}, 129 {Name: "calculator", Description: "Basic calculator"}, 130 {Name: "getStockBars", Description: "Get stock price bars"}, 131 {Name: "session_file_write", Description: "Write session file"}, 132 {Name: "some_future_tool", Description: "Unknown new tool"}, 133 } 134 135 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 136 w.Header().Set("Content-Type", "application/json") 137 json.NewEncoder(w).Encode(serverTools) 138 })) 139 defer server.Close() 140 141 gw := client.NewGatewayClient(server.URL, "") 142 reg, _, _, cleanup, err := RegisterAll(gw, nil) 143 defer cleanup() 144 if err != nil { 145 t.Fatalf("unexpected error: %v", err) 146 } 147 148 // Allowlisted tools should be registered 149 for _, name := range []string{"web_search", "getStockBars"} { 150 if _, ok := reg.Get(name); !ok { 151 t.Errorf("allowlisted tool %q should be registered", name) 152 } 153 } 154 155 // Non-allowlisted tools should be filtered out 156 for _, name := range []string{"python_executor", "calculator", "session_file_write", "some_future_tool"} { 157 if _, ok := reg.Get(name); ok { 158 t.Errorf("non-allowlisted tool %q should NOT be registered", name) 159 } 160 } 161 } 162 163 func TestExtractGatewayTools(t *testing.T) { 164 reg := agent.NewToolRegistry() 165 gw := client.NewGatewayClient("http://test", "key") 166 reg.Register(NewServerTool(client.ServerToolSchema{Name: "web_search", Description: "search"}, gw)) 167 reg.Register(&ThinkTool{}) 168 tools := ExtractGatewayTools(reg) 169 if len(tools) != 1 { 170 t.Fatalf("expected 1 gateway tool, got %d", len(tools)) 171 } 172 if tools[0].Info().Name != "web_search" { 173 t.Errorf("expected web_search, got %s", tools[0].Info().Name) 174 } 175 } 176 177 func TestExtractPostOverlays(t *testing.T) { 178 baseline := agent.NewToolRegistry() 179 baseline.Register(&ThinkTool{}) 180 181 full := baseline.Clone() 182 gw := client.NewGatewayClient("http://test", "key") 183 full.Register(NewServerTool(client.ServerToolSchema{Name: "web_search", Description: "search"}, gw)) 184 mgr := mcp.NewClientManager() 185 full.Register(NewMCPTool("playwright", mcpproto.Tool{Name: "browser_navigate"}, mgr)) 186 full.Register(&NotifyTool{}) // a local overlay 187 188 overlays := ExtractPostOverlays(full, baseline) 189 if len(overlays) != 1 { 190 t.Fatalf("expected 1 overlay, got %d", len(overlays)) 191 } 192 if overlays[0].Info().Name != "notify" { 193 t.Errorf("expected notify, got %s", overlays[0].Info().Name) 194 } 195 } 196 197 func TestRebuildRegistryForHealth_PlaywrightHealthy(t *testing.T) { 198 baseline := agent.NewToolRegistry() 199 baseline.Register(&ThinkTool{}) 200 baseline.Register(&BrowserTool{}) 201 202 healthStates := map[string]mcp.ServerHealth{ 203 "playwright": {State: mcp.StateHealthy}, 204 } 205 206 mgr := mcp.NewClientManager() 207 mgr.SeedToolCache("playwright", []mcp.RemoteTool{ 208 {ServerName: "playwright", Tool: mcpproto.Tool{Name: "browser_navigate"}}, 209 }) 210 211 reg := RebuildRegistryForHealth(baseline, nil, nil, healthStates, mgr, nil) 212 if _, ok := reg.Get("browser"); ok { 213 t.Error("legacy browser should be removed when Playwright is healthy") 214 } 215 if _, ok := reg.Get("browser_navigate"); !ok { 216 t.Error("browser_navigate should be registered from healthy Playwright") 217 } 218 } 219 220 func TestRebuildRegistryForHealth_PlaywrightDisconnected(t *testing.T) { 221 baseline := agent.NewToolRegistry() 222 baseline.Register(&ThinkTool{}) 223 baseline.Register(&BrowserTool{}) 224 225 healthStates := map[string]mcp.ServerHealth{ 226 "playwright": {State: mcp.StateDisconnected}, 227 } 228 229 mgr := mcp.NewClientManager() 230 mgr.SeedToolCache("playwright", []mcp.RemoteTool{ 231 {ServerName: "playwright", Tool: mcpproto.Tool{Name: "browser_navigate"}}, 232 }) 233 234 reg := RebuildRegistryForHealth(baseline, nil, nil, healthStates, mgr, nil) 235 // Disconnected Playwright tools are included from cache for on-demand reconnect. 236 if _, ok := reg.Get("browser_navigate"); !ok { 237 t.Error("browser_navigate should be present from cache even when disconnected") 238 } 239 // Legacy browser is removed when Playwright tools are present (even disconnected). 240 if _, ok := reg.Get("browser"); ok { 241 t.Error("legacy browser should be removed when Playwright tools are present") 242 } 243 } 244 245 func TestRebuildRegistryForHealth_GatewayAndPostOverlays(t *testing.T) { 246 baseline := agent.NewToolRegistry() 247 baseline.Register(&ThinkTool{}) 248 249 gw := client.NewGatewayClient("http://test", "key") 250 gatewayOverlay := []agent.Tool{ 251 NewServerTool(client.ServerToolSchema{Name: "web_search", Description: "search"}, gw), 252 } 253 postOverlays := []agent.Tool{ 254 &NotifyTool{}, 255 } 256 257 reg := RebuildRegistryForHealth(baseline, gatewayOverlay, postOverlays, nil, nil, nil) 258 if _, ok := reg.Get("think"); !ok { 259 t.Error("baseline tool 'think' should be present") 260 } 261 if _, ok := reg.Get("web_search"); !ok { 262 t.Error("gateway overlay 'web_search' should be present") 263 } 264 if _, ok := reg.Get("notify"); !ok { 265 t.Error("post overlay 'notify' should be present") 266 } 267 }