Compare commits

..

No commits in common. "master" and "v0.2.2" have entirely different histories.

51 changed files with 151 additions and 4901 deletions

5
.gitignore vendored
View file

@ -5,12 +5,9 @@ __pycache__
*.cert *.cert
.env-local .env-local
.coverage .coverage
coverage.out
dist dist
bin bin
# Butterrobot # Butterrobot
*.sqlite* *.sqlite*
butterrobot.db* butterrobot.db
/butterrobot
*_test.db*

View file

@ -3,7 +3,7 @@ when:
- push - push
- pull_request - pull_request
branch: branch:
- master - main
steps: steps:
format: format:

View file

@ -1,6 +1,6 @@
when: when:
- event: tag - event: tag
branch: master branch: main
steps: steps:
- name: Release - name: Release

View file

@ -1,29 +0,0 @@
# Claude Code Instructions
## Plugin Development Workflow
When creating, modifying, or removing plugins:
1. **Always update the plugin documentation** in `docs/plugins.md` after any plugin changes
2. Ensure the documentation includes:
- Plugin name and category (Development, Fun and entertainment, Utility, Security, Social Media)
- Brief description of functionality
- Usage instructions with examples
- Any configuration requirements
3. **For plugins with configuration options:**
- Set `ConfigRequired: true` in the plugin's BasePlugin struct
- Add corresponding HTML form fields in `internal/admin/templates/channel_plugin_config.html`
- Use conditional template logic: `{{else if eq .ChannelPlugin.PluginID "plugin.id"}}`
- Include proper form labels, help text, and value binding
## Testing
**CRITICAL**: After making ANY changes to code files, you MUST run these commands in order:
1. **Format code**: `make format` - Format all code according to project standards
2. **Lint code**: `make lint` - Check code style and quality (must show "0 issues")
3. **Run tests**: `make test` - Run all tests to ensure functionality works
4. Verify documentation accuracy
5. Ensure all examples work as described
**These commands are MANDATORY after every code change, no exceptions.**

View file

@ -1,6 +1,9 @@
# Butter Robot # Butter Robot
![Status badge](https://woodpecker.local.fmartingr.dev/api/badges/5/status.svg) | Stable | Master |
| --- | --- |
| ![Build stable tag docker image](https://git.nakama.town/fmartingr/butterrobot/workflows/Build%20stable%20tag%20docker%20image/badge.svg?branch=stable) | ![Build latest tag docker image](https://git.nakama.town/fmartingr/butterrobot/workflows/Build%20latest%20tag%20docker%20image/badge.svg?branch=master) |
| ![Test](https://git.nakama.town/fmartingr/butterrobot/workflows/Test/badge.svg?branch=stable) | ![Test](https://git.nakama.town/fmartingr/butterrobot/workflows/Test/badge.svg?branch=master) |
Go framework to create bots for several platforms. Go framework to create bots for several platforms.
@ -10,7 +13,7 @@ Go framework to create bots for several platforms.
## Features ## Features
- Support for multiple chat platforms (Slack (untested!), Telegram) - Support for multiple chat platforms (Slack, Telegram)
- Plugin system for easy extension - Plugin system for easy extension
- Admin interface for managing channels and plugins - Admin interface for managing channels and plugins
- Message queue for asynchronous processing - Message queue for asynchronous processing

View file

@ -1,19 +1,6 @@
# Creating a Plugin # Creating a Plugin
## Plugin Categories ## Example
ButterRobot organizes plugins into different categories:
- **Development**: Utility plugins like `ping`
- **Fun**: Entertainment plugins like dice rolling, coin flipping
- **Social**: Social media related plugins like URL transformers/expanders
- **Security**: Moderation and protection features like domain blocking
When creating a new plugin, consider which category it fits into and place it in the appropriate directory.
## Plugin Examples
### Basic Example: Marco Polo
This simple "Marco Polo" plugin will answer _Polo_ to the user that says _Marco_: This simple "Marco Polo" plugin will answer _Polo_ to the user that says _Marco_:
@ -60,207 +47,6 @@ func (p *MarcoPlugin) OnMessage(msg *model.Message, config map[string]interface{
} }
``` ```
### Configuration-Enabled Plugin
This plugin requires configuration to be set in the admin interface. It demonstrates how to create plugins that need channel-specific configuration:
```go
package security
import (
"fmt"
"regexp"
"strings"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// DomainBlockPlugin is a plugin that blocks messages containing links from specific domains
type DomainBlockPlugin struct {
plugin.BasePlugin
}
// New creates a new DomainBlockPlugin instance
func New() *DomainBlockPlugin {
return &DomainBlockPlugin{
BasePlugin: plugin.BasePlugin{
ID: "security.domainblock",
Name: "Domain Blocker",
Help: "Blocks messages containing links from configured domains",
ConfigRequired: true, // Mark this plugin as requiring configuration
},
}
}
// OnMessage processes incoming messages
func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
// Get blocked domains from config
blockedDomainsStr, ok := config["blocked_domains"].(string)
if !ok || blockedDomainsStr == "" {
return nil // No blocked domains configured
}
// Split and clean blocked domains
blockedDomains := strings.Split(blockedDomainsStr, ",")
for i, domain := range blockedDomains {
blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain))
}
// Extract domains from message
urlRegex := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`)
matches := urlRegex.FindAllStringSubmatch(msg.Text, -1)
// Check if any extracted domains are blocked
for _, match := range matches {
if len(match) < 2 {
continue
}
domain := strings.ToLower(match[1])
for _, blockedDomain := range blockedDomains {
if blockedDomain == "" {
continue
}
if strings.HasSuffix(domain, blockedDomain) || domain == blockedDomain {
// Domain is blocked, create warning message
response := &model.Message{
Text: fmt.Sprintf("⚠️ Message contained a link to blocked domain: %s", blockedDomain),
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
}
return []*model.Message{response}
}
}
}
return nil
}
func init() {
plugin.Register(New())
}
```
### Advanced Example: URL Transformer
This more complex plugin transforms URLs, useful for improving media embedding in chat platforms:
```go
package social
import (
"net/url"
"regexp"
"strings"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// TwitterExpander transforms twitter.com links to fxtwitter.com links
type TwitterExpander struct {
plugin.BasePlugin
}
// New creates a new TwitterExpander instance
func NewTwitter() *TwitterExpander {
return &TwitterExpander{
BasePlugin: plugin.BasePlugin{
ID: "social.twitter",
Name: "Twitter Link Expander",
Help: "Automatically converts twitter.com links to fxtwitter.com links and removes tracking parameters",
},
}
}
// OnMessage handles incoming messages
func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
// Skip empty messages
if strings.TrimSpace(msg.Text) == "" {
return nil
}
// Regex to match twitter.com links
twitterRegex := regexp.MustCompile(`https?://(www\.)?(twitter\.com|x\.com)/[^\s]+`)
// Check if the message contains a Twitter link
if !twitterRegex.MatchString(msg.Text) {
return nil
}
// Transform the URL
transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string {
// Parse the URL
parsedURL, err := url.Parse(link)
if err != nil {
// If parsing fails, just do the simple replacement
link = strings.Replace(link, "twitter.com", "fxtwitter.com", 1)
link = strings.Replace(link, "x.com", "fxtwitter.com", 1)
return link
}
// Change the host
if strings.Contains(parsedURL.Host, "twitter.com") {
parsedURL.Host = strings.Replace(parsedURL.Host, "twitter.com", "fxtwitter.com", 1)
} else if strings.Contains(parsedURL.Host, "x.com") {
parsedURL.Host = strings.Replace(parsedURL.Host, "x.com", "fxtwitter.com", 1)
}
// Remove query parameters
parsedURL.RawQuery = ""
// Return the cleaned URL
return parsedURL.String()
})
// Create response message
response := &model.Message{
Text: transformed,
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
}
return []*model.Message{response}
}
```
## Enabling Configuration for Plugins
To indicate that your plugin requires configuration:
1. Set `ConfigRequired: true` in the BasePlugin struct:
```go
BasePlugin: plugin.BasePlugin{
ID: "myplugin.id",
Name: "Plugin Name",
Help: "Help text",
ConfigRequired: true,
},
```
2. Access the configuration in the OnMessage method:
```go
func (p *MyPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
// Extract configuration values
configValue, ok := config["some_config_key"].(string)
if !ok || configValue == "" {
// Handle missing or empty configuration
return nil
}
// Use the configuration...
}
```
3. The admin interface will show a "Configure" button for plugins that require configuration.
## Registering Plugins
To use the plugin, register it in your application: To use the plugin, register it in your application:
```go ```go
@ -269,19 +55,8 @@ func (a *App) Run() error {
// ... // ...
// Register plugins // Register plugins
plugin.Register(ping.New()) // Development plugin plugin.Register(myplugin.New())
plugin.Register(fun.NewCoin()) // Fun plugin
plugin.Register(social.NewTwitter()) // Social media plugin
plugin.Register(myplugin.New()) // Your custom plugin
// ... // ...
} }
``` ```
Alternatively, you can register your plugin in its init() function:
```go
func init() {
plugin.Register(New())
}
```

View file

@ -9,19 +9,3 @@
- Lo quito: What happens when you say _"lo quito"_...? (Spanish pun) - Lo quito: What happens when you say _"lo quito"_...? (Spanish pun)
- Dice: Put `!dice` and wathever roll you want to perform. - Dice: Put `!dice` and wathever roll you want to perform.
- Coin: Flip a coin and get heads or tails. - Coin: Flip a coin and get heads or tails.
- How Long To Beat: Get game completion times from HowLongToBeat.com using `!hltb <game name>`
### Utility
- Help: Shows available commands when you type `!help`. Lists all enabled plugins for the current channel organized by category with their descriptions and usage instructions.
- Remind Me: Reply to a message with `!remindme <duration>` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time.
- Search and Replace: Reply to any message with `s/search/replace/[flags]` to perform text substitution. Supports flags: `g` (global), `i` (case insensitive), `n` (regex pattern). Example: `s/hello/hi/gi` replaces all occurrences of "hello" with "hi" case-insensitively.
### Security
- Domain Blocker: Blocks messages containing links from specified domains. Configure it per channel with a comma-separated list of domains to block. When a message contains a link matching any of the blocked domains, the bot will notify that the message contained a blocked domain. This plugin requires configuration through the admin interface.
### Social Media
- Twitter Link Expander: Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: fxtwitter.com).
- Instagram Link Expander: Automatically converts instagram.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: ddinstagram.com).

2
go.mod
View file

@ -6,7 +6,6 @@ require (
github.com/gorilla/sessions v1.4.0 github.com/gorilla/sessions v1.4.0
golang.org/x/crypto v0.37.0 golang.org/x/crypto v0.37.0
golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
modernc.org/sqlite v1.37.0 modernc.org/sqlite v1.37.0
) )
@ -17,6 +16,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
golang.org/x/sys v0.32.0 // indirect golang.org/x/sys v0.32.0 // indirect
modernc.org/libc v1.63.0 // indirect modernc.org/libc v1.63.0 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

View file

@ -16,7 +16,7 @@ import (
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
) )
//go:embed templates/*.html templates/plugins/*.html //go:embed templates/*.html
var templateFS embed.FS var templateFS embed.FS
const ( const (
@ -46,7 +46,6 @@ type TemplateData struct {
Channels []*model.Channel Channels []*model.Channel
Channel *model.Channel Channel *model.Channel
ChannelPlugin *model.ChannelPlugin ChannelPlugin *model.ChannelPlugin
Version string
} }
// Admin represents the admin interface // Admin represents the admin interface
@ -56,11 +55,10 @@ type Admin struct {
store *sessions.CookieStore store *sessions.CookieStore
templates map[string]*template.Template templates map[string]*template.Template
baseTemplate *template.Template baseTemplate *template.Template
version string
} }
// New creates a new Admin instance // New creates a new Admin instance
func New(cfg *config.Config, database *db.Database, version string) *Admin { func New(cfg *config.Config, database *db.Database) *Admin {
// Create session store with appropriate options // Create session store with appropriate options
store := sessions.NewCookieStore([]byte(cfg.SecretKey)) store := sessions.NewCookieStore([]byte(cfg.SecretKey))
store.Options = &sessions.Options{ store.Options = &sessions.Options{
@ -90,7 +88,7 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
} }
// Parse and register all templates // Parse and register all templates
mainTemplateFiles := []string{ templateFiles := []string{
"index.html", "index.html",
"login.html", "login.html",
"change_password.html", "change_password.html",
@ -98,48 +96,27 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
"channel_detail.html", "channel_detail.html",
"plugin_list.html", "plugin_list.html",
"channel_plugins_list.html", "channel_plugins_list.html",
"channel_plugin_config.html",
} }
pluginTemplateFiles := []string{ for _, tf := range templateFiles {
"plugins/security.domainblock.html",
"plugins/social.instagram.html",
"plugins/social.twitter.html",
}
for _, tf := range mainTemplateFiles {
// Read template content from embedded filesystem // Read template content from embedded filesystem
content, err := templateFS.ReadFile("templates/" + tf) content, err := templateFS.ReadFile("templates/" + tf)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Create a clone of the base template // Create a clone of the base template
t, err := baseTemplate.Clone() t, err := baseTemplate.Clone()
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Parse the template content // Parse the template content
t, err = t.Parse(string(content)) t, err = t.Parse(string(content))
if err != nil { if err != nil {
panic(err) panic(err)
} }
// If this is the channel_plugin_config template, also parse plugin templates
if tf == "channel_plugin_config.html" {
for _, pluginTf := range pluginTemplateFiles {
pluginContent, err := templateFS.ReadFile("templates/" + pluginTf)
if err != nil {
panic(err)
}
t, err = t.Parse(string(pluginContent))
if err != nil {
panic(err)
}
}
}
templates[tf] = t templates[tf] = t
} }
@ -149,7 +126,6 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
store: store, store: store,
templates: templates, templates: templates,
baseTemplate: baseTemplate, baseTemplate: baseTemplate,
version: version,
} }
} }
@ -164,7 +140,6 @@ func (a *Admin) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/admin/channels", a.handleChannelList) mux.HandleFunc("/admin/channels", a.handleChannelList)
mux.HandleFunc("/admin/channels/", a.handleChannelDetail) mux.HandleFunc("/admin/channels/", a.handleChannelDetail)
mux.HandleFunc("/admin/channelplugins", a.handleChannelPluginList) mux.HandleFunc("/admin/channelplugins", a.handleChannelPluginList)
mux.HandleFunc("/admin/channelplugins/config/", a.handleChannelPluginConfig)
mux.HandleFunc("/admin/channelplugins/", a.handleChannelPluginDetailOrDelete) mux.HandleFunc("/admin/channelplugins/", a.handleChannelPluginDetailOrDelete)
} }
@ -216,7 +191,7 @@ func (a *Admin) addFlash(w http.ResponseWriter, r *http.Request, message string,
} }
// Map internal categories to Bootstrap alert classes // Map internal categories to Bootstrap alert classes
var alertClass string alertClass := category
switch category { switch category {
case "success": case "success":
alertClass = "success" alertClass = "success"
@ -271,6 +246,17 @@ func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessag
return messages return messages
} }
// requireLogin middleware checks if the user is logged in
func (a *Admin) requireLogin(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !a.isLoggedIn(r) {
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
return
}
next(w, r)
}
}
// render renders a template with the given data // render renders a template with the given data
func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) { func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) {
// Add current user data // Add current user data
@ -278,7 +264,6 @@ func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName stri
data.LoggedIn = a.isLoggedIn(r) data.LoggedIn = a.isLoggedIn(r)
data.Path = r.URL.Path data.Path = r.URL.Path
data.Flash = a.getFlashes(w, r) data.Flash = a.getFlashes(w, r)
data.Version = a.version
// Get template // Get template
tmpl, ok := a.templates[templateName] tmpl, ok := a.templates[templateName]
@ -345,10 +330,7 @@ func (a *Admin) handleLogin(w http.ResponseWriter, r *http.Request) {
// Set session expiration // Set session expiration
session.Options.MaxAge = 3600 * 24 * 7 // 1 week session.Options.MaxAge = 3600 * 24 * 7 // 1 week
err = session.Save(r, w) session.Save(r, w)
if err != nil {
fmt.Printf("Error saving session: %v\n", err)
}
a.addFlash(w, r, "You were logged in", "success") a.addFlash(w, r, "You were logged in", "success")
@ -376,7 +358,7 @@ func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/admin/login", http.StatusSeeOther) http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
return return
} }
session.Values = make(map[interface{}]interface{}) session.Values = make(map[interface{}]interface{})
session.Options.MaxAge = -1 // Delete session session.Options.MaxAge = -1 // Delete session
err = session.Save(r, w) err = session.Save(r, w)
@ -564,13 +546,6 @@ func (a *Admin) handleChannelDetail(w http.ResponseWriter, r *http.Request) {
return return
} }
// Update enable_all_plugins
enableAllPlugins := r.FormValue("enable_all_plugins") == "true"
if err := a.db.UpdateChannelEnableAllPlugins(id, enableAllPlugins); err != nil {
http.Error(w, "Failed to update channel enable all plugins", http.StatusInternalServerError)
return
}
a.addFlash(w, r, "Channel updated", "success") a.addFlash(w, r, "Channel updated", "success")
http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther) http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther)
return return
@ -657,96 +632,6 @@ func (a *Admin) handleChannelPluginList(w http.ResponseWriter, r *http.Request)
}) })
} }
// handleChannelPluginConfig handles the channel plugin configuration route
func (a *Admin) handleChannelPluginConfig(w http.ResponseWriter, r *http.Request) {
// Check if user is logged in
if !a.isLoggedIn(r) {
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
return
}
// Extract channel plugin ID from path
path := r.URL.Path
channelPluginID := strings.TrimPrefix(path, "/admin/channelplugins/config/")
// Convert channel plugin ID to int64
id, err := strconv.ParseInt(channelPluginID, 10, 64)
if err != nil {
http.Error(w, "Invalid channel plugin ID", http.StatusBadRequest)
return
}
// Get the channel plugin
channelPlugin, err := a.db.GetChannelPluginByID(id)
if err != nil {
http.Error(w, "Channel plugin not found", http.StatusNotFound)
return
}
// Get the plugin
p, err := plugin.Get(channelPlugin.PluginID)
if err != nil {
http.Error(w, "Plugin not found", http.StatusNotFound)
return
}
// Handle form submission
if r.Method == http.MethodPost {
// Parse form
if err := r.ParseForm(); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Create config map from form values
config := make(map[string]interface{})
// Process form values based on plugin type
if channelPlugin.PluginID == "security.domainblock" {
// Get blocked domains from form
blockedDomains := r.FormValue("blocked_domains")
config["blocked_domains"] = blockedDomains
} else {
// Generic handling for other plugins
for key, values := range r.Form {
if key == "form_submitted" {
continue
}
if len(values) == 1 {
config[key] = values[0]
} else {
config[key] = values
}
}
}
// Update plugin configuration
if err := a.db.UpdateChannelPluginConfig(id, config); err != nil {
http.Error(w, "Failed to update plugin configuration", http.StatusInternalServerError)
return
}
// Get the channel to redirect back to the channel detail page
channel, err := a.db.GetChannelByID(channelPlugin.ChannelID)
if err != nil {
a.addFlash(w, r, "Plugin configuration updated", "success")
http.Redirect(w, r, "/admin/channelplugins", http.StatusSeeOther)
return
}
a.addFlash(w, r, "Plugin configuration updated", "success")
http.Redirect(w, r, fmt.Sprintf("/admin/channels/%d", channel.ID), http.StatusSeeOther)
return
}
// Render template
a.render(w, r, "channel_plugin_config.html", TemplateData{
Title: "Configure Plugin: " + p.GetName(),
ChannelPlugin: channelPlugin,
Plugins: map[string]model.Plugin{channelPlugin.PluginID: p},
})
}
// handleChannelPluginDetailOrDelete handles the channel plugin detail or delete route // handleChannelPluginDetailOrDelete handles the channel plugin detail or delete route
func (a *Admin) handleChannelPluginDetailOrDelete(w http.ResponseWriter, r *http.Request) { func (a *Admin) handleChannelPluginDetailOrDelete(w http.ResponseWriter, r *http.Request) {
// Check if user is logged in // Check if user is logged in

View file

@ -117,19 +117,6 @@
</div> </div>
</div> </div>
<footer class="footer footer-transparent d-print-none">
<div class="container-xl">
<div class="row text-center align-items-center flex-row-reverse">
<div class="col-12 col-lg-auto mt-3 mt-lg-0">
<ul class="list-inline list-inline-dots mb-0">
<li class="list-inline-item">
ButterRobot {{if .Version}}v{{.Version}}{{else}}(development){{end}}
</li>
</ul>
</div>
</div>
</div>
</footer>
</div> </div>
<script src="https://unpkg.com/@tabler/core@latest/dist/js/tabler.min.js"></script> <script src="https://unpkg.com/@tabler/core@latest/dist/js/tabler.min.js"></script>

View file

@ -27,15 +27,6 @@
<!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked --> <!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked -->
<input type="hidden" name="form_submitted" value="true"> <input type="hidden" name="form_submitted" value="true">
</div> </div>
<div class="mb-3">
<label class="form-check form-switch">
<input class="form-check-input" type="checkbox" name="enable_all_plugins" value="true" {{if .Channel.EnableAllPlugins}}checked{{end}}>
<span class="form-check-label">Enable All Plugins</span>
</label>
<div>
When enabled, all registered plugins will be automatically enabled for this channel. Individual plugin settings will be ignored.
</div>
</div>
<div class="form-footer"> <div class="form-footer">
<button type="submit" class="btn btn-primary">Save</button> <button type="submit" class="btn btn-primary">Save</button>
<a href="/admin/channels" class="btn btn-link">Back to Channels</a> <a href="/admin/channels" class="btn btn-link">Back to Channels</a>
@ -77,10 +68,6 @@
{{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}}
</button> </button>
</form> </form>
{{$plugin := index $.Plugins $pluginID}}
{{if $plugin.RequiresConfig}}
<a href="/admin/channelplugins/config/{{$channelPlugin.ID}}" class="btn btn-info btn-sm">Configure</a>
{{end}}
<form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline"> <form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline">
<button type="submit" class="btn btn-danger btn-sm" <button type="submit" class="btn btn-danger btn-sm"
onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button> onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button>
@ -124,4 +111,4 @@
</div> </div>
</div> </div>
</div> </div>
{{end}} {{end}}

View file

@ -1,32 +0,0 @@
{{define "content"}}
<div class="row">
<div class="col-md-12">
<div class="card">
<div class="card-header">
<h3 class="card-title">Configure Plugin: {{(index .Plugins .ChannelPlugin.PluginID).GetName}}</h3>
</div>
<div class="card-body">
<form method="post">
<!-- Plugin configuration fields -->
{{if eq .ChannelPlugin.PluginID "security.domainblock"}}
{{template "plugins/security.domainblock.html" .}}
{{else if eq .ChannelPlugin.PluginID "social.instagram"}}
{{template "plugins/social.instagram.html" .}}
{{else if eq .ChannelPlugin.PluginID "social.twitter"}}
{{template "plugins/social.twitter.html" .}}
{{else}}
<div class="alert alert-warning">
This plugin doesn't have specific configuration fields implemented yet.
</div>
{{end}}
<div class="form-footer">
<button type="submit" class="btn btn-primary">Save Configuration</button>
<a href="/admin/channels/{{.ChannelPlugin.ChannelID}}" class="btn btn-secondary">Cancel</a>
</div>
</form>
</div>
</div>
</div>
</div>
{{end}}

View file

@ -38,10 +38,6 @@
{{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}}
</button> </button>
</form> </form>
{{$plugin := index $.Plugins $pluginID}}
{{if $plugin.ConfigRequired}}
<a href="/admin/channelplugins/config/{{$channelPlugin.ID}}" class="btn btn-info btn-sm">Configure</a>
{{end}}
<form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline"> <form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline">
<button type="submit" class="btn btn-danger btn-sm" <button type="submit" class="btn btn-danger btn-sm"
onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button> onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button>
@ -94,4 +90,4 @@
</div> </div>
</div> </div>
</div> </div>
{{end}} {{end}}

View file

@ -1,12 +0,0 @@
{{define "plugins/security.domainblock.html"}}
<div class="mb-3">
<label class="form-label">Blocked Domains</label>
<input type="text" class="form-control" name="blocked_domains"
value="{{with .ChannelPlugin.Config}}{{index . "blocked_domains"}}{{end}}"
placeholder="example.com, evil.org, ads.com">
<div class="form-text text-muted">
Enter comma-separated list of domains to block (e.g., example.com, evil.org).
Messages containing links to these domains will be blocked.
</div>
</div>
{{end}}

View file

@ -1,11 +0,0 @@
{{define "plugins/social.instagram.html"}}
<div class="mb-3">
<label class="form-label">Replacement Domain</label>
<input type="text" class="form-control" name="domain"
value="{{with .ChannelPlugin.Config}}{{index . "domain"}}{{end}}"
placeholder="ddinstagram.com">
<div class="form-text text-muted">
Enter the domain to replace instagram.com links with. Default is ddinstagram.com if left empty.
</div>
</div>
{{end}}

View file

@ -1,11 +0,0 @@
{{define "plugins/social.twitter.html"}}
<div class="mb-3">
<label class="form-label">Replacement Domain</label>
<input type="text" class="form-control" name="domain"
value="{{with .ChannelPlugin.Config}}{{index . "domain"}}{{end}}"
placeholder="fxtwitter.com">
<div class="form-text text-muted">
Enter the domain to replace twitter.com and x.com links with. Default is fxtwitter.com if left empty.
</div>
</div>
{{end}}

View file

@ -9,37 +9,28 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"runtime/debug"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"git.nakama.town/fmartingr/butterrobot/internal/admin" "git.nakama.town/fmartingr/butterrobot/internal/admin"
"git.nakama.town/fmartingr/butterrobot/internal/cache"
"git.nakama.town/fmartingr/butterrobot/internal/config" "git.nakama.town/fmartingr/butterrobot/internal/config"
"git.nakama.town/fmartingr/butterrobot/internal/db" "git.nakama.town/fmartingr/butterrobot/internal/db"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/platform" "git.nakama.town/fmartingr/butterrobot/internal/platform"
"git.nakama.town/fmartingr/butterrobot/internal/plugin" "git.nakama.town/fmartingr/butterrobot/internal/plugin"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/domainblock"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun" "git.nakama.town/fmartingr/butterrobot/internal/plugin/fun"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/help"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping" "git.nakama.town/fmartingr/butterrobot/internal/plugin/ping"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/searchreplace"
"git.nakama.town/fmartingr/butterrobot/internal/plugin/social"
"git.nakama.town/fmartingr/butterrobot/internal/queue" "git.nakama.town/fmartingr/butterrobot/internal/queue"
) )
// App represents the application // App represents the application
type App struct { type App struct {
config *config.Config config *config.Config
logger *slog.Logger logger *slog.Logger
db *db.Database db *db.Database
router *http.ServeMux router *http.ServeMux
queue *queue.Queue queue *queue.Queue
admin *admin.Admin admin *admin.Admin
version string
} }
// New creates a new App instance // New creates a new App instance
@ -56,24 +47,16 @@ func New(cfg *config.Config, logger *slog.Logger) (*App, error) {
// Initialize message queue // Initialize message queue
messageQueue := queue.New(logger) messageQueue := queue.New(logger)
// Get version information
version := ""
info, ok := debug.ReadBuildInfo()
if ok {
version = info.Main.Version
}
// Initialize admin interface // Initialize admin interface
adminInterface := admin.New(cfg, database, version) adminInterface := admin.New(cfg, database)
return &App{ return &App{
config: cfg, config: cfg,
logger: logger, logger: logger,
db: database, db: database,
router: router, router: router,
queue: messageQueue, queue: messageQueue,
admin: adminInterface, admin: adminInterface,
version: version,
}, nil }, nil
} }
@ -89,13 +72,6 @@ func (a *App) Run() error {
plugin.Register(fun.NewCoin()) plugin.Register(fun.NewCoin())
plugin.Register(fun.NewDice()) plugin.Register(fun.NewDice())
plugin.Register(fun.NewLoquito()) plugin.Register(fun.NewLoquito())
plugin.Register(fun.NewHLTB())
plugin.Register(social.NewTwitterExpander())
plugin.Register(social.NewInstagramExpander())
plugin.Register(reminder.New(a.db))
plugin.Register(domainblock.New())
plugin.Register(searchreplace.New())
plugin.Register(help.New(a.db))
// Initialize routes // Initialize routes
a.initializeRoutes() a.initializeRoutes()
@ -103,12 +79,6 @@ func (a *App) Run() error {
// Start message queue worker // Start message queue worker
a.queue.Start(a.handleMessage) a.queue.Start(a.handleMessage)
// Start reminder scheduler
a.queue.StartReminderScheduler(a.handleReminder)
// Start cache cleanup scheduler
go a.startCacheCleanup()
// Create server // Create server
addr := fmt.Sprintf(":%s", a.config.Port) addr := fmt.Sprintf(":%s", a.config.Port)
srv := &http.Server{ srv := &http.Server{
@ -154,29 +124,13 @@ func (a *App) Run() error {
return nil return nil
} }
// startCacheCleanup runs periodic cache cleanup
func (a *App) startCacheCleanup() {
ticker := time.NewTicker(time.Hour) // Clean up every hour
defer ticker.Stop()
for range ticker.C {
if err := a.db.CacheCleanup(); err != nil {
a.logger.Error("Cache cleanup failed", "error", err)
} else {
a.logger.Debug("Cache cleanup completed")
}
}
}
// Initialize HTTP routes // Initialize HTTP routes
func (a *App) initializeRoutes() { func (a *App) initializeRoutes() {
// Health check endpoint // Health check endpoint
a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(map[string]interface{}{}); err != nil { json.NewEncoder(w).Encode(map[string]interface{}{})
a.logger.Error("Error encoding response", "error", err)
}
}) })
// Platform webhook endpoints // Platform webhook endpoints
@ -199,9 +153,7 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
if _, err := platform.Get(platformName); err != nil { if _, err := platform.Get(platformName); err != nil {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
if err := json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}); err != nil { json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"})
a.logger.Error("Error encoding response", "error", err)
}
return return
} }
@ -210,9 +162,7 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
if err := json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}); err != nil { json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"})
a.logger.Error("Error encoding response", "error", err)
}
return return
} }
@ -228,9 +178,7 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
// Respond with success // Respond with success
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { json.NewEncoder(w).Encode(map[string]any{})
a.logger.Error("Error encoding response", "error", err)
}
} }
// extractPlatformName extracts the platform name from the URL path // extractPlatformName extracts the platform name from the URL path
@ -314,21 +262,11 @@ func (a *App) handleMessage(item queue.Item) {
} }
// Process message with plugins // Process message with plugins
var pluginsToProcess []string for pluginID, channelPlugin := range channel.Plugins {
if !channel.HasEnabledPlugin(pluginID) {
if channel.EnableAllPlugins { continue
// If EnableAllPlugins is true, process all registered plugins
pluginsToProcess = plugin.GetAvailablePluginIDs()
} else {
// Otherwise, process only explicitly enabled plugins
for pluginID := range channel.Plugins {
if channel.HasEnabledPlugin(pluginID) {
pluginsToProcess = append(pluginsToProcess, pluginID)
}
} }
}
for _, pluginID := range pluginsToProcess {
// Get plugin // Get plugin
p, err := plugin.Get(pluginID) p, err := plugin.Get(pluginID)
if err != nil { if err != nil {
@ -336,121 +274,20 @@ func (a *App) handleMessage(item queue.Item) {
continue continue
} }
// Get plugin configuration (empty map if EnableAllPlugins and plugin not explicitly configured) // Process message
var config map[string]interface{} responses := p.OnMessage(message, channelPlugin.Config)
if channelPlugin, exists := channel.Plugins[pluginID]; exists {
config = channelPlugin.Config
} else {
config = make(map[string]interface{})
}
// Create cache instance for this plugin // Send responses
pluginCache := cache.New(a.db, pluginID)
// Process message and get actions
actions := p.OnMessage(message, config, pluginCache)
// Get platform for processing actions
platform, err := platform.Get(item.Platform) platform, err := platform.Get(item.Platform)
if err != nil { if err != nil {
a.logger.Error("Error getting platform", "error", err) a.logger.Error("Error getting platform", "error", err)
continue continue
} }
// Process each action for _, response := range responses {
for _, action := range actions { if err := platform.SendMessage(response); err != nil {
switch action.Type { a.logger.Error("Error sending message", "error", err)
case model.ActionSendMessage:
// Send a message
if action.Message != nil {
if err := platform.SendMessage(action.Message); err != nil {
a.logger.Error("Error sending message", "error", err)
}
} else {
a.logger.Error("Send message action with nil message")
}
case model.ActionDeleteMessage:
// Delete a message using direct DeleteMessage call
if err := platform.DeleteMessage(action.Chat, action.MessageID); err != nil {
a.logger.Error("Error deleting message", "error", err, "message_id", action.MessageID)
} else {
a.logger.Info("Message deleted", "message_id", action.MessageID)
}
default:
a.logger.Error("Unknown action type", "type", action.Type)
} }
} }
} }
} }
// handleReminder handles reminder processing
func (a *App) handleReminder(reminder *model.Reminder) {
// When called with nil, it means we should check for pending reminders
if reminder == nil {
// Get pending reminders
reminders, err := a.db.GetPendingReminders()
if err != nil {
a.logger.Error("Error getting pending reminders", "error", err)
return
}
// Process each reminder
for _, r := range reminders {
a.processReminder(r)
}
return
}
// Otherwise, process the specific reminder
a.processReminder(reminder)
}
// processReminder processes an individual reminder
func (a *App) processReminder(reminder *model.Reminder) {
a.logger.Info("Processing reminder",
"id", reminder.ID,
"platform", reminder.Platform,
"channel", reminder.ChannelID,
"trigger_at", reminder.TriggerAt,
)
// Get the platform handler
p, err := platform.Get(reminder.Platform)
if err != nil {
a.logger.Error("Error getting platform for reminder", "error", err, "platform", reminder.Platform)
return
}
// Get the channel
channel, err := a.db.GetChannelByPlatform(reminder.Platform, reminder.ChannelID)
if err != nil {
a.logger.Error("Error getting channel for reminder", "error", err)
return
}
// Create the reminder message
reminderText := fmt.Sprintf("@%s reminding you of this", reminder.Username)
message := &model.Message{
Text: reminderText,
Chat: reminder.ChannelID,
Channel: channel,
Author: "bot",
FromBot: true,
Date: time.Now(),
ReplyTo: reminder.ReplyToID, // Reply to the original message
}
// Send the reminder message
if err := p.SendMessage(message); err != nil {
a.logger.Error("Error sending reminder", "error", err)
return
}
// Mark the reminder as processed
if err := a.db.MarkReminderAsProcessed(reminder.ID); err != nil {
a.logger.Error("Error marking reminder as processed", "error", err)
}
}

View file

@ -1,83 +0,0 @@
package cache
import (
"encoding/json"
"fmt"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/db"
)
// Cache provides a plugin-friendly interface to the cache system
type Cache struct {
db *db.Database
pluginID string
}
// New creates a new Cache instance for a specific plugin
func New(database *db.Database, pluginID string) *Cache {
return &Cache{
db: database,
pluginID: pluginID,
}
}
// Get retrieves a value from the cache
func (c *Cache) Get(key string, destination interface{}) error {
// Create prefixed key
fullKey := c.createKey(key)
// Get from database
value, err := c.db.CacheGet(fullKey)
if err != nil {
return err
}
// Unmarshal JSON into destination
return json.Unmarshal([]byte(value), destination)
}
// Set stores a value in the cache with optional expiration
func (c *Cache) Set(key string, value interface{}, expiration *time.Time) error {
// Create prefixed key
fullKey := c.createKey(key)
// Marshal value to JSON
jsonValue, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal cache value: %w", err)
}
// Store in database
return c.db.CacheSet(fullKey, string(jsonValue), expiration)
}
// SetWithTTL stores a value in the cache with a time-to-live duration
func (c *Cache) SetWithTTL(key string, value interface{}, ttl time.Duration) error {
expiration := time.Now().Add(ttl)
return c.Set(key, value, &expiration)
}
// Delete removes a value from the cache
func (c *Cache) Delete(key string) error {
fullKey := c.createKey(key)
return c.db.CacheDelete(fullKey)
}
// Exists checks if a key exists in the cache
func (c *Cache) Exists(key string) (bool, error) {
fullKey := c.createKey(key)
_, err := c.db.CacheGet(fullKey)
if err == db.ErrNotFound {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// createKey creates a prefixed cache key
func (c *Cache) createKey(key string) string {
return fmt.Sprintf("%s_%s", c.pluginID, key)
}

View file

@ -1,176 +0,0 @@
package cache
import (
"fmt"
"os"
"testing"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/db"
)
func TestCache(t *testing.T) {
// Create temporary database for testing with unique name
dbFile := fmt.Sprintf("test_cache_%d.db", time.Now().UnixNano())
database, err := db.New(dbFile)
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer func() {
_ = database.Close()
// Clean up test database file
_ = os.Remove(dbFile)
}()
// Create cache instance
cache := New(database, "test.plugin")
// Test data
testKey := "test_key"
testValue := map[string]interface{}{
"name": "Test Game",
"time": 42,
}
// Test Set and Get
t.Run("Set and Get", func(t *testing.T) {
err := cache.Set(testKey, testValue, nil)
if err != nil {
t.Errorf("Failed to set cache value: %v", err)
}
var retrieved map[string]interface{}
err = cache.Get(testKey, &retrieved)
if err != nil {
t.Errorf("Failed to get cache value: %v", err)
}
if retrieved["name"] != testValue["name"] {
t.Errorf("Expected name %v, got %v", testValue["name"], retrieved["name"])
}
if int(retrieved["time"].(float64)) != testValue["time"].(int) {
t.Errorf("Expected time %v, got %v", testValue["time"], retrieved["time"])
}
})
// Test SetWithTTL and expiration
t.Run("SetWithTTL and expiration", func(t *testing.T) {
expiredKey := "expired_key"
// Set with very short TTL
err := cache.SetWithTTL(expiredKey, testValue, time.Millisecond)
if err != nil {
t.Errorf("Failed to set cache value with TTL: %v", err)
}
// Wait for expiration
time.Sleep(2 * time.Millisecond)
// Try to get - should fail
var retrieved map[string]interface{}
err = cache.Get(expiredKey, &retrieved)
if err == nil {
t.Errorf("Expected cache miss for expired key, but got value")
}
})
// Test Exists
t.Run("Exists", func(t *testing.T) {
existsKey := "exists_key"
// Make sure the key doesn't exist initially by deleting it
_ = cache.Delete(existsKey)
// Should not exist initially
exists, err := cache.Exists(existsKey)
if err != nil {
t.Errorf("Failed to check if key exists: %v", err)
}
if exists {
t.Errorf("Expected key to not exist, but it does")
}
// Set value
err = cache.Set(existsKey, testValue, nil)
if err != nil {
t.Errorf("Failed to set cache value: %v", err)
}
// Should exist now
exists, err = cache.Exists(existsKey)
if err != nil {
t.Errorf("Failed to check if key exists: %v", err)
}
if !exists {
t.Errorf("Expected key to exist, but it doesn't")
}
})
// Test Delete
t.Run("Delete", func(t *testing.T) {
deleteKey := "delete_key"
// Set value
err := cache.Set(deleteKey, testValue, nil)
if err != nil {
t.Errorf("Failed to set cache value: %v", err)
}
// Delete value
err = cache.Delete(deleteKey)
if err != nil {
t.Errorf("Failed to delete cache value: %v", err)
}
// Should not exist anymore
var retrieved map[string]interface{}
err = cache.Get(deleteKey, &retrieved)
if err == nil {
t.Errorf("Expected cache miss for deleted key, but got value")
}
})
// Test plugin ID prefixing
t.Run("Plugin ID prefixing", func(t *testing.T) {
cache1 := New(database, "plugin1")
cache2 := New(database, "plugin2")
sameKey := "same_key"
value1 := "value1"
value2 := "value2"
// Set same key in both caches
err := cache1.Set(sameKey, value1, nil)
if err != nil {
t.Errorf("Failed to set cache1 value: %v", err)
}
err = cache2.Set(sameKey, value2, nil)
if err != nil {
t.Errorf("Failed to set cache2 value: %v", err)
}
// Retrieve from both caches
var retrieved1, retrieved2 string
err = cache1.Get(sameKey, &retrieved1)
if err != nil {
t.Errorf("Failed to get cache1 value: %v", err)
}
err = cache2.Get(sameKey, &retrieved2)
if err != nil {
t.Errorf("Failed to get cache2 value: %v", err)
}
// Values should be different due to plugin ID prefixing
if retrieved1 != value1 {
t.Errorf("Expected cache1 value %v, got %v", value1, retrieved1)
}
if retrieved2 != value2 {
t.Errorf("Expected cache2 value %v, got %v", value2, retrieved2)
}
})
}

View file

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
@ -35,11 +34,6 @@ func New(dbPath string) (*Database, error) {
return nil, err return nil, err
} }
// Configure SQLite for better reliability
if err := configureSQLite(db); err != nil {
return nil, err
}
// Initialize database // Initialize database
if err := initDatabase(db); err != nil { if err := initDatabase(db); err != nil {
return nil, err return nil, err
@ -56,7 +50,7 @@ func (d *Database) Close() error {
// GetChannelByID retrieves a channel by ID // GetChannelByID retrieves a channel by ID
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
query := ` query := `
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw SELECT id, platform, platform_channel_id, enabled, channel_raw
FROM channels FROM channels
WHERE id = ? WHERE id = ?
` `
@ -67,11 +61,10 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
platform string platform string
platformChannelID string platformChannelID string
enabled bool enabled bool
enableAllPlugins bool
channelRawJSON string channelRawJSON string
) )
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON) err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, ErrNotFound return nil, ErrNotFound
} }
@ -91,7 +84,6 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
Platform: platform, Platform: platform,
PlatformChannelID: platformChannelID, PlatformChannelID: platformChannelID,
Enabled: enabled, Enabled: enabled,
EnableAllPlugins: enableAllPlugins,
ChannelRaw: channelRaw, ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin), Plugins: make(map[string]*model.ChannelPlugin),
} }
@ -112,7 +104,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
// GetChannelByPlatform retrieves a channel by platform and platform channel ID // GetChannelByPlatform retrieves a channel by platform and platform channel ID
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) { func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
query := ` query := `
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw SELECT id, platform, platform_channel_id, enabled, channel_raw
FROM channels FROM channels
WHERE platform = ? AND platform_channel_id = ? WHERE platform = ? AND platform_channel_id = ?
` `
@ -120,13 +112,12 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
row := d.db.QueryRow(query, platform, platformChannelID) row := d.db.QueryRow(query, platform, platformChannelID)
var ( var (
id int64 id int64
enabled bool enabled bool
enableAllPlugins bool channelRawJSON string
channelRawJSON string
) )
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON) err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, ErrNotFound return nil, ErrNotFound
} }
@ -146,7 +137,6 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
Platform: platform, Platform: platform,
PlatformChannelID: platformChannelID, PlatformChannelID: platformChannelID,
Enabled: enabled, Enabled: enabled,
EnableAllPlugins: enableAllPlugins,
ChannelRaw: channelRaw, ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin), Plugins: make(map[string]*model.ChannelPlugin),
} }
@ -174,11 +164,11 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
// Insert channel // Insert channel
query := ` query := `
INSERT INTO channels (platform, platform_channel_id, enabled, enable_all_plugins, channel_raw) INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?)
` `
result, err := d.db.Exec(query, platform, platformChannelID, enabled, false, string(channelRawJSON)) result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -195,7 +185,6 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
Platform: platform, Platform: platform,
PlatformChannelID: platformChannelID, PlatformChannelID: platformChannelID,
Enabled: enabled, Enabled: enabled,
EnableAllPlugins: false,
ChannelRaw: channelRaw, ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin), Plugins: make(map[string]*model.ChannelPlugin),
} }
@ -215,18 +204,6 @@ func (d *Database) UpdateChannel(id int64, enabled bool) error {
return err return err
} }
// UpdateChannelEnableAllPlugins updates a channel's enable_all_plugins status
func (d *Database) UpdateChannelEnableAllPlugins(id int64, enableAllPlugins bool) error {
query := `
UPDATE channels
SET enable_all_plugins = ?
WHERE id = ?
`
_, err := d.db.Exec(query, enableAllPlugins, id)
return err
}
// DeleteChannel deletes a channel // DeleteChannel deletes a channel
func (d *Database) DeleteChannel(id int64) error { func (d *Database) DeleteChannel(id int64) error {
// First delete all channel plugins // First delete all channel plugins
@ -256,11 +233,7 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer rows.Close()
if err := rows.Close(); err != nil {
fmt.Printf("Error closing rows: %v\n", err)
}
}()
var plugins []*model.ChannelPlugin var plugins []*model.ChannelPlugin
@ -278,7 +251,7 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
} }
// Parse config JSON // Parse config JSON
var config map[string]any var config map[string]interface{}
if err := json.Unmarshal([]byte(configJSON), &config); err != nil { if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
return nil, err return nil, err
} }
@ -305,28 +278,6 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
return plugins, nil return plugins, nil
} }
// GetChannelPluginsFromPlatformID retrieves all plugins for a channel by platform and platform channel ID
func (d *Database) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) {
// First, get the channel ID by platform and platform channel ID
query := `
SELECT id
FROM channels
WHERE platform = ? AND platform_channel_id = ?
`
var channelID int64
err := d.db.QueryRow(query, platform, platformChannelID).Scan(&channelID)
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
// Now get the plugins for this channel
return d.GetChannelPlugins(channelID)
}
// GetChannelPluginByID retrieves a channel plugin by ID // GetChannelPluginByID retrieves a channel plugin by ID
func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) { func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) {
query := ` query := `
@ -430,24 +381,6 @@ func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error {
return err return err
} }
// UpdateChannelPluginConfig updates a channel plugin's configuration
func (d *Database) UpdateChannelPluginConfig(id int64, config map[string]interface{}) error {
// Convert config to JSON
configJSON, err := json.Marshal(config)
if err != nil {
return err
}
query := `
UPDATE channel_plugin
SET config = ?
WHERE id = ?
`
_, err = d.db.Exec(query, string(configJSON), id)
return err
}
// DeleteChannelPlugin deletes a channel plugin // DeleteChannelPlugin deletes a channel plugin
func (d *Database) DeleteChannelPlugin(id int64) error { func (d *Database) DeleteChannelPlugin(id int64) error {
query := ` query := `
@ -473,7 +406,7 @@ func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
// GetAllChannels retrieves all channels // GetAllChannels retrieves all channels
func (d *Database) GetAllChannels() ([]*model.Channel, error) { func (d *Database) GetAllChannels() ([]*model.Channel, error) {
query := ` query := `
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw SELECT id, platform, platform_channel_id, enabled, channel_raw
FROM channels FROM channels
` `
@ -481,11 +414,7 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer rows.Close()
if err := rows.Close(); err != nil {
fmt.Printf("Error closing rows: %v\n", err)
}
}()
var channels []*model.Channel var channels []*model.Channel
@ -495,11 +424,10 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
platform string platform string
platformChannelID string platformChannelID string
enabled bool enabled bool
enableAllPlugins bool
channelRawJSON string channelRawJSON string
) )
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON); err != nil { if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil {
return nil, err return nil, err
} }
@ -515,7 +443,6 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
Platform: platform, Platform: platform,
PlatformChannelID: platformChannelID, PlatformChannelID: platformChannelID,
Enabled: enabled, Enabled: enabled,
EnableAllPlugins: enableAllPlugins,
ChannelRaw: channelRaw, ChannelRaw: channelRaw,
Plugins: make(map[string]*model.ChannelPlugin), Plugins: make(map[string]*model.ChannelPlugin),
} }
@ -526,9 +453,10 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
continue // Skip this channel if plugins can't be retrieved continue // Skip this channel if plugins can't be retrieved
} }
// Add plugins to channel if plugins != nil {
for _, plugin := range plugins { for _, plugin := range plugins {
channel.Plugins[plugin.PluginID] = plugin channel.Plugins[plugin.PluginID] = plugin
}
} }
channels = append(channels, channel) channels = append(channels, channel)
@ -663,124 +591,6 @@ func (d *Database) UpdateUserPassword(userID int64, newPassword string) error {
return err return err
} }
// CreateReminder creates a new reminder
func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
query := `
INSERT INTO reminders (
platform, channel_id, message_id, reply_to_id,
user_id, username, created_at, trigger_at,
content, processed
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0)
`
createdAt := time.Now()
result, err := d.db.Exec(
query,
platform, channelID, messageID, replyToID,
userID, username, createdAt, triggerAt,
content,
)
if err != nil {
return nil, err
}
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
return &model.Reminder{
ID: id,
Platform: platform,
ChannelID: channelID,
MessageID: messageID,
ReplyToID: replyToID,
UserID: userID,
Username: username,
CreatedAt: createdAt,
TriggerAt: triggerAt,
Content: content,
Processed: false,
}, nil
}
// GetPendingReminders gets all pending reminders that need to be processed
func (d *Database) GetPendingReminders() ([]*model.Reminder, error) {
query := `
SELECT id, platform, channel_id, message_id, reply_to_id,
user_id, username, created_at, trigger_at, content, processed
FROM reminders
WHERE processed = 0 AND trigger_at <= ?
`
rows, err := d.db.Query(query, time.Now())
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
fmt.Printf("Error closing rows: %v\n", err)
}
}()
var reminders []*model.Reminder
for rows.Next() {
var (
id int64
platform, channelID, messageID, replyToID string
userID, username, content string
createdAt, triggerAt time.Time
processed bool
)
if err := rows.Scan(
&id, &platform, &channelID, &messageID, &replyToID,
&userID, &username, &createdAt, &triggerAt, &content, &processed,
); err != nil {
return nil, err
}
reminder := &model.Reminder{
ID: id,
Platform: platform,
ChannelID: channelID,
MessageID: messageID,
ReplyToID: replyToID,
UserID: userID,
Username: username,
CreatedAt: createdAt,
TriggerAt: triggerAt,
Content: content,
Processed: processed,
}
reminders = append(reminders, reminder)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(reminders) == 0 {
return make([]*model.Reminder, 0), nil
}
return reminders, nil
}
// MarkReminderAsProcessed marks a reminder as processed
func (d *Database) MarkReminderAsProcessed(id int64) error {
query := `
UPDATE reminders
SET processed = 1
WHERE id = ?
`
_, err := d.db.Exec(query, id)
return err
}
// Helper function to hash password // Helper function to hash password
func hashPassword(password string) (string, error) { func hashPassword(password string) (string, error) {
// Use bcrypt for secure password hashing // Use bcrypt for secure password hashing
@ -799,25 +609,25 @@ func initDatabase(db *sql.DB) error {
if err := migration.EnsureMigrationTable(db); err != nil { if err := migration.EnsureMigrationTable(db); err != nil {
return fmt.Errorf("failed to create migration table: %w", err) return fmt.Errorf("failed to create migration table: %w", err)
} }
// Get applied migrations // Get applied migrations
applied, err := migration.GetAppliedMigrations(db) applied, err := migration.GetAppliedMigrations(db)
if err != nil { if err != nil {
return fmt.Errorf("failed to get applied migrations: %w", err) return fmt.Errorf("failed to get applied migrations: %w", err)
} }
// Get all migration versions // Get all migration versions
allMigrations := make([]int, 0, len(migration.Migrations)) allMigrations := make([]int, 0, len(migration.Migrations))
for version := range migration.Migrations { for version := range migration.Migrations {
allMigrations = append(allMigrations, version) allMigrations = append(allMigrations, version)
} }
// Create a map of applied migrations for quick lookup // Create a map of applied migrations for quick lookup
appliedMap := make(map[int]bool) appliedMap := make(map[int]bool)
for _, version := range applied { for _, version := range applied {
appliedMap[version] = true appliedMap[version] = true
} }
// Count pending migrations // Count pending migrations
pendingCount := 0 pendingCount := 0
for _, version := range allMigrations { for _, version := range allMigrations {
@ -825,7 +635,7 @@ func initDatabase(db *sql.DB) error {
pendingCount++ pendingCount++
} }
} }
// Run migrations if needed // Run migrations if needed
if pendingCount > 0 { if pendingCount > 0 {
fmt.Printf("Running %d pending database migrations...\n", pendingCount) fmt.Printf("Running %d pending database migrations...\n", pendingCount)
@ -836,85 +646,6 @@ func initDatabase(db *sql.DB) error {
} else { } else {
fmt.Println("Database schema is up to date.") fmt.Println("Database schema is up to date.")
} }
return nil return nil
} }
// Configure SQLite for better reliability
func configureSQLite(db *sql.DB) error {
pragmas := []string{
// Enable Write-Ahead Logging for better concurrency and crash recovery
"PRAGMA journal_mode = WAL",
// Set 5-second timeout when database is locked by another connection
"PRAGMA busy_timeout = 5000",
// Balance between safety and performance for disk writes
"PRAGMA synchronous = NORMAL",
// Set large cache size (1GB) for better read performance
"PRAGMA cache_size = 1000000000",
// Enable foreign key constraint enforcement
"PRAGMA foreign_keys = true",
// Store temporary tables and indices in memory for speed
"PRAGMA temp_store = memory",
}
for _, pragma := range pragmas {
if _, err := db.Exec(pragma); err != nil {
return fmt.Errorf("failed to execute %s: %w", pragma, err)
}
}
return nil
}
// CacheGet retrieves a value from the cache
func (d *Database) CacheGet(key string) (string, error) {
query := `
SELECT value
FROM cache
WHERE key = ? AND (expires_at IS NULL OR expires_at > ?)
`
var value string
err := d.db.QueryRow(query, key, time.Now()).Scan(&value)
if err == sql.ErrNoRows {
return "", ErrNotFound
}
if err != nil {
return "", err
}
return value, nil
}
// CacheSet stores a value in the cache with optional expiration
func (d *Database) CacheSet(key, value string, expiration *time.Time) error {
query := `
INSERT OR REPLACE INTO cache (key, value, expires_at, updated_at)
VALUES (?, ?, ?, ?)
`
_, err := d.db.Exec(query, key, value, expiration, time.Now())
return err
}
// CacheDelete removes a value from the cache
func (d *Database) CacheDelete(key string) error {
query := `
DELETE FROM cache
WHERE key = ?
`
_, err := d.db.Exec(query, key)
return err
}
// CacheCleanup removes expired cache entries
func (d *Database) CacheCleanup() error {
query := `
DELETE FROM cache
WHERE expires_at IS NOT NULL AND expires_at <= ?
`
_, err := d.db.Exec(query, time.Now())
return err
}

View file

@ -1,203 +0,0 @@
package db
import (
"fmt"
"os"
"testing"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/model"
)
func TestEnableAllPlugins(t *testing.T) {
// Create temporary database for testing with unique name
dbFile := fmt.Sprintf("test_db_%d.db", time.Now().UnixNano())
database, err := New(dbFile)
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer func() {
_ = database.Close()
// Clean up test database file
_ = os.Remove(dbFile)
}()
t.Run("CreateChannel with EnableAllPlugins default false", func(t *testing.T) {
channelRaw := map[string]interface{}{
"name": "test-channel",
}
channel, err := database.CreateChannel("telegram", "123456", true, channelRaw)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
if channel.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be false by default, got true")
}
// Verify it's also false when retrieved from database
retrieved, err := database.GetChannelByID(channel.ID)
if err != nil {
t.Fatalf("Failed to retrieve channel: %v", err)
}
if retrieved.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be false when retrieved from DB, got true")
}
})
t.Run("UpdateChannelEnableAllPlugins", func(t *testing.T) {
// Create a channel
channelRaw := map[string]interface{}{
"name": "test-channel-2",
}
channel, err := database.CreateChannel("telegram", "123457", true, channelRaw)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
// Update EnableAllPlugins to true
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
}
// Retrieve and verify
retrieved, err := database.GetChannelByID(channel.ID)
if err != nil {
t.Fatalf("Failed to retrieve channel: %v", err)
}
if !retrieved.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be true after update, got false")
}
// Update back to false
err = database.UpdateChannelEnableAllPlugins(channel.ID, false)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins back to false: %v", err)
}
// Retrieve and verify again
retrieved, err = database.GetChannelByID(channel.ID)
if err != nil {
t.Fatalf("Failed to retrieve channel: %v", err)
}
if retrieved.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be false after second update, got true")
}
})
t.Run("GetChannelByPlatform includes EnableAllPlugins", func(t *testing.T) {
// Create a channel
channelRaw := map[string]interface{}{
"name": "test-channel-3",
}
channel, err := database.CreateChannel("slack", "C123456", true, channelRaw)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
// Enable all plugins
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
}
// Retrieve by platform
retrieved, err := database.GetChannelByPlatform("slack", "C123456")
if err != nil {
t.Fatalf("Failed to retrieve channel by platform: %v", err)
}
if !retrieved.EnableAllPlugins {
t.Errorf("Expected EnableAllPlugins to be true when retrieved by platform, got false")
}
})
t.Run("GetAllChannels includes EnableAllPlugins", func(t *testing.T) {
// Create multiple channels with different EnableAllPlugins settings
channelRaw1 := map[string]interface{}{"name": "channel-1"}
channelRaw2 := map[string]interface{}{"name": "channel-2"}
channel1, err := database.CreateChannel("platform1", "ch1", true, channelRaw1)
if err != nil {
t.Fatalf("Failed to create channel1: %v", err)
}
channel2, err := database.CreateChannel("platform2", "ch2", true, channelRaw2)
if err != nil {
t.Fatalf("Failed to create channel2: %v", err)
}
// Enable all plugins for channel2 only
err = database.UpdateChannelEnableAllPlugins(channel2.ID, true)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins for channel2: %v", err)
}
// Get all channels
channels, err := database.GetAllChannels()
if err != nil {
t.Fatalf("Failed to get all channels: %v", err)
}
// Find our test channels
var foundChannel1, foundChannel2 *model.Channel
for _, ch := range channels {
if ch.ID == channel1.ID {
foundChannel1 = ch
}
if ch.ID == channel2.ID {
foundChannel2 = ch
}
}
if foundChannel1 == nil {
t.Fatalf("Channel1 not found in GetAllChannels result")
}
if foundChannel2 == nil {
t.Fatalf("Channel2 not found in GetAllChannels result")
}
if foundChannel1.EnableAllPlugins {
t.Errorf("Expected channel1 EnableAllPlugins to be false, got true")
}
if !foundChannel2.EnableAllPlugins {
t.Errorf("Expected channel2 EnableAllPlugins to be true, got false")
}
})
t.Run("Migration applied correctly", func(t *testing.T) {
// Test that we can create a channel and the enable_all_plugins column exists
// This implicitly tests that migration 4 was applied correctly
channelRaw := map[string]interface{}{
"name": "migration-test-channel",
}
channel, err := database.CreateChannel("test-platform", "migration-test", true, channelRaw)
if err != nil {
t.Fatalf("Failed to create channel after migration: %v", err)
}
// Try to update EnableAllPlugins - this would fail if the column doesn't exist
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
if err != nil {
t.Fatalf("Failed to update EnableAllPlugins - migration may not have been applied: %v", err)
}
// Verify the value was set correctly
retrieved, err := database.GetChannelByID(channel.ID)
if err != nil {
t.Fatalf("Failed to retrieve channel: %v", err)
}
if !retrieved.EnableAllPlugins {
t.Errorf("EnableAllPlugins should be true after update")
}
})
}

View file

@ -49,11 +49,7 @@ func GetAppliedMigrations(db *sql.DB) ([]int, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer rows.Close()
if err := rows.Close(); err != nil {
fmt.Printf("Error closing rows: %v\n", err)
}
}()
var versions []int var versions []int
for rows.Next() { for rows.Next() {
@ -132,9 +128,7 @@ func Migrate(db *sql.DB) error {
// Apply the migration // Apply the migration
if err := migration.Up(db); err != nil { if err := migration.Up(db); err != nil {
if err := tx.Rollback(); err != nil { tx.Rollback()
fmt.Printf("Error rolling back transaction: %v\n", err)
}
return fmt.Errorf("failed to apply migration %d: %w", version, err) return fmt.Errorf("failed to apply migration %d: %w", version, err)
} }
@ -143,9 +137,7 @@ func Migrate(db *sql.DB) error {
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", "INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
version, time.Now(), version, time.Now(),
); err != nil { ); err != nil {
if err := tx.Rollback(); err != nil { tx.Rollback()
fmt.Printf("Error rolling back transaction: %v\n", err)
}
return fmt.Errorf("failed to mark migration %d as applied: %w", version, err) return fmt.Errorf("failed to mark migration %d as applied: %w", version, err)
} }
@ -196,17 +188,13 @@ func MigrateDown(db *sql.DB, targetVersion int) error {
// Apply the down migration // Apply the down migration
if err := migration.Down(db); err != nil { if err := migration.Down(db); err != nil {
if err := tx.Rollback(); err != nil { tx.Rollback()
fmt.Printf("Error rolling back transaction: %v\n", err)
}
return fmt.Errorf("failed to roll back migration %d: %w", version, err) return fmt.Errorf("failed to roll back migration %d: %w", version, err)
} }
// Remove from applied list // Remove from applied list
if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil { if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil {
if err := tx.Rollback(); err != nil { tx.Rollback()
fmt.Printf("Error rolling back transaction: %v\n", err)
}
return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err) return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err)
} }
@ -220,4 +208,4 @@ func MigrateDown(db *sql.DB, targetVersion int) error {
} }
return nil return nil
} }

View file

@ -8,9 +8,6 @@ import (
func init() { func init() {
// Register migrations // Register migrations
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown) Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown)
Register(3, "Add cache table", migrateCacheUp, migrateCacheDown)
Register(4, "Add enable_all_plugins column to channels", migrateEnableAllPluginsUp, migrateEnableAllPluginsDown)
} }
// Initial schema creation with bcrypt passwords - version 1 // Initial schema creation with bcrypt passwords - version 1
@ -63,14 +60,14 @@ func migrateInitialSchemaUp(db *sql.DB) error {
if err != nil { if err != nil {
return err return err
} }
// Check if users table is empty before inserting // Check if users table is empty before inserting
var count int var count int
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
if err != nil { if err != nil {
return err return err
} }
if count == 0 { if count == 0 {
_, err = db.Exec( _, err = db.Exec(
"INSERT INTO users (username, password) VALUES (?, ?)", "INSERT INTO users (username, password) VALUES (?, ?)",
@ -102,113 +99,4 @@ func migrateInitialSchemaDown(db *sql.DB) error {
} }
return nil return nil
} }
// Add reminders table - version 2
func migrateRemindersUp(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS reminders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
platform TEXT NOT NULL,
channel_id TEXT NOT NULL,
message_id TEXT NOT NULL,
reply_to_id TEXT NOT NULL,
user_id TEXT NOT NULL,
username TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
trigger_at TIMESTAMP NOT NULL,
content TEXT NOT NULL,
processed BOOLEAN NOT NULL DEFAULT 0
)
`)
return err
}
func migrateRemindersDown(db *sql.DB) error {
_, err := db.Exec(`DROP TABLE IF EXISTS reminders`)
return err
}
// Add cache table - version 3
func migrateCacheUp(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
expires_at TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
}
// Create index on expires_at for efficient cleanup
_, err = db.Exec(`
CREATE INDEX IF NOT EXISTS idx_cache_expires_at ON cache(expires_at)
`)
return err
}
func migrateCacheDown(db *sql.DB) error {
_, err := db.Exec(`DROP TABLE IF EXISTS cache`)
return err
}
// Add enable_all_plugins column to channels table - version 4
func migrateEnableAllPluginsUp(db *sql.DB) error {
_, err := db.Exec(`
ALTER TABLE channels ADD COLUMN enable_all_plugins BOOLEAN NOT NULL DEFAULT 0
`)
return err
}
func migrateEnableAllPluginsDown(db *sql.DB) error {
// SQLite doesn't support DROP COLUMN, so we need to recreate the table
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
_ = tx.Rollback() // Ignore rollback errors
}()
// Create backup table
_, err = tx.Exec(`
CREATE TABLE channels_backup (
id INTEGER PRIMARY KEY AUTOINCREMENT,
platform TEXT NOT NULL,
platform_channel_id TEXT NOT NULL,
enabled BOOLEAN NOT NULL DEFAULT 0,
channel_raw TEXT NOT NULL,
UNIQUE(platform, platform_channel_id)
)
`)
if err != nil {
return err
}
// Copy data excluding enable_all_plugins column
_, err = tx.Exec(`
INSERT INTO channels_backup (id, platform, platform_channel_id, enabled, channel_raw)
SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels
`)
if err != nil {
return err
}
// Drop original table
_, err = tx.Exec(`DROP TABLE channels`)
if err != nil {
return err
}
// Rename backup table
_, err = tx.Exec(`ALTER TABLE channels_backup RENAME TO channels`)
if err != nil {
return err
}
return tx.Commit()
}

View file

@ -4,57 +4,31 @@ import (
"time" "time"
) )
// ActionType defines the type of action to perform
type ActionType string
const (
// ActionSendMessage is for sending a message to the chat
ActionSendMessage ActionType = "send_message"
// ActionDeleteMessage is for deleting a message from the chat
ActionDeleteMessage ActionType = "delete_message"
)
// MessageAction represents an action to be performed on the platform
type MessageAction struct {
Type ActionType
Message *Message // For send_message
MessageID string // For delete_message
Chat string // Chat where the action happens
Channel *Channel // Channel reference
Raw map[string]interface{} // Additional data for the action
}
// Message represents a chat message // Message represents a chat message
type Message struct { type Message struct {
Text string Text string
Chat string Chat string
Channel *Channel Channel *Channel
Author string Author string
FromBot bool FromBot bool
Date time.Time Date time.Time
ID string ID string
ReplyTo string ReplyTo string
Raw map[string]interface{} Raw map[string]interface{}
} }
// Channel represents a chat channel // Channel represents a chat channel
type Channel struct { type Channel struct {
ID int64 ID int64
Platform string Platform string
PlatformChannelID string PlatformChannelID string
ChannelRaw map[string]interface{} ChannelRaw map[string]interface{}
Enabled bool Enabled bool
EnableAllPlugins bool Plugins map[string]*ChannelPlugin
Plugins map[string]*ChannelPlugin
} }
// HasEnabledPlugin checks if a plugin is enabled for this channel // HasEnabledPlugin checks if a plugin is enabled for this channel
func (c *Channel) HasEnabledPlugin(pluginID string) bool { func (c *Channel) HasEnabledPlugin(pluginID string) bool {
// If EnableAllPlugins is true, all plugins are considered enabled
if c.EnableAllPlugins {
return true
}
plugin, exists := c.Plugins[pluginID] plugin, exists := c.Plugins[pluginID]
if !exists { if !exists {
return false return false
@ -66,18 +40,18 @@ func (c *Channel) HasEnabledPlugin(pluginID string) bool {
func (c *Channel) ChannelName() string { func (c *Channel) ChannelName() string {
// In a real implementation, this would use the platform-specific // In a real implementation, this would use the platform-specific
// ParseChannelNameFromRaw function // ParseChannelNameFromRaw function
// For simplicity, we'll just use the PlatformChannelID if we can't extract a name // For simplicity, we'll just use the PlatformChannelID if we can't extract a name
// Check if ChannelRaw has a name field // Check if ChannelRaw has a name field
if c.ChannelRaw == nil { if c.ChannelRaw == nil {
return c.PlatformChannelID return c.PlatformChannelID
} }
// Check common name fields in ChannelRaw // Check common name fields in ChannelRaw
if name, ok := c.ChannelRaw["name"].(string); ok && name != "" { if name, ok := c.ChannelRaw["name"].(string); ok && name != "" {
return name return name
} }
// Check for nested objects like "chat" (used by Telegram) // Check for nested objects like "chat" (used by Telegram)
if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok { if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok {
// Try different fields in order of preference // Try different fields in order of preference
@ -91,7 +65,7 @@ func (c *Channel) ChannelName() string {
return firstName return firstName
} }
} }
return c.PlatformChannelID return c.PlatformChannelID
} }
@ -101,7 +75,7 @@ type ChannelPlugin struct {
ChannelID int64 ChannelID int64
PluginID string PluginID string
Enabled bool Enabled bool
Config map[string]any Config map[string]interface{}
} }
// User represents an admin user // User represents an admin user
@ -109,19 +83,4 @@ type User struct {
ID int64 ID int64
Username string Username string
Password string Password string
} }
// Reminder represents a scheduled reminder
type Reminder struct {
ID int64
Platform string
ChannelID string
MessageID string
ReplyToID string
UserID string
Username string
CreatedAt time.Time
TriggerAt time.Time
Content string
Processed bool
}

View file

@ -1,234 +0,0 @@
package model
import (
"testing"
)
func TestChannel_HasEnabledPlugin(t *testing.T) {
t.Run("EnableAllPlugins false - plugin not in map", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: false,
Plugins: make(map[string]*ChannelPlugin),
}
// Plugin not in map should return false
result := channel.HasEnabledPlugin("nonexistent.plugin")
if result {
t.Errorf("Expected HasEnabledPlugin to return false for nonexistent plugin, got true")
}
})
t.Run("EnableAllPlugins false - plugin disabled", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: false,
Plugins: map[string]*ChannelPlugin{
"test.plugin": {
ID: 1,
ChannelID: 1,
PluginID: "test.plugin",
Enabled: false,
Config: make(map[string]any),
},
},
}
// Disabled plugin should return false
result := channel.HasEnabledPlugin("test.plugin")
if result {
t.Errorf("Expected HasEnabledPlugin to return false for disabled plugin, got true")
}
})
t.Run("EnableAllPlugins false - plugin enabled", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: false,
Plugins: map[string]*ChannelPlugin{
"test.plugin": {
ID: 1,
ChannelID: 1,
PluginID: "test.plugin",
Enabled: true,
Config: make(map[string]any),
},
},
}
// Enabled plugin should return true
result := channel.HasEnabledPlugin("test.plugin")
if !result {
t.Errorf("Expected HasEnabledPlugin to return true for enabled plugin, got false")
}
})
t.Run("EnableAllPlugins true - plugin not in map", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: make(map[string]*ChannelPlugin),
}
// When EnableAllPlugins is true, any plugin should be considered enabled
result := channel.HasEnabledPlugin("nonexistent.plugin")
if !result {
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false")
}
})
t.Run("EnableAllPlugins true - plugin disabled", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*ChannelPlugin{
"test.plugin": {
ID: 1,
ChannelID: 1,
PluginID: "test.plugin",
Enabled: false,
Config: make(map[string]any),
},
},
}
// When EnableAllPlugins is true, even disabled plugins should be considered enabled
result := channel.HasEnabledPlugin("test.plugin")
if !result {
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true (even for disabled plugin), got false")
}
})
t.Run("EnableAllPlugins true - plugin enabled", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*ChannelPlugin{
"test.plugin": {
ID: 1,
ChannelID: 1,
PluginID: "test.plugin",
Enabled: true,
Config: make(map[string]any),
},
},
}
// When EnableAllPlugins is true, enabled plugins should also return true
result := channel.HasEnabledPlugin("test.plugin")
if !result {
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false")
}
})
t.Run("EnableAllPlugins true - multiple plugins", func(t *testing.T) {
channel := &Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*ChannelPlugin{
"plugin1": {
ID: 1,
ChannelID: 1,
PluginID: "plugin1",
Enabled: true,
Config: make(map[string]any),
},
"plugin2": {
ID: 2,
ChannelID: 1,
PluginID: "plugin2",
Enabled: false,
Config: make(map[string]any),
},
},
}
// All plugins should be enabled when EnableAllPlugins is true
testCases := []string{"plugin1", "plugin2", "plugin3", "any.plugin"}
for _, pluginID := range testCases {
result := channel.HasEnabledPlugin(pluginID)
if !result {
t.Errorf("Expected HasEnabledPlugin('%s') to return true when EnableAllPlugins is true, got false", pluginID)
}
}
})
}
func TestChannelName(t *testing.T) {
t.Run("Returns PlatformChannelID when ChannelRaw is nil", func(t *testing.T) {
channel := &Channel{
PlatformChannelID: "test-id",
ChannelRaw: nil,
}
result := channel.ChannelName()
if result != "test-id" {
t.Errorf("Expected channel name to be 'test-id', got '%s'", result)
}
})
t.Run("Returns name from ChannelRaw when available", func(t *testing.T) {
channel := &Channel{
PlatformChannelID: "test-id",
ChannelRaw: map[string]interface{}{
"name": "Test Channel",
},
}
result := channel.ChannelName()
if result != "Test Channel" {
t.Errorf("Expected channel name to be 'Test Channel', got '%s'", result)
}
})
t.Run("Returns title from nested chat object (Telegram style)", func(t *testing.T) {
channel := &Channel{
PlatformChannelID: "test-id",
ChannelRaw: map[string]interface{}{
"chat": map[string]interface{}{
"title": "Telegram Group",
},
},
}
result := channel.ChannelName()
if result != "Telegram Group" {
t.Errorf("Expected channel name to be 'Telegram Group', got '%s'", result)
}
})
t.Run("Falls back to PlatformChannelID when no valid name found", func(t *testing.T) {
channel := &Channel{
PlatformChannelID: "fallback-id",
ChannelRaw: map[string]interface{}{
"other_field": "value",
},
}
result := channel.ChannelName()
if result != "fallback-id" {
t.Errorf("Expected channel name to fallback to 'fallback-id', got '%s'", result)
}
})
}

