migrations.go
1 // SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com> 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package db 6 7 import ( 8 "context" 9 "database/sql" 10 _ "embed" 11 "fmt" 12 ) 13 14 type migration struct { 15 upQuery string 16 downQuery string 17 postHook func(*sql.Tx) error 18 } 19 20 var ( 21 //go:embed sql/1_add_project_ids.up.sql 22 migration1Up string 23 //go:embed sql/1_add_project_ids.down.sql 24 migration1Down string 25 //go:embed sql/2_swap_project_url_for_id.up.sql 26 migration2Up string 27 //go:embed sql/2_swap_project_url_for_id.down.sql 28 migration2Down string 29 ) 30 31 var migrations = [...]migration{ 32 0: { 33 upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool); 34 INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`, 35 downQuery: `DROP TABLE schema_migrations;`, 36 }, 37 1: { 38 upQuery: migration1Up, 39 downQuery: migration1Down, 40 postHook: generateAndInsertProjectIDs, 41 }, 42 2: { 43 upQuery: migration2Up, 44 downQuery: migration2Down, 45 }, 46 3: { 47 postHook: correctProjectIDs, 48 }, 49 } 50 51 // Migrate runs all pending migrations 52 func Migrate(db *sql.DB) error { 53 version := getSchemaVersion(db) 54 for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ { 55 if err := runMigration(db, nextMigration); err != nil { 56 return fmt.Errorf("migrations failed: %w", err) 57 } 58 if version := getSchemaVersion(db); version != nextMigration { 59 return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version) 60 } 61 } 62 return nil 63 } 64 65 // runMigration runs a single migration inside a transaction, updates the schema 66 // version and commits the transaction if successful, and rolls back the 67 // transaction if unsuccessful. 68 func runMigration(db *sql.DB, migrationIdx int) (err error) { 69 current := migrations[migrationIdx] 70 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 71 if err != nil { 72 return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err) 73 } 74 defer func() { 75 if err == nil { 76 err = tx.Commit() 77 } 78 if err != nil { 79 if rbErr := tx.Rollback(); rbErr != nil { 80 err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err) 81 } 82 } 83 }() 84 if len(current.upQuery) > 0 { 85 if _, err := tx.Exec(current.upQuery); err != nil { 86 return fmt.Errorf("failed running migration %d: %w", migrationIdx, err) 87 } 88 } 89 if current.postHook != nil { 90 if err := current.postHook(tx); err != nil { 91 return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err) 92 } 93 } 94 return updateSchemaVersion(tx, migrationIdx) 95 } 96 97 // undoMigration rolls the single most recent migration back inside a 98 // transaction, updates the schema version and commits the transaction if 99 // successful, and rolls back the transaction if unsuccessful. 100 // 101 //lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34) 102 func undoMigration(db *sql.DB, migrationIdx int) (err error) { 103 current := migrations[migrationIdx] 104 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 105 if err != nil { 106 return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err) 107 } 108 defer func() { 109 if err == nil { 110 err = tx.Commit() 111 } 112 if err != nil { 113 if rbErr := tx.Rollback(); rbErr != nil { 114 err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err) 115 } 116 } 117 }() 118 if len(current.downQuery) > 0 { 119 if _, err := tx.Exec(current.downQuery); err != nil { 120 return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err) 121 } 122 } 123 return updateSchemaVersion(tx, migrationIdx-1) 124 } 125 126 // getSchemaVersion returns the schema version from the database 127 func getSchemaVersion(db *sql.DB) int { 128 row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`) 129 var version int 130 if err := row.Scan(&version); err != nil { 131 version = -1 132 } 133 return version 134 } 135 136 // updateSchemaVersion sets the version to the provided int 137 func updateSchemaVersion(tx *sql.Tx, version int) error { 138 if version < 0 { 139 // Do not try to use the schema_migrations table in a schema version where it doesn't exist 140 return nil 141 } 142 _, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version)) 143 return err 144 }