mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-13 15:28:38 +00:00
- Add SAML session fields (ProviderID, NameID, SessionIndex) to SessionData and sessionPersisted structs for persistence - Add CreateSAMLSession method to store SAML-authenticated sessions - Add GetSAMLSessionInfo method to retrieve SAML session data - Update establishSAMLSession to properly store SAML info instead of delegating to OIDC session creation - Implement getSAMLSessionInfo to retrieve session info for SLO This enables proper SAML Single Logout by storing the NameID and SessionIndex from the SAML assertion, which are required to construct valid LogoutRequest messages to the IdP.
444 lines
13 KiB
Go
444 lines
13 KiB
Go
package api
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// SessionStore handles persistent session storage
|
|
type SessionStore struct {
|
|
sessions map[string]*SessionData
|
|
mu sync.RWMutex
|
|
dataPath string
|
|
saveTicker *time.Ticker
|
|
stopChan chan bool
|
|
}
|
|
|
|
func sessionHash(token string) string {
|
|
sum := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
type sessionPersisted struct {
|
|
Key string `json:"key"`
|
|
Username string `json:"username,omitempty"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UserAgent string `json:"user_agent,omitempty"`
|
|
IP string `json:"ip,omitempty"`
|
|
OriginalDuration time.Duration `json:"original_duration,omitempty"`
|
|
// OIDC token fields for refresh token support
|
|
OIDCRefreshToken string `json:"oidc_refresh_token,omitempty"`
|
|
OIDCAccessTokenExp time.Time `json:"oidc_access_token_exp,omitempty"`
|
|
OIDCIssuer string `json:"oidc_issuer,omitempty"`
|
|
OIDCClientID string `json:"oidc_client_id,omitempty"`
|
|
OIDCTokenRefreshing bool `json:"-"` // transient, not persisted
|
|
// SAML session fields for Single Logout (SLO) support
|
|
SAMLProviderID string `json:"saml_provider_id,omitempty"`
|
|
SAMLNameID string `json:"saml_name_id,omitempty"`
|
|
SAMLSessionIndex string `json:"saml_session_index,omitempty"`
|
|
}
|
|
|
|
// SessionData represents a user session
|
|
type SessionData struct {
|
|
Username string `json:"username,omitempty"` // The authenticated user
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UserAgent string `json:"user_agent,omitempty"`
|
|
IP string `json:"ip,omitempty"`
|
|
OriginalDuration time.Duration `json:"original_duration,omitempty"` // Track original duration for sliding expiration
|
|
// OIDC token fields for refresh token support
|
|
OIDCRefreshToken string `json:"oidc_refresh_token,omitempty"` // Encrypted at rest
|
|
OIDCAccessTokenExp time.Time `json:"oidc_access_token_exp,omitempty"` // When the access token expires
|
|
OIDCIssuer string `json:"oidc_issuer,omitempty"` // IdP issuer URL
|
|
OIDCClientID string `json:"oidc_client_id,omitempty"` // OIDC client ID
|
|
OIDCTokenRefreshing bool `json:"-"` // Prevents concurrent refresh attempts
|
|
// SAML session fields for Single Logout (SLO) support
|
|
SAMLProviderID string `json:"saml_provider_id,omitempty"` // SAML IdP provider ID
|
|
SAMLNameID string `json:"saml_name_id,omitempty"` // SAML NameID from assertion
|
|
SAMLSessionIndex string `json:"saml_session_index,omitempty"` // SAML SessionIndex for SLO
|
|
}
|
|
|
|
// NewSessionStore creates a new persistent session store
|
|
func NewSessionStore(dataPath string) *SessionStore {
|
|
store := &SessionStore{
|
|
sessions: make(map[string]*SessionData),
|
|
dataPath: dataPath,
|
|
stopChan: make(chan bool),
|
|
}
|
|
|
|
// Load existing sessions from disk
|
|
store.load()
|
|
|
|
// Start periodic save and cleanup
|
|
store.saveTicker = time.NewTicker(5 * time.Minute)
|
|
go store.backgroundWorker()
|
|
|
|
return store
|
|
}
|
|
|
|
// backgroundWorker handles periodic saves and cleanup
|
|
func (s *SessionStore) backgroundWorker() {
|
|
for {
|
|
select {
|
|
case <-s.saveTicker.C:
|
|
s.cleanup()
|
|
s.save()
|
|
case <-s.stopChan:
|
|
s.save()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// CreateSession creates a new session
|
|
func (s *SessionStore) CreateSession(token string, duration time.Duration, userAgent, ip, username string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
key := sessionHash(token)
|
|
s.sessions[key] = &SessionData{
|
|
Username: username,
|
|
ExpiresAt: time.Now().Add(duration),
|
|
CreatedAt: time.Now(),
|
|
UserAgent: userAgent,
|
|
IP: ip,
|
|
OriginalDuration: duration,
|
|
}
|
|
|
|
// Save immediately for important operations
|
|
s.saveUnsafe()
|
|
}
|
|
|
|
// OIDCTokenInfo contains OAuth2 token information from the IdP
|
|
type OIDCTokenInfo struct {
|
|
RefreshToken string
|
|
AccessTokenExp time.Time
|
|
Issuer string
|
|
ClientID string
|
|
}
|
|
|
|
// CreateOIDCSession creates a new session with OIDC token information
|
|
func (s *SessionStore) CreateOIDCSession(token string, duration time.Duration, userAgent, ip, username string, oidc *OIDCTokenInfo) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
key := sessionHash(token)
|
|
session := &SessionData{
|
|
Username: username,
|
|
ExpiresAt: time.Now().Add(duration),
|
|
CreatedAt: time.Now(),
|
|
UserAgent: userAgent,
|
|
IP: ip,
|
|
OriginalDuration: duration,
|
|
}
|
|
|
|
if oidc != nil {
|
|
session.OIDCRefreshToken = oidc.RefreshToken
|
|
session.OIDCAccessTokenExp = oidc.AccessTokenExp
|
|
session.OIDCIssuer = oidc.Issuer
|
|
session.OIDCClientID = oidc.ClientID
|
|
}
|
|
|
|
s.sessions[key] = session
|
|
|
|
// Save immediately for important operations
|
|
s.saveUnsafe()
|
|
}
|
|
|
|
// SAMLTokenInfo contains SAML session information for Single Logout support
|
|
type SAMLTokenInfo struct {
|
|
ProviderID string // SAML IdP provider ID
|
|
NameID string // SAML NameID from assertion
|
|
SessionIndex string // SAML SessionIndex for SLO
|
|
}
|
|
|
|
// CreateSAMLSession creates a new session with SAML session information
|
|
func (s *SessionStore) CreateSAMLSession(token string, duration time.Duration, userAgent, ip, username string, saml *SAMLTokenInfo) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
key := sessionHash(token)
|
|
session := &SessionData{
|
|
Username: username,
|
|
ExpiresAt: time.Now().Add(duration),
|
|
CreatedAt: time.Now(),
|
|
UserAgent: userAgent,
|
|
IP: ip,
|
|
OriginalDuration: duration,
|
|
}
|
|
|
|
if saml != nil {
|
|
session.SAMLProviderID = saml.ProviderID
|
|
session.SAMLNameID = saml.NameID
|
|
session.SAMLSessionIndex = saml.SessionIndex
|
|
}
|
|
|
|
s.sessions[key] = session
|
|
|
|
// Save immediately for important operations
|
|
s.saveUnsafe()
|
|
}
|
|
|
|
// GetSAMLSessionInfo returns SAML-specific session info for the given token
|
|
func (s *SessionStore) GetSAMLSessionInfo(token string) *SAMLTokenInfo {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
session, exists := s.sessions[sessionHash(token)]
|
|
if !exists || session.SAMLProviderID == "" {
|
|
return nil
|
|
}
|
|
|
|
return &SAMLTokenInfo{
|
|
ProviderID: session.SAMLProviderID,
|
|
NameID: session.SAMLNameID,
|
|
SessionIndex: session.SAMLSessionIndex,
|
|
}
|
|
}
|
|
|
|
// GetSession returns a copy of the session data for the given token
|
|
func (s *SessionStore) GetSession(token string) *SessionData {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
session, exists := s.sessions[sessionHash(token)]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
// Return a copy to avoid race conditions
|
|
copy := *session
|
|
return ©
|
|
}
|
|
|
|
// UpdateOIDCTokens updates the OIDC tokens for a session after a successful refresh
|
|
func (s *SessionStore) UpdateOIDCTokens(token string, refreshToken string, accessTokenExp time.Time) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
key := sessionHash(token)
|
|
session, exists := s.sessions[key]
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
session.OIDCRefreshToken = refreshToken
|
|
session.OIDCAccessTokenExp = accessTokenExp
|
|
session.OIDCTokenRefreshing = false
|
|
|
|
// Also extend the session expiry since the token is still valid
|
|
if session.OriginalDuration > 0 {
|
|
session.ExpiresAt = time.Now().Add(session.OriginalDuration)
|
|
}
|
|
|
|
// Save immediately after token refresh
|
|
s.saveUnsafe()
|
|
}
|
|
|
|
// InvalidateSession removes a session (used when OIDC refresh fails)
|
|
func (s *SessionStore) InvalidateSession(token string) {
|
|
s.DeleteSession(token)
|
|
}
|
|
|
|
// SetTokenRefreshing marks a session as currently refreshing tokens
|
|
func (s *SessionStore) SetTokenRefreshing(token string, refreshing bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
key := sessionHash(token)
|
|
if session, exists := s.sessions[key]; exists {
|
|
session.OIDCTokenRefreshing = refreshing
|
|
}
|
|
}
|
|
|
|
// ValidateSession checks if a session is valid
|
|
func (s *SessionStore) ValidateSession(token string) bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
session, exists := s.sessions[sessionHash(token)]
|
|
if !exists {
|
|
return false
|
|
}
|
|
|
|
return time.Now().Before(session.ExpiresAt)
|
|
}
|
|
|
|
// ValidateAndExtendSession checks if a session is valid and extends it (sliding expiration)
|
|
func (s *SessionStore) ValidateAndExtendSession(token string) bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
key := sessionHash(token)
|
|
session, exists := s.sessions[key]
|
|
if !exists {
|
|
return false
|
|
}
|
|
|
|
now := time.Now()
|
|
if now.After(session.ExpiresAt) {
|
|
return false
|
|
}
|
|
|
|
// Extend session using the original duration (sliding window)
|
|
if session.OriginalDuration > 0 {
|
|
session.ExpiresAt = now.Add(session.OriginalDuration)
|
|
// Note: We don't save immediately for performance, background worker will save periodically
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// DeleteSession removes a session
|
|
func (s *SessionStore) DeleteSession(token string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
delete(s.sessions, sessionHash(token))
|
|
s.saveUnsafe()
|
|
}
|
|
|
|
// cleanup removes expired sessions
|
|
func (s *SessionStore) cleanup() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
for key, session := range s.sessions {
|
|
if now.After(session.ExpiresAt) {
|
|
delete(s.sessions, key)
|
|
log.Debug().Str("sessionKey", safePrefixForLog(key, 8)+"...").Msg("Cleaned up expired session")
|
|
}
|
|
}
|
|
}
|
|
|
|
// save persists sessions to disk
|
|
func (s *SessionStore) save() {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
s.saveUnsafe()
|
|
}
|
|
|
|
// saveUnsafe saves without locking (caller must hold lock)
|
|
func (s *SessionStore) saveUnsafe() {
|
|
sessionsFile := filepath.Join(s.dataPath, "sessions.json")
|
|
|
|
// Create directory if it doesn't exist
|
|
if err := os.MkdirAll(s.dataPath, 0700); err != nil {
|
|
log.Error().Err(err).Msg("Failed to create sessions directory")
|
|
return
|
|
}
|
|
|
|
// Marshal sessions
|
|
persisted := make([]sessionPersisted, 0, len(s.sessions))
|
|
for key, session := range s.sessions {
|
|
persisted = append(persisted, sessionPersisted{
|
|
Key: key,
|
|
Username: session.Username,
|
|
ExpiresAt: session.ExpiresAt,
|
|
CreatedAt: session.CreatedAt,
|
|
UserAgent: session.UserAgent,
|
|
IP: session.IP,
|
|
OriginalDuration: session.OriginalDuration,
|
|
OIDCRefreshToken: session.OIDCRefreshToken,
|
|
OIDCAccessTokenExp: session.OIDCAccessTokenExp,
|
|
OIDCIssuer: session.OIDCIssuer,
|
|
OIDCClientID: session.OIDCClientID,
|
|
SAMLProviderID: session.SAMLProviderID,
|
|
SAMLNameID: session.SAMLNameID,
|
|
SAMLSessionIndex: session.SAMLSessionIndex,
|
|
})
|
|
}
|
|
|
|
data, err := json.Marshal(persisted)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to marshal sessions")
|
|
return
|
|
}
|
|
|
|
// Write to temporary file first
|
|
tmpFile := sessionsFile + ".tmp"
|
|
if err := os.WriteFile(tmpFile, data, 0600); err != nil {
|
|
log.Error().Err(err).Msg("Failed to write sessions file")
|
|
return
|
|
}
|
|
|
|
// Atomic rename
|
|
if err := os.Rename(tmpFile, sessionsFile); err != nil {
|
|
log.Error().Err(err).Msg("Failed to rename sessions file")
|
|
return
|
|
}
|
|
|
|
log.Debug().Int("count", len(s.sessions)).Msg("Sessions saved to disk")
|
|
}
|
|
|
|
// load reads sessions from disk
|
|
func (s *SessionStore) load() {
|
|
sessionsFile := filepath.Join(s.dataPath, "sessions.json")
|
|
|
|
data, err := os.ReadFile(sessionsFile)
|
|
if err != nil {
|
|
if !os.IsNotExist(err) {
|
|
log.Error().Err(err).Msg("Failed to read sessions file")
|
|
}
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
s.sessions = make(map[string]*SessionData)
|
|
|
|
var persisted []sessionPersisted
|
|
if err := json.Unmarshal(data, &persisted); err == nil {
|
|
for _, entry := range persisted {
|
|
if now.After(entry.ExpiresAt) {
|
|
continue
|
|
}
|
|
s.sessions[entry.Key] = &SessionData{
|
|
Username: entry.Username,
|
|
ExpiresAt: entry.ExpiresAt,
|
|
CreatedAt: entry.CreatedAt,
|
|
UserAgent: entry.UserAgent,
|
|
IP: entry.IP,
|
|
OriginalDuration: entry.OriginalDuration,
|
|
OIDCRefreshToken: entry.OIDCRefreshToken,
|
|
OIDCAccessTokenExp: entry.OIDCAccessTokenExp,
|
|
OIDCIssuer: entry.OIDCIssuer,
|
|
OIDCClientID: entry.OIDCClientID,
|
|
SAMLProviderID: entry.SAMLProviderID,
|
|
SAMLNameID: entry.SAMLNameID,
|
|
SAMLSessionIndex: entry.SAMLSessionIndex,
|
|
}
|
|
}
|
|
log.Info().Int("loaded", len(s.sessions)).Int("total", len(persisted)).Msg("Sessions loaded from disk (hashed format)")
|
|
return
|
|
}
|
|
|
|
// Legacy map format fallback (keys stored as raw tokens)
|
|
var legacy map[string]*SessionData
|
|
if err := json.Unmarshal(data, &legacy); err != nil {
|
|
log.Error().Err(err).Msg("Failed to unmarshal legacy sessions")
|
|
return
|
|
}
|
|
|
|
loaded := 0
|
|
for token, session := range legacy {
|
|
if now.After(session.ExpiresAt) {
|
|
continue
|
|
}
|
|
s.sessions[sessionHash(token)] = session
|
|
loaded++
|
|
}
|
|
|
|
log.Info().
|
|
Int("loaded", loaded).
|
|
Int("total", len(legacy)).
|
|
Msg("Sessions loaded from disk (legacy format migrated)")
|
|
}
|