/ internal / config / setup.go
setup.go
  1  package config
  2  
  3  import (
  4  	"bufio"
  5  	"context"
  6  	"encoding/json"
  7  	"fmt"
  8  	"io"
  9  	"net/http"
 10  	"net/url"
 11  	"os"
 12  	"strings"
 13  	"time"
 14  
 15  	"golang.org/x/term"
 16  )
 17  
 18  const DefaultEndpoint = "https://api-dev.shannon.run"
 19  
 20  const (
 21  	hintCloud = "Get your API key at https://shannon.run"
 22  	hintLocal = "Running locally? See https://github.com/Kocoro-lab/Shannon for self-hosting docs."
 23  )
 24  
 25  // NeedsSetup returns true if the config has no API key and the endpoint
 26  // is not a local address (localhost/127.0.0.1 bypass auth).
 27  // Ollama provider never needs gateway setup.
 28  func NeedsSetup(cfg *Config) bool {
 29  	if cfg.Provider == "ollama" {
 30  		return cfg.Ollama.Model == "" // model required for ollama to be usable
 31  	}
 32  	if cfg.APIKey != "" {
 33  		return false
 34  	}
 35  	return !isLocalEndpoint(cfg.Endpoint)
 36  }
 37  
 38  // RunSetup runs the interactive setup flow, prompting the user for
 39  // provider selection (Shannon Cloud or Ollama) and provider-specific config.
 40  func RunSetup(cfg *Config, in io.Reader, out io.Writer) error {
 41  	reader := bufio.NewReader(in)
 42  
 43  	fmt.Fprintln(out, "Shannon CLI Setup")
 44  	fmt.Fprintln(out)
 45  
 46  	// Provider selection
 47  	fmt.Fprintln(out, "Choose your LLM provider:")
 48  	fmt.Fprintln(out, "  1) Shannon Cloud")
 49  	fmt.Fprintln(out, "  2) Local model (Ollama)")
 50  	fmt.Fprint(out, "Choice [1]: ")
 51  	choice, _ := reader.ReadString('\n')
 52  	choice = strings.TrimSpace(choice)
 53  
 54  	switch choice {
 55  	case "2":
 56  		if err := setupOllama(cfg, reader, out); err != nil {
 57  			return err
 58  		}
 59  	default:
 60  		if err := setupGateway(cfg, in, reader, out); err != nil {
 61  			return err
 62  		}
 63  	}
 64  
 65  	return saveSetup(cfg, out)
 66  }
 67  
 68  // setupGateway runs the gateway (Shannon Cloud) setup flow.
 69  func setupGateway(cfg *Config, in io.Reader, reader *bufio.Reader, out io.Writer) error {
 70  	cfg.Provider = "gateway"
 71  
 72  	// Endpoint
 73  	defaultEP := cfg.Endpoint
 74  	if defaultEP == "" {
 75  		defaultEP = DefaultEndpoint
 76  	}
 77  	fmt.Fprintf(out, "API endpoint [%s]: ", defaultEP)
 78  	epInput, _ := reader.ReadString('\n')
 79  	epInput = strings.TrimSpace(epInput)
 80  	if epInput != "" {
 81  		cfg.Endpoint = epInput
 82  	} else {
 83  		cfg.Endpoint = defaultEP
 84  	}
 85  
 86  	// Contextual hint
 87  	if isLocalEndpoint(cfg.Endpoint) {
 88  		fmt.Fprintln(out, hintLocal)
 89  	} else {
 90  		fmt.Fprintln(out, hintCloud)
 91  	}
 92  	fmt.Fprintln(out)
 93  
 94  	// API key + health check with retry (max 3 attempts)
 95  	const maxAttempts = 3
 96  	for attempt := 0; attempt < maxAttempts; attempt++ {
 97  		// Prompt for key
 98  		if isLocalEndpoint(cfg.Endpoint) {
 99  			fmt.Fprint(out, "API key (optional for local, Enter to skip): ")
100  		} else {
101  			fmt.Fprint(out, "API key: ")
102  		}
103  
104  		if f, ok := in.(*os.File); ok && term.IsTerminal(int(f.Fd())) {
105  			keyBytes, err := term.ReadPassword(int(f.Fd()))
106  			fmt.Fprintln(out) // newline after masked input
107  			if err != nil {
108  				fmt.Fprintf(out, "Error reading key: %v\n", err)
109  				continue
110  			}
111  			cfg.APIKey = strings.TrimSpace(string(keyBytes))
112  		} else {
113  			keyInput, _ := reader.ReadString('\n')
114  			cfg.APIKey = strings.TrimSpace(keyInput)
115  		}
116  
117  		// Health check
118  		fmt.Fprint(out, "Testing connection... ")
119  		if err := checkEndpointHealth(cfg.Endpoint, cfg.APIKey); err != nil {
120  			fmt.Fprintf(out, "FAILED (%v)\n", err)
121  
122  			if attempt == maxAttempts-1 {
123  				fmt.Fprintln(out, "Config saved anyway. Re-run 'shan --setup' to fix.")
124  				break
125  			}
126  			fmt.Fprint(out, "Re-enter credentials? [Y/n]: ")
127  			ans, _ := reader.ReadString('\n')
128  			ans = strings.TrimSpace(strings.ToLower(ans))
129  			if ans == "n" || ans == "no" {
130  				fmt.Fprintln(out, "Config saved anyway. Re-run 'shan --setup' to fix.")
131  				break
132  			}
133  			continue
134  		}
135  
136  		fmt.Fprintln(out, "OK")
137  		break
138  	}
139  
140  	return nil
141  }
142  
143  // setupOllama runs the Ollama local model setup flow.
144  func setupOllama(cfg *Config, reader *bufio.Reader, out io.Writer) error {
145  	cfg.Provider = "ollama"
146  
147  	// Endpoint
148  	defaultEP := cfg.Ollama.Endpoint
149  	if defaultEP == "" {
150  		defaultEP = "http://localhost:11434"
151  	}
152  	fmt.Fprintf(out, "Ollama endpoint [%s]: ", defaultEP)
153  	epInput, _ := reader.ReadString('\n')
154  	epInput = strings.TrimSpace(epInput)
155  	if epInput != "" {
156  		cfg.Ollama.Endpoint = epInput
157  	} else {
158  		cfg.Ollama.Endpoint = defaultEP
159  	}
160  
161  	// Health check
162  	fmt.Fprint(out, "Checking Ollama... ")
163  	if err := checkOllamaHealth(cfg.Ollama.Endpoint); err != nil {
164  		fmt.Fprintf(out, "FAILED (%v)\n", err)
165  		fmt.Fprintln(out, "Config saved anyway. Re-run 'shan --setup' to fix.")
166  		return nil
167  	}
168  	fmt.Fprintln(out, "OK")
169  
170  	// Fetch and list models
171  	models, err := fetchOllamaModels(cfg.Ollama.Endpoint)
172  	if err != nil {
173  		fmt.Fprintf(out, "Could not list models: %v\n", err)
174  		fmt.Fprint(out, "Enter model name manually: ")
175  		name, _ := reader.ReadString('\n')
176  		cfg.Ollama.Model = strings.TrimSpace(name)
177  		return nil
178  	}
179  
180  	if len(models) == 0 {
181  		fmt.Fprintln(out, "No models found. Pull a model first: ollama pull <model>")
182  		fmt.Fprint(out, "Enter model name manually: ")
183  		name, _ := reader.ReadString('\n')
184  		cfg.Ollama.Model = strings.TrimSpace(name)
185  		return nil
186  	}
187  
188  	fmt.Fprintln(out, "Available models:")
189  	for i, m := range models {
190  		sizeGB := float64(m.Size) / 1e9
191  		paramSize := m.Details.ParameterSize
192  		if paramSize != "" {
193  			fmt.Fprintf(out, "  %d) %s (%s, %.1f GB)\n", i+1, m.Name, paramSize, sizeGB)
194  		} else {
195  			fmt.Fprintf(out, "  %d) %s (%.1f GB)\n", i+1, m.Name, sizeGB)
196  		}
197  	}
198  	fmt.Fprint(out, "Choose model [1]: ")
199  	modelChoice, _ := reader.ReadString('\n')
200  	modelChoice = strings.TrimSpace(modelChoice)
201  
202  	idx := 0 // default to first
203  	if modelChoice != "" {
204  		fmt.Sscanf(modelChoice, "%d", &idx)
205  		idx-- // 1-based → 0-based
206  	}
207  	if idx < 0 || idx >= len(models) {
208  		idx = 0
209  	}
210  	cfg.Ollama.Model = models[idx].Name
211  
212  	return nil
213  }
214  
215  // ollamaModelInfo represents a model entry from the Ollama /api/tags response.
216  type ollamaModelInfo struct {
217  	Name    string `json:"name"`
218  	Size    int64  `json:"size"`
219  	Details struct {
220  		ParameterSize string `json:"parameter_size"`
221  	} `json:"details"`
222  }
223  
224  // checkOllamaHealth verifies that an Ollama server is reachable.
225  func checkOllamaHealth(endpoint string) error {
226  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
227  	defer cancel()
228  
229  	base := strings.TrimSuffix(endpoint, "/")
230  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, base+"/", nil)
231  	if err != nil {
232  		return err
233  	}
234  
235  	resp, err := http.DefaultClient.Do(req)
236  	if err != nil {
237  		return fmt.Errorf("unreachable")
238  	}
239  	defer resp.Body.Close()
240  
241  	if resp.StatusCode != http.StatusOK {
242  		return fmt.Errorf("status %d", resp.StatusCode)
243  	}
244  	return nil
245  }
246  
247  // fetchOllamaModels retrieves the list of locally available models from Ollama.
248  func fetchOllamaModels(endpoint string) ([]ollamaModelInfo, error) {
249  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
250  	defer cancel()
251  
252  	base := strings.TrimSuffix(endpoint, "/")
253  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, base+"/api/tags", nil)
254  	if err != nil {
255  		return nil, err
256  	}
257  
258  	resp, err := http.DefaultClient.Do(req)
259  	if err != nil {
260  		return nil, fmt.Errorf("unreachable")
261  	}
262  	defer resp.Body.Close()
263  
264  	if resp.StatusCode != http.StatusOK {
265  		return nil, fmt.Errorf("status %d", resp.StatusCode)
266  	}
267  
268  	var result struct {
269  		Models []ollamaModelInfo `json:"models"`
270  	}
271  	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
272  		return nil, fmt.Errorf("invalid response: %w", err)
273  	}
274  	return result.Models, nil
275  }
276  
277  // saveSetup persists the config to disk. Tolerates Save errors (e.g. in tests
278  // where no real shannon dir exists) by printing a warning instead of failing.
279  func saveSetup(cfg *Config, out io.Writer) error {
280  	if err := Save(cfg); err != nil {
281  		// Save may fail in test environments (no real config dir) or due to
282  		// permission issues. Print a warning so the user knows, but don't
283  		// block setup — the in-memory config is still correct for this session.
284  		fmt.Fprintf(out, "Warning: could not save config: %v\n", err)
285  	} else if dir := ShannonDir(); dir != "" {
286  		fmt.Fprintf(out, "Config saved to %s/config.yaml\n", dir)
287  	}
288  	fmt.Fprintln(out)
289  	return nil
290  }
291  
292  func isLocalEndpoint(endpoint string) bool {
293  	u, err := url.Parse(endpoint)
294  	if err != nil {
295  		return false
296  	}
297  	host := strings.ToLower(u.Hostname())
298  	return host == "localhost" || host == "127.0.0.1" || host == "::1" || host == "0.0.0.0"
299  }
300  
301  func checkEndpointHealth(endpoint, apiKey string) error {
302  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
303  	defer cancel()
304  
305  	base := strings.TrimSuffix(endpoint, "/")
306  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, base+"/health", nil)
307  	if err != nil {
308  		return err
309  	}
310  	if apiKey != "" {
311  		req.Header.Set("X-API-Key", apiKey)
312  	}
313  
314  	resp, err := http.DefaultClient.Do(req)
315  	if err != nil {
316  		return fmt.Errorf("unreachable")
317  	}
318  	defer resp.Body.Close()
319  
320  	if resp.StatusCode != http.StatusOK {
321  		return fmt.Errorf("status %d", resp.StatusCode)
322  	}
323  	return nil
324  }