This commit is contained in:
parent
9c78ea2d48
commit
7c684af8c3
79 changed files with 3594 additions and 3257 deletions
641
internal/db/db.go
Normal file
641
internal/db/db.go
Normal file
|
@ -0,0 +1,641 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotFound is returned when a record is not found
|
||||
ErrNotFound = errors.New("record not found")
|
||||
|
||||
// ErrDuplicated is returned when a record already exists
|
||||
ErrDuplicated = errors.New("record already exists")
|
||||
)
|
||||
|
||||
// Database handles database operations
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// New creates a new Database instance
|
||||
func New(dbPath string) (*Database, error) {
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
if err := initDatabase(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Database{db: db}, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (d *Database) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
// GetChannelByID retrieves a channel by ID
|
||||
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||
query := `
|
||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||
FROM channels
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
row := d.db.QueryRow(query, id)
|
||||
|
||||
var (
|
||||
platform string
|
||||
platformChannelID string
|
||||
enabled bool
|
||||
channelRawJSON string
|
||||
)
|
||||
|
||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse channel_raw JSON
|
||||
var channelRaw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create channel
|
||||
channel := &model.Channel{
|
||||
ID: id,
|
||||
Platform: platform,
|
||||
PlatformChannelID: platformChannelID,
|
||||
Enabled: enabled,
|
||||
ChannelRaw: channelRaw,
|
||||
Plugins: make(map[string]*model.ChannelPlugin),
|
||||
}
|
||||
|
||||
// Get channel plugins
|
||||
plugins, err := d.GetChannelPlugins(id)
|
||||
if err != nil && err != ErrNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, plugin := range plugins {
|
||||
channel.Plugins[plugin.PluginID] = plugin
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
|
||||
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
|
||||
query := `
|
||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||
FROM channels
|
||||
WHERE platform = ? AND platform_channel_id = ?
|
||||
`
|
||||
|
||||
row := d.db.QueryRow(query, platform, platformChannelID)
|
||||
|
||||
var (
|
||||
id int64
|
||||
enabled bool
|
||||
channelRawJSON string
|
||||
)
|
||||
|
||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse channel_raw JSON
|
||||
var channelRaw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create channel
|
||||
channel := &model.Channel{
|
||||
ID: id,
|
||||
Platform: platform,
|
||||
PlatformChannelID: platformChannelID,
|
||||
Enabled: enabled,
|
||||
ChannelRaw: channelRaw,
|
||||
Plugins: make(map[string]*model.ChannelPlugin),
|
||||
}
|
||||
|
||||
// Get channel plugins
|
||||
plugins, err := d.GetChannelPlugins(id)
|
||||
if err != nil && err != ErrNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, plugin := range plugins {
|
||||
channel.Plugins[plugin.PluginID] = plugin
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
// CreateChannel creates a new channel
|
||||
func (d *Database) CreateChannel(platform, platformChannelID string, enabled bool, channelRaw map[string]interface{}) (*model.Channel, error) {
|
||||
// Convert channelRaw to JSON
|
||||
channelRawJSON, err := json.Marshal(channelRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Insert channel
|
||||
query := `
|
||||
INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get inserted ID
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create channel
|
||||
channel := &model.Channel{
|
||||
ID: id,
|
||||
Platform: platform,
|
||||
PlatformChannelID: platformChannelID,
|
||||
Enabled: enabled,
|
||||
ChannelRaw: channelRaw,
|
||||
Plugins: make(map[string]*model.ChannelPlugin),
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
// UpdateChannel updates a channel's enabled status
|
||||
func (d *Database) UpdateChannel(id int64, enabled bool) error {
|
||||
query := `
|
||||
UPDATE channels
|
||||
SET enabled = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query, enabled, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteChannel deletes a channel
|
||||
func (d *Database) DeleteChannel(id int64) error {
|
||||
// First delete all channel plugins
|
||||
if err := d.DeleteChannelPluginsByChannel(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then delete the channel
|
||||
query := `
|
||||
DELETE FROM channels
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetChannelPlugins retrieves all plugins for a channel
|
||||
func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) {
|
||||
query := `
|
||||
SELECT id, channel_id, plugin_id, enabled, config
|
||||
FROM channel_plugin
|
||||
WHERE channel_id = ?
|
||||
`
|
||||
|
||||
rows, err := d.db.Query(query, channelID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var plugins []*model.ChannelPlugin
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
channelID int64
|
||||
pluginID string
|
||||
enabled bool
|
||||
configJSON string
|
||||
)
|
||||
|
||||
if err := rows.Scan(&id, &channelID, &pluginID, &enabled, &configJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse config JSON
|
||||
var config map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plugin := &model.ChannelPlugin{
|
||||
ID: id,
|
||||
ChannelID: channelID,
|
||||
PluginID: pluginID,
|
||||
Enabled: enabled,
|
||||
Config: config,
|
||||
}
|
||||
|
||||
plugins = append(plugins, plugin)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(plugins) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
// GetChannelPluginByID retrieves a channel plugin by ID
|
||||
func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) {
|
||||
query := `
|
||||
SELECT id, channel_id, plugin_id, enabled, config
|
||||
FROM channel_plugin
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
row := d.db.QueryRow(query, id)
|
||||
|
||||
var (
|
||||
channelID int64
|
||||
pluginID string
|
||||
enabled bool
|
||||
configJSON string
|
||||
)
|
||||
|
||||
err := row.Scan(&id, &channelID, &pluginID, &enabled, &configJSON)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse config JSON
|
||||
var config map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.ChannelPlugin{
|
||||
ID: id,
|
||||
ChannelID: channelID,
|
||||
PluginID: pluginID,
|
||||
Enabled: enabled,
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateChannelPlugin creates a new channel plugin
|
||||
func (d *Database) CreateChannelPlugin(channelID int64, pluginID string, enabled bool, config map[string]interface{}) (*model.ChannelPlugin, error) {
|
||||
// Check if plugin already exists for this channel
|
||||
query := `
|
||||
SELECT COUNT(*)
|
||||
FROM channel_plugin
|
||||
WHERE channel_id = ? AND plugin_id = ?
|
||||
`
|
||||
|
||||
var count int
|
||||
err := d.db.QueryRow(query, channelID, pluginID).Scan(&count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil, ErrDuplicated
|
||||
}
|
||||
|
||||
// Convert config to JSON
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Insert channel plugin
|
||||
insertQuery := `
|
||||
INSERT INTO channel_plugin (channel_id, plugin_id, enabled, config)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := d.db.Exec(insertQuery, channelID, pluginID, enabled, string(configJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get inserted ID
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.ChannelPlugin{
|
||||
ID: id,
|
||||
ChannelID: channelID,
|
||||
PluginID: pluginID,
|
||||
Enabled: enabled,
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateChannelPlugin updates a channel plugin's enabled status
|
||||
func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error {
|
||||
query := `
|
||||
UPDATE channel_plugin
|
||||
SET enabled = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query, enabled, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteChannelPlugin deletes a channel plugin
|
||||
func (d *Database) DeleteChannelPlugin(id int64) error {
|
||||
query := `
|
||||
DELETE FROM channel_plugin
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteChannelPluginsByChannel deletes all plugins for a channel
|
||||
func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
|
||||
query := `
|
||||
DELETE FROM channel_plugin
|
||||
WHERE channel_id = ?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query, channelID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAllChannels retrieves all channels
|
||||
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||
query := `
|
||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||
FROM channels
|
||||
`
|
||||
|
||||
rows, err := d.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var channels []*model.Channel
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
platform string
|
||||
platformChannelID string
|
||||
enabled bool
|
||||
channelRawJSON string
|
||||
)
|
||||
|
||||
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse channel_raw JSON
|
||||
var channelRaw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create channel
|
||||
channel := &model.Channel{
|
||||
ID: id,
|
||||
Platform: platform,
|
||||
PlatformChannelID: platformChannelID,
|
||||
Enabled: enabled,
|
||||
ChannelRaw: channelRaw,
|
||||
Plugins: make(map[string]*model.ChannelPlugin),
|
||||
}
|
||||
|
||||
// Get channel plugins
|
||||
plugins, err := d.GetChannelPlugins(id)
|
||||
if err != nil && err != ErrNotFound {
|
||||
continue // Skip this channel if plugins can't be retrieved
|
||||
}
|
||||
|
||||
if plugins != nil {
|
||||
for _, plugin := range plugins {
|
||||
channel.Plugins[plugin.PluginID] = plugin
|
||||
}
|
||||
}
|
||||
|
||||
channels = append(channels, channel)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(channels) == 0 {
|
||||
channels = make([]*model.Channel, 0)
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (d *Database) GetUserByID(id int64) (*model.User, error) {
|
||||
query := `
|
||||
SELECT id, username, password
|
||||
FROM users
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
row := d.db.QueryRow(query, id)
|
||||
|
||||
var (
|
||||
username string
|
||||
password string
|
||||
)
|
||||
|
||||
err := row.Scan(&id, &username, &password)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.User{
|
||||
ID: id,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user
|
||||
func (d *Database) CreateUser(username, password string) (*model.User, error) {
|
||||
// Hash password
|
||||
hashedPassword := hashPassword(password)
|
||||
|
||||
// Insert user
|
||||
query := `
|
||||
INSERT INTO users (username, password)
|
||||
VALUES (?, ?)
|
||||
`
|
||||
|
||||
result, err := d.db.Exec(query, username, hashedPassword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get inserted ID
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.User{
|
||||
ID: id,
|
||||
Username: username,
|
||||
Password: hashedPassword,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckCredentials checks if the username and password are valid
|
||||
func (d *Database) CheckCredentials(username, password string) (*model.User, error) {
|
||||
query := `
|
||||
SELECT id, username, password
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
`
|
||||
|
||||
row := d.db.QueryRow(query, username)
|
||||
|
||||
var (
|
||||
id int64
|
||||
dbUsername string
|
||||
dbPassword string
|
||||
)
|
||||
|
||||
err := row.Scan(&id, &dbUsername, &dbPassword)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check password
|
||||
hashedPassword := hashPassword(password)
|
||||
if dbPassword != hashedPassword {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
return &model.User{
|
||||
ID: id,
|
||||
Username: dbUsername,
|
||||
Password: dbPassword,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Helper function to hash password
|
||||
func hashPassword(password string) string {
|
||||
// In a real implementation, use a proper password hashing library like bcrypt
|
||||
// This is a simplified version for demonstration
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(password))
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// Initialize database tables
|
||||
func initDatabase(db *sql.DB) error {
|
||||
// Create channels table
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
platform TEXT NOT NULL,
|
||||
platform_channel_id TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
channel_raw TEXT NOT NULL,
|
||||
UNIQUE(platform, platform_channel_id)
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create channel_plugin table
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS channel_plugin (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL,
|
||||
plugin_id TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
config TEXT NOT NULL DEFAULT '{}',
|
||||
UNIQUE(channel_id, plugin_id),
|
||||
FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create users table
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password TEXT NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create default admin user if it doesn't exist
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
hashedPassword := hashPassword("admin")
|
||||
_, err = db.Exec("INSERT INTO users (username, password) VALUES (?, ?)", "admin", hashedPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue