package api import ( "net" "net/http" "strings" "sync" "time" "github.com/rcourtman/pulse-go-rewrite/internal/utils" "github.com/rs/zerolog/log" ) // Security improvements for Pulse // generateCSRFToken creates a new CSRF token for a session func generateCSRFToken(sessionID string) string { return GetCSRFStore().GenerateCSRFToken(sessionID) } // validateCSRFToken checks if a CSRF token is valid for a session func validateCSRFToken(sessionID, token string) bool { return GetCSRFStore().ValidateCSRFToken(sessionID, token) } // CheckCSRF validates CSRF token for state-changing requests func CheckCSRF(w http.ResponseWriter, r *http.Request) bool { // Skip CSRF check for safe methods if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" { return true } // Skip CSRF for API token auth (API clients don't have sessions) if r.Header.Get("X-API-Token") != "" { return true } // Skip CSRF for Basic Auth (doesn't use sessions, not vulnerable to CSRF) if r.Header.Get("Authorization") != "" { return true } // Get session from cookie cookie, err := r.Cookie("pulse_session") if err != nil { // No session cookie means no CSRF check needed // (either no auth configured or using basic auth which doesn't use sessions) return true } // Get CSRF token from header or form csrfToken := r.Header.Get("X-CSRF-Token") if csrfToken == "" { csrfToken = r.FormValue("csrf_token") } // No CSRF token means request is not eligible for mutation if csrfToken == "" { log.Warn(). Str("path", r.URL.Path). Str("session", cookie.Value[:8]+"..."). Msg("Missing CSRF token") clearCSRFCookie(w) if newToken := issueNewCSRFCookie(w, r, cookie.Value); newToken != "" { w.Header().Set("X-CSRF-Token", newToken) } return false } // Check if the CSRF token validates if !validateCSRFToken(cookie.Value, csrfToken) { log.Warn(). Str("path", r.URL.Path). Str("session", cookie.Value[:8]+"..."). Str("provided_token", csrfToken[:8]+"..."). Msg("Invalid CSRF token") clearCSRFCookie(w) if newToken := issueNewCSRFCookie(w, r, cookie.Value); newToken != "" { w.Header().Set("X-CSRF-Token", newToken) } return false } return true } func clearCSRFCookie(w http.ResponseWriter) { if w == nil { return } http.SetCookie(w, &http.Cookie{ Name: "pulse_csrf", Value: "", Path: "/", MaxAge: -1, HttpOnly: false, }) } func issueNewCSRFCookie(w http.ResponseWriter, r *http.Request, sessionID string) string { if w == nil || r == nil { return "" } if strings.TrimSpace(sessionID) == "" { return "" } newToken := generateCSRFToken(sessionID) secure, sameSite := getCookieSettings(r) http.SetCookie(w, &http.Cookie{ Name: "pulse_csrf", Value: newToken, Path: "/", Secure: secure, SameSite: sameSite, MaxAge: 86400, }) return newToken } // Rate Limiting - using existing RateLimiter from ratelimit.go var ( // Auth endpoints: 10 attempts per minute authLimiter = NewRateLimiter(10, 1*time.Minute) // General API: 500 requests per minute (increased for metadata endpoints) apiLimiter = NewRateLimiter(500, 1*time.Minute) ) // GetClientIP extracts the client IP from the request func GetClientIP(r *http.Request) string { rawRemoteIP := extractRemoteIP(r.RemoteAddr) if rawRemoteIP == "" { return "" } // Only trust proxy headers when the immediate peer is trusted. if isTrustedProxyIP(rawRemoteIP) { if forwarded := firstValidForwardedIP(r.Header.Get("X-Forwarded-For")); forwarded != "" { return forwarded } if realIP := strings.TrimSpace(strings.Trim(r.Header.Get("X-Real-IP"), "[]")); realIP != "" && net.ParseIP(realIP) != nil { return realIP } } return rawRemoteIP } // Failed Login Tracking type FailedLogin struct { Count int LastAttempt time.Time LockedUntil time.Time } var ( failedLogins = make(map[string]*FailedLogin) failedMu sync.RWMutex maxFailedAttempts = 5 lockoutDuration = 15 * time.Minute trustedProxyOnce sync.Once trustedProxyCIDRs []*net.IPNet ) func loadTrustedProxyCIDRs() { raw := utils.GetenvTrim("PULSE_TRUSTED_PROXY_CIDRS") if raw == "" { return } for _, entry := range strings.Split(raw, ",") { entry = strings.TrimSpace(entry) if entry == "" { continue } if strings.Contains(entry, "/") { _, network, parseErr := net.ParseCIDR(entry) if parseErr == nil { network.IP = network.IP.Mask(network.Mask) trustedProxyCIDRs = append(trustedProxyCIDRs, network) continue } log.Warn(). Str("cidr", entry). Err(parseErr). Msg("Ignoring invalid CIDR in PULSE_TRUSTED_PROXY_CIDRS") continue } ip := net.ParseIP(entry) if ip == nil { log.Warn(). Str("value", entry). Msg("Ignoring invalid IP in PULSE_TRUSTED_PROXY_CIDRS") continue } bits := 32 if ip.To4() == nil { bits = 128 } mask := net.CIDRMask(bits, bits) network := &net.IPNet{IP: ip.Mask(mask), Mask: mask} trustedProxyCIDRs = append(trustedProxyCIDRs, network) } } func extractRemoteIP(remoteAddr string) string { if remoteAddr == "" { return "" } if host, _, err := net.SplitHostPort(remoteAddr); err == nil { return strings.Trim(host, "[]") } return strings.Trim(remoteAddr, "[]") } func firstValidForwardedIP(header string) string { if header == "" { return "" } for _, part := range strings.Split(header, ",") { part = strings.TrimSpace(strings.Trim(part, "[]")) if part == "" { continue } if net.ParseIP(part) != nil { return part } } return "" } func isTrustedProxyIP(ipStr string) bool { ipStr = strings.TrimSpace(strings.Trim(ipStr, "[]")) if ipStr == "" { return false } ip := net.ParseIP(ipStr) if ip == nil { return false } trustedProxyOnce.Do(loadTrustedProxyCIDRs) if len(trustedProxyCIDRs) == 0 { return false } for _, network := range trustedProxyCIDRs { if network.Contains(ip) { return true } } return false } func isPrivateIP(ip string) bool { host := extractRemoteIP(ip) if host == "" { return false } parsedIP := net.ParseIP(host) if parsedIP == nil { return false } if parsedIP.IsLoopback() || parsedIP.IsLinkLocalUnicast() || parsedIP.IsLinkLocalMulticast() { return true } privateRanges := []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", "fe80::/10", } for _, cidr := range privateRanges { _, network, err := net.ParseCIDR(cidr) if err != nil { continue } if network.Contains(parsedIP) { return true } } return false } func isTrustedNetwork(ip string, trustedNetworks []string) bool { if len(trustedNetworks) == 0 { return isPrivateIP(ip) } host := extractRemoteIP(ip) if host == "" { return false } parsedIP := net.ParseIP(host) if parsedIP == nil { return false } for _, cidr := range trustedNetworks { _, network, err := net.ParseCIDR(strings.TrimSpace(cidr)) if err != nil { continue } if network.Contains(parsedIP) { return true } } return false } // RecordFailedLogin tracks failed login attempts func RecordFailedLogin(identifier string) { failedMu.Lock() defer failedMu.Unlock() failed, exists := failedLogins[identifier] if !exists { failed = &FailedLogin{} failedLogins[identifier] = failed } failed.Count++ failed.LastAttempt = time.Now() if failed.Count >= maxFailedAttempts { failed.LockedUntil = time.Now().Add(lockoutDuration) log.Warn(). Str("identifier", identifier). Int("attempts", failed.Count). Time("locked_until", failed.LockedUntil). Msg("Account locked due to failed login attempts") } } // ClearFailedLogins resets failed login counter on successful login func ClearFailedLogins(identifier string) { failedMu.Lock() defer failedMu.Unlock() delete(failedLogins, identifier) } // IsLockedOut checks if an account is locked out func IsLockedOut(identifier string) bool { failedMu.RLock() defer failedMu.RUnlock() failed, exists := failedLogins[identifier] if !exists { return false } if time.Now().After(failed.LockedUntil) { // Lockout expired return false } return failed.Count >= maxFailedAttempts } // GetLockoutInfo returns lockout information for an identifier func GetLockoutInfo(identifier string) (attempts int, lockedUntil time.Time, isLocked bool) { failedMu.RLock() defer failedMu.RUnlock() failed, exists := failedLogins[identifier] if !exists { return 0, time.Time{}, false } // Check if lockout has expired if time.Now().After(failed.LockedUntil) && failed.Count >= maxFailedAttempts { // Lockout expired, treat as no attempts return 0, time.Time{}, false } isLocked = failed.Count >= maxFailedAttempts && time.Now().Before(failed.LockedUntil) return failed.Count, failed.LockedUntil, isLocked } // ResetLockout manually resets lockout for an identifier (admin function) func ResetLockout(identifier string) { failedMu.Lock() defer failedMu.Unlock() delete(failedLogins, identifier) log.Info(). Str("identifier", identifier). Msg("Lockout manually reset") } // Security Headers Middleware func SecurityHeaders(next http.Handler) http.Handler { return SecurityHeadersWithConfig(next, false, "") } // SecurityHeadersWithConfig applies security headers with embedding configuration func SecurityHeadersWithConfig(next http.Handler, allowEmbedding bool, allowedOrigins string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Configure clickjacking protection based on embedding settings if allowEmbedding { // When embedding is allowed, don't set X-Frame-Options header // This allows embedding from any origin // Security note: User explicitly enabled this for iframe embedding } else { // Deny all embedding when not explicitly allowed w.Header().Set("X-Frame-Options", "DENY") } // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // Enable XSS protection w.Header().Set("X-XSS-Protection", "1; mode=block") // Build Content Security Policy cspDirectives := []string{ "default-src 'self'", "script-src 'self' 'unsafe-inline' 'unsafe-eval'", // Needed for React "style-src 'self' 'unsafe-inline'", // Needed for inline styles "img-src 'self' data: blob:", "connect-src 'self' ws: wss:", // WebSocket support "font-src 'self' data:", } // Add frame-ancestors based on embedding settings if allowEmbedding { if allowedOrigins != "" { // Parse comma-separated origins and add them to frame-ancestors origins := strings.Split(allowedOrigins, ",") frameAncestors := "frame-ancestors 'self'" for _, origin := range origins { origin = strings.TrimSpace(origin) if origin != "" { frameAncestors += " " + origin } } cspDirectives = append(cspDirectives, frameAncestors) } else { // Allow embedding from any origin (user explicitly enabled this) cspDirectives = append(cspDirectives, "frame-ancestors *") } } else { // Deny all embedding cspDirectives = append(cspDirectives, "frame-ancestors 'none'") } w.Header().Set("Content-Security-Policy", strings.Join(cspDirectives, "; ")) // Referrer Policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Permissions Policy (formerly Feature Policy) w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") next.ServeHTTP(w, r) }) } // Audit Logging type AuditEvent struct { Timestamp time.Time `json:"timestamp"` Event string `json:"event"` User string `json:"user,omitempty"` IP string `json:"ip"` Path string `json:"path,omitempty"` Success bool `json:"success"` Details string `json:"details,omitempty"` } // LogAuditEvent logs security-relevant events func LogAuditEvent(event string, user string, ip string, path string, success bool, details string) { if success { log.Info(). Str("event", event). Str("user", user). Str("ip", ip). Str("path", path). Str("details", details). Time("timestamp", time.Now()). Msg("Security audit event") } else { log.Warn(). Str("event", event). Str("user", user). Str("ip", ip). Str("path", path). Str("details", details). Time("timestamp", time.Now()). Msg("Security audit event - FAILED") } } // Session Management Improvements var ( allSessions = make(map[string][]string) // user -> []sessionIDs sessionsMu sync.RWMutex ) // TrackUserSession tracks which sessions belong to which user func TrackUserSession(user, sessionID string) { sessionsMu.Lock() defer sessionsMu.Unlock() if allSessions[user] == nil { allSessions[user] = []string{} } allSessions[user] = append(allSessions[user], sessionID) } // GetSessionUsername returns the username associated with a session ID func GetSessionUsername(sessionID string) string { sessionsMu.RLock() defer sessionsMu.RUnlock() for user, sessions := range allSessions { for _, sid := range sessions { if sid == sessionID { return user } } } return "" } // InvalidateUserSessions invalidates all sessions for a user (e.g., on password change) func InvalidateUserSessions(user string) { sessionsMu.Lock() defer sessionsMu.Unlock() sessionIDs := allSessions[user] for _, sid := range sessionIDs { // Delete from persistent session store GetSessionStore().DeleteSession(sid) // Delete CSRF tokens GetCSRFStore().DeleteCSRFToken(sid) } delete(allSessions, user) log.Info(). Str("user", user). Int("sessions_invalidated", len(sessionIDs)). Msg("Invalidated all user sessions") }