View file

@ -43,7 +43,4 @@ type Platform interface {
// SendMessage sends a message through the platform // SendMessage sends a message through the platform
SendMessage(msg *Message) error SendMessage(msg *Message) error
// DeleteMessage deletes a message from the platform
DeleteMessage(channel string, messageID string) error
} }

View file

@ -2,18 +2,8 @@ package model
import ( import (
"errors" "errors"
"time"
) )
// CacheInterface defines the cache interface available to plugins
type CacheInterface interface {
Get(key string, destination interface{}) error
Set(key string, value interface{}, expiration *time.Time) error
SetWithTTL(key string, value interface{}, ttl time.Duration) error
Delete(key string) error
Exists(key string) (bool, error)
}
var ( var (
// ErrPluginNotFound is returned when a requested plugin doesn't exist // ErrPluginNotFound is returned when a requested plugin doesn't exist
ErrPluginNotFound = errors.New("plugin not found") ErrPluginNotFound = errors.New("plugin not found")
@ -23,16 +13,16 @@ var (
type Plugin interface { type Plugin interface {
// GetID returns the plugin ID // GetID returns the plugin ID
GetID() string GetID() string
// GetName returns the plugin name // GetName returns the plugin name
GetName() string GetName() string
// GetHelp returns the plugin help text // GetHelp returns the plugin help text
GetHelp() string GetHelp() string
// RequiresConfig indicates if the plugin requires configuration // RequiresConfig indicates if the plugin requires configuration
RequiresConfig() bool RequiresConfig() bool
// OnMessage processes an incoming message and returns platform actions // OnMessage processes an incoming message and returns response messages
OnMessage(msg *Message, config map[string]interface{}, cache CacheInterface) []*MessageAction OnMessage(msg *Message, config map[string]interface{}) []*Message
} }

View file

@ -4,7 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -37,15 +37,11 @@ func (s *SlackPlatform) Init(_ *config.Config) error {
// ParseIncomingMessage parses an incoming Slack message // ParseIncomingMessage parses an incoming Slack message
func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) { func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) {
// Read request body // Read request body
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer r.Body.Close()
if err := r.Body.Close(); err != nil {
fmt.Printf("Error closing request body: %v\n", err)
}
}()
// Parse JSON // Parse JSON
var requestData map[string]interface{} var requestData map[string]interface{}
@ -167,12 +163,6 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error {
return errors.New("bot token not configured") return errors.New("bot token not configured")
} }
// Check for delete message action
if msg.Raw != nil && msg.Raw["action"] == "delete" {
// This is a request to delete a message
return s.deleteMessage(msg)
}
// Prepare payload // Prepare payload
payload := map[string]interface{}{ payload := map[string]interface{}{
"channel": msg.Chat, "channel": msg.Chat,
@ -204,11 +194,7 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error {
if err != nil { if err != nil {
return err return err
} }
defer func() { defer resp.Body.Close()
if err := resp.Body.Close(); err != nil {
fmt.Printf("Error closing response body: %v\n", err)
}
}()
// Check response // Check response
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
@ -218,63 +204,6 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error {
return nil return nil
} }
// DeleteMessage deletes a message on Slack
func (s *SlackPlatform) DeleteMessage(channel string, messageID string) error {
// Prepare payload for chat.delete API
payload := map[string]interface{}{
"channel": channel,
"ts": messageID, // In Slack, the ts (timestamp) is the message ID
}
// Convert payload to JSON
data, err := json.Marshal(payload)
if err != nil {
return err
}
// Send HTTP request to chat.delete endpoint
req, err := http.NewRequest("POST", "https://slack.com/api/chat.delete", strings.NewReader(string(data)))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.config.BotOAuthAccessToken))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("Error closing response body: %v\n", err)
}
}()
// Check response
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("slack API error: %d - %s", resp.StatusCode, string(respBody))
}
return nil
}
// deleteMessage is a legacy method that uses the Raw message approach
func (s *SlackPlatform) deleteMessage(msg *model.Message) error {
// Get message ID to delete
messageID, ok := msg.Raw["message_id"]
if !ok {
return fmt.Errorf("no message ID provided for deletion")
}
// Convert to string if needed
messageIDStr := fmt.Sprintf("%v", messageID)
return s.DeleteMessage(msg.Chat, messageIDStr)
}
// Helper function to parse int64 // Helper function to parse int64
func parseInt64(s string) (int64, error) { func parseInt64(s string) (int64, error) {
var n int64 var n int64

View file

@ -62,11 +62,7 @@ func (t *TelegramPlatform) Init(cfg *config.Config) error {
t.log.Error("Failed to set webhook", "error", err) t.log.Error("Failed to set webhook", "error", err)
return fmt.Errorf("failed to set webhook: %w", err) return fmt.Errorf("failed to set webhook: %w", err)
} }
defer func() { defer resp.Body.Close()
if err := resp.Body.Close(); err != nil {
t.log.Error("Error closing response body", "error", err)
}
}()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
@ -89,11 +85,7 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
t.log.Error("Failed to read request body", "error", err) t.log.Error("Failed to read request body", "error", err)
return nil, err return nil, err
} }
defer func() { defer r.Body.Close()
if err := r.Body.Close(); err != nil {
t.log.Error("Error closing request body", "error", err)
}
}()
// Parse JSON // Parse JSON
var update struct { var update struct {
@ -111,11 +103,8 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
Title string `json:"title,omitempty"` Title string `json:"title,omitempty"`
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"`
} `json:"chat"` } `json:"chat"`
Date int `json:"date"` Date int `json:"date"`
Text string `json:"text"` Text string `json:"text"`
ReplyToMessage struct {
MessageID int `json:"message_id"`
} `json:"reply_to_message"`
} `json:"message"` } `json:"message"`
} }
@ -139,7 +128,6 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
FromBot: update.Message.From.IsBot, FromBot: update.Message.From.IsBot,
Date: time.Unix(int64(update.Message.Date), 0), Date: time.Unix(int64(update.Message.Date), 0),
ID: strconv.Itoa(update.Message.MessageID), ID: strconv.Itoa(update.Message.MessageID),
ReplyTo: strconv.Itoa(update.Message.ReplyToMessage.MessageID),
Raw: raw, Raw: raw,
} }
@ -217,13 +205,6 @@ func (t *TelegramPlatform) ParseChannelFromMessage(body []byte) (map[string]any,
// SendMessage sends a message to Telegram // SendMessage sends a message to Telegram
func (t *TelegramPlatform) SendMessage(msg *model.Message) error { func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
// Check for delete message action (legacy method)
if msg.Raw != nil && msg.Raw["action"] == "delete" {
// This is a request to delete a message using the legacy method
return t.deleteMessage(msg)
}
// Regular message sending
// Convert chat ID to int64 // Convert chat ID to int64
chatID, err := strconv.ParseInt(msg.Chat, 10, 64) chatID, err := strconv.ParseInt(msg.Chat, 10, 64)
if err != nil { if err != nil {
@ -237,15 +218,6 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
"text": msg.Text, "text": msg.Text,
} }
// Set parse_mode based on plugin preference or default to empty string
if msg.Raw != nil && msg.Raw["parse_mode"] != nil {
// Plugin explicitly set parse_mode
payload["parse_mode"] = msg.Raw["parse_mode"]
} else {
// Default to empty string (no formatting)
payload["parse_mode"] = ""
}
// Add reply if needed // Add reply if needed
if msg.ReplyTo != "" { if msg.ReplyTo != "" {
replyToID, err := strconv.Atoi(msg.ReplyTo) replyToID, err := strconv.Atoi(msg.ReplyTo)
@ -275,11 +247,7 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
t.log.Error("Failed to send message", "error", err) t.log.Error("Failed to send message", "error", err)
return err return err
} }
defer func() { defer resp.Body.Close()
if err := resp.Body.Close(); err != nil {
t.log.Error("Error closing response body", "error", err)
}
}()
// Check response // Check response
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
@ -291,89 +259,4 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
t.log.Debug("Message sent successfully") t.log.Debug("Message sent successfully")
return nil return nil
} }
// DeleteMessage deletes a message on Telegram
func (t *TelegramPlatform) DeleteMessage(channel string, messageID string) error {
// Convert chat ID to int64
chatID, err := strconv.ParseInt(channel, 10, 64)
if err != nil {
t.log.Error("Invalid chat ID for message deletion", "chat_id", channel, "error", err)
return err
}
// Convert message ID to integer
msgID, err := strconv.Atoi(messageID)
if err != nil {
t.log.Error("Invalid message ID for deletion", "message_id", messageID, "error", err)
return err
}
// Prepare payload for deleteMessage API
payload := map[string]interface{}{
"chat_id": chatID,
"message_id": msgID,
}
t.log.Debug("Deleting message on Telegram", "chat_id", chatID, "message_id", msgID)
// Convert payload to JSON
data, err := json.Marshal(payload)
if err != nil {
t.log.Error("Failed to marshal delete message payload", "error", err)
return err
}
// Send HTTP request to deleteMessage endpoint
resp, err := http.Post(
t.apiURL+"/deleteMessage",
"application/json",
bytes.NewBuffer(data),
)
if err != nil {
t.log.Error("Failed to delete message", "error", err)
return err
}
defer func() {
if err := resp.Body.Close(); err != nil {
t.log.Error("Error closing response body", "error", err)
}
}()
// Check response
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
errMsg := string(bodyBytes)
t.log.Error("Telegram API error when deleting message", "status", resp.StatusCode, "response", errMsg)
return fmt.Errorf("telegram API error when deleting message: %d - %s", resp.StatusCode, errMsg)
}
t.log.Debug("Message deleted successfully")
return nil
}
// deleteMessage is a legacy method that uses the Raw message approach
func (t *TelegramPlatform) deleteMessage(msg *model.Message) error {
// Get message ID to delete
messageIDInterface, ok := msg.Raw["message_id"]
if !ok {
t.log.Error("No message ID provided for deletion")
return fmt.Errorf("no message ID provided for deletion")
}
// Convert message ID to string
var messageIDStr string
switch v := messageIDInterface.(type) {
case string:
messageIDStr = v
case int:
messageIDStr = strconv.Itoa(v)
case float64:
messageIDStr = strconv.Itoa(int(v))
default:
t.log.Error("Invalid message ID type for deletion", "type", fmt.Sprintf("%T", messageIDInterface))
return fmt.Errorf("invalid message ID type for deletion")
}
return t.DeleteMessage(msg.Chat, messageIDStr)
}

View file

@ -1,132 +0,0 @@
package domainblock
import (
"fmt"
"net/url"
"regexp"
"strings"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// DomainBlockPlugin is a plugin that blocks messages containing links from specific domains
type DomainBlockPlugin struct {
plugin.BasePlugin
}
// Debug helper to check if RequiresConfig is working
func (p *DomainBlockPlugin) RequiresConfig() bool {
return true
}
// New creates a new DomainBlockPlugin instance
func New() *DomainBlockPlugin {
return &DomainBlockPlugin{
BasePlugin: plugin.BasePlugin{
ID: "security.domainblock",
Name: "Domain Blocker",
Help: "Blocks messages containing links from configured domains",
ConfigRequired: true,
},
}
}
// extractDomains extracts domains from a message text
func extractDomains(text string) []string {
// URL regex pattern
urlPattern := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`)
matches := urlPattern.FindAllStringSubmatch(text, -1)
domains := make([]string, 0, len(matches))
for _, match := range matches {
if len(match) < 2 {
continue
}
// Try to parse the URL to extract the domain
urlStr := match[0]
parsedURL, err := url.Parse(urlStr)
if err != nil {
continue
}
// Extract the domain (host) from the URL
domain := parsedURL.Host
// Remove port if present
if i := strings.IndexByte(domain, ':'); i >= 0 {
domain = domain[:i]
}
domains = append(domains, strings.ToLower(domain))
}
return domains
}
// OnMessage processes incoming messages
func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Skip messages from bots
if msg.FromBot {
return nil
}
// Get blocked domains from config
blockedDomainsStr, ok := config["blocked_domains"].(string)
if !ok || blockedDomainsStr == "" {
return nil // No blocked domains configured
}
// Split and clean blocked domains
blockedDomains := strings.Split(blockedDomainsStr, ",")
for i, domain := range blockedDomains {
blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain))
}
// Extract domains from message
messageDomains := extractDomains(msg.Text)
if len(messageDomains) == 0 {
return nil // No domains in message
}
// Check if any domains in the message are blocked
for _, msgDomain := range messageDomains {
for _, blockedDomain := range blockedDomains {
if blockedDomain == "" {
continue
}
if strings.HasSuffix(msgDomain, blockedDomain) || msgDomain == blockedDomain {
// Domain is blocked, create actions
// 1. Create a delete message action
deleteAction := &model.MessageAction{
Type: model.ActionDeleteMessage,
MessageID: msg.ID,
Chat: msg.Chat,
Channel: msg.Channel,
}
// 2. Create a notification message action
notificationMsg := &model.Message{
Text: fmt.Sprintf("I don't like links from %s 🙈", blockedDomain),
Chat: msg.Chat,
Channel: msg.Channel,
}
sendAction := &model.MessageAction{
Type: model.ActionSendMessage,
Message: notificationMsg,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{deleteAction, sendAction}
}
}
}
return nil
}
// Plugin is registered in app.go, not using init()

View file

@ -1,142 +0,0 @@
package domainblock
import (
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
)
func TestExtractDomains(t *testing.T) {
tests := []struct {
name string
text string
expected []string
}{
{
name: "No URLs",
text: "Hello, world!",
expected: []string{},
},
{
name: "Single URL",
text: "Check out https://example.com for more info",
expected: []string{"example.com"},
},
{
name: "Multiple URLs",
text: "Check out https://example.com and http://test.example.org for more info",
expected: []string{"example.com", "test.example.org"},
},
{
name: "URL with path",
text: "Check out https://example.com/path/to/resource",
expected: []string{"example.com"},
},
{
name: "URL with port",
text: "Check out https://example.com:8080/path/to/resource",
expected: []string{"example.com"},
},
{
name: "URL with subdomain",
text: "Check out https://sub.example.com",
expected: []string{"sub.example.com"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
domains := extractDomains(test.text)
if len(domains) != len(test.expected) {
t.Errorf("Expected %d domains, got %d", len(test.expected), len(domains))
return
}
for i, domain := range domains {
if domain != test.expected[i] {
t.Errorf("Expected domain %s, got %s", test.expected[i], domain)
}
}
})
}
}
func TestOnMessage(t *testing.T) {
plugin := New()
tests := []struct {
name string
text string
blockedDomains string
expectBlocked bool
}{
{
name: "No blocked domains",
text: "Check out https://example.com",
blockedDomains: "",
expectBlocked: false,
},
{
name: "No matching domain",
text: "Check out https://example.com",
blockedDomains: "bad.com, evil.org",
expectBlocked: false,
},
{
name: "Matching domain",
text: "Check out https://example.com",
blockedDomains: "example.com, evil.org",
expectBlocked: true,
},
{
name: "Matching subdomain",
text: "Check out https://sub.example.com",
blockedDomains: "example.com",
expectBlocked: true,
},
{
name: "Multiple domains, one matching",
text: "Check out https://example.com and https://good.org",
blockedDomains: "bad.com, example.com",
expectBlocked: true,
},
{
name: "Spaces in blocked domains list",
text: "Check out https://example.com",
blockedDomains: "bad.com, example.com , evil.org",
expectBlocked: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
msg := &model.Message{
Text: test.text,
Chat: "test-chat",
ID: "test-id",
Channel: &model.Channel{
ID: 1,
},
}
config := map[string]interface{}{
"blocked_domains": test.blockedDomains,
}
mockCache := &testutil.MockCache{}
responses := plugin.OnMessage(msg, config, mockCache)
if test.expectBlocked {
if len(responses) == 0 {
t.Errorf("Expected message to be blocked, but it wasn't")
}
} else {
if len(responses) > 0 {
t.Errorf("Expected message not to be blocked, but it was")
}
}
})
}
}

View file

@ -29,7 +29,7 @@ func NewCoin() *CoinPlugin {
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") { if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") {
return nil return nil
} }
@ -46,12 +46,5 @@ func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}
Channel: msg.Channel, Channel: msg.Channel,
} }
action := &model.MessageAction{ return []*model.Message{response}
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
} }

View file

@ -32,7 +32,7 @@ func NewDice() *DicePlugin {
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") { if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") {
return nil return nil
} }
@ -62,14 +62,7 @@ func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}
Channel: msg.Channel, Channel: msg.Channel,
} }
action := &model.MessageAction{ return []*model.Message{response}
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
} }
// rollDice parses a dice formula string and returns the result // rollDice parses a dice formula string and returns the result
@ -114,10 +107,9 @@ func (p *DicePlugin) rollDice(formula string) (int, error) {
return 0, fmt.Errorf("invalid modifier") return 0, fmt.Errorf("invalid modifier")
} }
switch matches[3] { if matches[3] == "+" {
case "+":
total += modifier total += modifier
case "-": } else if matches[3] == "-" {
total -= modifier total -= modifier
} }
} }

View file

@ -1,394 +0,0 @@
package fun
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// HLTBPlugin searches HowLongToBeat for game completion times
type HLTBPlugin struct {
plugin.BasePlugin
httpClient *http.Client
}
// HLTBSearchRequest represents the search request payload
type HLTBSearchRequest struct {
SearchType string `json:"searchType"`
SearchTerms []string `json:"searchTerms"`
SearchPage int `json:"searchPage"`
Size int `json:"size"`
SearchOptions map[string]interface{} `json:"searchOptions"`
UseCache bool `json:"useCache"`
}
// HLTBGame represents a game from HowLongToBeat
type HLTBGame struct {
ID int `json:"game_id"`
Name string `json:"game_name"`
GameAlias string `json:"game_alias"`
GameImage string `json:"game_image"`
CompMain int `json:"comp_main"`
CompPlus int `json:"comp_plus"`
CompComplete int `json:"comp_complete"`
CompAll int `json:"comp_all"`
InvestedCo int `json:"invested_co"`
InvestedMp int `json:"invested_mp"`
CountComp int `json:"count_comp"`
CountSpeedruns int `json:"count_speedruns"`
CountBacklog int `json:"count_backlog"`
CountReview int `json:"count_review"`
ReviewScore int `json:"review_score"`
CountPlaying int `json:"count_playing"`
CountRetired int `json:"count_retired"`
}
// HLTBSearchResponse represents the search response
type HLTBSearchResponse struct {
Color string `json:"color"`
Title string `json:"title"`
Category string `json:"category"`
Count int `json:"count"`
Pagecurrent int `json:"pagecurrent"`
Pagesize int `json:"pagesize"`
Pagetotal int `json:"pagetotal"`
SearchTerm string `json:"searchTerm"`
SearchResults []HLTBGame `json:"data"`
}
// NewHLTB creates a new HLTBPlugin instance
func NewHLTB() *HLTBPlugin {
return &HLTBPlugin{
BasePlugin: plugin.BasePlugin{
ID: "fun.hltb",
Name: "How Long To Beat",
Help: "Get game completion times from HowLongToBeat.com using `!hltb <game name>`",
},
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// OnMessage handles incoming messages
func (p *HLTBPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Check if message starts with !hltb
text := strings.TrimSpace(msg.Text)
if !strings.HasPrefix(text, "!hltb ") {
return nil
}
// Extract game name
gameName := strings.TrimSpace(text[6:]) // Remove "!hltb "
if gameName == "" {
return p.createErrorResponse(msg, "Please provide a game name. Usage: !hltb <game name>")
}
// Check cache first
var games []HLTBGame
var err error
cacheKey := strings.ToLower(gameName)
err = cache.Get(cacheKey, &games)
if err != nil || len(games) == 0 {
// Cache miss - search for the game
games, err = p.searchGame(gameName)
if err != nil {
return p.createErrorResponse(msg, fmt.Sprintf("Error searching for game: %s", err.Error()))
}
if len(games) == 0 {
return p.createErrorResponse(msg, fmt.Sprintf("No results found for '%s'", gameName))
}
// Cache the results for 1 hour
err = cache.SetWithTTL(cacheKey, games, time.Hour)
if err != nil {
// Log cache error but don't fail the request
fmt.Printf("Warning: Failed to cache HLTB results: %v\n", err)
}
}
// Use the first result
game := games[0]
// Format the response
response := p.formatGameInfo(game)
// Create response message with game cover if available
responseMsg := &model.Message{
Text: response,
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
}
// Set parse mode for markdown formatting
if responseMsg.Raw == nil {
responseMsg.Raw = make(map[string]interface{})
}
responseMsg.Raw["parse_mode"] = "Markdown"
// Add game cover as attachment if available
if game.GameImage != "" {
imageURL := p.getFullImageURL(game.GameImage)
responseMsg.Raw["image_url"] = imageURL
}
action := &model.MessageAction{
Type: model.ActionSendMessage,
Message: responseMsg,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
}
// searchGame searches for a game on HowLongToBeat
func (p *HLTBPlugin) searchGame(gameName string) ([]HLTBGame, error) {
// Split search terms by words
searchTerms := strings.Fields(gameName)
// Prepare search request
searchRequest := HLTBSearchRequest{
SearchType: "games",
SearchTerms: searchTerms,
SearchPage: 1,
Size: 20,
SearchOptions: map[string]interface{}{
"games": map[string]interface{}{
"userId": 0,
"platform": "",
"sortCategory": "popular",
"rangeCategory": "main",
"rangeTime": map[string]interface{}{
"min": nil,
"max": nil,
},
"gameplay": map[string]interface{}{
"perspective": "",
"flow": "",
"genre": "",
"difficulty": "",
},
"rangeYear": map[string]interface{}{
"min": "",
"max": "",
},
"modifier": "",
},
"users": map[string]interface{}{
"sortCategory": "postcount",
},
"lists": map[string]interface{}{
"sortCategory": "follows",
},
"filter": "",
"sort": 0,
"randomizer": 0,
},
UseCache: true,
}
// Convert to JSON
jsonData, err := json.Marshal(searchRequest)
if err != nil {
return nil, fmt.Errorf("failed to marshal search request: %w", err)
}
// The API endpoint appears to have changed to use dynamic tokens
// Try to get the seek token first, fallback to basic search
seekToken, err := p.getSeekToken()
if err != nil {
// Fallback to old endpoint
seekToken = ""
}
var apiURL string
if seekToken != "" {
apiURL = fmt.Sprintf("https://howlongtobeat.com/api/seek/%s", seekToken)
} else {
apiURL = "https://howlongtobeat.com/api/search"
}
// Create HTTP request
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers to match the working curl request
req.Header.Set("Accept", "*/*")
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Origin", "https://howlongtobeat.com")
req.Header.Set("Pragma", "no-cache")
req.Header.Set("Referer", "https://howlongtobeat.com")
req.Header.Set("Sec-Fetch-Dest", "empty")
req.Header.Set("Sec-Fetch-Mode", "cors")
req.Header.Set("Sec-Fetch-Site", "same-origin")
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36")
// Send request
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned status code: %d", resp.StatusCode)
}
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Parse response
var searchResponse HLTBSearchResponse
if err := json.Unmarshal(body, &searchResponse); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return searchResponse.SearchResults, nil
}
// formatGameInfo formats game information for display
func (p *HLTBPlugin) formatGameInfo(game HLTBGame) string {
var response strings.Builder
response.WriteString(fmt.Sprintf("🎮 **%s**\n\n", game.Name))
// Format completion times
if game.CompMain > 0 {
response.WriteString(fmt.Sprintf("📖 **Main Story:** %s\n", p.formatTime(game.CompMain)))
}
if game.CompPlus > 0 {
response.WriteString(fmt.Sprintf(" **Main + Extras:** %s\n", p.formatTime(game.CompPlus)))
}
if game.CompComplete > 0 {
response.WriteString(fmt.Sprintf("💯 **Completionist:** %s\n", p.formatTime(game.CompComplete)))
}
if game.CompAll > 0 {
response.WriteString(fmt.Sprintf("🎯 **All Styles:** %s\n", p.formatTime(game.CompAll)))
}
// Add review score if available
if game.ReviewScore > 0 {
response.WriteString(fmt.Sprintf("\n⭐ **User Score:** %d/100", game.ReviewScore))
}
// Add source attribution
response.WriteString("\n\n*Source: HowLongToBeat.com*")
return response.String()
}
// formatTime converts seconds to a readable time format
func (p *HLTBPlugin) formatTime(seconds int) string {
if seconds <= 0 {
return "N/A"
}
hours := float64(seconds) / 3600.0
if hours < 1 {
minutes := seconds / 60
return fmt.Sprintf("%d minutes", minutes)
} else if hours < 2 {
return fmt.Sprintf("%.1f hour", hours)
} else {
return fmt.Sprintf("%.1f hours", hours)
}
}
// getFullImageURL constructs the full image URL
func (p *HLTBPlugin) getFullImageURL(imagePath string) string {
if imagePath == "" {
return ""
}
// Remove leading slash if present
imagePath = strings.TrimPrefix(imagePath, "/")
return fmt.Sprintf("https://howlongtobeat.com/games/%s", imagePath)
}
// getSeekToken attempts to retrieve the seek token from HowLongToBeat
func (p *HLTBPlugin) getSeekToken() (string, error) {
// Try to extract the seek token from the main page
req, err := http.NewRequest("GET", "https://howlongtobeat.com", nil)
if err != nil {
return "", fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36")
resp, err := p.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to fetch token: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read token response: %w", err)
}
// Look for patterns that might contain the token
patterns := []string{
`/api/seek/([a-f0-9]+)`,
`"seek/([a-f0-9]+)"`,
`seek/([a-f0-9]{12,})`,
}
bodyStr := string(body)
for _, pattern := range patterns {
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(bodyStr)
if len(matches) > 1 {
return matches[1], nil
}
}
// If we can't extract a token, return the known working one as fallback
return "d4b2e330db04dbf3", nil
}
// createErrorResponse creates an error response message
func (p *HLTBPlugin) createErrorResponse(msg *model.Message, errorText string) []*model.MessageAction {
response := &model.Message{
Text: fmt.Sprintf("❌ %s", errorText),
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
}
action := &model.MessageAction{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
}

View file

@ -1,131 +0,0 @@
package fun
import (
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
)
func TestHLTBPlugin_OnMessage(t *testing.T) {
plugin := NewHLTB()
tests := []struct {
name string
messageText string
shouldRespond bool
}{
{
name: "responds to !hltb command",
messageText: "!hltb The Witcher 3",
shouldRespond: true,
},
{
name: "ignores non-hltb messages",
messageText: "hello world",
shouldRespond: false,
},
{
name: "ignores !hltb without game name",
messageText: "!hltb",
shouldRespond: false,
},
{
name: "ignores !hltb with only spaces",
messageText: "!hltb ",
shouldRespond: false,
},
{
name: "ignores similar but incorrect commands",
messageText: "hltb The Witcher 3",
shouldRespond: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := &model.Message{
Text: tt.messageText,
Chat: "test-chat",
Channel: &model.Channel{ID: 1},
Author: "test-user",
}
mockCache := &testutil.MockCache{}
actions := plugin.OnMessage(msg, make(map[string]interface{}), mockCache)
if tt.shouldRespond && len(actions) == 0 {
t.Errorf("Expected plugin to respond to '%s', but it didn't", tt.messageText)
}
if !tt.shouldRespond && len(actions) > 0 {
t.Errorf("Expected plugin to not respond to '%s', but it did", tt.messageText)
}
// For messages that should respond, verify the response structure
if tt.shouldRespond && len(actions) > 0 {
action := actions[0]
if action.Type != model.ActionSendMessage {
t.Errorf("Expected ActionSendMessage, got %s", action.Type)
}
if action.Message == nil {
t.Error("Expected action to have a message")
}
if action.Message != nil && action.Message.ReplyTo != msg.ID {
t.Error("Expected response to reply to original message")
}
}
})
}
}
func TestHLTBPlugin_formatTime(t *testing.T) {
plugin := NewHLTB()
tests := []struct {
seconds int
expected string
}{
{0, "N/A"},
{-1, "N/A"},
{1800, "30 minutes"}, // 30 minutes
{3600, "1.0 hour"}, // 1 hour
{7200, "2.0 hours"}, // 2 hours
{10800, "3.0 hours"}, // 3 hours
{36000, "10.0 hours"}, // 10 hours
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := plugin.formatTime(tt.seconds)
if result != tt.expected {
t.Errorf("formatTime(%d) = %s, want %s", tt.seconds, result, tt.expected)
}
})
}
}
func TestHLTBPlugin_getFullImageURL(t *testing.T) {
plugin := NewHLTB()
tests := []struct {
imagePath string
expected string
}{
{"", ""},
{"game.jpg", "https://howlongtobeat.com/games/game.jpg"},
{"/game.jpg", "https://howlongtobeat.com/games/game.jpg"},
{"folder/game.png", "https://howlongtobeat.com/games/folder/game.png"},
}
for _, tt := range tests {
t.Run(tt.imagePath, func(t *testing.T) {
result := plugin.getFullImageURL(tt.imagePath)
if result != tt.expected {
t.Errorf("getFullImageURL(%s) = %s, want %s", tt.imagePath, result, tt.expected)
}
})
}
}

View file

@ -23,13 +23,8 @@ func NewLoquito() *LoquitoPlugin {
} }
} }
// GetHelp returns the plugin help text
func (p *LoquitoPlugin) GetHelp() string {
return ""
}
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
if !strings.Contains(strings.ToLower(msg.Text), "lo quito") { if !strings.Contains(strings.ToLower(msg.Text), "lo quito") {
return nil return nil
} }
@ -41,12 +36,5 @@ func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interfac
Channel: msg.Channel, Channel: msg.Channel,
} }
action := &model.MessageAction{ return []*model.Message{response}
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
} }

View file

@ -1,166 +0,0 @@
package help
import (
"fmt"
"sort"
"strings"
"git.nakama.town/fmartingr/butterrobot/internal/db"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
"golang.org/x/exp/slog"
)
// ChannelPluginGetter is an interface for getting channel plugins
type ChannelPluginGetter interface {
GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error)
GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error)
}
// HelpPlugin provides help information about available commands
type HelpPlugin struct {
plugin.BasePlugin
db ChannelPluginGetter
}
// New creates a new HelpPlugin instance
func New(db ChannelPluginGetter) *HelpPlugin {
return &HelpPlugin{
BasePlugin: plugin.BasePlugin{
ID: "utility.help",
Name: "Help",
Help: "Shows available commands when you type '!help'",
},
db: db,
}
}
// OnMessage handles incoming messages
func (p *HelpPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Check if message is the help command
if !strings.EqualFold(strings.TrimSpace(msg.Text), "!help") {
return nil
}
// Get channel plugins from database using platform and platform channel ID
channelPlugins, err := p.db.GetChannelPluginsFromPlatformID(msg.Channel.Platform, msg.Channel.PlatformChannelID)
if err != nil && err != db.ErrNotFound {
slog.Error("Failed to get channel plugins", slog.Any("err", err))
return []*model.MessageAction{}
}
// If no plugins found, initialize empty slice
if err == db.ErrNotFound {
channelPlugins = []*model.ChannelPlugin{}
}
// Get all available plugins
availablePlugins := plugin.GetAvailablePlugins()
// Filter to only enabled plugins for this channel
enabledPlugins := make(map[string]model.Plugin)
for _, channelPlugin := range channelPlugins {
if channelPlugin.Enabled {
if availablePlugin, exists := availablePlugins[channelPlugin.PluginID]; exists {
enabledPlugins[channelPlugin.PluginID] = availablePlugin
}
}
}
// If no plugins are enabled, return a message
if len(enabledPlugins) == 0 {
response := &model.Message{
Text: "No plugins are currently enabled for this channel.",
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
Raw: map[string]interface{}{"parse_mode": "Markdown"},
}
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
// Group plugins by category
categories := map[string][]model.Plugin{
"Development": {},
"Fun and Entertainment": {},
"Utility": {},
"Security": {},
"Social Media": {},
"Other": {},
}
// Categorize plugins based on their ID prefix
for _, p := range enabledPlugins {
category := p.GetID()
switch {
case strings.HasPrefix(category, "dev."):
categories["Development"] = append(categories["Development"], p)
case strings.HasPrefix(category, "fun."):
categories["Fun and Entertainment"] = append(categories["Fun and Entertainment"], p)
case strings.HasPrefix(category, "util.") || strings.HasPrefix(category, "reminder.") || strings.HasPrefix(category, "utility."):
categories["Utility"] = append(categories["Utility"], p)
case strings.HasPrefix(category, "security."):
categories["Security"] = append(categories["Security"], p)
case strings.HasPrefix(category, "social."):
categories["Social Media"] = append(categories["Social Media"], p)
default:
categories["Other"] = append(categories["Other"], p)
}
}
// Build the help message
var helpText strings.Builder
helpText.WriteString("🤖 **Available Commands**\n\n")
// Sort category names for consistent output
categoryOrder := []string{"Development", "Fun and Entertainment", "Utility", "Security", "Social Media", "Other"}
for _, categoryName := range categoryOrder {
pluginList := categories[categoryName]
if len(pluginList) == 0 {
continue
}
// Sort plugins within category by name
sort.Slice(pluginList, func(i, j int) bool {
return pluginList[i].GetName() < pluginList[j].GetName()
})
helpText.WriteString(fmt.Sprintf("**%s:**\n", categoryName))
for _, p := range pluginList {
if p.GetHelp() == "" {
continue
}
helpText.WriteString(fmt.Sprintf("• **%s** - %s\n", p.GetName(), p.GetHelp()))
}
helpText.WriteString("\n")
}
// Add footer
helpText.WriteString("_Use the specific commands or triggers mentioned above to interact with the bot._")
response := &model.Message{
Text: helpText.String(),
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
Raw: map[string]interface{}{"parse_mode": "Markdown"},
}
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}

View file

@ -1,206 +0,0 @@
package help
import (
"strings"
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/db"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// MockPlugin implements the Plugin interface for testing
type MockPlugin struct {
id string
name string
help string
}
func (m *MockPlugin) GetID() string { return m.id }
func (m *MockPlugin) GetName() string { return m.name }
func (m *MockPlugin) GetHelp() string { return m.help }
func (m *MockPlugin) RequiresConfig() bool {
return false
}
func (m *MockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
return nil
}
// MockDatabase implements the ChannelPluginGetter interface for testing
type MockDatabase struct {
channelPlugins map[int64][]*model.ChannelPlugin
platformChannelPlugins map[string][]*model.ChannelPlugin // key: "platform:platformChannelID"
}
func (m *MockDatabase) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) {
if plugins, exists := m.channelPlugins[channelID]; exists {
return plugins, nil
}
return nil, db.ErrNotFound
}
func (m *MockDatabase) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) {
key := platform + ":" + platformChannelID
if plugins, exists := m.platformChannelPlugins[key]; exists {
return plugins, nil
}
return nil, db.ErrNotFound
}
func TestHelpPlugin_OnMessage(t *testing.T) {
tests := []struct {
name string
messageText string
enabledPlugins map[string]*MockPlugin
expectResponse bool
expectNoPlugins bool
expectCategories []string
}{
{
name: "responds to !help command",
messageText: "!help",
enabledPlugins: map[string]*MockPlugin{
"dev.ping": {
id: "dev.ping",
name: "Ping",
help: "Responds to 'ping' with 'pong'",
},
"fun.dice": {
id: "fun.dice",
name: "Dice Roller",
help: "Rolls dice when you type '!dice [formula]'",
},
},
expectResponse: true,
expectCategories: []string{"Development", "Fun and Entertainment"},
},
{
name: "ignores non-help messages",
messageText: "hello world",
enabledPlugins: map[string]*MockPlugin{},
expectResponse: false,
},
{
name: "ignores case variation",
messageText: "!HELP",
enabledPlugins: map[string]*MockPlugin{},
expectResponse: true,
expectNoPlugins: true,
},
{
name: "handles no enabled plugins",
messageText: "!help",
enabledPlugins: map[string]*MockPlugin{},
expectResponse: true,
expectNoPlugins: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock database
mockDB := &MockDatabase{
channelPlugins: make(map[int64][]*model.ChannelPlugin),
platformChannelPlugins: make(map[string][]*model.ChannelPlugin),
}
// Setup channel plugins in mock database
var channelPluginList []*model.ChannelPlugin
pluginCounter := int64(1)
for pluginID := range tt.enabledPlugins {
channelPluginList = append(channelPluginList, &model.ChannelPlugin{
ID: pluginCounter,
ChannelID: 1,
PluginID: pluginID,
Enabled: true,
Config: make(map[string]interface{}),
})
pluginCounter++
}
// Set up both mapping approaches for the test
mockDB.channelPlugins[1] = channelPluginList
mockDB.platformChannelPlugins["test:test-channel"] = channelPluginList
// Create help plugin
p := New(mockDB)
// Create mock channel
channel := &model.Channel{
ID: 1,
Platform: "test",
PlatformChannelID: "test-channel",
}
// Create test message
msg := &model.Message{
ID: "test-msg",
Text: tt.messageText,
Chat: "test-chat",
Channel: channel,
}
// Mock the plugin registry
originalRegistry := plugin.GetAvailablePlugins()
// Override the registry for this test
plugin.ClearRegistry()
for _, mockPlugin := range tt.enabledPlugins {
plugin.Register(mockPlugin)
}
// Call OnMessage
actions := p.OnMessage(msg, map[string]interface{}{}, nil)
// Restore original registry
plugin.ClearRegistry()
for _, p := range originalRegistry {
plugin.Register(p)
}
if !tt.expectResponse {
if len(actions) != 0 {
t.Errorf("Expected no response, but got %d actions", len(actions))
}
return
}
if len(actions) != 1 {
t.Errorf("Expected 1 action, got %d", len(actions))
return
}
action := actions[0]
if action.Type != model.ActionSendMessage {
t.Errorf("Expected ActionSendMessage, got %v", action.Type)
return
}
responseText := action.Message.Text
if tt.expectNoPlugins {
if !strings.Contains(responseText, "No plugins are currently enabled") {
t.Errorf("Expected 'no plugins' message, got: %s", responseText)
}
return
}
// Check that expected categories appear in response
for _, category := range tt.expectCategories {
if !strings.Contains(responseText, "**"+category+":**") {
t.Errorf("Expected category '%s' in response, got: %s", category, responseText)
}
}
// Check that plugin names and help text appear
for _, mockPlugin := range tt.enabledPlugins {
if !strings.Contains(responseText, mockPlugin.GetName()) {
t.Errorf("Expected plugin name '%s' in response", mockPlugin.GetName())
}
if !strings.Contains(responseText, mockPlugin.GetHelp()) {
t.Errorf("Expected plugin help '%s' in response", mockPlugin.GetHelp())
}
}
})
}
}

View file

@ -24,12 +24,11 @@ func New() *PingPlugin {
} }
// OnMessage handles incoming messages // OnMessage handles incoming messages
func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") { if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") {
return nil return nil
} }
// Create the response message
response := &model.Message{ response := &model.Message{
Text: "pong", Text: "pong",
Chat: msg.Chat, Chat: msg.Chat,
@ -37,13 +36,5 @@ func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}
Channel: msg.Channel, Channel: msg.Channel,
} }
// Create an action to send the message return []*model.Message{response}
action := &model.MessageAction{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
} }

View file

@ -1,7 +1,6 @@
package plugin package plugin
import ( import (
"maps"
"sync" "sync"
"git.nakama.town/fmartingr/butterrobot/internal/model" "git.nakama.town/fmartingr/butterrobot/internal/model"
@ -42,31 +41,13 @@ func GetAvailablePlugins() map[string]model.Plugin {
// Create a copy to avoid race conditions // Create a copy to avoid race conditions
result := make(map[string]model.Plugin, len(plugins)) result := make(map[string]model.Plugin, len(plugins))
maps.Copy(result, plugins) for id, plugin := range plugins {
result[id] = plugin
return result
}
// GetAvailablePluginIDs returns a slice of all registered plugin IDs
func GetAvailablePluginIDs() []string {
pluginsMu.RLock()
defer pluginsMu.RUnlock()
result := make([]string, 0, len(plugins))
for pluginID := range plugins {
result = append(result, pluginID)
} }
return result return result
} }
// ClearRegistry clears all registered plugins (for testing)
func ClearRegistry() {
pluginsMu.Lock()
defer pluginsMu.Unlock()
plugins = make(map[string]model.Plugin)
}
// BasePlugin provides a common base for plugins // BasePlugin provides a common base for plugins
type BasePlugin struct { type BasePlugin struct {
ID string ID string
@ -96,6 +77,6 @@ func (p *BasePlugin) RequiresConfig() bool {
} }
// OnMessage is the default implementation that does nothing // OnMessage is the default implementation that does nothing
func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
return nil return nil
} }

View file

@ -1,331 +0,0 @@
package plugin
import (
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/model"
)
// Mock plugin for testing
type testPlugin struct {
BasePlugin
}
func (p *testPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: &model.Message{
Text: "test response",
Chat: msg.Chat,
Channel: msg.Channel,
},
},
}
}
func TestGetAvailablePluginIDs(t *testing.T) {
// Clear registry before test
ClearRegistry()
// Register test plugins
testPlugin1 := &testPlugin{
BasePlugin: BasePlugin{
ID: "test.plugin1",
Name: "Test Plugin 1",
},
}
testPlugin2 := &testPlugin{
BasePlugin: BasePlugin{
ID: "test.plugin2",
Name: "Test Plugin 2",
},
}
Register(testPlugin1)
Register(testPlugin2)
// Test GetAvailablePluginIDs
pluginIDs := GetAvailablePluginIDs()
if len(pluginIDs) != 2 {
t.Errorf("Expected 2 plugin IDs, got %d", len(pluginIDs))
}
// Check that both plugin IDs are present
found1, found2 := false, false
for _, id := range pluginIDs {
if id == "test.plugin1" {
found1 = true
}
if id == "test.plugin2" {
found2 = true
}
}
if !found1 {
t.Errorf("Expected to find test.plugin1 in plugin IDs")
}
if !found2 {
t.Errorf("Expected to find test.plugin2 in plugin IDs")
}
}
func TestEnableAllPluginsProcessingLogic(t *testing.T) {
// Clear registry before test
ClearRegistry()
// Register test plugins
testPlugin1 := &testPlugin{
BasePlugin: BasePlugin{
ID: "ping",
Name: "Ping Plugin",
},
}
testPlugin2 := &testPlugin{
BasePlugin: BasePlugin{
ID: "echo",
Name: "Echo Plugin",
},
}
testPlugin3 := &testPlugin{
BasePlugin: BasePlugin{
ID: "help",
Name: "Help Plugin",
},
}
Register(testPlugin1)
Register(testPlugin2)
Register(testPlugin3)
t.Run("EnableAllPlugins false - only explicitly enabled plugins", func(t *testing.T) {
// Create a channel with EnableAllPlugins = false and only some plugins enabled
channel := &model.Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: false,
Plugins: map[string]*model.ChannelPlugin{
"ping": {
ID: 1,
ChannelID: 1,
PluginID: "ping",
Enabled: true,
Config: map[string]interface{}{"key": "value"},
},
"echo": {
ID: 2,
ChannelID: 1,
PluginID: "echo",
Enabled: false, // Disabled
Config: map[string]interface{}{},
},
// help plugin not configured
},
}
// Simulate the plugin processing logic from handleMessage
var pluginsToProcess []string
if channel.EnableAllPlugins {
pluginsToProcess = GetAvailablePluginIDs()
} else {
for pluginID := range channel.Plugins {
if channel.HasEnabledPlugin(pluginID) {
pluginsToProcess = append(pluginsToProcess, pluginID)
}
}
}
// Should only have "ping" since echo is disabled and help is not configured
if len(pluginsToProcess) != 1 {
t.Errorf("Expected 1 plugin to process, got %d: %v", len(pluginsToProcess), pluginsToProcess)
}
if len(pluginsToProcess) > 0 && pluginsToProcess[0] != "ping" {
t.Errorf("Expected ping plugin to be processed, got %s", pluginsToProcess[0])
}
})
t.Run("EnableAllPlugins true - all registered plugins", func(t *testing.T) {
// Create a channel with EnableAllPlugins = true
channel := &model.Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*model.ChannelPlugin{
"ping": {
ID: 1,
ChannelID: 1,
PluginID: "ping",
Enabled: true,
Config: map[string]interface{}{"key": "value"},
},
"echo": {
ID: 2,
ChannelID: 1,
PluginID: "echo",
Enabled: false, // Disabled, but should still be processed
Config: map[string]interface{}{},
},
// help plugin not configured, but should still be processed
},
}
// Simulate the plugin processing logic from handleMessage
var pluginsToProcess []string
if channel.EnableAllPlugins {
pluginsToProcess = GetAvailablePluginIDs()
} else {
for pluginID := range channel.Plugins {
if channel.HasEnabledPlugin(pluginID) {
pluginsToProcess = append(pluginsToProcess, pluginID)
}
}
}
// Should have all 3 registered plugins
if len(pluginsToProcess) != 3 {
t.Errorf("Expected 3 plugins to process, got %d: %v", len(pluginsToProcess), pluginsToProcess)
}
// Check that all plugins are included
expectedPlugins := map[string]bool{"ping": false, "echo": false, "help": false}
for _, pluginID := range pluginsToProcess {
if _, exists := expectedPlugins[pluginID]; exists {
expectedPlugins[pluginID] = true
} else {
t.Errorf("Unexpected plugin in processing list: %s", pluginID)
}
}
for pluginID, found := range expectedPlugins {
if !found {
t.Errorf("Expected plugin %s to be in processing list", pluginID)
}
}
})
t.Run("Plugin configuration handling", func(t *testing.T) {
// Test the configuration logic from handleMessage
channel := &model.Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "123456",
Enabled: true,
EnableAllPlugins: true,
Plugins: map[string]*model.ChannelPlugin{
"ping": {
ID: 1,
ChannelID: 1,
PluginID: "ping",
Enabled: true,
Config: map[string]interface{}{"configured": "value"},
},
},
}
testCases := []struct {
pluginID string
expectedConfig map[string]interface{}
}{
{
pluginID: "ping",
expectedConfig: map[string]interface{}{"configured": "value"},
},
{
pluginID: "echo", // Not explicitly configured
expectedConfig: map[string]interface{}{},
},
}
for _, tc := range testCases {
// Simulate the config retrieval logic from handleMessage
var config map[string]interface{}
if channelPlugin, exists := channel.Plugins[tc.pluginID]; exists {
config = channelPlugin.Config
} else {
config = make(map[string]interface{})
}
if len(config) != len(tc.expectedConfig) {
t.Errorf("Plugin %s: expected config length %d, got %d", tc.pluginID, len(tc.expectedConfig), len(config))
}
for key, expectedValue := range tc.expectedConfig {
if actualValue, exists := config[key]; !exists || actualValue != expectedValue {
t.Errorf("Plugin %s: expected config[%s] = %v, got %v", tc.pluginID, key, expectedValue, actualValue)
}
}
}
})
}
func TestPluginRegistry(t *testing.T) {
// Clear registry before test
ClearRegistry()
testPlugin := &testPlugin{
BasePlugin: BasePlugin{
ID: "test.registry",
Name: "Test Registry Plugin",
},
}
t.Run("Register and Get plugin", func(t *testing.T) {
Register(testPlugin)
retrieved, err := Get("test.registry")
if err != nil {
t.Errorf("Failed to get registered plugin: %v", err)
}
if retrieved.GetID() != "test.registry" {
t.Errorf("Expected plugin ID 'test.registry', got '%s'", retrieved.GetID())
}
})
t.Run("Get nonexistent plugin", func(t *testing.T) {
_, err := Get("nonexistent.plugin")
if err == nil {
t.Errorf("Expected error when getting nonexistent plugin, got nil")
}
if err != model.ErrPluginNotFound {
t.Errorf("Expected ErrPluginNotFound, got %v", err)
}
})
t.Run("GetAvailablePlugins", func(t *testing.T) {
plugins := GetAvailablePlugins()
if len(plugins) != 1 {
t.Errorf("Expected 1 plugin in registry, got %d", len(plugins))
}
if plugin, exists := plugins["test.registry"]; !exists {
t.Errorf("Expected to find test.registry in available plugins")
} else if plugin.GetID() != "test.registry" {
t.Errorf("Expected plugin ID 'test.registry', got '%s'", plugin.GetID())
}
})
t.Run("ClearRegistry", func(t *testing.T) {
ClearRegistry()
plugins := GetAvailablePlugins()
if len(plugins) != 0 {
t.Errorf("Expected 0 plugins after clearing registry, got %d", len(plugins))
}
_, err := Get("test.registry")
if err == nil {
t.Errorf("Expected error when getting plugin after clearing registry, got nil")
}
})
}

View file

@ -1,200 +0,0 @@
package reminder
import (
"fmt"
"regexp"
"strconv"
"strings"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// Duration regex patterns to match reminders
var (
remindMePattern = regexp.MustCompile(`(?i)^!remindme\s(\d+)(y|mo|d|h|m|s)$`)
)
// ReminderCreator is an interface for creating reminders
type ReminderCreator interface {
CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error)
}
// Reminder is a plugin that sets reminders for messages
type Reminder struct {
plugin.BasePlugin
creator ReminderCreator
}
// New creates a new Reminder plugin
func New(creator ReminderCreator) *Reminder {
return &Reminder{
BasePlugin: plugin.BasePlugin{
ID: "reminder.remindme",
Name: "Remind Me",
Help: "Reply to a message with `!remindme <duration>` to set a reminder (e.g., `!remindme 2d` for 2 days, `!remindme 1y` for 1 year).",
ConfigRequired: false,
},
creator: creator,
}
}
// OnMessage processes incoming messages
func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Only process replies to messages
if msg.ReplyTo == "" {
return nil
}
// Check if the message is a reminder command
match := remindMePattern.FindStringSubmatch(msg.Text)
if match == nil {
return nil
}
// Parse the duration
amount, err := strconv.Atoi(match[1])
if err != nil {
errorMsg := &model.Message{
Text: "Invalid duration format. Please use a number followed by y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).",
Chat: msg.Chat,
Channel: msg.Channel,
Author: "bot",
FromBot: true,
Date: time.Now(),
ReplyTo: msg.ID,
}
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: errorMsg,
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
// Calculate the trigger time
var duration time.Duration
unit := match[2]
switch strings.ToLower(unit) {
case "y":
duration = time.Duration(amount) * 365 * 24 * time.Hour
case "mo":
duration = time.Duration(amount) * 30 * 24 * time.Hour
case "d":
duration = time.Duration(amount) * 24 * time.Hour
case "h":
duration = time.Duration(amount) * time.Hour
case "m":
duration = time.Duration(amount) * time.Minute
case "s":
duration = time.Duration(amount) * time.Second
default:
errorMsg := &model.Message{
Text: "Invalid duration unit. Please use y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).",
Chat: msg.Chat,
Channel: msg.Channel,
Author: "bot",
FromBot: true,
Date: time.Now(),
ReplyTo: msg.ID,
}
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: errorMsg,
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
triggerAt := time.Now().Add(duration)
// Determine the username for the reminder
username := msg.Author
if username == "" {
// Try to extract username from message raw data
if authorData, ok := msg.Raw["author"].(map[string]interface{}); ok {
if name, ok := authorData["username"].(string); ok {
username = name
} else if name, ok := authorData["name"].(string); ok {
username = name
}
}
}
// Create the reminder
_, err = r.creator.CreateReminder(
msg.Channel.Platform,
msg.Chat,
msg.ID,
msg.ReplyTo,
msg.Author,
username,
"", // No additional content for now
triggerAt,
)
if err != nil {
errorMsg := &model.Message{
Text: fmt.Sprintf("Failed to create reminder: %v", err),
Chat: msg.Chat,
Channel: msg.Channel,
Author: "bot",
FromBot: true,
Date: time.Now(),
ReplyTo: msg.ID,
}
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: errorMsg,
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
// Format the acknowledgment message
var confirmText string
switch strings.ToLower(unit) {
case "y":
confirmText = fmt.Sprintf("I'll remind you about this message in %d year(s) on %s", amount, triggerAt.Format("Mon, Jan 2, 2006 at 15:04"))
case "mo":
confirmText = fmt.Sprintf("I'll remind you about this message in %d month(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04"))
case "d":
confirmText = fmt.Sprintf("I'll remind you about this message in %d day(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04"))
case "h":
confirmText = fmt.Sprintf("I'll remind you about this message in %d hour(s) at %s", amount, triggerAt.Format("15:04"))
case "m":
confirmText = fmt.Sprintf("I'll remind you about this message in %d minute(s) at %s", amount, triggerAt.Format("15:04"))
case "s":
confirmText = fmt.Sprintf("I'll remind you about this message in %d second(s)", amount)
}
// Create confirmation message
confirmMsg := &model.Message{
Text: confirmText,
Chat: msg.Chat,
Channel: msg.Channel,
Author: "bot",
FromBot: true,
Date: time.Now(),
ReplyTo: msg.ID,
}
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: confirmMsg,
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}

View file

@ -1,177 +0,0 @@
package reminder
import (
"testing"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
)
// MockCreator is a mock implementation of ReminderCreator for testing
type MockCreator struct {
reminders []*model.Reminder
}
func (m *MockCreator) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
reminder := &model.Reminder{
ID: int64(len(m.reminders) + 1),
Platform: platform,
ChannelID: channelID,
MessageID: messageID,
ReplyToID: replyToID,
UserID: userID,
Username: username,
Content: content,
TriggerAt: triggerAt,
}
m.reminders = append(m.reminders, reminder)
return reminder, nil
}
func TestReminderOnMessage(t *testing.T) {
creator := &MockCreator{reminders: make([]*model.Reminder, 0)}
plugin := New(creator)
tests := []struct {
name string
message *model.Message
expectResponse bool
expectReminder bool
}{
{
name: "Valid reminder command - years",
message: &model.Message{
Text: "!remindme 1y",
ReplyTo: "original-message-id",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: true,
expectReminder: true,
},
{
name: "Valid reminder command - months",
message: &model.Message{
Text: "!remindme 3mo",
ReplyTo: "original-message-id",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: true,
expectReminder: true,
},
{
name: "Valid reminder command - days",
message: &model.Message{
Text: "!remindme 2d",
ReplyTo: "original-message-id",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: true,
expectReminder: true,
},
{
name: "Valid reminder command - hours",
message: &model.Message{
Text: "!remindme 5h",
ReplyTo: "original-message-id",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: true,
expectReminder: true,
},
{
name: "Valid reminder command - minutes",
message: &model.Message{
Text: "!remindme 30m",
ReplyTo: "original-message-id",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: true,
expectReminder: true,
},
{
name: "Valid reminder command - seconds",
message: &model.Message{
Text: "!remindme 60s",
ReplyTo: "original-message-id",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: true,
expectReminder: true,
},
{
name: "Not a reply",
message: &model.Message{
Text: "!remindme 2d",
ReplyTo: "",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: false,
expectReminder: false,
},
{
name: "Not a reminder command",
message: &model.Message{
Text: "hello world",
ReplyTo: "original-message-id",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: false,
expectReminder: false,
},
{
name: "Invalid duration format",
message: &model.Message{
Text: "!remindme abc",
ReplyTo: "original-message-id",
Author: "testuser",
Channel: &model.Channel{Platform: "test"},
},
expectResponse: false,
expectReminder: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
initialCount := len(creator.reminders)
mockCache := &testutil.MockCache{}
actions := plugin.OnMessage(tt.message, nil, mockCache)
if tt.expectResponse && len(actions) == 0 {
t.Errorf("Expected response action, but got none")
}
if !tt.expectResponse && len(actions) > 0 {
t.Errorf("Expected no actions, but got %d", len(actions))
}
// Verify action type is correct when actions are returned
if len(actions) > 0 {
if actions[0].Type != model.ActionSendMessage {
t.Errorf("Expected action type to be %s, but got %s", model.ActionSendMessage, actions[0].Type)
}
if actions[0].Message == nil {
t.Errorf("Expected message in action to not be nil")
}
}
if tt.expectReminder && len(creator.reminders) != initialCount+1 {
t.Errorf("Expected reminder to be created, but it wasn't")
}
if !tt.expectReminder && len(creator.reminders) != initialCount {
t.Errorf("Expected no reminder to be created, but got %d", len(creator.reminders)-initialCount)
}
})
}
}

View file

@ -1,50 +0,0 @@
# Search and Replace Plugin
This plugin allows users to perform search and replace operations on messages by replying to a message with a search/replace command.
## Usage
To use the plugin, reply to any message with a command in the following format:
```
s/search/replace/[flags]
```
Where:
- `search` is the text you want to find (case-sensitive by default)
- `replace` is the text you want to substitute in place of the search term
- `flags` (optional) control the behavior of the replacement
### Supported Flags
- `g` - Global: Replace all occurrences of the search term (without this flag, only the first occurrence is replaced)
- `i` - Case insensitive: Match regardless of case
- `n` - Treat search pattern as a regular expression (advanced users)
### Examples
1. Basic replacement (replaces first occurrence):
```
s/hello/hi/
```
2. Global replacement (replaces all occurrences):
```
s/hello/hi/g
```
3. Case-insensitive replacement:
```
s/Hello/hi/i
```
4. Combined flags (global and case-insensitive):
```
s/hello/hi/gi
```
## Limitations
- The plugin can only access the text content of the original message
- Regular expression support is available with the `n` flag, but should be used carefully as invalid regex patterns will cause errors
- The plugin does not modify the original message; it creates a new message with the replaced text

View file

@ -1,182 +0,0 @@
package searchreplace
import (
"fmt"
"regexp"
"strings"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// Regex pattern for search and replace operations: s/search/replace/[flags]
var searchReplacePattern = regexp.MustCompile(`^s/([^/]*)/([^/]*)(?:/([gimnsuy]*))?$`)
// SearchReplacePlugin is a plugin for performing search and replace operations on messages
type SearchReplacePlugin struct {
plugin.BasePlugin
}
// New creates a new SearchReplacePlugin instance
func New() *SearchReplacePlugin {
return &SearchReplacePlugin{
BasePlugin: plugin.BasePlugin{
ID: "util.searchreplace",
Name: "Search and Replace",
Help: "Reply to a message with a search and replace pattern (`s/search/replace/[flags]`) to create a modified message. " +
"Supported flags: g (global), i (case insensitive)",
},
}
}
// OnMessage handles incoming messages
func (p *SearchReplacePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Only process replies to messages
if msg.ReplyTo == "" {
return nil
}
// Check if the message matches the search/replace pattern
match := searchReplacePattern.FindStringSubmatch(strings.TrimSpace(msg.Text))
if match == nil {
return nil
}
// Get the original message text from the reply_to_message structure in Telegram messages
var originalText string
// For Telegram messages
if msgData, ok := msg.Raw["message"].(map[string]interface{}); ok {
if replyMsg, ok := msgData["reply_to_message"].(map[string]interface{}); ok {
if text, ok := replyMsg["text"].(string); ok {
originalText = text
}
}
}
// Generic fallback for other platforms or if the above method fails
if originalText == "" && msg.Raw["original_message"] != nil {
if original, ok := msg.Raw["original_message"].(map[string]interface{}); ok {
if text, ok := original["text"].(string); ok {
originalText = text
}
}
}
if originalText == "" {
// If we couldn't find the original message text, inform the user
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: &model.Message{
Text: "Sorry, I couldn't find the original message text to perform the replacement.",
Chat: msg.Chat,
Channel: msg.Channel,
ReplyTo: msg.ID,
},
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
// Extract search pattern, replacement and flags
searchPattern := match[1]
replacement := match[2]
flags := ""
if len(match) > 3 {
flags = match[3]
}
// Process the replacement
result, err := p.performReplacement(originalText, searchPattern, replacement, flags)
if err != nil {
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: &model.Message{
Text: fmt.Sprintf("Error performing replacement: %s", err.Error()),
Chat: msg.Chat,
Channel: msg.Channel,
ReplyTo: msg.ID,
},
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
// Only send a response if the text actually changed
if result == originalText {
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: &model.Message{
Text: "No changes were made to the original message.",
Chat: msg.Chat,
Channel: msg.Channel,
ReplyTo: msg.ID,
},
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
// Create a response with the modified text
return []*model.MessageAction{
{
Type: model.ActionSendMessage,
Message: &model.Message{
Text: result,
Chat: msg.Chat,
Channel: msg.Channel,
ReplyTo: msg.ReplyTo, // Reply to the original message
},
Chat: msg.Chat,
Channel: msg.Channel,
},
}
}
// performReplacement performs the search and replace operation on the given text
func (p *SearchReplacePlugin) performReplacement(text, search, replace, flags string) (string, error) {
// Process flags
globalReplace := strings.Contains(flags, "g")
caseInsensitive := strings.Contains(flags, "i")
// Create the regex pattern
pattern := search
regexFlags := ""
if caseInsensitive {
regexFlags += "(?i)"
}
// Escape special characters if we're not in a regular expression
if !strings.Contains(flags, "n") {
pattern = regexp.QuoteMeta(pattern)
}
// Compile the regex
reg, err := regexp.Compile(regexFlags + pattern)
if err != nil {
return "", fmt.Errorf("invalid search pattern: %v", err)
}
// Perform the replacement
var result string
if globalReplace {
result = reg.ReplaceAllString(text, replace)
} else {
// For non-global replace, only replace the first occurrence
indices := reg.FindStringIndex(text)
if indices == nil {
// No match found
return text, nil
}
result = text[:indices[0]] + replace + text[indices[1]:]
}
return result, nil
}

View file

@ -1,218 +0,0 @@
package searchreplace
import (
"testing"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
)
func TestSearchReplace(t *testing.T) {
// Create plugin instance
p := New()
// Test cases
tests := []struct {
name string
command string
originalText string
expectedResult string
expectActions bool
}{
{
name: "Simple replacement",
command: "s/hello/world/",
originalText: "hello everyone",
expectedResult: "world everyone",
expectActions: true,
},
{
name: "Case-insensitive replacement",
command: "s/HELLO/world/i",
originalText: "Hello everyone",
expectedResult: "world everyone",
expectActions: true,
},
{
name: "Global replacement",
command: "s/a/X/g",
originalText: "banana",
expectedResult: "bXnXnX",
expectActions: true,
},
{
name: "No change",
command: "s/nothing/something/",
originalText: "test message",
expectedResult: "test message",
expectActions: true, // We send a "no changes" message
},
{
name: "Not a search/replace command",
command: "hello",
originalText: "test message",
expectedResult: "",
expectActions: false,
},
{
name: "Invalid pattern",
command: "s/(/)/",
originalText: "test message",
expectedResult: "error",
expectActions: true, // We send an error message
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create message
msg := &model.Message{
Text: tc.command,
Chat: "test-chat",
ReplyTo: "original-message-id",
Date: time.Now(),
Channel: &model.Channel{
Platform: "test",
},
Raw: map[string]interface{}{
"message": map[string]interface{}{
"reply_to_message": map[string]interface{}{
"text": tc.originalText,
},
},
},
}
// Process message
mockCache := &testutil.MockCache{}
actions := p.OnMessage(msg, nil, mockCache)
// Check results
if tc.expectActions {
if len(actions) == 0 {
t.Fatalf("Expected actions but got none")
}
action := actions[0]
if action.Type != model.ActionSendMessage {
t.Fatalf("Expected send message action but got %v", action.Type)
}
if tc.expectedResult == "error" {
// Just checking that we got an error message
if action.Message == nil || action.Message.Text == "" {
t.Fatalf("Expected error message but got empty message")
}
} else if tc.originalText == tc.expectedResult {
// Check if we got the "no changes" message
if action.Message == nil || action.Message.Text != "No changes were made to the original message." {
t.Fatalf("Expected 'no changes' message but got: %s", action.Message.Text)
}
} else {
// Check actual replacement result
if action.Message == nil || action.Message.Text != tc.expectedResult {
t.Fatalf("Expected result: %s, got: %s", tc.expectedResult, action.Message.Text)
}
}
} else if len(actions) > 0 {
t.Fatalf("Expected no actions but got %d", len(actions))
}
})
}
}
func TestPerformReplacement(t *testing.T) {
p := New()
// Test cases for the performReplacement function
tests := []struct {
name string
text string
search string
replace string
flags string
expected string
expectErr bool
}{
{
name: "Simple replacement",
text: "Hello World",
search: "Hello",
replace: "Hi",
flags: "",
expected: "Hi World",
expectErr: false,
},
{
name: "Case insensitive",
text: "Hello World",
search: "hello",
replace: "Hi",
flags: "i",
expected: "Hi World",
expectErr: false,
},
{
name: "Global replacement",
text: "one two one two",
search: "one",
replace: "1",
flags: "g",
expected: "1 two 1 two",
expectErr: false,
},
{
name: "No match",
text: "Hello World",
search: "Goodbye",
replace: "Hi",
flags: "",
expected: "Hello World",
expectErr: false,
},
{
name: "Invalid regex",
text: "Hello World",
search: "(",
replace: "Hi",
flags: "n", // treat as regex
expected: "",
expectErr: true,
},
{
name: "Escape special chars by default",
text: "Hello (World)",
search: "(World)",
replace: "[Earth]",
flags: "",
expected: "Hello [Earth]",
expectErr: false,
},
{
name: "Regex mode with n flag",
text: "Hello (World)",
search: "\\(World\\)",
replace: "[Earth]",
flags: "n",
expected: "Hello [Earth]",
expectErr: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := p.performReplacement(tc.text, tc.search, tc.replace, tc.flags)
if tc.expectErr {
if err == nil {
t.Fatalf("Expected error but got none")
}
} else if err != nil {
t.Fatalf("Unexpected error: %v", err)
} else if result != tc.expected {
t.Fatalf("Expected result: %s, got: %s", tc.expected, result)
}
})
}
}

View file

@ -1,92 +0,0 @@
package social
import (
"net/url"
"regexp"
"strings"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// InstagramExpander transforms instagram.com links to ddinstagram.com links
type InstagramExpander struct {
plugin.BasePlugin
}
// New creates a new InstagramExpander instance
func NewInstagramExpander() *InstagramExpander {
return &InstagramExpander{
BasePlugin: plugin.BasePlugin{
ID: "social.instagram",
Name: "Instagram Link Expander",
Help: "Automatically converts instagram.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: ddinstagram.com)",
ConfigRequired: true,
},
}
}
// OnMessage handles incoming messages
func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Skip empty messages
if strings.TrimSpace(msg.Text) == "" {
return nil
}
// Get replacement domain from config, default to ddinstagram.com
replacementDomain := "ddinstagram.com"
if domain, ok := config["domain"].(string); ok && domain != "" {
replacementDomain = domain
}
// Regex to match instagram.com links
// Match both http://instagram.com and https://instagram.com formats
// Also match www.instagram.com
instagramRegex := regexp.MustCompile(`https?://(www\.)?(instagram\.com)/[^\s]+`)
// Check if the message contains an Instagram link
if !instagramRegex.MatchString(msg.Text) {
return nil
}
// Replace instagram.com with configured domain in the message and clean query parameters
transformed := instagramRegex.ReplaceAllStringFunc(msg.Text, func(link string) string {
// Parse the URL
parsedURL, err := url.Parse(link)
if err != nil {
// If parsing fails, just do the simple replacement
return link
}
// Ensure we don't change links that already come from the replacement domain
if parsedURL.Host != "instagram.com" && parsedURL.Host != "www.instagram.com" {
return link
}
// Change the host to the configured domain
parsedURL.Host = replacementDomain
// Remove query parameters
parsedURL.RawQuery = ""
// Return the cleaned URL
return parsedURL.String()
})
// Create response message
response := &model.Message{
Text: transformed,
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
}
action := &model.MessageAction{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
}

View file

@ -1,88 +0,0 @@
package social
import (
"net/url"
"regexp"
"strings"
"git.nakama.town/fmartingr/butterrobot/internal/model"
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
)
// TwitterExpander transforms twitter.com links to fxtwitter.com links
type TwitterExpander struct {
plugin.BasePlugin
}
// New creates a new TwitterExpander instance
func NewTwitterExpander() *TwitterExpander {
return &TwitterExpander{
BasePlugin: plugin.BasePlugin{
ID: "social.twitter",
Name: "Twitter Link Expander",
Help: "Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: fxtwitter.com)",
ConfigRequired: true,
},
}
}
// OnMessage handles incoming messages
func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
// Skip empty messages
if strings.TrimSpace(msg.Text) == "" {
return nil
}
// Get replacement domain from config, default to fxtwitter.com
replacementDomain := "fxtwitter.com"
if domain, ok := config["domain"].(string); ok && domain != "" {
replacementDomain = domain
}
// Regex to match twitter.com links
// Match both http://twitter.com and https://twitter.com formats
// Also match www.twitter.com
twitterRegex := regexp.MustCompile(`https?://(www\.)?(twitter\.com|x\.com)/[^\s]+`)
// Check if the message contains a Twitter link
if !twitterRegex.MatchString(msg.Text) {
return nil
}
// Replace twitter.com/x.com with configured domain in the message and clean query parameters
transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string {
// Parse the URL
parsedURL, err := url.Parse(link)
if err != nil {
return link
}
// Change the host to the configured domain
if strings.Contains(parsedURL.Host, "twitter.com") || strings.Contains(parsedURL.Host, "x.com") {
parsedURL.Host = replacementDomain
}
// Remove query parameters
parsedURL.RawQuery = ""
// Return the cleaned URL
return parsedURL.String()
})
// Create response message
response := &model.Message{
Text: transformed,
Chat: msg.Chat,
ReplyTo: msg.ID,
Channel: msg.Channel,
}
action := &model.MessageAction{
Type: model.ActionSendMessage,
Message: response,
Chat: msg.Chat,
Channel: msg.Channel,
}
return []*model.MessageAction{action}
}

View file

@ -1,120 +0,0 @@
package social
import (
"testing"
"git.nakama.town/fmartingr/butterrobot/internal/model"
)
func TestTwitterExpander_OnMessage(t *testing.T) {
plugin := NewTwitterExpander()
tests := []struct {
name string
input string
config map[string]interface{}
expected string
hasReply bool
}{
{
name: "Twitter URL with default domain",
input: "https://twitter.com/user/status/123456789",
config: map[string]interface{}{},
expected: "https://fxtwitter.com/user/status/123456789",
hasReply: true,
},
{
name: "X.com URL with custom domain",
input: "https://x.com/elonmusk/status/987654321",
config: map[string]interface{}{"domain": "vxtwitter.com"},
expected: "https://vxtwitter.com/elonmusk/status/987654321",
hasReply: true,
},
{
name: "Twitter URL with tracking parameters",
input: "https://twitter.com/openai/status/555?ref_src=twsrc%5Etfw&s=20",
config: map[string]interface{}{},
expected: "https://fxtwitter.com/openai/status/555",
hasReply: true,
},
{
name: "www.twitter.com URL",
input: "https://www.twitter.com/user/status/789",
config: map[string]interface{}{"domain": "nitter.net"},
expected: "https://nitter.net/user/status/789",
hasReply: true,
},
{
name: "Mixed text with Twitter URL",
input: "Check this out: https://twitter.com/user/status/123 amazing!",
config: map[string]interface{}{},
expected: "Check this out: https://fxtwitter.com/user/status/123 amazing!",
hasReply: true,
},
{
name: "No Twitter URLs",
input: "Just some regular text https://youtube.com/watch?v=abc",
config: map[string]interface{}{},
expected: "",
hasReply: false,
},
{
name: "Empty message",
input: "",
config: map[string]interface{}{},
expected: "",
hasReply: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := &model.Message{
ID: "test_msg",
Text: tt.input,
Chat: "test_chat",
Channel: &model.Channel{
ID: 1,
Platform: "telegram",
PlatformChannelID: "test_chat",
},
}
actions := plugin.OnMessage(msg, tt.config, nil)
if !tt.hasReply {
if len(actions) != 0 {
t.Errorf("Expected no actions, got %d", len(actions))
}
return
}
if len(actions) != 1 {
t.Errorf("Expected 1 action, got %d", len(actions))
return
}
action := actions[0]
if action.Type != model.ActionSendMessage {
t.Errorf("Expected ActionSendMessage, got %s", action.Type)
}
if action.Message == nil {
t.Error("Expected message in action, got nil")
return
}
if action.Message.Text != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, action.Message.Text)
}
if action.Message.ReplyTo != msg.ID {
t.Errorf("Expected ReplyTo '%s', got '%s'", msg.ID, action.Message.ReplyTo)
}
if action.Message.Raw == nil || action.Message.Raw["parse_mode"] != "" {
t.Error("Expected parse_mode to be empty string to disable markdown parsing")
}
})
}
}

View file

@ -3,9 +3,6 @@ package queue
import ( import (
"log/slog" "log/slog"
"sync" "sync"
"time"
"git.nakama.town/fmartingr/butterrobot/internal/model"
) )
// Item represents a queue item // Item represents a queue item
@ -17,19 +14,14 @@ type Item struct {
// HandlerFunc defines a function that processes queue items // HandlerFunc defines a function that processes queue items
type HandlerFunc func(item Item) type HandlerFunc func(item Item)
// ReminderHandlerFunc defines a function that processes reminder items
type ReminderHandlerFunc func(reminder *model.Reminder)
// Queue represents a message queue // Queue represents a message queue
type Queue struct { type Queue struct {
items chan Item items chan Item
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
logger *slog.Logger logger *slog.Logger
running bool running bool
runMutex sync.Mutex runMutex sync.Mutex
reminderTicker *time.Ticker
reminderHandler ReminderHandlerFunc
} }
// New creates a new Queue instance // New creates a new Queue instance
@ -57,24 +49,6 @@ func (q *Queue) Start(handler HandlerFunc) {
go q.worker(handler) go q.worker(handler)
} }
// StartReminderScheduler starts the reminder scheduler
func (q *Queue) StartReminderScheduler(handler ReminderHandlerFunc) {
q.runMutex.Lock()
defer q.runMutex.Unlock()
if q.reminderTicker != nil {
return
}
q.reminderHandler = handler
// Check for reminders every minute
q.reminderTicker = time.NewTicker(1 * time.Minute)
q.wg.Add(1)
go q.reminderWorker()
}
// Stop stops processing queue items // Stop stops processing queue items
func (q *Queue) Stop() { func (q *Queue) Stop() {
q.runMutex.Lock() q.runMutex.Lock()
@ -85,12 +59,6 @@ func (q *Queue) Stop() {
} }
q.running = false q.running = false
// Stop reminder ticker if it exists
if q.reminderTicker != nil {
q.reminderTicker.Stop()
}
close(q.quit) close(q.quit)
q.wg.Wait() q.wg.Wait()
} }
@ -128,34 +96,4 @@ func (q *Queue) worker(handler HandlerFunc) {
return return
} }
} }
} }
// reminderWorker processes reminder items on a schedule
func (q *Queue) reminderWorker() {
defer q.wg.Done()
for {
select {
case <-q.reminderTicker.C:
// This is triggered every minute to check for pending reminders
q.logger.Debug("Checking for pending reminders")
if q.reminderHandler != nil {
// The handler is responsible for fetching and processing reminders
func() {
defer func() {
if r := recover(); r != nil {
q.logger.Error("Panic in reminder worker", "error", r)
}
}()
// Call the handler with a nil reminder to indicate it should check the database
q.reminderHandler(nil)
}()
}
case <-q.quit:
// Quit worker
return
}
}
}

View file

@ -1,29 +0,0 @@
package testutil
import (
"errors"
"time"
)
// MockCache implements the CacheInterface for testing
type MockCache struct{}
func (m *MockCache) Get(key string, destination interface{}) error {
return errors.New("cache miss") // Always return cache miss for tests
}
func (m *MockCache) Set(key string, value interface{}, expiration *time.Time) error {
return nil // Always succeed for tests
}
func (m *MockCache) SetWithTTL(key string, value interface{}, ttl time.Duration) error {
return nil // Always succeed for tests
}
func (m *MockCache) Delete(key string) error {
return nil // Always succeed for tests
}
func (m *MockCache) Exists(key string) (bool, error) {
return false, nil // Always return false for tests
}