feat: allow enabling all plugins into a channel
This commit is contained in:
parent
899ac49336
commit
3b09a9dd47
10 changed files with 915 additions and 17 deletions
|
@ -56,7 +56,7 @@ func (d *Database) Close() error {
|
|||
// GetChannelByID retrieves a channel by ID
|
||||
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||
query := `
|
||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
||||
FROM channels
|
||||
WHERE id = ?
|
||||
`
|
||||
|
@ -67,10 +67,11 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
|||
platform string
|
||||
platformChannelID string
|
||||
enabled bool
|
||||
enableAllPlugins bool
|
||||
channelRawJSON string
|
||||
)
|
||||
|
||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
@ -90,6 +91,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
|||
Platform: platform,
|
||||
PlatformChannelID: platformChannelID,
|
||||
Enabled: enabled,
|
||||
EnableAllPlugins: enableAllPlugins,
|
||||
ChannelRaw: channelRaw,
|
||||
Plugins: make(map[string]*model.ChannelPlugin),
|
||||
}
|
||||
|
@ -110,7 +112,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
|||
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
|
||||
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
|
||||
query := `
|
||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
||||
FROM channels
|
||||
WHERE platform = ? AND platform_channel_id = ?
|
||||
`
|
||||
|
@ -118,12 +120,13 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
|
|||
row := d.db.QueryRow(query, platform, platformChannelID)
|
||||
|
||||
var (
|
||||
id int64
|
||||
enabled bool
|
||||
channelRawJSON string
|
||||
id int64
|
||||
enabled bool
|
||||
enableAllPlugins bool
|
||||
channelRawJSON string
|
||||
)
|
||||
|
||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
@ -143,6 +146,7 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
|
|||
Platform: platform,
|
||||
PlatformChannelID: platformChannelID,
|
||||
Enabled: enabled,
|
||||
EnableAllPlugins: enableAllPlugins,
|
||||
ChannelRaw: channelRaw,
|
||||
Plugins: make(map[string]*model.ChannelPlugin),
|
||||
}
|
||||
|
@ -170,11 +174,11 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
|
|||
|
||||
// Insert channel
|
||||
query := `
|
||||
INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw)
|
||||
VALUES (?, ?, ?, ?)
|
||||
INSERT INTO channels (platform, platform_channel_id, enabled, enable_all_plugins, channel_raw)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON))
|
||||
result, err := d.db.Exec(query, platform, platformChannelID, enabled, false, string(channelRawJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -191,6 +195,7 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
|
|||
Platform: platform,
|
||||
PlatformChannelID: platformChannelID,
|
||||
Enabled: enabled,
|
||||
EnableAllPlugins: false,
|
||||
ChannelRaw: channelRaw,
|
||||
Plugins: make(map[string]*model.ChannelPlugin),
|
||||
}
|
||||
|
@ -210,6 +215,18 @@ func (d *Database) UpdateChannel(id int64, enabled bool) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// UpdateChannelEnableAllPlugins updates a channel's enable_all_plugins status
|
||||
func (d *Database) UpdateChannelEnableAllPlugins(id int64, enableAllPlugins bool) error {
|
||||
query := `
|
||||
UPDATE channels
|
||||
SET enable_all_plugins = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query, enableAllPlugins, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteChannel deletes a channel
|
||||
func (d *Database) DeleteChannel(id int64) error {
|
||||
// First delete all channel plugins
|
||||
|
@ -456,7 +473,7 @@ func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
|
|||
// GetAllChannels retrieves all channels
|
||||
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||
query := `
|
||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
||||
FROM channels
|
||||
`
|
||||
|
||||
|
@ -478,10 +495,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
|||
platform string
|
||||
platformChannelID string
|
||||
enabled bool
|
||||
enableAllPlugins bool
|
||||
channelRawJSON string
|
||||
)
|
||||
|
||||
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil {
|
||||
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -497,6 +515,7 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
|||
Platform: platform,
|
||||
PlatformChannelID: platformChannelID,
|
||||
Enabled: enabled,
|
||||
EnableAllPlugins: enableAllPlugins,
|
||||
ChannelRaw: channelRaw,
|
||||
Plugins: make(map[string]*model.ChannelPlugin),
|
||||
}
|
||||
|
|
203
internal/db/db_test.go
Normal file
203
internal/db/db_test.go
Normal file
|
@ -0,0 +1,203 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||
)
|
||||
|
||||
func TestEnableAllPlugins(t *testing.T) {
|
||||
// Create temporary database for testing with unique name
|
||||
dbFile := fmt.Sprintf("test_db_%d.db", time.Now().UnixNano())
|
||||
database, err := New(dbFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = database.Close()
|
||||
// Clean up test database file
|
||||
_ = os.Remove(dbFile)
|
||||
}()
|
||||
|
||||
t.Run("CreateChannel with EnableAllPlugins default false", func(t *testing.T) {
|
||||
channelRaw := map[string]interface{}{
|
||||
"name": "test-channel",
|
||||
}
|
||||
|
||||
channel, err := database.CreateChannel("telegram", "123456", true, channelRaw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
|
||||
if channel.EnableAllPlugins {
|
||||
t.Errorf("Expected EnableAllPlugins to be false by default, got true")
|
||||
}
|
||||
|
||||
// Verify it's also false when retrieved from database
|
||||
retrieved, err := database.GetChannelByID(channel.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve channel: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.EnableAllPlugins {
|
||||
t.Errorf("Expected EnableAllPlugins to be false when retrieved from DB, got true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateChannelEnableAllPlugins", func(t *testing.T) {
|
||||
// Create a channel
|
||||
channelRaw := map[string]interface{}{
|
||||
"name": "test-channel-2",
|
||||
}
|
||||
|
||||
channel, err := database.CreateChannel("telegram", "123457", true, channelRaw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
|
||||
// Update EnableAllPlugins to true
|
||||
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify
|
||||
retrieved, err := database.GetChannelByID(channel.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve channel: %v", err)
|
||||
}
|
||||
|
||||
if !retrieved.EnableAllPlugins {
|
||||
t.Errorf("Expected EnableAllPlugins to be true after update, got false")
|
||||
}
|
||||
|
||||
// Update back to false
|
||||
err = database.UpdateChannelEnableAllPlugins(channel.ID, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update EnableAllPlugins back to false: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify again
|
||||
retrieved, err = database.GetChannelByID(channel.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve channel: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.EnableAllPlugins {
|
||||
t.Errorf("Expected EnableAllPlugins to be false after second update, got true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetChannelByPlatform includes EnableAllPlugins", func(t *testing.T) {
|
||||
// Create a channel
|
||||
channelRaw := map[string]interface{}{
|
||||
"name": "test-channel-3",
|
||||
}
|
||||
|
||||
channel, err := database.CreateChannel("slack", "C123456", true, channelRaw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
|
||||
// Enable all plugins
|
||||
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve by platform
|
||||
retrieved, err := database.GetChannelByPlatform("slack", "C123456")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve channel by platform: %v", err)
|
||||
}
|
||||
|
||||
if !retrieved.EnableAllPlugins {
|
||||
t.Errorf("Expected EnableAllPlugins to be true when retrieved by platform, got false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetAllChannels includes EnableAllPlugins", func(t *testing.T) {
|
||||
// Create multiple channels with different EnableAllPlugins settings
|
||||
channelRaw1 := map[string]interface{}{"name": "channel-1"}
|
||||
channelRaw2 := map[string]interface{}{"name": "channel-2"}
|
||||
|
||||
channel1, err := database.CreateChannel("platform1", "ch1", true, channelRaw1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel1: %v", err)
|
||||
}
|
||||
|
||||
channel2, err := database.CreateChannel("platform2", "ch2", true, channelRaw2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel2: %v", err)
|
||||
}
|
||||
|
||||
// Enable all plugins for channel2 only
|
||||
err = database.UpdateChannelEnableAllPlugins(channel2.ID, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update EnableAllPlugins for channel2: %v", err)
|
||||
}
|
||||
|
||||
// Get all channels
|
||||
channels, err := database.GetAllChannels()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get all channels: %v", err)
|
||||
}
|
||||
|
||||
// Find our test channels
|
||||
var foundChannel1, foundChannel2 *model.Channel
|
||||
for _, ch := range channels {
|
||||
if ch.ID == channel1.ID {
|
||||
foundChannel1 = ch
|
||||
}
|
||||
if ch.ID == channel2.ID {
|
||||
foundChannel2 = ch
|
||||
}
|
||||
}
|
||||
|
||||
if foundChannel1 == nil {
|
||||
t.Fatalf("Channel1 not found in GetAllChannels result")
|
||||
}
|
||||
if foundChannel2 == nil {
|
||||
t.Fatalf("Channel2 not found in GetAllChannels result")
|
||||
}
|
||||
|
||||
if foundChannel1.EnableAllPlugins {
|
||||
t.Errorf("Expected channel1 EnableAllPlugins to be false, got true")
|
||||
}
|
||||
if !foundChannel2.EnableAllPlugins {
|
||||
t.Errorf("Expected channel2 EnableAllPlugins to be true, got false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Migration applied correctly", func(t *testing.T) {
|
||||
// Test that we can create a channel and the enable_all_plugins column exists
|
||||
// This implicitly tests that migration 4 was applied correctly
|
||||
channelRaw := map[string]interface{}{
|
||||
"name": "migration-test-channel",
|
||||
}
|
||||
|
||||
channel, err := database.CreateChannel("test-platform", "migration-test", true, channelRaw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel after migration: %v", err)
|
||||
}
|
||||
|
||||
// Try to update EnableAllPlugins - this would fail if the column doesn't exist
|
||||
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update EnableAllPlugins - migration may not have been applied: %v", err)
|
||||
}
|
||||
|
||||
// Verify the value was set correctly
|
||||
retrieved, err := database.GetChannelByID(channel.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve channel: %v", err)
|
||||
}
|
||||
|
||||
if !retrieved.EnableAllPlugins {
|
||||
t.Errorf("EnableAllPlugins should be true after update")
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue