package db import ( "crypto/sha256" "database/sql" "encoding/hex" "encoding/json" "errors" _ "modernc.org/sqlite" "git.nakama.town/fmartingr/butterrobot/internal/model" ) var ( // ErrNotFound is returned when a record is not found ErrNotFound = errors.New("record not found") // ErrDuplicated is returned when a record already exists ErrDuplicated = errors.New("record already exists") ) // Database handles database operations type Database struct { db *sql.DB } // New creates a new Database instance func New(dbPath string) (*Database, error) { // Open database connection db, err := sql.Open("sqlite", dbPath) if err != nil { return nil, err } // Initialize database if err := initDatabase(db); err != nil { return nil, err } return &Database{db: db}, nil } // Close closes the database connection func (d *Database) Close() error { return d.db.Close() } // GetChannelByID retrieves a channel by ID func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { query := ` SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels WHERE id = ? ` row := d.db.QueryRow(query, id) var ( platform string platformChannelID string enabled bool channelRawJSON string ) err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } // Parse channel_raw JSON var channelRaw map[string]interface{} if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil { return nil, err } // Create channel channel := &model.Channel{ ID: id, Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } // Get channel plugins plugins, err := d.GetChannelPlugins(id) if err != nil && err != ErrNotFound { return nil, err } for _, plugin := range plugins { channel.Plugins[plugin.PluginID] = plugin } return channel, nil } // GetChannelByPlatform retrieves a channel by platform and platform channel ID func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) { query := ` SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels WHERE platform = ? AND platform_channel_id = ? ` row := d.db.QueryRow(query, platform, platformChannelID) var ( id int64 enabled bool channelRawJSON string ) err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } // Parse channel_raw JSON var channelRaw map[string]interface{} if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil { return nil, err } // Create channel channel := &model.Channel{ ID: id, Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } // Get channel plugins plugins, err := d.GetChannelPlugins(id) if err != nil && err != ErrNotFound { return nil, err } for _, plugin := range plugins { channel.Plugins[plugin.PluginID] = plugin } return channel, nil } // CreateChannel creates a new channel func (d *Database) CreateChannel(platform, platformChannelID string, enabled bool, channelRaw map[string]interface{}) (*model.Channel, error) { // Convert channelRaw to JSON channelRawJSON, err := json.Marshal(channelRaw) if err != nil { return nil, err } // Insert channel query := ` INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw) VALUES (?, ?, ?, ?) ` result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON)) if err != nil { return nil, err } // Get inserted ID id, err := result.LastInsertId() if err != nil { return nil, err } // Create channel channel := &model.Channel{ ID: id, Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } return channel, nil } // UpdateChannel updates a channel's enabled status func (d *Database) UpdateChannel(id int64, enabled bool) error { query := ` UPDATE channels SET enabled = ? WHERE id = ? ` _, err := d.db.Exec(query, enabled, id) return err } // DeleteChannel deletes a channel func (d *Database) DeleteChannel(id int64) error { // First delete all channel plugins if err := d.DeleteChannelPluginsByChannel(id); err != nil { return err } // Then delete the channel query := ` DELETE FROM channels WHERE id = ? ` _, err := d.db.Exec(query, id) return err } // GetChannelPlugins retrieves all plugins for a channel func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) { query := ` SELECT id, channel_id, plugin_id, enabled, config FROM channel_plugin WHERE channel_id = ? ` rows, err := d.db.Query(query, channelID) if err != nil { return nil, err } defer rows.Close() var plugins []*model.ChannelPlugin for rows.Next() { var ( id int64 channelID int64 pluginID string enabled bool configJSON string ) if err := rows.Scan(&id, &channelID, &pluginID, &enabled, &configJSON); err != nil { return nil, err } // Parse config JSON var config map[string]interface{} if err := json.Unmarshal([]byte(configJSON), &config); err != nil { return nil, err } plugin := &model.ChannelPlugin{ ID: id, ChannelID: channelID, PluginID: pluginID, Enabled: enabled, Config: config, } plugins = append(plugins, plugin) } if err := rows.Err(); err != nil { return nil, err } if len(plugins) == 0 { return nil, ErrNotFound } return plugins, nil } // GetChannelPluginByID retrieves a channel plugin by ID func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) { query := ` SELECT id, channel_id, plugin_id, enabled, config FROM channel_plugin WHERE id = ? ` row := d.db.QueryRow(query, id) var ( channelID int64 pluginID string enabled bool configJSON string ) err := row.Scan(&id, &channelID, &pluginID, &enabled, &configJSON) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } // Parse config JSON var config map[string]interface{} if err := json.Unmarshal([]byte(configJSON), &config); err != nil { return nil, err } return &model.ChannelPlugin{ ID: id, ChannelID: channelID, PluginID: pluginID, Enabled: enabled, Config: config, }, nil } // CreateChannelPlugin creates a new channel plugin func (d *Database) CreateChannelPlugin(channelID int64, pluginID string, enabled bool, config map[string]interface{}) (*model.ChannelPlugin, error) { // Check if plugin already exists for this channel query := ` SELECT COUNT(*) FROM channel_plugin WHERE channel_id = ? AND plugin_id = ? ` var count int err := d.db.QueryRow(query, channelID, pluginID).Scan(&count) if err != nil { return nil, err } if count > 0 { return nil, ErrDuplicated } // Convert config to JSON configJSON, err := json.Marshal(config) if err != nil { return nil, err } // Insert channel plugin insertQuery := ` INSERT INTO channel_plugin (channel_id, plugin_id, enabled, config) VALUES (?, ?, ?, ?) ` result, err := d.db.Exec(insertQuery, channelID, pluginID, enabled, string(configJSON)) if err != nil { return nil, err } // Get inserted ID id, err := result.LastInsertId() if err != nil { return nil, err } return &model.ChannelPlugin{ ID: id, ChannelID: channelID, PluginID: pluginID, Enabled: enabled, Config: config, }, nil } // UpdateChannelPlugin updates a channel plugin's enabled status func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error { query := ` UPDATE channel_plugin SET enabled = ? WHERE id = ? ` _, err := d.db.Exec(query, enabled, id) return err } // DeleteChannelPlugin deletes a channel plugin func (d *Database) DeleteChannelPlugin(id int64) error { query := ` DELETE FROM channel_plugin WHERE id = ? ` _, err := d.db.Exec(query, id) return err } // DeleteChannelPluginsByChannel deletes all plugins for a channel func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error { query := ` DELETE FROM channel_plugin WHERE channel_id = ? ` _, err := d.db.Exec(query, channelID) return err } // GetAllChannels retrieves all channels func (d *Database) GetAllChannels() ([]*model.Channel, error) { query := ` SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels ` rows, err := d.db.Query(query) if err != nil { return nil, err } defer rows.Close() var channels []*model.Channel for rows.Next() { var ( id int64 platform string platformChannelID string enabled bool channelRawJSON string ) if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil { return nil, err } // Parse channel_raw JSON var channelRaw map[string]interface{} if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil { return nil, err } // Create channel channel := &model.Channel{ ID: id, Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } // Get channel plugins plugins, err := d.GetChannelPlugins(id) if err != nil && err != ErrNotFound { continue // Skip this channel if plugins can't be retrieved } if plugins != nil { for _, plugin := range plugins { channel.Plugins[plugin.PluginID] = plugin } } channels = append(channels, channel) } if err := rows.Err(); err != nil { return nil, err } if len(channels) == 0 { channels = make([]*model.Channel, 0) } return channels, nil } // GetUserByID retrieves a user by ID func (d *Database) GetUserByID(id int64) (*model.User, error) { query := ` SELECT id, username, password FROM users WHERE id = ? ` row := d.db.QueryRow(query, id) var ( username string password string ) err := row.Scan(&id, &username, &password) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } return &model.User{ ID: id, Username: username, Password: password, }, nil } // CreateUser creates a new user func (d *Database) CreateUser(username, password string) (*model.User, error) { // Hash password hashedPassword := hashPassword(password) // Insert user query := ` INSERT INTO users (username, password) VALUES (?, ?) ` result, err := d.db.Exec(query, username, hashedPassword) if err != nil { return nil, err } // Get inserted ID id, err := result.LastInsertId() if err != nil { return nil, err } return &model.User{ ID: id, Username: username, Password: hashedPassword, }, nil } // CheckCredentials checks if the username and password are valid func (d *Database) CheckCredentials(username, password string) (*model.User, error) { query := ` SELECT id, username, password FROM users WHERE username = ? ` row := d.db.QueryRow(query, username) var ( id int64 dbUsername string dbPassword string ) err := row.Scan(&id, &dbUsername, &dbPassword) if err == sql.ErrNoRows { return nil, ErrNotFound } if err != nil { return nil, err } // Check password hashedPassword := hashPassword(password) if dbPassword != hashedPassword { return nil, errors.New("invalid credentials") } return &model.User{ ID: id, Username: dbUsername, Password: dbPassword, }, nil } // Helper function to hash password func hashPassword(password string) string { // In a real implementation, use a proper password hashing library like bcrypt // This is a simplified version for demonstration hasher := sha256.New() hasher.Write([]byte(password)) return hex.EncodeToString(hasher.Sum(nil)) } // Initialize database tables func initDatabase(db *sql.DB) error { // Create channels table _, err := db.Exec(` CREATE TABLE IF NOT EXISTS channels ( id INTEGER PRIMARY KEY AUTOINCREMENT, platform TEXT NOT NULL, platform_channel_id TEXT NOT NULL, enabled BOOLEAN NOT NULL DEFAULT 0, channel_raw TEXT NOT NULL, UNIQUE(platform, platform_channel_id) ) `) if err != nil { return err } // Create channel_plugin table _, err = db.Exec(` CREATE TABLE IF NOT EXISTS channel_plugin ( id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER NOT NULL, plugin_id TEXT NOT NULL, enabled BOOLEAN NOT NULL DEFAULT 0, config TEXT NOT NULL DEFAULT '{}', UNIQUE(channel_id, plugin_id), FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE ) `) if err != nil { return err } // Create users table _, err = db.Exec(` CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password TEXT NOT NULL ) `) if err != nil { return err } // Create default admin user if it doesn't exist var count int err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) if err != nil { return err } if count == 0 { hashedPassword := hashPassword("admin") _, err = db.Exec("INSERT INTO users (username, password) VALUES (?, ?)", "admin", hashedPassword) if err != nil { return err } } return nil }