butterrobot/internal/db/db.go
Felipe M. 72c6dd6982
All checks were successful
ci/woodpecker/tag/release Pipeline was successful
feat: remindme plugin
2025-04-22 11:29:39 +02:00

766 lines
16 KiB
Go

package db
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
"git.nakama.town/fmartingr/butterrobot/internal/migration"
"git.nakama.town/fmartingr/butterrobot/internal/model"
)
var (
// ErrNotFound is returned when a record is not found
ErrNotFound = errors.New("record not found")
// ErrDuplicated is returned when a record already exists
ErrDuplicated = errors.New("record already exists")
)
// Database handles database operations
type Database struct {
db *sql.DB
}
// New creates a new Database instance
func New(dbPath string) (*Database, error) {
// Open database connection
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, err
}
// Initialize database
if err := initDatabase(db); err != nil {
return nil, err
}
return &Database{db: db}, nil
}
// Close closes the database connection
func (d *Database) Close() error {
return d.db.Close()
}
// GetChannelByID retrieves a channel by ID
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
query := `
SELECT id, platform, platform_channel_id, enabled, channel_raw
FROM channels
WHERE id = ?
`
row := d.db.QueryRow(query, id)
var (
platform string
platformChannelID string
enabled bool
channelRawJSON string
)
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
// Parse channel_raw JSON
var channelRaw map[string]interface{}
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
return nil, err
}
// Create channel
channel := &model.Channel{
ID: id,
Platform: platform,
PlatformChannelID: platformChannelID,
Enabled: enabled,
ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin),
}
// Get channel plugins
plugins, err := d.GetChannelPlugins(id)
if err != nil && err != ErrNotFound {
return nil, err
}
for _, plugin := range plugins {
channel.Plugins[plugin.PluginID] = plugin
}
return channel, nil
}
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
query := `
SELECT id, platform, platform_channel_id, enabled, channel_raw
FROM channels
WHERE platform = ? AND platform_channel_id = ?
`
row := d.db.QueryRow(query, platform, platformChannelID)
var (
id int64
enabled bool
channelRawJSON string
)
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
// Parse channel_raw JSON
var channelRaw map[string]interface{}
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
return nil, err
}
// Create channel
channel := &model.Channel{
ID: id,
Platform: platform,
PlatformChannelID: platformChannelID,
Enabled: enabled,
ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin),
}
// Get channel plugins
plugins, err := d.GetChannelPlugins(id)
if err != nil && err != ErrNotFound {
return nil, err
}
for _, plugin := range plugins {
channel.Plugins[plugin.PluginID] = plugin
}
return channel, nil
}
// CreateChannel creates a new channel
func (d *Database) CreateChannel(platform, platformChannelID string, enabled bool, channelRaw map[string]interface{}) (*model.Channel, error) {
// Convert channelRaw to JSON
channelRawJSON, err := json.Marshal(channelRaw)
if err != nil {
return nil, err
}
// Insert channel
query := `
INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw)
VALUES (?, ?, ?, ?)
`
result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON))
if err != nil {
return nil, err
}
// Get inserted ID
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
// Create channel
channel := &model.Channel{
ID: id,
Platform: platform,
PlatformChannelID: platformChannelID,
Enabled: enabled,
ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin),
}
return channel, nil
}
// UpdateChannel updates a channel's enabled status
func (d *Database) UpdateChannel(id int64, enabled bool) error {
query := `
UPDATE channels
SET enabled = ?
WHERE id = ?
`
_, err := d.db.Exec(query, enabled, id)
return err
}
// DeleteChannel deletes a channel
func (d *Database) DeleteChannel(id int64) error {
// First delete all channel plugins
if err := d.DeleteChannelPluginsByChannel(id); err != nil {
return err
}
// Then delete the channel
query := `
DELETE FROM channels
WHERE id = ?
`
_, err := d.db.Exec(query, id)
return err
}
// GetChannelPlugins retrieves all plugins for a channel
func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) {
query := `
SELECT id, channel_id, plugin_id, enabled, config
FROM channel_plugin
WHERE channel_id = ?
`
rows, err := d.db.Query(query, channelID)
if err != nil {
return nil, err
}
defer rows.Close()
var plugins []*model.ChannelPlugin
for rows.Next() {
var (
id int64
channelID int64
pluginID string
enabled bool
configJSON string
)
if err := rows.Scan(&id, &channelID, &pluginID, &enabled, &configJSON); err != nil {
return nil, err
}
// Parse config JSON
var config map[string]interface{}
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
return nil, err
}
plugin := &model.ChannelPlugin{
ID: id,
ChannelID: channelID,
PluginID: pluginID,
Enabled: enabled,
Config: config,
}
plugins = append(plugins, plugin)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(plugins) == 0 {
return nil, ErrNotFound
}
return plugins, nil
}
// GetChannelPluginByID retrieves a channel plugin by ID
func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) {
query := `
SELECT id, channel_id, plugin_id, enabled, config
FROM channel_plugin
WHERE id = ?
`
row := d.db.QueryRow(query, id)
var (
channelID int64
pluginID string
enabled bool
configJSON string
)
err := row.Scan(&id, &channelID, &pluginID, &enabled, &configJSON)
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
// Parse config JSON
var config map[string]interface{}
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
return nil, err
}
return &model.ChannelPlugin{
ID: id,
ChannelID: channelID,
PluginID: pluginID,
Enabled: enabled,
Config: config,
}, nil
}
// CreateChannelPlugin creates a new channel plugin
func (d *Database) CreateChannelPlugin(channelID int64, pluginID string, enabled bool, config map[string]interface{}) (*model.ChannelPlugin, error) {
// Check if plugin already exists for this channel
query := `
SELECT COUNT(*)
FROM channel_plugin
WHERE channel_id = ? AND plugin_id = ?
`
var count int
err := d.db.QueryRow(query, channelID, pluginID).Scan(&count)
if err != nil {
return nil, err
}
if count > 0 {
return nil, ErrDuplicated
}
// Convert config to JSON
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
// Insert channel plugin
insertQuery := `
INSERT INTO channel_plugin (channel_id, plugin_id, enabled, config)
VALUES (?, ?, ?, ?)
`
result, err := d.db.Exec(insertQuery, channelID, pluginID, enabled, string(configJSON))
if err != nil {
return nil, err
}
// Get inserted ID
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
return &model.ChannelPlugin{
ID: id,
ChannelID: channelID,
PluginID: pluginID,
Enabled: enabled,
Config: config,
}, nil
}
// UpdateChannelPlugin updates a channel plugin's enabled status
func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error {
query := `
UPDATE channel_plugin
SET enabled = ?
WHERE id = ?
`
_, err := d.db.Exec(query, enabled, id)
return err
}
// DeleteChannelPlugin deletes a channel plugin
func (d *Database) DeleteChannelPlugin(id int64) error {
query := `
DELETE FROM channel_plugin
WHERE id = ?
`
_, err := d.db.Exec(query, id)
return err
}
// DeleteChannelPluginsByChannel deletes all plugins for a channel
func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
query := `
DELETE FROM channel_plugin
WHERE channel_id = ?
`
_, err := d.db.Exec(query, channelID)
return err
}
// GetAllChannels retrieves all channels
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
query := `
SELECT id, platform, platform_channel_id, enabled, channel_raw
FROM channels
`
rows, err := d.db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
var channels []*model.Channel
for rows.Next() {
var (
id int64
platform string
platformChannelID string
enabled bool
channelRawJSON string
)
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil {
return nil, err
}
// Parse channel_raw JSON
var channelRaw map[string]interface{}
if err := json.Unmarshal([]byte(channelRawJSON), &channelRaw); err != nil {
return nil, err
}
// Create channel
channel := &model.Channel{
ID: id,
Platform: platform,
PlatformChannelID: platformChannelID,
Enabled: enabled,
ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin),
}
// Get channel plugins
plugins, err := d.GetChannelPlugins(id)
if err != nil && err != ErrNotFound {
continue // Skip this channel if plugins can't be retrieved
}
if plugins != nil {
for _, plugin := range plugins {
channel.Plugins[plugin.PluginID] = plugin
}
}
channels = append(channels, channel)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(channels) == 0 {
channels = make([]*model.Channel, 0)
}
return channels, nil
}
// GetUserByID retrieves a user by ID
func (d *Database) GetUserByID(id int64) (*model.User, error) {
query := `
SELECT id, username, password
FROM users
WHERE id = ?
`
row := d.db.QueryRow(query, id)
var (
username string
password string
)
err := row.Scan(&id, &username, &password)
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
return &model.User{
ID: id,
Username: username,
Password: password,
}, nil
}
// CreateUser creates a new user
func (d *Database) CreateUser(username, password string) (*model.User, error) {
// Hash password
hashedPassword, err := hashPassword(password)
if err != nil {
return nil, err
}
// Insert user
query := `
INSERT INTO users (username, password)
VALUES (?, ?)
`
result, err := d.db.Exec(query, username, hashedPassword)
if err != nil {
return nil, err
}
// Get inserted ID
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
return &model.User{
ID: id,
Username: username,
Password: hashedPassword,
}, nil
}
// CheckCredentials checks if the username and password are valid
func (d *Database) CheckCredentials(username, password string) (*model.User, error) {
query := `
SELECT id, username, password
FROM users
WHERE username = ?
`
row := d.db.QueryRow(query, username)
var (
id int64
dbUsername string
dbPassword string
)
err := row.Scan(&id, &dbUsername, &dbPassword)
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
// Check password with bcrypt
err = bcrypt.CompareHashAndPassword([]byte(dbPassword), []byte(password))
if err != nil {
return nil, errors.New("invalid credentials")
}
return &model.User{
ID: id,
Username: dbUsername,
Password: dbPassword,
}, nil
}
// UpdateUserPassword updates a user's password
func (d *Database) UpdateUserPassword(userID int64, newPassword string) error {
// Hash the new password
hashedPassword, err := hashPassword(newPassword)
if err != nil {
return err
}
// Update the user's password
query := `
UPDATE users
SET password = ?
WHERE id = ?
`
_, err = d.db.Exec(query, hashedPassword, userID)
return err
}
// CreateReminder creates a new reminder
func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
query := `
INSERT INTO reminders (
platform, channel_id, message_id, reply_to_id,
user_id, username, created_at, trigger_at,
content, processed
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0)
`
createdAt := time.Now()
result, err := d.db.Exec(
query,
platform, channelID, messageID, replyToID,
userID, username, createdAt, triggerAt,
content,
)
if err != nil {
return nil, err
}
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
return &model.Reminder{
ID: id,
Platform: platform,
ChannelID: channelID,
MessageID: messageID,
ReplyToID: replyToID,
UserID: userID,
Username: username,
CreatedAt: createdAt,
TriggerAt: triggerAt,
Content: content,
Processed: false,
}, nil
}
// GetPendingReminders gets all pending reminders that need to be processed
func (d *Database) GetPendingReminders() ([]*model.Reminder, error) {
query := `
SELECT id, platform, channel_id, message_id, reply_to_id,
user_id, username, created_at, trigger_at, content, processed
FROM reminders
WHERE processed = 0 AND trigger_at <= ?
`
rows, err := d.db.Query(query, time.Now())
if err != nil {
return nil, err
}
defer rows.Close()
var reminders []*model.Reminder
for rows.Next() {
var (
id int64
platform, channelID, messageID, replyToID string
userID, username, content string
createdAt, triggerAt time.Time
processed bool
)
if err := rows.Scan(
&id, &platform, &channelID, &messageID, &replyToID,
&userID, &username, &createdAt, &triggerAt, &content, &processed,
); err != nil {
return nil, err
}
reminder := &model.Reminder{
ID: id,
Platform: platform,
ChannelID: channelID,
MessageID: messageID,
ReplyToID: replyToID,
UserID: userID,
Username: username,
CreatedAt: createdAt,
TriggerAt: triggerAt,
Content: content,
Processed: processed,
}
reminders = append(reminders, reminder)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(reminders) == 0 {
return make([]*model.Reminder, 0), nil
}
return reminders, nil
}
// MarkReminderAsProcessed marks a reminder as processed
func (d *Database) MarkReminderAsProcessed(id int64) error {
query := `
UPDATE reminders
SET processed = 1
WHERE id = ?
`
_, err := d.db.Exec(query, id)
return err
}
// Helper function to hash password
func hashPassword(password string) (string, error) {
// Use bcrypt for secure password hashing
// The cost parameter is the computational cost, higher is more secure but slower
// Recommended minimum is 12
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), 12)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// Initialize database tables
func initDatabase(db *sql.DB) error {
// Ensure migration table exists
if err := migration.EnsureMigrationTable(db); err != nil {
return fmt.Errorf("failed to create migration table: %w", err)
}
// Get applied migrations
applied, err := migration.GetAppliedMigrations(db)
if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err)
}
// Get all migration versions
allMigrations := make([]int, 0, len(migration.Migrations))
for version := range migration.Migrations {
allMigrations = append(allMigrations, version)
}
// Create a map of applied migrations for quick lookup
appliedMap := make(map[int]bool)
for _, version := range applied {
appliedMap[version] = true
}
// Count pending migrations
pendingCount := 0
for _, version := range allMigrations {
if !appliedMap[version] {
pendingCount++
}
}
// Run migrations if needed
if pendingCount > 0 {
fmt.Printf("Running %d pending database migrations...\n", pendingCount)
if err := migration.Migrate(db); err != nil {
return fmt.Errorf("migration failed: %w", err)
}
fmt.Println("Database migrations completed successfully.")
} else {
fmt.Println("Database schema is up to date.")
}
return nil
}