651 lines
14 KiB
Go
651 lines
14 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
_ "modernc.org/sqlite"
|
|
|
|
"git.nakama.town/fmartingr/butterrobot/internal/migration"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
)
|
|
|
|
var (
|
|
// ErrNotFound is returned when a record is not found
|
|
ErrNotFound = errors.New("record not found")
|
|
|
|
// ErrDuplicated is returned when a record already exists
|
|
ErrDuplicated = errors.New("record already exists")
|
|
)
|
|
|
|
// Database handles database operations
|
|
type Database struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// New creates a new Database instance
|
|
func New(dbPath string) (*Database, error) {
|
|
// Open database connection
|
|
db, err := sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Initialize database
|
|
if err := initDatabase(db); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Database{db: db}, nil
|
|
}
|
|
|
|
// Close closes the database connection
|
|
func (d *Database) Close() error {
|
|
return d.db.Close()
|
|
}
|
|
|
|
// GetChannelByID retrieves a channel by ID
|
|
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
|
query := `
|
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
|
FROM channels
|
|
WHERE id = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, id)
|
|
|
|
var (
|
|
platform string
|
|
platformChannelID string
|
|
enabled bool
|
|
channelRawJSON string
|
|
)
|
|
|
|
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse channel_raw JSON
|
|
var channelRaw map[string]interface{}
|
|
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create channel
|
|
channel := &model.Channel{
|
|
ID: id,
|
|
Platform: platform,
|
|
PlatformChannelID: platformChannelID,
|
|
Enabled: enabled,
|
|
ChannelRaw: channelRaw,
|
|
Plugins: make(map[string]*model.ChannelPlugin),
|
|
}
|
|
|
|
// Get channel plugins
|
|
plugins, err := d.GetChannelPlugins(id)
|
|
if err != nil && err != ErrNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
for _, plugin := range plugins {
|
|
channel.Plugins[plugin.PluginID] = plugin
|
|
}
|
|
|
|
return channel, nil
|
|
}
|
|
|
|
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
|
|
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
|
|
query := `
|
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
|
FROM channels
|
|
WHERE platform = ? AND platform_channel_id = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, platform, platformChannelID)
|
|
|
|
var (
|
|
id int64
|
|
enabled bool
|
|
channelRawJSON string
|
|
)
|
|
|
|
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse channel_raw JSON
|
|
var channelRaw map[string]interface{}
|
|
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create channel
|
|
channel := &model.Channel{
|
|
ID: id,
|
|
Platform: platform,
|
|
PlatformChannelID: platformChannelID,
|
|
Enabled: enabled,
|
|
ChannelRaw: channelRaw,
|
|
Plugins: make(map[string]*model.ChannelPlugin),
|
|
}
|
|
|
|
// Get channel plugins
|
|
plugins, err := d.GetChannelPlugins(id)
|
|
if err != nil && err != ErrNotFound {
|
|
return nil, err
|
|
}
|
|
|
|
for _, plugin := range plugins {
|
|
channel.Plugins[plugin.PluginID] = plugin
|
|
}
|
|
|
|
return channel, nil
|
|
}
|
|
|
|
// CreateChannel creates a new channel
|
|
func (d *Database) CreateChannel(platform, platformChannelID string, enabled bool, channelRaw map[string]interface{}) (*model.Channel, error) {
|
|
// Convert channelRaw to JSON
|
|
channelRawJSON, err := json.Marshal(channelRaw)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Insert channel
|
|
query := `
|
|
INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw)
|
|
VALUES (?, ?, ?, ?)
|
|
`
|
|
|
|
result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get inserted ID
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create channel
|
|
channel := &model.Channel{
|
|
ID: id,
|
|
Platform: platform,
|
|
PlatformChannelID: platformChannelID,
|
|
Enabled: enabled,
|
|
ChannelRaw: channelRaw,
|
|
Plugins: make(map[string]*model.ChannelPlugin),
|
|
}
|
|
|
|
return channel, nil
|
|
}
|
|
|
|
// UpdateChannel updates a channel's enabled status
|
|
func (d *Database) UpdateChannel(id int64, enabled bool) error {
|
|
query := `
|
|
UPDATE channels
|
|
SET enabled = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, enabled, id)
|
|
return err
|
|
}
|
|
|
|
// DeleteChannel deletes a channel
|
|
func (d *Database) DeleteChannel(id int64) error {
|
|
// First delete all channel plugins
|
|
if err := d.DeleteChannelPluginsByChannel(id); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Then delete the channel
|
|
query := `
|
|
DELETE FROM channels
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, id)
|
|
return err
|
|
}
|
|
|
|
// GetChannelPlugins retrieves all plugins for a channel
|
|
func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) {
|
|
query := `
|
|
SELECT id, channel_id, plugin_id, enabled, config
|
|
FROM channel_plugin
|
|
WHERE channel_id = ?
|
|
`
|
|
|
|
rows, err := d.db.Query(query, channelID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var plugins []*model.ChannelPlugin
|
|
|
|
for rows.Next() {
|
|
var (
|
|
id int64
|
|
channelID int64
|
|
pluginID string
|
|
enabled bool
|
|
configJSON string
|
|
)
|
|
|
|
if err := rows.Scan(&id, &channelID, &pluginID, &enabled, &configJSON); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse config JSON
|
|
var config map[string]interface{}
|
|
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
plugin := &model.ChannelPlugin{
|
|
ID: id,
|
|
ChannelID: channelID,
|
|
PluginID: pluginID,
|
|
Enabled: enabled,
|
|
Config: config,
|
|
}
|
|
|
|
plugins = append(plugins, plugin)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(plugins) == 0 {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
return plugins, nil
|
|
}
|
|
|
|
// GetChannelPluginByID retrieves a channel plugin by ID
|
|
func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) {
|
|
query := `
|
|
SELECT id, channel_id, plugin_id, enabled, config
|
|
FROM channel_plugin
|
|
WHERE id = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, id)
|
|
|
|
var (
|
|
channelID int64
|
|
pluginID string
|
|
enabled bool
|
|
configJSON string
|
|
)
|
|
|
|
err := row.Scan(&id, &channelID, &pluginID, &enabled, &configJSON)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse config JSON
|
|
var config map[string]interface{}
|
|
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.ChannelPlugin{
|
|
ID: id,
|
|
ChannelID: channelID,
|
|
PluginID: pluginID,
|
|
Enabled: enabled,
|
|
Config: config,
|
|
}, nil
|
|
}
|
|
|
|
// CreateChannelPlugin creates a new channel plugin
|
|
func (d *Database) CreateChannelPlugin(channelID int64, pluginID string, enabled bool, config map[string]interface{}) (*model.ChannelPlugin, error) {
|
|
// Check if plugin already exists for this channel
|
|
query := `
|
|
SELECT COUNT(*)
|
|
FROM channel_plugin
|
|
WHERE channel_id = ? AND plugin_id = ?
|
|
`
|
|
|
|
var count int
|
|
err := d.db.QueryRow(query, channelID, pluginID).Scan(&count)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if count > 0 {
|
|
return nil, ErrDuplicated
|
|
}
|
|
|
|
// Convert config to JSON
|
|
configJSON, err := json.Marshal(config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Insert channel plugin
|
|
insertQuery := `
|
|
INSERT INTO channel_plugin (channel_id, plugin_id, enabled, config)
|
|
VALUES (?, ?, ?, ?)
|
|
`
|
|
|
|
result, err := d.db.Exec(insertQuery, channelID, pluginID, enabled, string(configJSON))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get inserted ID
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.ChannelPlugin{
|
|
ID: id,
|
|
ChannelID: channelID,
|
|
PluginID: pluginID,
|
|
Enabled: enabled,
|
|
Config: config,
|
|
}, nil
|
|
}
|
|
|
|
// UpdateChannelPlugin updates a channel plugin's enabled status
|
|
func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error {
|
|
query := `
|
|
UPDATE channel_plugin
|
|
SET enabled = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, enabled, id)
|
|
return err
|
|
}
|
|
|
|
// DeleteChannelPlugin deletes a channel plugin
|
|
func (d *Database) DeleteChannelPlugin(id int64) error {
|
|
query := `
|
|
DELETE FROM channel_plugin
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, id)
|
|
return err
|
|
}
|
|
|
|
// DeleteChannelPluginsByChannel deletes all plugins for a channel
|
|
func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
|
|
query := `
|
|
DELETE FROM channel_plugin
|
|
WHERE channel_id = ?
|
|
`
|
|
|
|
_, err := d.db.Exec(query, channelID)
|
|
return err
|
|
}
|
|
|
|
// GetAllChannels retrieves all channels
|
|
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
|
query := `
|
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
|
FROM channels
|
|
`
|
|
|
|
rows, err := d.db.Query(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var channels []*model.Channel
|
|
|
|
for rows.Next() {
|
|
var (
|
|
id int64
|
|
platform string
|
|
platformChannelID string
|
|
enabled bool
|
|
channelRawJSON string
|
|
)
|
|
|
|
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse channel_raw JSON
|
|
var channelRaw map[string]interface{}
|
|
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create channel
|
|
channel := &model.Channel{
|
|
ID: id,
|
|
Platform: platform,
|
|
PlatformChannelID: platformChannelID,
|
|
Enabled: enabled,
|
|
ChannelRaw: channelRaw,
|
|
Plugins: make(map[string]*model.ChannelPlugin),
|
|
}
|
|
|
|
// Get channel plugins
|
|
plugins, err := d.GetChannelPlugins(id)
|
|
if err != nil && err != ErrNotFound {
|
|
continue // Skip this channel if plugins can't be retrieved
|
|
}
|
|
|
|
if plugins != nil {
|
|
for _, plugin := range plugins {
|
|
channel.Plugins[plugin.PluginID] = plugin
|
|
}
|
|
}
|
|
|
|
channels = append(channels, channel)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(channels) == 0 {
|
|
channels = make([]*model.Channel, 0)
|
|
}
|
|
|
|
return channels, nil
|
|
}
|
|
|
|
// GetUserByID retrieves a user by ID
|
|
func (d *Database) GetUserByID(id int64) (*model.User, error) {
|
|
query := `
|
|
SELECT id, username, password
|
|
FROM users
|
|
WHERE id = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, id)
|
|
|
|
var (
|
|
username string
|
|
password string
|
|
)
|
|
|
|
err := row.Scan(&id, &username, &password)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.User{
|
|
ID: id,
|
|
Username: username,
|
|
Password: password,
|
|
}, nil
|
|
}
|
|
|
|
// CreateUser creates a new user
|
|
func (d *Database) CreateUser(username, password string) (*model.User, error) {
|
|
// Hash password
|
|
hashedPassword, err := hashPassword(password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Insert user
|
|
query := `
|
|
INSERT INTO users (username, password)
|
|
VALUES (?, ?)
|
|
`
|
|
|
|
result, err := d.db.Exec(query, username, hashedPassword)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get inserted ID
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.User{
|
|
ID: id,
|
|
Username: username,
|
|
Password: hashedPassword,
|
|
}, nil
|
|
}
|
|
|
|
// CheckCredentials checks if the username and password are valid
|
|
func (d *Database) CheckCredentials(username, password string) (*model.User, error) {
|
|
query := `
|
|
SELECT id, username, password
|
|
FROM users
|
|
WHERE username = ?
|
|
`
|
|
|
|
row := d.db.QueryRow(query, username)
|
|
|
|
var (
|
|
id int64
|
|
dbUsername string
|
|
dbPassword string
|
|
)
|
|
|
|
err := row.Scan(&id, &dbUsername, &dbPassword)
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check password with bcrypt
|
|
err = bcrypt.CompareHashAndPassword([]byte(dbPassword), []byte(password))
|
|
if err != nil {
|
|
return nil, errors.New("invalid credentials")
|
|
}
|
|
|
|
return &model.User{
|
|
ID: id,
|
|
Username: dbUsername,
|
|
Password: dbPassword,
|
|
}, nil
|
|
}
|
|
|
|
// UpdateUserPassword updates a user's password
|
|
func (d *Database) UpdateUserPassword(userID int64, newPassword string) error {
|
|
// Hash the new password
|
|
hashedPassword, err := hashPassword(newPassword)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update the user's password
|
|
query := `
|
|
UPDATE users
|
|
SET password = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err = d.db.Exec(query, hashedPassword, userID)
|
|
return err
|
|
}
|
|
|
|
// Helper function to hash password
|
|
func hashPassword(password string) (string, error) {
|
|
// Use bcrypt for secure password hashing
|
|
// The cost parameter is the computational cost, higher is more secure but slower
|
|
// Recommended minimum is 12
|
|
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(hashedBytes), nil
|
|
}
|
|
|
|
// Initialize database tables
|
|
func initDatabase(db *sql.DB) error {
|
|
// Ensure migration table exists
|
|
if err := migration.EnsureMigrationTable(db); err != nil {
|
|
return fmt.Errorf("failed to create migration table: %w", err)
|
|
}
|
|
|
|
// Get applied migrations
|
|
applied, err := migration.GetAppliedMigrations(db)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
|
}
|
|
|
|
// Get all migration versions
|
|
allMigrations := make([]int, 0, len(migration.Migrations))
|
|
for version := range migration.Migrations {
|
|
allMigrations = append(allMigrations, version)
|
|
}
|
|
|
|
// Create a map of applied migrations for quick lookup
|
|
appliedMap := make(map[int]bool)
|
|
for _, version := range applied {
|
|
appliedMap[version] = true
|
|
}
|
|
|
|
// Count pending migrations
|
|
pendingCount := 0
|
|
for _, version := range allMigrations {
|
|
if !appliedMap[version] {
|
|
pendingCount++
|
|
}
|
|
}
|
|
|
|
// Run migrations if needed
|
|
if pendingCount > 0 {
|
|
fmt.Printf("Running %d pending database migrations...\n", pendingCount)
|
|
if err := migration.Migrate(db); err != nil {
|
|
return fmt.Errorf("migration failed: %w", err)
|
|
}
|
|
fmt.Println("Database migrations completed successfully.")
|
|
} else {
|
|
fmt.Println("Database schema is up to date.")
|
|
}
|
|
|
|
return nil
|
|
}
|