Compare commits

...

20 commits

Author SHA1 Message Date
377b1723c3
fix: default parse mode to text
Some checks failed
ci/woodpecker/push/ci Pipeline failed
ci/woodpecker/tag/release Pipeline was successful
2025-06-24 08:10:56 +02:00
60ceaffd82
fix: enable all plugins help text
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
2025-06-23 11:43:42 +02:00
3a5b5c216d
chore: try to ensure that code is checked after each session
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
2025-06-23 11:35:30 +02:00
bdc797d5c1
chore: make format 2025-06-23 11:34:27 +02:00
0edf41c792
fix: markdown parse mode breaking some plugins
Some checks failed
ci/woodpecker/push/ci Pipeline failed
ci/woodpecker/tag/release Pipeline was successful
2025-06-23 11:32:34 +02:00
35c14ce8a8
chore: update CLAUDE.md 2025-06-23 11:20:21 +02:00
e0ff369cff
chore: make format
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
2025-06-23 11:19:22 +02:00
368c45cd13
fix: twitter plugin replacement logic
Some checks failed
ci/woodpecker/push/ci Pipeline failed
ci/woodpecker/tag/release Pipeline was successful
2025-06-23 11:18:47 +02:00
3b09a9dd47
feat: allow enabling all plugins into a channel
Some checks failed
ci/woodpecker/push/ci Pipeline failed
ci/woodpecker/tag/release Pipeline was successful
2025-06-23 11:10:43 +02:00
899ac49336
chore: split plugin configuration templates
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
ci/woodpecker/tag/release Pipeline was successful
2025-06-15 13:03:52 +02:00
fc77c97547
feat: add configuration options for instagram and twitter plugins 2025-06-15 12:17:54 +02:00
3a4ba376e7
chore: ignore all test db files
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
ci/woodpecker/tag/release Pipeline was successful
2025-06-13 12:04:22 +02:00
bd9854676d
feat: added help command 2025-06-13 12:04:07 +02:00
4fc5ae63a1
chore: update ignore patterns for test files
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
2025-06-13 10:48:53 +02:00
3771d2de65
docs: update CLAUDE.md 2025-06-13 10:47:53 +02:00
c7fdb9fc6a
docs: updated plugin docs 2025-06-13 10:45:34 +02:00
1f80a22f4a
chore: remove commited test_cache 2025-06-13 10:45:28 +02:00
1e0bc86b21
feat: improve sqlite database reliability
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
ci/woodpecker/tag/release Pipeline was successful
2025-06-13 09:27:06 +02:00
8fa74fd046
fix: database tests for cache 2025-06-13 09:26:49 +02:00
d09b763aa7
feat: hltb plugin
Some checks failed
ci/woodpecker/push/ci Pipeline failed
2025-06-12 14:51:12 +02:00
40 changed files with 2564 additions and 73 deletions

5
.gitignore vendored
View file

@ -5,9 +5,12 @@ __pycache__
*.cert *.cert
.env-local .env-local
.coverage .coverage
coverage.out
dist dist
bin bin
# Butterrobot # Butterrobot
*.sqlite* *.sqlite*
butterrobot.db butterrobot.db*
/butterrobot
*_test.db*

29
CLAUDE.md Normal file
View file

@ -0,0 +1,29 @@
# Claude Code Instructions
## Plugin Development Workflow
When creating, modifying, or removing plugins:
1. **Always update the plugin documentation** in `docs/plugins.md` after any plugin changes
2. Ensure the documentation includes:
- Plugin name and category (Development, Fun and entertainment, Utility, Security, Social Media)
- Brief description of functionality
- Usage instructions with examples
- Any configuration requirements
3. **For plugins with configuration options:**
- Set `ConfigRequired: true` in the plugin's BasePlugin struct
- Add corresponding HTML form fields in `internal/admin/templates/channel_plugin_config.html`
- Use conditional template logic: `{{else if eq .ChannelPlugin.PluginID "plugin.id"}}`
- Include proper form labels, help text, and value binding
## Testing
**CRITICAL**: After making ANY changes to code files, you MUST run these commands in order:
1. **Format code**: `make format` - Format all code according to project standards
2. **Lint code**: `make lint` - Check code style and quality (must show "0 issues")
3. **Run tests**: `make test` - Run all tests to ensure functionality works
4. Verify documentation accuracy
5. Ensure all examples work as described
**These commands are MANDATORY after every code change, no exceptions.**

View file

@ -9,10 +9,13 @@
- Lo quito: What happens when you say _"lo quito"_...? (Spanish pun) - Lo quito: What happens when you say _"lo quito"_...? (Spanish pun)
- Dice: Put `!dice` and wathever roll you want to perform. - Dice: Put `!dice` and wathever roll you want to perform.
- Coin: Flip a coin and get heads or tails. - Coin: Flip a coin and get heads or tails.
- How Long To Beat: Get game completion times from HowLongToBeat.com using `!hltb <game name>`
### Utility ### Utility
- Help: Shows available commands when you type `!help`. Lists all enabled plugins for the current channel organized by category with their descriptions and usage instructions.
- Remind Me: Reply to a message with `!remindme <duration>` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time. - Remind Me: Reply to a message with `!remindme <duration>` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time.
- Search and Replace: Reply to any message with `s/search/replace/[flags]` to perform text substitution. Supports flags: `g` (global), `i` (case insensitive), `n` (regex pattern). Example: `s/hello/hi/gi` replaces all occurrences of "hello" with "hi" case-insensitively.
### Security ### Security
@ -20,5 +23,5 @@
### Social Media ### Social Media
- Twitter Link Expander: Automatically converts twitter.com and x.com links to fxtwitter.com links and removes tracking parameters. This allows for better media embedding in chat platforms. - Twitter Link Expander: Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: fxtwitter.com).
- Instagram Link Expander: Automatically converts instagram.com links to ddinstagram.com links and removes tracking parameters. This allows for better media embedding in chat platforms. - Instagram Link Expander: Automatically converts instagram.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: ddinstagram.com).

2
go.mod
View file

@ -6,6 +6,7 @@ require (
github.com/gorilla/sessions v1.4.0 github.com/gorilla/sessions v1.4.0
golang.org/x/crypto v0.37.0 golang.org/x/crypto v0.37.0
golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
modernc.org/sqlite v1.37.0 modernc.org/sqlite v1.37.0
) )
@ -16,7 +17,6 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
golang.org/x/sys v0.32.0 // indirect golang.org/x/sys v0.32.0 // indirect
modernc.org/libc v1.63.0 // indirect modernc.org/libc v1.63.0 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

View file

@ -16,7 +16,7 @@ import (
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
) )
//go:embed templates/*.html //go:embed templates/*.html templates/plugins/*.html
var templateFS embed.FS var templateFS embed.FS
const ( const (
@ -90,7 +90,7 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
} }
// Parse and register all templates // Parse and register all templates
templateFiles := []string{ mainTemplateFiles := []string{
"index.html", "index.html",
"login.html", "login.html",
"change_password.html", "change_password.html",
@ -101,7 +101,13 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
"channel_plugin_config.html", "channel_plugin_config.html",
} }
for _, tf := range templateFiles { pluginTemplateFiles := []string{
"plugins/security.domainblock.html",
"plugins/social.instagram.html",
"plugins/social.twitter.html",
}
for _, tf := range mainTemplateFiles {
// Read template content from embedded filesystem // Read template content from embedded filesystem
content, err := templateFS.ReadFile("templates/" + tf) content, err := templateFS.ReadFile("templates/" + tf)
if err != nil { if err != nil {
@ -120,6 +126,20 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
panic(err) panic(err)
} }
// If this is the channel_plugin_config template, also parse plugin templates
if tf == "channel_plugin_config.html" {
for _, pluginTf := range pluginTemplateFiles {
pluginContent, err := templateFS.ReadFile("templates/" + pluginTf)
if err != nil {
panic(err)
}
t, err = t.Parse(string(pluginContent))
if err != nil {
panic(err)
}
}
}
templates[tf] = t templates[tf] = t
} }
@ -544,6 +564,13 @@ func (a *Admin) handleChannelDetail(w http.ResponseWriter, r *http.Request) {
return return
} }
// Update enable_all_plugins
enableAllPlugins := r.FormValue("enable_all_plugins") == "true"
if err := a.db.UpdateChannelEnableAllPlugins(id, enableAllPlugins); err != nil {
http.Error(w, "Failed to update channel enable all plugins", http.StatusInternalServerError)
return
}
a.addFlash(w, r, "Channel updated", "success") a.addFlash(w, r, "Channel updated", "success")
http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther) http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther)
return return

View file

@ -27,6 +27,15 @@
<!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked --> <!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked -->
<input type="hidden" name="form_submitted" value="true"> <input type="hidden" name="form_submitted" value="true">
</div> </div>
<div class="mb-3">
<label class="form-check form-switch">
<input class="form-check-input" type="checkbox" name="enable_all_plugins" value="true" {{if .Channel.EnableAllPlugins}}checked{{end}}>
<span class="form-check-label">Enable All Plugins</span>
</label>
<div>
When enabled, all registered plugins will be automatically enabled for this channel. Individual plugin settings will be ignored.
</div>
</div>
<div class="form-footer"> <div class="form-footer">
<button type="submit" class="btn btn-primary">Save</button> <button type="submit" class="btn btn-primary">Save</button>
<a href="/admin/channels" class="btn btn-link">Back to Channels</a> <a href="/admin/channels" class="btn btn-link">Back to Channels</a>

View file

@ -9,16 +9,11 @@
<form method="post"> <form method="post">
<!-- Plugin configuration fields --> <!-- Plugin configuration fields -->
{{if eq .ChannelPlugin.PluginID "security.domainblock"}} {{if eq .ChannelPlugin.PluginID "security.domainblock"}}
<div class="mb-3"> {{template "plugins/security.domainblock.html" .}}
<label class="form-label">Blocked Domains</label> {{else if eq .ChannelPlugin.PluginID "social.instagram"}}
<input type="text" class="form-control" name="blocked_domains" {{template "plugins/social.instagram.html" .}}
value="{{with .ChannelPlugin.Config}}{{index . "blocked_domains"}}{{end}}" {{else if eq .ChannelPlugin.PluginID "social.twitter"}}
placeholder="example.com, evil.org, ads.com"> {{template "plugins/social.twitter.html" .}}
<div class="form-text text-muted">
Enter comma-separated list of domains to block (e.g., example.com, evil.org).
Messages containing links to these domains will be blocked.
</div>
</div>
{{else}} {{else}}
<div class="alert alert-warning"> <div class="alert alert-warning">
This plugin doesn't have specific configuration fields implemented yet. This plugin doesn't have specific configuration fields implemented yet.

View file

@ -0,0 +1,12 @@
{{define "plugins/security.domainblock.html"}}
<div class="mb-3">
<label class="form-label">Blocked Domains</label>
<input type="text" class="form-control" name="blocked_domains"
value="{{with .ChannelPlugin.Config}}{{index . "blocked_domains"}}{{end}}"
placeholder="example.com, evil.org, ads.com">
<div class="form-text text-muted">
Enter comma-separated list of domains to block (e.g., example.com, evil.org).
Messages containing links to these domains will be blocked.
</div>
</div>
{{end}}

View file

@ -0,0 +1,11 @@
{{define "plugins/social.instagram.html"}}
<div class="mb-3">
<label class="form-label">Replacement Domain</label>
<input type="text" class="form-control" name="domain"
value="{{with .ChannelPlugin.Config}}{{index . "domain"}}{{end}}"
placeholder="ddinstagram.com">
<div class="form-text text-muted">
Enter the domain to replace instagram.com links with. Default is ddinstagram.com if left empty.
</div>
</div>
{{end}}

View file

@ -0,0 +1,11 @@
{{define "plugins/social.twitter.html"}}
<div class="mb-3">
<label class="form-label">Replacement Domain</label>
<input type="text" class="form-control" name="domain"
value="{{with .ChannelPlugin.Config}}{{index . "domain"}}{{end}}"
placeholder="fxtwitter.com">
<div class="form-text text-muted">
Enter the domain to replace twitter.com and x.com links with. Default is fxtwitter.com if left empty.
</div>
</div>
{{end}}

View file

@ -15,6 +15,7 @@ import (
"time" "time"
"git.nakama.town/fmartingr/butterrobot/internal/admin" "git.nakama.town/fmartingr/butterrobot/internal/admin"
"git.nakama.town/fmartingr/butterrobot/internal/cache"
"git.nakama.town/fmartingr/butterrobot/internal/config" "git.nakama.town/fmartingr/butterrobot/internal/config"
"git.nakama.town/fmartingr/butterrobot/internal/db" "git.nakama.town/fmartingr/butterrobot/internal/db"
"git.nakama.town/fmartingr/butterrobot/internal/model" "git.nakama.town/fmartingr/butterrobot/internal/model"
@ -22,6 +23,7 @@ import (
"git.nakama.town/fmartingr/butterrobot/internal/plugin" "git.nakama.town/fmartingr/butterrobot/internal/plugin"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/domainblock" "git.nakama.town/fmartingr/butterrobot/internal/plugin/domainblock"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun" "git.nakama.town/fmartingr/butterrobot/internal/plugin/fun"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/help"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping" "git.nakama.town/fmartingr/butterrobot/internal/plugin/ping"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder" "git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/searchreplace" "git.nakama.town/fmartingr/butterrobot/internal/plugin/searchreplace"
@ -87,11 +89,13 @@ func (a *App) Run() error {
plugin.Register(fun.NewCoin()) plugin.Register(fun.NewCoin())
plugin.Register(fun.NewDice()) plugin.Register(fun.NewDice())
plugin.Register(fun.NewLoquito()) plugin.Register(fun.NewLoquito())
plugin.Register(fun.NewHLTB())
plugin.Register(social.NewTwitterExpander()) plugin.Register(social.NewTwitterExpander())
plugin.Register(social.NewInstagramExpander()) plugin.Register(social.NewInstagramExpander())
plugin.Register(reminder.New(a.db)) plugin.Register(reminder.New(a.db))
plugin.Register(domainblock.New()) plugin.Register(domainblock.New())
plugin.Register(searchreplace.New()) plugin.Register(searchreplace.New())
plugin.Register(help.New(a.db))
// Initialize routes // Initialize routes
a.initializeRoutes() a.initializeRoutes()
@ -102,6 +106,9 @@ func (a *App) Run() error {
// Start reminder scheduler // Start reminder scheduler
a.queue.StartReminderScheduler(a.handleReminder) a.queue.StartReminderScheduler(a.handleReminder)
// Start cache cleanup scheduler
go a.startCacheCleanup()
// Create server // Create server
addr := fmt.Sprintf(":%s", a.config.Port) addr := fmt.Sprintf(":%s", a.config.Port)
srv := &http.Server{ srv := &http.Server{
@ -147,6 +154,20 @@ func (a *App) Run() error {
return nil return nil
} }
// startCacheCleanup runs periodic cache cleanup
func (a *App) startCacheCleanup() {
ticker := time.NewTicker(time.Hour) // Clean up every hour
defer ticker.Stop()
for range ticker.C {
if err := a.db.CacheCleanup(); err != nil {
a.logger.Error("Cache cleanup failed", "error", err)
} else {
a.logger.Debug("Cache cleanup completed")
}
}
}
// Initialize HTTP routes // Initialize HTTP routes
func (a *App) initializeRoutes() { func (a *App) initializeRoutes() {
// Health check endpoint // Health check endpoint
@ -293,11 +314,21 @@ func (a *App) handleMessage(item queue.Item) {
} }
// Process message with plugins // Process message with plugins
for pluginID, channelPlugin := range channel.Plugins { var pluginsToProcess []string
if !channel.HasEnabledPlugin(pluginID) {
continue if channel.EnableAllPlugins {
// If EnableAllPlugins is true, process all registered plugins
pluginsToProcess = plugin.GetAvailablePluginIDs()
} else {
// Otherwise, process only explicitly enabled plugins
for pluginID := range channel.Plugins {
if channel.HasEnabledPlugin(pluginID) {
pluginsToProcess = append(pluginsToProcess, pluginID)
}
}
} }
for _, pluginID := range pluginsToProcess {
// Get plugin // Get plugin
p, err := plugin.Get(pluginID) p, err := plugin.Get(pluginID)
if err != nil { if err != nil {
@ -305,8 +336,19 @@ func (a *App) handleMessage(item queue.Item) {
continue continue
} }
// Get plugin configuration (empty map if EnableAllPlugins and plugin not explicitly configured)
var config map[string]interface{}
if channelPlugin, exists := channel.Plugins[pluginID]; exists {
config = channelPlugin.Config
} else {
config = make(map[string]interface{})
}
// Create cache instance for this plugin
pluginCache := cache.New(a.db, pluginID)
// Process message and get actions // Process message and get actions
actions := p.OnMessage(message, channelPlugin.Config) actions := p.OnMessage(message, config, pluginCache)
// Get platform for processing actions // Get platform for processing actions
platform, err := platform.Get(item.Platform) platform, err := platform.Get(item.Platform)

83
internal/cache/cache.go vendored Normal file
View file

@ -0,0 +1,83 @@
package cache
import (
"encoding/json"
"fmt"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/db"
)
// Cache provides a plugin-friendly interface to the cache system
type Cache struct {
db *db.Database
pluginID string
}
// New creates a new Cache instance for a specific plugin
func New(database *db.Database, pluginID string) *Cache {
return &Cache{
db: database,
pluginID: pluginID,
}
}
// Get retrieves a value from the cache
func (c *Cache) Get(key string, destination interface{}) error {
// Create prefixed key
fullKey := c.createKey(key)
// Get from database
value, err := c.db.CacheGet(fullKey)
if err != nil {
return err
}
// Unmarshal JSON into destination
return json.Unmarshal([]byte(value), destination)
}
// Set stores a value in the cache with optional expiration
func (c *Cache) Set(key string, value interface{}, expiration *time.Time) error {
// Create prefixed key
fullKey := c.createKey(key)
// Marshal value to JSON
jsonValue, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal cache value: %w", err)
}
// Store in database
return c.db.CacheSet(fullKey, string(jsonValue), expiration)
}
// SetWithTTL stores a value in the cache with a time-to-live duration
func (c *Cache) SetWithTTL(key string, value interface{}, ttl time.Duration) error {
expiration := time.Now().Add(ttl)
return c.Set(key, value, &expiration)
}
// Delete removes a value from the cache
func (c *Cache) Delete(key string) error {
fullKey := c.createKey(key)
return c.db.CacheDelete(fullKey)
}
// Exists checks if a key exists in the cache
func (c *Cache) Exists(key string) (bool, error) {
fullKey := c.createKey(key)
_, err := c.db.CacheGet(fullKey)
if err == db.ErrNotFound {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// createKey creates a prefixed cache key
func (c *Cache) createKey(key string) string {
return fmt.Sprintf("%s_%s", c.pluginID, key)
}

176
internal/cache/cache_test.go vendored Normal file
View file

@ -0,0 +1,176 @@
package cache
import (
"fmt"
"os"
"testing"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/db"
)
func TestCache(t *testing.T) {
// Create temporary database for testing with unique name
dbFile := fmt.Sprintf("test_cache_%d.db", time.Now().UnixNano())
database, err := db.New(dbFile)
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer func() {
_ = database.Close()
// Clean up test database file
_ = os.Remove(dbFile)
}()
// Create cache instance
cache := New(database, "test.plugin")
// Test data
testKey := "test_key"
testValue := map[string]interface{}{
"name": "Test Game",
"time": 42,
}
// Test Set and Get
t.Run("Set and Get", func(t *testing.T) {
err := cache.Set(testKey, testValue, nil)
if err != nil {
t.Errorf("Failed to set cache value: %v", err)
}
var retrieved map[string]interface{}
err = cache.Get(testKey, &retrieved)
if err != nil {
t.Errorf("Failed to get cache value: %v", err)
}
if retrieved["name"] != testValue["name"] {
t.Errorf("Expected name %v, got %v", testValue["name"], retrieved["name"])
}
if int(retrieved["time"].(float64)) != testValue["time"].(int) {
t.Errorf("Expected time %v, got %v", testValue["time"], retrieved["time"])
}
})
// Test SetWithTTL and expiration
t.Run("SetWithTTL and expiration", func(t *testing.T) {
expiredKey := "expired_key"
// Set with very short TTL
err := cache.SetWithTTL(expiredKey, testValue, time.Millisecond)
if err != nil {
t.Errorf("Failed to set cache value with TTL: %v", err)
}
// Wait for expiration
time.Sleep(2 * time.Millisecond)
// Try to get - should fail
var retrieved map[string]interface{}
err = cache.Get(expiredKey, &retrieved)
if err == nil {
t.Errorf("Expected cache miss for expired key, but got value")
}
})
// Test Exists
t.Run("Exists", func(t *testing.T) {
existsKey := "exists_key"
// Make sure the key doesn't exist initially by deleting it
_ = cache.Delete(existsKey)
// Should not exist initially
exists, err := cache.Exists(existsKey)
if err != nil {
t.Errorf("Failed to check if key exists: %v", err)
}
if exists {
t.Errorf("Expected key to not exist, but it does")
}
// Set value
err = cache.Set(existsKey, testValue, nil)
if err != nil {
t.Errorf("Failed to set cache value: %v", err)
}
// Should exist now
exists, err = cache.Exists(existsKey)
if err != nil {
t.Errorf("Failed to check if key exists: %v", err)
}
if !exists {
t.Errorf("Expected key to exist, but it doesn't")
}
})
// Test Delete
t.Run("Delete", func(t *testing.T) {
deleteKey := "delete_key"
// Set value
err := cache.Set(deleteKey, testValue, nil)
if err != nil {
t.Errorf("Failed to set cache value: %v", err)
}
// Delete value
err = cache.Delete(deleteKey)
if err != nil {
t.Errorf("Failed to delete cache value: %v", err)
}
// Should not exist anymore
var retrieved map[string]interface{}
err = cache.Get(deleteKey, &retrieved)
if err == nil {
t.Errorf("Expected cache miss for deleted key, but got value")
}
})
// Test plugin ID prefixing
t.Run("Plugin ID prefixing", func(t *testing.T) {
cache1 := New(database, "plugin1")
cache2 := New(database, "plugin2")
sameKey := "same_key"
value1 := "value1"
value2 := "value2"
// Set same key in both caches
err := cache1.Set(sameKey, value1, nil)
if err != nil {
t.Errorf("Failed to set cache1 value: %v", err)
}
err = cache2.Set(sameKey, value2, nil)
if err != nil {
t.Errorf("Failed to set cache2 value: %v", err)
}
// Retrieve from both caches
var retrieved1, retrieved2 string
err = cache1.Get(sameKey, &retrieved1)
if err != nil {
t.Errorf("Failed to get cache1 value: %v", err)
}
err = cache2.Get(sameKey, &retrieved2)
if err != nil {
t.Errorf("Failed to get cache2 value: %v", err)
}
// Values should be different due to plugin ID prefixing
if retrieved1 != value1 {
t.Errorf("Expected cache1 value %v, got %v", value1, retrieved1)
}
if retrieved2 != value2 {
t.Errorf("Expected cache2 value %v, got %v", value2, retrieved2)
}
})
}

View file

@ -35,6 +35,11 @@ func New(dbPath string) (*Database, error) {
return nil, err return nil, err
} }
// Configure SQLite for better reliability
if err := configureSQLite(db); err != nil {
return nil, err
}
// Initialize database // Initialize database
if err := initDatabase(db); err != nil { if err := initDatabase(db); err != nil {
return nil, err return nil, err
@ -51,7 +56,7 @@ func (d *Database) Close() error {
// GetChannelByID retrieves a channel by ID // GetChannelByID retrieves a channel by ID
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
query := ` query := `
SELECT id, platform, platform_channel_id, enabled, channel_raw SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
FROM channels FROM channels
WHERE id = ? WHERE id = ?
` `
@ -62,10 +67,11 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
platform string platform string
platformChannelID string platformChannelID string
enabled bool enabled bool
enableAllPlugins bool
channelRawJSON string channelRawJSON string
) )
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, ErrNotFound return nil, ErrNotFound
} }
@ -85,6 +91,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
Platform: platform, Platform: platform,
PlatformChannelID: platformChannelID, PlatformChannelID: platformChannelID,
Enabled: enabled, Enabled: enabled,
EnableAllPlugins: enableAllPlugins,
ChannelRaw: channelRaw, ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin), Plugins: make(map[string]*model.ChannelPlugin),
} }
@ -105,7 +112,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
// GetChannelByPlatform retrieves a channel by platform and platform channel ID // GetChannelByPlatform retrieves a channel by platform and platform channel ID
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) { func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
query := ` query := `
SELECT id, platform, platform_channel_id, enabled, channel_raw SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
FROM channels FROM channels
WHERE platform = ? AND platform_channel_id = ? WHERE platform = ? AND platform_channel_id = ?
` `
@ -115,10 +122,11 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
var ( var (
id int64 id int64
enabled bool enabled bool
enableAllPlugins bool
channelRawJSON string channelRawJSON string
) )
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, ErrNotFound return nil, ErrNotFound
} }
@ -138,6 +146,7 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
Platform: platform, Platform: platform,
PlatformChannelID: platformChannelID, PlatformChannelID: platformChannelID,
Enabled: enabled, Enabled: enabled,
EnableAllPlugins: enableAllPlugins,
ChannelRaw: channelRaw, ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin), Plugins: make(map[string]*model.ChannelPlugin),
} }
@ -165,11 +174,11 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
// Insert channel // Insert channel
query := ` query := `
INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw) INSERT INTO channels (platform, platform_channel_id, enabled, enable_all_plugins, channel_raw)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
` `
result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON)) result, err := d.db.Exec(query, platform, platformChannelID, enabled, false, string(channelRawJSON))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -186,6 +195,7 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
Platform: platform, Platform: platform,
PlatformChannelID: platformChannelID, PlatformChannelID: platformChannelID,
Enabled: enabled, Enabled: enabled,
EnableAllPlugins: false,
ChannelRaw: channelRaw, ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin), Plugins: make(map[string]*model.ChannelPlugin),
} }
@ -205,6 +215,18 @@ func (d *Database) UpdateChannel(id int64, enabled bool) error {
return err return err
} }
// UpdateChannelEnableAllPlugins updates a channel's enable_all_plugins status
func (d *Database) UpdateChannelEnableAllPlugins(id int64, enableAllPlugins bool) error {
query := `
UPDATE channels
SET enable_all_plugins = ?
WHERE id = ?
`
_, err := d.db.Exec(query, enableAllPlugins, id)
return err
}
// DeleteChannel deletes a channel // DeleteChannel deletes a channel
func (d *Database) DeleteChannel(id int64) error { func (d *Database) DeleteChannel(id int64) error {
// First delete all channel plugins // First delete all channel plugins
@ -256,7 +278,7 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
} }
// Parse config JSON // Parse config JSON
var config map[string]interface{} var config map[string]any
if err := json.Unmarshal([]byte(configJSON), &config); err != nil { if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
return nil, err return nil, err
} }
@ -283,6 +305,28 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
return plugins, nil return plugins, nil
} }
// GetChannelPluginsFromPlatformID retrieves all plugins for a channel by platform and platform channel ID
func (d *Database) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) {
// First, get the channel ID by platform and platform channel ID
query := `
SELECT id
FROM channels
WHERE platform = ? AND platform_channel_id = ?
`
var channelID int64
err := d.db.QueryRow(query, platform, platformChannelID).Scan(&channelID)
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
// Now get the plugins for this channel
return d.GetChannelPlugins(channelID)
}
// GetChannelPluginByID retrieves a channel plugin by ID // GetChannelPluginByID retrieves a channel plugin by ID
func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) { func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) {
query := ` query := `
@ -429,7 +473,7 @@ func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
// GetAllChannels retrieves all channels // GetAllChannels retrieves all channels
func (d *Database) GetAllChannels() ([]*model.Channel, error) { func (d *Database) GetAllChannels() ([]*model.Channel, error) {
query := ` query := `
SELECT id, platform, platform_channel_id, enabled, channel_raw SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
FROM channels FROM channels
` `
@ -451,10 +495,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
platform string platform string
platformChannelID string platformChannelID string
enabled bool enabled bool
enableAllPlugins bool
channelRawJSON string channelRawJSON string
) )
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil { if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON); err != nil {
return nil, err return nil, err
} }
@ -470,6 +515,7 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
Platform: platform, Platform: platform,
PlatformChannelID: platformChannelID, PlatformChannelID: platformChannelID,
Enabled: enabled, Enabled: enabled,
EnableAllPlugins: enableAllPlugins,
ChannelRaw: channelRaw, ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin), Plugins: make(map[string]*model.ChannelPlugin),
} }
@ -793,3 +839,82 @@ func initDatabase(db *sql.DB) error {
return nil return nil
} }
// Configure SQLite for better reliability
func configureSQLite(db *sql.DB) error {
pragmas := []string{
// Enable Write-Ahead Logging for better concurrency and crash recovery
"PRAGMA journal_mode = WAL",
// Set 5-second timeout when database is locked by another connection
"PRAGMA busy_timeout = 5000",
// Balance between safety and performance for disk writes
"PRAGMA synchronous = NORMAL",
// Set large cache size (1GB) for better read performance
"PRAGMA cache_size = 1000000000",
// Enable foreign key constraint enforcement
"PRAGMA foreign_keys = true",
// Store temporary tables and indices in memory for speed
"PRAGMA temp_store = memory",
}
for _, pragma := range pragmas {
if _, err := db.Exec(pragma); err != nil {
return fmt.Errorf("failed to execute %s: %w", pragma, err)
}
}
return nil
}
// CacheGet retrieves a value from the cache
func (d *Database) CacheGet(key string) (string, error) {
query := `
SELECT value
FROM cache
WHERE key = ? AND (expires_at IS NULL OR expires_at > ?)
`
var value string
err := d.db.QueryRow(query, key, time.Now()).Scan(&value)
if err == sql.ErrNoRows {
return "", ErrNotFound
}
if err != nil {
return "", err
}
return value, nil
}
// CacheSet stores a value in the cache with optional expiration
func (d *Database) CacheSet(key, value string, expiration *time.Time) error {
query := `
INSERT OR REPLACE INTO cache (key, value, expires_at, updated_at)
VALUES (?, ?, ?, ?)
`
_, err := d.db.Exec(query, key, value, expiration, time.Now())
return err
}
// CacheDelete removes a value from the cache
func (d *Database) CacheDelete(key string) error {
query := `
DELETE FROM cache
WHERE key = ?
`
_, err := d.db.Exec(query, key)
return err
}
// CacheCleanup removes expired cache entries
func (d *Database) CacheCleanup() error {
query := `
DELETE FROM cache
WHERE expires_at IS NOT NULL AND expires_at <= ?
`
_, err := d.db.Exec(query, time.Now())
return err
}

