package api import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/rcourtman/pulse-go-rewrite/internal/config" internalauth "github.com/rcourtman/pulse-go-rewrite/pkg/auth" "github.com/rs/zerolog/log" ) func (r *Router) handleOIDCLogin(w http.ResponseWriter, req *http.Request) { // Support both GET (direct redirect) and POST (JSON response) // GET is preferred for browsers as it guarantees same-window navigation if req.Method != http.MethodGet && req.Method != http.MethodPost { writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Only GET or POST is allowed", nil) return } cfg := r.ensureOIDCConfig() if cfg == nil || !cfg.Enabled { if req.Method == http.MethodGet { // Redirect back to login with error instead of plain text r.redirectOIDCError(w, req, "/", "oidc_disabled") return } writeErrorResponse(w, http.StatusBadRequest, "oidc_disabled", "OIDC authentication is not enabled", nil) return } // Build redirect URL from request (respects X-Forwarded-* headers) redirectURL := buildRedirectURL(req, cfg.RedirectURL) service, err := r.getOIDCService(req.Context(), redirectURL) if err != nil { log.Error().Err(err).Str("issuer", cfg.IssuerURL).Msg("Failed to initialise OIDC service") if req.Method == http.MethodGet { // Redirect back to login with error instead of plain text r.redirectOIDCError(w, req, "/", "oidc_init_failed") return } writeErrorResponse(w, http.StatusInternalServerError, "oidc_init_failed", "OIDC provider is unavailable", nil) return } log.Debug().Str("issuer", cfg.IssuerURL).Str("client_id", cfg.ClientID).Msg("Starting OIDC login flow") var returnTo string if req.Method == http.MethodPost { var payload struct { ReturnTo string `json:"returnTo"` } if err := json.NewDecoder(req.Body).Decode(&payload); err != nil && err != io.EOF { writeErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid request payload", nil) return } returnTo = sanitizeOIDCReturnTo(payload.ReturnTo) } else { // GET: read returnTo from query param returnTo = sanitizeOIDCReturnTo(req.URL.Query().Get("returnTo")) } state, entry, err := service.newStateEntry(returnTo) if err != nil { log.Error().Err(err).Msg("Failed to create OIDC state entry") if req.Method == http.MethodGet { r.redirectOIDCError(w, req, "/", "oidc_state_error") return } writeErrorResponse(w, http.StatusInternalServerError, "oidc_state_error", "Unable to start OIDC login", nil) return } authURL := service.authCodeURL(state, entry) // GET: direct HTTP redirect (guarantees same-window navigation in all browsers) // POST: return JSON (for API clients/backwards compatibility) if req.Method == http.MethodGet { http.Redirect(w, req, authURL, http.StatusFound) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "authorizationUrl": authURL, }) } func (r *Router) handleOIDCCallback(w http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Only GET is allowed", nil) return } cfg := r.ensureOIDCConfig() if cfg == nil || !cfg.Enabled { http.Error(w, "OIDC is not enabled", http.StatusNotFound) return } // Build redirect URL from request (respects X-Forwarded-* headers) redirectURL := buildRedirectURL(req, cfg.RedirectURL) service, err := r.getOIDCService(req.Context(), redirectURL) if err != nil { log.Error().Err(err).Str("issuer", cfg.IssuerURL).Msg("Failed to initialise OIDC service for callback") r.redirectOIDCError(w, req, "", "oidc_init_failed") return } log.Debug().Str("issuer", cfg.IssuerURL).Msg("Processing OIDC callback") query := req.URL.Query() if errParam := query.Get("error"); errParam != "" { log.Warn().Str("error", errParam).Msg("OIDC provider returned error") LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", "", GetClientIP(req), req.URL.Path, false, "Provider error: "+errParam) r.redirectOIDCError(w, req, "", errParam) return } state := query.Get("state") if state == "" { LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", "", GetClientIP(req), req.URL.Path, false, "Missing state parameter") r.redirectOIDCError(w, req, "", "missing_state") return } entry, ok := service.consumeState(state) if !ok { LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", "", GetClientIP(req), req.URL.Path, false, "Invalid or expired state") r.redirectOIDCError(w, req, "", "invalid_state") return } code := query.Get("code") if code == "" { LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", "", GetClientIP(req), req.URL.Path, false, "Missing authorization code") r.redirectOIDCError(w, req, entry.ReturnTo, "missing_code") return } ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second) defer cancel() ctx = service.contextWithHTTPClient(ctx) token, err := service.exchangeCode(ctx, code, entry) if err != nil { log.Error().Err(err).Str("issuer", cfg.IssuerURL).Msg("OIDC code exchange failed") LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", "", GetClientIP(req), req.URL.Path, false, "Code exchange failed: "+err.Error()) r.redirectOIDCError(w, req, entry.ReturnTo, "exchange_failed") return } log.Debug().Msg("OIDC code exchange successful") rawIDToken, ok := token.Extra("id_token").(string) if !ok || rawIDToken == "" { LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", "", GetClientIP(req), req.URL.Path, false, "Missing ID token") r.redirectOIDCError(w, req, entry.ReturnTo, "missing_id_token") return } // Verify the ID token idToken, err := service.verifier.Verify(ctx, rawIDToken) if err != nil { errorCode := "invalid_id_token" logMessage := "Failed to verify ID token - check issuer URL matches token issuer claim" if strings.Contains(err.Error(), "unexpected signature algorithm") { errorCode = "invalid_signature_alg" logMessage = "Failed to verify ID token - provider is issuing HS256 tokens, Pulse requires RS256" } log.Error(). Err(err). Str("issuer", cfg.IssuerURL). Str("client_id", cfg.ClientID). Str("redirect_url", cfg.RedirectURL). Msg(logMessage) LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", "", GetClientIP(req), req.URL.Path, false, "ID token verification failed: "+err.Error()) r.redirectOIDCError(w, req, entry.ReturnTo, errorCode) return } log.Debug().Str("subject", idToken.Subject).Msg("ID token verified successfully") claims := make(map[string]any) if err := idToken.Claims(&claims); err != nil { log.Error().Err(err).Msg("Failed to parse ID token claims") LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", "", GetClientIP(req), req.URL.Path, false, "Invalid token claims") r.redirectOIDCError(w, req, entry.ReturnTo, "invalid_claims") return } username := extractStringClaim(claims, cfg.UsernameClaim) email := extractStringClaim(claims, cfg.EmailClaim) if username == "" { username = email } if username == "" { username = extractStringClaim(claims, "name") } if username == "" { username = idToken.Subject } log.Debug(). Str("username", username). Str("email", email). Str("subject", idToken.Subject). Str("username_claim", cfg.UsernameClaim). Str("email_claim", cfg.EmailClaim). Msg("Extracted user identity from claims") if len(cfg.AllowedEmails) > 0 && !matchesValue(email, cfg.AllowedEmails) { log.Debug().Str("email", email).Strs("allowed_emails", cfg.AllowedEmails).Msg("Email not in allowed list") LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", email, GetClientIP(req), req.URL.Path, false, "Email not permitted") r.redirectOIDCError(w, req, entry.ReturnTo, "email_restricted") return } if len(cfg.AllowedDomains) > 0 && !matchesDomain(email, cfg.AllowedDomains) { log.Debug().Str("email", email).Strs("allowed_domains", cfg.AllowedDomains).Msg("Email domain not in allowed list") LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", email, GetClientIP(req), req.URL.Path, false, "Email domain restricted") r.redirectOIDCError(w, req, entry.ReturnTo, "domain_restricted") return } if len(cfg.AllowedGroups) > 0 { groups := extractStringSliceClaim(claims, cfg.GroupsClaim) log.Debug(). Strs("user_groups", groups). Strs("allowed_groups", cfg.AllowedGroups). Str("groups_claim", cfg.GroupsClaim). Msg("Checking group membership") if !intersects(groups, cfg.AllowedGroups) { log.Debug().Msg("User not in any allowed groups") LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", username, GetClientIP(req), req.URL.Path, false, "Group restriction failed") r.redirectOIDCError(w, req, entry.ReturnTo, "group_restricted") return } log.Debug().Msg("User group membership verified") } // RBAC Integration: Map OIDC groups to Pulse roles and ensure user is registered if authManager := internalauth.GetManager(); authManager != nil { groups := extractStringSliceClaim(claims, cfg.GroupsClaim) var rolesToAssign []string seenRoles := make(map[string]bool) for _, group := range groups { if roleID, ok := cfg.GroupRoleMappings[group]; ok { if !seenRoles[roleID] { rolesToAssign = append(rolesToAssign, roleID) seenRoles[roleID] = true } } } if len(rolesToAssign) > 0 { log.Info(). Str("user", username). Strs("mapped_roles", rolesToAssign). Msg("Auto-assigning roles based on OIDC group mapping") if err := authManager.UpdateUserRoles(username, rolesToAssign); err != nil { log.Error().Err(err).Str("user", username).Msg("Failed to auto-assign OIDC roles") LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_role_assignment", username, GetClientIP(req), req.URL.Path, false, "Failed to auto-assign roles: "+strings.Join(rolesToAssign, ", ")) // We don't fail the login here, but log the error } else { LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_role_assignment", username, GetClientIP(req), req.URL.Path, true, "Auto-assigned roles: "+strings.Join(rolesToAssign, ", ")) } } else if _, exists := authManager.GetUserAssignment(username); !exists { // Ensure SSO user appears in the Users list even without role mappings _ = authManager.UpdateUserRoles(username, []string{}) } } // Prepare OIDC token info for session storage (enables refresh token support) var oidcTokens *OIDCTokenInfo if token.RefreshToken != "" { oidcTokens = &OIDCTokenInfo{ RefreshToken: token.RefreshToken, AccessTokenExp: token.Expiry, Issuer: cfg.IssuerURL, ClientID: cfg.ClientID, } log.Debug(). Time("access_token_expiry", token.Expiry). Bool("has_refresh_token", true). Msg("OIDC tokens will be stored for session refresh") } if err := r.establishOIDCSession(w, req, username, oidcTokens); err != nil { log.Error().Err(err).Msg("Failed to establish session after OIDC login") LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", username, GetClientIP(req), req.URL.Path, false, "Session creation failed") r.redirectOIDCError(w, req, entry.ReturnTo, "session_failed") return } LogAuditEventForTenant(GetOrgID(req.Context()), "oidc_login", username, GetClientIP(req), req.URL.Path, true, "OIDC login success") target := sanitizeOIDCReturnTo(entry.ReturnTo) if target == "" { target = "/" } target = addQueryParam(target, "oidc", "success") http.Redirect(w, req, target, http.StatusFound) } func (r *Router) getOIDCService(ctx context.Context, redirectURL string) (*OIDCService, error) { cfg := r.ensureOIDCConfig() if cfg == nil || !cfg.Enabled { return nil, errors.New("oidc disabled") } r.oidcMu.Lock() defer r.oidcMu.Unlock() // Create a config clone with the dynamic redirect URL cfgWithRedirect := cfg.Clone() cfgWithRedirect.RedirectURL = redirectURL if r.oidcService != nil && r.oidcService.Matches(cfgWithRedirect) { return r.oidcService, nil } service, err := NewOIDCService(ctx, cfgWithRedirect) if err != nil { return nil, err } r.oidcService = service return service, nil } func sanitizeOIDCReturnTo(raw string) string { trimmed := strings.TrimSpace(raw) if trimmed == "" { return "" } if !strings.HasPrefix(trimmed, "/") { return "" } if len(trimmed) > 1 && (trimmed[1] == '/' || trimmed[1] == '\\') { return "" } parsed, err := url.Parse(trimmed) if err != nil || parsed.IsAbs() || parsed.Host != "" { return "" } return trimmed } func (r *Router) redirectOIDCError(w http.ResponseWriter, req *http.Request, returnTo string, code string) { target := returnTo if target == "" { target = "/" } target = addQueryParam(target, "oidc", "error") if code != "" { target = addQueryParam(target, "oidc_error", code) } http.Redirect(w, req, target, http.StatusFound) } func addQueryParam(path, key, value string) string { if path == "" { path = "/" } u, err := url.Parse(path) if err != nil { return path } q := u.Query() q.Set(key, value) u.RawQuery = q.Encode() return u.String() } func extractStringClaim(claims map[string]any, key string) string { if key == "" { return "" } value, ok := claims[key] if !ok { return "" } switch v := value.(type) { case string: return strings.TrimSpace(v) case []string: if len(v) > 0 { return strings.TrimSpace(v[0]) } case []interface{}: for _, item := range v { if str, ok := item.(string); ok { return strings.TrimSpace(str) } } } return "" } func extractStringSliceClaim(claims map[string]any, key string) []string { if key == "" { return nil } value, ok := claims[key] if !ok { return nil } switch v := value.(type) { case []string: return v case []interface{}: out := make([]string, 0, len(v)) for _, item := range v { if str, ok := item.(string); ok { out = append(out, str) } } return out case string: // Split on commas or spaces parts := strings.FieldsFunc(v, func(r rune) bool { return r == ',' || r == ' ' }) return parts default: return nil } } func matchesValue(candidate string, allowed []string) bool { candidate = strings.ToLower(strings.TrimSpace(candidate)) if candidate == "" { return false } for _, item := range allowed { if strings.ToLower(strings.TrimSpace(item)) == candidate { return true } } return false } func matchesDomain(email string, allowed []string) bool { email = strings.ToLower(strings.TrimSpace(email)) if email == "" { return false } at := strings.LastIndex(email, "@") if at == -1 || at == len(email)-1 { return false } domain := email[at+1:] for _, item := range allowed { normalized := strings.ToLower(strings.Trim(strings.TrimSpace(item), "@")) if normalized != "" && domain == normalized { return true } } return false } func intersects(values []string, allowed []string) bool { if len(values) == 0 || len(allowed) == 0 { return false } allowedSet := make(map[string]struct{}, len(allowed)) for _, item := range allowed { allowedSet[strings.ToLower(strings.TrimSpace(item))] = struct{}{} } for _, val := range values { if _, ok := allowedSet[strings.ToLower(strings.TrimSpace(val))]; ok { return true } } return false } func (r *Router) ensureOIDCConfig() *config.OIDCConfig { if r.config.OIDC == nil { r.config.OIDC = config.NewOIDCConfig() r.config.OIDC.ApplyDefaults(r.config.PublicURL) } return r.config.OIDC } // buildRedirectURL constructs the OIDC redirect URL from the incoming request, // respecting X-Forwarded-* headers when behind a reverse proxy func buildRedirectURL(req *http.Request, configuredURL string) string { // If explicitly configured, use that if configured := strings.TrimSpace(configuredURL); configured != "" { return configured } // Build from request headers (respects reverse proxy headers) scheme := "http" if req.TLS != nil { scheme = "https" } peerIP := extractRemoteIP(req.RemoteAddr) trustedProxy := isTrustedProxyIP(peerIP) if trustedProxy { if proto := firstForwardedValue(req.Header.Get("X-Forwarded-Proto")); proto != "" { scheme = proto } else if proto := firstForwardedValue(req.Header.Get("X-Forwarded-Scheme")); proto != "" { scheme = proto } } scheme = strings.ToLower(strings.TrimSpace(scheme)) switch scheme { case "https", "http": default: if req.TLS != nil { scheme = "https" } else { scheme = "http" } } rawHost := "" if trustedProxy { rawHost = firstForwardedValue(req.Header.Get("X-Forwarded-Host")) } if rawHost == "" { rawHost = req.Host } host, _ := sanitizeForwardedHost(rawHost) if host == "" { host = req.Host } redirectURL := fmt.Sprintf("%s://%s%s", scheme, host, config.DefaultOIDCCallbackPath) log.Debug(). Str("scheme", scheme). Str("host", host). Str("x_forwarded_proto", req.Header.Get("X-Forwarded-Proto")). Str("x_forwarded_host", req.Header.Get("X-Forwarded-Host")). Bool("trusted_proxy", trustedProxy). Str("peer_ip", peerIP). Str("redirect_url", redirectURL). Bool("has_tls", req.TLS != nil). Msg("Built OIDC redirect URL from request") return redirectURL }