package migration import ( "database/sql" "fmt" "sort" "time" ) // Migration represents a database migration type Migration struct { Version int Description string Up func(db *sql.DB) error Down func(db *sql.DB) error } // Migrations is a collection of registered migrations var Migrations = make(map[int]Migration) // Register adds a migration to the list of available migrations func Register(version int, description string, up, down func(db *sql.DB) error) { if _, exists := Migrations[version]; exists { panic(fmt.Sprintf("migration version %d already exists", version)) } Migrations[version] = Migration{ Version: version, Description: description, Up: up, Down: down, } } // EnsureMigrationTable creates the migration table if it doesn't exist func EnsureMigrationTable(db *sql.DB) error { _, err := db.Exec(` CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at TIMESTAMP NOT NULL ) `) return err } // GetAppliedMigrations returns a list of applied migration versions func GetAppliedMigrations(db *sql.DB) ([]int, error) { rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version") if err != nil { return nil, err } defer rows.Close() var versions []int for rows.Next() { var version int if err := rows.Scan(&version); err != nil { return nil, err } versions = append(versions, version) } return versions, rows.Err() } // IsApplied checks if a migration version has been applied func IsApplied(db *sql.DB, version int) (bool, error) { var count int err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version).Scan(&count) if err != nil { return false, err } return count > 0, nil } // MarkAsApplied marks a migration as applied func MarkAsApplied(db *sql.DB, version int) error { _, err := db.Exec( "INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", version, time.Now(), ) return err } // RemoveApplied removes a migration from the applied list func RemoveApplied(db *sql.DB, version int) error { _, err := db.Exec("DELETE FROM schema_migrations WHERE version = ?", version) return err } // Migrate runs pending migrations up to the latest version func Migrate(db *sql.DB) error { // Ensure migration table exists if err := EnsureMigrationTable(db); err != nil { return fmt.Errorf("failed to create migration table: %w", err) } // Get applied migrations applied, err := GetAppliedMigrations(db) if err != nil { return fmt.Errorf("failed to get applied migrations: %w", err) } // Create a map of applied migrations for quick lookup appliedMap := make(map[int]bool) for _, version := range applied { appliedMap[version] = true } // Get all migration versions and sort them var versions []int for version := range Migrations { versions = append(versions, version) } sort.Ints(versions) // Apply each pending migration for _, version := range versions { if !appliedMap[version] { migration := Migrations[version] fmt.Printf("Applying migration %d: %s...\n", version, migration.Description) // Start transaction for the migration tx, err := db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction for migration %d: %w", version, err) } // Apply the migration if err := migration.Up(db); err != nil { tx.Rollback() return fmt.Errorf("failed to apply migration %d: %w", version, err) } // Mark as applied if _, err := tx.Exec( "INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", version, time.Now(), ); err != nil { tx.Rollback() return fmt.Errorf("failed to mark migration %d as applied: %w", version, err) } // Commit the transaction if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit migration %d: %w", version, err) } fmt.Printf("Migration %d applied successfully\n", version) } } return nil } // MigrateDown rolls back migrations down to the specified version // If version is -1, it will roll back all migrations func MigrateDown(db *sql.DB, targetVersion int) error { // Ensure migration table exists if err := EnsureMigrationTable(db); err != nil { return fmt.Errorf("failed to create migration table: %w", err) } // Get applied migrations applied, err := GetAppliedMigrations(db) if err != nil { return fmt.Errorf("failed to get applied migrations: %w", err) } // Sort in descending order to roll back newest first sort.Sort(sort.Reverse(sort.IntSlice(applied))) // Roll back each migration until target version for _, version := range applied { if targetVersion == -1 || version > targetVersion { migration, exists := Migrations[version] if !exists { return fmt.Errorf("migration %d is applied but not found in codebase", version) } fmt.Printf("Rolling back migration %d: %s...\n", version, migration.Description) // Start transaction for the rollback tx, err := db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction for rollback %d: %w", version, err) } // Apply the down migration if err := migration.Down(db); err != nil { tx.Rollback() return fmt.Errorf("failed to roll back migration %d: %w", version, err) } // Remove from applied list if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil { tx.Rollback() return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err) } // Commit the transaction if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit rollback %d: %w", version, err) } fmt.Printf("Migration %d rolled back successfully\n", version) } } return nil }