/ operation / pull / pull.go
pull.go
  1  package pull
  2  
  3  import (
  4  	"context"
  5  	"fmt"
  6  	"strconv"
  7  
  8  	"codeberg.org/goern/forgejo-mcp/v2/operation/params"
  9  	"codeberg.org/goern/forgejo-mcp/v2/pkg/forgejo"
 10  	"codeberg.org/goern/forgejo-mcp/v2/pkg/log"
 11  	"codeberg.org/goern/forgejo-mcp/v2/pkg/ptr"
 12  	"codeberg.org/goern/forgejo-mcp/v2/pkg/to"
 13  
 14  	forgejo_sdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v3"
 15  	"github.com/mark3labs/mcp-go/mcp"
 16  	"github.com/mark3labs/mcp-go/server"
 17  )
 18  
 19  const (
 20  	GetPullRequestByIndexToolName  = "get_pull_request_by_index"
 21  	ListRepoPullRequestsToolName   = "list_repo_pull_requests"
 22  	CreatePullRequestToolName      = "create_pull_request"
 23  	UpdatePullRequestToolName      = "update_pull_request"
 24  	ListPullReviewsToolName        = "list_pull_reviews"
 25  	GetPullReviewToolName          = "get_pull_review"
 26  	ListPullReviewCommentsToolName = "list_pull_review_comments"
 27  	MergePullRequestToolName       = "merge_pull_request"
 28  	ListPullRequestFilesToolName   = "list_pull_request_files"
 29  	GetPullRequestDiffToolName     = "get_pull_request_diff"
 30  )
 31  
 32  var (
 33  	GetPullRequestByIndexTool = mcp.NewTool(
 34  		GetPullRequestByIndexToolName,
 35  		mcp.WithDescription("Get pull request by index"),
 36  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
 37  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
 38  		mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)),
 39  	)
 40  
 41  	ListRepoPullRequestsTool = mcp.NewTool(
 42  		ListRepoPullRequestsToolName,
 43  		mcp.WithDescription("List repo pull requests"),
 44  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
 45  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
 46  		mcp.WithString("state", mcp.Description("State (open|closed|all)"), mcp.DefaultString("open")),
 47  		mcp.WithString("sort", mcp.Description("Sort (oldest|recentupdate|leastupdate|mostcomment)")),
 48  		mcp.WithString("milestone", mcp.Description(params.Milestone)),
 49  		mcp.WithString("labels", mcp.Description(params.Labels)),
 50  		mcp.WithNumber("page", mcp.Description(params.Page), mcp.DefaultNumber(1)),
 51  		mcp.WithNumber("limit", mcp.Description(params.Limit), mcp.DefaultNumber(20)),
 52  	)
 53  
 54  	CreatePullRequestTool = mcp.NewTool(
 55  		CreatePullRequestToolName,
 56  		mcp.WithDescription("Create pull request"),
 57  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
 58  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
 59  		mcp.WithString("head", mcp.Required(), mcp.Description(params.Head)),
 60  		mcp.WithString("base", mcp.Required(), mcp.Description(params.Base)),
 61  		mcp.WithString("title", mcp.Required(), mcp.Description(params.Title)),
 62  		mcp.WithString("body", mcp.Description(params.Body)),
 63  	)
 64  
 65  	UpdatePullRequestTool = mcp.NewTool(
 66  		UpdatePullRequestToolName,
 67  		mcp.WithDescription("Update pull request"),
 68  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
 69  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
 70  		mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)),
 71  		mcp.WithString("title", mcp.Description(params.Title)),
 72  		mcp.WithString("body", mcp.Description(params.Body)),
 73  		mcp.WithString("base", mcp.Description(params.Base)),
 74  		mcp.WithString("assignee", mcp.Description("Assignee username")),
 75  		mcp.WithString("milestone", mcp.Description(params.Milestone)),
 76  	)
 77  
 78  	ListPullReviewsTool = mcp.NewTool(
 79  		ListPullReviewsToolName,
 80  		mcp.WithDescription("List reviews for a pull request"),
 81  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
 82  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
 83  		mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)),
 84  		mcp.WithNumber("page", mcp.Description(params.Page), mcp.DefaultNumber(1)),
 85  		mcp.WithNumber("limit", mcp.Description(params.Limit), mcp.DefaultNumber(20)),
 86  	)
 87  
 88  	GetPullReviewTool = mcp.NewTool(
 89  		GetPullReviewToolName,
 90  		mcp.WithDescription("Get a specific pull request review"),
 91  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
 92  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
 93  		mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)),
 94  		mcp.WithNumber("id", mcp.Required(), mcp.Description("Review ID")),
 95  	)
 96  
 97  	ListPullReviewCommentsTool = mcp.NewTool(
 98  		ListPullReviewCommentsToolName,
 99  		mcp.WithDescription("List comments on a pull request review"),
100  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
101  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
102  		mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)),
103  		mcp.WithNumber("id", mcp.Required(), mcp.Description("Review ID")),
104  	)
105  
106  	ListPullRequestFilesTool = mcp.NewTool(
107  		ListPullRequestFilesToolName,
108  		mcp.WithDescription("List changed files in a pull request"),
109  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
110  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
111  		mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)),
112  		mcp.WithNumber("page", mcp.Description(params.Page), mcp.DefaultNumber(1)),
113  		mcp.WithNumber("limit", mcp.Description(params.Limit), mcp.DefaultNumber(50)),
114  	)
115  
116  	GetPullRequestDiffTool = mcp.NewTool(
117  		GetPullRequestDiffToolName,
118  		mcp.WithDescription("Get the diff of a pull request"),
119  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
120  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
121  		mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)),
122  	)
123  
124  	MergePullRequestTool = mcp.NewTool(
125  		MergePullRequestToolName,
126  		mcp.WithDescription("Merge a pull request"),
127  		mcp.WithString("owner", mcp.Required(), mcp.Description(params.Owner)),
128  		mcp.WithString("repo", mcp.Required(), mcp.Description(params.Repo)),
129  		mcp.WithNumber("index", mcp.Required(), mcp.Description(params.PRIndex)),
130  		mcp.WithString("style", mcp.Required(), mcp.Description("Merge style (merge, rebase, rebase-merge, squash)")),
131  		mcp.WithString("title", mcp.Description("Merge commit title")),
132  		mcp.WithString("message", mcp.Description("Merge commit message")),
133  		mcp.WithBoolean("delete_branch_after_merge", mcp.Description("Delete head branch after merge")),
134  		mcp.WithBoolean("force_merge", mcp.Description("Force merge even if checks have not passed")),
135  		mcp.WithBoolean("merge_when_checks_succeed", mcp.Description("Schedule merge for when all checks succeed")),
136  	)
137  )
138  
139  func RegisterTool(s *server.MCPServer) {
140  	s.AddTool(GetPullRequestByIndexTool, GetPullRequestByIndexFn)
141  	s.AddTool(ListRepoPullRequestsTool, ListRepoPullRequestsFn)
142  	s.AddTool(CreatePullRequestTool, CreatePullRequestFn)
143  	s.AddTool(UpdatePullRequestTool, UpdatePullRequestFn)
144  	s.AddTool(ListPullReviewsTool, ListPullReviewsFn)
145  	s.AddTool(GetPullReviewTool, GetPullReviewFn)
146  	s.AddTool(ListPullReviewCommentsTool, ListPullReviewCommentsFn)
147  	s.AddTool(MergePullRequestTool, MergePullRequestFn)
148  	s.AddTool(ListPullRequestFilesTool, ListPullRequestFilesFn)
149  	s.AddTool(GetPullRequestDiffTool, GetPullRequestDiffFn)
150  }
151  
152  func GetPullRequestByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
153  	log.Debugf("Called GetPullRequestByIndexFn")
154  	owner, _ := req.GetArguments()["owner"].(string)
155  	repo, _ := req.GetArguments()["repo"].(string)
156  	index, _ := to.Float64(req.GetArguments()["index"])
157  
158  	pr, _, err := forgejo.Client().GetPullRequest(owner, repo, int64(index))
159  	if err != nil {
160  		return to.ErrorResult(fmt.Errorf("get pull request err: %v", err))
161  	}
162  	return to.TextResult(pr)
163  }
164  
165  func ListRepoPullRequestsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
166  	log.Debugf("Called ListRepoPullRequestsFn")
167  	owner, _ := req.GetArguments()["owner"].(string)
168  	repo, _ := req.GetArguments()["repo"].(string)
169  	state, ok := req.GetArguments()["state"].(string)
170  	if !ok {
171  		state = "open"
172  	}
173  	sort, _ := req.GetArguments()["sort"].(string)
174  	page, _ := to.Float64(req.GetArguments()["page"])
175  	if !ok {
176  		page = 1
177  	}
178  	limit, _ := to.Float64(req.GetArguments()["limit"])
179  	if !ok {
180  		limit = 20
181  	}
182  
183  	// Convert milestone from string to int64 if provided
184  	// Note: Not using milestoneID since it's not supported in the current Forgejo SDK
185  
186  	// Labels - not used directly in query per API, will be handled in the API call
187  
188  	opt := forgejo_sdk.ListPullRequestsOptions{
189  		State: forgejo_sdk.StateType(state),
190  		Sort:  sort,
191  		ListOptions: forgejo_sdk.ListOptions{
192  			Page:     int(page),
193  			PageSize: int(limit),
194  		},
195  	}
196  
197  	// Only set milestone if provided and valid
198  	// Note: Not using milestone as it's not supported in the current Forgejo SDK
199  
200  	prs, _, err := forgejo.Client().ListRepoPullRequests(owner, repo, opt)
201  	if err != nil {
202  		return to.ErrorResult(fmt.Errorf("get pull request list err: %v", err))
203  	}
204  	return to.TextResult(prs)
205  }
206  
207  func CreatePullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
208  	log.Debugf("Called CreatePullRequestFn")
209  	owner, _ := req.GetArguments()["owner"].(string)
210  	repo, _ := req.GetArguments()["repo"].(string)
211  	head, _ := req.GetArguments()["head"].(string)
212  	base, _ := req.GetArguments()["base"].(string)
213  	title, _ := req.GetArguments()["title"].(string)
214  	body, _ := req.GetArguments()["body"].(string)
215  
216  	opt := forgejo_sdk.CreatePullRequestOption{
217  		Head:  head,
218  		Base:  base,
219  		Title: title,
220  		Body:  body,
221  	}
222  	pr, _, err := forgejo.Client().CreatePullRequest(owner, repo, opt)
223  	if err != nil {
224  		return to.ErrorResult(fmt.Errorf("create pull request err: %v", err))
225  	}
226  	return to.TextResult(pr)
227  }
228  
229  func UpdatePullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
230  	log.Debugf("Called UpdatePullRequestFn")
231  	owner, _ := req.GetArguments()["owner"].(string)
232  	repo, _ := req.GetArguments()["repo"].(string)
233  	index, _ := to.Float64(req.GetArguments()["index"])
234  	title, _ := req.GetArguments()["title"].(string)
235  	body, _ := req.GetArguments()["body"].(string)
236  	base, _ := req.GetArguments()["base"].(string)
237  	assignee, _ := req.GetArguments()["assignee"].(string)
238  	milestone, _ := req.GetArguments()["milestone"].(string)
239  
240  	opt := forgejo_sdk.EditPullRequestOption{}
241  
242  	if title != "" {
243  		opt.Title = title
244  	}
245  	if body != "" {
246  		opt.Body = ptr.To(body)
247  	}
248  	if base != "" {
249  		opt.Base = base
250  	}
251  	if assignee != "" {
252  		opt.Assignee = assignee
253  	}
254  	if milestone != "" {
255  		milestoneID, err := strconv.ParseInt(milestone, 10, 64)
256  		if err != nil {
257  			return to.ErrorResult(fmt.Errorf("invalid milestone ID: %v", err))
258  		}
259  		opt.Milestone = milestoneID
260  	}
261  
262  	pr, _, err := forgejo.Client().EditPullRequest(owner, repo, int64(index), opt)
263  	if err != nil {
264  		return to.ErrorResult(fmt.Errorf("update pull request err: %v", err))
265  	}
266  	return to.TextResult(pr)
267  }
268  
269  func ListPullReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
270  	log.Debugf("Called ListPullReviewsFn")
271  	owner, _ := req.GetArguments()["owner"].(string)
272  	repo, _ := req.GetArguments()["repo"].(string)
273  	index, _ := to.Float64(req.GetArguments()["index"])
274  	page, _ := to.Float64(req.GetArguments()["page"])
275  	if page == 0 {
276  		page = 1
277  	}
278  	limit, _ := to.Float64(req.GetArguments()["limit"])
279  	if limit == 0 {
280  		limit = 20
281  	}
282  
283  	opt := forgejo_sdk.ListPullReviewsOptions{
284  		ListOptions: forgejo_sdk.ListOptions{
285  			Page:     int(page),
286  			PageSize: int(limit),
287  		},
288  	}
289  
290  	reviews, _, err := forgejo.Client().ListPullReviews(owner, repo, int64(index), opt)
291  	if err != nil {
292  		return to.ErrorResult(fmt.Errorf("list pull reviews err: %v", err))
293  	}
294  	return to.TextResult(reviews)
295  }
296  
297  func GetPullReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
298  	log.Debugf("Called GetPullReviewFn")
299  	owner, _ := req.GetArguments()["owner"].(string)
300  	repo, _ := req.GetArguments()["repo"].(string)
301  	index, _ := to.Float64(req.GetArguments()["index"])
302  	id, _ := to.Float64(req.GetArguments()["id"])
303  
304  	review, _, err := forgejo.Client().GetPullReview(owner, repo, int64(index), int64(id))
305  	if err != nil {
306  		return to.ErrorResult(fmt.Errorf("get pull review err: %v", err))
307  	}
308  	return to.TextResult(review)
309  }
310  
311  func MergePullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
312  	log.Debugf("Called MergePullRequestFn")
313  	owner, _ := req.GetArguments()["owner"].(string)
314  	repo, _ := req.GetArguments()["repo"].(string)
315  	index, _ := to.Float64(req.GetArguments()["index"])
316  	style, _ := req.GetArguments()["style"].(string)
317  	title, _ := req.GetArguments()["title"].(string)
318  	message, _ := req.GetArguments()["message"].(string)
319  	deleteBranch, _ := req.GetArguments()["delete_branch_after_merge"].(bool)
320  	forceMerge, _ := req.GetArguments()["force_merge"].(bool)
321  	mergeWhenChecks, _ := req.GetArguments()["merge_when_checks_succeed"].(bool)
322  
323  	opt := forgejo_sdk.MergePullRequestOption{
324  		Style:                  forgejo_sdk.MergeStyle(style),
325  		DeleteBranchAfterMerge: deleteBranch,
326  		ForceMerge:             forceMerge,
327  		MergeWhenChecksSucceed: mergeWhenChecks,
328  	}
329  
330  	if title != "" {
331  		opt.Title = title
332  	}
333  	if message != "" {
334  		opt.Message = message
335  	}
336  
337  	_, _, err := forgejo.Client().MergePullRequest(owner, repo, int64(index), opt)
338  	if err != nil {
339  		return to.ErrorResult(fmt.Errorf("merge pull request err: %v", err))
340  	}
341  
342  	result := "Pull request merged successfully"
343  	if mergeWhenChecks {
344  		result = "Pull request scheduled to merge when all checks succeed"
345  	}
346  	return &mcp.CallToolResult{
347  		Content: []mcp.Content{mcp.NewTextContent(result)},
348  	}, nil
349  }
350  
351  func ListPullRequestFilesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
352  	log.Debugf("Called ListPullRequestFilesFn")
353  	owner, _ := req.GetArguments()["owner"].(string)
354  	repo, _ := req.GetArguments()["repo"].(string)
355  	index, _ := to.Float64(req.GetArguments()["index"])
356  	page, _ := to.Float64(req.GetArguments()["page"])
357  	if page == 0 {
358  		page = 1
359  	}
360  	limit, _ := to.Float64(req.GetArguments()["limit"])
361  	if limit == 0 {
362  		limit = 50
363  	}
364  
365  	opt := forgejo_sdk.ListPullRequestFilesOptions{
366  		ListOptions: forgejo_sdk.ListOptions{
367  			Page:     int(page),
368  			PageSize: int(limit),
369  		},
370  	}
371  
372  	files, _, err := forgejo.Client().ListPullRequestFiles(owner, repo, int64(index), opt)
373  	if err != nil {
374  		return to.ErrorResult(fmt.Errorf("list pull request files err: %v", err))
375  	}
376  	return to.TextResult(files)
377  }
378  
379  func GetPullRequestDiffFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
380  	log.Debugf("Called GetPullRequestDiffFn")
381  	owner, _ := req.GetArguments()["owner"].(string)
382  	repo, _ := req.GetArguments()["repo"].(string)
383  	index, _ := to.Float64(req.GetArguments()["index"])
384  
385  	diff, _, err := forgejo.Client().GetPullRequestDiff(owner, repo, int64(index), forgejo_sdk.PullRequestDiffOptions{})
386  	if err != nil {
387  		return to.ErrorResult(fmt.Errorf("get pull request diff err: %v", err))
388  	}
389  	return &mcp.CallToolResult{
390  		Content: []mcp.Content{mcp.NewTextContent(string(diff))},
391  	}, nil
392  }
393  
394  func ListPullReviewCommentsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
395  	log.Debugf("Called ListPullReviewCommentsFn")
396  	owner, _ := req.GetArguments()["owner"].(string)
397  	repo, _ := req.GetArguments()["repo"].(string)
398  	index, _ := to.Float64(req.GetArguments()["index"])
399  	id, _ := to.Float64(req.GetArguments()["id"])
400  
401  	comments, _, err := forgejo.Client().ListPullReviewComments(owner, repo, int64(index), int64(id))
402  	if err != nil {
403  		return to.ErrorResult(fmt.Errorf("list pull review comments err: %v", err))
404  	}
405  	return to.TextResult(comments)
406  }