/ db / users.go
users.go
 1  // SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
 2  //
 3  // SPDX-License-Identifier: Apache-2.0
 4  
 5  package db
 6  
 7  import (
 8  	"database/sql"
 9  	"time"
10  )
11  
12  // DeleteUser deletes specific user from the database and returns an error if it
13  // fails
14  func DeleteUser(db *sql.DB, user string) error {
15  	mutex.Lock()
16  	defer mutex.Unlock()
17  	_, err := db.Exec("DELETE FROM users WHERE username = ?", user)
18  	return err
19  }
20  
21  // CreateUser creates a new user in the database and returns an error if it fails
22  func CreateUser(db *sql.DB, username, hash, salt string) error {
23  	mutex.Lock()
24  	defer mutex.Unlock()
25  	_, err := db.Exec("INSERT INTO users (username, hash, salt) VALUES (?, ?, ?)", username, hash, salt)
26  	return err
27  }
28  
29  // GetUser returns a user's hash and salt from the database as strings and
30  // returns an error if it fails
31  func GetUser(db *sql.DB, username string) (string, string, error) {
32  	var hash, salt string
33  	err := db.QueryRow("SELECT hash, salt FROM users WHERE username = ?", username).Scan(&hash, &salt)
34  	return hash, salt, err
35  }
36  
37  // GetUsers returns a list of all users in the database as a slice of strings
38  // and returns an error if it fails
39  func GetUsers(db *sql.DB) ([]string, error) {
40  	rows, err := db.Query("SELECT username FROM users")
41  	if err != nil {
42  		return nil, err
43  	}
44  	defer rows.Close()
45  
46  	var users []string
47  	for rows.Next() {
48  		var user string
49  		err = rows.Scan(&user)
50  		if err != nil {
51  			return nil, err
52  		}
53  		users = append(users, user)
54  	}
55  
56  	return users, nil
57  }
58  
59  // GetSession accepts a session ID and returns the username associated with it
60  // and an error
61  func GetSession(db *sql.DB, session string) (string, time.Time, error) {
62  	var username string
63  	var expiresString string
64  	err := db.QueryRow("SELECT username, expires FROM sessions WHERE token = ?", session).Scan(&username, &expiresString)
65  	if err != nil {
66  		return "", time.Time{}, err
67  	}
68  
69  	expires, err := time.Parse(time.RFC3339, expiresString)
70  	if err != nil {
71  		return "", time.Time{}, err
72  	}
73  	return username, expires, nil
74  }
75  
76  // InvalidateSession invalidates a session by setting the expiration date to the
77  // provided time.
78  func InvalidateSession(db *sql.DB, session string, expiry time.Time) error {
79  	mutex.Lock()
80  	defer mutex.Unlock()
81  	_, err := db.Exec("UPDATE sessions SET expires = ? WHERE token = ?", expiry.Format(time.RFC3339), session)
82  	return err
83  }
84  
85  // CreateSession creates a new session in the database and returns an error if
86  // it fails
87  func CreateSession(db *sql.DB, username, token string, expiry time.Time) error {
88  	mutex.Lock()
89  	defer mutex.Unlock()
90  	_, err := db.Exec("INSERT INTO sessions (token, username, expires) VALUES (?, ?, ?)", token, username, expiry.Format(time.RFC3339))
91  	return err
92  }