203 lines
5.7 KiB
Go
203 lines
5.7 KiB
Go
package db
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
)
|
|
|
|
func TestEnableAllPlugins(t *testing.T) {
|
|
// Create temporary database for testing with unique name
|
|
dbFile := fmt.Sprintf("test_db_%d.db", time.Now().UnixNano())
|
|
database, err := New(dbFile)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test database: %v", err)
|
|
}
|
|
defer func() {
|
|
_ = database.Close()
|
|
// Clean up test database file
|
|
_ = os.Remove(dbFile)
|
|
}()
|
|
|
|
t.Run("CreateChannel with EnableAllPlugins default false", func(t *testing.T) {
|
|
channelRaw := map[string]interface{}{
|
|
"name": "test-channel",
|
|
}
|
|
|
|
channel, err := database.CreateChannel("telegram", "123456", true, channelRaw)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create channel: %v", err)
|
|
}
|
|
|
|
if channel.EnableAllPlugins {
|
|
t.Errorf("Expected EnableAllPlugins to be false by default, got true")
|
|
}
|
|
|
|
// Verify it's also false when retrieved from database
|
|
retrieved, err := database.GetChannelByID(channel.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve channel: %v", err)
|
|
}
|
|
|
|
if retrieved.EnableAllPlugins {
|
|
t.Errorf("Expected EnableAllPlugins to be false when retrieved from DB, got true")
|
|
}
|
|
})
|
|
|
|
t.Run("UpdateChannelEnableAllPlugins", func(t *testing.T) {
|
|
// Create a channel
|
|
channelRaw := map[string]interface{}{
|
|
"name": "test-channel-2",
|
|
}
|
|
|
|
channel, err := database.CreateChannel("telegram", "123457", true, channelRaw)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create channel: %v", err)
|
|
}
|
|
|
|
// Update EnableAllPlugins to true
|
|
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
|
|
}
|
|
|
|
// Retrieve and verify
|
|
retrieved, err := database.GetChannelByID(channel.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve channel: %v", err)
|
|
}
|
|
|
|
if !retrieved.EnableAllPlugins {
|
|
t.Errorf("Expected EnableAllPlugins to be true after update, got false")
|
|
}
|
|
|
|
// Update back to false
|
|
err = database.UpdateChannelEnableAllPlugins(channel.ID, false)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update EnableAllPlugins back to false: %v", err)
|
|
}
|
|
|
|
// Retrieve and verify again
|
|
retrieved, err = database.GetChannelByID(channel.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve channel: %v", err)
|
|
}
|
|
|
|
if retrieved.EnableAllPlugins {
|
|
t.Errorf("Expected EnableAllPlugins to be false after second update, got true")
|
|
}
|
|
})
|
|
|
|
t.Run("GetChannelByPlatform includes EnableAllPlugins", func(t *testing.T) {
|
|
// Create a channel
|
|
channelRaw := map[string]interface{}{
|
|
"name": "test-channel-3",
|
|
}
|
|
|
|
channel, err := database.CreateChannel("slack", "C123456", true, channelRaw)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create channel: %v", err)
|
|
}
|
|
|
|
// Enable all plugins
|
|
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
|
|
}
|
|
|
|
// Retrieve by platform
|
|
retrieved, err := database.GetChannelByPlatform("slack", "C123456")
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve channel by platform: %v", err)
|
|
}
|
|
|
|
if !retrieved.EnableAllPlugins {
|
|
t.Errorf("Expected EnableAllPlugins to be true when retrieved by platform, got false")
|
|
}
|
|
})
|
|
|
|
t.Run("GetAllChannels includes EnableAllPlugins", func(t *testing.T) {
|
|
// Create multiple channels with different EnableAllPlugins settings
|
|
channelRaw1 := map[string]interface{}{"name": "channel-1"}
|
|
channelRaw2 := map[string]interface{}{"name": "channel-2"}
|
|
|
|
channel1, err := database.CreateChannel("platform1", "ch1", true, channelRaw1)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create channel1: %v", err)
|
|
}
|
|
|
|
channel2, err := database.CreateChannel("platform2", "ch2", true, channelRaw2)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create channel2: %v", err)
|
|
}
|
|
|
|
// Enable all plugins for channel2 only
|
|
err = database.UpdateChannelEnableAllPlugins(channel2.ID, true)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update EnableAllPlugins for channel2: %v", err)
|
|
}
|
|
|
|
// Get all channels
|
|
channels, err := database.GetAllChannels()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get all channels: %v", err)
|
|
}
|
|
|
|
// Find our test channels
|
|
var foundChannel1, foundChannel2 *model.Channel
|
|
for _, ch := range channels {
|
|
if ch.ID == channel1.ID {
|
|
foundChannel1 = ch
|
|
}
|
|
if ch.ID == channel2.ID {
|
|
foundChannel2 = ch
|
|
}
|
|
}
|
|
|
|
if foundChannel1 == nil {
|
|
t.Fatalf("Channel1 not found in GetAllChannels result")
|
|
}
|
|
if foundChannel2 == nil {
|
|
t.Fatalf("Channel2 not found in GetAllChannels result")
|
|
}
|
|
|
|
if foundChannel1.EnableAllPlugins {
|
|
t.Errorf("Expected channel1 EnableAllPlugins to be false, got true")
|
|
}
|
|
if !foundChannel2.EnableAllPlugins {
|
|
t.Errorf("Expected channel2 EnableAllPlugins to be true, got false")
|
|
}
|
|
})
|
|
|
|
t.Run("Migration applied correctly", func(t *testing.T) {
|
|
// Test that we can create a channel and the enable_all_plugins column exists
|
|
// This implicitly tests that migration 4 was applied correctly
|
|
channelRaw := map[string]interface{}{
|
|
"name": "migration-test-channel",
|
|
}
|
|
|
|
channel, err := database.CreateChannel("test-platform", "migration-test", true, channelRaw)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create channel after migration: %v", err)
|
|
}
|
|
|
|
// Try to update EnableAllPlugins - this would fail if the column doesn't exist
|
|
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update EnableAllPlugins - migration may not have been applied: %v", err)
|
|
}
|
|
|
|
// Verify the value was set correctly
|
|
retrieved, err := database.GetChannelByID(channel.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve channel: %v", err)
|
|
}
|
|
|
|
if !retrieved.EnableAllPlugins {
|
|
t.Errorf("EnableAllPlugins should be true after update")
|
|
}
|
|
})
|
|
}
|