Compare commits
5 commits
Author | SHA1 | Date | |
---|---|---|---|
c9edb57505 | |||
763a451251 | |||
abcd3c3c44 | |||
323ea4e8cd | |||
72c6dd6982 |
18 changed files with 774 additions and 95 deletions
|
@ -3,7 +3,7 @@ when:
|
||||||
- push
|
- push
|
||||||
- pull_request
|
- pull_request
|
||||||
branch:
|
branch:
|
||||||
- main
|
- master
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
format:
|
format:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
when:
|
when:
|
||||||
- event: tag
|
- event: tag
|
||||||
branch: main
|
branch: master
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Release
|
- name: Release
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
# Butter Robot
|
# Butter Robot
|
||||||
|
|
||||||
| Stable | Master |
|

|
||||||
| --- | --- |
|
|
||||||
|  |  |
|
|
||||||
|  |  |
|
|
||||||
|
|
||||||
Go framework to create bots for several platforms.
|
Go framework to create bots for several platforms.
|
||||||
|
|
||||||
|
@ -13,7 +10,7 @@ Go framework to create bots for several platforms.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Support for multiple chat platforms (Slack, Telegram)
|
- Support for multiple chat platforms (Slack (untested!), Telegram)
|
||||||
- Plugin system for easy extension
|
- Plugin system for easy extension
|
||||||
- Admin interface for managing channels and plugins
|
- Admin interface for managing channels and plugins
|
||||||
- Message queue for asynchronous processing
|
- Message queue for asynchronous processing
|
||||||
|
|
|
@ -10,6 +10,10 @@
|
||||||
- Dice: Put `!dice` and wathever roll you want to perform.
|
- Dice: Put `!dice` and wathever roll you want to perform.
|
||||||
- Coin: Flip a coin and get heads or tails.
|
- Coin: Flip a coin and get heads or tails.
|
||||||
|
|
||||||
|
### Utility
|
||||||
|
|
||||||
|
- Remind Me: Reply to a message with `!remindme <duration>` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time.
|
||||||
|
|
||||||
### Social Media
|
### Social Media
|
||||||
|
|
||||||
- Twitter Link Expander: Automatically converts twitter.com and x.com links to fxtwitter.com links and removes tracking parameters. This allows for better media embedding in chat platforms.
|
- Twitter Link Expander: Automatically converts twitter.com and x.com links to fxtwitter.com links and removes tracking parameters. This allows for better media embedding in chat platforms.
|
||||||
|
|
|
@ -106,19 +106,19 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a clone of the base template
|
// Create a clone of the base template
|
||||||
t, err := baseTemplate.Clone()
|
t, err := baseTemplate.Clone()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the template content
|
// Parse the template content
|
||||||
t, err = t.Parse(string(content))
|
t, err = t.Parse(string(content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
templates[tf] = t
|
templates[tf] = t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,7 +194,7 @@ func (a *Admin) addFlash(w http.ResponseWriter, r *http.Request, message string,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map internal categories to Bootstrap alert classes
|
// Map internal categories to Bootstrap alert classes
|
||||||
alertClass := category
|
var alertClass string
|
||||||
switch category {
|
switch category {
|
||||||
case "success":
|
case "success":
|
||||||
alertClass = "success"
|
alertClass = "success"
|
||||||
|
@ -249,17 +249,6 @@ func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessag
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
// requireLogin middleware checks if the user is logged in
|
|
||||||
func (a *Admin) requireLogin(next http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !a.isLoggedIn(r) {
|
|
||||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// render renders a template with the given data
|
// render renders a template with the given data
|
||||||
func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) {
|
func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) {
|
||||||
// Add current user data
|
// Add current user data
|
||||||
|
@ -334,7 +323,10 @@ func (a *Admin) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// Set session expiration
|
// Set session expiration
|
||||||
session.Options.MaxAge = 3600 * 24 * 7 // 1 week
|
session.Options.MaxAge = 3600 * 24 * 7 // 1 week
|
||||||
session.Save(r, w)
|
err = session.Save(r, w)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error saving session: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
a.addFlash(w, r, "You were logged in", "success")
|
a.addFlash(w, r, "You were logged in", "success")
|
||||||
|
|
||||||
|
@ -362,7 +354,7 @@ func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Values = make(map[interface{}]interface{})
|
session.Values = make(map[interface{}]interface{})
|
||||||
session.Options.MaxAge = -1 // Delete session
|
session.Options.MaxAge = -1 // Delete session
|
||||||
err = session.Save(r, w)
|
err = session.Save(r, w)
|
||||||
|
|
|
@ -17,10 +17,12 @@ import (
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/admin"
|
"git.nakama.town/fmartingr/butterrobot/internal/admin"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/config"
|
"git.nakama.town/fmartingr/butterrobot/internal/config"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/platform"
|
"git.nakama.town/fmartingr/butterrobot/internal/platform"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun"
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping"
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping"
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/social"
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/social"
|
||||||
"git.nakama.town/fmartingr/butterrobot/internal/queue"
|
"git.nakama.town/fmartingr/butterrobot/internal/queue"
|
||||||
)
|
)
|
||||||
|
@ -86,12 +88,19 @@ func (a *App) Run() error {
|
||||||
plugin.Register(social.NewTwitterExpander())
|
plugin.Register(social.NewTwitterExpander())
|
||||||
plugin.Register(social.NewInstagramExpander())
|
plugin.Register(social.NewInstagramExpander())
|
||||||
|
|
||||||
|
// Register reminder plugin
|
||||||
|
reminderPlugin := reminder.New(a.db)
|
||||||
|
plugin.Register(reminderPlugin)
|
||||||
|
|
||||||
// Initialize routes
|
// Initialize routes
|
||||||
a.initializeRoutes()
|
a.initializeRoutes()
|
||||||
|
|
||||||
// Start message queue worker
|
// Start message queue worker
|
||||||
a.queue.Start(a.handleMessage)
|
a.queue.Start(a.handleMessage)
|
||||||
|
|
||||||
|
// Start reminder scheduler
|
||||||
|
a.queue.StartReminderScheduler(a.handleReminder)
|
||||||
|
|
||||||
// Create server
|
// Create server
|
||||||
addr := fmt.Sprintf(":%s", a.config.Port)
|
addr := fmt.Sprintf(":%s", a.config.Port)
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
|
@ -143,7 +152,9 @@ func (a *App) initializeRoutes() {
|
||||||
a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{})
|
if err := json.NewEncoder(w).Encode(map[string]interface{}{}); err != nil {
|
||||||
|
a.logger.Error("Error encoding response", "error", err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Platform webhook endpoints
|
// Platform webhook endpoints
|
||||||
|
@ -166,7 +177,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
|
||||||
if _, err := platform.Get(platformName); err != nil {
|
if _, err := platform.Get(platformName); err != nil {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"})
|
if err := json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}); err != nil {
|
||||||
|
a.logger.Error("Error encoding response", "error", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,7 +188,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"})
|
if err := json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}); err != nil {
|
||||||
|
a.logger.Error("Error encoding response", "error", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,7 +206,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
|
||||||
// Respond with success
|
// Respond with success
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]any{})
|
if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil {
|
||||||
|
a.logger.Error("Error encoding response", "error", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractPlatformName extracts the platform name from the URL path
|
// extractPlatformName extracts the platform name from the URL path
|
||||||
|
@ -304,3 +321,73 @@ func (a *App) handleMessage(item queue.Item) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleReminder handles reminder processing
|
||||||
|
func (a *App) handleReminder(reminder *model.Reminder) {
|
||||||
|
// When called with nil, it means we should check for pending reminders
|
||||||
|
if reminder == nil {
|
||||||
|
// Get pending reminders
|
||||||
|
reminders, err := a.db.GetPendingReminders()
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error("Error getting pending reminders", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each reminder
|
||||||
|
for _, r := range reminders {
|
||||||
|
a.processReminder(r)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, process the specific reminder
|
||||||
|
a.processReminder(reminder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processReminder processes an individual reminder
|
||||||
|
func (a *App) processReminder(reminder *model.Reminder) {
|
||||||
|
a.logger.Info("Processing reminder",
|
||||||
|
"id", reminder.ID,
|
||||||
|
"platform", reminder.Platform,
|
||||||
|
"channel", reminder.ChannelID,
|
||||||
|
"trigger_at", reminder.TriggerAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get the platform handler
|
||||||
|
p, err := platform.Get(reminder.Platform)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error("Error getting platform for reminder", "error", err, "platform", reminder.Platform)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the channel
|
||||||
|
channel, err := a.db.GetChannelByPlatform(reminder.Platform, reminder.ChannelID)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error("Error getting channel for reminder", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the reminder message
|
||||||
|
reminderText := fmt.Sprintf("@%s reminding you of this", reminder.Username)
|
||||||
|
|
||||||
|
message := &model.Message{
|
||||||
|
Text: reminderText,
|
||||||
|
Chat: reminder.ChannelID,
|
||||||
|
Channel: channel,
|
||||||
|
Author: "bot",
|
||||||
|
FromBot: true,
|
||||||
|
Date: time.Now(),
|
||||||
|
ReplyTo: reminder.ReplyToID, // Reply to the original message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the reminder message
|
||||||
|
if err := p.SendMessage(message); err != nil {
|
||||||
|
a.logger.Error("Error sending reminder", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the reminder as processed
|
||||||
|
if err := a.db.MarkReminderAsProcessed(reminder.ID); err != nil {
|
||||||
|
a.logger.Error("Error marking reminder as processed", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
|
@ -233,7 +234,11 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() {
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
fmt.Printf("Error closing rows: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
var plugins []*model.ChannelPlugin
|
var plugins []*model.ChannelPlugin
|
||||||
|
|
||||||
|
@ -414,7 +419,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() {
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
fmt.Printf("Error closing rows: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
var channels []*model.Channel
|
var channels []*model.Channel
|
||||||
|
|
||||||
|
@ -453,10 +462,9 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) {
|
||||||
continue // Skip this channel if plugins can't be retrieved
|
continue // Skip this channel if plugins can't be retrieved
|
||||||
}
|
}
|
||||||
|
|
||||||
if plugins != nil {
|
// Add plugins to channel
|
||||||
for _, plugin := range plugins {
|
for _, plugin := range plugins {
|
||||||
channel.Plugins[plugin.PluginID] = plugin
|
channel.Plugins[plugin.PluginID] = plugin
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
channels = append(channels, channel)
|
channels = append(channels, channel)
|
||||||
|
@ -591,6 +599,124 @@ func (d *Database) UpdateUserPassword(userID int64, newPassword string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateReminder creates a new reminder
|
||||||
|
func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
|
||||||
|
query := `
|
||||||
|
INSERT INTO reminders (
|
||||||
|
platform, channel_id, message_id, reply_to_id,
|
||||||
|
user_id, username, created_at, trigger_at,
|
||||||
|
content, processed
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||||
|
`
|
||||||
|
|
||||||
|
createdAt := time.Now()
|
||||||
|
result, err := d.db.Exec(
|
||||||
|
query,
|
||||||
|
platform, channelID, messageID, replyToID,
|
||||||
|
userID, username, createdAt, triggerAt,
|
||||||
|
content,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.Reminder{
|
||||||
|
ID: id,
|
||||||
|
Platform: platform,
|
||||||
|
ChannelID: channelID,
|
||||||
|
MessageID: messageID,
|
||||||
|
ReplyToID: replyToID,
|
||||||
|
UserID: userID,
|
||||||
|
Username: username,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
TriggerAt: triggerAt,
|
||||||
|
Content: content,
|
||||||
|
Processed: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingReminders gets all pending reminders that need to be processed
|
||||||
|
func (d *Database) GetPendingReminders() ([]*model.Reminder, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, platform, channel_id, message_id, reply_to_id,
|
||||||
|
user_id, username, created_at, trigger_at, content, processed
|
||||||
|
FROM reminders
|
||||||
|
WHERE processed = 0 AND trigger_at <= ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := d.db.Query(query, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
fmt.Printf("Error closing rows: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var reminders []*model.Reminder
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
id int64
|
||||||
|
platform, channelID, messageID, replyToID string
|
||||||
|
userID, username, content string
|
||||||
|
createdAt, triggerAt time.Time
|
||||||
|
processed bool
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&id, &platform, &channelID, &messageID, &replyToID,
|
||||||
|
&userID, &username, &createdAt, &triggerAt, &content, &processed,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reminder := &model.Reminder{
|
||||||
|
ID: id,
|
||||||
|
Platform: platform,
|
||||||
|
ChannelID: channelID,
|
||||||
|
MessageID: messageID,
|
||||||
|
ReplyToID: replyToID,
|
||||||
|
UserID: userID,
|
||||||
|
Username: username,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
TriggerAt: triggerAt,
|
||||||
|
Content: content,
|
||||||
|
Processed: processed,
|
||||||
|
}
|
||||||
|
|
||||||
|
reminders = append(reminders, reminder)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(reminders) == 0 {
|
||||||
|
return make([]*model.Reminder, 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return reminders, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkReminderAsProcessed marks a reminder as processed
|
||||||
|
func (d *Database) MarkReminderAsProcessed(id int64) error {
|
||||||
|
query := `
|
||||||
|
UPDATE reminders
|
||||||
|
SET processed = 1
|
||||||
|
WHERE id = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := d.db.Exec(query, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to hash password
|
// Helper function to hash password
|
||||||
func hashPassword(password string) (string, error) {
|
func hashPassword(password string) (string, error) {
|
||||||
// Use bcrypt for secure password hashing
|
// Use bcrypt for secure password hashing
|
||||||
|
@ -609,25 +735,25 @@ func initDatabase(db *sql.DB) error {
|
||||||
if err := migration.EnsureMigrationTable(db); err != nil {
|
if err := migration.EnsureMigrationTable(db); err != nil {
|
||||||
return fmt.Errorf("failed to create migration table: %w", err)
|
return fmt.Errorf("failed to create migration table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get applied migrations
|
// Get applied migrations
|
||||||
applied, err := migration.GetAppliedMigrations(db)
|
applied, err := migration.GetAppliedMigrations(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get applied migrations: %w", err)
|
return fmt.Errorf("failed to get applied migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all migration versions
|
// Get all migration versions
|
||||||
allMigrations := make([]int, 0, len(migration.Migrations))
|
allMigrations := make([]int, 0, len(migration.Migrations))
|
||||||
for version := range migration.Migrations {
|
for version := range migration.Migrations {
|
||||||
allMigrations = append(allMigrations, version)
|
allMigrations = append(allMigrations, version)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a map of applied migrations for quick lookup
|
// Create a map of applied migrations for quick lookup
|
||||||
appliedMap := make(map[int]bool)
|
appliedMap := make(map[int]bool)
|
||||||
for _, version := range applied {
|
for _, version := range applied {
|
||||||
appliedMap[version] = true
|
appliedMap[version] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count pending migrations
|
// Count pending migrations
|
||||||
pendingCount := 0
|
pendingCount := 0
|
||||||
for _, version := range allMigrations {
|
for _, version := range allMigrations {
|
||||||
|
@ -635,7 +761,7 @@ func initDatabase(db *sql.DB) error {
|
||||||
pendingCount++
|
pendingCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run migrations if needed
|
// Run migrations if needed
|
||||||
if pendingCount > 0 {
|
if pendingCount > 0 {
|
||||||
fmt.Printf("Running %d pending database migrations...\n", pendingCount)
|
fmt.Printf("Running %d pending database migrations...\n", pendingCount)
|
||||||
|
@ -646,6 +772,6 @@ func initDatabase(db *sql.DB) error {
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Database schema is up to date.")
|
fmt.Println("Database schema is up to date.")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,11 @@ func GetAppliedMigrations(db *sql.DB) ([]int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() {
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
fmt.Printf("Error closing rows: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
var versions []int
|
var versions []int
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
@ -128,7 +132,9 @@ func Migrate(db *sql.DB) error {
|
||||||
|
|
||||||
// Apply the migration
|
// Apply the migration
|
||||||
if err := migration.Up(db); err != nil {
|
if err := migration.Up(db); err != nil {
|
||||||
tx.Rollback()
|
if err := tx.Rollback(); err != nil {
|
||||||
|
fmt.Printf("Error rolling back transaction: %v\n", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("failed to apply migration %d: %w", version, err)
|
return fmt.Errorf("failed to apply migration %d: %w", version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,7 +143,9 @@ func Migrate(db *sql.DB) error {
|
||||||
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
||||||
version, time.Now(),
|
version, time.Now(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
tx.Rollback()
|
if err := tx.Rollback(); err != nil {
|
||||||
|
fmt.Printf("Error rolling back transaction: %v\n", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("failed to mark migration %d as applied: %w", version, err)
|
return fmt.Errorf("failed to mark migration %d as applied: %w", version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,13 +196,17 @@ func MigrateDown(db *sql.DB, targetVersion int) error {
|
||||||
|
|
||||||
// Apply the down migration
|
// Apply the down migration
|
||||||
if err := migration.Down(db); err != nil {
|
if err := migration.Down(db); err != nil {
|
||||||
tx.Rollback()
|
if err := tx.Rollback(); err != nil {
|
||||||
|
fmt.Printf("Error rolling back transaction: %v\n", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("failed to roll back migration %d: %w", version, err)
|
return fmt.Errorf("failed to roll back migration %d: %w", version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove from applied list
|
// Remove from applied list
|
||||||
if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil {
|
if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil {
|
||||||
tx.Rollback()
|
if err := tx.Rollback(); err != nil {
|
||||||
|
fmt.Printf("Error rolling back transaction: %v\n", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err)
|
return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,4 +220,4 @@ func MigrateDown(db *sql.DB, targetVersion int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
func init() {
|
func init() {
|
||||||
// Register migrations
|
// Register migrations
|
||||||
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
|
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
|
||||||
|
Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initial schema creation with bcrypt passwords - version 1
|
// Initial schema creation with bcrypt passwords - version 1
|
||||||
|
@ -60,14 +61,14 @@ func migrateInitialSchemaUp(db *sql.DB) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if users table is empty before inserting
|
// Check if users table is empty before inserting
|
||||||
var count int
|
var count int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
_, err = db.Exec(
|
_, err = db.Exec(
|
||||||
"INSERT INTO users (username, password) VALUES (?, ?)",
|
"INSERT INTO users (username, password) VALUES (?, ?)",
|
||||||
|
@ -99,4 +100,29 @@ func migrateInitialSchemaDown(db *sql.DB) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add reminders table - version 2
|
||||||
|
func migrateRemindersUp(db *sql.DB) error {
|
||||||
|
_, err := db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS reminders (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
platform TEXT NOT NULL,
|
||||||
|
channel_id TEXT NOT NULL,
|
||||||
|
message_id TEXT NOT NULL,
|
||||||
|
reply_to_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
username TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL,
|
||||||
|
trigger_at TIMESTAMP NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
processed BOOLEAN NOT NULL DEFAULT 0
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateRemindersDown(db *sql.DB) error {
|
||||||
|
_, err := db.Exec(`DROP TABLE IF EXISTS reminders`)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -6,25 +6,25 @@ import (
|
||||||
|
|
||||||
// Message represents a chat message
|
// Message represents a chat message
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Text string
|
Text string
|
||||||
Chat string
|
Chat string
|
||||||
Channel *Channel
|
Channel *Channel
|
||||||
Author string
|
Author string
|
||||||
FromBot bool
|
FromBot bool
|
||||||
Date time.Time
|
Date time.Time
|
||||||
ID string
|
ID string
|
||||||
ReplyTo string
|
ReplyTo string
|
||||||
Raw map[string]interface{}
|
Raw map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Channel represents a chat channel
|
// Channel represents a chat channel
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
ID int64
|
ID int64
|
||||||
Platform string
|
Platform string
|
||||||
PlatformChannelID string
|
PlatformChannelID string
|
||||||
ChannelRaw map[string]interface{}
|
ChannelRaw map[string]interface{}
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Plugins map[string]*ChannelPlugin
|
Plugins map[string]*ChannelPlugin
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasEnabledPlugin checks if a plugin is enabled for this channel
|
// HasEnabledPlugin checks if a plugin is enabled for this channel
|
||||||
|
@ -40,18 +40,18 @@ func (c *Channel) HasEnabledPlugin(pluginID string) bool {
|
||||||
func (c *Channel) ChannelName() string {
|
func (c *Channel) ChannelName() string {
|
||||||
// In a real implementation, this would use the platform-specific
|
// In a real implementation, this would use the platform-specific
|
||||||
// ParseChannelNameFromRaw function
|
// ParseChannelNameFromRaw function
|
||||||
|
|
||||||
// For simplicity, we'll just use the PlatformChannelID if we can't extract a name
|
// For simplicity, we'll just use the PlatformChannelID if we can't extract a name
|
||||||
// Check if ChannelRaw has a name field
|
// Check if ChannelRaw has a name field
|
||||||
if c.ChannelRaw == nil {
|
if c.ChannelRaw == nil {
|
||||||
return c.PlatformChannelID
|
return c.PlatformChannelID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check common name fields in ChannelRaw
|
// Check common name fields in ChannelRaw
|
||||||
if name, ok := c.ChannelRaw["name"].(string); ok && name != "" {
|
if name, ok := c.ChannelRaw["name"].(string); ok && name != "" {
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for nested objects like "chat" (used by Telegram)
|
// Check for nested objects like "chat" (used by Telegram)
|
||||||
if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok {
|
if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok {
|
||||||
// Try different fields in order of preference
|
// Try different fields in order of preference
|
||||||
|
@ -65,7 +65,7 @@ func (c *Channel) ChannelName() string {
|
||||||
return firstName
|
return firstName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.PlatformChannelID
|
return c.PlatformChannelID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,4 +83,19 @@ type User struct {
|
||||||
ID int64
|
ID int64
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reminder represents a scheduled reminder
|
||||||
|
type Reminder struct {
|
||||||
|
ID int64
|
||||||
|
Platform string
|
||||||
|
ChannelID string
|
||||||
|
MessageID string
|
||||||
|
ReplyToID string
|
||||||
|
UserID string
|
||||||
|
Username string
|
||||||
|
CreatedAt time.Time
|
||||||
|
TriggerAt time.Time
|
||||||
|
Content string
|
||||||
|
Processed bool
|
||||||
|
}
|
||||||
|
|
|
@ -13,16 +13,16 @@ var (
|
||||||
type Plugin interface {
|
type Plugin interface {
|
||||||
// GetID returns the plugin ID
|
// GetID returns the plugin ID
|
||||||
GetID() string
|
GetID() string
|
||||||
|
|
||||||
// GetName returns the plugin name
|
// GetName returns the plugin name
|
||||||
GetName() string
|
GetName() string
|
||||||
|
|
||||||
// GetHelp returns the plugin help text
|
// GetHelp returns the plugin help text
|
||||||
GetHelp() string
|
GetHelp() string
|
||||||
|
|
||||||
// RequiresConfig indicates if the plugin requires configuration
|
// RequiresConfig indicates if the plugin requires configuration
|
||||||
RequiresConfig() bool
|
RequiresConfig() bool
|
||||||
|
|
||||||
// OnMessage processes an incoming message and returns response messages
|
// OnMessage processes an incoming message and returns response messages
|
||||||
OnMessage(msg *Message, config map[string]interface{}) []*Message
|
OnMessage(msg *Message, config map[string]interface{}) []*Message
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -37,11 +37,15 @@ func (s *SlackPlatform) Init(_ *config.Config) error {
|
||||||
// ParseIncomingMessage parses an incoming Slack message
|
// ParseIncomingMessage parses an incoming Slack message
|
||||||
func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) {
|
func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) {
|
||||||
// Read request body
|
// Read request body
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Body.Close()
|
defer func() {
|
||||||
|
if err := r.Body.Close(); err != nil {
|
||||||
|
fmt.Printf("Error closing request body: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Parse JSON
|
// Parse JSON
|
||||||
var requestData map[string]interface{}
|
var requestData map[string]interface{}
|
||||||
|
@ -194,7 +198,11 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() {
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
fmt.Printf("Error closing response body: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
|
|
@ -62,7 +62,11 @@ func (t *TelegramPlatform) Init(cfg *config.Config) error {
|
||||||
t.log.Error("Failed to set webhook", "error", err)
|
t.log.Error("Failed to set webhook", "error", err)
|
||||||
return fmt.Errorf("failed to set webhook: %w", err)
|
return fmt.Errorf("failed to set webhook: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() {
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
t.log.Error("Error closing response body", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
@ -85,7 +89,11 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
|
||||||
t.log.Error("Failed to read request body", "error", err)
|
t.log.Error("Failed to read request body", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Body.Close()
|
defer func() {
|
||||||
|
if err := r.Body.Close(); err != nil {
|
||||||
|
t.log.Error("Error closing request body", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Parse JSON
|
// Parse JSON
|
||||||
var update struct {
|
var update struct {
|
||||||
|
@ -103,8 +111,11 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
|
||||||
Title string `json:"title,omitempty"`
|
Title string `json:"title,omitempty"`
|
||||||
Username string `json:"username,omitempty"`
|
Username string `json:"username,omitempty"`
|
||||||
} `json:"chat"`
|
} `json:"chat"`
|
||||||
Date int `json:"date"`
|
Date int `json:"date"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
|
ReplyToMessage struct {
|
||||||
|
MessageID int `json:"message_id"`
|
||||||
|
} `json:"reply_to_message"`
|
||||||
} `json:"message"`
|
} `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,6 +139,7 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
|
||||||
FromBot: update.Message.From.IsBot,
|
FromBot: update.Message.From.IsBot,
|
||||||
Date: time.Unix(int64(update.Message.Date), 0),
|
Date: time.Unix(int64(update.Message.Date), 0),
|
||||||
ID: strconv.Itoa(update.Message.MessageID),
|
ID: strconv.Itoa(update.Message.MessageID),
|
||||||
|
ReplyTo: strconv.Itoa(update.Message.ReplyToMessage.MessageID),
|
||||||
Raw: raw,
|
Raw: raw,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,7 +259,11 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
||||||
t.log.Error("Failed to send message", "error", err)
|
t.log.Error("Failed to send message", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() {
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
t.log.Error("Error closing response body", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
@ -259,4 +275,4 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
||||||
|
|
||||||
t.log.Debug("Message sent successfully")
|
t.log.Debug("Message sent successfully")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,9 +107,10 @@ func (p *DicePlugin) rollDice(formula string) (int, error) {
|
||||||
return 0, fmt.Errorf("invalid modifier")
|
return 0, fmt.Errorf("invalid modifier")
|
||||||
}
|
}
|
||||||
|
|
||||||
if matches[3] == "+" {
|
switch matches[3] {
|
||||||
|
case "+":
|
||||||
total += modifier
|
total += modifier
|
||||||
} else if matches[3] == "-" {
|
case "-":
|
||||||
total -= modifier
|
total -= modifier
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
171
internal/plugin/reminder/reminder.go
Normal file
171
internal/plugin/reminder/reminder.go
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
package reminder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Duration regex patterns to match reminders
|
||||||
|
var (
|
||||||
|
remindMePattern = regexp.MustCompile(`(?i)^!remindme\s(\d+)(y|mo|d|h|m|s)$`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReminderCreator is an interface for creating reminders
|
||||||
|
type ReminderCreator interface {
|
||||||
|
CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reminder is a plugin that sets reminders for messages
|
||||||
|
type Reminder struct {
|
||||||
|
plugin.BasePlugin
|
||||||
|
creator ReminderCreator
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Reminder plugin
|
||||||
|
func New(creator ReminderCreator) *Reminder {
|
||||||
|
return &Reminder{
|
||||||
|
BasePlugin: plugin.BasePlugin{
|
||||||
|
ID: "reminder.remindme",
|
||||||
|
Name: "Remind Me",
|
||||||
|
Help: "Reply to a message with `!remindme <duration>` to set a reminder (e.g., `!remindme 2d` for 2 days, `!remindme 1y` for 1 year).",
|
||||||
|
ConfigRequired: false,
|
||||||
|
},
|
||||||
|
creator: creator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMessage processes incoming messages
|
||||||
|
func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
||||||
|
// Only process replies to messages
|
||||||
|
if msg.ReplyTo == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the message is a reminder command
|
||||||
|
match := remindMePattern.FindStringSubmatch(msg.Text)
|
||||||
|
if match == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the duration
|
||||||
|
amount, err := strconv.Atoi(match[1])
|
||||||
|
if err != nil {
|
||||||
|
return []*model.Message{
|
||||||
|
{
|
||||||
|
Text: "Invalid duration format. Please use a number followed by y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).",
|
||||||
|
Chat: msg.Chat,
|
||||||
|
Channel: msg.Channel,
|
||||||
|
Author: "bot",
|
||||||
|
FromBot: true,
|
||||||
|
Date: time.Now(),
|
||||||
|
ReplyTo: msg.ID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the trigger time
|
||||||
|
var duration time.Duration
|
||||||
|
unit := match[2]
|
||||||
|
switch strings.ToLower(unit) {
|
||||||
|
case "y":
|
||||||
|
duration = time.Duration(amount) * 365 * 24 * time.Hour
|
||||||
|
case "mo":
|
||||||
|
duration = time.Duration(amount) * 30 * 24 * time.Hour
|
||||||
|
case "d":
|
||||||
|
duration = time.Duration(amount) * 24 * time.Hour
|
||||||
|
case "h":
|
||||||
|
duration = time.Duration(amount) * time.Hour
|
||||||
|
case "m":
|
||||||
|
duration = time.Duration(amount) * time.Minute
|
||||||
|
case "s":
|
||||||
|
duration = time.Duration(amount) * time.Second
|
||||||
|
default:
|
||||||
|
return []*model.Message{
|
||||||
|
{
|
||||||
|
Text: "Invalid duration unit. Please use y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).",
|
||||||
|
Chat: msg.Chat,
|
||||||
|
Channel: msg.Channel,
|
||||||
|
Author: "bot",
|
||||||
|
FromBot: true,
|
||||||
|
Date: time.Now(),
|
||||||
|
ReplyTo: msg.ID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
triggerAt := time.Now().Add(duration)
|
||||||
|
|
||||||
|
// Determine the username for the reminder
|
||||||
|
username := msg.Author
|
||||||
|
if username == "" {
|
||||||
|
// Try to extract username from message raw data
|
||||||
|
if authorData, ok := msg.Raw["author"].(map[string]interface{}); ok {
|
||||||
|
if name, ok := authorData["username"].(string); ok {
|
||||||
|
username = name
|
||||||
|
} else if name, ok := authorData["name"].(string); ok {
|
||||||
|
username = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the reminder
|
||||||
|
_, err = r.creator.CreateReminder(
|
||||||
|
msg.Channel.Platform,
|
||||||
|
msg.Chat,
|
||||||
|
msg.ID,
|
||||||
|
msg.ReplyTo,
|
||||||
|
msg.Author,
|
||||||
|
username,
|
||||||
|
"", // No additional content for now
|
||||||
|
triggerAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return []*model.Message{
|
||||||
|
{
|
||||||
|
Text: fmt.Sprintf("Failed to create reminder: %v", err),
|
||||||
|
Chat: msg.Chat,
|
||||||
|
Channel: msg.Channel,
|
||||||
|
Author: "bot",
|
||||||
|
FromBot: true,
|
||||||
|
Date: time.Now(),
|
||||||
|
ReplyTo: msg.ID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the acknowledgment message
|
||||||
|
var confirmText string
|
||||||
|
switch strings.ToLower(unit) {
|
||||||
|
case "y":
|
||||||
|
confirmText = fmt.Sprintf("I'll remind you about this message in %d year(s) on %s", amount, triggerAt.Format("Mon, Jan 2, 2006 at 15:04"))
|
||||||
|
case "mo":
|
||||||
|
confirmText = fmt.Sprintf("I'll remind you about this message in %d month(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04"))
|
||||||
|
case "d":
|
||||||
|
confirmText = fmt.Sprintf("I'll remind you about this message in %d day(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04"))
|
||||||
|
case "h":
|
||||||
|
confirmText = fmt.Sprintf("I'll remind you about this message in %d hour(s) at %s", amount, triggerAt.Format("15:04"))
|
||||||
|
case "m":
|
||||||
|
confirmText = fmt.Sprintf("I'll remind you about this message in %d minute(s) at %s", amount, triggerAt.Format("15:04"))
|
||||||
|
case "s":
|
||||||
|
confirmText = fmt.Sprintf("I'll remind you about this message in %d second(s)", amount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []*model.Message{
|
||||||
|
{
|
||||||
|
Text: confirmText,
|
||||||
|
Chat: msg.Chat,
|
||||||
|
Channel: msg.Channel,
|
||||||
|
Author: "bot",
|
||||||
|
FromBot: true,
|
||||||
|
Date: time.Now(),
|
||||||
|
ReplyTo: msg.ID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
164
internal/plugin/reminder/reminder_test.go
Normal file
164
internal/plugin/reminder/reminder_test.go
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
package reminder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCreator is a mock implementation of ReminderCreator for testing
|
||||||
|
type MockCreator struct {
|
||||||
|
reminders []*model.Reminder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCreator) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
|
||||||
|
reminder := &model.Reminder{
|
||||||
|
ID: int64(len(m.reminders) + 1),
|
||||||
|
Platform: platform,
|
||||||
|
ChannelID: channelID,
|
||||||
|
MessageID: messageID,
|
||||||
|
ReplyToID: replyToID,
|
||||||
|
UserID: userID,
|
||||||
|
Username: username,
|
||||||
|
Content: content,
|
||||||
|
TriggerAt: triggerAt,
|
||||||
|
}
|
||||||
|
m.reminders = append(m.reminders, reminder)
|
||||||
|
return reminder, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReminderOnMessage(t *testing.T) {
|
||||||
|
creator := &MockCreator{reminders: make([]*model.Reminder, 0)}
|
||||||
|
plugin := New(creator)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
message *model.Message
|
||||||
|
expectResponse bool
|
||||||
|
expectReminder bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid reminder command - years",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "!remindme 1y",
|
||||||
|
ReplyTo: "original-message-id",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: true,
|
||||||
|
expectReminder: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid reminder command - months",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "!remindme 3mo",
|
||||||
|
ReplyTo: "original-message-id",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: true,
|
||||||
|
expectReminder: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid reminder command - days",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "!remindme 2d",
|
||||||
|
ReplyTo: "original-message-id",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: true,
|
||||||
|
expectReminder: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid reminder command - hours",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "!remindme 5h",
|
||||||
|
ReplyTo: "original-message-id",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: true,
|
||||||
|
expectReminder: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid reminder command - minutes",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "!remindme 30m",
|
||||||
|
ReplyTo: "original-message-id",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: true,
|
||||||
|
expectReminder: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid reminder command - seconds",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "!remindme 60s",
|
||||||
|
ReplyTo: "original-message-id",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: true,
|
||||||
|
expectReminder: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not a reply",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "!remindme 2d",
|
||||||
|
ReplyTo: "",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: false,
|
||||||
|
expectReminder: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not a reminder command",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "hello world",
|
||||||
|
ReplyTo: "original-message-id",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: false,
|
||||||
|
expectReminder: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid duration format",
|
||||||
|
message: &model.Message{
|
||||||
|
Text: "!remindme abc",
|
||||||
|
ReplyTo: "original-message-id",
|
||||||
|
Author: "testuser",
|
||||||
|
Channel: &model.Channel{Platform: "test"},
|
||||||
|
},
|
||||||
|
expectResponse: false,
|
||||||
|
expectReminder: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
initialCount := len(creator.reminders)
|
||||||
|
responses := plugin.OnMessage(tt.message, nil)
|
||||||
|
|
||||||
|
if tt.expectResponse && len(responses) == 0 {
|
||||||
|
t.Errorf("Expected response, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectResponse && len(responses) > 0 {
|
||||||
|
t.Errorf("Expected no response, but got %d", len(responses))
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectReminder && len(creator.reminders) != initialCount+1 {
|
||||||
|
t.Errorf("Expected reminder to be created, but it wasn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectReminder && len(creator.reminders) != initialCount {
|
||||||
|
t.Errorf("Expected no reminder to be created, but got %d", len(creator.reminders)-initialCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -53,9 +53,7 @@ func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]inte
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change the host
|
// Change the host
|
||||||
if strings.Contains(parsedURL.Host, "instagram.com") {
|
parsedURL.Host = strings.Replace(parsedURL.Host, "instagram.com", "ddinstagram.com", 1)
|
||||||
parsedURL.Host = strings.Replace(parsedURL.Host, "instagram.com", "ddinstagram.com", 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove query parameters
|
// Remove query parameters
|
||||||
parsedURL.RawQuery = ""
|
parsedURL.RawQuery = ""
|
||||||
|
|
|
@ -3,6 +3,9 @@ package queue
|
||||||
import (
|
import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Item represents a queue item
|
// Item represents a queue item
|
||||||
|
@ -14,14 +17,19 @@ type Item struct {
|
||||||
// HandlerFunc defines a function that processes queue items
|
// HandlerFunc defines a function that processes queue items
|
||||||
type HandlerFunc func(item Item)
|
type HandlerFunc func(item Item)
|
||||||
|
|
||||||
|
// ReminderHandlerFunc defines a function that processes reminder items
|
||||||
|
type ReminderHandlerFunc func(reminder *model.Reminder)
|
||||||
|
|
||||||
// Queue represents a message queue
|
// Queue represents a message queue
|
||||||
type Queue struct {
|
type Queue struct {
|
||||||
items chan Item
|
items chan Item
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
running bool
|
running bool
|
||||||
runMutex sync.Mutex
|
runMutex sync.Mutex
|
||||||
|
reminderTicker *time.Ticker
|
||||||
|
reminderHandler ReminderHandlerFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Queue instance
|
// New creates a new Queue instance
|
||||||
|
@ -49,6 +57,24 @@ func (q *Queue) Start(handler HandlerFunc) {
|
||||||
go q.worker(handler)
|
go q.worker(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartReminderScheduler starts the reminder scheduler
|
||||||
|
func (q *Queue) StartReminderScheduler(handler ReminderHandlerFunc) {
|
||||||
|
q.runMutex.Lock()
|
||||||
|
defer q.runMutex.Unlock()
|
||||||
|
|
||||||
|
if q.reminderTicker != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
q.reminderHandler = handler
|
||||||
|
|
||||||
|
// Check for reminders every minute
|
||||||
|
q.reminderTicker = time.NewTicker(1 * time.Minute)
|
||||||
|
|
||||||
|
q.wg.Add(1)
|
||||||
|
go q.reminderWorker()
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops processing queue items
|
// Stop stops processing queue items
|
||||||
func (q *Queue) Stop() {
|
func (q *Queue) Stop() {
|
||||||
q.runMutex.Lock()
|
q.runMutex.Lock()
|
||||||
|
@ -59,6 +85,12 @@ func (q *Queue) Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
q.running = false
|
q.running = false
|
||||||
|
|
||||||
|
// Stop reminder ticker if it exists
|
||||||
|
if q.reminderTicker != nil {
|
||||||
|
q.reminderTicker.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
close(q.quit)
|
close(q.quit)
|
||||||
q.wg.Wait()
|
q.wg.Wait()
|
||||||
}
|
}
|
||||||
|
@ -96,4 +128,34 @@ func (q *Queue) worker(handler HandlerFunc) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reminderWorker processes reminder items on a schedule
|
||||||
|
func (q *Queue) reminderWorker() {
|
||||||
|
defer q.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-q.reminderTicker.C:
|
||||||
|
// This is triggered every minute to check for pending reminders
|
||||||
|
q.logger.Debug("Checking for pending reminders")
|
||||||
|
|
||||||
|
if q.reminderHandler != nil {
|
||||||
|
// The handler is responsible for fetching and processing reminders
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
q.logger.Error("Panic in reminder worker", "error", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Call the handler with a nil reminder to indicate it should check the database
|
||||||
|
q.reminderHandler(nil)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
case <-q.quit:
|
||||||
|
// Quit worker
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue