/ internal / tools / mcp_tool_test.go
mcp_tool_test.go
  1  package tools
  2  
  3  import (
  4  	"context"
  5  	"io"
  6  	"sync/atomic"
  7  	"testing"
  8  	"time"
  9  
 10  	"github.com/Kocoro-lab/ShanClaw/internal/agent"
 11  	"github.com/Kocoro-lab/ShanClaw/internal/mcp"
 12  	mcpgo "github.com/mark3labs/mcp-go/mcp"
 13  )
 14  
 15  // --- Test 1: Disconnected → first call fails → on-demand reconnect → retry succeeds ---
 16  
 17  func TestMCPTool_Run_ReconnectOnDisconnected(t *testing.T) {
 18  	ctx, cancel := context.WithCancel(context.Background())
 19  	defer cancel()
 20  
 21  	// Set up a manager with config but NO client initially.
 22  	mgr := mcp.NewClientManager()
 23  	mgr.SeedConfig("playwright", mcp.MCPServerConfig{Command: "dummy"})
 24  
 25  	// Set up supervisor and start it — initial probe will fail (no client)
 26  	// → server enters StateDisconnected.
 27  	sup := mcp.NewSupervisor(mgr)
 28  	sup.Start(ctx)
 29  	defer sup.Stop()
 30  
 31  	// Let the initial probe cycle run and mark the server disconnected.
 32  	time.Sleep(100 * time.Millisecond)
 33  
 34  	h := sup.HealthFor("playwright")
 35  	if h.State != mcp.StateDisconnected {
 36  		t.Fatalf("expected disconnected after initial probe, got %v", h.State)
 37  	}
 38  
 39  	// Now inject a controllable client: CallTool fails once (io.EOF), then
 40  	// succeeds. ListTools always succeeds (so the transport probe works).
 41  	fake := &controllableCallToolClient{}
 42  	mgr.SeedClient("playwright", fake)
 43  
 44  	// Create MCPTool with supervisor for on-demand reconnect.
 45  	tool := mcpgo.Tool{Name: "browser_navigate"}
 46  	mt := NewMCPTool("playwright", tool, mgr)
 47  	mt.SetSupervisor(sup)
 48  
 49  	result, err := mt.Run(ctx, `{"url":"https://example.com"}`)
 50  	if err != nil {
 51  		t.Fatalf("unexpected error: %v", err)
 52  	}
 53  	if result.IsError {
 54  		t.Fatalf("expected success, got error result: %s", result.Content)
 55  	}
 56  
 57  	// Verify: first call failed (EOF), ProbeNow reconnected, second call succeeded.
 58  	calls := int(fake.callToolCount.Load())
 59  	if calls != 2 {
 60  		t.Errorf("expected 2 CallTool calls (fail + retry), got %d", calls)
 61  	}
 62  }
 63  
 64  // --- Test 2: No cache → disconnected server tools NOT injected ---
 65  
 66  func TestRebuildRegistryForHealth_DisconnectedNoCache(t *testing.T) {
 67  	baseline := agent.NewToolRegistry()
 68  	baseline.Register(&ThinkTool{})
 69  	baseline.Register(&BrowserTool{})
 70  
 71  	healthStates := map[string]mcp.ServerHealth{
 72  		"playwright": {State: mcp.StateDisconnected},
 73  	}
 74  
 75  	// Manager with no cached tools for the disconnected server.
 76  	mgr := mcp.NewClientManager()
 77  	// Deliberately NOT calling mgr.SeedToolCache("playwright", ...)
 78  
 79  	reg := RebuildRegistryForHealth(baseline, nil, nil, healthStates, mgr, nil)
 80  	if _, ok := reg.Get("browser_navigate"); ok {
 81  		t.Error("browser_navigate should NOT be in registry when cache is empty")
 82  	}
 83  	// Legacy browser should remain when no Playwright tools are present.
 84  	if _, ok := reg.Get("browser"); !ok {
 85  		t.Error("legacy browser should remain when no Playwright tools are present")
 86  	}
 87  }
 88  
 89  // --- Test 3: No supervisor → no reconnect, error returned directly ---
 90  
 91  func TestMCPTool_Run_NoSupervisor_NoReconnect(t *testing.T) {
 92  	mgr := mcp.NewClientManager()
 93  	// No client → CallTool will fail.
 94  
 95  	tool := mcpgo.Tool{Name: "browser_navigate"}
 96  	mt := NewMCPTool("playwright", tool, mgr)
 97  	// Deliberately NOT calling mt.SetSupervisor(...)
 98  
 99  	result, err := mt.Run(context.Background(), `{"url":"https://example.com"}`)
100  	if err != nil {
101  		t.Fatalf("unexpected error: %v", err)
102  	}
103  	if !result.IsError {
104  		t.Error("expected error result when server not connected and no supervisor")
105  	}
106  }
107  
108  func TestMCPTool_Run_PreflightsDedicatedChromeWhenAlreadyConnected(t *testing.T) {
109  	mgr := mcp.NewClientManager()
110  	mgr.SeedConfig("playwright", mcp.MCPServerConfig{
111  		Command: "dummy",
112  		Args:    []string{"--cdp-endpoint", "http://127.0.0.1:9223"},
113  	})
114  	mgr.SeedClient("playwright", &successCallToolClient{})
115  
116  	origEnsure := ensureChromeDebugPort
117  	origShouldPreflight := shouldPreflightChromeForTool
118  	t.Cleanup(func() {
119  		ensureChromeDebugPort = origEnsure
120  		shouldPreflightChromeForTool = origShouldPreflight
121  	})
122  
123  	var ensureCalls atomic.Int32
124  	ensureChromeDebugPort = func(port int) error {
125  		ensureCalls.Add(1)
126  		if port != 9223 {
127  			t.Fatalf("expected dedicated port 9223, got %d", port)
128  		}
129  		return nil
130  	}
131  	shouldPreflightChromeForTool = func(port int) bool {
132  		return port == 9223
133  	}
134  
135  	mt := NewMCPTool("playwright", mcpgo.Tool{Name: "browser_navigate"}, mgr)
136  	result, err := mt.Run(context.Background(), `{"url":"https://example.com"}`)
137  	if err != nil {
138  		t.Fatalf("unexpected error: %v", err)
139  	}
140  	if result.IsError {
141  		t.Fatalf("expected success, got error result: %s", result.Content)
142  	}
143  	if got := ensureCalls.Load(); got != 1 {
144  		t.Fatalf("expected 1 dedicated Chrome preflight, got %d", got)
145  	}
146  }
147  
148  // --- Fake MCP client with controllable CallTool ---
149  
150  // controllableCallToolClient is a minimal MCPClient where CallTool fails on the
151  // first call (io.EOF) and succeeds on subsequent calls. ListTools always succeeds
152  // so the supervisor's transport probe can mark the server healthy.
153  type controllableCallToolClient struct {
154  	callToolCount atomic.Int32
155  }
156  
157  type successCallToolClient struct{}
158  
159  func (c *controllableCallToolClient) Initialize(context.Context, mcpgo.InitializeRequest) (*mcpgo.InitializeResult, error) {
160  	return &mcpgo.InitializeResult{}, nil
161  }
162  func (c *successCallToolClient) Initialize(context.Context, mcpgo.InitializeRequest) (*mcpgo.InitializeResult, error) {
163  	return &mcpgo.InitializeResult{}, nil
164  }
165  func (c *controllableCallToolClient) Ping(context.Context) error { return nil }
166  func (c *successCallToolClient) Ping(context.Context) error      { return nil }
167  func (c *controllableCallToolClient) ListResourcesByPage(context.Context, mcpgo.ListResourcesRequest) (*mcpgo.ListResourcesResult, error) {
168  	return &mcpgo.ListResourcesResult{}, nil
169  }
170  func (c *successCallToolClient) ListResourcesByPage(context.Context, mcpgo.ListResourcesRequest) (*mcpgo.ListResourcesResult, error) {
171  	return &mcpgo.ListResourcesResult{}, nil
172  }
173  func (c *controllableCallToolClient) ListResources(context.Context, mcpgo.ListResourcesRequest) (*mcpgo.ListResourcesResult, error) {
174  	return &mcpgo.ListResourcesResult{}, nil
175  }
176  func (c *successCallToolClient) ListResources(context.Context, mcpgo.ListResourcesRequest) (*mcpgo.ListResourcesResult, error) {
177  	return &mcpgo.ListResourcesResult{}, nil
178  }
179  func (c *controllableCallToolClient) ListResourceTemplatesByPage(context.Context, mcpgo.ListResourceTemplatesRequest) (*mcpgo.ListResourceTemplatesResult, error) {
180  	return &mcpgo.ListResourceTemplatesResult{}, nil
181  }
182  func (c *successCallToolClient) ListResourceTemplatesByPage(context.Context, mcpgo.ListResourceTemplatesRequest) (*mcpgo.ListResourceTemplatesResult, error) {
183  	return &mcpgo.ListResourceTemplatesResult{}, nil
184  }
185  func (c *controllableCallToolClient) ListResourceTemplates(context.Context, mcpgo.ListResourceTemplatesRequest) (*mcpgo.ListResourceTemplatesResult, error) {
186  	return &mcpgo.ListResourceTemplatesResult{}, nil
187  }
188  func (c *successCallToolClient) ListResourceTemplates(context.Context, mcpgo.ListResourceTemplatesRequest) (*mcpgo.ListResourceTemplatesResult, error) {
189  	return &mcpgo.ListResourceTemplatesResult{}, nil
190  }
191  func (c *controllableCallToolClient) ReadResource(context.Context, mcpgo.ReadResourceRequest) (*mcpgo.ReadResourceResult, error) {
192  	return &mcpgo.ReadResourceResult{}, nil
193  }
194  func (c *successCallToolClient) ReadResource(context.Context, mcpgo.ReadResourceRequest) (*mcpgo.ReadResourceResult, error) {
195  	return &mcpgo.ReadResourceResult{}, nil
196  }
197  func (c *controllableCallToolClient) Subscribe(context.Context, mcpgo.SubscribeRequest) error {
198  	return nil
199  }
200  func (c *successCallToolClient) Subscribe(context.Context, mcpgo.SubscribeRequest) error {
201  	return nil
202  }
203  func (c *controllableCallToolClient) Unsubscribe(context.Context, mcpgo.UnsubscribeRequest) error {
204  	return nil
205  }
206  func (c *successCallToolClient) Unsubscribe(context.Context, mcpgo.UnsubscribeRequest) error {
207  	return nil
208  }
209  func (c *controllableCallToolClient) ListPromptsByPage(context.Context, mcpgo.ListPromptsRequest) (*mcpgo.ListPromptsResult, error) {
210  	return &mcpgo.ListPromptsResult{}, nil
211  }
212  func (c *successCallToolClient) ListPromptsByPage(context.Context, mcpgo.ListPromptsRequest) (*mcpgo.ListPromptsResult, error) {
213  	return &mcpgo.ListPromptsResult{}, nil
214  }
215  func (c *controllableCallToolClient) ListPrompts(context.Context, mcpgo.ListPromptsRequest) (*mcpgo.ListPromptsResult, error) {
216  	return &mcpgo.ListPromptsResult{}, nil
217  }
218  func (c *successCallToolClient) ListPrompts(context.Context, mcpgo.ListPromptsRequest) (*mcpgo.ListPromptsResult, error) {
219  	return &mcpgo.ListPromptsResult{}, nil
220  }
221  func (c *controllableCallToolClient) GetPrompt(context.Context, mcpgo.GetPromptRequest) (*mcpgo.GetPromptResult, error) {
222  	return &mcpgo.GetPromptResult{}, nil
223  }
224  func (c *successCallToolClient) GetPrompt(context.Context, mcpgo.GetPromptRequest) (*mcpgo.GetPromptResult, error) {
225  	return &mcpgo.GetPromptResult{}, nil
226  }
227  func (c *controllableCallToolClient) ListToolsByPage(context.Context, mcpgo.ListToolsRequest) (*mcpgo.ListToolsResult, error) {
228  	return &mcpgo.ListToolsResult{}, nil
229  }
230  func (c *successCallToolClient) ListToolsByPage(context.Context, mcpgo.ListToolsRequest) (*mcpgo.ListToolsResult, error) {
231  	return &mcpgo.ListToolsResult{}, nil
232  }
233  func (c *controllableCallToolClient) ListTools(context.Context, mcpgo.ListToolsRequest) (*mcpgo.ListToolsResult, error) {
234  	return &mcpgo.ListToolsResult{}, nil
235  }
236  func (c *successCallToolClient) ListTools(context.Context, mcpgo.ListToolsRequest) (*mcpgo.ListToolsResult, error) {
237  	return &mcpgo.ListToolsResult{}, nil
238  }
239  func (c *controllableCallToolClient) CallTool(_ context.Context, _ mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
240  	n := c.callToolCount.Add(1)
241  	if n == 1 {
242  		return nil, io.EOF // transport error → triggers reconnect path
243  	}
244  	return mcpgo.NewToolResultText("ok"), nil
245  }
246  func (c *successCallToolClient) CallTool(_ context.Context, _ mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
247  	return mcpgo.NewToolResultText("ok"), nil
248  }
249  func (c *controllableCallToolClient) SetLevel(context.Context, mcpgo.SetLevelRequest) error {
250  	return nil
251  }
252  func (c *successCallToolClient) SetLevel(context.Context, mcpgo.SetLevelRequest) error {
253  	return nil
254  }
255  func (c *controllableCallToolClient) Complete(context.Context, mcpgo.CompleteRequest) (*mcpgo.CompleteResult, error) {
256  	return &mcpgo.CompleteResult{}, nil
257  }
258  func (c *successCallToolClient) Complete(context.Context, mcpgo.CompleteRequest) (*mcpgo.CompleteResult, error) {
259  	return &mcpgo.CompleteResult{}, nil
260  }
261  func (c *controllableCallToolClient) Close() error { return nil }
262  func (c *successCallToolClient) Close() error      { return nil }
263  func (c *controllableCallToolClient) OnNotification(func(mcpgo.JSONRPCNotification)) {
264  }
265  func (c *successCallToolClient) OnNotification(func(mcpgo.JSONRPCNotification)) {}