Compare commits
9 commits
Author | SHA1 | Date | |
---|---|---|---|
377b1723c3 | |||
60ceaffd82 | |||
3a5b5c216d | |||
bdc797d5c1 | |||
0edf41c792 | |||
35c14ce8a8 | |||
e0ff369cff | |||
368c45cd13 | |||
3b09a9dd47 |
16 changed files with 1063 additions and 35 deletions
10
CLAUDE.md
10
CLAUDE.md
|
@ -18,10 +18,12 @@ When creating, modifying, or removing plugins:
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
Before committing plugin changes:
|
**CRITICAL**: After making ANY changes to code files, you MUST run these commands in order:
|
||||||
|
|
||||||
1. Check files are properly formatted: Run `make format`
|
1. **Format code**: `make format` - Format all code according to project standards
|
||||||
2. Check code style and linting: Run `make lint`
|
2. **Lint code**: `make lint` - Check code style and quality (must show "0 issues")
|
||||||
3. Test the plugin functionality: Run `make test`
|
3. **Run tests**: `make test` - Run all tests to ensure functionality works
|
||||||
4. Verify documentation accuracy
|
4. Verify documentation accuracy
|
||||||
5. Ensure all examples work as described
|
5. Ensure all examples work as described
|
||||||
|
|
||||||
|
**These commands are MANDATORY after every code change, no exceptions.**
|
||||||
|
|
|
@ -564,6 +564,13 @@ func (a *Admin) handleChannelDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update enable_all_plugins
|
||||||
|
enableAllPlugins := r.FormValue("enable_all_plugins") == "true"
|
||||||
|
if err := a.db.UpdateChannelEnableAllPlugins(id, enableAllPlugins); err != nil {
|
||||||
|
http.Error(w, "Failed to update channel enable all plugins", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
a.addFlash(w, r, "Channel updated", "success")
|
a.addFlash(w, r, "Channel updated", "success")
|
||||||
http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther)
|
http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther)
|
||||||
return
|
return
|
||||||
|
|
|
@ -27,6 +27,15 @@
|
||||||
<!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked -->
|
<!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked -->
|
||||||
<input type="hidden" name="form_submitted" value="true">
|
<input type="hidden" name="form_submitted" value="true">
|
||||||
</div>
|
</div>
|
||||||
|
<div class="mb-3">
|
||||||
|
<label class="form-check form-switch">
|
||||||
|
<input class="form-check-input" type="checkbox" name="enable_all_plugins" value="true" {{if .Channel.EnableAllPlugins}}checked{{end}}>
|
||||||
|
<span class="form-check-label">Enable All Plugins</span>
|
||||||
|
</label>
|
||||||
|
<div>
|
||||||
|
When enabled, all registered plugins will be automatically enabled for this channel. Individual plugin settings will be ignored.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
<div class="form-footer">
|
<div class="form-footer">
|
||||||
<button type="submit" class="btn btn-primary">Save</button>
|
<button type="submit" class="btn btn-primary">Save</button>
|
||||||
<a href="/admin/channels" class="btn btn-link">Back to Channels</a>
|
<a href="/admin/channels" class="btn btn-link">Back to Channels</a>
|
||||||
|
@ -115,4 +124,4 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{{end}}
|
{{end}}
|
||||||
|
|
|
@ -314,11 +314,21 @@ func (a *App) handleMessage(item queue.Item) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process message with plugins
|
// Process message with plugins
|
||||||
for pluginID, channelPlugin := range channel.Plugins {
|
var pluginsToProcess []string
|
||||||
if !channel.HasEnabledPlugin(pluginID) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if channel.EnableAllPlugins {
|
||||||
|
// If EnableAllPlugins is true, process all registered plugins
|
||||||
|
pluginsToProcess = plugin.GetAvailablePluginIDs()
|
||||||
|
} else {
|
||||||
|
// Otherwise, process only explicitly enabled plugins
|
||||||
|
for pluginID := range channel.Plugins {
|
||||||
|
if channel.HasEnabledPlugin(pluginID) {
|
||||||
|
pluginsToProcess = append(pluginsToProcess, pluginID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pluginID := range pluginsToProcess {
|
||||||
// Get plugin
|
// Get plugin
|
||||||
p, err := plugin.Get(pluginID)
|
p, err := plugin.Get(pluginID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -326,11 +336,19 @@ func (a *App) handleMessage(item queue.Item) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get plugin configuration (empty map if EnableAllPlugins and plugin not explicitly configured)
|
||||||
|
var config map[string]interface{}
|
||||||
|
if channelPlugin, exists := channel.Plugins[pluginID]; exists {
|
||||||
|
config = channelPlugin.Config
|
||||||
|
} else {
|
||||||
|
config = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
// Create cache instance for this plugin
|
// Create cache instance for this plugin
|
||||||
pluginCache := cache.New(a.db, pluginID)
|
pluginCache := cache.New(a.db, pluginID)
|
||||||
|
|
||||||
// Process message and get actions
|
// Process message and get actions
|
||||||
actions := p.OnMessage(message, channelPlugin.Config, pluginCache)
|
actions := p.OnMessage(message, config, pluginCache)
|
||||||
|
|
||||||
// Get platform for processing actions
|
// Get platform for processing actions
|
||||||
platform, err := platform.Get(item.Platform)
|
platform, err := platform.Get(item.Platform)
|
||||||
|
|
|
@ -56,7 +56,7 @@ func (d *Database) Close() error {
|
||||||
// GetChannelByID retrieves a channel by ID
|
// GetChannelByID retrieves a channel by ID
|
||||||
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
||||||
FROM channels
|
FROM channels
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
|
@ -67,10 +67,11 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||||
platform string
|
platform string
|
||||||
platformChannelID string
|
platformChannelID string
|
||||||
enabled bool
|
enabled bool
|
||||||
|
enableAllPlugins bool
|
||||||
channelRawJSON string
|
channelRawJSON string
|
||||||
)
|
)
|
||||||
|
|
||||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
@ -90,6 +91,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
PlatformChannelID: platformChannelID,
|
PlatformChannelID: platformChannelID,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
|
EnableAllPlugins: enableAllPlugins,
|
||||||
ChannelRaw: channelRaw,
|
ChannelRaw: channelRaw,
|
||||||
Plugins: make(map[string]*model.ChannelPlugin),
|
Plugins: make(map[string]*model.ChannelPlugin),
|
||||||
}
|
}
|
||||||
|
@ -110,7 +112,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||||
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
|
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
|
||||||
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
|
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
||||||
FROM channels
|
FROM channels
|
||||||
WHERE platform = ? AND platform_channel_id = ?
|
WHERE platform = ? AND platform_channel_id = ?
|
||||||
`
|
`
|
||||||
|
@ -118,12 +120,13 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
|
||||||
row := d.db.QueryRow(query, platform, platformChannelID)
|
row := d.db.QueryRow(query, platform, platformChannelID)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
id int64
|
id int64
|
||||||
enabled bool
|
enabled bool
|
||||||
channelRawJSON string
|
enableAllPlugins bool
|
||||||
|
channelRawJSON string
|
||||||
)
|
)
|
||||||
|
|
||||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
@ -143,6 +146,7 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
PlatformChannelID: platformChannelID,
|
PlatformChannelID: platformChannelID,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
|
EnableAllPlugins: enableAllPlugins,
|
||||||
ChannelRaw: channelRaw,
|
ChannelRaw: channelRaw,
|
||||||
Plugins: make(map[string]*model.ChannelPlugin),
|
Plugins: make(map[string]*model.ChannelPlugin),
|
||||||
}
|
}
|
||||||
|
@ -170,11 +174,11 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
|
||||||
|
|
||||||
// Insert channel
|
// Insert channel
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw)
|
INSERT INTO channels (platform, platform_channel_id, enabled, enable_all_plugins, channel_raw)
|
||||||
VALUES (?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
|
|
||||||
result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON))
|
result, err := d.db.Exec(query, platform, platformChannelID, enabled, false, string(channelRawJSON))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -191,6 +195,7 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
PlatformChannelID: platformChannelID,
|
PlatformChannelID: platformChannelID,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
|
EnableAllPlugins: false,
|
||||||
ChannelRaw: channelRaw,
|
ChannelRaw: channelRaw,
|
||||||
Plugins: make(map[string]*model.ChannelPlugin),
|
Plugins: make(map[string]*model.ChannelPlugin),
|
||||||
}
|
}
|
||||||
|
@ -210,6 +215,18 @@ func (d *Database) UpdateChannel(id int64, enabled bool) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateChannelEnableAllPlugins updates a channel's enable_all_plugins status
|
||||||
|
func (d *Database) UpdateChannelEnableAllPlugins(id int64, enableAllPlugins bool) error {
|
||||||
|
query := `
|
||||||
|
UPDATE channels
|
||||||
|
SET enable_all_plugins = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := d.db.Exec(query, enableAllPlugins, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteChannel deletes a channel
|
// DeleteChannel deletes a channel
|
||||||
func (d *Database) DeleteChannel(id int64) error {
|
func (d *Database) DeleteChannel(id int64) error {
|
||||||
// First delete all channel plugins
|
// First delete all channel plugins
|
||||||
|
@ -456,7 +473,7 @@ func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
|
||||||
// GetAllChannels retrieves all channels
|
// GetAllChannels retrieves all channels
|
||||||
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
||||||
FROM channels
|
FROM channels
|
||||||
`
|
`
|
||||||
|
|
||||||
|
@ -478,10 +495,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
platform string
|
platform string
|
||||||
platformChannelID string
|
platformChannelID string
|
||||||
enabled bool
|
enabled bool
|
||||||
|
enableAllPlugins bool
|
||||||
channelRawJSON string
|
channelRawJSON string
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil {
|
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -497,6 +515,7 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
PlatformChannelID: platformChannelID,
|
PlatformChannelID: platformChannelID,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
|
EnableAllPlugins: enableAllPlugins,
|
||||||
ChannelRaw: channelRaw,
|
ChannelRaw: channelRaw,
|
||||||
Plugins: make(map[string]*model.ChannelPlugin),
|
Plugins: make(map[string]*model.ChannelPlugin),
|
||||||
}
|
}
|
||||||
|
|
203
internal/db/db_test.go
Normal file
203
internal/db/db_test.go
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnableAllPlugins(t *testing.T) {
|
||||||
|
// Create temporary database for testing with unique name
|
||||||
|
dbFile := fmt.Sprintf("test_db_%d.db", time.Now().UnixNano())
|
||||||
|
database, err := New(dbFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create test database: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = database.Close()
|
||||||
|
// Clean up test database file
|
||||||
|
_ = os.Remove(dbFile)
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Run("CreateChannel with EnableAllPlugins default false", func(t *testing.T) {
|
||||||
|
channelRaw := map[string]interface{}{
|
||||||
|
"name": "test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, err := database.CreateChannel("telegram", "123456", true, channelRaw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if channel.EnableAllPlugins {
|
||||||
|
t.Errorf("Expected EnableAllPlugins to be false by default, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's also false when retrieved from database
|
||||||
|
retrieved, err := database.GetChannelByID(channel.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to retrieve channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieved.EnableAllPlugins {
|
||||||
|
t.Errorf("Expected EnableAllPlugins to be false when retrieved from DB, got true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UpdateChannelEnableAllPlugins", func(t *testing.T) {
|
||||||
|
// Create a channel
|
||||||
|
channelRaw := map[string]interface{}{
|
||||||
|
"name": "test-channel-2",
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, err := database.CreateChannel("telegram", "123457", true, channelRaw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update EnableAllPlugins to true
|
||||||
|
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve and verify
|
||||||
|
retrieved, err := database.GetChannelByID(channel.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to retrieve channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !retrieved.EnableAllPlugins {
|
||||||
|
t.Errorf("Expected EnableAllPlugins to be true after update, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update back to false
|
||||||
|
err = database.UpdateChannelEnableAllPlugins(channel.ID, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to update EnableAllPlugins back to false: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve and verify again
|
||||||
|
retrieved, err = database.GetChannelByID(channel.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to retrieve channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieved.EnableAllPlugins {
|
||||||
|
t.Errorf("Expected EnableAllPlugins to be false after second update, got true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetChannelByPlatform includes EnableAllPlugins", func(t *testing.T) {
|
||||||
|
// Create a channel
|
||||||
|
channelRaw := map[string]interface{}{
|
||||||
|
"name": "test-channel-3",
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, err := database.CreateChannel("slack", "C123456", true, channelRaw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable all plugins
|
||||||
|
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve by platform
|
||||||
|
retrieved, err := database.GetChannelByPlatform("slack", "C123456")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to retrieve channel by platform: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !retrieved.EnableAllPlugins {
|
||||||
|
t.Errorf("Expected EnableAllPlugins to be true when retrieved by platform, got false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetAllChannels includes EnableAllPlugins", func(t *testing.T) {
|
||||||
|
// Create multiple channels with different EnableAllPlugins settings
|
||||||
|
channelRaw1 := map[string]interface{}{"name": "channel-1"}
|
||||||
|
channelRaw2 := map[string]interface{}{"name": "channel-2"}
|
||||||
|
|
||||||
|
channel1, err := database.CreateChannel("platform1", "ch1", true, channelRaw1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create channel1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
channel2, err := database.CreateChannel("platform2", "ch2", true, channelRaw2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create channel2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable all plugins for channel2 only
|
||||||
|
err = database.UpdateChannelEnableAllPlugins(channel2.ID, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to update EnableAllPlugins for channel2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all channels
|
||||||
|
channels, err := database.GetAllChannels()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get all channels: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find our test channels
|
||||||
|
var foundChannel1, foundChannel2 *model.Channel
|
||||||
|
for _, ch := range channels {
|
||||||
|
if ch.ID == channel1.ID {
|
||||||
|
foundChannel1 = ch
|
||||||
|
}
|
||||||
|
if ch.ID == channel2.ID {
|
||||||
|
foundChannel2 = ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if foundChannel1 == nil {
|
||||||
|
t.Fatalf("Channel1 not found in GetAllChannels result")
|
||||||
|
}
|
||||||
|
if foundChannel2 == nil {
|
||||||
|
t.Fatalf("Channel2 not found in GetAllChannels result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if foundChannel1.EnableAllPlugins {
|
||||||
|
t.Errorf("Expected channel1 EnableAllPlugins to be false, got true")
|
||||||
|
}
|
||||||
|
if !foundChannel2.EnableAllPlugins {
|
||||||
|
t.Errorf("Expected channel2 EnableAllPlugins to be true, got false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Migration applied correctly", func(t *testing.T) {
|
||||||
|
// Test that we can create a channel and the enable_all_plugins column exists
|
||||||
|
// This implicitly tests that migration 4 was applied correctly
|
||||||
|
channelRaw := map[string]interface{}{
|
||||||
|
"name": "migration-test-channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, err := database.CreateChannel("test-platform", "migration-test", true, channelRaw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create channel after migration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to update EnableAllPlugins - this would fail if the column doesn't exist
|
||||||
|
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to update EnableAllPlugins - migration may not have been applied: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the value was set correctly
|
||||||
|
retrieved, err := database.GetChannelByID(channel.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to retrieve channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !retrieved.EnableAllPlugins {
|
||||||
|
t.Errorf("EnableAllPlugins should be true after update")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -10,6 +10,7 @@ func init() {
|
||||||
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
|
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
|
||||||
Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown)
|
Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown)
|
||||||
Register(3, "Add cache table", migrateCacheUp, migrateCacheDown)
|
Register(3, "Add cache table", migrateCacheUp, migrateCacheDown)
|
||||||
|
Register(4, "Add enable_all_plugins column to channels", migrateEnableAllPluginsUp, migrateEnableAllPluginsDown)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initial schema creation with bcrypt passwords - version 1
|
// Initial schema creation with bcrypt passwords - version 1
|
||||||
|
@ -154,3 +155,60 @@ func migrateCacheDown(db *sql.DB) error {
|
||||||
_, err := db.Exec(`DROP TABLE IF EXISTS cache`)
|
_, err := db.Exec(`DROP TABLE IF EXISTS cache`)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add enable_all_plugins column to channels table - version 4
|
||||||
|
func migrateEnableAllPluginsUp(db *sql.DB) error {
|
||||||
|
_, err := db.Exec(`
|
||||||
|
ALTER TABLE channels ADD COLUMN enable_all_plugins BOOLEAN NOT NULL DEFAULT 0
|
||||||
|
`)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateEnableAllPluginsDown(db *sql.DB) error {
|
||||||
|
// SQLite doesn't support DROP COLUMN, so we need to recreate the table
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = tx.Rollback() // Ignore rollback errors
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create backup table
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
CREATE TABLE channels_backup (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
platform TEXT NOT NULL,
|
||||||
|
platform_channel_id TEXT NOT NULL,
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
channel_raw TEXT NOT NULL,
|
||||||
|
UNIQUE(platform, platform_channel_id)
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy data excluding enable_all_plugins column
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
INSERT INTO channels_backup (id, platform, platform_channel_id, enabled, channel_raw)
|
||||||
|
SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop original table
|
||||||
|
_, err = tx.Exec(`DROP TABLE channels`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rename backup table
|
||||||
|
_, err = tx.Exec(`ALTER TABLE channels_backup RENAME TO channels`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
|
@ -44,11 +44,17 @@ type Channel struct {
|
||||||
PlatformChannelID string
|
PlatformChannelID string
|
||||||
ChannelRaw map[string]interface{}
|
ChannelRaw map[string]interface{}
|
||||||
Enabled bool
|
Enabled bool
|
||||||
|
EnableAllPlugins bool
|
||||||
Plugins map[string]*ChannelPlugin
|
Plugins map[string]*ChannelPlugin
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasEnabledPlugin checks if a plugin is enabled for this channel
|
// HasEnabledPlugin checks if a plugin is enabled for this channel
|
||||||
func (c *Channel) HasEnabledPlugin(pluginID string) bool {
|
func (c *Channel) HasEnabledPlugin(pluginID string) bool {
|
||||||
|
// If EnableAllPlugins is true, all plugins are considered enabled
|
||||||
|
if c.EnableAllPlugins {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
plugin, exists := c.Plugins[pluginID]
|
plugin, exists := c.Plugins[pluginID]
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
|
|
234
internal/model/message_test.go
Normal file
234
internal/model/message_test.go
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChannel_HasEnabledPlugin(t *testing.T) {
|
||||||
|
t.Run("EnableAllPlugins false - plugin not in map", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: false,
|
||||||
|
Plugins: make(map[string]*ChannelPlugin),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plugin not in map should return false
|
||||||
|
result := channel.HasEnabledPlugin("nonexistent.plugin")
|
||||||
|
if result {
|
||||||
|
t.Errorf("Expected HasEnabledPlugin to return false for nonexistent plugin, got true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EnableAllPlugins false - plugin disabled", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: false,
|
||||||
|
Plugins: map[string]*ChannelPlugin{
|
||||||
|
"test.plugin": {
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "test.plugin",
|
||||||
|
Enabled: false,
|
||||||
|
Config: make(map[string]any),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disabled plugin should return false
|
||||||
|
result := channel.HasEnabledPlugin("test.plugin")
|
||||||
|
if result {
|
||||||
|
t.Errorf("Expected HasEnabledPlugin to return false for disabled plugin, got true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EnableAllPlugins false - plugin enabled", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: false,
|
||||||
|
Plugins: map[string]*ChannelPlugin{
|
||||||
|
"test.plugin": {
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "test.plugin",
|
||||||
|
Enabled: true,
|
||||||
|
Config: make(map[string]any),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enabled plugin should return true
|
||||||
|
result := channel.HasEnabledPlugin("test.plugin")
|
||||||
|
if !result {
|
||||||
|
t.Errorf("Expected HasEnabledPlugin to return true for enabled plugin, got false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EnableAllPlugins true - plugin not in map", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: true,
|
||||||
|
Plugins: make(map[string]*ChannelPlugin),
|
||||||
|
}
|
||||||
|
|
||||||
|
// When EnableAllPlugins is true, any plugin should be considered enabled
|
||||||
|
result := channel.HasEnabledPlugin("nonexistent.plugin")
|
||||||
|
if !result {
|
||||||
|
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EnableAllPlugins true - plugin disabled", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: true,
|
||||||
|
Plugins: map[string]*ChannelPlugin{
|
||||||
|
"test.plugin": {
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "test.plugin",
|
||||||
|
Enabled: false,
|
||||||
|
Config: make(map[string]any),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// When EnableAllPlugins is true, even disabled plugins should be considered enabled
|
||||||
|
result := channel.HasEnabledPlugin("test.plugin")
|
||||||
|
if !result {
|
||||||
|
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true (even for disabled plugin), got false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EnableAllPlugins true - plugin enabled", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: true,
|
||||||
|
Plugins: map[string]*ChannelPlugin{
|
||||||
|
"test.plugin": {
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "test.plugin",
|
||||||
|
Enabled: true,
|
||||||
|
Config: make(map[string]any),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// When EnableAllPlugins is true, enabled plugins should also return true
|
||||||
|
result := channel.HasEnabledPlugin("test.plugin")
|
||||||
|
if !result {
|
||||||
|
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EnableAllPlugins true - multiple plugins", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: true,
|
||||||
|
Plugins: map[string]*ChannelPlugin{
|
||||||
|
"plugin1": {
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "plugin1",
|
||||||
|
Enabled: true,
|
||||||
|
Config: make(map[string]any),
|
||||||
|
},
|
||||||
|
"plugin2": {
|
||||||
|
ID: 2,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "plugin2",
|
||||||
|
Enabled: false,
|
||||||
|
Config: make(map[string]any),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// All plugins should be enabled when EnableAllPlugins is true
|
||||||
|
testCases := []string{"plugin1", "plugin2", "plugin3", "any.plugin"}
|
||||||
|
for _, pluginID := range testCases {
|
||||||
|
result := channel.HasEnabledPlugin(pluginID)
|
||||||
|
if !result {
|
||||||
|
t.Errorf("Expected HasEnabledPlugin('%s') to return true when EnableAllPlugins is true, got false", pluginID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannelName(t *testing.T) {
|
||||||
|
t.Run("Returns PlatformChannelID when ChannelRaw is nil", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
PlatformChannelID: "test-id",
|
||||||
|
ChannelRaw: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := channel.ChannelName()
|
||||||
|
if result != "test-id" {
|
||||||
|
t.Errorf("Expected channel name to be 'test-id', got '%s'", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns name from ChannelRaw when available", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
PlatformChannelID: "test-id",
|
||||||
|
ChannelRaw: map[string]interface{}{
|
||||||
|
"name": "Test Channel",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := channel.ChannelName()
|
||||||
|
if result != "Test Channel" {
|
||||||
|
t.Errorf("Expected channel name to be 'Test Channel', got '%s'", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns title from nested chat object (Telegram style)", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
PlatformChannelID: "test-id",
|
||||||
|
ChannelRaw: map[string]interface{}{
|
||||||
|
"chat": map[string]interface{}{
|
||||||
|
"title": "Telegram Group",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := channel.ChannelName()
|
||||||
|
if result != "Telegram Group" {
|
||||||
|
t.Errorf("Expected channel name to be 'Telegram Group', got '%s'", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Falls back to PlatformChannelID when no valid name found", func(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
PlatformChannelID: "fallback-id",
|
||||||
|
ChannelRaw: map[string]interface{}{
|
||||||
|
"other_field": "value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := channel.ChannelName()
|
||||||
|
if result != "fallback-id" {
|
||||||
|
t.Errorf("Expected channel name to fallback to 'fallback-id', got '%s'", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -233,9 +233,17 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
||||||
|
|
||||||
// Prepare payload
|
// Prepare payload
|
||||||
payload := map[string]interface{}{
|
payload := map[string]interface{}{
|
||||||
"chat_id": chatID,
|
"chat_id": chatID,
|
||||||
"text": msg.Text,
|
"text": msg.Text,
|
||||||
"parse_mode": "Markdown",
|
}
|
||||||
|
|
||||||
|
// Set parse_mode based on plugin preference or default to empty string
|
||||||
|
if msg.Raw != nil && msg.Raw["parse_mode"] != nil {
|
||||||
|
// Plugin explicitly set parse_mode
|
||||||
|
payload["parse_mode"] = msg.Raw["parse_mode"]
|
||||||
|
} else {
|
||||||
|
// Default to empty string (no formatting)
|
||||||
|
payload["parse_mode"] = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add reply if needed
|
// Add reply if needed
|
||||||
|
|
|
@ -131,12 +131,15 @@ func (p *HLTBPlugin) OnMessage(msg *model.Message, config map[string]interface{}
|
||||||
Channel: msg.Channel,
|
Channel: msg.Channel,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set parse mode for markdown formatting
|
||||||
|
if responseMsg.Raw == nil {
|
||||||
|
responseMsg.Raw = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
responseMsg.Raw["parse_mode"] = "Markdown"
|
||||||
|
|
||||||
// Add game cover as attachment if available
|
// Add game cover as attachment if available
|
||||||
if game.GameImage != "" {
|
if game.GameImage != "" {
|
||||||
imageURL := p.getFullImageURL(game.GameImage)
|
imageURL := p.getFullImageURL(game.GameImage)
|
||||||
if responseMsg.Raw == nil {
|
|
||||||
responseMsg.Raw = make(map[string]interface{})
|
|
||||||
}
|
|
||||||
responseMsg.Raw["image_url"] = imageURL
|
responseMsg.Raw["image_url"] = imageURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,7 @@ func (p *HelpPlugin) OnMessage(msg *model.Message, config map[string]interface{}
|
||||||
Chat: msg.Chat,
|
Chat: msg.Chat,
|
||||||
ReplyTo: msg.ID,
|
ReplyTo: msg.ID,
|
||||||
Channel: msg.Channel,
|
Channel: msg.Channel,
|
||||||
|
Raw: map[string]interface{}{"parse_mode": "Markdown"},
|
||||||
}
|
}
|
||||||
|
|
||||||
return []*model.MessageAction{
|
return []*model.MessageAction{
|
||||||
|
@ -151,6 +152,7 @@ func (p *HelpPlugin) OnMessage(msg *model.Message, config map[string]interface{}
|
||||||
Chat: msg.Chat,
|
Chat: msg.Chat,
|
||||||
ReplyTo: msg.ID,
|
ReplyTo: msg.ID,
|
||||||
Channel: msg.Channel,
|
Channel: msg.Channel,
|
||||||
|
Raw: map[string]interface{}{"parse_mode": "Markdown"},
|
||||||
}
|
}
|
||||||
|
|
||||||
return []*model.MessageAction{
|
return []*model.MessageAction{
|
||||||
|
|
|
@ -47,6 +47,19 @@ func GetAvailablePlugins() map[string]model.Plugin {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAvailablePluginIDs returns a slice of all registered plugin IDs
|
||||||
|
func GetAvailablePluginIDs() []string {
|
||||||
|
pluginsMu.RLock()
|
||||||
|
defer pluginsMu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]string, 0, len(plugins))
|
||||||
|
for pluginID := range plugins {
|
||||||
|
result = append(result, pluginID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// ClearRegistry clears all registered plugins (for testing)
|
// ClearRegistry clears all registered plugins (for testing)
|
||||||
func ClearRegistry() {
|
func ClearRegistry() {
|
||||||
pluginsMu.Lock()
|
pluginsMu.Lock()
|
||||||
|
|
331
internal/plugin/plugin_test.go
Normal file
331
internal/plugin/plugin_test.go
Normal file
|
@ -0,0 +1,331 @@
|
||||||
|
package plugin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mock plugin for testing
|
||||||
|
type testPlugin struct {
|
||||||
|
BasePlugin
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *testPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
||||||
|
return []*model.MessageAction{
|
||||||
|
{
|
||||||
|
Type: model.ActionSendMessage,
|
||||||
|
Message: &model.Message{
|
||||||
|
Text: "test response",
|
||||||
|
Chat: msg.Chat,
|
||||||
|
Channel: msg.Channel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailablePluginIDs(t *testing.T) {
|
||||||
|
// Clear registry before test
|
||||||
|
ClearRegistry()
|
||||||
|
|
||||||
|
// Register test plugins
|
||||||
|
testPlugin1 := &testPlugin{
|
||||||
|
BasePlugin: BasePlugin{
|
||||||
|
ID: "test.plugin1",
|
||||||
|
Name: "Test Plugin 1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
testPlugin2 := &testPlugin{
|
||||||
|
BasePlugin: BasePlugin{
|
||||||
|
ID: "test.plugin2",
|
||||||
|
Name: "Test Plugin 2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
Register(testPlugin1)
|
||||||
|
Register(testPlugin2)
|
||||||
|
|
||||||
|
// Test GetAvailablePluginIDs
|
||||||
|
pluginIDs := GetAvailablePluginIDs()
|
||||||
|
|
||||||
|
if len(pluginIDs) != 2 {
|
||||||
|
t.Errorf("Expected 2 plugin IDs, got %d", len(pluginIDs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that both plugin IDs are present
|
||||||
|
found1, found2 := false, false
|
||||||
|
for _, id := range pluginIDs {
|
||||||
|
if id == "test.plugin1" {
|
||||||
|
found1 = true
|
||||||
|
}
|
||||||
|
if id == "test.plugin2" {
|
||||||
|
found2 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found1 {
|
||||||
|
t.Errorf("Expected to find test.plugin1 in plugin IDs")
|
||||||
|
}
|
||||||
|
if !found2 {
|
||||||
|
t.Errorf("Expected to find test.plugin2 in plugin IDs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnableAllPluginsProcessingLogic(t *testing.T) {
|
||||||
|
// Clear registry before test
|
||||||
|
ClearRegistry()
|
||||||
|
|
||||||
|
// Register test plugins
|
||||||
|
testPlugin1 := &testPlugin{
|
||||||
|
BasePlugin: BasePlugin{
|
||||||
|
ID: "ping",
|
||||||
|
Name: "Ping Plugin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
testPlugin2 := &testPlugin{
|
||||||
|
BasePlugin: BasePlugin{
|
||||||
|
ID: "echo",
|
||||||
|
Name: "Echo Plugin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
testPlugin3 := &testPlugin{
|
||||||
|
BasePlugin: BasePlugin{
|
||||||
|
ID: "help",
|
||||||
|
Name: "Help Plugin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
Register(testPlugin1)
|
||||||
|
Register(testPlugin2)
|
||||||
|
Register(testPlugin3)
|
||||||
|
|
||||||
|
t.Run("EnableAllPlugins false - only explicitly enabled plugins", func(t *testing.T) {
|
||||||
|
// Create a channel with EnableAllPlugins = false and only some plugins enabled
|
||||||
|
channel := &model.Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: false,
|
||||||
|
Plugins: map[string]*model.ChannelPlugin{
|
||||||
|
"ping": {
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "ping",
|
||||||
|
Enabled: true,
|
||||||
|
Config: map[string]interface{}{"key": "value"},
|
||||||
|
},
|
||||||
|
"echo": {
|
||||||
|
ID: 2,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "echo",
|
||||||
|
Enabled: false, // Disabled
|
||||||
|
Config: map[string]interface{}{},
|
||||||
|
},
|
||||||
|
// help plugin not configured
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate the plugin processing logic from handleMessage
|
||||||
|
var pluginsToProcess []string
|
||||||
|
|
||||||
|
if channel.EnableAllPlugins {
|
||||||
|
pluginsToProcess = GetAvailablePluginIDs()
|
||||||
|
} else {
|
||||||
|
for pluginID := range channel.Plugins {
|
||||||
|
if channel.HasEnabledPlugin(pluginID) {
|
||||||
|
pluginsToProcess = append(pluginsToProcess, pluginID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should only have "ping" since echo is disabled and help is not configured
|
||||||
|
if len(pluginsToProcess) != 1 {
|
||||||
|
t.Errorf("Expected 1 plugin to process, got %d: %v", len(pluginsToProcess), pluginsToProcess)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pluginsToProcess) > 0 && pluginsToProcess[0] != "ping" {
|
||||||
|
t.Errorf("Expected ping plugin to be processed, got %s", pluginsToProcess[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EnableAllPlugins true - all registered plugins", func(t *testing.T) {
|
||||||
|
// Create a channel with EnableAllPlugins = true
|
||||||
|
channel := &model.Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: true,
|
||||||
|
Plugins: map[string]*model.ChannelPlugin{
|
||||||
|
"ping": {
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "ping",
|
||||||
|
Enabled: true,
|
||||||
|
Config: map[string]interface{}{"key": "value"},
|
||||||
|
},
|
||||||
|
"echo": {
|
||||||
|
ID: 2,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "echo",
|
||||||
|
Enabled: false, // Disabled, but should still be processed
|
||||||
|
Config: map[string]interface{}{},
|
||||||
|
},
|
||||||
|
// help plugin not configured, but should still be processed
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate the plugin processing logic from handleMessage
|
||||||
|
var pluginsToProcess []string
|
||||||
|
|
||||||
|
if channel.EnableAllPlugins {
|
||||||
|
pluginsToProcess = GetAvailablePluginIDs()
|
||||||
|
} else {
|
||||||
|
for pluginID := range channel.Plugins {
|
||||||
|
if channel.HasEnabledPlugin(pluginID) {
|
||||||
|
pluginsToProcess = append(pluginsToProcess, pluginID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have all 3 registered plugins
|
||||||
|
if len(pluginsToProcess) != 3 {
|
||||||
|
t.Errorf("Expected 3 plugins to process, got %d: %v", len(pluginsToProcess), pluginsToProcess)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all plugins are included
|
||||||
|
expectedPlugins := map[string]bool{"ping": false, "echo": false, "help": false}
|
||||||
|
for _, pluginID := range pluginsToProcess {
|
||||||
|
if _, exists := expectedPlugins[pluginID]; exists {
|
||||||
|
expectedPlugins[pluginID] = true
|
||||||
|
} else {
|
||||||
|
t.Errorf("Unexpected plugin in processing list: %s", pluginID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for pluginID, found := range expectedPlugins {
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Expected plugin %s to be in processing list", pluginID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Plugin configuration handling", func(t *testing.T) {
|
||||||
|
// Test the configuration logic from handleMessage
|
||||||
|
channel := &model.Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "123456",
|
||||||
|
Enabled: true,
|
||||||
|
EnableAllPlugins: true,
|
||||||
|
Plugins: map[string]*model.ChannelPlugin{
|
||||||
|
"ping": {
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1,
|
||||||
|
PluginID: "ping",
|
||||||
|
Enabled: true,
|
||||||
|
Config: map[string]interface{}{"configured": "value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
pluginID string
|
||||||
|
expectedConfig map[string]interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
pluginID: "ping",
|
||||||
|
expectedConfig: map[string]interface{}{"configured": "value"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
pluginID: "echo", // Not explicitly configured
|
||||||
|
expectedConfig: map[string]interface{}{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
// Simulate the config retrieval logic from handleMessage
|
||||||
|
var config map[string]interface{}
|
||||||
|
if channelPlugin, exists := channel.Plugins[tc.pluginID]; exists {
|
||||||
|
config = channelPlugin.Config
|
||||||
|
} else {
|
||||||
|
config = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config) != len(tc.expectedConfig) {
|
||||||
|
t.Errorf("Plugin %s: expected config length %d, got %d", tc.pluginID, len(tc.expectedConfig), len(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, expectedValue := range tc.expectedConfig {
|
||||||
|
if actualValue, exists := config[key]; !exists || actualValue != expectedValue {
|
||||||
|
t.Errorf("Plugin %s: expected config[%s] = %v, got %v", tc.pluginID, key, expectedValue, actualValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluginRegistry(t *testing.T) {
|
||||||
|
// Clear registry before test
|
||||||
|
ClearRegistry()
|
||||||
|
|
||||||
|
testPlugin := &testPlugin{
|
||||||
|
BasePlugin: BasePlugin{
|
||||||
|
ID: "test.registry",
|
||||||
|
Name: "Test Registry Plugin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Register and Get plugin", func(t *testing.T) {
|
||||||
|
Register(testPlugin)
|
||||||
|
|
||||||
|
retrieved, err := Get("test.registry")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get registered plugin: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieved.GetID() != "test.registry" {
|
||||||
|
t.Errorf("Expected plugin ID 'test.registry', got '%s'", retrieved.GetID())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Get nonexistent plugin", func(t *testing.T) {
|
||||||
|
_, err := Get("nonexistent.plugin")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error when getting nonexistent plugin, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != model.ErrPluginNotFound {
|
||||||
|
t.Errorf("Expected ErrPluginNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetAvailablePlugins", func(t *testing.T) {
|
||||||
|
plugins := GetAvailablePlugins()
|
||||||
|
|
||||||
|
if len(plugins) != 1 {
|
||||||
|
t.Errorf("Expected 1 plugin in registry, got %d", len(plugins))
|
||||||
|
}
|
||||||
|
|
||||||
|
if plugin, exists := plugins["test.registry"]; !exists {
|
||||||
|
t.Errorf("Expected to find test.registry in available plugins")
|
||||||
|
} else if plugin.GetID() != "test.registry" {
|
||||||
|
t.Errorf("Expected plugin ID 'test.registry', got '%s'", plugin.GetID())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ClearRegistry", func(t *testing.T) {
|
||||||
|
ClearRegistry()
|
||||||
|
|
||||||
|
plugins := GetAvailablePlugins()
|
||||||
|
if len(plugins) != 0 {
|
||||||
|
t.Errorf("Expected 0 plugins after clearing registry, got %d", len(plugins))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Get("test.registry")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error when getting plugin after clearing registry, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -54,17 +54,12 @@ func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interf
|
||||||
// Parse the URL
|
// Parse the URL
|
||||||
parsedURL, err := url.Parse(link)
|
parsedURL, err := url.Parse(link)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If parsing fails, just do the simple replacement
|
|
||||||
link = strings.Replace(link, "twitter.com", replacementDomain, 1)
|
|
||||||
link = strings.Replace(link, "x.com", replacementDomain, 1)
|
|
||||||
return link
|
return link
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change the host to the configured domain
|
// Change the host to the configured domain
|
||||||
if strings.Contains(parsedURL.Host, "twitter.com") {
|
if strings.Contains(parsedURL.Host, "twitter.com") || strings.Contains(parsedURL.Host, "x.com") {
|
||||||
parsedURL.Host = strings.Replace(parsedURL.Host, "twitter.com", replacementDomain, 1)
|
parsedURL.Host = replacementDomain
|
||||||
} else if strings.Contains(parsedURL.Host, "x.com") {
|
|
||||||
parsedURL.Host = strings.Replace(parsedURL.Host, "x.com", replacementDomain, 1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove query parameters
|
// Remove query parameters
|
||||||
|
|
120
internal/plugin/social/twitter_test.go
Normal file
120
internal/plugin/social/twitter_test.go
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
package social
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTwitterExpander_OnMessage(t *testing.T) {
|
||||||
|
plugin := NewTwitterExpander()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
config map[string]interface{}
|
||||||
|
expected string
|
||||||
|
hasReply bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Twitter URL with default domain",
|
||||||
|
input: "https://twitter.com/user/status/123456789",
|
||||||
|
config: map[string]interface{}{},
|
||||||
|
expected: "https://fxtwitter.com/user/status/123456789",
|
||||||
|
hasReply: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X.com URL with custom domain",
|
||||||
|
input: "https://x.com/elonmusk/status/987654321",
|
||||||
|
config: map[string]interface{}{"domain": "vxtwitter.com"},
|
||||||
|
expected: "https://vxtwitter.com/elonmusk/status/987654321",
|
||||||
|
hasReply: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Twitter URL with tracking parameters",
|
||||||
|
input: "https://twitter.com/openai/status/555?ref_src=twsrc%5Etfw&s=20",
|
||||||
|
config: map[string]interface{}{},
|
||||||
|
expected: "https://fxtwitter.com/openai/status/555",
|
||||||
|
hasReply: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "www.twitter.com URL",
|
||||||
|
input: "https://www.twitter.com/user/status/789",
|
||||||
|
config: map[string]interface{}{"domain": "nitter.net"},
|
||||||
|
expected: "https://nitter.net/user/status/789",
|
||||||
|
hasReply: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed text with Twitter URL",
|
||||||
|
input: "Check this out: https://twitter.com/user/status/123 amazing!",
|
||||||
|
config: map[string]interface{}{},
|
||||||
|
expected: "Check this out: https://fxtwitter.com/user/status/123 amazing!",
|
||||||
|
hasReply: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Twitter URLs",
|
||||||
|
input: "Just some regular text https://youtube.com/watch?v=abc",
|
||||||
|
config: map[string]interface{}{},
|
||||||
|
expected: "",
|
||||||
|
hasReply: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty message",
|
||||||
|
input: "",
|
||||||
|
config: map[string]interface{}{},
|
||||||
|
expected: "",
|
||||||
|
hasReply: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
msg := &model.Message{
|
||||||
|
ID: "test_msg",
|
||||||
|
Text: tt.input,
|
||||||
|
Chat: "test_chat",
|
||||||
|
Channel: &model.Channel{
|
||||||
|
ID: 1,
|
||||||
|
Platform: "telegram",
|
||||||
|
PlatformChannelID: "test_chat",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
actions := plugin.OnMessage(msg, tt.config, nil)
|
||||||
|
|
||||||
|
if !tt.hasReply {
|
||||||
|
if len(actions) != 0 {
|
||||||
|
t.Errorf("Expected no actions, got %d", len(actions))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(actions) != 1 {
|
||||||
|
t.Errorf("Expected 1 action, got %d", len(actions))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
action := actions[0]
|
||||||
|
if action.Type != model.ActionSendMessage {
|
||||||
|
t.Errorf("Expected ActionSendMessage, got %s", action.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if action.Message == nil {
|
||||||
|
t.Error("Expected message in action, got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if action.Message.Text != tt.expected {
|
||||||
|
t.Errorf("Expected '%s', got '%s'", tt.expected, action.Message.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
if action.Message.ReplyTo != msg.ID {
|
||||||
|
t.Errorf("Expected ReplyTo '%s', got '%s'", msg.ID, action.Message.ReplyTo)
|
||||||
|
}
|
||||||
|
|
||||||
|
if action.Message.Raw == nil || action.Message.Raw["parse_mode"] != "" {
|
||||||
|
t.Error("Expected parse_mode to be empty string to disable markdown parsing")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue