setup.go
1 package config 2 3 import ( 4 "bufio" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "net/http" 10 "net/url" 11 "os" 12 "strings" 13 "time" 14 15 "golang.org/x/term" 16 ) 17 18 const DefaultEndpoint = "https://api-dev.shannon.run" 19 20 const ( 21 hintCloud = "Get your API key at https://shannon.run" 22 hintLocal = "Running locally? See https://github.com/Kocoro-lab/Shannon for self-hosting docs." 23 ) 24 25 // NeedsSetup returns true if the config has no API key and the endpoint 26 // is not a local address (localhost/127.0.0.1 bypass auth). 27 // Ollama provider never needs gateway setup. 28 func NeedsSetup(cfg *Config) bool { 29 if cfg.Provider == "ollama" { 30 return cfg.Ollama.Model == "" // model required for ollama to be usable 31 } 32 if cfg.APIKey != "" { 33 return false 34 } 35 return !isLocalEndpoint(cfg.Endpoint) 36 } 37 38 // RunSetup runs the interactive setup flow, prompting the user for 39 // provider selection (Shannon Cloud or Ollama) and provider-specific config. 40 func RunSetup(cfg *Config, in io.Reader, out io.Writer) error { 41 reader := bufio.NewReader(in) 42 43 fmt.Fprintln(out, "Shannon CLI Setup") 44 fmt.Fprintln(out) 45 46 // Provider selection 47 fmt.Fprintln(out, "Choose your LLM provider:") 48 fmt.Fprintln(out, " 1) Shannon Cloud") 49 fmt.Fprintln(out, " 2) Local model (Ollama)") 50 fmt.Fprint(out, "Choice [1]: ") 51 choice, _ := reader.ReadString('\n') 52 choice = strings.TrimSpace(choice) 53 54 switch choice { 55 case "2": 56 if err := setupOllama(cfg, reader, out); err != nil { 57 return err 58 } 59 default: 60 if err := setupGateway(cfg, in, reader, out); err != nil { 61 return err 62 } 63 } 64 65 return saveSetup(cfg, out) 66 } 67 68 // setupGateway runs the gateway (Shannon Cloud) setup flow. 69 func setupGateway(cfg *Config, in io.Reader, reader *bufio.Reader, out io.Writer) error { 70 cfg.Provider = "gateway" 71 72 // Endpoint 73 defaultEP := cfg.Endpoint 74 if defaultEP == "" { 75 defaultEP = DefaultEndpoint 76 } 77 fmt.Fprintf(out, "API endpoint [%s]: ", defaultEP) 78 epInput, _ := reader.ReadString('\n') 79 epInput = strings.TrimSpace(epInput) 80 if epInput != "" { 81 cfg.Endpoint = epInput 82 } else { 83 cfg.Endpoint = defaultEP 84 } 85 86 // Contextual hint 87 if isLocalEndpoint(cfg.Endpoint) { 88 fmt.Fprintln(out, hintLocal) 89 } else { 90 fmt.Fprintln(out, hintCloud) 91 } 92 fmt.Fprintln(out) 93 94 // API key + health check with retry (max 3 attempts) 95 const maxAttempts = 3 96 for attempt := 0; attempt < maxAttempts; attempt++ { 97 // Prompt for key 98 if isLocalEndpoint(cfg.Endpoint) { 99 fmt.Fprint(out, "API key (optional for local, Enter to skip): ") 100 } else { 101 fmt.Fprint(out, "API key: ") 102 } 103 104 if f, ok := in.(*os.File); ok && term.IsTerminal(int(f.Fd())) { 105 keyBytes, err := term.ReadPassword(int(f.Fd())) 106 fmt.Fprintln(out) // newline after masked input 107 if err != nil { 108 fmt.Fprintf(out, "Error reading key: %v\n", err) 109 continue 110 } 111 cfg.APIKey = strings.TrimSpace(string(keyBytes)) 112 } else { 113 keyInput, _ := reader.ReadString('\n') 114 cfg.APIKey = strings.TrimSpace(keyInput) 115 } 116 117 // Health check 118 fmt.Fprint(out, "Testing connection... ") 119 if err := checkEndpointHealth(cfg.Endpoint, cfg.APIKey); err != nil { 120 fmt.Fprintf(out, "FAILED (%v)\n", err) 121 122 if attempt == maxAttempts-1 { 123 fmt.Fprintln(out, "Config saved anyway. Re-run 'shan --setup' to fix.") 124 break 125 } 126 fmt.Fprint(out, "Re-enter credentials? [Y/n]: ") 127 ans, _ := reader.ReadString('\n') 128 ans = strings.TrimSpace(strings.ToLower(ans)) 129 if ans == "n" || ans == "no" { 130 fmt.Fprintln(out, "Config saved anyway. Re-run 'shan --setup' to fix.") 131 break 132 } 133 continue 134 } 135 136 fmt.Fprintln(out, "OK") 137 break 138 } 139 140 return nil 141 } 142 143 // setupOllama runs the Ollama local model setup flow. 144 func setupOllama(cfg *Config, reader *bufio.Reader, out io.Writer) error { 145 cfg.Provider = "ollama" 146 147 // Endpoint 148 defaultEP := cfg.Ollama.Endpoint 149 if defaultEP == "" { 150 defaultEP = "http://localhost:11434" 151 } 152 fmt.Fprintf(out, "Ollama endpoint [%s]: ", defaultEP) 153 epInput, _ := reader.ReadString('\n') 154 epInput = strings.TrimSpace(epInput) 155 if epInput != "" { 156 cfg.Ollama.Endpoint = epInput 157 } else { 158 cfg.Ollama.Endpoint = defaultEP 159 } 160 161 // Health check 162 fmt.Fprint(out, "Checking Ollama... ") 163 if err := checkOllamaHealth(cfg.Ollama.Endpoint); err != nil { 164 fmt.Fprintf(out, "FAILED (%v)\n", err) 165 fmt.Fprintln(out, "Config saved anyway. Re-run 'shan --setup' to fix.") 166 return nil 167 } 168 fmt.Fprintln(out, "OK") 169 170 // Fetch and list models 171 models, err := fetchOllamaModels(cfg.Ollama.Endpoint) 172 if err != nil { 173 fmt.Fprintf(out, "Could not list models: %v\n", err) 174 fmt.Fprint(out, "Enter model name manually: ") 175 name, _ := reader.ReadString('\n') 176 cfg.Ollama.Model = strings.TrimSpace(name) 177 return nil 178 } 179 180 if len(models) == 0 { 181 fmt.Fprintln(out, "No models found. Pull a model first: ollama pull <model>") 182 fmt.Fprint(out, "Enter model name manually: ") 183 name, _ := reader.ReadString('\n') 184 cfg.Ollama.Model = strings.TrimSpace(name) 185 return nil 186 } 187 188 fmt.Fprintln(out, "Available models:") 189 for i, m := range models { 190 sizeGB := float64(m.Size) / 1e9 191 paramSize := m.Details.ParameterSize 192 if paramSize != "" { 193 fmt.Fprintf(out, " %d) %s (%s, %.1f GB)\n", i+1, m.Name, paramSize, sizeGB) 194 } else { 195 fmt.Fprintf(out, " %d) %s (%.1f GB)\n", i+1, m.Name, sizeGB) 196 } 197 } 198 fmt.Fprint(out, "Choose model [1]: ") 199 modelChoice, _ := reader.ReadString('\n') 200 modelChoice = strings.TrimSpace(modelChoice) 201 202 idx := 0 // default to first 203 if modelChoice != "" { 204 fmt.Sscanf(modelChoice, "%d", &idx) 205 idx-- // 1-based → 0-based 206 } 207 if idx < 0 || idx >= len(models) { 208 idx = 0 209 } 210 cfg.Ollama.Model = models[idx].Name 211 212 return nil 213 } 214 215 // ollamaModelInfo represents a model entry from the Ollama /api/tags response. 216 type ollamaModelInfo struct { 217 Name string `json:"name"` 218 Size int64 `json:"size"` 219 Details struct { 220 ParameterSize string `json:"parameter_size"` 221 } `json:"details"` 222 } 223 224 // checkOllamaHealth verifies that an Ollama server is reachable. 225 func checkOllamaHealth(endpoint string) error { 226 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 227 defer cancel() 228 229 base := strings.TrimSuffix(endpoint, "/") 230 req, err := http.NewRequestWithContext(ctx, http.MethodGet, base+"/", nil) 231 if err != nil { 232 return err 233 } 234 235 resp, err := http.DefaultClient.Do(req) 236 if err != nil { 237 return fmt.Errorf("unreachable") 238 } 239 defer resp.Body.Close() 240 241 if resp.StatusCode != http.StatusOK { 242 return fmt.Errorf("status %d", resp.StatusCode) 243 } 244 return nil 245 } 246 247 // fetchOllamaModels retrieves the list of locally available models from Ollama. 248 func fetchOllamaModels(endpoint string) ([]ollamaModelInfo, error) { 249 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 250 defer cancel() 251 252 base := strings.TrimSuffix(endpoint, "/") 253 req, err := http.NewRequestWithContext(ctx, http.MethodGet, base+"/api/tags", nil) 254 if err != nil { 255 return nil, err 256 } 257 258 resp, err := http.DefaultClient.Do(req) 259 if err != nil { 260 return nil, fmt.Errorf("unreachable") 261 } 262 defer resp.Body.Close() 263 264 if resp.StatusCode != http.StatusOK { 265 return nil, fmt.Errorf("status %d", resp.StatusCode) 266 } 267 268 var result struct { 269 Models []ollamaModelInfo `json:"models"` 270 } 271 if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 272 return nil, fmt.Errorf("invalid response: %w", err) 273 } 274 return result.Models, nil 275 } 276 277 // saveSetup persists the config to disk. Tolerates Save errors (e.g. in tests 278 // where no real shannon dir exists) by printing a warning instead of failing. 279 func saveSetup(cfg *Config, out io.Writer) error { 280 if err := Save(cfg); err != nil { 281 // Save may fail in test environments (no real config dir) or due to 282 // permission issues. Print a warning so the user knows, but don't 283 // block setup — the in-memory config is still correct for this session. 284 fmt.Fprintf(out, "Warning: could not save config: %v\n", err) 285 } else if dir := ShannonDir(); dir != "" { 286 fmt.Fprintf(out, "Config saved to %s/config.yaml\n", dir) 287 } 288 fmt.Fprintln(out) 289 return nil 290 } 291 292 func isLocalEndpoint(endpoint string) bool { 293 u, err := url.Parse(endpoint) 294 if err != nil { 295 return false 296 } 297 host := strings.ToLower(u.Hostname()) 298 return host == "localhost" || host == "127.0.0.1" || host == "::1" || host == "0.0.0.0" 299 } 300 301 func checkEndpointHealth(endpoint, apiKey string) error { 302 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 303 defer cancel() 304 305 base := strings.TrimSuffix(endpoint, "/") 306 req, err := http.NewRequestWithContext(ctx, http.MethodGet, base+"/health", nil) 307 if err != nil { 308 return err 309 } 310 if apiKey != "" { 311 req.Header.Set("X-API-Key", apiKey) 312 } 313 314 resp, err := http.DefaultClient.Do(req) 315 if err != nil { 316 return fmt.Errorf("unreachable") 317 } 318 defer resp.Body.Close() 319 320 if resp.StatusCode != http.StatusOK { 321 return fmt.Errorf("status %d", resp.StatusCode) 322 } 323 return nil 324 }