diff --git a/.woodpecker/ci.yml b/.woodpecker/ci.yml index 4353088..5b32d48 100644 --- a/.woodpecker/ci.yml +++ b/.woodpecker/ci.yml @@ -3,7 +3,7 @@ when: - push - pull_request branch: - - main + - master steps: format: diff --git a/.woodpecker/release.yml b/.woodpecker/release.yml index 3630566..39dbf65 100644 --- a/.woodpecker/release.yml +++ b/.woodpecker/release.yml @@ -1,6 +1,6 @@ when: - event: tag - branch: main + branch: master steps: - name: Release diff --git a/README.md b/README.md index 214afa6..920d087 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,6 @@ # Butter Robot -| Stable | Master | -| --- | --- | -| ![Build stable tag docker image](https://git.nakama.town/fmartingr/butterrobot/workflows/Build%20stable%20tag%20docker%20image/badge.svg?branch=stable) | ![Build latest tag docker image](https://git.nakama.town/fmartingr/butterrobot/workflows/Build%20latest%20tag%20docker%20image/badge.svg?branch=master) | -| ![Test](https://git.nakama.town/fmartingr/butterrobot/workflows/Test/badge.svg?branch=stable) | ![Test](https://git.nakama.town/fmartingr/butterrobot/workflows/Test/badge.svg?branch=master) | +![Status badge](https://woodpecker.local.fmartingr.dev/api/badges/5/status.svg) Go framework to create bots for several platforms. @@ -13,7 +10,7 @@ Go framework to create bots for several platforms. ## Features -- Support for multiple chat platforms (Slack, Telegram) +- Support for multiple chat platforms (Slack (untested!), Telegram) - Plugin system for easy extension - Admin interface for managing channels and plugins - Message queue for asynchronous processing diff --git a/docs/creating-a-plugin.md b/docs/creating-a-plugin.md index 469491a..b8e4a78 100644 --- a/docs/creating-a-plugin.md +++ b/docs/creating-a-plugin.md @@ -7,6 +7,7 @@ ButterRobot organizes plugins into different categories: - **Development**: Utility plugins like `ping` - **Fun**: Entertainment plugins like dice rolling, coin flipping - **Social**: Social media related plugins like URL transformers/expanders +- **Security**: Moderation and protection features like domain blocking When creating a new plugin, consider which category it fits into and place it in the appropriate directory. @@ -59,6 +60,91 @@ func (p *MarcoPlugin) OnMessage(msg *model.Message, config map[string]interface{ } ``` +### Configuration-Enabled Plugin + +This plugin requires configuration to be set in the admin interface. It demonstrates how to create plugins that need channel-specific configuration: + +```go +package security + +import ( + "fmt" + "regexp" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// DomainBlockPlugin is a plugin that blocks messages containing links from specific domains +type DomainBlockPlugin struct { + plugin.BasePlugin +} + +// New creates a new DomainBlockPlugin instance +func New() *DomainBlockPlugin { + return &DomainBlockPlugin{ + BasePlugin: plugin.BasePlugin{ + ID: "security.domainblock", + Name: "Domain Blocker", + Help: "Blocks messages containing links from configured domains", + ConfigRequired: true, // Mark this plugin as requiring configuration + }, + } +} + +// OnMessage processes incoming messages +func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { + // Get blocked domains from config + blockedDomainsStr, ok := config["blocked_domains"].(string) + if !ok || blockedDomainsStr == "" { + return nil // No blocked domains configured + } + + // Split and clean blocked domains + blockedDomains := strings.Split(blockedDomainsStr, ",") + for i, domain := range blockedDomains { + blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + } + + // Extract domains from message + urlRegex := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`) + matches := urlRegex.FindAllStringSubmatch(msg.Text, -1) + + // Check if any extracted domains are blocked + for _, match := range matches { + if len(match) < 2 { + continue + } + + domain := strings.ToLower(match[1]) + + for _, blockedDomain := range blockedDomains { + if blockedDomain == "" { + continue + } + + if strings.HasSuffix(domain, blockedDomain) || domain == blockedDomain { + // Domain is blocked, create warning message + response := &model.Message{ + Text: fmt.Sprintf("⚠️ Message contained a link to blocked domain: %s", blockedDomain), + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + } + return []*model.Message{response} + } + } + } + + return nil +} + +func init() { + plugin.Register(New()) +} +``` + ### Advanced Example: URL Transformer This more complex plugin transforms URLs, useful for improving media embedding in chat platforms: @@ -143,6 +229,36 @@ func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interf } ``` +## Enabling Configuration for Plugins + +To indicate that your plugin requires configuration: + +1. Set `ConfigRequired: true` in the BasePlugin struct: + ```go + BasePlugin: plugin.BasePlugin{ + ID: "myplugin.id", + Name: "Plugin Name", + Help: "Help text", + ConfigRequired: true, + }, + ``` + +2. Access the configuration in the OnMessage method: + ```go + func (p *MyPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { + // Extract configuration values + configValue, ok := config["some_config_key"].(string) + if !ok || configValue == "" { + // Handle missing or empty configuration + return nil + } + + // Use the configuration... + } + ``` + +3. The admin interface will show a "Configure" button for plugins that require configuration. + ## Registering Plugins To use the plugin, register it in your application: @@ -161,3 +277,11 @@ func (a *App) Run() error { // ... } ``` + +Alternatively, you can register your plugin in its init() function: + +```go +func init() { + plugin.Register(New()) +} +``` diff --git a/docs/plugins.md b/docs/plugins.md index 84578e5..25df16c 100644 --- a/docs/plugins.md +++ b/docs/plugins.md @@ -14,6 +14,10 @@ - Remind Me: Reply to a message with `!remindme ` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time. +### Security + +- Domain Blocker: Blocks messages containing links from specified domains. Configure it per channel with a comma-separated list of domains to block. When a message contains a link matching any of the blocked domains, the bot will notify that the message contained a blocked domain. This plugin requires configuration through the admin interface. + ### Social Media - Twitter Link Expander: Automatically converts twitter.com and x.com links to fxtwitter.com links and removes tracking parameters. This allows for better media embedding in chat platforms. diff --git a/internal/admin/admin.go b/internal/admin/admin.go index c2a78ca..2b41820 100644 --- a/internal/admin/admin.go +++ b/internal/admin/admin.go @@ -98,6 +98,7 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin { "channel_detail.html", "plugin_list.html", "channel_plugins_list.html", + "channel_plugin_config.html", } for _, tf := range templateFiles { @@ -143,6 +144,7 @@ func (a *Admin) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/admin/channels", a.handleChannelList) mux.HandleFunc("/admin/channels/", a.handleChannelDetail) mux.HandleFunc("/admin/channelplugins", a.handleChannelPluginList) + mux.HandleFunc("/admin/channelplugins/config/", a.handleChannelPluginConfig) mux.HandleFunc("/admin/channelplugins/", a.handleChannelPluginDetailOrDelete) } @@ -194,7 +196,7 @@ func (a *Admin) addFlash(w http.ResponseWriter, r *http.Request, message string, } // Map internal categories to Bootstrap alert classes - alertClass := category + var alertClass string switch category { case "success": alertClass = "success" @@ -249,17 +251,6 @@ func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessag return messages } -// requireLogin middleware checks if the user is logged in -func (a *Admin) requireLogin(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if !a.isLoggedIn(r) { - http.Redirect(w, r, "/admin/login", http.StatusSeeOther) - return - } - next(w, r) - } -} - // render renders a template with the given data func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) { // Add current user data @@ -334,7 +325,10 @@ func (a *Admin) handleLogin(w http.ResponseWriter, r *http.Request) { // Set session expiration session.Options.MaxAge = 3600 * 24 * 7 // 1 week - session.Save(r, w) + err = session.Save(r, w) + if err != nil { + fmt.Printf("Error saving session: %v\n", err) + } a.addFlash(w, r, "You were logged in", "success") @@ -636,6 +630,96 @@ func (a *Admin) handleChannelPluginList(w http.ResponseWriter, r *http.Request) }) } +// handleChannelPluginConfig handles the channel plugin configuration route +func (a *Admin) handleChannelPluginConfig(w http.ResponseWriter, r *http.Request) { + // Check if user is logged in + if !a.isLoggedIn(r) { + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + + // Extract channel plugin ID from path + path := r.URL.Path + channelPluginID := strings.TrimPrefix(path, "/admin/channelplugins/config/") + + // Convert channel plugin ID to int64 + id, err := strconv.ParseInt(channelPluginID, 10, 64) + if err != nil { + http.Error(w, "Invalid channel plugin ID", http.StatusBadRequest) + return + } + + // Get the channel plugin + channelPlugin, err := a.db.GetChannelPluginByID(id) + if err != nil { + http.Error(w, "Channel plugin not found", http.StatusNotFound) + return + } + + // Get the plugin + p, err := plugin.Get(channelPlugin.PluginID) + if err != nil { + http.Error(w, "Plugin not found", http.StatusNotFound) + return + } + + // Handle form submission + if r.Method == http.MethodPost { + // Parse form + if err := r.ParseForm(); err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Create config map from form values + config := make(map[string]interface{}) + + // Process form values based on plugin type + if channelPlugin.PluginID == "security.domainblock" { + // Get blocked domains from form + blockedDomains := r.FormValue("blocked_domains") + config["blocked_domains"] = blockedDomains + } else { + // Generic handling for other plugins + for key, values := range r.Form { + if key == "form_submitted" { + continue + } + if len(values) == 1 { + config[key] = values[0] + } else { + config[key] = values + } + } + } + + // Update plugin configuration + if err := a.db.UpdateChannelPluginConfig(id, config); err != nil { + http.Error(w, "Failed to update plugin configuration", http.StatusInternalServerError) + return + } + + // Get the channel to redirect back to the channel detail page + channel, err := a.db.GetChannelByID(channelPlugin.ChannelID) + if err != nil { + a.addFlash(w, r, "Plugin configuration updated", "success") + http.Redirect(w, r, "/admin/channelplugins", http.StatusSeeOther) + return + } + + a.addFlash(w, r, "Plugin configuration updated", "success") + http.Redirect(w, r, fmt.Sprintf("/admin/channels/%d", channel.ID), http.StatusSeeOther) + return + } + + // Render template + a.render(w, r, "channel_plugin_config.html", TemplateData{ + Title: "Configure Plugin: " + p.GetName(), + ChannelPlugin: channelPlugin, + Plugins: map[string]model.Plugin{channelPlugin.PluginID: p}, + }) +} + // handleChannelPluginDetailOrDelete handles the channel plugin detail or delete route func (a *Admin) handleChannelPluginDetailOrDelete(w http.ResponseWriter, r *http.Request) { // Check if user is logged in diff --git a/internal/admin/templates/channel_detail.html b/internal/admin/templates/channel_detail.html index 764d7b1..78909df 100644 --- a/internal/admin/templates/channel_detail.html +++ b/internal/admin/templates/channel_detail.html @@ -68,6 +68,10 @@ {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} + {{$plugin := index $.Plugins $pluginID}} + {{if $plugin.RequiresConfig}} + Configure + {{end}}
diff --git a/internal/admin/templates/channel_plugin_config.html b/internal/admin/templates/channel_plugin_config.html new file mode 100644 index 0000000..decf1a2 --- /dev/null +++ b/internal/admin/templates/channel_plugin_config.html @@ -0,0 +1,37 @@ +{{define "content"}} +
+
+
+
+

