393 lines
No EOL
9.7 KiB
Go
393 lines
No EOL
9.7 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"runtime/debug"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"git.nakama.town/fmartingr/butterrobot/internal/admin"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/config"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/db"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/model"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/platform"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/plugin/social"
|
|
"git.nakama.town/fmartingr/butterrobot/internal/queue"
|
|
)
|
|
|
|
// App represents the application
|
|
type App struct {
|
|
config *config.Config
|
|
logger *slog.Logger
|
|
db *db.Database
|
|
router *http.ServeMux
|
|
queue *queue.Queue
|
|
admin *admin.Admin
|
|
version string
|
|
}
|
|
|
|
// New creates a new App instance
|
|
func New(cfg *config.Config, logger *slog.Logger) (*App, error) {
|
|
// Initialize router
|
|
router := http.NewServeMux()
|
|
|
|
// Initialize database
|
|
database, err := db.New(cfg.DatabasePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
|
}
|
|
|
|
// Initialize message queue
|
|
messageQueue := queue.New(logger)
|
|
|
|
// Get version information
|
|
version := ""
|
|
info, ok := debug.ReadBuildInfo()
|
|
if ok {
|
|
version = info.Main.Version
|
|
}
|
|
|
|
// Initialize admin interface
|
|
adminInterface := admin.New(cfg, database, version)
|
|
|
|
return &App{
|
|
config: cfg,
|
|
logger: logger,
|
|
db: database,
|
|
router: router,
|
|
queue: messageQueue,
|
|
admin: adminInterface,
|
|
version: version,
|
|
}, nil
|
|
}
|
|
|
|
// Run starts the application
|
|
func (a *App) Run() error {
|
|
// Initialize platforms
|
|
if err := platform.InitializePlatforms(a.config); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Register built-in plugins
|
|
plugin.Register(ping.New())
|
|
plugin.Register(fun.NewCoin())
|
|
plugin.Register(fun.NewDice())
|
|
plugin.Register(fun.NewLoquito())
|
|
plugin.Register(social.NewTwitterExpander())
|
|
plugin.Register(social.NewInstagramExpander())
|
|
|
|
// Register reminder plugin
|
|
reminderPlugin := reminder.New(a.db)
|
|
plugin.Register(reminderPlugin)
|
|
|
|
// Initialize routes
|
|
a.initializeRoutes()
|
|
|
|
// Start message queue worker
|
|
a.queue.Start(a.handleMessage)
|
|
|
|
// Start reminder scheduler
|
|
a.queue.StartReminderScheduler(a.handleReminder)
|
|
|
|
// Create server
|
|
addr := fmt.Sprintf(":%s", a.config.Port)
|
|
srv := &http.Server{
|
|
Addr: addr,
|
|
Handler: a.router,
|
|
}
|
|
|
|
// Start server in a goroutine
|
|
go func() {
|
|
a.logger.Info("Server starting on", "addr", addr)
|
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
a.logger.Error("Server error", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
}()
|
|
|
|
// Wait for interrupt signal
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
<-quit
|
|
|
|
a.logger.Info("Shutting down server...")
|
|
|
|
// Create shutdown context with timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
// Shutdown server
|
|
if err := srv.Shutdown(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Stop message queue
|
|
a.queue.Stop()
|
|
|
|
// Close database connection
|
|
if err := a.db.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
a.logger.Info("Server stopped")
|
|
|
|
return nil
|
|
}
|
|
|
|
// Initialize HTTP routes
|
|
func (a *App) initializeRoutes() {
|
|
// Health check endpoint
|
|
a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(map[string]interface{}{}); err != nil {
|
|
a.logger.Error("Error encoding response", "error", err)
|
|
}
|
|
})
|
|
|
|
// Platform webhook endpoints
|
|
for name := range platform.GetAvailablePlatforms() {
|
|
a.logger.Info("Registering webhook endpoint for platform", "platform", name)
|
|
platformName := name // Create a copy to avoid closure issues
|
|
a.router.HandleFunc("/"+platformName+"/incoming/", a.handleIncomingWebhook)
|
|
}
|
|
|
|
// Register admin routes
|
|
a.admin.RegisterRoutes(a.router)
|
|
}
|
|
|
|
// Handle incoming webhook
|
|
func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) {
|
|
// Extract platform name from path
|
|
platformName := extractPlatformName(r.URL.Path)
|
|
|
|
// Check if platform exists
|
|
if _, err := platform.Get(platformName); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}); err != nil {
|
|
a.logger.Error("Error encoding response", "error", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Read request body
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}); err != nil {
|
|
a.logger.Error("Error encoding response", "error", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Queue message for processing
|
|
a.queue.Add(queue.Item{
|
|
Platform: platformName,
|
|
Request: map[string]any{
|
|
"path": r.URL.Path,
|
|
"json": json.RawMessage(body),
|
|
},
|
|
})
|
|
|
|
// Respond with success
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil {
|
|
a.logger.Error("Error encoding response", "error", err)
|
|
}
|
|
}
|
|
|
|
// extractPlatformName extracts the platform name from the URL path
|
|
func extractPlatformName(path string) string {
|
|
// Remove leading slash
|
|
path = strings.TrimPrefix(path, "/")
|
|
|
|
// Split by slash
|
|
parts := strings.Split(path, "/")
|
|
|
|
// First part is the platform name
|
|
if len(parts) > 0 {
|
|
// Special case for Telegram with token in the URL
|
|
if parts[0] == "telegram" && len(parts) > 1 && parts[1] == "incoming" {
|
|
return "telegram"
|
|
}
|
|
return parts[0]
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// Handle message processing
|
|
func (a *App) handleMessage(item queue.Item) {
|
|
// Get platform
|
|
p, err := platform.Get(item.Platform)
|
|
if err != nil {
|
|
a.logger.Error("Error getting platform", "error", err)
|
|
return
|
|
}
|
|
|
|
// Create a new request with the body
|
|
bodyJSON, ok := item.Request["json"].(json.RawMessage)
|
|
if !ok {
|
|
a.logger.Error("Invalid JSON in request")
|
|
return
|
|
}
|
|
|
|
reqPath, ok := item.Request["path"].(string)
|
|
if !ok {
|
|
a.logger.Error("Invalid path in request")
|
|
return
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", reqPath, strings.NewReader(string(bodyJSON)))
|
|
if err != nil {
|
|
a.logger.Error("Error creating request", "error", err)
|
|
return
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// Parse message
|
|
message, err := p.ParseIncomingMessage(req)
|
|
if err != nil {
|
|
a.logger.Error("Error parsing message", "error", err)
|
|
return
|
|
}
|
|
|
|
// Skip if message is from a bot
|
|
if message == nil || message.FromBot {
|
|
return
|
|
}
|
|
|
|
// Get or create channel
|
|
channel, err := a.db.GetChannelByPlatform(item.Platform, message.Chat)
|
|
if err == db.ErrNotFound {
|
|
channel, err = a.db.CreateChannel(item.Platform, message.Chat, false, message.Channel.ChannelRaw)
|
|
if err != nil {
|
|
a.logger.Error("Error creating channel", "error", err)
|
|
return
|
|
}
|
|
} else if err != nil {
|
|
a.logger.Error("Error getting channel", "error", err)
|
|
return
|
|
}
|
|
|
|
// Skip if channel is disabled
|
|
if !channel.Enabled {
|
|
return
|
|
}
|
|
|
|
// Process message with plugins
|
|
for pluginID, channelPlugin := range channel.Plugins {
|
|
if !channel.HasEnabledPlugin(pluginID) {
|
|
continue
|
|
}
|
|
|
|
// Get plugin
|
|
p, err := plugin.Get(pluginID)
|
|
if err != nil {
|
|
a.logger.Error("Error getting plugin", "error", err)
|
|
continue
|
|
}
|
|
|
|
// Process message
|
|
responses := p.OnMessage(message, channelPlugin.Config)
|
|
|
|
// Send responses
|
|
platform, err := platform.Get(item.Platform)
|
|
if err != nil {
|
|
a.logger.Error("Error getting platform", "error", err)
|
|
continue
|
|
}
|
|
|
|
for _, response := range responses {
|
|
if err := platform.SendMessage(response); err != nil {
|
|
a.logger.Error("Error sending message", "error", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleReminder handles reminder processing
|
|
func (a *App) handleReminder(reminder *model.Reminder) {
|
|
// When called with nil, it means we should check for pending reminders
|
|
if reminder == nil {
|
|
// Get pending reminders
|
|
reminders, err := a.db.GetPendingReminders()
|
|
if err != nil {
|
|
a.logger.Error("Error getting pending reminders", "error", err)
|
|
return
|
|
}
|
|
|
|
// Process each reminder
|
|
for _, r := range reminders {
|
|
a.processReminder(r)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Otherwise, process the specific reminder
|
|
a.processReminder(reminder)
|
|
}
|
|
|
|
// processReminder processes an individual reminder
|
|
func (a *App) processReminder(reminder *model.Reminder) {
|
|
a.logger.Info("Processing reminder",
|
|
"id", reminder.ID,
|
|
"platform", reminder.Platform,
|
|
"channel", reminder.ChannelID,
|
|
"trigger_at", reminder.TriggerAt,
|
|
)
|
|
|
|
// Get the platform handler
|
|
p, err := platform.Get(reminder.Platform)
|
|
if err != nil {
|
|
a.logger.Error("Error getting platform for reminder", "error", err, "platform", reminder.Platform)
|
|
return
|
|
}
|
|
|
|
// Get the channel
|
|
channel, err := a.db.GetChannelByPlatform(reminder.Platform, reminder.ChannelID)
|
|
if err != nil {
|
|
a.logger.Error("Error getting channel for reminder", "error", err)
|
|
return
|
|
}
|
|
|
|
// Create the reminder message
|
|
reminderText := fmt.Sprintf("@%s reminding you of this", reminder.Username)
|
|
|
|
message := &model.Message{
|
|
Text: reminderText,
|
|
Chat: reminder.ChannelID,
|
|
Channel: channel,
|
|
Author: "bot",
|
|
FromBot: true,
|
|
Date: time.Now(),
|
|
ReplyTo: reminder.ReplyToID, // Reply to the original message
|
|
}
|
|
|
|
// Send the reminder message
|
|
if err := p.SendMessage(message); err != nil {
|
|
a.logger.Error("Error sending reminder", "error", err)
|
|
return
|
|
}
|
|
|
|
// Mark the reminder as processed
|
|
if err := a.db.MarkReminderAsProcessed(reminder.ID); err != nil {
|
|
a.logger.Error("Error marking reminder as processed", "error", err)
|
|
}
|
|
} |