/ test / race / race_test.go
race_test.go
  1  // Package race_test reproduces the "concurrent map writes" panic from
  2  // https://codeberg.org/goern/forgejo-mcp/issues/76
  3  //
  4  // Run:  go test -race -count=10 -timeout 120s ./test/race/
  5  package race_test
  6  
  7  import (
  8  	"context"
  9  	"encoding/json"
 10  	"net/http"
 11  	"net/http/httptest"
 12  	"strings"
 13  	"sync"
 14  	"testing"
 15  
 16  	"codeberg.org/goern/forgejo-mcp/v2/operation"
 17  	flagPkg "codeberg.org/goern/forgejo-mcp/v2/pkg/flag"
 18  	"codeberg.org/goern/forgejo-mcp/v2/pkg/forgejo"
 19  
 20  	"github.com/mark3labs/mcp-go/mcp"
 21  	"github.com/mark3labs/mcp-go/server"
 22  )
 23  
 24  // fakeAPI is a package-level test server so the forgejo.Client() singleton
 25  // (initialized via sync.Once) always points to a live server.
 26  var (
 27  	fakeAPI  *httptest.Server
 28  	setupMu  sync.Once
 29  	mcpSrv   *server.MCPServer
 30  	allTools []string
 31  )
 32  
 33  func setup(t *testing.T) {
 34  	t.Helper()
 35  	setupMu.Do(func() {
 36  		fakeAPI = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 37  			w.Header().Set("Content-Type", "application/json")
 38  
 39  			// Return realistic responses for common endpoints.
 40  			switch {
 41  			case r.URL.Path == "/api/v1/user":
 42  				json.NewEncoder(w).Encode(map[string]interface{}{
 43  					"id": 1, "login": "testuser", "full_name": "Test User",
 44  					"email": "test@example.com", "avatar_url": "",
 45  				})
 46  			case r.URL.Path == "/api/v1/version":
 47  				json.NewEncoder(w).Encode(map[string]interface{}{
 48  					"version": "1.21.0",
 49  				})
 50  			case strings.Contains(r.URL.Path, "/issues/") && r.Method == http.MethodGet:
 51  				json.NewEncoder(w).Encode(map[string]interface{}{
 52  					"id": 1, "number": 1, "title": "test issue",
 53  					"state": "open", "body": "body",
 54  					"user": map[string]interface{}{"id": 1, "login": "testuser"},
 55  				})
 56  			case strings.Contains(r.URL.Path, "/pulls/") && r.Method == http.MethodGet:
 57  				json.NewEncoder(w).Encode(map[string]interface{}{
 58  					"id": 1, "number": 1, "title": "test pr",
 59  					"state": "open", "body": "body",
 60  					"user": map[string]interface{}{"id": 1, "login": "testuser"},
 61  				})
 62  			default:
 63  				if r.Method == http.MethodGet {
 64  					json.NewEncoder(w).Encode([]interface{}{})
 65  				} else {
 66  					json.NewEncoder(w).Encode(map[string]interface{}{
 67  						"id": 1, "number": 1, "title": "test",
 68  						"state": "open", "body": "test body",
 69  						"user": map[string]interface{}{"id": 1, "login": "testuser"},
 70  					})
 71  				}
 72  			}
 73  		}))
 74  
 75  		flagPkg.URL = fakeAPI.URL
 76  		flagPkg.Token = "fake-token-for-testing"
 77  		flagPkg.Debug = false
 78  
 79  		// Force the forgejo client singleton to initialize against our fake server.
 80  		_ = forgejo.Client()
 81  
 82  		mcpSrv = server.NewMCPServer("forgejo-mcp", "test", server.WithLogging())
 83  		operation.RegisterTool(mcpSrv)
 84  
 85  		tools := mcpSrv.ListTools()
 86  		allTools = make([]string, 0, len(tools))
 87  		for name := range tools {
 88  			allTools = append(allTools, name)
 89  		}
 90  	})
 91  }
 92  
 93  // minimalArgs returns the minimum required arguments for a tool.
 94  func minimalArgs(toolName string) map[string]any {
 95  	base := map[string]any{
 96  		"owner": "testowner",
 97  		"repo":  "testrepo",
 98  	}
 99  	switch toolName {
100  	case "get_issue_by_index", "update_issue", "issue_state_change", "add_issue_labels":
101  		base["index"] = float64(1)
102  	case "create_issue":
103  		base["title"] = "test issue"
104  	case "create_issue_comment":
105  		base["index"] = float64(1)
106  		base["body"] = "test comment"
107  	case "list_issue_comments":
108  		base["index"] = float64(1)
109  	case "get_issue_comment", "edit_issue_comment", "delete_issue_comment":
110  		base["comment_id"] = float64(1)
111  	case "get_pull_request", "merge_pull_request":
112  		base["index"] = float64(1)
113  	case "create_pull_request":
114  		base["title"] = "test pr"
115  		base["head"] = "feature"
116  		base["base"] = "main"
117  	case "get_pull_request_diff":
118  		base["index"] = float64(1)
119  	case "create_pull_request_review":
120  		base["index"] = float64(1)
121  		base["event"] = "COMMENT"
122  		base["body"] = "looks good"
123  	case "list_pull_request_reviews":
124  		base["index"] = float64(1)
125  	case "dismiss_pull_request_review":
126  		base["index"] = float64(1)
127  		base["review_id"] = float64(1)
128  	case "submit_pull_request_review":
129  		base["index"] = float64(1)
130  		base["review_id"] = float64(1)
131  		base["event"] = "COMMENT"
132  	case "get_file_content":
133  		base["filepath"] = "README.md"
134  	case "search_repos", "search_issues", "search_users":
135  		base["keyword"] = "test"
136  		delete(base, "owner")
137  		delete(base, "repo")
138  	case "search_org_teams":
139  		base["org"] = "testorg"
140  		delete(base, "owner")
141  		delete(base, "repo")
142  	case "get_notification_thread", "mark_notification_read":
143  		base["id"] = float64(1)
144  		delete(base, "owner")
145  		delete(base, "repo")
146  	case "mark_all_notifications_read", "check_notifications":
147  		delete(base, "owner")
148  		delete(base, "repo")
149  	case "list_repo_notifications", "mark_repo_notifications_read":
150  		// needs owner + repo (already set)
151  	case "get_my_user_info", "get_forgejo_version":
152  		delete(base, "owner")
153  		delete(base, "repo")
154  	case "create_branch":
155  		base["branch"] = "new-branch"
156  	case "fork_repo":
157  		// needs owner + repo (already set)
158  	case "list_repo_milestones", "list_repo_labels":
159  		// needs owner + repo (already set), defaults apply for page/limit/state
160  	}
161  	return base
162  }
163  
164  // TestConcurrentToolCalls invokes all registered MCP tool handlers
165  // concurrently from multiple goroutines, simulating the mcp-go worker pool.
166  func TestConcurrentToolCalls(t *testing.T) {
167  	setup(t)
168  	t.Logf("registered %d tools", len(allTools))
169  
170  	const concurrency = 20
171  	const iterations = 3
172  
173  	for iter := 0; iter < iterations; iter++ {
174  		var wg sync.WaitGroup
175  		for i := 0; i < concurrency; i++ {
176  			wg.Add(1)
177  			go func() {
178  				defer wg.Done()
179  				for _, toolName := range allTools {
180  					st := mcpSrv.GetTool(toolName)
181  					if st == nil {
182  						continue
183  					}
184  					req := mcp.CallToolRequest{
185  						Params: mcp.CallToolParams{
186  							Name:      toolName,
187  							Arguments: minimalArgs(toolName),
188  						},
189  					}
190  					// We don't care about errors — only panics / races.
191  					_, _ = st.Handler(context.Background(), req)
192  				}
193  			}()
194  		}
195  		wg.Wait()
196  	}
197  }
198  
199  // TestConcurrentSameToolRepeated hammers each tool individually from many goroutines.
200  func TestConcurrentSameToolRepeated(t *testing.T) {
201  	setup(t)
202  
203  	for _, name := range allTools {
204  		name := name
205  		t.Run(name, func(t *testing.T) {
206  			t.Parallel()
207  			st := mcpSrv.GetTool(name)
208  			if st == nil {
209  				t.Skip("tool not found")
210  			}
211  
212  			var wg sync.WaitGroup
213  			for i := 0; i < 50; i++ {
214  				wg.Add(1)
215  				go func() {
216  					defer wg.Done()
217  					for j := 0; j < 10; j++ {
218  						req := mcp.CallToolRequest{
219  							Params: mcp.CallToolParams{
220  								Name:      name,
221  								Arguments: minimalArgs(name),
222  							},
223  						}
224  						_, _ = st.Handler(context.Background(), req)
225  					}
226  				}()
227  			}
228  			wg.Wait()
229  		})
230  	}
231  }
232  
233  // TestConcurrentListAndRegister tests for races between listing tools
234  // and registering tools concurrently.
235  func TestConcurrentListAndRegister(t *testing.T) {
236  	srv := server.NewMCPServer("forgejo-mcp", "test", server.WithLogging())
237  
238  	var wg sync.WaitGroup
239  	wg.Add(1)
240  	go func() {
241  		defer wg.Done()
242  		operation.RegisterTool(srv)
243  	}()
244  
245  	for i := 0; i < 10; i++ {
246  		wg.Add(1)
247  		go func() {
248  			defer wg.Done()
249  			for j := 0; j < 100; j++ {
250  				_ = srv.ListTools()
251  			}
252  		}()
253  	}
254  	wg.Wait()
255  }
256  
257  // TestInitFlagParseBug documents that cmd.init() calls flag.Parse() on the
258  // global flag.CommandLine, preventing `go test ./cmd/` from working.
259  func TestInitFlagParseBug(t *testing.T) {
260  	t.Log("cmd.init() calls flag.Parse() on global CommandLine, " +
261  		"preventing 'go test ./cmd/' from running. " +
262  		"This test documents the issue (it lives in test/race/ to avoid it).")
263  }
264  
265  // TestNilResponseDeref documents a nil-pointer bug in tool handlers.
266  // Many handlers access resp.StatusCode BEFORE checking err != nil.
267  // When the forgejo client returns (nil, nil, err), this panics.
268  // Example: operation/user/user.go:44
269  //
270  //	user, resp, err := forgejo.Client().GetMyUserInfo()
271  //	forgejo.LogAPICall(ctx, "GET", "/user", duration, resp.StatusCode, err) // CRASH if resp==nil
272  //	if err != nil { ... }
273  //
274  // This is a separate bug but was discovered while investigating #76.
275  func TestNilResponseDeref(t *testing.T) {
276  	t.Log("Many tool handlers access resp.StatusCode before checking err. " +
277  		"If the API call returns (nil, nil, err), resp.StatusCode panics. " +
278  		"Fix: check err before accessing resp, or guard with 'if resp != nil'.")
279  }