Pulse/internal/api/rate_limit_config.go
rcourtman 7599915b8f refactor(api): remove sensor proxy config from API handlers
- config_handlers.go: remove proxy configuration endpoints
- system_settings.go: remove proxy-related settings
- rate_limit_config.go: update rate limit configuration
- Update related test files
2026-01-21 12:02:46 +00:00

204 lines
6.8 KiB
Go

package api
import (
"net/http"
"strconv"
"strings"
"time"
)
// EndpointRateLimitConfig defines rate limiting configuration for different endpoint categories
type EndpointRateLimitConfig struct {
AuthEndpoints *RateLimiter // Login, logout, password change
ConfigEndpoints *RateLimiter // Node configuration changes
ExportEndpoints *RateLimiter // Export/import operations
RecoveryEndpoints *RateLimiter // Recovery operations
UpdateEndpoints *RateLimiter // Update checks and operations
WebSocketEndpoints *RateLimiter // WebSocket connections
GeneralAPI *RateLimiter // General API calls
PublicEndpoints *RateLimiter // Public endpoints (health, version)
}
var globalRateLimitConfig *EndpointRateLimitConfig
// InitializeRateLimiters sets up rate limiters for all endpoint categories
func InitializeRateLimiters() {
globalRateLimitConfig = &EndpointRateLimitConfig{
// Authentication endpoints: strict limits to prevent brute force
AuthEndpoints: NewRateLimiter(10, 1*time.Minute), // 10 attempts per minute
// Configuration changes: moderate limits
ConfigEndpoints: NewRateLimiter(30, 1*time.Minute), // 30 changes per minute
// Export/import: very strict limits
ExportEndpoints: NewRateLimiter(5, 5*time.Minute), // 5 exports per 5 minutes
// Recovery operations: extremely strict
RecoveryEndpoints: NewRateLimiter(3, 10*time.Minute), // 3 attempts per 10 minutes
// Update operations: allow frequent polling during updates
UpdateEndpoints: NewRateLimiter(60, 1*time.Minute), // 60 checks per minute (modal polls every 2s)
// WebSocket connections: per-connection limits
WebSocketEndpoints: NewRateLimiter(30, 1*time.Minute), // 30 new connections per minute
// General API: higher limits for normal operations
GeneralAPI: NewRateLimiter(500, 1*time.Minute), // 500 requests per minute
// Public endpoints: very high limits (health checks, etc.)
PublicEndpoints: NewRateLimiter(1000, 1*time.Minute), // 1000 requests per minute
}
}
// GetRateLimiterForEndpoint returns the appropriate rate limiter for a given endpoint
func GetRateLimiterForEndpoint(path string, method string) *RateLimiter {
if globalRateLimitConfig == nil {
InitializeRateLimiters()
}
// Normalize path
path = strings.ToLower(path)
// Authentication endpoints
if strings.Contains(path, "/api/login") ||
strings.Contains(path, "/api/logout") ||
strings.Contains(path, "/api/security/change-password") ||
strings.Contains(path, "/api/auth") {
return globalRateLimitConfig.AuthEndpoints
}
// Recovery endpoints
if strings.Contains(path, "/api/security/recovery") {
return globalRateLimitConfig.RecoveryEndpoints
}
// Export/Import endpoints
if strings.Contains(path, "/api/config/export") ||
strings.Contains(path, "/api/config/import") {
return globalRateLimitConfig.ExportEndpoints
}
// Configuration endpoints (write operations only)
if method != "GET" && (strings.Contains(path, "/api/config/nodes") ||
strings.Contains(path, "/api/config/system") ||
strings.Contains(path, "/api/config/webhooks") ||
strings.Contains(path, "/api/config/alerts")) {
return globalRateLimitConfig.ConfigEndpoints
}
// Configuration read endpoints get higher limits to prevent UI issues
if method == "GET" && (strings.Contains(path, "/api/config/") ||
strings.Contains(path, "/api/discover") ||
strings.Contains(path, "/api/security/status")) {
return globalRateLimitConfig.PublicEndpoints // Use higher limit for reads
}
// Update endpoints
if strings.Contains(path, "/api/updates") {
return globalRateLimitConfig.UpdateEndpoints
}
// WebSocket endpoints
if strings.Contains(path, "/ws") {
return globalRateLimitConfig.WebSocketEndpoints
}
// Public endpoints (no auth required)
if strings.Contains(path, "/api/health") ||
strings.Contains(path, "/api/version") ||
strings.Contains(path, "/api/security/status") ||
strings.Contains(path, "/api/security/validate-bootstrap-token") ||
strings.Contains(path, "/metrics") {
return globalRateLimitConfig.PublicEndpoints
}
// Default to general API rate limiter
return globalRateLimitConfig.GeneralAPI
}
// UniversalRateLimitMiddleware applies appropriate rate limiting to all endpoints
func UniversalRateLimitMiddleware(next http.Handler) http.Handler {
// Initialize rate limiters if not already done
if globalRateLimitConfig == nil {
InitializeRateLimiters()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip rate limiting for static assets
if !strings.HasPrefix(r.URL.Path, "/api") && !strings.HasPrefix(r.URL.Path, "/ws") {
next.ServeHTTP(w, r)
return
}
// Skip rate limiting for real-time data endpoints that are polled frequently
// These endpoints are essential for UI functionality and should not be rate limited
skipPaths := []string{
"/api/state", // Real-time state updates
"/api/guests/metadata", // Guest metadata (polled frequently)
"/ws", // WebSocket connections (authenticated, reconnections are normal)
}
for _, path := range skipPaths {
if strings.Contains(r.URL.Path, path) {
next.ServeHTTP(w, r)
return
}
}
// Extract client IP
ip := GetClientIP(r)
// Skip rate limiting only for direct loopback requests (no proxy headers)
if (ip == "127.0.0.1" || ip == "::1" || ip == "localhost") && isDirectLoopbackRequest(r) {
next.ServeHTTP(w, r)
return
}
// Get appropriate rate limiter for this endpoint
limiter := GetRateLimiterForEndpoint(r.URL.Path, r.Method)
// Check rate limit
if !limiter.Allow(ip) {
// Add retry-after header matching the limiter's actual window
w.Header().Set("Retry-After", strconv.Itoa(int(limiter.window.Seconds())))
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(limiter.limit))
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("X-RateLimit-Reset", time.Now().Add(limiter.window).Format(time.RFC3339))
http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests)
return
}
// Continue to next handler
next.ServeHTTP(w, r)
})
}
// ResetRateLimitForIP resets rate limit counters for a specific IP (use carefully)
func ResetRateLimitForIP(ip string) {
if globalRateLimitConfig == nil {
return
}
// Reset for all rate limiters
limiters := []*RateLimiter{
globalRateLimitConfig.AuthEndpoints,
globalRateLimitConfig.ConfigEndpoints,
globalRateLimitConfig.ExportEndpoints,
globalRateLimitConfig.RecoveryEndpoints,
globalRateLimitConfig.UpdateEndpoints,
globalRateLimitConfig.WebSocketEndpoints,
globalRateLimitConfig.GeneralAPI,
globalRateLimitConfig.PublicEndpoints,
}
for _, limiter := range limiters {
limiter.Reset(ip)
}
}
// Reset clears rate limit history for a specific IP
func (rl *RateLimiter) Reset(ip string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.attempts, ip)
}