pull.go
1 package pull 2 3 import ( 4 "context" 5 "fmt" 6 "strconv" 7 8 "codeberg.org/goern/forgejo-mcp/v2/operation/params" 9 "codeberg.org/goern/forgejo-mcp/v2/pkg/forgejo" 10 "codeberg.org/goern/forgejo-mcp/v2/pkg/log" 11 "codeberg.org/goern/forgejo-mcp/v2/pkg/ptr" 12 "codeberg.org/goern/forgejo-mcp/v2/pkg/to" 13 14 forgejo_sdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v3" 15 "github.com/mark3labs/mcp-go/mcp" 16 "github.com/mark3labs/mcp-go/server" 17 ) 18 19 const ( 20 GetPullRequestByIndexToolName = "get_pull_request_by_index" 21 ListRepoPullRequestsToolName = "list_repo_pull_requests" 22 CreatePullRequestToolName = "create_pull_request" 23 UpdatePullRequestToolName = "update_pull_request" 24 ListPullReviewsToolName = "list_pull_reviews" 25 GetPullReviewToolName = "get_pull_review" 26 ListPullReviewCommentsToolName = "list_pull_review_comments" 27 MergePullRequestToolName = "merge_pull_request" 28 ListPullRequestFilesToolName = "list_pull_request_files" 29 GetPullRequestDiffToolName = "get_pull_request_diff" 30 ) 31 32 var ( 33 GetPullRequestByIndexTool = mcp.NewTool( 34 GetPullRequestByIndexToolName, 35 mcp.WithDescription("Get pull request by index"), 36 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 37 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 38 mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)), 39 ) 40 41 ListRepoPullRequestsTool = mcp.NewTool( 42 ListRepoPullRequestsToolName, 43 mcp.WithDescription("List repo pull requests"), 44 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 45 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 46 mcp.WithString("state", mcp.Description("State (open|closed|all)"), mcp.DefaultString("open")), 47 mcp.WithString("sort", mcp.Description("Sort (oldest|recentupdate|leastupdate|mostcomment)")), 48 mcp.WithString("milestone", mcp.Description(params.Milestone)), 49 mcp.WithString("labels", mcp.Description(params.Labels)), 50 mcp.WithNumber("page", mcp.Description(params.Page), mcp.DefaultNumber(1)), 51 mcp.WithNumber("limit", mcp.Description(params.Limit), mcp.DefaultNumber(20)), 52 ) 53 54 CreatePullRequestTool = mcp.NewTool( 55 CreatePullRequestToolName, 56 mcp.WithDescription("Create pull request"), 57 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 58 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 59 mcp.WithString("head", mcp.Required(), mcp.Description(params.Head)), 60 mcp.WithString("base", mcp.Required(), mcp.Description(params.Base)), 61 mcp.WithString("title", mcp.Required(), mcp.Description(params.Title)), 62 mcp.WithString("body", mcp.Description(params.Body)), 63 ) 64 65 UpdatePullRequestTool = mcp.NewTool( 66 UpdatePullRequestToolName, 67 mcp.WithDescription("Update pull request"), 68 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 69 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 70 mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)), 71 mcp.WithString("title", mcp.Description(params.Title)), 72 mcp.WithString("body", mcp.Description(params.Body)), 73 mcp.WithString("base", mcp.Description(params.Base)), 74 mcp.WithString("assignee", mcp.Description("Assignee username")), 75 mcp.WithString("milestone", mcp.Description(params.Milestone)), 76 ) 77 78 ListPullReviewsTool = mcp.NewTool( 79 ListPullReviewsToolName, 80 mcp.WithDescription("List reviews for a pull request"), 81 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 82 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 83 mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)), 84 mcp.WithNumber("page", mcp.Description(params.Page), mcp.DefaultNumber(1)), 85 mcp.WithNumber("limit", mcp.Description(params.Limit), mcp.DefaultNumber(20)), 86 ) 87 88 GetPullReviewTool = mcp.NewTool( 89 GetPullReviewToolName, 90 mcp.WithDescription("Get a specific pull request review"), 91 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 92 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 93 mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)), 94 mcp.WithNumber("id", mcp.Required(), mcp.Description("Review ID")), 95 ) 96 97 ListPullReviewCommentsTool = mcp.NewTool( 98 ListPullReviewCommentsToolName, 99 mcp.WithDescription("List comments on a pull request review"), 100 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 101 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 102 mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)), 103 mcp.WithNumber("id", mcp.Required(), mcp.Description("Review ID")), 104 ) 105 106 ListPullRequestFilesTool = mcp.NewTool( 107 ListPullRequestFilesToolName, 108 mcp.WithDescription("List changed files in a pull request"), 109 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 110 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 111 mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)), 112 mcp.WithNumber("page", mcp.Description(params.Page), mcp.DefaultNumber(1)), 113 mcp.WithNumber("limit", mcp.Description(params.Limit), mcp.DefaultNumber(50)), 114 ) 115 116 GetPullRequestDiffTool = mcp.NewTool( 117 GetPullRequestDiffToolName, 118 mcp.WithDescription("Get the diff of a pull request"), 119 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 120 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 121 mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)), 122 ) 123 124 MergePullRequestTool = mcp.NewTool( 125 MergePullRequestToolName, 126 mcp.WithDescription("Merge a pull request"), 127 mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)), 128 mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)), 129 mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)), 130 mcp.WithString("style", mcp.Required(), mcp.Description("Merge style (merge, rebase, rebase-merge, squash)")), 131 mcp.WithString("title", mcp.Description("Merge commit title")), 132 mcp.WithString("message", mcp.Description("Merge commit message")), 133 mcp.WithBoolean("delete_branch_after_merge", mcp.Description("Delete head branch after merge")), 134 mcp.WithBoolean("force_merge", mcp.Description("Force merge even if checks have not passed")), 135 mcp.WithBoolean("merge_when_checks_succeed", mcp.Description("Schedule merge for when all checks succeed")), 136 ) 137 ) 138 139 func RegisterTool(s *server.MCPServer) { 140 s.AddTool(GetPullRequestByIndexTool, GetPullRequestByIndexFn) 141 s.AddTool(ListRepoPullRequestsTool, ListRepoPullRequestsFn) 142 s.AddTool(CreatePullRequestTool, CreatePullRequestFn) 143 s.AddTool(UpdatePullRequestTool, UpdatePullRequestFn) 144 s.AddTool(ListPullReviewsTool, ListPullReviewsFn) 145 s.AddTool(GetPullReviewTool, GetPullReviewFn) 146 s.AddTool(ListPullReviewCommentsTool, ListPullReviewCommentsFn) 147 s.AddTool(MergePullRequestTool, MergePullRequestFn) 148 s.AddTool(ListPullRequestFilesTool, ListPullRequestFilesFn) 149 s.AddTool(GetPullRequestDiffTool, GetPullRequestDiffFn) 150 } 151 152 func GetPullRequestByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 153 log.Debugf("Called GetPullRequestByIndexFn") 154 owner, _ := req.GetArguments()["owner"].(string) 155 repo, _ := req.GetArguments()["repo"].(string) 156 index, _ := to.Float64(req.GetArguments()["index"]) 157 158 pr, _, err := forgejo.Client().GetPullRequest(owner, repo, int64(index)) 159 if err != nil { 160 return to.ErrorResult(fmt.Errorf("get pull request err: %v", err)) 161 } 162 return to.TextResult(pr) 163 } 164 165 func ListRepoPullRequestsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 166 log.Debugf("Called ListRepoPullRequestsFn") 167 owner, _ := req.GetArguments()["owner"].(string) 168 repo, _ := req.GetArguments()["repo"].(string) 169 state, ok := req.GetArguments()["state"].(string) 170 if !ok { 171 state = "open" 172 } 173 sort, _ := req.GetArguments()["sort"].(string) 174 page, _ := to.Float64(req.GetArguments()["page"]) 175 if !ok { 176 page = 1 177 } 178 limit, _ := to.Float64(req.GetArguments()["limit"]) 179 if !ok { 180 limit = 20 181 } 182 183 // Convert milestone from string to int64 if provided 184 // Note: Not using milestoneID since it's not supported in the current Forgejo SDK 185 186 // Labels - not used directly in query per API, will be handled in the API call 187 188 opt := forgejo_sdk.ListPullRequestsOptions{ 189 State: forgejo_sdk.StateType(state), 190 Sort: sort, 191 ListOptions: forgejo_sdk.ListOptions{ 192 Page: int(page), 193 PageSize: int(limit), 194 }, 195 } 196 197 // Only set milestone if provided and valid 198 // Note: Not using milestone as it's not supported in the current Forgejo SDK 199 200 prs, _, err := forgejo.Client().ListRepoPullRequests(owner, repo, opt) 201 if err != nil { 202 return to.ErrorResult(fmt.Errorf("get pull request list err: %v", err)) 203 } 204 return to.TextResult(prs) 205 } 206 207 func CreatePullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 208 log.Debugf("Called CreatePullRequestFn") 209 owner, _ := req.GetArguments()["owner"].(string) 210 repo, _ := req.GetArguments()["repo"].(string) 211 head, _ := req.GetArguments()["head"].(string) 212 base, _ := req.GetArguments()["base"].(string) 213 title, _ := req.GetArguments()["title"].(string) 214 body, _ := req.GetArguments()["body"].(string) 215 216 opt := forgejo_sdk.CreatePullRequestOption{ 217 Head: head, 218 Base: base, 219 Title: title, 220 Body: body, 221 } 222 pr, _, err := forgejo.Client().CreatePullRequest(owner, repo, opt) 223 if err != nil { 224 return to.ErrorResult(fmt.Errorf("create pull request err: %v", err)) 225 } 226 return to.TextResult(pr) 227 } 228 229 func UpdatePullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 230 log.Debugf("Called UpdatePullRequestFn") 231 owner, _ := req.GetArguments()["owner"].(string) 232 repo, _ := req.GetArguments()["repo"].(string) 233 index, _ := to.Float64(req.GetArguments()["index"]) 234 title, _ := req.GetArguments()["title"].(string) 235 body, _ := req.GetArguments()["body"].(string) 236 base, _ := req.GetArguments()["base"].(string) 237 assignee, _ := req.GetArguments()["assignee"].(string) 238 milestone, _ := req.GetArguments()["milestone"].(string) 239 240 opt := forgejo_sdk.EditPullRequestOption{} 241 242 if title != "" { 243 opt.Title = title 244 } 245 if body != "" { 246 opt.Body = ptr.To(body) 247 } 248 if base != "" { 249 opt.Base = base 250 } 251 if assignee != "" { 252 opt.Assignee = assignee 253 } 254 if milestone != "" { 255 milestoneID, err := strconv.ParseInt(milestone, 10, 64) 256 if err != nil { 257 return to.ErrorResult(fmt.Errorf("invalid milestone ID: %v", err)) 258 } 259 opt.Milestone = milestoneID 260 } 261 262 pr, _, err := forgejo.Client().EditPullRequest(owner, repo, int64(index), opt) 263 if err != nil { 264 return to.ErrorResult(fmt.Errorf("update pull request err: %v", err)) 265 } 266 return to.TextResult(pr) 267 } 268 269 func ListPullReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 270 log.Debugf("Called ListPullReviewsFn") 271 owner, _ := req.GetArguments()["owner"].(string) 272 repo, _ := req.GetArguments()["repo"].(string) 273 index, _ := to.Float64(req.GetArguments()["index"]) 274 page, _ := to.Float64(req.GetArguments()["page"]) 275 if page == 0 { 276 page = 1 277 } 278 limit, _ := to.Float64(req.GetArguments()["limit"]) 279 if limit == 0 { 280 limit = 20 281 } 282 283 opt := forgejo_sdk.ListPullReviewsOptions{ 284 ListOptions: forgejo_sdk.ListOptions{ 285 Page: int(page), 286 PageSize: int(limit), 287 }, 288 } 289 290 reviews, _, err := forgejo.Client().ListPullReviews(owner, repo, int64(index), opt) 291 if err != nil { 292 return to.ErrorResult(fmt.Errorf("list pull reviews err: %v", err)) 293 } 294 return to.TextResult(reviews) 295 } 296 297 func GetPullReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 298 log.Debugf("Called GetPullReviewFn") 299 owner, _ := req.GetArguments()["owner"].(string) 300 repo, _ := req.GetArguments()["repo"].(string) 301 index, _ := to.Float64(req.GetArguments()["index"]) 302 id, _ := to.Float64(req.GetArguments()["id"]) 303 304 review, _, err := forgejo.Client().GetPullReview(owner, repo, int64(index), int64(id)) 305 if err != nil { 306 return to.ErrorResult(fmt.Errorf("get pull review err: %v", err)) 307 } 308 return to.TextResult(review) 309 } 310 311 func MergePullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 312 log.Debugf("Called MergePullRequestFn") 313 owner, _ := req.GetArguments()["owner"].(string) 314 repo, _ := req.GetArguments()["repo"].(string) 315 index, _ := to.Float64(req.GetArguments()["index"]) 316 style, _ := req.GetArguments()["style"].(string) 317 title, _ := req.GetArguments()["title"].(string) 318 message, _ := req.GetArguments()["message"].(string) 319 deleteBranch, _ := req.GetArguments()["delete_branch_after_merge"].(bool) 320 forceMerge, _ := req.GetArguments()["force_merge"].(bool) 321 mergeWhenChecks, _ := req.GetArguments()["merge_when_checks_succeed"].(bool) 322 323 opt := forgejo_sdk.MergePullRequestOption{ 324 Style: forgejo_sdk.MergeStyle(style), 325 DeleteBranchAfterMerge: deleteBranch, 326 ForceMerge: forceMerge, 327 MergeWhenChecksSucceed: mergeWhenChecks, 328 } 329 330 if title != "" { 331 opt.Title = title 332 } 333 if message != "" { 334 opt.Message = message 335 } 336 337 _, _, err := forgejo.Client().MergePullRequest(owner, repo, int64(index), opt) 338 if err != nil { 339 return to.ErrorResult(fmt.Errorf("merge pull request err: %v", err)) 340 } 341 342 result := "Pull request merged successfully" 343 if mergeWhenChecks { 344 result = "Pull request scheduled to merge when all checks succeed" 345 } 346 return &mcp.CallToolResult{ 347 Content: []mcp.Content{mcp.NewTextContent(result)}, 348 }, nil 349 } 350 351 func ListPullRequestFilesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 352 log.Debugf("Called ListPullRequestFilesFn") 353 owner, _ := req.GetArguments()["owner"].(string) 354 repo, _ := req.GetArguments()["repo"].(string) 355 index, _ := to.Float64(req.GetArguments()["index"]) 356 page, _ := to.Float64(req.GetArguments()["page"]) 357 if page == 0 { 358 page = 1 359 } 360 limit, _ := to.Float64(req.GetArguments()["limit"]) 361 if limit == 0 { 362 limit = 50 363 } 364 365 opt := forgejo_sdk.ListPullRequestFilesOptions{ 366 ListOptions: forgejo_sdk.ListOptions{ 367 Page: int(page), 368 PageSize: int(limit), 369 }, 370 } 371 372 files, _, err := forgejo.Client().ListPullRequestFiles(owner, repo, int64(index), opt) 373 if err != nil { 374 return to.ErrorResult(fmt.Errorf("list pull request files err: %v", err)) 375 } 376 return to.TextResult(files) 377 } 378 379 func GetPullRequestDiffFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 380 log.Debugf("Called GetPullRequestDiffFn") 381 owner, _ := req.GetArguments()["owner"].(string) 382 repo, _ := req.GetArguments()["repo"].(string) 383 index, _ := to.Float64(req.GetArguments()["index"]) 384 385 diff, _, err := forgejo.Client().GetPullRequestDiff(owner, repo, int64(index), forgejo_sdk.PullRequestDiffOptions{}) 386 if err != nil { 387 return to.ErrorResult(fmt.Errorf("get pull request diff err: %v", err)) 388 } 389 return &mcp.CallToolResult{ 390 Content: []mcp.Content{mcp.NewTextContent(string(diff))}, 391 }, nil 392 } 393 394 func ListPullReviewCommentsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 395 log.Debugf("Called ListPullReviewCommentsFn") 396 owner, _ := req.GetArguments()["owner"].(string) 397 repo, _ := req.GetArguments()["repo"].(string) 398 index, _ := to.Float64(req.GetArguments()["index"]) 399 id, _ := to.Float64(req.GetArguments()["id"]) 400 401 comments, _, err := forgejo.Client().ListPullReviewComments(owner, repo, int64(index), int64(id)) 402 if err != nil { 403 return to.ErrorResult(fmt.Errorf("list pull review comments err: %v", err)) 404 } 405 return to.TextResult(comments) 406 }