Compare commits
No commits in common. "master" and "v0.2.2" have entirely different histories.
51 changed files with 151 additions and 4901 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -5,12 +5,9 @@ __pycache__
|
||||||
*.cert
|
*.cert
|
||||||
.env-local
|
.env-local
|
||||||
.coverage
|
.coverage
|
||||||
coverage.out
|
|
||||||
|
|
||||||
dist
|
dist
|
||||||
bin
|
bin
|
||||||
# Butterrobot
|
# Butterrobot
|
||||||
*.sqlite*
|
*.sqlite*
|
||||||
butterrobot.db*
|
butterrobot.db
|
||||||
/butterrobot
|
|
||||||
*_test.db*
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ when:
|
||||||
- push
|
- push
|
||||||
- pull_request
|
- pull_request
|
||||||
branch:
|
branch:
|
||||||
- master
|
- main
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
format:
|
format:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
when:
|
when:
|
||||||
- event: tag
|
- event: tag
|
||||||
branch: master
|
branch: main
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Release
|
- name: Release
|
||||||
|
|
29
CLAUDE.md
29
CLAUDE.md
|
@ -1,29 +0,0 @@
|
||||||
# Claude Code Instructions
|
|
||||||
|
|
||||||
## Plugin Development Workflow
|
|
||||||
|
|
||||||
When creating, modifying, or removing plugins:
|
|
||||||
|
|
||||||
1. **Always update the plugin documentation** in `docs/plugins.md` after any plugin changes
|
|
||||||
2. Ensure the documentation includes:
|
|
||||||
- Plugin name and category (Development, Fun and entertainment, Utility, Security, Social Media)
|
|
||||||
- Brief description of functionality
|
|
||||||
- Usage instructions with examples
|
|
||||||
- Any configuration requirements
|
|
||||||
3. **For plugins with configuration options:**
|
|
||||||
- Set `ConfigRequired: true` in the plugin's BasePlugin struct
|
|
||||||
- Add corresponding HTML form fields in `internal/admin/templates/channel_plugin_config.html`
|
|
||||||
- Use conditional template logic: `{{else if eq .ChannelPlugin.PluginID "plugin.id"}}`
|
|
||||||
- Include proper form labels, help text, and value binding
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
**CRITICAL**: After making ANY changes to code files, you MUST run these commands in order:
|
|
||||||
|
|
||||||
1. **Format code**: `make format` - Format all code according to project standards
|
|
||||||
2. **Lint code**: `make lint` - Check code style and quality (must show "0 issues")
|
|
||||||
3. **Run tests**: `make test` - Run all tests to ensure functionality works
|
|
||||||
4. Verify documentation accuracy
|
|
||||||
5. Ensure all examples work as described
|
|
||||||
|
|
||||||
**These commands are MANDATORY after every code change, no exceptions.**
|
|
|
@ -1,6 +1,9 @@
|
||||||
# Butter Robot
|
# Butter Robot
|
||||||
|
|
||||||

