/ internal / tools / register_test.go
register_test.go
  1  package tools
  2  
  3  import (
  4  	"encoding/json"
  5  	"net/http"
  6  	"net/http/httptest"
  7  	"testing"
  8  
  9  	"github.com/Kocoro-lab/ShanClaw/internal/agent"
 10  	"github.com/Kocoro-lab/ShanClaw/internal/client"
 11  	"github.com/Kocoro-lab/ShanClaw/internal/mcp"
 12  	mcpproto "github.com/mark3labs/mcp-go/mcp"
 13  )
 14  
 15  func TestRegisterAll_WithServerTools(t *testing.T) {
 16  	serverTools := []client.ServerToolSchema{
 17  		{Name: "web_search", Description: "Search the web"},
 18  		{Name: "getStockBars", Description: "Get stock price bars"},
 19  	}
 20  
 21  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 22  		w.Header().Set("Content-Type", "application/json")
 23  		json.NewEncoder(w).Encode(serverTools)
 24  	}))
 25  	defer server.Close()
 26  
 27  	gw := client.NewGatewayClient(server.URL, "")
 28  	reg, _, _, cleanup, err := RegisterAll(gw, nil)
 29  	defer cleanup()
 30  	if err != nil {
 31  		t.Fatalf("unexpected error: %v", err)
 32  	}
 33  
 34  	// Check local tools are registered
 35  	for _, name := range []string{"use_skill", "file_read", "file_write", "file_edit", "glob", "grep", "bash", "think", "directory_list", "http", "system_info", "clipboard", "notify", "process", "applescript", "accessibility", "ghostty", "browser", "screenshot", "computer", "wait_for", "schedule_create", "schedule_list", "schedule_update", "schedule_remove"} {
 36  		if _, ok := reg.Get(name); !ok {
 37  			t.Errorf("local tool %q not registered", name)
 38  		}
 39  	}
 40  
 41  	// Check server tools are registered
 42  	for _, name := range []string{"web_search", "getStockBars"} {
 43  		if _, ok := reg.Get(name); !ok {
 44  			t.Errorf("server tool %q not registered", name)
 45  		}
 46  	}
 47  
 48  	// Total: 26 local + 2 server = 28
 49  	schemas := reg.Schemas()
 50  	if len(schemas) != 28 {
 51  		t.Errorf("expected 28 tools, got %d", len(schemas))
 52  	}
 53  }
 54  
 55  func TestRegisterAll_ServerUnavailable(t *testing.T) {
 56  	// Point to a closed server
 57  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
 58  	server.Close()
 59  
 60  	gw := client.NewGatewayClient(server.URL, "")
 61  	reg, _, _, cleanup, err := RegisterAll(gw, nil)
 62  	defer cleanup()
 63  	if err == nil {
 64  		t.Error("expected warning error when server is unavailable")
 65  	}
 66  
 67  	// Local tools should still be registered
 68  	for _, name := range []string{"file_read", "bash", "glob"} {
 69  		if _, ok := reg.Get(name); !ok {
 70  			t.Errorf("local tool %q should still be registered", name)
 71  		}
 72  	}
 73  
 74  	schemas := reg.Schemas()
 75  	if len(schemas) != 26 {
 76  		t.Errorf("expected 26 local tools, got %d", len(schemas))
 77  	}
 78  }
 79  
 80  func TestRegisterAll_LocalPriority(t *testing.T) {
 81  	// Server returns a tool named "bash" — should be skipped
 82  	serverTools := []client.ServerToolSchema{
 83  		{Name: "bash", Description: "Server bash (should be skipped)"},
 84  		{Name: "web_search", Description: "Search the web"},
 85  	}
 86  
 87  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 88  		w.Header().Set("Content-Type", "application/json")
 89  		json.NewEncoder(w).Encode(serverTools)
 90  	}))
 91  	defer server.Close()
 92  
 93  	gw := client.NewGatewayClient(server.URL, "")
 94  	reg, _, _, cleanup, err := RegisterAll(gw, nil)
 95  	defer cleanup()
 96  	if err != nil {
 97  		t.Fatalf("unexpected error: %v", err)
 98  	}
 99  