203
internal/db/db_test.go Normal file
View file

@ -0,0 +1,203 @@
package db
import (
"fmt"
"os"
"testing"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/model"
)
func TestEnableAllPlugins(t *testing.T) {
// Create temporary database for testing with unique name
dbFile := fmt.Sprintf("test_db_%d.db", time.Now().UnixNano())
database, err := New(dbFile)
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer func() {
_ = database.Close()
// Clean up test database file
_ = os.Remove(dbFile)
}()
t.Run("CreateChannel with EnableAllPlugins default false", func(t *testing.T) {
channelRaw := map[string]interface{}{
"name": "test-channel",
}
channel, err := database.CreateChannel("telegram", "123456", true, channelRaw)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
if channel.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be false by default, got true")
}
// Verify it's also false when retrieved from database
retrieved, err := database.GetChannelByID(channel.ID)
if err != nil {
t.Fatalf("Failed to retrieve channel: %v", err)
}
if retrieved.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be false when retrieved from DB, got true")
}
})
t.Run("UpdateChannelEnableAllPlugins", func(t *testing.T) {
// Create a channel
channelRaw := map[string]interface{}{
"name": "test-channel-2",
}
channel, err := database.CreateChannel("telegram", "123457", true, channelRaw)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
// Update EnableAllPlugins to true
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
}
// Retrieve and verify
retrieved, err := database.GetChannelByID(channel.ID)
if err != nil {
t.Fatalf("Failed to retrieve channel: %v", err)
}
if !retrieved.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be true after update, got false")
}
// Update back to false
err = database.UpdateChannelEnableAllPlugins(channel.ID, false)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins back to false: %v", err)
}
// Retrieve and verify again
retrieved, err = database.GetChannelByID(channel.ID)
if err != nil {
t.Fatalf("Failed to retrieve channel: %v", err)
}
if retrieved.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be false after second update, got true")
}
})
t.Run("GetChannelByPlatform includes EnableAllPlugins", func(t *testing.T) {
// Create a channel
channelRaw := map[string]interface{}{
"name": "test-channel-3",
}
channel, err := database.CreateChannel("slack", "C123456", true, channelRaw)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
// Enable all plugins
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
}
// Retrieve by platform
retrieved, err := database.GetChannelByPlatform("slack", "C123456")
if err != nil {
t.Fatalf("Failed to retrieve channel by platform: %v", err)
}
if !retrieved.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be true when retrieved by platform, got false")
}
})
t.Run("GetAllChannels includes EnableAllPlugins", func(t *testing.T) {
// Create multiple channels with different EnableAllPlugins settings
channelRaw1 := map[string]interface{}{"name": "channel-1"}
channelRaw2 := map[string]interface{}{"name": "channel-2"}
channel1, err := database.CreateChannel("platform1", "ch1", true, channelRaw1)
if err != nil {
t.Fatalf("Failed to create channel1: %v", err)
}
channel2, err := database.CreateChannel("platform2", "ch2", true, channelRaw2)
if err != nil {
t.Fatalf("Failed to create channel2: %v", err)
}
// Enable all plugins for channel2 only
err = database.UpdateChannelEnableAllPlugins(channel2.ID, true)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins for channel2: %v", err)
}
// Get all channels
channels, err := database.GetAllChannels()
if err != nil {
t.Fatalf("Failed to get all channels: %v", err)
}
// Find our test channels
var foundChannel1, foundChannel2 *model.Channel
for _, ch := range channels {
if ch.ID == channel1.ID {
foundChannel1 = ch
}
if ch.ID == channel2.ID {
foundChannel2 = ch
}
}
if foundChannel1 == nil {
t.Fatalf("Channel1 not found in GetAllChannels result")
}
if foundChannel2 == nil {
t.Fatalf("Channel2 not found in GetAllChannels result")
}
if foundChannel1.EnableAllPlugins {
t.Errorf("Expected channel1 EnableAllPlugins to be false, got true")
}
if !foundChannel2.EnableAllPlugins {
t.Errorf("Expected channel2 EnableAllPlugins to be true, got false")
}
})
t.Run("Migration applied correctly", func(t *testing.T) {
// Test that we can create a channel and the enable_all_plugins column exists
// This implicitly tests that migration 4 was applied correctly
channelRaw := map[string]interface{}{
"name": "migration-test-channel",
}
channel, err := database.CreateChannel("test-platform", "migration-test", true, channelRaw)
if err != nil {
t.Fatalf("Failed to create channel after migration: %v", err)
}
// Try to update EnableAllPlugins - this would fail if the column doesn't exist
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins - migration may not have been applied: %v", err)
}
// Verify the value was set correctly
retrieved, err := database.GetChannelByID(channel.ID)
if err != nil {
t.Fatalf("Failed to retrieve channel: %v", err)
}
if !retrieved.EnableAllPlugins {
t.Errorf("EnableAllPlugins should be true after update")
}
})
}

