This commit is contained in:
parent
21e4c434fd
commit
72c6dd6982
12 changed files with 695 additions and 48 deletions
|
@ -10,6 +10,10 @@
|
|||
- Dice: Put `!dice` and wathever roll you want to perform.
|
||||
- Coin: Flip a coin and get heads or tails.
|
||||
|
||||
### Utility
|
||||
|
||||
- Remind Me: Reply to a message with `!remindme <duration>` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time.
|
||||
|
||||
### Social Media
|
||||
|
||||
- Twitter Link Expander: Automatically converts twitter.com and x.com links to fxtwitter.com links and removes tracking parameters. This allows for better media embedding in chat platforms.
|
||||
|
|
|
@ -106,19 +106,19 @@ func New(cfg *config.Config, database *db.Database, version string) *Admin {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
// Create a clone of the base template
|
||||
t, err := baseTemplate.Clone()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
// Parse the template content
|
||||
t, err = t.Parse(string(content))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
templates[tf] = t
|
||||
}
|
||||
|
||||
|
@ -362,7 +362,7 @@ func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
session.Options.MaxAge = -1 // Delete session
|
||||
err = session.Save(r, w)
|
||||
|
|
|
@ -17,10 +17,12 @@ import (
|
|||
"git.nakama.town/fmartingr/butterrobot/internal/admin"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/config"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/platform"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin/social"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/queue"
|
||||
)
|
||||
|
@ -86,12 +88,19 @@ func (a *App) Run() error {
|
|||
plugin.Register(social.NewTwitterExpander())
|
||||
plugin.Register(social.NewInstagramExpander())
|
||||
|
||||
// Register reminder plugin
|
||||
reminderPlugin := reminder.New(a.db)
|
||||
plugin.Register(reminderPlugin)
|
||||
|
||||
// Initialize routes
|
||||
a.initializeRoutes()
|
||||
|
||||
// Start message queue worker
|
||||
a.queue.Start(a.handleMessage)
|
||||
|
||||
// Start reminder scheduler
|
||||
a.queue.StartReminderScheduler(a.handleReminder)
|
||||
|
||||
// Create server
|
||||
addr := fmt.Sprintf(":%s", a.config.Port)
|
||||
srv := &http.Server{
|
||||
|
@ -304,3 +313,73 @@ func (a *App) handleMessage(item queue.Item) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleReminder handles reminder processing
|
||||
func (a *App) handleReminder(reminder *model.Reminder) {
|
||||
// When called with nil, it means we should check for pending reminders
|
||||
if reminder == nil {
|
||||
// Get pending reminders
|
||||
reminders, err := a.db.GetPendingReminders()
|
||||
if err != nil {
|
||||
a.logger.Error("Error getting pending reminders", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Process each reminder
|
||||
for _, r := range reminders {
|
||||
a.processReminder(r)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, process the specific reminder
|
||||
a.processReminder(reminder)
|
||||
}
|
||||
|
||||
// processReminder processes an individual reminder
|
||||
func (a *App) processReminder(reminder *model.Reminder) {
|
||||
a.logger.Info("Processing reminder",
|
||||
"id", reminder.ID,
|
||||
"platform", reminder.Platform,
|
||||
"channel", reminder.ChannelID,
|
||||
"trigger_at", reminder.TriggerAt,
|
||||
)
|
||||
|
||||
// Get the platform handler
|
||||
p, err := platform.Get(reminder.Platform)
|
||||
if err != nil {
|
||||
a.logger.Error("Error getting platform for reminder", "error", err, "platform", reminder.Platform)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the channel
|
||||
channel, err := a.db.GetChannelByPlatform(reminder.Platform, reminder.ChannelID)
|
||||
if err != nil {
|
||||
a.logger.Error("Error getting channel for reminder", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create the reminder message
|
||||
reminderText := fmt.Sprintf("@%s reminding you of this", reminder.Username)
|
||||
|
||||
message := &model.Message{
|
||||
Text: reminderText,
|
||||
Chat: reminder.ChannelID,
|
||||
Channel: channel,
|
||||
Author: "bot",
|
||||
FromBot: true,
|
||||
Date: time.Now(),
|
||||
ReplyTo: reminder.ReplyToID, // Reply to the original message
|
||||
}
|
||||
|
||||
// Send the reminder message
|
||||
if err := p.SendMessage(message); err != nil {
|
||||
a.logger.Error("Error sending reminder", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the reminder as processed
|
||||
if err := a.db.MarkReminderAsProcessed(reminder.ID); err != nil {
|
||||
a.logger.Error("Error marking reminder as processed", "error", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
_ "modernc.org/sqlite"
|
||||
|
@ -591,6 +592,120 @@ func (d *Database) UpdateUserPassword(userID int64, newPassword string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// CreateReminder creates a new reminder
|
||||
func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
|
||||
query := `
|
||||
INSERT INTO reminders (
|
||||
platform, channel_id, message_id, reply_to_id,
|
||||
user_id, username, created_at, trigger_at,
|
||||
content, processed
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||
`
|
||||
|
||||
createdAt := time.Now()
|
||||
result, err := d.db.Exec(
|
||||
query,
|
||||
platform, channelID, messageID, replyToID,
|
||||
userID, username, createdAt, triggerAt,
|
||||
content,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.Reminder{
|
||||
ID: id,
|
||||
Platform: platform,
|
||||
ChannelID: channelID,
|
||||
MessageID: messageID,
|
||||
ReplyToID: replyToID,
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
CreatedAt: createdAt,
|
||||
TriggerAt: triggerAt,
|
||||
Content: content,
|
||||
Processed: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPendingReminders gets all pending reminders that need to be processed
|
||||
func (d *Database) GetPendingReminders() ([]*model.Reminder, error) {
|
||||
query := `
|
||||
SELECT id, platform, channel_id, message_id, reply_to_id,
|
||||
user_id, username, created_at, trigger_at, content, processed
|
||||
FROM reminders
|
||||
WHERE processed = 0 AND trigger_at <= ?
|
||||
`
|
||||
|
||||
rows, err := d.db.Query(query, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var reminders []*model.Reminder
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
platform, channelID, messageID, replyToID string
|
||||
userID, username, content string
|
||||
createdAt, triggerAt time.Time
|
||||
processed bool
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&id, &platform, &channelID, &messageID, &replyToID,
|
||||
&userID, &username, &createdAt, &triggerAt, &content, &processed,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reminder := &model.Reminder{
|
||||
ID: id,
|
||||
Platform: platform,
|
||||
ChannelID: channelID,
|
||||
MessageID: messageID,
|
||||
ReplyToID: replyToID,
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
CreatedAt: createdAt,
|
||||
TriggerAt: triggerAt,
|
||||
Content: content,
|
||||
Processed: processed,
|
||||
}
|
||||
|
||||
reminders = append(reminders, reminder)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(reminders) == 0 {
|
||||
return make([]*model.Reminder, 0), nil
|
||||
}
|
||||
|
||||
return reminders, nil
|
||||
}
|
||||
|
||||
// MarkReminderAsProcessed marks a reminder as processed
|
||||
func (d *Database) MarkReminderAsProcessed(id int64) error {
|
||||
query := `
|
||||
UPDATE reminders
|
||||
SET processed = 1
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := d.db.Exec(query, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// Helper function to hash password
|
||||
func hashPassword(password string) (string, error) {
|
||||
// Use bcrypt for secure password hashing
|
||||
|
@ -609,25 +724,25 @@ func initDatabase(db *sql.DB) error {
|
|||
if err := migration.EnsureMigrationTable(db); err != nil {
|
||||
return fmt.Errorf("failed to create migration table: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Get applied migrations
|
||||
applied, err := migration.GetAppliedMigrations(db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get applied migrations: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Get all migration versions
|
||||
allMigrations := make([]int, 0, len(migration.Migrations))
|
||||
for version := range migration.Migrations {
|
||||
allMigrations = append(allMigrations, version)
|
||||
}
|
||||
|
||||
|
||||
// Create a map of applied migrations for quick lookup
|
||||
appliedMap := make(map[int]bool)
|
||||
for _, version := range applied {
|
||||
appliedMap[version] = true
|
||||
}
|
||||
|
||||
|
||||
// Count pending migrations
|
||||
pendingCount := 0
|
||||
for _, version := range allMigrations {
|
||||
|
@ -635,7 +750,7 @@ func initDatabase(db *sql.DB) error {
|
|||
pendingCount++
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Run migrations if needed
|
||||
if pendingCount > 0 {
|
||||
fmt.Printf("Running %d pending database migrations...\n", pendingCount)
|
||||
|
@ -646,6 +761,6 @@ func initDatabase(db *sql.DB) error {
|
|||
} else {
|
||||
fmt.Println("Database schema is up to date.")
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -208,4 +208,4 @@ func MigrateDown(db *sql.DB, targetVersion int) error {
|
|||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
func init() {
|
||||
// Register migrations
|
||||
Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown)
|
||||
Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown)
|
||||
}
|
||||
|
||||
// Initial schema creation with bcrypt passwords - version 1
|
||||
|
@ -60,14 +61,14 @@ func migrateInitialSchemaUp(db *sql.DB) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
// Check if users table is empty before inserting
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
if count == 0 {
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO users (username, password) VALUES (?, ?)",
|
||||
|
@ -99,4 +100,29 @@ func migrateInitialSchemaDown(db *sql.DB) error {
|
|||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add reminders table - version 2
|
||||
func migrateRemindersUp(db *sql.DB) error {
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS reminders (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
platform TEXT NOT NULL,
|
||||
channel_id TEXT NOT NULL,
|
||||
message_id TEXT NOT NULL,
|
||||
reply_to_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
username TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
trigger_at TIMESTAMP NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
processed BOOLEAN NOT NULL DEFAULT 0
|
||||
)
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
func migrateRemindersDown(db *sql.DB) error {
|
||||
_, err := db.Exec(`DROP TABLE IF EXISTS reminders`)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,25 +6,25 @@ import (
|
|||
|
||||
// Message represents a chat message
|
||||
type Message struct {
|
||||
Text string
|
||||
Chat string
|
||||
Channel *Channel
|
||||
Author string
|
||||
FromBot bool
|
||||
Date time.Time
|
||||
ID string
|
||||
ReplyTo string
|
||||
Raw map[string]interface{}
|
||||
Text string
|
||||
Chat string
|
||||
Channel *Channel
|
||||
Author string
|
||||
FromBot bool
|
||||
Date time.Time
|
||||
ID string
|
||||
ReplyTo string
|
||||
Raw map[string]interface{}
|
||||
}
|
||||
|
||||
// Channel represents a chat channel
|
||||
type Channel struct {
|
||||
ID int64
|
||||
Platform string
|
||||
ID int64
|
||||
Platform string
|
||||
PlatformChannelID string
|
||||
ChannelRaw map[string]interface{}
|
||||
Enabled bool
|
||||
Plugins map[string]*ChannelPlugin
|
||||
ChannelRaw map[string]interface{}
|
||||
Enabled bool
|
||||
Plugins map[string]*ChannelPlugin
|
||||
}
|
||||
|
||||
// HasEnabledPlugin checks if a plugin is enabled for this channel
|
||||
|
@ -40,18 +40,18 @@ func (c *Channel) HasEnabledPlugin(pluginID string) bool {
|
|||
func (c *Channel) ChannelName() string {
|
||||
// In a real implementation, this would use the platform-specific
|
||||
// ParseChannelNameFromRaw function
|
||||
|
||||
|
||||
// For simplicity, we'll just use the PlatformChannelID if we can't extract a name
|
||||
// Check if ChannelRaw has a name field
|
||||
if c.ChannelRaw == nil {
|
||||
return c.PlatformChannelID
|
||||
}
|
||||
|
||||
|
||||
// Check common name fields in ChannelRaw
|
||||
if name, ok := c.ChannelRaw["name"].(string); ok && name != "" {
|
||||
return name
|
||||
}
|
||||
|
||||
|
||||
// Check for nested objects like "chat" (used by Telegram)
|
||||
if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok {
|
||||
// Try different fields in order of preference
|
||||
|
@ -65,7 +65,7 @@ func (c *Channel) ChannelName() string {
|
|||
return firstName
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return c.PlatformChannelID
|
||||
}
|
||||
|
||||
|
@ -83,4 +83,19 @@ type User struct {
|
|||
ID int64
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
}
|
||||
|
||||
// Reminder represents a scheduled reminder
|
||||
type Reminder struct {
|
||||
ID int64
|
||||
Platform string
|
||||
ChannelID string
|
||||
MessageID string
|
||||
ReplyToID string
|
||||
UserID string
|
||||
Username string
|
||||
CreatedAt time.Time
|
||||
TriggerAt time.Time
|
||||
Content string
|
||||
Processed bool
|
||||
}
|
||||
|
|
|
@ -13,16 +13,16 @@ var (
|
|||
type Plugin interface {
|
||||
// GetID returns the plugin ID
|
||||
GetID() string
|
||||
|
||||
|
||||
// GetName returns the plugin name
|
||||
GetName() string
|
||||
|
||||
|
||||
// GetHelp returns the plugin help text
|
||||
GetHelp() string
|
||||
|
||||
|
||||
// RequiresConfig indicates if the plugin requires configuration
|
||||
RequiresConfig() bool
|
||||
|
||||
|
||||
// OnMessage processes an incoming message and returns response messages
|
||||
OnMessage(msg *Message, config map[string]interface{}) []*Message
|
||||
}
|
||||
}
|
||||
|
|
|
@ -103,8 +103,11 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
|
|||
Title string `json:"title,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
} `json:"chat"`
|
||||
Date int `json:"date"`
|
||||
Text string `json:"text"`
|
||||
Date int `json:"date"`
|
||||
Text string `json:"text"`
|
||||
ReplyToMessage struct {
|
||||
MessageID int `json:"message_id"`
|
||||
} `json:"reply_to_message"`
|
||||
} `json:"message"`
|
||||
}
|
||||
|
||||
|
@ -128,6 +131,7 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message
|
|||
FromBot: update.Message.From.IsBot,
|
||||
Date: time.Unix(int64(update.Message.Date), 0),
|
||||
ID: strconv.Itoa(update.Message.MessageID),
|
||||
ReplyTo: strconv.Itoa(update.Message.ReplyToMessage.MessageID),
|
||||
Raw: raw,
|
||||
}
|
||||
|
||||
|
@ -259,4 +263,4 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error {
|
|||
|
||||
t.log.Debug("Message sent successfully")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
178
internal/plugin/reminder/reminder.go
Normal file
178
internal/plugin/reminder/reminder.go
Normal file
|
@ -0,0 +1,178 @@
|
|||
package reminder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
||||
)
|
||||
|
||||
// Duration regex patterns to match reminders
|
||||
var (
|
||||
remindMePattern = regexp.MustCompile(`(?i)^!remindme\s(\d+)(y|mo|d|h|m|s)$`)
|
||||
)
|
||||
|
||||
// ReminderCreator is an interface for creating reminders
|
||||
type ReminderCreator interface {
|
||||
CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error)
|
||||
}
|
||||
|
||||
// Reminder is a plugin that sets reminders for messages
|
||||
type Reminder struct {
|
||||
plugin.BasePlugin
|
||||
creator ReminderCreator
|
||||
}
|
||||
|
||||
// New creates a new Reminder plugin
|
||||
func New(creator ReminderCreator) *Reminder {
|
||||
return &Reminder{
|
||||
BasePlugin: plugin.BasePlugin{
|
||||
ID: "reminder.remindme",
|
||||
Name: "Remind Me",
|
||||
Help: "Reply to a message with `!remindme <duration>` to set a reminder (e.g., `!remindme 2d` for 2 days, `!remindme 1y` for 1 year).",
|
||||
ConfigRequired: false,
|
||||
},
|
||||
creator: creator,
|
||||
}
|
||||
}
|
||||
|
||||
// OnMessage processes incoming messages
|
||||
func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message {
|
||||
// Only process replies to messages
|
||||
if msg.ReplyTo == "" {
|
||||
return []*model.Message{
|
||||
{
|
||||
Text: "Please reply to a message with `!remindme <duration>` to set a reminder.",
|
||||
Chat: msg.Chat,
|
||||
Channel: msg.Channel,
|
||||
ReplyTo: msg.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the message is a reminder command
|
||||
match := remindMePattern.FindStringSubmatch(msg.Text)
|
||||
if match == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the duration
|
||||
amount, err := strconv.Atoi(match[1])
|
||||
if err != nil {
|
||||
return []*model.Message{
|
||||
{
|
||||
Text: "Invalid duration format. Please use a number followed by y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).",
|
||||
Chat: msg.Chat,
|
||||
Channel: msg.Channel,
|
||||
Author: "bot",
|
||||
FromBot: true,
|
||||
Date: time.Now(),
|
||||
ReplyTo: msg.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate the trigger time
|
||||
var duration time.Duration
|
||||
unit := match[2]
|
||||
switch strings.ToLower(unit) {
|
||||
case "y":
|
||||
duration = time.Duration(amount) * 365 * 24 * time.Hour
|
||||
case "mo":
|
||||
duration = time.Duration(amount) * 30 * 24 * time.Hour
|
||||
case "d":
|
||||
duration = time.Duration(amount) * 24 * time.Hour
|
||||
case "h":
|
||||
duration = time.Duration(amount) * time.Hour
|
||||
case "m":
|
||||
duration = time.Duration(amount) * time.Minute
|
||||
case "s":
|
||||
duration = time.Duration(amount) * time.Second
|
||||
default:
|
||||
return []*model.Message{
|
||||
{
|
||||
Text: "Invalid duration unit. Please use y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).",
|
||||
Chat: msg.Chat,
|
||||
Channel: msg.Channel,
|
||||
Author: "bot",
|
||||
FromBot: true,
|
||||
Date: time.Now(),
|
||||
ReplyTo: msg.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
triggerAt := time.Now().Add(duration)
|
||||
|
||||
// Determine the username for the reminder
|
||||
username := msg.Author
|
||||
if username == "" {
|
||||
// Try to extract username from message raw data
|
||||
if authorData, ok := msg.Raw["author"].(map[string]interface{}); ok {
|
||||
if name, ok := authorData["username"].(string); ok {
|
||||
username = name
|
||||
} else if name, ok := authorData["name"].(string); ok {
|
||||
username = name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the reminder
|
||||
_, err = r.creator.CreateReminder(
|
||||
msg.Channel.Platform,
|
||||
msg.Chat,
|
||||
msg.ID,
|
||||
msg.ReplyTo,
|
||||
msg.Author,
|
||||
username,
|
||||
"", // No additional content for now
|
||||
triggerAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return []*model.Message{
|
||||
{
|
||||
Text: fmt.Sprintf("Failed to create reminder: %v", err),
|
||||
Chat: msg.Chat,
|
||||
Channel: msg.Channel,
|
||||
Author: "bot",
|
||||
FromBot: true,
|
||||
Date: time.Now(),
|
||||
ReplyTo: msg.ID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format the acknowledgment message
|
||||
var confirmText string
|
||||
switch strings.ToLower(unit) {
|
||||
case "y":
|
||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d year(s) on %s", amount, triggerAt.Format("Mon, Jan 2, 2006 at 15:04"))
|
||||
case "mo":
|
||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d month(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04"))
|
||||
case "d":
|
||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d day(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04"))
|
||||
case "h":
|
||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d hour(s) at %s", amount, triggerAt.Format("15:04"))
|
||||
case "m":
|
||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d minute(s) at %s", amount, triggerAt.Format("15:04"))
|
||||
case "s":
|
||||
confirmText = fmt.Sprintf("I'll remind you about this message in %d second(s)", amount)
|
||||
}
|
||||
|
||||
return []*model.Message{
|
||||
{
|
||||
Text: confirmText,
|
||||
Chat: msg.Chat,
|
||||
Channel: msg.Channel,
|
||||
Author: "bot",
|
||||
FromBot: true,
|
||||
Date: time.Now(),
|
||||
ReplyTo: msg.ID,
|
||||
},
|
||||
}
|
||||
}
|
164
internal/plugin/reminder/reminder_test.go
Normal file
164
internal/plugin/reminder/reminder_test.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package reminder
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||
)
|
||||
|
||||
// MockCreator is a mock implementation of ReminderCreator for testing
|
||||
type MockCreator struct {
|
||||
reminders []*model.Reminder
|
||||
}
|
||||
|
||||
func (m *MockCreator) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) {
|
||||
reminder := &model.Reminder{
|
||||
ID: int64(len(m.reminders) + 1),
|
||||
Platform: platform,
|
||||
ChannelID: channelID,
|
||||
MessageID: messageID,
|
||||
ReplyToID: replyToID,
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Content: content,
|
||||
TriggerAt: triggerAt,
|
||||
}
|
||||
m.reminders = append(m.reminders, reminder)
|
||||
return reminder, nil
|
||||
}
|
||||
|
||||
func TestReminderOnMessage(t *testing.T) {
|
||||
creator := &MockCreator{reminders: make([]*model.Reminder, 0)}
|
||||
plugin := New(creator)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message *model.Message
|
||||
expectResponse bool
|
||||
expectReminder bool
|
||||
}{
|
||||
{
|
||||
name: "Valid reminder command - years",
|
||||
message: &model.Message{
|
||||
Text: "!remindme 1y",
|
||||
ReplyTo: "original-message-id",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: true,
|
||||
expectReminder: true,
|
||||
},
|
||||
{
|
||||
name: "Valid reminder command - months",
|
||||
message: &model.Message{
|
||||
Text: "!remindme 3mo",
|
||||
ReplyTo: "original-message-id",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: true,
|
||||
expectReminder: true,
|
||||
},
|
||||
{
|
||||
name: "Valid reminder command - days",
|
||||
message: &model.Message{
|
||||
Text: "!remindme 2d",
|
||||
ReplyTo: "original-message-id",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: true,
|
||||
expectReminder: true,
|
||||
},
|
||||
{
|
||||
name: "Valid reminder command - hours",
|
||||
message: &model.Message{
|
||||
Text: "!remindme 5h",
|
||||
ReplyTo: "original-message-id",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: true,
|
||||
expectReminder: true,
|
||||
},
|
||||
{
|
||||
name: "Valid reminder command - minutes",
|
||||
message: &model.Message{
|
||||
Text: "!remindme 30m",
|
||||
ReplyTo: "original-message-id",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: true,
|
||||
expectReminder: true,
|
||||
},
|
||||
{
|
||||
name: "Valid reminder command - seconds",
|
||||
message: &model.Message{
|
||||
Text: "!remindme 60s",
|
||||
ReplyTo: "original-message-id",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: true,
|
||||
expectReminder: true,
|
||||
},
|
||||
{
|
||||
name: "Not a reply",
|
||||
message: &model.Message{
|
||||
Text: "!remindme 2d",
|
||||
ReplyTo: "",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: false,
|
||||
expectReminder: false,
|
||||
},
|
||||
{
|
||||
name: "Not a reminder command",
|
||||
message: &model.Message{
|
||||
Text: "hello world",
|
||||
ReplyTo: "original-message-id",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: false,
|
||||
expectReminder: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid duration format",
|
||||
message: &model.Message{
|
||||
Text: "!remindme abc",
|
||||
ReplyTo: "original-message-id",
|
||||
Author: "testuser",
|
||||
Channel: &model.Channel{Platform: "test"},
|
||||
},
|
||||
expectResponse: false,
|
||||
expectReminder: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
initialCount := len(creator.reminders)
|
||||
responses := plugin.OnMessage(tt.message, nil)
|
||||
|
||||
if tt.expectResponse && len(responses) == 0 {
|
||||
t.Errorf("Expected response, but got none")
|
||||
}
|
||||
|
||||
if !tt.expectResponse && len(responses) > 0 {
|
||||
t.Errorf("Expected no response, but got %d", len(responses))
|
||||
}
|
||||
|
||||
if tt.expectReminder && len(creator.reminders) != initialCount+1 {
|
||||
t.Errorf("Expected reminder to be created, but it wasn't")
|
||||
}
|
||||
|
||||
if !tt.expectReminder && len(creator.reminders) != initialCount {
|
||||
t.Errorf("Expected no reminder to be created, but got %d", len(creator.reminders)-initialCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,6 +3,9 @@ package queue
|
|||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
||||
)
|
||||
|
||||
// Item represents a queue item
|
||||
|
@ -14,14 +17,19 @@ type Item struct {
|
|||
// HandlerFunc defines a function that processes queue items
|
||||
type HandlerFunc func(item Item)
|
||||
|
||||
// ReminderHandlerFunc defines a function that processes reminder items
|
||||
type ReminderHandlerFunc func(reminder *model.Reminder)
|
||||
|
||||
// Queue represents a message queue
|
||||
type Queue struct {
|
||||
items chan Item
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
logger *slog.Logger
|
||||
running bool
|
||||
runMutex sync.Mutex
|
||||
items chan Item
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
logger *slog.Logger
|
||||
running bool
|
||||
runMutex sync.Mutex
|
||||
reminderTicker *time.Ticker
|
||||
reminderHandler ReminderHandlerFunc
|
||||
}
|
||||
|
||||
// New creates a new Queue instance
|
||||
|
@ -49,6 +57,24 @@ func (q *Queue) Start(handler HandlerFunc) {
|
|||
go q.worker(handler)
|
||||
}
|
||||
|
||||
// StartReminderScheduler starts the reminder scheduler
|
||||
func (q *Queue) StartReminderScheduler(handler ReminderHandlerFunc) {
|
||||
q.runMutex.Lock()
|
||||
defer q.runMutex.Unlock()
|
||||
|
||||
if q.reminderTicker != nil {
|
||||
return
|
||||
}
|
||||
|
||||
q.reminderHandler = handler
|
||||
|
||||
// Check for reminders every minute
|
||||
q.reminderTicker = time.NewTicker(1 * time.Minute)
|
||||
|
||||
q.wg.Add(1)
|
||||
go q.reminderWorker()
|
||||
}
|
||||
|
||||
// Stop stops processing queue items
|
||||
func (q *Queue) Stop() {
|
||||
q.runMutex.Lock()
|
||||
|
@ -59,6 +85,12 @@ func (q *Queue) Stop() {
|
|||
}
|
||||
|
||||
q.running = false
|
||||
|
||||
// Stop reminder ticker if it exists
|
||||
if q.reminderTicker != nil {
|
||||
q.reminderTicker.Stop()
|
||||
}
|
||||
|
||||
close(q.quit)
|
||||
q.wg.Wait()
|
||||
}
|
||||
|
@ -96,4 +128,34 @@ func (q *Queue) worker(handler HandlerFunc) {
|
|||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reminderWorker processes reminder items on a schedule
|
||||
func (q *Queue) reminderWorker() {
|
||||
defer q.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-q.reminderTicker.C:
|
||||
// This is triggered every minute to check for pending reminders
|
||||
q.logger.Debug("Checking for pending reminders")
|
||||
|
||||
if q.reminderHandler != nil {
|
||||
// The handler is responsible for fetching and processing reminders
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
q.logger.Error("Panic in reminder worker", "error", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Call the handler with a nil reminder to indicate it should check the database
|
||||
q.reminderHandler(nil)
|
||||
}()
|
||||
}
|
||||
case <-q.quit:
|
||||
// Quit worker
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue