Compare commits
	
		
			40 commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 248c42d609 | |||
| 5bec3b6a7c | |||
| 377b1723c3 | |||
| 60ceaffd82 | |||
| 3a5b5c216d | |||
| bdc797d5c1 | |||
| 0edf41c792 | |||
| 35c14ce8a8 | |||
| e0ff369cff | |||
| 368c45cd13 | |||
| 3b09a9dd47 | |||
| 899ac49336 | |||
| fc77c97547 | |||
| 3a4ba376e7 | |||
| bd9854676d | |||
| 4fc5ae63a1 | |||
| 3771d2de65 | |||
| c7fdb9fc6a | |||
| 1f80a22f4a | |||
| 1e0bc86b21 | |||
| 8fa74fd046 | |||
| d09b763aa7 | |||
| c53942ac53 | |||
| a9b4ad52cb | |||
| 4a154f16f9 | |||
| 8d188217e9 | |||
| fae6f35774 | |||
| 7dd02c0056 | |||
| c9edb57505 | |||
| 763a451251 | |||
| abcd3c3c44 | |||
| 323ea4e8cd | |||
| 72c6dd6982 | |||
| 21e4c434fd | |||
| a0f12efd65 | |||
| c920eb94a0 | |||
| e0ae0c2a0b | |||
| 6aedfc794f | |||
| ece8280358 | |||
| 84e5feeb81 | 
					 56 changed files with 5726 additions and 219 deletions
				
			
		
							
								
								
									
										5
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							|  | @ -5,9 +5,12 @@ __pycache__ | ||||||
| *.cert | *.cert | ||||||
| .env-local | .env-local | ||||||
| .coverage | .coverage | ||||||
|  | coverage.out | ||||||
| 
 | 
 | ||||||
| dist | dist | ||||||
| bin | bin | ||||||
| # Butterrobot | # Butterrobot | ||||||
| *.sqlite* | *.sqlite* | ||||||
| butterrobot.db | butterrobot.db* | ||||||
|  | /butterrobot | ||||||
|  | *_test.db* | ||||||
|  |  | ||||||
|  | @ -93,7 +93,7 @@ docker_manifests: | ||||||
| 
 | 
 | ||||||
| nfpms: | nfpms: | ||||||
|   - maintainer: Felipe Martin <me@fmartingr.com> |   - maintainer: Felipe Martin <me@fmartingr.com> | ||||||
|     description: SMTP server to forward messages to shoutrrr endpoints |     description: A chatbot server with customizable commands and triggers | ||||||
|     homepage: https://git.nakama.town/fmartingr/butterrobot |     homepage: https://git.nakama.town/fmartingr/butterrobot | ||||||
|     license: AGPL-3.0 |     license: AGPL-3.0 | ||||||
|     formats: |     formats: | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ when: | ||||||
|     - push |     - push | ||||||
|     - pull_request |     - pull_request | ||||||
|   branch: |   branch: | ||||||
|     - main |     - master | ||||||
| 
 | 
 | ||||||
| steps: | steps: | ||||||
|   format: |   format: | ||||||
|  |  | ||||||
|  | @ -1,6 +1,6 @@ | ||||||
| when: | when: | ||||||
|   - event: tag |   - event: tag | ||||||
|     branch: main |     branch: master | ||||||
| 
 | 
 | ||||||
| steps: | steps: | ||||||
|   - name: Release |   - name: Release | ||||||
|  | @ -13,4 +13,4 @@ steps: | ||||||
|       - "/var/run/docker.sock:/var/run/docker.sock" |       - "/var/run/docker.sock:/var/run/docker.sock" | ||||||
|     commands: |     commands: | ||||||
|       - docker login -u fmartingr -p $GITEA_TOKEN git.nakama.town |       - docker login -u fmartingr -p $GITEA_TOKEN git.nakama.town | ||||||
|       - goreleaser release --clean |       - goreleaser release --clean  --parallelism=2 | ||||||
|  |  | ||||||
							
								
								
									
										29
									
								
								CLAUDE.md
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								CLAUDE.md
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,29 @@ | ||||||
|  | # Claude Code Instructions | ||||||
|  | 
 | ||||||
|  | ## Plugin Development Workflow | ||||||
|  | 
 | ||||||
|  | When creating, modifying, or removing plugins: | ||||||
|  | 
 | ||||||
|  | 1. **Always update the plugin documentation** in `docs/plugins.md` after any plugin changes | ||||||
|  | 2. Ensure the documentation includes: | ||||||
|  |    - Plugin name and category (Development, Fun and entertainment, Utility, Security, Social Media) | ||||||
|  |    - Brief description of functionality | ||||||
|  |    - Usage instructions with examples | ||||||
|  |    - Any configuration requirements | ||||||
|  | 3. **For plugins with configuration options:** | ||||||
|  |    - Set `ConfigRequired: true` in the plugin's BasePlugin struct | ||||||
|  |    - Add corresponding HTML form fields in `internal/admin/templates/channel_plugin_config.html` | ||||||
|  |    - Use conditional template logic: `{{else if eq .ChannelPlugin.PluginID "plugin.id"}}` | ||||||
|  |    - Include proper form labels, help text, and value binding | ||||||
|  | 
 | ||||||
|  | ## Testing | ||||||
|  | 
 | ||||||
|  | **CRITICAL**: After making ANY changes to code files, you MUST run these commands in order: | ||||||
|  | 
 | ||||||
|  | 1. **Format code**: `make format` - Format all code according to project standards | ||||||
|  | 2. **Lint code**: `make lint` - Check code style and quality (must show "0 issues") | ||||||
|  | 3. **Run tests**: `make test` - Run all tests to ensure functionality works | ||||||
|  | 4. Verify documentation accuracy | ||||||
|  | 5. Ensure all examples work as described | ||||||
|  | 
 | ||||||
|  | **These commands are MANDATORY after every code change, no exceptions.** | ||||||
							
								
								
									
										13
									
								
								README.md
									
										
									
									
									
								
							
							
						
						
									
										13
									
								
								README.md
									
										
									
									
									
								
							|  | @ -1,9 +1,6 @@ | ||||||
| # Butter Robot | # Butter Robot | ||||||
| 
 | 
 | ||||||
| | Stable | Master | |  | ||||||
| | --- | --- | |  | ||||||
| |  |  | |  | ||||||
| |  |  | |  | ||||||
| 
 | 
 | ||||||
| Go framework to create bots for several platforms. | Go framework to create bots for several platforms. | ||||||
| 
 | 
 | ||||||
|  | @ -13,7 +10,7 @@ Go framework to create bots for several platforms. | ||||||
| 
 | 
 | ||||||
| ## Features | ## Features | ||||||
| 
 | 
 | ||||||
| - Support for multiple chat platforms (Slack, Telegram) | - Support for multiple chat platforms (Slack (untested!), Telegram) | ||||||
| - Plugin system for easy extension | - Plugin system for easy extension | ||||||
| - Admin interface for managing channels and plugins | - Admin interface for managing channels and plugins | ||||||
| - Message queue for asynchronous processing | - Message queue for asynchronous processing | ||||||
|  | @ -22,6 +19,12 @@ Go framework to create bots for several platforms. | ||||||
| 
 | 
 | ||||||
| [Go to documentation](./docs) | [Go to documentation](./docs) | ||||||
| 
 | 
 | ||||||
|  | ### Database Management | ||||||
|  | 
 | ||||||
|  | ButterRobot includes an automatic database migration system. Migrations are applied automatically when the application starts, ensuring your database schema is always up to date. | ||||||
|  | 
 | ||||||
|  | [Learn more about migrations](./docs/migrations.md) | ||||||
|  | 
 | ||||||
| ## Installation | ## Installation | ||||||
| 
 | 
 | ||||||
| ### From Source | ### From Source | ||||||
|  |  | ||||||
|  | @ -1,11 +1,15 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"log/slog" | 	"log/slog" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"runtime/debug" | ||||||
| 
 | 
 | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/app" | 	"git.nakama.town/fmartingr/butterrobot/internal/app" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/config" | 	"git.nakama.town/fmartingr/butterrobot/internal/config" | ||||||
|  | 
 | ||||||