View file

@ -9,6 +9,8 @@ func init() {
// Register migrations // Register migrations
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown) Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown) Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown)
Register(3, "Add cache table", migrateCacheUp, migrateCacheDown)
Register(4, "Add enable_all_plugins column to channels", migrateEnableAllPluginsUp, migrateEnableAllPluginsDown)
} }
// Initial schema creation with bcrypt passwords - version 1 // Initial schema creation with bcrypt passwords - version 1
@ -126,3 +128,87 @@ func migrateRemindersDown(db *sql.DB) error {
_, err := db.Exec(`DROP TABLE IF EXISTS reminders`) _, err := db.Exec(`DROP TABLE IF EXISTS reminders`)
return err return err
} }
// Add cache table - version 3
func migrateCacheUp(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
expires_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
}
// Create index on expires_at for efficient cleanup
_, err = db.Exec(`
CREATE INDEX IF NOT EXISTS idx_cache_expires_at ON cache(expires_at)
`)
return err
}
func migrateCacheDown(db *sql.DB) error {
_, err := db.Exec(`DROP TABLE IF EXISTS cache`)
return err
}
// Add enable_all_plugins column to channels table - version 4
func migrateEnableAllPluginsUp(db *sql.DB) error {
_, err := db.Exec(`
ALTER TABLE channels ADD COLUMN enable_all_plugins BOOLEAN NOT NULL DEFAULT 0
`)
return err
}
func migrateEnableAllPluginsDown(db *sql.DB) error {
// SQLite doesn't support DROP COLUMN, so we need to recreate the table
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
_ = tx.Rollback() // Ignore rollback errors
}()
// Create backup table
_, err = tx.Exec(`
CREATE TABLE channels_backup (
id INTEGER PRIMARY KEY AUTOINCREMENT,
platform TEXT NOT NULL,
platform_channel_id TEXT NOT NULL,
enabled BOOLEAN NOT NULL DEFAULT 0,
channel_raw TEXT NOT NULL,
UNIQUE(platform, platform_channel_id)
)
`)
if err != nil {
return err
}
// Copy data excluding enable_all_plugins column
_, err = tx.Exec(`
INSERT INTO channels_backup (id, platform, platform_channel_id, enabled, channel_raw)
SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels
`)
if err != nil {
return err
}
// Drop original table
_, err = tx.Exec(`DROP TABLE channels`)
if err != nil {
return err
}
// Rename backup table
_, err = tx.Exec(`ALTER TABLE channels_backup RENAME TO channels`)
if err != nil {
return err
}
return tx.Commit()
}

View file

@ -44,11 +44,17 @@ type Channel struct {
PlatformChannelID string PlatformChannelID string
ChannelRaw map[string]interface{} ChannelRaw map[string]interface{}
Enabled bool Enabled bool
EnableAllPlugins bool
Plugins map[string]*ChannelPlugin Plugins map[string]*ChannelPlugin
} }
// HasEnabledPlugin checks if a plugin is enabled for this channel // HasEnabledPlugin checks if a plugin is enabled for this channel
func (c *Channel) HasEnabledPlugin(pluginID string) bool { func (c *Channel) HasEnabledPlugin(pluginID string) bool {
// If EnableAllPlugins is true, all plugins are considered enabled
if c.EnableAllPlugins {
return true
}
plugin, exists := c.Plugins[pluginID] plugin, exists := c.Plugins[pluginID]
if !exists { if !exists {
return false return false

View file

@ -0,0 +1,234 @@
package model
import (
"testing"
)
func TestChannel_HasEnabledPlugin(t *testing.T) {
t.Run("EnableAllPlugins false - plugin not in map", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: false,
Plugins: make(map[string]*ChannelPlugin),
}
// Plugin not in map should return false
result := channel.HasEnabledPlugin("nonexistent.plugin")
if result {
t.Errorf("Expected HasEnabledPlugin to return false for nonexistent plugin, got true")
}
})
t.Run("EnableAllPlugins false - plugin disabled", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: false,
Plugins: map[string]*ChannelPlugin{
"test.plugin": {
ID: 1,
ChannelID: 1,
PluginID: "test.plugin",
Enabled: false,
Config: make(map[string]any),
},
},
}
// Disabled plugin should return false
result := channel.HasEnabledPlugin("test.plugin")
if result {
t.Errorf("Expected HasEnabledPlugin to return false for disabled plugin, got true")
}
})
t.Run("EnableAllPlugins false - plugin enabled", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: false,
Plugins: map[string]*ChannelPlugin{
"test.plugin": {
ID: 1,
ChannelID: 1,
PluginID: "test.plugin",
Enabled: true,
Config: make(map[string]any),
},
},
}
// Enabled plugin should return true
result := channel.HasEnabledPlugin("test.plugin")
if !result {
t.Errorf("Expected HasEnabledPlugin to return true for enabled plugin, got false")
}
})
t.Run("EnableAllPlugins true - plugin not in map", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: make(map[string]*ChannelPlugin),
}
// When EnableAllPlugins is true, any plugin should be considered enabled
result := channel.HasEnabledPlugin("nonexistent.plugin")
if !result {
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false")
}
})
t.Run("EnableAllPlugins true - plugin disabled", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*ChannelPlugin{
"test.plugin": {
ID: 1,
ChannelID: 1,
PluginID: "test.plugin",
Enabled: false,
Config: make(map[string]any),
},
},
}
// When EnableAllPlugins is true, even disabled plugins should be considered enabled
result := channel.HasEnabledPlugin("test.plugin")
if !result {
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true (even for disabled plugin), got false")
}
})
t.Run("EnableAllPlugins true - plugin enabled", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*ChannelPlugin{
"test.plugin": {
ID: 1,
ChannelID: 1,
PluginID: "test.plugin",
Enabled: true,
Config: make(map[string]any),
},
},
}
// When EnableAllPlugins is true, enabled plugins should also return true
result := channel.HasEnabledPlugin("test.plugin")
if !result {
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false")
}
})
t.Run("EnableAllPlugins true - multiple plugins", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*ChannelPlugin{
"plugin1": {
ID: 1,
ChannelID: 1,
PluginID: "plugin1",
Enabled: true,
Config: make(map[string]any),
},
"plugin2": {
ID: 2,
ChannelID: 1,
PluginID: "plugin2",
Enabled: false,
Config: make(map[string]any),
},
},
}
// All plugins should be enabled when EnableAllPlugins is true
testCases := []string{"plugin1", "plugin2", "plugin3", "any.plugin"}
for _, pluginID := range testCases {
result := channel.HasEnabledPlugin(pluginID)
if !result {
t.Errorf("Expected HasEnabledPlugin('%s') to return true when EnableAllPlugins is true, got false", pluginID)
}
}
})
}
func TestChannelName(t *testing.T) {
t.Run("Returns PlatformChannelID when ChannelRaw is nil", func(t *testing.T) {
channel := &Channel{
PlatformChannelID: "test-id",
ChannelRaw: nil,
}
result := channel.ChannelName()
if result != "test-id" {
t.Errorf("Expected channel name to be 'test-id', got '%s'", result)
}
})
t.Run("Returns name from ChannelRaw when available", func(t *testing.T) {
channel := &Channel{
PlatformChannelID: "test-id",
ChannelRaw: map[string]interface{}{
"name": "Test Channel",
},
}
result := channel.ChannelName()
if result != "Test Channel" {
t.Errorf("Expected channel name to be 'Test Channel', got '%s'", result)
}
})
t.Run("Returns title from nested chat object (Telegram style)", func(t *testing.T) {
channel := &Channel{
PlatformChannelID: "test-id",
ChannelRaw: map[string]interface{}{
"chat": map[string]interface{}{
"title": "Telegram Group",
},
},
}
result := channel.ChannelName()
if result != "Telegram Group" {
t.Errorf("Expected channel name to be 'Telegram Group', got '%s'", result)
}
})
t.Run("Falls back to PlatformChannelID when no valid name found", func(t *testing.T) {
channel := &Channel{
PlatformChannelID: "fallback-id",
ChannelRaw: map[string]interface{}{
"other_field": "value",
},
}
result := channel.ChannelName()
if result != "fallback-id" {
t.Errorf("Expected channel name to fallback to 'fallback-id', got '%s'", result)
}
})
}

View file

