diff --git a/.goreleaser.yml b/.goreleaser.yml index c89e189..a3836e9 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -93,7 +93,7 @@ docker_manifests: nfpms: - maintainer: Felipe Martin - description: SMTP server to forward messages to shoutrrr endpoints + description: A chatbot server with customizable commands and triggers homepage: https://git.nakama.town/fmartingr/butterrobot license: AGPL-3.0 formats: diff --git a/.woodpecker/ci.yml b/.woodpecker/ci.yml index 4353088..5b32d48 100644 --- a/.woodpecker/ci.yml +++ b/.woodpecker/ci.yml @@ -3,7 +3,7 @@ when: - push - pull_request branch: - - main + - master steps: format: diff --git a/.woodpecker/release.yml b/.woodpecker/release.yml index b24eb15..39dbf65 100644 --- a/.woodpecker/release.yml +++ b/.woodpecker/release.yml @@ -1,6 +1,6 @@ when: - event: tag - branch: main + branch: master steps: - name: Release @@ -13,4 +13,4 @@ steps: - "/var/run/docker.sock:/var/run/docker.sock" commands: - docker login -u fmartingr -p $GITEA_TOKEN git.nakama.town - - goreleaser release --clean + - goreleaser release --clean --parallelism=2 diff --git a/README.md b/README.md index 36ec708..920d087 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,6 @@ # Butter Robot -| Stable | Master | -| --- | --- | -| ![Build stable tag docker image](https://git.nakama.town/fmartingr/butterrobot/workflows/Build%20stable%20tag%20docker%20image/badge.svg?branch=stable) | ![Build latest tag docker image](https://git.nakama.town/fmartingr/butterrobot/workflows/Build%20latest%20tag%20docker%20image/badge.svg?branch=master) | -| ![Test](https://git.nakama.town/fmartingr/butterrobot/workflows/Test/badge.svg?branch=stable) | ![Test](https://git.nakama.town/fmartingr/butterrobot/workflows/Test/badge.svg?branch=master) | +![Status badge](https://woodpecker.local.fmartingr.dev/api/badges/5/status.svg) Go framework to create bots for several platforms. @@ -13,7 +10,7 @@ Go framework to create bots for several platforms. ## Features -- Support for multiple chat platforms (Slack, Telegram) +- Support for multiple chat platforms (Slack (untested!), Telegram) - Plugin system for easy extension - Admin interface for managing channels and plugins - Message queue for asynchronous processing @@ -22,6 +19,12 @@ Go framework to create bots for several platforms. [Go to documentation](./docs) +### Database Management + +ButterRobot includes an automatic database migration system. Migrations are applied automatically when the application starts, ensuring your database schema is always up to date. + +[Learn more about migrations](./docs/migrations.md) + ## Installation ### From Source diff --git a/cmd/butterrobot/main.go b/cmd/butterrobot/main.go index 5cf57f9..3bc56cb 100644 --- a/cmd/butterrobot/main.go +++ b/cmd/butterrobot/main.go @@ -1,11 +1,15 @@ package main import ( + "fmt" "log/slog" "os" + "runtime/debug" "git.nakama.town/fmartingr/butterrobot/internal/app" "git.nakama.town/fmartingr/butterrobot/internal/config" + + _ "golang.org/x/crypto/x509roots/fallback" ) func main() { @@ -19,15 +23,26 @@ func main() { os.Exit(1) } + // Handle version command + if len(os.Args) > 1 && os.Args[1] == "version" { + info, ok := debug.ReadBuildInfo() + if ok { + fmt.Printf("ButterRobot version %s\n", info.Main.Version) + } else { + fmt.Println("ButterRobot. Can't determine build information.") + } + return + } + // Initialize and run application application, err := app.New(cfg, logger) if err != nil { logger.Error("Failed to initialize application", "error", err) os.Exit(1) } - + if err := application.Run(); err != nil { logger.Error("Application error", "error", err) os.Exit(1) } -} \ No newline at end of file +} diff --git a/docs/creating-a-plugin.md b/docs/creating-a-plugin.md index 945d03c..b8e4a78 100644 --- a/docs/creating-a-plugin.md +++ b/docs/creating-a-plugin.md @@ -1,6 +1,19 @@ # Creating a Plugin -## Example +## Plugin Categories + +ButterRobot organizes plugins into different categories: + +- **Development**: Utility plugins like `ping` +- **Fun**: Entertainment plugins like dice rolling, coin flipping +- **Social**: Social media related plugins like URL transformers/expanders +- **Security**: Moderation and protection features like domain blocking + +When creating a new plugin, consider which category it fits into and place it in the appropriate directory. + +## Plugin Examples + +### Basic Example: Marco Polo This simple "Marco Polo" plugin will answer _Polo_ to the user that says _Marco_: @@ -47,6 +60,207 @@ func (p *MarcoPlugin) OnMessage(msg *model.Message, config map[string]interface{ } ``` +### Configuration-Enabled Plugin + +This plugin requires configuration to be set in the admin interface. It demonstrates how to create plugins that need channel-specific configuration: + +```go +package security + +import ( + "fmt" + "regexp" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// DomainBlockPlugin is a plugin that blocks messages containing links from specific domains +type DomainBlockPlugin struct { + plugin.BasePlugin +} + +// New creates a new DomainBlockPlugin instance +func New() *DomainBlockPlugin { + return &DomainBlockPlugin{ + BasePlugin: plugin.BasePlugin{ + ID: "security.domainblock", + Name: "Domain Blocker", + Help: "Blocks messages containing links from configured domains", + ConfigRequired: true, // Mark this plugin as requiring configuration + }, + } +} + +// OnMessage processes incoming messages +func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { + // Get blocked domains from config + blockedDomainsStr, ok := config["blocked_domains"].(string) + if !ok || blockedDomainsStr == "" { + return nil // No blocked domains configured + } + + // Split and clean blocked domains + blockedDomains := strings.Split(blockedDomainsStr, ",") + for i, domain := range blockedDomains { + blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + } + + // Extract domains from message + urlRegex := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`) + matches := urlRegex.FindAllStringSubmatch(msg.Text, -1) + + // Check if any extracted domains are blocked + for _, match := range matches { + if len(match) < 2 { + continue + } + + domain := strings.ToLower(match[1]) + + for _, blockedDomain := range blockedDomains { + if blockedDomain == "" { + continue + } + + if strings.HasSuffix(domain, blockedDomain) || domain == blockedDomain { + // Domain is blocked, create warning message + response := &model.Message{ + Text: fmt.Sprintf("⚠️ Message contained a link to blocked domain: %s", blockedDomain), + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + } + return []*model.Message{response} + } + } + } + + return nil +} + +func init() { + plugin.Register(New()) +} +``` + +### Advanced Example: URL Transformer + +This more complex plugin transforms URLs, useful for improving media embedding in chat platforms: + +```go +package social + +import ( + "net/url" + "regexp" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// TwitterExpander transforms twitter.com links to fxtwitter.com links +type TwitterExpander struct { + plugin.BasePlugin +} + +// New creates a new TwitterExpander instance +func NewTwitter() *TwitterExpander { + return &TwitterExpander{ + BasePlugin: plugin.BasePlugin{ + ID: "social.twitter", + Name: "Twitter Link Expander", + Help: "Automatically converts twitter.com links to fxtwitter.com links and removes tracking parameters", + }, + } +} + +// OnMessage handles incoming messages +func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { + // Skip empty messages + if strings.TrimSpace(msg.Text) == "" { + return nil + } + + // Regex to match twitter.com links + twitterRegex := regexp.MustCompile(`https?://(www\.)?(twitter\.com|x\.com)/[^\s]+`) + + // Check if the message contains a Twitter link + if !twitterRegex.MatchString(msg.Text) { + return nil + } + + // Transform the URL + transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { + // Parse the URL + parsedURL, err := url.Parse(link) + if err != nil { + // If parsing fails, just do the simple replacement + link = strings.Replace(link, "twitter.com", "fxtwitter.com", 1) + link = strings.Replace(link, "x.com", "fxtwitter.com", 1) + return link + } + + // Change the host + if strings.Contains(parsedURL.Host, "twitter.com") { + parsedURL.Host = strings.Replace(parsedURL.Host, "twitter.com", "fxtwitter.com", 1) + } else if strings.Contains(parsedURL.Host, "x.com") { + parsedURL.Host = strings.Replace(parsedURL.Host, "x.com", "fxtwitter.com", 1) + } + + // Remove query parameters + parsedURL.RawQuery = "" + + // Return the cleaned URL + return parsedURL.String() + }) + + // Create response message + response := &model.Message{ + Text: transformed, + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + } + + return []*model.Message{response} +} +``` + +## Enabling Configuration for Plugins + +To indicate that your plugin requires configuration: + +1. Set `ConfigRequired: true` in the BasePlugin struct: + ```go + BasePlugin: plugin.BasePlugin{ + ID: "myplugin.id", + Name: "Plugin Name", + Help: "Help text", + ConfigRequired: true, + }, + ``` + +2. Access the configuration in the OnMessage method: + ```go + func (p *MyPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { + // Extract configuration values + configValue, ok := config["some_config_key"].(string) + if !ok || configValue == "" { + // Handle missing or empty configuration + return nil + } + + // Use the configuration... + } + ``` + +3. The admin interface will show a "Configure" button for plugins that require configuration. + +## Registering Plugins + To use the plugin, register it in your application: ```go @@ -55,8 +269,19 @@ func (a *App) Run() error { // ... // Register plugins - plugin.Register(myplugin.New()) + plugin.Register(ping.New()) // Development plugin + plugin.Register(fun.NewCoin()) // Fun plugin + plugin.Register(social.NewTwitter()) // Social media plugin + plugin.Register(myplugin.New()) // Your custom plugin // ... } ``` + +Alternatively, you can register your plugin in its init() function: + +```go +func init() { + plugin.Register(New()) +} +``` diff --git a/docs/migrations.md b/docs/migrations.md new file mode 100644 index 0000000..65fcd99 --- /dev/null +++ b/docs/migrations.md @@ -0,0 +1,99 @@ +# Database Migrations + +ButterRobot uses a simple database migration system to manage database schema changes. This document explains how the migration system works and how to extend it. + +## Automatic Migrations + +Migrations in ButterRobot are applied automatically when the application starts. This ensures your database schema is always up to date without requiring manual intervention. + +The migration system: +1. Checks which migrations have been applied +2. Applies any pending migrations in sequential order +3. Records each successful migration in the `schema_migrations` table + +## Initial State + +The initial migration (version 1) sets up the database with the following: + +- `channels` table for chat platforms +- `channel_plugin` table for plugins associated with channels +- `users` table for admin users with bcrypt password hashing +- Default admin user with username "admin" and password "admin" + +This migration represents the current state of the database schema. It is not backwards compatible with previous versions of ButterRobot. + +## Creating New Migrations + +To add a new migration, follow these steps: + +1. Open `/internal/migration/migrations.go` +2. Add a new migration version in the `init()` function: + +```go +Register(2, "Add example table", migrateAddExampleTableUp, migrateAddExampleTableDown) +``` + +3. Implement the up and down functions for your migration: + +```go +// Migration to add example table - version 2 +func migrateAddExampleTableUp(db *sql.DB) error { + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS example ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + `) + return err +} + +func migrateAddExampleTableDown(db *sql.DB) error { + _, err := db.Exec(`DROP TABLE IF EXISTS example`) + return err +} +``` + +## Migration Guidelines + +1. **Incremental Changes**: Each migration should make a small, focused change to the database schema. +2. **Backward Compatibility**: Ensure migrations are backward compatible with existing code when possible. +3. **Test Thoroughly**: Test both up and down migrations before deploying. +4. **Document Changes**: Add comments explaining the purpose of each migration. +5. **Version Numbers**: Use sequential version numbers for migrations. + +## How Migrations Work + +The migration system tracks applied migrations in a `schema_migrations` table. When you run migrations, the system: + +1. Checks which migrations have been applied +2. Applies any pending migrations in order +3. Records each successful migration in the `schema_migrations` table + +When rolling back, it performs the down migrations in reverse order. + +## In Code Usage + +The application automatically runs pending migrations when starting up. This is done in the `initDatabase` function. + +You can also programmatically work with migrations: + +```go +// Get database instance +database, err := db.New(cfg.DatabasePath) +if err != nil { + // Handle error +} +defer database.Close() + +// Run migrations +if err := database.MigrateUp(); err != nil { + // Handle error +} + +// Check migration status +applied, pending, err := database.MigrationStatus() +if err != nil { + // Handle error +} +``` \ No newline at end of file diff --git a/docs/plugins.md b/docs/plugins.md index 11e3d16..25df16c 100644 --- a/docs/plugins.md +++ b/docs/plugins.md @@ -9,3 +9,16 @@ - Lo quito: What happens when you say _"lo quito"_...? (Spanish pun) - Dice: Put `!dice` and wathever roll you want to perform. - Coin: Flip a coin and get heads or tails. + +### Utility + +- Remind Me: Reply to a message with `!remindme ` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time. + +### Security + +- Domain Blocker: Blocks messages containing links from specified domains. Configure it per channel with a comma-separated list of domains to block. When a message contains a link matching any of the blocked domains, the bot will notify that the message contained a blocked domain. This plugin requires configuration through the admin interface. + +### Social Media + +- Twitter Link Expander: Automatically converts twitter.com and x.com links to fxtwitter.com links and removes tracking parameters. This allows for better media embedding in chat platforms. +- Instagram Link Expander: Automatically converts instagram.com links to ddinstagram.com links and removes tracking parameters. This allows for better media embedding in chat platforms. diff --git a/go.mod b/go.mod index ab85fc8..cd1bee5 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.24 require ( github.com/gorilla/sessions v1.4.0 + golang.org/x/crypto v0.37.0 + golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df modernc.org/sqlite v1.37.0 ) diff --git a/go.sum b/go.sum index 248cd40..00c4a3c 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,10 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df h1:SwgTucX8ajPE0La2ELpYOIs8jVMoCMpAvYB6mDqP9vk= +golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df/go.mod h1:lxN5T34bK4Z/i6cMaU7frUU57VkDXFD4Kamfl/cp9oU= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= diff --git a/internal/admin/admin.go b/internal/admin/admin.go index d590995..2b41820 100644 --- a/internal/admin/admin.go +++ b/internal/admin/admin.go @@ -2,6 +2,8 @@ package admin import ( "embed" + "encoding/gob" + "fmt" "html/template" "net/http" "strconv" @@ -28,6 +30,11 @@ type FlashMessage struct { Message string } +func init() { + // Register the FlashMessage type with gob package for session serialization + gob.Register(FlashMessage{}) +} + // TemplateData holds data for rendering templates type TemplateData struct { User *model.User @@ -39,6 +46,7 @@ type TemplateData struct { Channels []*model.Channel Channel *model.Channel ChannelPlugin *model.ChannelPlugin + Version string } // Admin represents the admin interface @@ -48,12 +56,18 @@ type Admin struct { store *sessions.CookieStore templates map[string]*template.Template baseTemplate *template.Template + version string } // New creates a new Admin instance -func New(cfg *config.Config, database *db.Database) *Admin { - // Create session store +func New(cfg *config.Config, database *db.Database, version string) *Admin { + // Create session store with appropriate options store := sessions.NewCookieStore([]byte(cfg.SecretKey)) + store.Options = &sessions.Options{ + Path: "/admin", + MaxAge: 3600 * 24 * 7, // 1 week + HttpOnly: true, + } // Load templates templates := make(map[string]*template.Template) @@ -79,10 +93,12 @@ func New(cfg *config.Config, database *db.Database) *Admin { templateFiles := []string{ "index.html", "login.html", + "change_password.html", "channel_list.html", "channel_detail.html", "plugin_list.html", "channel_plugins_list.html", + "channel_plugin_config.html", } for _, tf := range templateFiles { @@ -91,19 +107,19 @@ func New(cfg *config.Config, database *db.Database) *Admin { if err != nil { panic(err) } - + // Create a clone of the base template t, err := baseTemplate.Clone() if err != nil { panic(err) } - + // Parse the template content t, err = t.Parse(string(content)) if err != nil { panic(err) } - + templates[tf] = t } @@ -113,6 +129,7 @@ func New(cfg *config.Config, database *db.Database) *Admin { store: store, templates: templates, baseTemplate: baseTemplate, + version: version, } } @@ -122,16 +139,22 @@ func (a *Admin) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/admin/", a.handleIndex) mux.HandleFunc("/admin/login", a.handleLogin) mux.HandleFunc("/admin/logout", a.handleLogout) + mux.HandleFunc("/admin/change-password", a.handleChangePassword) mux.HandleFunc("/admin/plugins", a.handlePluginList) mux.HandleFunc("/admin/channels", a.handleChannelList) mux.HandleFunc("/admin/channels/", a.handleChannelDetail) mux.HandleFunc("/admin/channelplugins", a.handleChannelPluginList) + mux.HandleFunc("/admin/channelplugins/config/", a.handleChannelPluginConfig) mux.HandleFunc("/admin/channelplugins/", a.handleChannelPluginDetailOrDelete) } // getCurrentUser gets the current user from the session func (a *Admin) getCurrentUser(r *http.Request) *model.User { - session, _ := a.store.Get(r, sessionKey) + session, err := a.store.Get(r, sessionKey) + if err != nil { + fmt.Printf("Error getting session for user retrieval: %v\n", err) + return nil + } // Check if user is logged in userID, ok := session.Values["user_id"].(int64) @@ -142,6 +165,7 @@ func (a *Admin) getCurrentUser(r *http.Request) *model.User { // Get user from database user, err := a.db.GetUserByID(userID) if err != nil { + fmt.Printf("Error retrieving user from database: %v\n", err) return nil } @@ -150,32 +174,63 @@ func (a *Admin) getCurrentUser(r *http.Request) *model.User { // isLoggedIn checks if the user is logged in func (a *Admin) isLoggedIn(r *http.Request) bool { - session, _ := a.store.Get(r, sessionKey) + session, err := a.store.Get(r, sessionKey) + if err != nil { + fmt.Printf("Error getting session for login check: %v\n", err) + return false + } return session.Values["logged_in"] == true } // addFlash adds a flash message to the session func (a *Admin) addFlash(w http.ResponseWriter, r *http.Request, message string, category string) { - session, _ := a.store.Get(r, sessionKey) + session, err := a.store.Get(r, sessionKey) + if err != nil { + // If there's an error getting the session, create a new one + session = sessions.NewSession(a.store, sessionKey) + session.Options = &sessions.Options{ + Path: "/admin", + MaxAge: 3600 * 24 * 7, // 1 week + HttpOnly: true, + } + } - // Add flash message - flashes := session.Flashes() - if flashes == nil { - flashes = make([]interface{}, 0) + // Map internal categories to Bootstrap alert classes + var alertClass string + switch category { + case "success": + alertClass = "success" + case "danger": + alertClass = "danger" + case "warning": + alertClass = "warning" + case "info": + alertClass = "info" + default: + alertClass = "info" } flash := FlashMessage{ - Category: category, + Category: alertClass, Message: message, } session.AddFlash(flash) - session.Save(r, w) + err = session.Save(r, w) + if err != nil { + // Log the error or handle it appropriately + fmt.Printf("Error saving session: %v\n", err) + } } // getFlashes gets all flash messages from the session func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessage { - session, _ := a.store.Get(r, sessionKey) + session, err := a.store.Get(r, sessionKey) + if err != nil { + // If there's an error getting the session, return an empty slice + fmt.Printf("Error getting session for flashes: %v\n", err) + return []FlashMessage{} + } // Get flash messages flashes := session.Flashes() @@ -188,22 +243,14 @@ func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessag } // Save session to clear flashes - session.Save(r, w) + err = session.Save(r, w) + if err != nil { + fmt.Printf("Error saving session after getting flashes: %v\n", err) + } return messages } -// requireLogin middleware checks if the user is logged in -func (a *Admin) requireLogin(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if !a.isLoggedIn(r) { - http.Redirect(w, r, "/admin/login", http.StatusSeeOther) - return - } - next(w, r) - } -} - // render renders a template with the given data func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) { // Add current user data @@ -211,6 +258,7 @@ func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName stri data.LoggedIn = a.isLoggedIn(r) data.Path = r.URL.Path data.Flash = a.getFlashes(w, r) + data.Version = a.version // Get template tmpl, ok := a.templates[templateName] @@ -277,7 +325,10 @@ func (a *Admin) handleLogin(w http.ResponseWriter, r *http.Request) { // Set session expiration session.Options.MaxAge = 3600 * 24 * 7 // 1 week - session.Save(r, w) + err = session.Save(r, w) + if err != nil { + fmt.Printf("Error saving session: %v\n", err) + } a.addFlash(w, r, "You were logged in", "success") @@ -299,10 +350,19 @@ func (a *Admin) handleLogin(w http.ResponseWriter, r *http.Request) { // handleLogout handles the logout route func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) { // Clear session - session, _ := a.store.Get(r, sessionKey) + session, err := a.store.Get(r, sessionKey) + if err != nil { + fmt.Printf("Error getting session for logout: %v\n", err) + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + session.Values = make(map[interface{}]interface{}) session.Options.MaxAge = -1 // Delete session - session.Save(r, w) + err = session.Save(r, w) + if err != nil { + fmt.Printf("Error saving session for logout: %v\n", err) + } a.addFlash(w, r, "You were logged out", "success") @@ -310,6 +370,74 @@ func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/admin/login", http.StatusSeeOther) } +// handleChangePassword handles the change password route +func (a *Admin) handleChangePassword(w http.ResponseWriter, r *http.Request) { + // Check if user is logged in + if !a.isLoggedIn(r) { + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + + // Get current user + user := a.getCurrentUser(r) + if user == nil { + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + + // Handle form submission + if r.Method == http.MethodPost { + // Parse form + if err := r.ParseForm(); err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Get form values + currentPassword := r.FormValue("current_password") + newPassword := r.FormValue("new_password") + confirmPassword := r.FormValue("confirm_password") + + // Validate current password + _, err := a.db.CheckCredentials(user.Username, currentPassword) + if err != nil { + a.addFlash(w, r, "Current password is incorrect", "danger") + http.Redirect(w, r, "/admin/change-password", http.StatusSeeOther) + return + } + + // Validate new password and confirmation + if newPassword == "" { + a.addFlash(w, r, "New password cannot be empty", "danger") + http.Redirect(w, r, "/admin/change-password", http.StatusSeeOther) + return + } + + if newPassword != confirmPassword { + a.addFlash(w, r, "New passwords do not match", "danger") + http.Redirect(w, r, "/admin/change-password", http.StatusSeeOther) + return + } + + // Update password + if err := a.db.UpdateUserPassword(user.ID, newPassword); err != nil { + a.addFlash(w, r, "Failed to update password: "+err.Error(), "danger") + http.Redirect(w, r, "/admin/change-password", http.StatusSeeOther) + return + } + + // Success + a.addFlash(w, r, "Password changed successfully", "success") + http.Redirect(w, r, "/admin/", http.StatusSeeOther) + return + } + + // Render change password template + a.render(w, r, "change_password.html", TemplateData{ + Title: "Change Password", + }) +} + // handlePluginList handles the plugin list route func (a *Admin) handlePluginList(w http.ResponseWriter, r *http.Request) { // Check if user is logged in @@ -502,6 +630,96 @@ func (a *Admin) handleChannelPluginList(w http.ResponseWriter, r *http.Request) }) } +// handleChannelPluginConfig handles the channel plugin configuration route +func (a *Admin) handleChannelPluginConfig(w http.ResponseWriter, r *http.Request) { + // Check if user is logged in + if !a.isLoggedIn(r) { + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + + // Extract channel plugin ID from path + path := r.URL.Path + channelPluginID := strings.TrimPrefix(path, "/admin/channelplugins/config/") + + // Convert channel plugin ID to int64 + id, err := strconv.ParseInt(channelPluginID, 10, 64) + if err != nil { + http.Error(w, "Invalid channel plugin ID", http.StatusBadRequest) + return + } + + // Get the channel plugin + channelPlugin, err := a.db.GetChannelPluginByID(id) + if err != nil { + http.Error(w, "Channel plugin not found", http.StatusNotFound) + return + } + + // Get the plugin + p, err := plugin.Get(channelPlugin.PluginID) + if err != nil { + http.Error(w, "Plugin not found", http.StatusNotFound) + return + } + + // Handle form submission + if r.Method == http.MethodPost { + // Parse form + if err := r.ParseForm(); err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Create config map from form values + config := make(map[string]interface{}) + + // Process form values based on plugin type + if channelPlugin.PluginID == "security.domainblock" { + // Get blocked domains from form + blockedDomains := r.FormValue("blocked_domains") + config["blocked_domains"] = blockedDomains + } else { + // Generic handling for other plugins + for key, values := range r.Form { + if key == "form_submitted" { + continue + } + if len(values) == 1 { + config[key] = values[0] + } else { + config[key] = values + } + } + } + + // Update plugin configuration + if err := a.db.UpdateChannelPluginConfig(id, config); err != nil { + http.Error(w, "Failed to update plugin configuration", http.StatusInternalServerError) + return + } + + // Get the channel to redirect back to the channel detail page + channel, err := a.db.GetChannelByID(channelPlugin.ChannelID) + if err != nil { + a.addFlash(w, r, "Plugin configuration updated", "success") + http.Redirect(w, r, "/admin/channelplugins", http.StatusSeeOther) + return + } + + a.addFlash(w, r, "Plugin configuration updated", "success") + http.Redirect(w, r, fmt.Sprintf("/admin/channels/%d", channel.ID), http.StatusSeeOther) + return + } + + // Render template + a.render(w, r, "channel_plugin_config.html", TemplateData{ + Title: "Configure Plugin: " + p.GetName(), + ChannelPlugin: channelPlugin, + Plugins: map[string]model.Plugin{channelPlugin.PluginID: p}, + }) +} + // handleChannelPluginDetailOrDelete handles the channel plugin detail or delete route func (a *Admin) handleChannelPluginDetailOrDelete(w http.ResponseWriter, r *http.Request) { // Check if user is logged in diff --git a/internal/admin/templates/_base.html b/internal/admin/templates/_base.html index d056ab5..3ebdf85 100644 --- a/internal/admin/templates/_base.html +++ b/internal/admin/templates/_base.html @@ -28,8 +28,10 @@ Log in {{else}}
-
{{.User.Username}} - Log out
+
{{.User.Username}} - + Change Password | + Log out +
{{end}} @@ -100,14 +102,14 @@ {{end}} - {{range .Flash}} -
-
-
-

