diff --git a/internal/api/rate_limit_config_test.go b/internal/api/rate_limit_config_test.go index 9dc5ff93b..4ed15812f 100644 --- a/internal/api/rate_limit_config_test.go +++ b/internal/api/rate_limit_config_test.go @@ -7,6 +7,373 @@ import ( "testing" ) +func TestGetRateLimiterForEndpoint(t *testing.T) { + // Ensure rate limiters are initialized + InitializeRateLimiters() + + tests := []struct { + name string + path string + method string + wantLimiterPtr **RateLimiter + wantLimiterNm string + }{ + // Authentication endpoints + { + name: "login endpoint POST", + path: "/api/login", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints, + wantLimiterNm: "AuthEndpoints", + }, + { + name: "login endpoint GET", + path: "/api/login", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints, + wantLimiterNm: "AuthEndpoints", + }, + { + name: "logout endpoint", + path: "/api/logout", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints, + wantLimiterNm: "AuthEndpoints", + }, + { + name: "change password endpoint", + path: "/api/security/change-password", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints, + wantLimiterNm: "AuthEndpoints", + }, + { + name: "generic auth endpoint", + path: "/api/auth/something", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints, + wantLimiterNm: "AuthEndpoints", + }, + { + name: "login with uppercase path still matches (lowercased)", + path: "/API/LOGIN", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.AuthEndpoints, + wantLimiterNm: "AuthEndpoints", + }, + + // Recovery endpoints + { + name: "recovery endpoint POST", + path: "/api/security/recovery", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.RecoveryEndpoints, + wantLimiterNm: "RecoveryEndpoints", + }, + { + name: "recovery endpoint GET", + path: "/api/security/recovery/status", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.RecoveryEndpoints, + wantLimiterNm: "RecoveryEndpoints", + }, + + // Export/Import endpoints + { + name: "config export endpoint", + path: "/api/config/export", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.ExportEndpoints, + wantLimiterNm: "ExportEndpoints", + }, + { + name: "config import endpoint", + path: "/api/config/import", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.ExportEndpoints, + wantLimiterNm: "ExportEndpoints", + }, + + // Configuration endpoints (write operations) + { + name: "config nodes POST", + path: "/api/config/nodes", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints, + wantLimiterNm: "ConfigEndpoints", + }, + { + name: "config nodes PUT", + path: "/api/config/nodes/123", + method: http.MethodPut, + wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints, + wantLimiterNm: "ConfigEndpoints", + }, + { + name: "config nodes DELETE", + path: "/api/config/nodes/123", + method: http.MethodDelete, + wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints, + wantLimiterNm: "ConfigEndpoints", + }, + { + name: "config system POST", + path: "/api/config/system", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints, + wantLimiterNm: "ConfigEndpoints", + }, + { + name: "config webhooks POST", + path: "/api/config/webhooks", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints, + wantLimiterNm: "ConfigEndpoints", + }, + { + name: "config alerts POST", + path: "/api/config/alerts", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.ConfigEndpoints, + wantLimiterNm: "ConfigEndpoints", + }, + + // Configuration endpoints (read operations - higher limits) + { + name: "config nodes GET", + path: "/api/config/nodes", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + { + name: "config system GET", + path: "/api/config/system", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + { + name: "discover GET", + path: "/api/discover", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + { + name: "security status GET", + path: "/api/security/status", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + + // Update endpoints + { + name: "updates check GET", + path: "/api/updates", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.UpdateEndpoints, + wantLimiterNm: "UpdateEndpoints", + }, + { + name: "updates status GET", + path: "/api/updates/status", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.UpdateEndpoints, + wantLimiterNm: "UpdateEndpoints", + }, + + // WebSocket endpoints + { + name: "ws endpoint", + path: "/ws", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.WebSocketEndpoints, + wantLimiterNm: "WebSocketEndpoints", + }, + { + name: "ws subpath", + path: "/ws/events", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.WebSocketEndpoints, + wantLimiterNm: "WebSocketEndpoints", + }, + + // Public endpoints + { + name: "health endpoint", + path: "/api/health", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + { + name: "version endpoint", + path: "/api/version", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + { + name: "validate bootstrap token", + path: "/api/security/validate-bootstrap-token", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + { + name: "temperature proxy authorized nodes", + path: "/api/temperature-proxy/authorized-nodes", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + { + name: "metrics endpoint", + path: "/metrics", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.PublicEndpoints, + wantLimiterNm: "PublicEndpoints", + }, + + // General API (default) + { + name: "guests endpoint", + path: "/api/guests", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.GeneralAPI, + wantLimiterNm: "GeneralAPI", + }, + { + name: "nodes endpoint", + path: "/api/nodes", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.GeneralAPI, + wantLimiterNm: "GeneralAPI", + }, + { + name: "unknown api endpoint", + path: "/api/unknown/endpoint", + method: http.MethodGet, + wantLimiterPtr: &globalRateLimitConfig.GeneralAPI, + wantLimiterNm: "GeneralAPI", + }, + { + name: "containers endpoint POST", + path: "/api/containers/start", + method: http.MethodPost, + wantLimiterPtr: &globalRateLimitConfig.GeneralAPI, + wantLimiterNm: "GeneralAPI", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetRateLimiterForEndpoint(tt.path, tt.method) + want := *tt.wantLimiterPtr + if got != want { + t.Errorf("GetRateLimiterForEndpoint(%q, %q) returned %s limiter, want %s", + tt.path, tt.method, identifyLimiter(got), tt.wantLimiterNm) + } + }) + } +} + +// identifyLimiter returns a string name for a rate limiter pointer for test output +func identifyLimiter(rl *RateLimiter) string { + if rl == nil { + return "nil" + } + if globalRateLimitConfig == nil { + return "unknown (config nil)" + } + switch rl { + case globalRateLimitConfig.AuthEndpoints: + return "AuthEndpoints" + case globalRateLimitConfig.ConfigEndpoints: + return "ConfigEndpoints" + case globalRateLimitConfig.ExportEndpoints: + return "ExportEndpoints" + case globalRateLimitConfig.RecoveryEndpoints: + return "RecoveryEndpoints" + case globalRateLimitConfig.UpdateEndpoints: + return "UpdateEndpoints" + case globalRateLimitConfig.WebSocketEndpoints: + return "WebSocketEndpoints" + case globalRateLimitConfig.GeneralAPI: + return "GeneralAPI" + case globalRateLimitConfig.PublicEndpoints: + return "PublicEndpoints" + default: + return "unknown" + } +} + +func TestGetRateLimiterForEndpoint_PriorityOrder(t *testing.T) { + // Test that more specific patterns match before general ones + InitializeRateLimiters() + + tests := []struct { + name string + path string + method string + wantLimiterNm string + reason string + }{ + { + name: "recovery takes priority over security status", + path: "/api/security/recovery", + method: http.MethodPost, + wantLimiterNm: "RecoveryEndpoints", + reason: "recovery paths should match before generic security paths", + }, + { + name: "export takes priority over config read", + path: "/api/config/export", + method: http.MethodGet, + wantLimiterNm: "ExportEndpoints", + reason: "export paths should match before generic config read paths", + }, + { + name: "auth path takes priority over other paths", + path: "/api/auth/callback", + method: http.MethodGet, + wantLimiterNm: "AuthEndpoints", + reason: "auth paths should be rate limited strictly", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetRateLimiterForEndpoint(tt.path, tt.method) + gotName := identifyLimiter(got) + if gotName != tt.wantLimiterNm { + t.Errorf("GetRateLimiterForEndpoint(%q, %q) = %s, want %s; %s", + tt.path, tt.method, gotName, tt.wantLimiterNm, tt.reason) + } + }) + } +} + +func TestGetRateLimiterForEndpoint_InitializesIfNeeded(t *testing.T) { + // Save and restore global state + saved := globalRateLimitConfig + globalRateLimitConfig = nil + t.Cleanup(func() { + globalRateLimitConfig = saved + }) + + // Call GetRateLimiterForEndpoint with nil config - should initialize + got := GetRateLimiterForEndpoint("/api/login", http.MethodPost) + if got == nil { + t.Fatal("GetRateLimiterForEndpoint returned nil after initialization") + } + if globalRateLimitConfig == nil { + t.Fatal("globalRateLimitConfig should have been initialized") + } +} + func TestUniversalRateLimitMiddleware_HeaderFormat(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK)