223 lines
No EOL
6.1 KiB
Go
223 lines
No EOL
6.1 KiB
Go
package migration
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"sort"
|
|
"time"
|
|
)
|
|
|
|
// Migration represents a database migration
|
|
type Migration struct {
|
|
Version int
|
|
Description string
|
|
Up func(db *sql.DB) error
|
|
Down func(db *sql.DB) error
|
|
}
|
|
|
|
// Migrations is a collection of registered migrations
|
|
var Migrations = make(map[int]Migration)
|
|
|
|
// Register adds a migration to the list of available migrations
|
|
func Register(version int, description string, up, down func(db *sql.DB) error) {
|
|
if _, exists := Migrations[version]; exists {
|
|
panic(fmt.Sprintf("migration version %d already exists", version))
|
|
}
|
|
|
|
Migrations[version] = Migration{
|
|
Version: version,
|
|
Description: description,
|
|
Up: up,
|
|
Down: down,
|
|
}
|
|
}
|
|
|
|
// EnsureMigrationTable creates the migration table if it doesn't exist
|
|
func EnsureMigrationTable(db *sql.DB) error {
|
|
_, err := db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version INTEGER PRIMARY KEY,
|
|
applied_at TIMESTAMP NOT NULL
|
|
)
|
|
`)
|
|
return err
|
|
}
|
|
|
|
// GetAppliedMigrations returns a list of applied migration versions
|
|
func GetAppliedMigrations(db *sql.DB) ([]int, error) {
|
|
rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err := rows.Close(); err != nil {
|
|
fmt.Printf("Error closing rows: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
var versions []int
|
|
for rows.Next() {
|
|
var version int
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, err
|
|
}
|
|
versions = append(versions, version)
|
|
}
|
|
|
|
return versions, rows.Err()
|
|
}
|
|
|
|
// IsApplied checks if a migration version has been applied
|
|
func IsApplied(db *sql.DB, version int) (bool, error) {
|
|
var count int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version).Scan(&count)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
// MarkAsApplied marks a migration as applied
|
|
func MarkAsApplied(db *sql.DB, version int) error {
|
|
_, err := db.Exec(
|
|
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
|
version, time.Now(),
|
|
)
|
|
return err
|
|
}
|
|
|
|
// RemoveApplied removes a migration from the applied list
|
|
func RemoveApplied(db *sql.DB, version int) error {
|
|
_, err := db.Exec("DELETE FROM schema_migrations WHERE version = ?", version)
|
|
return err
|
|
}
|
|
|
|
// Migrate runs pending migrations up to the latest version
|
|
func Migrate(db *sql.DB) error {
|
|
// Ensure migration table exists
|
|
if err := EnsureMigrationTable(db); err != nil {
|
|
return fmt.Errorf("failed to create migration table: %w", err)
|
|
}
|
|
|
|
// Get applied migrations
|
|
applied, err := GetAppliedMigrations(db)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
|
}
|
|
|
|
// Create a map of applied migrations for quick lookup
|
|
appliedMap := make(map[int]bool)
|
|
for _, version := range applied {
|
|
appliedMap[version] = true
|
|
}
|
|
|
|
// Get all migration versions and sort them
|
|
var versions []int
|
|
for version := range Migrations {
|
|
versions = append(versions, version)
|
|
}
|
|
sort.Ints(versions)
|
|
|
|
// Apply each pending migration
|
|
for _, version := range versions {
|
|
if !appliedMap[version] {
|
|
migration := Migrations[version]
|
|
fmt.Printf("Applying migration %d: %s...\n", version, migration.Description)
|
|
|
|
// Start transaction for the migration
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction for migration %d: %w", version, err)
|
|
}
|
|
|
|
// Apply the migration
|
|
if err := migration.Up(db); err != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
fmt.Printf("Error rolling back transaction: %v\n", err)
|
|
}
|
|
return fmt.Errorf("failed to apply migration %d: %w", version, err)
|
|
}
|
|
|
|
// Mark as applied
|
|
if _, err := tx.Exec(
|
|
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
|
version, time.Now(),
|
|
); err != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
fmt.Printf("Error rolling back transaction: %v\n", err)
|
|
}
|
|
return fmt.Errorf("failed to mark migration %d as applied: %w", version, err)
|
|
}
|
|
|
|
// Commit the transaction
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit migration %d: %w", version, err)
|
|
}
|
|
|
|
fmt.Printf("Migration %d applied successfully\n", version)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MigrateDown rolls back migrations down to the specified version
|
|
// If version is -1, it will roll back all migrations
|
|
func MigrateDown(db *sql.DB, targetVersion int) error {
|
|
// Ensure migration table exists
|
|
if err := EnsureMigrationTable(db); err != nil {
|
|
return fmt.Errorf("failed to create migration table: %w", err)
|
|
}
|
|
|
|
// Get applied migrations
|
|
applied, err := GetAppliedMigrations(db)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
|
}
|
|
|
|
// Sort in descending order to roll back newest first
|
|
sort.Sort(sort.Reverse(sort.IntSlice(applied)))
|
|
|
|
// Roll back each migration until target version
|
|
for _, version := range applied {
|
|
if targetVersion == -1 || version > targetVersion {
|
|
migration, exists := Migrations[version]
|
|
if !exists {
|
|
return fmt.Errorf("migration %d is applied but not found in codebase", version)
|
|
}
|
|
|
|
fmt.Printf("Rolling back migration %d: %s...\n", version, migration.Description)
|
|
|
|
// Start transaction for the rollback
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction for rollback %d: %w", version, err)
|
|
}
|
|
|
|
// Apply the down migration
|
|
if err := migration.Down(db); err != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
fmt.Printf("Error rolling back transaction: %v\n", err)
|
|
}
|
|
return fmt.Errorf("failed to roll back migration %d: %w", version, err)
|
|
}
|
|
|
|
// Remove from applied list
|
|
if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
fmt.Printf("Error rolling back transaction: %v\n", err)
|
|
}
|
|
return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err)
|
|
}
|
|
|
|
// Commit the transaction
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit rollback %d: %w", version, err)
|
|
}
|
|
|
|
fmt.Printf("Migration %d rolled back successfully\n", version)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
} |