@ -2,8 +2,18 @@ package model
import ( import (
"errors" "errors"
"time"
) )
// CacheInterface defines the cache interface available to plugins
type CacheInterface interface {
Get(key string, destination interface{}) error
Set(key string, value interface{}, expiration *time.Time) error
SetWithTTL(key string, value interface{}, ttl time.Duration) error
Delete(key string) error
Exists(key string) (bool, error)
}
var ( var (
// ErrPluginNotFound is returned when a requested plugin doesn't exist // ErrPluginNotFound is returned when a requested plugin doesn't exist
ErrPluginNotFound = errors.New("plugin not found") ErrPluginNotFound = errors.New("plugin not found")
@ -24,5 +34,5 @@ type Plugin interface {
RequiresConfig() bool RequiresConfig() bool
// OnMessage processes an incoming message and returns platform actions // OnMessage processes an incoming message and returns platform actions
OnMessage(msg *Message, config map[string]interface{}) []*MessageAction OnMessage(msg *Message, config map[string]interface{}, cache CacheInterface) []*MessageAction
} }

View file

@ -237,6 +237,15 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
"text": msg.Text, "text": msg.Text,
} }
// Set parse_mode based on plugin preference or default to empty string
if msg.Raw != nil && msg.Raw["parse_mode"] != nil {
// Plugin explicitly set parse_mode
payload["parse_mode"] = msg.Raw["parse_mode"]
} else {
// Default to empty string (no formatting)
payload["parse_mode"] = ""
}
// Add reply if needed // Add reply if needed
if msg.ReplyTo != "" { if msg.ReplyTo != "" {
replyToID, err := strconv.Atoi(msg.ReplyTo) replyToID, err := strconv.Atoi(msg.ReplyTo)

View file

@ -65,7 +65,7 @@ func extractDomains(text string) []string {
} }
// OnMessage processes incoming messages // OnMessage processes incoming messages
func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Skip messages from bots // Skip messages from bots
if msg.FromBot { if msg.FromBot {
return nil return nil

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"git.nakama.town/fmartingr/butterrobot/internal/model" "git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
) )
func TestExtractDomains(t *testing.T) { func TestExtractDomains(t *testing.T) {
@ -124,7 +125,8 @@ func TestOnMessage(t *testing.T) {
"blocked_domains": test.blockedDomains, "blocked_domains": test.blockedDomains,
} }
responses := plugin.OnMessage(msg, config) mockCache := &testutil.MockCache{}
responses := plugin.OnMessage(msg, config, mockCache)
if test.expectBlocked { if test.expectBlocked {
if len(responses) == 0 { if len(responses) == 0 {

View file

@ -29,7 +29,7 @@ func NewCoin() *CoinPlugin {
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") { if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") {
return nil return nil
} }

View file

@ -32,7 +32,7 @@ func NewDice() *DicePlugin {
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") { if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") {
return nil return nil
} }

394
internal/plugin/fun/hltb.go Normal file
View file

@ -0,0 +1,394 @@
package fun
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// HLTBPlugin searches HowLongToBeat for game completion times
type HLTBPlugin struct {
plugin.BasePlugin
httpClient *http.Client
}
// HLTBSearchRequest represents the search request payload
type HLTBSearchRequest struct {
SearchType string `json:"searchType"`
SearchTerms []string `json:"searchTerms"`
SearchPage int `json:"searchPage"`
Size int `json:"size"`
SearchOptions map[string]interface{} `json:"searchOptions"`
UseCache bool `json:"useCache"`
}
// HLTBGame represents a game from HowLongToBeat
type HLTBGame struct {
ID int `json:"game_id"`
Name string `json:"game_name"`
GameAlias string `json:"game_alias"`
GameImage string `json:"game_image"`
CompMain int `json:"comp_main"`
CompPlus int `json:"comp_plus"`
CompComplete int `json:"comp_complete"`
CompAll int `json:"comp_all"`
InvestedCo int `json:"invested_co"`
InvestedMp int `json:"invested_mp"`
CountComp int `json:"count_comp"`
CountSpeedruns int `json:"count_speedruns"`
CountBacklog int `json:"count_backlog"`
CountReview int `json:"count_review"`
ReviewScore int `json:"review_score"`
CountPlaying int `json:"count_playing"`
CountRetired int `json:"count_retired"`
}
// HLTBSearchResponse represents the search response
type HLTBSearchResponse struct {
Color string `json:"color"`
Title string `json:"title"`
Category string `json:"category"`
Count int `json:"count"`
Pagecurrent int `json:"pagecurrent"`
Pagesize int `json:"pagesize"`
Pagetotal int `json:"pagetotal"`
SearchTerm string `json:"searchTerm"`
SearchResults []HLTBGame `json:"data"`
}
// NewHLTB creates a new HLTBPlugin instance
func NewHLTB() *HLTBPlugin {
return &HLTBPlugin{
BasePlugin: plugin.BasePlugin{
ID: "fun.hltb",
Name: "How Long To Beat",
Help: "Get game completion times from HowLongToBeat.com using `!hltb <game name>`",
},
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// OnMessage handles incoming messages
func (p *HLTBPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Check if message starts with !hltb
text := strings.TrimSpace(msg.Text)
if !strings.HasPrefix(text, "!hltb ") {
return nil
}
// Extract game name
gameName := strings.TrimSpace(text[6:]) // Remove "!hltb "
if gameName == "" {
return p.createErrorResponse(msg, "Please provide a game name. Usage: !hltb <game name>")
}
// Check cache first
var games []HLTBGame
var err error
cacheKey := strings.ToLower(gameName)
err = cache.Get(cacheKey, &games)
if err != nil || len(games) == 0 {
// Cache miss - search for the game
games, err = p.searchGame(gameName)
if err != nil {
return p.createErrorResponse(msg, fmt.Sprintf("Error searching for game: %s", err.Error()))
}
if len(games) == 0 {
return p.createErrorResponse(msg, fmt.Sprintf("No results found for '%s'", gameName))
}
// Cache the results for 1 hour
err = cache.SetWithTTL(cacheKey, games, time.Hour)
if err != nil {
// Log cache error but don't fail the request
fmt.Printf("Warning: Failed to cache HLTB results: %v\n", err)
}
}
// Use the first result
game := games[0]
// Format the response
response := p.formatGameInfo(game)
// Create response message with game cover if available
responseMsg := &model.Message{
Text: response,
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
}
// Set parse mode for markdown formatting
if responseMsg.Raw == nil {
responseMsg.Raw = make(map[string]interface{})
}
responseMsg.Raw["parse_mode"] = "Markdown"
// Add game cover as attachment if available
if game.GameImage != "" {
imageURL := p.getFullImageURL(game.GameImage)
responseMsg.Raw["image_url"] = imageURL
}
action := &model.MessageAction{
Type: model.ActionSendMessage,
Message: responseMsg,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
}
// searchGame searches for a game on HowLongToBeat
func (p *HLTBPlugin) searchGame(gameName string) ([]HLTBGame, error) {
// Split search terms by words
searchTerms := strings.Fields(gameName)
// Prepare search request
searchRequest := HLTBSearchRequest{
SearchType: "games",
SearchTerms: searchTerms,
SearchPage: 1,
Size: 20,
SearchOptions: map[string]interface{}{
"games": map[string]interface{}{
"userId": 0,
"platform": "",
"sortCategory": "popular",
"rangeCategory": "main",
"rangeTime": map[string]interface{}{
"min": nil,
"max": nil,
},
"gameplay": map[string]interface{}{
"perspective": "",
"flow": "",
"genre": "",
"difficulty": "",
},
"rangeYear": map[string]interface{}{
"min": "",
"max": "",
},
"modifier": "",
},
"users": map[string]interface{}{
"sortCategory": "postcount",
},
"lists": map[string]interface{}{
"sortCategory": "follows",
},
"filter": "",
"sort": 0,
"randomizer": 0,
},
UseCache: true,
}
// Convert to JSON
jsonData, err := json.Marshal(searchRequest)
if err != nil {
return nil, fmt.Errorf("failed to marshal search request: %w", err)
}
// The API endpoint appears to have changed to use dynamic tokens
// Try to get the seek token first, fallback to basic search
seekToken, err := p.getSeekToken()
if err != nil {
// Fallback to old endpoint
seekToken = ""
}
var apiURL string
if seekToken != "" {
apiURL = fmt.Sprintf("https://howlongtobeat.com/api/seek/%s", seekToken)
} else {
apiURL = "https://howlongtobeat.com/api/search"
}
// Create HTTP request
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers to match the working curl request
req.Header.Set("Accept", "*/*")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Origin", "https://howlongtobeat.com")
req.Header.Set("Pragma", "no-cache")
req.Header.Set("Referer", "https://howlongtobeat.com")
req.Header.Set("Sec-Fetch-Dest", "empty")
req.Header.Set("Sec-Fetch-Mode", "cors")
req.Header.Set("Sec-Fetch-Site", "same-origin")
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36")
// Send request
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned status code: %d", resp.StatusCode)
}
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Parse response
var searchResponse HLTBSearchResponse
if err := json.Unmarshal(body, &searchResponse); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return searchResponse.SearchResults, nil
}
// formatGameInfo formats game information for display
func (p *HLTBPlugin) formatGameInfo(game HLTBGame) string {
var response strings.Builder
response.WriteString(fmt.Sprintf("🎮 **%s**\n\n", game.Name))
// Format completion times
if game.CompMain > 0 {
response.WriteString(fmt.Sprintf("📖 **Main Story:** %s\n", p.formatTime(game.CompMain)))
}
if game.CompPlus > 0 {
response.WriteString(fmt.Sprintf(" **Main + Extras:** %s\n", p.formatTime(game.CompPlus)))
}
if game.CompComplete > 0 {
response.WriteString(fmt.Sprintf("💯 **Completionist:** %s\n", p.formatTime(game.CompComplete)))
}
if game.CompAll > 0 {
response.WriteString(fmt.Sprintf("🎯 **All Styles:** %s\n", p.formatTime(game.CompAll)))
}
// Add review score if available
if game.ReviewScore > 0 {
response.WriteString(fmt.Sprintf("\n⭐ **User Score:** %d/100", game.ReviewScore))
}
// Add source attribution
response.WriteString("\n\n*Source: HowLongToBeat.com*")
return response.String()
}
// formatTime converts seconds to a readable time format
func (p *HLTBPlugin) formatTime(seconds int) string {
if seconds <= 0 {
return "N/A"
}
hours := float64(seconds) / 3600.0
if hours < 1 {
minutes := seconds / 60
return fmt.Sprintf("%d minutes", minutes)
} else if hours < 2 {
return fmt.Sprintf("%.1f hour", hours)
} else {
return fmt.Sprintf("%.1f hours", hours)
}
}
// getFullImageURL constructs the full image URL
func (p *HLTBPlugin) getFullImageURL(imagePath string) string {
if imagePath == "" {
return ""
}
// Remove leading slash if present
imagePath = strings.TrimPrefix(imagePath, "/")
return fmt.Sprintf("https://howlongtobeat.com/games/%s", imagePath)
}
// getSeekToken attempts to retrieve the seek token from HowLongToBeat
func (p *HLTBPlugin) getSeekToken() (string, error) {
// Try to extract the seek token from the main page
req, err := http.NewRequest("GET", "https://howlongtobeat.com", nil)
if err != nil {
return "", fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36")
resp, err := p.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to fetch token: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read token response: %w", err)
}
// Look for patterns that might contain the token
patterns := []string{
`/api/seek/([a-f0-9]+)`,
`"seek/([a-f0-9]+)"`,
`seek/([a-f0-9]{12,})`,
}
bodyStr := string(body)
for _, pattern := range patterns {
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(bodyStr)
if len(matches) > 1 {
return matches[1], nil
}
}
// If we can't extract a token, return the known working one as fallback
return "d4b2e330db04dbf3", nil
}
// createErrorResponse creates an error response message
func (p *HLTBPlugin) createErrorResponse(msg *model.Message, errorText string) []*model.MessageAction {
response := &model.Message{
Text: fmt.Sprintf("❌ %s", errorText),
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
}
action := &model.MessageAction{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
}

View file

@ -0,0 +1,131 @@
package fun
import (
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
)
func TestHLTBPlugin_OnMessage(t *testing.T) {
plugin := NewHLTB()
tests := []struct {
name string
messageText string
shouldRespond bool
}{
{
name: "responds to !hltb command",
messageText: "!hltb The Witcher 3",
shouldRespond: true,
},
{
name: "ignores non-hltb messages",
messageText: "hello world",
shouldRespond: false,
},
{
name: "ignores !hltb without game name",
messageText: "!hltb",
shouldRespond: false,
},
{
name: "ignores !hltb with only spaces",
messageText: "!hltb ",
shouldRespond: false,
},
{
name: "ignores similar but incorrect commands",
messageText: "hltb The Witcher 3",
shouldRespond: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := &model.Message{
Text: tt.messageText,
Chat: "test-chat",
Channel: &model.Channel{ID: 1},
Author: "test-user",
}
mockCache := &testutil.MockCache{}
actions := plugin.OnMessage(msg, make(map[string]interface{}), mockCache)
if tt.shouldRespond && len(actions) == 0 {
t.Errorf("Expected plugin to respond to '%s', but it didn't", tt.messageText)
}
if !tt.shouldRespond && len(actions) > 0 {
t.Errorf("Expected plugin to not respond to '%s', but it did", tt.messageText)
}
// For messages that should respond, verify the response structure
if tt.shouldRespond && len(actions) > 0 {
action := actions[0]
if action.Type != model.ActionSendMessage {
t.Errorf("Expected ActionSendMessage, got %s", action.Type)
}
if action.Message == nil {
t.Error("Expected action to have a message")
}
if action.Message != nil && action.Message.ReplyTo != msg.ID {
t.Error("Expected response to reply to original message")
}
}
})
}
}
func TestHLTBPlugin_formatTime(t *testing.T) {
plugin := NewHLTB()
tests := []struct {
seconds int
expected string
}{
{0, "N/A"},
{-1, "N/A"},
{1800, "30 minutes"}, // 30 minutes
{3600, "1.0 hour"}, // 1 hour
{7200, "2.0 hours"}, // 2 hours
{10800, "3.0 hours"}, // 3 hours
{36000, "10.0 hours"}, // 10 hours
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := plugin.formatTime(tt.seconds)
if result != tt.expected {
t.Errorf("formatTime(%d) = %s, want %s", tt.seconds, result, tt.expected)
}
})
}
}
func TestHLTBPlugin_getFullImageURL(t *testing.T) {
plugin := NewHLTB()
tests := []struct {
imagePath string
expected string
}{
{"", ""},
{"game.jpg", "https://howlongtobeat.com/games/game.jpg"},
{"/game.jpg", "https://howlongtobeat.com/games/game.jpg"},
{"folder/game.png", "https://howlongtobeat.com/games/folder/game.png"},
}
for _, tt := range tests {
t.Run(tt.imagePath, func(t *testing.T) {
result := plugin.getFullImageURL(tt.imagePath)
if result != tt.expected {
t.Errorf("getFullImageURL(%s) = %s, want %s", tt.imagePath, result, tt.expected)
}
})
}
}

View file

@ -23,8 +23,13 @@ func NewLoquito() *LoquitoPlugin {
} }
} }
// GetHelp returns the plugin help text
func (p *LoquitoPlugin) GetHelp() string {
return ""
}
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
if !strings.Contains(strings.ToLower(msg.Text), "lo quito") { if !strings.Contains(strings.ToLower(msg.Text), "lo quito") {
return nil return nil
} }

View file

@ -0,0 +1,166 @@
package help
import (
"fmt"
"sort"
"strings"
"git.nakama.town/fmartingr/butterrobot/internal/db"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
"golang.org/x/exp/slog"
)
// ChannelPluginGetter is an interface for getting channel plugins
type ChannelPluginGetter interface {
GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error)
GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error)
}
// HelpPlugin provides help information about available commands
type HelpPlugin struct {
plugin.BasePlugin
db ChannelPluginGetter
}
// New creates a new HelpPlugin instance
func New(db ChannelPluginGetter) *HelpPlugin {
return &HelpPlugin{
BasePlugin: plugin.BasePlugin{
ID: "utility.help",
Name: "Help",
Help: "Shows available commands when you type '!help'",
},
db: db,
}
}
// OnMessage handles incoming messages
func (p *HelpPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Check if message is the help command
if !strings.EqualFold(strings.TrimSpace(msg.Text), "!help") {
return nil
}
// Get channel plugins from database using platform and platform channel ID
channelPlugins, err := p.db.GetChannelPluginsFromPlatformID(msg.Channel.Platform, msg.Channel.PlatformChannelID)
if err != nil && err != db.ErrNotFound {
slog.Error("Failed to get channel plugins", slog.Any("err", err))
return []*model.MessageAction{}
}
// If no plugins found, initialize empty slice
if err == db.ErrNotFound {
channelPlugins = []*model.ChannelPlugin{}
}
// Get all available plugins
availablePlugins := plugin.GetAvailablePlugins()
// Filter to only enabled plugins for this channel
enabledPlugins := make(map[string]model.Plugin)
for _, channelPlugin := range channelPlugins {
if channelPlugin.Enabled {
if availablePlugin, exists := availablePlugins[channelPlugin.PluginID]; exists {
enabledPlugins[channelPlugin.PluginID] = availablePlugin
}
}
}
// If no plugins are enabled, return a message
if len(enabledPlugins) == 0 {
response := &model.Message{
Text: "No plugins are currently enabled for this channel.",
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
Raw: map[string]interface{}{"parse_mode": "Markdown"},
}
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
// Group plugins by category
categories := map[string][]model.Plugin{
"Development": {},
"Fun and Entertainment": {},
"Utility": {},
"Security": {},
"Social Media": {},
"Other": {},
}
// Categorize plugins based on their ID prefix
for _, p := range enabledPlugins {
category := p.GetID()
switch {
case strings.HasPrefix(category, "dev."):
categories["Development"] = append(categories["Development"], p)
case strings.HasPrefix(category, "fun."):
categories["Fun and Entertainment"] = append(categories["Fun and Entertainment"], p)
case strings.HasPrefix(category, "util.") || strings.HasPrefix(category, "reminder.") || strings.HasPrefix(category, "utility."):
categories["Utility"] = append(categories["Utility"], p)
case strings.HasPrefix(category, "security."):
categories["Security"] = append(categories["Security"], p)
case strings.HasPrefix(category, "social."):
categories["Social Media"] = append(categories["Social Media"], p)
default:
categories["Other"] = append(categories["Other"], p)
}
}
// Build the help message
var helpText strings.Builder
helpText.WriteString("🤖 **Available Commands**\n\n")
// Sort category names for consistent output
categoryOrder := []string{"Development", "Fun and Entertainment", "Utility", "Security", "Social Media", "Other"}
for _, categoryName := range categoryOrder {
pluginList := categories[categoryName]
if len(pluginList) == 0 {
continue
}
// Sort plugins within category by name
sort.Slice(pluginList, func(i, j int) bool {
return pluginList[i].GetName() < pluginList[j].GetName()
})
helpText.WriteString(fmt.Sprintf("**%s:**\n", categoryName))
for _, p := range pluginList {
if p.GetHelp() == "" {
continue
}
helpText.WriteString(fmt.Sprintf("• **%s** - %s\n", p.GetName(), p.GetHelp()))
}
helpText.WriteString("\n")
}
// Add footer
helpText.WriteString("_Use the specific commands or triggers mentioned above to interact with the bot._")
response := &model.Message{
Text: helpText.String(),
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
Raw: map[string]interface{}{"parse_mode": "Markdown"},
}
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}

View file

@ -0,0 +1,206 @@
package help
import (
"strings"
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/db"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// MockPlugin implements the Plugin interface for testing
type MockPlugin struct {
id string
name string
help string
}
func (m *MockPlugin) GetID() string { return m.id }
func (m *MockPlugin) GetName() string { return m.name }
func (m *MockPlugin) GetHelp() string { return m.help }
func (m *MockPlugin) RequiresConfig() bool {
return false
}
func (m *MockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
return nil
}
// MockDatabase implements the ChannelPluginGetter interface for testing
type MockDatabase struct {
channelPlugins map[int64][]*model.ChannelPlugin
platformChannelPlugins map[string][]*model.ChannelPlugin // key: "platform:platformChannelID"
}
func (m *MockDatabase) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) {
if plugins, exists := m.channelPlugins[channelID]; exists {
return plugins, nil
}
return nil, db.ErrNotFound
}
func (m *MockDatabase) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) {
key := platform + ":" + platformChannelID
if plugins, exists := m.platformChannelPlugins[key]; exists {
return plugins, nil
}
return nil, db.ErrNotFound
}
func TestHelpPlugin_OnMessage(t *testing.T) {
tests := []struct {
name string
messageText string
enabledPlugins map[string]*MockPlugin
expectResponse bool
expectNoPlugins bool
expectCategories []string
}{
{
name: "responds to !help command",
messageText: "!help",
enabledPlugins: map[string]*MockPlugin{
"dev.ping": {
id: "dev.ping",
name: "Ping",
help: "Responds to 'ping' with 'pong'",
},
"fun.dice": {
id: "fun.dice",
name: "Dice Roller",
help: "Rolls dice when you type '!dice [formula]'",
},
},
expectResponse: true,
expectCategories: []string{"Development", "Fun and Entertainment"},
},
{
name: "ignores non-help messages",
messageText: "hello world",
enabledPlugins: map[string]*MockPlugin{},
expectResponse: false,
},
{
name: "ignores case variation",
messageText: "!HELP",
enabledPlugins: map[string]*MockPlugin{},
expectResponse: true,
expectNoPlugins: true,
},
{
name: "handles no enabled plugins",
messageText: "!help",
enabledPlugins: map[string]*MockPlugin{},
expectResponse: true,
expectNoPlugins: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock database
mockDB := &MockDatabase{
channelPlugins: make(map[int64][]*model.ChannelPlugin),
platformChannelPlugins: make(map[string][]*model.ChannelPlugin),
}
// Setup channel plugins in mock database
var channelPluginList []*model.ChannelPlugin
pluginCounter := int64(1)
for pluginID := range tt.enabledPlugins {
channelPluginList = append(channelPluginList, &model.ChannelPlugin{
ID: pluginCounter,
ChannelID: 1,
PluginID: pluginID,
Enabled: true,
Config: make(map[string]interface{}),
})
pluginCounter++
}
// Set up both mapping approaches for the test
mockDB.channelPlugins[1] = channelPluginList
mockDB.platformChannelPlugins["test:test-channel"] = channelPluginList
// Create help plugin
p := New(mockDB)
// Create mock channel
channel := &model.Channel{
ID: 1,
Platform: "test",
PlatformChannelID: "test-channel",
}
// Create test message
msg := &model.Message{
ID: "test-msg",
Text: tt.messageText,
Chat: "test-chat",
Channel: channel,
}
// Mock the plugin registry
originalRegistry := plugin.GetAvailablePlugins()
// Override the registry for this test
plugin.ClearRegistry()
for _, mockPlugin := range tt.enabledPlugins {
plugin.Register(mockPlugin)
}
// Call OnMessage
actions := p.OnMessage(msg, map[string]interface{}{}, nil)
// Restore original registry
plugin.ClearRegistry()
for _, p := range originalRegistry {
plugin.Register(p)
}
if !tt.expectResponse {
if len(actions) != 0 {
t.Errorf("Expected no response, but got %d actions", len(actions))
}
return
}
if len(actions) != 1 {
t.Errorf("Expected 1 action, got %d", len(actions))
return
}
action := actions[0]
if action.Type != model.ActionSendMessage {
t.Errorf("Expected ActionSendMessage, got %v", action.Type)
return
}
responseText := action.Message.Text
if tt.expectNoPlugins {
if !strings.Contains(responseText, "No plugins are currently enabled") {
t.Errorf("Expected 'no plugins' message, got: %s", responseText)
}
return
}
// Check that expected categories appear in response
for _, category := range tt.expectCategories {
if !strings.Contains(responseText, "**"+category+":**") {
t.Errorf("Expected category '%s' in response, got: %s", category, responseText)
}
}
// Check that plugin names and help text appear
for _, mockPlugin := range tt.enabledPlugins {
if !strings.Contains(responseText, mockPlugin.GetName()) {
t.Errorf("Expected plugin name '%s' in response", mockPlugin.GetName())
}
if !strings.Contains(responseText, mockPlugin.GetHelp()) {
t.Errorf("Expected plugin help '%s' in response", mockPlugin.GetHelp())
}
}
})
}
}

View file

@ -24,7 +24,7 @@ func New() *PingPlugin {
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") { if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") {
return nil return nil
} }

View file

@ -47,6 +47,26 @@ func GetAvailablePlugins() map[string]model.Plugin {
return result return result
} }
// GetAvailablePluginIDs returns a slice of all registered plugin IDs
func GetAvailablePluginIDs() []string {
pluginsMu.RLock()
defer pluginsMu.RUnlock()
result := make([]string, 0, len(plugins))
for pluginID := range plugins {
result = append(result, pluginID)
}
return result
}
// ClearRegistry clears all registered plugins (for testing)
func ClearRegistry() {
pluginsMu.Lock()
defer pluginsMu.Unlock()
plugins = make(map[string]model.Plugin)
}
// BasePlugin provides a common base for plugins // BasePlugin provides a common base for plugins
type BasePlugin struct { type BasePlugin struct {
ID string ID string
@ -76,6 +96,6 @@ func (p *BasePlugin) RequiresConfig() bool {
} }
// OnMessage is the default implementation that does nothing // OnMessage is the default implementation that does nothing
func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
return nil return nil
} }

View file

@ -0,0 +1,331 @@
package plugin
import (
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/model"
)
// Mock plugin for testing
type testPlugin struct {
BasePlugin
}
func (p *testPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: &model.Message{
Text: "test response",
Chat: msg.Chat,
Channel: msg.Channel,
},
},
}
}
func TestGetAvailablePluginIDs(t *testing.T) {
// Clear registry before test
ClearRegistry()
// Register test plugins
testPlugin1 := &testPlugin{
BasePlugin: BasePlugin{
ID: "test.plugin1",
Name: "Test Plugin 1",
},
}
testPlugin2 := &testPlugin{
BasePlugin: BasePlugin{
ID: "test.plugin2",
Name: "Test Plugin 2",
},
}
Register(testPlugin1)
Register(testPlugin2)
// Test GetAvailablePluginIDs
pluginIDs := GetAvailablePluginIDs()
if len(pluginIDs) != 2 {
t.Errorf("Expected 2 plugin IDs, got %d", len(pluginIDs))
}
// Check that both plugin IDs are present
found1, found2 := false, false
for _, id := range pluginIDs {
if id == "test.plugin1" {
found1 = true
}
if id == "test.plugin2" {
found2 = true
}
}
if !found1 {
t.Errorf("Expected to find test.plugin1 in plugin IDs")
}
if !found2 {
t.Errorf("Expected to find test.plugin2 in plugin IDs")
}
}
func TestEnableAllPluginsProcessingLogic(t *testing.T) {
// Clear registry before test
ClearRegistry()
// Register test plugins
testPlugin1 := &testPlugin{
BasePlugin: BasePlugin{
ID: "ping",
Name: "Ping Plugin",
},
}
testPlugin2 := &testPlugin{
BasePlugin: BasePlugin{
ID: "echo",
Name: "Echo Plugin",
},
}
testPlugin3 := &testPlugin{
BasePlugin: BasePlugin{
ID: "help",
Name: "Help Plugin",
},
}
Register(testPlugin1)
Register(testPlugin2)
Register(testPlugin3)
t.Run("EnableAllPlugins false - only explicitly enabled plugins", func(t *testing.T) {
// Create a channel with EnableAllPlugins = false and only some plugins enabled
channel := &model.Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: false,
Plugins: map[string]*model.ChannelPlugin{
"ping": {
ID: 1,
ChannelID: 1,
PluginID: "ping",
Enabled: true,
Config: map[string]interface{}{"key": "value"},
},
"echo": {
ID: 2,
ChannelID: 1,
PluginID: "echo",
Enabled: false, // Disabled
Config: map[string]interface{}{},
},
// help plugin not configured
},
}
// Simulate the plugin processing logic from handleMessage
var pluginsToProcess []string
if channel.EnableAllPlugins {
pluginsToProcess = GetAvailablePluginIDs()
} else {
for pluginID := range channel.Plugins {
if channel.HasEnabledPlugin(pluginID) {
pluginsToProcess = append(pluginsToProcess, pluginID)
}
}
}
// Should only have "ping" since echo is disabled and help is not configured
if len(pluginsToProcess) != 1 {
t.Errorf("Expected 1 plugin to process, got %d: %v", len(pluginsToProcess), pluginsToProcess)
}
if len(pluginsToProcess) > 0 && pluginsToProcess[0] != "ping" {
t.Errorf("Expected ping plugin to be processed, got %s", pluginsToProcess[0])
}
})
t.Run("EnableAllPlugins true - all registered plugins", func(t *testing.T) {
// Create a channel with EnableAllPlugins = true
channel := &model.Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*model.ChannelPlugin{
"ping": {
ID: 1,
ChannelID: 1,
PluginID: "ping",
Enabled: true,
Config: map[string]interface{}{"key": "value"},
},
"echo": {
ID: 2,
ChannelID: 1,
PluginID: "echo",
Enabled: false, // Disabled, but should still be processed
Config: map[string]interface{}{},
},
// help plugin not configured, but should still be processed
},
}
// Simulate the plugin processing logic from handleMessage
var pluginsToProcess []string
if channel.EnableAllPlugins {
pluginsToProcess = GetAvailablePluginIDs()
} else {
for pluginID := range channel.Plugins {
if channel.HasEnabledPlugin(pluginID) {
pluginsToProcess = append(pluginsToProcess, pluginID)
}
}
}
// Should have all 3 registered plugins
if len(pluginsToProcess) != 3 {
t.Errorf("Expected 3 plugins to process, got %d: %v", len(pluginsToProcess), pluginsToProcess)
}
// Check that all plugins are included
expectedPlugins := map[string]bool{"ping": false, "echo": false, "help": false}
for _, pluginID := range pluginsToProcess {
if _, exists := expectedPlugins[pluginID]; exists {
expectedPlugins[pluginID] = true
} else {
t.Errorf("Unexpected plugin in processing list: %s", pluginID)
}
}
for pluginID, found := range expectedPlugins {
if !found {
t.Errorf("Expected plugin %s to be in processing list", pluginID)
}
}
})
t.Run("Plugin configuration handling", func(t *testing.T) {
// Test the configuration logic from handleMessage
channel := &model.Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*model.ChannelPlugin{
"ping": {
ID: 1,
ChannelID: 1,
PluginID: "ping",
Enabled: true,
Config: map[string]interface{}{"configured": "value"},
},
},
}
testCases := []struct {
pluginID string
expectedConfig map[string]interface{}
}{
{
pluginID: "ping",
expectedConfig: map[string]interface{}{"configured": "value"},
},
{
pluginID: "echo", // Not explicitly configured
expectedConfig: map[string]interface{}{},
},
}
for _, tc := range testCases {
// Simulate the config retrieval logic from handleMessage
var config map[string]interface{}
if channelPlugin, exists := channel.Plugins[tc.pluginID]; exists {
config = channelPlugin.Config
} else {
config = make(map[string]interface{})
}
if len(config) != len(tc.expectedConfig) {
t.Errorf("Plugin %s: expected config length %d, got %d", tc.pluginID, len(tc.expectedConfig), len(config))
}
for key, expectedValue := range tc.expectedConfig {
if actualValue, exists := config[key]; !exists || actualValue != expectedValue {
t.Errorf("Plugin %s: expected config[%s] = %v, got %v", tc.pluginID, key, expectedValue, actualValue)
}
}
}
})
}
func TestPluginRegistry(t *testing.T) {
// Clear registry before test
ClearRegistry()
testPlugin := &testPlugin{
BasePlugin: BasePlugin{
ID: "test.registry",
Name: "Test Registry Plugin",
},
}
t.Run("Register and Get plugin", func(t *testing.T) {
Register(testPlugin)
retrieved, err := Get("test.registry")
if err != nil {
t.Errorf("Failed to get registered plugin: %v", err)
}
if retrieved.GetID() != "test.registry" {
t.Errorf("Expected plugin ID 'test.registry', got '%s'", retrieved.GetID())
}
})
t.Run("Get nonexistent plugin", func(t *testing.T) {
_, err := Get("nonexistent.plugin")
if err == nil {
t.Errorf("Expected error when getting nonexistent plugin, got nil")
}
if err != model.ErrPluginNotFound {
t.Errorf("Expected ErrPluginNotFound, got %v", err)
}
})
t.Run("GetAvailablePlugins", func(t *testing.T) {
plugins := GetAvailablePlugins()
if len(plugins) != 1 {
t.Errorf("Expected 1 plugin in registry, got %d", len(plugins))
}
if plugin, exists := plugins["test.registry"]; !exists {
t.Errorf("Expected to find test.registry in available plugins")
} else if plugin.GetID() != "test.registry" {
t.Errorf("Expected plugin ID 'test.registry', got '%s'", plugin.GetID())
}
})
t.Run("ClearRegistry", func(t *testing.T) {
ClearRegistry()
plugins := GetAvailablePlugins()
if len(plugins) != 0 {
t.Errorf("Expected 0 plugins after clearing registry, got %d", len(plugins))
}
_, err := Get("test.registry")
if err == nil {
t.Errorf("Expected error when getting plugin after clearing registry, got nil")
}
})
}

