/ cmd / willow.go
willow.go
  1  // SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2  //
  3  // SPDX-License-Identifier: Apache-2.0
  4  
  5  package main
  6  
  7  import (
  8  	"errors"
  9  	"fmt"
 10  	"log"
 11  	"net/http"
 12  	"os"
 13  	"strconv"
 14  	"sync"
 15  
 16  	"git.sr.ht/~amolith/willow/db"
 17  	"git.sr.ht/~amolith/willow/project"
 18  	"git.sr.ht/~amolith/willow/ws"
 19  
 20  	"github.com/BurntSushi/toml"
 21  	flag "github.com/spf13/pflag"
 22  )
 23  
 24  type (
 25  	Config struct {
 26  		Server server
 27  		DBConn string
 28  		// TODO: Make cache location configurable
 29  		// CacheLocation string
 30  		FetchInterval int
 31  	}
 32  
 33  	server struct {
 34  		Listen string
 35  	}
 36  )
 37  
 38  var (
 39  	flagConfig          = flag.StringP("config", "c", "config.toml", "Path to config file")
 40  	flagAddUser         = flag.StringP("add", "a", "", "Username of account to add")
 41  	flagDeleteUser      = flag.StringP("deleteuser", "d", "", "Username of account to delete")
 42  	flagCheckAuthorised = flag.StringP("validatecredentials", "V", "", "Username of account to check")
 43  	flagListUsers       = flag.BoolP("listusers", "l", false, "List all users")
 44  	flagShowVersion     = flag.BoolP("version", "v", false, "Print Willow's version")
 45  	version             = ""
 46  	config              Config
 47  	req                 = make(chan struct{})
 48  	res                 = make(chan []project.Project)
 49  	manualRefresh       = make(chan struct{})
 50  )
 51  
 52  func main() {
 53  	flag.Parse()
 54  
 55  	if *flagShowVersion {
 56  		fmt.Println(version)
 57  		os.Exit(0)
 58  	}
 59  
 60  	err := checkConfig()
 61  	if err != nil {
 62  		log.Fatalln(err)
 63  	}
 64  
 65  	fmt.Println("Opening database at", config.DBConn)
 66  
 67  	dbConn, err := db.Open(config.DBConn)
 68  	if err != nil {
 69  		fmt.Println("Error opening database:", err)
 70  		os.Exit(1)
 71  	}
 72  
 73  	fmt.Println("Checking whether database needs initialising")
 74  	err = db.InitialiseDatabase(dbConn)
 75  	if err != nil {
 76  		fmt.Println("Error initialising database:", err)
 77  		os.Exit(1)
 78  	}
 79  	fmt.Println("Checking whether there are pending migrations")
 80  	err = db.Migrate(dbConn)
 81  	if err != nil {
 82  		fmt.Println("Error migrating database schema:", err)
 83  		os.Exit(1)
 84  	}
 85  
 86  	if len(*flagAddUser) > 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
 87  		createUser(dbConn, *flagAddUser)
 88  		os.Exit(0)
 89  	} else if len(*flagAddUser) == 0 && len(*flagDeleteUser) > 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
 90  		deleteUser(dbConn, *flagDeleteUser)
 91  		os.Exit(0)
 92  	} else if len(*flagAddUser) == 0 && len(*flagDeleteUser) == 0 && *flagListUsers && len(*flagCheckAuthorised) == 0 {
 93  		listUsers(dbConn)
 94  		os.Exit(0)
 95  	} else if len(*flagAddUser) == 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) > 0 {
 96  		checkAuthorised(dbConn, *flagCheckAuthorised)
 97  		os.Exit(0)
 98  	}
 99  
100  	mu := sync.Mutex{}
101  
102  	fmt.Println("Starting refresh loop")
103  	go project.RefreshLoop(dbConn, &mu, config.FetchInterval, &manualRefresh, &req, &res)
104  
105  	wsHandler := ws.Handler{
106  		DbConn:        dbConn,
107  		Req:           &req,
108  		Res:           &res,
109  		ManualRefresh: &manualRefresh,
110  		Mu:            &mu,
111  		Version:       &version,
112  	}
113  
114  	mux := http.NewServeMux()
115  	mux.HandleFunc("/static/", ws.StaticHandler)
116  	mux.HandleFunc("/new", wsHandler.NewHandler)
117  	mux.HandleFunc("/login", wsHandler.LoginHandler)
118  	mux.HandleFunc("/logout", wsHandler.LogoutHandler)
119  	mux.HandleFunc("/", wsHandler.RootHandler)
120  
121  	httpServer := &http.Server{
122  		Addr:    config.Server.Listen,
123  		Handler: mux,
124  	}
125  
126  	fmt.Println("Starting web server on", config.Server.Listen)
127  	if err := httpServer.ListenAndServe(); errors.Is(err, http.ErrServerClosed) {
128  		fmt.Println("Web server closed")
129  		os.Exit(0)
130  	} else {
131  		fmt.Println(err)
132  		os.Exit(1)
133  	}
134  }
135  
136  func checkConfig() error {
137  	defaultDBConn := "willow.sqlite"
138  	defaultFetchInterval := 3600
139  	defaultListen := "127.0.0.1:1313"
140  
141  	defaultConfig := fmt.Sprintf(`# Path to SQLite database
142  DBConn = "%s"
143  # How often to fetch new releases in seconds
144  ## Minimum is %ds to avoid rate limits and unintentional abuse
145  FetchInterval = %d
146  
147  [Server]
148  # Address to listen on
149  Listen = "%s"`, defaultDBConn, defaultFetchInterval, defaultFetchInterval, defaultListen)
150  
151  	file, err := os.Open(*flagConfig)
152  	if err != nil {
153  		if os.IsNotExist(err) {
154  			file, err = os.Create(*flagConfig)
155  			if err != nil {
156  				return err
157  			}
158  			defer file.Close()
159  
160  			_, err = file.WriteString(defaultConfig)
161  			if err != nil {
162  				return err
163  			}
164  
165  			fmt.Println("Config file created at", *flagConfig)
166  			fmt.Println("Please edit it and restart the server")
167  			os.Exit(0)
168  		} else {
169  			return err
170  		}
171  	}
172  	defer file.Close()
173  
174  	_, err = toml.DecodeFile(*flagConfig, &config)
175  	if err != nil {
176  		return err
177  	}
178  
179  	if config.FetchInterval < defaultFetchInterval {
180  		fmt.Println("Fetch interval is set to", strconv.Itoa(config.FetchInterval), "seconds, but the minimum is", defaultFetchInterval, "seconds, using", strconv.Itoa(defaultFetchInterval)+"s")
181  		config.FetchInterval = defaultFetchInterval
182  	}
183  
184  	if config.Server.Listen == "" {
185  		fmt.Println("No listen address specified, using", defaultListen)
186  		config.Server.Listen = defaultListen
187  	}
188  
189  	if config.DBConn == "" {
190  		fmt.Println("No SQLite path specified, using \"" + defaultDBConn + "\"")
191  		config.DBConn = defaultDBConn
192  	}
193  
194  	return nil
195  }