/ ws / ws.go
ws.go
  1  // SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2  //
  3  // SPDX-License-Identifier: Apache-2.0
  4  
  5  package ws
  6  
  7  import (
  8  	"database/sql"
  9  	"embed"
 10  	"fmt"
 11  	"io"
 12  	"net/http"
 13  	"net/url"
 14  	"strings"
 15  	"sync"
 16  	"text/template"
 17  	"time"
 18  
 19  	"git.sr.ht/~amolith/willow/project"
 20  	"git.sr.ht/~amolith/willow/users"
 21  	"github.com/microcosm-cc/bluemonday"
 22  )
 23  
 24  type Handler struct {
 25  	DbConn        *sql.DB
 26  	Req           *chan struct{}
 27  	ManualRefresh *chan struct{}
 28  	Res           *chan []project.Project
 29  	Mu            *sync.Mutex
 30  	Version       *string
 31  }
 32  
 33  //go:embed static
 34  var fs embed.FS
 35  
 36  // bmUGC    = bluemonday.UGCPolicy()
 37  var bmStrict = bluemonday.StrictPolicy()
 38  
 39  func (h Handler) RootHandler(w http.ResponseWriter, r *http.Request) {
 40  	if !h.isAuthorised(r) {
 41  		http.Redirect(w, r, "/login", http.StatusSeeOther)
 42  		return
 43  	}
 44  	projectsWithReleases, err := project.GetProjectsWithReleases(h.DbConn, h.Mu)
 45  	if err != nil {
 46  		fmt.Println(err)
 47  		w.WriteHeader(http.StatusInternalServerError)
 48  		_, err := w.Write([]byte("Internal Server Error"))
 49  		if err != nil {
 50  			fmt.Println(err)
 51  		}
 52  		return
 53  	}
 54  
 55  	data := struct {
 56  		Version     string
 57  		Projects    []project.Project
 58  		IsDashboard bool
 59  	}{
 60  		Version:     *h.Version,
 61  		Projects:    projectsWithReleases,
 62  		IsDashboard: true,
 63  	}
 64  
 65  	tmpl := template.Must(template.ParseFS(fs, "static/dashboard.html.tmpl", "static/head.html.tmpl", "static/header.html.tmpl", "static/footer.html.tmpl"))
 66  	if err := tmpl.Execute(w, data); err != nil {
 67  		fmt.Println(err)
 68  	}
 69  }
 70  
 71  func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) {
 72  	if !h.isAuthorised(r) {
 73  		http.Redirect(w, r, "/login", http.StatusSeeOther)
 74  		return
 75  	}
 76  	params := r.URL.Query()
 77  	action := bmStrict.Sanitize(params.Get("action"))
 78  	if r.Method == http.MethodGet {
 79  		if action == "" {
 80  			data := struct{ Version string }{Version: *h.Version}
 81  			tmpl := template.Must(template.ParseFS(fs, "static/new.html.tmpl", "static/head.html.tmpl", "static/header.html.tmpl", "static/footer.html.tmpl"))
 82  			if err := tmpl.Execute(w, data); err != nil {
 83  				fmt.Println(err)
 84  			}
 85  		} else if action != "delete" {
 86  			submittedURL := bmStrict.Sanitize(params.Get("url"))
 87  			if submittedURL == "" {
 88  				w.WriteHeader(http.StatusBadRequest)
 89  				_, err := w.Write([]byte("No URL provided"))
 90  				if err != nil {
 91  					fmt.Println(err)
 92  				}
 93  				return
 94  			}
 95  
 96  			forge := bmStrict.Sanitize(params.Get("forge"))
 97  			if forge == "" {
 98  				w.WriteHeader(http.StatusBadRequest)
 99  				_, err := w.Write([]byte("No forge provided"))
100  				if err != nil {
101  					fmt.Println(err)
102  				}
103  				return
104  			}
105  
106  			name := bmStrict.Sanitize(params.Get("name"))
107  			if name == "" {
108  				w.WriteHeader(http.StatusBadRequest)
109  				_, err := w.Write([]byte("No name provided"))
110  				if err != nil {
111  					fmt.Println(err)
112  				}
113  				return
114  			}
115  
116  			proj := project.Project{
117  				ID:    project.GenProjectID(submittedURL, name, forge),
118  				URL:   submittedURL,
119  				Name:  name,
120  				Forge: forge,
121  			}
122  
123  			proj, err := project.GetProject(h.DbConn, proj)
124  			if err != nil && err != sql.ErrNoRows {
125  				w.WriteHeader(http.StatusBadRequest)
126  				_, err := w.Write([]byte(fmt.Sprintf("Error getting project: %s", err)))
127  				if err != nil {
128  					fmt.Println(err)
129  				}
130  				return
131  			}
132  
133  			proj, err = project.GetReleases(h.DbConn, h.Mu, proj)
134  			if err != nil {
135  				w.WriteHeader(http.StatusBadRequest)
136  				_, err := w.Write([]byte(fmt.Sprintf("Error getting releases: %s", err)))
137  				if err != nil {
138  					fmt.Println(err)
139  				}
140  				return
141  			}
142  
143  			data := struct {
144  				Version string
145  				Project project.Project
146  			}{
147  				Version: *h.Version,
148  				Project: proj,
149  			}
150  
151  			tmpl := template.Must(template.ParseFS(fs, "static/select-release.html.tmpl", "static/head.html.tmpl", "static/header.html.tmpl", "static/footer.html.tmpl"))
152  			if err := tmpl.Execute(w, data); err != nil {
153  				fmt.Println(err)
154  			}
155  		} else if action == "delete" {
156  			submittedID := params.Get("id")
157  			if submittedID == "" {
158  				w.WriteHeader(http.StatusBadRequest)
159  				_, err := w.Write([]byte("No URL provided"))
160  				if err != nil {
161  					fmt.Println(err)
162  				}
163  				return
164  			}
165  
166  			project.Untrack(h.DbConn, h.Mu, submittedID)
167  			http.Redirect(w, r, "/", http.StatusSeeOther)
168  		}
169  	}
170  
171  	if r.Method == http.MethodPost {
172  		err := r.ParseForm()
173  		if err != nil {
174  			fmt.Println(err)
175  		}
176  		idValue := bmStrict.Sanitize(r.FormValue("id"))
177  		nameValue := bmStrict.Sanitize(r.FormValue("name"))
178  		urlValue := bmStrict.Sanitize(r.FormValue("url"))
179  		forgeValue := bmStrict.Sanitize(r.FormValue("forge"))
180  		releaseValue := bmStrict.Sanitize(r.FormValue("release"))
181  
182  		// If releaseValue is not empty, we're updating an existing project
183  		if idValue != "" && nameValue != "" && urlValue != "" && forgeValue != "" && releaseValue != "" {
184  			project.Track(h.DbConn, h.Mu, h.ManualRefresh, nameValue, urlValue, forgeValue, releaseValue)
185  			http.Redirect(w, r, "/", http.StatusSeeOther)
186  			return
187  		}
188  
189  		// If releaseValue is empty, we're creating a new project
190  		if idValue == "" && nameValue != "" && urlValue != "" && forgeValue != "" && releaseValue == "" {
191  			http.Redirect(w, r, "/new?action=yoink&name="+url.QueryEscape(nameValue)+"&url="+url.QueryEscape(urlValue)+"&forge="+url.QueryEscape(forgeValue), http.StatusSeeOther)
192  			return
193  		}
194  
195  		w.WriteHeader(http.StatusBadRequest)
196  		_, err = w.Write([]byte("No data provided"))
197  		if err != nil {
198  			fmt.Println(err)
199  		}
200  	}
201  }
202  
203  func (h Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
204  	if r.Method == http.MethodGet {
205  		if h.isAuthorised(r) {
206  			http.Redirect(w, r, "/", http.StatusSeeOther)
207  			return
208  		}
209  
210  		data := struct {
211  			Version string
212  		}{
213  			Version: *h.Version,
214  		}
215  		tmpl := template.Must(template.ParseFS(fs, "static/login.html.tmpl", "static/head.html.tmpl", "static/footer.html.tmpl"))
216  		if err := tmpl.Execute(w, data); err != nil {
217  			fmt.Println(err)
218  		}
219  	}
220  
221  	if r.Method == http.MethodPost {
222  		err := r.ParseForm()
223  		if err != nil {
224  			fmt.Println(err)
225  		}
226  		username := bmStrict.Sanitize(r.FormValue("username"))
227  		password := bmStrict.Sanitize(r.FormValue("password"))
228  
229  		if username == "" || password == "" {
230  			w.WriteHeader(http.StatusBadRequest)
231  			_, err := w.Write([]byte("No data provided"))
232  			if err != nil {
233  				fmt.Println(err)
234  			}
235  			return
236  		}
237  
238  		authorised, err := users.UserAuthorised(h.DbConn, username, password)
239  		if err != nil {
240  			w.WriteHeader(http.StatusBadRequest)
241  			_, err := w.Write([]byte(fmt.Sprintf("Error logging in: %s", err)))
242  			if err != nil {
243  				fmt.Println(err)
244  			}
245  			return
246  		}
247  
248  		if !authorised {
249  			w.WriteHeader(http.StatusUnauthorized)
250  			_, err := w.Write([]byte("Incorrect username or password"))
251  			if err != nil {
252  				fmt.Println(err)
253  			}
254  			return
255  		}
256  
257  		session, expiry, err := users.CreateSession(h.DbConn, username)
258  		if err != nil {
259  			w.WriteHeader(http.StatusBadRequest)
260  			_, err := w.Write([]byte(fmt.Sprintf("Error creating session: %s", err)))
261  			if err != nil {
262  				fmt.Println(err)
263  			}
264  			return
265  		}
266  
267  		maxAge := int(time.Until(expiry))
268  
269  		cookie := http.Cookie{
270  			Name:     "id",
271  			Value:    session,
272  			MaxAge:   maxAge,
273  			HttpOnly: true,
274  			SameSite: http.SameSiteStrictMode,
275  			Secure:   true,
276  		}
277  
278  		http.SetCookie(w, &cookie)
279  		http.Redirect(w, r, "/", http.StatusSeeOther)
280  	}
281  }
282  
283  func (h Handler) LogoutHandler(w http.ResponseWriter, r *http.Request) {
284  	cookie, err := r.Cookie("id")
285  	if err != nil {
286  		fmt.Println(err)
287  	}
288  
289  	err = users.InvalidateSession(h.DbConn, cookie.Value)
290  	if err != nil {
291  		fmt.Println(err)
292  		_, err = w.Write([]byte(fmt.Sprintf("Error logging out: %s", err)))
293  		if err != nil {
294  			fmt.Println(err)
295  		}
296  		return
297  	}
298  	cookie.MaxAge = -1
299  	http.SetCookie(w, cookie)
300  	http.Redirect(w, r, "/login", http.StatusSeeOther)
301  }
302  
303  // isAuthorised makes a database request to the sessions table to see if the
304  // user has a valid session cookie.
305  func (h Handler) isAuthorised(r *http.Request) bool {
306  	cookie, err := r.Cookie("id")
307  	if err != nil {
308  		return false
309  	}
310  
311  	authorised, err := users.SessionAuthorised(h.DbConn, cookie.Value)
312  	if err != nil {
313  		fmt.Println("Error checking session:", err)
314  		return false
315  	}
316  
317  	return authorised
318  }
319  
320  func StaticHandler(writer http.ResponseWriter, request *http.Request) {
321  	resource := strings.TrimPrefix(request.URL.Path, "/")
322  	if strings.HasSuffix(resource, ".css") {
323  		writer.Header().Set("Content-Type", "text/css")
324  	} else if strings.HasSuffix(resource, ".js") {
325  		writer.Header().Set("Content-Type", "text/javascript")
326  	}
327  	home, err := fs.ReadFile(resource)
328  	if err != nil {
329  		fmt.Println(err)
330  	}
331  	if _, err = io.Writer.Write(writer, home); err != nil {
332  		fmt.Println(err)
333  	}
334  }