package db import ( "database/sql" "encoding/json" "errors" "fmt" "golang.org/x/crypto/bcrypt" _ "modernc.org/sqlite" "git.nakama.town/fmartingr/butterrobot/internal/migration" "git.nakama.town/fmartingr/butterrobot/internal/model" ) var ( // ErrNotFound is returned when a record is not found ErrNotFound = errors.New("record not found") // ErrDuplicated is returned when a record already exists ErrDuplicated = errors.New("record already exists") ) // Database handles database operations type Database struct { db *sql.DB } // New creates a new Database instance func New(dbPath string) (*Database, error) { // Open database connection db, err := sql.Open("sqlite", dbPath) if err != nil { return nil, err } // Initialize database if err := initDatabase(db); err != nil { return nil, err } return &Database{db: db}, nil } // Close closes the database connection func (d *Database) Close() error { return d.db.Close() } // GetChannelByID retrieves a channel by ID func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { query := ` SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels WHERE id = ? ` row := d.db.QueryRow(query, id) var ( platform string platformChannelID string enabled bool channelRawJSON string ) err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } // Parse channel_raw JSON var channelRaw map[string]interface{} if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil { return nil, err } // Create channel channel := &model.Channel{ ID: id, Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } // Get channel plugins plugins, err := d.GetChannelPlugins(id) if err != nil && err != ErrNotFound { return nil, err } for _, plugin := range plugins { channel.Plugins[plugin.PluginID] = plugin } return channel, nil } // GetChannelByPlatform retrieves a channel by platform and platform channel ID func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) { query := ` SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels WHERE platform = ? AND platform_channel_id = ? ` row := d.db.QueryRow(query, platform, platformChannelID) var ( id int64 enabled bool channelRawJSON string ) err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } // Parse channel_raw JSON var channelRaw map[string]interface{} if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil { return nil, err } // Create channel channel := &model.Channel{ ID: id, Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } // Get channel plugins plugins, err := d.GetChannelPlugins(id) if err != nil && err != ErrNotFound { return nil, err } for _, plugin := range plugins { channel.Plugins[plugin.PluginID] = plugin } return channel, nil } // CreateChannel creates a new channel func (d *Database) CreateChannel(platform, platformChannelID string, enabled bool, channelRaw map[string]interface{}) (*model.Channel, error) { // Convert channelRaw to JSON channelRawJSON, err := json.Marshal(channelRaw) if err != nil { return nil, err } // Insert channel query := ` INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw) VALUES (?, ?, ?, ?) ` result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON)) if err != nil { return nil, err } // Get inserted ID id, err := result.LastInsertId() if err != nil { return nil, err } // Create channel channel := &model.Channel{ ID: id, Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } return channel, nil } // UpdateChannel updates a channel's enabled status func (d *Database) UpdateChannel(id int64, enabled bool) error { query := ` UPDATE channels SET enabled = ? WHERE id = ? ` _, err := d.db.Exec(query, enabled, id) return err } // DeleteChannel deletes a channel func (d *Database) DeleteChannel(id int64) error { // First delete all channel plugins if err := d.DeleteChannelPluginsByChannel(id); err != nil { return err } // Then delete the channel query := ` DELETE FROM channels WHERE id = ? ` _, err := d.db.Exec(query, id) return err } // GetChannelPlugins retrieves all plugins for a channel func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) { query := ` SELECT id, channel_id, plugin_id, enabled, config FROM channel_plugin WHERE channel_id = ? ` rows, err := d.db.Query(query, channelID) if err != nil { return nil, err } defer rows.Close() var plugins []*model.ChannelPlugin for rows.Next() { var ( id int64 channelID int64 pluginID string enabled bool configJSON string ) if err := rows.Scan(&id, &channelID, &pluginID, &enabled, &configJSON); err != nil { return nil, err } // Parse config JSON var config map[string]interface{} if err := json.Unmarshal([]byte(configJSON), &config); err != nil { return nil, err } plugin := &model.ChannelPlugin{ ID: id, ChannelID: channelID, PluginID: pluginID, Enabled: enabled, Config: config, } plugins = append(plugins, plugin) } if err := rows.Err(); err != nil { return nil, err } if len(plugins) == 0 { return nil, ErrNotFound } return plugins, nil } // GetChannelPluginByID retrieves a channel plugin by ID func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) { query := ` SELECT id, channel_id, plugin_id, enabled, config FROM channel_plugin WHERE id = ? ` row := d.db.QueryRow(query, id) var ( channelID int64 pluginID string enabled bool configJSON string ) err := row.Scan(&id, &channelID, &pluginID, &enabled, &configJSON) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } // Parse config JSON var config map[string]interface{} if err := json.Unmarshal([]byte(configJSON), &config); err != nil { return nil, err } return &model.ChannelPlugin{ ID: id, ChannelID: channelID, PluginID: pluginID, Enabled: enabled, Config: config, }, nil } // CreateChannelPlugin creates a new channel plugin func (d *Database) CreateChannelPlugin(channelID int64, pluginID string, enabled bool, config map[string]interface{}) (*model.ChannelPlugin, error) { // Check if plugin already exists for this channel query := ` SELECT COUNT(*) FROM channel_plugin WHERE channel_id = ? AND plugin_id = ? ` var count int err := d.db.QueryRow(query, channelID, pluginID).Scan(&count) if err != nil { return nil, err } if count > 0 { return nil, ErrDuplicated } // Convert config to JSON configJSON, err := json.Marshal(config) if err != nil { return nil, err } // Insert channel plugin insertQuery := ` INSERT INTO channel_plugin (channel_id, plugin_id, enabled, config) VALUES (?, ?, ?, ?) ` result, err := d.db.Exec(insertQuery, channelID, pluginID, enabled, string(configJSON)) if err != nil { return nil, err } // Get inserted ID id, err := result.LastInsertId() if err != nil { return nil, err } return &model.ChannelPlugin{ ID: id, ChannelID: channelID, PluginID: pluginID, Enabled: enabled, Config: config, }, nil } // UpdateChannelPlugin updates a channel plugin's enabled status func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error { query := ` UPDATE channel_plugin SET enabled = ? WHERE id = ? ` _, err := d.db.Exec(query, enabled, id) return err } // DeleteChannelPlugin deletes a channel plugin func (d *Database) DeleteChannelPlugin(id int64) error { query := ` DELETE FROM channel_plugin WHERE id = ? ` _, err := d.db.Exec(query, id) return err } // DeleteChannelPluginsByChannel deletes all plugins for a channel func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error { query := ` DELETE FROM channel_plugin WHERE channel_id = ? ` _, err := d.db.Exec(query, channelID) return err } // GetAllChannels retrieves all channels func (d *Database) GetAllChannels() ([]*model.Channel, error) { query := ` SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels ` rows, err := d.db.Query(query) if err != nil { return nil, err } defer rows.Close() var channels []*model.Channel for rows.Next() { var ( id int64 platform string platformChannelID string enabled bool channelRawJSON string ) if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil { return nil, err } // Parse channel_raw JSON var channelRaw map[string]interface{} if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil { return nil, err } // Create channel channel := &model.Channel{ ID: id, Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } // Get channel plugins plugins, err := d.GetChannelPlugins(id) if err != nil && err != ErrNotFound { continue // Skip this channel if plugins can't be retrieved } if plugins != nil { for _, plugin := range plugins { channel.Plugins[plugin.PluginID] = plugin } } channels = append(channels, channel) } if err := rows.Err(); err != nil { return nil, err } if len(channels) == 0 { channels = make([]*model.Channel, 0) } return channels, nil } // GetUserByID retrieves a user by ID func (d *Database) GetUserByID(id int64) (*model.User, error) { query := ` SELECT id, username, password FROM users WHERE id = ? ` row := d.db.QueryRow(query, id) var ( username string password string ) err := row.Scan(&id, &username, &password) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } return &model.User{ ID: id, Username: username, Password: password, }, nil } // CreateUser creates a new user func (d *Database) CreateUser(username, password string) (*model.User, error) { // Hash password hashedPassword, err := hashPassword(password) if err != nil { return nil, err } // Insert user query := ` INSERT INTO users (username, password) VALUES (?, ?) ` result, err := d.db.Exec(query, username, hashedPassword) if err != nil { return nil, err } // Get inserted ID id, err := result.LastInsertId() if err != nil { return nil, err } return &model.User{ ID: id, Username: username, Password: hashedPassword, }, nil } // CheckCredentials checks if the username and password are valid func (d *Database) CheckCredentials(username, password string) (*model.User, error) { query := ` SELECT id, username, password FROM users WHERE username = ? ` row := d.db.QueryRow(query, username) var ( id int64 dbUsername string dbPassword string ) err := row.Scan(&id, &dbUsername, &dbPassword) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } // Check password with bcrypt err = bcrypt.CompareHashAndPassword([]byte(dbPassword), []byte(password)) if err != nil { return nil, errors.New("invalid credentials") } return &model.User{ ID: id, Username: dbUsername, Password: dbPassword, }, nil } // UpdateUserPassword updates a user's password func (d *Database) UpdateUserPassword(userID int64, newPassword string) error { // Hash the new password hashedPassword, err := hashPassword(newPassword) if err != nil { return err } // Update the user's password query := ` UPDATE users SET password = ? WHERE id = ? ` _, err = d.db.Exec(query, hashedPassword, userID) return err } // Helper function to hash password func hashPassword(password string) (string, error) { // Use bcrypt for secure password hashing // The cost parameter is the computational cost, higher is more secure but slower // Recommended minimum is 12 hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), 12) if err != nil { return "", err } return string(hashedBytes), nil } // Initialize database tables func initDatabase(db *sql.DB) error { // Ensure migration table exists if err := migration.EnsureMigrationTable(db); err != nil { return fmt.Errorf("failed to create migration table: %w", err) } // Get applied migrations applied, err := migration.GetAppliedMigrations(db) if err != nil { return fmt.Errorf("failed to get applied migrations: %w", err) } // Get all migration versions allMigrations := make([]int, 0, len(migration.Migrations)) for version := range migration.Migrations { allMigrations = append(allMigrations, version) } // Create a map of applied migrations for quick lookup appliedMap := make(map[int]bool) for _, version := range applied { appliedMap[version] = true } // Count pending migrations pendingCount := 0 for _, version := range allMigrations { if !appliedMap[version] { pendingCount++ } } // Run migrations if needed if pendingCount > 0 { fmt.Printf("Running %d pending database migrations...\n", pendingCount) if err := migration.Migrate(db); err != nil { return fmt.Errorf("migration failed: %w", err) } fmt.Println("Database migrations completed successfully.") } else { fmt.Println("Database schema is up to date.") } return nil }