/ db / migrations.go
migrations.go
  1  // SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com>
  2  //
  3  // SPDX-License-Identifier: Apache-2.0
  4  
  5  package db
  6  
  7  import (
  8  	"context"
  9  	"database/sql"
 10  	_ "embed"
 11  	"fmt"
 12  )
 13  
 14  type migration struct {
 15  	upQuery   string
 16  	downQuery string
 17  	postHook  func(*sql.Tx) error
 18  }
 19  
 20  var (
 21  	//go:embed sql/1_add_project_ids.up.sql
 22  	migration1Up string
 23  	//go:embed sql/1_add_project_ids.down.sql
 24  	migration1Down string
 25  	//go:embed sql/2_swap_project_url_for_id.up.sql
 26  	migration2Up string
 27  	//go:embed sql/2_swap_project_url_for_id.down.sql
 28  	migration2Down string
 29  )
 30  
 31  var migrations = [...]migration{
 32  	0: {
 33  		upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool);
 34  		INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`,
 35  		downQuery: `DROP TABLE schema_migrations;`,
 36  	},
 37  	1: {
 38  		upQuery:   migration1Up,
 39  		downQuery: migration1Down,
 40  		postHook:  generateAndInsertProjectIDs,
 41  	},
 42  	2: {
 43  		upQuery:   migration2Up,
 44  		downQuery: migration2Down,
 45  	},
 46  	3: {
 47  		postHook: correctProjectIDs,
 48  	},
 49  }
 50  
 51  // Migrate runs all pending migrations
 52  func Migrate(db *sql.DB) error {
 53  	version := getSchemaVersion(db)
 54  	for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
 55  		if err := runMigration(db, nextMigration); err != nil {
 56  			return fmt.Errorf("migrations failed: %w", err)
 57  		}
 58  		if version := getSchemaVersion(db); version != nextMigration {
 59  			return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
 60  		}
 61  	}
 62  	return nil
 63  }
 64  
 65  // runMigration runs a single migration inside a transaction, updates the schema
 66  // version and commits the transaction if successful, and rolls back the
 67  // transaction if unsuccessful.
 68  func runMigration(db *sql.DB, migrationIdx int) (err error) {
 69  	current := migrations[migrationIdx]
 70  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
 71  	if err != nil {
 72  		return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
 73  	}
 74  	defer func() {
 75  		if err == nil {
 76  			err = tx.Commit()
 77  		}
 78  		if err != nil {
 79  			if rbErr := tx.Rollback(); rbErr != nil {
 80  				err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
 81  			}
 82  		}
 83  	}()
 84  	if len(current.upQuery) > 0 {
 85  		if _, err := tx.Exec(current.upQuery); err != nil {
 86  			return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
 87  		}
 88  	}
 89  	if current.postHook != nil {
 90  		if err := current.postHook(tx); err != nil {
 91  			return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
 92  		}
 93  	}
 94  	return updateSchemaVersion(tx, migrationIdx)
 95  }
 96  
 97  // undoMigration rolls the single most recent migration back inside a
 98  // transaction, updates the schema version and commits the transaction if
 99  // successful, and rolls back the transaction if unsuccessful.
100  //
101  //lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
102  func undoMigration(db *sql.DB, migrationIdx int) (err error) {
103  	current := migrations[migrationIdx]
104  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
105  	if err != nil {
106  		return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
107  	}
108  	defer func() {
109  		if err == nil {
110  			err = tx.Commit()
111  		}
112  		if err != nil {
113  			if rbErr := tx.Rollback(); rbErr != nil {
114  				err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
115  			}
116  		}
117  	}()
118  	if len(current.downQuery) > 0 {
119  		if _, err := tx.Exec(current.downQuery); err != nil {
120  			return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
121  		}
122  	}
123  	return updateSchemaVersion(tx, migrationIdx-1)
124  }
125  
126  // getSchemaVersion returns the schema version from the database
127  func getSchemaVersion(db *sql.DB) int {
128  	row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
129  	var version int
130  	if err := row.Scan(&version); err != nil {
131  		version = -1
132  	}
133  	return version
134  }
135  
136  // updateSchemaVersion sets the version to the provided int
137  func updateSchemaVersion(tx *sql.Tx, version int) error {
138  	if version < 0 {
139  		// Do not try to use the schema_migrations table in a schema version where it doesn't exist
140  		return nil
141  	}
142  	_, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
143  	return err
144  }