Configure Plugin: {{(index .Plugins .ChannelPlugin.PluginID).GetName}}

+
+
+ + + {{if eq .ChannelPlugin.PluginID "security.domainblock"}} +
+ + +
+ Enter comma-separated list of domains to block (e.g., example.com, evil.org). + Messages containing links to these domains will be blocked. +
+
+ {{else}} +
+ This plugin doesn't have specific configuration fields implemented yet. +
+ {{end}} + + + +
+
+
+
+{{end}} diff --git a/internal/admin/templates/channel_plugins_list.html b/internal/admin/templates/channel_plugins_list.html index b57c60e..485150b 100644 --- a/internal/admin/templates/channel_plugins_list.html +++ b/internal/admin/templates/channel_plugins_list.html @@ -38,6 +38,10 @@ {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} + {{$plugin := index $.Plugins $pluginID}} + {{if $plugin.ConfigRequired}} + Configure + {{end}}
@@ -90,4 +94,4 @@ -{{end}} \ No newline at end of file +{{end}} diff --git a/internal/app/app.go b/internal/app/app.go index 1d878ab..becd5ea 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -20,9 +20,11 @@ import ( "git.nakama.town/fmartingr/butterrobot/internal/model" "git.nakama.town/fmartingr/butterrobot/internal/platform" "git.nakama.town/fmartingr/butterrobot/internal/plugin" + "git.nakama.town/fmartingr/butterrobot/internal/plugin/domainblock" "git.nakama.town/fmartingr/butterrobot/internal/plugin/fun" "git.nakama.town/fmartingr/butterrobot/internal/plugin/ping" "git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder" + "git.nakama.town/fmartingr/butterrobot/internal/plugin/searchreplace" "git.nakama.town/fmartingr/butterrobot/internal/plugin/social" "git.nakama.town/fmartingr/butterrobot/internal/queue" ) @@ -87,10 +89,9 @@ func (a *App) Run() error { plugin.Register(fun.NewLoquito()) plugin.Register(social.NewTwitterExpander()) plugin.Register(social.NewInstagramExpander()) - - // Register reminder plugin - reminderPlugin := reminder.New(a.db) - plugin.Register(reminderPlugin) + plugin.Register(reminder.New(a.db)) + plugin.Register(domainblock.New()) + plugin.Register(searchreplace.New()) // Initialize routes a.initializeRoutes() @@ -152,7 +153,9 @@ func (a *App) initializeRoutes() { a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]interface{}{}) + if err := json.NewEncoder(w).Encode(map[string]interface{}{}); err != nil { + a.logger.Error("Error encoding response", "error", err) + } }) // Platform webhook endpoints @@ -175,7 +178,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { if _, err := platform.Get(platformName); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}) + if err := json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}); err != nil { + a.logger.Error("Error encoding response", "error", err) + } return } @@ -184,7 +189,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}) + if err := json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}); err != nil { + a.logger.Error("Error encoding response", "error", err) + } return } @@ -200,7 +207,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { // Respond with success w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]any{}) + if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { + a.logger.Error("Error encoding response", "error", err) + } } // extractPlatformName extracts the platform name from the URL path @@ -296,19 +305,39 @@ func (a *App) handleMessage(item queue.Item) { continue } - // Process message - responses := p.OnMessage(message, channelPlugin.Config) + // Process message and get actions + actions := p.OnMessage(message, channelPlugin.Config) - // Send responses + // Get platform for processing actions platform, err := platform.Get(item.Platform) if err != nil { a.logger.Error("Error getting platform", "error", err) continue } - for _, response := range responses { - if err := platform.SendMessage(response); err != nil { - a.logger.Error("Error sending message", "error", err) + // Process each action + for _, action := range actions { + switch action.Type { + case model.ActionSendMessage: + // Send a message + if action.Message != nil { + if err := platform.SendMessage(action.Message); err != nil { + a.logger.Error("Error sending message", "error", err) + } + } else { + a.logger.Error("Send message action with nil message") + } + + case model.ActionDeleteMessage: + // Delete a message using direct DeleteMessage call + if err := platform.DeleteMessage(action.Chat, action.MessageID); err != nil { + a.logger.Error("Error deleting message", "error", err, "message_id", action.MessageID) + } else { + a.logger.Info("Message deleted", "message_id", action.MessageID) + } + + default: + a.logger.Error("Unknown action type", "type", action.Type) } } } diff --git a/internal/db/db.go b/internal/db/db.go index b71b543..0da285e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -234,7 +234,11 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e if err != nil { return nil, err } - defer rows.Close() + defer func() { + if err := rows.Close(); err != nil { + fmt.Printf("Error closing rows: %v\n", err) + } + }() var plugins []*model.ChannelPlugin @@ -382,6 +386,24 @@ func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error { return err } +// UpdateChannelPluginConfig updates a channel plugin's configuration +func (d *Database) UpdateChannelPluginConfig(id int64, config map[string]interface{}) error { + // Convert config to JSON + configJSON, err := json.Marshal(config) + if err != nil { + return err + } + + query := ` + UPDATE channel_plugin + SET config = ? + WHERE id = ? + ` + + _, err = d.db.Exec(query, string(configJSON), id) + return err +} + // DeleteChannelPlugin deletes a channel plugin func (d *Database) DeleteChannelPlugin(id int64) error { query := ` @@ -415,7 +437,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { if err != nil { return nil, err } - defer rows.Close() + defer func() { + if err := rows.Close(); err != nil { + fmt.Printf("Error closing rows: %v\n", err) + } + }() var channels []*model.Channel @@ -454,10 +480,9 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { continue // Skip this channel if plugins can't be retrieved } - if plugins != nil { - for _, plugin := range plugins { - channel.Plugins[plugin.PluginID] = plugin - } + // Add plugins to channel + for _, plugin := range plugins { + channel.Plugins[plugin.PluginID] = plugin } channels = append(channels, channel) @@ -646,7 +671,11 @@ func (d *Database) GetPendingReminders() ([]*model.Reminder, error) { if err != nil { return nil, err } - defer rows.Close() + defer func() { + if err := rows.Close(); err != nil { + fmt.Printf("Error closing rows: %v\n", err) + } + }() var reminders []*model.Reminder diff --git a/internal/migration/migration.go b/internal/migration/migration.go index dec4ff5..63da5d8 100644 --- a/internal/migration/migration.go +++ b/internal/migration/migration.go @@ -49,7 +49,11 @@ func GetAppliedMigrations(db *sql.DB) ([]int, error) { if err != nil { return nil, err } - defer rows.Close() + defer func() { + if err := rows.Close(); err != nil { + fmt.Printf("Error closing rows: %v\n", err) + } + }() var versions []int for rows.Next() { @@ -128,7 +132,9 @@ func Migrate(db *sql.DB) error { // Apply the migration if err := migration.Up(db); err != nil { - tx.Rollback() + if err := tx.Rollback(); err != nil { + fmt.Printf("Error rolling back transaction: %v\n", err) + } return fmt.Errorf("failed to apply migration %d: %w", version, err) } @@ -137,7 +143,9 @@ func Migrate(db *sql.DB) error { "INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", version, time.Now(), ); err != nil { - tx.Rollback() + if err := tx.Rollback(); err != nil { + fmt.Printf("Error rolling back transaction: %v\n", err) + } return fmt.Errorf("failed to mark migration %d as applied: %w", version, err) } @@ -188,13 +196,17 @@ func MigrateDown(db *sql.DB, targetVersion int) error { // Apply the down migration if err := migration.Down(db); err != nil { - tx.Rollback() + if err := tx.Rollback(); err != nil { + fmt.Printf("Error rolling back transaction: %v\n", err) + } return fmt.Errorf("failed to roll back migration %d: %w", version, err) } // Remove from applied list if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil { - tx.Rollback() + if err := tx.Rollback(); err != nil { + fmt.Printf("Error rolling back transaction: %v\n", err) + } return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err) } diff --git a/internal/model/message.go b/internal/model/message.go index e6f86f6..26ec5da 100644 --- a/internal/model/message.go +++ b/internal/model/message.go @@ -4,6 +4,26 @@ import ( "time" ) +// ActionType defines the type of action to perform +type ActionType string + +const ( + // ActionSendMessage is for sending a message to the chat + ActionSendMessage ActionType = "send_message" + // ActionDeleteMessage is for deleting a message from the chat + ActionDeleteMessage ActionType = "delete_message" +) + +// MessageAction represents an action to be performed on the platform +type MessageAction struct { + Type ActionType + Message *Message // For send_message + MessageID string // For delete_message + Chat string // Chat where the action happens + Channel *Channel // Channel reference + Raw map[string]interface{} // Additional data for the action +} + // Message represents a chat message type Message struct { Text string @@ -75,7 +95,7 @@ type ChannelPlugin struct { ChannelID int64 PluginID string Enabled bool - Config map[string]interface{} + Config map[string]any } // User represents an admin user diff --git a/internal/model/platform.go b/internal/model/platform.go index 01318eb..7d49ad3 100644 --- a/internal/model/platform.go +++ b/internal/model/platform.go @@ -43,4 +43,7 @@ type Platform interface { // SendMessage sends a message through the platform SendMessage(msg *Message) error + + // DeleteMessage deletes a message from the platform + DeleteMessage(channel string, messageID string) error } diff --git a/internal/model/plugin.go b/internal/model/plugin.go index 9f2b34a..03e4f96 100644 --- a/internal/model/plugin.go +++ b/internal/model/plugin.go @@ -23,6 +23,6 @@ type Plugin interface { // RequiresConfig indicates if the plugin requires configuration RequiresConfig() bool - // OnMessage processes an incoming message and returns response messages - OnMessage(msg *Message, config map[string]interface{}) []*Message + // OnMessage processes an incoming message and returns platform actions + OnMessage(msg *Message, config map[string]interface{}) []*MessageAction } diff --git a/internal/platform/slack/slack.go b/internal/platform/slack/slack.go index 3683ada..2ca7bef 100644 --- a/internal/platform/slack/slack.go +++ b/internal/platform/slack/slack.go @@ -4,7 +4,7 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" + "io" "net/http" "strings" "time" @@ -37,11 +37,15 @@ func (s *SlackPlatform) Init(_ *config.Config) error { // ParseIncomingMessage parses an incoming Slack message func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) { // Read request body - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { return nil, err } - defer r.Body.Close() + defer func() { + if err := r.Body.Close(); err != nil { + fmt.Printf("Error closing request body: %v\n", err) + } + }() // Parse JSON var requestData map[string]interface{} @@ -163,6 +167,12 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { return errors.New("bot token not configured") } + // Check for delete message action + if msg.Raw != nil && msg.Raw["action"] == "delete" { + // This is a request to delete a message + return s.deleteMessage(msg) + } + // Prepare payload payload := map[string]interface{}{ "channel": msg.Chat, @@ -194,7 +204,11 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { if err != nil { return err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("Error closing response body: %v\n", err) + } + }() // Check response if resp.StatusCode != http.StatusOK { @@ -204,6 +218,63 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { return nil } +// DeleteMessage deletes a message on Slack +func (s *SlackPlatform) DeleteMessage(channel string, messageID string) error { + // Prepare payload for chat.delete API + payload := map[string]interface{}{ + "channel": channel, + "ts": messageID, // In Slack, the ts (timestamp) is the message ID + } + + // Convert payload to JSON + data, err := json.Marshal(payload) + if err != nil { + return err + } + + // Send HTTP request to chat.delete endpoint + req, err := http.NewRequest("POST", "https://slack.com/api/chat.delete", strings.NewReader(string(data))) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.config.BotOAuthAccessToken)) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("Error closing response body: %v\n", err) + } + }() + + // Check response + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("slack API error: %d - %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// deleteMessage is a legacy method that uses the Raw message approach +func (s *SlackPlatform) deleteMessage(msg *model.Message) error { + // Get message ID to delete + messageID, ok := msg.Raw["message_id"] + if !ok { + return fmt.Errorf("no message ID provided for deletion") + } + + // Convert to string if needed + messageIDStr := fmt.Sprintf("%v", messageID) + + return s.DeleteMessage(msg.Chat, messageIDStr) +} + // Helper function to parse int64 func parseInt64(s string) (int64, error) { var n int64 diff --git a/internal/platform/telegram/telegram.go b/internal/platform/telegram/telegram.go index 6c9a2b3..8da4995 100644 --- a/internal/platform/telegram/telegram.go +++ b/internal/platform/telegram/telegram.go @@ -62,7 +62,11 @@ func (t *TelegramPlatform) Init(cfg *config.Config) error { t.log.Error("Failed to set webhook", "error", err) return fmt.Errorf("failed to set webhook: %w", err) } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + t.log.Error("Error closing response body", "error", err) + } + }() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) @@ -85,7 +89,11 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message t.log.Error("Failed to read request body", "error", err) return nil, err } - defer r.Body.Close() + defer func() { + if err := r.Body.Close(); err != nil { + t.log.Error("Error closing request body", "error", err) + } + }() // Parse JSON var update struct { @@ -209,6 +217,13 @@ func (t *TelegramPlatform) ParseChannelFromMessage(body []byte) (map[string]any, // SendMessage sends a message to Telegram func (t *TelegramPlatform) SendMessage(msg *model.Message) error { + // Check for delete message action (legacy method) + if msg.Raw != nil && msg.Raw["action"] == "delete" { + // This is a request to delete a message using the legacy method + return t.deleteMessage(msg) + } + + // Regular message sending // Convert chat ID to int64 chatID, err := strconv.ParseInt(msg.Chat, 10, 64) if err != nil { @@ -251,7 +266,11 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error { t.log.Error("Failed to send message", "error", err) return err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + t.log.Error("Error closing response body", "error", err) + } + }() // Check response if resp.StatusCode != http.StatusOK { @@ -264,3 +283,88 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error { t.log.Debug("Message sent successfully") return nil } + +// DeleteMessage deletes a message on Telegram +func (t *TelegramPlatform) DeleteMessage(channel string, messageID string) error { + // Convert chat ID to int64 + chatID, err := strconv.ParseInt(channel, 10, 64) + if err != nil { + t.log.Error("Invalid chat ID for message deletion", "chat_id", channel, "error", err) + return err + } + + // Convert message ID to integer + msgID, err := strconv.Atoi(messageID) + if err != nil { + t.log.Error("Invalid message ID for deletion", "message_id", messageID, "error", err) + return err + } + + // Prepare payload for deleteMessage API + payload := map[string]interface{}{ + "chat_id": chatID, + "message_id": msgID, + } + + t.log.Debug("Deleting message on Telegram", "chat_id", chatID, "message_id", msgID) + + // Convert payload to JSON + data, err := json.Marshal(payload) + if err != nil { + t.log.Error("Failed to marshal delete message payload", "error", err) + return err + } + + // Send HTTP request to deleteMessage endpoint + resp, err := http.Post( + t.apiURL+"/deleteMessage", + "application/json", + bytes.NewBuffer(data), + ) + if err != nil { + t.log.Error("Failed to delete message", "error", err) + return err + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.log.Error("Error closing response body", "error", err) + } + }() + + // Check response + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + errMsg := string(bodyBytes) + t.log.Error("Telegram API error when deleting message", "status", resp.StatusCode, "response", errMsg) + return fmt.Errorf("telegram API error when deleting message: %d - %s", resp.StatusCode, errMsg) + } + + t.log.Debug("Message deleted successfully") + return nil +} + +// deleteMessage is a legacy method that uses the Raw message approach +func (t *TelegramPlatform) deleteMessage(msg *model.Message) error { + // Get message ID to delete + messageIDInterface, ok := msg.Raw["message_id"] + if !ok { + t.log.Error("No message ID provided for deletion") + return fmt.Errorf("no message ID provided for deletion") + } + + // Convert message ID to string + var messageIDStr string + switch v := messageIDInterface.(type) { + case string: + messageIDStr = v + case int: + messageIDStr = strconv.Itoa(v) + case float64: + messageIDStr = strconv.Itoa(int(v)) + default: + t.log.Error("Invalid message ID type for deletion", "type", fmt.Sprintf("%T", messageIDInterface)) + return fmt.Errorf("invalid message ID type for deletion") + } + + return t.DeleteMessage(msg.Chat, messageIDStr) +} diff --git a/internal/plugin/domainblock/domainblock.go b/internal/plugin/domainblock/domainblock.go new file mode 100644 index 0000000..5a44c49 --- /dev/null +++ b/internal/plugin/domainblock/domainblock.go @@ -0,0 +1,132 @@ +package domainblock + +import ( + "fmt" + "net/url" + "regexp" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// DomainBlockPlugin is a plugin that blocks messages containing links from specific domains +type DomainBlockPlugin struct { + plugin.BasePlugin +} + +// Debug helper to check if RequiresConfig is working +func (p *DomainBlockPlugin) RequiresConfig() bool { + return true +} + +// New creates a new DomainBlockPlugin instance +func New() *DomainBlockPlugin { + return &DomainBlockPlugin{ + BasePlugin: plugin.BasePlugin{ + ID: "security.domainblock", + Name: "Domain Blocker", + Help: "Blocks messages containing links from configured domains", + ConfigRequired: true, + }, + } +} + +// extractDomains extracts domains from a message text +func extractDomains(text string) []string { + // URL regex pattern + urlPattern := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`) + matches := urlPattern.FindAllStringSubmatch(text, -1) + + domains := make([]string, 0, len(matches)) + for _, match := range matches { + if len(match) < 2 { + continue + } + + // Try to parse the URL to extract the domain + urlStr := match[0] + parsedURL, err := url.Parse(urlStr) + if err != nil { + continue + } + + // Extract the domain (host) from the URL + domain := parsedURL.Host + // Remove port if present + if i := strings.IndexByte(domain, ':'); i >= 0 { + domain = domain[:i] + } + + domains = append(domains, strings.ToLower(domain)) + } + + return domains +} + +// OnMessage processes incoming messages +func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { + // Skip messages from bots + if msg.FromBot { + return nil + } + + // Get blocked domains from config + blockedDomainsStr, ok := config["blocked_domains"].(string) + if !ok || blockedDomainsStr == "" { + return nil // No blocked domains configured + } + + // Split and clean blocked domains + blockedDomains := strings.Split(blockedDomainsStr, ",") + for i, domain := range blockedDomains { + blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + } + + // Extract domains from message + messageDomains := extractDomains(msg.Text) + if len(messageDomains) == 0 { + return nil // No domains in message + } + + // Check if any domains in the message are blocked + for _, msgDomain := range messageDomains { + for _, blockedDomain := range blockedDomains { + if blockedDomain == "" { + continue + } + + if strings.HasSuffix(msgDomain, blockedDomain) || msgDomain == blockedDomain { + // Domain is blocked, create actions + + // 1. Create a delete message action + deleteAction := &model.MessageAction{ + Type: model.ActionDeleteMessage, + MessageID: msg.ID, + Chat: msg.Chat, + Channel: msg.Channel, + } + + // 2. Create a notification message action + notificationMsg := &model.Message{ + Text: fmt.Sprintf("I don't like links from %s 🙈", blockedDomain), + Chat: msg.Chat, + Channel: msg.Channel, + } + + sendAction := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: notificationMsg, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{deleteAction, sendAction} + } + } + } + + return nil +} + +// Plugin is registered in app.go, not using init() diff --git a/internal/plugin/domainblock/domainblock_test.go b/internal/plugin/domainblock/domainblock_test.go new file mode 100644 index 0000000..1d65964 --- /dev/null +++ b/internal/plugin/domainblock/domainblock_test.go @@ -0,0 +1,140 @@ +package domainblock + +import ( + "testing" + + "git.nakama.town/fmartingr/butterrobot/internal/model" +) + +func TestExtractDomains(t *testing.T) { + tests := []struct { + name string + text string + expected []string + }{ + { + name: "No URLs", + text: "Hello, world!", + expected: []string{}, + }, + { + name: "Single URL", + text: "Check out https://example.com for more info", + expected: []string{"example.com"}, + }, + { + name: "Multiple URLs", + text: "Check out https://example.com and http://test.example.org for more info", + expected: []string{"example.com", "test.example.org"}, + }, + { + name: "URL with path", + text: "Check out https://example.com/path/to/resource", + expected: []string{"example.com"}, + }, + { + name: "URL with port", + text: "Check out https://example.com:8080/path/to/resource", + expected: []string{"example.com"}, + }, + { + name: "URL with subdomain", + text: "Check out https://sub.example.com", + expected: []string{"sub.example.com"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + domains := extractDomains(test.text) + + if len(domains) != len(test.expected) { + t.Errorf("Expected %d domains, got %d", len(test.expected), len(domains)) + return + } + + for i, domain := range domains { + if domain != test.expected[i] { + t.Errorf("Expected domain %s, got %s", test.expected[i], domain) + } + } + }) + } +} + +func TestOnMessage(t *testing.T) { + plugin := New() + + tests := []struct { + name string + text string + blockedDomains string + expectBlocked bool + }{ + { + name: "No blocked domains", + text: "Check out https://example.com", + blockedDomains: "", + expectBlocked: false, + }, + { + name: "No matching domain", + text: "Check out https://example.com", + blockedDomains: "bad.com, evil.org", + expectBlocked: false, + }, + { + name: "Matching domain", + text: "Check out https://example.com", + blockedDomains: "example.com, evil.org", + expectBlocked: true, + }, + { + name: "Matching subdomain", + text: "Check out https://sub.example.com", + blockedDomains: "example.com", + expectBlocked: true, + }, + { + name: "Multiple domains, one matching", + text: "Check out https://example.com and https://good.org", + blockedDomains: "bad.com, example.com", + expectBlocked: true, + }, + { + name: "Spaces in blocked domains list", + text: "Check out https://example.com", + blockedDomains: "bad.com, example.com , evil.org", + expectBlocked: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + msg := &model.Message{ + Text: test.text, + Chat: "test-chat", + ID: "test-id", + Channel: &model.Channel{ + ID: 1, + }, + } + + config := map[string]interface{}{ + "blocked_domains": test.blockedDomains, + } + + responses := plugin.OnMessage(msg, config) + + if test.expectBlocked { + if len(responses) == 0 { + t.Errorf("Expected message to be blocked, but it wasn't") + } + } else { + if len(responses) > 0 { + t.Errorf("Expected message not to be blocked, but it was") + } + } + }) + } +} diff --git a/internal/plugin/fun/coin.go b/internal/plugin/fun/coin.go index 8e12a8d..bd083d1 100644 --- a/internal/plugin/fun/coin.go +++ b/internal/plugin/fun/coin.go @@ -29,7 +29,7 @@ func NewCoin() *CoinPlugin { } // OnMessage handles incoming messages -func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") { return nil } @@ -46,5 +46,12 @@ func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{} Channel: msg.Channel, } - return []*model.Message{response} + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } diff --git a/internal/plugin/fun/dice.go b/internal/plugin/fun/dice.go index 00fc7cc..8b13edb 100644 --- a/internal/plugin/fun/dice.go +++ b/internal/plugin/fun/dice.go @@ -32,7 +32,7 @@ func NewDice() *DicePlugin { } // OnMessage handles incoming messages -func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") { return nil } @@ -62,7 +62,14 @@ func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{} Channel: msg.Channel, } - return []*model.Message{response} + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } // rollDice parses a dice formula string and returns the result @@ -107,9 +114,10 @@ func (p *DicePlugin) rollDice(formula string) (int, error) { return 0, fmt.Errorf("invalid modifier") } - if matches[3] == "+" { + switch matches[3] { + case "+": total += modifier - } else if matches[3] == "-" { + case "-": total -= modifier } } diff --git a/internal/plugin/fun/loquito.go b/internal/plugin/fun/loquito.go index 7b0ea43..fef78bd 100644 --- a/internal/plugin/fun/loquito.go +++ b/internal/plugin/fun/loquito.go @@ -24,7 +24,7 @@ func NewLoquito() *LoquitoPlugin { } // OnMessage handles incoming messages -func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { if !strings.Contains(strings.ToLower(msg.Text), "lo quito") { return nil } @@ -36,5 +36,12 @@ func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interfac Channel: msg.Channel, } - return []*model.Message{response} + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } diff --git a/internal/plugin/ping/ping.go b/internal/plugin/ping/ping.go index b09caaf..3dacf6f 100644 --- a/internal/plugin/ping/ping.go +++ b/internal/plugin/ping/ping.go @@ -24,11 +24,12 @@ func New() *PingPlugin { } // OnMessage handles incoming messages -func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") { return nil } + // Create the response message response := &model.Message{ Text: "pong", Chat: msg.Chat, @@ -36,5 +37,13 @@ func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{} Channel: msg.Channel, } - return []*model.Message{response} + // Create an action to send the message + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go index 69da2c2..eb3789f 100644 --- a/internal/plugin/plugin.go +++ b/internal/plugin/plugin.go @@ -1,6 +1,7 @@ package plugin import ( + "maps" "sync" "git.nakama.town/fmartingr/butterrobot/internal/model" @@ -41,9 +42,7 @@ func GetAvailablePlugins() map[string]model.Plugin { // Create a copy to avoid race conditions result := make(map[string]model.Plugin, len(plugins)) - for id, plugin := range plugins { - result[id] = plugin - } + maps.Copy(result, plugins) return result } @@ -77,6 +76,6 @@ func (p *BasePlugin) RequiresConfig() bool { } // OnMessage is the default implementation that does nothing -func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { return nil } diff --git a/internal/plugin/reminder/reminder.go b/internal/plugin/reminder/reminder.go index 6d7c1aa..029c8d9 100644 --- a/internal/plugin/reminder/reminder.go +++ b/internal/plugin/reminder/reminder.go @@ -41,17 +41,10 @@ func New(creator ReminderCreator) *Reminder { } // OnMessage processes incoming messages -func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { // Only process replies to messages if msg.ReplyTo == "" { - return []*model.Message{ - { - Text: "Please reply to a message with `!remindme ` to set a reminder.", - Chat: msg.Chat, - Channel: msg.Channel, - ReplyTo: msg.ID, - }, - } + return nil } // Check if the message is a reminder command @@ -63,15 +56,22 @@ func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) // Parse the duration amount, err := strconv.Atoi(match[1]) if err != nil { - return []*model.Message{ + errorMsg := &model.Message{ + Text: "Invalid duration format. Please use a number followed by y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).", + Chat: msg.Chat, + Channel: msg.Channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: msg.ID, + } + + return []*model.MessageAction{ { - Text: "Invalid duration format. Please use a number followed by y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).", + Type: model.ActionSendMessage, + Message: errorMsg, Chat: msg.Chat, Channel: msg.Channel, - Author: "bot", - FromBot: true, - Date: time.Now(), - ReplyTo: msg.ID, }, } } @@ -93,15 +93,22 @@ func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) case "s": duration = time.Duration(amount) * time.Second default: - return []*model.Message{ + errorMsg := &model.Message{ + Text: "Invalid duration unit. Please use y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).", + Chat: msg.Chat, + Channel: msg.Channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: msg.ID, + } + + return []*model.MessageAction{ { - Text: "Invalid duration unit. Please use y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).", + Type: model.ActionSendMessage, + Message: errorMsg, Chat: msg.Chat, Channel: msg.Channel, - Author: "bot", - FromBot: true, - Date: time.Now(), - ReplyTo: msg.ID, }, } } @@ -134,15 +141,22 @@ func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) ) if err != nil { - return []*model.Message{ + errorMsg := &model.Message{ + Text: fmt.Sprintf("Failed to create reminder: %v", err), + Chat: msg.Chat, + Channel: msg.Channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: msg.ID, + } + + return []*model.MessageAction{ { - Text: fmt.Sprintf("Failed to create reminder: %v", err), + Type: model.ActionSendMessage, + Message: errorMsg, Chat: msg.Chat, Channel: msg.Channel, - Author: "bot", - FromBot: true, - Date: time.Now(), - ReplyTo: msg.ID, }, } } @@ -164,15 +178,23 @@ func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) confirmText = fmt.Sprintf("I'll remind you about this message in %d second(s)", amount) } - return []*model.Message{ + // Create confirmation message + confirmMsg := &model.Message{ + Text: confirmText, + Chat: msg.Chat, + Channel: msg.Channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: msg.ID, + } + + return []*model.MessageAction{ { - Text: confirmText, + Type: model.ActionSendMessage, + Message: confirmMsg, Chat: msg.Chat, Channel: msg.Channel, - Author: "bot", - FromBot: true, - Date: time.Now(), - ReplyTo: msg.ID, }, } } diff --git a/internal/plugin/reminder/reminder_test.go b/internal/plugin/reminder/reminder_test.go index b76fd2f..8e611ce 100644 --- a/internal/plugin/reminder/reminder_test.go +++ b/internal/plugin/reminder/reminder_test.go @@ -142,14 +142,25 @@ func TestReminderOnMessage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { initialCount := len(creator.reminders) - responses := plugin.OnMessage(tt.message, nil) + actions := plugin.OnMessage(tt.message, nil) - if tt.expectResponse && len(responses) == 0 { - t.Errorf("Expected response, but got none") + if tt.expectResponse && len(actions) == 0 { + t.Errorf("Expected response action, but got none") } - if !tt.expectResponse && len(responses) > 0 { - t.Errorf("Expected no response, but got %d", len(responses)) + if !tt.expectResponse && len(actions) > 0 { + t.Errorf("Expected no actions, but got %d", len(actions)) + } + + // Verify action type is correct when actions are returned + if len(actions) > 0 { + if actions[0].Type != model.ActionSendMessage { + t.Errorf("Expected action type to be %s, but got %s", model.ActionSendMessage, actions[0].Type) + } + + if actions[0].Message == nil { + t.Errorf("Expected message in action to not be nil") + } } if tt.expectReminder && len(creator.reminders) != initialCount+1 { @@ -161,4 +172,4 @@ func TestReminderOnMessage(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/internal/plugin/searchreplace/README.md b/internal/plugin/searchreplace/README.md new file mode 100644 index 0000000..c7b7786 --- /dev/null +++ b/internal/plugin/searchreplace/README.md @@ -0,0 +1,50 @@ +# Search and Replace Plugin + +This plugin allows users to perform search and replace operations on messages by replying to a message with a search/replace command. + +## Usage + +To use the plugin, reply to any message with a command in the following format: + +``` +s/search/replace/[flags] +``` + +Where: +- `search` is the text you want to find (case-sensitive by default) +- `replace` is the text you want to substitute in place of the search term +- `flags` (optional) control the behavior of the replacement + +### Supported Flags + +- `g` - Global: Replace all occurrences of the search term (without this flag, only the first occurrence is replaced) +- `i` - Case insensitive: Match regardless of case +- `n` - Treat search pattern as a regular expression (advanced users) + +### Examples + +1. Basic replacement (replaces first occurrence): + ``` + s/hello/hi/ + ``` + +2. Global replacement (replaces all occurrences): + ``` + s/hello/hi/g + ``` + +3. Case-insensitive replacement: + ``` + s/Hello/hi/i + ``` + +4. Combined flags (global and case-insensitive): + ``` + s/hello/hi/gi + ``` + +## Limitations + +- The plugin can only access the text content of the original message +- Regular expression support is available with the `n` flag, but should be used carefully as invalid regex patterns will cause errors +- The plugin does not modify the original message; it creates a new message with the replaced text \ No newline at end of file diff --git a/internal/plugin/searchreplace/searchreplace.go b/internal/plugin/searchreplace/searchreplace.go new file mode 100644 index 0000000..876e880 --- /dev/null +++ b/internal/plugin/searchreplace/searchreplace.go @@ -0,0 +1,182 @@ +package searchreplace + +import ( + "fmt" + "regexp" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// Regex pattern for search and replace operations: s/search/replace/[flags] +var searchReplacePattern = regexp.MustCompile(`^s/([^/]*)/([^/]*)(?:/([gimnsuy]*))?$`) + +// SearchReplacePlugin is a plugin for performing search and replace operations on messages +type SearchReplacePlugin struct { + plugin.BasePlugin +} + +// New creates a new SearchReplacePlugin instance +func New() *SearchReplacePlugin { + return &SearchReplacePlugin{ + BasePlugin: plugin.BasePlugin{ + ID: "util.searchreplace", + Name: "Search and Replace", + Help: "Reply to a message with a search and replace pattern (s/search/replace/[flags]) to create a modified message. " + + "Supported flags: g (global), i (case insensitive)", + }, + } +} + +// OnMessage handles incoming messages +func (p *SearchReplacePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { + // Only process replies to messages + if msg.ReplyTo == "" { + return nil + } + + // Check if the message matches the search/replace pattern + match := searchReplacePattern.FindStringSubmatch(strings.TrimSpace(msg.Text)) + if match == nil { + return nil + } + + // Get the original message text from the reply_to_message structure in Telegram messages + var originalText string + + // For Telegram messages + if msgData, ok := msg.Raw["message"].(map[string]interface{}); ok { + if replyMsg, ok := msgData["reply_to_message"].(map[string]interface{}); ok { + if text, ok := replyMsg["text"].(string); ok { + originalText = text + } + } + } + + // Generic fallback for other platforms or if the above method fails + if originalText == "" && msg.Raw["original_message"] != nil { + if original, ok := msg.Raw["original_message"].(map[string]interface{}); ok { + if text, ok := original["text"].(string); ok { + originalText = text + } + } + } + + if originalText == "" { + // If we couldn't find the original message text, inform the user + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: &model.Message{ + Text: "Sorry, I couldn't find the original message text to perform the replacement.", + Chat: msg.Chat, + Channel: msg.Channel, + ReplyTo: msg.ID, + }, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } + } + + // Extract search pattern, replacement and flags + searchPattern := match[1] + replacement := match[2] + flags := "" + if len(match) > 3 { + flags = match[3] + } + + // Process the replacement + result, err := p.performReplacement(originalText, searchPattern, replacement, flags) + if err != nil { + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: &model.Message{ + Text: fmt.Sprintf("Error performing replacement: %s", err.Error()), + Chat: msg.Chat, + Channel: msg.Channel, + ReplyTo: msg.ID, + }, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } + } + + // Only send a response if the text actually changed + if result == originalText { + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: &model.Message{ + Text: "No changes were made to the original message.", + Chat: msg.Chat, + Channel: msg.Channel, + ReplyTo: msg.ID, + }, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } + } + + // Create a response with the modified text + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: &model.Message{ + Text: result, + Chat: msg.Chat, + Channel: msg.Channel, + ReplyTo: msg.ReplyTo, // Reply to the original message + }, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } +} + +// performReplacement performs the search and replace operation on the given text +func (p *SearchReplacePlugin) performReplacement(text, search, replace, flags string) (string, error) { + // Process flags + globalReplace := strings.Contains(flags, "g") + caseInsensitive := strings.Contains(flags, "i") + + // Create the regex pattern + pattern := search + regexFlags := "" + if caseInsensitive { + regexFlags += "(?i)" + } + + // Escape special characters if we're not in a regular expression + if !strings.Contains(flags, "n") { + pattern = regexp.QuoteMeta(pattern) + } + + // Compile the regex + reg, err := regexp.Compile(regexFlags + pattern) + if err != nil { + return "", fmt.Errorf("invalid search pattern: %v", err) + } + + // Perform the replacement + var result string + if globalReplace { + result = reg.ReplaceAllString(text, replace) + } else { + // For non-global replace, only replace the first occurrence + indices := reg.FindStringIndex(text) + if indices == nil { + // No match found + return text, nil + } + + result = text[:indices[0]] + replace + text[indices[1]:] + } + + return result, nil +} diff --git a/internal/plugin/searchreplace/searchreplace_test.go b/internal/plugin/searchreplace/searchreplace_test.go new file mode 100644 index 0000000..415610c --- /dev/null +++ b/internal/plugin/searchreplace/searchreplace_test.go @@ -0,0 +1,216 @@ +package searchreplace + +import ( + "testing" + "time" + + "git.nakama.town/fmartingr/butterrobot/internal/model" +) + +func TestSearchReplace(t *testing.T) { + // Create plugin instance + p := New() + + // Test cases + tests := []struct { + name string + command string + originalText string + expectedResult string + expectActions bool + }{ + { + name: "Simple replacement", + command: "s/hello/world/", + originalText: "hello everyone", + expectedResult: "world everyone", + expectActions: true, + }, + { + name: "Case-insensitive replacement", + command: "s/HELLO/world/i", + originalText: "Hello everyone", + expectedResult: "world everyone", + expectActions: true, + }, + { + name: "Global replacement", + command: "s/a/X/g", + originalText: "banana", + expectedResult: "bXnXnX", + expectActions: true, + }, + { + name: "No change", + command: "s/nothing/something/", + originalText: "test message", + expectedResult: "test message", + expectActions: true, // We send a "no changes" message + }, + { + name: "Not a search/replace command", + command: "hello", + originalText: "test message", + expectedResult: "", + expectActions: false, + }, + { + name: "Invalid pattern", + command: "s/(/)/", + originalText: "test message", + expectedResult: "error", + expectActions: true, // We send an error message + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create message + msg := &model.Message{ + Text: tc.command, + Chat: "test-chat", + ReplyTo: "original-message-id", + Date: time.Now(), + Channel: &model.Channel{ + Platform: "test", + }, + Raw: map[string]interface{}{ + "message": map[string]interface{}{ + "reply_to_message": map[string]interface{}{ + "text": tc.originalText, + }, + }, + }, + } + + // Process message + actions := p.OnMessage(msg, nil) + + // Check results + if tc.expectActions { + if len(actions) == 0 { + t.Fatalf("Expected actions but got none") + } + + action := actions[0] + if action.Type != model.ActionSendMessage { + t.Fatalf("Expected send message action but got %v", action.Type) + } + + if tc.expectedResult == "error" { + // Just checking that we got an error message + if action.Message == nil || action.Message.Text == "" { + t.Fatalf("Expected error message but got empty message") + } + } else if tc.originalText == tc.expectedResult { + // Check if we got the "no changes" message + if action.Message == nil || action.Message.Text != "No changes were made to the original message." { + t.Fatalf("Expected 'no changes' message but got: %s", action.Message.Text) + } + } else { + // Check actual replacement result + if action.Message == nil || action.Message.Text != tc.expectedResult { + t.Fatalf("Expected result: %s, got: %s", tc.expectedResult, action.Message.Text) + } + } + } else if len(actions) > 0 { + t.Fatalf("Expected no actions but got %d", len(actions)) + } + }) + } +} + +func TestPerformReplacement(t *testing.T) { + p := New() + + // Test cases for the performReplacement function + tests := []struct { + name string + text string + search string + replace string + flags string + expected string + expectErr bool + }{ + { + name: "Simple replacement", + text: "Hello World", + search: "Hello", + replace: "Hi", + flags: "", + expected: "Hi World", + expectErr: false, + }, + { + name: "Case insensitive", + text: "Hello World", + search: "hello", + replace: "Hi", + flags: "i", + expected: "Hi World", + expectErr: false, + }, + { + name: "Global replacement", + text: "one two one two", + search: "one", + replace: "1", + flags: "g", + expected: "1 two 1 two", + expectErr: false, + }, + { + name: "No match", + text: "Hello World", + search: "Goodbye", + replace: "Hi", + flags: "", + expected: "Hello World", + expectErr: false, + }, + { + name: "Invalid regex", + text: "Hello World", + search: "(", + replace: "Hi", + flags: "n", // treat as regex + expected: "", + expectErr: true, + }, + { + name: "Escape special chars by default", + text: "Hello (World)", + search: "(World)", + replace: "[Earth]", + flags: "", + expected: "Hello [Earth]", + expectErr: false, + }, + { + name: "Regex mode with n flag", + text: "Hello (World)", + search: "\\(World\\)", + replace: "[Earth]", + flags: "n", + expected: "Hello [Earth]", + expectErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := p.performReplacement(tc.text, tc.search, tc.replace, tc.flags) + + if tc.expectErr { + if err == nil { + t.Fatalf("Expected error but got none") + } + } else if err != nil { + t.Fatalf("Unexpected error: %v", err) + } else if result != tc.expected { + t.Fatalf("Expected result: %s, got: %s", tc.expected, result) + } + }) + } +} diff --git a/internal/plugin/social/instagram.go b/internal/plugin/social/instagram.go index a4f758a..0b4ff55 100644 --- a/internal/plugin/social/instagram.go +++ b/internal/plugin/social/instagram.go @@ -26,7 +26,7 @@ func NewInstagramExpander() *InstagramExpander { } // OnMessage handles incoming messages -func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { // Skip empty messages if strings.TrimSpace(msg.Text) == "" { return nil @@ -48,14 +48,16 @@ func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]inte parsedURL, err := url.Parse(link) if err != nil { // If parsing fails, just do the simple replacement - link = strings.Replace(link, "instagram.com", "ddinstagram.com", 1) + return link + } + + // Ensure we don't change links that already come from ddinstagram.com + if parsedURL.Host != "instagram.com" && parsedURL.Host != "www.instagram.com" { return link } // Change the host - if strings.Contains(parsedURL.Host, "instagram.com") { - parsedURL.Host = strings.Replace(parsedURL.Host, "instagram.com", "ddinstagram.com", 1) - } + parsedURL.Host = "d.ddinstagram.com" // Remove query parameters parsedURL.RawQuery = "" @@ -72,5 +74,12 @@ func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]inte Channel: msg.Channel, } - return []*model.Message{response} + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } diff --git a/internal/plugin/social/twitter.go b/internal/plugin/social/twitter.go index 837b6c9..865f421 100644 --- a/internal/plugin/social/twitter.go +++ b/internal/plugin/social/twitter.go @@ -26,7 +26,7 @@ func NewTwitterExpander() *TwitterExpander { } // OnMessage handles incoming messages -func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { // Skip empty messages if strings.TrimSpace(msg.Text) == "" { return nil @@ -75,5 +75,12 @@ func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interf Channel: msg.Channel, } - return []*model.Message{response} + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} }