butterrobot/internal/migration/migration.go
Felipe M. 72c6dd6982
All checks were successful
ci/woodpecker/tag/release Pipeline was successful
feat: remindme plugin
2025-04-22 11:29:39 +02:00

211 lines
5.7 KiB
Go

package migration
import (
"database/sql"
"fmt"
"sort"
"time"
)
// Migration represents a database migration
type Migration struct {
Version int
Description string
Up func(db *sql.DB) error
Down func(db *sql.DB) error
}
// Migrations is a collection of registered migrations
var Migrations = make(map[int]Migration)
// Register adds a migration to the list of available migrations
func Register(version int, description string, up, down func(db *sql.DB) error) {
if _, exists := Migrations[version]; exists {
panic(fmt.Sprintf("migration version %d already exists", version))
}
Migrations[version] = Migration{
Version: version,
Description: description,
Up: up,
Down: down,
}
}
// EnsureMigrationTable creates the migration table if it doesn't exist
func EnsureMigrationTable(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TIMESTAMP NOT NULL
)
`)
return err
}
// GetAppliedMigrations returns a list of applied migration versions
func GetAppliedMigrations(db *sql.DB) ([]int, error) {
rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version")
if err != nil {
return nil, err
}
defer rows.Close()
var versions []int
for rows.Next() {
var version int
if err := rows.Scan(&version); err != nil {
return nil, err
}
versions = append(versions, version)
}
return versions, rows.Err()
}
// IsApplied checks if a migration version has been applied
func IsApplied(db *sql.DB, version int) (bool, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// MarkAsApplied marks a migration as applied
func MarkAsApplied(db *sql.DB, version int) error {
_, err := db.Exec(
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
version, time.Now(),
)
return err
}
// RemoveApplied removes a migration from the applied list
func RemoveApplied(db *sql.DB, version int) error {
_, err := db.Exec("DELETE FROM schema_migrations WHERE version = ?", version)
return err
}
// Migrate runs pending migrations up to the latest version
func Migrate(db *sql.DB) error {
// Ensure migration table exists
if err := EnsureMigrationTable(db); err != nil {
return fmt.Errorf("failed to create migration table: %w", err)
}
// Get applied migrations
applied, err := GetAppliedMigrations(db)
if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err)
}
// Create a map of applied migrations for quick lookup
appliedMap := make(map[int]bool)
for _, version := range applied {
appliedMap[version] = true
}
// Get all migration versions and sort them
var versions []int
for version := range Migrations {
versions = append(versions, version)
}
sort.Ints(versions)
// Apply each pending migration
for _, version := range versions {
if !appliedMap[version] {
migration := Migrations[version]
fmt.Printf("Applying migration %d: %s...\n", version, migration.Description)
// Start transaction for the migration
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction for migration %d: %w", version, err)
}
// Apply the migration
if err := migration.Up(db); err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %d: %w", version, err)
}
// Mark as applied
if _, err := tx.Exec(
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
version, time.Now(),
); err != nil {
tx.Rollback()
return fmt.Errorf("failed to mark migration %d as applied: %w", version, err)
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit migration %d: %w", version, err)
}
fmt.Printf("Migration %d applied successfully\n", version)
}
}
return nil
}
// MigrateDown rolls back migrations down to the specified version
// If version is -1, it will roll back all migrations
func MigrateDown(db *sql.DB, targetVersion int) error {
// Ensure migration table exists
if err := EnsureMigrationTable(db); err != nil {
return fmt.Errorf("failed to create migration table: %w", err)
}
// Get applied migrations
applied, err := GetAppliedMigrations(db)
if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err)
}
// Sort in descending order to roll back newest first
sort.Sort(sort.Reverse(sort.IntSlice(applied)))
// Roll back each migration until target version
for _, version := range applied {
if targetVersion == -1 || version > targetVersion {
migration, exists := Migrations[version]
if !exists {
return fmt.Errorf("migration %d is applied but not found in codebase", version)
}
fmt.Printf("Rolling back migration %d: %s...\n", version, migration.Description)
// Start transaction for the rollback
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction for rollback %d: %w", version, err)
}
// Apply the down migration
if err := migration.Down(db); err != nil {
tx.Rollback()
return fmt.Errorf("failed to roll back migration %d: %w", version, err)
}
// Remove from applied list
if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil {
tx.Rollback()
return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err)
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit rollback %d: %w", version, err)
}
fmt.Printf("Migration %d rolled back successfully\n", version)
}
}
return nil
}