feat: db migrations, encrypted passwords
All checks were successful
ci/woodpecker/tag/release Pipeline was successful
All checks were successful
ci/woodpecker/tag/release Pipeline was successful
This commit is contained in:
parent
84e5feeb81
commit
ece8280358
8 changed files with 490 additions and 65 deletions
|
@ -1,14 +1,15 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/migration"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||
)
|
||||
|
||||
|
@ -505,7 +506,10 @@ func (d *Database) GetUserByID(id int64) (*model.User, error) {
|
|||
// CreateUser creates a new user
|
||||
func (d *Database) CreateUser(username, password string) (*model.User, error) {
|
||||
// Hash password
|
||||
hashedPassword := hashPassword(password)
|
||||
hashedPassword, err := hashPassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Insert user
|
||||
query := `
|
||||
|
@ -555,9 +559,9 @@ func (d *Database) CheckCredentials(username, password string) (*model.User, err
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Check password
|
||||
hashedPassword := hashPassword(password)
|
||||
if dbPassword != hashedPassword {
|
||||
// Check password with bcrypt
|
||||
err = bcrypt.CompareHashAndPassword([]byte(dbPassword), []byte(password))
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
|
@ -569,73 +573,60 @@ func (d *Database) CheckCredentials(username, password string) (*model.User, err
|
|||
}
|
||||
|
||||
// Helper function to hash password
|
||||
func hashPassword(password string) string {
|
||||
// In a real implementation, use a proper password hashing library like bcrypt
|
||||
// This is a simplified version for demonstration
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(password))
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
func hashPassword(password string) (string, error) {
|
||||
// Use bcrypt for secure password hashing
|
||||
// The cost parameter is the computational cost, higher is more secure but slower
|
||||
// Recommended minimum is 12
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// Initialize database tables
|
||||
func initDatabase(db *sql.DB) error {
|
||||
// Create channels table
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
platform TEXT NOT NULL,
|
||||
platform_channel_id TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
channel_raw TEXT NOT NULL,
|
||||
UNIQUE(platform, platform_channel_id)
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
// Ensure migration table exists
|
||||
if err := migration.EnsureMigrationTable(db); err != nil {
|
||||
return fmt.Errorf("failed to create migration table: %w", err)
|
||||
}
|
||||
|
||||
// Create channel_plugin table
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS channel_plugin (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL,
|
||||
plugin_id TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
config TEXT NOT NULL DEFAULT '{}',
|
||||
UNIQUE(channel_id, plugin_id),
|
||||
FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
|
||||
// Get applied migrations
|
||||
applied, err := migration.GetAppliedMigrations(db)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to get applied migrations: %w", err)
|
||||
}
|
||||
|
||||
// Create users table
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password TEXT NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
// Get all migration versions
|
||||
allMigrations := make([]int, 0, len(migration.Migrations))
|
||||
for version := range migration.Migrations {
|
||||
allMigrations = append(allMigrations, version)
|
||||
}
|
||||
|
||||
// Create default admin user if it doesn't exist
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
// Create a map of applied migrations for quick lookup
|
||||
appliedMap := make(map[int]bool)
|
||||
for _, version := range applied {
|
||||
appliedMap[version] = true
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
hashedPassword := hashPassword("admin")
|
||||
_, err = db.Exec("INSERT INTO users (username, password) VALUES (?, ?)", "admin", hashedPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
// Count pending migrations
|
||||
pendingCount := 0
|
||||
for _, version := range allMigrations {
|
||||
if !appliedMap[version] {
|
||||
pendingCount++
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Run migrations if needed
|
||||
if pendingCount > 0 {
|
||||
fmt.Printf("Running %d pending database migrations...\n", pendingCount)
|
||||
if err := migration.Migrate(db); err != nil {
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
fmt.Println("Database migrations completed successfully.")
|
||||
} else {
|
||||
fmt.Println("Database schema is up to date.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
211
internal/migration/migration.go
Normal file
211
internal/migration/migration.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package migration
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Migration represents a database migration
|
||||
type Migration struct {
|
||||
Version int
|
||||
Description string
|
||||
Up func(db *sql.DB) error
|
||||
Down func(db *sql.DB) error
|
||||
}
|
||||
|
||||
// Migrations is a collection of registered migrations
|
||||
var Migrations = make(map[int]Migration)
|
||||
|
||||
// Register adds a migration to the list of available migrations
|
||||
func Register(version int, description string, up, down func(db *sql.DB) error) {
|
||||
if _, exists := Migrations[version]; exists {
|
||||
panic(fmt.Sprintf("migration version %d already exists", version))
|
||||
}
|
||||
|
||||
Migrations[version] = Migration{
|
||||
Version: version,
|
||||
Description: description,
|
||||
Up: up,
|
||||
Down: down,
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureMigrationTable creates the migration table if it doesn't exist
|
||||
func EnsureMigrationTable(db *sql.DB) error {
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at TIMESTAMP NOT NULL
|
||||
)
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAppliedMigrations returns a list of applied migration versions
|
||||
func GetAppliedMigrations(db *sql.DB) ([]int, error) {
|
||||
rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var versions []int
|
||||
for rows.Next() {
|
||||
var version int
|
||||
if err := rows.Scan(&version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
versions = append(versions, version)
|
||||
}
|
||||
|
||||
return versions, rows.Err()
|
||||
}
|
||||
|
||||
// IsApplied checks if a migration version has been applied
|
||||
func IsApplied(db *sql.DB, version int) (bool, error) {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version).Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// MarkAsApplied marks a migration as applied
|
||||
func MarkAsApplied(db *sql.DB, version int) error {
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
||||
version, time.Now(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveApplied removes a migration from the applied list
|
||||
func RemoveApplied(db *sql.DB, version int) error {
|
||||
_, err := db.Exec("DELETE FROM schema_migrations WHERE version = ?", version)
|
||||
return err
|
||||
}
|
||||
|
||||
// Migrate runs pending migrations up to the latest version
|
||||
func Migrate(db *sql.DB) error {
|
||||
// Ensure migration table exists
|
||||
if err := EnsureMigrationTable(db); err != nil {
|
||||
return fmt.Errorf("failed to create migration table: %w", err)
|
||||
}
|
||||
|
||||
// Get applied migrations
|
||||
applied, err := GetAppliedMigrations(db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get applied migrations: %w", err)
|
||||
}
|
||||
|
||||
// Create a map of applied migrations for quick lookup
|
||||
appliedMap := make(map[int]bool)
|
||||
for _, version := range applied {
|
||||
appliedMap[version] = true
|
||||
}
|
||||
|
||||
// Get all migration versions and sort them
|
||||
var versions []int
|
||||
for version := range Migrations {
|
||||
versions = append(versions, version)
|
||||
}
|
||||
sort.Ints(versions)
|
||||
|
||||
// Apply each pending migration
|
||||
for _, version := range versions {
|
||||
if !appliedMap[version] {
|
||||
migration := Migrations[version]
|
||||
fmt.Printf("Applying migration %d: %s...\n", version, migration.Description)
|
||||
|
||||
// Start transaction for the migration
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction for migration %d: %w", version, err)
|
||||
}
|
||||
|
||||
// Apply the migration
|
||||
if err := migration.Up(db); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to apply migration %d: %w", version, err)
|
||||
}
|
||||
|
||||
// Mark as applied
|
||||
if _, err := tx.Exec(
|
||||
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
||||
version, time.Now(),
|
||||
); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to mark migration %d as applied: %w", version, err)
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit migration %d: %w", version, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Migration %d applied successfully\n", version)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateDown rolls back migrations down to the specified version
|
||||
// If version is -1, it will roll back all migrations
|
||||
func MigrateDown(db *sql.DB, targetVersion int) error {
|
||||
// Ensure migration table exists
|
||||
if err := EnsureMigrationTable(db); err != nil {
|
||||
return fmt.Errorf("failed to create migration table: %w", err)
|
||||
}
|
||||
|
||||
// Get applied migrations
|
||||
applied, err := GetAppliedMigrations(db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get applied migrations: %w", err)
|
||||
}
|
||||
|
||||
// Sort in descending order to roll back newest first
|
||||
sort.Sort(sort.Reverse(sort.IntSlice(applied)))
|
||||
|
||||
// Roll back each migration until target version
|
||||
for _, version := range applied {
|
||||
if targetVersion == -1 || version > targetVersion {
|
||||
migration, exists := Migrations[version]
|
||||
if !exists {
|
||||
return fmt.Errorf("migration %d is applied but not found in codebase", version)
|
||||
}
|
||||
|
||||
fmt.Printf("Rolling back migration %d: %s...\n", version, migration.Description)
|
||||
|
||||
// Start transaction for the rollback
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction for rollback %d: %w", version, err)
|
||||
}
|
||||
|
||||
// Apply the down migration
|
||||
if err := migration.Down(db); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to roll back migration %d: %w", version, err)
|
||||
}
|
||||
|
||||
// Remove from applied list
|
||||
if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err)
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit rollback %d: %w", version, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Migration %d rolled back successfully\n", version)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
102
internal/migration/migrations.go
Normal file
102
internal/migration/migrations.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package migration
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Register migrations
|
||||
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
|
||||
}
|
||||
|
||||
// Initial schema creation with bcrypt passwords - version 1
|
||||
func migrateInitialSchemaUp(db *sql.DB) error {
|
||||
// Create channels table
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
platform TEXT NOT NULL,
|
||||
platform_channel_id TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
channel_raw TEXT NOT NULL,
|
||||
UNIQUE(platform, platform_channel_id)
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create channel_plugin table
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS channel_plugin (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL,
|
||||
plugin_id TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
config TEXT NOT NULL DEFAULT '{}',
|
||||
UNIQUE(channel_id, plugin_id),
|
||||
FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create users table with bcrypt passwords
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password TEXT NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create default admin user with bcrypt password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte("admin"), 12)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if users table is empty before inserting
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO users (username, password) VALUES (?, ?)",
|
||||
"admin", string(hashedPassword),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateInitialSchemaDown(db *sql.DB) error {
|
||||
// Drop tables in reverse order of dependencies
|
||||
_, err := db.Exec(`DROP TABLE IF EXISTS channel_plugin`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`DROP TABLE IF EXISTS channels`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`DROP TABLE IF EXISTS users`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue