mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 00:37:36 +00:00
782 lines
22 KiB
Go
782 lines
22 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/utils"
|
|
"github.com/rcourtman/pulse-go-rewrite/pkg/audit"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// cspNonceKey is the context key for the per-request CSP nonce.
|
|
type cspNonceKey struct{}
|
|
|
|
// CSPNonceFromContext returns the CSP nonce stored in the request context, or ""
|
|
// if none is present (e.g. dev mode).
|
|
func CSPNonceFromContext(ctx context.Context) string {
|
|
if v, ok := ctx.Value(cspNonceKey{}).(string); ok {
|
|
return v
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// generateCSPNonce returns a 16-byte (128-bit) base64-encoded cryptographic nonce.
|
|
func generateCSPNonce() string {
|
|
b := make([]byte, 16)
|
|
if _, err := rand.Read(b); err != nil {
|
|
// crypto/rand should never fail; if it does, fall back to empty which
|
|
// will leave the CSP without a nonce (equivalent to current behaviour).
|
|
log.Error().Err(err).Msg("CSP nonce generation failed — falling back to unsafe-inline")
|
|
return ""
|
|
}
|
|
return base64.StdEncoding.EncodeToString(b)
|
|
}
|
|
|
|
// Security improvements for Pulse
|
|
|
|
// generateCSRFToken creates a new CSRF token for a session
|
|
func generateCSRFToken(sessionID string) string {
|
|
return GetCSRFStore().GenerateCSRFToken(sessionID)
|
|
}
|
|
|
|
// validateCSRFToken checks if a CSRF token is valid for a session
|
|
func validateCSRFToken(sessionID, token string) bool {
|
|
return GetCSRFStore().ValidateCSRFToken(sessionID, token)
|
|
}
|
|
|
|
// CheckCSRF validates CSRF token for state-changing requests
|
|
func CheckCSRF(w http.ResponseWriter, r *http.Request) bool {
|
|
// Skip CSRF check for safe methods
|
|
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
|
return true
|
|
}
|
|
|
|
// Skip CSRF for API token auth (API clients don't have sessions)
|
|
if r.Header.Get("X-API-Token") != "" {
|
|
log.Debug().Str("path", r.URL.Path).Msg("CSRF check skipped: API token auth")
|
|
return true
|
|
}
|
|
|
|
// Skip CSRF only for explicit non-session auth schemes.
|
|
if authHeader := strings.TrimSpace(r.Header.Get("Authorization")); authHeader != "" {
|
|
lower := strings.ToLower(authHeader)
|
|
if strings.HasPrefix(lower, "basic ") {
|
|
log.Debug().Str("path", r.URL.Path).Msg("CSRF check skipped: Basic auth header present")
|
|
return true
|
|
}
|
|
if strings.HasPrefix(lower, "bearer ") {
|
|
log.Debug().Str("path", r.URL.Path).Msg("CSRF check skipped: Bearer auth header present")
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Get session from cookie
|
|
cookie, err := readSessionCookie(r)
|
|
if err != nil {
|
|
// No session cookie means no CSRF check needed
|
|
// (either no auth configured or using basic auth which doesn't use sessions)
|
|
log.Debug().Str("path", r.URL.Path).Msg("CSRF check skipped: no session cookie")
|
|
return true
|
|
}
|
|
|
|
// Get CSRF token from header or form
|
|
csrfToken := r.Header.Get("X-CSRF-Token")
|
|
if csrfToken == "" {
|
|
csrfToken = r.FormValue("csrf_token")
|
|
}
|
|
|
|
// Log CSRF validation attempt for debugging
|
|
log.Debug().
|
|
Str("path", r.URL.Path).
|
|
Str("method", r.Method).
|
|
Str("session", safePrefixForLog(cookie.Value, 8)+"...").
|
|
Bool("has_csrf_token", csrfToken != "").
|
|
Msg("CSRF validation attempt")
|
|
|
|
// No CSRF token means request is not eligible for mutation
|
|
if csrfToken == "" {
|
|
log.Warn().
|
|
Str("path", r.URL.Path).
|
|
Str("session", safePrefixForLog(cookie.Value, 8)+"...").
|
|
Msg("Missing CSRF token")
|
|
clearCSRFCookie(w, r)
|
|
if newToken := issueNewCSRFCookie(w, r, cookie.Value); newToken != "" {
|
|
w.Header().Set("X-CSRF-Token", newToken)
|
|
log.Debug().Str("new_token", safePrefixForLog(newToken, 8)+"...").Msg("Issued new CSRF token after missing")
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Check if the CSRF token validates
|
|
if !validateCSRFToken(cookie.Value, csrfToken) {
|
|
log.Warn().
|
|
Str("path", r.URL.Path).
|
|
Str("session", safePrefixForLog(cookie.Value, 8)+"...").
|
|
Str("provided_token", safePrefixForLog(csrfToken, 8)+"...").
|
|
Msg("Invalid CSRF token")
|
|
clearCSRFCookie(w, r)
|
|
if newToken := issueNewCSRFCookie(w, r, cookie.Value); newToken != "" {
|
|
w.Header().Set("X-CSRF-Token", newToken)
|
|
log.Debug().Str("new_token", safePrefixForLog(newToken, 8)+"...").Msg("Issued new CSRF token after invalid")
|
|
}
|
|
return false
|
|
}
|
|
|
|
log.Debug().
|
|
Str("path", r.URL.Path).
|
|
Str("session", safePrefixForLog(cookie.Value, 8)+"...").
|
|
Msg("CSRF validation successful")
|
|
return true
|
|
}
|
|
|
|
func clearCSRFCookie(w http.ResponseWriter, r *http.Request) {
|
|
if w == nil {
|
|
return
|
|
}
|
|
var secure bool
|
|
var sameSite http.SameSite
|
|
if r != nil {
|
|
secure, sameSite = getCookieSettings(r)
|
|
}
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: CookieNameCSRF,
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: false,
|
|
Secure: secure,
|
|
SameSite: sameSite,
|
|
})
|
|
}
|
|
|
|
func issueNewCSRFCookie(w http.ResponseWriter, r *http.Request, sessionID string) string {
|
|
if w == nil || r == nil {
|
|
return ""
|
|
}
|
|
if strings.TrimSpace(sessionID) == "" {
|
|
return ""
|
|
}
|
|
|
|
newToken := generateCSRFToken(sessionID)
|
|
secure, sameSite := getCookieSettings(r)
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: CookieNameCSRF,
|
|
Value: newToken,
|
|
Path: "/",
|
|
Secure: secure,
|
|
SameSite: sameSite,
|
|
MaxAge: 86400,
|
|
})
|
|
return newToken
|
|
}
|
|
|
|
// Rate Limiting - using existing RateLimiter from ratelimit.go
|
|
var (
|
|
// Auth endpoints: 10 attempts per minute
|
|
authLimiter = NewRateLimiter(10, 1*time.Minute)
|
|
)
|
|
|
|
// GetClientIP extracts the client IP from the request
|
|
func GetClientIP(r *http.Request) string {
|
|
rawRemoteIP := extractRemoteIP(r.RemoteAddr)
|
|
if rawRemoteIP == "" {
|
|
return ""
|
|
}
|
|
|
|
// Only trust proxy headers when the immediate peer is trusted.
|
|
if isTrustedProxyIP(rawRemoteIP) {
|
|
if forwarded := firstValidForwardedIP(r.Header.Get("X-Forwarded-For")); forwarded != "" {
|
|
return forwarded
|
|
}
|
|
|
|
if realIP := strings.TrimSpace(strings.Trim(r.Header.Get("X-Real-IP"), "[]")); realIP != "" && net.ParseIP(realIP) != nil {
|
|
return realIP
|
|
}
|
|
}
|
|
|
|
return rawRemoteIP
|
|
}
|
|
|
|
// Failed Login Tracking
|
|
type FailedLogin struct {
|
|
Count int
|
|
LastAttempt time.Time
|
|
LockedUntil time.Time
|
|
}
|
|
|
|
var (
|
|
failedLogins = make(map[string]*FailedLogin)
|
|
failedMu sync.RWMutex
|
|
|
|
maxFailedAttempts = 5
|
|
lockoutDuration = 15 * time.Minute
|
|
|
|
trustedProxyOnce sync.Once
|
|
trustedProxyCIDRs []*net.IPNet
|
|
)
|
|
|
|
func loadTrustedProxyCIDRs() {
|
|
raw := utils.GetenvTrim("PULSE_TRUSTED_PROXY_CIDRS")
|
|
if raw == "" {
|
|
return
|
|
}
|
|
|
|
for _, entry := range strings.Split(raw, ",") {
|
|
entry = strings.TrimSpace(entry)
|
|
if entry == "" {
|
|
continue
|
|
}
|
|
|
|
if strings.Contains(entry, "/") {
|
|
_, network, parseErr := net.ParseCIDR(entry)
|
|
if parseErr == nil {
|
|
network.IP = network.IP.Mask(network.Mask)
|
|
trustedProxyCIDRs = append(trustedProxyCIDRs, network)
|
|
continue
|
|
}
|
|
log.Warn().
|
|
Str("cidr", entry).
|
|
Err(parseErr).
|
|
Msg("Ignoring invalid CIDR in PULSE_TRUSTED_PROXY_CIDRS")
|
|
continue
|
|
}
|
|
|
|
ip := net.ParseIP(entry)
|
|
if ip == nil {
|
|
log.Warn().
|
|
Str("value", entry).
|
|
Msg("Ignoring invalid IP in PULSE_TRUSTED_PROXY_CIDRS")
|
|
continue
|
|
}
|
|
|
|
bits := 32
|
|
if ip.To4() == nil {
|
|
bits = 128
|
|
}
|
|
mask := net.CIDRMask(bits, bits)
|
|
network := &net.IPNet{IP: ip.Mask(mask), Mask: mask}
|
|
trustedProxyCIDRs = append(trustedProxyCIDRs, network)
|
|
}
|
|
}
|
|
|
|
func extractRemoteIP(remoteAddr string) string {
|
|
if remoteAddr == "" {
|
|
return ""
|
|
}
|
|
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
|
return strings.Trim(host, "[]")
|
|
}
|
|
return strings.Trim(remoteAddr, "[]")
|
|
}
|
|
|
|
func firstValidForwardedIP(header string) string {
|
|
if header == "" {
|
|
return ""
|
|
}
|
|
for _, part := range strings.Split(header, ",") {
|
|
part = strings.TrimSpace(strings.Trim(part, "[]"))
|
|
if part == "" {
|
|
continue
|
|
}
|
|
|
|
if net.ParseIP(part) != nil {
|
|
return part
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// IsTrustedProxyIP reports whether ipStr belongs to a CIDR in PULSE_TRUSTED_PROXY_CIDRS.
|
|
// Exported so the websocket hub can gate X-Forwarded-* trust on the same list.
|
|
func IsTrustedProxyIP(ipStr string) bool {
|
|
return isTrustedProxyIP(ipStr)
|
|
}
|
|
|
|
func isTrustedProxyIP(ipStr string) bool {
|
|
ipStr = strings.TrimSpace(strings.Trim(ipStr, "[]"))
|
|
if ipStr == "" {
|
|
return false
|
|
}
|
|
ip := net.ParseIP(ipStr)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
|
|
trustedProxyOnce.Do(loadTrustedProxyCIDRs)
|
|
if len(trustedProxyCIDRs) == 0 {
|
|
return false
|
|
}
|
|
for _, network := range trustedProxyCIDRs {
|
|
if network.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isPrivateIP(ip string) bool {
|
|
host := extractRemoteIP(ip)
|
|
if host == "" {
|
|
return false
|
|
}
|
|
|
|
parsedIP := net.ParseIP(host)
|
|
if parsedIP == nil {
|
|
return false
|
|
}
|
|
|
|
if parsedIP.IsLoopback() ||
|
|
parsedIP.IsLinkLocalUnicast() ||
|
|
parsedIP.IsLinkLocalMulticast() {
|
|
return true
|
|
}
|
|
|
|
privateRanges := []string{
|
|
"10.0.0.0/8",
|
|
"172.16.0.0/12",
|
|
"192.168.0.0/16",
|
|
"127.0.0.0/8",
|
|
"::1/128",
|
|
"fc00::/7",
|
|
"fe80::/10",
|
|
}
|
|
|
|
for _, cidr := range privateRanges {
|
|
_, network, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if network.Contains(parsedIP) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isTrustedNetwork(ip string, trustedNetworks []string) bool {
|
|
if len(trustedNetworks) == 0 {
|
|
return isPrivateIP(ip)
|
|
}
|
|
|
|
host := extractRemoteIP(ip)
|
|
if host == "" {
|
|
return false
|
|
}
|
|
|
|
parsedIP := net.ParseIP(host)
|
|
if parsedIP == nil {
|
|
return false
|
|
}
|
|
|
|
for _, cidr := range trustedNetworks {
|
|
_, network, err := net.ParseCIDR(strings.TrimSpace(cidr))
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if network.Contains(parsedIP) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func init() {
|
|
// Periodically purge expired failedLogins entries to prevent unbounded
|
|
// memory growth from brute-force attempts with many distinct identifiers.
|
|
go func() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
cleanupExpiredFailedLogins()
|
|
}
|
|
}()
|
|
}
|
|
|
|
func cleanupExpiredFailedLogins() {
|
|
now := time.Now()
|
|
failedMu.Lock()
|
|
defer failedMu.Unlock()
|
|
for id, f := range failedLogins {
|
|
// Remove entries whose lockout has expired (or were never locked) and
|
|
// haven't had a new attempt in 2x the lockout window.
|
|
if now.After(f.LockedUntil) && now.Sub(f.LastAttempt) > 2*lockoutDuration {
|
|
delete(failedLogins, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
// RecordFailedLogin tracks failed login attempts
|
|
func RecordFailedLogin(identifier string) {
|
|
failedMu.Lock()
|
|
defer failedMu.Unlock()
|
|
|
|
failed, exists := failedLogins[identifier]
|
|
if !exists {
|
|
failed = &FailedLogin{}
|
|
failedLogins[identifier] = failed
|
|
}
|
|
|
|
failed.Count++
|
|
failed.LastAttempt = time.Now()
|
|
|
|
if failed.Count >= maxFailedAttempts {
|
|
failed.LockedUntil = time.Now().Add(lockoutDuration)
|
|
log.Warn().
|
|
Str("identifier", identifier).
|
|
Int("attempts", failed.Count).
|
|
Time("locked_until", failed.LockedUntil).
|
|
Msg("Account locked due to failed login attempts")
|
|
}
|
|
}
|
|
|
|
// ClearFailedLogins resets failed login counter on successful login
|
|
func ClearFailedLogins(identifier string) {
|
|
failedMu.Lock()
|
|
defer failedMu.Unlock()
|
|
delete(failedLogins, identifier)
|
|
}
|
|
|
|
// GetLockoutInfo returns lockout information for an identifier
|
|
func GetLockoutInfo(identifier string) (attempts int, lockedUntil time.Time, isLocked bool) {
|
|
failedMu.RLock()
|
|
defer failedMu.RUnlock()
|
|
|
|
failed, exists := failedLogins[identifier]
|
|
if !exists {
|
|
return 0, time.Time{}, false
|
|
}
|
|
|
|
// Check if lockout has expired
|
|
if time.Now().After(failed.LockedUntil) && failed.Count >= maxFailedAttempts {
|
|
// Lockout expired, treat as no attempts
|
|
return 0, time.Time{}, false
|
|
}
|
|
|
|
isLocked = failed.Count >= maxFailedAttempts && time.Now().Before(failed.LockedUntil)
|
|
return failed.Count, failed.LockedUntil, isLocked
|
|
}
|
|
|
|
// ResetLockout manually resets lockout for an identifier (admin function)
|
|
func ResetLockout(identifier string) {
|
|
failedMu.Lock()
|
|
defer failedMu.Unlock()
|
|
delete(failedLogins, identifier)
|
|
|
|
log.Info().
|
|
Str("identifier", identifier).
|
|
Msg("Lockout manually reset")
|
|
}
|
|
|
|
// SecurityHeadersWithConfig applies security headers with embedding configuration.
|
|
// When devMode is false, a per-request cryptographic nonce is generated and used
|
|
// in script-src / style-src instead of 'unsafe-inline'/'unsafe-eval'.
|
|
// When devMode is true, 'unsafe-inline' and 'unsafe-eval' are kept so Vite HMR works.
|
|
func SecurityHeadersWithConfig(next http.Handler, allowEmbedding bool, allowedOrigins string, devMode bool) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Generate a per-request nonce in production mode and store it in context
|
|
// so downstream handlers (e.g. serveIndexWithNonce) can inject it into HTML.
|
|
var nonce string
|
|
if !devMode {
|
|
nonce = generateCSPNonce()
|
|
if nonce != "" {
|
|
ctx := context.WithValue(r.Context(), cspNonceKey{}, nonce)
|
|
r = r.WithContext(ctx)
|
|
}
|
|
}
|
|
|
|
// Configure clickjacking protection based on embedding settings
|
|
if allowEmbedding {
|
|
// When embedding is allowed, don't set X-Frame-Options header
|
|
// frame-ancestors CSP directive controls allowed embed origins below
|
|
// Security note: User explicitly enabled this for iframe embedding
|
|
} else {
|
|
// Deny all embedding when not explicitly allowed
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
}
|
|
|
|
// Prevent MIME type sniffing
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
|
|
// Disable legacy XSS auditor — it is removed from modern browsers and
|
|
// can introduce vulnerabilities in older ones. CSP provides XSS protection.
|
|
w.Header().Set("X-XSS-Protection", "0")
|
|
|
|
// Build Content Security Policy.
|
|
// In production, use nonce-based directives — browsers that support nonces
|
|
// (CSP Level 2+) automatically ignore 'unsafe-inline' when a nonce is present,
|
|
// so we omit 'unsafe-inline' entirely.
|
|
// In dev mode, keep 'unsafe-inline'/'unsafe-eval' for Vite HMR scripts.
|
|
var scriptSrc, styleSrc string
|
|
if devMode {
|
|
scriptSrc = "script-src 'self' 'unsafe-inline' 'unsafe-eval'"
|
|
styleSrc = "style-src 'self' 'unsafe-inline'"
|
|
} else {
|
|
if nonce != "" {
|
|
scriptSrc = "script-src 'self' 'nonce-" + nonce + "'"
|
|
styleSrc = "style-src 'self' 'nonce-" + nonce + "'"
|
|
} else {
|
|
// Fallback if nonce generation failed — keep unsafe-inline so the
|
|
// app still works, though with weaker CSP protection.
|
|
scriptSrc = "script-src 'self' 'unsafe-inline' 'unsafe-eval'"
|
|
styleSrc = "style-src 'self' 'unsafe-inline'"
|
|
}
|
|
}
|
|
|
|
cspDirectives := []string{
|
|
"default-src 'self'",
|
|
scriptSrc,
|
|
styleSrc,
|
|
"img-src 'self' data: blob:",
|
|
"connect-src 'self' ws: wss:", // WebSocket support
|
|
"font-src 'self' data:",
|
|
}
|
|
|
|
// Add frame-ancestors based on embedding settings
|
|
if allowEmbedding {
|
|
if allowedOrigins != "" {
|
|
// Parse comma-separated origins and add them to frame-ancestors
|
|
origins := strings.Split(allowedOrigins, ",")
|
|
frameAncestors := "frame-ancestors 'self'"
|
|
for _, origin := range origins {
|
|
origin = strings.TrimSpace(origin)
|
|
if origin != "" {
|
|
frameAncestors += " " + origin
|
|
}
|
|
}
|
|
cspDirectives = append(cspDirectives, frameAncestors)
|
|
} else {
|
|
// Default to self-only when embedding is enabled but no specific origins configured.
|
|
// This prevents clickjacking while still allowing same-origin iframes.
|
|
cspDirectives = append(cspDirectives, "frame-ancestors 'self'")
|
|
}
|
|
} else {
|
|
// Deny all embedding
|
|
cspDirectives = append(cspDirectives, "frame-ancestors 'none'")
|
|
}
|
|
|
|
// Upgrade HTTP sub-resource requests to HTTPS when the page is served over HTTPS
|
|
if shouldSetHSTS(r) {
|
|
cspDirectives = append(cspDirectives, "upgrade-insecure-requests")
|
|
}
|
|
|
|
w.Header().Set("Content-Security-Policy", strings.Join(cspDirectives, "; "))
|
|
|
|
// Referrer Policy
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
// Permissions Policy (formerly Feature Policy)
|
|
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), accelerometer=()")
|
|
|
|
// Enable HSTS only for requests known to be HTTPS.
|
|
// Forwarded proto is trusted only when the direct peer is a trusted proxy.
|
|
if shouldSetHSTS(r) {
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func shouldSetHSTS(r *http.Request) bool {
|
|
if r == nil {
|
|
return false
|
|
}
|
|
if r.TLS != nil {
|
|
return true
|
|
}
|
|
|
|
peerIP := extractRemoteIP(r.RemoteAddr)
|
|
if !isTrustedProxyIP(peerIP) {
|
|
return false
|
|
}
|
|
|
|
proto := strings.ToLower(firstForwardedValue(r.Header.Get("X-Forwarded-Proto")))
|
|
if proto == "" {
|
|
proto = strings.ToLower(firstForwardedValue(r.Header.Get("X-Forwarded-Scheme")))
|
|
}
|
|
return proto == "https"
|
|
}
|
|
|
|
// LogAuditEvent logs security-relevant events using the audit package.
|
|
// This function delegates to the configured audit.Logger, allowing enterprise
|
|
// versions to provide persistent storage and signing.
|
|
// For tenant-aware logging, use LogAuditEventForTenant instead.
|
|
func LogAuditEvent(event string, user string, ip string, path string, success bool, details string) {
|
|
audit.Log(event, user, ip, path, success, details)
|
|
}
|
|
|
|
// LogAuditEventForTenant logs security-relevant events to a tenant-specific audit log.
|
|
// Uses the TenantLoggerManager to route events to the correct tenant's audit database.
|
|
func LogAuditEventForTenant(orgID, event, user, ip, path string, success bool, details string) {
|
|
manager := GetTenantAuditManager()
|
|
if manager == nil {
|
|
// Fall back to global logger
|
|
audit.Log(event, user, ip, path, success, details)
|
|
return
|
|
}
|
|
if err := manager.Log(orgID, event, user, ip, path, success, details); err != nil {
|
|
// If tenant logging fails, fall back to global logger
|
|
audit.Log(event, user, ip, path, success, details)
|
|
}
|
|
}
|
|
|
|
// global tenant audit manager
|
|
var (
|
|
tenantAuditManager *audit.TenantLoggerManager
|
|
tenantAuditManagerMu sync.RWMutex
|
|
)
|
|
|
|
// SetTenantAuditManager sets the global tenant audit manager.
|
|
func SetTenantAuditManager(manager *audit.TenantLoggerManager) {
|
|
tenantAuditManagerMu.Lock()
|
|
defer tenantAuditManagerMu.Unlock()
|
|
tenantAuditManager = manager
|
|
}
|
|
|
|
// GetTenantAuditManager returns the global tenant audit manager.
|
|
func GetTenantAuditManager() *audit.TenantLoggerManager {
|
|
tenantAuditManagerMu.RLock()
|
|
defer tenantAuditManagerMu.RUnlock()
|
|
return tenantAuditManager
|
|
}
|
|
|
|
// Session Management Improvements
|
|
var (
|
|
allSessions = make(map[string][]string) // user -> []sessionIDs
|
|
sessionsMu sync.RWMutex
|
|
)
|
|
|
|
// maxSessionsPerUser limits concurrent sessions to prevent session accumulation.
|
|
const maxSessionsPerUser = 10
|
|
|
|
// TrackUserSession tracks which sessions belong to which user.
|
|
// When the limit is exceeded, the oldest sessions are evicted.
|
|
func TrackUserSession(user, sessionID string) {
|
|
sessionsMu.Lock()
|
|
defer sessionsMu.Unlock()
|
|
|
|
if allSessions[user] == nil {
|
|
allSessions[user] = []string{}
|
|
}
|
|
|
|
// Add the new session
|
|
allSessions[user] = append(allSessions[user], sessionID)
|
|
|
|
// If near the limit, prune stale session IDs (already deleted via logout/
|
|
// rotation/expiry) before evicting valid sessions.
|
|
if len(allSessions[user]) > maxSessionsPerUser {
|
|
store := GetSessionStore()
|
|
alive := make([]string, 0, len(allSessions[user]))
|
|
for _, sid := range allSessions[user] {
|
|
if store.ValidateSession(sid) {
|
|
alive = append(alive, sid)
|
|
}
|
|
}
|
|
allSessions[user] = alive
|
|
|
|
// After pruning stale entries, evict oldest if still over the limit
|
|
if len(allSessions[user]) > maxSessionsPerUser {
|
|
excess := allSessions[user][:len(allSessions[user])-maxSessionsPerUser]
|
|
for _, oldSID := range excess {
|
|
store.DeleteSession(oldSID)
|
|
GetCSRFStore().DeleteCSRFToken(oldSID)
|
|
}
|
|
allSessions[user] = allSessions[user][len(allSessions[user])-maxSessionsPerUser:]
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetSessionUsername returns the username associated with a session ID
|
|
func GetSessionUsername(sessionID string) string {
|
|
// First check in-memory map
|
|
sessionsMu.RLock()
|
|
for user, sessions := range allSessions {
|
|
for _, sid := range sessions {
|
|
if sid == sessionID {
|
|
sessionsMu.RUnlock()
|
|
return user
|
|
}
|
|
}
|
|
}
|
|
sessionsMu.RUnlock()
|
|
|
|
// Fall back to persisted username in session store (survives restarts)
|
|
if session := GetSessionStore().GetSession(sessionID); session != nil && session.Username != "" {
|
|
// Re-populate in-memory map for faster future lookups
|
|
TrackUserSession(session.Username, sessionID)
|
|
return session.Username
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// InvalidateUserSessions invalidates all sessions for a user (e.g., on password change)
|
|
func InvalidateUserSessions(user string) {
|
|
sessionsMu.Lock()
|
|
defer sessionsMu.Unlock()
|
|
|
|
sessionIDs := allSessions[user]
|
|
for _, sid := range sessionIDs {
|
|
// Delete from persistent session store
|
|
GetSessionStore().DeleteSession(sid)
|
|
|
|
// Delete CSRF tokens
|
|
GetCSRFStore().DeleteCSRFToken(sid)
|
|
}
|
|
|
|
delete(allSessions, user)
|
|
|
|
log.Info().
|
|
Str("user", user).
|
|
Int("sessions_invalidated", len(sessionIDs)).
|
|
Msg("Invalidated all user sessions")
|
|
}
|
|
|
|
// UntrackUserSession removes all occurrences of a session from a user's session list
|
|
// (used for single session logout, not password change which clears all)
|
|
func UntrackUserSession(user, sessionID string) {
|
|
sessionsMu.Lock()
|
|
defer sessionsMu.Unlock()
|
|
|
|
sessions := allSessions[user]
|
|
filtered := sessions[:0]
|
|
for _, sid := range sessions {
|
|
if sid != sessionID {
|
|
filtered = append(filtered, sid)
|
|
}
|
|
}
|
|
allSessions[user] = filtered
|
|
}
|
|
|
|
// InvalidateOldSessionFromRequest destroys any pre-existing session cookie to
|
|
// prevent session fixation attacks. Call this before creating a new session.
|
|
// It deletes the session from the persistent store, its CSRF token, and
|
|
// removes it from the in-memory user session tracking map.
|
|
func InvalidateOldSessionFromRequest(r *http.Request) {
|
|
cookie, err := readSessionCookie(r)
|
|
if err != nil || cookie.Value == "" {
|
|
return
|
|
}
|
|
oldToken := cookie.Value
|
|
|
|
// Remove from persistent store
|
|
GetSessionStore().DeleteSession(oldToken)
|
|
GetCSRFStore().DeleteCSRFToken(oldToken)
|
|
|
|
// Remove from in-memory tracking so GetSessionUsername won't resolve it
|
|
user := GetSessionUsername(oldToken)
|
|
if user != "" {
|
|
UntrackUserSession(user, oldToken)
|
|
}
|
|
}
|