{{.Message}}

+
+ {{range .Flash}} + + {{end}}
- {{end}}
@@ -115,6 +117,19 @@
+
+
+
+
+
    +
  • + ButterRobot {{if .Version}}v{{.Version}}{{else}}(development){{end}} +
  • +
+
+
+
+
diff --git a/internal/admin/templates/change_password.html b/internal/admin/templates/change_password.html new file mode 100644 index 0000000..eed3dc5 --- /dev/null +++ b/internal/admin/templates/change_password.html @@ -0,0 +1,30 @@ +{{define "content"}} +
+
+
+
+

Change Password

+
+
+
+
+ + +
+
+ + +
+
+ + +
+ +
+
+
+
+
+{{end}} \ No newline at end of file diff --git a/internal/admin/templates/channel_detail.html b/internal/admin/templates/channel_detail.html index 764d7b1..78909df 100644 --- a/internal/admin/templates/channel_detail.html +++ b/internal/admin/templates/channel_detail.html @@ -68,6 +68,10 @@ {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} + {{$plugin := index $.Plugins $pluginID}} + {{if $plugin.RequiresConfig}} + Configure + {{end}}
diff --git a/internal/admin/templates/channel_plugin_config.html b/internal/admin/templates/channel_plugin_config.html new file mode 100644 index 0000000..decf1a2 --- /dev/null +++ b/internal/admin/templates/channel_plugin_config.html @@ -0,0 +1,37 @@ +{{define "content"}} +
+
+
+
+

Configure Plugin: {{(index .Plugins .ChannelPlugin.PluginID).GetName}}

+
+
+ + + {{if eq .ChannelPlugin.PluginID "security.domainblock"}} +
+ + +
+ Enter comma-separated list of domains to block (e.g., example.com, evil.org). + Messages containing links to these domains will be blocked. +
+
+ {{else}} +
+ This plugin doesn't have specific configuration fields implemented yet. +
+ {{end}} + + + +
+
+
+
+{{end}} diff --git a/internal/admin/templates/channel_plugins_list.html b/internal/admin/templates/channel_plugins_list.html index b57c60e..485150b 100644 --- a/internal/admin/templates/channel_plugins_list.html +++ b/internal/admin/templates/channel_plugins_list.html @@ -38,6 +38,10 @@ {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} + {{$plugin := index $.Plugins $pluginID}} + {{if $plugin.ConfigRequired}} + Configure + {{end}}
@@ -90,4 +94,4 @@
-{{end}} \ No newline at end of file +{{end}} diff --git a/internal/app/app.go b/internal/app/app.go index 8d4ffcd..4614325 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "os/signal" + "runtime/debug" "strings" "syscall" "time" @@ -16,21 +17,26 @@ import ( "git.nakama.town/fmartingr/butterrobot/internal/admin" "git.nakama.town/fmartingr/butterrobot/internal/config" "git.nakama.town/fmartingr/butterrobot/internal/db" + "git.nakama.town/fmartingr/butterrobot/internal/model" "git.nakama.town/fmartingr/butterrobot/internal/platform" "git.nakama.town/fmartingr/butterrobot/internal/plugin" + "git.nakama.town/fmartingr/butterrobot/internal/plugin/domainblock" "git.nakama.town/fmartingr/butterrobot/internal/plugin/fun" "git.nakama.town/fmartingr/butterrobot/internal/plugin/ping" + "git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder" + "git.nakama.town/fmartingr/butterrobot/internal/plugin/social" "git.nakama.town/fmartingr/butterrobot/internal/queue" ) // App represents the application type App struct { - config *config.Config - logger *slog.Logger - db *db.Database - router *http.ServeMux - queue *queue.Queue - admin *admin.Admin + config *config.Config + logger *slog.Logger + db *db.Database + router *http.ServeMux + queue *queue.Queue + admin *admin.Admin + version string } // New creates a new App instance @@ -47,16 +53,24 @@ func New(cfg *config.Config, logger *slog.Logger) (*App, error) { // Initialize message queue messageQueue := queue.New(logger) + // Get version information + version := "" + info, ok := debug.ReadBuildInfo() + if ok { + version = info.Main.Version + } + // Initialize admin interface - adminInterface := admin.New(cfg, database) + adminInterface := admin.New(cfg, database, version) return &App{ - config: cfg, - logger: logger, - db: database, - router: router, - queue: messageQueue, - admin: adminInterface, + config: cfg, + logger: logger, + db: database, + router: router, + queue: messageQueue, + admin: adminInterface, + version: version, }, nil } @@ -72,6 +86,10 @@ func (a *App) Run() error { plugin.Register(fun.NewCoin()) plugin.Register(fun.NewDice()) plugin.Register(fun.NewLoquito()) + plugin.Register(social.NewTwitterExpander()) + plugin.Register(social.NewInstagramExpander()) + plugin.Register(reminder.New(a.db)) + plugin.Register(domainblock.New()) // Initialize routes a.initializeRoutes() @@ -79,6 +97,9 @@ func (a *App) Run() error { // Start message queue worker a.queue.Start(a.handleMessage) + // Start reminder scheduler + a.queue.StartReminderScheduler(a.handleReminder) + // Create server addr := fmt.Sprintf(":%s", a.config.Port) srv := &http.Server{ @@ -130,7 +151,9 @@ func (a *App) initializeRoutes() { a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]interface{}{}) + if err := json.NewEncoder(w).Encode(map[string]interface{}{}); err != nil { + a.logger.Error("Error encoding response", "error", err) + } }) // Platform webhook endpoints @@ -153,7 +176,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { if _, err := platform.Get(platformName); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}) + if err := json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}); err != nil { + a.logger.Error("Error encoding response", "error", err) + } return } @@ -162,7 +187,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}) + if err := json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}); err != nil { + a.logger.Error("Error encoding response", "error", err) + } return } @@ -178,7 +205,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { // Respond with success w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]any{}) + if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { + a.logger.Error("Error encoding response", "error", err) + } } // extractPlatformName extracts the platform name from the URL path @@ -274,20 +303,110 @@ func (a *App) handleMessage(item queue.Item) { continue } - // Process message - responses := p.OnMessage(message, channelPlugin.Config) + // Process message and get actions + actions := p.OnMessage(message, channelPlugin.Config) - // Send responses + // Get platform for processing actions platform, err := platform.Get(item.Platform) if err != nil { a.logger.Error("Error getting platform", "error", err) continue } - for _, response := range responses { - if err := platform.SendMessage(response); err != nil { - a.logger.Error("Error sending message", "error", err) + // Process each action + for _, action := range actions { + switch action.Type { + case model.ActionSendMessage: + // Send a message + if action.Message != nil { + if err := platform.SendMessage(action.Message); err != nil { + a.logger.Error("Error sending message", "error", err) + } + } else { + a.logger.Error("Send message action with nil message") + } + + case model.ActionDeleteMessage: + // Delete a message using direct DeleteMessage call + if err := platform.DeleteMessage(action.Chat, action.MessageID); err != nil { + a.logger.Error("Error deleting message", "error", err, "message_id", action.MessageID) + } else { + a.logger.Info("Message deleted", "message_id", action.MessageID) + } + + default: + a.logger.Error("Unknown action type", "type", action.Type) } } } } + +// handleReminder handles reminder processing +func (a *App) handleReminder(reminder *model.Reminder) { + // When called with nil, it means we should check for pending reminders + if reminder == nil { + // Get pending reminders + reminders, err := a.db.GetPendingReminders() + if err != nil { + a.logger.Error("Error getting pending reminders", "error", err) + return + } + + // Process each reminder + for _, r := range reminders { + a.processReminder(r) + } + return + } + + // Otherwise, process the specific reminder + a.processReminder(reminder) +} + +// processReminder processes an individual reminder +func (a *App) processReminder(reminder *model.Reminder) { + a.logger.Info("Processing reminder", + "id", reminder.ID, + "platform", reminder.Platform, + "channel", reminder.ChannelID, + "trigger_at", reminder.TriggerAt, + ) + + // Get the platform handler + p, err := platform.Get(reminder.Platform) + if err != nil { + a.logger.Error("Error getting platform for reminder", "error", err, "platform", reminder.Platform) + return + } + + // Get the channel + channel, err := a.db.GetChannelByPlatform(reminder.Platform, reminder.ChannelID) + if err != nil { + a.logger.Error("Error getting channel for reminder", "error", err) + return + } + + // Create the reminder message + reminderText := fmt.Sprintf("@%s reminding you of this", reminder.Username) + + message := &model.Message{ + Text: reminderText, + Chat: reminder.ChannelID, + Channel: channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: reminder.ReplyToID, // Reply to the original message + } + + // Send the reminder message + if err := p.SendMessage(message); err != nil { + a.logger.Error("Error sending reminder", "error", err) + return + } + + // Mark the reminder as processed + if err := a.db.MarkReminderAsProcessed(reminder.ID); err != nil { + a.logger.Error("Error marking reminder as processed", "error", err) + } +} diff --git a/internal/db/db.go b/internal/db/db.go index e288bb3..0da285e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,14 +1,16 @@ package db import ( - "crypto/sha256" "database/sql" - "encoding/hex" "encoding/json" "errors" + "fmt" + "time" + "golang.org/x/crypto/bcrypt" _ "modernc.org/sqlite" + "git.nakama.town/fmartingr/butterrobot/internal/migration" "git.nakama.town/fmartingr/butterrobot/internal/model" ) @@ -232,7 +234,11 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e if err != nil { return nil, err } - defer rows.Close() + defer func() { + if err := rows.Close(); err != nil { + fmt.Printf("Error closing rows: %v\n", err) + } + }() var plugins []*model.ChannelPlugin @@ -380,6 +386,24 @@ func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error { return err } +// UpdateChannelPluginConfig updates a channel plugin's configuration +func (d *Database) UpdateChannelPluginConfig(id int64, config map[string]interface{}) error { + // Convert config to JSON + configJSON, err := json.Marshal(config) + if err != nil { + return err + } + + query := ` + UPDATE channel_plugin + SET config = ? + WHERE id = ? + ` + + _, err = d.db.Exec(query, string(configJSON), id) + return err +} + // DeleteChannelPlugin deletes a channel plugin func (d *Database) DeleteChannelPlugin(id int64) error { query := ` @@ -413,7 +437,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { if err != nil { return nil, err } - defer rows.Close() + defer func() { + if err := rows.Close(); err != nil { + fmt.Printf("Error closing rows: %v\n", err) + } + }() var channels []*model.Channel @@ -452,10 +480,9 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { continue // Skip this channel if plugins can't be retrieved } - if plugins != nil { - for _, plugin := range plugins { - channel.Plugins[plugin.PluginID] = plugin - } + // Add plugins to channel + for _, plugin := range plugins { + channel.Plugins[plugin.PluginID] = plugin } channels = append(channels, channel) @@ -505,7 +532,10 @@ func (d *Database) GetUserByID(id int64) (*model.User, error) { // CreateUser creates a new user func (d *Database) CreateUser(username, password string) (*model.User, error) { // Hash password - hashedPassword := hashPassword(password) + hashedPassword, err := hashPassword(password) + if err != nil { + return nil, err + } // Insert user query := ` @@ -555,9 +585,9 @@ func (d *Database) CheckCredentials(username, password string) (*model.User, err return nil, err } - // Check password - hashedPassword := hashPassword(password) - if dbPassword != hashedPassword { + // Check password with bcrypt + err = bcrypt.CompareHashAndPassword([]byte(dbPassword), []byte(password)) + if err != nil { return nil, errors.New("invalid credentials") } @@ -568,74 +598,198 @@ func (d *Database) CheckCredentials(username, password string) (*model.User, err }, nil } +// UpdateUserPassword updates a user's password +func (d *Database) UpdateUserPassword(userID int64, newPassword string) error { + // Hash the new password + hashedPassword, err := hashPassword(newPassword) + if err != nil { + return err + } + + // Update the user's password + query := ` + UPDATE users + SET password = ? + WHERE id = ? + ` + + _, err = d.db.Exec(query, hashedPassword, userID) + return err +} + +// CreateReminder creates a new reminder +func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) { + query := ` + INSERT INTO reminders ( + platform, channel_id, message_id, reply_to_id, + user_id, username, created_at, trigger_at, + content, processed + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0) + ` + + createdAt := time.Now() + result, err := d.db.Exec( + query, + platform, channelID, messageID, replyToID, + userID, username, createdAt, triggerAt, + content, + ) + if err != nil { + return nil, err + } + + id, err := result.LastInsertId() + if err != nil { + return nil, err + } + + return &model.Reminder{ + ID: id, + Platform: platform, + ChannelID: channelID, + MessageID: messageID, + ReplyToID: replyToID, + UserID: userID, + Username: username, + CreatedAt: createdAt, + TriggerAt: triggerAt, + Content: content, + Processed: false, + }, nil +} + +// GetPendingReminders gets all pending reminders that need to be processed +func (d *Database) GetPendingReminders() ([]*model.Reminder, error) { + query := ` + SELECT id, platform, channel_id, message_id, reply_to_id, + user_id, username, created_at, trigger_at, content, processed + FROM reminders + WHERE processed = 0 AND trigger_at <= ? + ` + + rows, err := d.db.Query(query, time.Now()) + if err != nil { + return nil, err + } + defer func() { + if err := rows.Close(); err != nil { + fmt.Printf("Error closing rows: %v\n", err) + } + }() + + var reminders []*model.Reminder + + for rows.Next() { + var ( + id int64 + platform, channelID, messageID, replyToID string + userID, username, content string + createdAt, triggerAt time.Time + processed bool + ) + + if err := rows.Scan( + &id, &platform, &channelID, &messageID, &replyToID, + &userID, &username, &createdAt, &triggerAt, &content, &processed, + ); err != nil { + return nil, err + } + + reminder := &model.Reminder{ + ID: id, + Platform: platform, + ChannelID: channelID, + MessageID: messageID, + ReplyToID: replyToID, + UserID: userID, + Username: username, + CreatedAt: createdAt, + TriggerAt: triggerAt, + Content: content, + Processed: processed, + } + + reminders = append(reminders, reminder) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + if len(reminders) == 0 { + return make([]*model.Reminder, 0), nil + } + + return reminders, nil +} + +// MarkReminderAsProcessed marks a reminder as processed +func (d *Database) MarkReminderAsProcessed(id int64) error { + query := ` + UPDATE reminders + SET processed = 1 + WHERE id = ? + ` + + _, err := d.db.Exec(query, id) + return err +} + // Helper function to hash password -func hashPassword(password string) string { - // In a real implementation, use a proper password hashing library like bcrypt - // This is a simplified version for demonstration - hasher := sha256.New() - hasher.Write([]byte(password)) - return hex.EncodeToString(hasher.Sum(nil)) +func hashPassword(password string) (string, error) { + // Use bcrypt for secure password hashing + // The cost parameter is the computational cost, higher is more secure but slower + // Recommended minimum is 12 + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), 12) + if err != nil { + return "", err + } + return string(hashedBytes), nil } // Initialize database tables func initDatabase(db *sql.DB) error { - // Create channels table - _, err := db.Exec(` - CREATE TABLE IF NOT EXISTS channels ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - platform TEXT NOT NULL, - platform_channel_id TEXT NOT NULL, - enabled BOOLEAN NOT NULL DEFAULT 0, - channel_raw TEXT NOT NULL, - UNIQUE(platform, platform_channel_id) - ) - `) - if err != nil { - return err + // Ensure migration table exists + if err := migration.EnsureMigrationTable(db); err != nil { + return fmt.Errorf("failed to create migration table: %w", err) } - // Create channel_plugin table - _, err = db.Exec(` - CREATE TABLE IF NOT EXISTS channel_plugin ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - channel_id INTEGER NOT NULL, - plugin_id TEXT NOT NULL, - enabled BOOLEAN NOT NULL DEFAULT 0, - config TEXT NOT NULL DEFAULT '{}', - UNIQUE(channel_id, plugin_id), - FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE - ) - `) + // Get applied migrations + applied, err := migration.GetAppliedMigrations(db) if err != nil { - return err + return fmt.Errorf("failed to get applied migrations: %w", err) } - // Create users table - _, err = db.Exec(` - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT NOT NULL UNIQUE, - password TEXT NOT NULL - ) - `) - if err != nil { - return err + // Get all migration versions + allMigrations := make([]int, 0, len(migration.Migrations)) + for version := range migration.Migrations { + allMigrations = append(allMigrations, version) } - // Create default admin user if it doesn't exist - var count int - err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) - if err != nil { - return err + // Create a map of applied migrations for quick lookup + appliedMap := make(map[int]bool) + for _, version := range applied { + appliedMap[version] = true } - if count == 0 { - hashedPassword := hashPassword("admin") - _, err = db.Exec("INSERT INTO users (username, password) VALUES (?, ?)", "admin", hashedPassword) - if err != nil { - return err + // Count pending migrations + pendingCount := 0 + for _, version := range allMigrations { + if !appliedMap[version] { + pendingCount++ } } + // Run migrations if needed + if pendingCount > 0 { + fmt.Printf("Running %d pending database migrations...\n", pendingCount) + if err := migration.Migrate(db); err != nil { + return fmt.Errorf("migration failed: %w", err) + } + fmt.Println("Database migrations completed successfully.") + } else { + fmt.Println("Database schema is up to date.") + } + return nil } diff --git a/internal/migration/migration.go b/internal/migration/migration.go new file mode 100644 index 0000000..63da5d8 --- /dev/null +++ b/internal/migration/migration.go @@ -0,0 +1,223 @@ +package migration + +import ( + "database/sql" + "fmt" + "sort" + "time" +) + +// Migration represents a database migration +type Migration struct { + Version int + Description string + Up func(db *sql.DB) error + Down func(db *sql.DB) error +} + +// Migrations is a collection of registered migrations +var Migrations = make(map[int]Migration) + +// Register adds a migration to the list of available migrations +func Register(version int, description string, up, down func(db *sql.DB) error) { + if _, exists := Migrations[version]; exists { + panic(fmt.Sprintf("migration version %d already exists", version)) + } + + Migrations[version] = Migration{ + Version: version, + Description: description, + Up: up, + Down: down, + } +} + +// EnsureMigrationTable creates the migration table if it doesn't exist +func EnsureMigrationTable(db *sql.DB) error { + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at TIMESTAMP NOT NULL + ) + `) + return err +} + +// GetAppliedMigrations returns a list of applied migration versions +func GetAppliedMigrations(db *sql.DB) ([]int, error) { + rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version") + if err != nil { + return nil, err + } + defer func() { + if err := rows.Close(); err != nil { + fmt.Printf("Error closing rows: %v\n", err) + } + }() + + var versions []int + for rows.Next() { + var version int + if err := rows.Scan(&version); err != nil { + return nil, err + } + versions = append(versions, version) + } + + return versions, rows.Err() +} + +// IsApplied checks if a migration version has been applied +func IsApplied(db *sql.DB, version int) (bool, error) { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +// MarkAsApplied marks a migration as applied +func MarkAsApplied(db *sql.DB, version int) error { + _, err := db.Exec( + "INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", + version, time.Now(), + ) + return err +} + +// RemoveApplied removes a migration from the applied list +func RemoveApplied(db *sql.DB, version int) error { + _, err := db.Exec("DELETE FROM schema_migrations WHERE version = ?", version) + return err +} + +// Migrate runs pending migrations up to the latest version +func Migrate(db *sql.DB) error { + // Ensure migration table exists + if err := EnsureMigrationTable(db); err != nil { + return fmt.Errorf("failed to create migration table: %w", err) + } + + // Get applied migrations + applied, err := GetAppliedMigrations(db) + if err != nil { + return fmt.Errorf("failed to get applied migrations: %w", err) + } + + // Create a map of applied migrations for quick lookup + appliedMap := make(map[int]bool) + for _, version := range applied { + appliedMap[version] = true + } + + // Get all migration versions and sort them + var versions []int + for version := range Migrations { + versions = append(versions, version) + } + sort.Ints(versions) + + // Apply each pending migration + for _, version := range versions { + if !appliedMap[version] { + migration := Migrations[version] + fmt.Printf("Applying migration %d: %s...\n", version, migration.Description) + + // Start transaction for the migration + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction for migration %d: %w", version, err) + } + + // Apply the migration + if err := migration.Up(db); err != nil { + if err := tx.Rollback(); err != nil { + fmt.Printf("Error rolling back transaction: %v\n", err) + } + return fmt.Errorf("failed to apply migration %d: %w", version, err) + } + + // Mark as applied + if _, err := tx.Exec( + "INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", + version, time.Now(), + ); err != nil { + if err := tx.Rollback(); err != nil { + fmt.Printf("Error rolling back transaction: %v\n", err) + } + return fmt.Errorf("failed to mark migration %d as applied: %w", version, err) + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit migration %d: %w", version, err) + } + + fmt.Printf("Migration %d applied successfully\n", version) + } + } + + return nil +} + +// MigrateDown rolls back migrations down to the specified version +// If version is -1, it will roll back all migrations +func MigrateDown(db *sql.DB, targetVersion int) error { + // Ensure migration table exists + if err := EnsureMigrationTable(db); err != nil { + return fmt.Errorf("failed to create migration table: %w", err) + } + + // Get applied migrations + applied, err := GetAppliedMigrations(db) + if err != nil { + return fmt.Errorf("failed to get applied migrations: %w", err) + } + + // Sort in descending order to roll back newest first + sort.Sort(sort.Reverse(sort.IntSlice(applied))) + + // Roll back each migration until target version + for _, version := range applied { + if targetVersion == -1 || version > targetVersion { + migration, exists := Migrations[version] + if !exists { + return fmt.Errorf("migration %d is applied but not found in codebase", version) + } + + fmt.Printf("Rolling back migration %d: %s...\n", version, migration.Description) + + // Start transaction for the rollback + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction for rollback %d: %w", version, err) + } + + // Apply the down migration + if err := migration.Down(db); err != nil { + if err := tx.Rollback(); err != nil { + fmt.Printf("Error rolling back transaction: %v\n", err) + } + return fmt.Errorf("failed to roll back migration %d: %w", version, err) + } + + // Remove from applied list + if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil { + if err := tx.Rollback(); err != nil { + fmt.Printf("Error rolling back transaction: %v\n", err) + } + return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err) + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit rollback %d: %w", version, err) + } + + fmt.Printf("Migration %d rolled back successfully\n", version) + } + } + + return nil +} diff --git a/internal/migration/migrations.go b/internal/migration/migrations.go new file mode 100644 index 0000000..8db229b --- /dev/null +++ b/internal/migration/migrations.go @@ -0,0 +1,128 @@ +package migration + +import ( + "database/sql" + "golang.org/x/crypto/bcrypt" +) + +func init() { + // Register migrations + Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown) + Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown) +} + +// Initial schema creation with bcrypt passwords - version 1 +func migrateInitialSchemaUp(db *sql.DB) error { + // Create channels table + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS channels ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + platform TEXT NOT NULL, + platform_channel_id TEXT NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT 0, + channel_raw TEXT NOT NULL, + UNIQUE(platform, platform_channel_id) + ) + `) + if err != nil { + return err + } + + // Create channel_plugin table + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS channel_plugin ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL, + plugin_id TEXT NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT 0, + config TEXT NOT NULL DEFAULT '{}', + UNIQUE(channel_id, plugin_id), + FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE + ) + `) + if err != nil { + return err + } + + // Create users table with bcrypt passwords + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + password TEXT NOT NULL + ) + `) + if err != nil { + return err + } + + // Create default admin user with bcrypt password + hashedPassword, err := bcrypt.GenerateFromPassword([]byte("admin"), 12) + if err != nil { + return err + } + + // Check if users table is empty before inserting + var count int + err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) + if err != nil { + return err + } + + if count == 0 { + _, err = db.Exec( + "INSERT INTO users (username, password) VALUES (?, ?)", + "admin", string(hashedPassword), + ) + if err != nil { + return err + } + } + + return nil +} + +func migrateInitialSchemaDown(db *sql.DB) error { + // Drop tables in reverse order of dependencies + _, err := db.Exec(`DROP TABLE IF EXISTS channel_plugin`) + if err != nil { + return err + } + + _, err = db.Exec(`DROP TABLE IF EXISTS channels`) + if err != nil { + return err + } + + _, err = db.Exec(`DROP TABLE IF EXISTS users`) + if err != nil { + return err + } + + return nil +} + +// Add reminders table - version 2 +func migrateRemindersUp(db *sql.DB) error { + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS reminders ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + platform TEXT NOT NULL, + channel_id TEXT NOT NULL, + message_id TEXT NOT NULL, + reply_to_id TEXT NOT NULL, + user_id TEXT NOT NULL, + username TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + trigger_at TIMESTAMP NOT NULL, + content TEXT NOT NULL, + processed BOOLEAN NOT NULL DEFAULT 0 + ) + `) + return err +} + +func migrateRemindersDown(db *sql.DB) error { + _, err := db.Exec(`DROP TABLE IF EXISTS reminders`) + return err +} diff --git a/internal/model/message.go b/internal/model/message.go index fe8c5e4..26ec5da 100644 --- a/internal/model/message.go +++ b/internal/model/message.go @@ -4,27 +4,47 @@ import ( "time" ) +// ActionType defines the type of action to perform +type ActionType string + +const ( + // ActionSendMessage is for sending a message to the chat + ActionSendMessage ActionType = "send_message" + // ActionDeleteMessage is for deleting a message from the chat + ActionDeleteMessage ActionType = "delete_message" +) + +// MessageAction represents an action to be performed on the platform +type MessageAction struct { + Type ActionType + Message *Message // For send_message + MessageID string // For delete_message + Chat string // Chat where the action happens + Channel *Channel // Channel reference + Raw map[string]interface{} // Additional data for the action +} + // Message represents a chat message type Message struct { - Text string - Chat string - Channel *Channel - Author string - FromBot bool - Date time.Time - ID string - ReplyTo string - Raw map[string]interface{} + Text string + Chat string + Channel *Channel + Author string + FromBot bool + Date time.Time + ID string + ReplyTo string + Raw map[string]interface{} } // Channel represents a chat channel type Channel struct { - ID int64 - Platform string + ID int64 + Platform string PlatformChannelID string - ChannelRaw map[string]interface{} - Enabled bool - Plugins map[string]*ChannelPlugin + ChannelRaw map[string]interface{} + Enabled bool + Plugins map[string]*ChannelPlugin } // HasEnabledPlugin checks if a plugin is enabled for this channel @@ -40,18 +60,18 @@ func (c *Channel) HasEnabledPlugin(pluginID string) bool { func (c *Channel) ChannelName() string { // In a real implementation, this would use the platform-specific // ParseChannelNameFromRaw function - + // For simplicity, we'll just use the PlatformChannelID if we can't extract a name // Check if ChannelRaw has a name field if c.ChannelRaw == nil { return c.PlatformChannelID } - + // Check common name fields in ChannelRaw if name, ok := c.ChannelRaw["name"].(string); ok && name != "" { return name } - + // Check for nested objects like "chat" (used by Telegram) if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok { // Try different fields in order of preference @@ -65,7 +85,7 @@ func (c *Channel) ChannelName() string { return firstName } } - + return c.PlatformChannelID } @@ -75,7 +95,7 @@ type ChannelPlugin struct { ChannelID int64 PluginID string Enabled bool - Config map[string]interface{} + Config map[string]any } // User represents an admin user @@ -83,4 +103,19 @@ type User struct { ID int64 Username string Password string -} \ No newline at end of file +} + +// Reminder represents a scheduled reminder +type Reminder struct { + ID int64 + Platform string + ChannelID string + MessageID string + ReplyToID string + UserID string + Username string + CreatedAt time.Time + TriggerAt time.Time + Content string + Processed bool +} diff --git a/internal/model/platform.go b/internal/model/platform.go index 01318eb..7d49ad3 100644 --- a/internal/model/platform.go +++ b/internal/model/platform.go @@ -43,4 +43,7 @@ type Platform interface { // SendMessage sends a message through the platform SendMessage(msg *Message) error + + // DeleteMessage deletes a message from the platform + DeleteMessage(channel string, messageID string) error } diff --git a/internal/model/plugin.go b/internal/model/plugin.go index ffc3c2f..03e4f96 100644 --- a/internal/model/plugin.go +++ b/internal/model/plugin.go @@ -13,16 +13,16 @@ var ( type Plugin interface { // GetID returns the plugin ID GetID() string - + // GetName returns the plugin name GetName() string - + // GetHelp returns the plugin help text GetHelp() string - + // RequiresConfig indicates if the plugin requires configuration RequiresConfig() bool - - // OnMessage processes an incoming message and returns response messages - OnMessage(msg *Message, config map[string]interface{}) []*Message -} \ No newline at end of file + + // OnMessage processes an incoming message and returns platform actions + OnMessage(msg *Message, config map[string]interface{}) []*MessageAction +} diff --git a/internal/platform/slack/slack.go b/internal/platform/slack/slack.go index 3683ada..2ca7bef 100644 --- a/internal/platform/slack/slack.go +++ b/internal/platform/slack/slack.go @@ -4,7 +4,7 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" + "io" "net/http" "strings" "time" @@ -37,11 +37,15 @@ func (s *SlackPlatform) Init(_ *config.Config) error { // ParseIncomingMessage parses an incoming Slack message func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) { // Read request body - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { return nil, err } - defer r.Body.Close() + defer func() { + if err := r.Body.Close(); err != nil { + fmt.Printf("Error closing request body: %v\n", err) + } + }() // Parse JSON var requestData map[string]interface{} @@ -163,6 +167,12 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { return errors.New("bot token not configured") } + // Check for delete message action + if msg.Raw != nil && msg.Raw["action"] == "delete" { + // This is a request to delete a message + return s.deleteMessage(msg) + } + // Prepare payload payload := map[string]interface{}{ "channel": msg.Chat, @@ -194,7 +204,11 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { if err != nil { return err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("Error closing response body: %v\n", err) + } + }() // Check response if resp.StatusCode != http.StatusOK { @@ -204,6 +218,63 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { return nil } +// DeleteMessage deletes a message on Slack +func (s *SlackPlatform) DeleteMessage(channel string, messageID string) error { + // Prepare payload for chat.delete API + payload := map[string]interface{}{ + "channel": channel, + "ts": messageID, // In Slack, the ts (timestamp) is the message ID + } + + // Convert payload to JSON + data, err := json.Marshal(payload) + if err != nil { + return err + } + + // Send HTTP request to chat.delete endpoint + req, err := http.NewRequest("POST", "https://slack.com/api/chat.delete", strings.NewReader(string(data))) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.config.BotOAuthAccessToken)) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Printf("Error closing response body: %v\n", err) + } + }() + + // Check response + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("slack API error: %d - %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// deleteMessage is a legacy method that uses the Raw message approach +func (s *SlackPlatform) deleteMessage(msg *model.Message) error { + // Get message ID to delete + messageID, ok := msg.Raw["message_id"] + if !ok { + return fmt.Errorf("no message ID provided for deletion") + } + + // Convert to string if needed + messageIDStr := fmt.Sprintf("%v", messageID) + + return s.DeleteMessage(msg.Chat, messageIDStr) +} + // Helper function to parse int64 func parseInt64(s string) (int64, error) { var n int64 diff --git a/internal/platform/telegram/telegram.go b/internal/platform/telegram/telegram.go index a9ff2db..8da4995 100644 --- a/internal/platform/telegram/telegram.go +++ b/internal/platform/telegram/telegram.go @@ -62,7 +62,11 @@ func (t *TelegramPlatform) Init(cfg *config.Config) error { t.log.Error("Failed to set webhook", "error", err) return fmt.Errorf("failed to set webhook: %w", err) } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + t.log.Error("Error closing response body", "error", err) + } + }() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) @@ -85,7 +89,11 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message t.log.Error("Failed to read request body", "error", err) return nil, err } - defer r.Body.Close() + defer func() { + if err := r.Body.Close(); err != nil { + t.log.Error("Error closing request body", "error", err) + } + }() // Parse JSON var update struct { @@ -103,8 +111,11 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message Title string `json:"title,omitempty"` Username string `json:"username,omitempty"` } `json:"chat"` - Date int `json:"date"` - Text string `json:"text"` + Date int `json:"date"` + Text string `json:"text"` + ReplyToMessage struct { + MessageID int `json:"message_id"` + } `json:"reply_to_message"` } `json:"message"` } @@ -128,6 +139,7 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message FromBot: update.Message.From.IsBot, Date: time.Unix(int64(update.Message.Date), 0), ID: strconv.Itoa(update.Message.MessageID), + ReplyTo: strconv.Itoa(update.Message.ReplyToMessage.MessageID), Raw: raw, } @@ -205,6 +217,13 @@ func (t *TelegramPlatform) ParseChannelFromMessage(body []byte) (map[string]any, // SendMessage sends a message to Telegram func (t *TelegramPlatform) SendMessage(msg *model.Message) error { + // Check for delete message action (legacy method) + if msg.Raw != nil && msg.Raw["action"] == "delete" { + // This is a request to delete a message using the legacy method + return t.deleteMessage(msg) + } + + // Regular message sending // Convert chat ID to int64 chatID, err := strconv.ParseInt(msg.Chat, 10, 64) if err != nil { @@ -247,7 +266,11 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error { t.log.Error("Failed to send message", "error", err) return err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + t.log.Error("Error closing response body", "error", err) + } + }() // Check response if resp.StatusCode != http.StatusOK { @@ -259,4 +282,89 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error { t.log.Debug("Message sent successfully") return nil -} \ No newline at end of file +} + +// DeleteMessage deletes a message on Telegram +func (t *TelegramPlatform) DeleteMessage(channel string, messageID string) error { + // Convert chat ID to int64 + chatID, err := strconv.ParseInt(channel, 10, 64) + if err != nil { + t.log.Error("Invalid chat ID for message deletion", "chat_id", channel, "error", err) + return err + } + + // Convert message ID to integer + msgID, err := strconv.Atoi(messageID) + if err != nil { + t.log.Error("Invalid message ID for deletion", "message_id", messageID, "error", err) + return err + } + + // Prepare payload for deleteMessage API + payload := map[string]interface{}{ + "chat_id": chatID, + "message_id": msgID, + } + + t.log.Debug("Deleting message on Telegram", "chat_id", chatID, "message_id", msgID) + + // Convert payload to JSON + data, err := json.Marshal(payload) + if err != nil { + t.log.Error("Failed to marshal delete message payload", "error", err) + return err + } + + // Send HTTP request to deleteMessage endpoint + resp, err := http.Post( + t.apiURL+"/deleteMessage", + "application/json", + bytes.NewBuffer(data), + ) + if err != nil { + t.log.Error("Failed to delete message", "error", err) + return err + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.log.Error("Error closing response body", "error", err) + } + }() + + // Check response + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + errMsg := string(bodyBytes) + t.log.Error("Telegram API error when deleting message", "status", resp.StatusCode, "response", errMsg) + return fmt.Errorf("telegram API error when deleting message: %d - %s", resp.StatusCode, errMsg) + } + + t.log.Debug("Message deleted successfully") + return nil +} + +// deleteMessage is a legacy method that uses the Raw message approach +func (t *TelegramPlatform) deleteMessage(msg *model.Message) error { + // Get message ID to delete + messageIDInterface, ok := msg.Raw["message_id"] + if !ok { + t.log.Error("No message ID provided for deletion") + return fmt.Errorf("no message ID provided for deletion") + } + + // Convert message ID to string + var messageIDStr string + switch v := messageIDInterface.(type) { + case string: + messageIDStr = v + case int: + messageIDStr = strconv.Itoa(v) + case float64: + messageIDStr = strconv.Itoa(int(v)) + default: + t.log.Error("Invalid message ID type for deletion", "type", fmt.Sprintf("%T", messageIDInterface)) + return fmt.Errorf("invalid message ID type for deletion") + } + + return t.DeleteMessage(msg.Chat, messageIDStr) +} diff --git a/internal/plugin/domainblock/domainblock.go b/internal/plugin/domainblock/domainblock.go new file mode 100644 index 0000000..5a44c49 --- /dev/null +++ b/internal/plugin/domainblock/domainblock.go @@ -0,0 +1,132 @@ +package domainblock + +import ( + "fmt" + "net/url" + "regexp" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// DomainBlockPlugin is a plugin that blocks messages containing links from specific domains +type DomainBlockPlugin struct { + plugin.BasePlugin +} + +// Debug helper to check if RequiresConfig is working +func (p *DomainBlockPlugin) RequiresConfig() bool { + return true +} + +// New creates a new DomainBlockPlugin instance +func New() *DomainBlockPlugin { + return &DomainBlockPlugin{ + BasePlugin: plugin.BasePlugin{ + ID: "security.domainblock", + Name: "Domain Blocker", + Help: "Blocks messages containing links from configured domains", + ConfigRequired: true, + }, + } +} + +// extractDomains extracts domains from a message text +func extractDomains(text string) []string { + // URL regex pattern + urlPattern := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`) + matches := urlPattern.FindAllStringSubmatch(text, -1) + + domains := make([]string, 0, len(matches)) + for _, match := range matches { + if len(match) < 2 { + continue + } + + // Try to parse the URL to extract the domain + urlStr := match[0] + parsedURL, err := url.Parse(urlStr) + if err != nil { + continue + } + + // Extract the domain (host) from the URL + domain := parsedURL.Host + // Remove port if present + if i := strings.IndexByte(domain, ':'); i >= 0 { + domain = domain[:i] + } + + domains = append(domains, strings.ToLower(domain)) + } + + return domains +} + +// OnMessage processes incoming messages +func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { + // Skip messages from bots + if msg.FromBot { + return nil + } + + // Get blocked domains from config + blockedDomainsStr, ok := config["blocked_domains"].(string) + if !ok || blockedDomainsStr == "" { + return nil // No blocked domains configured + } + + // Split and clean blocked domains + blockedDomains := strings.Split(blockedDomainsStr, ",") + for i, domain := range blockedDomains { + blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + } + + // Extract domains from message + messageDomains := extractDomains(msg.Text) + if len(messageDomains) == 0 { + return nil // No domains in message + } + + // Check if any domains in the message are blocked + for _, msgDomain := range messageDomains { + for _, blockedDomain := range blockedDomains { + if blockedDomain == "" { + continue + } + + if strings.HasSuffix(msgDomain, blockedDomain) || msgDomain == blockedDomain { + // Domain is blocked, create actions + + // 1. Create a delete message action + deleteAction := &model.MessageAction{ + Type: model.ActionDeleteMessage, + MessageID: msg.ID, + Chat: msg.Chat, + Channel: msg.Channel, + } + + // 2. Create a notification message action + notificationMsg := &model.Message{ + Text: fmt.Sprintf("I don't like links from %s 🙈", blockedDomain), + Chat: msg.Chat, + Channel: msg.Channel, + } + + sendAction := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: notificationMsg, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{deleteAction, sendAction} + } + } + } + + return nil +} + +// Plugin is registered in app.go, not using init() diff --git a/internal/plugin/domainblock/domainblock_test.go b/internal/plugin/domainblock/domainblock_test.go new file mode 100644 index 0000000..69cd8b8 --- /dev/null +++ b/internal/plugin/domainblock/domainblock_test.go @@ -0,0 +1,140 @@ +package domainblock + +import ( + "testing" + + "git.nakama.town/fmartingr/butterrobot/internal/model" +) + +func TestExtractDomains(t *testing.T) { + tests := []struct { + name string + text string + expected []string + }{ + { + name: "No URLs", + text: "Hello, world!", + expected: []string{}, + }, + { + name: "Single URL", + text: "Check out https://example.com for more info", + expected: []string{"example.com"}, + }, + { + name: "Multiple URLs", + text: "Check out https://example.com and http://test.example.org for more info", + expected: []string{"example.com", "test.example.org"}, + }, + { + name: "URL with path", + text: "Check out https://example.com/path/to/resource", + expected: []string{"example.com"}, + }, + { + name: "URL with port", + text: "Check out https://example.com:8080/path/to/resource", + expected: []string{"example.com"}, + }, + { + name: "URL with subdomain", + text: "Check out https://sub.example.com", + expected: []string{"sub.example.com"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + domains := extractDomains(test.text) + + if len(domains) != len(test.expected) { + t.Errorf("Expected %d domains, got %d", len(test.expected), len(domains)) + return + } + + for i, domain := range domains { + if domain != test.expected[i] { + t.Errorf("Expected domain %s, got %s", test.expected[i], domain) + } + } + }) + } +} + +func TestOnMessage(t *testing.T) { + plugin := New() + + tests := []struct { + name string + text string + blockedDomains string + expectBlocked bool + }{ + { + name: "No blocked domains", + text: "Check out https://example.com", + blockedDomains: "", + expectBlocked: false, + }, + { + name: "No matching domain", + text: "Check out https://example.com", + blockedDomains: "bad.com, evil.org", + expectBlocked: false, + }, + { + name: "Matching domain", + text: "Check out https://example.com", + blockedDomains: "example.com, evil.org", + expectBlocked: true, + }, + { + name: "Matching subdomain", + text: "Check out https://sub.example.com", + blockedDomains: "example.com", + expectBlocked: true, + }, + { + name: "Multiple domains, one matching", + text: "Check out https://example.com and https://good.org", + blockedDomains: "bad.com, example.com", + expectBlocked: true, + }, + { + name: "Spaces in blocked domains list", + text: "Check out https://example.com", + blockedDomains: "bad.com, example.com , evil.org", + expectBlocked: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + msg := &model.Message{ + Text: test.text, + Chat: "test-chat", + ID: "test-id", + Channel: &model.Channel{ + ID: 1, + }, + } + + config := map[string]interface{}{ + "blocked_domains": test.blockedDomains, + } + + responses := plugin.OnMessage(msg, config) + + if test.expectBlocked { + if responses == nil || len(responses) == 0 { + t.Errorf("Expected message to be blocked, but it wasn't") + } + } else { + if responses != nil && len(responses) > 0 { + t.Errorf("Expected message not to be blocked, but it was") + } + } + }) + } +} diff --git a/internal/plugin/fun/coin.go b/internal/plugin/fun/coin.go index 8e12a8d..bd083d1 100644 --- a/internal/plugin/fun/coin.go +++ b/internal/plugin/fun/coin.go @@ -29,7 +29,7 @@ func NewCoin() *CoinPlugin { } // OnMessage handles incoming messages -func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") { return nil } @@ -46,5 +46,12 @@ func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{} Channel: msg.Channel, } - return []*model.Message{response} + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } diff --git a/internal/plugin/fun/dice.go b/internal/plugin/fun/dice.go index 00fc7cc..8b13edb 100644 --- a/internal/plugin/fun/dice.go +++ b/internal/plugin/fun/dice.go @@ -32,7 +32,7 @@ func NewDice() *DicePlugin { } // OnMessage handles incoming messages -func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") { return nil } @@ -62,7 +62,14 @@ func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{} Channel: msg.Channel, } - return []*model.Message{response} + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } // rollDice parses a dice formula string and returns the result @@ -107,9 +114,10 @@ func (p *DicePlugin) rollDice(formula string) (int, error) { return 0, fmt.Errorf("invalid modifier") } - if matches[3] == "+" { + switch matches[3] { + case "+": total += modifier - } else if matches[3] == "-" { + case "-": total -= modifier } } diff --git a/internal/plugin/fun/loquito.go b/internal/plugin/fun/loquito.go index 7b0ea43..fef78bd 100644 --- a/internal/plugin/fun/loquito.go +++ b/internal/plugin/fun/loquito.go @@ -24,7 +24,7 @@ func NewLoquito() *LoquitoPlugin { } // OnMessage handles incoming messages -func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { if !strings.Contains(strings.ToLower(msg.Text), "lo quito") { return nil } @@ -36,5 +36,12 @@ func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interfac Channel: msg.Channel, } - return []*model.Message{response} + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } diff --git a/internal/plugin/ping/ping.go b/internal/plugin/ping/ping.go index b09caaf..3dacf6f 100644 --- a/internal/plugin/ping/ping.go +++ b/internal/plugin/ping/ping.go @@ -24,11 +24,12 @@ func New() *PingPlugin { } // OnMessage handles incoming messages -func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") { return nil } + // Create the response message response := &model.Message{ Text: "pong", Chat: msg.Chat, @@ -36,5 +37,13 @@ func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{} Channel: msg.Channel, } - return []*model.Message{response} + // Create an action to send the message + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} } diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go index 69da2c2..eb3789f 100644 --- a/internal/plugin/plugin.go +++ b/internal/plugin/plugin.go @@ -1,6 +1,7 @@ package plugin import ( + "maps" "sync" "git.nakama.town/fmartingr/butterrobot/internal/model" @@ -41,9 +42,7 @@ func GetAvailablePlugins() map[string]model.Plugin { // Create a copy to avoid race conditions result := make(map[string]model.Plugin, len(plugins)) - for id, plugin := range plugins { - result[id] = plugin - } + maps.Copy(result, plugins) return result } @@ -77,6 +76,6 @@ func (p *BasePlugin) RequiresConfig() bool { } // OnMessage is the default implementation that does nothing -func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { +func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { return nil } diff --git a/internal/plugin/reminder/reminder.go b/internal/plugin/reminder/reminder.go new file mode 100644 index 0000000..029c8d9 --- /dev/null +++ b/internal/plugin/reminder/reminder.go @@ -0,0 +1,200 @@ +package reminder + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// Duration regex patterns to match reminders +var ( + remindMePattern = regexp.MustCompile(`(?i)^!remindme\s(\d+)(y|mo|d|h|m|s)$`) +) + +// ReminderCreator is an interface for creating reminders +type ReminderCreator interface { + CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) +} + +// Reminder is a plugin that sets reminders for messages +type Reminder struct { + plugin.BasePlugin + creator ReminderCreator +} + +// New creates a new Reminder plugin +func New(creator ReminderCreator) *Reminder { + return &Reminder{ + BasePlugin: plugin.BasePlugin{ + ID: "reminder.remindme", + Name: "Remind Me", + Help: "Reply to a message with `!remindme ` to set a reminder (e.g., `!remindme 2d` for 2 days, `!remindme 1y` for 1 year).", + ConfigRequired: false, + }, + creator: creator, + } +} + +// OnMessage processes incoming messages +func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { + // Only process replies to messages + if msg.ReplyTo == "" { + return nil + } + + // Check if the message is a reminder command + match := remindMePattern.FindStringSubmatch(msg.Text) + if match == nil { + return nil + } + + // Parse the duration + amount, err := strconv.Atoi(match[1]) + if err != nil { + errorMsg := &model.Message{ + Text: "Invalid duration format. Please use a number followed by y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).", + Chat: msg.Chat, + Channel: msg.Channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: msg.ID, + } + + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: errorMsg, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } + } + + // Calculate the trigger time + var duration time.Duration + unit := match[2] + switch strings.ToLower(unit) { + case "y": + duration = time.Duration(amount) * 365 * 24 * time.Hour + case "mo": + duration = time.Duration(amount) * 30 * 24 * time.Hour + case "d": + duration = time.Duration(amount) * 24 * time.Hour + case "h": + duration = time.Duration(amount) * time.Hour + case "m": + duration = time.Duration(amount) * time.Minute + case "s": + duration = time.Duration(amount) * time.Second + default: + errorMsg := &model.Message{ + Text: "Invalid duration unit. Please use y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).", + Chat: msg.Chat, + Channel: msg.Channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: msg.ID, + } + + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: errorMsg, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } + } + + triggerAt := time.Now().Add(duration) + + // Determine the username for the reminder + username := msg.Author + if username == "" { + // Try to extract username from message raw data + if authorData, ok := msg.Raw["author"].(map[string]interface{}); ok { + if name, ok := authorData["username"].(string); ok { + username = name + } else if name, ok := authorData["name"].(string); ok { + username = name + } + } + } + + // Create the reminder + _, err = r.creator.CreateReminder( + msg.Channel.Platform, + msg.Chat, + msg.ID, + msg.ReplyTo, + msg.Author, + username, + "", // No additional content for now + triggerAt, + ) + + if err != nil { + errorMsg := &model.Message{ + Text: fmt.Sprintf("Failed to create reminder: %v", err), + Chat: msg.Chat, + Channel: msg.Channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: msg.ID, + } + + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: errorMsg, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } + } + + // Format the acknowledgment message + var confirmText string + switch strings.ToLower(unit) { + case "y": + confirmText = fmt.Sprintf("I'll remind you about this message in %d year(s) on %s", amount, triggerAt.Format("Mon, Jan 2, 2006 at 15:04")) + case "mo": + confirmText = fmt.Sprintf("I'll remind you about this message in %d month(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04")) + case "d": + confirmText = fmt.Sprintf("I'll remind you about this message in %d day(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04")) + case "h": + confirmText = fmt.Sprintf("I'll remind you about this message in %d hour(s) at %s", amount, triggerAt.Format("15:04")) + case "m": + confirmText = fmt.Sprintf("I'll remind you about this message in %d minute(s) at %s", amount, triggerAt.Format("15:04")) + case "s": + confirmText = fmt.Sprintf("I'll remind you about this message in %d second(s)", amount) + } + + // Create confirmation message + confirmMsg := &model.Message{ + Text: confirmText, + Chat: msg.Chat, + Channel: msg.Channel, + Author: "bot", + FromBot: true, + Date: time.Now(), + ReplyTo: msg.ID, + } + + return []*model.MessageAction{ + { + Type: model.ActionSendMessage, + Message: confirmMsg, + Chat: msg.Chat, + Channel: msg.Channel, + }, + } +} diff --git a/internal/plugin/reminder/reminder_test.go b/internal/plugin/reminder/reminder_test.go new file mode 100644 index 0000000..8e611ce --- /dev/null +++ b/internal/plugin/reminder/reminder_test.go @@ -0,0 +1,175 @@ +package reminder + +import ( + "testing" + "time" + + "git.nakama.town/fmartingr/butterrobot/internal/model" +) + +// MockCreator is a mock implementation of ReminderCreator for testing +type MockCreator struct { + reminders []*model.Reminder +} + +func (m *MockCreator) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) { + reminder := &model.Reminder{ + ID: int64(len(m.reminders) + 1), + Platform: platform, + ChannelID: channelID, + MessageID: messageID, + ReplyToID: replyToID, + UserID: userID, + Username: username, + Content: content, + TriggerAt: triggerAt, + } + m.reminders = append(m.reminders, reminder) + return reminder, nil +} + +func TestReminderOnMessage(t *testing.T) { + creator := &MockCreator{reminders: make([]*model.Reminder, 0)} + plugin := New(creator) + + tests := []struct { + name string + message *model.Message + expectResponse bool + expectReminder bool + }{ + { + name: "Valid reminder command - years", + message: &model.Message{ + Text: "!remindme 1y", + ReplyTo: "original-message-id", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: true, + expectReminder: true, + }, + { + name: "Valid reminder command - months", + message: &model.Message{ + Text: "!remindme 3mo", + ReplyTo: "original-message-id", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: true, + expectReminder: true, + }, + { + name: "Valid reminder command - days", + message: &model.Message{ + Text: "!remindme 2d", + ReplyTo: "original-message-id", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: true, + expectReminder: true, + }, + { + name: "Valid reminder command - hours", + message: &model.Message{ + Text: "!remindme 5h", + ReplyTo: "original-message-id", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: true, + expectReminder: true, + }, + { + name: "Valid reminder command - minutes", + message: &model.Message{ + Text: "!remindme 30m", + ReplyTo: "original-message-id", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: true, + expectReminder: true, + }, + { + name: "Valid reminder command - seconds", + message: &model.Message{ + Text: "!remindme 60s", + ReplyTo: "original-message-id", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: true, + expectReminder: true, + }, + { + name: "Not a reply", + message: &model.Message{ + Text: "!remindme 2d", + ReplyTo: "", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: false, + expectReminder: false, + }, + { + name: "Not a reminder command", + message: &model.Message{ + Text: "hello world", + ReplyTo: "original-message-id", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: false, + expectReminder: false, + }, + { + name: "Invalid duration format", + message: &model.Message{ + Text: "!remindme abc", + ReplyTo: "original-message-id", + Author: "testuser", + Channel: &model.Channel{Platform: "test"}, + }, + expectResponse: false, + expectReminder: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + initialCount := len(creator.reminders) + actions := plugin.OnMessage(tt.message, nil) + + if tt.expectResponse && len(actions) == 0 { + t.Errorf("Expected response action, but got none") + } + + if !tt.expectResponse && len(actions) > 0 { + t.Errorf("Expected no actions, but got %d", len(actions)) + } + + // Verify action type is correct when actions are returned + if len(actions) > 0 { + if actions[0].Type != model.ActionSendMessage { + t.Errorf("Expected action type to be %s, but got %s", model.ActionSendMessage, actions[0].Type) + } + + if actions[0].Message == nil { + t.Errorf("Expected message in action to not be nil") + } + } + + if tt.expectReminder && len(creator.reminders) != initialCount+1 { + t.Errorf("Expected reminder to be created, but it wasn't") + } + + if !tt.expectReminder && len(creator.reminders) != initialCount { + t.Errorf("Expected no reminder to be created, but got %d", len(creator.reminders)-initialCount) + } + }) + } +} diff --git a/internal/plugin/social/instagram.go b/internal/plugin/social/instagram.go new file mode 100644 index 0000000..d05bd30 --- /dev/null +++ b/internal/plugin/social/instagram.go @@ -0,0 +1,81 @@ +package social + +import ( + "net/url" + "regexp" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// InstagramExpander transforms instagram.com links to ddinstagram.com links +type InstagramExpander struct { + plugin.BasePlugin +} + +// New creates a new InstagramExpander instance +func NewInstagramExpander() *InstagramExpander { + return &InstagramExpander{ + BasePlugin: plugin.BasePlugin{ + ID: "social.instagram", + Name: "Instagram Link Expander", + Help: "Automatically converts instagram.com links to ddinstagram.com links and removes tracking parameters", + }, + } +} + +// OnMessage handles incoming messages +func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { + // Skip empty messages + if strings.TrimSpace(msg.Text) == "" { + return nil + } + + // Regex to match instagram.com links + // Match both http://instagram.com and https://instagram.com formats + // Also match www.instagram.com + instagramRegex := regexp.MustCompile(`https?://(www\.)?(instagram\.com)/[^\s]+`) + + // Check if the message contains an Instagram link + if !instagramRegex.MatchString(msg.Text) { + return nil + } + + // Replace instagram.com with ddinstagram.com in the message and clean query parameters + transformed := instagramRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { + // Parse the URL + parsedURL, err := url.Parse(link) + if err != nil { + // If parsing fails, just do the simple replacement + link = strings.Replace(link, "instagram.com", "ddinstagram.com", 1) + return link + } + + // Change the host + parsedURL.Host = strings.Replace(parsedURL.Host, "instagram.com", "ddinstagram.com", 1) + + // Remove query parameters + parsedURL.RawQuery = "" + + // Return the cleaned URL + return parsedURL.String() + }) + + // Create response message + response := &model.Message{ + Text: transformed, + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + } + + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} +} diff --git a/internal/plugin/social/twitter.go b/internal/plugin/social/twitter.go new file mode 100644 index 0000000..865f421 --- /dev/null +++ b/internal/plugin/social/twitter.go @@ -0,0 +1,86 @@ +package social + +import ( + "net/url" + "regexp" + "strings" + + "git.nakama.town/fmartingr/butterrobot/internal/model" + "git.nakama.town/fmartingr/butterrobot/internal/plugin" +) + +// TwitterExpander transforms twitter.com links to fxtwitter.com links +type TwitterExpander struct { + plugin.BasePlugin +} + +// New creates a new TwitterExpander instance +func NewTwitterExpander() *TwitterExpander { + return &TwitterExpander{ + BasePlugin: plugin.BasePlugin{ + ID: "social.twitter", + Name: "Twitter Link Expander", + Help: "Automatically converts twitter.com links to fxtwitter.com links and removes tracking parameters", + }, + } +} + +// OnMessage handles incoming messages +func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.MessageAction { + // Skip empty messages + if strings.TrimSpace(msg.Text) == "" { + return nil + } + + // Regex to match twitter.com links + // Match both http://twitter.com and https://twitter.com formats + // Also match www.twitter.com + twitterRegex := regexp.MustCompile(`https?://(www\.)?(twitter\.com|x\.com)/[^\s]+`) + + // Check if the message contains a Twitter link + if !twitterRegex.MatchString(msg.Text) { + return nil + } + + // Replace twitter.com with fxtwitter.com in the message and clean query parameters + transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { + // Parse the URL + parsedURL, err := url.Parse(link) + if err != nil { + // If parsing fails, just do the simple replacement + link = strings.Replace(link, "twitter.com", "fxtwitter.com", 1) + link = strings.Replace(link, "x.com", "fxtwitter.com", 1) + return link + } + + // Change the host + if strings.Contains(parsedURL.Host, "twitter.com") { + parsedURL.Host = strings.Replace(parsedURL.Host, "twitter.com", "fxtwitter.com", 1) + } else if strings.Contains(parsedURL.Host, "x.com") { + parsedURL.Host = strings.Replace(parsedURL.Host, "x.com", "fxtwitter.com", 1) + } + + // Remove query parameters + parsedURL.RawQuery = "" + + // Return the cleaned URL + return parsedURL.String() + }) + + // Create response message + response := &model.Message{ + Text: transformed, + Chat: msg.Chat, + ReplyTo: msg.ID, + Channel: msg.Channel, + } + + action := &model.MessageAction{ + Type: model.ActionSendMessage, + Message: response, + Chat: msg.Chat, + Channel: msg.Channel, + } + + return []*model.MessageAction{action} +} diff --git a/internal/queue/queue.go b/internal/queue/queue.go index 668bf60..692816e 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -3,6 +3,9 @@ package queue import ( "log/slog" "sync" + "time" + + "git.nakama.town/fmartingr/butterrobot/internal/model" ) // Item represents a queue item @@ -14,14 +17,19 @@ type Item struct { // HandlerFunc defines a function that processes queue items type HandlerFunc func(item Item) +// ReminderHandlerFunc defines a function that processes reminder items +type ReminderHandlerFunc func(reminder *model.Reminder) + // Queue represents a message queue type Queue struct { - items chan Item - wg sync.WaitGroup - quit chan struct{} - logger *slog.Logger - running bool - runMutex sync.Mutex + items chan Item + wg sync.WaitGroup + quit chan struct{} + logger *slog.Logger + running bool + runMutex sync.Mutex + reminderTicker *time.Ticker + reminderHandler ReminderHandlerFunc } // New creates a new Queue instance @@ -49,6 +57,24 @@ func (q *Queue) Start(handler HandlerFunc) { go q.worker(handler) } +// StartReminderScheduler starts the reminder scheduler +func (q *Queue) StartReminderScheduler(handler ReminderHandlerFunc) { + q.runMutex.Lock() + defer q.runMutex.Unlock() + + if q.reminderTicker != nil { + return + } + + q.reminderHandler = handler + + // Check for reminders every minute + q.reminderTicker = time.NewTicker(1 * time.Minute) + + q.wg.Add(1) + go q.reminderWorker() +} + // Stop stops processing queue items func (q *Queue) Stop() { q.runMutex.Lock() @@ -59,6 +85,12 @@ func (q *Queue) Stop() { } q.running = false + + // Stop reminder ticker if it exists + if q.reminderTicker != nil { + q.reminderTicker.Stop() + } + close(q.quit) q.wg.Wait() } @@ -96,4 +128,34 @@ func (q *Queue) worker(handler HandlerFunc) { return } } -} \ No newline at end of file +} + +// reminderWorker processes reminder items on a schedule +func (q *Queue) reminderWorker() { + defer q.wg.Done() + + for { + select { + case <-q.reminderTicker.C: + // This is triggered every minute to check for pending reminders + q.logger.Debug("Checking for pending reminders") + + if q.reminderHandler != nil { + // The handler is responsible for fetching and processing reminders + func() { + defer func() { + if r := recover(); r != nil { + q.logger.Error("Panic in reminder worker", "error", r) + } + }() + + // Call the handler with a nil reminder to indicate it should check the database + q.reminderHandler(nil) + }() + } + case <-q.quit: + // Quit worker + return + } + } +}