race_test.go
1 // Package race_test reproduces the "concurrent map writes" panic from 2 // https://codeberg.org/goern/forgejo-mcp/issues/76 3 // 4 // Run: go test -race -count=10 -timeout 120s ./test/race/ 5 package race_test 6 7 import ( 8 "context" 9 "encoding/json" 10 "net/http" 11 "net/http/httptest" 12 "strings" 13 "sync" 14 "testing" 15 16 "codeberg.org/goern/forgejo-mcp/v2/operation" 17 flagPkg "codeberg.org/goern/forgejo-mcp/v2/pkg/flag" 18 "codeberg.org/goern/forgejo-mcp/v2/pkg/forgejo" 19 20 "github.com/mark3labs/mcp-go/mcp" 21 "github.com/mark3labs/mcp-go/server" 22 ) 23 24 // fakeAPI is a package-level test server so the forgejo.Client() singleton 25 // (initialized via sync.Once) always points to a live server. 26 var ( 27 fakeAPI *httptest.Server 28 setupMu sync.Once 29 mcpSrv *server.MCPServer 30 allTools []string 31 ) 32 33 func setup(t *testing.T) { 34 t.Helper() 35 setupMu.Do(func() { 36 fakeAPI = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 w.Header().Set("Content-Type", "application/json") 38 39 // Return realistic responses for common endpoints. 40 switch { 41 case r.URL.Path == "/api/v1/user": 42 json.NewEncoder(w).Encode(map[string]interface{}{ 43 "id": 1, "login": "testuser", "full_name": "Test User", 44 "email": "test@example.com", "avatar_url": "", 45 }) 46 case r.URL.Path == "/api/v1/version": 47 json.NewEncoder(w).Encode(map[string]interface{}{ 48 "version": "1.21.0", 49 }) 50 case strings.Contains(r.URL.Path, "/issues/") && r.Method == http.MethodGet: 51 json.NewEncoder(w).Encode(map[string]interface{}{ 52 "id": 1, "number": 1, "title": "test issue", 53 "state": "open", "body": "body", 54 "user": map[string]interface{}{"id": 1, "login": "testuser"}, 55 }) 56 case strings.Contains(r.URL.Path, "/pulls/") && r.Method == http.MethodGet: 57 json.NewEncoder(w).Encode(map[string]interface{}{ 58 "id": 1, "number": 1, "title": "test pr", 59 "state": "open", "body": "body", 60 "user": map[string]interface{}{"id": 1, "login": "testuser"}, 61 }) 62 default: 63 if r.Method == http.MethodGet { 64 json.NewEncoder(w).Encode([]interface{}{}) 65 } else { 66 json.NewEncoder(w).Encode(map[string]interface{}{ 67 "id": 1, "number": 1, "title": "test", 68 "state": "open", "body": "test body", 69 "user": map[string]interface{}{"id": 1, "login": "testuser"}, 70 }) 71 } 72 } 73 })) 74 75 flagPkg.URL = fakeAPI.URL 76 flagPkg.Token = "fake-token-for-testing" 77 flagPkg.Debug = false 78 79 // Force the forgejo client singleton to initialize against our fake server. 80 _ = forgejo.Client() 81 82 mcpSrv = server.NewMCPServer("forgejo-mcp", "test", server.WithLogging()) 83 operation.RegisterTool(mcpSrv) 84 85 tools := mcpSrv.ListTools() 86 allTools = make([]string, 0, len(tools)) 87 for name := range tools { 88 allTools = append(allTools, name) 89 } 90 }) 91 } 92 93 // minimalArgs returns the minimum required arguments for a tool. 94 func minimalArgs(toolName string) map[string]any { 95 base := map[string]any{ 96 "owner": "testowner", 97 "repo": "testrepo", 98 } 99 switch toolName { 100 case "get_issue_by_index", "update_issue", "issue_state_change", "add_issue_labels": 101 base["index"] = float64(1) 102 case "create_issue": 103 base["title"] = "test issue" 104 case "create_issue_comment": 105 base["index"] = float64(1) 106 base["body"] = "test comment" 107 case "list_issue_comments": 108 base["index"] = float64(1) 109 case "get_issue_comment", "edit_issue_comment", "delete_issue_comment": 110 base["comment_id"] = float64(1) 111 case "get_pull_request", "merge_pull_request": 112 base["index"] = float64(1) 113 case "create_pull_request": 114 base["title"] = "test pr" 115 base["head"] = "feature" 116 base["base"] = "main" 117 case "get_pull_request_diff": 118 base["index"] = float64(1) 119 case "create_pull_request_review": 120 base["index"] = float64(1) 121 base["event"] = "COMMENT" 122 base["body"] = "looks good" 123 case "list_pull_request_reviews": 124 base["index"] = float64(1) 125 case "dismiss_pull_request_review": 126 base["index"] = float64(1) 127 base["review_id"] = float64(1) 128 case "submit_pull_request_review": 129 base["index"] = float64(1) 130 base["review_id"] = float64(1) 131 base["event"] = "COMMENT" 132 case "get_file_content": 133 base["filepath"] = "README.md" 134 case "search_repos", "search_issues", "search_users": 135 base["keyword"] = "test" 136 delete(base, "owner") 137 delete(base, "repo") 138 case "search_org_teams": 139 base["org"] = "testorg" 140 delete(base, "owner") 141 delete(base, "repo") 142 case "get_notification_thread", "mark_notification_read": 143 base["id"] = float64(1) 144 delete(base, "owner") 145 delete(base, "repo") 146 case "mark_all_notifications_read", "check_notifications": 147 delete(base, "owner") 148 delete(base, "repo") 149 case "list_repo_notifications", "mark_repo_notifications_read": 150 // needs owner + repo (already set) 151 case "get_my_user_info", "get_forgejo_version": 152 delete(base, "owner") 153 delete(base, "repo") 154 case "create_branch": 155 base["branch"] = "new-branch" 156 case "fork_repo": 157 // needs owner + repo (already set) 158 case "list_repo_milestones", "list_repo_labels": 159 // needs owner + repo (already set), defaults apply for page/limit/state 160 } 161 return base 162 } 163 164 // TestConcurrentToolCalls invokes all registered MCP tool handlers 165 // concurrently from multiple goroutines, simulating the mcp-go worker pool. 166 func TestConcurrentToolCalls(t *testing.T) { 167 setup(t) 168 t.Logf("registered %d tools", len(allTools)) 169 170 const concurrency = 20 171 const iterations = 3 172 173 for iter := 0; iter < iterations; iter++ { 174 var wg sync.WaitGroup 175 for i := 0; i < concurrency; i++ { 176 wg.Add(1) 177 go func() { 178 defer wg.Done() 179 for _, toolName := range allTools { 180 st := mcpSrv.GetTool(toolName) 181 if st == nil { 182 continue 183 } 184 req := mcp.CallToolRequest{ 185 Params: mcp.CallToolParams{ 186 Name: toolName, 187 Arguments: minimalArgs(toolName), 188 }, 189 } 190 // We don't care about errors — only panics / races. 191 _, _ = st.Handler(context.Background(), req) 192 } 193 }() 194 } 195 wg.Wait() 196 } 197 } 198 199 // TestConcurrentSameToolRepeated hammers each tool individually from many goroutines. 200 func TestConcurrentSameToolRepeated(t *testing.T) { 201 setup(t) 202 203 for _, name := range allTools { 204 name := name 205 t.Run(name, func(t *testing.T) { 206 t.Parallel() 207 st := mcpSrv.GetTool(name) 208 if st == nil { 209 t.Skip("tool not found") 210 } 211 212 var wg sync.WaitGroup 213 for i := 0; i < 50; i++ { 214 wg.Add(1) 215 go func() { 216 defer wg.Done() 217 for j := 0; j < 10; j++ { 218 req := mcp.CallToolRequest{ 219 Params: mcp.CallToolParams{ 220 Name: name, 221 Arguments: minimalArgs(name), 222 }, 223 } 224 _, _ = st.Handler(context.Background(), req) 225 } 226 }() 227 } 228 wg.Wait() 229 }) 230 } 231 } 232 233 // TestConcurrentListAndRegister tests for races between listing tools 234 // and registering tools concurrently. 235 func TestConcurrentListAndRegister(t *testing.T) { 236 srv := server.NewMCPServer("forgejo-mcp", "test", server.WithLogging()) 237 238 var wg sync.WaitGroup 239 wg.Add(1) 240 go func() { 241 defer wg.Done() 242 operation.RegisterTool(srv) 243 }() 244 245 for i := 0; i < 10; i++ { 246 wg.Add(1) 247 go func() { 248 defer wg.Done() 249 for j := 0; j < 100; j++ { 250 _ = srv.ListTools() 251 } 252 }() 253 } 254 wg.Wait() 255 } 256 257 // TestInitFlagParseBug documents that cmd.init() calls flag.Parse() on the 258 // global flag.CommandLine, preventing `go test ./cmd/` from working. 259 func TestInitFlagParseBug(t *testing.T) { 260 t.Log("cmd.init() calls flag.Parse() on global CommandLine, " + 261 "preventing 'go test ./cmd/' from running. " + 262 "This test documents the issue (it lives in test/race/ to avoid it).") 263 } 264 265 // TestNilResponseDeref documents a nil-pointer bug in tool handlers. 266 // Many handlers access resp.StatusCode BEFORE checking err != nil. 267 // When the forgejo client returns (nil, nil, err), this panics. 268 // Example: operation/user/user.go:44 269 // 270 // user, resp, err := forgejo.Client().GetMyUserInfo() 271 // forgejo.LogAPICall(ctx, "GET", "/user", duration, resp.StatusCode, err) // CRASH if resp==nil 272 // if err != nil { ... } 273 // 274 // This is a separate bug but was discovered while investigating #76. 275 func TestNilResponseDeref(t *testing.T) { 276 t.Log("Many tool handlers access resp.StatusCode before checking err. " + 277 "If the API call returns (nil, nil, err), resp.StatusCode panics. " + 278 "Fix: check err before accessing resp, or guard with 'if resp != nil'.") 279 }