package api import ( "net" "net/http" "strings" "sync" "time" "github.com/rcourtman/pulse-go-rewrite/internal/utils" "github.com/rcourtman/pulse-go-rewrite/pkg/audit" "github.com/rs/zerolog/log" ) // Security improvements for Pulse // generateCSRFToken creates a new CSRF token for a session func generateCSRFToken(sessionID string) string { return GetCSRFStore().GenerateCSRFToken(sessionID) } // validateCSRFToken checks if a CSRF token is valid for a session func validateCSRFToken(sessionID, token string) bool { return GetCSRFStore().ValidateCSRFToken(sessionID, token) } // CheckCSRF validates CSRF token for state-changing requests func CheckCSRF(w http.ResponseWriter, r *http.Request) bool { // Skip CSRF check for safe methods if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" { return true } // Skip CSRF for API token auth (API clients don't have sessions) if r.Header.Get("X-API-Token") != "" { log.Debug().Str("path", r.URL.Path).Msg("CSRF check skipped: API token auth") return true } // Skip CSRF for Basic Auth (doesn't use sessions, not vulnerable to CSRF) if r.Header.Get("Authorization") != "" { log.Debug().Str("path", r.URL.Path).Msg("CSRF check skipped: Basic Auth header present") return true } // Get session from cookie cookie, err := r.Cookie("pulse_session") if err != nil { // No session cookie means no CSRF check needed // (either no auth configured or using basic auth which doesn't use sessions) log.Debug().Str("path", r.URL.Path).Msg("CSRF check skipped: no session cookie") return true } // Get CSRF token from header or form csrfToken := r.Header.Get("X-CSRF-Token") if csrfToken == "" { csrfToken = r.FormValue("csrf_token") } // Log CSRF validation attempt for debugging log.Debug(). Str("path", r.URL.Path). Str("method", r.Method). Str("session", safePrefixForLog(cookie.Value, 8)+"..."). Bool("has_csrf_token", csrfToken != ""). Msg("CSRF validation attempt") // No CSRF token means request is not eligible for mutation if csrfToken == "" { log.Warn(). Str("path", r.URL.Path). Str("session", safePrefixForLog(cookie.Value, 8)+"..."). Msg("Missing CSRF token") clearCSRFCookie(w, r) if newToken := issueNewCSRFCookie(w, r, cookie.Value); newToken != "" { w.Header().Set("X-CSRF-Token", newToken) log.Debug().Str("new_token", safePrefixForLog(newToken, 8)+"...").Msg("Issued new CSRF token after missing") } return false } // Check if the CSRF token validates if !validateCSRFToken(cookie.Value, csrfToken) { log.Warn(). Str("path", r.URL.Path). Str("session", safePrefixForLog(cookie.Value, 8)+"..."). Str("provided_token", safePrefixForLog(csrfToken, 8)+"..."). Msg("Invalid CSRF token") clearCSRFCookie(w, r) if newToken := issueNewCSRFCookie(w, r, cookie.Value); newToken != "" { w.Header().Set("X-CSRF-Token", newToken) log.Debug().Str("new_token", safePrefixForLog(newToken, 8)+"...").Msg("Issued new CSRF token after invalid") } return false } log.Debug(). Str("path", r.URL.Path). Str("session", safePrefixForLog(cookie.Value, 8)+"..."). Msg("CSRF validation successful") return true } func clearCSRFCookie(w http.ResponseWriter, r *http.Request) { if w == nil { return } secure, sameSite := getCookieSettings(r) http.SetCookie(w, &http.Cookie{ Name: "pulse_csrf", Value: "", Path: "/", MaxAge: -1, HttpOnly: false, Secure: secure, SameSite: sameSite, }) } func issueNewCSRFCookie(w http.ResponseWriter, r *http.Request, sessionID string) string { if w == nil || r == nil { return "" } if strings.TrimSpace(sessionID) == "" { return "" } newToken := generateCSRFToken(sessionID) secure, sameSite := getCookieSettings(r) http.SetCookie(w, &http.Cookie{ Name: "pulse_csrf", Value: newToken, Path: "/", Secure: secure, SameSite: sameSite, MaxAge: 86400, }) return newToken } // Rate Limiting - using existing RateLimiter from ratelimit.go var ( // Auth endpoints: 10 attempts per minute authLimiter = NewRateLimiter(10, 1*time.Minute) ) // GetClientIP extracts the client IP from the request func GetClientIP(r *http.Request) string { rawRemoteIP := extractRemoteIP(r.RemoteAddr) if rawRemoteIP == "" { return "" } // Only trust proxy headers when the immediate peer is trusted. if isTrustedProxyIP(rawRemoteIP) { if forwarded := firstValidForwardedIP(r.Header.Get("X-Forwarded-For")); forwarded != "" { return forwarded } if realIP := strings.TrimSpace(strings.Trim(r.Header.Get("X-Real-IP"), "[]")); realIP != "" && net.ParseIP(realIP) != nil { return realIP } } return rawRemoteIP } // Failed Login Tracking type FailedLogin struct { Count int LastAttempt time.Time LockedUntil time.Time } var ( failedLogins = make(map[string]*FailedLogin) failedMu sync.RWMutex maxFailedAttempts = 5 lockoutDuration = 15 * time.Minute trustedProxyOnce sync.Once trustedProxyCIDRs []*net.IPNet ) func loadTrustedProxyCIDRs() { raw := utils.GetenvTrim("PULSE_TRUSTED_PROXY_CIDRS") if raw == "" { return } for _, entry := range strings.Split(raw, ",") { entry = strings.TrimSpace(entry) if entry == "" { continue } if strings.Contains(entry, "/") { _, network, parseErr := net.ParseCIDR(entry) if parseErr == nil { network.IP = network.IP.Mask(network.Mask) trustedProxyCIDRs = append(trustedProxyCIDRs, network) continue } log.Warn(). Str("cidr", entry). Err(parseErr). Msg("Ignoring invalid CIDR in PULSE_TRUSTED_PROXY_CIDRS") continue } ip := net.ParseIP(entry) if ip == nil { log.Warn(). Str("value", entry). Msg("Ignoring invalid IP in PULSE_TRUSTED_PROXY_CIDRS") continue } bits := 32 if ip.To4() == nil { bits = 128 } mask := net.CIDRMask(bits, bits) network := &net.IPNet{IP: ip.Mask(mask), Mask: mask} trustedProxyCIDRs = append(trustedProxyCIDRs, network) } } func extractRemoteIP(remoteAddr string) string { if remoteAddr == "" { return "" } if host, _, err := net.SplitHostPort(remoteAddr); err == nil { return strings.Trim(host, "[]") } return strings.Trim(remoteAddr, "[]") } func firstValidForwardedIP(header string) string { if header == "" { return "" } for _, part := range strings.Split(header, ",") { part = strings.TrimSpace(strings.Trim(part, "[]")) if part == "" { continue } if net.ParseIP(part) != nil { return part } } return "" } func isTrustedProxyIP(ipStr string) bool { ipStr = strings.TrimSpace(strings.Trim(ipStr, "[]")) if ipStr == "" { return false } ip := net.ParseIP(ipStr) if ip == nil { return false } trustedProxyOnce.Do(loadTrustedProxyCIDRs) if len(trustedProxyCIDRs) == 0 { return false } for _, network := range trustedProxyCIDRs { if network.Contains(ip) { return true } } return false } func isPrivateIP(ip string) bool { host := extractRemoteIP(ip) if host == "" { return false } parsedIP := net.ParseIP(host) if parsedIP == nil { return false } if parsedIP.IsLoopback() || parsedIP.IsLinkLocalUnicast() || parsedIP.IsLinkLocalMulticast() { return true } privateRanges := []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", "fe80::/10", } for _, cidr := range privateRanges { _, network, err := net.ParseCIDR(cidr) if err != nil { continue } if network.Contains(parsedIP) { return true } } return false } func isTrustedNetwork(ip string, trustedNetworks []string) bool { if len(trustedNetworks) == 0 { return isPrivateIP(ip) } host := extractRemoteIP(ip) if host == "" { return false } parsedIP := net.ParseIP(host) if parsedIP == nil { return false } for _, cidr := range trustedNetworks { _, network, err := net.ParseCIDR(strings.TrimSpace(cidr)) if err != nil { continue } if network.Contains(parsedIP) { return true } } return false } // RecordFailedLogin tracks failed login attempts func RecordFailedLogin(identifier string) { failedMu.Lock() defer failedMu.Unlock() failed, exists := failedLogins[identifier] if !exists { failed = &FailedLogin{} failedLogins[identifier] = failed } failed.Count++ failed.LastAttempt = time.Now() if failed.Count >= maxFailedAttempts { failed.LockedUntil = time.Now().Add(lockoutDuration) log.Warn(). Str("identifier", identifier). Int("attempts", failed.Count). Time("locked_until", failed.LockedUntil). Msg("Account locked due to failed login attempts") } } // ClearFailedLogins resets failed login counter on successful login func ClearFailedLogins(identifier string) { failedMu.Lock() defer failedMu.Unlock() delete(failedLogins, identifier) } // GetLockoutInfo returns lockout information for an identifier func GetLockoutInfo(identifier string) (attempts int, lockedUntil time.Time, isLocked bool) { failedMu.RLock() defer failedMu.RUnlock() failed, exists := failedLogins[identifier] if !exists { return 0, time.Time{}, false } // Check if lockout has expired if time.Now().After(failed.LockedUntil) && failed.Count >= maxFailedAttempts { // Lockout expired, treat as no attempts return 0, time.Time{}, false } isLocked = failed.Count >= maxFailedAttempts && time.Now().Before(failed.LockedUntil) return failed.Count, failed.LockedUntil, isLocked } // ResetLockout manually resets lockout for an identifier (admin function) func ResetLockout(identifier string) { failedMu.Lock() defer failedMu.Unlock() delete(failedLogins, identifier) log.Info(). Str("identifier", identifier). Msg("Lockout manually reset") } // SecurityHeadersWithConfig applies security headers with embedding configuration func SecurityHeadersWithConfig(next http.Handler, allowEmbedding bool, allowedOrigins string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Configure clickjacking protection based on embedding settings if allowEmbedding { // When embedding is allowed, don't set X-Frame-Options header // This allows embedding from any origin // Security note: User explicitly enabled this for iframe embedding } else { // Deny all embedding when not explicitly allowed w.Header().Set("X-Frame-Options", "DENY") } // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // Enable XSS protection w.Header().Set("X-XSS-Protection", "1; mode=block") // Build Content Security Policy cspDirectives := []string{ "default-src 'self'", "script-src 'self' 'unsafe-inline' 'unsafe-eval'", // Needed for React "style-src 'self' 'unsafe-inline'", // Needed for inline styles "img-src 'self' data: blob:", "connect-src 'self' ws: wss:", // WebSocket support "font-src 'self' data:", } // Add frame-ancestors based on embedding settings if allowEmbedding { if allowedOrigins != "" { // Parse comma-separated origins and add them to frame-ancestors origins := strings.Split(allowedOrigins, ",") frameAncestors := "frame-ancestors 'self'" for _, origin := range origins { origin = strings.TrimSpace(origin) if origin != "" { frameAncestors += " " + origin } } cspDirectives = append(cspDirectives, frameAncestors) } else { // Allow embedding from any origin (user explicitly enabled this) cspDirectives = append(cspDirectives, "frame-ancestors *") } } else { // Deny all embedding cspDirectives = append(cspDirectives, "frame-ancestors 'none'") } w.Header().Set("Content-Security-Policy", strings.Join(cspDirectives, "; ")) // Referrer Policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Permissions Policy (formerly Feature Policy) w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") next.ServeHTTP(w, r) }) } // Audit Logging type AuditEvent struct { Timestamp time.Time `json:"timestamp"` Event string `json:"event"` User string `json:"user,omitempty"` IP string `json:"ip"` Path string `json:"path,omitempty"` Success bool `json:"success"` Details string `json:"details,omitempty"` } // LogAuditEvent logs security-relevant events using the audit package. // This function delegates to the configured audit.Logger, allowing enterprise // versions to provide persistent storage and signing. // For tenant-aware logging, use LogAuditEventForTenant instead. func LogAuditEvent(event string, user string, ip string, path string, success bool, details string) { audit.Log(event, user, ip, path, success, details) } // LogAuditEventForTenant logs security-relevant events to a tenant-specific audit log. // Uses the TenantLoggerManager to route events to the correct tenant's audit database. func LogAuditEventForTenant(orgID, event, user, ip, path string, success bool, details string) { manager := GetTenantAuditManager() if manager == nil { // Fall back to global logger audit.Log(event, user, ip, path, success, details) return } if err := manager.Log(orgID, event, user, ip, path, success, details); err != nil { // If tenant logging fails, fall back to global logger audit.Log(event, user, ip, path, success, details) } } // global tenant audit manager var ( tenantAuditManager *audit.TenantLoggerManager tenantAuditManagerMu sync.RWMutex ) // SetTenantAuditManager sets the global tenant audit manager. func SetTenantAuditManager(manager *audit.TenantLoggerManager) { tenantAuditManagerMu.Lock() defer tenantAuditManagerMu.Unlock() tenantAuditManager = manager } // GetTenantAuditManager returns the global tenant audit manager. func GetTenantAuditManager() *audit.TenantLoggerManager { tenantAuditManagerMu.RLock() defer tenantAuditManagerMu.RUnlock() return tenantAuditManager } // Session Management Improvements var ( allSessions = make(map[string][]string) // user -> []sessionIDs sessionsMu sync.RWMutex ) // TrackUserSession tracks which sessions belong to which user func TrackUserSession(user, sessionID string) { sessionsMu.Lock() defer sessionsMu.Unlock() if allSessions[user] == nil { allSessions[user] = []string{} } allSessions[user] = append(allSessions[user], sessionID) } // GetSessionUsername returns the username associated with a session ID func GetSessionUsername(sessionID string) string { // First check in-memory map sessionsMu.RLock() for user, sessions := range allSessions { for _, sid := range sessions { if sid == sessionID { sessionsMu.RUnlock() return user } } } sessionsMu.RUnlock() // Fall back to persisted username in session store (survives restarts) if session := GetSessionStore().GetSession(sessionID); session != nil && session.Username != "" { // Re-populate in-memory map for faster future lookups TrackUserSession(session.Username, sessionID) return session.Username } return "" } // InvalidateUserSessions invalidates all sessions for a user (e.g., on password change) func InvalidateUserSessions(user string) { sessionsMu.Lock() defer sessionsMu.Unlock() sessionIDs := allSessions[user] for _, sid := range sessionIDs { // Delete from persistent session store GetSessionStore().DeleteSession(sid) // Delete CSRF tokens GetCSRFStore().DeleteCSRFToken(sid) } delete(allSessions, user) log.Info(). Str("user", user). Int("sessions_invalidated", len(sessionIDs)). Msg("Invalidated all user sessions") } // UntrackUserSession removes a single session from a user's session list // (used for single session logout, not password change which clears all) func UntrackUserSession(user, sessionID string) { sessionsMu.Lock() defer sessionsMu.Unlock() sessions := allSessions[user] for i, sid := range sessions { if sid == sessionID { allSessions[user] = append(sessions[:i], sessions[i+1:]...) break } } }