Pulse/internal/api/session_store.go
rcourtman da6ee7b1a6 feat(sso): implement SAML session storage for Single Logout support
- Add SAML session fields (ProviderID, NameID, SessionIndex) to
  SessionData and sessionPersisted structs for persistence
- Add CreateSAMLSession method to store SAML-authenticated sessions
- Add GetSAMLSessionInfo method to retrieve SAML session data
- Update establishSAMLSession to properly store SAML info instead
  of delegating to OIDC session creation
- Implement getSAMLSessionInfo to retrieve session info for SLO

This enables proper SAML Single Logout by storing the NameID and
SessionIndex from the SAML assertion, which are required to construct
valid LogoutRequest messages to the IdP.
2026-01-12 16:37:07 +00:00

444 lines
13 KiB
Go

package api
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"path/filepath"
"sync"
"time"
"github.com/rs/zerolog/log"
)
// SessionStore handles persistent session storage
type SessionStore struct {
sessions map[string]*SessionData
mu sync.RWMutex
dataPath string
saveTicker *time.Ticker
stopChan chan bool
}
func sessionHash(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
type sessionPersisted struct {
Key string `json:"key"`
Username string `json:"username,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UserAgent string `json:"user_agent,omitempty"`
IP string `json:"ip,omitempty"`
OriginalDuration time.Duration `json:"original_duration,omitempty"`
// OIDC token fields for refresh token support
OIDCRefreshToken string `json:"oidc_refresh_token,omitempty"`
OIDCAccessTokenExp time.Time `json:"oidc_access_token_exp,omitempty"`
OIDCIssuer string `json:"oidc_issuer,omitempty"`
OIDCClientID string `json:"oidc_client_id,omitempty"`
OIDCTokenRefreshing bool `json:"-"` // transient, not persisted
// SAML session fields for Single Logout (SLO) support
SAMLProviderID string `json:"saml_provider_id,omitempty"`
SAMLNameID string `json:"saml_name_id,omitempty"`
SAMLSessionIndex string `json:"saml_session_index,omitempty"`
}
// SessionData represents a user session
type SessionData struct {
Username string `json:"username,omitempty"` // The authenticated user
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UserAgent string `json:"user_agent,omitempty"`
IP string `json:"ip,omitempty"`
OriginalDuration time.Duration `json:"original_duration,omitempty"` // Track original duration for sliding expiration
// OIDC token fields for refresh token support
OIDCRefreshToken string `json:"oidc_refresh_token,omitempty"` // Encrypted at rest
OIDCAccessTokenExp time.Time `json:"oidc_access_token_exp,omitempty"` // When the access token expires
OIDCIssuer string `json:"oidc_issuer,omitempty"` // IdP issuer URL
OIDCClientID string `json:"oidc_client_id,omitempty"` // OIDC client ID
OIDCTokenRefreshing bool `json:"-"` // Prevents concurrent refresh attempts
// SAML session fields for Single Logout (SLO) support
SAMLProviderID string `json:"saml_provider_id,omitempty"` // SAML IdP provider ID
SAMLNameID string `json:"saml_name_id,omitempty"` // SAML NameID from assertion
SAMLSessionIndex string `json:"saml_session_index,omitempty"` // SAML SessionIndex for SLO
}
// NewSessionStore creates a new persistent session store
func NewSessionStore(dataPath string) *SessionStore {
store := &SessionStore{
sessions: make(map[string]*SessionData),
dataPath: dataPath,
stopChan: make(chan bool),
}
// Load existing sessions from disk
store.load()
// Start periodic save and cleanup
store.saveTicker = time.NewTicker(5 * time.Minute)
go store.backgroundWorker()
return store
}
// backgroundWorker handles periodic saves and cleanup
func (s *SessionStore) backgroundWorker() {
for {
select {
case <-s.saveTicker.C:
s.cleanup()
s.save()
case <-s.stopChan:
s.save()
return
}
}
}
// CreateSession creates a new session
func (s *SessionStore) CreateSession(token string, duration time.Duration, userAgent, ip, username string) {
s.mu.Lock()
defer s.mu.Unlock()
key := sessionHash(token)
s.sessions[key] = &SessionData{
Username: username,
ExpiresAt: time.Now().Add(duration),
CreatedAt: time.Now(),
UserAgent: userAgent,
IP: ip,
OriginalDuration: duration,
}
// Save immediately for important operations
s.saveUnsafe()
}
// OIDCTokenInfo contains OAuth2 token information from the IdP
type OIDCTokenInfo struct {
RefreshToken string
AccessTokenExp time.Time
Issuer string
ClientID string
}
// CreateOIDCSession creates a new session with OIDC token information
func (s *SessionStore) CreateOIDCSession(token string, duration time.Duration, userAgent, ip, username string, oidc *OIDCTokenInfo) {
s.mu.Lock()
defer s.mu.Unlock()
key := sessionHash(token)
session := &SessionData{
Username: username,
ExpiresAt: time.Now().Add(duration),
CreatedAt: time.Now(),
UserAgent: userAgent,
IP: ip,
OriginalDuration: duration,
}
if oidc != nil {
session.OIDCRefreshToken = oidc.RefreshToken
session.OIDCAccessTokenExp = oidc.AccessTokenExp
session.OIDCIssuer = oidc.Issuer
session.OIDCClientID = oidc.ClientID
}
s.sessions[key] = session
// Save immediately for important operations
s.saveUnsafe()
}
// SAMLTokenInfo contains SAML session information for Single Logout support
type SAMLTokenInfo struct {
ProviderID string // SAML IdP provider ID
NameID string // SAML NameID from assertion
SessionIndex string // SAML SessionIndex for SLO
}
// CreateSAMLSession creates a new session with SAML session information
func (s *SessionStore) CreateSAMLSession(token string, duration time.Duration, userAgent, ip, username string, saml *SAMLTokenInfo) {
s.mu.Lock()
defer s.mu.Unlock()
key := sessionHash(token)
session := &SessionData{
Username: username,
ExpiresAt: time.Now().Add(duration),
CreatedAt: time.Now(),
UserAgent: userAgent,
IP: ip,
OriginalDuration: duration,
}
if saml != nil {
session.SAMLProviderID = saml.ProviderID
session.SAMLNameID = saml.NameID
session.SAMLSessionIndex = saml.SessionIndex
}
s.sessions[key] = session
// Save immediately for important operations
s.saveUnsafe()
}
// GetSAMLSessionInfo returns SAML-specific session info for the given token
func (s *SessionStore) GetSAMLSessionInfo(token string) *SAMLTokenInfo {
s.mu.RLock()
defer s.mu.RUnlock()
session, exists := s.sessions[sessionHash(token)]
if !exists || session.SAMLProviderID == "" {
return nil
}
return &SAMLTokenInfo{
ProviderID: session.SAMLProviderID,
NameID: session.SAMLNameID,
SessionIndex: session.SAMLSessionIndex,
}
}
// GetSession returns a copy of the session data for the given token
func (s *SessionStore) GetSession(token string) *SessionData {
s.mu.RLock()
defer s.mu.RUnlock()
session, exists := s.sessions[sessionHash(token)]
if !exists {
return nil
}
// Return a copy to avoid race conditions
copy := *session
return &copy
}
// UpdateOIDCTokens updates the OIDC tokens for a session after a successful refresh
func (s *SessionStore) UpdateOIDCTokens(token string, refreshToken string, accessTokenExp time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
key := sessionHash(token)
session, exists := s.sessions[key]
if !exists {
return
}
session.OIDCRefreshToken = refreshToken
session.OIDCAccessTokenExp = accessTokenExp
session.OIDCTokenRefreshing = false
// Also extend the session expiry since the token is still valid
if session.OriginalDuration > 0 {
session.ExpiresAt = time.Now().Add(session.OriginalDuration)
}
// Save immediately after token refresh
s.saveUnsafe()
}
// InvalidateSession removes a session (used when OIDC refresh fails)
func (s *SessionStore) InvalidateSession(token string) {
s.DeleteSession(token)
}
// SetTokenRefreshing marks a session as currently refreshing tokens
func (s *SessionStore) SetTokenRefreshing(token string, refreshing bool) {
s.mu.Lock()
defer s.mu.Unlock()
key := sessionHash(token)
if session, exists := s.sessions[key]; exists {
session.OIDCTokenRefreshing = refreshing
}
}
// ValidateSession checks if a session is valid
func (s *SessionStore) ValidateSession(token string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
session, exists := s.sessions[sessionHash(token)]
if !exists {
return false
}
return time.Now().Before(session.ExpiresAt)
}
// ValidateAndExtendSession checks if a session is valid and extends it (sliding expiration)
func (s *SessionStore) ValidateAndExtendSession(token string) bool {
s.mu.Lock()
defer s.mu.Unlock()
key := sessionHash(token)
session, exists := s.sessions[key]
if !exists {
return false
}
now := time.Now()
if now.After(session.ExpiresAt) {
return false
}
// Extend session using the original duration (sliding window)
if session.OriginalDuration > 0 {
session.ExpiresAt = now.Add(session.OriginalDuration)
// Note: We don't save immediately for performance, background worker will save periodically
}
return true
}
// DeleteSession removes a session
func (s *SessionStore) DeleteSession(token string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionHash(token))
s.saveUnsafe()
}
// cleanup removes expired sessions
func (s *SessionStore) cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for key, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, key)
log.Debug().Str("sessionKey", safePrefixForLog(key, 8)+"...").Msg("Cleaned up expired session")
}
}
}
// save persists sessions to disk
func (s *SessionStore) save() {
s.mu.RLock()
defer s.mu.RUnlock()
s.saveUnsafe()
}
// saveUnsafe saves without locking (caller must hold lock)
func (s *SessionStore) saveUnsafe() {
sessionsFile := filepath.Join(s.dataPath, "sessions.json")
// Create directory if it doesn't exist
if err := os.MkdirAll(s.dataPath, 0700); err != nil {
log.Error().Err(err).Msg("Failed to create sessions directory")
return
}
// Marshal sessions
persisted := make([]sessionPersisted, 0, len(s.sessions))
for key, session := range s.sessions {
persisted = append(persisted, sessionPersisted{
Key: key,
Username: session.Username,
ExpiresAt: session.ExpiresAt,
CreatedAt: session.CreatedAt,
UserAgent: session.UserAgent,
IP: session.IP,
OriginalDuration: session.OriginalDuration,
OIDCRefreshToken: session.OIDCRefreshToken,
OIDCAccessTokenExp: session.OIDCAccessTokenExp,
OIDCIssuer: session.OIDCIssuer,
OIDCClientID: session.OIDCClientID,
SAMLProviderID: session.SAMLProviderID,
SAMLNameID: session.SAMLNameID,
SAMLSessionIndex: session.SAMLSessionIndex,
})
}
data, err := json.Marshal(persisted)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal sessions")
return
}
// Write to temporary file first
tmpFile := sessionsFile + ".tmp"
if err := os.WriteFile(tmpFile, data, 0600); err != nil {
log.Error().Err(err).Msg("Failed to write sessions file")
return
}
// Atomic rename
if err := os.Rename(tmpFile, sessionsFile); err != nil {
log.Error().Err(err).Msg("Failed to rename sessions file")
return
}
log.Debug().Int("count", len(s.sessions)).Msg("Sessions saved to disk")
}
// load reads sessions from disk
func (s *SessionStore) load() {
sessionsFile := filepath.Join(s.dataPath, "sessions.json")
data, err := os.ReadFile(sessionsFile)
if err != nil {
if !os.IsNotExist(err) {
log.Error().Err(err).Msg("Failed to read sessions file")
}
return
}
now := time.Now()
s.sessions = make(map[string]*SessionData)
var persisted []sessionPersisted
if err := json.Unmarshal(data, &persisted); err == nil {
for _, entry := range persisted {
if now.After(entry.ExpiresAt) {
continue
}
s.sessions[entry.Key] = &SessionData{
Username: entry.Username,
ExpiresAt: entry.ExpiresAt,
CreatedAt: entry.CreatedAt,
UserAgent: entry.UserAgent,
IP: entry.IP,
OriginalDuration: entry.OriginalDuration,
OIDCRefreshToken: entry.OIDCRefreshToken,
OIDCAccessTokenExp: entry.OIDCAccessTokenExp,
OIDCIssuer: entry.OIDCIssuer,
OIDCClientID: entry.OIDCClientID,
SAMLProviderID: entry.SAMLProviderID,
SAMLNameID: entry.SAMLNameID,
SAMLSessionIndex: entry.SAMLSessionIndex,
}
}
log.Info().Int("loaded", len(s.sessions)).Int("total", len(persisted)).Msg("Sessions loaded from disk (hashed format)")
return
}
// Legacy map format fallback (keys stored as raw tokens)
var legacy map[string]*SessionData
if err := json.Unmarshal(data, &legacy); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal legacy sessions")
return
}
loaded := 0
for token, session := range legacy {
if now.After(session.ExpiresAt) {
continue
}
s.sessions[sessionHash(token)] = session
loaded++
}
log.Info().
Int("loaded", loaded).
Int("total", len(legacy)).
Msg("Sessions loaded from disk (legacy format migrated)")
}