100  	// "bash" should be the local BashTool, not the server one
101  	tool, ok := reg.Get("bash")
102  	if !ok {
103  		t.Fatal("bash tool not found")
104  	}
105  	if _, isServer := tool.(*ServerTool); isServer {
106  		t.Error("bash should be local tool, not server tool")
107  	}
108  
109  	// web_search should be server tool
110  	tool, ok = reg.Get("web_search")
111  	if !ok {
112  		t.Fatal("web_search tool not found")
113  	}
114  	if _, isServer := tool.(*ServerTool); !isServer {
115  		t.Error("web_search should be a server tool")
116  	}
117  
118  	// 26 local + 1 server (bash skipped) = 27
119  	schemas := reg.Schemas()
120  	if len(schemas) != 27 {
121  		t.Errorf("expected 27 tools, got %d", len(schemas))
122  	}
123  }
124  
125  func TestRegisterServerTools_AllowlistFiltering(t *testing.T) {
126  	serverTools := []client.ServerToolSchema{
127  		{Name: "web_search", Description: "Search the web"},
128  		{Name: "python_executor", Description: "Run Python in sandbox"},
129  		{Name: "calculator", Description: "Basic calculator"},
130  		{Name: "getStockBars", Description: "Get stock price bars"},
131  		{Name: "session_file_write", Description: "Write session file"},
132  		{Name: "some_future_tool", Description: "Unknown new tool"},
133  	}
134  
135  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
136  		w.Header().Set("Content-Type", "application/json")
137  		json.NewEncoder(w).Encode(serverTools)
138  	}))
139  	defer server.Close()
140  
141  	gw := client.NewGatewayClient(server.URL, "")
142  	reg, _, _, cleanup, err := RegisterAll(gw, nil)
143  	defer cleanup()
144  	if err != nil {
145  		t.Fatalf("unexpected error: %v", err)
146  	}
147  
148  	// Allowlisted tools should be registered
149  	for _, name := range []string{"web_search", "getStockBars"} {
150  		if _, ok := reg.Get(name); !ok {
151  			t.Errorf("allowlisted tool %q should be registered", name)
152  		}
153  	}
154  
155  	// Non-allowlisted tools should be filtered out
156  	for _, name := range []string{"python_executor", "calculator", "session_file_write", "some_future_tool"} {
157  		if _, ok := reg.Get(name); ok {
158  			t.Errorf("non-allowlisted tool %q should NOT be registered", name)
159  		}
160  	}
161  }
162  
163  func TestExtractGatewayTools(t *testing.T) {
164  	reg := agent.NewToolRegistry()
165  	gw := client.NewGatewayClient("http://test", "key")
166  	reg.Register(NewServerTool(client.ServerToolSchema{Name: "web_search", Description: "search"}, gw))
167  	reg.Register(&ThinkTool{})
168  	tools := ExtractGatewayTools(reg)
169  	if len(tools) != 1 {
170  		t.Fatalf("expected 1 gateway tool, got %d", len(tools))
171  	}
172  	if tools[0].Info().Name != "web_search" {
173  		t.Errorf("expected web_search, got %s", tools[0].Info().Name)
174  	}
175  }
176  
177  func TestExtractPostOverlays(t *testing.T) {
178  	baseline := agent.NewToolRegistry()
179  	baseline.Register(&ThinkTool{})
180  
181  	full := baseline.Clone()
182  	gw := client.NewGatewayClient("http://test", "key")
183  	full.Register(NewServerTool(client.ServerToolSchema{Name: "web_search", Description: "search"}, gw))
184  	mgr := mcp.NewClientManager()
185  	full.Register(NewMCPTool("playwright", mcpproto.Tool{Name: "browser_navigate"}, mgr))
186  	full.Register(&NotifyTool{}) // a local overlay
187  
188  	overlays := ExtractPostOverlays(full, baseline)
189  	if len(overlays) != 1 {
190  		t.Fatalf("expected 1 overlay, got %d", len(overlays))
191  	}
192  	if overlays[0].Info().Name != "notify" {
193  		t.Errorf("expected notify, got %s", overlays[0].Info().Name)
194  	}
195  }
196  
197  func TestRebuildRegistryForHealth_PlaywrightHealthy(t *testing.T) {
198  	baseline := agent.NewToolRegistry()
199  	baseline.Register(&ThinkTool{})
200  	baseline.Register(&BrowserTool{})
201  
202  	healthStates := map[string]mcp.ServerHealth{
203  		"playwright": {State: mcp.StateHealthy},
204  	}
205  
206  	mgr := mcp.NewClientManager()
207  	mgr.SeedToolCache("playwright", []mcp.RemoteTool{
208  		{ServerName: "playwright", Tool: mcpproto.Tool{Name: "browser_navigate"}},
209  	})
210  
211  	reg := RebuildRegistryForHealth(baseline, nil, nil, healthStates, mgr, nil)
212  	if _, ok := reg.Get("browser"); ok {
213  		t.Error("legacy browser should be removed when Playwright is healthy")
214  	}
215  	if _, ok := reg.Get("browser_navigate"); !ok {
216  		t.Error("browser_navigate should be registered from healthy Playwright")
217  	}
218  }
219  
220  func TestRebuildRegistryForHealth_PlaywrightDisconnected(t *testing.T) {
221  	baseline := agent.NewToolRegistry()
222  	baseline.Register(&ThinkTool{})
223  	baseline.Register(&BrowserTool{})
224  
225  	healthStates := map[string]mcp.ServerHealth{
226  		"playwright": {State: mcp.StateDisconnected},
227  	}
228  
229  	mgr := mcp.NewClientManager()
230  	mgr.SeedToolCache("playwright", []mcp.RemoteTool{
231  		{ServerName: "playwright", Tool: mcpproto.Tool{Name: "browser_navigate"}},
232  	})
233  
234  	reg := RebuildRegistryForHealth(baseline, nil, nil, healthStates, mgr, nil)
235  	// Disconnected Playwright tools are included from cache for on-demand reconnect.
236  	if _, ok := reg.Get("browser_navigate"); !ok {
237  		t.Error("browser_navigate should be present from cache even when disconnected")
238  	}
239  	// Legacy browser is removed when Playwright tools are present (even disconnected).
240  	if _, ok := reg.Get("browser"); ok {
241  		t.Error("legacy browser should be removed when Playwright tools are present")
242  	}
243  }
244  
245  func TestRebuildRegistryForHealth_GatewayAndPostOverlays(t *testing.T) {
246  	baseline := agent.NewToolRegistry()
247  	baseline.Register(&ThinkTool{})
248  
249  	gw := client.NewGatewayClient("http://test", "key")
250  	gatewayOverlay := []agent.Tool{
251  		NewServerTool(client.ServerToolSchema{Name: "web_search", Description: "search"}, gw),
252  	}
253  	postOverlays := []agent.Tool{
254  		&NotifyTool{},
255  	}
256  
257  	reg := RebuildRegistryForHealth(baseline, gatewayOverlay, postOverlays, nil, nil, nil)
258  	if _, ok := reg.Get("think"); !ok {
259  		t.Error("baseline tool 'think' should be present")
260  	}
261  	if _, ok := reg.Get("web_search"); !ok {
262  		t.Error("gateway overlay 'web_search' should be present")
263  	}
264  	if _, ok := reg.Get("notify"); !ok {
265  		t.Error("post overlay 'notify' should be present")
266  	}
267  }