Pulse/internal/api/oidc_handlers.go
2026-04-01 16:00:04 +01:00

615 lines
19 KiB
Go

package api
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
internalauth "github.com/rcourtman/pulse-go-rewrite/pkg/auth"
"github.com/rs/zerolog/log"
)
func sanitizeOIDCReturnTo(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
if !strings.HasPrefix(trimmed, "/") {
return ""
}
if len(trimmed) > 1 && (trimmed[1] == '/' || trimmed[1] == '\\') {
return ""
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.IsAbs() || parsed.Host != "" {
return ""
}
return trimmed
}
func (r *Router) redirectOIDCError(w http.ResponseWriter, req *http.Request, returnTo string, code string) {
http.Redirect(w, req, buildLocalRedirectTarget(returnTo, map[string]string{
"oidc": "error",
"oidc_error": code,
}), http.StatusFound)
}
func addQueryParam(path, key, value string) string {
if path == "" {
path = "/"
}
u, err := url.Parse(path)
if err != nil {
return path
}
q := u.Query()
q.Set(key, value)
u.RawQuery = q.Encode()
return u.String()
}
func buildLocalRedirectTarget(returnTo string, queryParams map[string]string) string {
target := sanitizeOIDCReturnTo(returnTo)
if target == "" {
target = "/"
}
parsed, err := url.Parse(target)
if err != nil || parsed.IsAbs() || parsed.Host != "" {
target = "/"
} else if parsed.Path == "" {
parsed.Path = "/"
target = parsed.String()
} else {
target = parsed.String()
}
for key, value := range queryParams {
if key == "" || value == "" {
continue
}
target = addQueryParam(target, key, value)
}
return target
}
func extractStringClaim(claims map[string]any, key string) string {
if key == "" {
return ""
}
value, ok := claims[key]
if !ok {
return ""
}
switch v := value.(type) {
case string:
return strings.TrimSpace(v)
case []string:
if len(v) > 0 {
return strings.TrimSpace(v[0])
}
case []interface{}:
for _, item := range v {
if str, ok := item.(string); ok {
return strings.TrimSpace(str)
}
}
}
return ""
}
func extractStringSliceClaim(claims map[string]any, key string) []string {
if key == "" {
return nil
}
value, ok := claims[key]
if !ok {
return nil
}
switch v := value.(type) {
case []string:
return v
case []interface{}:
out := make([]string, 0, len(v))
for _, item := range v {
if str, ok := item.(string); ok {
out = append(out, str)
}
}
return out
case string:
// Split on commas or spaces
parts := strings.FieldsFunc(v, func(r rune) bool {
return r == ',' || r == ' '
})
return parts
default:
return nil
}
}
func matchesValue(candidate string, allowed []string) bool {
candidate = strings.ToLower(strings.TrimSpace(candidate))
if candidate == "" {
return false
}
for _, item := range allowed {
if strings.ToLower(strings.TrimSpace(item)) == candidate {
return true
}
}
return false
}
func matchesDomain(email string, allowed []string) bool {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" {
return false
}
at := strings.LastIndex(email, "@")
if at == -1 || at == len(email)-1 {
return false
}
domain := email[at+1:]
for _, item := range allowed {
normalized := strings.ToLower(strings.Trim(strings.TrimSpace(item), "@"))
if normalized != "" && domain == normalized {
return true
}
}
return false
}
func intersects(values []string, allowed []string) bool {
if len(values) == 0 || len(allowed) == 0 {
return false
}
allowedSet := make(map[string]struct{}, len(allowed))
for _, item := range allowed {
allowedSet[strings.ToLower(strings.TrimSpace(item))] = struct{}{}
}
for _, val := range values {
if _, ok := allowedSet[strings.ToLower(strings.TrimSpace(val))]; ok {
return true
}
}
return false
}
// InitializeOIDCProviders initializes all enabled SSO OIDC providers at startup.
func (r *Router) InitializeOIDCProviders(ctx context.Context) error {
if r.ssoConfig == nil {
return nil
}
for _, provider := range r.ssoConfig.Providers {
if provider.Type == config.SSOProviderTypeOIDC && provider.Enabled && provider.OIDC != nil {
p := provider // capture loop variable
if err := r.oidcManager.InitializeProvider(ctx, p.ID, &p, ""); err != nil {
log.Error().
Err(err).
Str("provider_id", p.ID).
Msg("Failed to initialize SSO OIDC provider")
// Continue initializing other providers
}
}
}
return nil
}
// ssoProviderToOIDCConfig converts an SSO multi-provider config into the legacy
// OIDCConfig that NewOIDCService expects.
func ssoProviderToOIDCConfig(provider *config.SSOProvider, redirectURL string) *config.OIDCConfig {
oidc := provider.OIDC
scopes := oidc.Scopes
if len(scopes) == 0 {
scopes = []string{"openid", "profile", "email"}
}
return &config.OIDCConfig{
Enabled: true,
IssuerURL: oidc.IssuerURL,
ClientID: oidc.ClientID,
ClientSecret: oidc.ClientSecret,
RedirectURL: redirectURL,
LogoutURL: oidc.LogoutURL,
Scopes: scopes,
UsernameClaim: oidc.UsernameClaim,
EmailClaim: oidc.EmailClaim,
GroupsClaim: provider.GroupsClaim,
AllowedGroups: provider.AllowedGroups,
AllowedDomains: provider.AllowedDomains,
AllowedEmails: provider.AllowedEmails,
GroupRoleMappings: provider.GroupRoleMappings,
CABundle: oidc.CABundle,
}
}
// extractOIDCProviderID extracts the provider ID from an OIDC endpoint path.
// Expected paths: /api/oidc/{providerID}/login, /api/oidc/{providerID}/callback
func extractOIDCProviderID(urlPath, endpoint string) string {
parts := strings.Split(strings.TrimPrefix(urlPath, "/"), "/")
// parts: ["api", "oidc", "{providerID}", "{endpoint}"]
if len(parts) >= 4 && parts[0] == "api" && parts[1] == "oidc" && parts[3] == endpoint {
return parts[2]
}
return ""
}
// buildSSOOIDCCallbackURL constructs the callback URL for a multi-provider OIDC flow.
// The path includes the provider ID: /api/oidc/{providerID}/callback
func buildSSOOIDCCallbackURL(req *http.Request, providerID string, configuredURL string) string {
if configured := strings.TrimSpace(configuredURL); configured != "" {
return configured
}
scheme := "http"
if req.TLS != nil {
scheme = "https"
}
peerIP := extractRemoteIP(req.RemoteAddr)
trustedProxy := isTrustedProxyIP(peerIP)
if trustedProxy {
if proto := firstForwardedValue(req.Header.Get("X-Forwarded-Proto")); proto != "" {
scheme = proto
} else if proto := firstForwardedValue(req.Header.Get("X-Forwarded-Scheme")); proto != "" {
scheme = proto
}
}
scheme = strings.ToLower(strings.TrimSpace(scheme))
switch scheme {
case "https", "http":
default:
if req.TLS != nil {
scheme = "https"
} else {
scheme = "http"
}
}
rawHost := ""
if trustedProxy {
rawHost = firstForwardedValue(req.Header.Get("X-Forwarded-Host"))
}
if rawHost == "" {
rawHost = req.Host
}
host, _ := sanitizeForwardedHost(rawHost)
if host == "" {
host = req.Host
}
return fmt.Sprintf("%s://%s/api/oidc/%s/callback", scheme, host, providerID)
}
// handleSSOOIDCLogin handles login for a multi-provider SSO OIDC provider.
// Path: /api/oidc/{providerID}/login
func (r *Router) handleSSOOIDCLogin(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet && req.Method != http.MethodPost {
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Only GET or POST is allowed", nil)
return
}
providerID := extractOIDCProviderID(req.URL.Path, "login")
if providerID == "" {
writeErrorResponse(w, http.StatusBadRequest, "invalid_provider", "Provider ID is required", nil)
return
}
if !validateProviderID(providerID) {
writeErrorResponse(w, http.StatusBadRequest, "invalid_provider", "Invalid provider ID format", nil)
return
}
provider := r.getSSOProvider(providerID)
if provider == nil || provider.Type != config.SSOProviderTypeOIDC || !provider.Enabled {
writeErrorResponse(w, http.StatusNotFound, "provider_not_found", "OIDC provider not found or not enabled", nil)
return
}
if provider.OIDC == nil {
writeErrorResponse(w, http.StatusBadRequest, "invalid_config", "OIDC configuration is missing", nil)
return
}
redirectURL := buildSSOOIDCCallbackURL(req, providerID, provider.OIDC.RedirectURL)
service := r.oidcManager.GetService(providerID)
if service != nil {
// Check if cached service config still matches (especially redirect URL which
// is unknown at startup and may change behind different reverse proxies).
expectedCfg := ssoProviderToOIDCConfig(provider, redirectURL)
if !service.Matches(expectedCfg) {
service = nil // force re-initialization with correct config
}
}
if service == nil {
if err := r.oidcManager.InitializeProvider(req.Context(), providerID, provider, redirectURL); err != nil {
log.Error().Err(err).Str("provider_id", providerID).Str("issuer", provider.OIDC.IssuerURL).Msg("Failed to initialize SSO OIDC provider")
if req.Method == http.MethodGet {
r.redirectOIDCError(w, req, "/", "oidc_init_failed")
return
}
writeErrorResponse(w, http.StatusInternalServerError, "oidc_init_failed", "OIDC provider is unavailable", nil)
return
}
service = r.oidcManager.GetService(providerID)
}
if service == nil {
log.Error().Str("provider_id", providerID).Msg("SSO OIDC service unavailable after initialization")
if req.Method == http.MethodGet {
r.redirectOIDCError(w, req, "/", "oidc_init_failed")
return
}
writeErrorResponse(w, http.StatusInternalServerError, "oidc_init_failed", "OIDC provider is unavailable", nil)
return
}
var returnTo string
if req.Method == http.MethodPost {
var payload struct {
ReturnTo string `json:"returnTo"`
}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil && err != io.EOF {
writeErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid request payload", nil)
return
}
returnTo = sanitizeOIDCReturnTo(payload.ReturnTo)
} else {
returnTo = sanitizeOIDCReturnTo(req.URL.Query().Get("returnTo"))
}
state, entry, err := service.newStateEntry(providerID, returnTo)
if err != nil {
log.Error().Err(err).Msg("Failed to create OIDC state entry")
if req.Method == http.MethodGet {
r.redirectOIDCError(w, req, "/", "oidc_state_error")
return
}
writeErrorResponse(w, http.StatusInternalServerError, "oidc_state_error", "Unable to start OIDC login", nil)
return
}
authURL := service.authCodeURL(state, entry)
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login_initiated", "", GetClientIP(req), req.URL.Path, true, "Provider: "+providerID)
if req.Method == http.MethodGet {
http.Redirect(w, req, authURL, http.StatusFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"authorizationUrl": authURL,
})
}
// handleSSOOIDCCallback handles the callback for a multi-provider SSO OIDC provider.
// Path: /api/oidc/{providerID}/callback
func (r *Router) handleSSOOIDCCallback(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Only GET is allowed", nil)
return
}
providerID := extractOIDCProviderID(req.URL.Path, "callback")
if providerID == "" {
writeErrorResponse(w, http.StatusBadRequest, "invalid_provider", "Provider ID is required", nil)
return
}
if !validateProviderID(providerID) {
writeErrorResponse(w, http.StatusBadRequest, "invalid_provider", "Invalid provider ID format", nil)
return
}
provider := r.getSSOProvider(providerID)
if provider == nil || provider.Type != config.SSOProviderTypeOIDC || !provider.Enabled {
r.redirectOIDCError(w, req, "/", "provider_not_found")
return
}
if provider.OIDC == nil {
r.redirectOIDCError(w, req, "/", "invalid_config")
return
}
redirectURL := buildSSOOIDCCallbackURL(req, providerID, provider.OIDC.RedirectURL)
service := r.oidcManager.GetService(providerID)
if service != nil {
expectedCfg := ssoProviderToOIDCConfig(provider, redirectURL)
if !service.Matches(expectedCfg) {
service = nil
}
}
if service == nil {
if err := r.oidcManager.InitializeProvider(req.Context(), providerID, provider, redirectURL); err != nil {
log.Error().Err(err).Str("provider_id", providerID).Msg("Failed to initialize SSO OIDC provider for callback")
r.redirectOIDCError(w, req, "/", "oidc_init_failed")
return
}
service = r.oidcManager.GetService(providerID)
}
if service == nil {
log.Error().Str("provider_id", providerID).Msg("SSO OIDC service unavailable after initialization for callback")
r.redirectOIDCError(w, req, "/", "oidc_init_failed")
return
}
query := req.URL.Query()
if errParam := query.Get("error"); errParam != "" {
log.Warn().Str("error", errParam).Str("provider_id", providerID).Msg("OIDC provider returned error")
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "Provider error: "+errParam)
r.redirectOIDCError(w, req, "", errParam)
return
}
state := query.Get("state")
if state == "" {
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "Missing state parameter")
r.redirectOIDCError(w, req, "", "missing_state")
return
}
entry, ok := service.consumeState(state)
if !ok {
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "Invalid or expired state")
r.redirectOIDCError(w, req, "", "invalid_state")
return
}
// Safety check: verify the provider ID in the state matches the callback path.
// SSO flow always stores providerID in state; reject any mismatch including empty.
if entry.ProviderID != providerID {
log.Warn().Str("state_provider", entry.ProviderID).Str("path_provider", providerID).Msg("OIDC provider ID mismatch between state and callback path")
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "Provider ID mismatch")
r.redirectOIDCError(w, req, entry.ReturnTo, "provider_mismatch")
return
}
code := query.Get("code")
if code == "" {
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "Missing authorization code")
r.redirectOIDCError(w, req, entry.ReturnTo, "missing_code")
return
}
ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second)
defer cancel()
ctx = service.contextWithHTTPClient(ctx)
token, err := service.exchangeCode(ctx, code, entry)
if err != nil {
log.Error().Err(err).Str("provider_id", providerID).Msg("OIDC code exchange failed")
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "Code exchange failed: "+err.Error())
r.redirectOIDCError(w, req, entry.ReturnTo, "exchange_failed")
return
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok || rawIDToken == "" {
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "Missing ID token")
r.redirectOIDCError(w, req, entry.ReturnTo, "missing_id_token")
return
}
idToken, err := service.verifier.Verify(ctx, rawIDToken)
if err != nil {
errorCode := "invalid_id_token"
if strings.Contains(err.Error(), "unexpected signature algorithm") {
errorCode = "invalid_signature_alg"
}
log.Error().Err(err).Str("provider_id", providerID).Msg("Failed to verify ID token")
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "ID token verification failed: "+err.Error())
r.redirectOIDCError(w, req, entry.ReturnTo, errorCode)
return
}
// Verify nonce matches what we sent in the authorization request
if idToken.Nonce != entry.Nonce {
log.Warn().Str("provider_id", providerID).Msg("OIDC nonce mismatch — possible token replay")
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", "", GetClientIP(req), req.URL.Path, false, "Nonce mismatch")
r.redirectOIDCError(w, req, entry.ReturnTo, "nonce_mismatch")
return
}
claims := make(map[string]any)
if err := idToken.Claims(&claims); err != nil {
log.Error().Err(err).Msg("Failed to parse ID token claims")
r.redirectOIDCError(w, req, entry.ReturnTo, "invalid_claims")
return
}
// Use SSO provider config for claim names
usernameClaim := provider.OIDC.UsernameClaim
if usernameClaim == "" {
usernameClaim = "preferred_username"
}
emailClaim := provider.OIDC.EmailClaim
if emailClaim == "" {
emailClaim = "email"
}
username := extractStringClaim(claims, usernameClaim)
email := extractStringClaim(claims, emailClaim)
if username == "" {
username = email
}
if username == "" {
username = extractStringClaim(claims, "name")
}
if username == "" {
username = idToken.Subject
}
// Apply access restrictions from SSO provider config
if len(provider.AllowedEmails) > 0 && !matchesValue(email, provider.AllowedEmails) {
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", email, GetClientIP(req), req.URL.Path, false, "Email not permitted")
r.redirectOIDCError(w, req, entry.ReturnTo, "email_restricted")
return
}
if len(provider.AllowedDomains) > 0 && !matchesDomain(email, provider.AllowedDomains) {
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", email, GetClientIP(req), req.URL.Path, false, "Email domain restricted")
r.redirectOIDCError(w, req, entry.ReturnTo, "domain_restricted")
return
}
if len(provider.AllowedGroups) > 0 {
groups := extractStringSliceClaim(claims, provider.GroupsClaim)
if !intersects(groups, provider.AllowedGroups) {
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", username, GetClientIP(req), req.URL.Path, false, "Group restriction failed")
r.redirectOIDCError(w, req, entry.ReturnTo, "group_restricted")
return
}
}
// RBAC: Always call UpdateUserRoles so user appears in Users list
if authManager := internalauth.GetManager(); authManager != nil {
groups := extractStringSliceClaim(claims, provider.GroupsClaim)
var rolesToAssign []string
seenRoles := make(map[string]bool)
for _, group := range groups {
if roleID, ok := provider.GroupRoleMappings[group]; ok {
if !seenRoles[roleID] {
rolesToAssign = append(rolesToAssign, roleID)
seenRoles[roleID] = true
}
}
}
if err := authManager.UpdateUserRoles(username, rolesToAssign); err != nil {
log.Error().Err(err).Str("user", username).Msg("Failed to update SSO OIDC user roles")
} else if len(rolesToAssign) > 0 {
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_role_assignment", username, GetClientIP(req), req.URL.Path, true, "Auto-assigned roles: "+strings.Join(rolesToAssign, ", "))
}
}
// Store OIDC tokens for session refresh
var oidcTokens *OIDCTokenInfo
if token.RefreshToken != "" {
oidcTokens = &OIDCTokenInfo{
RefreshToken: token.RefreshToken,
AccessTokenExp: token.Expiry,
Issuer: provider.OIDC.IssuerURL,
ClientID: provider.OIDC.ClientID,
}
}
if err := r.establishOIDCSession(w, req, username, oidcTokens); err != nil {
log.Error().Err(err).Msg("Failed to establish session after SSO OIDC login")
r.redirectOIDCError(w, req, entry.ReturnTo, "session_failed")
return
}
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_oidc_login", username, GetClientIP(req), req.URL.Path, true, "SSO OIDC login success via provider: "+providerID)
http.Redirect(w, req, buildLocalRedirectTarget(entry.ReturnTo, map[string]string{
"oidc": "success",
}), http.StatusFound)
}