Pulse/internal/api/rate_limit_config.go
rcourtman e894bc7b1d Fix recurring update issues (related to #671)
This commit addresses three recurring issues with the update system:

1. **Checksum mismatches (v4.27.0, v4.28.0):**
   - Root cause: Release process uploads checksums.txt first, but if artifacts
     are rebuilt after that upload, checksums become stale
   - Fix: Update RELEASE_CHECKLIST.md to REQUIRE running validate-release.sh
     before publishing (step 9, non-negotiable)
   - The validation script exists and catches these errors, but wasn't being
     enforced in the release process

2. **Duplicate error modals:**
   - Root cause: UpdateProgressModal rendered in both App.tsx
     (GlobalUpdateProgressWatcher) and UpdateBanner.tsx
   - Fix: Remove UpdateProgressModal from UpdateBanner.tsx
   - GlobalUpdateProgressWatcher automatically shows the modal when updates
     start, so the banner's modal is redundant

3. **Rate limiting too strict:**
   - Root cause: UpdateProgressModal polls /api/updates/status every 2 seconds
     (30 req/min), but rate limit was 20/min
   - Fix: Increase UpdateEndpoints rate limit from 20/min to 60/min
   - Allows modal to poll without hitting rate limits during updates

These were all manual process errors and configuration issues, not code bugs.
The validation script enforcement prevents future checksum mismatches.
2025-11-11 09:09:30 +00:00

202 lines
6.5 KiB
Go

package api
import (
"net/http"
"strconv"
"strings"
"time"
)
// EndpointRateLimitConfig defines rate limiting configuration for different endpoint categories
type EndpointRateLimitConfig struct {
AuthEndpoints *RateLimiter // Login, logout, password change
ConfigEndpoints *RateLimiter // Node configuration changes
ExportEndpoints *RateLimiter // Export/import operations
RecoveryEndpoints *RateLimiter // Recovery operations
UpdateEndpoints *RateLimiter // Update checks and operations
WebSocketEndpoints *RateLimiter // WebSocket connections
GeneralAPI *RateLimiter // General API calls
PublicEndpoints *RateLimiter // Public endpoints (health, version)
}
var globalRateLimitConfig *EndpointRateLimitConfig
// InitializeRateLimiters sets up rate limiters for all endpoint categories
func InitializeRateLimiters() {
globalRateLimitConfig = &EndpointRateLimitConfig{
// Authentication endpoints: strict limits to prevent brute force
AuthEndpoints: NewRateLimiter(10, 1*time.Minute), // 10 attempts per minute
// Configuration changes: moderate limits
ConfigEndpoints: NewRateLimiter(30, 1*time.Minute), // 30 changes per minute
// Export/import: very strict limits
ExportEndpoints: NewRateLimiter(5, 5*time.Minute), // 5 exports per 5 minutes
// Recovery operations: extremely strict
RecoveryEndpoints: NewRateLimiter(3, 10*time.Minute), // 3 attempts per 10 minutes
// Update operations: allow frequent polling during updates
UpdateEndpoints: NewRateLimiter(60, 1*time.Minute), // 60 checks per minute (modal polls every 2s)
// WebSocket connections: per-connection limits
WebSocketEndpoints: NewRateLimiter(30, 1*time.Minute), // 30 new connections per minute
// General API: higher limits for normal operations
GeneralAPI: NewRateLimiter(500, 1*time.Minute), // 500 requests per minute
// Public endpoints: very high limits (health checks, etc.)
PublicEndpoints: NewRateLimiter(1000, 1*time.Minute), // 1000 requests per minute
}
}
// GetRateLimiterForEndpoint returns the appropriate rate limiter for a given endpoint
func GetRateLimiterForEndpoint(path string, method string) *RateLimiter {
if globalRateLimitConfig == nil {
InitializeRateLimiters()
}
// Normalize path
path = strings.ToLower(path)
// Authentication endpoints
if strings.Contains(path, "/api/login") ||
strings.Contains(path, "/api/logout") ||
strings.Contains(path, "/api/security/change-password") ||
strings.Contains(path, "/api/auth") {
return globalRateLimitConfig.AuthEndpoints
}
// Recovery endpoints
if strings.Contains(path, "/api/security/recovery") {
return globalRateLimitConfig.RecoveryEndpoints
}
// Export/Import endpoints
if strings.Contains(path, "/api/config/export") ||
strings.Contains(path, "/api/config/import") {
return globalRateLimitConfig.ExportEndpoints
}
// Configuration endpoints (write operations only)
if method != "GET" && (strings.Contains(path, "/api/config/nodes") ||
strings.Contains(path, "/api/config/system") ||
strings.Contains(path, "/api/config/webhooks") ||
strings.Contains(path, "/api/config/alerts")) {
return globalRateLimitConfig.ConfigEndpoints
}
// Configuration read endpoints get higher limits to prevent UI issues
if method == "GET" && (strings.Contains(path, "/api/config/") ||
strings.Contains(path, "/api/discover") ||
strings.Contains(path, "/api/security/status")) {
return globalRateLimitConfig.PublicEndpoints // Use higher limit for reads
}
// Update endpoints
if strings.Contains(path, "/api/updates") {
return globalRateLimitConfig.UpdateEndpoints
}
// WebSocket endpoints
if strings.Contains(path, "/ws") {
return globalRateLimitConfig.WebSocketEndpoints
}
// Public endpoints (no auth required)
if strings.Contains(path, "/api/health") ||
strings.Contains(path, "/api/version") ||
strings.Contains(path, "/api/security/status") ||
strings.Contains(path, "/metrics") {
return globalRateLimitConfig.PublicEndpoints
}
// Default to general API rate limiter
return globalRateLimitConfig.GeneralAPI
}
// UniversalRateLimitMiddleware applies appropriate rate limiting to all endpoints
func UniversalRateLimitMiddleware(next http.Handler) http.Handler {
// Initialize rate limiters if not already done
if globalRateLimitConfig == nil {
InitializeRateLimiters()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip rate limiting for static assets
if !strings.HasPrefix(r.URL.Path, "/api") && !strings.HasPrefix(r.URL.Path, "/ws") {
next.ServeHTTP(w, r)
return
}
// Skip rate limiting for real-time data endpoints that are polled frequently
// These endpoints are essential for UI functionality and should not be rate limited
skipPaths := []string{
"/api/state", // Real-time state updates
"/api/guests/metadata", // Guest metadata (polled frequently)
}
for _, path := range skipPaths {
if strings.Contains(r.URL.Path, path) {
next.ServeHTTP(w, r)
return
}
}
// Extract client IP
ip := GetClientIP(r)
// Skip rate limiting only for direct loopback requests (no proxy headers)
if (ip == "127.0.0.1" || ip == "::1" || ip == "localhost") && isDirectLoopbackRequest(r) {
next.ServeHTTP(w, r)
return
}
// Get appropriate rate limiter for this endpoint
limiter := GetRateLimiterForEndpoint(r.URL.Path, r.Method)
// Check rate limit
if !limiter.Allow(ip) {
// Add retry-after header
w.Header().Set("Retry-After", "60")
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(limiter.limit))
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("X-RateLimit-Reset", time.Now().Add(limiter.window).Format(time.RFC3339))
http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests)
return
}
// Continue to next handler
next.ServeHTTP(w, r)
})
}
// ResetRateLimitForIP resets rate limit counters for a specific IP (use carefully)
func ResetRateLimitForIP(ip string) {
if globalRateLimitConfig == nil {
return
}
// Reset for all rate limiters
limiters := []*RateLimiter{
globalRateLimitConfig.AuthEndpoints,
globalRateLimitConfig.ConfigEndpoints,
globalRateLimitConfig.ExportEndpoints,
globalRateLimitConfig.RecoveryEndpoints,
globalRateLimitConfig.UpdateEndpoints,
globalRateLimitConfig.WebSocketEndpoints,
globalRateLimitConfig.GeneralAPI,
globalRateLimitConfig.PublicEndpoints,
}
for _, limiter := range limiters {
limiter.Reset(ip)
}
}
// Reset clears rate limit history for a specific IP
func (rl *RateLimiter) Reset(ip string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.attempts, ip)
}