View file

@ -41,7 +41,7 @@ func New(creator ReminderCreator) *Reminder {
} }
// OnMessage processes incoming messages // OnMessage processes incoming messages
func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Only process replies to messages // Only process replies to messages
if msg.ReplyTo == "" { if msg.ReplyTo == "" {
return nil return nil

View file

@ -5,6 +5,7 @@ import (
"time" "time"
"git.nakama.town/fmartingr/butterrobot/internal/model" "git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
) )
// MockCreator is a mock implementation of ReminderCreator for testing // MockCreator is a mock implementation of ReminderCreator for testing
@ -142,7 +143,8 @@ func TestReminderOnMessage(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
initialCount := len(creator.reminders) initialCount := len(creator.reminders)
actions := plugin.OnMessage(tt.message, nil) mockCache := &testutil.MockCache{}
actions := plugin.OnMessage(tt.message, nil, mockCache)
if tt.expectResponse && len(actions) == 0 { if tt.expectResponse && len(actions) == 0 {
t.Errorf("Expected response action, but got none") t.Errorf("Expected response action, but got none")

View file

@ -23,14 +23,14 @@ func New() *SearchReplacePlugin {
BasePlugin: plugin.BasePlugin{ BasePlugin: plugin.BasePlugin{
ID: "util.searchreplace", ID: "util.searchreplace",
Name: "Search and Replace", Name: "Search and Replace",
Help: "Reply to a message with a search and replace pattern (s/search/replace/[flags]) to create a modified message. " + Help: "Reply to a message with a search and replace pattern (`s/search/replace/[flags]`) to create a modified message. " +
"Supported flags: g (global), i (case insensitive)", "Supported flags: g (global), i (case insensitive)",
}, },
} }
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *SearchReplacePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *SearchReplacePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Only process replies to messages // Only process replies to messages
if msg.ReplyTo == "" { if msg.ReplyTo == "" {
return nil return nil

View file

@ -5,6 +5,7 @@ import (
"time" "time"
"git.nakama.town/fmartingr/butterrobot/internal/model" "git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
) )
func TestSearchReplace(t *testing.T) { func TestSearchReplace(t *testing.T) {
@ -84,7 +85,8 @@ func TestSearchReplace(t *testing.T) {
} }
// Process message // Process message
actions := p.OnMessage(msg, nil) mockCache := &testutil.MockCache{}
actions := p.OnMessage(msg, nil, mockCache)
// Check results // Check results
if tc.expectActions { if tc.expectActions {

View file

@ -20,18 +20,25 @@ func NewInstagramExpander() *InstagramExpander {
BasePlugin: plugin.BasePlugin{ BasePlugin: plugin.BasePlugin{
ID: "social.instagram", ID: "social.instagram",
Name: "Instagram Link Expander", Name: "Instagram Link Expander",
Help: "Automatically converts instagram.com links to ddinstagram.com links and removes tracking parameters", Help: "Automatically converts instagram.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: ddinstagram.com)",
ConfigRequired: true,
}, },
} }
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Skip empty messages // Skip empty messages
if strings.TrimSpace(msg.Text) == "" { if strings.TrimSpace(msg.Text) == "" {
return nil return nil
} }
// Get replacement domain from config, default to ddinstagram.com
replacementDomain := "ddinstagram.com"
if domain, ok := config["domain"].(string); ok && domain != "" {
replacementDomain = domain
}
// Regex to match instagram.com links // Regex to match instagram.com links
// Match both http://instagram.com and https://instagram.com formats // Match both http://instagram.com and https://instagram.com formats
// Also match www.instagram.com // Also match www.instagram.com
@ -42,7 +49,7 @@ func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]inte
return nil return nil
} }
// Replace instagram.com with ddinstagram.com in the message and clean query parameters // Replace instagram.com with configured domain in the message and clean query parameters
transformed := instagramRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { transformed := instagramRegex.ReplaceAllStringFunc(msg.Text, func(link string) string {
// Parse the URL // Parse the URL
parsedURL, err := url.Parse(link) parsedURL, err := url.Parse(link)
@ -51,13 +58,13 @@ func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]inte
return link return link
} }
// Ensure we don't change links that already come from ddinstagram.com // Ensure we don't change links that already come from the replacement domain
if parsedURL.Host != "instagram.com" && parsedURL.Host != "www.instagram.com" { if parsedURL.Host != "instagram.com" && parsedURL.Host != "www.instagram.com" {
return link return link
} }
// Change the host // Change the host to the configured domain
parsedURL.Host = "d.ddinstagram.com" parsedURL.Host = replacementDomain
// Remove query parameters // Remove query parameters
parsedURL.RawQuery = "" parsedURL.RawQuery = ""

View file

@ -20,18 +20,25 @@ func NewTwitterExpander() *TwitterExpander {
BasePlugin: plugin.BasePlugin{ BasePlugin: plugin.BasePlugin{
ID: "social.twitter", ID: "social.twitter",
Name: "Twitter Link Expander", Name: "Twitter Link Expander",
Help: "Automatically converts twitter.com links to fxtwitter.com links and removes tracking parameters", Help: "Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: fxtwitter.com)",
ConfigRequired: true,
}, },
} }
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Skip empty messages // Skip empty messages
if strings.TrimSpace(msg.Text) == "" { if strings.TrimSpace(msg.Text) == "" {
return nil return nil
} }
// Get replacement domain from config, default to fxtwitter.com
replacementDomain := "fxtwitter.com"
if domain, ok := config["domain"].(string); ok && domain != "" {
replacementDomain = domain
}
// Regex to match twitter.com links // Regex to match twitter.com links
// Match both http://twitter.com and https://twitter.com formats // Match both http://twitter.com and https://twitter.com formats
// Also match www.twitter.com // Also match www.twitter.com
@ -42,22 +49,17 @@ func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interf
return nil return nil
} }
// Replace twitter.com with fxtwitter.com in the message and clean query parameters // Replace twitter.com/x.com with configured domain in the message and clean query parameters
transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string {
// Parse the URL // Parse the URL
parsedURL, err := url.Parse(link) parsedURL, err := url.Parse(link)
if err != nil { if err != nil {
// If parsing fails, just do the simple replacement
link = strings.Replace(link, "twitter.com", "fxtwitter.com", 1)
link = strings.Replace(link, "x.com", "fxtwitter.com", 1)
return link return link
} }
// Change the host // Change the host to the configured domain
if strings.Contains(parsedURL.Host, "twitter.com") { if strings.Contains(parsedURL.Host, "twitter.com") || strings.Contains(parsedURL.Host, "x.com") {
parsedURL.Host = strings.Replace(parsedURL.Host, "twitter.com", "fxtwitter.com", 1) parsedURL.Host = replacementDomain
} else if strings.Contains(parsedURL.Host, "x.com") {
parsedURL.Host = strings.Replace(parsedURL.Host, "x.com", "fxtwitter.com", 1)
} }
// Remove query parameters // Remove query parameters

View file

@ -0,0 +1,120 @@
package social
import (
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/model"
)
func TestTwitterExpander_OnMessage(t *testing.T) {
plugin := NewTwitterExpander()
tests := []struct {
name string
input string
config map[string]interface{}
expected string
hasReply bool
}{
{
name: "Twitter URL with default domain",
input: "https://twitter.com/user/status/123456789",
config: map[string]interface{}{},
expected: "https://fxtwitter.com/user/status/123456789",
hasReply: true,
},
{
name: "X.com URL with custom domain",
input: "https://x.com/elonmusk/status/987654321",
config: map[string]interface{}{"domain": "vxtwitter.com"},
expected: "https://vxtwitter.com/elonmusk/status/987654321",
hasReply: true,
},
{
name: "Twitter URL with tracking parameters",
input: "https://twitter.com/openai/status/555?ref_src=twsrc%5Etfw&s=20",
config: map[string]interface{}{},
expected: "https://fxtwitter.com/openai/status/555",
hasReply: true,
},
{
name: "www.twitter.com URL",
input: "https://www.twitter.com/user/status/789",
config: map[string]interface{}{"domain": "nitter.net"},
expected: "https://nitter.net/user/status/789",
hasReply: true,
},
{
name: "Mixed text with Twitter URL",
input: "Check this out: https://twitter.com/user/status/123 amazing!",
config: map[string]interface{}{},
expected: "Check this out: https://fxtwitter.com/user/status/123 amazing!",
hasReply: true,
},
{
name: "No Twitter URLs",
input: "Just some regular text https://youtube.com/watch?v=abc",
config: map[string]interface{}{},
expected: "",
hasReply: false,
},
{
name: "Empty message",
input: "",
config: map[string]interface{}{},
expected: "",
hasReply: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := &model.Message{
ID: "test_msg",
Text: tt.input,
Chat: "test_chat",
Channel: &model.Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "test_chat",
},
}
actions := plugin.OnMessage(msg, tt.config, nil)
if !tt.hasReply {
if len(actions) != 0 {
t.Errorf("Expected no actions, got %d", len(actions))
}
return
}
if len(actions) != 1 {
t.Errorf("Expected 1 action, got %d", len(actions))
return
}
action := actions[0]
if action.Type != model.ActionSendMessage {
t.Errorf("Expected ActionSendMessage, got %s", action.Type)
}
if action.Message == nil {
t.Error("Expected message in action, got nil")
return
}
if action.Message.Text != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, action.Message.Text)
}
if action.Message.ReplyTo != msg.ID {
t.Errorf("Expected ReplyTo '%s', got '%s'", msg.ID, action.Message.ReplyTo)
}
if action.Message.Raw == nil || action.Message.Raw["parse_mode"] != "" {
t.Error("Expected parse_mode to be empty string to disable markdown parsing")
}
})
}
}

View file

@ -0,0 +1,29 @@
package testutil
import (
"errors"
"time"
)
// MockCache implements the CacheInterface for testing
type MockCache struct{}
func (m *MockCache) Get(key string, destination interface{}) error {
return errors.New("cache miss") // Always return cache miss for tests
}
func (m *MockCache) Set(key string, value interface{}, expiration *time.Time) error {
return nil // Always succeed for tests
}
func (m *MockCache) SetWithTTL(key string, value interface{}, ttl time.Duration) error {
return nil // Always succeed for tests
}
func (m *MockCache) Delete(key string) error {
return nil // Always succeed for tests
}
func (m *MockCache) Exists(key string) (bool, error) {
return false, nil // Always return false for tests
}