|
| Stable | Master |
|
||||||
|
| --- | --- |
|
||||||
|
|  |  |
|
||||||
|
|  |  |
|
||||||
|
|
||||||
Go framework to create bots for several platforms.
|
Go framework to create bots for several platforms.
|
||||||
|
|
||||||
|
@ -10,7 +13,7 @@ Go framework to create bots for several platforms.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Support for multiple chat platforms (Slack (untested!), Telegram)
|
- Support for multiple chat platforms (Slack, Telegram)
|
||||||
- Plugin system for easy extension
|
- Plugin system for easy extension
|
||||||
- Admin interface for managing channels and plugins
|
- Admin interface for managing channels and plugins
|
||||||
- Message queue for asynchronous processing
|
- Message queue for asynchronous processing
|
||||||
|
|
|
@ -1,19 +1,6 @@
|
||||||
# Creating a Plugin
|
# Creating a Plugin
|
||||||
|
|
||||||
## Plugin Categories
|
## Example
|
||||||
|
|
||||||
ButterRobot organizes plugins into different categories:
|
|
||||||
|
|
||||||
- **Development**: Utility plugins like `ping`
|
|
||||||
- **Fun**: Entertainment plugins like dice rolling, coin flipping
|
|
||||||
- **Social**: Social media related plugins like URL transformers/expanders
|
|
||||||
- **Security**: Moderation and protection features like domain blocking
|
|
||||||
|
|
||||||
When creating a new plugin, consider which category it fits into and place it in the appropriate directory.
|
|
||||||
|
|
||||||
## Plugin Examples
|
|
||||||
|
|
||||||
### Basic Example: Marco Polo
|
|
||||||
|
|
||||||
This simple "Marco Polo" plugin will answer _Polo_ to the user that says _Marco_:
|
This simple "Marco Polo" plugin will answer _Polo_ to the user that says _Marco_:
|
||||||
|
|
||||||
|
@ -60,207 +47,6 @@ func (p *MarcoPlugin) OnMessage(msg *model.Message, config map[string]interface{
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Configuration-Enabled Plugin
|
|
||||||
|
|
||||||
This plugin requires configuration to be set in the admin interface. It demonstrates how to create plugins that need channel-specific configuration:
|
|
||||||
|
|
||||||
```go
|
|
||||||
package security
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DomainBlockPlugin is a plugin that blocks messages containing links from specific domains
|
|
||||||
type DomainBlockPlugin struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new DomainBlockPlugin instance
|
|
||||||
func New() *DomainBlockPlugin {
|
|
||||||
return &DomainBlockPlugin{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "security.domainblock",
|
|
||||||
Name: "Domain Blocker",
|
|
||||||
Help: "Blocks messages containing links from configured domains",
|
|
||||||
ConfigRequired: true, // Mark this plugin as requiring configuration
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage processes incoming messages
|
|
||||||
func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
|
||||||
// Get blocked domains from config
|
|
||||||
blockedDomainsStr, ok := config["blocked_domains"].(string)
|
|
||||||
if !ok || blockedDomainsStr == "" {
|
|
||||||
return nil // No blocked domains configured
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split and clean blocked domains
|
|
||||||
blockedDomains := strings.Split(blockedDomainsStr, ",")
|
|
||||||
for i, domain := range blockedDomains {
|
|
||||||
blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract domains from message
|
|
||||||
urlRegex := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`)
|
|
||||||
matches := urlRegex.FindAllStringSubmatch(msg.Text, -1)
|
|
||||||
|
|
||||||
// Check if any extracted domains are blocked
|
|
||||||
for _, match := range matches {
|
|
||||||
if len(match) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
domain := strings.ToLower(match[1])
|
|
||||||
|
|
||||||
for _, blockedDomain := range blockedDomains {
|
|
||||||
if blockedDomain == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasSuffix(domain, blockedDomain) || domain == blockedDomain {
|
|
||||||
// Domain is blocked, create warning message
|
|
||||||
response := &model.Message{
|
|
||||||
Text: fmt.Sprintf("⚠️ Message contained a link to blocked domain: %s", blockedDomain),
|
|
||||||
Chat: msg.Chat,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
return []*model.Message{response}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
plugin.Register(New())
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Advanced Example: URL Transformer
|
|
||||||
|
|
||||||
This more complex plugin transforms URLs, useful for improving media embedding in chat platforms:
|
|
||||||
|
|
||||||
```go
|
|
||||||
package social
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TwitterExpander transforms twitter.com links to fxtwitter.com links
|
|
||||||
type TwitterExpander struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new TwitterExpander instance
|
|
||||||
func NewTwitter() *TwitterExpander {
|
|
||||||
return &TwitterExpander{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "social.twitter",
|
|
||||||
Name: "Twitter Link Expander",
|
|
||||||
Help: "Automatically converts twitter.com links to fxtwitter.com links and removes tracking parameters",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
|
||||||
func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
|
||||||
// Skip empty messages
|
|
||||||
if strings.TrimSpace(msg.Text) == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regex to match twitter.com links
|
|
||||||
twitterRegex := regexp.MustCompile(`https?://(www\.)?(twitter\.com|x\.com)/[^\s]+`)
|
|
||||||
|
|
||||||
// Check if the message contains a Twitter link
|
|
||||||
if !twitterRegex.MatchString(msg.Text) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transform the URL
|
|
||||||
transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string {
|
|
||||||
// Parse the URL
|
|
||||||
parsedURL, err := url.Parse(link)
|
|
||||||
if err != nil {
|
|
||||||
// If parsing fails, just do the simple replacement
|
|
||||||
link = strings.Replace(link, "twitter.com", "fxtwitter.com", 1)
|
|
||||||
link = strings.Replace(link, "x.com", "fxtwitter.com", 1)
|
|
||||||
return link
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change the host
|
|
||||||
if strings.Contains(parsedURL.Host, "twitter.com") {
|
|
||||||
parsedURL.Host = strings.Replace(parsedURL.Host, "twitter.com", "fxtwitter.com", 1)
|
|
||||||
} else if strings.Contains(parsedURL.Host, "x.com") {
|
|
||||||
parsedURL.Host = strings.Replace(parsedURL.Host, "x.com", "fxtwitter.com", 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove query parameters
|
|
||||||
parsedURL.RawQuery = ""
|
|
||||||
|
|
||||||
// Return the cleaned URL
|
|
||||||
return parsedURL.String()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create response message
|
|
||||||
response := &model.Message{
|
|
||||||
Text: transformed,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.Message{response}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Enabling Configuration for Plugins
|
|
||||||
|
|
||||||
To indicate that your plugin requires configuration:
|
|
||||||
|
|
||||||
1. Set `ConfigRequired: true` in the BasePlugin struct:
|
|
||||||
```go
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "myplugin.id",
|
|
||||||
Name: "Plugin Name",
|
|
||||||
Help: "Help text",
|
|
||||||
ConfigRequired: true,
|
|
||||||
},
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Access the configuration in the OnMessage method:
|
|
||||||
```go
|
|
||||||
func (p *MyPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
|
||||||
// Extract configuration values
|
|
||||||
configValue, ok := config["some_config_key"].(string)
|
|
||||||
if !ok || configValue == "" {
|
|
||||||
// Handle missing or empty configuration
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the configuration...
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
3. The admin interface will show a "Configure" button for plugins that require configuration.
|
|
||||||
|
|
||||||
## Registering Plugins
|
|
||||||
|
|
||||||
To use the plugin, register it in your application:
|
To use the plugin, register it in your application:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
|
@ -269,19 +55,8 @@ func (a *App) Run() error {
|
||||||
// ...
|
// ...
|
||||||
|
|
||||||
// Register plugins
|
// Register plugins
|
||||||
plugin.Register(ping.New()) // Development plugin
|
plugin.Register(myplugin.New())
|
||||||
plugin.Register(fun.NewCoin()) // Fun plugin
|
|
||||||
plugin.Register(social.NewTwitter()) // Social media plugin
|
|
||||||
plugin.Register(myplugin.New()) // Your custom plugin
|
|
||||||
|
|
||||||
// ...
|
// ...
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, you can register your plugin in its init() function:
|
|
||||||
|
|
||||||
```go
|
|
||||||
func init() {
|
|
||||||
plugin.Register(New())
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
|
@ -9,19 +9,3 @@
|
||||||
- Lo quito: What happens when you say _"lo quito"_...? (Spanish pun)
|
- Lo quito: What happens when you say _"lo quito"_...? (Spanish pun)
|
||||||
- Dice: Put `!dice` and wathever roll you want to perform.
|
- Dice: Put `!dice` and wathever roll you want to perform.
|
||||||
- Coin: Flip a coin and get heads or tails.
|
- Coin: Flip a coin and get heads or tails.
|
||||||
- How Long To Beat: Get game completion times from HowLongToBeat.com using `!hltb <game name>`
|
|
||||||
|
|
||||||
### Utility
|
|
||||||
|
|
||||||
- Help: Shows available commands when you type `!help`. Lists all enabled plugins for the current channel organized by category with their descriptions and usage instructions.
|
|
||||||
- Remind Me: Reply to a message with `!remindme <duration>` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time.
|
|
||||||
- Search and Replace: Reply to any message with `s/search/replace/[flags]` to perform text substitution. Supports flags: `g` (global), `i` (case insensitive), `n` (regex pattern). Example: `s/hello/hi/gi` replaces all occurrences of "hello" with "hi" case-insensitively.
|
|
||||||
|
|
||||||
### Security
|
|
||||||
|
|
||||||
- Domain Blocker: Blocks messages containing links from specified domains. Configure it per channel with a comma-separated list of domains to block. When a message contains a link matching any of the blocked domains, the bot will notify that the message contained a blocked domain. This plugin requires configuration through the admin interface.
|
|
||||||
|
|
||||||
### Social Media
|
|
||||||
|
|
||||||
- Twitter Link Expander: Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: fxtwitter.com).
|
|
||||||
- Instagram Link Expander: Automatically converts instagram.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: ddinstagram.com).
|
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -6,7 +6,6 @@ require (
|
||||||
github.com/gorilla/sessions v1.4.0
|
github.com/gorilla/sessions v1.4.0
|
||||||
golang.org/x/crypto v0.37.0
|
golang.org/x/crypto v0.37.0
|
||||||
golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df
|
golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df
|
||||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
|
|
||||||
modernc.org/sqlite v1.37.0
|
modernc.org/sqlite v1.37.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +16,7 @@ require (
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
|
||||||
golang.org/x/sys v0.32.0 // indirect
|
golang.org/x/sys v0.32.0 // indirect
|
||||||
modernc.org/libc v1.63.0 // indirect
|
modernc.org/libc v1.63.0 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed templates/*.html templates/plugins/*.html
|
//go:embed templates/*.html
|
||||||
var templateFS embed.FS
|
var templateFS embed.FS
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -46,7 +46,6 @@ type TemplateData struct {
|
||||||
Channels []*model.Channel
|
Channels []*model.Channel
|
||||||
Channel *model.Channel
|
Channel *model.Channel
|
||||||
ChannelPlugin *model.ChannelPlugin
|
ChannelPlugin *model.ChannelPlugin
|
||||||
Version string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Admin represents the admin interface
|
// Admin represents the admin interface
|
||||||
|
@ -56,11 +55,10 @@ type Admin struct {
|
||||||
store *sessions.CookieStore
|
store *sessions.CookieStore
|
||||||
templates map[string]*template.Template
|
templates map[string]*template.Template
|
||||||
baseTemplate *template.Template
|
baseTemplate *template.Template
|
||||||
version string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Admin instance
|
// New creates a new Admin instance
|
||||||
func New(cfg *config.Config, database *db.Database, version string) *Admin {
|
func New(cfg *config.Config, database *db.Database) *Admin {
|
||||||
// Create session store with appropriate options
|
// Create session store with appropriate options
|
||||||
store := sessions.NewCookieStore([]byte(cfg.SecretKey))
|
store := sessions.NewCookieStore([]byte(cfg.SecretKey))
|
||||||
store.Options = &sessions.Options{
|
store.Options = &sessions.Options{
|
||||||
|
@ -90,7 +88,7 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse and register all templates
|
// Parse and register all templates
|
||||||
mainTemplateFiles := []string{
|
templateFiles := []string{
|
||||||
"index.html",
|
"index.html",
|
||||||
"login.html",
|
"login.html",
|
||||||
"change_password.html",
|
"change_password.html",
|
||||||
|
@ -98,48 +96,27 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
|
||||||
"channel_detail.html",
|
"channel_detail.html",
|
||||||
"plugin_list.html",
|
"plugin_list.html",
|
||||||
"channel_plugins_list.html",
|
"channel_plugins_list.html",
|
||||||
"channel_plugin_config.html",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pluginTemplateFiles := []string{
|
for _, tf := range templateFiles {
|
||||||
"plugins/security.domainblock.html",
|
|
||||||
"plugins/social.instagram.html",
|
|
||||||
"plugins/social.twitter.html",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tf := range mainTemplateFiles {
|
|
||||||
// Read template content from embedded filesystem
|
// Read template content from embedded filesystem
|
||||||
content, err := templateFS.ReadFile("templates/" + tf)
|
content, err := templateFS.ReadFile("templates/" + tf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a clone of the base template
|
// Create a clone of the base template
|
||||||
t, err := baseTemplate.Clone()
|
t, err := baseTemplate.Clone()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the template content
|
// Parse the template content
|
||||||
t, err = t.Parse(string(content))
|
t, err = t.Parse(string(content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is the channel_plugin_config template, also parse plugin templates
|
|
||||||
if tf == "channel_plugin_config.html" {
|
|
||||||
for _, pluginTf := range pluginTemplateFiles {
|
|
||||||
pluginContent, err := templateFS.ReadFile("templates/" + pluginTf)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
t, err = t.Parse(string(pluginContent))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
templates[tf] = t
|
templates[tf] = t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,7 +126,6 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
|
||||||
store: store,
|
store: store,
|
||||||
templates: templates,
|
templates: templates,
|
||||||
baseTemplate: baseTemplate,
|
baseTemplate: baseTemplate,
|
||||||
version: version,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,7 +140,6 @@ func (a *Admin) RegisterRoutes(mux *http.ServeMux) {
|
||||||
mux.HandleFunc("/admin/channels", a.handleChannelList)
|
mux.HandleFunc("/admin/channels", a.handleChannelList)
|
||||||
mux.HandleFunc("/admin/channels/", a.handleChannelDetail)
|
mux.HandleFunc("/admin/channels/", a.handleChannelDetail)
|
||||||
mux.HandleFunc("/admin/channelplugins", a.handleChannelPluginList)
|
mux.HandleFunc("/admin/channelplugins", a.handleChannelPluginList)
|
||||||
mux.HandleFunc("/admin/channelplugins/config/", a.handleChannelPluginConfig)
|
|
||||||
mux.HandleFunc("/admin/channelplugins/", a.handleChannelPluginDetailOrDelete)
|
mux.HandleFunc("/admin/channelplugins/", a.handleChannelPluginDetailOrDelete)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,7 +191,7 @@ func (a *Admin) addFlash(w http.ResponseWriter, r *http.Request, message string,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map internal categories to Bootstrap alert classes
|
// Map internal categories to Bootstrap alert classes
|
||||||
var alertClass string
|
alertClass := category
|
||||||
switch category {
|
switch category {
|
||||||
case "success":
|
case "success":
|
||||||
alertClass = "success"
|
alertClass = "success"
|
||||||
|
@ -271,6 +246,17 @@ func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessag
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// requireLogin middleware checks if the user is logged in
|
||||||
|
func (a *Admin) requireLogin(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !a.isLoggedIn(r) {
|
||||||
|
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// render renders a template with the given data
|
// render renders a template with the given data
|
||||||
func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) {
|
func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) {
|
||||||
// Add current user data
|
// Add current user data
|
||||||
|
@ -278,7 +264,6 @@ func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName stri
|
||||||
data.LoggedIn = a.isLoggedIn(r)
|
data.LoggedIn = a.isLoggedIn(r)
|
||||||
data.Path = r.URL.Path
|
data.Path = r.URL.Path
|
||||||
data.Flash = a.getFlashes(w, r)
|
data.Flash = a.getFlashes(w, r)
|
||||||
data.Version = a.version
|
|
||||||
|
|
||||||
// Get template
|
// Get template
|
||||||
tmpl, ok := a.templates[templateName]
|
tmpl, ok := a.templates[templateName]
|
||||||
|
@ -345,10 +330,7 @@ func (a *Admin) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// Set session expiration
|
// Set session expiration
|
||||||
session.Options.MaxAge = 3600 * 24 * 7 // 1 week
|
session.Options.MaxAge = 3600 * 24 * 7 // 1 week
|
||||||
err = session.Save(r, w)
|
session.Save(r, w)
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Error saving session: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
a.addFlash(w, r, "You were logged in", "success")
|
a.addFlash(w, r, "You were logged in", "success")
|
||||||
|
|
||||||
|
@ -376,7 +358,7 @@ func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Values = make(map[interface{}]interface{})
|
session.Values = make(map[interface{}]interface{})
|
||||||
session.Options.MaxAge = -1 // Delete session
|
session.Options.MaxAge = -1 // Delete session
|
||||||
err = session.Save(r, w)
|
err = session.Save(r, w)
|
||||||
|
@ -564,13 +546,6 @@ func (a *Admin) handleChannelDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update enable_all_plugins
|
|
||||||
enableAllPlugins := r.FormValue("enable_all_plugins") == "true"
|
|
||||||
if err := a.db.UpdateChannelEnableAllPlugins(id, enableAllPlugins); err != nil {
|
|
||||||
http.Error(w, "Failed to update channel enable all plugins", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
a.addFlash(w, r, "Channel updated", "success")
|
a.addFlash(w, r, "Channel updated", "success")
|
||||||
http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther)
|
http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther)
|
||||||
return
|
return
|
||||||
|
@ -657,96 +632,6 @@ func (a *Admin) handleChannelPluginList(w http.ResponseWriter, r *http.Request)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleChannelPluginConfig handles the channel plugin configuration route
|
|
||||||
func (a *Admin) handleChannelPluginConfig(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Check if user is logged in
|
|
||||||
if !a.isLoggedIn(r) {
|
|
||||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract channel plugin ID from path
|
|
||||||
path := r.URL.Path
|
|
||||||
channelPluginID := strings.TrimPrefix(path, "/admin/channelplugins/config/")
|
|
||||||
|
|
||||||
// Convert channel plugin ID to int64
|
|
||||||
id, err := strconv.ParseInt(channelPluginID, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Invalid channel plugin ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the channel plugin
|
|
||||||
channelPlugin, err := a.db.GetChannelPluginByID(id)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Channel plugin not found", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the plugin
|
|
||||||
p, err := plugin.Get(channelPlugin.PluginID)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Plugin not found", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle form submission
|
|
||||||
if r.Method == http.MethodPost {
|
|
||||||
// Parse form
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create config map from form values
|
|
||||||
config := make(map[string]interface{})
|
|
||||||
|
|
||||||
// Process form values based on plugin type
|
|
||||||
if channelPlugin.PluginID == "security.domainblock" {
|
|
||||||
// Get blocked domains from form
|
|
||||||
blockedDomains := r.FormValue("blocked_domains")
|
|
||||||
config["blocked_domains"] = blockedDomains
|
|
||||||
} else {
|
|
||||||
// Generic handling for other plugins
|
|
||||||
for key, values := range r.Form {
|
|
||||||
if key == "form_submitted" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(values) == 1 {
|
|
||||||
config[key] = values[0]
|
|
||||||
} else {
|
|
||||||
config[key] = values
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update plugin configuration
|
|
||||||
if err := a.db.UpdateChannelPluginConfig(id, config); err != nil {
|
|
||||||
http.Error(w, "Failed to update plugin configuration", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the channel to redirect back to the channel detail page
|
|
||||||
channel, err := a.db.GetChannelByID(channelPlugin.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
a.addFlash(w, r, "Plugin configuration updated", "success")
|
|
||||||
http.Redirect(w, r, "/admin/channelplugins", http.StatusSeeOther)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
a.addFlash(w, r, "Plugin configuration updated", "success")
|
|
||||||
http.Redirect(w, r, fmt.Sprintf("/admin/channels/%d", channel.ID), http.StatusSeeOther)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Render template
|
|
||||||
a.render(w, r, "channel_plugin_config.html", TemplateData{
|
|
||||||
Title: "Configure Plugin: " + p.GetName(),
|
|
||||||
ChannelPlugin: channelPlugin,
|
|
||||||
Plugins: map[string]model.Plugin{channelPlugin.PluginID: p},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleChannelPluginDetailOrDelete handles the channel plugin detail or delete route
|
// handleChannelPluginDetailOrDelete handles the channel plugin detail or delete route
|
||||||
func (a *Admin) handleChannelPluginDetailOrDelete(w http.ResponseWriter, r *http.Request) {
|
func (a *Admin) handleChannelPluginDetailOrDelete(w http.ResponseWriter, r *http.Request) {
|
||||||
// Check if user is logged in
|
// Check if user is logged in
|
||||||
|
|
|
@ -117,19 +117,6 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<footer class="footer footer-transparent d-print-none">
|
|
||||||
<div class="container-xl">
|
|
||||||
<div class="row text-center align-items-center flex-row-reverse">
|
|
||||||
<div class="col-12 col-lg-auto mt-3 mt-lg-0">
|
|
||||||
<ul class="list-inline list-inline-dots mb-0">
|
|
||||||
<li class="list-inline-item">
|
|
||||||
ButterRobot {{if .Version}}v{{.Version}}{{else}}(development){{end}}
|
|
||||||
</li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</footer>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<script src="https://unpkg.com/@tabler/core@latest/dist/js/tabler.min.js"></script>
|
<script src="https://unpkg.com/@tabler/core@latest/dist/js/tabler.min.js"></script>
|
||||||
|
|
|
@ -27,15 +27,6 @@
|
||||||
<!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked -->
|
<!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked -->
|
||||||
<input type="hidden" name="form_submitted" value="true">
|
<input type="hidden" name="form_submitted" value="true">
|
||||||
</div>
|
</div>
|
||||||
<div class="mb-3">
|
|
||||||
<label class="form-check form-switch">
|
|
||||||
<input class="form-check-input" type="checkbox" name="enable_all_plugins" value="true" {{if .Channel.EnableAllPlugins}}checked{{end}}>
|
|
||||||
<span class="form-check-label">Enable All Plugins</span>
|
|
||||||
</label>
|
|
||||||
<div>
|
|
||||||
When enabled, all registered plugins will be automatically enabled for this channel. Individual plugin settings will be ignored.
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="form-footer">
|
<div class="form-footer">
|
||||||
<button type="submit" class="btn btn-primary">Save</button>
|
<button type="submit" class="btn btn-primary">Save</button>
|
||||||
<a href="/admin/channels" class="btn btn-link">Back to Channels</a>
|
<a href="/admin/channels" class="btn btn-link">Back to Channels</a>
|
||||||
|
@ -77,10 +68,6 @@
|
||||||
{{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}}
|
{{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}}
|
||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
{{$plugin := index $.Plugins $pluginID}}
|
|
||||||
{{if $plugin.RequiresConfig}}
|
|
||||||
<a href="/admin/channelplugins/config/{{$channelPlugin.ID}}" class="btn btn-info btn-sm">Configure</a>
|
|
||||||
{{end}}
|
|
||||||
<form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline">
|
<form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline">
|
||||||
<button type="submit" class="btn btn-danger btn-sm"
|
<button type="submit" class="btn btn-danger btn-sm"
|
||||||
onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button>
|
onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button>
|
||||||
|
@ -124,4 +111,4 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{{end}}
|
{{end}}
|
|
@ -1,32 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="row">
|
|
||||||
<div class="col-md-12">
|
|
||||||
<div class="card">
|
|
||||||
<div class="card-header">
|
|
||||||
<h3 class="card-title">Configure Plugin: {{(index .Plugins .ChannelPlugin.PluginID).GetName}}</h3>
|
|
||||||
</div>
|
|
||||||
<div class="card-body">
|
|
||||||
<form method="post">
|
|
||||||
<!-- Plugin configuration fields -->
|
|
||||||
{{if eq .ChannelPlugin.PluginID "security.domainblock"}}
|
|
||||||
{{template "plugins/security.domainblock.html" .}}
|
|
||||||
{{else if eq .ChannelPlugin.PluginID "social.instagram"}}
|
|
||||||
{{template "plugins/social.instagram.html" .}}
|
|
||||||
{{else if eq .ChannelPlugin.PluginID "social.twitter"}}
|
|
||||||
{{template "plugins/social.twitter.html" .}}
|
|
||||||
{{else}}
|
|
||||||
<div class="alert alert-warning">
|
|
||||||
This plugin doesn't have specific configuration fields implemented yet.
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
<div class="form-footer">
|
|
||||||
<button type="submit" class="btn btn-primary">Save Configuration</button>
|
|
||||||
<a href="/admin/channels/{{.ChannelPlugin.ChannelID}}" class="btn btn-secondary">Cancel</a>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
|
@ -38,10 +38,6 @@
|
||||||
{{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}}
|
{{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}}
|
||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
{{$plugin := index $.Plugins $pluginID}}
|
|
||||||
{{if $plugin.ConfigRequired}}
|
|
||||||
<a href="/admin/channelplugins/config/{{$channelPlugin.ID}}" class="btn btn-info btn-sm">Configure</a>
|
|
||||||
{{end}}
|
|
||||||
<form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline">
|
<form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline">
|
||||||
<button type="submit" class="btn btn-danger btn-sm"
|
<button type="submit" class="btn btn-danger btn-sm"
|
||||||
onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button>
|
onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button>
|
||||||
|
@ -94,4 +90,4 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{{end}}
|
{{end}}
|
|
@ -1,12 +0,0 @@
|
||||||
{{define "plugins/security.domainblock.html"}}
|
|
||||||
<div class="mb-3">
|
|
||||||
<label class="form-label">Blocked Domains</label>
|
|
||||||
<input type="text" class="form-control" name="blocked_domains"
|
|
||||||
value="{{with .ChannelPlugin.Config}}{{index . "blocked_domains"}}{{end}}"
|
|
||||||
placeholder="example.com, evil.org, ads.com">
|
|
||||||
<div class="form-text text-muted">
|
|
||||||
Enter comma-separated list of domains to block (e.g., example.com, evil.org).
|
|
||||||
Messages containing links to these domains will be blocked.
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
|
@ -1,11 +0,0 @@
|
||||||
{{define "plugins/social.instagram.html"}}
|
|
||||||
<div class="mb-3">
|
|
||||||
<label class="form-label">Replacement Domain</label>
|
|
||||||
<input type="text" class="form-control" name="domain"
|
|
||||||
value="{{with .ChannelPlugin.Config}}{{index . "domain"}}{{end}}"
|
|
||||||
placeholder="ddinstagram.com">
|
|
||||||
<div class="form-text text-muted">
|
|
||||||
Enter the domain to replace instagram.com links with. Default is ddinstagram.com if left empty.
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
|
@ -1,11 +0,0 @@
|
||||||
{{define "plugins/social.twitter.html"}}
|
|
||||||
<div class="mb-3">
|
|
||||||
<label class="form-label">Replacement Domain</label>
|
|
||||||
<input type="text" class="form-control" name="domain"
|
|
||||||
value="{{with .ChannelPlugin.Config}}{{index . "domain"}}{{end}}"
|
|
||||||
placeholder="fxtwitter.com">
|
|
||||||
<div class="form-text text-muted">
|
|
||||||
Enter the domain to replace twitter.com and x.com links with. Default is fxtwitter.com if left empty.
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
|
@ -9,37 +9,28 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime/debug"
|
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/admin"
|
"git.nakama.town/fmartingr/butterrobot/internal/admin"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/cache"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/config"
|
"git.nakama.town/fmartingr/butterrobot/internal/config"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/platform"
|
"git.nakama.town/fmartingr/butterrobot/internal/platform"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/domainblock"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun"
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/help"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping"
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/searchreplace"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/social"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/queue"
|
"git.nakama.town/fmartingr/butterrobot/internal/queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
// App represents the application
|
// App represents the application
|
||||||
type App struct {
|
type App struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
db *db.Database
|
db *db.Database
|
||||||
router *http.ServeMux
|
router *http.ServeMux
|
||||||
queue *queue.Queue
|
queue *queue.Queue
|
||||||
admin *admin.Admin
|
admin *admin.Admin
|
||||||
version string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new App instance
|
// New creates a new App instance
|
||||||
|
@ -56,24 +47,16 @@ func New(cfg *config.Config, logger *slog.Logger) (*App, error) {
|
||||||
// Initialize message queue
|
// Initialize message queue
|
||||||
messageQueue := queue.New(logger)
|
messageQueue := queue.New(logger)
|
||||||
|
|
||||||
// Get version information
|
|
||||||
version := ""
|
|
||||||
info, ok := debug.ReadBuildInfo()
|
|
||||||
if ok {
|
|
||||||
version = info.Main.Version
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize admin interface
|
// Initialize admin interface
|
||||||
adminInterface := admin.New(cfg, database, version)
|
adminInterface := admin.New(cfg, database)
|
||||||
|
|
||||||
return &App{
|
return &App{
|
||||||
config: cfg,
|
config: cfg,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
db: database,
|
db: database,
|
||||||
router: router,
|
router: router,
|
||||||
queue: messageQueue,
|
queue: messageQueue,
|
||||||
admin: adminInterface,
|
admin: adminInterface,
|
||||||
version: version,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,13 +72,6 @@ func (a *App) Run() error {
|
||||||
plugin.Register(fun.NewCoin())
|
plugin.Register(fun.NewCoin())
|
||||||
plugin.Register(fun.NewDice())
|
plugin.Register(fun.NewDice())
|
||||||
plugin.Register(fun.NewLoquito())
|
plugin.Register(fun.NewLoquito())
|
||||||
plugin.Register(fun.NewHLTB())
|
|
||||||
plugin.Register(social.NewTwitterExpander())
|
|
||||||
plugin.Register(social.NewInstagramExpander())
|
|
||||||
plugin.Register(reminder.New(a.db))
|
|
||||||
plugin.Register(domainblock.New())
|
|
||||||
plugin.Register(searchreplace.New())
|
|
||||||
plugin.Register(help.New(a.db))
|
|
||||||
|
|
||||||
// Initialize routes
|
// Initialize routes
|
||||||
a.initializeRoutes()
|
a.initializeRoutes()
|
||||||
|
@ -103,12 +79,6 @@ func (a *App) Run() error {
|
||||||
// Start message queue worker
|
// Start message queue worker
|
||||||
a.queue.Start(a.handleMessage)
|
a.queue.Start(a.handleMessage)
|
||||||
|
|
||||||
// Start reminder scheduler
|
|
||||||
a.queue.StartReminderScheduler(a.handleReminder)
|
|
||||||
|
|
||||||
// Start cache cleanup scheduler
|
|
||||||
go a.startCacheCleanup()
|
|
||||||
|
|
||||||
// Create server
|
// Create server
|
||||||
addr := fmt.Sprintf(":%s", a.config.Port)
|
addr := fmt.Sprintf(":%s", a.config.Port)
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
|
@ -154,29 +124,13 @@ func (a *App) Run() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// startCacheCleanup runs periodic cache cleanup
|
|
||||||
func (a *App) startCacheCleanup() {
|
|
||||||
ticker := time.NewTicker(time.Hour) // Clean up every hour
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for range ticker.C {
|
|
||||||
if err := a.db.CacheCleanup(); err != nil {
|
|
||||||
a.logger.Error("Cache cleanup failed", "error", err)
|
|
||||||
} else {
|
|
||||||
a.logger.Debug("Cache cleanup completed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize HTTP routes
|
// Initialize HTTP routes
|
||||||
func (a *App) initializeRoutes() {
|
func (a *App) initializeRoutes() {
|
||||||
// Health check endpoint
|
// Health check endpoint
|
||||||
a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{}); err != nil {
|
json.NewEncoder(w).Encode(map[string]interface{}{})
|
||||||
a.logger.Error("Error encoding response", "error", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Platform webhook endpoints
|
// Platform webhook endpoints
|
||||||
|
@ -199,9 +153,7 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
|
||||||
if _, err := platform.Get(platformName); err != nil {
|
if _, err := platform.Get(platformName); err != nil {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
if err := json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}); err != nil {
|
json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"})
|
||||||
a.logger.Error("Error encoding response", "error", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,9 +162,7 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
if err := json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}); err != nil {
|
json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"})
|
||||||
a.logger.Error("Error encoding response", "error", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,9 +178,7 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
|
||||||
// Respond with success
|
// Respond with success
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil {
|
json.NewEncoder(w).Encode(map[string]any{})
|
||||||
a.logger.Error("Error encoding response", "error", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractPlatformName extracts the platform name from the URL path
|
// extractPlatformName extracts the platform name from the URL path
|
||||||
|
@ -314,21 +262,11 @@ func (a *App) handleMessage(item queue.Item) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process message with plugins
|
// Process message with plugins
|
||||||
var pluginsToProcess []string
|
for pluginID, channelPlugin := range channel.Plugins {
|
||||||
|
if !channel.HasEnabledPlugin(pluginID) {
|
||||||
if channel.EnableAllPlugins {
|
continue
|
||||||
// If EnableAllPlugins is true, process all registered plugins
|
|
||||||
pluginsToProcess = plugin.GetAvailablePluginIDs()
|
|
||||||
} else {
|
|
||||||
// Otherwise, process only explicitly enabled plugins
|
|
||||||
for pluginID := range channel.Plugins {
|
|
||||||
if channel.HasEnabledPlugin(pluginID) {
|
|
||||||
pluginsToProcess = append(pluginsToProcess, pluginID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for _, pluginID := range pluginsToProcess {
|
|
||||||
// Get plugin
|
// Get plugin
|
||||||
p, err := plugin.Get(pluginID)
|
p, err := plugin.Get(pluginID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -336,121 +274,20 @@ func (a *App) handleMessage(item queue.Item) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get plugin configuration (empty map if EnableAllPlugins and plugin not explicitly configured)
|
// Process message
|
||||||
var config map[string]interface{}
|
responses := p.OnMessage(message, channelPlugin.Config)
|
||||||
if channelPlugin, exists := channel.Plugins[pluginID]; exists {
|
|
||||||
config = channelPlugin.Config
|
|
||||||
} else {
|
|
||||||
config = make(map[string]interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create cache instance for this plugin
|
// Send responses
|
||||||
pluginCache := cache.New(a.db, pluginID)
|
|
||||||
|
|
||||||
// Process message and get actions
|
|
||||||
actions := p.OnMessage(message, config, pluginCache)
|
|
||||||
|
|
||||||
// Get platform for processing actions
|
|
||||||
platform, err := platform.Get(item.Platform)
|
platform, err := platform.Get(item.Platform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Error("Error getting platform", "error", err)
|
a.logger.Error("Error getting platform", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process each action
|
for _, response := range responses {
|
||||||
for _, action := range actions {
|
if err := platform.SendMessage(response); err != nil {
|
||||||
switch action.Type {
|
a.logger.Error("Error sending message", "error", err)
|
||||||
case model.ActionSendMessage:
|
|
||||||
// Send a message
|
|
||||||
if action.Message != nil {
|
|
||||||
if err := platform.SendMessage(action.Message); err != nil {
|
|
||||||
a.logger.Error("Error sending message", "error", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
a.logger.Error("Send message action with nil message")
|
|
||||||
}
|
|
||||||
|
|
||||||
case model.ActionDeleteMessage:
|
|
||||||
// Delete a message using direct DeleteMessage call
|
|
||||||
if err := platform.DeleteMessage(action.Chat, action.MessageID); err != nil {
|
|
||||||
a.logger.Error("Error deleting message", "error", err, "message_id", action.MessageID)
|
|
||||||
} else {
|
|
||||||
a.logger.Info("Message deleted", "message_id", action.MessageID)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
a.logger.Error("Unknown action type", "type", action.Type)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleReminder handles reminder processing
|
|
||||||
func (a *App) handleReminder(reminder *model.Reminder) {
|
|
||||||
// When called with nil, it means we should check for pending reminders
|
|
||||||
if reminder == nil {
|
|
||||||
// Get pending reminders
|
|
||||||
reminders, err := a.db.GetPendingReminders()
|
|
||||||
if err != nil {
|
|
||||||
a.logger.Error("Error getting pending reminders", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process each reminder
|
|
||||||
for _, r := range reminders {
|
|
||||||
a.processReminder(r)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, process the specific reminder
|
|
||||||
a.processReminder(reminder)
|
|
||||||
}
|
|
||||||
|
|
||||||
// processReminder processes an individual reminder
|
|
||||||
func (a *App) processReminder(reminder *model.Reminder) {
|
|
||||||
a.logger.Info("Processing reminder",
|
|
||||||
"id", reminder.ID,
|
|
||||||
"platform", reminder.Platform,
|
|
||||||
"channel", reminder.ChannelID,
|
|
||||||
"trigger_at", reminder.TriggerAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Get the platform handler
|
|
||||||
p, err := platform.Get(reminder.Platform)
|
|
||||||
if err != nil {
|
|
||||||
a.logger.Error("Error getting platform for reminder", "error", err, "platform", reminder.Platform)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the channel
|
|
||||||
channel, err := a.db.GetChannelByPlatform(reminder.Platform, reminder.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
a.logger.Error("Error getting channel for reminder", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the reminder message
|
|
||||||
reminderText := fmt.Sprintf("@%s reminding you of this", reminder.Username)
|
|
||||||
|
|
||||||
message := &model.Message{
|
|
||||||
Text: reminderText,
|
|
||||||
Chat: reminder.ChannelID,
|
|
||||||
Channel: channel,
|
|
||||||
Author: "bot",
|
|
||||||
FromBot: true,
|
|
||||||
Date: time.Now(),
|
|
||||||
ReplyTo: reminder.ReplyToID, // Reply to the original message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the reminder message
|
|
||||||
if err := p.SendMessage(message); err != nil {
|
|
||||||
a.logger.Error("Error sending reminder", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark the reminder as processed
|
|
||||||
if err := a.db.MarkReminderAsProcessed(reminder.ID); err != nil {
|
|
||||||
a.logger.Error("Error marking reminder as processed", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
83
internal/cache/cache.go
vendored
83
internal/cache/cache.go
vendored
|
@ -1,83 +0,0 @@
|
||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Cache provides a plugin-friendly interface to the cache system
|
|
||||||
type Cache struct {
|
|
||||||
db *db.Database
|
|
||||||
pluginID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new Cache instance for a specific plugin
|
|
||||||
func New(database *db.Database, pluginID string) *Cache {
|
|
||||||
return &Cache{
|
|
||||||
db: database,
|
|
||||||
pluginID: pluginID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get retrieves a value from the cache
|
|
||||||
func (c *Cache) Get(key string, destination interface{}) error {
|
|
||||||
// Create prefixed key
|
|
||||||
fullKey := c.createKey(key)
|
|
||||||
|
|
||||||
// Get from database
|
|
||||||
value, err := c.db.CacheGet(fullKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal JSON into destination
|
|
||||||
return json.Unmarshal([]byte(value), destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set stores a value in the cache with optional expiration
|
|
||||||
func (c *Cache) Set(key string, value interface{}, expiration *time.Time) error {
|
|
||||||
// Create prefixed key
|
|
||||||
fullKey := c.createKey(key)
|
|
||||||
|
|
||||||
// Marshal value to JSON
|
|
||||||
jsonValue, err := json.Marshal(value)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal cache value: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store in database
|
|
||||||
return c.db.CacheSet(fullKey, string(jsonValue), expiration)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWithTTL stores a value in the cache with a time-to-live duration
|
|
||||||
func (c *Cache) SetWithTTL(key string, value interface{}, ttl time.Duration) error {
|
|
||||||
expiration := time.Now().Add(ttl)
|
|
||||||
return c.Set(key, value, &expiration)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete removes a value from the cache
|
|
||||||
func (c *Cache) Delete(key string) error {
|
|
||||||
fullKey := c.createKey(key)
|
|
||||||
return c.db.CacheDelete(fullKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exists checks if a key exists in the cache
|
|
||||||
func (c *Cache) Exists(key string) (bool, error) {
|
|
||||||
fullKey := c.createKey(key)
|
|
||||||
_, err := c.db.CacheGet(fullKey)
|
|
||||||
if err == db.ErrNotFound {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createKey creates a prefixed cache key
|
|
||||||
func (c *Cache) createKey(key string) string {
|
|
||||||
return fmt.Sprintf("%s_%s", c.pluginID, key)
|
|
||||||
}
|
|
176
internal/cache/cache_test.go
vendored
176
internal/cache/cache_test.go
vendored
|
@ -1,176 +0,0 @@
|
||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCache(t *testing.T) {
|
|
||||||
// Create temporary database for testing with unique name
|
|
||||||
dbFile := fmt.Sprintf("test_cache_%d.db", time.Now().UnixNano())
|
|
||||||
database, err := db.New(dbFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = database.Close()
|
|
||||||
// Clean up test database file
|
|
||||||
_ = os.Remove(dbFile)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Create cache instance
|
|
||||||
cache := New(database, "test.plugin")
|
|
||||||
|
|
||||||
// Test data
|
|
||||||
testKey := "test_key"
|
|
||||||
testValue := map[string]interface{}{
|
|
||||||
"name": "Test Game",
|
|
||||||
"time": 42,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test Set and Get
|
|
||||||
t.Run("Set and Get", func(t *testing.T) {
|
|
||||||
err := cache.Set(testKey, testValue, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to set cache value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var retrieved map[string]interface{}
|
|
||||||
err = cache.Get(testKey, &retrieved)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to get cache value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieved["name"] != testValue["name"] {
|
|
||||||
t.Errorf("Expected name %v, got %v", testValue["name"], retrieved["name"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(retrieved["time"].(float64)) != testValue["time"].(int) {
|
|
||||||
t.Errorf("Expected time %v, got %v", testValue["time"], retrieved["time"])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test SetWithTTL and expiration
|
|
||||||
t.Run("SetWithTTL and expiration", func(t *testing.T) {
|
|
||||||
expiredKey := "expired_key"
|
|
||||||
|
|
||||||
// Set with very short TTL
|
|
||||||
err := cache.SetWithTTL(expiredKey, testValue, time.Millisecond)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to set cache value with TTL: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for expiration
|
|
||||||
time.Sleep(2 * time.Millisecond)
|
|
||||||
|
|
||||||
// Try to get - should fail
|
|
||||||
var retrieved map[string]interface{}
|
|
||||||
err = cache.Get(expiredKey, &retrieved)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected cache miss for expired key, but got value")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test Exists
|
|
||||||
t.Run("Exists", func(t *testing.T) {
|
|
||||||
existsKey := "exists_key"
|
|
||||||
|
|
||||||
// Make sure the key doesn't exist initially by deleting it
|
|
||||||
_ = cache.Delete(existsKey)
|
|
||||||
|
|
||||||
// Should not exist initially
|
|
||||||
exists, err := cache.Exists(existsKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to check if key exists: %v", err)
|
|
||||||
}
|
|
||||||
if exists {
|
|
||||||
t.Errorf("Expected key to not exist, but it does")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set value
|
|
||||||
err = cache.Set(existsKey, testValue, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to set cache value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should exist now
|
|
||||||
exists, err = cache.Exists(existsKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to check if key exists: %v", err)
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
t.Errorf("Expected key to exist, but it doesn't")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test Delete
|
|
||||||
t.Run("Delete", func(t *testing.T) {
|
|
||||||
deleteKey := "delete_key"
|
|
||||||
|
|
||||||
// Set value
|
|
||||||
err := cache.Set(deleteKey, testValue, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to set cache value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete value
|
|
||||||
err = cache.Delete(deleteKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to delete cache value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should not exist anymore
|
|
||||||
var retrieved map[string]interface{}
|
|
||||||
err = cache.Get(deleteKey, &retrieved)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected cache miss for deleted key, but got value")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test plugin ID prefixing
|
|
||||||
t.Run("Plugin ID prefixing", func(t *testing.T) {
|
|
||||||
cache1 := New(database, "plugin1")
|
|
||||||
cache2 := New(database, "plugin2")
|
|
||||||
|
|
||||||
sameKey := "same_key"
|
|
||||||
value1 := "value1"
|
|
||||||
value2 := "value2"
|
|
||||||
|
|
||||||
// Set same key in both caches
|
|
||||||
err := cache1.Set(sameKey, value1, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to set cache1 value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cache2.Set(sameKey, value2, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to set cache2 value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve from both caches
|
|
||||||
var retrieved1, retrieved2 string
|
|
||||||
|
|
||||||
err = cache1.Get(sameKey, &retrieved1)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to get cache1 value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cache2.Get(sameKey, &retrieved2)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to get cache2 value: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Values should be different due to plugin ID prefixing
|
|
||||||
if retrieved1 != value1 {
|
|
||||||
t.Errorf("Expected cache1 value %v, got %v", value1, retrieved1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieved2 != value2 {
|
|
||||||
t.Errorf("Expected cache2 value %v, got %v", value2, retrieved2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
|
@ -35,11 +34,6 @@ func New(dbPath string) (*Database, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure SQLite for better reliability
|
|
||||||
if err := configureSQLite(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize database
|
// Initialize database
|
||||||
if err := initDatabase(db); err != nil {
|
if err := initDatabase(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -56,7 +50,7 @@ func (d *Database) Close() error {
|
||||||
// GetChannelByID retrieves a channel by ID
|
// GetChannelByID retrieves a channel by ID
|
||||||
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||||
FROM channels
|
FROM channels
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
`
|
`
|
||||||
|
@ -67,11 +61,10 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||||
platform string
|
platform string
|
||||||
platformChannelID string
|
platformChannelID string
|
||||||
enabled bool
|
enabled bool
|
||||||
enableAllPlugins bool
|
|
||||||
channelRawJSON string
|
channelRawJSON string
|
||||||
)
|
)
|
||||||
|
|
||||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON)
|
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
@ -91,7 +84,6 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
PlatformChannelID: platformChannelID,
|
PlatformChannelID: platformChannelID,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
EnableAllPlugins: enableAllPlugins,
|
|
||||||
ChannelRaw: channelRaw,
|
ChannelRaw: channelRaw,
|
||||||
Plugins: make(map[string]*model.ChannelPlugin),
|
Plugins: make(map[string]*model.ChannelPlugin),
|
||||||
}
|
}
|
||||||
|
@ -112,7 +104,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) {
|
||||||
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
|
// GetChannelByPlatform retrieves a channel by platform and platform channel ID
|
||||||
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
|
func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||||
FROM channels
|
FROM channels
|
||||||
WHERE platform = ? AND platform_channel_id = ?
|
WHERE platform = ? AND platform_channel_id = ?
|
||||||
`
|
`
|
||||||
|
@ -120,13 +112,12 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
|
||||||
row := d.db.QueryRow(query, platform, platformChannelID)
|
row := d.db.QueryRow(query, platform, platformChannelID)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
id int64
|
id int64
|
||||||
enabled bool
|
enabled bool
|
||||||
enableAllPlugins bool
|
channelRawJSON string
|
||||||
channelRawJSON string
|
|
||||||
)
|
)
|
||||||
|
|
||||||
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON)
|
err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
@ -146,7 +137,6 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
PlatformChannelID: platformChannelID,
|
PlatformChannelID: platformChannelID,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
EnableAllPlugins: enableAllPlugins,
|
|
||||||
ChannelRaw: channelRaw,
|
ChannelRaw: channelRaw,
|
||||||
Plugins: make(map[string]*model.ChannelPlugin),
|
Plugins: make(map[string]*model.ChannelPlugin),
|
||||||
}
|
}
|
||||||
|
@ -174,11 +164,11 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
|
||||||
|
|
||||||
// Insert channel
|
// Insert channel
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO channels (platform, platform_channel_id, enabled, enable_all_plugins, channel_raw)
|
INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
|
|
||||||
result, err := d.db.Exec(query, platform, platformChannelID, enabled, false, string(channelRawJSON))
|
result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -195,7 +185,6 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
PlatformChannelID: platformChannelID,
|
PlatformChannelID: platformChannelID,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
EnableAllPlugins: false,
|
|
||||||
ChannelRaw: channelRaw,
|
ChannelRaw: channelRaw,
|
||||||
Plugins: make(map[string]*model.ChannelPlugin),
|
Plugins: make(map[string]*model.ChannelPlugin),
|
||||||
}
|
}
|
||||||
|
@ -215,18 +204,6 @@ func (d *Database) UpdateChannel(id int64, enabled bool) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChannelEnableAllPlugins updates a channel's enable_all_plugins status
|
|
||||||
func (d *Database) UpdateChannelEnableAllPlugins(id int64, enableAllPlugins bool) error {
|
|
||||||
query := `
|
|
||||||
UPDATE channels
|
|
||||||
SET enable_all_plugins = ?
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
_, err := d.db.Exec(query, enableAllPlugins, id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteChannel deletes a channel
|
// DeleteChannel deletes a channel
|
||||||
func (d *Database) DeleteChannel(id int64) error {
|
func (d *Database) DeleteChannel(id int64) error {
|
||||||
// First delete all channel plugins
|
// First delete all channel plugins
|
||||||
|
@ -256,11 +233,7 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer rows.Close()
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
fmt.Printf("Error closing rows: %v\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var plugins []*model.ChannelPlugin
|
var plugins []*model.ChannelPlugin
|
||||||
|
|
||||||
|
@ -278,7 +251,7 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse config JSON
|
// Parse config JSON
|
||||||
var config map[string]any
|
var config map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -305,28 +278,6 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
|
||||||
return plugins, nil
|
return plugins, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetChannelPluginsFromPlatformID retrieves all plugins for a channel by platform and platform channel ID
|
|
||||||
func (d *Database) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) {
|
|
||||||
// First, get the channel ID by platform and platform channel ID
|
|
||||||
query := `
|
|
||||||
SELECT id
|
|
||||||
FROM channels
|
|
||||||
WHERE platform = ? AND platform_channel_id = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
var channelID int64
|
|
||||||
err := d.db.QueryRow(query, platform, platformChannelID).Scan(&channelID)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now get the plugins for this channel
|
|
||||||
return d.GetChannelPlugins(channelID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetChannelPluginByID retrieves a channel plugin by ID
|
// GetChannelPluginByID retrieves a channel plugin by ID
|
||||||
func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) {
|
func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) {
|
||||||
query := `
|
query := `
|
||||||
|
@ -430,24 +381,6 @@ func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChannelPluginConfig updates a channel plugin's configuration
|
|
||||||
func (d *Database) UpdateChannelPluginConfig(id int64, config map[string]interface{}) error {
|
|
||||||
// Convert config to JSON
|
|
||||||
configJSON, err := json.Marshal(config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
query := `
|
|
||||||
UPDATE channel_plugin
|
|
||||||
SET config = ?
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
_, err = d.db.Exec(query, string(configJSON), id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteChannelPlugin deletes a channel plugin
|
// DeleteChannelPlugin deletes a channel plugin
|
||||||
func (d *Database) DeleteChannelPlugin(id int64) error {
|
func (d *Database) DeleteChannelPlugin(id int64) error {
|
||||||
query := `
|
query := `
|
||||||
|
@ -473,7 +406,7 @@ func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error {
|
||||||
// GetAllChannels retrieves all channels
|
// GetAllChannels retrieves all channels
|
||||||
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw
|
SELECT id, platform, platform_channel_id, enabled, channel_raw
|
||||||
FROM channels
|
FROM channels
|
||||||
`
|
`
|
||||||
|
|
||||||
|
@ -481,11 +414,7 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer rows.Close()
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
fmt.Printf("Error closing rows: %v\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var channels []*model.Channel
|
var channels []*model.Channel
|
||||||
|
|
||||||
|
@ -495,11 +424,10 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
platform string
|
platform string
|
||||||
platformChannelID string
|
platformChannelID string
|
||||||
enabled bool
|
enabled bool
|
||||||
enableAllPlugins bool
|
|
||||||
channelRawJSON string
|
channelRawJSON string
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON); err != nil {
|
if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -515,7 +443,6 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
PlatformChannelID: platformChannelID,
|
PlatformChannelID: platformChannelID,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
EnableAllPlugins: enableAllPlugins,
|
|
||||||
ChannelRaw: channelRaw,
|
ChannelRaw: channelRaw,
|
||||||
Plugins: make(map[string]*model.ChannelPlugin),
|
Plugins: make(map[string]*model.ChannelPlugin),
|
||||||
}
|
}
|
||||||
|
@ -526,9 +453,10 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
continue // Skip this channel if plugins can't be retrieved
|
continue // Skip this channel if plugins can't be retrieved
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add plugins to channel
|
if plugins != nil {
|
||||||
for _, plugin := range plugins {
|
for _, plugin := range plugins {
|
||||||
channel.Plugins[plugin.PluginID] = plugin
|
channel.Plugins[plugin.PluginID] = plugin
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
channels = append(channels, channel)
|
channels = append(channels, channel)
|
||||||
|
@ -663,124 +591,6 @@ func (d *Database) UpdateUserPassword(userID int64, newPassword string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateReminder creates a new reminder
|
|
||||||
func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
|
|
||||||
query := `
|
|
||||||
INSERT INTO reminders (
|
|
||||||
platform, channel_id, message_id, reply_to_id,
|
|
||||||
user_id, username, created_at, trigger_at,
|
|
||||||
content, processed
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0)
|
|
||||||
`
|
|
||||||
|
|
||||||
createdAt := time.Now()
|
|
||||||
result, err := d.db.Exec(
|
|
||||||
query,
|
|
||||||
platform, channelID, messageID, replyToID,
|
|
||||||
userID, username, createdAt, triggerAt,
|
|
||||||
content,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := result.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &model.Reminder{
|
|
||||||
ID: id,
|
|
||||||
Platform: platform,
|
|
||||||
ChannelID: channelID,
|
|
||||||
MessageID: messageID,
|
|
||||||
ReplyToID: replyToID,
|
|
||||||
UserID: userID,
|
|
||||||
Username: username,
|
|
||||||
CreatedAt: createdAt,
|
|
||||||
TriggerAt: triggerAt,
|
|
||||||
Content: content,
|
|
||||||
Processed: false,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPendingReminders gets all pending reminders that need to be processed
|
|
||||||
func (d *Database) GetPendingReminders() ([]*model.Reminder, error) {
|
|
||||||
query := `
|
|
||||||
SELECT id, platform, channel_id, message_id, reply_to_id,
|
|
||||||
user_id, username, created_at, trigger_at, content, processed
|
|
||||||
FROM reminders
|
|
||||||
WHERE processed = 0 AND trigger_at <= ?
|
|
||||||
`
|
|
||||||
|
|
||||||
rows, err := d.db.Query(query, time.Now())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
fmt.Printf("Error closing rows: %v\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var reminders []*model.Reminder
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var (
|
|
||||||
id int64
|
|
||||||
platform, channelID, messageID, replyToID string
|
|
||||||
userID, username, content string
|
|
||||||
createdAt, triggerAt time.Time
|
|
||||||
processed bool
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := rows.Scan(
|
|
||||||
&id, &platform, &channelID, &messageID, &replyToID,
|
|
||||||
&userID, &username, &createdAt, &triggerAt, &content, &processed,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
reminder := &model.Reminder{
|
|
||||||
ID: id,
|
|
||||||
Platform: platform,
|
|
||||||
ChannelID: channelID,
|
|
||||||
MessageID: messageID,
|
|
||||||
ReplyToID: replyToID,
|
|
||||||
UserID: userID,
|
|
||||||
Username: username,
|
|
||||||
CreatedAt: createdAt,
|
|
||||||
TriggerAt: triggerAt,
|
|
||||||
Content: content,
|
|
||||||
Processed: processed,
|
|
||||||
}
|
|
||||||
|
|
||||||
reminders = append(reminders, reminder)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reminders) == 0 {
|
|
||||||
return make([]*model.Reminder, 0), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return reminders, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkReminderAsProcessed marks a reminder as processed
|
|
||||||
func (d *Database) MarkReminderAsProcessed(id int64) error {
|
|
||||||
query := `
|
|
||||||
UPDATE reminders
|
|
||||||
SET processed = 1
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
_, err := d.db.Exec(query, id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to hash password
|
// Helper function to hash password
|
||||||
func hashPassword(password string) (string, error) {
|
func hashPassword(password string) (string, error) {
|
||||||
// Use bcrypt for secure password hashing
|
// Use bcrypt for secure password hashing
|
||||||
|
@ -799,25 +609,25 @@ func initDatabase(db *sql.DB) error {
|
||||||
if err := migration.EnsureMigrationTable(db); err != nil {
|
if err := migration.EnsureMigrationTable(db); err != nil {
|
||||||
return fmt.Errorf("failed to create migration table: %w", err)
|
return fmt.Errorf("failed to create migration table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get applied migrations
|
// Get applied migrations
|
||||||
applied, err := migration.GetAppliedMigrations(db)
|
applied, err := migration.GetAppliedMigrations(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get applied migrations: %w", err)
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all migration versions
|
// Get all migration versions
|
||||||
allMigrations := make([]int, 0, len(migration.Migrations))
|
allMigrations := make([]int, 0, len(migration.Migrations))
|
||||||
for version := range migration.Migrations {
|
for version := range migration.Migrations {
|
||||||
allMigrations = append(allMigrations, version)
|
allMigrations = append(allMigrations, version)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a map of applied migrations for quick lookup
|
// Create a map of applied migrations for quick lookup
|
||||||
appliedMap := make(map[int]bool)
|
appliedMap := make(map[int]bool)
|
||||||
for _, version := range applied {
|
for _, version := range applied {
|
||||||
appliedMap[version] = true
|
appliedMap[version] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count pending migrations
|
// Count pending migrations
|
||||||
pendingCount := 0
|
pendingCount := 0
|
||||||
for _, version := range allMigrations {
|
for _, version := range allMigrations {
|
||||||
|
@ -825,7 +635,7 @@ func initDatabase(db *sql.DB) error {
|
||||||
pendingCount++
|
pendingCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run migrations if needed
|
// Run migrations if needed
|
||||||
if pendingCount > 0 {
|
if pendingCount > 0 {
|
||||||
fmt.Printf("Running %d pending database migrations...\n", pendingCount)
|
fmt.Printf("Running %d pending database migrations...\n", pendingCount)
|
||||||
|
@ -836,85 +646,6 @@ func initDatabase(db *sql.DB) error {
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Database schema is up to date.")
|
fmt.Println("Database schema is up to date.")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure SQLite for better reliability
|
|
||||||
func configureSQLite(db *sql.DB) error {
|
|
||||||
pragmas := []string{
|
|
||||||
// Enable Write-Ahead Logging for better concurrency and crash recovery
|
|
||||||
"PRAGMA journal_mode = WAL",
|
|
||||||
// Set 5-second timeout when database is locked by another connection
|
|
||||||
"PRAGMA busy_timeout = 5000",
|
|
||||||
// Balance between safety and performance for disk writes
|
|
||||||
"PRAGMA synchronous = NORMAL",
|
|
||||||
// Set large cache size (1GB) for better read performance
|
|
||||||
"PRAGMA cache_size = 1000000000",
|
|
||||||
// Enable foreign key constraint enforcement
|
|
||||||
"PRAGMA foreign_keys = true",
|
|
||||||
// Store temporary tables and indices in memory for speed
|
|
||||||
"PRAGMA temp_store = memory",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, pragma := range pragmas {
|
|
||||||
if _, err := db.Exec(pragma); err != nil {
|
|
||||||
return fmt.Errorf("failed to execute %s: %w", pragma, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheGet retrieves a value from the cache
|
|
||||||
func (d *Database) CacheGet(key string) (string, error) {
|
|
||||||
query := `
|
|
||||||
SELECT value
|
|
||||||
FROM cache
|
|
||||||
WHERE key = ? AND (expires_at IS NULL OR expires_at > ?)
|
|
||||||
`
|
|
||||||
|
|
||||||
var value string
|
|
||||||
err := d.db.QueryRow(query, key, time.Now()).Scan(&value)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return "", ErrNotFound
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheSet stores a value in the cache with optional expiration
|
|
||||||
func (d *Database) CacheSet(key, value string, expiration *time.Time) error {
|
|
||||||
query := `
|
|
||||||
INSERT OR REPLACE INTO cache (key, value, expires_at, updated_at)
|
|
||||||
VALUES (?, ?, ?, ?)
|
|
||||||
`
|
|
||||||
|
|
||||||
_, err := d.db.Exec(query, key, value, expiration, time.Now())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheDelete removes a value from the cache
|
|
||||||
func (d *Database) CacheDelete(key string) error {
|
|
||||||
query := `
|
|
||||||
DELETE FROM cache
|
|
||||||
WHERE key = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
_, err := d.db.Exec(query, key)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheCleanup removes expired cache entries
|
|
||||||
func (d *Database) CacheCleanup() error {
|
|
||||||
query := `
|
|
||||||
DELETE FROM cache
|
|
||||||
WHERE expires_at IS NOT NULL AND expires_at <= ?
|
|
||||||
`
|
|
||||||
|
|
||||||
_, err := d.db.Exec(query, time.Now())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,203 +0,0 @@
|
||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEnableAllPlugins(t *testing.T) {
|
|
||||||
// Create temporary database for testing with unique name
|
|
||||||
dbFile := fmt.Sprintf("test_db_%d.db", time.Now().UnixNano())
|
|
||||||
database, err := New(dbFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = database.Close()
|
|
||||||
// Clean up test database file
|
|
||||||
_ = os.Remove(dbFile)
|
|
||||||
}()
|
|
||||||
|
|
||||||
t.Run("CreateChannel with EnableAllPlugins default false", func(t *testing.T) {
|
|
||||||
channelRaw := map[string]interface{}{
|
|
||||||
"name": "test-channel",
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := database.CreateChannel("telegram", "123456", true, channelRaw)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if channel.EnableAllPlugins {
|
|
||||||
t.Errorf("Expected EnableAllPlugins to be false by default, got true")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify it's also false when retrieved from database
|
|
||||||
retrieved, err := database.GetChannelByID(channel.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to retrieve channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieved.EnableAllPlugins {
|
|
||||||
t.Errorf("Expected EnableAllPlugins to be false when retrieved from DB, got true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("UpdateChannelEnableAllPlugins", func(t *testing.T) {
|
|
||||||
// Create a channel
|
|
||||||
channelRaw := map[string]interface{}{
|
|
||||||
"name": "test-channel-2",
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := database.CreateChannel("telegram", "123457", true, channelRaw)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update EnableAllPlugins to true
|
|
||||||
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve and verify
|
|
||||||
retrieved, err := database.GetChannelByID(channel.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to retrieve channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !retrieved.EnableAllPlugins {
|
|
||||||
t.Errorf("Expected EnableAllPlugins to be true after update, got false")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update back to false
|
|
||||||
err = database.UpdateChannelEnableAllPlugins(channel.ID, false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to update EnableAllPlugins back to false: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve and verify again
|
|
||||||
retrieved, err = database.GetChannelByID(channel.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to retrieve channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieved.EnableAllPlugins {
|
|
||||||
t.Errorf("Expected EnableAllPlugins to be false after second update, got true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("GetChannelByPlatform includes EnableAllPlugins", func(t *testing.T) {
|
|
||||||
// Create a channel
|
|
||||||
channelRaw := map[string]interface{}{
|
|
||||||
"name": "test-channel-3",
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := database.CreateChannel("slack", "C123456", true, channelRaw)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable all plugins
|
|
||||||
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to update EnableAllPlugins: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve by platform
|
|
||||||
retrieved, err := database.GetChannelByPlatform("slack", "C123456")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to retrieve channel by platform: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !retrieved.EnableAllPlugins {
|
|
||||||
t.Errorf("Expected EnableAllPlugins to be true when retrieved by platform, got false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("GetAllChannels includes EnableAllPlugins", func(t *testing.T) {
|
|
||||||
// Create multiple channels with different EnableAllPlugins settings
|
|
||||||
channelRaw1 := map[string]interface{}{"name": "channel-1"}
|
|
||||||
channelRaw2 := map[string]interface{}{"name": "channel-2"}
|
|
||||||
|
|
||||||
channel1, err := database.CreateChannel("platform1", "ch1", true, channelRaw1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create channel1: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
channel2, err := database.CreateChannel("platform2", "ch2", true, channelRaw2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create channel2: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable all plugins for channel2 only
|
|
||||||
err = database.UpdateChannelEnableAllPlugins(channel2.ID, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to update EnableAllPlugins for channel2: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all channels
|
|
||||||
channels, err := database.GetAllChannels()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get all channels: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find our test channels
|
|
||||||
var foundChannel1, foundChannel2 *model.Channel
|
|
||||||
for _, ch := range channels {
|
|
||||||
if ch.ID == channel1.ID {
|
|
||||||
foundChannel1 = ch
|
|
||||||
}
|
|
||||||
if ch.ID == channel2.ID {
|
|
||||||
foundChannel2 = ch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if foundChannel1 == nil {
|
|
||||||
t.Fatalf("Channel1 not found in GetAllChannels result")
|
|
||||||
}
|
|
||||||
if foundChannel2 == nil {
|
|
||||||
t.Fatalf("Channel2 not found in GetAllChannels result")
|
|
||||||
}
|
|
||||||
|
|
||||||
if foundChannel1.EnableAllPlugins {
|
|
||||||
t.Errorf("Expected channel1 EnableAllPlugins to be false, got true")
|
|
||||||
}
|
|
||||||
if !foundChannel2.EnableAllPlugins {
|
|
||||||
t.Errorf("Expected channel2 EnableAllPlugins to be true, got false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Migration applied correctly", func(t *testing.T) {
|
|
||||||
// Test that we can create a channel and the enable_all_plugins column exists
|
|
||||||
// This implicitly tests that migration 4 was applied correctly
|
|
||||||
channelRaw := map[string]interface{}{
|
|
||||||
"name": "migration-test-channel",
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := database.CreateChannel("test-platform", "migration-test", true, channelRaw)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create channel after migration: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to update EnableAllPlugins - this would fail if the column doesn't exist
|
|
||||||
err = database.UpdateChannelEnableAllPlugins(channel.ID, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to update EnableAllPlugins - migration may not have been applied: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the value was set correctly
|
|
||||||
retrieved, err := database.GetChannelByID(channel.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to retrieve channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !retrieved.EnableAllPlugins {
|
|
||||||
t.Errorf("EnableAllPlugins should be true after update")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -49,11 +49,7 @@ func GetAppliedMigrations(db *sql.DB) ([]int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer rows.Close()
|
||||||
if err := rows.Close(); err != nil {
|
|
||||||
fmt.Printf("Error closing rows: %v\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var versions []int
|
var versions []int
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
@ -132,9 +128,7 @@ func Migrate(db *sql.DB) error {
|
||||||
|
|
||||||
// Apply the migration
|
// Apply the migration
|
||||||
if err := migration.Up(db); err != nil {
|
if err := migration.Up(db); err != nil {
|
||||||
if err := tx.Rollback(); err != nil {
|
tx.Rollback()
|
||||||
fmt.Printf("Error rolling back transaction: %v\n", err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to apply migration %d: %w", version, err)
|
return fmt.Errorf("failed to apply migration %d: %w", version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,9 +137,7 @@ func Migrate(db *sql.DB) error {
|
||||||
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
||||||
version, time.Now(),
|
version, time.Now(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
if err := tx.Rollback(); err != nil {
|
tx.Rollback()
|
||||||
fmt.Printf("Error rolling back transaction: %v\n", err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to mark migration %d as applied: %w", version, err)
|
return fmt.Errorf("failed to mark migration %d as applied: %w", version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,17 +188,13 @@ func MigrateDown(db *sql.DB, targetVersion int) error {
|
||||||
|
|
||||||
// Apply the down migration
|
// Apply the down migration
|
||||||
if err := migration.Down(db); err != nil {
|
if err := migration.Down(db); err != nil {
|
||||||
if err := tx.Rollback(); err != nil {
|
tx.Rollback()
|
||||||
fmt.Printf("Error rolling back transaction: %v\n", err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to roll back migration %d: %w", version, err)
|
return fmt.Errorf("failed to roll back migration %d: %w", version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove from applied list
|
// Remove from applied list
|
||||||
if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil {
|
if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil {
|
||||||
if err := tx.Rollback(); err != nil {
|
tx.Rollback()
|
||||||
fmt.Printf("Error rolling back transaction: %v\n", err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err)
|
return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,4 +208,4 @@ func MigrateDown(db *sql.DB, targetVersion int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
|
@ -8,9 +8,6 @@ import (
|
||||||
func init() {
|
func init() {
|
||||||
// Register migrations
|
// Register migrations
|
||||||
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
|
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
|
||||||
Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown)
|
|
||||||
Register(3, "Add cache table", migrateCacheUp, migrateCacheDown)
|
|
||||||
Register(4, "Add enable_all_plugins column to channels", migrateEnableAllPluginsUp, migrateEnableAllPluginsDown)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initial schema creation with bcrypt passwords - version 1
|
// Initial schema creation with bcrypt passwords - version 1
|
||||||
|
@ -63,14 +60,14 @@ func migrateInitialSchemaUp(db *sql.DB) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if users table is empty before inserting
|
// Check if users table is empty before inserting
|
||||||
var count int
|
var count int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
_, err = db.Exec(
|
_, err = db.Exec(
|
||||||
"INSERT INTO users (username, password) VALUES (?, ?)",
|
"INSERT INTO users (username, password) VALUES (?, ?)",
|
||||||
|
@ -102,113 +99,4 @@ func migrateInitialSchemaDown(db *sql.DB) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add reminders table - version 2
|
|
||||||
func migrateRemindersUp(db *sql.DB) error {
|
|
||||||
_, err := db.Exec(`
|
|
||||||
CREATE TABLE IF NOT EXISTS reminders (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
platform TEXT NOT NULL,
|
|
||||||
channel_id TEXT NOT NULL,
|
|
||||||
message_id TEXT NOT NULL,
|
|
||||||
reply_to_id TEXT NOT NULL,
|
|
||||||
user_id TEXT NOT NULL,
|
|
||||||
username TEXT NOT NULL,
|
|
||||||
created_at TIMESTAMP NOT NULL,
|
|
||||||
trigger_at TIMESTAMP NOT NULL,
|
|
||||||
content TEXT NOT NULL,
|
|
||||||
processed BOOLEAN NOT NULL DEFAULT 0
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func migrateRemindersDown(db *sql.DB) error {
|
|
||||||
_, err := db.Exec(`DROP TABLE IF EXISTS reminders`)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add cache table - version 3
|
|
||||||
func migrateCacheUp(db *sql.DB) error {
|
|
||||||
_, err := db.Exec(`
|
|
||||||
CREATE TABLE IF NOT EXISTS cache (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
value TEXT NOT NULL,
|
|
||||||
expires_at TIMESTAMP,
|
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create index on expires_at for efficient cleanup
|
|
||||||
_, err = db.Exec(`
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_cache_expires_at ON cache(expires_at)
|
|
||||||
`)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func migrateCacheDown(db *sql.DB) error {
|
|
||||||
_, err := db.Exec(`DROP TABLE IF EXISTS cache`)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add enable_all_plugins column to channels table - version 4
|
|
||||||
func migrateEnableAllPluginsUp(db *sql.DB) error {
|
|
||||||
_, err := db.Exec(`
|
|
||||||
ALTER TABLE channels ADD COLUMN enable_all_plugins BOOLEAN NOT NULL DEFAULT 0
|
|
||||||
`)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func migrateEnableAllPluginsDown(db *sql.DB) error {
|
|
||||||
// SQLite doesn't support DROP COLUMN, so we need to recreate the table
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = tx.Rollback() // Ignore rollback errors
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Create backup table
|
|
||||||
_, err = tx.Exec(`
|
|
||||||
CREATE TABLE channels_backup (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
platform TEXT NOT NULL,
|
|
||||||
platform_channel_id TEXT NOT NULL,
|
|
||||||
enabled BOOLEAN NOT NULL DEFAULT 0,
|
|
||||||
channel_raw TEXT NOT NULL,
|
|
||||||
UNIQUE(platform, platform_channel_id)
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy data excluding enable_all_plugins column
|
|
||||||
_, err = tx.Exec(`
|
|
||||||
INSERT INTO channels_backup (id, platform, platform_channel_id, enabled, channel_raw)
|
|
||||||
SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Drop original table
|
|
||||||
_, err = tx.Exec(`DROP TABLE channels`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rename backup table
|
|
||||||
_, err = tx.Exec(`ALTER TABLE channels_backup RENAME TO channels`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
|
@ -4,57 +4,31 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ActionType defines the type of action to perform
|
|
||||||
type ActionType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ActionSendMessage is for sending a message to the chat
|
|
||||||
ActionSendMessage ActionType = "send_message"
|
|
||||||
// ActionDeleteMessage is for deleting a message from the chat
|
|
||||||
ActionDeleteMessage ActionType = "delete_message"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MessageAction represents an action to be performed on the platform
|
|
||||||
type MessageAction struct {
|
|
||||||
Type ActionType
|
|
||||||
Message *Message // For send_message
|
|
||||||
MessageID string // For delete_message
|
|
||||||
Chat string // Chat where the action happens
|
|
||||||
Channel *Channel // Channel reference
|
|
||||||
Raw map[string]interface{} // Additional data for the action
|
|
||||||
}
|
|
||||||
|
|
||||||
// Message represents a chat message
|
// Message represents a chat message
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Text string
|
Text string
|
||||||
Chat string
|
Chat string
|
||||||
Channel *Channel
|
Channel *Channel
|
||||||
Author string
|
Author string
|
||||||
FromBot bool
|
FromBot bool
|
||||||
Date time.Time
|
Date time.Time
|
||||||
ID string
|
ID string
|
||||||
ReplyTo string
|
ReplyTo string
|
||||||
Raw map[string]interface{}
|
Raw map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Channel represents a chat channel
|
// Channel represents a chat channel
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
ID int64
|
ID int64
|
||||||
Platform string
|
Platform string
|
||||||
PlatformChannelID string
|
PlatformChannelID string
|
||||||
ChannelRaw map[string]interface{}
|
ChannelRaw map[string]interface{}
|
||||||
Enabled bool
|
Enabled bool
|
||||||
EnableAllPlugins bool
|
Plugins map[string]*ChannelPlugin
|
||||||
Plugins map[string]*ChannelPlugin
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasEnabledPlugin checks if a plugin is enabled for this channel
|
// HasEnabledPlugin checks if a plugin is enabled for this channel
|
||||||
func (c *Channel) HasEnabledPlugin(pluginID string) bool {
|
func (c *Channel) HasEnabledPlugin(pluginID string) bool {
|
||||||
// If EnableAllPlugins is true, all plugins are considered enabled
|
|
||||||
if c.EnableAllPlugins {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
plugin, exists := c.Plugins[pluginID]
|
plugin, exists := c.Plugins[pluginID]
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
|
@ -66,18 +40,18 @@ func (c *Channel) HasEnabledPlugin(pluginID string) bool {
|
||||||
func (c *Channel) ChannelName() string {
|
func (c *Channel) ChannelName() string {
|
||||||
// In a real implementation, this would use the platform-specific
|
// In a real implementation, this would use the platform-specific
|
||||||
// ParseChannelNameFromRaw function
|
// ParseChannelNameFromRaw function
|
||||||
|
|
||||||
// For simplicity, we'll just use the PlatformChannelID if we can't extract a name
|
// For simplicity, we'll just use the PlatformChannelID if we can't extract a name
|
||||||
// Check if ChannelRaw has a name field
|
// Check if ChannelRaw has a name field
|
||||||
if c.ChannelRaw == nil {
|
if c.ChannelRaw == nil {
|
||||||
return c.PlatformChannelID
|
return c.PlatformChannelID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check common name fields in ChannelRaw
|
// Check common name fields in ChannelRaw
|
||||||
if name, ok := c.ChannelRaw["name"].(string); ok && name != "" {
|
if name, ok := c.ChannelRaw["name"].(string); ok && name != "" {
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for nested objects like "chat" (used by Telegram)
|
// Check for nested objects like "chat" (used by Telegram)
|
||||||
if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok {
|
if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok {
|
||||||
// Try different fields in order of preference
|
// Try different fields in order of preference
|
||||||
|
@ -91,7 +65,7 @@ func (c *Channel) ChannelName() string {
|
||||||
return firstName
|
return firstName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.PlatformChannelID
|
return c.PlatformChannelID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +75,7 @@ type ChannelPlugin struct {
|
||||||
ChannelID int64
|
ChannelID int64
|
||||||
PluginID string
|
PluginID string
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Config map[string]any
|
Config map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// User represents an admin user
|
// User represents an admin user
|
||||||
|
@ -109,19 +83,4 @@ type User struct {
|
||||||
ID int64
|
ID int64
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reminder represents a scheduled reminder
|
|
||||||
type Reminder struct {
|
|
||||||
ID int64
|
|
||||||
Platform string
|
|
||||||
ChannelID string
|
|
||||||
MessageID string
|
|
||||||
ReplyToID string
|
|
||||||
UserID string
|
|
||||||
Username string
|
|
||||||
CreatedAt time.Time
|
|
||||||
TriggerAt time.Time
|
|
||||||
Content string
|
|
||||||
Processed bool
|
|
||||||
}
|
|
|
@ -1,234 +0,0 @@
|
||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestChannel_HasEnabledPlugin(t *testing.T) {
|
|
||||||
t.Run("EnableAllPlugins false - plugin not in map", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: false,
|
|
||||||
Plugins: make(map[string]*ChannelPlugin),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Plugin not in map should return false
|
|
||||||
result := channel.HasEnabledPlugin("nonexistent.plugin")
|
|
||||||
if result {
|
|
||||||
t.Errorf("Expected HasEnabledPlugin to return false for nonexistent plugin, got true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EnableAllPlugins false - plugin disabled", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: false,
|
|
||||||
Plugins: map[string]*ChannelPlugin{
|
|
||||||
"test.plugin": {
|
|
||||||
ID: 1,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "test.plugin",
|
|
||||||
Enabled: false,
|
|
||||||
Config: make(map[string]any),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disabled plugin should return false
|
|
||||||
result := channel.HasEnabledPlugin("test.plugin")
|
|
||||||
if result {
|
|
||||||
t.Errorf("Expected HasEnabledPlugin to return false for disabled plugin, got true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EnableAllPlugins false - plugin enabled", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: false,
|
|
||||||
Plugins: map[string]*ChannelPlugin{
|
|
||||||
"test.plugin": {
|
|
||||||
ID: 1,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "test.plugin",
|
|
||||||
Enabled: true,
|
|
||||||
Config: make(map[string]any),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enabled plugin should return true
|
|
||||||
result := channel.HasEnabledPlugin("test.plugin")
|
|
||||||
if !result {
|
|
||||||
t.Errorf("Expected HasEnabledPlugin to return true for enabled plugin, got false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EnableAllPlugins true - plugin not in map", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: true,
|
|
||||||
Plugins: make(map[string]*ChannelPlugin),
|
|
||||||
}
|
|
||||||
|
|
||||||
// When EnableAllPlugins is true, any plugin should be considered enabled
|
|
||||||
result := channel.HasEnabledPlugin("nonexistent.plugin")
|
|
||||||
if !result {
|
|
||||||
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EnableAllPlugins true - plugin disabled", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: true,
|
|
||||||
Plugins: map[string]*ChannelPlugin{
|
|
||||||
"test.plugin": {
|
|
||||||
ID: 1,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "test.plugin",
|
|
||||||
Enabled: false,
|
|
||||||
Config: make(map[string]any),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// When EnableAllPlugins is true, even disabled plugins should be considered enabled
|
|
||||||
result := channel.HasEnabledPlugin("test.plugin")
|
|
||||||
if !result {
|
|
||||||
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true (even for disabled plugin), got false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EnableAllPlugins true - plugin enabled", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: true,
|
|
||||||
Plugins: map[string]*ChannelPlugin{
|
|
||||||
"test.plugin": {
|
|
||||||
ID: 1,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "test.plugin",
|
|
||||||
Enabled: true,
|
|
||||||
Config: make(map[string]any),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// When EnableAllPlugins is true, enabled plugins should also return true
|
|
||||||
result := channel.HasEnabledPlugin("test.plugin")
|
|
||||||
if !result {
|
|
||||||
t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EnableAllPlugins true - multiple plugins", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: true,
|
|
||||||
Plugins: map[string]*ChannelPlugin{
|
|
||||||
"plugin1": {
|
|
||||||
ID: 1,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "plugin1",
|
|
||||||
Enabled: true,
|
|
||||||
Config: make(map[string]any),
|
|
||||||
},
|
|
||||||
"plugin2": {
|
|
||||||
ID: 2,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "plugin2",
|
|
||||||
Enabled: false,
|
|
||||||
Config: make(map[string]any),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// All plugins should be enabled when EnableAllPlugins is true
|
|
||||||
testCases := []string{"plugin1", "plugin2", "plugin3", "any.plugin"}
|
|
||||||
for _, pluginID := range testCases {
|
|
||||||
result := channel.HasEnabledPlugin(pluginID)
|
|
||||||
if !result {
|
|
||||||
t.Errorf("Expected HasEnabledPlugin('%s') to return true when EnableAllPlugins is true, got false", pluginID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChannelName(t *testing.T) {
|
|
||||||
t.Run("Returns PlatformChannelID when ChannelRaw is nil", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
PlatformChannelID: "test-id",
|
|
||||||
ChannelRaw: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
result := channel.ChannelName()
|
|
||||||
if result != "test-id" {
|
|
||||||
t.Errorf("Expected channel name to be 'test-id', got '%s'", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Returns name from ChannelRaw when available", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
PlatformChannelID: "test-id",
|
|
||||||
ChannelRaw: map[string]interface{}{
|
|
||||||
"name": "Test Channel",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result := channel.ChannelName()
|
|
||||||
if result != "Test Channel" {
|
|
||||||
t.Errorf("Expected channel name to be 'Test Channel', got '%s'", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Returns title from nested chat object (Telegram style)", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
PlatformChannelID: "test-id",
|
|
||||||
ChannelRaw: map[string]interface{}{
|
|
||||||
"chat": map[string]interface{}{
|
|
||||||
"title": "Telegram Group",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result := channel.ChannelName()
|
|
||||||
if result != "Telegram Group" {
|
|
||||||
t.Errorf("Expected channel name to be 'Telegram Group', got '%s'", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Falls back to PlatformChannelID when no valid name found", func(t *testing.T) {
|
|
||||||
channel := &Channel{
|
|
||||||
PlatformChannelID: "fallback-id",
|
|
||||||
ChannelRaw: map[string]interface{}{
|
|
||||||
"other_field": "value",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result := channel.ChannelName()
|
|
||||||
if result != "fallback-id" {
|
|
||||||
t.Errorf("Expected channel name to fallback to 'fallback-id', got '%s'", result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -43,7 +43,4 @@ type Platform interface {
|
||||||
|
|
||||||
// SendMessage sends a message through the platform
|
// SendMessage sends a message through the platform
|
||||||
SendMessage(msg *Message) error
|
SendMessage(msg *Message) error
|
||||||
|
|
||||||
// DeleteMessage deletes a message from the platform
|
|
||||||
DeleteMessage(channel string, messageID string) error
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,18 +2,8 @@ package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CacheInterface defines the cache interface available to plugins
|
|
||||||
type CacheInterface interface {
|
|
||||||
Get(key string, destination interface{}) error
|
|
||||||
Set(key string, value interface{}, expiration *time.Time) error
|
|
||||||
SetWithTTL(key string, value interface{}, ttl time.Duration) error
|
|
||||||
Delete(key string) error
|
|
||||||
Exists(key string) (bool, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrPluginNotFound is returned when a requested plugin doesn't exist
|
// ErrPluginNotFound is returned when a requested plugin doesn't exist
|
||||||
ErrPluginNotFound = errors.New("plugin not found")
|
ErrPluginNotFound = errors.New("plugin not found")
|
||||||
|
@ -23,16 +13,16 @@ var (
|
||||||
type Plugin interface {
|
type Plugin interface {
|
||||||
// GetID returns the plugin ID
|
// GetID returns the plugin ID
|
||||||
GetID() string
|
GetID() string
|
||||||
|
|
||||||
// GetName returns the plugin name
|
// GetName returns the plugin name
|
||||||
GetName() string
|
GetName() string
|
||||||
|
|
||||||
// GetHelp returns the plugin help text
|
// GetHelp returns the plugin help text
|
||||||
GetHelp() string
|
GetHelp() string
|
||||||
|
|
||||||
// RequiresConfig indicates if the plugin requires configuration
|
// RequiresConfig indicates if the plugin requires configuration
|
||||||
RequiresConfig() bool
|
RequiresConfig() bool
|
||||||
|
|
||||||
// OnMessage processes an incoming message and returns platform actions
|
// OnMessage processes an incoming message and returns response messages
|
||||||
OnMessage(msg *Message, config map[string]interface{}, cache CacheInterface) []*MessageAction
|
OnMessage(msg *Message, config map[string]interface{}) []*Message
|
||||||
}
|
}
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -37,15 +37,11 @@ func (s *SlackPlatform) Init(_ *config.Config) error {
|
||||||
// ParseIncomingMessage parses an incoming Slack message
|
// ParseIncomingMessage parses an incoming Slack message
|
||||||
func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) {
|
func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) {
|
||||||
// Read request body
|
// Read request body
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer r.Body.Close()
|
||||||
if err := r.Body.Close(); err != nil {
|
|
||||||
fmt.Printf("Error closing request body: %v\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Parse JSON
|
// Parse JSON
|
||||||
var requestData map[string]interface{}
|
var requestData map[string]interface{}
|
||||||
|
@ -167,12 +163,6 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error {
|
||||||
return errors.New("bot token not configured")
|
return errors.New("bot token not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for delete message action
|
|
||||||
if msg.Raw != nil && msg.Raw["action"] == "delete" {
|
|
||||||
// This is a request to delete a message
|
|
||||||
return s.deleteMessage(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare payload
|
// Prepare payload
|
||||||
payload := map[string]interface{}{
|
payload := map[string]interface{}{
|
||||||
"channel": msg.Chat,
|
"channel": msg.Chat,
|
||||||
|
@ -204,11 +194,7 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer resp.Body.Close()
|
||||||
if err := resp.Body.Close(); err != nil {
|
|
||||||
fmt.Printf("Error closing response body: %v\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
@ -218,63 +204,6 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMessage deletes a message on Slack
|
|
||||||
func (s *SlackPlatform) DeleteMessage(channel string, messageID string) error {
|
|
||||||
// Prepare payload for chat.delete API
|
|
||||||
payload := map[string]interface{}{
|
|
||||||
"channel": channel,
|
|
||||||
"ts": messageID, // In Slack, the ts (timestamp) is the message ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert payload to JSON
|
|
||||||
data, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send HTTP request to chat.delete endpoint
|
|
||||||
req, err := http.NewRequest("POST", "https://slack.com/api/chat.delete", strings.NewReader(string(data)))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.config.BotOAuthAccessToken))
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := resp.Body.Close(); err != nil {
|
|
||||||
fmt.Printf("Error closing response body: %v\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Check response
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
|
||||||
return fmt.Errorf("slack API error: %d - %s", resp.StatusCode, string(respBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteMessage is a legacy method that uses the Raw message approach
|
|
||||||
func (s *SlackPlatform) deleteMessage(msg *model.Message) error {
|
|
||||||
// Get message ID to delete
|
|
||||||
messageID, ok := msg.Raw["message_id"]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("no message ID provided for deletion")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to string if needed
|
|
||||||
messageIDStr := fmt.Sprintf("%v", messageID)
|
|
||||||
|
|
||||||
return s.DeleteMessage(msg.Chat, messageIDStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to parse int64
|
// Helper function to parse int64
|
||||||
func parseInt64(s string) (int64, error) {
|
func parseInt64(s string) (int64, error) {
|
||||||
var n int64
|
var n int64
|
||||||
|
|
|
@ -62,11 +62,7 @@ func (t *TelegramPlatform) Init(cfg *config.Config) error {
|
||||||
t.log.Error("Failed to set webhook", "error", err)
|
t.log.Error("Failed to set webhook", "error", err)
|
||||||
return fmt.Errorf("failed to set webhook: %w", err)
|
return fmt.Errorf("failed to set webhook: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer resp.Body.Close()
|
||||||
if err := resp.Body.Close(); err != nil {
|
|
||||||
t.log.Error("Error closing response body", "error", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
@ -89,11 +85,7 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
|
||||||
t.log.Error("Failed to read request body", "error", err)
|
t.log.Error("Failed to read request body", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer r.Body.Close()
|
||||||
if err := r.Body.Close(); err != nil {
|
|
||||||
t.log.Error("Error closing request body", "error", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Parse JSON
|
// Parse JSON
|
||||||
var update struct {
|
var update struct {
|
||||||
|
@ -111,11 +103,8 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
|
||||||
Title string `json:"title,omitempty"`
|
Title string `json:"title,omitempty"`
|
||||||
Username string `json:"username,omitempty"`
|
Username string `json:"username,omitempty"`
|
||||||
} `json:"chat"`
|
} `json:"chat"`
|
||||||
Date int `json:"date"`
|
Date int `json:"date"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
ReplyToMessage struct {
|
|
||||||
MessageID int `json:"message_id"`
|
|
||||||
} `json:"reply_to_message"`
|
|
||||||
} `json:"message"`
|
} `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,7 +128,6 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
|
||||||
FromBot: update.Message.From.IsBot,
|
FromBot: update.Message.From.IsBot,
|
||||||
Date: time.Unix(int64(update.Message.Date), 0),
|
Date: time.Unix(int64(update.Message.Date), 0),
|
||||||
ID: strconv.Itoa(update.Message.MessageID),
|
ID: strconv.Itoa(update.Message.MessageID),
|
||||||
ReplyTo: strconv.Itoa(update.Message.ReplyToMessage.MessageID),
|
|
||||||
Raw: raw,
|
Raw: raw,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,13 +205,6 @@ func (t *TelegramPlatform) ParseChannelFromMessage(body []byte) (map[string]any,
|
||||||
|
|
||||||
// SendMessage sends a message to Telegram
|
// SendMessage sends a message to Telegram
|
||||||
func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
||||||
// Check for delete message action (legacy method)
|
|
||||||
if msg.Raw != nil && msg.Raw["action"] == "delete" {
|
|
||||||
// This is a request to delete a message using the legacy method
|
|
||||||
return t.deleteMessage(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regular message sending
|
|
||||||
// Convert chat ID to int64
|
// Convert chat ID to int64
|
||||||
chatID, err := strconv.ParseInt(msg.Chat, 10, 64)
|
chatID, err := strconv.ParseInt(msg.Chat, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -237,15 +218,6 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
||||||
"text": msg.Text,
|
"text": msg.Text,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set parse_mode based on plugin preference or default to empty string
|
|
||||||
if msg.Raw != nil && msg.Raw["parse_mode"] != nil {
|
|
||||||
// Plugin explicitly set parse_mode
|
|
||||||
payload["parse_mode"] = msg.Raw["parse_mode"]
|
|
||||||
} else {
|
|
||||||
// Default to empty string (no formatting)
|
|
||||||
payload["parse_mode"] = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add reply if needed
|
// Add reply if needed
|
||||||
if msg.ReplyTo != "" {
|
if msg.ReplyTo != "" {
|
||||||
replyToID, err := strconv.Atoi(msg.ReplyTo)
|
replyToID, err := strconv.Atoi(msg.ReplyTo)
|
||||||
|
@ -275,11 +247,7 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
||||||
t.log.Error("Failed to send message", "error", err)
|
t.log.Error("Failed to send message", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer resp.Body.Close()
|
||||||
if err := resp.Body.Close(); err != nil {
|
|
||||||
t.log.Error("Error closing response body", "error", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
@ -291,89 +259,4 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
||||||
|
|
||||||
t.log.Debug("Message sent successfully")
|
t.log.Debug("Message sent successfully")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMessage deletes a message on Telegram
|
|
||||||
func (t *TelegramPlatform) DeleteMessage(channel string, messageID string) error {
|
|
||||||
// Convert chat ID to int64
|
|
||||||
chatID, err := strconv.ParseInt(channel, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
t.log.Error("Invalid chat ID for message deletion", "chat_id", channel, "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert message ID to integer
|
|
||||||
msgID, err := strconv.Atoi(messageID)
|
|
||||||
if err != nil {
|
|
||||||
t.log.Error("Invalid message ID for deletion", "message_id", messageID, "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare payload for deleteMessage API
|
|
||||||
payload := map[string]interface{}{
|
|
||||||
"chat_id": chatID,
|
|
||||||
"message_id": msgID,
|
|
||||||
}
|
|
||||||
|
|
||||||
t.log.Debug("Deleting message on Telegram", "chat_id", chatID, "message_id", msgID)
|
|
||||||
|
|
||||||
// Convert payload to JSON
|
|
||||||
data, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
t.log.Error("Failed to marshal delete message payload", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send HTTP request to deleteMessage endpoint
|
|
||||||
resp, err := http.Post(
|
|
||||||
t.apiURL+"/deleteMessage",
|
|
||||||
"application/json",
|
|
||||||
bytes.NewBuffer(data),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.log.Error("Failed to delete message", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := resp.Body.Close(); err != nil {
|
|
||||||
t.log.Error("Error closing response body", "error", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Check response
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
||||||
errMsg := string(bodyBytes)
|
|
||||||
t.log.Error("Telegram API error when deleting message", "status", resp.StatusCode, "response", errMsg)
|
|
||||||
return fmt.Errorf("telegram API error when deleting message: %d - %s", resp.StatusCode, errMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.log.Debug("Message deleted successfully")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteMessage is a legacy method that uses the Raw message approach
|
|
||||||
func (t *TelegramPlatform) deleteMessage(msg *model.Message) error {
|
|
||||||
// Get message ID to delete
|
|
||||||
messageIDInterface, ok := msg.Raw["message_id"]
|
|
||||||
if !ok {
|
|
||||||
t.log.Error("No message ID provided for deletion")
|
|
||||||
return fmt.Errorf("no message ID provided for deletion")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert message ID to string
|
|
||||||
var messageIDStr string
|
|
||||||
switch v := messageIDInterface.(type) {
|
|
||||||
case string:
|
|
||||||
messageIDStr = v
|
|
||||||
case int:
|
|
||||||
messageIDStr = strconv.Itoa(v)
|
|
||||||
case float64:
|
|
||||||
messageIDStr = strconv.Itoa(int(v))
|
|
||||||
default:
|
|
||||||
t.log.Error("Invalid message ID type for deletion", "type", fmt.Sprintf("%T", messageIDInterface))
|
|
||||||
return fmt.Errorf("invalid message ID type for deletion")
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.DeleteMessage(msg.Chat, messageIDStr)
|
|
||||||
}
|
|
|
@ -1,132 +0,0 @@
|
||||||
package domainblock
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DomainBlockPlugin is a plugin that blocks messages containing links from specific domains
|
|
||||||
type DomainBlockPlugin struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug helper to check if RequiresConfig is working
|
|
||||||
func (p *DomainBlockPlugin) RequiresConfig() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new DomainBlockPlugin instance
|
|
||||||
func New() *DomainBlockPlugin {
|
|
||||||
return &DomainBlockPlugin{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "security.domainblock",
|
|
||||||
Name: "Domain Blocker",
|
|
||||||
Help: "Blocks messages containing links from configured domains",
|
|
||||||
ConfigRequired: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractDomains extracts domains from a message text
|
|
||||||
func extractDomains(text string) []string {
|
|
||||||
// URL regex pattern
|
|
||||||
urlPattern := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`)
|
|
||||||
matches := urlPattern.FindAllStringSubmatch(text, -1)
|
|
||||||
|
|
||||||
domains := make([]string, 0, len(matches))
|
|
||||||
for _, match := range matches {
|
|
||||||
if len(match) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to parse the URL to extract the domain
|
|
||||||
urlStr := match[0]
|
|
||||||
parsedURL, err := url.Parse(urlStr)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the domain (host) from the URL
|
|
||||||
domain := parsedURL.Host
|
|
||||||
// Remove port if present
|
|
||||||
if i := strings.IndexByte(domain, ':'); i >= 0 {
|
|
||||||
domain = domain[:i]
|
|
||||||
}
|
|
||||||
|
|
||||||
domains = append(domains, strings.ToLower(domain))
|
|
||||||
}
|
|
||||||
|
|
||||||
return domains
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage processes incoming messages
|
|
||||||
func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
// Skip messages from bots
|
|
||||||
if msg.FromBot {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get blocked domains from config
|
|
||||||
blockedDomainsStr, ok := config["blocked_domains"].(string)
|
|
||||||
if !ok || blockedDomainsStr == "" {
|
|
||||||
return nil // No blocked domains configured
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split and clean blocked domains
|
|
||||||
blockedDomains := strings.Split(blockedDomainsStr, ",")
|
|
||||||
for i, domain := range blockedDomains {
|
|
||||||
blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract domains from message
|
|
||||||
messageDomains := extractDomains(msg.Text)
|
|
||||||
if len(messageDomains) == 0 {
|
|
||||||
return nil // No domains in message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if any domains in the message are blocked
|
|
||||||
for _, msgDomain := range messageDomains {
|
|
||||||
for _, blockedDomain := range blockedDomains {
|
|
||||||
if blockedDomain == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasSuffix(msgDomain, blockedDomain) || msgDomain == blockedDomain {
|
|
||||||
// Domain is blocked, create actions
|
|
||||||
|
|
||||||
// 1. Create a delete message action
|
|
||||||
deleteAction := &model.MessageAction{
|
|
||||||
Type: model.ActionDeleteMessage,
|
|
||||||
MessageID: msg.ID,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Create a notification message action
|
|
||||||
notificationMsg := &model.Message{
|
|
||||||
Text: fmt.Sprintf("I don't like links from %s 🙈", blockedDomain),
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
sendAction := &model.MessageAction{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: notificationMsg,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{deleteAction, sendAction}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Plugin is registered in app.go, not using init()
|
|
|
@ -1,142 +0,0 @@
|
||||||
package domainblock
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestExtractDomains(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
text string
|
|
||||||
expected []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "No URLs",
|
|
||||||
text: "Hello, world!",
|
|
||||||
expected: []string{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Single URL",
|
|
||||||
text: "Check out https://example.com for more info",
|
|
||||||
expected: []string{"example.com"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Multiple URLs",
|
|
||||||
text: "Check out https://example.com and http://test.example.org for more info",
|
|
||||||
expected: []string{"example.com", "test.example.org"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "URL with path",
|
|
||||||
text: "Check out https://example.com/path/to/resource",
|
|
||||||
expected: []string{"example.com"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "URL with port",
|
|
||||||
text: "Check out https://example.com:8080/path/to/resource",
|
|
||||||
expected: []string{"example.com"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "URL with subdomain",
|
|
||||||
text: "Check out https://sub.example.com",
|
|
||||||
expected: []string{"sub.example.com"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
domains := extractDomains(test.text)
|
|
||||||
|
|
||||||
if len(domains) != len(test.expected) {
|
|
||||||
t.Errorf("Expected %d domains, got %d", len(test.expected), len(domains))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, domain := range domains {
|
|
||||||
if domain != test.expected[i] {
|
|
||||||
t.Errorf("Expected domain %s, got %s", test.expected[i], domain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOnMessage(t *testing.T) {
|
|
||||||
plugin := New()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
text string
|
|
||||||
blockedDomains string
|
|
||||||
expectBlocked bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "No blocked domains",
|
|
||||||
text: "Check out https://example.com",
|
|
||||||
blockedDomains: "",
|
|
||||||
expectBlocked: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No matching domain",
|
|
||||||
text: "Check out https://example.com",
|
|
||||||
blockedDomains: "bad.com, evil.org",
|
|
||||||
expectBlocked: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Matching domain",
|
|
||||||
text: "Check out https://example.com",
|
|
||||||
blockedDomains: "example.com, evil.org",
|
|
||||||
expectBlocked: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Matching subdomain",
|
|
||||||
text: "Check out https://sub.example.com",
|
|
||||||
blockedDomains: "example.com",
|
|
||||||
expectBlocked: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Multiple domains, one matching",
|
|
||||||
text: "Check out https://example.com and https://good.org",
|
|
||||||
blockedDomains: "bad.com, example.com",
|
|
||||||
expectBlocked: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Spaces in blocked domains list",
|
|
||||||
text: "Check out https://example.com",
|
|
||||||
blockedDomains: "bad.com, example.com , evil.org",
|
|
||||||
expectBlocked: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
msg := &model.Message{
|
|
||||||
Text: test.text,
|
|
||||||
Chat: "test-chat",
|
|
||||||
ID: "test-id",
|
|
||||||
Channel: &model.Channel{
|
|
||||||
ID: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
config := map[string]interface{}{
|
|
||||||
"blocked_domains": test.blockedDomains,
|
|
||||||
}
|
|
||||||
|
|
||||||
mockCache := &testutil.MockCache{}
|
|
||||||
responses := plugin.OnMessage(msg, config, mockCache)
|
|
||||||
|
|
||||||
if test.expectBlocked {
|
|
||||||
if len(responses) == 0 {
|
|
||||||
t.Errorf("Expected message to be blocked, but it wasn't")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if len(responses) > 0 {
|
|
||||||
t.Errorf("Expected message not to be blocked, but it was")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -29,7 +29,7 @@ func NewCoin() *CoinPlugin {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
// OnMessage handles incoming messages
|
||||||
func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
||||||
if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") {
|
if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -46,12 +46,5 @@ func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}
|
||||||
Channel: msg.Channel,
|
Channel: msg.Channel,
|
||||||
}
|
}
|
||||||
|
|
||||||
action := &model.MessageAction{
|
return []*model.Message{response}
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{action}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ func NewDice() *DicePlugin {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
// OnMessage handles incoming messages
|
||||||
func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
||||||
if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") {
|
if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -62,14 +62,7 @@ func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}
|
||||||
Channel: msg.Channel,
|
Channel: msg.Channel,
|
||||||
}
|
}
|
||||||
|
|
||||||
action := &model.MessageAction{
|
return []*model.Message{response}
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{action}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// rollDice parses a dice formula string and returns the result
|
// rollDice parses a dice formula string and returns the result
|
||||||
|
@ -114,10 +107,9 @@ func (p *DicePlugin) rollDice(formula string) (int, error) {
|
||||||
return 0, fmt.Errorf("invalid modifier")
|
return 0, fmt.Errorf("invalid modifier")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch matches[3] {
|
if matches[3] == "+" {
|
||||||
case "+":
|
|
||||||
total += modifier
|
total += modifier
|
||||||
case "-":
|
} else if matches[3] == "-" {
|
||||||
total -= modifier
|
total -= modifier
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,394 +0,0 @@
|
||||||
package fun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// HLTBPlugin searches HowLongToBeat for game completion times
|
|
||||||
type HLTBPlugin struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
httpClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// HLTBSearchRequest represents the search request payload
|
|
||||||
type HLTBSearchRequest struct {
|
|
||||||
SearchType string `json:"searchType"`
|
|
||||||
SearchTerms []string `json:"searchTerms"`
|
|
||||||
SearchPage int `json:"searchPage"`
|
|
||||||
Size int `json:"size"`
|
|
||||||
SearchOptions map[string]interface{} `json:"searchOptions"`
|
|
||||||
UseCache bool `json:"useCache"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// HLTBGame represents a game from HowLongToBeat
|
|
||||||
type HLTBGame struct {
|
|
||||||
ID int `json:"game_id"`
|
|
||||||
Name string `json:"game_name"`
|
|
||||||
GameAlias string `json:"game_alias"`
|
|
||||||
GameImage string `json:"game_image"`
|
|
||||||
CompMain int `json:"comp_main"`
|
|
||||||
CompPlus int `json:"comp_plus"`
|
|
||||||
CompComplete int `json:"comp_complete"`
|
|
||||||
CompAll int `json:"comp_all"`
|
|
||||||
InvestedCo int `json:"invested_co"`
|
|
||||||
InvestedMp int `json:"invested_mp"`
|
|
||||||
CountComp int `json:"count_comp"`
|
|
||||||
CountSpeedruns int `json:"count_speedruns"`
|
|
||||||
CountBacklog int `json:"count_backlog"`
|
|
||||||
CountReview int `json:"count_review"`
|
|
||||||
ReviewScore int `json:"review_score"`
|
|
||||||
CountPlaying int `json:"count_playing"`
|
|
||||||
CountRetired int `json:"count_retired"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// HLTBSearchResponse represents the search response
|
|
||||||
type HLTBSearchResponse struct {
|
|
||||||
Color string `json:"color"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Category string `json:"category"`
|
|
||||||
Count int `json:"count"`
|
|
||||||
Pagecurrent int `json:"pagecurrent"`
|
|
||||||
Pagesize int `json:"pagesize"`
|
|
||||||
Pagetotal int `json:"pagetotal"`
|
|
||||||
SearchTerm string `json:"searchTerm"`
|
|
||||||
SearchResults []HLTBGame `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHLTB creates a new HLTBPlugin instance
|
|
||||||
func NewHLTB() *HLTBPlugin {
|
|
||||||
return &HLTBPlugin{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "fun.hltb",
|
|
||||||
Name: "How Long To Beat",
|
|
||||||
Help: "Get game completion times from HowLongToBeat.com using `!hltb <game name>`",
|
|
||||||
},
|
|
||||||
httpClient: &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
|
||||||
func (p *HLTBPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
// Check if message starts with !hltb
|
|
||||||
text := strings.TrimSpace(msg.Text)
|
|
||||||
if !strings.HasPrefix(text, "!hltb ") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract game name
|
|
||||||
gameName := strings.TrimSpace(text[6:]) // Remove "!hltb "
|
|
||||||
if gameName == "" {
|
|
||||||
return p.createErrorResponse(msg, "Please provide a game name. Usage: !hltb <game name>")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check cache first
|
|
||||||
var games []HLTBGame
|
|
||||||
var err error
|
|
||||||
cacheKey := strings.ToLower(gameName)
|
|
||||||
|
|
||||||
err = cache.Get(cacheKey, &games)
|
|
||||||
if err != nil || len(games) == 0 {
|
|
||||||
// Cache miss - search for the game
|
|
||||||
games, err = p.searchGame(gameName)
|
|
||||||
if err != nil {
|
|
||||||
return p.createErrorResponse(msg, fmt.Sprintf("Error searching for game: %s", err.Error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(games) == 0 {
|
|
||||||
return p.createErrorResponse(msg, fmt.Sprintf("No results found for '%s'", gameName))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache the results for 1 hour
|
|
||||||
err = cache.SetWithTTL(cacheKey, games, time.Hour)
|
|
||||||
if err != nil {
|
|
||||||
// Log cache error but don't fail the request
|
|
||||||
fmt.Printf("Warning: Failed to cache HLTB results: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the first result
|
|
||||||
game := games[0]
|
|
||||||
|
|
||||||
// Format the response
|
|
||||||
response := p.formatGameInfo(game)
|
|
||||||
|
|
||||||
// Create response message with game cover if available
|
|
||||||
responseMsg := &model.Message{
|
|
||||||
Text: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set parse mode for markdown formatting
|
|
||||||
if responseMsg.Raw == nil {
|
|
||||||
responseMsg.Raw = make(map[string]interface{})
|
|
||||||
}
|
|
||||||
responseMsg.Raw["parse_mode"] = "Markdown"
|
|
||||||
|
|
||||||
// Add game cover as attachment if available
|
|
||||||
if game.GameImage != "" {
|
|
||||||
imageURL := p.getFullImageURL(game.GameImage)
|
|
||||||
responseMsg.Raw["image_url"] = imageURL
|
|
||||||
}
|
|
||||||
|
|
||||||
action := &model.MessageAction{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: responseMsg,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{action}
|
|
||||||
}
|
|
||||||
|
|
||||||
// searchGame searches for a game on HowLongToBeat
|
|
||||||
func (p *HLTBPlugin) searchGame(gameName string) ([]HLTBGame, error) {
|
|
||||||
// Split search terms by words
|
|
||||||
searchTerms := strings.Fields(gameName)
|
|
||||||
|
|
||||||
// Prepare search request
|
|
||||||
searchRequest := HLTBSearchRequest{
|
|
||||||
SearchType: "games",
|
|
||||||
SearchTerms: searchTerms,
|
|
||||||
SearchPage: 1,
|
|
||||||
Size: 20,
|
|
||||||
SearchOptions: map[string]interface{}{
|
|
||||||
"games": map[string]interface{}{
|
|
||||||
"userId": 0,
|
|
||||||
"platform": "",
|
|
||||||
"sortCategory": "popular",
|
|
||||||
"rangeCategory": "main",
|
|
||||||
"rangeTime": map[string]interface{}{
|
|
||||||
"min": nil,
|
|
||||||
"max": nil,
|
|
||||||
},
|
|
||||||
"gameplay": map[string]interface{}{
|
|
||||||
"perspective": "",
|
|
||||||
"flow": "",
|
|
||||||
"genre": "",
|
|
||||||
"difficulty": "",
|
|
||||||
},
|
|
||||||
"rangeYear": map[string]interface{}{
|
|
||||||
"min": "",
|
|
||||||
"max": "",
|
|
||||||
},
|
|
||||||
"modifier": "",
|
|
||||||
},
|
|
||||||
"users": map[string]interface{}{
|
|
||||||
"sortCategory": "postcount",
|
|
||||||
},
|
|
||||||
"lists": map[string]interface{}{
|
|
||||||
"sortCategory": "follows",
|
|
||||||
},
|
|
||||||
"filter": "",
|
|
||||||
"sort": 0,
|
|
||||||
"randomizer": 0,
|
|
||||||
},
|
|
||||||
UseCache: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to JSON
|
|
||||||
jsonData, err := json.Marshal(searchRequest)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to marshal search request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The API endpoint appears to have changed to use dynamic tokens
|
|
||||||
// Try to get the seek token first, fallback to basic search
|
|
||||||
seekToken, err := p.getSeekToken()
|
|
||||||
if err != nil {
|
|
||||||
// Fallback to old endpoint
|
|
||||||
seekToken = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var apiURL string
|
|
||||||
if seekToken != "" {
|
|
||||||
apiURL = fmt.Sprintf("https://howlongtobeat.com/api/seek/%s", seekToken)
|
|
||||||
} else {
|
|
||||||
apiURL = "https://howlongtobeat.com/api/search"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create HTTP request
|
|
||||||
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set headers to match the working curl request
|
|
||||||
req.Header.Set("Accept", "*/*")
|
|
||||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
|
||||||
req.Header.Set("Cache-Control", "no-cache")
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Origin", "https://howlongtobeat.com")
|
|
||||||
req.Header.Set("Pragma", "no-cache")
|
|
||||||
req.Header.Set("Referer", "https://howlongtobeat.com")
|
|
||||||
req.Header.Set("Sec-Fetch-Dest", "empty")
|
|
||||||
req.Header.Set("Sec-Fetch-Mode", "cors")
|
|
||||||
req.Header.Set("Sec-Fetch-Site", "same-origin")
|
|
||||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36")
|
|
||||||
|
|
||||||
// Send request
|
|
||||||
resp, err := p.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("API returned status code: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read response body
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse response
|
|
||||||
var searchResponse HLTBSearchResponse
|
|
||||||
if err := json.Unmarshal(body, &searchResponse); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return searchResponse.SearchResults, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatGameInfo formats game information for display
|
|
||||||
func (p *HLTBPlugin) formatGameInfo(game HLTBGame) string {
|
|
||||||
var response strings.Builder
|
|
||||||
|
|
||||||
response.WriteString(fmt.Sprintf("🎮 **%s**\n\n", game.Name))
|
|
||||||
|
|
||||||
// Format completion times
|
|
||||||
if game.CompMain > 0 {
|
|
||||||
response.WriteString(fmt.Sprintf("📖 **Main Story:** %s\n", p.formatTime(game.CompMain)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if game.CompPlus > 0 {
|
|
||||||
response.WriteString(fmt.Sprintf("➕ **Main + Extras:** %s\n", p.formatTime(game.CompPlus)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if game.CompComplete > 0 {
|
|
||||||
response.WriteString(fmt.Sprintf("💯 **Completionist:** %s\n", p.formatTime(game.CompComplete)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if game.CompAll > 0 {
|
|
||||||
response.WriteString(fmt.Sprintf("🎯 **All Styles:** %s\n", p.formatTime(game.CompAll)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add review score if available
|
|
||||||
if game.ReviewScore > 0 {
|
|
||||||
response.WriteString(fmt.Sprintf("\n⭐ **User Score:** %d/100", game.ReviewScore))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add source attribution
|
|
||||||
response.WriteString("\n\n*Source: HowLongToBeat.com*")
|
|
||||||
|
|
||||||
return response.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatTime converts seconds to a readable time format
|
|
||||||
func (p *HLTBPlugin) formatTime(seconds int) string {
|
|
||||||
if seconds <= 0 {
|
|
||||||
return "N/A"
|
|
||||||
}
|
|
||||||
|
|
||||||
hours := float64(seconds) / 3600.0
|
|
||||||
|
|
||||||
if hours < 1 {
|
|
||||||
minutes := seconds / 60
|
|
||||||
return fmt.Sprintf("%d minutes", minutes)
|
|
||||||
} else if hours < 2 {
|
|
||||||
return fmt.Sprintf("%.1f hour", hours)
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%.1f hours", hours)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getFullImageURL constructs the full image URL
|
|
||||||
func (p *HLTBPlugin) getFullImageURL(imagePath string) string {
|
|
||||||
if imagePath == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove leading slash if present
|
|
||||||
imagePath = strings.TrimPrefix(imagePath, "/")
|
|
||||||
|
|
||||||
return fmt.Sprintf("https://howlongtobeat.com/games/%s", imagePath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getSeekToken attempts to retrieve the seek token from HowLongToBeat
|
|
||||||
func (p *HLTBPlugin) getSeekToken() (string, error) {
|
|
||||||
// Try to extract the seek token from the main page
|
|
||||||
req, err := http.NewRequest("GET", "https://howlongtobeat.com", nil)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to create token request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36")
|
|
||||||
|
|
||||||
resp, err := p.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to fetch token: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to read token response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look for patterns that might contain the token
|
|
||||||
patterns := []string{
|
|
||||||
`/api/seek/([a-f0-9]+)`,
|
|
||||||
`"seek/([a-f0-9]+)"`,
|
|
||||||
`seek/([a-f0-9]{12,})`,
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyStr := string(body)
|
|
||||||
for _, pattern := range patterns {
|
|
||||||
re := regexp.MustCompile(pattern)
|
|
||||||
matches := re.FindStringSubmatch(bodyStr)
|
|
||||||
if len(matches) > 1 {
|
|
||||||
return matches[1], nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we can't extract a token, return the known working one as fallback
|
|
||||||
return "d4b2e330db04dbf3", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createErrorResponse creates an error response message
|
|
||||||
func (p *HLTBPlugin) createErrorResponse(msg *model.Message, errorText string) []*model.MessageAction {
|
|
||||||
response := &model.Message{
|
|
||||||
Text: fmt.Sprintf("❌ %s", errorText),
|
|
||||||
Chat: msg.Chat,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
action := &model.MessageAction{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{action}
|
|
||||||
}
|
|
|
@ -1,131 +0,0 @@
|
||||||
package fun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHLTBPlugin_OnMessage(t *testing.T) {
|
|
||||||
plugin := NewHLTB()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
messageText string
|
|
||||||
shouldRespond bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "responds to !hltb command",
|
|
||||||
messageText: "!hltb The Witcher 3",
|
|
||||||
shouldRespond: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ignores non-hltb messages",
|
|
||||||
messageText: "hello world",
|
|
||||||
shouldRespond: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ignores !hltb without game name",
|
|
||||||
messageText: "!hltb",
|
|
||||||
shouldRespond: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ignores !hltb with only spaces",
|
|
||||||
messageText: "!hltb ",
|
|
||||||
shouldRespond: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ignores similar but incorrect commands",
|
|
||||||
messageText: "hltb The Witcher 3",
|
|
||||||
shouldRespond: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
msg := &model.Message{
|
|
||||||
Text: tt.messageText,
|
|
||||||
Chat: "test-chat",
|
|
||||||
Channel: &model.Channel{ID: 1},
|
|
||||||
Author: "test-user",
|
|
||||||
}
|
|
||||||
|
|
||||||
mockCache := &testutil.MockCache{}
|
|
||||||
actions := plugin.OnMessage(msg, make(map[string]interface{}), mockCache)
|
|
||||||
|
|
||||||
if tt.shouldRespond && len(actions) == 0 {
|
|
||||||
t.Errorf("Expected plugin to respond to '%s', but it didn't", tt.messageText)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tt.shouldRespond && len(actions) > 0 {
|
|
||||||
t.Errorf("Expected plugin to not respond to '%s', but it did", tt.messageText)
|
|
||||||
}
|
|
||||||
|
|
||||||
// For messages that should respond, verify the response structure
|
|
||||||
if tt.shouldRespond && len(actions) > 0 {
|
|
||||||
action := actions[0]
|
|
||||||
if action.Type != model.ActionSendMessage {
|
|
||||||
t.Errorf("Expected ActionSendMessage, got %s", action.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
if action.Message == nil {
|
|
||||||
t.Error("Expected action to have a message")
|
|
||||||
}
|
|
||||||
|
|
||||||
if action.Message != nil && action.Message.ReplyTo != msg.ID {
|
|
||||||
t.Error("Expected response to reply to original message")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHLTBPlugin_formatTime(t *testing.T) {
|
|
||||||
plugin := NewHLTB()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
seconds int
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{0, "N/A"},
|
|
||||||
{-1, "N/A"},
|
|
||||||
{1800, "30 minutes"}, // 30 minutes
|
|
||||||
{3600, "1.0 hour"}, // 1 hour
|
|
||||||
{7200, "2.0 hours"}, // 2 hours
|
|
||||||
{10800, "3.0 hours"}, // 3 hours
|
|
||||||
{36000, "10.0 hours"}, // 10 hours
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.expected, func(t *testing.T) {
|
|
||||||
result := plugin.formatTime(tt.seconds)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("formatTime(%d) = %s, want %s", tt.seconds, result, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHLTBPlugin_getFullImageURL(t *testing.T) {
|
|
||||||
plugin := NewHLTB()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
imagePath string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"", ""},
|
|
||||||
{"game.jpg", "https://howlongtobeat.com/games/game.jpg"},
|
|
||||||
{"/game.jpg", "https://howlongtobeat.com/games/game.jpg"},
|
|
||||||
{"folder/game.png", "https://howlongtobeat.com/games/folder/game.png"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.imagePath, func(t *testing.T) {
|
|
||||||
result := plugin.getFullImageURL(tt.imagePath)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("getFullImageURL(%s) = %s, want %s", tt.imagePath, result, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -23,13 +23,8 @@ func NewLoquito() *LoquitoPlugin {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHelp returns the plugin help text
|
|
||||||
func (p *LoquitoPlugin) GetHelp() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
// OnMessage handles incoming messages
|
||||||
func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
||||||
if !strings.Contains(strings.ToLower(msg.Text), "lo quito") {
|
if !strings.Contains(strings.ToLower(msg.Text), "lo quito") {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -41,12 +36,5 @@ func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interfac
|
||||||
Channel: msg.Channel,
|
Channel: msg.Channel,
|
||||||
}
|
}
|
||||||
|
|
||||||
action := &model.MessageAction{
|
return []*model.Message{response}
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{action}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,166 +0,0 @@
|
||||||
package help
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
"golang.org/x/exp/slog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ChannelPluginGetter is an interface for getting channel plugins
|
|
||||||
type ChannelPluginGetter interface {
|
|
||||||
GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error)
|
|
||||||
GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HelpPlugin provides help information about available commands
|
|
||||||
type HelpPlugin struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
db ChannelPluginGetter
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new HelpPlugin instance
|
|
||||||
func New(db ChannelPluginGetter) *HelpPlugin {
|
|
||||||
return &HelpPlugin{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "utility.help",
|
|
||||||
Name: "Help",
|
|
||||||
Help: "Shows available commands when you type '!help'",
|
|
||||||
},
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
|
||||||
func (p *HelpPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
// Check if message is the help command
|
|
||||||
if !strings.EqualFold(strings.TrimSpace(msg.Text), "!help") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get channel plugins from database using platform and platform channel ID
|
|
||||||
channelPlugins, err := p.db.GetChannelPluginsFromPlatformID(msg.Channel.Platform, msg.Channel.PlatformChannelID)
|
|
||||||
if err != nil && err != db.ErrNotFound {
|
|
||||||
slog.Error("Failed to get channel plugins", slog.Any("err", err))
|
|
||||||
return []*model.MessageAction{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no plugins found, initialize empty slice
|
|
||||||
if err == db.ErrNotFound {
|
|
||||||
channelPlugins = []*model.ChannelPlugin{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all available plugins
|
|
||||||
availablePlugins := plugin.GetAvailablePlugins()
|
|
||||||
|
|
||||||
// Filter to only enabled plugins for this channel
|
|
||||||
enabledPlugins := make(map[string]model.Plugin)
|
|
||||||
for _, channelPlugin := range channelPlugins {
|
|
||||||
if channelPlugin.Enabled {
|
|
||||||
if availablePlugin, exists := availablePlugins[channelPlugin.PluginID]; exists {
|
|
||||||
enabledPlugins[channelPlugin.PluginID] = availablePlugin
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no plugins are enabled, return a message
|
|
||||||
if len(enabledPlugins) == 0 {
|
|
||||||
response := &model.Message{
|
|
||||||
Text: "No plugins are currently enabled for this channel.",
|
|
||||||
Chat: msg.Chat,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
Raw: map[string]interface{}{"parse_mode": "Markdown"},
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Group plugins by category
|
|
||||||
categories := map[string][]model.Plugin{
|
|
||||||
"Development": {},
|
|
||||||
"Fun and Entertainment": {},
|
|
||||||
"Utility": {},
|
|
||||||
"Security": {},
|
|
||||||
"Social Media": {},
|
|
||||||
"Other": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Categorize plugins based on their ID prefix
|
|
||||||
for _, p := range enabledPlugins {
|
|
||||||
category := p.GetID()
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(category, "dev."):
|
|
||||||
categories["Development"] = append(categories["Development"], p)
|
|
||||||
case strings.HasPrefix(category, "fun."):
|
|
||||||
categories["Fun and Entertainment"] = append(categories["Fun and Entertainment"], p)
|
|
||||||
case strings.HasPrefix(category, "util.") || strings.HasPrefix(category, "reminder.") || strings.HasPrefix(category, "utility."):
|
|
||||||
categories["Utility"] = append(categories["Utility"], p)
|
|
||||||
case strings.HasPrefix(category, "security."):
|
|
||||||
categories["Security"] = append(categories["Security"], p)
|
|
||||||
case strings.HasPrefix(category, "social."):
|
|
||||||
categories["Social Media"] = append(categories["Social Media"], p)
|
|
||||||
default:
|
|
||||||
categories["Other"] = append(categories["Other"], p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the help message
|
|
||||||
var helpText strings.Builder
|
|
||||||
helpText.WriteString("🤖 **Available Commands**\n\n")
|
|
||||||
|
|
||||||
// Sort category names for consistent output
|
|
||||||
categoryOrder := []string{"Development", "Fun and Entertainment", "Utility", "Security", "Social Media", "Other"}
|
|
||||||
|
|
||||||
for _, categoryName := range categoryOrder {
|
|
||||||
pluginList := categories[categoryName]
|
|
||||||
if len(pluginList) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort plugins within category by name
|
|
||||||
sort.Slice(pluginList, func(i, j int) bool {
|
|
||||||
return pluginList[i].GetName() < pluginList[j].GetName()
|
|
||||||
})
|
|
||||||
|
|
||||||
helpText.WriteString(fmt.Sprintf("**%s:**\n", categoryName))
|
|
||||||
for _, p := range pluginList {
|
|
||||||
if p.GetHelp() == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
helpText.WriteString(fmt.Sprintf("• **%s** - %s\n", p.GetName(), p.GetHelp()))
|
|
||||||
}
|
|
||||||
helpText.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add footer
|
|
||||||
helpText.WriteString("_Use the specific commands or triggers mentioned above to interact with the bot._")
|
|
||||||
|
|
||||||
response := &model.Message{
|
|
||||||
Text: helpText.String(),
|
|
||||||
Chat: msg.Chat,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
Raw: map[string]interface{}{"parse_mode": "Markdown"},
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,206 +0,0 @@
|
||||||
package help
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockPlugin implements the Plugin interface for testing
|
|
||||||
type MockPlugin struct {
|
|
||||||
id string
|
|
||||||
name string
|
|
||||||
help string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockPlugin) GetID() string { return m.id }
|
|
||||||
func (m *MockPlugin) GetName() string { return m.name }
|
|
||||||
func (m *MockPlugin) GetHelp() string { return m.help }
|
|
||||||
func (m *MockPlugin) RequiresConfig() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
func (m *MockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockDatabase implements the ChannelPluginGetter interface for testing
|
|
||||||
type MockDatabase struct {
|
|
||||||
channelPlugins map[int64][]*model.ChannelPlugin
|
|
||||||
platformChannelPlugins map[string][]*model.ChannelPlugin // key: "platform:platformChannelID"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockDatabase) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) {
|
|
||||||
if plugins, exists := m.channelPlugins[channelID]; exists {
|
|
||||||
return plugins, nil
|
|
||||||
}
|
|
||||||
return nil, db.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockDatabase) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) {
|
|
||||||
key := platform + ":" + platformChannelID
|
|
||||||
if plugins, exists := m.platformChannelPlugins[key]; exists {
|
|
||||||
return plugins, nil
|
|
||||||
}
|
|
||||||
return nil, db.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHelpPlugin_OnMessage(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
messageText string
|
|
||||||
enabledPlugins map[string]*MockPlugin
|
|
||||||
expectResponse bool
|
|
||||||
expectNoPlugins bool
|
|
||||||
expectCategories []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "responds to !help command",
|
|
||||||
messageText: "!help",
|
|
||||||
enabledPlugins: map[string]*MockPlugin{
|
|
||||||
"dev.ping": {
|
|
||||||
id: "dev.ping",
|
|
||||||
name: "Ping",
|
|
||||||
help: "Responds to 'ping' with 'pong'",
|
|
||||||
},
|
|
||||||
"fun.dice": {
|
|
||||||
id: "fun.dice",
|
|
||||||
name: "Dice Roller",
|
|
||||||
help: "Rolls dice when you type '!dice [formula]'",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectResponse: true,
|
|
||||||
expectCategories: []string{"Development", "Fun and Entertainment"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ignores non-help messages",
|
|
||||||
messageText: "hello world",
|
|
||||||
enabledPlugins: map[string]*MockPlugin{},
|
|
||||||
expectResponse: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ignores case variation",
|
|
||||||
messageText: "!HELP",
|
|
||||||
enabledPlugins: map[string]*MockPlugin{},
|
|
||||||
expectResponse: true,
|
|
||||||
expectNoPlugins: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "handles no enabled plugins",
|
|
||||||
messageText: "!help",
|
|
||||||
enabledPlugins: map[string]*MockPlugin{},
|
|
||||||
expectResponse: true,
|
|
||||||
expectNoPlugins: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Create mock database
|
|
||||||
mockDB := &MockDatabase{
|
|
||||||
channelPlugins: make(map[int64][]*model.ChannelPlugin),
|
|
||||||
platformChannelPlugins: make(map[string][]*model.ChannelPlugin),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup channel plugins in mock database
|
|
||||||
var channelPluginList []*model.ChannelPlugin
|
|
||||||
pluginCounter := int64(1)
|
|
||||||
for pluginID := range tt.enabledPlugins {
|
|
||||||
channelPluginList = append(channelPluginList, &model.ChannelPlugin{
|
|
||||||
ID: pluginCounter,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: pluginID,
|
|
||||||
Enabled: true,
|
|
||||||
Config: make(map[string]interface{}),
|
|
||||||
})
|
|
||||||
pluginCounter++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up both mapping approaches for the test
|
|
||||||
mockDB.channelPlugins[1] = channelPluginList
|
|
||||||
mockDB.platformChannelPlugins["test:test-channel"] = channelPluginList
|
|
||||||
|
|
||||||
// Create help plugin
|
|
||||||
p := New(mockDB)
|
|
||||||
|
|
||||||
// Create mock channel
|
|
||||||
channel := &model.Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "test",
|
|
||||||
PlatformChannelID: "test-channel",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create test message
|
|
||||||
msg := &model.Message{
|
|
||||||
ID: "test-msg",
|
|
||||||
Text: tt.messageText,
|
|
||||||
Chat: "test-chat",
|
|
||||||
Channel: channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mock the plugin registry
|
|
||||||
originalRegistry := plugin.GetAvailablePlugins()
|
|
||||||
|
|
||||||
// Override the registry for this test
|
|
||||||
plugin.ClearRegistry()
|
|
||||||
for _, mockPlugin := range tt.enabledPlugins {
|
|
||||||
plugin.Register(mockPlugin)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call OnMessage
|
|
||||||
actions := p.OnMessage(msg, map[string]interface{}{}, nil)
|
|
||||||
|
|
||||||
// Restore original registry
|
|
||||||
plugin.ClearRegistry()
|
|
||||||
for _, p := range originalRegistry {
|
|
||||||
plugin.Register(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tt.expectResponse {
|
|
||||||
if len(actions) != 0 {
|
|
||||||
t.Errorf("Expected no response, but got %d actions", len(actions))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(actions) != 1 {
|
|
||||||
t.Errorf("Expected 1 action, got %d", len(actions))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
action := actions[0]
|
|
||||||
if action.Type != model.ActionSendMessage {
|
|
||||||
t.Errorf("Expected ActionSendMessage, got %v", action.Type)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
responseText := action.Message.Text
|
|
||||||
|
|
||||||
if tt.expectNoPlugins {
|
|
||||||
if !strings.Contains(responseText, "No plugins are currently enabled") {
|
|
||||||
t.Errorf("Expected 'no plugins' message, got: %s", responseText)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that expected categories appear in response
|
|
||||||
for _, category := range tt.expectCategories {
|
|
||||||
if !strings.Contains(responseText, "**"+category+":**") {
|
|
||||||
t.Errorf("Expected category '%s' in response, got: %s", category, responseText)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that plugin names and help text appear
|
|
||||||
for _, mockPlugin := range tt.enabledPlugins {
|
|
||||||
if !strings.Contains(responseText, mockPlugin.GetName()) {
|
|
||||||
t.Errorf("Expected plugin name '%s' in response", mockPlugin.GetName())
|
|
||||||
}
|
|
||||||
if !strings.Contains(responseText, mockPlugin.GetHelp()) {
|
|
||||||
t.Errorf("Expected plugin help '%s' in response", mockPlugin.GetHelp())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -24,12 +24,11 @@ func New() *PingPlugin {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
// OnMessage handles incoming messages
|
||||||
func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
||||||
if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") {
|
if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the response message
|
|
||||||
response := &model.Message{
|
response := &model.Message{
|
||||||
Text: "pong",
|
Text: "pong",
|
||||||
Chat: msg.Chat,
|
Chat: msg.Chat,
|
||||||
|
@ -37,13 +36,5 @@ func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}
|
||||||
Channel: msg.Channel,
|
Channel: msg.Channel,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an action to send the message
|
return []*model.Message{response}
|
||||||
action := &model.MessageAction{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{action}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package plugin
|
package plugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"maps"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||||
|
@ -42,31 +41,13 @@ func GetAvailablePlugins() map[string]model.Plugin {
|
||||||
|
|
||||||
// Create a copy to avoid race conditions
|
// Create a copy to avoid race conditions
|
||||||
result := make(map[string]model.Plugin, len(plugins))
|
result := make(map[string]model.Plugin, len(plugins))
|
||||||
maps.Copy(result, plugins)
|
for id, plugin := range plugins {
|
||||||
|
result[id] = plugin
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAvailablePluginIDs returns a slice of all registered plugin IDs
|
|
||||||
func GetAvailablePluginIDs() []string {
|
|
||||||
pluginsMu.RLock()
|
|
||||||
defer pluginsMu.RUnlock()
|
|
||||||
|
|
||||||
result := make([]string, 0, len(plugins))
|
|
||||||
for pluginID := range plugins {
|
|
||||||
result = append(result, pluginID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearRegistry clears all registered plugins (for testing)
|
|
||||||
func ClearRegistry() {
|
|
||||||
pluginsMu.Lock()
|
|
||||||
defer pluginsMu.Unlock()
|
|
||||||
plugins = make(map[string]model.Plugin)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BasePlugin provides a common base for plugins
|
// BasePlugin provides a common base for plugins
|
||||||
type BasePlugin struct {
|
type BasePlugin struct {
|
||||||
ID string
|
ID string
|
||||||
|
@ -96,6 +77,6 @@ func (p *BasePlugin) RequiresConfig() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnMessage is the default implementation that does nothing
|
// OnMessage is the default implementation that does nothing
|
||||||
func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,331 +0,0 @@
|
||||||
package plugin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Mock plugin for testing
|
|
||||||
type testPlugin struct {
|
|
||||||
BasePlugin
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: &model.Message{
|
|
||||||
Text: "test response",
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAvailablePluginIDs(t *testing.T) {
|
|
||||||
// Clear registry before test
|
|
||||||
ClearRegistry()
|
|
||||||
|
|
||||||
// Register test plugins
|
|
||||||
testPlugin1 := &testPlugin{
|
|
||||||
BasePlugin: BasePlugin{
|
|
||||||
ID: "test.plugin1",
|
|
||||||
Name: "Test Plugin 1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
testPlugin2 := &testPlugin{
|
|
||||||
BasePlugin: BasePlugin{
|
|
||||||
ID: "test.plugin2",
|
|
||||||
Name: "Test Plugin 2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
Register(testPlugin1)
|
|
||||||
Register(testPlugin2)
|
|
||||||
|
|
||||||
// Test GetAvailablePluginIDs
|
|
||||||
pluginIDs := GetAvailablePluginIDs()
|
|
||||||
|
|
||||||
if len(pluginIDs) != 2 {
|
|
||||||
t.Errorf("Expected 2 plugin IDs, got %d", len(pluginIDs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that both plugin IDs are present
|
|
||||||
found1, found2 := false, false
|
|
||||||
for _, id := range pluginIDs {
|
|
||||||
if id == "test.plugin1" {
|
|
||||||
found1 = true
|
|
||||||
}
|
|
||||||
if id == "test.plugin2" {
|
|
||||||
found2 = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found1 {
|
|
||||||
t.Errorf("Expected to find test.plugin1 in plugin IDs")
|
|
||||||
}
|
|
||||||
if !found2 {
|
|
||||||
t.Errorf("Expected to find test.plugin2 in plugin IDs")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnableAllPluginsProcessingLogic(t *testing.T) {
|
|
||||||
// Clear registry before test
|
|
||||||
ClearRegistry()
|
|
||||||
|
|
||||||
// Register test plugins
|
|
||||||
testPlugin1 := &testPlugin{
|
|
||||||
BasePlugin: BasePlugin{
|
|
||||||
ID: "ping",
|
|
||||||
Name: "Ping Plugin",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
testPlugin2 := &testPlugin{
|
|
||||||
BasePlugin: BasePlugin{
|
|
||||||
ID: "echo",
|
|
||||||
Name: "Echo Plugin",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
testPlugin3 := &testPlugin{
|
|
||||||
BasePlugin: BasePlugin{
|
|
||||||
ID: "help",
|
|
||||||
Name: "Help Plugin",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
Register(testPlugin1)
|
|
||||||
Register(testPlugin2)
|
|
||||||
Register(testPlugin3)
|
|
||||||
|
|
||||||
t.Run("EnableAllPlugins false - only explicitly enabled plugins", func(t *testing.T) {
|
|
||||||
// Create a channel with EnableAllPlugins = false and only some plugins enabled
|
|
||||||
channel := &model.Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: false,
|
|
||||||
Plugins: map[string]*model.ChannelPlugin{
|
|
||||||
"ping": {
|
|
||||||
ID: 1,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "ping",
|
|
||||||
Enabled: true,
|
|
||||||
Config: map[string]interface{}{"key": "value"},
|
|
||||||
},
|
|
||||||
"echo": {
|
|
||||||
ID: 2,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "echo",
|
|
||||||
Enabled: false, // Disabled
|
|
||||||
Config: map[string]interface{}{},
|
|
||||||
},
|
|
||||||
// help plugin not configured
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate the plugin processing logic from handleMessage
|
|
||||||
var pluginsToProcess []string
|
|
||||||
|
|
||||||
if channel.EnableAllPlugins {
|
|
||||||
pluginsToProcess = GetAvailablePluginIDs()
|
|
||||||
} else {
|
|
||||||
for pluginID := range channel.Plugins {
|
|
||||||
if channel.HasEnabledPlugin(pluginID) {
|
|
||||||
pluginsToProcess = append(pluginsToProcess, pluginID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should only have "ping" since echo is disabled and help is not configured
|
|
||||||
if len(pluginsToProcess) != 1 {
|
|
||||||
t.Errorf("Expected 1 plugin to process, got %d: %v", len(pluginsToProcess), pluginsToProcess)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(pluginsToProcess) > 0 && pluginsToProcess[0] != "ping" {
|
|
||||||
t.Errorf("Expected ping plugin to be processed, got %s", pluginsToProcess[0])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EnableAllPlugins true - all registered plugins", func(t *testing.T) {
|
|
||||||
// Create a channel with EnableAllPlugins = true
|
|
||||||
channel := &model.Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: true,
|
|
||||||
Plugins: map[string]*model.ChannelPlugin{
|
|
||||||
"ping": {
|
|
||||||
ID: 1,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "ping",
|
|
||||||
Enabled: true,
|
|
||||||
Config: map[string]interface{}{"key": "value"},
|
|
||||||
},
|
|
||||||
"echo": {
|
|
||||||
ID: 2,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "echo",
|
|
||||||
Enabled: false, // Disabled, but should still be processed
|
|
||||||
Config: map[string]interface{}{},
|
|
||||||
},
|
|
||||||
// help plugin not configured, but should still be processed
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate the plugin processing logic from handleMessage
|
|
||||||
var pluginsToProcess []string
|
|
||||||
|
|
||||||
if channel.EnableAllPlugins {
|
|
||||||
pluginsToProcess = GetAvailablePluginIDs()
|
|
||||||
} else {
|
|
||||||
for pluginID := range channel.Plugins {
|
|
||||||
if channel.HasEnabledPlugin(pluginID) {
|
|
||||||
pluginsToProcess = append(pluginsToProcess, pluginID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should have all 3 registered plugins
|
|
||||||
if len(pluginsToProcess) != 3 {
|
|
||||||
t.Errorf("Expected 3 plugins to process, got %d: %v", len(pluginsToProcess), pluginsToProcess)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that all plugins are included
|
|
||||||
expectedPlugins := map[string]bool{"ping": false, "echo": false, "help": false}
|
|
||||||
for _, pluginID := range pluginsToProcess {
|
|
||||||
if _, exists := expectedPlugins[pluginID]; exists {
|
|
||||||
expectedPlugins[pluginID] = true
|
|
||||||
} else {
|
|
||||||
t.Errorf("Unexpected plugin in processing list: %s", pluginID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for pluginID, found := range expectedPlugins {
|
|
||||||
if !found {
|
|
||||||
t.Errorf("Expected plugin %s to be in processing list", pluginID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Plugin configuration handling", func(t *testing.T) {
|
|
||||||
// Test the configuration logic from handleMessage
|
|
||||||
channel := &model.Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "123456",
|
|
||||||
Enabled: true,
|
|
||||||
EnableAllPlugins: true,
|
|
||||||
Plugins: map[string]*model.ChannelPlugin{
|
|
||||||
"ping": {
|
|
||||||
ID: 1,
|
|
||||||
ChannelID: 1,
|
|
||||||
PluginID: "ping",
|
|
||||||
Enabled: true,
|
|
||||||
Config: map[string]interface{}{"configured": "value"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
pluginID string
|
|
||||||
expectedConfig map[string]interface{}
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
pluginID: "ping",
|
|
||||||
expectedConfig: map[string]interface{}{"configured": "value"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
pluginID: "echo", // Not explicitly configured
|
|
||||||
expectedConfig: map[string]interface{}{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
// Simulate the config retrieval logic from handleMessage
|
|
||||||
var config map[string]interface{}
|
|
||||||
if channelPlugin, exists := channel.Plugins[tc.pluginID]; exists {
|
|
||||||
config = channelPlugin.Config
|
|
||||||
} else {
|
|
||||||
config = make(map[string]interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(config) != len(tc.expectedConfig) {
|
|
||||||
t.Errorf("Plugin %s: expected config length %d, got %d", tc.pluginID, len(tc.expectedConfig), len(config))
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, expectedValue := range tc.expectedConfig {
|
|
||||||
if actualValue, exists := config[key]; !exists || actualValue != expectedValue {
|
|
||||||
t.Errorf("Plugin %s: expected config[%s] = %v, got %v", tc.pluginID, key, expectedValue, actualValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPluginRegistry(t *testing.T) {
|
|
||||||
// Clear registry before test
|
|
||||||
ClearRegistry()
|
|
||||||
|
|
||||||
testPlugin := &testPlugin{
|
|
||||||
BasePlugin: BasePlugin{
|
|
||||||
ID: "test.registry",
|
|
||||||
Name: "Test Registry Plugin",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("Register and Get plugin", func(t *testing.T) {
|
|
||||||
Register(testPlugin)
|
|
||||||
|
|
||||||
retrieved, err := Get("test.registry")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to get registered plugin: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieved.GetID() != "test.registry" {
|
|
||||||
t.Errorf("Expected plugin ID 'test.registry', got '%s'", retrieved.GetID())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Get nonexistent plugin", func(t *testing.T) {
|
|
||||||
_, err := Get("nonexistent.plugin")
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error when getting nonexistent plugin, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != model.ErrPluginNotFound {
|
|
||||||
t.Errorf("Expected ErrPluginNotFound, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("GetAvailablePlugins", func(t *testing.T) {
|
|
||||||
plugins := GetAvailablePlugins()
|
|
||||||
|
|
||||||
if len(plugins) != 1 {
|
|
||||||
t.Errorf("Expected 1 plugin in registry, got %d", len(plugins))
|
|
||||||
}
|
|
||||||
|
|
||||||
if plugin, exists := plugins["test.registry"]; !exists {
|
|
||||||
t.Errorf("Expected to find test.registry in available plugins")
|
|
||||||
} else if plugin.GetID() != "test.registry" {
|
|
||||||
t.Errorf("Expected plugin ID 'test.registry', got '%s'", plugin.GetID())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ClearRegistry", func(t *testing.T) {
|
|
||||||
ClearRegistry()
|
|
||||||
|
|
||||||
plugins := GetAvailablePlugins()
|
|
||||||
if len(plugins) != 0 {
|
|
||||||
t.Errorf("Expected 0 plugins after clearing registry, got %d", len(plugins))
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := Get("test.registry")
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error when getting plugin after clearing registry, got nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,200 +0,0 @@
|
||||||
package reminder
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Duration regex patterns to match reminders
|
|
||||||
var (
|
|
||||||
remindMePattern = regexp.MustCompile(`(?i)^!remindme\s(\d+)(y|mo|d|h|m|s)$`)
|
|
||||||
)
|
|
||||||
|
|
||||||
// ReminderCreator is an interface for creating reminders
|
|
||||||
type ReminderCreator interface {
|
|
||||||
CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reminder is a plugin that sets reminders for messages
|
|
||||||
type Reminder struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
creator ReminderCreator
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new Reminder plugin
|
|
||||||
func New(creator ReminderCreator) *Reminder {
|
|
||||||
return &Reminder{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "reminder.remindme",
|
|
||||||
Name: "Remind Me",
|
|
||||||
Help: "Reply to a message with `!remindme <duration>` to set a reminder (e.g., `!remindme 2d` for 2 days, `!remindme 1y` for 1 year).",
|
|
||||||
ConfigRequired: false,
|
|
||||||
},
|
|
||||||
creator: creator,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage processes incoming messages
|
|
||||||
func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
// Only process replies to messages
|
|
||||||
if msg.ReplyTo == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the message is a reminder command
|
|
||||||
match := remindMePattern.FindStringSubmatch(msg.Text)
|
|
||||||
if match == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the duration
|
|
||||||
amount, err := strconv.Atoi(match[1])
|
|
||||||
if err != nil {
|
|
||||||
errorMsg := &model.Message{
|
|
||||||
Text: "Invalid duration format. Please use a number followed by y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).",
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
Author: "bot",
|
|
||||||
FromBot: true,
|
|
||||||
Date: time.Now(),
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: errorMsg,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the trigger time
|
|
||||||
var duration time.Duration
|
|
||||||
unit := match[2]
|
|
||||||
switch strings.ToLower(unit) {
|
|
||||||
case "y":
|
|
||||||
duration = time.Duration(amount) * 365 * 24 * time.Hour
|
|
||||||
case "mo":
|
|
||||||
duration = time.Duration(amount) * 30 * 24 * time.Hour
|
|
||||||
case "d":
|
|
||||||
duration = time.Duration(amount) * 24 * time.Hour
|
|
||||||
case "h":
|
|
||||||
duration = time.Duration(amount) * time.Hour
|
|
||||||
case "m":
|
|
||||||
duration = time.Duration(amount) * time.Minute
|
|
||||||
case "s":
|
|
||||||
duration = time.Duration(amount) * time.Second
|
|
||||||
default:
|
|
||||||
errorMsg := &model.Message{
|
|
||||||
Text: "Invalid duration unit. Please use y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).",
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
Author: "bot",
|
|
||||||
FromBot: true,
|
|
||||||
Date: time.Now(),
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: errorMsg,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
triggerAt := time.Now().Add(duration)
|
|
||||||
|
|
||||||
// Determine the username for the reminder
|
|
||||||
username := msg.Author
|
|
||||||
if username == "" {
|
|
||||||
// Try to extract username from message raw data
|
|
||||||
if authorData, ok := msg.Raw["author"].(map[string]interface{}); ok {
|
|
||||||
if name, ok := authorData["username"].(string); ok {
|
|
||||||
username = name
|
|
||||||
} else if name, ok := authorData["name"].(string); ok {
|
|
||||||
username = name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the reminder
|
|
||||||
_, err = r.creator.CreateReminder(
|
|
||||||
msg.Channel.Platform,
|
|
||||||
msg.Chat,
|
|
||||||
msg.ID,
|
|
||||||
msg.ReplyTo,
|
|
||||||
msg.Author,
|
|
||||||
username,
|
|
||||||
"", // No additional content for now
|
|
||||||
triggerAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
errorMsg := &model.Message{
|
|
||||||
Text: fmt.Sprintf("Failed to create reminder: %v", err),
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
Author: "bot",
|
|
||||||
FromBot: true,
|
|
||||||
Date: time.Now(),
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: errorMsg,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format the acknowledgment message
|
|
||||||
var confirmText string
|
|
||||||
switch strings.ToLower(unit) {
|
|
||||||
case "y":
|
|
||||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d year(s) on %s", amount, triggerAt.Format("Mon, Jan 2, 2006 at 15:04"))
|
|
||||||
case "mo":
|
|
||||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d month(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04"))
|
|
||||||
case "d":
|
|
||||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d day(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04"))
|
|
||||||
case "h":
|
|
||||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d hour(s) at %s", amount, triggerAt.Format("15:04"))
|
|
||||||
case "m":
|
|
||||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d minute(s) at %s", amount, triggerAt.Format("15:04"))
|
|
||||||
case "s":
|
|
||||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d second(s)", amount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create confirmation message
|
|
||||||
confirmMsg := &model.Message{
|
|
||||||
Text: confirmText,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
Author: "bot",
|
|
||||||
FromBot: true,
|
|
||||||
Date: time.Now(),
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: confirmMsg,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,177 +0,0 @@
|
||||||
package reminder
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockCreator is a mock implementation of ReminderCreator for testing
|
|
||||||
type MockCreator struct {
|
|
||||||
reminders []*model.Reminder
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockCreator) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
|
|
||||||
reminder := &model.Reminder{
|
|
||||||
ID: int64(len(m.reminders) + 1),
|
|
||||||
Platform: platform,
|
|
||||||
ChannelID: channelID,
|
|
||||||
MessageID: messageID,
|
|
||||||
ReplyToID: replyToID,
|
|
||||||
UserID: userID,
|
|
||||||
Username: username,
|
|
||||||
Content: content,
|
|
||||||
TriggerAt: triggerAt,
|
|
||||||
}
|
|
||||||
m.reminders = append(m.reminders, reminder)
|
|
||||||
return reminder, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReminderOnMessage(t *testing.T) {
|
|
||||||
creator := &MockCreator{reminders: make([]*model.Reminder, 0)}
|
|
||||||
plugin := New(creator)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
message *model.Message
|
|
||||||
expectResponse bool
|
|
||||||
expectReminder bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid reminder command - years",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "!remindme 1y",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: true,
|
|
||||||
expectReminder: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid reminder command - months",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "!remindme 3mo",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: true,
|
|
||||||
expectReminder: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid reminder command - days",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "!remindme 2d",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: true,
|
|
||||||
expectReminder: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid reminder command - hours",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "!remindme 5h",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: true,
|
|
||||||
expectReminder: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid reminder command - minutes",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "!remindme 30m",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: true,
|
|
||||||
expectReminder: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid reminder command - seconds",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "!remindme 60s",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: true,
|
|
||||||
expectReminder: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Not a reply",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "!remindme 2d",
|
|
||||||
ReplyTo: "",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: false,
|
|
||||||
expectReminder: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Not a reminder command",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "hello world",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: false,
|
|
||||||
expectReminder: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid duration format",
|
|
||||||
message: &model.Message{
|
|
||||||
Text: "!remindme abc",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Author: "testuser",
|
|
||||||
Channel: &model.Channel{Platform: "test"},
|
|
||||||
},
|
|
||||||
expectResponse: false,
|
|
||||||
expectReminder: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
initialCount := len(creator.reminders)
|
|
||||||
mockCache := &testutil.MockCache{}
|
|
||||||
actions := plugin.OnMessage(tt.message, nil, mockCache)
|
|
||||||
|
|
||||||
if tt.expectResponse && len(actions) == 0 {
|
|
||||||
t.Errorf("Expected response action, but got none")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tt.expectResponse && len(actions) > 0 {
|
|
||||||
t.Errorf("Expected no actions, but got %d", len(actions))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify action type is correct when actions are returned
|
|
||||||
if len(actions) > 0 {
|
|
||||||
if actions[0].Type != model.ActionSendMessage {
|
|
||||||
t.Errorf("Expected action type to be %s, but got %s", model.ActionSendMessage, actions[0].Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
if actions[0].Message == nil {
|
|
||||||
t.Errorf("Expected message in action to not be nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.expectReminder && len(creator.reminders) != initialCount+1 {
|
|
||||||
t.Errorf("Expected reminder to be created, but it wasn't")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tt.expectReminder && len(creator.reminders) != initialCount {
|
|
||||||
t.Errorf("Expected no reminder to be created, but got %d", len(creator.reminders)-initialCount)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,50 +0,0 @@
|
||||||
# Search and Replace Plugin
|
|
||||||
|
|
||||||
This plugin allows users to perform search and replace operations on messages by replying to a message with a search/replace command.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To use the plugin, reply to any message with a command in the following format:
|
|
||||||
|
|
||||||
```
|
|
||||||
s/search/replace/[flags]
|
|
||||||
```
|
|
||||||
|
|
||||||
Where:
|
|
||||||
- `search` is the text you want to find (case-sensitive by default)
|
|
||||||
- `replace` is the text you want to substitute in place of the search term
|
|
||||||
- `flags` (optional) control the behavior of the replacement
|
|
||||||
|
|
||||||
### Supported Flags
|
|
||||||
|
|
||||||
- `g` - Global: Replace all occurrences of the search term (without this flag, only the first occurrence is replaced)
|
|
||||||
- `i` - Case insensitive: Match regardless of case
|
|
||||||
- `n` - Treat search pattern as a regular expression (advanced users)
|
|
||||||
|
|
||||||
### Examples
|
|
||||||
|
|
||||||
1. Basic replacement (replaces first occurrence):
|
|
||||||
```
|
|
||||||
s/hello/hi/
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Global replacement (replaces all occurrences):
|
|
||||||
```
|
|
||||||
s/hello/hi/g
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Case-insensitive replacement:
|
|
||||||
```
|
|
||||||
s/Hello/hi/i
|
|
||||||
```
|
|
||||||
|
|
||||||
4. Combined flags (global and case-insensitive):
|
|
||||||
```
|
|
||||||
s/hello/hi/gi
|
|
||||||
```
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- The plugin can only access the text content of the original message
|
|
||||||
- Regular expression support is available with the `n` flag, but should be used carefully as invalid regex patterns will cause errors
|
|
||||||
- The plugin does not modify the original message; it creates a new message with the replaced text
|
|
|
@ -1,182 +0,0 @@
|
||||||
package searchreplace
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Regex pattern for search and replace operations: s/search/replace/[flags]
|
|
||||||
var searchReplacePattern = regexp.MustCompile(`^s/([^/]*)/([^/]*)(?:/([gimnsuy]*))?$`)
|
|
||||||
|
|
||||||
// SearchReplacePlugin is a plugin for performing search and replace operations on messages
|
|
||||||
type SearchReplacePlugin struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new SearchReplacePlugin instance
|
|
||||||
func New() *SearchReplacePlugin {
|
|
||||||
return &SearchReplacePlugin{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "util.searchreplace",
|
|
||||||
Name: "Search and Replace",
|
|
||||||
Help: "Reply to a message with a search and replace pattern (`s/search/replace/[flags]`) to create a modified message. " +
|
|
||||||
"Supported flags: g (global), i (case insensitive)",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
|
||||||
func (p *SearchReplacePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
// Only process replies to messages
|
|
||||||
if msg.ReplyTo == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the message matches the search/replace pattern
|
|
||||||
match := searchReplacePattern.FindStringSubmatch(strings.TrimSpace(msg.Text))
|
|
||||||
if match == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the original message text from the reply_to_message structure in Telegram messages
|
|
||||||
var originalText string
|
|
||||||
|
|
||||||
// For Telegram messages
|
|
||||||
if msgData, ok := msg.Raw["message"].(map[string]interface{}); ok {
|
|
||||||
if replyMsg, ok := msgData["reply_to_message"].(map[string]interface{}); ok {
|
|
||||||
if text, ok := replyMsg["text"].(string); ok {
|
|
||||||
originalText = text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generic fallback for other platforms or if the above method fails
|
|
||||||
if originalText == "" && msg.Raw["original_message"] != nil {
|
|
||||||
if original, ok := msg.Raw["original_message"].(map[string]interface{}); ok {
|
|
||||||
if text, ok := original["text"].(string); ok {
|
|
||||||
originalText = text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if originalText == "" {
|
|
||||||
// If we couldn't find the original message text, inform the user
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: &model.Message{
|
|
||||||
Text: "Sorry, I couldn't find the original message text to perform the replacement.",
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
},
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract search pattern, replacement and flags
|
|
||||||
searchPattern := match[1]
|
|
||||||
replacement := match[2]
|
|
||||||
flags := ""
|
|
||||||
if len(match) > 3 {
|
|
||||||
flags = match[3]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the replacement
|
|
||||||
result, err := p.performReplacement(originalText, searchPattern, replacement, flags)
|
|
||||||
if err != nil {
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: &model.Message{
|
|
||||||
Text: fmt.Sprintf("Error performing replacement: %s", err.Error()),
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
},
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only send a response if the text actually changed
|
|
||||||
if result == originalText {
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: &model.Message{
|
|
||||||
Text: "No changes were made to the original message.",
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
},
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a response with the modified text
|
|
||||||
return []*model.MessageAction{
|
|
||||||
{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: &model.Message{
|
|
||||||
Text: result,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
ReplyTo: msg.ReplyTo, // Reply to the original message
|
|
||||||
},
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// performReplacement performs the search and replace operation on the given text
|
|
||||||
func (p *SearchReplacePlugin) performReplacement(text, search, replace, flags string) (string, error) {
|
|
||||||
// Process flags
|
|
||||||
globalReplace := strings.Contains(flags, "g")
|
|
||||||
caseInsensitive := strings.Contains(flags, "i")
|
|
||||||
|
|
||||||
// Create the regex pattern
|
|
||||||
pattern := search
|
|
||||||
regexFlags := ""
|
|
||||||
if caseInsensitive {
|
|
||||||
regexFlags += "(?i)"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Escape special characters if we're not in a regular expression
|
|
||||||
if !strings.Contains(flags, "n") {
|
|
||||||
pattern = regexp.QuoteMeta(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile the regex
|
|
||||||
reg, err := regexp.Compile(regexFlags + pattern)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("invalid search pattern: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform the replacement
|
|
||||||
var result string
|
|
||||||
if globalReplace {
|
|
||||||
result = reg.ReplaceAllString(text, replace)
|
|
||||||
} else {
|
|
||||||
// For non-global replace, only replace the first occurrence
|
|
||||||
indices := reg.FindStringIndex(text)
|
|
||||||
if indices == nil {
|
|
||||||
// No match found
|
|
||||||
return text, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result = text[:indices[0]] + replace + text[indices[1]:]
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
|
@ -1,218 +0,0 @@
|
||||||
package searchreplace
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/testutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSearchReplace(t *testing.T) {
|
|
||||||
// Create plugin instance
|
|
||||||
p := New()
|
|
||||||
|
|
||||||
// Test cases
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
command string
|
|
||||||
originalText string
|
|
||||||
expectedResult string
|
|
||||||
expectActions bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Simple replacement",
|
|
||||||
command: "s/hello/world/",
|
|
||||||
originalText: "hello everyone",
|
|
||||||
expectedResult: "world everyone",
|
|
||||||
expectActions: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Case-insensitive replacement",
|
|
||||||
command: "s/HELLO/world/i",
|
|
||||||
originalText: "Hello everyone",
|
|
||||||
expectedResult: "world everyone",
|
|
||||||
expectActions: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Global replacement",
|
|
||||||
command: "s/a/X/g",
|
|
||||||
originalText: "banana",
|
|
||||||
expectedResult: "bXnXnX",
|
|
||||||
expectActions: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No change",
|
|
||||||
command: "s/nothing/something/",
|
|
||||||
originalText: "test message",
|
|
||||||
expectedResult: "test message",
|
|
||||||
expectActions: true, // We send a "no changes" message
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Not a search/replace command",
|
|
||||||
command: "hello",
|
|
||||||
originalText: "test message",
|
|
||||||
expectedResult: "",
|
|
||||||
expectActions: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid pattern",
|
|
||||||
command: "s/(/)/",
|
|
||||||
originalText: "test message",
|
|
||||||
expectedResult: "error",
|
|
||||||
expectActions: true, // We send an error message
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
// Create message
|
|
||||||
msg := &model.Message{
|
|
||||||
Text: tc.command,
|
|
||||||
Chat: "test-chat",
|
|
||||||
ReplyTo: "original-message-id",
|
|
||||||
Date: time.Now(),
|
|
||||||
Channel: &model.Channel{
|
|
||||||
Platform: "test",
|
|
||||||
},
|
|
||||||
Raw: map[string]interface{}{
|
|
||||||
"message": map[string]interface{}{
|
|
||||||
"reply_to_message": map[string]interface{}{
|
|
||||||
"text": tc.originalText,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process message
|
|
||||||
mockCache := &testutil.MockCache{}
|
|
||||||
actions := p.OnMessage(msg, nil, mockCache)
|
|
||||||
|
|
||||||
// Check results
|
|
||||||
if tc.expectActions {
|
|
||||||
if len(actions) == 0 {
|
|
||||||
t.Fatalf("Expected actions but got none")
|
|
||||||
}
|
|
||||||
|
|
||||||
action := actions[0]
|
|
||||||
if action.Type != model.ActionSendMessage {
|
|
||||||
t.Fatalf("Expected send message action but got %v", action.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.expectedResult == "error" {
|
|
||||||
// Just checking that we got an error message
|
|
||||||
if action.Message == nil || action.Message.Text == "" {
|
|
||||||
t.Fatalf("Expected error message but got empty message")
|
|
||||||
}
|
|
||||||
} else if tc.originalText == tc.expectedResult {
|
|
||||||
// Check if we got the "no changes" message
|
|
||||||
if action.Message == nil || action.Message.Text != "No changes were made to the original message." {
|
|
||||||
t.Fatalf("Expected 'no changes' message but got: %s", action.Message.Text)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Check actual replacement result
|
|
||||||
if action.Message == nil || action.Message.Text != tc.expectedResult {
|
|
||||||
t.Fatalf("Expected result: %s, got: %s", tc.expectedResult, action.Message.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if len(actions) > 0 {
|
|
||||||
t.Fatalf("Expected no actions but got %d", len(actions))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPerformReplacement(t *testing.T) {
|
|
||||||
p := New()
|
|
||||||
|
|
||||||
// Test cases for the performReplacement function
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
text string
|
|
||||||
search string
|
|
||||||
replace string
|
|
||||||
flags string
|
|
||||||
expected string
|
|
||||||
expectErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Simple replacement",
|
|
||||||
text: "Hello World",
|
|
||||||
search: "Hello",
|
|
||||||
replace: "Hi",
|
|
||||||
flags: "",
|
|
||||||
expected: "Hi World",
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Case insensitive",
|
|
||||||
text: "Hello World",
|
|
||||||
search: "hello",
|
|
||||||
replace: "Hi",
|
|
||||||
flags: "i",
|
|
||||||
expected: "Hi World",
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Global replacement",
|
|
||||||
text: "one two one two",
|
|
||||||
search: "one",
|
|
||||||
replace: "1",
|
|
||||||
flags: "g",
|
|
||||||
expected: "1 two 1 two",
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No match",
|
|
||||||
text: "Hello World",
|
|
||||||
search: "Goodbye",
|
|
||||||
replace: "Hi",
|
|
||||||
flags: "",
|
|
||||||
expected: "Hello World",
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid regex",
|
|
||||||
text: "Hello World",
|
|
||||||
search: "(",
|
|
||||||
replace: "Hi",
|
|
||||||
flags: "n", // treat as regex
|
|
||||||
expected: "",
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Escape special chars by default",
|
|
||||||
text: "Hello (World)",
|
|
||||||
search: "(World)",
|
|
||||||
replace: "[Earth]",
|
|
||||||
flags: "",
|
|
||||||
expected: "Hello [Earth]",
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Regex mode with n flag",
|
|
||||||
text: "Hello (World)",
|
|
||||||
search: "\\(World\\)",
|
|
||||||
replace: "[Earth]",
|
|
||||||
flags: "n",
|
|
||||||
expected: "Hello [Earth]",
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
result, err := p.performReplacement(tc.text, tc.search, tc.replace, tc.flags)
|
|
||||||
|
|
||||||
if tc.expectErr {
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("Expected error but got none")
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
} else if result != tc.expected {
|
|
||||||
t.Fatalf("Expected result: %s, got: %s", tc.expected, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,92 +0,0 @@
|
||||||
package social
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// InstagramExpander transforms instagram.com links to ddinstagram.com links
|
|
||||||
type InstagramExpander struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new InstagramExpander instance
|
|
||||||
func NewInstagramExpander() *InstagramExpander {
|
|
||||||
return &InstagramExpander{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "social.instagram",
|
|
||||||
Name: "Instagram Link Expander",
|
|
||||||
Help: "Automatically converts instagram.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: ddinstagram.com)",
|
|
||||||
ConfigRequired: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
|
||||||
func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
// Skip empty messages
|
|
||||||
if strings.TrimSpace(msg.Text) == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get replacement domain from config, default to ddinstagram.com
|
|
||||||
replacementDomain := "ddinstagram.com"
|
|
||||||
if domain, ok := config["domain"].(string); ok && domain != "" {
|
|
||||||
replacementDomain = domain
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regex to match instagram.com links
|
|
||||||
// Match both http://instagram.com and https://instagram.com formats
|
|
||||||
// Also match www.instagram.com
|
|
||||||
instagramRegex := regexp.MustCompile(`https?://(www\.)?(instagram\.com)/[^\s]+`)
|
|
||||||
|
|
||||||
// Check if the message contains an Instagram link
|
|
||||||
if !instagramRegex.MatchString(msg.Text) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace instagram.com with configured domain in the message and clean query parameters
|
|
||||||
transformed := instagramRegex.ReplaceAllStringFunc(msg.Text, func(link string) string {
|
|
||||||
// Parse the URL
|
|
||||||
parsedURL, err := url.Parse(link)
|
|
||||||
if err != nil {
|
|
||||||
// If parsing fails, just do the simple replacement
|
|
||||||
return link
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure we don't change links that already come from the replacement domain
|
|
||||||
if parsedURL.Host != "instagram.com" && parsedURL.Host != "www.instagram.com" {
|
|
||||||
return link
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change the host to the configured domain
|
|
||||||
parsedURL.Host = replacementDomain
|
|
||||||
|
|
||||||
// Remove query parameters
|
|
||||||
parsedURL.RawQuery = ""
|
|
||||||
|
|
||||||
// Return the cleaned URL
|
|
||||||
return parsedURL.String()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create response message
|
|
||||||
response := &model.Message{
|
|
||||||
Text: transformed,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
action := &model.MessageAction{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{action}
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
package social
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TwitterExpander transforms twitter.com links to fxtwitter.com links
|
|
||||||
type TwitterExpander struct {
|
|
||||||
plugin.BasePlugin
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new TwitterExpander instance
|
|
||||||
func NewTwitterExpander() *TwitterExpander {
|
|
||||||
return &TwitterExpander{
|
|
||||||
BasePlugin: plugin.BasePlugin{
|
|
||||||
ID: "social.twitter",
|
|
||||||
Name: "Twitter Link Expander",
|
|
||||||
Help: "Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: fxtwitter.com)",
|
|
||||||
ConfigRequired: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnMessage handles incoming messages
|
|
||||||
func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction {
|
|
||||||
// Skip empty messages
|
|
||||||
if strings.TrimSpace(msg.Text) == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get replacement domain from config, default to fxtwitter.com
|
|
||||||
replacementDomain := "fxtwitter.com"
|
|
||||||
if domain, ok := config["domain"].(string); ok && domain != "" {
|
|
||||||
replacementDomain = domain
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regex to match twitter.com links
|
|
||||||
// Match both http://twitter.com and https://twitter.com formats
|
|
||||||
// Also match www.twitter.com
|
|
||||||
twitterRegex := regexp.MustCompile(`https?://(www\.)?(twitter\.com|x\.com)/[^\s]+`)
|
|
||||||
|
|
||||||
// Check if the message contains a Twitter link
|
|
||||||
if !twitterRegex.MatchString(msg.Text) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace twitter.com/x.com with configured domain in the message and clean query parameters
|
|
||||||
transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string {
|
|
||||||
// Parse the URL
|
|
||||||
parsedURL, err := url.Parse(link)
|
|
||||||
if err != nil {
|
|
||||||
return link
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change the host to the configured domain
|
|
||||||
if strings.Contains(parsedURL.Host, "twitter.com") || strings.Contains(parsedURL.Host, "x.com") {
|
|
||||||
parsedURL.Host = replacementDomain
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove query parameters
|
|
||||||
parsedURL.RawQuery = ""
|
|
||||||
|
|
||||||
// Return the cleaned URL
|
|
||||||
return parsedURL.String()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create response message
|
|
||||||
response := &model.Message{
|
|
||||||
Text: transformed,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
ReplyTo: msg.ID,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
action := &model.MessageAction{
|
|
||||||
Type: model.ActionSendMessage,
|
|
||||||
Message: response,
|
|
||||||
Chat: msg.Chat,
|
|
||||||
Channel: msg.Channel,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []*model.MessageAction{action}
|
|
||||||
}
|
|
|
@ -1,120 +0,0 @@
|
||||||
package social
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestTwitterExpander_OnMessage(t *testing.T) {
|
|
||||||
plugin := NewTwitterExpander()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
config map[string]interface{}
|
|
||||||
expected string
|
|
||||||
hasReply bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Twitter URL with default domain",
|
|
||||||
input: "https://twitter.com/user/status/123456789",
|
|
||||||
config: map[string]interface{}{},
|
|
||||||
expected: "https://fxtwitter.com/user/status/123456789",
|
|
||||||
hasReply: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "X.com URL with custom domain",
|
|
||||||
input: "https://x.com/elonmusk/status/987654321",
|
|
||||||
config: map[string]interface{}{"domain": "vxtwitter.com"},
|
|
||||||
expected: "https://vxtwitter.com/elonmusk/status/987654321",
|
|
||||||
hasReply: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Twitter URL with tracking parameters",
|
|
||||||
input: "https://twitter.com/openai/status/555?ref_src=twsrc%5Etfw&s=20",
|
|
||||||
config: map[string]interface{}{},
|
|
||||||
expected: "https://fxtwitter.com/openai/status/555",
|
|
||||||
hasReply: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "www.twitter.com URL",
|
|
||||||
input: "https://www.twitter.com/user/status/789",
|
|
||||||
config: map[string]interface{}{"domain": "nitter.net"},
|
|
||||||
expected: "https://nitter.net/user/status/789",
|
|
||||||
hasReply: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Mixed text with Twitter URL",
|
|
||||||
input: "Check this out: https://twitter.com/user/status/123 amazing!",
|
|
||||||
config: map[string]interface{}{},
|
|
||||||
expected: "Check this out: https://fxtwitter.com/user/status/123 amazing!",
|
|
||||||
hasReply: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No Twitter URLs",
|
|
||||||
input: "Just some regular text https://youtube.com/watch?v=abc",
|
|
||||||
config: map[string]interface{}{},
|
|
||||||
expected: "",
|
|
||||||
hasReply: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty message",
|
|
||||||
input: "",
|
|
||||||
config: map[string]interface{}{},
|
|
||||||
expected: "",
|
|
||||||
hasReply: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
msg := &model.Message{
|
|
||||||
ID: "test_msg",
|
|
||||||
Text: tt.input,
|
|
||||||
Chat: "test_chat",
|
|
||||||
Channel: &model.Channel{
|
|
||||||
ID: 1,
|
|
||||||
Platform: "telegram",
|
|
||||||
PlatformChannelID: "test_chat",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
actions := plugin.OnMessage(msg, tt.config, nil)
|
|
||||||
|
|
||||||
if !tt.hasReply {
|
|
||||||
if len(actions) != 0 {
|
|
||||||
t.Errorf("Expected no actions, got %d", len(actions))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(actions) != 1 {
|
|
||||||
t.Errorf("Expected 1 action, got %d", len(actions))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
action := actions[0]
|
|
||||||
if action.Type != model.ActionSendMessage {
|
|
||||||
t.Errorf("Expected ActionSendMessage, got %s", action.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
if action.Message == nil {
|
|
||||||
t.Error("Expected message in action, got nil")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if action.Message.Text != tt.expected {
|
|
||||||
t.Errorf("Expected '%s', got '%s'", tt.expected, action.Message.Text)
|
|
||||||
}
|
|
||||||
|
|
||||||
if action.Message.ReplyTo != msg.ID {
|
|
||||||
t.Errorf("Expected ReplyTo '%s', got '%s'", msg.ID, action.Message.ReplyTo)
|
|
||||||
}
|
|
||||||
|
|
||||||
if action.Message.Raw == nil || action.Message.Raw["parse_mode"] != "" {
|
|
||||||
t.Error("Expected parse_mode to be empty string to disable markdown parsing")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,9 +3,6 @@ package queue
|
||||||
import (
|
import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Item represents a queue item
|
// Item represents a queue item
|
||||||
|
@ -17,19 +14,14 @@ type Item struct {
|
||||||
// HandlerFunc defines a function that processes queue items
|
// HandlerFunc defines a function that processes queue items
|
||||||
type HandlerFunc func(item Item)
|
type HandlerFunc func(item Item)
|
||||||
|
|
||||||
// ReminderHandlerFunc defines a function that processes reminder items
|
|
||||||
type ReminderHandlerFunc func(reminder *model.Reminder)
|
|
||||||
|
|
||||||
// Queue represents a message queue
|
// Queue represents a message queue
|
||||||
type Queue struct {
|
type Queue struct {
|
||||||
items chan Item
|
items chan Item
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
running bool
|
running bool
|
||||||
runMutex sync.Mutex
|
runMutex sync.Mutex
|
||||||
reminderTicker *time.Ticker
|
|
||||||
reminderHandler ReminderHandlerFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Queue instance
|
// New creates a new Queue instance
|
||||||
|
@ -57,24 +49,6 @@ func (q *Queue) Start(handler HandlerFunc) {
|
||||||
go q.worker(handler)
|
go q.worker(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartReminderScheduler starts the reminder scheduler
|
|
||||||
func (q *Queue) StartReminderScheduler(handler ReminderHandlerFunc) {
|
|
||||||
q.runMutex.Lock()
|
|
||||||
defer q.runMutex.Unlock()
|
|
||||||
|
|
||||||
if q.reminderTicker != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
q.reminderHandler = handler
|
|
||||||
|
|
||||||
// Check for reminders every minute
|
|
||||||
q.reminderTicker = time.NewTicker(1 * time.Minute)
|
|
||||||
|
|
||||||
q.wg.Add(1)
|
|
||||||
go q.reminderWorker()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops processing queue items
|
// Stop stops processing queue items
|
||||||
func (q *Queue) Stop() {
|
func (q *Queue) Stop() {
|
||||||
q.runMutex.Lock()
|
q.runMutex.Lock()
|
||||||
|
@ -85,12 +59,6 @@ func (q *Queue) Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
q.running = false
|
q.running = false
|
||||||
|
|
||||||
// Stop reminder ticker if it exists
|
|
||||||
if q.reminderTicker != nil {
|
|
||||||
q.reminderTicker.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
close(q.quit)
|
close(q.quit)
|
||||||
q.wg.Wait()
|
q.wg.Wait()
|
||||||
}
|
}
|
||||||
|
@ -128,34 +96,4 @@ func (q *Queue) worker(handler HandlerFunc) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reminderWorker processes reminder items on a schedule
|
|
||||||
func (q *Queue) reminderWorker() {
|
|
||||||
defer q.wg.Done()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-q.reminderTicker.C:
|
|
||||||
// This is triggered every minute to check for pending reminders
|
|
||||||
q.logger.Debug("Checking for pending reminders")
|
|
||||||
|
|
||||||
if q.reminderHandler != nil {
|
|
||||||
// The handler is responsible for fetching and processing reminders
|
|
||||||
func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
q.logger.Error("Panic in reminder worker", "error", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Call the handler with a nil reminder to indicate it should check the database
|
|
||||||
q.reminderHandler(nil)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
case <-q.quit:
|
|
||||||
// Quit worker
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,29 +0,0 @@
|
||||||
package testutil
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockCache implements the CacheInterface for testing
|
|
||||||
type MockCache struct{}
|
|
||||||
|
|
||||||
func (m *MockCache) Get(key string, destination interface{}) error {
|
|
||||||
return errors.New("cache miss") // Always return cache miss for tests
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockCache) Set(key string, value interface{}, expiration *time.Time) error {
|
|
||||||
return nil // Always succeed for tests
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockCache) SetWithTTL(key string, value interface{}, ttl time.Duration) error {
|
|
||||||
return nil // Always succeed for tests
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockCache) Delete(key string) error {
|
|
||||||
return nil // Always succeed for tests
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockCache) Exists(key string) (bool, error) {
|
|
||||||
return false, nil // Always return false for tests
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue