/ cmd / cli.go
cli.go
  1  package cmd
  2  
  3  import (
  4  	"context"
  5  	"encoding/json"
  6  	"flag"
  7  	"fmt"
  8  	"io"
  9  	"os"
 10  	"sort"
 11  	"strings"
 12  
 13  	"codeberg.org/goern/forgejo-mcp/v2/operation"
 14  	flagPkg "codeberg.org/goern/forgejo-mcp/v2/pkg/flag"
 15  
 16  	"github.com/mark3labs/mcp-go/mcp"
 17  	"github.com/mark3labs/mcp-go/server"
 18  )
 19  
 20  // cliMode is set to true when --cli is detected in os.Args.
 21  var cliMode bool
 22  
 23  // hasCLIFlag checks os.Args for --cli before flag.Parse() runs.
 24  func hasCLIFlag() bool {
 25  	for _, arg := range os.Args[1:] {
 26  		if arg == "--cli" {
 27  			return true
 28  		}
 29  	}
 30  	return false
 31  }
 32  
 33  // toolDomains maps tool names to their domain for grouped listing.
 34  // Built by registering tools per domain and tracking which names appear.
 35  var toolDomains = map[string]string{}
 36  
 37  // registerToolsWithDomains registers all tools and builds the domain mapping.
 38  func registerToolsWithDomains(s *server.MCPServer) {
 39  	beforeNames := toolNames(s)
 40  
 41  	type domainReg struct {
 42  		name string
 43  		fn   func(*server.MCPServer)
 44  	}
 45  	domains := []domainReg{
 46  		{"user", operation.RegisterUserTool},
 47  		{"repo", operation.RegisterRepoTool},
 48  		{"issue", operation.RegisterIssueTool},
 49  		{"pull", operation.RegisterPullTool},
 50  		{"pull", operation.RegisterPullReviewTool},
 51  		{"search", operation.RegisterSearchTool},
 52  		{"version", operation.RegisterVersionTool},
 53  		{"actions", operation.RegisterActionsTool},
 54  		{"org", operation.RegisterOrgTool},
 55  	}
 56  
 57  	for _, d := range domains {
 58  		d.fn(s)
 59  		afterNames := toolNames(s)
 60  		for _, name := range afterNames {
 61  			if !contains(beforeNames, name) {
 62  				toolDomains[name] = d.name
 63  			}
 64  		}
 65  		beforeNames = afterNames
 66  	}
 67  }
 68  
 69  func toolNames(s *server.MCPServer) []string {
 70  	tools := s.ListTools()
 71  	names := make([]string, 0, len(tools))
 72  	for name := range tools {
 73  		names = append(names, name)
 74  	}
 75  	return names
 76  }
 77  
 78  func contains(ss []string, s string) bool {
 79  	for _, v := range ss {
 80  		if v == s {
 81  			return true
 82  		}
 83  	}
 84  	return false
 85  }
 86  
 87  // RunCLI is the entry point for --cli mode.
 88  func RunCLI(version string) {
 89  	// Parse CLI-specific flags using a separate FlagSet.
 90  	fs := flag.NewFlagSet("cli", flag.ExitOnError)
 91  	argsFlag := fs.String("args", "", "JSON arguments for tool invocation")
 92  	outputFlag := fs.String("output", "", "Output format: json or text")
 93  	helpFlag := fs.Bool("help", false, "Show tool parameter help")
 94  
 95  	// Find the positional command (first non-flag arg after --cli).
 96  	// os.Args has been filtered by init() to remove --cli and preceding flags.
 97  	cliArgs := cliArgsToParse()
 98  	if len(cliArgs) == 0 {
 99  		fmt.Fprintln(os.Stderr, "Usage: forgejo-mcp --cli <command> [options]")
100  		fmt.Fprintln(os.Stderr, "Commands: list, <tool-name>")
101  		fmt.Fprintln(os.Stderr, "Options: --args '{json}', --output=json|text, --help")
102  		os.Exit(1)
103  	}
104  
105  	command := cliArgs[0]
106  	_ = fs.Parse(cliArgs[1:])
107  
108  	// Build the MCPServer and register tools with domain tracking.
109  	flagPkg.Version = version
110  	mcpSrv := server.NewMCPServer("Forgejo MCP Server", version, server.WithLogging())
111  	registerToolsWithDomains(mcpSrv)
112  
113  	switch command {
114  	case "list":
115  		outputMode := *outputFlag
116  		if outputMode == "" {
117  			outputMode = "text"
118  		}
119  		if err := cliList(mcpSrv, outputMode); err != nil {
120  			fmt.Fprintf(os.Stderr, "Error: %v\n", err)
121  			os.Exit(1)
122  		}
123  	default:
124  		if *helpFlag {
125  			if err := cliHelp(mcpSrv, command); err != nil {
126  				fmt.Fprintf(os.Stderr, "Error: %v\n", err)
127  				os.Exit(1)
128  			}
129  			return
130  		}
131  
132  		outputMode := *outputFlag
133  		if outputMode == "" {
134  			outputMode = "json"
135  		}
136  
137  		argsJSON, err := resolveArgs(*argsFlag)
138  		if err != nil {
139  			fmt.Fprintf(os.Stderr, "Error reading arguments: %v\n", err)
140  			os.Exit(1)
141  		}
142  
143  		if err := cliExec(mcpSrv, command, argsJSON, outputMode); err != nil {
144  			fmt.Fprintf(os.Stderr, "Error: %v\n", err)
145  			os.Exit(1)
146  		}
147  	}
148  }
149  
150  // cliArgsToParse extracts the args after --cli from os.Args.
151  func cliArgsToParse() []string {
152  	for i, arg := range os.Args {
153  		if arg == "--cli" {
154  			return os.Args[i+1:]
155  		}
156  	}
157  	return nil
158  }
159  
160  // resolveArgs returns JSON args from --args flag or stdin pipe.
161  // --args takes precedence. If neither provided, returns "{}".
162  func resolveArgs(argsFlag string) (string, error) {
163  	if argsFlag != "" {
164  		return argsFlag, nil
165  	}
166  
167  	// Check if stdin is a pipe (not a terminal).
168  	stat, err := os.Stdin.Stat()
169  	if err != nil {
170  		return "{}", nil
171  	}
172  	if (stat.Mode() & os.ModeCharDevice) == 0 {
173  		data, err := io.ReadAll(os.Stdin)
174  		if err != nil {
175  			return "", fmt.Errorf("reading stdin: %w", err)
176  		}
177  		if len(data) > 0 {
178  			return string(data), nil
179  		}
180  	}
181  
182  	return "{}", nil
183  }
184  
185  // cliList prints all registered tools.
186  func cliList(s *server.MCPServer, outputMode string) error {
187  	tools := s.ListTools()
188  	if tools == nil {
189  		fmt.Println("No tools registered.")
190  		return nil
191  	}
192  
193  	type toolInfo struct {
194  		Name        string `json:"name"`
195  		Description string `json:"description"`
196  		Domain      string `json:"domain"`
197  	}
198  
199  	// Build sorted list.
200  	var infos []toolInfo
201  	for name, st := range tools {
202  		domain := toolDomains[name]
203  		if domain == "" {
204  			domain = "other"
205  		}
206  		infos = append(infos, toolInfo{
207  			Name:        name,
208  			Description: st.Tool.Description,
209  			Domain:      domain,
210  		})
211  	}
212  	sort.Slice(infos, func(i, j int) bool {
213  		if infos[i].Domain != infos[j].Domain {
214  			return infos[i].Domain < infos[j].Domain
215  		}
216  		return infos[i].Name < infos[j].Name
217  	})
218  
219  	if outputMode == "json" {
220  		enc := json.NewEncoder(os.Stdout)
221  		enc.SetIndent("", "  ")
222  		return enc.Encode(infos)
223  	}
224  
225  	// Text mode: grouped by domain.
226  	grouped := map[string][]toolInfo{}
227  	domainOrder := []string{}
228  	for _, info := range infos {
229  		if _, exists := grouped[info.Domain]; !exists {
230  			domainOrder = append(domainOrder, info.Domain)
231  		}
232  		grouped[info.Domain] = append(grouped[info.Domain], info)
233  	}
234  
235  	for _, domain := range domainOrder {
236  		fmt.Printf("\n%s:\n", strings.ToUpper(domain))
237  		for _, info := range grouped[domain] {
238  			fmt.Printf("  %-40s %s\n", info.Name, info.Description)
239  		}
240  	}
241  	fmt.Println()
242  
243  	return nil
244  }
245  
246  // cliHelp prints the parameter schema for a tool.
247  func cliHelp(s *server.MCPServer, toolName string) error {
248  	st := s.GetTool(toolName)
249  	if st == nil {
250  		return fmt.Errorf("unknown tool: %s", toolName)
251  	}
252  
253  	fmt.Printf("Tool: %s\n", st.Tool.Name)
254  	if st.Tool.Description != "" {
255  		fmt.Printf("Description: %s\n", st.Tool.Description)
256  	}
257  	fmt.Println()
258  
259  	props := st.Tool.InputSchema.Properties
260  	required := st.Tool.InputSchema.Required
261  
262  	if len(props) == 0 {
263  		fmt.Println("No parameters.")
264  		return nil
265  	}
266  
267  	fmt.Println("Parameters:")
268  	// Sort parameter names for consistent output.
269  	names := make([]string, 0, len(props))
270  	for name := range props {
271  		names = append(names, name)
272  	}
273  	sort.Strings(names)
274  
275  	requiredSet := map[string]bool{}
276  	for _, r := range required {
277  		requiredSet[r] = true
278  	}
279  
280  	for _, name := range names {
281  		prop := props[name]
282  		reqStr := "optional"
283  		if requiredSet[name] {
284  			reqStr = "required"
285  		}
286  
287  		// Property is stored as map[string]any.
288  		propMap, ok := prop.(map[string]any)
289  		if !ok {
290  			fmt.Printf("  %-20s (%s)\n", name, reqStr)
291  			continue
292  		}
293  
294  		typStr, _ := propMap["type"].(string)
295  		desc, _ := propMap["description"].(string)
296  
297  		fmt.Printf("  %-20s %-10s %-10s %s\n", name, typStr, reqStr, desc)
298  	}
299  
300  	return nil
301  }
302  
303  // cliExec invokes a tool handler and prints the result.
304  func cliExec(s *server.MCPServer, toolName, argsJSON, outputMode string) error {
305  	st := s.GetTool(toolName)
306  	if st == nil {
307  		return fmt.Errorf("unknown tool: %s\nRun 'forgejo-mcp --cli list' to see available tools", toolName)
308  	}
309  
310  	// Parse JSON arguments.
311  	var args map[string]any
312  	if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
313  		return fmt.Errorf("invalid JSON arguments: %w", err)
314  	}
315  
316  	// Construct CallToolRequest.
317  	req := mcp.CallToolRequest{
318  		Params: mcp.CallToolParams{
319  			Name:      toolName,
320  			Arguments: args,
321  		},
322  	}
323  
324  	// Call the handler.
325  	result, err := st.Handler(context.Background(), req)
326  	if err != nil {
327  		return fmt.Errorf("tool execution failed: %w", err)
328  	}
329  
330  	// Check IsError flag.
331  	if result.IsError {
332  		if outputMode == "json" {
333  			enc := json.NewEncoder(os.Stderr)
334  			enc.SetIndent("", "  ")
335  			_ = enc.Encode(result.Content)
336  		} else {
337  			for _, c := range result.Content {
338  				if tc, ok := c.(mcp.TextContent); ok {
339  					fmt.Fprintln(os.Stderr, tc.Text)
340  				}
341  			}
342  		}
343  		os.Exit(1)
344  	}
345  
346  	// Output result.
347  	if outputMode == "json" {
348  		enc := json.NewEncoder(os.Stdout)
349  		enc.SetIndent("", "  ")
350  		return enc.Encode(result.Content)
351  	}
352  
353  	// Text mode: print text content line by line.
354  	for _, c := range result.Content {
355  		if tc, ok := c.(mcp.TextContent); ok {
356  			fmt.Println(tc.Text)
357  		}
358  	}
359  
360  	return nil
361  }