package api import ( "crypto/sha256" "encoding/hex" "encoding/json" "os" "path/filepath" "sync" "time" "github.com/rcourtman/pulse-go-rewrite/internal/crypto" "github.com/rs/zerolog/log" ) // SessionStore handles persistent session storage type SessionStore struct { sessions map[string]*SessionData mu sync.RWMutex dataPath string saveTicker *time.Ticker stopChan chan bool crypto *crypto.CryptoManager } func sessionHash(token string) string { sum := sha256.Sum256([]byte(token)) return hex.EncodeToString(sum[:]) } type sessionPersisted struct { Key string `json:"key"` Username string `json:"username,omitempty"` ExpiresAt time.Time `json:"expires_at"` CreatedAt time.Time `json:"created_at"` UserAgent string `json:"user_agent,omitempty"` IP string `json:"ip,omitempty"` OriginalDuration time.Duration `json:"original_duration,omitempty"` // OIDC token fields for refresh token support OIDCRefreshToken string `json:"oidc_refresh_token,omitempty"` OIDCAccessTokenExp time.Time `json:"oidc_access_token_exp,omitempty"` OIDCIssuer string `json:"oidc_issuer,omitempty"` OIDCClientID string `json:"oidc_client_id,omitempty"` OIDCTokenRefreshing bool `json:"-"` // transient, not persisted // SAML session fields for Single Logout (SLO) support SAMLProviderID string `json:"saml_provider_id,omitempty"` SAMLNameID string `json:"saml_name_id,omitempty"` SAMLSessionIndex string `json:"saml_session_index,omitempty"` } // SessionData represents a user session type SessionData struct { Username string `json:"username,omitempty"` // The authenticated user ExpiresAt time.Time `json:"expires_at"` CreatedAt time.Time `json:"created_at"` UserAgent string `json:"user_agent,omitempty"` IP string `json:"ip,omitempty"` OriginalDuration time.Duration `json:"original_duration,omitempty"` // Track original duration for sliding expiration // OIDC token fields for refresh token support OIDCRefreshToken string `json:"oidc_refresh_token,omitempty"` // Encrypted at rest OIDCAccessTokenExp time.Time `json:"oidc_access_token_exp,omitempty"` // When the access token expires OIDCIssuer string `json:"oidc_issuer,omitempty"` // IdP issuer URL OIDCClientID string `json:"oidc_client_id,omitempty"` // OIDC client ID OIDCTokenRefreshing bool `json:"-"` // Prevents concurrent refresh attempts // SAML session fields for Single Logout (SLO) support SAMLProviderID string `json:"saml_provider_id,omitempty"` // SAML IdP provider ID SAMLNameID string `json:"saml_name_id,omitempty"` // SAML NameID from assertion SAMLSessionIndex string `json:"saml_session_index,omitempty"` // SAML SessionIndex for SLO } // NewSessionStore creates a new persistent session store func NewSessionStore(dataPath string) *SessionStore { cm, err := crypto.NewCryptoManagerAt(dataPath) if err != nil { log.Error().Err(err).Msg("Failed to initialize crypto manager for session store") } store := &SessionStore{ sessions: make(map[string]*SessionData), dataPath: dataPath, stopChan: make(chan bool), crypto: cm, } // Load existing sessions from disk store.load() // Start periodic save and cleanup store.saveTicker = time.NewTicker(5 * time.Minute) go store.backgroundWorker() return store } // backgroundWorker handles periodic saves and cleanup func (s *SessionStore) backgroundWorker() { for { select { case <-s.saveTicker.C: s.cleanup() s.save() case <-s.stopChan: s.save() return } } } // CreateSession creates a new session func (s *SessionStore) CreateSession(token string, duration time.Duration, userAgent, ip, username string) { s.mu.Lock() defer s.mu.Unlock() key := sessionHash(token) s.sessions[key] = &SessionData{ Username: username, ExpiresAt: time.Now().Add(duration), CreatedAt: time.Now(), UserAgent: userAgent, IP: ip, OriginalDuration: duration, } // Save immediately for important operations s.saveUnsafe() } // OIDCTokenInfo contains OAuth2 token information from the IdP type OIDCTokenInfo struct { RefreshToken string AccessTokenExp time.Time Issuer string ClientID string } // CreateOIDCSession creates a new session with OIDC token information func (s *SessionStore) CreateOIDCSession(token string, duration time.Duration, userAgent, ip, username string, oidc *OIDCTokenInfo) { s.mu.Lock() defer s.mu.Unlock() key := sessionHash(token) session := &SessionData{ Username: username, ExpiresAt: time.Now().Add(duration), CreatedAt: time.Now(), UserAgent: userAgent, IP: ip, OriginalDuration: duration, } if oidc != nil { session.OIDCRefreshToken = oidc.RefreshToken session.OIDCAccessTokenExp = oidc.AccessTokenExp session.OIDCIssuer = oidc.Issuer session.OIDCClientID = oidc.ClientID } s.sessions[key] = session // Save immediately for important operations s.saveUnsafe() } // SAMLTokenInfo contains SAML session information for Single Logout support type SAMLTokenInfo struct { ProviderID string // SAML IdP provider ID NameID string // SAML NameID from assertion SessionIndex string // SAML SessionIndex for SLO } // CreateSAMLSession creates a new session with SAML session information func (s *SessionStore) CreateSAMLSession(token string, duration time.Duration, userAgent, ip, username string, saml *SAMLTokenInfo) { s.mu.Lock() defer s.mu.Unlock() key := sessionHash(token) session := &SessionData{ Username: username, ExpiresAt: time.Now().Add(duration), CreatedAt: time.Now(), UserAgent: userAgent, IP: ip, OriginalDuration: duration, } if saml != nil { session.SAMLProviderID = saml.ProviderID session.SAMLNameID = saml.NameID session.SAMLSessionIndex = saml.SessionIndex } s.sessions[key] = session // Save immediately for important operations s.saveUnsafe() } // GetSAMLSessionInfo returns SAML-specific session info for the given token func (s *SessionStore) GetSAMLSessionInfo(token string) *SAMLTokenInfo { s.mu.RLock() defer s.mu.RUnlock() session, exists := s.sessions[sessionHash(token)] if !exists || session.SAMLProviderID == "" { return nil } return &SAMLTokenInfo{ ProviderID: session.SAMLProviderID, NameID: session.SAMLNameID, SessionIndex: session.SAMLSessionIndex, } } // GetSession returns a copy of the session data for the given token func (s *SessionStore) GetSession(token string) *SessionData { s.mu.RLock() defer s.mu.RUnlock() session, exists := s.sessions[sessionHash(token)] if !exists { return nil } // Return a copy to avoid race conditions copy := *session return © } // UpdateOIDCTokens updates the OIDC tokens for a session after a successful refresh func (s *SessionStore) UpdateOIDCTokens(token string, refreshToken string, accessTokenExp time.Time) { s.mu.Lock() defer s.mu.Unlock() key := sessionHash(token) session, exists := s.sessions[key] if !exists { return } session.OIDCRefreshToken = refreshToken session.OIDCAccessTokenExp = accessTokenExp session.OIDCTokenRefreshing = false // Also extend the session expiry since the token is still valid if session.OriginalDuration > 0 { session.ExpiresAt = time.Now().Add(session.OriginalDuration) } // Save immediately after token refresh s.saveUnsafe() } // InvalidateSession removes a session (used when OIDC refresh fails) func (s *SessionStore) InvalidateSession(token string) { s.DeleteSession(token) } // SetTokenRefreshing marks a session as currently refreshing tokens func (s *SessionStore) SetTokenRefreshing(token string, refreshing bool) { s.mu.Lock() defer s.mu.Unlock() key := sessionHash(token) if session, exists := s.sessions[key]; exists { session.OIDCTokenRefreshing = refreshing } } // ValidateSession checks if a session is valid func (s *SessionStore) ValidateSession(token string) bool { s.mu.RLock() defer s.mu.RUnlock() session, exists := s.sessions[sessionHash(token)] if !exists { return false } return time.Now().Before(session.ExpiresAt) } // ValidateAndExtendSession checks if a session is valid and extends it (sliding expiration) func (s *SessionStore) ValidateAndExtendSession(token string) bool { s.mu.Lock() defer s.mu.Unlock() key := sessionHash(token) session, exists := s.sessions[key] if !exists { return false } now := time.Now() if now.After(session.ExpiresAt) { return false } // Extend session using the original duration (sliding window) if session.OriginalDuration > 0 { session.ExpiresAt = now.Add(session.OriginalDuration) // Note: We don't save immediately for performance, background worker will save periodically } return true } // DeleteSession removes a session func (s *SessionStore) DeleteSession(token string) { s.mu.Lock() defer s.mu.Unlock() delete(s.sessions, sessionHash(token)) s.saveUnsafe() } // cleanup removes expired sessions func (s *SessionStore) cleanup() { s.mu.Lock() defer s.mu.Unlock() now := time.Now() for key, session := range s.sessions { if now.After(session.ExpiresAt) { delete(s.sessions, key) log.Debug().Str("sessionKey", safePrefixForLog(key, 8)+"...").Msg("Cleaned up expired session") } } } // save persists sessions to disk func (s *SessionStore) save() { s.mu.RLock() defer s.mu.RUnlock() s.saveUnsafe() } // saveUnsafe saves without locking (caller must hold lock) func (s *SessionStore) saveUnsafe() { sessionsFile := filepath.Join(s.dataPath, "sessions.json") // Create directory if it doesn't exist if err := os.MkdirAll(s.dataPath, 0700); err != nil { log.Error().Err(err).Msg("Failed to create sessions directory") return } // Marshal sessions persisted := make([]sessionPersisted, 0, len(s.sessions)) for key, session := range s.sessions { refreshToken := session.OIDCRefreshToken // Encrypt refresh token if crypto is available and token exists if refreshToken != "" && s.crypto != nil { if encrypted, err := s.crypto.EncryptString(refreshToken); err == nil { refreshToken = encrypted } else { log.Error().Err(err).Msg("Failed to encrypt refresh token") // Don't persist if encryption fails to prevent leak refreshToken = "" } } persisted = append(persisted, sessionPersisted{ Key: key, Username: session.Username, ExpiresAt: session.ExpiresAt, CreatedAt: session.CreatedAt, UserAgent: session.UserAgent, IP: session.IP, OriginalDuration: session.OriginalDuration, OIDCRefreshToken: refreshToken, OIDCAccessTokenExp: session.OIDCAccessTokenExp, OIDCIssuer: session.OIDCIssuer, OIDCClientID: session.OIDCClientID, SAMLProviderID: session.SAMLProviderID, SAMLNameID: session.SAMLNameID, SAMLSessionIndex: session.SAMLSessionIndex, }) } data, err := json.Marshal(persisted) if err != nil { log.Error().Err(err).Msg("Failed to marshal sessions") return } // Write to temporary file first tmpFile := sessionsFile + ".tmp" if err := os.WriteFile(tmpFile, data, 0600); err != nil { log.Error().Err(err).Msg("Failed to write sessions file") return } // Atomic rename if err := os.Rename(tmpFile, sessionsFile); err != nil { log.Error().Err(err).Msg("Failed to rename sessions file") return } log.Debug().Int("count", len(s.sessions)).Msg("Sessions saved to disk") } // load reads sessions from disk func (s *SessionStore) load() { sessionsFile := filepath.Join(s.dataPath, "sessions.json") data, err := os.ReadFile(sessionsFile) if err != nil { if !os.IsNotExist(err) { log.Error().Err(err).Msg("Failed to read sessions file") } return } now := time.Now() s.sessions = make(map[string]*SessionData) var persisted []sessionPersisted if err := json.Unmarshal(data, &persisted); err == nil { for _, entry := range persisted { if now.After(entry.ExpiresAt) { continue } refreshToken := entry.OIDCRefreshToken // Decrypt refresh token if needed (handles migration from plaintext) if refreshToken != "" && s.crypto != nil { if decrypted, err := s.crypto.DecryptString(refreshToken); err == nil { refreshToken = decrypted } // If decryption fails, assume it's legacy plaintext and leave as is } s.sessions[entry.Key] = &SessionData{ Username: entry.Username, ExpiresAt: entry.ExpiresAt, CreatedAt: entry.CreatedAt, UserAgent: entry.UserAgent, IP: entry.IP, OriginalDuration: entry.OriginalDuration, OIDCRefreshToken: refreshToken, OIDCAccessTokenExp: entry.OIDCAccessTokenExp, OIDCIssuer: entry.OIDCIssuer, OIDCClientID: entry.OIDCClientID, SAMLProviderID: entry.SAMLProviderID, SAMLNameID: entry.SAMLNameID, SAMLSessionIndex: entry.SAMLSessionIndex, } } log.Info().Int("loaded", len(s.sessions)).Int("total", len(persisted)).Msg("Sessions loaded from disk (hashed format)") return } // Legacy map format fallback (keys stored as raw tokens) var legacy map[string]*SessionData if err := json.Unmarshal(data, &legacy); err != nil { log.Error().Err(err).Msg("Failed to unmarshal legacy sessions") return } loaded := 0 for token, session := range legacy { if now.After(session.ExpiresAt) { continue } s.sessions[sessionHash(token)] = session loaded++ } log.Info(). Int("loaded", loaded). Int("total", len(legacy)). Msg("Sessions loaded from disk (legacy format migrated)") }