641 lines
14 KiB
Go
641 lines
14 KiB
Go
package db
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
|
|
_ "modernc.org/sqlite"
|
|
|
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
)
|
|
|
|
var (
|
|
// ErrNotFound is returned when a record is not found
|
|
ErrNotFound = errors.New("record not found")
|
|
|
|
// ErrDuplicated is returned when a record already exists
|
|
ErrDuplicated = errors.New("record already exists")
|
|
)
|
|
|
|
// Database handles database operations
|
|
type Database struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// New creates a new Database instance
|
|
func New(dbPath string) (*Database, error) {
|
|
// Open database connection
|
|
db, err := sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Initialize database
|
|
if err := initDatabase(db); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Database{db: db}, nil
|
|
}
|
|
|
|
// Close closes the database connection
|
|
func (d *Database) Close() error {
|
|
return d.db.Close()
|
|
}
|
|
|
|
// GetChannelByID retrieves a channel by ID
|
|
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
|
query := `
|
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
|
FROM channels
|
|
WHERE id = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, id)
|
|
|
|
var (
|
|
platform string
|
|
platformChannelID string
|
|
enabled bool
|
|
channelRawJSON string
|
|
)
|
|
|
|
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse channel_raw JSON
|
|
var channelRaw map[string]interface{}
|
|
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create channel
|
|
channel := &model.Channel{
|
|
ID: id,
|
|
Platform: platform,
|
|
PlatformChannelID: platformChannelID,
|
|
Enabled: enabled,
|
|
ChannelRaw: channelRaw,
|
|
Plugins: make(map[string]*model.ChannelPlugin),
|
|
}
|
|
|
|
// Get channel plugins
|
|
plugins, err := d.GetChannelPlugins(id)
|
|
if err != nil && err != ErrNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
for _, plugin := range plugins {
|
|
channel.Plugins[plugin.PluginID] = plugin
|
|
}
|
|
|
|
return channel, nil
|
|
}
|
|
|
|
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
|
|
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
|
|
query := `
|
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
|
FROM channels
|
|
WHERE platform = ? AND platform_channel_id = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, platform, platformChannelID)
|
|
|
|
var (
|
|
id int64
|
|
enabled bool
|
|
channelRawJSON string
|
|
)
|
|
|
|
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse channel_raw JSON
|
|
var channelRaw map[string]interface{}
|
|
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create channel
|
|
channel := &model.Channel{
|
|
ID: id,
|
|
Platform: platform,
|
|
PlatformChannelID: platformChannelID,
|
|
Enabled: enabled,
|
|
ChannelRaw: channelRaw,
|
|
Plugins: make(map[string]*model.ChannelPlugin),
|
|
}
|
|
|
|
// Get channel plugins
|
|
plugins, err := d.GetChannelPlugins(id)
|
|
if err != nil && err != ErrNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
for _, plugin := range plugins {
|
|
channel.Plugins[plugin.PluginID] = plugin
|
|
}
|
|
|
|
return channel, nil
|
|
}
|
|
|
|
// CreateChannel creates a new channel
|
|
func (d *Database) CreateChannel(platform, platformChannelID string, enabled bool, channelRaw map[string]interface{}) (*model.Channel, error) {
|
|
// Convert channelRaw to JSON
|
|
channelRawJSON, err := json.Marshal(channelRaw)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Insert channel
|
|
query := `
|
|
INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw)
|
|
VALUES (?, ?, ?, ?)
|
|
`
|
|
|
|
result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get inserted ID
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create channel
|
|
channel := &model.Channel{
|
|
ID: id,
|
|
Platform: platform,
|
|
PlatformChannelID: platformChannelID,
|
|
Enabled: enabled,
|
|
ChannelRaw: channelRaw,
|
|
Plugins: make(map[string]*model.ChannelPlugin),
|
|
}
|
|
|
|
return channel, nil
|
|
}
|
|
|
|
// UpdateChannel updates a channel's enabled status
|
|
func (d *Database) UpdateChannel(id int64, enabled bool) error {
|
|
query := `
|
|
UPDATE channels
|
|
SET enabled = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, enabled, id)
|
|
return err
|
|
}
|
|
|
|
// DeleteChannel deletes a channel
|
|
func (d *Database) DeleteChannel(id int64) error {
|
|
// First delete all channel plugins
|
|
if err := d.DeleteChannelPluginsByChannel(id); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Then delete the channel
|
|
query := `
|
|
DELETE FROM channels
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, id)
|
|
return err
|
|
}
|
|
|
|
// GetChannelPlugins retrieves all plugins for a channel
|
|
func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) {
|
|
query := `
|
|
SELECT id, channel_id, plugin_id, enabled, config
|
|
FROM channel_plugin
|
|
WHERE channel_id = ?
|
|
`
|
|
|
|
rows, err := d.db.Query(query, channelID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var plugins []*model.ChannelPlugin
|
|
|
|
for rows.Next() {
|
|
var (
|
|
id int64
|
|
channelID int64
|
|
pluginID string
|
|
enabled bool
|
|
configJSON string
|
|
)
|
|
|
|
if err := rows.Scan(&id, &channelID, &pluginID, &enabled, &configJSON); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse config JSON
|
|
var config map[string]interface{}
|
|
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
plugin := &model.ChannelPlugin{
|
|
ID: id,
|
|
ChannelID: channelID,
|
|
PluginID: pluginID,
|
|
Enabled: enabled,
|
|
Config: config,
|
|
}
|
|
|
|
plugins = append(plugins, plugin)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(plugins) == 0 {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
return plugins, nil
|
|
}
|
|
|
|
// GetChannelPluginByID retrieves a channel plugin by ID
|
|
func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) {
|
|
query := `
|
|
SELECT id, channel_id, plugin_id, enabled, config
|
|
FROM channel_plugin
|
|
WHERE id = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, id)
|
|
|
|
var (
|
|
channelID int64
|
|
pluginID string
|
|
enabled bool
|
|
configJSON string
|
|
)
|
|
|
|
err := row.Scan(&id, &channelID, &pluginID, &enabled, &configJSON)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse config JSON
|
|
var config map[string]interface{}
|
|
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.ChannelPlugin{
|
|
ID: id,
|
|
ChannelID: channelID,
|
|
PluginID: pluginID,
|
|
Enabled: enabled,
|
|
Config: config,
|
|
}, nil
|
|
}
|
|
|
|
// CreateChannelPlugin creates a new channel plugin
|
|
func (d *Database) CreateChannelPlugin(channelID int64, pluginID string, enabled bool, config map[string]interface{}) (*model.ChannelPlugin, error) {
|
|
// Check if plugin already exists for this channel
|
|
query := `
|
|
SELECT COUNT(*)
|
|
FROM channel_plugin
|
|
WHERE channel_id = ? AND plugin_id = ?
|
|
`
|
|
|
|
var count int
|
|
err := d.db.QueryRow(query, channelID, pluginID).Scan(&count)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if count > 0 {
|
|
return nil, ErrDuplicated
|
|
}
|
|
|
|
// Convert config to JSON
|
|
configJSON, err := json.Marshal(config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Insert channel plugin
|
|
insertQuery := `
|
|
INSERT INTO channel_plugin (channel_id, plugin_id, enabled, config)
|
|
VALUES (?, ?, ?, ?)
|
|
`
|
|
|
|
result, err := d.db.Exec(insertQuery, channelID, pluginID, enabled, string(configJSON))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get inserted ID
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.ChannelPlugin{
|
|
ID: id,
|
|
ChannelID: channelID,
|
|
PluginID: pluginID,
|
|
Enabled: enabled,
|
|
Config: config,
|
|
}, nil
|
|
}
|
|
|
|
// UpdateChannelPlugin updates a channel plugin's enabled status
|
|
func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error {
|
|
query := `
|
|
UPDATE channel_plugin
|
|
SET enabled = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, enabled, id)
|
|
return err
|
|
}
|
|
|
|
// DeleteChannelPlugin deletes a channel plugin
|
|
func (d *Database) DeleteChannelPlugin(id int64) error {
|
|
query := `
|
|
DELETE FROM channel_plugin
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, id)
|
|
return err
|
|
}
|
|
|
|
// DeleteChannelPluginsByChannel deletes all plugins for a channel
|
|
func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
|
|
query := `
|
|
DELETE FROM channel_plugin
|
|
WHERE channel_id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, channelID)
|
|
return err
|
|
}
|
|
|
|
// GetAllChannels retrieves all channels
|
|
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
|
query := `
|
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
|
FROM channels
|
|
`
|
|
|
|
rows, err := d.db.Query(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var channels []*model.Channel
|
|
|
|
for rows.Next() {
|
|
var (
|
|
id int64
|
|
platform string
|
|
platformChannelID string
|
|
enabled bool
|
|
channelRawJSON string
|
|
)
|
|
|
|
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse channel_raw JSON
|
|
var channelRaw map[string]interface{}
|
|
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create channel
|
|
channel := &model.Channel{
|
|
ID: id,
|
|
Platform: platform,
|
|
PlatformChannelID: platformChannelID,
|
|
Enabled: enabled,
|
|
ChannelRaw: channelRaw,
|
|
Plugins: make(map[string]*model.ChannelPlugin),
|
|
}
|
|
|
|
// Get channel plugins
|
|
plugins, err := d.GetChannelPlugins(id)
|
|
if err != nil && err != ErrNotFound {
|
|
continue // Skip this channel if plugins can't be retrieved
|
|
}
|
|
|
|
if plugins != nil {
|
|
for _, plugin := range plugins {
|
|
channel.Plugins[plugin.PluginID] = plugin
|
|
}
|
|
}
|
|
|
|
channels = append(channels, channel)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(channels) == 0 {
|
|
channels = make([]*model.Channel, 0)
|
|
}
|
|
|
|
return channels, nil
|
|
}
|
|
|
|
// GetUserByID retrieves a user by ID
|
|
func (d *Database) GetUserByID(id int64) (*model.User, error) {
|
|
query := `
|
|
SELECT id, username, password
|
|
FROM users
|
|
WHERE id = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, id)
|
|
|
|
var (
|
|
username string
|
|
password string
|
|
)
|
|
|
|
err := row.Scan(&id, &username, &password)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.User{
|
|
ID: id,
|
|
Username: username,
|
|
Password: password,
|
|
}, nil
|
|
}
|
|
|
|
// CreateUser creates a new user
|
|
func (d *Database) CreateUser(username, password string) (*model.User, error) {
|
|
// Hash password
|
|
hashedPassword := hashPassword(password)
|
|
|
|
// Insert user
|
|
query := `
|
|
INSERT INTO users (username, password)
|
|
VALUES (?, ?)
|
|
`
|
|
|
|
result, err := d.db.Exec(query, username, hashedPassword)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get inserted ID
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.User{
|
|
ID: id,
|
|
Username: username,
|
|
Password: hashedPassword,
|
|
}, nil
|
|
}
|
|
|
|
// CheckCredentials checks if the username and password are valid
|
|
func (d *Database) CheckCredentials(username, password string) (*model.User, error) {
|
|
query := `
|
|
SELECT id, username, password
|
|
FROM users
|
|
WHERE username = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, username)
|
|
|
|
var (
|
|
id int64
|
|
dbUsername string
|
|
dbPassword string
|
|
)
|
|
|
|
err := row.Scan(&id, &dbUsername, &dbPassword)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check password
|
|
hashedPassword := hashPassword(password)
|
|
if dbPassword != hashedPassword {
|
|
return nil, errors.New("invalid credentials")
|
|
}
|
|
|
|
return &model.User{
|
|
ID: id,
|
|
Username: dbUsername,
|
|
Password: dbPassword,
|
|
}, nil
|
|
}
|
|
|
|
// Helper function to hash password
|
|
func hashPassword(password string) string {
|
|
// In a real implementation, use a proper password hashing library like bcrypt
|
|
// This is a simplified version for demonstration
|
|
hasher := sha256.New()
|
|
hasher.Write([]byte(password))
|
|
return hex.EncodeToString(hasher.Sum(nil))
|
|
}
|
|
|
|
// Initialize database tables
|
|
func initDatabase(db *sql.DB) error {
|
|
// Create channels table
|
|
_, err := db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS channels (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
platform TEXT NOT NULL,
|
|
platform_channel_id TEXT NOT NULL,
|
|
enabled BOOLEAN NOT NULL DEFAULT 0,
|
|
channel_raw TEXT NOT NULL,
|
|
UNIQUE(platform, platform_channel_id)
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create channel_plugin table
|
|
_, err = db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS channel_plugin (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
channel_id INTEGER NOT NULL,
|
|
plugin_id TEXT NOT NULL,
|
|
enabled BOOLEAN NOT NULL DEFAULT 0,
|
|
config TEXT NOT NULL DEFAULT '{}',
|
|
UNIQUE(channel_id, plugin_id),
|
|
FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create users table
|
|
_, err = db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password TEXT NOT NULL
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create default admin user if it doesn't exist
|
|
var count int
|
|
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if count == 0 {
|
|
hashedPassword := hashPassword("admin")
|
|
_, err = db.Exec("INSERT INTO users (username, password) VALUES (?, ?)", "admin", hashedPassword)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|