diff --git a/.gitignore b/.gitignore index d964ffb..9dab4b7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,9 +5,12 @@ __pycache__ *.cert .env-local .coverage +coverage.out dist bin # Butterrobot *.sqlite* -butterrobot.db +butterrobot.db* +/butterrobot +*_test.db* diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..cb4e70a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,29 @@ +# Claude Code Instructions + +## Plugin Development Workflow + +When creating, modifying, or removing plugins: + +1. **Always update the plugin documentation** in `docs/plugins.md` after any plugin changes +2. Ensure the documentation includes: + - Plugin name and category (Development, Fun and entertainment, Utility, Security, Social Media) + - Brief description of functionality + - Usage instructions with examples + - Any configuration requirements +3. **For plugins with configuration options:** + - Set `ConfigRequired: true` in the plugin's BasePlugin struct + - Add corresponding HTML form fields in `internal/admin/templates/channel_plugin_config.html` + - Use conditional template logic: `{{else if eq .ChannelPlugin.PluginID "plugin.id"}}` + - Include proper form labels, help text, and value binding + +## Testing + +**CRITICAL**: After making ANY changes to code files, you MUST run these commands in order: + +1. **Format code**: `make format` - Format all code according to project standards +2. **Lint code**: `make lint` - Check code style and quality (must show "0 issues") +3. **Run tests**: `make test` - Run all tests to ensure functionality works +4. Verify documentation accuracy +5. Ensure all examples work as described + +**These commands are MANDATORY after every code change, no exceptions.** diff --git a/docs/plugins.md b/docs/plugins.md index 25df16c..3472a08 100644 --- a/docs/plugins.md +++ b/docs/plugins.md @@ -9,10 +9,13 @@ - Lo quito: What happens when you say _"lo quito"_...? (Spanish pun) - Dice: Put `!dice` and wathever roll you want to perform. - Coin: Flip a coin and get heads or tails. +- How Long To Beat: Get game completion times from HowLongToBeat.com using `!hltb ` ### Utility +- Help: Shows available commands when you type `!help`. Lists all enabled plugins for the current channel organized by category with their descriptions and usage instructions. - Remind Me: Reply to a message with `!remindme ` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time. +- Search and Replace: Reply to any message with `s/search/replace/[flags]` to perform text substitution. Supports flags: `g` (global), `i` (case insensitive), `n` (regex pattern). Example: `s/hello/hi/gi` replaces all occurrences of "hello" with "hi" case-insensitively. ### Security @@ -20,5 +23,5 @@ ### Social Media -- Twitter Link Expander: Automatically converts twitter.com and x.com links to fxtwitter.com links and removes tracking parameters. This allows for better media embedding in chat platforms. -- Instagram Link Expander: Automatically converts instagram.com links to ddinstagram.com links and removes tracking parameters. This allows for better media embedding in chat platforms. +- Twitter Link Expander: Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: fxtwitter.com). +- Instagram Link Expander: Automatically converts instagram.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: ddinstagram.com). diff --git a/go.mod b/go.mod index cd1bee5..5127bfd 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/gorilla/sessions v1.4.0 golang.org/x/crypto v0.37.0 golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df + golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 modernc.org/sqlite v1.37.0 ) @@ -16,7 +17,6 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect golang.org/x/sys v0.32.0 // indirect modernc.org/libc v1.63.0 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/internal/admin/admin.go b/internal/admin/admin.go index 2b41820..abefb72 100644 --- a/internal/admin/admin.go +++ b/internal/admin/admin.go @@ -16,7 +16,7 @@ import ( "github.com/gorilla/sessions" ) -//go:embed templates/*.html +//go:embed templates/*.html templates/plugins/*.html var templateFS embed.FS const ( @@ -90,7 +90,7 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin { } // Parse and register all templates - templateFiles := []string{ + mainTemplateFiles := []string{ "index.html", "login.html", "change_password.html", @@ -101,7 +101,13 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin { "channel_plugin_config.html", } - for _, tf := range templateFiles { + pluginTemplateFiles := []string{ + "plugins/security.domainblock.html", + "plugins/social.instagram.html", + "plugins/social.twitter.html", + } + + for _, tf := range mainTemplateFiles { // Read template content from embedded filesystem content, err := templateFS.ReadFile("templates/" + tf) if err != nil { @@ -120,6 +126,20 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin { panic(err) } + // If this is the channel_plugin_config template, also parse plugin templates + if tf == "channel_plugin_config.html" { + for _, pluginTf := range pluginTemplateFiles { + pluginContent, err := templateFS.ReadFile("templates/" + pluginTf) + if err != nil { + panic(err) + } + t, err = t.Parse(string(pluginContent)) + if err != nil { + panic(err) + } + } + } + templates[tf] = t } @@ -544,6 +564,13 @@ func (a *Admin) handleChannelDetail(w http.ResponseWriter, r *http.Request) { return } + // Update enable_all_plugins + enableAllPlugins := r.FormValue("enable_all_plugins") == "true" + if err := a.db.UpdateChannelEnableAllPlugins(id, enableAllPlugins); err != nil { + http.Error(w, "Failed to update channel enable all plugins", http.StatusInternalServerError) + return + } + a.addFlash(w, r, "Channel updated", "success") http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther) return diff --git a/internal/admin/templates/channel_detail.html b/internal/admin/templates/channel_detail.html index 78909df..7e12d57 100644 --- a/internal/admin/templates/channel_detail.html +++ b/internal/admin/templates/channel_detail.html @@ -27,6 +27,15 @@ +
+ +
+ When enabled, all registered plugins will be automatically enabled for this channel. Individual plugin settings will be ignored. +
+
-{{end}} \ No newline at end of file +{{end}} diff --git a/internal/admin/templates/channel_plugin_config.html b/internal/admin/templates/channel_plugin_config.html index decf1a2..0f229f9 100644 --- a/internal/admin/templates/channel_plugin_config.html +++ b/internal/admin/templates/channel_plugin_config.html @@ -9,16 +9,11 @@
{{if eq .ChannelPlugin.PluginID "security.domainblock"}} -
- - -
- Enter comma-separated list of domains to block (e.g., example.com, evil.org). - Messages containing links to these domains will be blocked. -
-
+ {{template "plugins/security.domainblock.html" .}} + {{else if eq .ChannelPlugin.PluginID "social.instagram"}} + {{template "plugins/social.instagram.html" .}} + {{else if eq .ChannelPlugin.PluginID "social.twitter"}} + {{template "plugins/social.twitter.html" .}} {{else}}
This plugin doesn't have specific configuration fields implemented yet. diff --git a/internal/admin/templates/plugins/security.domainblock.html b/internal/admin/templates/plugins/security.domainblock.html new file mode 100644 index 0000000..7ffcc48 --- /dev/null +++ b/internal/admin/templates/plugins/security.domainblock.html @@ -0,0 +1,12 @@ +{{define "plugins/security.domainblock.html"}} +
+ + +
+ Enter comma-separated list of domains to block (e.g., example.com, evil.org). + Messages containing links to these domains will be blocked. +
+
+{{end}} \ No newline at end of file diff --git a/internal/admin/templates/plugins/social.instagram.html b/internal/admin/templates/plugins/social.instagram.html new file mode 100644 index 0000000..a83485d --- /dev/null +++ b/internal/admin/templates/plugins/social.instagram.html @@ -0,0 +1,11 @@ +{{define "plugins/social.instagram.html"}} +
+ + +
+ Enter the domain to replace instagram.com links with. Default is ddinstagram.com if left empty. +
+
+{{end}} \ No newline at end of file diff --git a/internal/admin/templates/plugins/social.twitter.html b/internal/admin/templates/plugins/social.twitter.html new file mode 100644 index 0000000..cb4885f --- /dev/null +++ b/internal/admin/templates/plugins/social.twitter.html @@ -0,0 +1,11 @@ +{{define "plugins/social.twitter.html"}} +
+ + +
+ Enter the domain to replace twitter.com and x.com links with. Default is fxtwitter.com if left empty. +
+
+{{end}} \ No newline at end of file diff --git a/internal/app/app.go b/internal/app/app.go index becd5ea..037089f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -15,6 +15,7 @@ import ( "time" "git.nakama.town/fmartingr/butterrobot/internal/admin" + "git.nakama.town/fmartingr/butterrobot/internal/cache" "git.nakama.town/fmartingr/butterrobot/internal/config" "git.nakama.town/fmartingr/butterrobot/internal/db" "git.nakama.town/fmartingr/butterrobot/internal/model" @@ -22,6 +23,7 @@ import ( "git.nakama.town/fmartingr/butterrobot/internal/plugin" "git.nakama.town/fmartingr/butterrobot/internal/plugin/domainblock" "git.nakama.town/fmartingr/butterrobot/internal/plugin/fun" + "git.nakama.town/fmartingr/butterrobot/internal/plugin/help" "git.nakama.town/fmartingr/butterrobot/internal/plugin/ping" "git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder" "git.nakama.town/fmartingr/butterrobot/internal/plugin/searchreplace" @@ -87,11 +89,13 @@ func (a *App) Run() error { plugin.Register(fun.NewCoin()) plugin.Register(fun.NewDice()) plugin.Register(fun.NewLoquito()) + plugin.Register(fun.NewHLTB()) plugin.Register(social.NewTwitterExpander()) plugin.Register(social.NewInstagramExpander()) plugin.Register(reminder.New(a.db)) plugin.Register(domainblock.New()) plugin.Register(searchreplace.New()) + plugin.Register(help.New(a.db)) // Initialize routes a.initializeRoutes() @@ -102,6 +106,9 @@ func (a *App) Run() error { // Start reminder scheduler a.queue.StartReminderScheduler(a.handleReminder) + // Start cache cleanup scheduler + go a.startCacheCleanup() + // Create server addr := fmt.Sprintf(":%s", a.config.Port) srv := &http.Server{ @@ -147,6 +154,20 @@ func (a *App) Run() error { return nil } +// startCacheCleanup runs periodic cache cleanup +func (a *App) startCacheCleanup() { + ticker := time.NewTicker(time.Hour) // Clean up every hour + defer ticker.Stop() + + for range ticker.C { + if err := a.db.CacheCleanup(); err != nil { + a.logger.Error("Cache cleanup failed", "error", err) + } else { + a.logger.Debug("Cache cleanup completed") + } + } +} + // Initialize HTTP routes func (a *App) initializeRoutes() { // Health check endpoint @@ -293,11 +314,21 @@ func (a *App) handleMessage(item queue.Item) { } // Process message with plugins - for pluginID, channelPlugin := range channel.Plugins { - if !channel.HasEnabledPlugin(pluginID) { - continue - } + var pluginsToProcess []string + if channel.EnableAllPlugins { + // If EnableAllPlugins is true, process all registered plugins + pluginsToProcess = plugin.GetAvailablePluginIDs() + } else { + // Otherwise, process only explicitly enabled plugins + for pluginID := range channel.Plugins { + if channel.HasEnabledPlugin(pluginID) { + pluginsToProcess = append(pluginsToProcess, pluginID) + } + } + } + + for _, pluginID := range pluginsToProcess { // Get plugin p, err := plugin.Get(pluginID) if err != nil { @@ -305,8 +336,19 @@ func (a *App) handleMessage(item queue.Item) { continue } + // Get plugin configuration (empty map if EnableAllPlugins and plugin not explicitly configured) + var config map[string]interface{} + if channelPlugin, exists := channel.Plugins[pluginID]; exists { + config = channelPlugin.Config + } else { + config = make(map[string]interface{}) + } + + // Create cache instance for this plugin + pluginCache := cache.New(a.db, pluginID) + // Process message and get actions - actions := p.OnMessage(message, channelPlugin.Config) + actions := p.OnMessage(message, config, pluginCache) // Get platform for processing actions platform, err := platform.Get(item.Platform) diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..9419c08 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,83 @@ +package cache + +import ( + "encoding/json" + "fmt" + "time" + + "git.nakama.town/fmartingr/butterrobot/internal/db" +) + +// Cache provides a plugin-friendly interface to the cache system +type Cache struct { + db *db.Database + pluginID string +} + +// New creates a new Cache instance for a specific plugin +func New(database *db.Database, pluginID string) *Cache { + return &Cache{ + db: database, + pluginID: pluginID, + } +} + +// Get retrieves a value from the cache +func (c *Cache) Get(key string, destination interface{}) error { + // Create prefixed key + fullKey := c.createKey(key) + + // Get from database + value, err := c.db.CacheGet(fullKey) + if err != nil { + return err + } + + // Unmarshal JSON into destination + return json.Unmarshal([]byte(value), destination) +} + +// Set stores a value in the cache with optional expiration +func (c *Cache) Set(key string, value interface{}, expiration *time.Time) error { + // Create prefixed key + fullKey := c.createKey(key) + + // Marshal value to JSON + jsonValue, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to marshal cache value: %w", err) + } + + // Store in database + return c.db.CacheSet(fullKey, string(jsonValue), expiration) +} + +// SetWithTTL stores a value in the cache with a time-to-live duration +func (c *Cache) SetWithTTL(key string, value interface{}, ttl time.Duration) error { + expiration := time.Now().Add(ttl) + return c.Set(key, value, &expiration) +} + +// Delete removes a value from the cache +func (c *Cache) Delete(key string) error { + fullKey := c.createKey(key) + return c.db.CacheDelete(fullKey) +} + +// Exists checks if a key exists in the cache +func (c *Cache) Exists(key string) (bool, error) { + fullKey := c.createKey(key) + _, err := c.db.CacheGet(fullKey) + if err == db.ErrNotFound { + return false, nil + } + if err != nil { + return false, err + } + return true, nil +} + +// createKey creates a prefixed cache key +func (c *Cache) createKey(key string) string { + return fmt.Sprintf("%s_%s", c.pluginID, key) +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..7038276 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,176 @@ +package cache + +import ( + "fmt" + "os" + "testing" + "time" + + "git.nakama.town/fmartingr/butterrobot/internal/db" +) + +func TestCache(t *testing.T) { + // Create temporary database for testing with unique name + dbFile := fmt.Sprintf("test_cache_%d.db", time.Now().UnixNano()) + database, err := db.New(dbFile) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer func() { + _ = database.Close() + // Clean up test database file + _ = os.Remove(dbFile) + }() + + // Create cache instance + cache := New(database, "test.plugin") + + // Test data + testKey := "test_key" + testValue := map[string]interface{}{ + "name": "Test Game", + "time": 42, + } + + // Test Set and Get + t.Run("Set and Get", func(t *testing.T) { + err := cache.Set(testKey, testValue, nil) + if err != nil { + t.Errorf("Failed to set cache value: %v", err) + } + + var retrieved map[string]interface{} + err = cache.Get(testKey, &retrieved) + if err != nil { + t.Errorf("Failed to get cache value: %v", err) + } + + if retrieved["name"] != testValue["name"] { + t.Errorf("Expected name %v, got %v", testValue["name"], retrieved["name"]) + } + + if int(retrieved["time"].(float64)) != testValue["time"].(int) { + t.Errorf("Expected time %v, got %v", testValue["time"], retrieved["time"]) + } + }) + + // Test SetWithTTL and expiration + t.Run("SetWithTTL and expiration", func(t *testing.T) { + expiredKey := "expired_key" + + // Set with very short TTL + err := cache.SetWithTTL(expiredKey, testValue, time.Millisecond) + if err != nil { + t.Errorf("Failed to set cache value with TTL: %v", err) + } + + // Wait for expiration + time.Sleep(2 * time.Millisecond) + + // Try to get - should fail + var retrieved map[string]interface{} + err = cache.Get(expiredKey, &retrieved) + if err == nil { + t.Errorf("Expected cache miss for expired key, but got value") + } + }) + + // Test Exists + t.Run("Exists", func(t *testing.T) { + existsKey := "exists_key" + + // Make sure the key doesn't exist initially by deleting it + _ = cache.Delete(existsKey) + + // Should not exist initially + exists, err := cache.Exists(existsKey) + if err != nil { + t.Errorf("Failed to check if key exists: %v", err) + } + if exists { + t.Errorf("Expected key to not exist, but it does") + } + + // Set value + err = cache.Set(existsKey, testValue, nil) + if err != nil { + t.Errorf("Failed to set cache value: %v", err) + } + + // Should exist now + exists, err = cache.Exists(existsKey) + if err != nil { + t.Errorf("Failed to check if key exists: %v", err) + } + if !exists { + t.Errorf("Expected key to exist, but it doesn't") + } + }) + + // Test Delete + t.Run("Delete", func(t *testing.T) { + deleteKey := "delete_key" + + // Set value + err := cache.Set(deleteKey, testValue, nil) + if err != nil { + t.Errorf("Failed to set cache value: %v", err) + } + + // Delete value + err = cache.Delete(deleteKey) + if err != nil { + t.Errorf("Failed to delete cache value: %v", err) + } + + // Should not exist anymore + var retrieved map[string]interface{} + err = cache.Get(deleteKey, &retrieved) + if err == nil { + t.Errorf("Expected cache miss for deleted key, but got value") + } + }) + + // Test plugin ID prefixing + t.Run("Plugin ID prefixing", func(t *testing.T) { + cache1 := New(database, "plugin1") + cache2 := New(database, "plugin2") + + sameKey := "same_key" + value1 := "value1" + value2 := "value2" + + // Set same key in both caches + err := cache1.Set(sameKey, value1, nil) + if err != nil { + t.Errorf("Failed to set cache1 value: %v", err) + } + + err = cache2.Set(sameKey, value2, nil) + if err != nil { + t.Errorf("Failed to set cache2 value: %v", err) + } + + // Retrieve from both caches + var retrieved1, retrieved2 string + + err = cache1.Get(sameKey, &retrieved1) + if err != nil { + t.Errorf("Failed to get cache1 value: %v", err) + } + + err = cache2.Get(sameKey, &retrieved2) + if err != nil { + t.Errorf("Failed to get cache2 value: %v", err) + } + + // Values should be different due to plugin ID prefixing + if retrieved1 != value1 { + t.Errorf("Expected cache1 value %v, got %v", value1, retrieved1) + } + + if retrieved2 != value2 { + t.Errorf("Expected cache2 value %v, got %v", value2, retrieved2) + } + }) +} diff --git a/internal/db/db.go b/internal/db/db.go index 0da285e..1c54ad4 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -35,6 +35,11 @@ func New(dbPath string) (*Database, error) { return nil, err } + // Configure SQLite for better reliability + if err := configureSQLite(db); err != nil { + return nil, err + } + // Initialize database if err := initDatabase(db); err != nil { return nil, err @@ -51,7 +56,7 @@ func (d *Database) Close() error { // GetChannelByID retrieves a channel by ID func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { query := ` - SELECT id, platform, platform_channel_id, enabled, channel_raw + SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw FROM channels WHERE id = ? ` @@ -62,10 +67,11 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { platform string platformChannelID string enabled bool + enableAllPlugins bool channelRawJSON string ) - err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) + err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON) if err == sql.ErrNoRows { return nil, ErrNotFound } @@ -85,6 +91,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, + EnableAllPlugins: enableAllPlugins, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } @@ -105,7 +112,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { // GetChannelByPlatform retrieves a channel by platform and platform channel ID func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) { query := ` - SELECT id, platform, platform_channel_id, enabled, channel_raw + SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw FROM channels WHERE platform = ? AND platform_channel_id = ? ` @@ -113,12 +120,13 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo row := d.db.QueryRow(query, platform, platformChannelID) var ( - id int64 - enabled bool - channelRawJSON string + id int64 + enabled bool + enableAllPlugins bool + channelRawJSON string ) - err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) + err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON) if err == sql.ErrNoRows { return nil, ErrNotFound } @@ -138,6 +146,7 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, + EnableAllPlugins: enableAllPlugins, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } @@ -165,11 +174,11 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo // Insert channel query := ` - INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw) - VALUES (?, ?, ?, ?) + INSERT INTO channels (platform, platform_channel_id, enabled, enable_all_plugins, channel_raw) + VALUES (?, ?, ?, ?, ?) ` - result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON)) + result, err := d.db.Exec(query, platform, platformChannelID, enabled, false, string(channelRawJSON)) if err != nil { return nil, err } @@ -186,6 +195,7 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, + EnableAllPlugins: false, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } @@ -205,6 +215,18 @@ func (d *Database) UpdateChannel(id int64, enabled bool) error { return err } +// UpdateChannelEnableAllPlugins updates a channel's enable_all_plugins status +func (d *Database) UpdateChannelEnableAllPlugins(id int64, enableAllPlugins bool) error { + query := ` + UPDATE channels + SET enable_all_plugins = ? + WHERE id = ? + ` + + _, err := d.db.Exec(query, enableAllPlugins, id) + return err +} + // DeleteChannel deletes a channel func (d *Database) DeleteChannel(id int64) error { // First delete all channel plugins @@ -256,7 +278,7 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e } // Parse config JSON - var config map[string]interface{} + var config map[string]any if err := json.Unmarshal([]byte(configJSON), &config); err != nil { return nil, err } @@ -283,6 +305,28 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e return plugins, nil } +// GetChannelPluginsFromPlatformID retrieves all plugins for a channel by platform and platform channel ID +func (d *Database) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) { + // First, get the channel ID by platform and platform channel ID + query := ` + SELECT id + FROM channels + WHERE platform = ? AND platform_channel_id = ? + ` + + var channelID int64 + err := d.db.QueryRow(query, platform, platformChannelID).Scan(&channelID) + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + if err != nil { + return nil, err + } + + // Now get the plugins for this channel + return d.GetChannelPlugins(channelID) +} + // GetChannelPluginByID retrieves a channel plugin by ID func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) { query := ` @@ -429,7 +473,7 @@ func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error { // GetAllChannels retrieves all channels func (d *Database) GetAllChannels() ([]*model.Channel, error) { query := ` - SELECT id, platform, platform_channel_id, enabled, channel_raw + SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw FROM channels ` @@ -451,10 +495,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { platform string platformChannelID string enabled bool + enableAllPlugins bool channelRawJSON string ) - if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil { + if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON); err != nil { return nil, err } @@ -470,6 +515,7 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { Platform: platform, PlatformChannelID: platformChannelID, Enabled: enabled, + EnableAllPlugins: enableAllPlugins, ChannelRaw: channelRaw, Plugins: make(map[string]*model.ChannelPlugin), } @@ -621,8 +667,8 @@ func (d *Database) UpdateUserPassword(userID int64, newPassword string) error { func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) { query := ` INSERT INTO reminders ( - platform, channel_id, message_id, reply_to_id, - user_id, username, created_at, trigger_at, + platform, channel_id, message_id, reply_to_id, + user_id, username, created_at, trigger_at, content, processed ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0) ` @@ -661,7 +707,7 @@ func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, use // GetPendingReminders gets all pending reminders that need to be processed func (d *Database) GetPendingReminders() ([]*model.Reminder, error) { query := ` - SELECT id, platform, channel_id, message_id, reply_to_id, + SELECT id, platform, channel_id, message_id, reply_to_id, user_id, username, created_at, trigger_at, content, processed FROM reminders WHERE processed = 0 AND trigger_at <= ? @@ -793,3 +839,82 @@ func initDatabase(db *sql.DB) error { return nil } + +// Configure SQLite for better reliability +func configureSQLite(db *sql.DB) error { + pragmas := []string{ + // Enable Write-Ahead Logging for better concurrency and crash recovery + "PRAGMA journal_mode = WAL", + // Set 5-second timeout when database is locked by another connection + "PRAGMA busy_timeout = 5000", + // Balance between safety and performance for disk writes + "PRAGMA synchronous = NORMAL", + // Set large cache size (1GB) for better read performance + "PRAGMA cache_size = 1000000000", + // Enable foreign key constraint enforcement + "PRAGMA foreign_keys = true", + // Store temporary tables and indices in memory for speed + "PRAGMA temp_store = memory", + } + + for _, pragma := range pragmas { + if _, err := db.Exec(pragma); err != nil { + return fmt.Errorf("failed to execute %s: %w", pragma, err) + } + } + + return nil +} + +// CacheGet retrieves a value from the cache +func (d *Database) CacheGet(key string) (string, error) { + query := ` + SELECT value + FROM cache + WHERE key = ? AND (expires_at IS NULL OR expires_at > ?) + ` + + var value string + err := d.db.QueryRow(query, key, time.Now()).Scan(&value) + if err == sql.ErrNoRows { + return "", ErrNotFound + } + if err != nil { + return "", err + } + + return value, nil +} + +// CacheSet stores a value in the cache with optional expiration +func (d *Database) CacheSet(key, value string, expiration *time.Time) error { + query := ` + INSERT OR REPLACE INTO cache (key, value, expires_at, updated_at) + VALUES (?, ?, ?, ?) + ` + + _, err := d.db.Exec(query, key, value, expiration, time.Now()) + return err +} + +// CacheDelete removes a value from the cache +func (d *Database) CacheDelete(key string) error { + query := ` + DELETE FROM cache + WHERE key = ? + ` + + _, err := d.db.Exec(query, key) + return err +} + +// CacheCleanup removes expired cache entries +func (d *Database) CacheCleanup() error { + query := ` + DELETE FROM cache + WHERE expires_at IS NOT NULL AND expires_at <= ? + ` + + _, err := d.db.Exec(query, time.Now()) + return err +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 0000000..beb485d --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,203 @@ +package db + +import ( + "fmt" + "os" + "testing" + "time" + + "git.nakama.town/fmartingr/butterrobot/internal/model" +) + +func TestEnableAllPlugins(t *testing.T) { + // Create temporary database for testing with unique name + dbFile := fmt.Sprintf("test_db_%d.db", time.Now().UnixNano()) + database, err := New(dbFile) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer func() { + _ = database.Close() + // Clean up test database file + _ = os.Remove(dbFile) + }() + + t.Run("CreateChannel with EnableAllPlugins default false", func(t *testing.T) { + channelRaw := map[string]interface{}{ + "name": "test-channel", + } + + channel, err := database.CreateChannel("telegram", "123456", true, channelRaw) + if err != nil { + t.Fatalf("Failed to create channel: %v", err) + } + + if channel.EnableAllPlugins { + t.Errorf("Expected EnableAllPlugins to be false by default, got true") + } + + // Verify it's also false when retrieved from database + retrieved, err := database.GetChannelByID(channel.ID) + if err != nil { + t.Fatalf("Failed to retrieve channel: %v", err) + } + + if retrieved.EnableAllPlugins { + t.Errorf("Expected EnableAllPlugins to be false when retrieved from DB, got true") + } + }) + + t.Run("UpdateChannelEnableAllPlugins", func(t *testing.T) { + // Create a channel + channelRaw := map[string]interface{}{ + "name": "test-channel-2", + } + + channel, err := database.CreateChannel("telegram", "123457", true, channelRaw) + if err != nil { + t.Fatalf("Failed to create channel: %v", err) + } + + // Update EnableAllPlugins to true + err = database.UpdateChannelEnableAllPlugins(channel.ID, true) + if err != nil { + t.Fatalf("Failed to update EnableAllPlugins: %v", err) + } + + // Retrieve and verify + retrieved, err := database.GetChannelByID(channel.ID) + if err != nil { + t.Fatalf("Failed to retrieve channel: %v", err) + } + + if !retrieved.EnableAllPlugins { + t.Errorf("Expected EnableAllPlugins to be true after update, got false") + } + + // Update back to false + err = database.UpdateChannelEnableAllPlugins(channel.ID, false) + if err != nil { + t.Fatalf("Failed to update EnableAllPlugins back to false: %v", err) + } + + // Retrieve and verify again + retrieved, err = database.GetChannelByID(channel.ID) + if err != nil { + t.Fatalf("Failed to retrieve channel: %v", err) + } + + if retrieved.EnableAllPlugins { + t.Errorf("Expected EnableAllPlugins to be false after second update, got true") + } + }) + + t.Run("GetChannelByPlatform includes EnableAllPlugins", func(t *testing.T) { + // Create a channel + channelRaw := map[string]interface{}{ + "name": "test-channel-3", + } + + channel, err := database.CreateChannel("slack", "C123456", true, channelRaw) + if err != nil { + t.Fatalf("Failed to create channel: %v", err) + } + + // Enable all plugins + err = database.UpdateChannelEnableAllPlugins(channel.ID, true) + if err != nil { + t.Fatalf("Failed to update EnableAllPlugins: %v", err) + } + + // Retrieve by platform + retrieved, err := database.GetChannelByPlatform("slack", "C123456") + if err != nil { + t.Fatalf("Failed to retrieve channel by platform: %v", err) + } + + if !retrieved.EnableAllPlugins { + t.Errorf("Expected EnableAllPlugins to be true when retrieved by platform, got false") + } + }) + + t.Run("GetAllChannels includes EnableAllPlugins", func(t *testing.T) { + // Create multiple channels with different EnableAllPlugins settings + channelRaw1 := map[string]interface{}{"name": "channel-1"} + channelRaw2 := map[string]interface{}{"name": "channel-2"} + + channel1, err := database.CreateChannel("platform1", "ch1", true, channelRaw1) + if err != nil { + t.Fatalf("Failed to create channel1: %v", err) + } + + channel2, err := database.CreateChannel("platform2", "ch2", true, channelRaw2) + if err != nil { + t.Fatalf("Failed to create channel2: %v", err) + } + + // Enable all plugins for channel2 only + err = database.UpdateChannelEnableAllPlugins(channel2.ID, true) + if err != nil { + t.Fatalf("Failed to update EnableAllPlugins for channel2: %v", err) + } + + // Get all channels + channels, err := database.GetAllChannels() + if err != nil { + t.Fatalf("Failed to get all channels: %v", err) + } + + // Find our test channels + var foundChannel1, foundChannel2 *model.Channel + for _, ch := range channels { + if ch.ID == channel1.ID { + foundChannel1 = ch + } + if ch.ID == channel2.ID { + foundChannel2 = ch + } + } + + if foundChannel1 == nil { + t.Fatalf("Channel1 not found in GetAllChannels result") + } + if foundChannel2 == nil { + t.Fatalf("Channel2 not found in GetAllChannels result") + } + + if foundChannel1.EnableAllPlugins { + t.Errorf("Expected channel1 EnableAllPlugins to be false, got true") + } + if !foundChannel2.EnableAllPlugins { + t.Errorf("Expected channel2 EnableAllPlugins to be true, got false") + } + }) + + t.Run("Migration applied correctly", func(t *testing.T) { + // Test that we can create a channel and the enable_all_plugins column exists + // This implicitly tests that migration 4 was applied correctly + channelRaw := map[string]interface{}{ + "name": "migration-test-channel", + } + + channel, err := database.CreateChannel("test-platform", "migration-test", true, channelRaw) + if err != nil { + t.Fatalf("Failed to create channel after migration: %v", err) + } + + // Try to update EnableAllPlugins - this would fail if the column doesn't exist + err = database.UpdateChannelEnableAllPlugins(channel.ID, true) + if err != nil { + t.Fatalf("Failed to update EnableAllPlugins - migration may not have been applied: %v", err) + } + + // Verify the value was set correctly + retrieved, err := database.GetChannelByID(channel.ID) + if err != nil { + t.Fatalf("Failed to retrieve channel: %v", err) + } + + if !retrieved.EnableAllPlugins { + t.Errorf("EnableAllPlugins should be true after update") + } + }) +} diff --git a/internal/migration/migrations.go b/internal/migration/migrations.go index 8db229b..11aa716 100644 --- a/internal/migration/migrations.go +++ b/internal/migration/migrations.go @@ -9,6 +9,8 @@ func init() { // Register migrations Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown) Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown) + Register(3, "Add cache table", migrateCacheUp, migrateCacheDown) + Register(4, "Add enable_all_plugins column to channels", migrateEnableAllPluginsUp, migrateEnableAllPluginsDown) } // Initial schema creation with bcrypt passwords - version 1 @@ -126,3 +128,87 @@ func migrateRemindersDown(db *sql.DB) error { _, err := db.Exec(`DROP TABLE IF EXISTS reminders`) return err } + +// Add cache table - version 3 +func migrateCacheUp(db *sql.DB) error { + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS cache ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + expires_at TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + `) + if err != nil { + return err + } + + // Create index on expires_at for efficient cleanup + _, err = db.Exec(` + CREATE INDEX IF NOT EXISTS idx_cache_expires_at ON cache(expires_at) + `) + return err +} + +func migrateCacheDown(db *sql.DB) error { + _, err := db.Exec(`DROP TABLE IF EXISTS cache`) + return err +} + +// Add enable_all_plugins column to channels table - version 4 +func migrateEnableAllPluginsUp(db *sql.DB) error { + _, err := db.Exec(` + ALTER TABLE channels ADD COLUMN enable_all_plugins BOOLEAN NOT NULL DEFAULT 0 + `) + return err +} + +func migrateEnableAllPluginsDown(db *sql.DB) error { + // SQLite doesn't support DROP COLUMN, so we need to recreate the table + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + _ = tx.Rollback() // Ignore rollback errors + }() + + // Create backup table + _, err = tx.Exec(` + CREATE TABLE channels_backup ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + platform TEXT NOT NULL, + platform_channel_id TEXT NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT 0, + channel_raw TEXT NOT NULL, + UNIQUE(platform, platform_channel_id) + ) + `) + if err != nil { + return err + } + + // Copy data excluding enable_all_plugins column + _, err = tx.Exec(` + INSERT INTO channels_backup (id, platform, platform_channel_id, enabled, channel_raw) + SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels + `) + if err != nil { + return err + } + + // Drop original table + _, err = tx.Exec(`DROP TABLE channels`) + if err != nil { + return err + } + + // Rename backup table + _, err = tx.Exec(`ALTER TABLE channels_backup RENAME TO channels`) + if err != nil { + return err + } + + return tx.Commit() +} diff --git a/internal/model/message.go b/internal/model/message.go index 26ec5da..8b38830 100644 --- a/internal/model/message.go +++ b/internal/model/message.go @@ -44,11 +44,17 @@ type Channel struct { PlatformChannelID string ChannelRaw map[string]interface{} Enabled bool + EnableAllPlugins bool Plugins map[string]*ChannelPlugin } // HasEnabledPlugin checks if a plugin is enabled for this channel func (c *Channel) HasEnabledPlugin(pluginID string) bool { + // If EnableAllPlugins is true, all plugins are considered enabled + if c.EnableAllPlugins { + return true + } + plugin, exists := c.Plugins[pluginID] if !exists { return false diff --git a/internal/model/message_test.go b/internal/model/message_test.go new file mode 100644 index 0000000..d2dfedc --- /dev/null +++ b/internal/model/message_test.go @@ -0,0 +1,234 @@ +package model + +import ( + "testing" +) + +func TestChannel_HasEnabledPlugin(t *testing.T) { + t.Run("EnableAllPlugins false - plugin not in map", func(t *testing.T) { + channel := &Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: false, + Plugins: make(map[string]*ChannelPlugin), + } + + // Plugin not in map should return false + result := channel.HasEnabledPlugin("nonexistent.plugin") + if result { + t.Errorf("Expected HasEnabledPlugin to return false for nonexistent plugin, got true") + } + }) + + t.Run("EnableAllPlugins false - plugin disabled", func(t *testing.T) { + channel := &Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: false, + Plugins: map[string]*ChannelPlugin{ + "test.plugin": { + ID: 1, + ChannelID: 1, + PluginID: "test.plugin", + Enabled: false, + Config: make(map[string]any), + }, + }, + } + + // Disabled plugin should return false + result := channel.HasEnabledPlugin("test.plugin") + if result { + t.Errorf("Expected HasEnabledPlugin to return false for disabled plugin, got true") + } + }) + + t.Run("EnableAllPlugins false - plugin enabled", func(t *testing.T) { + channel := &Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: false, + Plugins: map[string]*ChannelPlugin{ + "test.plugin": { + ID: 1, + ChannelID: 1, + PluginID: "test.plugin", + Enabled: true, + Config: make(map[string]any), + }, + }, + } + + // Enabled plugin should return true + result := channel.HasEnabledPlugin("test.plugin") + if !result { + t.Errorf("Expected HasEnabledPlugin to return true for enabled plugin, got false") + } + }) + + t.Run("EnableAllPlugins true - plugin not in map", func(t *testing.T) { + channel := &Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: true, + Plugins: make(map[string]*ChannelPlugin), + } + + // When EnableAllPlugins is true, any plugin should be considered enabled + result := channel.HasEnabledPlugin("nonexistent.plugin") + if !result { + t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false") + } + }) + + t.Run("EnableAllPlugins true - plugin disabled", func(t *testing.T) { + channel := &Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: true, + Plugins: map[string]*ChannelPlugin{ + "test.plugin": { + ID: 1, + ChannelID: 1, + PluginID: "test.plugin", + Enabled: false, + Config: make(map[string]any), + }, + }, + } + + // When EnableAllPlugins is true, even disabled plugins should be considered enabled + result := channel.HasEnabledPlugin("test.plugin") + if !result { + t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true (even for disabled plugin), got false") + } + }) + + t.Run("EnableAllPlugins true - plugin enabled", func(t *testing.T) { + channel := &Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: true, + Plugins: map[string]*ChannelPlugin{ + "test.plugin": { + ID: 1, + ChannelID: 1, + PluginID: "test.plugin", + Enabled: true, + Config: make(map[string]any), + }, + }, + } + + // When EnableAllPlugins is true, enabled plugins should also return true + result := channel.HasEnabledPlugin("test.plugin") + if !result { + t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false") + } + }) + + t.Run("EnableAllPlugins true - multiple plugins", func(t *testing.T) { + channel := &Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: true, + Plugins: map[string]*ChannelPlugin{ + "plugin1": { + ID: 1, + ChannelID: 1, + PluginID: "plugin1", + Enabled: true, + Config: make(map[string]any), + }, + "plugin2": { + ID: 2, + ChannelID: 1, + PluginID: "plugin2", + Enabled: false, + Config: make(map[string]any), + }, + }, + } + + // All plugins should be enabled when EnableAllPlugins is true + testCases := []string{"plugin1", "plugin2", "plugin3", "any.plugin"} + for _, pluginID := range testCases { + result := channel.HasEnabledPlugin(pluginID) + if !result { + t.Errorf("Expected HasEnabledPlugin('%s') to return true when EnableAllPlugins is true, got false", pluginID) + } + } + }) +} + +func TestChannelName(t *testing.T) { + t.Run("Returns PlatformChannelID when ChannelRaw is nil", func(t *testing.T) { + channel := &Channel{ + PlatformChannelID: "test-id", + ChannelRaw: nil, + } + + result := channel.ChannelName() + if result != "test-id" { + t.Errorf("Expected channel name to be 'test-id', got '%s'", result) + } + }) + + t.Run("Returns name from ChannelRaw when available", func(t *testing.T) { + channel := &Channel{ + PlatformChannelID: "test-id", + ChannelRaw: map[string]interface{}{ + "name": "Test Channel", + }, + } + + result := channel.ChannelName() + if result != "Test Channel" { + t.Errorf("Expected channel name to be 'Test Channel', got '%s'", result) + } + }) + + t.Run("Returns title from nested chat object (Telegram style)", func(t *testing.T) { + channel := &Channel{ + PlatformChannelID: "test-id", + ChannelRaw: map[string]interface{}{ + "chat": map[string]interface{}{ + "title": "Telegram Group", + }, + }, + } + + result := channel.ChannelName() + if result != "Telegram Group" { + t.Errorf("Expected channel name to be 'Telegram Group', got '%s'", result) + } + }) + + t.Run("Falls back to PlatformChannelID when no valid name found", func(t *testing.T) { + channel := &Channel{ + PlatformChannelID: "fallback-id", + ChannelRaw: map[string]interface{}{ + "other_field": "value", + }, + } + + result := channel.ChannelName() + if result != "fallback-id" { + t.Errorf("Expected channel name to fallback to 'fallback-id', got '%s'", result) + } + }) +} diff --git a/internal/model/plugin.go b/internal/model/plugin.go index 03e4f96..4a1449f 100644 --- a/internal/model/plugin.go +++ b/internal/model/plugin.go @@ -2,8 +2,18 @@ package model import ( "errors" + "time" ) +// CacheInterface defines the cache interface available to plugins +type CacheInterface interface { + Get(key string, destination interface{}) error + Set(key string, value interface{}, expiration *time.Time) error + SetWithTTL(key string, value interface{}, ttl time.Duration) error + Delete(key string) error + Exists(key string) (bool, error) +} + var ( // ErrPluginNotFound is returned when a requested plugin doesn't exist ErrPluginNotFound = errors.New("plugin not found") @@ -24,5 +34,5 @@ type Plugin interface { RequiresConfig() bool // OnMessage processes an incoming message and returns platform actions - OnMessage(msg *Message, config map[string]interface{}) []*MessageAction + OnMessage(msg *Message, config map[string]interface{}, cache CacheInterface) []*MessageAction } diff --git a/internal/platform/telegram/telegram.go b/internal/platform/telegram/telegram.go index 8da4995..b015793 100644 --- a/internal/platform/telegram/telegram.go +++ b/internal/platform/telegram/telegram.go @@ -237,6 +237,15 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error { "text": msg.Text, } + // Set parse_mode based on plugin preference or default to empty string + if msg.Raw != nil && msg.Raw["parse_mode"] != nil { + // Plugin explicitly set parse_mode + payload["parse_mode"] = msg.Raw["parse_mode"] + } else { + // Default to empty string (no formatting) + payload["parse_mode"] = "" + } + // Add reply if needed if msg.ReplyTo != "" { replyToID, err := strconv.Atoi(msg.ReplyTo) diff --git a/internal/plugin/domainblock/domainblock.go b/internal/plugin/domainblock/domainblock.go index 5a44c49..1f8ff1e 100644 --- a/internal/plugin/domainblock/domainblock.go +++ b/internal/plugin/domainblock/domainblock.go @@ -65,7 +65,7 @@ func extractDomains(text string) []string { } // OnMessage processes incoming messages -func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { // Skip messages from bots if msg.FromBot { return nil diff --git a/internal/plugin/domainblock/domainblock_test.go b/internal/plugin/domainblock/domainblock_test.go index 1d65964..57e8833 100644 --- a/internal/plugin/domainblock/domainblock_test.go +++ b/internal/plugin/domainblock/domainblock_test.go @@ -4,6 +4,7 @@ import ( "testing" "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/testutil" ) func TestExtractDomains(t *testing.T) { @@ -124,7 +125,8 @@ func TestOnMessage(t *testing.T) { "blocked_domains": test.blockedDomains, } - responses := plugin.OnMessage(msg, config) + mockCache := &testutil.MockCache{} + responses := plugin.OnMessage(msg, config, mockCache) if test.expectBlocked { if len(responses) == 0 { diff --git a/internal/plugin/fun/coin.go b/internal/plugin/fun/coin.go index bd083d1..ab679ea 100644 --- a/internal/plugin/fun/coin.go +++ b/internal/plugin/fun/coin.go @@ -29,7 +29,7 @@ func NewCoin() *CoinPlugin { } // OnMessage handles incoming messages -func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") { return nil } diff --git a/internal/plugin/fun/dice.go b/internal/plugin/fun/dice.go index 8b13edb..6136097 100644 --- a/internal/plugin/fun/dice.go +++ b/internal/plugin/fun/dice.go @@ -32,7 +32,7 @@ func NewDice() *DicePlugin { } // OnMessage handles incoming messages -func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") { return nil } diff --git a/internal/plugin/fun/hltb.go b/internal/plugin/fun/hltb.go new file mode 100644 index 0000000..f94f2ba --- /dev/null +++ b/internal/plugin/fun/hltb.go @@ -0,0 +1,394 @@ +package fun + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "time" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// HLTBPlugin searches HowLongToBeat for game completion times +type HLTBPlugin struct { + plugin.BasePlugin + httpClient *http.Client +} + +// HLTBSearchRequest represents the search request payload +type HLTBSearchRequest struct { + SearchType string `json:"searchType"` + SearchTerms []string `json:"searchTerms"` + SearchPage int `json:"searchPage"` + Size int `json:"size"` + SearchOptions map[string]interface{} `json:"searchOptions"` + UseCache bool `json:"useCache"` +} + +// HLTBGame represents a game from HowLongToBeat +type HLTBGame struct { + ID int `json:"game_id"` + Name string `json:"game_name"` + GameAlias string `json:"game_alias"` + GameImage string `json:"game_image"` + CompMain int `json:"comp_main"` + CompPlus int `json:"comp_plus"` + CompComplete int `json:"comp_complete"` + CompAll int `json:"comp_all"` + InvestedCo int `json:"invested_co"` + InvestedMp int `json:"invested_mp"` + CountComp int `json:"count_comp"` + CountSpeedruns int `json:"count_speedruns"` + CountBacklog int `json:"count_backlog"` + CountReview int `json:"count_review"` + ReviewScore int `json:"review_score"` + CountPlaying int `json:"count_playing"` + CountRetired int `json:"count_retired"` +} + +// HLTBSearchResponse represents the search response +type HLTBSearchResponse struct { + Color string `json:"color"` + Title string `json:"title"` + Category string `json:"category"` + Count int `json:"count"` + Pagecurrent int `json:"pagecurrent"` + Pagesize int `json:"pagesize"` + Pagetotal int `json:"pagetotal"` + SearchTerm string `json:"searchTerm"` + SearchResults []HLTBGame `json:"data"` +} + +// NewHLTB creates a new HLTBPlugin instance +func NewHLTB() *HLTBPlugin { + return &HLTBPlugin{ + BasePlugin: plugin.BasePlugin{ + ID: "fun.hltb", + Name: "How Long To Beat", + Help: "Get game completion times from HowLongToBeat.com using `!hltb `", + }, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// OnMessage handles incoming messages +func (p *HLTBPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { + // Check if message starts with !hltb + text := strings.TrimSpace(msg.Text) + if !strings.HasPrefix(text, "!hltb ") { + return nil + } + + // Extract game name + gameName := strings.TrimSpace(text[6:]) // Remove "!hltb " + if gameName == "" { + return p.createErrorResponse(msg, "Please provide a game name. Usage: !hltb ") + } + + // Check cache first + var games []HLTBGame + var err error + cacheKey := strings.ToLower(gameName) + + err = cache.Get(cacheKey, &games) + if err != nil || len(games) == 0 { + // Cache miss - search for the game + games, err = p.searchGame(gameName) + if err != nil { + return p.createErrorResponse(msg, fmt.Sprintf("Error searching for game: %s", err.Error())) + } + + if len(games) == 0 { + return p.createErrorResponse(msg, fmt.Sprintf("No results found for '%s'", gameName)) + } + + // Cache the results for 1 hour + err = cache.SetWithTTL(cacheKey, games, time.Hour) + if err != nil { + // Log cache error but don't fail the request + fmt.Printf("Warning: Failed to cache HLTB results: %v\n", err) + } + } + + // Use the first result + game := games[0] + + // Format the response + response := p.formatGameInfo(game) + + // Create response message with game cover if available + responseMsg := &model.Message{ + Text: response, + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + } + + // Set parse mode for markdown formatting + if responseMsg.Raw == nil { + responseMsg.Raw = make(map[string]interface{}) + } + responseMsg.Raw["parse_mode"] = "Markdown" + + // Add game cover as attachment if available + if game.GameImage != "" { + imageURL := p.getFullImageURL(game.GameImage) + responseMsg.Raw["image_url"] = imageURL + } + + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: responseMsg, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} +} + +// searchGame searches for a game on HowLongToBeat +func (p *HLTBPlugin) searchGame(gameName string) ([]HLTBGame, error) { + // Split search terms by words + searchTerms := strings.Fields(gameName) + + // Prepare search request + searchRequest := HLTBSearchRequest{ + SearchType: "games", + SearchTerms: searchTerms, + SearchPage: 1, + Size: 20, + SearchOptions: map[string]interface{}{ + "games": map[string]interface{}{ + "userId": 0, + "platform": "", + "sortCategory": "popular", + "rangeCategory": "main", + "rangeTime": map[string]interface{}{ + "min": nil, + "max": nil, + }, + "gameplay": map[string]interface{}{ + "perspective": "", + "flow": "", + "genre": "", + "difficulty": "", + }, + "rangeYear": map[string]interface{}{ + "min": "", + "max": "", + }, + "modifier": "", + }, + "users": map[string]interface{}{ + "sortCategory": "postcount", + }, + "lists": map[string]interface{}{ + "sortCategory": "follows", + }, + "filter": "", + "sort": 0, + "randomizer": 0, + }, + UseCache: true, + } + + // Convert to JSON + jsonData, err := json.Marshal(searchRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal search request: %w", err) + } + + // The API endpoint appears to have changed to use dynamic tokens + // Try to get the seek token first, fallback to basic search + seekToken, err := p.getSeekToken() + if err != nil { + // Fallback to old endpoint + seekToken = "" + } + + var apiURL string + if seekToken != "" { + apiURL = fmt.Sprintf("https://howlongtobeat.com/api/seek/%s", seekToken) + } else { + apiURL = "https://howlongtobeat.com/api/search" + } + + // Create HTTP request + req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers to match the working curl request + req.Header.Set("Accept", "*/*") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", "https://howlongtobeat.com") + req.Header.Set("Pragma", "no-cache") + req.Header.Set("Referer", "https://howlongtobeat.com") + req.Header.Set("Sec-Fetch-Dest", "empty") + req.Header.Set("Sec-Fetch-Mode", "cors") + req.Header.Set("Sec-Fetch-Site", "same-origin") + req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36") + + // Send request + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API returned status code: %d", resp.StatusCode) + } + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + // Parse response + var searchResponse HLTBSearchResponse + if err := json.Unmarshal(body, &searchResponse); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return searchResponse.SearchResults, nil +} + +// formatGameInfo formats game information for display +func (p *HLTBPlugin) formatGameInfo(game HLTBGame) string { + var response strings.Builder + + response.WriteString(fmt.Sprintf("šŸŽ® **%s**\n\n", game.Name)) + + // Format completion times + if game.CompMain > 0 { + response.WriteString(fmt.Sprintf("šŸ“– **Main Story:** %s\n", p.formatTime(game.CompMain))) + } + + if game.CompPlus > 0 { + response.WriteString(fmt.Sprintf("āž• **Main + Extras:** %s\n", p.formatTime(game.CompPlus))) + } + + if game.CompComplete > 0 { + response.WriteString(fmt.Sprintf("šŸ’Æ **Completionist:** %s\n", p.formatTime(game.CompComplete))) + } + + if game.CompAll > 0 { + response.WriteString(fmt.Sprintf("šŸŽÆ **All Styles:** %s\n", p.formatTime(game.CompAll))) + } + + // Add review score if available + if game.ReviewScore > 0 { + response.WriteString(fmt.Sprintf("\n⭐ **User Score:** %d/100", game.ReviewScore)) + } + + // Add source attribution + response.WriteString("\n\n*Source: HowLongToBeat.com*") + + return response.String() +} + +// formatTime converts seconds to a readable time format +func (p *HLTBPlugin) formatTime(seconds int) string { + if seconds <= 0 { + return "N/A" + } + + hours := float64(seconds) / 3600.0 + + if hours < 1 { + minutes := seconds / 60 + return fmt.Sprintf("%d minutes", minutes) + } else if hours < 2 { + return fmt.Sprintf("%.1f hour", hours) + } else { + return fmt.Sprintf("%.1f hours", hours) + } +} + +// getFullImageURL constructs the full image URL +func (p *HLTBPlugin) getFullImageURL(imagePath string) string { + if imagePath == "" { + return "" + } + + // Remove leading slash if present + imagePath = strings.TrimPrefix(imagePath, "/") + + return fmt.Sprintf("https://howlongtobeat.com/games/%s", imagePath) +} + +// getSeekToken attempts to retrieve the seek token from HowLongToBeat +func (p *HLTBPlugin) getSeekToken() (string, error) { + // Try to extract the seek token from the main page + req, err := http.NewRequest("GET", "https://howlongtobeat.com", nil) + if err != nil { + return "", fmt.Errorf("failed to create token request: %w", err) + } + + req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36") + + resp, err := p.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to fetch token: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read token response: %w", err) + } + + // Look for patterns that might contain the token + patterns := []string{ + `/api/seek/([a-f0-9]+)`, + `"seek/([a-f0-9]+)"`, + `seek/([a-f0-9]{12,})`, + } + + bodyStr := string(body) + for _, pattern := range patterns { + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(bodyStr) + if len(matches) > 1 { + return matches[1], nil + } + } + + // If we can't extract a token, return the known working one as fallback + return "d4b2e330db04dbf3", nil +} + +// createErrorResponse creates an error response message +func (p *HLTBPlugin) createErrorResponse(msg *model.Message, errorText string) []*model.MessageAction { + response := &model.Message{ + Text: fmt.Sprintf("āŒ %s", errorText), + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + } + + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} +} diff --git a/internal/plugin/fun/hltb_test.go b/internal/plugin/fun/hltb_test.go new file mode 100644 index 0000000..62810e3 --- /dev/null +++ b/internal/plugin/fun/hltb_test.go @@ -0,0 +1,131 @@ +package fun + +import ( + "testing" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/testutil" +) + +func TestHLTBPlugin_OnMessage(t *testing.T) { + plugin := NewHLTB() + + tests := []struct { + name string + messageText string + shouldRespond bool + }{ + { + name: "responds to !hltb command", + messageText: "!hltb The Witcher 3", + shouldRespond: true, + }, + { + name: "ignores non-hltb messages", + messageText: "hello world", + shouldRespond: false, + }, + { + name: "ignores !hltb without game name", + messageText: "!hltb", + shouldRespond: false, + }, + { + name: "ignores !hltb with only spaces", + messageText: "!hltb ", + shouldRespond: false, + }, + { + name: "ignores similar but incorrect commands", + messageText: "hltb The Witcher 3", + shouldRespond: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := &model.Message{ + Text: tt.messageText, + Chat: "test-chat", + Channel: &model.Channel{ID: 1}, + Author: "test-user", + } + + mockCache := &testutil.MockCache{} + actions := plugin.OnMessage(msg, make(map[string]interface{}), mockCache) + + if tt.shouldRespond && len(actions) == 0 { + t.Errorf("Expected plugin to respond to '%s', but it didn't", tt.messageText) + } + + if !tt.shouldRespond && len(actions) > 0 { + t.Errorf("Expected plugin to not respond to '%s', but it did", tt.messageText) + } + + // For messages that should respond, verify the response structure + if tt.shouldRespond && len(actions) > 0 { + action := actions[0] + if action.Type != model.ActionSendMessage { + t.Errorf("Expected ActionSendMessage, got %s", action.Type) + } + + if action.Message == nil { + t.Error("Expected action to have a message") + } + + if action.Message != nil && action.Message.ReplyTo != msg.ID { + t.Error("Expected response to reply to original message") + } + } + }) + } +} + +func TestHLTBPlugin_formatTime(t *testing.T) { + plugin := NewHLTB() + + tests := []struct { + seconds int + expected string + }{ + {0, "N/A"}, + {-1, "N/A"}, + {1800, "30 minutes"}, // 30 minutes + {3600, "1.0 hour"}, // 1 hour + {7200, "2.0 hours"}, // 2 hours + {10800, "3.0 hours"}, // 3 hours + {36000, "10.0 hours"}, // 10 hours + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + result := plugin.formatTime(tt.seconds) + if result != tt.expected { + t.Errorf("formatTime(%d) = %s, want %s", tt.seconds, result, tt.expected) + } + }) + } +} + +func TestHLTBPlugin_getFullImageURL(t *testing.T) { + plugin := NewHLTB() + + tests := []struct { + imagePath string + expected string + }{ + {"", ""}, + {"game.jpg", "https://howlongtobeat.com/games/game.jpg"}, + {"/game.jpg", "https://howlongtobeat.com/games/game.jpg"}, + {"folder/game.png", "https://howlongtobeat.com/games/folder/game.png"}, + } + + for _, tt := range tests { + t.Run(tt.imagePath, func(t *testing.T) { + result := plugin.getFullImageURL(tt.imagePath) + if result != tt.expected { + t.Errorf("getFullImageURL(%s) = %s, want %s", tt.imagePath, result, tt.expected) + } + }) + } +} diff --git a/internal/plugin/fun/loquito.go b/internal/plugin/fun/loquito.go index fef78bd..4b102f7 100644 --- a/internal/plugin/fun/loquito.go +++ b/internal/plugin/fun/loquito.go @@ -23,8 +23,13 @@ func NewLoquito() *LoquitoPlugin { } } +// GetHelp returns the plugin help text +func (p *LoquitoPlugin) GetHelp() string { + return "" +} + // OnMessage handles incoming messages -func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { if !strings.Contains(strings.ToLower(msg.Text), "lo quito") { return nil } diff --git a/internal/plugin/help/help.go b/internal/plugin/help/help.go new file mode 100644 index 0000000..4e6215a --- /dev/null +++ b/internal/plugin/help/help.go @@ -0,0 +1,166 @@ +package help + +import ( + "fmt" + "sort" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/db" + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" + "golang.org/x/exp/slog" +) + +// ChannelPluginGetter is an interface for getting channel plugins +type ChannelPluginGetter interface { + GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) + GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) +} + +// HelpPlugin provides help information about available commands +type HelpPlugin struct { + plugin.BasePlugin + db ChannelPluginGetter +} + +// New creates a new HelpPlugin instance +func New(db ChannelPluginGetter) *HelpPlugin { + return &HelpPlugin{ + BasePlugin: plugin.BasePlugin{ + ID: "utility.help", + Name: "Help", + Help: "Shows available commands when you type '!help'", + }, + db: db, + } +} + +// OnMessage handles incoming messages +func (p *HelpPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { + // Check if message is the help command + if !strings.EqualFold(strings.TrimSpace(msg.Text), "!help") { + return nil + } + + // Get channel plugins from database using platform and platform channel ID + channelPlugins, err := p.db.GetChannelPluginsFromPlatformID(msg.Channel.Platform, msg.Channel.PlatformChannelID) + if err != nil && err != db.ErrNotFound { + slog.Error("Failed to get channel plugins", slog.Any("err", err)) + return []*model.MessageAction{} + } + + // If no plugins found, initialize empty slice + if err == db.ErrNotFound { + channelPlugins = []*model.ChannelPlugin{} + } + + // Get all available plugins + availablePlugins := plugin.GetAvailablePlugins() + + // Filter to only enabled plugins for this channel + enabledPlugins := make(map[string]model.Plugin) + for _, channelPlugin := range channelPlugins { + if channelPlugin.Enabled { + if availablePlugin, exists := availablePlugins[channelPlugin.PluginID]; exists { + enabledPlugins[channelPlugin.PluginID] = availablePlugin + } + } + } + + // If no plugins are enabled, return a message + if len(enabledPlugins) == 0 { + response := &model.Message{ + Text: "No plugins are currently enabled for this channel.", + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + Raw: map[string]interface{}{"parse_mode": "Markdown"}, + } + + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } + } + + // Group plugins by category + categories := map[string][]model.Plugin{ + "Development": {}, + "Fun and Entertainment": {}, + "Utility": {}, + "Security": {}, + "Social Media": {}, + "Other": {}, + } + + // Categorize plugins based on their ID prefix + for _, p := range enabledPlugins { + category := p.GetID() + switch { + case strings.HasPrefix(category, "dev."): + categories["Development"] = append(categories["Development"], p) + case strings.HasPrefix(category, "fun."): + categories["Fun and Entertainment"] = append(categories["Fun and Entertainment"], p) + case strings.HasPrefix(category, "util.") || strings.HasPrefix(category, "reminder.") || strings.HasPrefix(category, "utility."): + categories["Utility"] = append(categories["Utility"], p) + case strings.HasPrefix(category, "security."): + categories["Security"] = append(categories["Security"], p) + case strings.HasPrefix(category, "social."): + categories["Social Media"] = append(categories["Social Media"], p) + default: + categories["Other"] = append(categories["Other"], p) + } + } + + // Build the help message + var helpText strings.Builder + helpText.WriteString("šŸ¤– **Available Commands**\n\n") + + // Sort category names for consistent output + categoryOrder := []string{"Development", "Fun and Entertainment", "Utility", "Security", "Social Media", "Other"} + + for _, categoryName := range categoryOrder { + pluginList := categories[categoryName] + if len(pluginList) == 0 { + continue + } + + // Sort plugins within category by name + sort.Slice(pluginList, func(i, j int) bool { + return pluginList[i].GetName() < pluginList[j].GetName() + }) + + helpText.WriteString(fmt.Sprintf("**%s:**\n", categoryName)) + for _, p := range pluginList { + if p.GetHelp() == "" { + continue + } + helpText.WriteString(fmt.Sprintf("• **%s** - %s\n", p.GetName(), p.GetHelp())) + } + helpText.WriteString("\n") + } + + // Add footer + helpText.WriteString("_Use the specific commands or triggers mentioned above to interact with the bot._") + + response := &model.Message{ + Text: helpText.String(), + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + Raw: map[string]interface{}{"parse_mode": "Markdown"}, + } + + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } +} diff --git a/internal/plugin/help/help_test.go b/internal/plugin/help/help_test.go new file mode 100644 index 0000000..25d5376 --- /dev/null +++ b/internal/plugin/help/help_test.go @@ -0,0 +1,206 @@ +package help + +import ( + "strings" + "testing" + + "git.nakama.town/fmartingr/butterrobot/internal/db" + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// MockPlugin implements the Plugin interface for testing +type MockPlugin struct { + id string + name string + help string +} + +func (m *MockPlugin) GetID() string { return m.id } +func (m *MockPlugin) GetName() string { return m.name } +func (m *MockPlugin) GetHelp() string { return m.help } +func (m *MockPlugin) RequiresConfig() bool { + return false +} +func (m *MockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { + return nil +} + +// MockDatabase implements the ChannelPluginGetter interface for testing +type MockDatabase struct { + channelPlugins map[int64][]*model.ChannelPlugin + platformChannelPlugins map[string][]*model.ChannelPlugin // key: "platform:platformChannelID" +} + +func (m *MockDatabase) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) { + if plugins, exists := m.channelPlugins[channelID]; exists { + return plugins, nil + } + return nil, db.ErrNotFound +} + +func (m *MockDatabase) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) { + key := platform + ":" + platformChannelID + if plugins, exists := m.platformChannelPlugins[key]; exists { + return plugins, nil + } + return nil, db.ErrNotFound +} + +func TestHelpPlugin_OnMessage(t *testing.T) { + tests := []struct { + name string + messageText string + enabledPlugins map[string]*MockPlugin + expectResponse bool + expectNoPlugins bool + expectCategories []string + }{ + { + name: "responds to !help command", + messageText: "!help", + enabledPlugins: map[string]*MockPlugin{ + "dev.ping": { + id: "dev.ping", + name: "Ping", + help: "Responds to 'ping' with 'pong'", + }, + "fun.dice": { + id: "fun.dice", + name: "Dice Roller", + help: "Rolls dice when you type '!dice [formula]'", + }, + }, + expectResponse: true, + expectCategories: []string{"Development", "Fun and Entertainment"}, + }, + { + name: "ignores non-help messages", + messageText: "hello world", + enabledPlugins: map[string]*MockPlugin{}, + expectResponse: false, + }, + { + name: "ignores case variation", + messageText: "!HELP", + enabledPlugins: map[string]*MockPlugin{}, + expectResponse: true, + expectNoPlugins: true, + }, + { + name: "handles no enabled plugins", + messageText: "!help", + enabledPlugins: map[string]*MockPlugin{}, + expectResponse: true, + expectNoPlugins: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock database + mockDB := &MockDatabase{ + channelPlugins: make(map[int64][]*model.ChannelPlugin), + platformChannelPlugins: make(map[string][]*model.ChannelPlugin), + } + + // Setup channel plugins in mock database + var channelPluginList []*model.ChannelPlugin + pluginCounter := int64(1) + for pluginID := range tt.enabledPlugins { + channelPluginList = append(channelPluginList, &model.ChannelPlugin{ + ID: pluginCounter, + ChannelID: 1, + PluginID: pluginID, + Enabled: true, + Config: make(map[string]interface{}), + }) + pluginCounter++ + } + + // Set up both mapping approaches for the test + mockDB.channelPlugins[1] = channelPluginList + mockDB.platformChannelPlugins["test:test-channel"] = channelPluginList + + // Create help plugin + p := New(mockDB) + + // Create mock channel + channel := &model.Channel{ + ID: 1, + Platform: "test", + PlatformChannelID: "test-channel", + } + + // Create test message + msg := &model.Message{ + ID: "test-msg", + Text: tt.messageText, + Chat: "test-chat", + Channel: channel, + } + + // Mock the plugin registry + originalRegistry := plugin.GetAvailablePlugins() + + // Override the registry for this test + plugin.ClearRegistry() + for _, mockPlugin := range tt.enabledPlugins { + plugin.Register(mockPlugin) + } + + // Call OnMessage + actions := p.OnMessage(msg, map[string]interface{}{}, nil) + + // Restore original registry + plugin.ClearRegistry() + for _, p := range originalRegistry { + plugin.Register(p) + } + + if !tt.expectResponse { + if len(actions) != 0 { + t.Errorf("Expected no response, but got %d actions", len(actions)) + } + return + } + + if len(actions) != 1 { + t.Errorf("Expected 1 action, got %d", len(actions)) + return + } + + action := actions[0] + if action.Type != model.ActionSendMessage { + t.Errorf("Expected ActionSendMessage, got %v", action.Type) + return + } + + responseText := action.Message.Text + + if tt.expectNoPlugins { + if !strings.Contains(responseText, "No plugins are currently enabled") { + t.Errorf("Expected 'no plugins' message, got: %s", responseText) + } + return + } + + // Check that expected categories appear in response + for _, category := range tt.expectCategories { + if !strings.Contains(responseText, "**"+category+":**") { + t.Errorf("Expected category '%s' in response, got: %s", category, responseText) + } + } + + // Check that plugin names and help text appear + for _, mockPlugin := range tt.enabledPlugins { + if !strings.Contains(responseText, mockPlugin.GetName()) { + t.Errorf("Expected plugin name '%s' in response", mockPlugin.GetName()) + } + if !strings.Contains(responseText, mockPlugin.GetHelp()) { + t.Errorf("Expected plugin help '%s' in response", mockPlugin.GetHelp()) + } + } + }) + } +} diff --git a/internal/plugin/ping/ping.go b/internal/plugin/ping/ping.go index 3dacf6f..be0402c 100644 --- a/internal/plugin/ping/ping.go +++ b/internal/plugin/ping/ping.go @@ -24,7 +24,7 @@ func New() *PingPlugin { } // OnMessage handles incoming messages -func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") { return nil } diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go index eb3789f..8f8413a 100644 --- a/internal/plugin/plugin.go +++ b/internal/plugin/plugin.go @@ -47,6 +47,26 @@ func GetAvailablePlugins() map[string]model.Plugin { return result } +// GetAvailablePluginIDs returns a slice of all registered plugin IDs +func GetAvailablePluginIDs() []string { + pluginsMu.RLock() + defer pluginsMu.RUnlock() + + result := make([]string, 0, len(plugins)) + for pluginID := range plugins { + result = append(result, pluginID) + } + + return result +} + +// ClearRegistry clears all registered plugins (for testing) +func ClearRegistry() { + pluginsMu.Lock() + defer pluginsMu.Unlock() + plugins = make(map[string]model.Plugin) +} + // BasePlugin provides a common base for plugins type BasePlugin struct { ID string @@ -76,6 +96,6 @@ func (p *BasePlugin) RequiresConfig() bool { } // OnMessage is the default implementation that does nothing -func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { return nil } diff --git a/internal/plugin/plugin_test.go b/internal/plugin/plugin_test.go new file mode 100644 index 0000000..0bfd207 --- /dev/null +++ b/internal/plugin/plugin_test.go @@ -0,0 +1,331 @@ +package plugin + +import ( + "testing" + + "git.nakama.town/fmartingr/butterrobot/internal/model" +) + +// Mock plugin for testing +type testPlugin struct { + BasePlugin +} + +func (p *testPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: &model.Message{ + Text: "test response", + Chat: msg.Chat, + Channel: msg.Channel, + }, + }, + } +} + +func TestGetAvailablePluginIDs(t *testing.T) { + // Clear registry before test + ClearRegistry() + + // Register test plugins + testPlugin1 := &testPlugin{ + BasePlugin: BasePlugin{ + ID: "test.plugin1", + Name: "Test Plugin 1", + }, + } + testPlugin2 := &testPlugin{ + BasePlugin: BasePlugin{ + ID: "test.plugin2", + Name: "Test Plugin 2", + }, + } + + Register(testPlugin1) + Register(testPlugin2) + + // Test GetAvailablePluginIDs + pluginIDs := GetAvailablePluginIDs() + + if len(pluginIDs) != 2 { + t.Errorf("Expected 2 plugin IDs, got %d", len(pluginIDs)) + } + + // Check that both plugin IDs are present + found1, found2 := false, false + for _, id := range pluginIDs { + if id == "test.plugin1" { + found1 = true + } + if id == "test.plugin2" { + found2 = true + } + } + + if !found1 { + t.Errorf("Expected to find test.plugin1 in plugin IDs") + } + if !found2 { + t.Errorf("Expected to find test.plugin2 in plugin IDs") + } +} + +func TestEnableAllPluginsProcessingLogic(t *testing.T) { + // Clear registry before test + ClearRegistry() + + // Register test plugins + testPlugin1 := &testPlugin{ + BasePlugin: BasePlugin{ + ID: "ping", + Name: "Ping Plugin", + }, + } + testPlugin2 := &testPlugin{ + BasePlugin: BasePlugin{ + ID: "echo", + Name: "Echo Plugin", + }, + } + testPlugin3 := &testPlugin{ + BasePlugin: BasePlugin{ + ID: "help", + Name: "Help Plugin", + }, + } + + Register(testPlugin1) + Register(testPlugin2) + Register(testPlugin3) + + t.Run("EnableAllPlugins false - only explicitly enabled plugins", func(t *testing.T) { + // Create a channel with EnableAllPlugins = false and only some plugins enabled + channel := &model.Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: false, + Plugins: map[string]*model.ChannelPlugin{ + "ping": { + ID: 1, + ChannelID: 1, + PluginID: "ping", + Enabled: true, + Config: map[string]interface{}{"key": "value"}, + }, + "echo": { + ID: 2, + ChannelID: 1, + PluginID: "echo", + Enabled: false, // Disabled + Config: map[string]interface{}{}, + }, + // help plugin not configured + }, + } + + // Simulate the plugin processing logic from handleMessage + var pluginsToProcess []string + + if channel.EnableAllPlugins { + pluginsToProcess = GetAvailablePluginIDs() + } else { + for pluginID := range channel.Plugins { + if channel.HasEnabledPlugin(pluginID) { + pluginsToProcess = append(pluginsToProcess, pluginID) + } + } + } + + // Should only have "ping" since echo is disabled and help is not configured + if len(pluginsToProcess) != 1 { + t.Errorf("Expected 1 plugin to process, got %d: %v", len(pluginsToProcess), pluginsToProcess) + } + + if len(pluginsToProcess) > 0 && pluginsToProcess[0] != "ping" { + t.Errorf("Expected ping plugin to be processed, got %s", pluginsToProcess[0]) + } + }) + + t.Run("EnableAllPlugins true - all registered plugins", func(t *testing.T) { + // Create a channel with EnableAllPlugins = true + channel := &model.Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: true, + Plugins: map[string]*model.ChannelPlugin{ + "ping": { + ID: 1, + ChannelID: 1, + PluginID: "ping", + Enabled: true, + Config: map[string]interface{}{"key": "value"}, + }, + "echo": { + ID: 2, + ChannelID: 1, + PluginID: "echo", + Enabled: false, // Disabled, but should still be processed + Config: map[string]interface{}{}, + }, + // help plugin not configured, but should still be processed + }, + } + + // Simulate the plugin processing logic from handleMessage + var pluginsToProcess []string + + if channel.EnableAllPlugins { + pluginsToProcess = GetAvailablePluginIDs() + } else { + for pluginID := range channel.Plugins { + if channel.HasEnabledPlugin(pluginID) { + pluginsToProcess = append(pluginsToProcess, pluginID) + } + } + } + + // Should have all 3 registered plugins + if len(pluginsToProcess) != 3 { + t.Errorf("Expected 3 plugins to process, got %d: %v", len(pluginsToProcess), pluginsToProcess) + } + + // Check that all plugins are included + expectedPlugins := map[string]bool{"ping": false, "echo": false, "help": false} + for _, pluginID := range pluginsToProcess { + if _, exists := expectedPlugins[pluginID]; exists { + expectedPlugins[pluginID] = true + } else { + t.Errorf("Unexpected plugin in processing list: %s", pluginID) + } + } + + for pluginID, found := range expectedPlugins { + if !found { + t.Errorf("Expected plugin %s to be in processing list", pluginID) + } + } + }) + + t.Run("Plugin configuration handling", func(t *testing.T) { + // Test the configuration logic from handleMessage + channel := &model.Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "123456", + Enabled: true, + EnableAllPlugins: true, + Plugins: map[string]*model.ChannelPlugin{ + "ping": { + ID: 1, + ChannelID: 1, + PluginID: "ping", + Enabled: true, + Config: map[string]interface{}{"configured": "value"}, + }, + }, + } + + testCases := []struct { + pluginID string + expectedConfig map[string]interface{} + }{ + { + pluginID: "ping", + expectedConfig: map[string]interface{}{"configured": "value"}, + }, + { + pluginID: "echo", // Not explicitly configured + expectedConfig: map[string]interface{}{}, + }, + } + + for _, tc := range testCases { + // Simulate the config retrieval logic from handleMessage + var config map[string]interface{} + if channelPlugin, exists := channel.Plugins[tc.pluginID]; exists { + config = channelPlugin.Config + } else { + config = make(map[string]interface{}) + } + + if len(config) != len(tc.expectedConfig) { + t.Errorf("Plugin %s: expected config length %d, got %d", tc.pluginID, len(tc.expectedConfig), len(config)) + } + + for key, expectedValue := range tc.expectedConfig { + if actualValue, exists := config[key]; !exists || actualValue != expectedValue { + t.Errorf("Plugin %s: expected config[%s] = %v, got %v", tc.pluginID, key, expectedValue, actualValue) + } + } + } + }) +} + +func TestPluginRegistry(t *testing.T) { + // Clear registry before test + ClearRegistry() + + testPlugin := &testPlugin{ + BasePlugin: BasePlugin{ + ID: "test.registry", + Name: "Test Registry Plugin", + }, + } + + t.Run("Register and Get plugin", func(t *testing.T) { + Register(testPlugin) + + retrieved, err := Get("test.registry") + if err != nil { + t.Errorf("Failed to get registered plugin: %v", err) + } + + if retrieved.GetID() != "test.registry" { + t.Errorf("Expected plugin ID 'test.registry', got '%s'", retrieved.GetID()) + } + }) + + t.Run("Get nonexistent plugin", func(t *testing.T) { + _, err := Get("nonexistent.plugin") + if err == nil { + t.Errorf("Expected error when getting nonexistent plugin, got nil") + } + + if err != model.ErrPluginNotFound { + t.Errorf("Expected ErrPluginNotFound, got %v", err) + } + }) + + t.Run("GetAvailablePlugins", func(t *testing.T) { + plugins := GetAvailablePlugins() + + if len(plugins) != 1 { + t.Errorf("Expected 1 plugin in registry, got %d", len(plugins)) + } + + if plugin, exists := plugins["test.registry"]; !exists { + t.Errorf("Expected to find test.registry in available plugins") + } else if plugin.GetID() != "test.registry" { + t.Errorf("Expected plugin ID 'test.registry', got '%s'", plugin.GetID()) + } + }) + + t.Run("ClearRegistry", func(t *testing.T) { + ClearRegistry() + + plugins := GetAvailablePlugins() + if len(plugins) != 0 { + t.Errorf("Expected 0 plugins after clearing registry, got %d", len(plugins)) + } + + _, err := Get("test.registry") + if err == nil { + t.Errorf("Expected error when getting plugin after clearing registry, got nil") + } + }) +} diff --git a/internal/plugin/reminder/reminder.go b/internal/plugin/reminder/reminder.go index 029c8d9..bb21dbf 100644 --- a/internal/plugin/reminder/reminder.go +++ b/internal/plugin/reminder/reminder.go @@ -41,7 +41,7 @@ func New(creator ReminderCreator) *Reminder { } // OnMessage processes incoming messages -func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { // Only process replies to messages if msg.ReplyTo == "" { return nil diff --git a/internal/plugin/reminder/reminder_test.go b/internal/plugin/reminder/reminder_test.go index 8e611ce..f2c1d21 100644 --- a/internal/plugin/reminder/reminder_test.go +++ b/internal/plugin/reminder/reminder_test.go @@ -5,6 +5,7 @@ import ( "time" "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/testutil" ) // MockCreator is a mock implementation of ReminderCreator for testing @@ -142,7 +143,8 @@ func TestReminderOnMessage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { initialCount := len(creator.reminders) - actions := plugin.OnMessage(tt.message, nil) + mockCache := &testutil.MockCache{} + actions := plugin.OnMessage(tt.message, nil, mockCache) if tt.expectResponse && len(actions) == 0 { t.Errorf("Expected response action, but got none") diff --git a/internal/plugin/searchreplace/searchreplace.go b/internal/plugin/searchreplace/searchreplace.go index 876e880..d9cd04c 100644 --- a/internal/plugin/searchreplace/searchreplace.go +++ b/internal/plugin/searchreplace/searchreplace.go @@ -23,14 +23,14 @@ func New() *SearchReplacePlugin { BasePlugin: plugin.BasePlugin{ ID: "util.searchreplace", Name: "Search and Replace", - Help: "Reply to a message with a search and replace pattern (s/search/replace/[flags]) to create a modified message. " + + Help: "Reply to a message with a search and replace pattern (`s/search/replace/[flags]`) to create a modified message. " + "Supported flags: g (global), i (case insensitive)", }, } } // OnMessage handles incoming messages -func (p *SearchReplacePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *SearchReplacePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { // Only process replies to messages if msg.ReplyTo == "" { return nil diff --git a/internal/plugin/searchreplace/searchreplace_test.go b/internal/plugin/searchreplace/searchreplace_test.go index 415610c..fa5cdf5 100644 --- a/internal/plugin/searchreplace/searchreplace_test.go +++ b/internal/plugin/searchreplace/searchreplace_test.go @@ -5,6 +5,7 @@ import ( "time" "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/testutil" ) func TestSearchReplace(t *testing.T) { @@ -84,7 +85,8 @@ func TestSearchReplace(t *testing.T) { } // Process message - actions := p.OnMessage(msg, nil) + mockCache := &testutil.MockCache{} + actions := p.OnMessage(msg, nil, mockCache) // Check results if tc.expectActions { diff --git a/internal/plugin/social/instagram.go b/internal/plugin/social/instagram.go index 0b4ff55..b423b45 100644 --- a/internal/plugin/social/instagram.go +++ b/internal/plugin/social/instagram.go @@ -18,20 +18,27 @@ type InstagramExpander struct { func NewInstagramExpander() *InstagramExpander { return &InstagramExpander{ BasePlugin: plugin.BasePlugin{ - ID: "social.instagram", - Name: "Instagram Link Expander", - Help: "Automatically converts instagram.com links to ddinstagram.com links and removes tracking parameters", + ID: "social.instagram", + Name: "Instagram Link Expander", + Help: "Automatically converts instagram.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: ddinstagram.com)", + ConfigRequired: true, }, } } // OnMessage handles incoming messages -func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { // Skip empty messages if strings.TrimSpace(msg.Text) == "" { return nil } + // Get replacement domain from config, default to ddinstagram.com + replacementDomain := "ddinstagram.com" + if domain, ok := config["domain"].(string); ok && domain != "" { + replacementDomain = domain + } + // Regex to match instagram.com links // Match both http://instagram.com and https://instagram.com formats // Also match www.instagram.com @@ -42,7 +49,7 @@ func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]inte return nil } - // Replace instagram.com with ddinstagram.com in the message and clean query parameters + // Replace instagram.com with configured domain in the message and clean query parameters transformed := instagramRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { // Parse the URL parsedURL, err := url.Parse(link) @@ -51,13 +58,13 @@ func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]inte return link } - // Ensure we don't change links that already come from ddinstagram.com + // Ensure we don't change links that already come from the replacement domain if parsedURL.Host != "instagram.com" && parsedURL.Host != "www.instagram.com" { return link } - // Change the host - parsedURL.Host = "d.ddinstagram.com" + // Change the host to the configured domain + parsedURL.Host = replacementDomain // Remove query parameters parsedURL.RawQuery = "" diff --git a/internal/plugin/social/twitter.go b/internal/plugin/social/twitter.go index 865f421..f2c6cc9 100644 --- a/internal/plugin/social/twitter.go +++ b/internal/plugin/social/twitter.go @@ -18,20 +18,27 @@ type TwitterExpander struct { func NewTwitterExpander() *TwitterExpander { return &TwitterExpander{ BasePlugin: plugin.BasePlugin{ - ID: "social.twitter", - Name: "Twitter Link Expander", - Help: "Automatically converts twitter.com links to fxtwitter.com links and removes tracking parameters", + ID: "social.twitter", + Name: "Twitter Link Expander", + Help: "Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: fxtwitter.com)", + ConfigRequired: true, }, } } // OnMessage handles incoming messages -func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { +func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { // Skip empty messages if strings.TrimSpace(msg.Text) == "" { return nil } + // Get replacement domain from config, default to fxtwitter.com + replacementDomain := "fxtwitter.com" + if domain, ok := config["domain"].(string); ok && domain != "" { + replacementDomain = domain + } + // Regex to match twitter.com links // Match both http://twitter.com and https://twitter.com formats // Also match www.twitter.com @@ -42,22 +49,17 @@ func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interf return nil } - // Replace twitter.com with fxtwitter.com in the message and clean query parameters + // Replace twitter.com/x.com with configured domain in the message and clean query parameters transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { // Parse the URL parsedURL, err := url.Parse(link) if err != nil { - // If parsing fails, just do the simple replacement - link = strings.Replace(link, "twitter.com", "fxtwitter.com", 1) - link = strings.Replace(link, "x.com", "fxtwitter.com", 1) return link } - // Change the host - if strings.Contains(parsedURL.Host, "twitter.com") { - parsedURL.Host = strings.Replace(parsedURL.Host, "twitter.com", "fxtwitter.com", 1) - } else if strings.Contains(parsedURL.Host, "x.com") { - parsedURL.Host = strings.Replace(parsedURL.Host, "x.com", "fxtwitter.com", 1) + // Change the host to the configured domain + if strings.Contains(parsedURL.Host, "twitter.com") || strings.Contains(parsedURL.Host, "x.com") { + parsedURL.Host = replacementDomain } // Remove query parameters diff --git a/internal/plugin/social/twitter_test.go b/internal/plugin/social/twitter_test.go new file mode 100644 index 0000000..c0e1681 --- /dev/null +++ b/internal/plugin/social/twitter_test.go @@ -0,0 +1,120 @@ +package social + +import ( + "testing" + + "git.nakama.town/fmartingr/butterrobot/internal/model" +) + +func TestTwitterExpander_OnMessage(t *testing.T) { + plugin := NewTwitterExpander() + + tests := []struct { + name string + input string + config map[string]interface{} + expected string + hasReply bool + }{ + { + name: "Twitter URL with default domain", + input: "https://twitter.com/user/status/123456789", + config: map[string]interface{}{}, + expected: "https://fxtwitter.com/user/status/123456789", + hasReply: true, + }, + { + name: "X.com URL with custom domain", + input: "https://x.com/elonmusk/status/987654321", + config: map[string]interface{}{"domain": "vxtwitter.com"}, + expected: "https://vxtwitter.com/elonmusk/status/987654321", + hasReply: true, + }, + { + name: "Twitter URL with tracking parameters", + input: "https://twitter.com/openai/status/555?ref_src=twsrc%5Etfw&s=20", + config: map[string]interface{}{}, + expected: "https://fxtwitter.com/openai/status/555", + hasReply: true, + }, + { + name: "www.twitter.com URL", + input: "https://www.twitter.com/user/status/789", + config: map[string]interface{}{"domain": "nitter.net"}, + expected: "https://nitter.net/user/status/789", + hasReply: true, + }, + { + name: "Mixed text with Twitter URL", + input: "Check this out: https://twitter.com/user/status/123 amazing!", + config: map[string]interface{}{}, + expected: "Check this out: https://fxtwitter.com/user/status/123 amazing!", + hasReply: true, + }, + { + name: "No Twitter URLs", + input: "Just some regular text https://youtube.com/watch?v=abc", + config: map[string]interface{}{}, + expected: "", + hasReply: false, + }, + { + name: "Empty message", + input: "", + config: map[string]interface{}{}, + expected: "", + hasReply: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := &model.Message{ + ID: "test_msg", + Text: tt.input, + Chat: "test_chat", + Channel: &model.Channel{ + ID: 1, + Platform: "telegram", + PlatformChannelID: "test_chat", + }, + } + + actions := plugin.OnMessage(msg, tt.config, nil) + + if !tt.hasReply { + if len(actions) != 0 { + t.Errorf("Expected no actions, got %d", len(actions)) + } + return + } + + if len(actions) != 1 { + t.Errorf("Expected 1 action, got %d", len(actions)) + return + } + + action := actions[0] + if action.Type != model.ActionSendMessage { + t.Errorf("Expected ActionSendMessage, got %s", action.Type) + } + + if action.Message == nil { + t.Error("Expected message in action, got nil") + return + } + + if action.Message.Text != tt.expected { + t.Errorf("Expected '%s', got '%s'", tt.expected, action.Message.Text) + } + + if action.Message.ReplyTo != msg.ID { + t.Errorf("Expected ReplyTo '%s', got '%s'", msg.ID, action.Message.ReplyTo) + } + + if action.Message.Raw == nil || action.Message.Raw["parse_mode"] != "" { + t.Error("Expected parse_mode to be empty string to disable markdown parsing") + } + }) + } +} diff --git a/internal/testutil/mock_cache.go b/internal/testutil/mock_cache.go new file mode 100644 index 0000000..7ecb878 --- /dev/null +++ b/internal/testutil/mock_cache.go @@ -0,0 +1,29 @@ +package testutil + +import ( + "errors" + "time" +) + +// MockCache implements the CacheInterface for testing +type MockCache struct{} + +func (m *MockCache) Get(key string, destination interface{}) error { + return errors.New("cache miss") // Always return cache miss for tests +} + +func (m *MockCache) Set(key string, value interface{}, expiration *time.Time) error { + return nil // Always succeed for tests +} + +func (m *MockCache) SetWithTTL(key string, value interface{}, ttl time.Duration) error { + return nil // Always succeed for tests +} + +func (m *MockCache) Delete(key string) error { + return nil // Always succeed for tests +} + +func (m *MockCache) Exists(key string) (bool, error) { + return false, nil // Always return false for tests +}