|  | 	_ "golang.org/x/crypto/x509roots/fallback" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func main() { | func main() { | ||||||
|  | @ -19,15 +23,26 @@ func main() { | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Handle version command | ||||||
|  | 	if len(os.Args) > 1 && os.Args[1] == "version" { | ||||||
|  | 		info, ok := debug.ReadBuildInfo() | ||||||
|  | 		if ok { | ||||||
|  | 			fmt.Printf("ButterRobot version %s\n", info.Main.Version) | ||||||
|  | 		} else { | ||||||
|  | 			fmt.Println("ButterRobot. Can't determine build information.") | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Initialize and run application | 	// Initialize and run application | ||||||
| 	application, err := app.New(cfg, logger) | 	application, err := app.New(cfg, logger) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Error("Failed to initialize application", "error", err) | 		logger.Error("Failed to initialize application", "error", err) | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
| 	} | 	} | ||||||
| 	 | 
 | ||||||
| 	if err := application.Run(); err != nil { | 	if err := application.Run(); err != nil { | ||||||
| 		logger.Error("Application error", "error", err) | 		logger.Error("Application error", "error", err) | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,6 +1,19 @@ | ||||||
| # Creating a Plugin | # Creating a Plugin | ||||||
| 
 | 
 | ||||||
| ## Example | ## Plugin Categories | ||||||
|  | 
 | ||||||
|  | ButterRobot organizes plugins into different categories: | ||||||
|  | 
 | ||||||
|  | - **Development**: Utility plugins like `ping` | ||||||
|  | - **Fun**: Entertainment plugins like dice rolling, coin flipping | ||||||
|  | - **Social**: Social media related plugins like URL transformers/expanders | ||||||
|  | - **Security**: Moderation and protection features like domain blocking | ||||||
|  | 
 | ||||||
|  | When creating a new plugin, consider which category it fits into and place it in the appropriate directory. | ||||||
|  | 
 | ||||||
|  | ## Plugin Examples | ||||||
|  | 
 | ||||||
|  | ### Basic Example: Marco Polo | ||||||
| 
 | 
 | ||||||
| This simple "Marco Polo" plugin will answer _Polo_ to the user that says _Marco_: | This simple "Marco Polo" plugin will answer _Polo_ to the user that says _Marco_: | ||||||
| 
 | 
 | ||||||
|  | @ -47,6 +60,207 @@ func (p *MarcoPlugin) OnMessage(msg *model.Message, config map[string]interface{ | ||||||
| } | } | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
|  | ### Configuration-Enabled Plugin | ||||||
|  | 
 | ||||||
|  | This plugin requires configuration to be set in the admin interface. It demonstrates how to create plugins that need channel-specific configuration: | ||||||
|  | 
 | ||||||
|  | ```go | ||||||
|  | package security | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // DomainBlockPlugin is a plugin that blocks messages containing links from specific domains | ||||||
|  | type DomainBlockPlugin struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new DomainBlockPlugin instance | ||||||
|  | func New() *DomainBlockPlugin { | ||||||
|  | 	return &DomainBlockPlugin{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:             "security.domainblock", | ||||||
|  | 			Name:           "Domain Blocker", | ||||||
|  | 			Help:           "Blocks messages containing links from configured domains", | ||||||
|  | 			ConfigRequired: true, // Mark this plugin as requiring configuration | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage processes incoming messages | ||||||
|  | func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { | ||||||
|  | 	// Get blocked domains from config | ||||||
|  | 	blockedDomainsStr, ok := config["blocked_domains"].(string) | ||||||
|  | 	if !ok || blockedDomainsStr == "" { | ||||||
|  | 		return nil // No blocked domains configured | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Split and clean blocked domains | ||||||
|  | 	blockedDomains := strings.Split(blockedDomainsStr, ",") | ||||||
|  | 	for i, domain := range blockedDomains { | ||||||
|  | 		blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract domains from message | ||||||
|  | 	urlRegex := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`) | ||||||
|  | 	matches := urlRegex.FindAllStringSubmatch(msg.Text, -1) | ||||||
|  | 	 | ||||||
|  | 	// Check if any extracted domains are blocked | ||||||
|  | 	for _, match := range matches { | ||||||
|  | 		if len(match) < 2 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		 | ||||||
|  | 		domain := strings.ToLower(match[1]) | ||||||
|  | 		 | ||||||
|  | 		for _, blockedDomain := range blockedDomains { | ||||||
|  | 			if blockedDomain == "" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			 | ||||||
|  | 			if strings.HasSuffix(domain, blockedDomain) || domain == blockedDomain { | ||||||
|  | 				// Domain is blocked, create warning message | ||||||
|  | 				response := &model.Message{ | ||||||
|  | 					Text:    fmt.Sprintf("⚠️ Message contained a link to blocked domain: %s", blockedDomain), | ||||||
|  | 					Chat:    msg.Chat, | ||||||
|  | 					ReplyTo: msg.ID, | ||||||
|  | 					Channel: msg.Channel, | ||||||
|  | 				} | ||||||
|  | 				return []*model.Message{response} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func init() { | ||||||
|  | 	plugin.Register(New()) | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ### Advanced Example: URL Transformer | ||||||
|  | 
 | ||||||
|  | This more complex plugin transforms URLs, useful for improving media embedding in chat platforms: | ||||||
|  | 
 | ||||||
|  | ```go | ||||||
|  | package social | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/url" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // TwitterExpander transforms twitter.com links to fxtwitter.com links | ||||||
|  | type TwitterExpander struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new TwitterExpander instance | ||||||
|  | func NewTwitter() *TwitterExpander { | ||||||
|  | 	return &TwitterExpander{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:   "social.twitter", | ||||||
|  | 			Name: "Twitter Link Expander", | ||||||
|  | 			Help: "Automatically converts twitter.com links to fxtwitter.com links and removes tracking parameters", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage handles incoming messages | ||||||
|  | func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { | ||||||
|  | 	// Skip empty messages | ||||||
|  | 	if strings.TrimSpace(msg.Text) == "" { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Regex to match twitter.com links | ||||||
|  | 	twitterRegex := regexp.MustCompile(`https?://(www\.)?(twitter\.com|x\.com)/[^\s]+`) | ||||||
|  | 
 | ||||||
|  | 	// Check if the message contains a Twitter link | ||||||
|  | 	if !twitterRegex.MatchString(msg.Text) { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Transform the URL | ||||||
|  | 	transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { | ||||||
|  | 		// Parse the URL | ||||||
|  | 		parsedURL, err := url.Parse(link) | ||||||
|  | 		if err != nil { | ||||||
|  | 			// If parsing fails, just do the simple replacement | ||||||
|  | 			link = strings.Replace(link, "twitter.com", "fxtwitter.com", 1) | ||||||
|  | 			link = strings.Replace(link, "x.com", "fxtwitter.com", 1) | ||||||
|  | 			return link | ||||||
|  | 		} | ||||||
|  | 		 | ||||||
|  | 		// Change the host | ||||||
|  | 		if strings.Contains(parsedURL.Host, "twitter.com") { | ||||||
|  | 			parsedURL.Host = strings.Replace(parsedURL.Host, "twitter.com", "fxtwitter.com", 1) | ||||||
|  | 		} else if strings.Contains(parsedURL.Host, "x.com") { | ||||||
|  | 			parsedURL.Host = strings.Replace(parsedURL.Host, "x.com", "fxtwitter.com", 1) | ||||||
|  | 		} | ||||||
|  | 		 | ||||||
|  | 		// Remove query parameters | ||||||
|  | 		parsedURL.RawQuery = "" | ||||||
|  | 		 | ||||||
|  | 		// Return the cleaned URL | ||||||
|  | 		return parsedURL.String() | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Create response message | ||||||
|  | 	response := &model.Message{ | ||||||
|  | 		Text:    transformed, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		ReplyTo: msg.ID, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.Message{response} | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ## Enabling Configuration for Plugins | ||||||
|  | 
 | ||||||
|  | To indicate that your plugin requires configuration: | ||||||
|  | 
 | ||||||
|  | 1. Set `ConfigRequired: true` in the BasePlugin struct: | ||||||
|  |    ```go | ||||||
|  |    BasePlugin: plugin.BasePlugin{ | ||||||
|  |        ID:             "myplugin.id", | ||||||
|  |        Name:           "Plugin Name", | ||||||
|  |        Help:           "Help text", | ||||||
|  |        ConfigRequired: true, | ||||||
|  |    }, | ||||||
|  |    ``` | ||||||
|  | 
 | ||||||
|  | 2. Access the configuration in the OnMessage method: | ||||||
|  |    ```go | ||||||
|  |    func (p *MyPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { | ||||||
|  |        // Extract configuration values | ||||||
|  |        configValue, ok := config["some_config_key"].(string) | ||||||
|  |        if !ok || configValue == "" { | ||||||
|  |            // Handle missing or empty configuration | ||||||
|  |            return nil | ||||||
|  |        } | ||||||
|  |         | ||||||
|  |        // Use the configuration... | ||||||
|  |    } | ||||||
|  |    ``` | ||||||
|  | 
 | ||||||
|  | 3. The admin interface will show a "Configure" button for plugins that require configuration. | ||||||
|  | 
 | ||||||
|  | ## Registering Plugins | ||||||
|  | 
 | ||||||
| To use the plugin, register it in your application: | To use the plugin, register it in your application: | ||||||
| 
 | 
 | ||||||
| ```go | ```go | ||||||
|  | @ -55,8 +269,19 @@ func (a *App) Run() error { | ||||||
|     // ... |     // ... | ||||||
| 
 | 
 | ||||||
|     // Register plugins |     // Register plugins | ||||||
|     plugin.Register(myplugin.New()) |     plugin.Register(ping.New())                 // Development plugin | ||||||
|  |     plugin.Register(fun.NewCoin())              // Fun plugin   | ||||||
|  |     plugin.Register(social.NewTwitter())        // Social media plugin | ||||||
|  |     plugin.Register(myplugin.New())             // Your custom plugin | ||||||
| 
 | 
 | ||||||
|     // ... |     // ... | ||||||
| } | } | ||||||
| ``` | ``` | ||||||
|  | 
 | ||||||
|  | Alternatively, you can register your plugin in its init() function: | ||||||
|  | 
 | ||||||
|  | ```go | ||||||
|  | func init() { | ||||||
|  |     plugin.Register(New()) | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
							
								
								
									
										99
									
								
								docs/migrations.md
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								docs/migrations.md
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,99 @@ | ||||||
|  | # Database Migrations | ||||||
|  | 
 | ||||||
|  | ButterRobot uses a simple database migration system to manage database schema changes. This document explains how the migration system works and how to extend it. | ||||||
|  | 
 | ||||||
|  | ## Automatic Migrations | ||||||
|  | 
 | ||||||
|  | Migrations in ButterRobot are applied automatically when the application starts. This ensures your database schema is always up to date without requiring manual intervention. | ||||||
|  | 
 | ||||||
|  | The migration system: | ||||||
|  | 1. Checks which migrations have been applied | ||||||
|  | 2. Applies any pending migrations in sequential order | ||||||
|  | 3. Records each successful migration in the `schema_migrations` table | ||||||
|  | 
 | ||||||
|  | ## Initial State | ||||||
|  | 
 | ||||||
|  | The initial migration (version 1) sets up the database with the following: | ||||||
|  | 
 | ||||||
|  | - `channels` table for chat platforms | ||||||
|  | - `channel_plugin` table for plugins associated with channels | ||||||
|  | - `users` table for admin users with bcrypt password hashing | ||||||
|  | - Default admin user with username "admin" and password "admin" | ||||||
|  | 
 | ||||||
|  | This migration represents the current state of the database schema. It is not backwards compatible with previous versions of ButterRobot. | ||||||
|  | 
 | ||||||
|  | ## Creating New Migrations | ||||||
|  | 
 | ||||||
|  | To add a new migration, follow these steps: | ||||||
|  | 
 | ||||||
|  | 1. Open `/internal/migration/migrations.go` | ||||||
|  | 2. Add a new migration version in the `init()` function: | ||||||
|  | 
 | ||||||
|  | ```go | ||||||
|  | Register(2, "Add example table", migrateAddExampleTableUp, migrateAddExampleTableDown) | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | 3. Implement the up and down functions for your migration: | ||||||
|  | 
 | ||||||
|  | ```go | ||||||
|  | // Migration to add example table - version 2 | ||||||
|  | func migrateAddExampleTableUp(db *sql.DB) error { | ||||||
|  |     _, err := db.Exec(` | ||||||
|  |         CREATE TABLE IF NOT EXISTS example ( | ||||||
|  |             id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
|  |             name TEXT NOT NULL, | ||||||
|  |             created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | ||||||
|  |         ) | ||||||
|  |     `) | ||||||
|  |     return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func migrateAddExampleTableDown(db *sql.DB) error { | ||||||
|  |     _, err := db.Exec(`DROP TABLE IF EXISTS example`) | ||||||
|  |     return err | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ## Migration Guidelines | ||||||
|  | 
 | ||||||
|  | 1. **Incremental Changes**: Each migration should make a small, focused change to the database schema. | ||||||
|  | 2. **Backward Compatibility**: Ensure migrations are backward compatible with existing code when possible. | ||||||
|  | 3. **Test Thoroughly**: Test both up and down migrations before deploying. | ||||||
|  | 4. **Document Changes**: Add comments explaining the purpose of each migration. | ||||||
|  | 5. **Version Numbers**: Use sequential version numbers for migrations. | ||||||
|  | 
 | ||||||
|  | ## How Migrations Work | ||||||
|  | 
 | ||||||
|  | The migration system tracks applied migrations in a `schema_migrations` table. When you run migrations, the system: | ||||||
|  | 
 | ||||||
|  | 1. Checks which migrations have been applied | ||||||
|  | 2. Applies any pending migrations in order | ||||||
|  | 3. Records each successful migration in the `schema_migrations` table | ||||||
|  | 
 | ||||||
|  | When rolling back, it performs the down migrations in reverse order. | ||||||
|  | 
 | ||||||
|  | ## In Code Usage | ||||||
|  | 
 | ||||||
|  | The application automatically runs pending migrations when starting up. This is done in the `initDatabase` function. | ||||||
|  | 
 | ||||||
|  | You can also programmatically work with migrations: | ||||||
|  | 
 | ||||||
|  | ```go | ||||||
|  | // Get database instance | ||||||
|  | database, err := db.New(cfg.DatabasePath) | ||||||
|  | if err != nil { | ||||||
|  |     // Handle error | ||||||
|  | } | ||||||
|  | defer database.Close() | ||||||
|  | 
 | ||||||
|  | // Run migrations | ||||||
|  | if err := database.MigrateUp(); err != nil { | ||||||
|  |     // Handle error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Check migration status | ||||||
|  | applied, pending, err := database.MigrationStatus() | ||||||
|  | if err != nil { | ||||||
|  |     // Handle error | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  | @ -9,3 +9,19 @@ | ||||||
| - Lo quito: What happens when you say _"lo quito"_...? (Spanish pun) | - Lo quito: What happens when you say _"lo quito"_...? (Spanish pun) | ||||||
| - Dice: Put `!dice` and wathever roll you want to perform. | - Dice: Put `!dice` and wathever roll you want to perform. | ||||||
| - Coin: Flip a coin and get heads or tails. | - Coin: Flip a coin and get heads or tails. | ||||||
|  | - How Long To Beat: Get game completion times from HowLongToBeat.com using `!hltb <game name>` | ||||||
|  | 
 | ||||||
|  | ### Utility | ||||||
|  | 
 | ||||||
|  | - Help: Shows available commands when you type `!help`. Lists all enabled plugins for the current channel organized by category with their descriptions and usage instructions. | ||||||
|  | - Remind Me: Reply to a message with `!remindme <duration>` to set a reminder. Supported duration units: y (years), mo (months), d (days), h (hours), m (minutes), s (seconds). Examples: `!remindme 1y` for 1 year, `!remindme 3mo` for 3 months, `!remindme 2d` for 2 days, `!remindme 3h` for 3 hours. The bot will mention you with a reminder after the specified time. | ||||||
|  | - Search and Replace: Reply to any message with `s/search/replace/[flags]` to perform text substitution. Supports flags: `g` (global), `i` (case insensitive), `n` (regex pattern). Example: `s/hello/hi/gi` replaces all occurrences of "hello" with "hi" case-insensitively. | ||||||
|  | 
 | ||||||
|  | ### Security | ||||||
|  | 
 | ||||||
|  | - Domain Blocker: Blocks messages containing links from specified domains. Configure it per channel with a comma-separated list of domains to block. When a message contains a link matching any of the blocked domains, the bot will notify that the message contained a blocked domain. This plugin requires configuration through the admin interface. | ||||||
|  | 
 | ||||||
|  | ### Social Media | ||||||
|  | 
 | ||||||
|  | - Twitter Link Expander: Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: fxtwitter.com). | ||||||
|  | - Instagram Link Expander: Automatically converts instagram.com links to alternative domain links and removes tracking parameters. This allows for better media embedding in chat platforms. Configure with `domain` option to set replacement domain (default: ddinstagram.com). | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								go.mod
									
										
									
									
									
								
							
							
						
						
									
										4
									
								
								go.mod
									
										
									
									
									
								
							|  | @ -4,6 +4,9 @@ go 1.24 | ||||||
| 
 | 
 | ||||||
| require ( | require ( | ||||||
| 	github.com/gorilla/sessions v1.4.0 | 	github.com/gorilla/sessions v1.4.0 | ||||||
|  | 	golang.org/x/crypto v0.37.0 | ||||||
|  | 	golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df | ||||||
|  | 	golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 | ||||||
| 	modernc.org/sqlite v1.37.0 | 	modernc.org/sqlite v1.37.0 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -14,7 +17,6 @@ require ( | ||||||
| 	github.com/mattn/go-isatty v0.0.20 // indirect | 	github.com/mattn/go-isatty v0.0.20 // indirect | ||||||
| 	github.com/ncruces/go-strftime v0.1.9 // indirect | 	github.com/ncruces/go-strftime v0.1.9 // indirect | ||||||
| 	github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect | 	github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect | ||||||
| 	golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect |  | ||||||
| 	golang.org/x/sys v0.32.0 // indirect | 	golang.org/x/sys v0.32.0 // indirect | ||||||
| 	modernc.org/libc v1.63.0 // indirect | 	modernc.org/libc v1.63.0 // indirect | ||||||
| 	modernc.org/mathutil v1.7.1 // indirect | 	modernc.org/mathutil v1.7.1 // indirect | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								go.sum
									
										
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
										
									
									
									
								
							|  | @ -16,6 +16,10 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh | ||||||
| github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= | github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= | ||||||
| github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= | ||||||
| github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= | ||||||
|  | golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= | ||||||
|  | golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= | ||||||
|  | golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df h1:SwgTucX8ajPE0La2ELpYOIs8jVMoCMpAvYB6mDqP9vk= | ||||||
|  | golang.org/x/crypto/x509roots/fallback v0.0.0-20250418111936-9c1aa6af88df/go.mod h1:lxN5T34bK4Z/i6cMaU7frUU57VkDXFD4Kamfl/cp9oU= | ||||||
| golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= | golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= | ||||||
| golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= | golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= | ||||||
| golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= | golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= | ||||||
|  |  | ||||||
|  | @ -2,6 +2,8 @@ package admin | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"embed" | 	"embed" | ||||||
|  | 	"encoding/gob" | ||||||
|  | 	"fmt" | ||||||
| 	"html/template" | 	"html/template" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | @ -14,7 +16,7 @@ import ( | ||||||
| 	"github.com/gorilla/sessions" | 	"github.com/gorilla/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| //go:embed templates/*.html | //go:embed templates/*.html templates/plugins/*.html | ||||||
| var templateFS embed.FS | var templateFS embed.FS | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | @ -28,6 +30,11 @@ type FlashMessage struct { | ||||||
| 	Message  string | 	Message  string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func init() { | ||||||
|  | 	// Register the FlashMessage type with gob package for session serialization | ||||||
|  | 	gob.Register(FlashMessage{}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // TemplateData holds data for rendering templates | // TemplateData holds data for rendering templates | ||||||
| type TemplateData struct { | type TemplateData struct { | ||||||
| 	User          *model.User | 	User          *model.User | ||||||
|  | @ -39,6 +46,7 @@ type TemplateData struct { | ||||||
| 	Channels      []*model.Channel | 	Channels      []*model.Channel | ||||||
| 	Channel       *model.Channel | 	Channel       *model.Channel | ||||||
| 	ChannelPlugin *model.ChannelPlugin | 	ChannelPlugin *model.ChannelPlugin | ||||||
|  | 	Version       string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Admin represents the admin interface | // Admin represents the admin interface | ||||||
|  | @ -48,12 +56,18 @@ type Admin struct { | ||||||
| 	store        *sessions.CookieStore | 	store        *sessions.CookieStore | ||||||
| 	templates    map[string]*template.Template | 	templates    map[string]*template.Template | ||||||
| 	baseTemplate *template.Template | 	baseTemplate *template.Template | ||||||
|  | 	version      string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New creates a new Admin instance | // New creates a new Admin instance | ||||||
| func New(cfg *config.Config, database *db.Database) *Admin { | func New(cfg *config.Config, database *db.Database, version string) *Admin { | ||||||
| 	// Create session store | 	// Create session store with appropriate options | ||||||
| 	store := sessions.NewCookieStore([]byte(cfg.SecretKey)) | 	store := sessions.NewCookieStore([]byte(cfg.SecretKey)) | ||||||
|  | 	store.Options = &sessions.Options{ | ||||||
|  | 		Path:     "/admin", | ||||||
|  | 		MaxAge:   3600 * 24 * 7, // 1 week | ||||||
|  | 		HttpOnly: true, | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Load templates | 	// Load templates | ||||||
| 	templates := make(map[string]*template.Template) | 	templates := make(map[string]*template.Template) | ||||||
|  | @ -76,34 +90,56 @@ func New(cfg *config.Config, database *db.Database) *Admin { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Parse and register all templates | 	// Parse and register all templates | ||||||
| 	templateFiles := []string{ | 	mainTemplateFiles := []string{ | ||||||
| 		"index.html", | 		"index.html", | ||||||
| 		"login.html", | 		"login.html", | ||||||
|  | 		"change_password.html", | ||||||
| 		"channel_list.html", | 		"channel_list.html", | ||||||
| 		"channel_detail.html", | 		"channel_detail.html", | ||||||
| 		"plugin_list.html", | 		"plugin_list.html", | ||||||
| 		"channel_plugins_list.html", | 		"channel_plugins_list.html", | ||||||
|  | 		"channel_plugin_config.html", | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, tf := range templateFiles { | 	pluginTemplateFiles := []string{ | ||||||
|  | 		"plugins/security.domainblock.html", | ||||||
|  | 		"plugins/social.instagram.html", | ||||||
|  | 		"plugins/social.twitter.html", | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tf := range mainTemplateFiles { | ||||||
| 		// Read template content from embedded filesystem | 		// Read template content from embedded filesystem | ||||||
| 		content, err := templateFS.ReadFile("templates/" + tf) | 		content, err := templateFS.ReadFile("templates/" + tf) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			panic(err) | 			panic(err) | ||||||
| 		} | 		} | ||||||
| 		 | 
 | ||||||
| 		// Create a clone of the base template | 		// Create a clone of the base template | ||||||
| 		t, err := baseTemplate.Clone() | 		t, err := baseTemplate.Clone() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			panic(err) | 			panic(err) | ||||||
| 		} | 		} | ||||||
| 		 | 
 | ||||||
| 		// Parse the template content | 		// Parse the template content | ||||||
| 		t, err = t.Parse(string(content)) | 		t, err = t.Parse(string(content)) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			panic(err) | 			panic(err) | ||||||
| 		} | 		} | ||||||
| 		 | 
 | ||||||
|  | 		// If this is the channel_plugin_config template, also parse plugin templates | ||||||
|  | 		if tf == "channel_plugin_config.html" { | ||||||
|  | 			for _, pluginTf := range pluginTemplateFiles { | ||||||
|  | 				pluginContent, err := templateFS.ReadFile("templates/" + pluginTf) | ||||||
|  | 				if err != nil { | ||||||
|  | 					panic(err) | ||||||
|  | 				} | ||||||
|  | 				t, err = t.Parse(string(pluginContent)) | ||||||
|  | 				if err != nil { | ||||||
|  | 					panic(err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		templates[tf] = t | 		templates[tf] = t | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -113,6 +149,7 @@ func New(cfg *config.Config, database *db.Database) *Admin { | ||||||
| 		store:        store, | 		store:        store, | ||||||
| 		templates:    templates, | 		templates:    templates, | ||||||
| 		baseTemplate: baseTemplate, | 		baseTemplate: baseTemplate, | ||||||
|  | 		version:      version, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -122,16 +159,22 @@ func (a *Admin) RegisterRoutes(mux *http.ServeMux) { | ||||||
| 	mux.HandleFunc("/admin/", a.handleIndex) | 	mux.HandleFunc("/admin/", a.handleIndex) | ||||||
| 	mux.HandleFunc("/admin/login", a.handleLogin) | 	mux.HandleFunc("/admin/login", a.handleLogin) | ||||||
| 	mux.HandleFunc("/admin/logout", a.handleLogout) | 	mux.HandleFunc("/admin/logout", a.handleLogout) | ||||||
|  | 	mux.HandleFunc("/admin/change-password", a.handleChangePassword) | ||||||
| 	mux.HandleFunc("/admin/plugins", a.handlePluginList) | 	mux.HandleFunc("/admin/plugins", a.handlePluginList) | ||||||
| 	mux.HandleFunc("/admin/channels", a.handleChannelList) | 	mux.HandleFunc("/admin/channels", a.handleChannelList) | ||||||
| 	mux.HandleFunc("/admin/channels/", a.handleChannelDetail) | 	mux.HandleFunc("/admin/channels/", a.handleChannelDetail) | ||||||
| 	mux.HandleFunc("/admin/channelplugins", a.handleChannelPluginList) | 	mux.HandleFunc("/admin/channelplugins", a.handleChannelPluginList) | ||||||
|  | 	mux.HandleFunc("/admin/channelplugins/config/", a.handleChannelPluginConfig) | ||||||
| 	mux.HandleFunc("/admin/channelplugins/", a.handleChannelPluginDetailOrDelete) | 	mux.HandleFunc("/admin/channelplugins/", a.handleChannelPluginDetailOrDelete) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getCurrentUser gets the current user from the session | // getCurrentUser gets the current user from the session | ||||||
| func (a *Admin) getCurrentUser(r *http.Request) *model.User { | func (a *Admin) getCurrentUser(r *http.Request) *model.User { | ||||||
| 	session, _ := a.store.Get(r, sessionKey) | 	session, err := a.store.Get(r, sessionKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Printf("Error getting session for user retrieval: %v\n", err) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check if user is logged in | 	// Check if user is logged in | ||||||
| 	userID, ok := session.Values["user_id"].(int64) | 	userID, ok := session.Values["user_id"].(int64) | ||||||
|  | @ -142,6 +185,7 @@ func (a *Admin) getCurrentUser(r *http.Request) *model.User { | ||||||
| 	// Get user from database | 	// Get user from database | ||||||
| 	user, err := a.db.GetUserByID(userID) | 	user, err := a.db.GetUserByID(userID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		fmt.Printf("Error retrieving user from database: %v\n", err) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -150,32 +194,63 @@ func (a *Admin) getCurrentUser(r *http.Request) *model.User { | ||||||
| 
 | 
 | ||||||
| // isLoggedIn checks if the user is logged in | // isLoggedIn checks if the user is logged in | ||||||
| func (a *Admin) isLoggedIn(r *http.Request) bool { | func (a *Admin) isLoggedIn(r *http.Request) bool { | ||||||
| 	session, _ := a.store.Get(r, sessionKey) | 	session, err := a.store.Get(r, sessionKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Printf("Error getting session for login check: %v\n", err) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
| 	return session.Values["logged_in"] == true | 	return session.Values["logged_in"] == true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // addFlash adds a flash message to the session | // addFlash adds a flash message to the session | ||||||
| func (a *Admin) addFlash(w http.ResponseWriter, r *http.Request, message string, category string) { | func (a *Admin) addFlash(w http.ResponseWriter, r *http.Request, message string, category string) { | ||||||
| 	session, _ := a.store.Get(r, sessionKey) | 	session, err := a.store.Get(r, sessionKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// If there's an error getting the session, create a new one | ||||||
|  | 		session = sessions.NewSession(a.store, sessionKey) | ||||||
|  | 		session.Options = &sessions.Options{ | ||||||
|  | 			Path:     "/admin", | ||||||
|  | 			MaxAge:   3600 * 24 * 7, // 1 week | ||||||
|  | 			HttpOnly: true, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Add flash message | 	// Map internal categories to Bootstrap alert classes | ||||||
| 	flashes := session.Flashes() | 	var alertClass string | ||||||
| 	if flashes == nil { | 	switch category { | ||||||
| 		flashes = make([]interface{}, 0) | 	case "success": | ||||||
|  | 		alertClass = "success" | ||||||
|  | 	case "danger": | ||||||
|  | 		alertClass = "danger" | ||||||
|  | 	case "warning": | ||||||
|  | 		alertClass = "warning" | ||||||
|  | 	case "info": | ||||||
|  | 		alertClass = "info" | ||||||
|  | 	default: | ||||||
|  | 		alertClass = "info" | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	flash := FlashMessage{ | 	flash := FlashMessage{ | ||||||
| 		Category: category, | 		Category: alertClass, | ||||||
| 		Message:  message, | 		Message:  message, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	session.AddFlash(flash) | 	session.AddFlash(flash) | ||||||
| 	session.Save(r, w) | 	err = session.Save(r, w) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// Log the error or handle it appropriately | ||||||
|  | 		fmt.Printf("Error saving session: %v\n", err) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getFlashes gets all flash messages from the session | // getFlashes gets all flash messages from the session | ||||||
| func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessage { | func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessage { | ||||||
| 	session, _ := a.store.Get(r, sessionKey) | 	session, err := a.store.Get(r, sessionKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// If there's an error getting the session, return an empty slice | ||||||
|  | 		fmt.Printf("Error getting session for flashes: %v\n", err) | ||||||
|  | 		return []FlashMessage{} | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Get flash messages | 	// Get flash messages | ||||||
| 	flashes := session.Flashes() | 	flashes := session.Flashes() | ||||||
|  | @ -188,22 +263,14 @@ func (a *Admin) getFlashes(w http.ResponseWriter, r *http.Request) []FlashMessag | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Save session to clear flashes | 	// Save session to clear flashes | ||||||
| 	session.Save(r, w) | 	err = session.Save(r, w) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Printf("Error saving session after getting flashes: %v\n", err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	return messages | 	return messages | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // requireLogin middleware checks if the user is logged in |  | ||||||
| func (a *Admin) requireLogin(next http.HandlerFunc) http.HandlerFunc { |  | ||||||
| 	return func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 		if !a.isLoggedIn(r) { |  | ||||||
| 			http.Redirect(w, r, "/admin/login", http.StatusSeeOther) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		next(w, r) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // render renders a template with the given data | // render renders a template with the given data | ||||||
| func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) { | func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName string, data TemplateData) { | ||||||
| 	// Add current user data | 	// Add current user data | ||||||
|  | @ -211,6 +278,7 @@ func (a *Admin) render(w http.ResponseWriter, r *http.Request, templateName stri | ||||||
| 	data.LoggedIn = a.isLoggedIn(r) | 	data.LoggedIn = a.isLoggedIn(r) | ||||||
| 	data.Path = r.URL.Path | 	data.Path = r.URL.Path | ||||||
| 	data.Flash = a.getFlashes(w, r) | 	data.Flash = a.getFlashes(w, r) | ||||||
|  | 	data.Version = a.version | ||||||
| 
 | 
 | ||||||
| 	// Get template | 	// Get template | ||||||
| 	tmpl, ok := a.templates[templateName] | 	tmpl, ok := a.templates[templateName] | ||||||
|  | @ -277,7 +345,10 @@ func (a *Admin) handleLogin(w http.ResponseWriter, r *http.Request) { | ||||||
| 
 | 
 | ||||||
| 		// Set session expiration | 		// Set session expiration | ||||||
| 		session.Options.MaxAge = 3600 * 24 * 7 // 1 week | 		session.Options.MaxAge = 3600 * 24 * 7 // 1 week | ||||||
| 		session.Save(r, w) | 		err = session.Save(r, w) | ||||||
|  | 		if err != nil { | ||||||
|  | 			fmt.Printf("Error saving session: %v\n", err) | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		a.addFlash(w, r, "You were logged in", "success") | 		a.addFlash(w, r, "You were logged in", "success") | ||||||
| 
 | 
 | ||||||
|  | @ -299,10 +370,19 @@ func (a *Admin) handleLogin(w http.ResponseWriter, r *http.Request) { | ||||||
| // handleLogout handles the logout route | // handleLogout handles the logout route | ||||||
| func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) { | func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// Clear session | 	// Clear session | ||||||
| 	session, _ := a.store.Get(r, sessionKey) | 	session, err := a.store.Get(r, sessionKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Printf("Error getting session for logout: %v\n", err) | ||||||
|  | 		http.Redirect(w, r, "/admin/login", http.StatusSeeOther) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	session.Values = make(map[interface{}]interface{}) | 	session.Values = make(map[interface{}]interface{}) | ||||||
| 	session.Options.MaxAge = -1 // Delete session | 	session.Options.MaxAge = -1 // Delete session | ||||||
| 	session.Save(r, w) | 	err = session.Save(r, w) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Printf("Error saving session for logout: %v\n", err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	a.addFlash(w, r, "You were logged out", "success") | 	a.addFlash(w, r, "You were logged out", "success") | ||||||
| 
 | 
 | ||||||
|  | @ -310,6 +390,74 @@ func (a *Admin) handleLogout(w http.ResponseWriter, r *http.Request) { | ||||||
| 	http.Redirect(w, r, "/admin/login", http.StatusSeeOther) | 	http.Redirect(w, r, "/admin/login", http.StatusSeeOther) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // handleChangePassword handles the change password route | ||||||
|  | func (a *Admin) handleChangePassword(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	// Check if user is logged in | ||||||
|  | 	if !a.isLoggedIn(r) { | ||||||
|  | 		http.Redirect(w, r, "/admin/login", http.StatusSeeOther) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get current user | ||||||
|  | 	user := a.getCurrentUser(r) | ||||||
|  | 	if user == nil { | ||||||
|  | 		http.Redirect(w, r, "/admin/login", http.StatusSeeOther) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Handle form submission | ||||||
|  | 	if r.Method == http.MethodPost { | ||||||
|  | 		// Parse form | ||||||
|  | 		if err := r.ParseForm(); err != nil { | ||||||
|  | 			http.Error(w, "Bad request", http.StatusBadRequest) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Get form values | ||||||
|  | 		currentPassword := r.FormValue("current_password") | ||||||
|  | 		newPassword := r.FormValue("new_password") | ||||||
|  | 		confirmPassword := r.FormValue("confirm_password") | ||||||
|  | 
 | ||||||
|  | 		// Validate current password | ||||||
|  | 		_, err := a.db.CheckCredentials(user.Username, currentPassword) | ||||||
|  | 		if err != nil { | ||||||
|  | 			a.addFlash(w, r, "Current password is incorrect", "danger") | ||||||
|  | 			http.Redirect(w, r, "/admin/change-password", http.StatusSeeOther) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Validate new password and confirmation | ||||||
|  | 		if newPassword == "" { | ||||||
|  | 			a.addFlash(w, r, "New password cannot be empty", "danger") | ||||||
|  | 			http.Redirect(w, r, "/admin/change-password", http.StatusSeeOther) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if newPassword != confirmPassword { | ||||||
|  | 			a.addFlash(w, r, "New passwords do not match", "danger") | ||||||
|  | 			http.Redirect(w, r, "/admin/change-password", http.StatusSeeOther) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Update password | ||||||
|  | 		if err := a.db.UpdateUserPassword(user.ID, newPassword); err != nil { | ||||||
|  | 			a.addFlash(w, r, "Failed to update password: "+err.Error(), "danger") | ||||||
|  | 			http.Redirect(w, r, "/admin/change-password", http.StatusSeeOther) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Success | ||||||
|  | 		a.addFlash(w, r, "Password changed successfully", "success") | ||||||
|  | 		http.Redirect(w, r, "/admin/", http.StatusSeeOther) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Render change password template | ||||||
|  | 	a.render(w, r, "change_password.html", TemplateData{ | ||||||
|  | 		Title: "Change Password", | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // handlePluginList handles the plugin list route | // handlePluginList handles the plugin list route | ||||||
| func (a *Admin) handlePluginList(w http.ResponseWriter, r *http.Request) { | func (a *Admin) handlePluginList(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// Check if user is logged in | 	// Check if user is logged in | ||||||
|  | @ -416,6 +564,13 @@ func (a *Admin) handleChannelDetail(w http.ResponseWriter, r *http.Request) { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | 			// Update enable_all_plugins | ||||||
|  | 			enableAllPlugins := r.FormValue("enable_all_plugins") == "true" | ||||||
|  | 			if err := a.db.UpdateChannelEnableAllPlugins(id, enableAllPlugins); err != nil { | ||||||
|  | 				http.Error(w, "Failed to update channel enable all plugins", http.StatusInternalServerError) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
| 			a.addFlash(w, r, "Channel updated", "success") | 			a.addFlash(w, r, "Channel updated", "success") | ||||||
| 			http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther) | 			http.Redirect(w, r, "/admin/channels/"+channelID, http.StatusSeeOther) | ||||||
| 			return | 			return | ||||||
|  | @ -502,6 +657,96 @@ func (a *Admin) handleChannelPluginList(w http.ResponseWriter, r *http.Request) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // handleChannelPluginConfig handles the channel plugin configuration route | ||||||
|  | func (a *Admin) handleChannelPluginConfig(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	// Check if user is logged in | ||||||
|  | 	if !a.isLoggedIn(r) { | ||||||
|  | 		http.Redirect(w, r, "/admin/login", http.StatusSeeOther) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract channel plugin ID from path | ||||||
|  | 	path := r.URL.Path | ||||||
|  | 	channelPluginID := strings.TrimPrefix(path, "/admin/channelplugins/config/") | ||||||
|  | 
 | ||||||
|  | 	// Convert channel plugin ID to int64 | ||||||
|  | 	id, err := strconv.ParseInt(channelPluginID, 10, 64) | ||||||
|  | 	if err != nil { | ||||||
|  | 		http.Error(w, "Invalid channel plugin ID", http.StatusBadRequest) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the channel plugin | ||||||
|  | 	channelPlugin, err := a.db.GetChannelPluginByID(id) | ||||||
|  | 	if err != nil { | ||||||
|  | 		http.Error(w, "Channel plugin not found", http.StatusNotFound) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the plugin | ||||||
|  | 	p, err := plugin.Get(channelPlugin.PluginID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		http.Error(w, "Plugin not found", http.StatusNotFound) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Handle form submission | ||||||
|  | 	if r.Method == http.MethodPost { | ||||||
|  | 		// Parse form | ||||||
|  | 		if err := r.ParseForm(); err != nil { | ||||||
|  | 			http.Error(w, "Bad request", http.StatusBadRequest) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Create config map from form values | ||||||
|  | 		config := make(map[string]interface{}) | ||||||
|  | 
 | ||||||
|  | 		// Process form values based on plugin type | ||||||
|  | 		if channelPlugin.PluginID == "security.domainblock" { | ||||||
|  | 			// Get blocked domains from form | ||||||
|  | 			blockedDomains := r.FormValue("blocked_domains") | ||||||
|  | 			config["blocked_domains"] = blockedDomains | ||||||
|  | 		} else { | ||||||
|  | 			// Generic handling for other plugins | ||||||
|  | 			for key, values := range r.Form { | ||||||
|  | 				if key == "form_submitted" { | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 				if len(values) == 1 { | ||||||
|  | 					config[key] = values[0] | ||||||
|  | 				} else { | ||||||
|  | 					config[key] = values | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Update plugin configuration | ||||||
|  | 		if err := a.db.UpdateChannelPluginConfig(id, config); err != nil { | ||||||
|  | 			http.Error(w, "Failed to update plugin configuration", http.StatusInternalServerError) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Get the channel to redirect back to the channel detail page | ||||||
|  | 		channel, err := a.db.GetChannelByID(channelPlugin.ChannelID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			a.addFlash(w, r, "Plugin configuration updated", "success") | ||||||
|  | 			http.Redirect(w, r, "/admin/channelplugins", http.StatusSeeOther) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		a.addFlash(w, r, "Plugin configuration updated", "success") | ||||||
|  | 		http.Redirect(w, r, fmt.Sprintf("/admin/channels/%d", channel.ID), http.StatusSeeOther) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Render template | ||||||
|  | 	a.render(w, r, "channel_plugin_config.html", TemplateData{ | ||||||
|  | 		Title:         "Configure Plugin: " + p.GetName(), | ||||||
|  | 		ChannelPlugin: channelPlugin, | ||||||
|  | 		Plugins:       map[string]model.Plugin{channelPlugin.PluginID: p}, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // handleChannelPluginDetailOrDelete handles the channel plugin detail or delete route | // handleChannelPluginDetailOrDelete handles the channel plugin detail or delete route | ||||||
| func (a *Admin) handleChannelPluginDetailOrDelete(w http.ResponseWriter, r *http.Request) { | func (a *Admin) handleChannelPluginDetailOrDelete(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// Check if user is logged in | 	// Check if user is logged in | ||||||
|  |  | ||||||
|  | @ -28,8 +28,10 @@ | ||||||
|                             <a href="/admin/login">Log in</a> |                             <a href="/admin/login">Log in</a> | ||||||
|                             {{else}} |                             {{else}} | ||||||
|                             <div class="d-none d-xl-block pl-2"> |                             <div class="d-none d-xl-block pl-2"> | ||||||
|                                 <div>{{.User.Username}} - <a class="mt-1 small" |                                 <div>{{.User.Username}} -  | ||||||
|                                         href="/admin/logout">Log out</a></div> |                                     <a class="mt-1 small" href="/admin/change-password">Change Password</a> |  | ||||||
|  |                                     <a class="mt-1 small" href="/admin/logout">Log out</a> | ||||||
|  |                                 </div> | ||||||
|                             </div> |                             </div> | ||||||
|                             </a> |                             </a> | ||||||
|                             {{end}} |                             {{end}} | ||||||
|  | @ -100,14 +102,14 @@ | ||||||
|             {{end}} |             {{end}} | ||||||
|         </div> |         </div> | ||||||
| 
 | 
 | ||||||
|         {{range .Flash}} |         <div class="container-xl mt-3"> | ||||||
|         <div class="card"> |             {{range .Flash}} | ||||||
|             <div class="card-status-top bg-{{.Category}}"></div> |             <div class="alert alert-{{.Category}} alert-dismissible" role="alert"> | ||||||
|             <div class="card-body"> |                 {{.Message}} | ||||||
|                 <p>{{.Message}}</p> |                 <button type="button" class="btn-close" data-bs-dismiss="alert" aria-label="Close"></button> | ||||||
|             </div> |             </div> | ||||||
|  |             {{end}} | ||||||
|         </div> |         </div> | ||||||
|         {{end}} |  | ||||||
| 
 | 
 | ||||||
|         <div class="content"> |         <div class="content"> | ||||||
|             <div class="container-xl"> |             <div class="container-xl"> | ||||||
|  | @ -115,6 +117,19 @@ | ||||||
|             </div> |             </div> | ||||||
|         </div> |         </div> | ||||||
| 
 | 
 | ||||||
|  |         <footer class="footer footer-transparent d-print-none"> | ||||||
|  |             <div class="container-xl"> | ||||||
|  |                 <div class="row text-center align-items-center flex-row-reverse"> | ||||||
|  |                     <div class="col-12 col-lg-auto mt-3 mt-lg-0"> | ||||||
|  |                         <ul class="list-inline list-inline-dots mb-0"> | ||||||
|  |                             <li class="list-inline-item"> | ||||||
|  |                                 ButterRobot {{if .Version}}v{{.Version}}{{else}}(development){{end}} | ||||||
|  |                             </li> | ||||||
|  |                         </ul> | ||||||
|  |                     </div> | ||||||
|  |                 </div> | ||||||
|  |             </div> | ||||||
|  |         </footer> | ||||||
|     </div> |     </div> | ||||||
| 
 | 
 | ||||||
|     <script src="https://unpkg.com/@tabler/core@latest/dist/js/tabler.min.js"></script> |     <script src="https://unpkg.com/@tabler/core@latest/dist/js/tabler.min.js"></script> | ||||||
|  |  | ||||||
							
								
								
									
										30
									
								
								internal/admin/templates/change_password.html
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								internal/admin/templates/change_password.html
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,30 @@ | ||||||
|  | {{define "content"}} | ||||||
|  | <div class="row justify-content-center"> | ||||||
|  |     <div class="col-md-6"> | ||||||
|  |         <div class="card"> | ||||||
|  |             <div class="card-header"> | ||||||
|  |                 <h3 class="card-title">Change Password</h3> | ||||||
|  |             </div> | ||||||
|  |             <div class="card-body"> | ||||||
|  |                 <form method="post" action="/admin/change-password"> | ||||||
|  |                     <div class="mb-3"> | ||||||
|  |                         <label class="form-label">Current Password</label> | ||||||
|  |                         <input type="password" name="current_password" class="form-control" placeholder="Current Password" required> | ||||||
|  |                     </div> | ||||||
|  |                     <div class="mb-3"> | ||||||
|  |                         <label class="form-label">New Password</label> | ||||||
|  |                         <input type="password" name="new_password" class="form-control" placeholder="New Password" required> | ||||||
|  |                     </div> | ||||||
|  |                     <div class="mb-3"> | ||||||
|  |                         <label class="form-label">Confirm New Password</label> | ||||||
|  |                         <input type="password" name="confirm_password" class="form-control" placeholder="Confirm New Password" required> | ||||||
|  |                     </div> | ||||||
|  |                     <div class="form-footer"> | ||||||
|  |                         <button type="submit" class="btn btn-primary">Change Password</button> | ||||||
|  |                     </div> | ||||||
|  |                 </form> | ||||||
|  |             </div> | ||||||
|  |         </div> | ||||||
|  |     </div> | ||||||
|  | </div> | ||||||
|  | {{end}} | ||||||
|  | @ -27,6 +27,15 @@ | ||||||
|                         <!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked --> |                         <!-- Add a hidden field to ensure a value is sent even when checkbox is unchecked --> | ||||||
|                         <input type="hidden" name="form_submitted" value="true"> |                         <input type="hidden" name="form_submitted" value="true"> | ||||||
|                     </div> |                     </div> | ||||||
|  |                     <div class="mb-3"> | ||||||
|  |                         <label class="form-check form-switch"> | ||||||
|  |                             <input class="form-check-input" type="checkbox" name="enable_all_plugins" value="true" {{if .Channel.EnableAllPlugins}}checked{{end}}> | ||||||
|  |                             <span class="form-check-label">Enable All Plugins</span> | ||||||
|  |                         </label> | ||||||
|  |                         <div> | ||||||
|  |                             When enabled, all registered plugins will be automatically enabled for this channel. Individual plugin settings will be ignored. | ||||||
|  |                         </div> | ||||||
|  |                     </div> | ||||||
|                     <div class="form-footer"> |                     <div class="form-footer"> | ||||||
|                         <button type="submit" class="btn btn-primary">Save</button> |                         <button type="submit" class="btn btn-primary">Save</button> | ||||||
|                         <a href="/admin/channels" class="btn btn-link">Back to Channels</a> |                         <a href="/admin/channels" class="btn btn-link">Back to Channels</a> | ||||||
|  | @ -68,6 +77,10 @@ | ||||||
|                                             {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} |                                             {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} | ||||||
|                                         </button> |                                         </button> | ||||||
|                                     </form> |                                     </form> | ||||||
|  |                                     {{$plugin := index $.Plugins $pluginID}} | ||||||
|  |                                     {{if $plugin.RequiresConfig}} | ||||||
|  |                                     <a href="/admin/channelplugins/config/{{$channelPlugin.ID}}" class="btn btn-info btn-sm">Configure</a> | ||||||
|  |                                     {{end}} | ||||||
|                                     <form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline"> |                                     <form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline"> | ||||||
|                                         <button type="submit" class="btn btn-danger btn-sm" |                                         <button type="submit" class="btn btn-danger btn-sm" | ||||||
|                                             onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button> |                                             onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button> | ||||||
|  | @ -111,4 +124,4 @@ | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
| </div> | </div> | ||||||
| {{end}} | {{end}} | ||||||
|  |  | ||||||
							
								
								
									
										32
									
								
								internal/admin/templates/channel_plugin_config.html
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								internal/admin/templates/channel_plugin_config.html
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,32 @@ | ||||||
|  | {{define "content"}} | ||||||
|  | <div class="row"> | ||||||
|  |     <div class="col-md-12"> | ||||||
|  |         <div class="card"> | ||||||
|  |             <div class="card-header"> | ||||||
|  |                 <h3 class="card-title">Configure Plugin: {{(index .Plugins .ChannelPlugin.PluginID).GetName}}</h3> | ||||||
|  |             </div> | ||||||
|  |             <div class="card-body"> | ||||||
|  |                 <form method="post"> | ||||||
|  |                     <!-- Plugin configuration fields --> | ||||||
|  |                     {{if eq .ChannelPlugin.PluginID "security.domainblock"}} | ||||||
|  |                         {{template "plugins/security.domainblock.html" .}} | ||||||
|  |                     {{else if eq .ChannelPlugin.PluginID "social.instagram"}} | ||||||
|  |                         {{template "plugins/social.instagram.html" .}} | ||||||
|  |                     {{else if eq .ChannelPlugin.PluginID "social.twitter"}} | ||||||
|  |                         {{template "plugins/social.twitter.html" .}} | ||||||
|  |                     {{else}} | ||||||
|  |                     <div class="alert alert-warning"> | ||||||
|  |                         This plugin doesn't have specific configuration fields implemented yet. | ||||||
|  |                     </div> | ||||||
|  |                     {{end}} | ||||||
|  | 
 | ||||||
|  |                     <div class="form-footer"> | ||||||
|  |                         <button type="submit" class="btn btn-primary">Save Configuration</button> | ||||||
|  |                         <a href="/admin/channels/{{.ChannelPlugin.ChannelID}}" class="btn btn-secondary">Cancel</a> | ||||||
|  |                     </div> | ||||||
|  |                 </form> | ||||||
|  |             </div> | ||||||
|  |         </div> | ||||||
|  |     </div> | ||||||
|  | </div> | ||||||
|  | {{end}} | ||||||
|  | @ -38,6 +38,10 @@ | ||||||
|                                                 {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} |                                                 {{if $channelPlugin.Enabled}}Disable{{else}}Enable{{end}} | ||||||
|                                             </button> |                                             </button> | ||||||
|                                         </form> |                                         </form> | ||||||
|  |                                         {{$plugin := index $.Plugins $pluginID}} | ||||||
|  |                                         {{if $plugin.ConfigRequired}} | ||||||
|  |                                             <a href="/admin/channelplugins/config/{{$channelPlugin.ID}}" class="btn btn-info btn-sm">Configure</a> | ||||||
|  |                                         {{end}} | ||||||
|                                         <form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline"> |                                         <form method="post" action="/admin/channelplugins/{{$channelPlugin.ID}}/delete" class="d-inline"> | ||||||
|                                             <button type="submit" class="btn btn-danger btn-sm" |                                             <button type="submit" class="btn btn-danger btn-sm" | ||||||
|                                                 onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button> |                                                 onclick="return confirm('Are you sure you want to remove this plugin?')">Remove</button> | ||||||
|  | @ -90,4 +94,4 @@ | ||||||
|         </div> |         </div> | ||||||
|     </div> |     </div> | ||||||
| </div> | </div> | ||||||
| {{end}} | {{end}} | ||||||
|  |  | ||||||
							
								
								
									
										12
									
								
								internal/admin/templates/plugins/security.domainblock.html
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								internal/admin/templates/plugins/security.domainblock.html
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,12 @@ | ||||||
|  | {{define "plugins/security.domainblock.html"}} | ||||||
|  | <div class="mb-3"> | ||||||
|  |     <label class="form-label">Blocked Domains</label> | ||||||
|  |     <input type="text" class="form-control" name="blocked_domains" | ||||||
|  |            value="{{with .ChannelPlugin.Config}}{{index . "blocked_domains"}}{{end}}" | ||||||
|  |            placeholder="example.com, evil.org, ads.com"> | ||||||
|  |     <div class="form-text text-muted"> | ||||||
|  |         Enter comma-separated list of domains to block (e.g., example.com, evil.org). | ||||||
|  |         Messages containing links to these domains will be blocked. | ||||||
|  |     </div> | ||||||
|  | </div> | ||||||
|  | {{end}} | ||||||
							
								
								
									
										11
									
								
								internal/admin/templates/plugins/social.instagram.html
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								internal/admin/templates/plugins/social.instagram.html
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,11 @@ | ||||||
|  | {{define "plugins/social.instagram.html"}} | ||||||
|  | <div class="mb-3"> | ||||||
|  |     <label class="form-label">Replacement Domain</label> | ||||||
|  |     <input type="text" class="form-control" name="domain" | ||||||
|  |            value="{{with .ChannelPlugin.Config}}{{index . "domain"}}{{end}}" | ||||||
|  |            placeholder="ddinstagram.com"> | ||||||
|  |     <div class="form-text text-muted"> | ||||||
|  |         Enter the domain to replace instagram.com links with. Default is ddinstagram.com if left empty. | ||||||
|  |     </div> | ||||||
|  | </div> | ||||||
|  | {{end}} | ||||||
							
								
								
									
										11
									
								
								internal/admin/templates/plugins/social.twitter.html
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								internal/admin/templates/plugins/social.twitter.html
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,11 @@ | ||||||
|  | {{define "plugins/social.twitter.html"}} | ||||||
|  | <div class="mb-3"> | ||||||
|  |     <label class="form-label">Replacement Domain</label> | ||||||
|  |     <input type="text" class="form-control" name="domain" | ||||||
|  |            value="{{with .ChannelPlugin.Config}}{{index . "domain"}}{{end}}" | ||||||
|  |            placeholder="fxtwitter.com"> | ||||||
|  |     <div class="form-text text-muted"> | ||||||
|  |         Enter the domain to replace twitter.com and x.com links with. Default is fxtwitter.com if left empty. | ||||||
|  |     </div> | ||||||
|  | </div> | ||||||
|  | {{end}} | ||||||
|  | @ -9,28 +9,37 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/signal" | 	"os/signal" | ||||||
|  | 	"runtime/debug" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"syscall" | 	"syscall" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/admin" | 	"git.nakama.town/fmartingr/butterrobot/internal/admin" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/cache" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/config" | 	"git.nakama.town/fmartingr/butterrobot/internal/config" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/db" | 	"git.nakama.town/fmartingr/butterrobot/internal/db" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/platform" | 	"git.nakama.town/fmartingr/butterrobot/internal/platform" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/domainblock" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun" | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/fun" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/help" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping" | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/ping" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/reminder" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/searchreplace" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin/social" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/queue" | 	"git.nakama.town/fmartingr/butterrobot/internal/queue" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // App represents the application | // App represents the application | ||||||
| type App struct { | type App struct { | ||||||
| 	config *config.Config | 	config  *config.Config | ||||||
| 	logger *slog.Logger | 	logger  *slog.Logger | ||||||
| 	db     *db.Database | 	db      *db.Database | ||||||
| 	router *http.ServeMux | 	router  *http.ServeMux | ||||||
| 	queue  *queue.Queue | 	queue   *queue.Queue | ||||||
| 	admin  *admin.Admin | 	admin   *admin.Admin | ||||||
|  | 	version string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New creates a new App instance | // New creates a new App instance | ||||||
|  | @ -47,16 +56,24 @@ func New(cfg *config.Config, logger *slog.Logger) (*App, error) { | ||||||
| 	// Initialize message queue | 	// Initialize message queue | ||||||
| 	messageQueue := queue.New(logger) | 	messageQueue := queue.New(logger) | ||||||
| 
 | 
 | ||||||
|  | 	// Get version information | ||||||
|  | 	version := "" | ||||||
|  | 	info, ok := debug.ReadBuildInfo() | ||||||
|  | 	if ok { | ||||||
|  | 		version = info.Main.Version | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Initialize admin interface | 	// Initialize admin interface | ||||||
| 	adminInterface := admin.New(cfg, database) | 	adminInterface := admin.New(cfg, database, version) | ||||||
| 
 | 
 | ||||||
| 	return &App{ | 	return &App{ | ||||||
| 		config: cfg, | 		config:  cfg, | ||||||
| 		logger: logger, | 		logger:  logger, | ||||||
| 		db:     database, | 		db:      database, | ||||||
| 		router: router, | 		router:  router, | ||||||
| 		queue:  messageQueue, | 		queue:   messageQueue, | ||||||
| 		admin:  adminInterface, | 		admin:   adminInterface, | ||||||
|  | 		version: version, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -72,6 +89,13 @@ func (a *App) Run() error { | ||||||
| 	plugin.Register(fun.NewCoin()) | 	plugin.Register(fun.NewCoin()) | ||||||
| 	plugin.Register(fun.NewDice()) | 	plugin.Register(fun.NewDice()) | ||||||
| 	plugin.Register(fun.NewLoquito()) | 	plugin.Register(fun.NewLoquito()) | ||||||
|  | 	plugin.Register(fun.NewHLTB()) | ||||||
|  | 	plugin.Register(social.NewTwitterExpander()) | ||||||
|  | 	plugin.Register(social.NewInstagramExpander()) | ||||||
|  | 	plugin.Register(reminder.New(a.db)) | ||||||
|  | 	plugin.Register(domainblock.New()) | ||||||
|  | 	plugin.Register(searchreplace.New()) | ||||||
|  | 	plugin.Register(help.New(a.db)) | ||||||
| 
 | 
 | ||||||
| 	// Initialize routes | 	// Initialize routes | ||||||
| 	a.initializeRoutes() | 	a.initializeRoutes() | ||||||
|  | @ -79,6 +103,12 @@ func (a *App) Run() error { | ||||||
| 	// Start message queue worker | 	// Start message queue worker | ||||||
| 	a.queue.Start(a.handleMessage) | 	a.queue.Start(a.handleMessage) | ||||||
| 
 | 
 | ||||||
|  | 	// Start reminder scheduler | ||||||
|  | 	a.queue.StartReminderScheduler(a.handleReminder) | ||||||
|  | 
 | ||||||
|  | 	// Start cache cleanup scheduler | ||||||
|  | 	go a.startCacheCleanup() | ||||||
|  | 
 | ||||||
| 	// Create server | 	// Create server | ||||||
| 	addr := fmt.Sprintf(":%s", a.config.Port) | 	addr := fmt.Sprintf(":%s", a.config.Port) | ||||||
| 	srv := &http.Server{ | 	srv := &http.Server{ | ||||||
|  | @ -124,13 +154,29 @@ func (a *App) Run() error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // startCacheCleanup runs periodic cache cleanup | ||||||
|  | func (a *App) startCacheCleanup() { | ||||||
|  | 	ticker := time.NewTicker(time.Hour) // Clean up every hour | ||||||
|  | 	defer ticker.Stop() | ||||||
|  | 
 | ||||||
|  | 	for range ticker.C { | ||||||
|  | 		if err := a.db.CacheCleanup(); err != nil { | ||||||
|  | 			a.logger.Error("Cache cleanup failed", "error", err) | ||||||
|  | 		} else { | ||||||
|  | 			a.logger.Debug("Cache cleanup completed") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Initialize HTTP routes | // Initialize HTTP routes | ||||||
| func (a *App) initializeRoutes() { | func (a *App) initializeRoutes() { | ||||||
| 	// Health check endpoint | 	// Health check endpoint | ||||||
| 	a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { | 	a.router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		w.Header().Set("Content-Type", "application/json") | 		w.Header().Set("Content-Type", "application/json") | ||||||
| 		w.WriteHeader(http.StatusOK) | 		w.WriteHeader(http.StatusOK) | ||||||
| 		json.NewEncoder(w).Encode(map[string]interface{}{}) | 		if err := json.NewEncoder(w).Encode(map[string]interface{}{}); err != nil { | ||||||
|  | 			a.logger.Error("Error encoding response", "error", err) | ||||||
|  | 		} | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	// Platform webhook endpoints | 	// Platform webhook endpoints | ||||||
|  | @ -153,7 +199,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { | ||||||
| 	if _, err := platform.Get(platformName); err != nil { | 	if _, err := platform.Get(platformName); err != nil { | ||||||
| 		w.Header().Set("Content-Type", "application/json") | 		w.Header().Set("Content-Type", "application/json") | ||||||
| 		w.WriteHeader(http.StatusBadRequest) | 		w.WriteHeader(http.StatusBadRequest) | ||||||
| 		json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}) | 		if err := json.NewEncoder(w).Encode(map[string]string{"error": "Unknown platform"}); err != nil { | ||||||
|  | 			a.logger.Error("Error encoding response", "error", err) | ||||||
|  | 		} | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -162,7 +210,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		w.Header().Set("Content-Type", "application/json") | 		w.Header().Set("Content-Type", "application/json") | ||||||
| 		w.WriteHeader(http.StatusBadRequest) | 		w.WriteHeader(http.StatusBadRequest) | ||||||
| 		json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}) | 		if err := json.NewEncoder(w).Encode(map[string]string{"error": "Failed to read request body"}); err != nil { | ||||||
|  | 			a.logger.Error("Error encoding response", "error", err) | ||||||
|  | 		} | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -178,7 +228,9 @@ func (a *App) handleIncomingWebhook(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// Respond with success | 	// Respond with success | ||||||
| 	w.Header().Set("Content-Type", "application/json") | 	w.Header().Set("Content-Type", "application/json") | ||||||
| 	w.WriteHeader(http.StatusOK) | 	w.WriteHeader(http.StatusOK) | ||||||
| 	json.NewEncoder(w).Encode(map[string]any{}) | 	if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { | ||||||
|  | 		a.logger.Error("Error encoding response", "error", err) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // extractPlatformName extracts the platform name from the URL path | // extractPlatformName extracts the platform name from the URL path | ||||||
|  | @ -262,11 +314,21 @@ func (a *App) handleMessage(item queue.Item) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Process message with plugins | 	// Process message with plugins | ||||||
| 	for pluginID, channelPlugin := range channel.Plugins { | 	var pluginsToProcess []string | ||||||
| 		if !channel.HasEnabledPlugin(pluginID) { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
|  | 	if channel.EnableAllPlugins { | ||||||
|  | 		// If EnableAllPlugins is true, process all registered plugins | ||||||
|  | 		pluginsToProcess = plugin.GetAvailablePluginIDs() | ||||||
|  | 	} else { | ||||||
|  | 		// Otherwise, process only explicitly enabled plugins | ||||||
|  | 		for pluginID := range channel.Plugins { | ||||||
|  | 			if channel.HasEnabledPlugin(pluginID) { | ||||||
|  | 				pluginsToProcess = append(pluginsToProcess, pluginID) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, pluginID := range pluginsToProcess { | ||||||
| 		// Get plugin | 		// Get plugin | ||||||
| 		p, err := plugin.Get(pluginID) | 		p, err := plugin.Get(pluginID) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | @ -274,20 +336,121 @@ func (a *App) handleMessage(item queue.Item) { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Process message | 		// Get plugin configuration (empty map if EnableAllPlugins and plugin not explicitly configured) | ||||||
| 		responses := p.OnMessage(message, channelPlugin.Config) | 		var config map[string]interface{} | ||||||
|  | 		if channelPlugin, exists := channel.Plugins[pluginID]; exists { | ||||||
|  | 			config = channelPlugin.Config | ||||||
|  | 		} else { | ||||||
|  | 			config = make(map[string]interface{}) | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		// Send responses | 		// Create cache instance for this plugin | ||||||
|  | 		pluginCache := cache.New(a.db, pluginID) | ||||||
|  | 
 | ||||||
|  | 		// Process message and get actions | ||||||
|  | 		actions := p.OnMessage(message, config, pluginCache) | ||||||
|  | 
 | ||||||
|  | 		// Get platform for processing actions | ||||||
| 		platform, err := platform.Get(item.Platform) | 		platform, err := platform.Get(item.Platform) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			a.logger.Error("Error getting platform", "error", err) | 			a.logger.Error("Error getting platform", "error", err) | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for _, response := range responses { | 		// Process each action | ||||||
| 			if err := platform.SendMessage(response); err != nil { | 		for _, action := range actions { | ||||||
| 				a.logger.Error("Error sending message", "error", err) | 			switch action.Type { | ||||||
|  | 			case model.ActionSendMessage: | ||||||
|  | 				// Send a message | ||||||
|  | 				if action.Message != nil { | ||||||
|  | 					if err := platform.SendMessage(action.Message); err != nil { | ||||||
|  | 						a.logger.Error("Error sending message", "error", err) | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					a.logger.Error("Send message action with nil message") | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 			case model.ActionDeleteMessage: | ||||||
|  | 				// Delete a message using direct DeleteMessage call | ||||||
|  | 				if err := platform.DeleteMessage(action.Chat, action.MessageID); err != nil { | ||||||
|  | 					a.logger.Error("Error deleting message", "error", err, "message_id", action.MessageID) | ||||||
|  | 				} else { | ||||||
|  | 					a.logger.Info("Message deleted", "message_id", action.MessageID) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 			default: | ||||||
|  | 				a.logger.Error("Unknown action type", "type", action.Type) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // handleReminder handles reminder processing | ||||||
|  | func (a *App) handleReminder(reminder *model.Reminder) { | ||||||
|  | 	// When called with nil, it means we should check for pending reminders | ||||||
|  | 	if reminder == nil { | ||||||
|  | 		// Get pending reminders | ||||||
|  | 		reminders, err := a.db.GetPendingReminders() | ||||||
|  | 		if err != nil { | ||||||
|  | 			a.logger.Error("Error getting pending reminders", "error", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Process each reminder | ||||||
|  | 		for _, r := range reminders { | ||||||
|  | 			a.processReminder(r) | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Otherwise, process the specific reminder | ||||||
|  | 	a.processReminder(reminder) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // processReminder processes an individual reminder | ||||||
|  | func (a *App) processReminder(reminder *model.Reminder) { | ||||||
|  | 	a.logger.Info("Processing reminder", | ||||||
|  | 		"id", reminder.ID, | ||||||
|  | 		"platform", reminder.Platform, | ||||||
|  | 		"channel", reminder.ChannelID, | ||||||
|  | 		"trigger_at", reminder.TriggerAt, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// Get the platform handler | ||||||
|  | 	p, err := platform.Get(reminder.Platform) | ||||||
|  | 	if err != nil { | ||||||
|  | 		a.logger.Error("Error getting platform for reminder", "error", err, "platform", reminder.Platform) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the channel | ||||||
|  | 	channel, err := a.db.GetChannelByPlatform(reminder.Platform, reminder.ChannelID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		a.logger.Error("Error getting channel for reminder", "error", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create the reminder message | ||||||
|  | 	reminderText := fmt.Sprintf("@%s reminding you of this", reminder.Username) | ||||||
|  | 
 | ||||||
|  | 	message := &model.Message{ | ||||||
|  | 		Text:    reminderText, | ||||||
|  | 		Chat:    reminder.ChannelID, | ||||||
|  | 		Channel: channel, | ||||||
|  | 		Author:  "bot", | ||||||
|  | 		FromBot: true, | ||||||
|  | 		Date:    time.Now(), | ||||||
|  | 		ReplyTo: reminder.ReplyToID, // Reply to the original message | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Send the reminder message | ||||||
|  | 	if err := p.SendMessage(message); err != nil { | ||||||
|  | 		a.logger.Error("Error sending reminder", "error", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Mark the reminder as processed | ||||||
|  | 	if err := a.db.MarkReminderAsProcessed(reminder.ID); err != nil { | ||||||
|  | 		a.logger.Error("Error marking reminder as processed", "error", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
							
								
								
									
										83
									
								
								internal/cache/cache.go
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								internal/cache/cache.go
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,83 @@ | ||||||
|  | package cache | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/db" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Cache provides a plugin-friendly interface to the cache system | ||||||
|  | type Cache struct { | ||||||
|  | 	db       *db.Database | ||||||
|  | 	pluginID string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new Cache instance for a specific plugin | ||||||
|  | func New(database *db.Database, pluginID string) *Cache { | ||||||
|  | 	return &Cache{ | ||||||
|  | 		db:       database, | ||||||
|  | 		pluginID: pluginID, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Get retrieves a value from the cache | ||||||
|  | func (c *Cache) Get(key string, destination interface{}) error { | ||||||
|  | 	// Create prefixed key | ||||||
|  | 	fullKey := c.createKey(key) | ||||||
|  | 
 | ||||||
|  | 	// Get from database | ||||||
|  | 	value, err := c.db.CacheGet(fullKey) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Unmarshal JSON into destination | ||||||
|  | 	return json.Unmarshal([]byte(value), destination) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Set stores a value in the cache with optional expiration | ||||||
|  | func (c *Cache) Set(key string, value interface{}, expiration *time.Time) error { | ||||||
|  | 	// Create prefixed key | ||||||
|  | 	fullKey := c.createKey(key) | ||||||
|  | 
 | ||||||
|  | 	// Marshal value to JSON | ||||||
|  | 	jsonValue, err := json.Marshal(value) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("failed to marshal cache value: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Store in database | ||||||
|  | 	return c.db.CacheSet(fullKey, string(jsonValue), expiration) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetWithTTL stores a value in the cache with a time-to-live duration | ||||||
|  | func (c *Cache) SetWithTTL(key string, value interface{}, ttl time.Duration) error { | ||||||
|  | 	expiration := time.Now().Add(ttl) | ||||||
|  | 	return c.Set(key, value, &expiration) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Delete removes a value from the cache | ||||||
|  | func (c *Cache) Delete(key string) error { | ||||||
|  | 	fullKey := c.createKey(key) | ||||||
|  | 	return c.db.CacheDelete(fullKey) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Exists checks if a key exists in the cache | ||||||
|  | func (c *Cache) Exists(key string) (bool, error) { | ||||||
|  | 	fullKey := c.createKey(key) | ||||||
|  | 	_, err := c.db.CacheGet(fullKey) | ||||||
|  | 	if err == db.ErrNotFound { | ||||||
|  | 		return false, nil | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	return true, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // createKey creates a prefixed cache key | ||||||
|  | func (c *Cache) createKey(key string) string { | ||||||
|  | 	return fmt.Sprintf("%s_%s", c.pluginID, key) | ||||||
|  | } | ||||||
							
								
								
									
										176
									
								
								internal/cache/cache_test.go
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								internal/cache/cache_test.go
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,176 @@ | ||||||
|  | package cache | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/db" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestCache(t *testing.T) { | ||||||
|  | 	// Create temporary database for testing with unique name | ||||||
|  | 	dbFile := fmt.Sprintf("test_cache_%d.db", time.Now().UnixNano()) | ||||||
|  | 	database, err := db.New(dbFile) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to create test database: %v", err) | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		_ = database.Close() | ||||||
|  | 		// Clean up test database file | ||||||
|  | 		_ = os.Remove(dbFile) | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	// Create cache instance | ||||||
|  | 	cache := New(database, "test.plugin") | ||||||
|  | 
 | ||||||
|  | 	// Test data | ||||||
|  | 	testKey := "test_key" | ||||||
|  | 	testValue := map[string]interface{}{ | ||||||
|  | 		"name": "Test Game", | ||||||
|  | 		"time": 42, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Test Set and Get | ||||||
|  | 	t.Run("Set and Get", func(t *testing.T) { | ||||||
|  | 		err := cache.Set(testKey, testValue, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to set cache value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var retrieved map[string]interface{} | ||||||
|  | 		err = cache.Get(testKey, &retrieved) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to get cache value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if retrieved["name"] != testValue["name"] { | ||||||
|  | 			t.Errorf("Expected name %v, got %v", testValue["name"], retrieved["name"]) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if int(retrieved["time"].(float64)) != testValue["time"].(int) { | ||||||
|  | 			t.Errorf("Expected time %v, got %v", testValue["time"], retrieved["time"]) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Test SetWithTTL and expiration | ||||||
|  | 	t.Run("SetWithTTL and expiration", func(t *testing.T) { | ||||||
|  | 		expiredKey := "expired_key" | ||||||
|  | 
 | ||||||
|  | 		// Set with very short TTL | ||||||
|  | 		err := cache.SetWithTTL(expiredKey, testValue, time.Millisecond) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to set cache value with TTL: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Wait for expiration | ||||||
|  | 		time.Sleep(2 * time.Millisecond) | ||||||
|  | 
 | ||||||
|  | 		// Try to get - should fail | ||||||
|  | 		var retrieved map[string]interface{} | ||||||
|  | 		err = cache.Get(expiredKey, &retrieved) | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Errorf("Expected cache miss for expired key, but got value") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Test Exists | ||||||
|  | 	t.Run("Exists", func(t *testing.T) { | ||||||
|  | 		existsKey := "exists_key" | ||||||
|  | 
 | ||||||
|  | 		// Make sure the key doesn't exist initially by deleting it | ||||||
|  | 		_ = cache.Delete(existsKey) | ||||||
|  | 
 | ||||||
|  | 		// Should not exist initially | ||||||
|  | 		exists, err := cache.Exists(existsKey) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to check if key exists: %v", err) | ||||||
|  | 		} | ||||||
|  | 		if exists { | ||||||
|  | 			t.Errorf("Expected key to not exist, but it does") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Set value | ||||||
|  | 		err = cache.Set(existsKey, testValue, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to set cache value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Should exist now | ||||||
|  | 		exists, err = cache.Exists(existsKey) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to check if key exists: %v", err) | ||||||
|  | 		} | ||||||
|  | 		if !exists { | ||||||
|  | 			t.Errorf("Expected key to exist, but it doesn't") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Test Delete | ||||||
|  | 	t.Run("Delete", func(t *testing.T) { | ||||||
|  | 		deleteKey := "delete_key" | ||||||
|  | 
 | ||||||
|  | 		// Set value | ||||||
|  | 		err := cache.Set(deleteKey, testValue, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to set cache value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Delete value | ||||||
|  | 		err = cache.Delete(deleteKey) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to delete cache value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Should not exist anymore | ||||||
|  | 		var retrieved map[string]interface{} | ||||||
|  | 		err = cache.Get(deleteKey, &retrieved) | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Errorf("Expected cache miss for deleted key, but got value") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Test plugin ID prefixing | ||||||
|  | 	t.Run("Plugin ID prefixing", func(t *testing.T) { | ||||||
|  | 		cache1 := New(database, "plugin1") | ||||||
|  | 		cache2 := New(database, "plugin2") | ||||||
|  | 
 | ||||||
|  | 		sameKey := "same_key" | ||||||
|  | 		value1 := "value1" | ||||||
|  | 		value2 := "value2" | ||||||
|  | 
 | ||||||
|  | 		// Set same key in both caches | ||||||
|  | 		err := cache1.Set(sameKey, value1, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to set cache1 value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		err = cache2.Set(sameKey, value2, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to set cache2 value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Retrieve from both caches | ||||||
|  | 		var retrieved1, retrieved2 string | ||||||
|  | 
 | ||||||
|  | 		err = cache1.Get(sameKey, &retrieved1) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to get cache1 value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		err = cache2.Get(sameKey, &retrieved2) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to get cache2 value: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Values should be different due to plugin ID prefixing | ||||||
|  | 		if retrieved1 != value1 { | ||||||
|  | 			t.Errorf("Expected cache1 value %v, got %v", value1, retrieved1) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if retrieved2 != value2 { | ||||||
|  | 			t.Errorf("Expected cache2 value %v, got %v", value2, retrieved2) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | @ -1,14 +1,16 @@ | ||||||
| package db | package db | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/sha256" |  | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"encoding/hex" |  | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"golang.org/x/crypto/bcrypt" | ||||||
| 	_ "modernc.org/sqlite" | 	_ "modernc.org/sqlite" | ||||||
| 
 | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/migration" | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/model" | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -33,6 +35,11 @@ func New(dbPath string) (*Database, error) { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Configure SQLite for better reliability | ||||||
|  | 	if err := configureSQLite(db); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Initialize database | 	// Initialize database | ||||||
| 	if err := initDatabase(db); err != nil { | 	if err := initDatabase(db); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -49,7 +56,7 @@ func (d *Database) Close() error { | ||||||
| // GetChannelByID retrieves a channel by ID | // GetChannelByID retrieves a channel by ID | ||||||
| func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { | func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { | ||||||
| 	query := ` | 	query := ` | ||||||
| 		SELECT id, platform, platform_channel_id, enabled, channel_raw | 		SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw | ||||||
| 		FROM channels | 		FROM channels | ||||||
| 		WHERE id = ? | 		WHERE id = ? | ||||||
| 	` | 	` | ||||||
|  | @ -60,10 +67,11 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { | ||||||
| 		platform          string | 		platform          string | ||||||
| 		platformChannelID string | 		platformChannelID string | ||||||
| 		enabled           bool | 		enabled           bool | ||||||
|  | 		enableAllPlugins  bool | ||||||
| 		channelRawJSON    string | 		channelRawJSON    string | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) | 	err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON) | ||||||
| 	if err == sql.ErrNoRows { | 	if err == sql.ErrNoRows { | ||||||
| 		return nil, ErrNotFound | 		return nil, ErrNotFound | ||||||
| 	} | 	} | ||||||
|  | @ -83,6 +91,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { | ||||||
| 		Platform:          platform, | 		Platform:          platform, | ||||||
| 		PlatformChannelID: platformChannelID, | 		PlatformChannelID: platformChannelID, | ||||||
| 		Enabled:           enabled, | 		Enabled:           enabled, | ||||||
|  | 		EnableAllPlugins:  enableAllPlugins, | ||||||
| 		ChannelRaw:        channelRaw, | 		ChannelRaw:        channelRaw, | ||||||
| 		Plugins:           make(map[string]*model.ChannelPlugin), | 		Plugins:           make(map[string]*model.ChannelPlugin), | ||||||
| 	} | 	} | ||||||
|  | @ -103,7 +112,7 @@ func (d *Database) GetChannelByID(id int64) (*model.Channel, error) { | ||||||
| // GetChannelByPlatform retrieves a channel by platform and platform channel ID | // GetChannelByPlatform retrieves a channel by platform and platform channel ID | ||||||
| func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) { | func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*model.Channel, error) { | ||||||
| 	query := ` | 	query := ` | ||||||
| 		SELECT id, platform, platform_channel_id, enabled, channel_raw | 		SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw | ||||||
| 		FROM channels | 		FROM channels | ||||||
| 		WHERE platform = ? AND platform_channel_id = ? | 		WHERE platform = ? AND platform_channel_id = ? | ||||||
| 	` | 	` | ||||||
|  | @ -111,12 +120,13 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo | ||||||
| 	row := d.db.QueryRow(query, platform, platformChannelID) | 	row := d.db.QueryRow(query, platform, platformChannelID) | ||||||
| 
 | 
 | ||||||
| 	var ( | 	var ( | ||||||
| 		id             int64 | 		id               int64 | ||||||
| 		enabled        bool | 		enabled          bool | ||||||
| 		channelRawJSON string | 		enableAllPlugins bool | ||||||
|  | 		channelRawJSON   string | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	err := row.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON) | 	err := row.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON) | ||||||
| 	if err == sql.ErrNoRows { | 	if err == sql.ErrNoRows { | ||||||
| 		return nil, ErrNotFound | 		return nil, ErrNotFound | ||||||
| 	} | 	} | ||||||
|  | @ -136,6 +146,7 @@ func (d *Database) GetChannelByPlatform(platform, platformChannelID string) (*mo | ||||||
| 		Platform:          platform, | 		Platform:          platform, | ||||||
| 		PlatformChannelID: platformChannelID, | 		PlatformChannelID: platformChannelID, | ||||||
| 		Enabled:           enabled, | 		Enabled:           enabled, | ||||||
|  | 		EnableAllPlugins:  enableAllPlugins, | ||||||
| 		ChannelRaw:        channelRaw, | 		ChannelRaw:        channelRaw, | ||||||
| 		Plugins:           make(map[string]*model.ChannelPlugin), | 		Plugins:           make(map[string]*model.ChannelPlugin), | ||||||
| 	} | 	} | ||||||
|  | @ -163,11 +174,11 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo | ||||||
| 
 | 
 | ||||||
| 	// Insert channel | 	// Insert channel | ||||||
| 	query := ` | 	query := ` | ||||||
| 		INSERT INTO channels (platform, platform_channel_id, enabled, channel_raw) | 		INSERT INTO channels (platform, platform_channel_id, enabled, enable_all_plugins, channel_raw) | ||||||
| 		VALUES (?, ?, ?, ?) | 		VALUES (?, ?, ?, ?, ?) | ||||||
| 	` | 	` | ||||||
| 
 | 
 | ||||||
| 	result, err := d.db.Exec(query, platform, platformChannelID, enabled, string(channelRawJSON)) | 	result, err := d.db.Exec(query, platform, platformChannelID, enabled, false, string(channelRawJSON)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -184,6 +195,7 @@ func (d *Database) CreateChannel(platform, platformChannelID string, enabled boo | ||||||
| 		Platform:          platform, | 		Platform:          platform, | ||||||
| 		PlatformChannelID: platformChannelID, | 		PlatformChannelID: platformChannelID, | ||||||
| 		Enabled:           enabled, | 		Enabled:           enabled, | ||||||
|  | 		EnableAllPlugins:  false, | ||||||
| 		ChannelRaw:        channelRaw, | 		ChannelRaw:        channelRaw, | ||||||
| 		Plugins:           make(map[string]*model.ChannelPlugin), | 		Plugins:           make(map[string]*model.ChannelPlugin), | ||||||
| 	} | 	} | ||||||
|  | @ -203,6 +215,18 @@ func (d *Database) UpdateChannel(id int64, enabled bool) error { | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // UpdateChannelEnableAllPlugins updates a channel's enable_all_plugins status | ||||||
|  | func (d *Database) UpdateChannelEnableAllPlugins(id int64, enableAllPlugins bool) error { | ||||||
|  | 	query := ` | ||||||
|  | 		UPDATE channels | ||||||
|  | 		SET enable_all_plugins = ? | ||||||
|  | 		WHERE id = ? | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	_, err := d.db.Exec(query, enableAllPlugins, id) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // DeleteChannel deletes a channel | // DeleteChannel deletes a channel | ||||||
| func (d *Database) DeleteChannel(id int64) error { | func (d *Database) DeleteChannel(id int64) error { | ||||||
| 	// First delete all channel plugins | 	// First delete all channel plugins | ||||||
|  | @ -232,7 +256,11 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	defer rows.Close() | 	defer func() { | ||||||
|  | 		if err := rows.Close(); err != nil { | ||||||
|  | 			fmt.Printf("Error closing rows: %v\n", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	var plugins []*model.ChannelPlugin | 	var plugins []*model.ChannelPlugin | ||||||
| 
 | 
 | ||||||
|  | @ -250,7 +278,7 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Parse config JSON | 		// Parse config JSON | ||||||
| 		var config map[string]interface{} | 		var config map[string]any | ||||||
| 		if err := json.Unmarshal([]byte(configJSON), &config); err != nil { | 		if err := json.Unmarshal([]byte(configJSON), &config); err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  | @ -277,6 +305,28 @@ func (d *Database) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, e | ||||||
| 	return plugins, nil | 	return plugins, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetChannelPluginsFromPlatformID retrieves all plugins for a channel by platform and platform channel ID | ||||||
|  | func (d *Database) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) { | ||||||
|  | 	// First, get the channel ID by platform and platform channel ID | ||||||
|  | 	query := ` | ||||||
|  | 		SELECT id | ||||||
|  | 		FROM channels | ||||||
|  | 		WHERE platform = ? AND platform_channel_id = ? | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	var channelID int64 | ||||||
|  | 	err := d.db.QueryRow(query, platform, platformChannelID).Scan(&channelID) | ||||||
|  | 	if err == sql.ErrNoRows { | ||||||
|  | 		return nil, ErrNotFound | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Now get the plugins for this channel | ||||||
|  | 	return d.GetChannelPlugins(channelID) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // GetChannelPluginByID retrieves a channel plugin by ID | // GetChannelPluginByID retrieves a channel plugin by ID | ||||||
| func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) { | func (d *Database) GetChannelPluginByID(id int64) (*model.ChannelPlugin, error) { | ||||||
| 	query := ` | 	query := ` | ||||||
|  | @ -380,6 +430,24 @@ func (d *Database) UpdateChannelPlugin(id int64, enabled bool) error { | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // UpdateChannelPluginConfig updates a channel plugin's configuration | ||||||
|  | func (d *Database) UpdateChannelPluginConfig(id int64, config map[string]interface{}) error { | ||||||
|  | 	// Convert config to JSON | ||||||
|  | 	configJSON, err := json.Marshal(config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	query := ` | ||||||
|  | 		UPDATE channel_plugin | ||||||
|  | 		SET config = ? | ||||||
|  | 		WHERE id = ? | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	_, err = d.db.Exec(query, string(configJSON), id) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // DeleteChannelPlugin deletes a channel plugin | // DeleteChannelPlugin deletes a channel plugin | ||||||
| func (d *Database) DeleteChannelPlugin(id int64) error { | func (d *Database) DeleteChannelPlugin(id int64) error { | ||||||
| 	query := ` | 	query := ` | ||||||
|  | @ -405,7 +473,7 @@ func (d *Database) DeleteChannelPluginsByChannel(channelID int64) error { | ||||||
| // GetAllChannels retrieves all channels | // GetAllChannels retrieves all channels | ||||||
| func (d *Database) GetAllChannels() ([]*model.Channel, error) { | func (d *Database) GetAllChannels() ([]*model.Channel, error) { | ||||||
| 	query := ` | 	query := ` | ||||||
| 		SELECT id, platform, platform_channel_id, enabled, channel_raw | 		SELECT id, platform, platform_channel_id, enabled, enable_all_plugins, channel_raw | ||||||
| 		FROM channels | 		FROM channels | ||||||
| 	` | 	` | ||||||
| 
 | 
 | ||||||
|  | @ -413,7 +481,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	defer rows.Close() | 	defer func() { | ||||||
|  | 		if err := rows.Close(); err != nil { | ||||||
|  | 			fmt.Printf("Error closing rows: %v\n", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	var channels []*model.Channel | 	var channels []*model.Channel | ||||||
| 
 | 
 | ||||||
|  | @ -423,10 +495,11 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { | ||||||
| 			platform          string | 			platform          string | ||||||
| 			platformChannelID string | 			platformChannelID string | ||||||
| 			enabled           bool | 			enabled           bool | ||||||
|  | 			enableAllPlugins  bool | ||||||
| 			channelRawJSON    string | 			channelRawJSON    string | ||||||
| 		) | 		) | ||||||
| 
 | 
 | ||||||
| 		if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &channelRawJSON); err != nil { | 		if err := rows.Scan(&id, &platform, &platformChannelID, &enabled, &enableAllPlugins, &channelRawJSON); err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -442,6 +515,7 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { | ||||||
| 			Platform:          platform, | 			Platform:          platform, | ||||||
| 			PlatformChannelID: platformChannelID, | 			PlatformChannelID: platformChannelID, | ||||||
| 			Enabled:           enabled, | 			Enabled:           enabled, | ||||||
|  | 			EnableAllPlugins:  enableAllPlugins, | ||||||
| 			ChannelRaw:        channelRaw, | 			ChannelRaw:        channelRaw, | ||||||
| 			Plugins:           make(map[string]*model.ChannelPlugin), | 			Plugins:           make(map[string]*model.ChannelPlugin), | ||||||
| 		} | 		} | ||||||
|  | @ -452,10 +526,9 @@ func (d *Database) GetAllChannels() ([]*model.Channel, error) { | ||||||
| 			continue // Skip this channel if plugins can't be retrieved | 			continue // Skip this channel if plugins can't be retrieved | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if plugins != nil { | 		// Add plugins to channel | ||||||
| 			for _, plugin := range plugins { | 		for _, plugin := range plugins { | ||||||
| 				channel.Plugins[plugin.PluginID] = plugin | 			channel.Plugins[plugin.PluginID] = plugin | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		channels = append(channels, channel) | 		channels = append(channels, channel) | ||||||
|  | @ -505,7 +578,10 @@ func (d *Database) GetUserByID(id int64) (*model.User, error) { | ||||||
| // CreateUser creates a new user | // CreateUser creates a new user | ||||||
| func (d *Database) CreateUser(username, password string) (*model.User, error) { | func (d *Database) CreateUser(username, password string) (*model.User, error) { | ||||||
| 	// Hash password | 	// Hash password | ||||||
| 	hashedPassword := hashPassword(password) | 	hashedPassword, err := hashPassword(password) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Insert user | 	// Insert user | ||||||
| 	query := ` | 	query := ` | ||||||
|  | @ -555,9 +631,9 @@ func (d *Database) CheckCredentials(username, password string) (*model.User, err | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check password | 	// Check password with bcrypt | ||||||
| 	hashedPassword := hashPassword(password) | 	err = bcrypt.CompareHashAndPassword([]byte(dbPassword), []byte(password)) | ||||||
| 	if dbPassword != hashedPassword { | 	if err != nil { | ||||||
| 		return nil, errors.New("invalid credentials") | 		return nil, errors.New("invalid credentials") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -568,74 +644,277 @@ func (d *Database) CheckCredentials(username, password string) (*model.User, err | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // UpdateUserPassword updates a user's password | ||||||
|  | func (d *Database) UpdateUserPassword(userID int64, newPassword string) error { | ||||||
|  | 	// Hash the new password | ||||||
|  | 	hashedPassword, err := hashPassword(newPassword) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Update the user's password | ||||||
|  | 	query := ` | ||||||
|  | 		UPDATE users | ||||||
|  | 		SET password = ? | ||||||
|  | 		WHERE id = ? | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	_, err = d.db.Exec(query, hashedPassword, userID) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CreateReminder creates a new reminder | ||||||
|  | func (d *Database) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) { | ||||||
|  | 	query := ` | ||||||
|  | 		INSERT INTO reminders ( | ||||||
|  | 			platform, channel_id, message_id, reply_to_id, | ||||||
|  | 			user_id, username, created_at, trigger_at, | ||||||
|  | 			content, processed | ||||||
|  | 		) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0) | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	createdAt := time.Now() | ||||||
|  | 	result, err := d.db.Exec( | ||||||
|  | 		query, | ||||||
|  | 		platform, channelID, messageID, replyToID, | ||||||
|  | 		userID, username, createdAt, triggerAt, | ||||||
|  | 		content, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	id, err := result.LastInsertId() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &model.Reminder{ | ||||||
|  | 		ID:        id, | ||||||
|  | 		Platform:  platform, | ||||||
|  | 		ChannelID: channelID, | ||||||
|  | 		MessageID: messageID, | ||||||
|  | 		ReplyToID: replyToID, | ||||||
|  | 		UserID:    userID, | ||||||
|  | 		Username:  username, | ||||||
|  | 		CreatedAt: createdAt, | ||||||
|  | 		TriggerAt: triggerAt, | ||||||
|  | 		Content:   content, | ||||||
|  | 		Processed: false, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetPendingReminders gets all pending reminders that need to be processed | ||||||
|  | func (d *Database) GetPendingReminders() ([]*model.Reminder, error) { | ||||||
|  | 	query := ` | ||||||
|  | 		SELECT id, platform, channel_id, message_id, reply_to_id, | ||||||
|  | 			   user_id, username, created_at, trigger_at, content, processed | ||||||
|  | 		FROM reminders | ||||||
|  | 		WHERE processed = 0 AND trigger_at <= ? | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	rows, err := d.db.Query(query, time.Now()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		if err := rows.Close(); err != nil { | ||||||
|  | 			fmt.Printf("Error closing rows: %v\n", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	var reminders []*model.Reminder | ||||||
|  | 
 | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		var ( | ||||||
|  | 			id                                        int64 | ||||||
|  | 			platform, channelID, messageID, replyToID string | ||||||
|  | 			userID, username, content                 string | ||||||
|  | 			createdAt, triggerAt                      time.Time | ||||||
|  | 			processed                                 bool | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		if err := rows.Scan( | ||||||
|  | 			&id, &platform, &channelID, &messageID, &replyToID, | ||||||
|  | 			&userID, &username, &createdAt, &triggerAt, &content, &processed, | ||||||
|  | 		); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		reminder := &model.Reminder{ | ||||||
|  | 			ID:        id, | ||||||
|  | 			Platform:  platform, | ||||||
|  | 			ChannelID: channelID, | ||||||
|  | 			MessageID: messageID, | ||||||
|  | 			ReplyToID: replyToID, | ||||||
|  | 			UserID:    userID, | ||||||
|  | 			Username:  username, | ||||||
|  | 			CreatedAt: createdAt, | ||||||
|  | 			TriggerAt: triggerAt, | ||||||
|  | 			Content:   content, | ||||||
|  | 			Processed: processed, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		reminders = append(reminders, reminder) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := rows.Err(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(reminders) == 0 { | ||||||
|  | 		return make([]*model.Reminder, 0), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return reminders, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // MarkReminderAsProcessed marks a reminder as processed | ||||||
|  | func (d *Database) MarkReminderAsProcessed(id int64) error { | ||||||
|  | 	query := ` | ||||||
|  | 		UPDATE reminders | ||||||
|  | 		SET processed = 1 | ||||||
|  | 		WHERE id = ? | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	_, err := d.db.Exec(query, id) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Helper function to hash password | // Helper function to hash password | ||||||
| func hashPassword(password string) string { | func hashPassword(password string) (string, error) { | ||||||
| 	// In a real implementation, use a proper password hashing library like bcrypt | 	// Use bcrypt for secure password hashing | ||||||
| 	// This is a simplified version for demonstration | 	// The cost parameter is the computational cost, higher is more secure but slower | ||||||
| 	hasher := sha256.New() | 	// Recommended minimum is 12 | ||||||
| 	hasher.Write([]byte(password)) | 	hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), 12) | ||||||
| 	return hex.EncodeToString(hasher.Sum(nil)) | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	return string(hashedBytes), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Initialize database tables | // Initialize database tables | ||||||
| func initDatabase(db *sql.DB) error { | func initDatabase(db *sql.DB) error { | ||||||
| 	// Create channels table | 	// Ensure migration table exists | ||||||
| 	_, err := db.Exec(` | 	if err := migration.EnsureMigrationTable(db); err != nil { | ||||||
| 		CREATE TABLE IF NOT EXISTS channels ( | 		return fmt.Errorf("failed to create migration table: %w", err) | ||||||
| 			id INTEGER PRIMARY KEY AUTOINCREMENT, |  | ||||||
| 			platform TEXT NOT NULL, |  | ||||||
| 			platform_channel_id TEXT NOT NULL, |  | ||||||
| 			enabled BOOLEAN NOT NULL DEFAULT 0, |  | ||||||
| 			channel_raw TEXT NOT NULL, |  | ||||||
| 			UNIQUE(platform, platform_channel_id) |  | ||||||
| 		) |  | ||||||
| 	`) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Create channel_plugin table | 	// Get applied migrations | ||||||
| 	_, err = db.Exec(` | 	applied, err := migration.GetAppliedMigrations(db) | ||||||
| 		CREATE TABLE IF NOT EXISTS channel_plugin ( |  | ||||||
| 			id INTEGER PRIMARY KEY AUTOINCREMENT, |  | ||||||
| 			channel_id INTEGER NOT NULL, |  | ||||||
| 			plugin_id TEXT NOT NULL, |  | ||||||
| 			enabled BOOLEAN NOT NULL DEFAULT 0, |  | ||||||
| 			config TEXT NOT NULL DEFAULT '{}', |  | ||||||
| 			UNIQUE(channel_id, plugin_id), |  | ||||||
| 			FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE |  | ||||||
| 		) |  | ||||||
| 	`) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return fmt.Errorf("failed to get applied migrations: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Create users table | 	// Get all migration versions | ||||||
| 	_, err = db.Exec(` | 	allMigrations := make([]int, 0, len(migration.Migrations)) | ||||||
| 		CREATE TABLE IF NOT EXISTS users ( | 	for version := range migration.Migrations { | ||||||
| 			id INTEGER PRIMARY KEY AUTOINCREMENT, | 		allMigrations = append(allMigrations, version) | ||||||
| 			username TEXT NOT NULL UNIQUE, |  | ||||||
| 			password TEXT NOT NULL |  | ||||||
| 		) |  | ||||||
| 	`) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Create default admin user if it doesn't exist | 	// Create a map of applied migrations for quick lookup | ||||||
| 	var count int | 	appliedMap := make(map[int]bool) | ||||||
| 	err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) | 	for _, version := range applied { | ||||||
| 	if err != nil { | 		appliedMap[version] = true | ||||||
| 		return err |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if count == 0 { | 	// Count pending migrations | ||||||
| 		hashedPassword := hashPassword("admin") | 	pendingCount := 0 | ||||||
| 		_, err = db.Exec("INSERT INTO users (username, password) VALUES (?, ?)", "admin", hashedPassword) | 	for _, version := range allMigrations { | ||||||
| 		if err != nil { | 		if !appliedMap[version] { | ||||||
| 			return err | 			pendingCount++ | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Run migrations if needed | ||||||
|  | 	if pendingCount > 0 { | ||||||
|  | 		fmt.Printf("Running %d pending database migrations...\n", pendingCount) | ||||||
|  | 		if err := migration.Migrate(db); err != nil { | ||||||
|  | 			return fmt.Errorf("migration failed: %w", err) | ||||||
|  | 		} | ||||||
|  | 		fmt.Println("Database migrations completed successfully.") | ||||||
|  | 	} else { | ||||||
|  | 		fmt.Println("Database schema is up to date.") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Configure SQLite for better reliability | ||||||
|  | func configureSQLite(db *sql.DB) error { | ||||||
|  | 	pragmas := []string{ | ||||||
|  | 		// Enable Write-Ahead Logging for better concurrency and crash recovery | ||||||
|  | 		"PRAGMA journal_mode = WAL", | ||||||
|  | 		// Set 5-second timeout when database is locked by another connection | ||||||
|  | 		"PRAGMA busy_timeout = 5000", | ||||||
|  | 		// Balance between safety and performance for disk writes | ||||||
|  | 		"PRAGMA synchronous = NORMAL", | ||||||
|  | 		// Set large cache size (1GB) for better read performance | ||||||
|  | 		"PRAGMA cache_size = 1000000000", | ||||||
|  | 		// Enable foreign key constraint enforcement | ||||||
|  | 		"PRAGMA foreign_keys = true", | ||||||
|  | 		// Store temporary tables and indices in memory for speed | ||||||
|  | 		"PRAGMA temp_store = memory", | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, pragma := range pragmas { | ||||||
|  | 		if _, err := db.Exec(pragma); err != nil { | ||||||
|  | 			return fmt.Errorf("failed to execute %s: %w", pragma, err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // CacheGet retrieves a value from the cache | ||||||
|  | func (d *Database) CacheGet(key string) (string, error) { | ||||||
|  | 	query := ` | ||||||
|  | 		SELECT value | ||||||
|  | 		FROM cache | ||||||
|  | 		WHERE key = ? AND (expires_at IS NULL OR expires_at > ?) | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	var value string | ||||||
|  | 	err := d.db.QueryRow(query, key, time.Now()).Scan(&value) | ||||||
|  | 	if err == sql.ErrNoRows { | ||||||
|  | 		return "", ErrNotFound | ||||||
|  | 	} | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return value, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CacheSet stores a value in the cache with optional expiration | ||||||
|  | func (d *Database) CacheSet(key, value string, expiration *time.Time) error { | ||||||
|  | 	query := ` | ||||||
|  | 		INSERT OR REPLACE INTO cache (key, value, expires_at, updated_at) | ||||||
|  | 		VALUES (?, ?, ?, ?) | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	_, err := d.db.Exec(query, key, value, expiration, time.Now()) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CacheDelete removes a value from the cache | ||||||
|  | func (d *Database) CacheDelete(key string) error { | ||||||
|  | 	query := ` | ||||||
|  | 		DELETE FROM cache | ||||||
|  | 		WHERE key = ? | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	_, err := d.db.Exec(query, key) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CacheCleanup removes expired cache entries | ||||||
|  | func (d *Database) CacheCleanup() error { | ||||||
|  | 	query := ` | ||||||
|  | 		DELETE FROM cache | ||||||
|  | 		WHERE expires_at IS NOT NULL AND expires_at <= ? | ||||||
|  | 	` | ||||||
|  | 
 | ||||||
|  | 	_, err := d.db.Exec(query, time.Now()) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
							
								
								
									
										203
									
								
								internal/db/db_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								internal/db/db_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,203 @@ | ||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestEnableAllPlugins(t *testing.T) { | ||||||
|  | 	// Create temporary database for testing with unique name | ||||||
|  | 	dbFile := fmt.Sprintf("test_db_%d.db", time.Now().UnixNano()) | ||||||
|  | 	database, err := New(dbFile) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to create test database: %v", err) | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		_ = database.Close() | ||||||
|  | 		// Clean up test database file | ||||||
|  | 		_ = os.Remove(dbFile) | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	t.Run("CreateChannel with EnableAllPlugins default false", func(t *testing.T) { | ||||||
|  | 		channelRaw := map[string]interface{}{ | ||||||
|  | 			"name": "test-channel", | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		channel, err := database.CreateChannel("telegram", "123456", true, channelRaw) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to create channel: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if channel.EnableAllPlugins { | ||||||
|  | 			t.Errorf("Expected EnableAllPlugins to be false by default, got true") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Verify it's also false when retrieved from database | ||||||
|  | 		retrieved, err := database.GetChannelByID(channel.ID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to retrieve channel: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if retrieved.EnableAllPlugins { | ||||||
|  | 			t.Errorf("Expected EnableAllPlugins to be false when retrieved from DB, got true") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("UpdateChannelEnableAllPlugins", func(t *testing.T) { | ||||||
|  | 		// Create a channel | ||||||
|  | 		channelRaw := map[string]interface{}{ | ||||||
|  | 			"name": "test-channel-2", | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		channel, err := database.CreateChannel("telegram", "123457", true, channelRaw) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to create channel: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Update EnableAllPlugins to true | ||||||
|  | 		err = database.UpdateChannelEnableAllPlugins(channel.ID, true) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to update EnableAllPlugins: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Retrieve and verify | ||||||
|  | 		retrieved, err := database.GetChannelByID(channel.ID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to retrieve channel: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !retrieved.EnableAllPlugins { | ||||||
|  | 			t.Errorf("Expected EnableAllPlugins to be true after update, got false") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Update back to false | ||||||
|  | 		err = database.UpdateChannelEnableAllPlugins(channel.ID, false) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to update EnableAllPlugins back to false: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Retrieve and verify again | ||||||
|  | 		retrieved, err = database.GetChannelByID(channel.ID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to retrieve channel: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if retrieved.EnableAllPlugins { | ||||||
|  | 			t.Errorf("Expected EnableAllPlugins to be false after second update, got true") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("GetChannelByPlatform includes EnableAllPlugins", func(t *testing.T) { | ||||||
|  | 		// Create a channel | ||||||
|  | 		channelRaw := map[string]interface{}{ | ||||||
|  | 			"name": "test-channel-3", | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		channel, err := database.CreateChannel("slack", "C123456", true, channelRaw) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to create channel: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Enable all plugins | ||||||
|  | 		err = database.UpdateChannelEnableAllPlugins(channel.ID, true) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to update EnableAllPlugins: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Retrieve by platform | ||||||
|  | 		retrieved, err := database.GetChannelByPlatform("slack", "C123456") | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to retrieve channel by platform: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !retrieved.EnableAllPlugins { | ||||||
|  | 			t.Errorf("Expected EnableAllPlugins to be true when retrieved by platform, got false") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("GetAllChannels includes EnableAllPlugins", func(t *testing.T) { | ||||||
|  | 		// Create multiple channels with different EnableAllPlugins settings | ||||||
|  | 		channelRaw1 := map[string]interface{}{"name": "channel-1"} | ||||||
|  | 		channelRaw2 := map[string]interface{}{"name": "channel-2"} | ||||||
|  | 
 | ||||||
|  | 		channel1, err := database.CreateChannel("platform1", "ch1", true, channelRaw1) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to create channel1: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		channel2, err := database.CreateChannel("platform2", "ch2", true, channelRaw2) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to create channel2: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Enable all plugins for channel2 only | ||||||
|  | 		err = database.UpdateChannelEnableAllPlugins(channel2.ID, true) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to update EnableAllPlugins for channel2: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Get all channels | ||||||
|  | 		channels, err := database.GetAllChannels() | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to get all channels: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Find our test channels | ||||||
|  | 		var foundChannel1, foundChannel2 *model.Channel | ||||||
|  | 		for _, ch := range channels { | ||||||
|  | 			if ch.ID == channel1.ID { | ||||||
|  | 				foundChannel1 = ch | ||||||
|  | 			} | ||||||
|  | 			if ch.ID == channel2.ID { | ||||||
|  | 				foundChannel2 = ch | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if foundChannel1 == nil { | ||||||
|  | 			t.Fatalf("Channel1 not found in GetAllChannels result") | ||||||
|  | 		} | ||||||
|  | 		if foundChannel2 == nil { | ||||||
|  | 			t.Fatalf("Channel2 not found in GetAllChannels result") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if foundChannel1.EnableAllPlugins { | ||||||
|  | 			t.Errorf("Expected channel1 EnableAllPlugins to be false, got true") | ||||||
|  | 		} | ||||||
|  | 		if !foundChannel2.EnableAllPlugins { | ||||||
|  | 			t.Errorf("Expected channel2 EnableAllPlugins to be true, got false") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Migration applied correctly", func(t *testing.T) { | ||||||
|  | 		// Test that we can create a channel and the enable_all_plugins column exists | ||||||
|  | 		// This implicitly tests that migration 4 was applied correctly | ||||||
|  | 		channelRaw := map[string]interface{}{ | ||||||
|  | 			"name": "migration-test-channel", | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		channel, err := database.CreateChannel("test-platform", "migration-test", true, channelRaw) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to create channel after migration: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Try to update EnableAllPlugins - this would fail if the column doesn't exist | ||||||
|  | 		err = database.UpdateChannelEnableAllPlugins(channel.ID, true) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to update EnableAllPlugins - migration may not have been applied: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Verify the value was set correctly | ||||||
|  | 		retrieved, err := database.GetChannelByID(channel.ID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to retrieve channel: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !retrieved.EnableAllPlugins { | ||||||
|  | 			t.Errorf("EnableAllPlugins should be true after update") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										223
									
								
								internal/migration/migration.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										223
									
								
								internal/migration/migration.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,223 @@ | ||||||
|  | package migration | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"fmt" | ||||||
|  | 	"sort" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Migration represents a database migration | ||||||
|  | type Migration struct { | ||||||
|  | 	Version     int | ||||||
|  | 	Description string | ||||||
|  | 	Up          func(db *sql.DB) error | ||||||
|  | 	Down        func(db *sql.DB) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Migrations is a collection of registered migrations | ||||||
|  | var Migrations = make(map[int]Migration) | ||||||
|  | 
 | ||||||
|  | // Register adds a migration to the list of available migrations | ||||||
|  | func Register(version int, description string, up, down func(db *sql.DB) error) { | ||||||
|  | 	if _, exists := Migrations[version]; exists { | ||||||
|  | 		panic(fmt.Sprintf("migration version %d already exists", version)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	Migrations[version] = Migration{ | ||||||
|  | 		Version:     version, | ||||||
|  | 		Description: description, | ||||||
|  | 		Up:          up, | ||||||
|  | 		Down:        down, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // EnsureMigrationTable creates the migration table if it doesn't exist | ||||||
|  | func EnsureMigrationTable(db *sql.DB) error { | ||||||
|  | 	_, err := db.Exec(` | ||||||
|  | 		CREATE TABLE IF NOT EXISTS schema_migrations ( | ||||||
|  | 			version INTEGER PRIMARY KEY, | ||||||
|  | 			applied_at TIMESTAMP NOT NULL | ||||||
|  | 		) | ||||||
|  | 	`) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetAppliedMigrations returns a list of applied migration versions | ||||||
|  | func GetAppliedMigrations(db *sql.DB) ([]int, error) { | ||||||
|  | 	rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		if err := rows.Close(); err != nil { | ||||||
|  | 			fmt.Printf("Error closing rows: %v\n", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	var versions []int | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		var version int | ||||||
|  | 		if err := rows.Scan(&version); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		versions = append(versions, version) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return versions, rows.Err() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsApplied checks if a migration version has been applied | ||||||
|  | func IsApplied(db *sql.DB, version int) (bool, error) { | ||||||
|  | 	var count int | ||||||
|  | 	err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", version).Scan(&count) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	return count > 0, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // MarkAsApplied marks a migration as applied | ||||||
|  | func MarkAsApplied(db *sql.DB, version int) error { | ||||||
|  | 	_, err := db.Exec( | ||||||
|  | 		"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", | ||||||
|  | 		version, time.Now(), | ||||||
|  | 	) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // RemoveApplied removes a migration from the applied list | ||||||
|  | func RemoveApplied(db *sql.DB, version int) error { | ||||||
|  | 	_, err := db.Exec("DELETE FROM schema_migrations WHERE version = ?", version) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Migrate runs pending migrations up to the latest version | ||||||
|  | func Migrate(db *sql.DB) error { | ||||||
|  | 	// Ensure migration table exists | ||||||
|  | 	if err := EnsureMigrationTable(db); err != nil { | ||||||
|  | 		return fmt.Errorf("failed to create migration table: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get applied migrations | ||||||
|  | 	applied, err := GetAppliedMigrations(db) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("failed to get applied migrations: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create a map of applied migrations for quick lookup | ||||||
|  | 	appliedMap := make(map[int]bool) | ||||||
|  | 	for _, version := range applied { | ||||||
|  | 		appliedMap[version] = true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get all migration versions and sort them | ||||||
|  | 	var versions []int | ||||||
|  | 	for version := range Migrations { | ||||||
|  | 		versions = append(versions, version) | ||||||
|  | 	} | ||||||
|  | 	sort.Ints(versions) | ||||||
|  | 
 | ||||||
|  | 	// Apply each pending migration | ||||||
|  | 	for _, version := range versions { | ||||||
|  | 		if !appliedMap[version] { | ||||||
|  | 			migration := Migrations[version] | ||||||
|  | 			fmt.Printf("Applying migration %d: %s...\n", version, migration.Description) | ||||||
|  | 
 | ||||||
|  | 			// Start transaction for the migration | ||||||
|  | 			tx, err := db.Begin() | ||||||
|  | 			if err != nil { | ||||||
|  | 				return fmt.Errorf("failed to begin transaction for migration %d: %w", version, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Apply the migration | ||||||
|  | 			if err := migration.Up(db); err != nil { | ||||||
|  | 				if err := tx.Rollback(); err != nil { | ||||||
|  | 					fmt.Printf("Error rolling back transaction: %v\n", err) | ||||||
|  | 				} | ||||||
|  | 				return fmt.Errorf("failed to apply migration %d: %w", version, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Mark as applied | ||||||
|  | 			if _, err := tx.Exec( | ||||||
|  | 				"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", | ||||||
|  | 				version, time.Now(), | ||||||
|  | 			); err != nil { | ||||||
|  | 				if err := tx.Rollback(); err != nil { | ||||||
|  | 					fmt.Printf("Error rolling back transaction: %v\n", err) | ||||||
|  | 				} | ||||||
|  | 				return fmt.Errorf("failed to mark migration %d as applied: %w", version, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Commit the transaction | ||||||
|  | 			if err := tx.Commit(); err != nil { | ||||||
|  | 				return fmt.Errorf("failed to commit migration %d: %w", version, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			fmt.Printf("Migration %d applied successfully\n", version) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // MigrateDown rolls back migrations down to the specified version | ||||||
|  | // If version is -1, it will roll back all migrations | ||||||
|  | func MigrateDown(db *sql.DB, targetVersion int) error { | ||||||
|  | 	// Ensure migration table exists | ||||||
|  | 	if err := EnsureMigrationTable(db); err != nil { | ||||||
|  | 		return fmt.Errorf("failed to create migration table: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get applied migrations | ||||||
|  | 	applied, err := GetAppliedMigrations(db) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("failed to get applied migrations: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Sort in descending order to roll back newest first | ||||||
|  | 	sort.Sort(sort.Reverse(sort.IntSlice(applied))) | ||||||
|  | 
 | ||||||
|  | 	// Roll back each migration until target version | ||||||
|  | 	for _, version := range applied { | ||||||
|  | 		if targetVersion == -1 || version > targetVersion { | ||||||
|  | 			migration, exists := Migrations[version] | ||||||
|  | 			if !exists { | ||||||
|  | 				return fmt.Errorf("migration %d is applied but not found in codebase", version) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			fmt.Printf("Rolling back migration %d: %s...\n", version, migration.Description) | ||||||
|  | 
 | ||||||
|  | 			// Start transaction for the rollback | ||||||
|  | 			tx, err := db.Begin() | ||||||
|  | 			if err != nil { | ||||||
|  | 				return fmt.Errorf("failed to begin transaction for rollback %d: %w", version, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Apply the down migration | ||||||
|  | 			if err := migration.Down(db); err != nil { | ||||||
|  | 				if err := tx.Rollback(); err != nil { | ||||||
|  | 					fmt.Printf("Error rolling back transaction: %v\n", err) | ||||||
|  | 				} | ||||||
|  | 				return fmt.Errorf("failed to roll back migration %d: %w", version, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Remove from applied list | ||||||
|  | 			if _, err := tx.Exec("DELETE FROM schema_migrations WHERE version = ?", version); err != nil { | ||||||
|  | 				if err := tx.Rollback(); err != nil { | ||||||
|  | 					fmt.Printf("Error rolling back transaction: %v\n", err) | ||||||
|  | 				} | ||||||
|  | 				return fmt.Errorf("failed to remove migration %d from applied list: %w", version, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Commit the transaction | ||||||
|  | 			if err := tx.Commit(); err != nil { | ||||||
|  | 				return fmt.Errorf("failed to commit rollback %d: %w", version, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			fmt.Printf("Migration %d rolled back successfully\n", version) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										214
									
								
								internal/migration/migrations.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								internal/migration/migrations.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,214 @@ | ||||||
|  | package migration | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"golang.org/x/crypto/bcrypt" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func init() { | ||||||
|  | 	// Register migrations | ||||||
|  | 	Register(1, "Initial schema with bcrypt passwords", migrateInitialSchemaUp, migrateInitialSchemaDown) | ||||||
|  | 	Register(2, "Add reminders table", migrateRemindersUp, migrateRemindersDown) | ||||||
|  | 	Register(3, "Add cache table", migrateCacheUp, migrateCacheDown) | ||||||
|  | 	Register(4, "Add enable_all_plugins column to channels", migrateEnableAllPluginsUp, migrateEnableAllPluginsDown) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Initial schema creation with bcrypt passwords - version 1 | ||||||
|  | func migrateInitialSchemaUp(db *sql.DB) error { | ||||||
|  | 	// Create channels table | ||||||
|  | 	_, err := db.Exec(` | ||||||
|  | 		CREATE TABLE IF NOT EXISTS channels ( | ||||||
|  | 			id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
|  | 			platform TEXT NOT NULL, | ||||||
|  | 			platform_channel_id TEXT NOT NULL, | ||||||
|  | 			enabled BOOLEAN NOT NULL DEFAULT 0, | ||||||
|  | 			channel_raw TEXT NOT NULL, | ||||||
|  | 			UNIQUE(platform, platform_channel_id) | ||||||
|  | 		) | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create channel_plugin table | ||||||
|  | 	_, err = db.Exec(` | ||||||
|  | 		CREATE TABLE IF NOT EXISTS channel_plugin ( | ||||||
|  | 			id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
|  | 			channel_id INTEGER NOT NULL, | ||||||
|  | 			plugin_id TEXT NOT NULL, | ||||||
|  | 			enabled BOOLEAN NOT NULL DEFAULT 0, | ||||||
|  | 			config TEXT NOT NULL DEFAULT '{}', | ||||||
|  | 			UNIQUE(channel_id, plugin_id), | ||||||
|  | 			FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE | ||||||
|  | 		) | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create users table with bcrypt passwords | ||||||
|  | 	_, err = db.Exec(` | ||||||
|  | 		CREATE TABLE IF NOT EXISTS users ( | ||||||
|  | 			id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
|  | 			username TEXT NOT NULL UNIQUE, | ||||||
|  | 			password TEXT NOT NULL | ||||||
|  | 		) | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create default admin user with bcrypt password | ||||||
|  | 	hashedPassword, err := bcrypt.GenerateFromPassword([]byte("admin"), 12) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check if users table is empty before inserting | ||||||
|  | 	var count int | ||||||
|  | 	err = db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if count == 0 { | ||||||
|  | 		_, err = db.Exec( | ||||||
|  | 			"INSERT INTO users (username, password) VALUES (?, ?)", | ||||||
|  | 			"admin", string(hashedPassword), | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func migrateInitialSchemaDown(db *sql.DB) error { | ||||||
|  | 	// Drop tables in reverse order of dependencies | ||||||
|  | 	_, err := db.Exec(`DROP TABLE IF EXISTS channel_plugin`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, err = db.Exec(`DROP TABLE IF EXISTS channels`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, err = db.Exec(`DROP TABLE IF EXISTS users`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Add reminders table - version 2 | ||||||
|  | func migrateRemindersUp(db *sql.DB) error { | ||||||
|  | 	_, err := db.Exec(` | ||||||
|  | 		CREATE TABLE IF NOT EXISTS reminders ( | ||||||
|  | 			id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
|  | 			platform TEXT NOT NULL, | ||||||
|  | 			channel_id TEXT NOT NULL, | ||||||
|  | 			message_id TEXT NOT NULL, | ||||||
|  | 			reply_to_id TEXT NOT NULL, | ||||||
|  | 			user_id TEXT NOT NULL, | ||||||
|  | 			username TEXT NOT NULL, | ||||||
|  | 			created_at TIMESTAMP NOT NULL, | ||||||
|  | 			trigger_at TIMESTAMP NOT NULL, | ||||||
|  | 			content TEXT NOT NULL, | ||||||
|  | 			processed BOOLEAN NOT NULL DEFAULT 0 | ||||||
|  | 		) | ||||||
|  | 	`) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func migrateRemindersDown(db *sql.DB) error { | ||||||
|  | 	_, err := db.Exec(`DROP TABLE IF EXISTS reminders`) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Add cache table - version 3 | ||||||
|  | func migrateCacheUp(db *sql.DB) error { | ||||||
|  | 	_, err := db.Exec(` | ||||||
|  | 		CREATE TABLE IF NOT EXISTS cache ( | ||||||
|  | 			key TEXT PRIMARY KEY, | ||||||
|  | 			value TEXT NOT NULL, | ||||||
|  | 			expires_at TIMESTAMP, | ||||||
|  | 			created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||||||
|  | 			updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | ||||||
|  | 		) | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create index on expires_at for efficient cleanup | ||||||
|  | 	_, err = db.Exec(` | ||||||
|  | 		CREATE INDEX IF NOT EXISTS idx_cache_expires_at ON cache(expires_at) | ||||||
|  | 	`) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func migrateCacheDown(db *sql.DB) error { | ||||||
|  | 	_, err := db.Exec(`DROP TABLE IF EXISTS cache`) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Add enable_all_plugins column to channels table - version 4 | ||||||
|  | func migrateEnableAllPluginsUp(db *sql.DB) error { | ||||||
|  | 	_, err := db.Exec(` | ||||||
|  | 		ALTER TABLE channels ADD COLUMN enable_all_plugins BOOLEAN NOT NULL DEFAULT 0 | ||||||
|  | 	`) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func migrateEnableAllPluginsDown(db *sql.DB) error { | ||||||
|  | 	// SQLite doesn't support DROP COLUMN, so we need to recreate the table | ||||||
|  | 	tx, err := db.Begin() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		_ = tx.Rollback() // Ignore rollback errors | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	// Create backup table | ||||||
|  | 	_, err = tx.Exec(` | ||||||
|  | 		CREATE TABLE channels_backup ( | ||||||
|  | 			id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
|  | 			platform TEXT NOT NULL, | ||||||
|  | 			platform_channel_id TEXT NOT NULL, | ||||||
|  | 			enabled BOOLEAN NOT NULL DEFAULT 0, | ||||||
|  | 			channel_raw TEXT NOT NULL, | ||||||
|  | 			UNIQUE(platform, platform_channel_id) | ||||||
|  | 		) | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Copy data excluding enable_all_plugins column | ||||||
|  | 	_, err = tx.Exec(` | ||||||
|  | 		INSERT INTO channels_backup (id, platform, platform_channel_id, enabled, channel_raw) | ||||||
|  | 		SELECT id, platform, platform_channel_id, enabled, channel_raw FROM channels | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Drop original table | ||||||
|  | 	_, err = tx.Exec(`DROP TABLE channels`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Rename backup table | ||||||
|  | 	_, err = tx.Exec(`ALTER TABLE channels_backup RENAME TO channels`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return tx.Commit() | ||||||
|  | } | ||||||
|  | @ -4,31 +4,57 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // ActionType defines the type of action to perform | ||||||
|  | type ActionType string | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	// ActionSendMessage is for sending a message to the chat | ||||||
|  | 	ActionSendMessage ActionType = "send_message" | ||||||
|  | 	// ActionDeleteMessage is for deleting a message from the chat | ||||||
|  | 	ActionDeleteMessage ActionType = "delete_message" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // MessageAction represents an action to be performed on the platform | ||||||
|  | type MessageAction struct { | ||||||
|  | 	Type      ActionType | ||||||
|  | 	Message   *Message               // For send_message | ||||||
|  | 	MessageID string                 // For delete_message | ||||||
|  | 	Chat      string                 // Chat where the action happens | ||||||
|  | 	Channel   *Channel               // Channel reference | ||||||
|  | 	Raw       map[string]interface{} // Additional data for the action | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Message represents a chat message | // Message represents a chat message | ||||||
| type Message struct { | type Message struct { | ||||||
| 	Text     string | 	Text    string | ||||||
| 	Chat     string | 	Chat    string | ||||||
| 	Channel  *Channel | 	Channel *Channel | ||||||
| 	Author   string | 	Author  string | ||||||
| 	FromBot  bool | 	FromBot bool | ||||||
| 	Date     time.Time | 	Date    time.Time | ||||||
| 	ID       string | 	ID      string | ||||||
| 	ReplyTo  string | 	ReplyTo string | ||||||
| 	Raw      map[string]interface{} | 	Raw     map[string]interface{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Channel represents a chat channel | // Channel represents a chat channel | ||||||
| type Channel struct { | type Channel struct { | ||||||
| 	ID               int64 | 	ID                int64 | ||||||
| 	Platform         string | 	Platform          string | ||||||
| 	PlatformChannelID string | 	PlatformChannelID string | ||||||
| 	ChannelRaw       map[string]interface{} | 	ChannelRaw        map[string]interface{} | ||||||
| 	Enabled          bool | 	Enabled           bool | ||||||
| 	Plugins          map[string]*ChannelPlugin | 	EnableAllPlugins  bool | ||||||
|  | 	Plugins           map[string]*ChannelPlugin | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // HasEnabledPlugin checks if a plugin is enabled for this channel | // HasEnabledPlugin checks if a plugin is enabled for this channel | ||||||
| func (c *Channel) HasEnabledPlugin(pluginID string) bool { | func (c *Channel) HasEnabledPlugin(pluginID string) bool { | ||||||
|  | 	// If EnableAllPlugins is true, all plugins are considered enabled | ||||||
|  | 	if c.EnableAllPlugins { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	plugin, exists := c.Plugins[pluginID] | 	plugin, exists := c.Plugins[pluginID] | ||||||
| 	if !exists { | 	if !exists { | ||||||
| 		return false | 		return false | ||||||
|  | @ -40,18 +66,18 @@ func (c *Channel) HasEnabledPlugin(pluginID string) bool { | ||||||
| func (c *Channel) ChannelName() string { | func (c *Channel) ChannelName() string { | ||||||
| 	// In a real implementation, this would use the platform-specific | 	// In a real implementation, this would use the platform-specific | ||||||
| 	// ParseChannelNameFromRaw function | 	// ParseChannelNameFromRaw function | ||||||
| 	 | 
 | ||||||
| 	// For simplicity, we'll just use the PlatformChannelID if we can't extract a name | 	// For simplicity, we'll just use the PlatformChannelID if we can't extract a name | ||||||
| 	// Check if ChannelRaw has a name field | 	// Check if ChannelRaw has a name field | ||||||
| 	if c.ChannelRaw == nil { | 	if c.ChannelRaw == nil { | ||||||
| 		return c.PlatformChannelID | 		return c.PlatformChannelID | ||||||
| 	} | 	} | ||||||
| 	 | 
 | ||||||
| 	// Check common name fields in ChannelRaw | 	// Check common name fields in ChannelRaw | ||||||
| 	if name, ok := c.ChannelRaw["name"].(string); ok && name != "" { | 	if name, ok := c.ChannelRaw["name"].(string); ok && name != "" { | ||||||
| 		return name | 		return name | ||||||
| 	} | 	} | ||||||
| 	 | 
 | ||||||
| 	// Check for nested objects like "chat" (used by Telegram) | 	// Check for nested objects like "chat" (used by Telegram) | ||||||
| 	if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok { | 	if chat, ok := c.ChannelRaw["chat"].(map[string]interface{}); ok { | ||||||
| 		// Try different fields in order of preference | 		// Try different fields in order of preference | ||||||
|  | @ -65,7 +91,7 @@ func (c *Channel) ChannelName() string { | ||||||
| 			return firstName | 			return firstName | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	 | 
 | ||||||
| 	return c.PlatformChannelID | 	return c.PlatformChannelID | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -75,7 +101,7 @@ type ChannelPlugin struct { | ||||||
| 	ChannelID int64 | 	ChannelID int64 | ||||||
| 	PluginID  string | 	PluginID  string | ||||||
| 	Enabled   bool | 	Enabled   bool | ||||||
| 	Config    map[string]interface{} | 	Config    map[string]any | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // User represents an admin user | // User represents an admin user | ||||||
|  | @ -83,4 +109,19 @@ type User struct { | ||||||
| 	ID       int64 | 	ID       int64 | ||||||
| 	Username string | 	Username string | ||||||
| 	Password string | 	Password string | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // Reminder represents a scheduled reminder | ||||||
|  | type Reminder struct { | ||||||
|  | 	ID        int64 | ||||||
|  | 	Platform  string | ||||||
|  | 	ChannelID string | ||||||
|  | 	MessageID string | ||||||
|  | 	ReplyToID string | ||||||
|  | 	UserID    string | ||||||
|  | 	Username  string | ||||||
|  | 	CreatedAt time.Time | ||||||
|  | 	TriggerAt time.Time | ||||||
|  | 	Content   string | ||||||
|  | 	Processed bool | ||||||
|  | } | ||||||
|  |  | ||||||
							
								
								
									
										234
									
								
								internal/model/message_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										234
									
								
								internal/model/message_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,234 @@ | ||||||
|  | package model | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestChannel_HasEnabledPlugin(t *testing.T) { | ||||||
|  | 	t.Run("EnableAllPlugins false - plugin not in map", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  false, | ||||||
|  | 			Plugins:           make(map[string]*ChannelPlugin), | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Plugin not in map should return false | ||||||
|  | 		result := channel.HasEnabledPlugin("nonexistent.plugin") | ||||||
|  | 		if result { | ||||||
|  | 			t.Errorf("Expected HasEnabledPlugin to return false for nonexistent plugin, got true") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EnableAllPlugins false - plugin disabled", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  false, | ||||||
|  | 			Plugins: map[string]*ChannelPlugin{ | ||||||
|  | 				"test.plugin": { | ||||||
|  | 					ID:        1, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "test.plugin", | ||||||
|  | 					Enabled:   false, | ||||||
|  | 					Config:    make(map[string]any), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Disabled plugin should return false | ||||||
|  | 		result := channel.HasEnabledPlugin("test.plugin") | ||||||
|  | 		if result { | ||||||
|  | 			t.Errorf("Expected HasEnabledPlugin to return false for disabled plugin, got true") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EnableAllPlugins false - plugin enabled", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  false, | ||||||
|  | 			Plugins: map[string]*ChannelPlugin{ | ||||||
|  | 				"test.plugin": { | ||||||
|  | 					ID:        1, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "test.plugin", | ||||||
|  | 					Enabled:   true, | ||||||
|  | 					Config:    make(map[string]any), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Enabled plugin should return true | ||||||
|  | 		result := channel.HasEnabledPlugin("test.plugin") | ||||||
|  | 		if !result { | ||||||
|  | 			t.Errorf("Expected HasEnabledPlugin to return true for enabled plugin, got false") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EnableAllPlugins true - plugin not in map", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  true, | ||||||
|  | 			Plugins:           make(map[string]*ChannelPlugin), | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// When EnableAllPlugins is true, any plugin should be considered enabled | ||||||
|  | 		result := channel.HasEnabledPlugin("nonexistent.plugin") | ||||||
|  | 		if !result { | ||||||
|  | 			t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EnableAllPlugins true - plugin disabled", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  true, | ||||||
|  | 			Plugins: map[string]*ChannelPlugin{ | ||||||
|  | 				"test.plugin": { | ||||||
|  | 					ID:        1, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "test.plugin", | ||||||
|  | 					Enabled:   false, | ||||||
|  | 					Config:    make(map[string]any), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// When EnableAllPlugins is true, even disabled plugins should be considered enabled | ||||||
|  | 		result := channel.HasEnabledPlugin("test.plugin") | ||||||
|  | 		if !result { | ||||||
|  | 			t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true (even for disabled plugin), got false") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EnableAllPlugins true - plugin enabled", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  true, | ||||||
|  | 			Plugins: map[string]*ChannelPlugin{ | ||||||
|  | 				"test.plugin": { | ||||||
|  | 					ID:        1, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "test.plugin", | ||||||
|  | 					Enabled:   true, | ||||||
|  | 					Config:    make(map[string]any), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// When EnableAllPlugins is true, enabled plugins should also return true | ||||||
|  | 		result := channel.HasEnabledPlugin("test.plugin") | ||||||
|  | 		if !result { | ||||||
|  | 			t.Errorf("Expected HasEnabledPlugin to return true when EnableAllPlugins is true, got false") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EnableAllPlugins true - multiple plugins", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  true, | ||||||
|  | 			Plugins: map[string]*ChannelPlugin{ | ||||||
|  | 				"plugin1": { | ||||||
|  | 					ID:        1, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "plugin1", | ||||||
|  | 					Enabled:   true, | ||||||
|  | 					Config:    make(map[string]any), | ||||||
|  | 				}, | ||||||
|  | 				"plugin2": { | ||||||
|  | 					ID:        2, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "plugin2", | ||||||
|  | 					Enabled:   false, | ||||||
|  | 					Config:    make(map[string]any), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// All plugins should be enabled when EnableAllPlugins is true | ||||||
|  | 		testCases := []string{"plugin1", "plugin2", "plugin3", "any.plugin"} | ||||||
|  | 		for _, pluginID := range testCases { | ||||||
|  | 			result := channel.HasEnabledPlugin(pluginID) | ||||||
|  | 			if !result { | ||||||
|  | 				t.Errorf("Expected HasEnabledPlugin('%s') to return true when EnableAllPlugins is true, got false", pluginID) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestChannelName(t *testing.T) { | ||||||
|  | 	t.Run("Returns PlatformChannelID when ChannelRaw is nil", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			PlatformChannelID: "test-id", | ||||||
|  | 			ChannelRaw:        nil, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		result := channel.ChannelName() | ||||||
|  | 		if result != "test-id" { | ||||||
|  | 			t.Errorf("Expected channel name to be 'test-id', got '%s'", result) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Returns name from ChannelRaw when available", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			PlatformChannelID: "test-id", | ||||||
|  | 			ChannelRaw: map[string]interface{}{ | ||||||
|  | 				"name": "Test Channel", | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		result := channel.ChannelName() | ||||||
|  | 		if result != "Test Channel" { | ||||||
|  | 			t.Errorf("Expected channel name to be 'Test Channel', got '%s'", result) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Returns title from nested chat object (Telegram style)", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			PlatformChannelID: "test-id", | ||||||
|  | 			ChannelRaw: map[string]interface{}{ | ||||||
|  | 				"chat": map[string]interface{}{ | ||||||
|  | 					"title": "Telegram Group", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		result := channel.ChannelName() | ||||||
|  | 		if result != "Telegram Group" { | ||||||
|  | 			t.Errorf("Expected channel name to be 'Telegram Group', got '%s'", result) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Falls back to PlatformChannelID when no valid name found", func(t *testing.T) { | ||||||
|  | 		channel := &Channel{ | ||||||
|  | 			PlatformChannelID: "fallback-id", | ||||||
|  | 			ChannelRaw: map[string]interface{}{ | ||||||
|  | 				"other_field": "value", | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		result := channel.ChannelName() | ||||||
|  | 		if result != "fallback-id" { | ||||||
|  | 			t.Errorf("Expected channel name to fallback to 'fallback-id', got '%s'", result) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | @ -43,4 +43,7 @@ type Platform interface { | ||||||
| 
 | 
 | ||||||
| 	// SendMessage sends a message through the platform | 	// SendMessage sends a message through the platform | ||||||
| 	SendMessage(msg *Message) error | 	SendMessage(msg *Message) error | ||||||
|  | 
 | ||||||
|  | 	// DeleteMessage deletes a message from the platform | ||||||
|  | 	DeleteMessage(channel string, messageID string) error | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -2,8 +2,18 @@ package model | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // CacheInterface defines the cache interface available to plugins | ||||||
|  | type CacheInterface interface { | ||||||
|  | 	Get(key string, destination interface{}) error | ||||||
|  | 	Set(key string, value interface{}, expiration *time.Time) error | ||||||
|  | 	SetWithTTL(key string, value interface{}, ttl time.Duration) error | ||||||
|  | 	Delete(key string) error | ||||||
|  | 	Exists(key string) (bool, error) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| var ( | var ( | ||||||
| 	// ErrPluginNotFound is returned when a requested plugin doesn't exist | 	// ErrPluginNotFound is returned when a requested plugin doesn't exist | ||||||
| 	ErrPluginNotFound = errors.New("plugin not found") | 	ErrPluginNotFound = errors.New("plugin not found") | ||||||
|  | @ -13,16 +23,16 @@ var ( | ||||||
| type Plugin interface { | type Plugin interface { | ||||||
| 	// GetID returns the plugin ID | 	// GetID returns the plugin ID | ||||||
| 	GetID() string | 	GetID() string | ||||||
| 	 | 
 | ||||||
| 	// GetName returns the plugin name | 	// GetName returns the plugin name | ||||||
| 	GetName() string | 	GetName() string | ||||||
| 	 | 
 | ||||||
| 	// GetHelp returns the plugin help text | 	// GetHelp returns the plugin help text | ||||||
| 	GetHelp() string | 	GetHelp() string | ||||||
| 	 | 
 | ||||||
| 	// RequiresConfig indicates if the plugin requires configuration | 	// RequiresConfig indicates if the plugin requires configuration | ||||||
| 	RequiresConfig() bool | 	RequiresConfig() bool | ||||||
| 	 | 
 | ||||||
| 	// OnMessage processes an incoming message and returns response messages | 	// OnMessage processes an incoming message and returns platform actions | ||||||
| 	OnMessage(msg *Message, config map[string]interface{}) []*Message | 	OnMessage(msg *Message, config map[string]interface{}, cache CacheInterface) []*MessageAction | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -37,11 +37,15 @@ func (s *SlackPlatform) Init(_ *config.Config) error { | ||||||
| // ParseIncomingMessage parses an incoming Slack message | // ParseIncomingMessage parses an incoming Slack message | ||||||
| func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) { | func (s *SlackPlatform) ParseIncomingMessage(r *http.Request) (*model.Message, error) { | ||||||
| 	// Read request body | 	// Read request body | ||||||
| 	body, err := ioutil.ReadAll(r.Body) | 	body, err := io.ReadAll(r.Body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	defer r.Body.Close() | 	defer func() { | ||||||
|  | 		if err := r.Body.Close(); err != nil { | ||||||
|  | 			fmt.Printf("Error closing request body: %v\n", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	// Parse JSON | 	// Parse JSON | ||||||
| 	var requestData map[string]interface{} | 	var requestData map[string]interface{} | ||||||
|  | @ -163,6 +167,12 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { | ||||||
| 		return errors.New("bot token not configured") | 		return errors.New("bot token not configured") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Check for delete message action | ||||||
|  | 	if msg.Raw != nil && msg.Raw["action"] == "delete" { | ||||||
|  | 		// This is a request to delete a message | ||||||
|  | 		return s.deleteMessage(msg) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Prepare payload | 	// Prepare payload | ||||||
| 	payload := map[string]interface{}{ | 	payload := map[string]interface{}{ | ||||||
| 		"channel": msg.Chat, | 		"channel": msg.Chat, | ||||||
|  | @ -194,7 +204,11 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	defer resp.Body.Close() | 	defer func() { | ||||||
|  | 		if err := resp.Body.Close(); err != nil { | ||||||
|  | 			fmt.Printf("Error closing response body: %v\n", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	// Check response | 	// Check response | ||||||
| 	if resp.StatusCode != http.StatusOK { | 	if resp.StatusCode != http.StatusOK { | ||||||
|  | @ -204,6 +218,63 @@ func (s *SlackPlatform) SendMessage(msg *model.Message) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // DeleteMessage deletes a message on Slack | ||||||
|  | func (s *SlackPlatform) DeleteMessage(channel string, messageID string) error { | ||||||
|  | 	// Prepare payload for chat.delete API | ||||||
|  | 	payload := map[string]interface{}{ | ||||||
|  | 		"channel": channel, | ||||||
|  | 		"ts":      messageID, // In Slack, the ts (timestamp) is the message ID | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Convert payload to JSON | ||||||
|  | 	data, err := json.Marshal(payload) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Send HTTP request to chat.delete endpoint | ||||||
|  | 	req, err := http.NewRequest("POST", "https://slack.com/api/chat.delete", strings.NewReader(string(data))) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	req.Header.Set("Content-Type", "application/json") | ||||||
|  | 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", s.config.BotOAuthAccessToken)) | ||||||
|  | 
 | ||||||
|  | 	client := &http.Client{} | ||||||
|  | 	resp, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		if err := resp.Body.Close(); err != nil { | ||||||
|  | 			fmt.Printf("Error closing response body: %v\n", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	// Check response | ||||||
|  | 	if resp.StatusCode != http.StatusOK { | ||||||
|  | 		respBody, _ := io.ReadAll(resp.Body) | ||||||
|  | 		return fmt.Errorf("slack API error: %d - %s", resp.StatusCode, string(respBody)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // deleteMessage is a legacy method that uses the Raw message approach | ||||||
|  | func (s *SlackPlatform) deleteMessage(msg *model.Message) error { | ||||||
|  | 	// Get message ID to delete | ||||||
|  | 	messageID, ok := msg.Raw["message_id"] | ||||||
|  | 	if !ok { | ||||||
|  | 		return fmt.Errorf("no message ID provided for deletion") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Convert to string if needed | ||||||
|  | 	messageIDStr := fmt.Sprintf("%v", messageID) | ||||||
|  | 
 | ||||||
|  | 	return s.DeleteMessage(msg.Chat, messageIDStr) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Helper function to parse int64 | // Helper function to parse int64 | ||||||
| func parseInt64(s string) (int64, error) { | func parseInt64(s string) (int64, error) { | ||||||
| 	var n int64 | 	var n int64 | ||||||
|  |  | ||||||
|  | @ -62,7 +62,11 @@ func (t *TelegramPlatform) Init(cfg *config.Config) error { | ||||||
| 		t.log.Error("Failed to set webhook", "error", err) | 		t.log.Error("Failed to set webhook", "error", err) | ||||||
| 		return fmt.Errorf("failed to set webhook: %w", err) | 		return fmt.Errorf("failed to set webhook: %w", err) | ||||||
| 	} | 	} | ||||||
| 	defer resp.Body.Close() | 	defer func() { | ||||||
|  | 		if err := resp.Body.Close(); err != nil { | ||||||
|  | 			t.log.Error("Error closing response body", "error", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	if resp.StatusCode != http.StatusOK { | 	if resp.StatusCode != http.StatusOK { | ||||||
| 		bodyBytes, _ := io.ReadAll(resp.Body) | 		bodyBytes, _ := io.ReadAll(resp.Body) | ||||||
|  | @ -85,7 +89,11 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message | ||||||
| 		t.log.Error("Failed to read request body", "error", err) | 		t.log.Error("Failed to read request body", "error", err) | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	defer r.Body.Close() | 	defer func() { | ||||||
|  | 		if err := r.Body.Close(); err != nil { | ||||||
|  | 			t.log.Error("Error closing request body", "error", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	// Parse JSON | 	// Parse JSON | ||||||
| 	var update struct { | 	var update struct { | ||||||
|  | @ -103,8 +111,11 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message | ||||||
| 				Title    string `json:"title,omitempty"` | 				Title    string `json:"title,omitempty"` | ||||||
| 				Username string `json:"username,omitempty"` | 				Username string `json:"username,omitempty"` | ||||||
| 			} `json:"chat"` | 			} `json:"chat"` | ||||||
| 			Date int    `json:"date"` | 			Date           int    `json:"date"` | ||||||
| 			Text string `json:"text"` | 			Text           string `json:"text"` | ||||||
|  | 			ReplyToMessage struct { | ||||||
|  | 				MessageID int `json:"message_id"` | ||||||
|  | 			} `json:"reply_to_message"` | ||||||
| 		} `json:"message"` | 		} `json:"message"` | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -128,6 +139,7 @@ func (t *TelegramPlatform) ParseIncomingMessage(r *http.Request) (*model.Message | ||||||
| 		FromBot: update.Message.From.IsBot, | 		FromBot: update.Message.From.IsBot, | ||||||
| 		Date:    time.Unix(int64(update.Message.Date), 0), | 		Date:    time.Unix(int64(update.Message.Date), 0), | ||||||
| 		ID:      strconv.Itoa(update.Message.MessageID), | 		ID:      strconv.Itoa(update.Message.MessageID), | ||||||
|  | 		ReplyTo: strconv.Itoa(update.Message.ReplyToMessage.MessageID), | ||||||
| 		Raw:     raw, | 		Raw:     raw, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -205,6 +217,13 @@ func (t *TelegramPlatform) ParseChannelFromMessage(body []byte) (map[string]any, | ||||||
| 
 | 
 | ||||||
| // SendMessage sends a message to Telegram | // SendMessage sends a message to Telegram | ||||||
| func (t *TelegramPlatform) SendMessage(msg *model.Message) error { | func (t *TelegramPlatform) SendMessage(msg *model.Message) error { | ||||||
|  | 	// Check for delete message action (legacy method) | ||||||
|  | 	if msg.Raw != nil && msg.Raw["action"] == "delete" { | ||||||
|  | 		// This is a request to delete a message using the legacy method | ||||||
|  | 		return t.deleteMessage(msg) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Regular message sending | ||||||
| 	// Convert chat ID to int64 | 	// Convert chat ID to int64 | ||||||
| 	chatID, err := strconv.ParseInt(msg.Chat, 10, 64) | 	chatID, err := strconv.ParseInt(msg.Chat, 10, 64) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -218,6 +237,15 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error { | ||||||
| 		"text":    msg.Text, | 		"text":    msg.Text, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Set parse_mode based on plugin preference or default to empty string | ||||||
|  | 	if msg.Raw != nil && msg.Raw["parse_mode"] != nil { | ||||||
|  | 		// Plugin explicitly set parse_mode | ||||||
|  | 		payload["parse_mode"] = msg.Raw["parse_mode"] | ||||||
|  | 	} else { | ||||||
|  | 		// Default to empty string (no formatting) | ||||||
|  | 		payload["parse_mode"] = "" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Add reply if needed | 	// Add reply if needed | ||||||
| 	if msg.ReplyTo != "" { | 	if msg.ReplyTo != "" { | ||||||
| 		replyToID, err := strconv.Atoi(msg.ReplyTo) | 		replyToID, err := strconv.Atoi(msg.ReplyTo) | ||||||
|  | @ -247,7 +275,11 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error { | ||||||
| 		t.log.Error("Failed to send message", "error", err) | 		t.log.Error("Failed to send message", "error", err) | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	defer resp.Body.Close() | 	defer func() { | ||||||
|  | 		if err := resp.Body.Close(); err != nil { | ||||||
|  | 			t.log.Error("Error closing response body", "error", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
| 
 | 
 | ||||||
| 	// Check response | 	// Check response | ||||||
| 	if resp.StatusCode != http.StatusOK { | 	if resp.StatusCode != http.StatusOK { | ||||||
|  | @ -259,4 +291,89 @@ func (t *TelegramPlatform) SendMessage(msg *model.Message) error { | ||||||
| 
 | 
 | ||||||
| 	t.log.Debug("Message sent successfully") | 	t.log.Debug("Message sent successfully") | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // DeleteMessage deletes a message on Telegram | ||||||
|  | func (t *TelegramPlatform) DeleteMessage(channel string, messageID string) error { | ||||||
|  | 	// Convert chat ID to int64 | ||||||
|  | 	chatID, err := strconv.ParseInt(channel, 10, 64) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.log.Error("Invalid chat ID for message deletion", "chat_id", channel, "error", err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Convert message ID to integer | ||||||
|  | 	msgID, err := strconv.Atoi(messageID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.log.Error("Invalid message ID for deletion", "message_id", messageID, "error", err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Prepare payload for deleteMessage API | ||||||
|  | 	payload := map[string]interface{}{ | ||||||
|  | 		"chat_id":    chatID, | ||||||
|  | 		"message_id": msgID, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	t.log.Debug("Deleting message on Telegram", "chat_id", chatID, "message_id", msgID) | ||||||
|  | 
 | ||||||
|  | 	// Convert payload to JSON | ||||||
|  | 	data, err := json.Marshal(payload) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.log.Error("Failed to marshal delete message payload", "error", err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Send HTTP request to deleteMessage endpoint | ||||||
|  | 	resp, err := http.Post( | ||||||
|  | 		t.apiURL+"/deleteMessage", | ||||||
|  | 		"application/json", | ||||||
|  | 		bytes.NewBuffer(data), | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.log.Error("Failed to delete message", "error", err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		if err := resp.Body.Close(); err != nil { | ||||||
|  | 			t.log.Error("Error closing response body", "error", err) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	// Check response | ||||||
|  | 	if resp.StatusCode != http.StatusOK { | ||||||
|  | 		bodyBytes, _ := io.ReadAll(resp.Body) | ||||||
|  | 		errMsg := string(bodyBytes) | ||||||
|  | 		t.log.Error("Telegram API error when deleting message", "status", resp.StatusCode, "response", errMsg) | ||||||
|  | 		return fmt.Errorf("telegram API error when deleting message: %d - %s", resp.StatusCode, errMsg) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	t.log.Debug("Message deleted successfully") | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // deleteMessage is a legacy method that uses the Raw message approach | ||||||
|  | func (t *TelegramPlatform) deleteMessage(msg *model.Message) error { | ||||||
|  | 	// Get message ID to delete | ||||||
|  | 	messageIDInterface, ok := msg.Raw["message_id"] | ||||||
|  | 	if !ok { | ||||||
|  | 		t.log.Error("No message ID provided for deletion") | ||||||
|  | 		return fmt.Errorf("no message ID provided for deletion") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Convert message ID to string | ||||||
|  | 	var messageIDStr string | ||||||
|  | 	switch v := messageIDInterface.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		messageIDStr = v | ||||||
|  | 	case int: | ||||||
|  | 		messageIDStr = strconv.Itoa(v) | ||||||
|  | 	case float64: | ||||||
|  | 		messageIDStr = strconv.Itoa(int(v)) | ||||||
|  | 	default: | ||||||
|  | 		t.log.Error("Invalid message ID type for deletion", "type", fmt.Sprintf("%T", messageIDInterface)) | ||||||
|  | 		return fmt.Errorf("invalid message ID type for deletion") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return t.DeleteMessage(msg.Chat, messageIDStr) | ||||||
|  | } | ||||||
|  |  | ||||||
							
								
								
									
										132
									
								
								internal/plugin/domainblock/domainblock.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								internal/plugin/domainblock/domainblock.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,132 @@ | ||||||
|  | package domainblock | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/url" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // DomainBlockPlugin is a plugin that blocks messages containing links from specific domains | ||||||
|  | type DomainBlockPlugin struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Debug helper to check if RequiresConfig is working | ||||||
|  | func (p *DomainBlockPlugin) RequiresConfig() bool { | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new DomainBlockPlugin instance | ||||||
|  | func New() *DomainBlockPlugin { | ||||||
|  | 	return &DomainBlockPlugin{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:             "security.domainblock", | ||||||
|  | 			Name:           "Domain Blocker", | ||||||
|  | 			Help:           "Blocks messages containing links from configured domains", | ||||||
|  | 			ConfigRequired: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // extractDomains extracts domains from a message text | ||||||
|  | func extractDomains(text string) []string { | ||||||
|  | 	// URL regex pattern | ||||||
|  | 	urlPattern := regexp.MustCompile(`https?://([^\s/$.?#].[^\s]*)`) | ||||||
|  | 	matches := urlPattern.FindAllStringSubmatch(text, -1) | ||||||
|  | 
 | ||||||
|  | 	domains := make([]string, 0, len(matches)) | ||||||
|  | 	for _, match := range matches { | ||||||
|  | 		if len(match) < 2 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Try to parse the URL to extract the domain | ||||||
|  | 		urlStr := match[0] | ||||||
|  | 		parsedURL, err := url.Parse(urlStr) | ||||||
|  | 		if err != nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Extract the domain (host) from the URL | ||||||
|  | 		domain := parsedURL.Host | ||||||
|  | 		// Remove port if present | ||||||
|  | 		if i := strings.IndexByte(domain, ':'); i >= 0 { | ||||||
|  | 			domain = domain[:i] | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		domains = append(domains, strings.ToLower(domain)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return domains | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage processes incoming messages | ||||||
|  | func (p *DomainBlockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	// Skip messages from bots | ||||||
|  | 	if msg.FromBot { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get blocked domains from config | ||||||
|  | 	blockedDomainsStr, ok := config["blocked_domains"].(string) | ||||||
|  | 	if !ok || blockedDomainsStr == "" { | ||||||
|  | 		return nil // No blocked domains configured | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Split and clean blocked domains | ||||||
|  | 	blockedDomains := strings.Split(blockedDomainsStr, ",") | ||||||
|  | 	for i, domain := range blockedDomains { | ||||||
|  | 		blockedDomains[i] = strings.ToLower(strings.TrimSpace(domain)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract domains from message | ||||||
|  | 	messageDomains := extractDomains(msg.Text) | ||||||
|  | 	if len(messageDomains) == 0 { | ||||||
|  | 		return nil // No domains in message | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check if any domains in the message are blocked | ||||||
|  | 	for _, msgDomain := range messageDomains { | ||||||
|  | 		for _, blockedDomain := range blockedDomains { | ||||||
|  | 			if blockedDomain == "" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if strings.HasSuffix(msgDomain, blockedDomain) || msgDomain == blockedDomain { | ||||||
|  | 				// Domain is blocked, create actions | ||||||
|  | 
 | ||||||
|  | 				// 1. Create a delete message action | ||||||
|  | 				deleteAction := &model.MessageAction{ | ||||||
|  | 					Type:      model.ActionDeleteMessage, | ||||||
|  | 					MessageID: msg.ID, | ||||||
|  | 					Chat:      msg.Chat, | ||||||
|  | 					Channel:   msg.Channel, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// 2. Create a notification message action | ||||||
|  | 				notificationMsg := &model.Message{ | ||||||
|  | 					Text:    fmt.Sprintf("I don't like links from %s 🙈", blockedDomain), | ||||||
|  | 					Chat:    msg.Chat, | ||||||
|  | 					Channel: msg.Channel, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				sendAction := &model.MessageAction{ | ||||||
|  | 					Type:    model.ActionSendMessage, | ||||||
|  | 					Message: notificationMsg, | ||||||
|  | 					Chat:    msg.Chat, | ||||||
|  | 					Channel: msg.Channel, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				return []*model.MessageAction{deleteAction, sendAction} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Plugin is registered in app.go, not using init() | ||||||
							
								
								
									
										142
									
								
								internal/plugin/domainblock/domainblock_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								internal/plugin/domainblock/domainblock_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,142 @@ | ||||||
|  | package domainblock | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/testutil" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestExtractDomains(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name     string | ||||||
|  | 		text     string | ||||||
|  | 		expected []string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:     "No URLs", | ||||||
|  | 			text:     "Hello, world!", | ||||||
|  | 			expected: []string{}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "Single URL", | ||||||
|  | 			text:     "Check out https://example.com for more info", | ||||||
|  | 			expected: []string{"example.com"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "Multiple URLs", | ||||||
|  | 			text:     "Check out https://example.com and http://test.example.org for more info", | ||||||
|  | 			expected: []string{"example.com", "test.example.org"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "URL with path", | ||||||
|  | 			text:     "Check out https://example.com/path/to/resource", | ||||||
|  | 			expected: []string{"example.com"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "URL with port", | ||||||
|  | 			text:     "Check out https://example.com:8080/path/to/resource", | ||||||
|  | 			expected: []string{"example.com"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "URL with subdomain", | ||||||
|  | 			text:     "Check out https://sub.example.com", | ||||||
|  | 			expected: []string{"sub.example.com"}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, test := range tests { | ||||||
|  | 		t.Run(test.name, func(t *testing.T) { | ||||||
|  | 			domains := extractDomains(test.text) | ||||||
|  | 
 | ||||||
|  | 			if len(domains) != len(test.expected) { | ||||||
|  | 				t.Errorf("Expected %d domains, got %d", len(test.expected), len(domains)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for i, domain := range domains { | ||||||
|  | 				if domain != test.expected[i] { | ||||||
|  | 					t.Errorf("Expected domain %s, got %s", test.expected[i], domain) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestOnMessage(t *testing.T) { | ||||||
|  | 	plugin := New() | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name           string | ||||||
|  | 		text           string | ||||||
|  | 		blockedDomains string | ||||||
|  | 		expectBlocked  bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:           "No blocked domains", | ||||||
|  | 			text:           "Check out https://example.com", | ||||||
|  | 			blockedDomains: "", | ||||||
|  | 			expectBlocked:  false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "No matching domain", | ||||||
|  | 			text:           "Check out https://example.com", | ||||||
|  | 			blockedDomains: "bad.com, evil.org", | ||||||
|  | 			expectBlocked:  false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "Matching domain", | ||||||
|  | 			text:           "Check out https://example.com", | ||||||
|  | 			blockedDomains: "example.com, evil.org", | ||||||
|  | 			expectBlocked:  true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "Matching subdomain", | ||||||
|  | 			text:           "Check out https://sub.example.com", | ||||||
|  | 			blockedDomains: "example.com", | ||||||
|  | 			expectBlocked:  true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "Multiple domains, one matching", | ||||||
|  | 			text:           "Check out https://example.com and https://good.org", | ||||||
|  | 			blockedDomains: "bad.com, example.com", | ||||||
|  | 			expectBlocked:  true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "Spaces in blocked domains list", | ||||||
|  | 			text:           "Check out https://example.com", | ||||||
|  | 			blockedDomains: "bad.com, example.com , evil.org", | ||||||
|  | 			expectBlocked:  true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, test := range tests { | ||||||
|  | 		t.Run(test.name, func(t *testing.T) { | ||||||
|  | 			msg := &model.Message{ | ||||||
|  | 				Text: test.text, | ||||||
|  | 				Chat: "test-chat", | ||||||
|  | 				ID:   "test-id", | ||||||
|  | 				Channel: &model.Channel{ | ||||||
|  | 					ID: 1, | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			config := map[string]interface{}{ | ||||||
|  | 				"blocked_domains": test.blockedDomains, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			mockCache := &testutil.MockCache{} | ||||||
|  | 			responses := plugin.OnMessage(msg, config, mockCache) | ||||||
|  | 
 | ||||||
|  | 			if test.expectBlocked { | ||||||
|  | 				if len(responses) == 0 { | ||||||
|  | 					t.Errorf("Expected message to be blocked, but it wasn't") | ||||||
|  | 				} | ||||||
|  | 			} else { | ||||||
|  | 				if len(responses) > 0 { | ||||||
|  | 					t.Errorf("Expected message not to be blocked, but it was") | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -29,7 +29,7 @@ func NewCoin() *CoinPlugin { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // OnMessage handles incoming messages | // OnMessage handles incoming messages | ||||||
| func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { | func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
| 	if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") { | 	if !strings.Contains(strings.ToLower(msg.Text), "flip a coin") { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  | @ -46,5 +46,12 @@ func (p *CoinPlugin) OnMessage(msg *model.Message, config map[string]interface{} | ||||||
| 		Channel: msg.Channel, | 		Channel: msg.Channel, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return []*model.Message{response} | 	action := &model.MessageAction{ | ||||||
|  | 		Type:    model.ActionSendMessage, | ||||||
|  | 		Message: response, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{action} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -32,7 +32,7 @@ func NewDice() *DicePlugin { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // OnMessage handles incoming messages | // OnMessage handles incoming messages | ||||||
| func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { | func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
| 	if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") { | 	if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(msg.Text)), "!dice") { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  | @ -62,7 +62,14 @@ func (p *DicePlugin) OnMessage(msg *model.Message, config map[string]interface{} | ||||||
| 		Channel: msg.Channel, | 		Channel: msg.Channel, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return []*model.Message{response} | 	action := &model.MessageAction{ | ||||||
|  | 		Type:    model.ActionSendMessage, | ||||||
|  | 		Message: response, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{action} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // rollDice parses a dice formula string and returns the result | // rollDice parses a dice formula string and returns the result | ||||||
|  | @ -107,9 +114,10 @@ func (p *DicePlugin) rollDice(formula string) (int, error) { | ||||||
| 			return 0, fmt.Errorf("invalid modifier") | 			return 0, fmt.Errorf("invalid modifier") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if matches[3] == "+" { | 		switch matches[3] { | ||||||
|  | 		case "+": | ||||||
| 			total += modifier | 			total += modifier | ||||||
| 		} else if matches[3] == "-" { | 		case "-": | ||||||
| 			total -= modifier | 			total -= modifier | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
							
								
								
									
										540
									
								
								internal/plugin/fun/hltb.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										540
									
								
								internal/plugin/fun/hltb.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,540 @@ | ||||||
|  | package fun | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // HLTBPlugin searches HowLongToBeat for game completion times | ||||||
|  | type HLTBPlugin struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | 	httpClient *http.Client | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // HLTBGame represents a game from HowLongToBeat | ||||||
|  | type HLTBGame struct { | ||||||
|  | 	ID             int    `json:"game_id"` | ||||||
|  | 	Name           string `json:"game_name"` | ||||||
|  | 	GameAlias      string `json:"game_alias"` | ||||||
|  | 	GameImage      string `json:"game_image"` | ||||||
|  | 	CompMain       int    `json:"comp_main"` | ||||||
|  | 	CompPlus       int    `json:"comp_plus"` | ||||||
|  | 	CompComplete   int    `json:"comp_complete"` | ||||||
|  | 	CompAll        int    `json:"comp_all"` | ||||||
|  | 	InvestedCo     int    `json:"invested_co"` | ||||||
|  | 	InvestedMp     int    `json:"invested_mp"` | ||||||
|  | 	CountComp      int    `json:"count_comp"` | ||||||
|  | 	CountSpeedruns int    `json:"count_speedruns"` | ||||||
|  | 	CountBacklog   int    `json:"count_backlog"` | ||||||
|  | 	CountReview    int    `json:"count_review"` | ||||||
|  | 	ReviewScore    int    `json:"review_score"` | ||||||
|  | 	CountPlaying   int    `json:"count_playing"` | ||||||
|  | 	CountRetired   int    `json:"count_retired"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewHLTB creates a new HLTBPlugin instance | ||||||
|  | func NewHLTB() *HLTBPlugin { | ||||||
|  | 	return &HLTBPlugin{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:   "fun.hltb", | ||||||
|  | 			Name: "How Long To Beat", | ||||||
|  | 			Help: "Get game completion times from HowLongToBeat.com using `!hltb <game name>`", | ||||||
|  | 		}, | ||||||
|  | 		httpClient: &http.Client{ | ||||||
|  | 			Timeout: 10 * time.Second, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage handles incoming messages | ||||||
|  | func (p *HLTBPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	// Check if message starts with !hltb | ||||||
|  | 	text := strings.TrimSpace(msg.Text) | ||||||
|  | 	if !strings.HasPrefix(text, "!hltb ") { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract game name | ||||||
|  | 	gameName := strings.TrimSpace(text[6:]) // Remove "!hltb " | ||||||
|  | 	if gameName == "" { | ||||||
|  | 		return p.createErrorResponse(msg, "Please provide a game name. Usage: !hltb <game name>") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check cache first | ||||||
|  | 	var games []HLTBGame | ||||||
|  | 	var err error | ||||||
|  | 	cacheKey := strings.ToLower(gameName) | ||||||
|  | 
 | ||||||
|  | 	err = cache.Get(cacheKey, &games) | ||||||
|  | 	if err != nil || len(games) == 0 { | ||||||
|  | 		// Cache miss - search for the game | ||||||
|  | 		games, err = p.searchGame(gameName) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return p.createErrorResponse(msg, fmt.Sprintf("Error searching for game: %s", err.Error())) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if len(games) == 0 { | ||||||
|  | 			return p.createErrorResponse(msg, fmt.Sprintf("No results found for '%s'", gameName)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Cache the results for 1 hour | ||||||
|  | 		err = cache.SetWithTTL(cacheKey, games, time.Hour) | ||||||
|  | 		if err != nil { | ||||||
|  | 			// Log cache error but don't fail the request | ||||||
|  | 			fmt.Printf("Warning: Failed to cache HLTB results: %v\n", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Use the first result | ||||||
|  | 	game := games[0] | ||||||
|  | 
 | ||||||
|  | 	// Format the response | ||||||
|  | 	response := p.formatGameInfo(game) | ||||||
|  | 
 | ||||||
|  | 	// Create response message with game cover if available | ||||||
|  | 	responseMsg := &model.Message{ | ||||||
|  | 		Text:    response, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		ReplyTo: msg.ID, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Set parse mode for markdown formatting | ||||||
|  | 	if responseMsg.Raw == nil { | ||||||
|  | 		responseMsg.Raw = make(map[string]interface{}) | ||||||
|  | 	} | ||||||
|  | 	responseMsg.Raw["parse_mode"] = "Markdown" | ||||||
|  | 
 | ||||||
|  | 	// Add game cover as attachment if available | ||||||
|  | 	if game.GameImage != "" { | ||||||
|  | 		imageURL := p.getFullImageURL(game.GameImage) | ||||||
|  | 		responseMsg.Raw["image_url"] = imageURL | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	action := &model.MessageAction{ | ||||||
|  | 		Type:    model.ActionSendMessage, | ||||||
|  | 		Message: responseMsg, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{action} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // searchGame searches for a game on HowLongToBeat using the API | ||||||
|  | func (p *HLTBPlugin) searchGame(gameName string) ([]HLTBGame, error) { | ||||||
|  | 	// Only the seek token endpoint works now | ||||||
|  | 	return p.searchWithSeekToken(gameName) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // searchWithSeekToken attempts to search using the seek token approach | ||||||
|  | func (p *HLTBPlugin) searchWithSeekToken(gameName string) ([]HLTBGame, error) { | ||||||
|  | 	// Get the seek token from the main page | ||||||
|  | 	seekToken, err := p.getSeekToken() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to get seek token: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Split search terms by words | ||||||
|  | 	searchTerms := strings.Fields(gameName) | ||||||
|  | 
 | ||||||
|  | 	// Create search URL with seek token | ||||||
|  | 	searchURL := fmt.Sprintf("https://howlongtobeat.com/api/seek/%s", seekToken) | ||||||
|  | 
 | ||||||
|  | 	// Prepare search request | ||||||
|  | 	searchRequest := map[string]interface{}{ | ||||||
|  | 		"searchType":  "games", | ||||||
|  | 		"searchTerms": searchTerms, | ||||||
|  | 		"searchPage":  1, | ||||||
|  | 		"size":        20, | ||||||
|  | 		"searchOptions": map[string]interface{}{ | ||||||
|  | 			"games": map[string]interface{}{ | ||||||
|  | 				"userId":        0, | ||||||
|  | 				"platform":      "", | ||||||
|  | 				"sortCategory":  "popular", | ||||||
|  | 				"rangeCategory": "main", | ||||||
|  | 				"rangeTime": map[string]interface{}{ | ||||||
|  | 					"min": nil, | ||||||
|  | 					"max": nil, | ||||||
|  | 				}, | ||||||
|  | 				"gameplay": map[string]interface{}{ | ||||||
|  | 					"perspective": "", | ||||||
|  | 					"flow":        "", | ||||||
|  | 					"genre":       "", | ||||||
|  | 					"difficulty":  "", | ||||||
|  | 				}, | ||||||
|  | 				"rangeYear": map[string]interface{}{ | ||||||
|  | 					"min": "", | ||||||
|  | 					"max": "", | ||||||
|  | 				}, | ||||||
|  | 				"modifier": "", | ||||||
|  | 			}, | ||||||
|  | 			"users": map[string]interface{}{ | ||||||
|  | 				"sortCategory": "postcount", | ||||||
|  | 			}, | ||||||
|  | 			"lists": map[string]interface{}{ | ||||||
|  | 				"sortCategory": "follows", | ||||||
|  | 			}, | ||||||
|  | 			"filter":     "", | ||||||
|  | 			"sort":       0, | ||||||
|  | 			"randomizer": 0, | ||||||
|  | 		}, | ||||||
|  | 		"useCache": true, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return p.performAPISearch(searchURL, searchRequest) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // performAPISearch performs the actual API search request | ||||||
|  | func (p *HLTBPlugin) performAPISearch(searchURL string, searchRequest map[string]interface{}) ([]HLTBGame, error) { | ||||||
|  | 	// Convert to JSON | ||||||
|  | 	jsonData, err := json.Marshal(searchRequest) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to marshal search request: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create HTTP request | ||||||
|  | 	req, err := http.NewRequest("POST", searchURL, bytes.NewBuffer(jsonData)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to create request: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Set headers to match the working curl request | ||||||
|  | 	req.Header.Set("Accept", "*/*") | ||||||
|  | 	req.Header.Set("Accept-Language", "en-US,en;q=0.9") | ||||||
|  | 	req.Header.Set("Cache-Control", "no-cache") | ||||||
|  | 	req.Header.Set("Content-Type", "application/json") | ||||||
|  | 	req.Header.Set("Origin", "https://howlongtobeat.com") | ||||||
|  | 	req.Header.Set("Pragma", "no-cache") | ||||||
|  | 	req.Header.Set("Referer", "https://howlongtobeat.com/") | ||||||
|  | 	req.Header.Set("Sec-Fetch-Dest", "empty") | ||||||
|  | 	req.Header.Set("Sec-Fetch-Mode", "cors") | ||||||
|  | 	req.Header.Set("Sec-Fetch-Site", "same-origin") | ||||||
|  | 	req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36") | ||||||
|  | 
 | ||||||
|  | 	// Send request | ||||||
|  | 	resp, err := p.httpClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to send request: %w", err) | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		_ = resp.Body.Close() | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	if resp.StatusCode != http.StatusOK { | ||||||
|  | 		return nil, fmt.Errorf("API returned status code: %d", resp.StatusCode) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Read response body | ||||||
|  | 	body, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to read response: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Parse response | ||||||
|  | 	var searchResponse struct { | ||||||
|  | 		Color         string     `json:"color"` | ||||||
|  | 		Title         string     `json:"title"` | ||||||
|  | 		Category      string     `json:"category"` | ||||||
|  | 		Count         int        `json:"count"` | ||||||
|  | 		Pagecurrent   int        `json:"pagecurrent"` | ||||||
|  | 		Pagesize      int        `json:"pagesize"` | ||||||
|  | 		Pagetotal     int        `json:"pagetotal"` | ||||||
|  | 		SearchTerm    string     `json:"searchTerm"` | ||||||
|  | 		SearchResults []HLTBGame `json:"data"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := json.Unmarshal(body, &searchResponse); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to parse response: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return searchResponse.SearchResults, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // formatGameInfo formats game information for display | ||||||
|  | func (p *HLTBPlugin) formatGameInfo(game HLTBGame) string { | ||||||
|  | 	var response strings.Builder | ||||||
|  | 
 | ||||||
|  | 	response.WriteString(fmt.Sprintf("🎮 **%s**\n\n", game.Name)) | ||||||
|  | 
 | ||||||
|  | 	// Format completion times | ||||||
|  | 	if game.CompMain > 0 { | ||||||
|  | 		response.WriteString(fmt.Sprintf("📖 **Main Story:** %s\n", p.formatTime(game.CompMain))) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if game.CompPlus > 0 { | ||||||
|  | 		response.WriteString(fmt.Sprintf("➕ **Main + Extras:** %s\n", p.formatTime(game.CompPlus))) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if game.CompComplete > 0 { | ||||||
|  | 		response.WriteString(fmt.Sprintf("💯 **Completionist:** %s\n", p.formatTime(game.CompComplete))) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if game.CompAll > 0 { | ||||||
|  | 		response.WriteString(fmt.Sprintf("🎯 **All Styles:** %s\n", p.formatTime(game.CompAll))) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Add review score if available | ||||||
|  | 	if game.ReviewScore > 0 { | ||||||
|  | 		response.WriteString(fmt.Sprintf("\n⭐ **User Score:** %d/100", game.ReviewScore)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Add source attribution | ||||||
|  | 	response.WriteString("\n\n*Source: HowLongToBeat.com*") | ||||||
|  | 
 | ||||||
|  | 	return response.String() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // formatTime converts seconds to a readable time format | ||||||
|  | func (p *HLTBPlugin) formatTime(seconds int) string { | ||||||
|  | 	if seconds <= 0 { | ||||||
|  | 		return "N/A" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	hours := float64(seconds) / 3600.0 | ||||||
|  | 
 | ||||||
|  | 	if hours < 1 { | ||||||
|  | 		minutes := seconds / 60 | ||||||
|  | 		return fmt.Sprintf("%d minutes", minutes) | ||||||
|  | 	} else if hours < 2 { | ||||||
|  | 		return fmt.Sprintf("%.1f hour", hours) | ||||||
|  | 	} else { | ||||||
|  | 		return fmt.Sprintf("%.1f hours", hours) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getFullImageURL constructs the full image URL | ||||||
|  | func (p *HLTBPlugin) getFullImageURL(imagePath string) string { | ||||||
|  | 	if imagePath == "" { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Remove leading slash if present | ||||||
|  | 	imagePath = strings.TrimPrefix(imagePath, "/") | ||||||
|  | 
 | ||||||
|  | 	return fmt.Sprintf("https://howlongtobeat.com/games/%s", imagePath) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getSeekToken retrieves the seek token from HowLongToBeat | ||||||
|  | func (p *HLTBPlugin) getSeekToken() (string, error) { | ||||||
|  | 	// Get the main page to extract buildId | ||||||
|  | 	req, err := http.NewRequest("GET", "https://howlongtobeat.com", nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("failed to create token request: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36") | ||||||
|  | 
 | ||||||
|  | 	resp, err := p.httpClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("failed to fetch token: %w", err) | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		_ = resp.Body.Close() | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	body, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("failed to read token response: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	bodyStr := string(body) | ||||||
|  | 
 | ||||||
|  | 	// First, try to find buildId in the __NEXT_DATA__ or page source | ||||||
|  | 	buildIdPatterns := []string{ | ||||||
|  | 		`"buildId":"([a-zA-Z0-9_-]+)"`, | ||||||
|  | 		`buildId":"([a-zA-Z0-9_-]+)"`, | ||||||
|  | 		`/_next/static/([a-zA-Z0-9_-]+)/_buildManifest`, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, pattern := range buildIdPatterns { | ||||||
|  | 		re := regexp.MustCompile(pattern) | ||||||
|  | 		matches := re.FindStringSubmatch(bodyStr) | ||||||
|  | 		if len(matches) > 1 { | ||||||
|  | 			buildId := matches[1] | ||||||
|  | 			// Now try to get the seek token from the JavaScript files using buildId | ||||||
|  | 			if token, err := p.getSeekTokenFromBuildId(buildId); err == nil { | ||||||
|  | 				return token, nil | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// If we can't find buildId, look for direct seek token patterns | ||||||
|  | 	seekPatterns := []string{ | ||||||
|  | 		`/api/seek/([a-f0-9]{16})`, | ||||||
|  | 		`"seek/([a-f0-9]{16})"`, | ||||||
|  | 		`api/seek/([a-f0-9]{16})`, | ||||||
|  | 		`seek/([a-f0-9]{12,})`, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, pattern := range seekPatterns { | ||||||
|  | 		re := regexp.MustCompile(pattern) | ||||||
|  | 		matches := re.FindStringSubmatch(bodyStr) | ||||||
|  | 		if len(matches) > 1 { | ||||||
|  | 			return matches[1], nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Last resort: try multiple known working tokens | ||||||
|  | 	knownTokens := []string{ | ||||||
|  | 		"6e17f7a193ef3188", // From your curl example | ||||||
|  | 		"d4b2e330db04dbf3", // Common fallback | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, token := range knownTokens { | ||||||
|  | 		if p.testSeekToken(token) { | ||||||
|  | 			return token, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Generate a token as last resort | ||||||
|  | 	return p.generateSeekToken(), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getSeekTokenFromBuildId attempts to extract seek token from build-specific files | ||||||
|  | func (p *HLTBPlugin) getSeekTokenFromBuildId(buildId string) (string, error) { | ||||||
|  | 	// Common build file patterns where seek tokens might be stored | ||||||
|  | 	fileURLs := []string{ | ||||||
|  | 		fmt.Sprintf("https://howlongtobeat.com/_next/static/%s/_buildManifest.js", buildId), | ||||||
|  | 		fmt.Sprintf("https://howlongtobeat.com/_next/static/%s/_ssgManifest.js", buildId), | ||||||
|  | 		fmt.Sprintf("https://howlongtobeat.com/_next/static/chunks/pages/index-%s.js", buildId[:12]), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, fileURL := range fileURLs { | ||||||
|  | 		if token, err := p.extractSeekTokenFromFile(fileURL); err == nil && token != "" { | ||||||
|  | 			return token, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return "", fmt.Errorf("no seek token found in build files") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // extractSeekTokenFromFile downloads and searches a file for seek token | ||||||
|  | func (p *HLTBPlugin) extractSeekTokenFromFile(fileURL string) (string, error) { | ||||||
|  | 	req, err := http.NewRequest("GET", fileURL, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36") | ||||||
|  | 
 | ||||||
|  | 	resp, err := p.httpClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		_ = resp.Body.Close() | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	if resp.StatusCode != http.StatusOK { | ||||||
|  | 		return "", fmt.Errorf("failed to fetch file: %d", resp.StatusCode) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	body, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	bodyStr := string(body) | ||||||
|  | 	patterns := []string{ | ||||||
|  | 		`seek/([a-f0-9]{16})`, | ||||||
|  | 		`"([a-f0-9]{16})"`, | ||||||
|  | 		`'([a-f0-9]{16})'`, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, pattern := range patterns { | ||||||
|  | 		re := regexp.MustCompile(pattern) | ||||||
|  | 		matches := re.FindStringSubmatch(bodyStr) | ||||||
|  | 		if len(matches) > 1 { | ||||||
|  | 			return matches[1], nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return "", fmt.Errorf("no seek token found in file") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // testSeekToken tests if a seek token works by making a simple API call | ||||||
|  | func (p *HLTBPlugin) testSeekToken(token string) bool { | ||||||
|  | 	searchURL := fmt.Sprintf("https://howlongtobeat.com/api/seek/%s", token) | ||||||
|  | 	searchRequest := map[string]interface{}{ | ||||||
|  | 		"searchType":  "games", | ||||||
|  | 		"searchTerms": []string{"test"}, | ||||||
|  | 		"searchPage":  1, | ||||||
|  | 		"size":        1, | ||||||
|  | 		"searchOptions": map[string]interface{}{ | ||||||
|  | 			"games": map[string]interface{}{ | ||||||
|  | 				"userId":        0, | ||||||
|  | 				"platform":      "", | ||||||
|  | 				"sortCategory":  "popular", | ||||||
|  | 				"rangeCategory": "main", | ||||||
|  | 				"rangeTime": map[string]interface{}{ | ||||||
|  | 					"min": nil, | ||||||
|  | 					"max": nil, | ||||||
|  | 				}, | ||||||
|  | 				"gameplay": map[string]interface{}{ | ||||||
|  | 					"perspective": "", | ||||||
|  | 					"flow":        "", | ||||||
|  | 					"genre":       "", | ||||||
|  | 					"difficulty":  "", | ||||||
|  | 				}, | ||||||
|  | 				"rangeYear": map[string]interface{}{ | ||||||
|  | 					"min": "", | ||||||
|  | 					"max": "", | ||||||
|  | 				}, | ||||||
|  | 				"modifier": "", | ||||||
|  | 			}, | ||||||
|  | 			"users": map[string]interface{}{ | ||||||
|  | 				"sortCategory": "postcount", | ||||||
|  | 			}, | ||||||
|  | 			"lists": map[string]interface{}{ | ||||||
|  | 				"sortCategory": "follows", | ||||||
|  | 			}, | ||||||
|  | 			"filter":     "", | ||||||
|  | 			"sort":       0, | ||||||
|  | 			"randomizer": 0, | ||||||
|  | 		}, | ||||||
|  | 		"useCache": true, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Test the token with a simple search | ||||||
|  | 	if _, err := p.performAPISearch(searchURL, searchRequest); err == nil { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // generateSeekToken generates a seek token based on current time | ||||||
|  | func (p *HLTBPlugin) generateSeekToken() string { | ||||||
|  | 	// Use a simple hash-like approach with current timestamp | ||||||
|  | 	// This is a fallback approach since the real token generation is unknown | ||||||
|  | 	now := time.Now().Unix() | ||||||
|  | 	return fmt.Sprintf("%x", now%0xffffffff)[:16] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // createErrorResponse creates an error response message | ||||||
|  | func (p *HLTBPlugin) createErrorResponse(msg *model.Message, errorText string) []*model.MessageAction { | ||||||
|  | 	response := &model.Message{ | ||||||
|  | 		Text:    fmt.Sprintf("❌ %s", errorText), | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		ReplyTo: msg.ID, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	action := &model.MessageAction{ | ||||||
|  | 		Type:    model.ActionSendMessage, | ||||||
|  | 		Message: response, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{action} | ||||||
|  | } | ||||||
							
								
								
									
										131
									
								
								internal/plugin/fun/hltb_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								internal/plugin/fun/hltb_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,131 @@ | ||||||
|  | package fun | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/testutil" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestHLTBPlugin_OnMessage(t *testing.T) { | ||||||
|  | 	plugin := NewHLTB() | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name          string | ||||||
|  | 		messageText   string | ||||||
|  | 		shouldRespond bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:          "responds to !hltb command", | ||||||
|  | 			messageText:   "!hltb The Witcher 3", | ||||||
|  | 			shouldRespond: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "ignores non-hltb messages", | ||||||
|  | 			messageText:   "hello world", | ||||||
|  | 			shouldRespond: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "ignores !hltb without game name", | ||||||
|  | 			messageText:   "!hltb", | ||||||
|  | 			shouldRespond: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "ignores !hltb with only spaces", | ||||||
|  | 			messageText:   "!hltb   ", | ||||||
|  | 			shouldRespond: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "ignores similar but incorrect commands", | ||||||
|  | 			messageText:   "hltb The Witcher 3", | ||||||
|  | 			shouldRespond: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			msg := &model.Message{ | ||||||
|  | 				Text:    tt.messageText, | ||||||
|  | 				Chat:    "test-chat", | ||||||
|  | 				Channel: &model.Channel{ID: 1}, | ||||||
|  | 				Author:  "test-user", | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			mockCache := &testutil.MockCache{} | ||||||
|  | 			actions := plugin.OnMessage(msg, make(map[string]interface{}), mockCache) | ||||||
|  | 
 | ||||||
|  | 			if tt.shouldRespond && len(actions) == 0 { | ||||||
|  | 				t.Errorf("Expected plugin to respond to '%s', but it didn't", tt.messageText) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if !tt.shouldRespond && len(actions) > 0 { | ||||||
|  | 				t.Errorf("Expected plugin to not respond to '%s', but it did", tt.messageText) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// For messages that should respond, verify the response structure | ||||||
|  | 			if tt.shouldRespond && len(actions) > 0 { | ||||||
|  | 				action := actions[0] | ||||||
|  | 				if action.Type != model.ActionSendMessage { | ||||||
|  | 					t.Errorf("Expected ActionSendMessage, got %s", action.Type) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if action.Message == nil { | ||||||
|  | 					t.Error("Expected action to have a message") | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if action.Message != nil && action.Message.ReplyTo != msg.ID { | ||||||
|  | 					t.Error("Expected response to reply to original message") | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestHLTBPlugin_formatTime(t *testing.T) { | ||||||
|  | 	plugin := NewHLTB() | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		seconds  int | ||||||
|  | 		expected string | ||||||
|  | 	}{ | ||||||
|  | 		{0, "N/A"}, | ||||||
|  | 		{-1, "N/A"}, | ||||||
|  | 		{1800, "30 minutes"},  // 30 minutes | ||||||
|  | 		{3600, "1.0 hour"},    // 1 hour | ||||||
|  | 		{7200, "2.0 hours"},   // 2 hours | ||||||
|  | 		{10800, "3.0 hours"},  // 3 hours | ||||||
|  | 		{36000, "10.0 hours"}, // 10 hours | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.expected, func(t *testing.T) { | ||||||
|  | 			result := plugin.formatTime(tt.seconds) | ||||||
|  | 			if result != tt.expected { | ||||||
|  | 				t.Errorf("formatTime(%d) = %s, want %s", tt.seconds, result, tt.expected) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestHLTBPlugin_getFullImageURL(t *testing.T) { | ||||||
|  | 	plugin := NewHLTB() | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		imagePath string | ||||||
|  | 		expected  string | ||||||
|  | 	}{ | ||||||
|  | 		{"", ""}, | ||||||
|  | 		{"game.jpg", "https://howlongtobeat.com/games/game.jpg"}, | ||||||
|  | 		{"/game.jpg", "https://howlongtobeat.com/games/game.jpg"}, | ||||||
|  | 		{"folder/game.png", "https://howlongtobeat.com/games/folder/game.png"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.imagePath, func(t *testing.T) { | ||||||
|  | 			result := plugin.getFullImageURL(tt.imagePath) | ||||||
|  | 			if result != tt.expected { | ||||||
|  | 				t.Errorf("getFullImageURL(%s) = %s, want %s", tt.imagePath, result, tt.expected) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -23,8 +23,13 @@ func NewLoquito() *LoquitoPlugin { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetHelp returns the plugin help text | ||||||
|  | func (p *LoquitoPlugin) GetHelp() string { | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // OnMessage handles incoming messages | // OnMessage handles incoming messages | ||||||
| func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { | func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
| 	if !strings.Contains(strings.ToLower(msg.Text), "lo quito") { | 	if !strings.Contains(strings.ToLower(msg.Text), "lo quito") { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  | @ -36,5 +41,12 @@ func (p *LoquitoPlugin) OnMessage(msg *model.Message, config map[string]interfac | ||||||
| 		Channel: msg.Channel, | 		Channel: msg.Channel, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return []*model.Message{response} | 	action := &model.MessageAction{ | ||||||
|  | 		Type:    model.ActionSendMessage, | ||||||
|  | 		Message: response, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{action} | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										166
									
								
								internal/plugin/help/help.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								internal/plugin/help/help.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,166 @@ | ||||||
|  | package help | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"sort" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/db" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | 	"golang.org/x/exp/slog" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // ChannelPluginGetter is an interface for getting channel plugins | ||||||
|  | type ChannelPluginGetter interface { | ||||||
|  | 	GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) | ||||||
|  | 	GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // HelpPlugin provides help information about available commands | ||||||
|  | type HelpPlugin struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | 	db ChannelPluginGetter | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new HelpPlugin instance | ||||||
|  | func New(db ChannelPluginGetter) *HelpPlugin { | ||||||
|  | 	return &HelpPlugin{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:   "utility.help", | ||||||
|  | 			Name: "Help", | ||||||
|  | 			Help: "Shows available commands when you type '!help'", | ||||||
|  | 		}, | ||||||
|  | 		db: db, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage handles incoming messages | ||||||
|  | func (p *HelpPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	// Check if message is the help command | ||||||
|  | 	if !strings.EqualFold(strings.TrimSpace(msg.Text), "!help") { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get channel plugins from database using platform and platform channel ID | ||||||
|  | 	channelPlugins, err := p.db.GetChannelPluginsFromPlatformID(msg.Channel.Platform, msg.Channel.PlatformChannelID) | ||||||
|  | 	if err != nil && err != db.ErrNotFound { | ||||||
|  | 		slog.Error("Failed to get channel plugins", slog.Any("err", err)) | ||||||
|  | 		return []*model.MessageAction{} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// If no plugins found, initialize empty slice | ||||||
|  | 	if err == db.ErrNotFound { | ||||||
|  | 		channelPlugins = []*model.ChannelPlugin{} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get all available plugins | ||||||
|  | 	availablePlugins := plugin.GetAvailablePlugins() | ||||||
|  | 
 | ||||||
|  | 	// Filter to only enabled plugins for this channel | ||||||
|  | 	enabledPlugins := make(map[string]model.Plugin) | ||||||
|  | 	for _, channelPlugin := range channelPlugins { | ||||||
|  | 		if channelPlugin.Enabled { | ||||||
|  | 			if availablePlugin, exists := availablePlugins[channelPlugin.PluginID]; exists { | ||||||
|  | 				enabledPlugins[channelPlugin.PluginID] = availablePlugin | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// If no plugins are enabled, return a message | ||||||
|  | 	if len(enabledPlugins) == 0 { | ||||||
|  | 		response := &model.Message{ | ||||||
|  | 			Text:    "No plugins are currently enabled for this channel.", | ||||||
|  | 			Chat:    msg.Chat, | ||||||
|  | 			ReplyTo: msg.ID, | ||||||
|  | 			Channel: msg.Channel, | ||||||
|  | 			Raw:     map[string]interface{}{"parse_mode": "Markdown"}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return []*model.MessageAction{ | ||||||
|  | 			{ | ||||||
|  | 				Type:    model.ActionSendMessage, | ||||||
|  | 				Message: response, | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Group plugins by category | ||||||
|  | 	categories := map[string][]model.Plugin{ | ||||||
|  | 		"Development":           {}, | ||||||
|  | 		"Fun and Entertainment": {}, | ||||||
|  | 		"Utility":               {}, | ||||||
|  | 		"Security":              {}, | ||||||
|  | 		"Social Media":          {}, | ||||||
|  | 		"Other":                 {}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Categorize plugins based on their ID prefix | ||||||
|  | 	for _, p := range enabledPlugins { | ||||||
|  | 		category := p.GetID() | ||||||
|  | 		switch { | ||||||
|  | 		case strings.HasPrefix(category, "dev."): | ||||||
|  | 			categories["Development"] = append(categories["Development"], p) | ||||||
|  | 		case strings.HasPrefix(category, "fun."): | ||||||
|  | 			categories["Fun and Entertainment"] = append(categories["Fun and Entertainment"], p) | ||||||
|  | 		case strings.HasPrefix(category, "util.") || strings.HasPrefix(category, "reminder.") || strings.HasPrefix(category, "utility."): | ||||||
|  | 			categories["Utility"] = append(categories["Utility"], p) | ||||||
|  | 		case strings.HasPrefix(category, "security."): | ||||||
|  | 			categories["Security"] = append(categories["Security"], p) | ||||||
|  | 		case strings.HasPrefix(category, "social."): | ||||||
|  | 			categories["Social Media"] = append(categories["Social Media"], p) | ||||||
|  | 		default: | ||||||
|  | 			categories["Other"] = append(categories["Other"], p) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Build the help message | ||||||
|  | 	var helpText strings.Builder | ||||||
|  | 	helpText.WriteString("🤖 **Available Commands**\n\n") | ||||||
|  | 
 | ||||||
|  | 	// Sort category names for consistent output | ||||||
|  | 	categoryOrder := []string{"Development", "Fun and Entertainment", "Utility", "Security", "Social Media", "Other"} | ||||||
|  | 
 | ||||||
|  | 	for _, categoryName := range categoryOrder { | ||||||
|  | 		pluginList := categories[categoryName] | ||||||
|  | 		if len(pluginList) == 0 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Sort plugins within category by name | ||||||
|  | 		sort.Slice(pluginList, func(i, j int) bool { | ||||||
|  | 			return pluginList[i].GetName() < pluginList[j].GetName() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		helpText.WriteString(fmt.Sprintf("**%s:**\n", categoryName)) | ||||||
|  | 		for _, p := range pluginList { | ||||||
|  | 			if p.GetHelp() == "" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			helpText.WriteString(fmt.Sprintf("• **%s** - %s\n", p.GetName(), p.GetHelp())) | ||||||
|  | 		} | ||||||
|  | 		helpText.WriteString("\n") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Add footer | ||||||
|  | 	helpText.WriteString("_Use the specific commands or triggers mentioned above to interact with the bot._") | ||||||
|  | 
 | ||||||
|  | 	response := &model.Message{ | ||||||
|  | 		Text:    helpText.String(), | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		ReplyTo: msg.ID, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 		Raw:     map[string]interface{}{"parse_mode": "Markdown"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{ | ||||||
|  | 		{ | ||||||
|  | 			Type:    model.ActionSendMessage, | ||||||
|  | 			Message: response, | ||||||
|  | 			Chat:    msg.Chat, | ||||||
|  | 			Channel: msg.Channel, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										206
									
								
								internal/plugin/help/help_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								internal/plugin/help/help_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,206 @@ | ||||||
|  | package help | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/db" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // MockPlugin implements the Plugin interface for testing | ||||||
|  | type MockPlugin struct { | ||||||
|  | 	id   string | ||||||
|  | 	name string | ||||||
|  | 	help string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *MockPlugin) GetID() string   { return m.id } | ||||||
|  | func (m *MockPlugin) GetName() string { return m.name } | ||||||
|  | func (m *MockPlugin) GetHelp() string { return m.help } | ||||||
|  | func (m *MockPlugin) RequiresConfig() bool { | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | func (m *MockPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // MockDatabase implements the ChannelPluginGetter interface for testing | ||||||
|  | type MockDatabase struct { | ||||||
|  | 	channelPlugins         map[int64][]*model.ChannelPlugin | ||||||
|  | 	platformChannelPlugins map[string][]*model.ChannelPlugin // key: "platform:platformChannelID" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *MockDatabase) GetChannelPlugins(channelID int64) ([]*model.ChannelPlugin, error) { | ||||||
|  | 	if plugins, exists := m.channelPlugins[channelID]; exists { | ||||||
|  | 		return plugins, nil | ||||||
|  | 	} | ||||||
|  | 	return nil, db.ErrNotFound | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *MockDatabase) GetChannelPluginsFromPlatformID(platform, platformChannelID string) ([]*model.ChannelPlugin, error) { | ||||||
|  | 	key := platform + ":" + platformChannelID | ||||||
|  | 	if plugins, exists := m.platformChannelPlugins[key]; exists { | ||||||
|  | 		return plugins, nil | ||||||
|  | 	} | ||||||
|  | 	return nil, db.ErrNotFound | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestHelpPlugin_OnMessage(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name             string | ||||||
|  | 		messageText      string | ||||||
|  | 		enabledPlugins   map[string]*MockPlugin | ||||||
|  | 		expectResponse   bool | ||||||
|  | 		expectNoPlugins  bool | ||||||
|  | 		expectCategories []string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:        "responds to !help command", | ||||||
|  | 			messageText: "!help", | ||||||
|  | 			enabledPlugins: map[string]*MockPlugin{ | ||||||
|  | 				"dev.ping": { | ||||||
|  | 					id:   "dev.ping", | ||||||
|  | 					name: "Ping", | ||||||
|  | 					help: "Responds to 'ping' with 'pong'", | ||||||
|  | 				}, | ||||||
|  | 				"fun.dice": { | ||||||
|  | 					id:   "fun.dice", | ||||||
|  | 					name: "Dice Roller", | ||||||
|  | 					help: "Rolls dice when you type '!dice [formula]'", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse:   true, | ||||||
|  | 			expectCategories: []string{"Development", "Fun and Entertainment"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "ignores non-help messages", | ||||||
|  | 			messageText:    "hello world", | ||||||
|  | 			enabledPlugins: map[string]*MockPlugin{}, | ||||||
|  | 			expectResponse: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:            "ignores case variation", | ||||||
|  | 			messageText:     "!HELP", | ||||||
|  | 			enabledPlugins:  map[string]*MockPlugin{}, | ||||||
|  | 			expectResponse:  true, | ||||||
|  | 			expectNoPlugins: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:            "handles no enabled plugins", | ||||||
|  | 			messageText:     "!help", | ||||||
|  | 			enabledPlugins:  map[string]*MockPlugin{}, | ||||||
|  | 			expectResponse:  true, | ||||||
|  | 			expectNoPlugins: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			// Create mock database | ||||||
|  | 			mockDB := &MockDatabase{ | ||||||
|  | 				channelPlugins:         make(map[int64][]*model.ChannelPlugin), | ||||||
|  | 				platformChannelPlugins: make(map[string][]*model.ChannelPlugin), | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Setup channel plugins in mock database | ||||||
|  | 			var channelPluginList []*model.ChannelPlugin | ||||||
|  | 			pluginCounter := int64(1) | ||||||
|  | 			for pluginID := range tt.enabledPlugins { | ||||||
|  | 				channelPluginList = append(channelPluginList, &model.ChannelPlugin{ | ||||||
|  | 					ID:        pluginCounter, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  pluginID, | ||||||
|  | 					Enabled:   true, | ||||||
|  | 					Config:    make(map[string]interface{}), | ||||||
|  | 				}) | ||||||
|  | 				pluginCounter++ | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Set up both mapping approaches for the test | ||||||
|  | 			mockDB.channelPlugins[1] = channelPluginList | ||||||
|  | 			mockDB.platformChannelPlugins["test:test-channel"] = channelPluginList | ||||||
|  | 
 | ||||||
|  | 			// Create help plugin | ||||||
|  | 			p := New(mockDB) | ||||||
|  | 
 | ||||||
|  | 			// Create mock channel | ||||||
|  | 			channel := &model.Channel{ | ||||||
|  | 				ID:                1, | ||||||
|  | 				Platform:          "test", | ||||||
|  | 				PlatformChannelID: "test-channel", | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Create test message | ||||||
|  | 			msg := &model.Message{ | ||||||
|  | 				ID:      "test-msg", | ||||||
|  | 				Text:    tt.messageText, | ||||||
|  | 				Chat:    "test-chat", | ||||||
|  | 				Channel: channel, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Mock the plugin registry | ||||||
|  | 			originalRegistry := plugin.GetAvailablePlugins() | ||||||
|  | 
 | ||||||
|  | 			// Override the registry for this test | ||||||
|  | 			plugin.ClearRegistry() | ||||||
|  | 			for _, mockPlugin := range tt.enabledPlugins { | ||||||
|  | 				plugin.Register(mockPlugin) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Call OnMessage | ||||||
|  | 			actions := p.OnMessage(msg, map[string]interface{}{}, nil) | ||||||
|  | 
 | ||||||
|  | 			// Restore original registry | ||||||
|  | 			plugin.ClearRegistry() | ||||||
|  | 			for _, p := range originalRegistry { | ||||||
|  | 				plugin.Register(p) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if !tt.expectResponse { | ||||||
|  | 				if len(actions) != 0 { | ||||||
|  | 					t.Errorf("Expected no response, but got %d actions", len(actions)) | ||||||
|  | 				} | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if len(actions) != 1 { | ||||||
|  | 				t.Errorf("Expected 1 action, got %d", len(actions)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			action := actions[0] | ||||||
|  | 			if action.Type != model.ActionSendMessage { | ||||||
|  | 				t.Errorf("Expected ActionSendMessage, got %v", action.Type) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			responseText := action.Message.Text | ||||||
|  | 
 | ||||||
|  | 			if tt.expectNoPlugins { | ||||||
|  | 				if !strings.Contains(responseText, "No plugins are currently enabled") { | ||||||
|  | 					t.Errorf("Expected 'no plugins' message, got: %s", responseText) | ||||||
|  | 				} | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Check that expected categories appear in response | ||||||
|  | 			for _, category := range tt.expectCategories { | ||||||
|  | 				if !strings.Contains(responseText, "**"+category+":**") { | ||||||
|  | 					t.Errorf("Expected category '%s' in response, got: %s", category, responseText) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Check that plugin names and help text appear | ||||||
|  | 			for _, mockPlugin := range tt.enabledPlugins { | ||||||
|  | 				if !strings.Contains(responseText, mockPlugin.GetName()) { | ||||||
|  | 					t.Errorf("Expected plugin name '%s' in response", mockPlugin.GetName()) | ||||||
|  | 				} | ||||||
|  | 				if !strings.Contains(responseText, mockPlugin.GetHelp()) { | ||||||
|  | 					t.Errorf("Expected plugin help '%s' in response", mockPlugin.GetHelp()) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -24,11 +24,12 @@ func New() *PingPlugin { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // OnMessage handles incoming messages | // OnMessage handles incoming messages | ||||||
| func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { | func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
| 	if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") { | 	if !strings.EqualFold(strings.TrimSpace(msg.Text), "ping") { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Create the response message | ||||||
| 	response := &model.Message{ | 	response := &model.Message{ | ||||||
| 		Text:    "pong", | 		Text:    "pong", | ||||||
| 		Chat:    msg.Chat, | 		Chat:    msg.Chat, | ||||||
|  | @ -36,5 +37,13 @@ func (p *PingPlugin) OnMessage(msg *model.Message, config map[string]interface{} | ||||||
| 		Channel: msg.Channel, | 		Channel: msg.Channel, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return []*model.Message{response} | 	// Create an action to send the message | ||||||
|  | 	action := &model.MessageAction{ | ||||||
|  | 		Type:    model.ActionSendMessage, | ||||||
|  | 		Message: response, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{action} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| package plugin | package plugin | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"maps" | ||||||
| 	"sync" | 	"sync" | ||||||
| 
 | 
 | ||||||
| 	"git.nakama.town/fmartingr/butterrobot/internal/model" | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | @ -41,13 +42,31 @@ func GetAvailablePlugins() map[string]model.Plugin { | ||||||
| 
 | 
 | ||||||
| 	// Create a copy to avoid race conditions | 	// Create a copy to avoid race conditions | ||||||
| 	result := make(map[string]model.Plugin, len(plugins)) | 	result := make(map[string]model.Plugin, len(plugins)) | ||||||
| 	for id, plugin := range plugins { | 	maps.Copy(result, plugins) | ||||||
| 		result[id] = plugin | 
 | ||||||
|  | 	return result | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetAvailablePluginIDs returns a slice of all registered plugin IDs | ||||||
|  | func GetAvailablePluginIDs() []string { | ||||||
|  | 	pluginsMu.RLock() | ||||||
|  | 	defer pluginsMu.RUnlock() | ||||||
|  | 
 | ||||||
|  | 	result := make([]string, 0, len(plugins)) | ||||||
|  | 	for pluginID := range plugins { | ||||||
|  | 		result = append(result, pluginID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return result | 	return result | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ClearRegistry clears all registered plugins (for testing) | ||||||
|  | func ClearRegistry() { | ||||||
|  | 	pluginsMu.Lock() | ||||||
|  | 	defer pluginsMu.Unlock() | ||||||
|  | 	plugins = make(map[string]model.Plugin) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // BasePlugin provides a common base for plugins | // BasePlugin provides a common base for plugins | ||||||
| type BasePlugin struct { | type BasePlugin struct { | ||||||
| 	ID             string | 	ID             string | ||||||
|  | @ -77,6 +96,6 @@ func (p *BasePlugin) RequiresConfig() bool { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // OnMessage is the default implementation that does nothing | // OnMessage is the default implementation that does nothing | ||||||
| func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}) []*model.Message { | func (p *BasePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										331
									
								
								internal/plugin/plugin_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										331
									
								
								internal/plugin/plugin_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,331 @@ | ||||||
|  | package plugin | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Mock plugin for testing | ||||||
|  | type testPlugin struct { | ||||||
|  | 	BasePlugin | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *testPlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	return []*model.MessageAction{ | ||||||
|  | 		{ | ||||||
|  | 			Type: model.ActionSendMessage, | ||||||
|  | 			Message: &model.Message{ | ||||||
|  | 				Text:    "test response", | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGetAvailablePluginIDs(t *testing.T) { | ||||||
|  | 	// Clear registry before test | ||||||
|  | 	ClearRegistry() | ||||||
|  | 
 | ||||||
|  | 	// Register test plugins | ||||||
|  | 	testPlugin1 := &testPlugin{ | ||||||
|  | 		BasePlugin: BasePlugin{ | ||||||
|  | 			ID:   "test.plugin1", | ||||||
|  | 			Name: "Test Plugin 1", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	testPlugin2 := &testPlugin{ | ||||||
|  | 		BasePlugin: BasePlugin{ | ||||||
|  | 			ID:   "test.plugin2", | ||||||
|  | 			Name: "Test Plugin 2", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	Register(testPlugin1) | ||||||
|  | 	Register(testPlugin2) | ||||||
|  | 
 | ||||||
|  | 	// Test GetAvailablePluginIDs | ||||||
|  | 	pluginIDs := GetAvailablePluginIDs() | ||||||
|  | 
 | ||||||
|  | 	if len(pluginIDs) != 2 { | ||||||
|  | 		t.Errorf("Expected 2 plugin IDs, got %d", len(pluginIDs)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check that both plugin IDs are present | ||||||
|  | 	found1, found2 := false, false | ||||||
|  | 	for _, id := range pluginIDs { | ||||||
|  | 		if id == "test.plugin1" { | ||||||
|  | 			found1 = true | ||||||
|  | 		} | ||||||
|  | 		if id == "test.plugin2" { | ||||||
|  | 			found2 = true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !found1 { | ||||||
|  | 		t.Errorf("Expected to find test.plugin1 in plugin IDs") | ||||||
|  | 	} | ||||||
|  | 	if !found2 { | ||||||
|  | 		t.Errorf("Expected to find test.plugin2 in plugin IDs") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestEnableAllPluginsProcessingLogic(t *testing.T) { | ||||||
|  | 	// Clear registry before test | ||||||
|  | 	ClearRegistry() | ||||||
|  | 
 | ||||||
|  | 	// Register test plugins | ||||||
|  | 	testPlugin1 := &testPlugin{ | ||||||
|  | 		BasePlugin: BasePlugin{ | ||||||
|  | 			ID:   "ping", | ||||||
|  | 			Name: "Ping Plugin", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	testPlugin2 := &testPlugin{ | ||||||
|  | 		BasePlugin: BasePlugin{ | ||||||
|  | 			ID:   "echo", | ||||||
|  | 			Name: "Echo Plugin", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	testPlugin3 := &testPlugin{ | ||||||
|  | 		BasePlugin: BasePlugin{ | ||||||
|  | 			ID:   "help", | ||||||
|  | 			Name: "Help Plugin", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	Register(testPlugin1) | ||||||
|  | 	Register(testPlugin2) | ||||||
|  | 	Register(testPlugin3) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EnableAllPlugins false - only explicitly enabled plugins", func(t *testing.T) { | ||||||
|  | 		// Create a channel with EnableAllPlugins = false and only some plugins enabled | ||||||
|  | 		channel := &model.Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  false, | ||||||
|  | 			Plugins: map[string]*model.ChannelPlugin{ | ||||||
|  | 				"ping": { | ||||||
|  | 					ID:        1, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "ping", | ||||||
|  | 					Enabled:   true, | ||||||
|  | 					Config:    map[string]interface{}{"key": "value"}, | ||||||
|  | 				}, | ||||||
|  | 				"echo": { | ||||||
|  | 					ID:        2, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "echo", | ||||||
|  | 					Enabled:   false, // Disabled | ||||||
|  | 					Config:    map[string]interface{}{}, | ||||||
|  | 				}, | ||||||
|  | 				// help plugin not configured | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Simulate the plugin processing logic from handleMessage | ||||||
|  | 		var pluginsToProcess []string | ||||||
|  | 
 | ||||||
|  | 		if channel.EnableAllPlugins { | ||||||
|  | 			pluginsToProcess = GetAvailablePluginIDs() | ||||||
|  | 		} else { | ||||||
|  | 			for pluginID := range channel.Plugins { | ||||||
|  | 				if channel.HasEnabledPlugin(pluginID) { | ||||||
|  | 					pluginsToProcess = append(pluginsToProcess, pluginID) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Should only have "ping" since echo is disabled and help is not configured | ||||||
|  | 		if len(pluginsToProcess) != 1 { | ||||||
|  | 			t.Errorf("Expected 1 plugin to process, got %d: %v", len(pluginsToProcess), pluginsToProcess) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if len(pluginsToProcess) > 0 && pluginsToProcess[0] != "ping" { | ||||||
|  | 			t.Errorf("Expected ping plugin to be processed, got %s", pluginsToProcess[0]) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EnableAllPlugins true - all registered plugins", func(t *testing.T) { | ||||||
|  | 		// Create a channel with EnableAllPlugins = true | ||||||
|  | 		channel := &model.Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  true, | ||||||
|  | 			Plugins: map[string]*model.ChannelPlugin{ | ||||||
|  | 				"ping": { | ||||||
|  | 					ID:        1, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "ping", | ||||||
|  | 					Enabled:   true, | ||||||
|  | 					Config:    map[string]interface{}{"key": "value"}, | ||||||
|  | 				}, | ||||||
|  | 				"echo": { | ||||||
|  | 					ID:        2, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "echo", | ||||||
|  | 					Enabled:   false, // Disabled, but should still be processed | ||||||
|  | 					Config:    map[string]interface{}{}, | ||||||
|  | 				}, | ||||||
|  | 				// help plugin not configured, but should still be processed | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Simulate the plugin processing logic from handleMessage | ||||||
|  | 		var pluginsToProcess []string | ||||||
|  | 
 | ||||||
|  | 		if channel.EnableAllPlugins { | ||||||
|  | 			pluginsToProcess = GetAvailablePluginIDs() | ||||||
|  | 		} else { | ||||||
|  | 			for pluginID := range channel.Plugins { | ||||||
|  | 				if channel.HasEnabledPlugin(pluginID) { | ||||||
|  | 					pluginsToProcess = append(pluginsToProcess, pluginID) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Should have all 3 registered plugins | ||||||
|  | 		if len(pluginsToProcess) != 3 { | ||||||
|  | 			t.Errorf("Expected 3 plugins to process, got %d: %v", len(pluginsToProcess), pluginsToProcess) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Check that all plugins are included | ||||||
|  | 		expectedPlugins := map[string]bool{"ping": false, "echo": false, "help": false} | ||||||
|  | 		for _, pluginID := range pluginsToProcess { | ||||||
|  | 			if _, exists := expectedPlugins[pluginID]; exists { | ||||||
|  | 				expectedPlugins[pluginID] = true | ||||||
|  | 			} else { | ||||||
|  | 				t.Errorf("Unexpected plugin in processing list: %s", pluginID) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for pluginID, found := range expectedPlugins { | ||||||
|  | 			if !found { | ||||||
|  | 				t.Errorf("Expected plugin %s to be in processing list", pluginID) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Plugin configuration handling", func(t *testing.T) { | ||||||
|  | 		// Test the configuration logic from handleMessage | ||||||
|  | 		channel := &model.Channel{ | ||||||
|  | 			ID:                1, | ||||||
|  | 			Platform:          "telegram", | ||||||
|  | 			PlatformChannelID: "123456", | ||||||
|  | 			Enabled:           true, | ||||||
|  | 			EnableAllPlugins:  true, | ||||||
|  | 			Plugins: map[string]*model.ChannelPlugin{ | ||||||
|  | 				"ping": { | ||||||
|  | 					ID:        1, | ||||||
|  | 					ChannelID: 1, | ||||||
|  | 					PluginID:  "ping", | ||||||
|  | 					Enabled:   true, | ||||||
|  | 					Config:    map[string]interface{}{"configured": "value"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		testCases := []struct { | ||||||
|  | 			pluginID       string | ||||||
|  | 			expectedConfig map[string]interface{} | ||||||
|  | 		}{ | ||||||
|  | 			{ | ||||||
|  | 				pluginID:       "ping", | ||||||
|  | 				expectedConfig: map[string]interface{}{"configured": "value"}, | ||||||
|  | 			}, | ||||||
|  | 			{ | ||||||
|  | 				pluginID:       "echo", // Not explicitly configured | ||||||
|  | 				expectedConfig: map[string]interface{}{}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for _, tc := range testCases { | ||||||
|  | 			// Simulate the config retrieval logic from handleMessage | ||||||
|  | 			var config map[string]interface{} | ||||||
|  | 			if channelPlugin, exists := channel.Plugins[tc.pluginID]; exists { | ||||||
|  | 				config = channelPlugin.Config | ||||||
|  | 			} else { | ||||||
|  | 				config = make(map[string]interface{}) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if len(config) != len(tc.expectedConfig) { | ||||||
|  | 				t.Errorf("Plugin %s: expected config length %d, got %d", tc.pluginID, len(tc.expectedConfig), len(config)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for key, expectedValue := range tc.expectedConfig { | ||||||
|  | 				if actualValue, exists := config[key]; !exists || actualValue != expectedValue { | ||||||
|  | 					t.Errorf("Plugin %s: expected config[%s] = %v, got %v", tc.pluginID, key, expectedValue, actualValue) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestPluginRegistry(t *testing.T) { | ||||||
|  | 	// Clear registry before test | ||||||
|  | 	ClearRegistry() | ||||||
|  | 
 | ||||||
|  | 	testPlugin := &testPlugin{ | ||||||
|  | 		BasePlugin: BasePlugin{ | ||||||
|  | 			ID:   "test.registry", | ||||||
|  | 			Name: "Test Registry Plugin", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	t.Run("Register and Get plugin", func(t *testing.T) { | ||||||
|  | 		Register(testPlugin) | ||||||
|  | 
 | ||||||
|  | 		retrieved, err := Get("test.registry") | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("Failed to get registered plugin: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if retrieved.GetID() != "test.registry" { | ||||||
|  | 			t.Errorf("Expected plugin ID 'test.registry', got '%s'", retrieved.GetID()) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Get nonexistent plugin", func(t *testing.T) { | ||||||
|  | 		_, err := Get("nonexistent.plugin") | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Errorf("Expected error when getting nonexistent plugin, got nil") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err != model.ErrPluginNotFound { | ||||||
|  | 			t.Errorf("Expected ErrPluginNotFound, got %v", err) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("GetAvailablePlugins", func(t *testing.T) { | ||||||
|  | 		plugins := GetAvailablePlugins() | ||||||
|  | 
 | ||||||
|  | 		if len(plugins) != 1 { | ||||||
|  | 			t.Errorf("Expected 1 plugin in registry, got %d", len(plugins)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if plugin, exists := plugins["test.registry"]; !exists { | ||||||
|  | 			t.Errorf("Expected to find test.registry in available plugins") | ||||||
|  | 		} else if plugin.GetID() != "test.registry" { | ||||||
|  | 			t.Errorf("Expected plugin ID 'test.registry', got '%s'", plugin.GetID()) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("ClearRegistry", func(t *testing.T) { | ||||||
|  | 		ClearRegistry() | ||||||
|  | 
 | ||||||
|  | 		plugins := GetAvailablePlugins() | ||||||
|  | 		if len(plugins) != 0 { | ||||||
|  | 			t.Errorf("Expected 0 plugins after clearing registry, got %d", len(plugins)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		_, err := Get("test.registry") | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Errorf("Expected error when getting plugin after clearing registry, got nil") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										200
									
								
								internal/plugin/reminder/reminder.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								internal/plugin/reminder/reminder.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,200 @@ | ||||||
|  | package reminder | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Duration regex patterns to match reminders | ||||||
|  | var ( | ||||||
|  | 	remindMePattern = regexp.MustCompile(`(?i)^!remindme\s(\d+)(y|mo|d|h|m|s)$`) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // ReminderCreator is an interface for creating reminders | ||||||
|  | type ReminderCreator interface { | ||||||
|  | 	CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Reminder is a plugin that sets reminders for messages | ||||||
|  | type Reminder struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | 	creator ReminderCreator | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new Reminder plugin | ||||||
|  | func New(creator ReminderCreator) *Reminder { | ||||||
|  | 	return &Reminder{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:             "reminder.remindme", | ||||||
|  | 			Name:           "Remind Me", | ||||||
|  | 			Help:           "Reply to a message with `!remindme <duration>` to set a reminder (e.g., `!remindme 2d` for 2 days, `!remindme 1y` for 1 year).", | ||||||
|  | 			ConfigRequired: false, | ||||||
|  | 		}, | ||||||
|  | 		creator: creator, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage processes incoming messages | ||||||
|  | func (r *Reminder) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	// Only process replies to messages | ||||||
|  | 	if msg.ReplyTo == "" { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check if the message is a reminder command | ||||||
|  | 	match := remindMePattern.FindStringSubmatch(msg.Text) | ||||||
|  | 	if match == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Parse the duration | ||||||
|  | 	amount, err := strconv.Atoi(match[1]) | ||||||
|  | 	if err != nil { | ||||||
|  | 		errorMsg := &model.Message{ | ||||||
|  | 			Text:    "Invalid duration format. Please use a number followed by y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).", | ||||||
|  | 			Chat:    msg.Chat, | ||||||
|  | 			Channel: msg.Channel, | ||||||
|  | 			Author:  "bot", | ||||||
|  | 			FromBot: true, | ||||||
|  | 			Date:    time.Now(), | ||||||
|  | 			ReplyTo: msg.ID, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return []*model.MessageAction{ | ||||||
|  | 			{ | ||||||
|  | 				Type:    model.ActionSendMessage, | ||||||
|  | 				Message: errorMsg, | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Calculate the trigger time | ||||||
|  | 	var duration time.Duration | ||||||
|  | 	unit := match[2] | ||||||
|  | 	switch strings.ToLower(unit) { | ||||||
|  | 	case "y": | ||||||
|  | 		duration = time.Duration(amount) * 365 * 24 * time.Hour | ||||||
|  | 	case "mo": | ||||||
|  | 		duration = time.Duration(amount) * 30 * 24 * time.Hour | ||||||
|  | 	case "d": | ||||||
|  | 		duration = time.Duration(amount) * 24 * time.Hour | ||||||
|  | 	case "h": | ||||||
|  | 		duration = time.Duration(amount) * time.Hour | ||||||
|  | 	case "m": | ||||||
|  | 		duration = time.Duration(amount) * time.Minute | ||||||
|  | 	case "s": | ||||||
|  | 		duration = time.Duration(amount) * time.Second | ||||||
|  | 	default: | ||||||
|  | 		errorMsg := &model.Message{ | ||||||
|  | 			Text:    "Invalid duration unit. Please use y (years), mo (months), d (days), h (hours), m (minutes), or s (seconds).", | ||||||
|  | 			Chat:    msg.Chat, | ||||||
|  | 			Channel: msg.Channel, | ||||||
|  | 			Author:  "bot", | ||||||
|  | 			FromBot: true, | ||||||
|  | 			Date:    time.Now(), | ||||||
|  | 			ReplyTo: msg.ID, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return []*model.MessageAction{ | ||||||
|  | 			{ | ||||||
|  | 				Type:    model.ActionSendMessage, | ||||||
|  | 				Message: errorMsg, | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	triggerAt := time.Now().Add(duration) | ||||||
|  | 
 | ||||||
|  | 	// Determine the username for the reminder | ||||||
|  | 	username := msg.Author | ||||||
|  | 	if username == "" { | ||||||
|  | 		// Try to extract username from message raw data | ||||||
|  | 		if authorData, ok := msg.Raw["author"].(map[string]interface{}); ok { | ||||||
|  | 			if name, ok := authorData["username"].(string); ok { | ||||||
|  | 				username = name | ||||||
|  | 			} else if name, ok := authorData["name"].(string); ok { | ||||||
|  | 				username = name | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create the reminder | ||||||
|  | 	_, err = r.creator.CreateReminder( | ||||||
|  | 		msg.Channel.Platform, | ||||||
|  | 		msg.Chat, | ||||||
|  | 		msg.ID, | ||||||
|  | 		msg.ReplyTo, | ||||||
|  | 		msg.Author, | ||||||
|  | 		username, | ||||||
|  | 		"", // No additional content for now | ||||||
|  | 		triggerAt, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		errorMsg := &model.Message{ | ||||||
|  | 			Text:    fmt.Sprintf("Failed to create reminder: %v", err), | ||||||
|  | 			Chat:    msg.Chat, | ||||||
|  | 			Channel: msg.Channel, | ||||||
|  | 			Author:  "bot", | ||||||
|  | 			FromBot: true, | ||||||
|  | 			Date:    time.Now(), | ||||||
|  | 			ReplyTo: msg.ID, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return []*model.MessageAction{ | ||||||
|  | 			{ | ||||||
|  | 				Type:    model.ActionSendMessage, | ||||||
|  | 				Message: errorMsg, | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Format the acknowledgment message | ||||||
|  | 	var confirmText string | ||||||
|  | 	switch strings.ToLower(unit) { | ||||||
|  | 	case "y": | ||||||
|  | 		confirmText = fmt.Sprintf("I'll remind you about this message in %d year(s) on %s", amount, triggerAt.Format("Mon, Jan 2, 2006 at 15:04")) | ||||||
|  | 	case "mo": | ||||||
|  | 		confirmText = fmt.Sprintf("I'll remind you about this message in %d month(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04")) | ||||||
|  | 	case "d": | ||||||
|  | 		confirmText = fmt.Sprintf("I'll remind you about this message in %d day(s) on %s", amount, triggerAt.Format("Mon, Jan 2 at 15:04")) | ||||||
|  | 	case "h": | ||||||
|  | 		confirmText = fmt.Sprintf("I'll remind you about this message in %d hour(s) at %s", amount, triggerAt.Format("15:04")) | ||||||
|  | 	case "m": | ||||||
|  | 		confirmText = fmt.Sprintf("I'll remind you about this message in %d minute(s) at %s", amount, triggerAt.Format("15:04")) | ||||||
|  | 	case "s": | ||||||
|  | 		confirmText = fmt.Sprintf("I'll remind you about this message in %d second(s)", amount) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create confirmation message | ||||||
|  | 	confirmMsg := &model.Message{ | ||||||
|  | 		Text:    confirmText, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 		Author:  "bot", | ||||||
|  | 		FromBot: true, | ||||||
|  | 		Date:    time.Now(), | ||||||
|  | 		ReplyTo: msg.ID, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{ | ||||||
|  | 		{ | ||||||
|  | 			Type:    model.ActionSendMessage, | ||||||
|  | 			Message: confirmMsg, | ||||||
|  | 			Chat:    msg.Chat, | ||||||
|  | 			Channel: msg.Channel, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										177
									
								
								internal/plugin/reminder/reminder_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								internal/plugin/reminder/reminder_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,177 @@ | ||||||
|  | package reminder | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/testutil" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // MockCreator is a mock implementation of ReminderCreator for testing | ||||||
|  | type MockCreator struct { | ||||||
|  | 	reminders []*model.Reminder | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *MockCreator) CreateReminder(platform, channelID, messageID, replyToID, userID, username, content string, triggerAt time.Time) (*model.Reminder, error) { | ||||||
|  | 	reminder := &model.Reminder{ | ||||||
|  | 		ID:        int64(len(m.reminders) + 1), | ||||||
|  | 		Platform:  platform, | ||||||
|  | 		ChannelID: channelID, | ||||||
|  | 		MessageID: messageID, | ||||||
|  | 		ReplyToID: replyToID, | ||||||
|  | 		UserID:    userID, | ||||||
|  | 		Username:  username, | ||||||
|  | 		Content:   content, | ||||||
|  | 		TriggerAt: triggerAt, | ||||||
|  | 	} | ||||||
|  | 	m.reminders = append(m.reminders, reminder) | ||||||
|  | 	return reminder, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestReminderOnMessage(t *testing.T) { | ||||||
|  | 	creator := &MockCreator{reminders: make([]*model.Reminder, 0)} | ||||||
|  | 	plugin := New(creator) | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name           string | ||||||
|  | 		message        *model.Message | ||||||
|  | 		expectResponse bool | ||||||
|  | 		expectReminder bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "Valid reminder command - years", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "!remindme 1y", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: true, | ||||||
|  | 			expectReminder: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Valid reminder command - months", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "!remindme 3mo", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: true, | ||||||
|  | 			expectReminder: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Valid reminder command - days", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "!remindme 2d", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: true, | ||||||
|  | 			expectReminder: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Valid reminder command - hours", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "!remindme 5h", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: true, | ||||||
|  | 			expectReminder: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Valid reminder command - minutes", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "!remindme 30m", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: true, | ||||||
|  | 			expectReminder: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Valid reminder command - seconds", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "!remindme 60s", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: true, | ||||||
|  | 			expectReminder: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Not a reply", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "!remindme 2d", | ||||||
|  | 				ReplyTo: "", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: false, | ||||||
|  | 			expectReminder: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Not a reminder command", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "hello world", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: false, | ||||||
|  | 			expectReminder: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Invalid duration format", | ||||||
|  | 			message: &model.Message{ | ||||||
|  | 				Text:    "!remindme abc", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Author:  "testuser", | ||||||
|  | 				Channel: &model.Channel{Platform: "test"}, | ||||||
|  | 			}, | ||||||
|  | 			expectResponse: false, | ||||||
|  | 			expectReminder: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			initialCount := len(creator.reminders) | ||||||
|  | 			mockCache := &testutil.MockCache{} | ||||||
|  | 			actions := plugin.OnMessage(tt.message, nil, mockCache) | ||||||
|  | 
 | ||||||
|  | 			if tt.expectResponse && len(actions) == 0 { | ||||||
|  | 				t.Errorf("Expected response action, but got none") | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if !tt.expectResponse && len(actions) > 0 { | ||||||
|  | 				t.Errorf("Expected no actions, but got %d", len(actions)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Verify action type is correct when actions are returned | ||||||
|  | 			if len(actions) > 0 { | ||||||
|  | 				if actions[0].Type != model.ActionSendMessage { | ||||||
|  | 					t.Errorf("Expected action type to be %s, but got %s", model.ActionSendMessage, actions[0].Type) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if actions[0].Message == nil { | ||||||
|  | 					t.Errorf("Expected message in action to not be nil") | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if tt.expectReminder && len(creator.reminders) != initialCount+1 { | ||||||
|  | 				t.Errorf("Expected reminder to be created, but it wasn't") | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if !tt.expectReminder && len(creator.reminders) != initialCount { | ||||||
|  | 				t.Errorf("Expected no reminder to be created, but got %d", len(creator.reminders)-initialCount) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										50
									
								
								internal/plugin/searchreplace/README.md
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								internal/plugin/searchreplace/README.md
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,50 @@ | ||||||
|  | # Search and Replace Plugin | ||||||
|  | 
 | ||||||
|  | This plugin allows users to perform search and replace operations on messages by replying to a message with a search/replace command. | ||||||
|  | 
 | ||||||
|  | ## Usage | ||||||
|  | 
 | ||||||
|  | To use the plugin, reply to any message with a command in the following format: | ||||||
|  | 
 | ||||||
|  | ``` | ||||||
|  | s/search/replace/[flags] | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | Where: | ||||||
|  | - `search` is the text you want to find (case-sensitive by default) | ||||||
|  | - `replace` is the text you want to substitute in place of the search term | ||||||
|  | - `flags` (optional) control the behavior of the replacement | ||||||
|  | 
 | ||||||
|  | ### Supported Flags | ||||||
|  | 
 | ||||||
|  | - `g` - Global: Replace all occurrences of the search term (without this flag, only the first occurrence is replaced) | ||||||
|  | - `i` - Case insensitive: Match regardless of case | ||||||
|  | - `n` - Treat search pattern as a regular expression (advanced users) | ||||||
|  | 
 | ||||||
|  | ### Examples | ||||||
|  | 
 | ||||||
|  | 1. Basic replacement (replaces first occurrence): | ||||||
|  |    ``` | ||||||
|  |    s/hello/hi/ | ||||||
|  |    ``` | ||||||
|  | 
 | ||||||
|  | 2. Global replacement (replaces all occurrences): | ||||||
|  |    ``` | ||||||
|  |    s/hello/hi/g | ||||||
|  |    ``` | ||||||
|  | 
 | ||||||
|  | 3. Case-insensitive replacement: | ||||||
|  |    ``` | ||||||
|  |    s/Hello/hi/i | ||||||
|  |    ``` | ||||||
|  | 
 | ||||||
|  | 4. Combined flags (global and case-insensitive): | ||||||
|  |    ``` | ||||||
|  |    s/hello/hi/gi | ||||||
|  |    ``` | ||||||
|  | 
 | ||||||
|  | ## Limitations | ||||||
|  | 
 | ||||||
|  | - The plugin can only access the text content of the original message | ||||||
|  | - Regular expression support is available with the `n` flag, but should be used carefully as invalid regex patterns will cause errors | ||||||
|  | - The plugin does not modify the original message; it creates a new message with the replaced text | ||||||
							
								
								
									
										182
									
								
								internal/plugin/searchreplace/searchreplace.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								internal/plugin/searchreplace/searchreplace.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,182 @@ | ||||||
|  | package searchreplace | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Regex pattern for search and replace operations: s/search/replace/[flags] | ||||||
|  | var searchReplacePattern = regexp.MustCompile(`^s/([^/]*)/([^/]*)(?:/([gimnsuy]*))?$`) | ||||||
|  | 
 | ||||||
|  | // SearchReplacePlugin is a plugin for performing search and replace operations on messages | ||||||
|  | type SearchReplacePlugin struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new SearchReplacePlugin instance | ||||||
|  | func New() *SearchReplacePlugin { | ||||||
|  | 	return &SearchReplacePlugin{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:   "util.searchreplace", | ||||||
|  | 			Name: "Search and Replace", | ||||||
|  | 			Help: "Reply to a message with a search and replace pattern (`s/search/replace/[flags]`) to create a modified message. " + | ||||||
|  | 				"Supported flags: g (global), i (case insensitive)", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage handles incoming messages | ||||||
|  | func (p *SearchReplacePlugin) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	// Only process replies to messages | ||||||
|  | 	if msg.ReplyTo == "" { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check if the message matches the search/replace pattern | ||||||
|  | 	match := searchReplacePattern.FindStringSubmatch(strings.TrimSpace(msg.Text)) | ||||||
|  | 	if match == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the original message text from the reply_to_message structure in Telegram messages | ||||||
|  | 	var originalText string | ||||||
|  | 
 | ||||||
|  | 	// For Telegram messages | ||||||
|  | 	if msgData, ok := msg.Raw["message"].(map[string]interface{}); ok { | ||||||
|  | 		if replyMsg, ok := msgData["reply_to_message"].(map[string]interface{}); ok { | ||||||
|  | 			if text, ok := replyMsg["text"].(string); ok { | ||||||
|  | 				originalText = text | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Generic fallback for other platforms or if the above method fails | ||||||
|  | 	if originalText == "" && msg.Raw["original_message"] != nil { | ||||||
|  | 		if original, ok := msg.Raw["original_message"].(map[string]interface{}); ok { | ||||||
|  | 			if text, ok := original["text"].(string); ok { | ||||||
|  | 				originalText = text | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if originalText == "" { | ||||||
|  | 		// If we couldn't find the original message text, inform the user | ||||||
|  | 		return []*model.MessageAction{ | ||||||
|  | 			{ | ||||||
|  | 				Type: model.ActionSendMessage, | ||||||
|  | 				Message: &model.Message{ | ||||||
|  | 					Text:    "Sorry, I couldn't find the original message text to perform the replacement.", | ||||||
|  | 					Chat:    msg.Chat, | ||||||
|  | 					Channel: msg.Channel, | ||||||
|  | 					ReplyTo: msg.ID, | ||||||
|  | 				}, | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract search pattern, replacement and flags | ||||||
|  | 	searchPattern := match[1] | ||||||
|  | 	replacement := match[2] | ||||||
|  | 	flags := "" | ||||||
|  | 	if len(match) > 3 { | ||||||
|  | 		flags = match[3] | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Process the replacement | ||||||
|  | 	result, err := p.performReplacement(originalText, searchPattern, replacement, flags) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return []*model.MessageAction{ | ||||||
|  | 			{ | ||||||
|  | 				Type: model.ActionSendMessage, | ||||||
|  | 				Message: &model.Message{ | ||||||
|  | 					Text:    fmt.Sprintf("Error performing replacement: %s", err.Error()), | ||||||
|  | 					Chat:    msg.Chat, | ||||||
|  | 					Channel: msg.Channel, | ||||||
|  | 					ReplyTo: msg.ID, | ||||||
|  | 				}, | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Only send a response if the text actually changed | ||||||
|  | 	if result == originalText { | ||||||
|  | 		return []*model.MessageAction{ | ||||||
|  | 			{ | ||||||
|  | 				Type: model.ActionSendMessage, | ||||||
|  | 				Message: &model.Message{ | ||||||
|  | 					Text:    "No changes were made to the original message.", | ||||||
|  | 					Chat:    msg.Chat, | ||||||
|  | 					Channel: msg.Channel, | ||||||
|  | 					ReplyTo: msg.ID, | ||||||
|  | 				}, | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create a response with the modified text | ||||||
|  | 	return []*model.MessageAction{ | ||||||
|  | 		{ | ||||||
|  | 			Type: model.ActionSendMessage, | ||||||
|  | 			Message: &model.Message{ | ||||||
|  | 				Text:    result, | ||||||
|  | 				Chat:    msg.Chat, | ||||||
|  | 				Channel: msg.Channel, | ||||||
|  | 				ReplyTo: msg.ReplyTo, // Reply to the original message | ||||||
|  | 			}, | ||||||
|  | 			Chat:    msg.Chat, | ||||||
|  | 			Channel: msg.Channel, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // performReplacement performs the search and replace operation on the given text | ||||||
|  | func (p *SearchReplacePlugin) performReplacement(text, search, replace, flags string) (string, error) { | ||||||
|  | 	// Process flags | ||||||
|  | 	globalReplace := strings.Contains(flags, "g") | ||||||
|  | 	caseInsensitive := strings.Contains(flags, "i") | ||||||
|  | 
 | ||||||
|  | 	// Create the regex pattern | ||||||
|  | 	pattern := search | ||||||
|  | 	regexFlags := "" | ||||||
|  | 	if caseInsensitive { | ||||||
|  | 		regexFlags += "(?i)" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Escape special characters if we're not in a regular expression | ||||||
|  | 	if !strings.Contains(flags, "n") { | ||||||
|  | 		pattern = regexp.QuoteMeta(pattern) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Compile the regex | ||||||
|  | 	reg, err := regexp.Compile(regexFlags + pattern) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("invalid search pattern: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Perform the replacement | ||||||
|  | 	var result string | ||||||
|  | 	if globalReplace { | ||||||
|  | 		result = reg.ReplaceAllString(text, replace) | ||||||
|  | 	} else { | ||||||
|  | 		// For non-global replace, only replace the first occurrence | ||||||
|  | 		indices := reg.FindStringIndex(text) | ||||||
|  | 		if indices == nil { | ||||||
|  | 			// No match found | ||||||
|  | 			return text, nil | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		result = text[:indices[0]] + replace + text[indices[1]:] | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return result, nil | ||||||
|  | } | ||||||
							
								
								
									
										218
									
								
								internal/plugin/searchreplace/searchreplace_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								internal/plugin/searchreplace/searchreplace_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,218 @@ | ||||||
|  | package searchreplace | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/testutil" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestSearchReplace(t *testing.T) { | ||||||
|  | 	// Create plugin instance | ||||||
|  | 	p := New() | ||||||
|  | 
 | ||||||
|  | 	// Test cases | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name           string | ||||||
|  | 		command        string | ||||||
|  | 		originalText   string | ||||||
|  | 		expectedResult string | ||||||
|  | 		expectActions  bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:           "Simple replacement", | ||||||
|  | 			command:        "s/hello/world/", | ||||||
|  | 			originalText:   "hello everyone", | ||||||
|  | 			expectedResult: "world everyone", | ||||||
|  | 			expectActions:  true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "Case-insensitive replacement", | ||||||
|  | 			command:        "s/HELLO/world/i", | ||||||
|  | 			originalText:   "Hello everyone", | ||||||
|  | 			expectedResult: "world everyone", | ||||||
|  | 			expectActions:  true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "Global replacement", | ||||||
|  | 			command:        "s/a/X/g", | ||||||
|  | 			originalText:   "banana", | ||||||
|  | 			expectedResult: "bXnXnX", | ||||||
|  | 			expectActions:  true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "No change", | ||||||
|  | 			command:        "s/nothing/something/", | ||||||
|  | 			originalText:   "test message", | ||||||
|  | 			expectedResult: "test message", | ||||||
|  | 			expectActions:  true, // We send a "no changes" message | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "Not a search/replace command", | ||||||
|  | 			command:        "hello", | ||||||
|  | 			originalText:   "test message", | ||||||
|  | 			expectedResult: "", | ||||||
|  | 			expectActions:  false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "Invalid pattern", | ||||||
|  | 			command:        "s/(/)/", | ||||||
|  | 			originalText:   "test message", | ||||||
|  | 			expectedResult: "error", | ||||||
|  | 			expectActions:  true, // We send an error message | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			// Create message | ||||||
|  | 			msg := &model.Message{ | ||||||
|  | 				Text:    tc.command, | ||||||
|  | 				Chat:    "test-chat", | ||||||
|  | 				ReplyTo: "original-message-id", | ||||||
|  | 				Date:    time.Now(), | ||||||
|  | 				Channel: &model.Channel{ | ||||||
|  | 					Platform: "test", | ||||||
|  | 				}, | ||||||
|  | 				Raw: map[string]interface{}{ | ||||||
|  | 					"message": map[string]interface{}{ | ||||||
|  | 						"reply_to_message": map[string]interface{}{ | ||||||
|  | 							"text": tc.originalText, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Process message | ||||||
|  | 			mockCache := &testutil.MockCache{} | ||||||
|  | 			actions := p.OnMessage(msg, nil, mockCache) | ||||||
|  | 
 | ||||||
|  | 			// Check results | ||||||
|  | 			if tc.expectActions { | ||||||
|  | 				if len(actions) == 0 { | ||||||
|  | 					t.Fatalf("Expected actions but got none") | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				action := actions[0] | ||||||
|  | 				if action.Type != model.ActionSendMessage { | ||||||
|  | 					t.Fatalf("Expected send message action but got %v", action.Type) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if tc.expectedResult == "error" { | ||||||
|  | 					// Just checking that we got an error message | ||||||
|  | 					if action.Message == nil || action.Message.Text == "" { | ||||||
|  | 						t.Fatalf("Expected error message but got empty message") | ||||||
|  | 					} | ||||||
|  | 				} else if tc.originalText == tc.expectedResult { | ||||||
|  | 					// Check if we got the "no changes" message | ||||||
|  | 					if action.Message == nil || action.Message.Text != "No changes were made to the original message." { | ||||||
|  | 						t.Fatalf("Expected 'no changes' message but got: %s", action.Message.Text) | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					// Check actual replacement result | ||||||
|  | 					if action.Message == nil || action.Message.Text != tc.expectedResult { | ||||||
|  | 						t.Fatalf("Expected result: %s, got: %s", tc.expectedResult, action.Message.Text) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} else if len(actions) > 0 { | ||||||
|  | 				t.Fatalf("Expected no actions but got %d", len(actions)) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestPerformReplacement(t *testing.T) { | ||||||
|  | 	p := New() | ||||||
|  | 
 | ||||||
|  | 	// Test cases for the performReplacement function | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name      string | ||||||
|  | 		text      string | ||||||
|  | 		search    string | ||||||
|  | 		replace   string | ||||||
|  | 		flags     string | ||||||
|  | 		expected  string | ||||||
|  | 		expectErr bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:      "Simple replacement", | ||||||
|  | 			text:      "Hello World", | ||||||
|  | 			search:    "Hello", | ||||||
|  | 			replace:   "Hi", | ||||||
|  | 			flags:     "", | ||||||
|  | 			expected:  "Hi World", | ||||||
|  | 			expectErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:      "Case insensitive", | ||||||
|  | 			text:      "Hello World", | ||||||
|  | 			search:    "hello", | ||||||
|  | 			replace:   "Hi", | ||||||
|  | 			flags:     "i", | ||||||
|  | 			expected:  "Hi World", | ||||||
|  | 			expectErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:      "Global replacement", | ||||||
|  | 			text:      "one two one two", | ||||||
|  | 			search:    "one", | ||||||
|  | 			replace:   "1", | ||||||
|  | 			flags:     "g", | ||||||
|  | 			expected:  "1 two 1 two", | ||||||
|  | 			expectErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:      "No match", | ||||||
|  | 			text:      "Hello World", | ||||||
|  | 			search:    "Goodbye", | ||||||
|  | 			replace:   "Hi", | ||||||
|  | 			flags:     "", | ||||||
|  | 			expected:  "Hello World", | ||||||
|  | 			expectErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:      "Invalid regex", | ||||||
|  | 			text:      "Hello World", | ||||||
|  | 			search:    "(", | ||||||
|  | 			replace:   "Hi", | ||||||
|  | 			flags:     "n", // treat as regex | ||||||
|  | 			expected:  "", | ||||||
|  | 			expectErr: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:      "Escape special chars by default", | ||||||
|  | 			text:      "Hello (World)", | ||||||
|  | 			search:    "(World)", | ||||||
|  | 			replace:   "[Earth]", | ||||||
|  | 			flags:     "", | ||||||
|  | 			expected:  "Hello [Earth]", | ||||||
|  | 			expectErr: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:      "Regex mode with n flag", | ||||||
|  | 			text:      "Hello (World)", | ||||||
|  | 			search:    "\\(World\\)", | ||||||
|  | 			replace:   "[Earth]", | ||||||
|  | 			flags:     "n", | ||||||
|  | 			expected:  "Hello [Earth]", | ||||||
|  | 			expectErr: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			result, err := p.performReplacement(tc.text, tc.search, tc.replace, tc.flags) | ||||||
|  | 
 | ||||||
|  | 			if tc.expectErr { | ||||||
|  | 				if err == nil { | ||||||
|  | 					t.Fatalf("Expected error but got none") | ||||||
|  | 				} | ||||||
|  | 			} else if err != nil { | ||||||
|  | 				t.Fatalf("Unexpected error: %v", err) | ||||||
|  | 			} else if result != tc.expected { | ||||||
|  | 				t.Fatalf("Expected result: %s, got: %s", tc.expected, result) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										92
									
								
								internal/plugin/social/instagram.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								internal/plugin/social/instagram.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,92 @@ | ||||||
|  | package social | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/url" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // InstagramExpander transforms instagram.com links to ddinstagram.com links | ||||||
|  | type InstagramExpander struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new InstagramExpander instance | ||||||
|  | func NewInstagramExpander() *InstagramExpander { | ||||||
|  | 	return &InstagramExpander{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:             "social.instagram", | ||||||
|  | 			Name:           "Instagram Link Expander", | ||||||
|  | 			Help:           "Automatically converts instagram.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: ddinstagram.com)", | ||||||
|  | 			ConfigRequired: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage handles incoming messages | ||||||
|  | func (p *InstagramExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	// Skip empty messages | ||||||
|  | 	if strings.TrimSpace(msg.Text) == "" { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get replacement domain from config, default to ddinstagram.com | ||||||
|  | 	replacementDomain := "ddinstagram.com" | ||||||
|  | 	if domain, ok := config["domain"].(string); ok && domain != "" { | ||||||
|  | 		replacementDomain = domain | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Regex to match instagram.com links | ||||||
|  | 	// Match both http://instagram.com and https://instagram.com formats | ||||||
|  | 	// Also match www.instagram.com | ||||||
|  | 	instagramRegex := regexp.MustCompile(`https?://(www\.)?(instagram\.com)/[^\s]+`) | ||||||
|  | 
 | ||||||
|  | 	// Check if the message contains an Instagram link | ||||||
|  | 	if !instagramRegex.MatchString(msg.Text) { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Replace instagram.com with configured domain in the message and clean query parameters | ||||||
|  | 	transformed := instagramRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { | ||||||
|  | 		// Parse the URL | ||||||
|  | 		parsedURL, err := url.Parse(link) | ||||||
|  | 		if err != nil { | ||||||
|  | 			// If parsing fails, just do the simple replacement | ||||||
|  | 			return link | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Ensure we don't change links that already come from the replacement domain | ||||||
|  | 		if parsedURL.Host != "instagram.com" && parsedURL.Host != "www.instagram.com" { | ||||||
|  | 			return link | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Change the host to the configured domain | ||||||
|  | 		parsedURL.Host = replacementDomain | ||||||
|  | 
 | ||||||
|  | 		// Remove query parameters | ||||||
|  | 		parsedURL.RawQuery = "" | ||||||
|  | 
 | ||||||
|  | 		// Return the cleaned URL | ||||||
|  | 		return parsedURL.String() | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Create response message | ||||||
|  | 	response := &model.Message{ | ||||||
|  | 		Text:    transformed, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		ReplyTo: msg.ID, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	action := &model.MessageAction{ | ||||||
|  | 		Type:    model.ActionSendMessage, | ||||||
|  | 		Message: response, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{action} | ||||||
|  | } | ||||||
							
								
								
									
										88
									
								
								internal/plugin/social/twitter.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								internal/plugin/social/twitter.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,88 @@ | ||||||
|  | package social | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/url" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/plugin" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // TwitterExpander transforms twitter.com links to fxtwitter.com links | ||||||
|  | type TwitterExpander struct { | ||||||
|  | 	plugin.BasePlugin | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New creates a new TwitterExpander instance | ||||||
|  | func NewTwitterExpander() *TwitterExpander { | ||||||
|  | 	return &TwitterExpander{ | ||||||
|  | 		BasePlugin: plugin.BasePlugin{ | ||||||
|  | 			ID:             "social.twitter", | ||||||
|  | 			Name:           "Twitter Link Expander", | ||||||
|  | 			Help:           "Automatically converts twitter.com and x.com links to alternative domain links and removes tracking parameters. Configure 'domain' option to set replacement domain (default: fxtwitter.com)", | ||||||
|  | 			ConfigRequired: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // OnMessage handles incoming messages | ||||||
|  | func (p *TwitterExpander) OnMessage(msg *model.Message, config map[string]interface{}, cache model.CacheInterface) []*model.MessageAction { | ||||||
|  | 	// Skip empty messages | ||||||
|  | 	if strings.TrimSpace(msg.Text) == "" { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get replacement domain from config, default to fxtwitter.com | ||||||
|  | 	replacementDomain := "fxtwitter.com" | ||||||
|  | 	if domain, ok := config["domain"].(string); ok && domain != "" { | ||||||
|  | 		replacementDomain = domain | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Regex to match twitter.com links | ||||||
|  | 	// Match both http://twitter.com and https://twitter.com formats | ||||||
|  | 	// Also match www.twitter.com | ||||||
|  | 	twitterRegex := regexp.MustCompile(`https?://(www\.)?(twitter\.com|x\.com)/[^\s]+`) | ||||||
|  | 
 | ||||||
|  | 	// Check if the message contains a Twitter link | ||||||
|  | 	if !twitterRegex.MatchString(msg.Text) { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Replace twitter.com/x.com with configured domain in the message and clean query parameters | ||||||
|  | 	transformed := twitterRegex.ReplaceAllStringFunc(msg.Text, func(link string) string { | ||||||
|  | 		// Parse the URL | ||||||
|  | 		parsedURL, err := url.Parse(link) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return link | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Change the host to the configured domain | ||||||
|  | 		if strings.Contains(parsedURL.Host, "twitter.com") || strings.Contains(parsedURL.Host, "x.com") { | ||||||
|  | 			parsedURL.Host = replacementDomain | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Remove query parameters | ||||||
|  | 		parsedURL.RawQuery = "" | ||||||
|  | 
 | ||||||
|  | 		// Return the cleaned URL | ||||||
|  | 		return parsedURL.String() | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Create response message | ||||||
|  | 	response := &model.Message{ | ||||||
|  | 		Text:    transformed, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		ReplyTo: msg.ID, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	action := &model.MessageAction{ | ||||||
|  | 		Type:    model.ActionSendMessage, | ||||||
|  | 		Message: response, | ||||||
|  | 		Chat:    msg.Chat, | ||||||
|  | 		Channel: msg.Channel, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return []*model.MessageAction{action} | ||||||
|  | } | ||||||
							
								
								
									
										120
									
								
								internal/plugin/social/twitter_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								internal/plugin/social/twitter_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,120 @@ | ||||||
|  | package social | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestTwitterExpander_OnMessage(t *testing.T) { | ||||||
|  | 	plugin := NewTwitterExpander() | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name     string | ||||||
|  | 		input    string | ||||||
|  | 		config   map[string]interface{} | ||||||
|  | 		expected string | ||||||
|  | 		hasReply bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:     "Twitter URL with default domain", | ||||||
|  | 			input:    "https://twitter.com/user/status/123456789", | ||||||
|  | 			config:   map[string]interface{}{}, | ||||||
|  | 			expected: "https://fxtwitter.com/user/status/123456789", | ||||||
|  | 			hasReply: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "X.com URL with custom domain", | ||||||
|  | 			input:    "https://x.com/elonmusk/status/987654321", | ||||||
|  | 			config:   map[string]interface{}{"domain": "vxtwitter.com"}, | ||||||
|  | 			expected: "https://vxtwitter.com/elonmusk/status/987654321", | ||||||
|  | 			hasReply: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "Twitter URL with tracking parameters", | ||||||
|  | 			input:    "https://twitter.com/openai/status/555?ref_src=twsrc%5Etfw&s=20", | ||||||
|  | 			config:   map[string]interface{}{}, | ||||||
|  | 			expected: "https://fxtwitter.com/openai/status/555", | ||||||
|  | 			hasReply: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "www.twitter.com URL", | ||||||
|  | 			input:    "https://www.twitter.com/user/status/789", | ||||||
|  | 			config:   map[string]interface{}{"domain": "nitter.net"}, | ||||||
|  | 			expected: "https://nitter.net/user/status/789", | ||||||
|  | 			hasReply: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "Mixed text with Twitter URL", | ||||||
|  | 			input:    "Check this out: https://twitter.com/user/status/123 amazing!", | ||||||
|  | 			config:   map[string]interface{}{}, | ||||||
|  | 			expected: "Check this out: https://fxtwitter.com/user/status/123 amazing!", | ||||||
|  | 			hasReply: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "No Twitter URLs", | ||||||
|  | 			input:    "Just some regular text https://youtube.com/watch?v=abc", | ||||||
|  | 			config:   map[string]interface{}{}, | ||||||
|  | 			expected: "", | ||||||
|  | 			hasReply: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "Empty message", | ||||||
|  | 			input:    "", | ||||||
|  | 			config:   map[string]interface{}{}, | ||||||
|  | 			expected: "", | ||||||
|  | 			hasReply: false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			msg := &model.Message{ | ||||||
|  | 				ID:   "test_msg", | ||||||
|  | 				Text: tt.input, | ||||||
|  | 				Chat: "test_chat", | ||||||
|  | 				Channel: &model.Channel{ | ||||||
|  | 					ID:                1, | ||||||
|  | 					Platform:          "telegram", | ||||||
|  | 					PlatformChannelID: "test_chat", | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			actions := plugin.OnMessage(msg, tt.config, nil) | ||||||
|  | 
 | ||||||
|  | 			if !tt.hasReply { | ||||||
|  | 				if len(actions) != 0 { | ||||||
|  | 					t.Errorf("Expected no actions, got %d", len(actions)) | ||||||
|  | 				} | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if len(actions) != 1 { | ||||||
|  | 				t.Errorf("Expected 1 action, got %d", len(actions)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			action := actions[0] | ||||||
|  | 			if action.Type != model.ActionSendMessage { | ||||||
|  | 				t.Errorf("Expected ActionSendMessage, got %s", action.Type) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if action.Message == nil { | ||||||
|  | 				t.Error("Expected message in action, got nil") | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if action.Message.Text != tt.expected { | ||||||
|  | 				t.Errorf("Expected '%s', got '%s'", tt.expected, action.Message.Text) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if action.Message.ReplyTo != msg.ID { | ||||||
|  | 				t.Errorf("Expected ReplyTo '%s', got '%s'", msg.ID, action.Message.ReplyTo) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if action.Message.Raw == nil || action.Message.Raw["parse_mode"] != "" { | ||||||
|  | 				t.Error("Expected parse_mode to be empty string to disable markdown parsing") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -3,6 +3,9 @@ package queue | ||||||
| import ( | import ( | ||||||
| 	"log/slog" | 	"log/slog" | ||||||
| 	"sync" | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"git.nakama.town/fmartingr/butterrobot/internal/model" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Item represents a queue item | // Item represents a queue item | ||||||
|  | @ -14,14 +17,19 @@ type Item struct { | ||||||
| // HandlerFunc defines a function that processes queue items | // HandlerFunc defines a function that processes queue items | ||||||
| type HandlerFunc func(item Item) | type HandlerFunc func(item Item) | ||||||
| 
 | 
 | ||||||
|  | // ReminderHandlerFunc defines a function that processes reminder items | ||||||
|  | type ReminderHandlerFunc func(reminder *model.Reminder) | ||||||
|  | 
 | ||||||
| // Queue represents a message queue | // Queue represents a message queue | ||||||
| type Queue struct { | type Queue struct { | ||||||
| 	items    chan Item | 	items           chan Item | ||||||
| 	wg       sync.WaitGroup | 	wg              sync.WaitGroup | ||||||
| 	quit     chan struct{} | 	quit            chan struct{} | ||||||
| 	logger   *slog.Logger | 	logger          *slog.Logger | ||||||
| 	running  bool | 	running         bool | ||||||
| 	runMutex sync.Mutex | 	runMutex        sync.Mutex | ||||||
|  | 	reminderTicker  *time.Ticker | ||||||
|  | 	reminderHandler ReminderHandlerFunc | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New creates a new Queue instance | // New creates a new Queue instance | ||||||
|  | @ -49,6 +57,24 @@ func (q *Queue) Start(handler HandlerFunc) { | ||||||
| 	go q.worker(handler) | 	go q.worker(handler) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // StartReminderScheduler starts the reminder scheduler | ||||||
|  | func (q *Queue) StartReminderScheduler(handler ReminderHandlerFunc) { | ||||||
|  | 	q.runMutex.Lock() | ||||||
|  | 	defer q.runMutex.Unlock() | ||||||
|  | 
 | ||||||
|  | 	if q.reminderTicker != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	q.reminderHandler = handler | ||||||
|  | 
 | ||||||
|  | 	// Check for reminders every minute | ||||||
|  | 	q.reminderTicker = time.NewTicker(1 * time.Minute) | ||||||
|  | 
 | ||||||
|  | 	q.wg.Add(1) | ||||||
|  | 	go q.reminderWorker() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Stop stops processing queue items | // Stop stops processing queue items | ||||||
| func (q *Queue) Stop() { | func (q *Queue) Stop() { | ||||||
| 	q.runMutex.Lock() | 	q.runMutex.Lock() | ||||||
|  | @ -59,6 +85,12 @@ func (q *Queue) Stop() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	q.running = false | 	q.running = false | ||||||
|  | 
 | ||||||
|  | 	// Stop reminder ticker if it exists | ||||||
|  | 	if q.reminderTicker != nil { | ||||||
|  | 		q.reminderTicker.Stop() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	close(q.quit) | 	close(q.quit) | ||||||
| 	q.wg.Wait() | 	q.wg.Wait() | ||||||
| } | } | ||||||
|  | @ -96,4 +128,34 @@ func (q *Queue) worker(handler HandlerFunc) { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // reminderWorker processes reminder items on a schedule | ||||||
|  | func (q *Queue) reminderWorker() { | ||||||
|  | 	defer q.wg.Done() | ||||||
|  | 
 | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case <-q.reminderTicker.C: | ||||||
|  | 			// This is triggered every minute to check for pending reminders | ||||||
|  | 			q.logger.Debug("Checking for pending reminders") | ||||||
|  | 
 | ||||||
|  | 			if q.reminderHandler != nil { | ||||||
|  | 				// The handler is responsible for fetching and processing reminders | ||||||
|  | 				func() { | ||||||
|  | 					defer func() { | ||||||
|  | 						if r := recover(); r != nil { | ||||||
|  | 							q.logger.Error("Panic in reminder worker", "error", r) | ||||||
|  | 						} | ||||||
|  | 					}() | ||||||
|  | 
 | ||||||
|  | 					// Call the handler with a nil reminder to indicate it should check the database | ||||||
|  | 					q.reminderHandler(nil) | ||||||
|  | 				}() | ||||||
|  | 			} | ||||||
|  | 		case <-q.quit: | ||||||
|  | 			// Quit worker | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
							
								
								
									
										29
									
								
								internal/testutil/mock_cache.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								internal/testutil/mock_cache.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,29 @@ | ||||||
|  | package testutil | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // MockCache implements the CacheInterface for testing | ||||||
|  | type MockCache struct{} | ||||||
|  | 
 | ||||||
|  | func (m *MockCache) Get(key string, destination interface{}) error { | ||||||
|  | 	return errors.New("cache miss") // Always return cache miss for tests | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *MockCache) Set(key string, value interface{}, expiration *time.Time) error { | ||||||
|  | 	return nil // Always succeed for tests | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *MockCache) SetWithTTL(key string, value interface{}, ttl time.Duration) error { | ||||||
|  | 	return nil // Always succeed for tests | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *MockCache) Delete(key string) error { | ||||||
|  | 	return nil // Always succeed for tests | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *MockCache) Exists(key string) (bool, error) { | ||||||
|  | 	return false, nil // Always return false for tests | ||||||
|  | } | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue