mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-29 20:10:21 +00:00
Two remaining issues from #1255 after the 5.1.10 fixes: 1. OIDC/SAML provider edit fields appeared blank because the GET endpoint returned a flattened response while the frontend reads nested oidc/saml objects. Now returns the full provider config with secrets redacted (client secret, SP private key). 2. SSO users didn't appear in Settings > Users because RBAC entries were only created when group-role mappings matched. Now ensures every SSO user is registered in RBAC on login, even without role mappings. Also fixes: SAML SP private key and certificate lost on edit (no preservation logic existed), OIDC client secret preservation hardened to check actual secret presence not just flag.
1143 lines
35 KiB
Go
1143 lines
35 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/crewjam/saml"
|
|
"github.com/google/uuid"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// Security constants for SSO
|
|
const (
|
|
maxProviderIDLength = 64
|
|
maxProviderNameLength = 128
|
|
maxURLLength = 2048
|
|
maxRequestBodySize = 1 << 20 // 1MB
|
|
)
|
|
|
|
// providerIDRegex validates provider IDs (alphanumeric, hyphens, underscores)
|
|
var providerIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}[a-zA-Z0-9]?$`)
|
|
|
|
// validateProviderID checks if a provider ID is safe and valid
|
|
func validateProviderID(id string) bool {
|
|
if id == "" || len(id) > maxProviderIDLength {
|
|
return false
|
|
}
|
|
return providerIDRegex.MatchString(id)
|
|
}
|
|
|
|
// sanitizeProviderName sanitizes a provider name
|
|
func sanitizeProviderName(name string) string {
|
|
name = strings.TrimSpace(name)
|
|
if len(name) > maxProviderNameLength {
|
|
name = name[:maxProviderNameLength]
|
|
}
|
|
// Remove control characters
|
|
name = strings.Map(func(r rune) rune {
|
|
if r < 32 || r == 127 {
|
|
return -1
|
|
}
|
|
return r
|
|
}, name)
|
|
return name
|
|
}
|
|
|
|
// validateURL checks if a URL is valid and uses an allowed scheme
|
|
func validateURL(urlStr string, allowedSchemes []string) bool {
|
|
if urlStr == "" || len(urlStr) > maxURLLength {
|
|
return false
|
|
}
|
|
parsed, err := url.ParseRequestURI(urlStr)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
for _, scheme := range allowedSchemes {
|
|
if strings.EqualFold(parsed.Scheme, scheme) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// SSOProviderResponse represents an SSO provider for API responses
|
|
type SSOProviderResponse struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Type string `json:"type"`
|
|
Enabled bool `json:"enabled"`
|
|
DisplayName string `json:"displayName,omitempty"`
|
|
IconURL string `json:"iconUrl,omitempty"`
|
|
Priority int `json:"priority"`
|
|
|
|
// OIDC-specific (only present for OIDC providers)
|
|
OIDCIssuerURL string `json:"oidcIssuerUrl,omitempty"`
|
|
OIDCClientID string `json:"oidcClientId,omitempty"`
|
|
OIDCClientSecretSet bool `json:"oidcClientSecretSet,omitempty"`
|
|
|
|
// SAML-specific (only present for SAML providers)
|
|
SAMLIDPEntityID string `json:"samlIdpEntityId,omitempty"`
|
|
SAMLSPEntityID string `json:"samlSpEntityId,omitempty"`
|
|
SAMLMetadataURL string `json:"samlMetadataUrl,omitempty"`
|
|
SAMLACSUrl string `json:"samlAcsUrl,omitempty"`
|
|
|
|
// Common restrictions
|
|
AllowedGroups []string `json:"allowedGroups,omitempty"`
|
|
AllowedDomains []string `json:"allowedDomains,omitempty"`
|
|
AllowedEmails []string `json:"allowedEmails,omitempty"`
|
|
}
|
|
|
|
// SSOProvidersListResponse represents the list of SSO providers
|
|
type SSOProvidersListResponse struct {
|
|
Providers []SSOProviderResponse `json:"providers"`
|
|
DefaultProviderID string `json:"defaultProviderId,omitempty"`
|
|
AllowMultipleProviders bool `json:"allowMultipleProviders"`
|
|
}
|
|
|
|
// handleSSOProviders handles listing and creating SSO providers
|
|
func (r *Router) handleSSOProviders(w http.ResponseWriter, req *http.Request) {
|
|
switch req.Method {
|
|
case http.MethodGet:
|
|
r.handleListSSOProviders(w, req)
|
|
case http.MethodPost:
|
|
r.handleCreateSSOProvider(w, req)
|
|
default:
|
|
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed", nil)
|
|
}
|
|
}
|
|
|
|
// handleSSOProvider handles getting, updating, and deleting a specific SSO provider
|
|
func (r *Router) handleSSOProvider(w http.ResponseWriter, req *http.Request) {
|
|
// Extract provider ID from path: /api/security/sso/providers/{id}
|
|
providerID := strings.TrimPrefix(req.URL.Path, "/api/security/sso/providers/")
|
|
providerID = strings.TrimSuffix(providerID, "/")
|
|
|
|
if providerID == "" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "missing_id", "Provider ID is required", nil)
|
|
return
|
|
}
|
|
|
|
// Security: Validate provider ID format to prevent injection attacks
|
|
if !validateProviderID(providerID) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_id", "Invalid provider ID format", nil)
|
|
return
|
|
}
|
|
|
|
switch req.Method {
|
|
case http.MethodGet:
|
|
r.handleGetSSOProvider(w, req, providerID)
|
|
case http.MethodPut:
|
|
r.handleUpdateSSOProvider(w, req, providerID)
|
|
case http.MethodDelete:
|
|
r.handleDeleteSSOProvider(w, req, providerID)
|
|
default:
|
|
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed", nil)
|
|
}
|
|
}
|
|
|
|
func (r *Router) handleListSSOProviders(w http.ResponseWriter, req *http.Request) {
|
|
if r.ssoConfig == nil {
|
|
r.ssoConfig = config.NewSSOConfig()
|
|
}
|
|
|
|
response := SSOProvidersListResponse{
|
|
Providers: make([]SSOProviderResponse, 0),
|
|
DefaultProviderID: r.ssoConfig.DefaultProviderID,
|
|
AllowMultipleProviders: r.ssoConfig.AllowMultipleProviders,
|
|
}
|
|
|
|
for _, p := range r.ssoConfig.Providers {
|
|
response.Providers = append(response.Providers, providerToResponse(&p, r.config.PublicURL))
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
func (r *Router) handleGetSSOProvider(w http.ResponseWriter, req *http.Request, providerID string) {
|
|
if r.ssoConfig == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
provider := r.ssoConfig.GetProvider(providerID)
|
|
if provider == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
// Return the full provider config so the edit form can populate all fields.
|
|
// We make a shallow copy and redact sensitive secrets.
|
|
safe := *provider
|
|
if safe.OIDC != nil {
|
|
oidcCopy := *safe.OIDC
|
|
oidcCopy.ClientSecretSet = oidcCopy.ClientSecret != "" || oidcCopy.ClientSecretSet
|
|
oidcCopy.ClientSecret = "" // Never expose the secret
|
|
safe.OIDC = &oidcCopy
|
|
}
|
|
if safe.SAML != nil {
|
|
samlCopy := *safe.SAML
|
|
samlCopy.SPPrivateKey = "" // Never expose private key
|
|
safe.SAML = &samlCopy
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(safe)
|
|
}
|
|
|
|
func (r *Router) handleCreateSSOProvider(w http.ResponseWriter, req *http.Request) {
|
|
body, err := io.ReadAll(io.LimitReader(req.Body, maxRequestBodySize))
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "read_error", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
var provider config.SSOProvider
|
|
if err := json.Unmarshal(body, &provider); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_json", "Invalid JSON payload", nil)
|
|
return
|
|
}
|
|
|
|
// Generate ID if not provided
|
|
if provider.ID == "" {
|
|
provider.ID = uuid.NewString()
|
|
}
|
|
|
|
// Security: Validate provider ID format
|
|
if !validateProviderID(provider.ID) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid provider ID format", nil)
|
|
return
|
|
}
|
|
|
|
// Sanitize provider name
|
|
provider.Name = sanitizeProviderName(provider.Name)
|
|
provider.DisplayName = sanitizeProviderName(provider.DisplayName)
|
|
|
|
// Validate provider
|
|
if provider.Name == "" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Provider name is required", nil)
|
|
return
|
|
}
|
|
|
|
if provider.Type != config.SSOProviderTypeOIDC && provider.Type != config.SSOProviderTypeSAML {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Provider type must be 'oidc' or 'saml'", nil)
|
|
return
|
|
}
|
|
|
|
// Security: Validate OIDC configuration
|
|
if provider.Type == config.SSOProviderTypeOIDC {
|
|
if provider.OIDC == nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "OIDC configuration is required", nil)
|
|
return
|
|
}
|
|
if provider.OIDC.IssuerURL != "" && !validateURL(provider.OIDC.IssuerURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid OIDC issuer URL", nil)
|
|
return
|
|
}
|
|
if provider.OIDC.RedirectURL != "" && !validateURL(provider.OIDC.RedirectURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid OIDC redirect URL", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Security: Validate SAML configuration
|
|
if provider.Type == config.SSOProviderTypeSAML {
|
|
if provider.SAML == nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "SAML configuration is required", nil)
|
|
return
|
|
}
|
|
if provider.SAML.IDPMetadataURL != "" && !validateURL(provider.SAML.IDPMetadataURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid SAML metadata URL", nil)
|
|
return
|
|
}
|
|
if provider.SAML.IDPSSOURL != "" && !validateURL(provider.SAML.IDPSSOURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid SAML SSO URL", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Security: Validate icon URL if provided
|
|
if provider.IconURL != "" && !validateURL(provider.IconURL, []string{"https", "http", "data"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid icon URL", nil)
|
|
return
|
|
}
|
|
|
|
if r.ssoConfig == nil {
|
|
r.ssoConfig = config.NewSSOConfig()
|
|
}
|
|
|
|
// Check for duplicate ID
|
|
if r.ssoConfig.GetProvider(provider.ID) != nil {
|
|
writeErrorResponse(w, http.StatusConflict, "duplicate_id", "Provider with this ID already exists", nil)
|
|
return
|
|
}
|
|
|
|
// Track whether secrets are configured (so updates can preserve them)
|
|
if provider.OIDC != nil && provider.OIDC.ClientSecret != "" {
|
|
provider.OIDC.ClientSecretSet = true
|
|
}
|
|
|
|
// Add provider
|
|
if err := r.ssoConfig.AddProvider(provider); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "add_error", err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
// Persist configuration
|
|
if err := r.saveSSOConfig(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to persist SSO configuration")
|
|
// Remove the provider we just added since persistence failed
|
|
r.ssoConfig.RemoveProvider(provider.ID)
|
|
writeErrorResponse(w, http.StatusInternalServerError, "save_error", "Failed to save configuration", nil)
|
|
return
|
|
}
|
|
|
|
// Initialize SAML provider if applicable
|
|
if provider.Type == config.SSOProviderTypeSAML && provider.Enabled && provider.SAML != nil {
|
|
if err := r.samlManager.InitializeProvider(req.Context(), provider.ID, provider.SAML); err != nil {
|
|
log.Warn().Err(err).Str("provider_id", provider.ID).Msg("Failed to initialize SAML provider (will retry on first use)")
|
|
}
|
|
}
|
|
|
|
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_provider_created", "", GetClientIP(req), req.URL.Path, true, "Created provider: "+provider.Name)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusCreated)
|
|
json.NewEncoder(w).Encode(providerToResponse(&provider, r.config.PublicURL))
|
|
}
|
|
|
|
func (r *Router) handleUpdateSSOProvider(w http.ResponseWriter, req *http.Request, providerID string) {
|
|
if r.ssoConfig == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
existing := r.ssoConfig.GetProvider(providerID)
|
|
if existing == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(req.Body, maxRequestBodySize))
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "read_error", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
var updated config.SSOProvider
|
|
if err := json.Unmarshal(body, &updated); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_json", "Invalid JSON payload", nil)
|
|
return
|
|
}
|
|
|
|
// Ensure ID matches
|
|
updated.ID = providerID
|
|
|
|
// Sanitize inputs
|
|
updated.Name = sanitizeProviderName(updated.Name)
|
|
updated.DisplayName = sanitizeProviderName(updated.DisplayName)
|
|
|
|
if updated.Name == "" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Provider name is required", nil)
|
|
return
|
|
}
|
|
|
|
// Security: Validate URLs for OIDC
|
|
if updated.Type == config.SSOProviderTypeOIDC && updated.OIDC != nil {
|
|
if updated.OIDC.IssuerURL != "" && !validateURL(updated.OIDC.IssuerURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid OIDC issuer URL", nil)
|
|
return
|
|
}
|
|
if updated.OIDC.RedirectURL != "" && !validateURL(updated.OIDC.RedirectURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid OIDC redirect URL", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Security: Validate URLs for SAML
|
|
if updated.Type == config.SSOProviderTypeSAML && updated.SAML != nil {
|
|
if updated.SAML.IDPMetadataURL != "" && !validateURL(updated.SAML.IDPMetadataURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid SAML metadata URL", nil)
|
|
return
|
|
}
|
|
if updated.SAML.IDPSSOURL != "" && !validateURL(updated.SAML.IDPSSOURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid SAML SSO URL", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Security: Validate icon URL if provided
|
|
if updated.IconURL != "" && !validateURL(updated.IconURL, []string{"https", "http", "data"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid icon URL", nil)
|
|
return
|
|
}
|
|
|
|
// Preserve provider-specific configs when the update payload omits nested objects
|
|
// (e.g. when the frontend sends back the flat list-response shape).
|
|
if updated.OIDC == nil && existing.OIDC != nil {
|
|
oidcCopy := *existing.OIDC
|
|
updated.OIDC = &oidcCopy
|
|
}
|
|
if updated.SAML == nil && existing.SAML != nil {
|
|
samlCopy := *existing.SAML
|
|
updated.SAML = &samlCopy
|
|
}
|
|
|
|
// Preserve secrets if not provided in update
|
|
if updated.Type == config.SSOProviderTypeOIDC && updated.OIDC != nil && existing.OIDC != nil {
|
|
if updated.OIDC.ClientSecret == "" && (existing.OIDC.ClientSecretSet || existing.OIDC.ClientSecret != "") {
|
|
updated.OIDC.ClientSecret = existing.OIDC.ClientSecret
|
|
updated.OIDC.ClientSecretSet = true
|
|
} else if updated.OIDC.ClientSecret != "" {
|
|
updated.OIDC.ClientSecretSet = true
|
|
}
|
|
}
|
|
if updated.Type == config.SSOProviderTypeSAML && updated.SAML != nil && existing.SAML != nil {
|
|
if updated.SAML.SPPrivateKey == "" && existing.SAML.SPPrivateKey != "" {
|
|
updated.SAML.SPPrivateKey = existing.SAML.SPPrivateKey
|
|
}
|
|
if updated.SAML.SPCertificate == "" && existing.SAML.SPCertificate != "" {
|
|
updated.SAML.SPCertificate = existing.SAML.SPCertificate
|
|
}
|
|
}
|
|
|
|
// Update provider
|
|
if err := r.ssoConfig.UpdateProvider(updated); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "update_error", err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
// Persist configuration
|
|
if err := r.saveSSOConfig(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to persist SSO configuration")
|
|
// Revert to existing provider
|
|
r.ssoConfig.UpdateProvider(*existing)
|
|
writeErrorResponse(w, http.StatusInternalServerError, "save_error", "Failed to save configuration", nil)
|
|
return
|
|
}
|
|
|
|
// Re-initialize SAML provider if applicable
|
|
if updated.Type == config.SSOProviderTypeSAML && updated.SAML != nil {
|
|
if updated.Enabled {
|
|
if err := r.samlManager.InitializeProvider(req.Context(), updated.ID, updated.SAML); err != nil {
|
|
log.Warn().Err(err).Str("provider_id", updated.ID).Msg("Failed to re-initialize SAML provider")
|
|
}
|
|
} else {
|
|
r.samlManager.RemoveProvider(updated.ID)
|
|
}
|
|
}
|
|
|
|
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_provider_updated", "", GetClientIP(req), req.URL.Path, true, "Updated provider: "+updated.Name)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(providerToResponse(&updated, r.config.PublicURL))
|
|
}
|
|
|
|
func (r *Router) handleDeleteSSOProvider(w http.ResponseWriter, req *http.Request, providerID string) {
|
|
if r.ssoConfig == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
existing := r.ssoConfig.GetProvider(providerID)
|
|
if existing == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
// Remove provider
|
|
if err := r.ssoConfig.RemoveProvider(providerID); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "remove_error", err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
// Persist configuration
|
|
if err := r.saveSSOConfig(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to persist SSO configuration")
|
|
// Re-add the provider since persistence failed
|
|
r.ssoConfig.AddProvider(*existing)
|
|
writeErrorResponse(w, http.StatusInternalServerError, "save_error", "Failed to save configuration", nil)
|
|
return
|
|
}
|
|
|
|
// Remove SAML service if applicable
|
|
if existing.Type == config.SSOProviderTypeSAML {
|
|
r.samlManager.RemoveProvider(providerID)
|
|
}
|
|
|
|
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_provider_deleted", "", GetClientIP(req), req.URL.Path, true, "Deleted provider: "+existing.Name)
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (r *Router) saveSSOConfig() error {
|
|
if r.persistence == nil {
|
|
return nil
|
|
}
|
|
return r.persistence.SaveSSOConfig(r.ssoConfig)
|
|
}
|
|
|
|
func providerToResponse(p *config.SSOProvider, publicURL string) SSOProviderResponse {
|
|
resp := SSOProviderResponse{
|
|
ID: p.ID,
|
|
Name: p.Name,
|
|
Type: string(p.Type),
|
|
Enabled: p.Enabled,
|
|
DisplayName: p.DisplayName,
|
|
IconURL: p.IconURL,
|
|
Priority: p.Priority,
|
|
AllowedGroups: p.AllowedGroups,
|
|
AllowedDomains: p.AllowedDomains,
|
|
AllowedEmails: p.AllowedEmails,
|
|
}
|
|
|
|
if resp.DisplayName == "" {
|
|
resp.DisplayName = p.Name
|
|
}
|
|
|
|
baseURL := publicURL
|
|
if baseURL == "" {
|
|
baseURL = "http://localhost:7655"
|
|
}
|
|
|
|
if p.Type == config.SSOProviderTypeOIDC && p.OIDC != nil {
|
|
resp.OIDCIssuerURL = p.OIDC.IssuerURL
|
|
resp.OIDCClientID = p.OIDC.ClientID
|
|
resp.OIDCClientSecretSet = p.OIDC.ClientSecretSet || p.OIDC.ClientSecret != ""
|
|
}
|
|
|
|
if p.Type == config.SSOProviderTypeSAML && p.SAML != nil {
|
|
resp.SAMLIDPEntityID = p.SAML.IDPEntityID
|
|
if resp.SAMLIDPEntityID == "" {
|
|
resp.SAMLIDPEntityID = p.SAML.IDPIssuer
|
|
}
|
|
resp.SAMLSPEntityID = p.SAML.SPEntityID
|
|
if resp.SAMLSPEntityID == "" {
|
|
resp.SAMLSPEntityID = baseURL + "/saml/" + p.ID
|
|
}
|
|
resp.SAMLMetadataURL = baseURL + "/api/saml/" + p.ID + "/metadata"
|
|
resp.SAMLACSUrl = baseURL + "/api/saml/" + p.ID + "/acs"
|
|
}
|
|
|
|
return resp
|
|
}
|
|
|
|
// ============================================================================
|
|
// SSO Provider Connection Testing
|
|
// ============================================================================
|
|
|
|
const (
|
|
maxTestRequestBodySize = 32 * 1024 // 32KB
|
|
testConnectionTimeout = 30 * time.Second
|
|
)
|
|
|
|
// SSOTestRequest represents a request to test SSO provider configuration
|
|
type SSOTestRequest struct {
|
|
Type string `json:"type"` // "saml" or "oidc"
|
|
SAML *SAMLTestConfig `json:"saml,omitempty"`
|
|
OIDC *OIDCTestConfig `json:"oidc,omitempty"`
|
|
}
|
|
|
|
// SAMLTestConfig contains SAML configuration to test
|
|
type SAMLTestConfig struct {
|
|
IDPMetadataURL string `json:"idpMetadataUrl,omitempty"`
|
|
IDPMetadataXML string `json:"idpMetadataXml,omitempty"`
|
|
IDPSSOURL string `json:"idpSsoUrl,omitempty"`
|
|
IDPCertificate string `json:"idpCertificate,omitempty"`
|
|
}
|
|
|
|
// OIDCTestConfig contains OIDC configuration to test
|
|
type OIDCTestConfig struct {
|
|
IssuerURL string `json:"issuerUrl"`
|
|
ClientID string `json:"clientId,omitempty"`
|
|
}
|
|
|
|
// SSOTestResponse represents the result of a connection test
|
|
type SSOTestResponse struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message"`
|
|
Error string `json:"error,omitempty"`
|
|
Details *SSOTestDetails `json:"details,omitempty"`
|
|
}
|
|
|
|
// SSOTestDetails contains detailed information about the tested provider
|
|
type SSOTestDetails struct {
|
|
Type string `json:"type"`
|
|
EntityID string `json:"entityId,omitempty"`
|
|
SSOURL string `json:"ssoUrl,omitempty"`
|
|
SLOURL string `json:"sloUrl,omitempty"`
|
|
Certificates []CertificateInfo `json:"certificates,omitempty"`
|
|
// OIDC-specific
|
|
TokenEndpoint string `json:"tokenEndpoint,omitempty"`
|
|
UserinfoEndpoint string `json:"userinfoEndpoint,omitempty"`
|
|
JwksURI string `json:"jwksUri,omitempty"`
|
|
SupportedScopes []string `json:"supportedScopes,omitempty"`
|
|
}
|
|
|
|
// CertificateInfo contains certificate details
|
|
type CertificateInfo struct {
|
|
Subject string `json:"subject"`
|
|
Issuer string `json:"issuer"`
|
|
NotBefore time.Time `json:"notBefore"`
|
|
NotAfter time.Time `json:"notAfter"`
|
|
IsExpired bool `json:"isExpired"`
|
|
}
|
|
|
|
// handleTestSSOProvider tests an SSO provider configuration
|
|
func (r *Router) handleTestSSOProvider(w http.ResponseWriter, req *http.Request) {
|
|
if req.Method != http.MethodPost {
|
|
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed", nil)
|
|
return
|
|
}
|
|
|
|
// Rate limiting
|
|
clientIP := GetClientIP(req)
|
|
if !authLimiter.Allow(clientIP) {
|
|
writeErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests", nil)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(req.Body, maxTestRequestBodySize))
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "read_error", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
var testReq SSOTestRequest
|
|
if err := json.Unmarshal(body, &testReq); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_json", "Invalid JSON payload", nil)
|
|
return
|
|
}
|
|
|
|
// Validate request
|
|
if testReq.Type != "saml" && testReq.Type != "oidc" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Type must be 'saml' or 'oidc'", nil)
|
|
return
|
|
}
|
|
|
|
var response SSOTestResponse
|
|
|
|
ctx, cancel := context.WithTimeout(req.Context(), testConnectionTimeout)
|
|
defer cancel()
|
|
|
|
switch testReq.Type {
|
|
case "saml":
|
|
response = r.testSAMLConnection(ctx, testReq.SAML)
|
|
case "oidc":
|
|
response = r.testOIDCConnection(ctx, testReq.OIDC)
|
|
}
|
|
|
|
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_provider_test", "", clientIP, req.URL.Path, response.Success,
|
|
"Tested "+testReq.Type+" provider connection")
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if response.Success {
|
|
w.WriteHeader(http.StatusOK)
|
|
} else {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
func (r *Router) testSAMLConnection(ctx context.Context, cfg *SAMLTestConfig) SSOTestResponse {
|
|
if cfg == nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "SAML configuration is required",
|
|
Error: "missing_config",
|
|
}
|
|
}
|
|
|
|
// Need at least one source of metadata
|
|
if cfg.IDPMetadataURL == "" && cfg.IDPMetadataXML == "" && cfg.IDPSSOURL == "" {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Provide IdP Metadata URL, XML, or SSO URL",
|
|
Error: "missing_metadata",
|
|
}
|
|
}
|
|
|
|
// Validate URLs
|
|
if cfg.IDPMetadataURL != "" && !validateURL(cfg.IDPMetadataURL, []string{"https", "http"}) {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Invalid metadata URL format",
|
|
Error: "invalid_url",
|
|
}
|
|
}
|
|
|
|
var metadata *saml.EntityDescriptor
|
|
var rawXML []byte
|
|
var err error
|
|
|
|
httpClient := newTestHTTPClient()
|
|
|
|
if cfg.IDPMetadataURL != "" {
|
|
rawXML, metadata, err = fetchSAMLMetadataFromURL(ctx, httpClient, cfg.IDPMetadataURL)
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to fetch metadata from URL",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
} else if cfg.IDPMetadataXML != "" {
|
|
rawXML = []byte(cfg.IDPMetadataXML)
|
|
metadata, err = parseSAMLMetadataXML(rawXML)
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to parse metadata XML",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
} else {
|
|
// Manual configuration - just validate the SSO URL
|
|
if !validateURL(cfg.IDPSSOURL, []string{"https", "http"}) {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Invalid SSO URL format",
|
|
Error: "invalid_url",
|
|
}
|
|
}
|
|
return SSOTestResponse{
|
|
Success: true,
|
|
Message: "SSO URL is valid (manual configuration)",
|
|
Details: &SSOTestDetails{
|
|
Type: "saml",
|
|
SSOURL: cfg.IDPSSOURL,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Extract details from metadata
|
|
details := &SSOTestDetails{
|
|
Type: "saml",
|
|
EntityID: metadata.EntityID,
|
|
}
|
|
|
|
// Extract SSO URL
|
|
if len(metadata.IDPSSODescriptors) > 0 {
|
|
idpDesc := metadata.IDPSSODescriptors[0]
|
|
for _, sso := range idpDesc.SingleSignOnServices {
|
|
if sso.Binding == saml.HTTPPostBinding || sso.Binding == saml.HTTPRedirectBinding {
|
|
details.SSOURL = sso.Location
|
|
break
|
|
}
|
|
}
|
|
// Extract SLO URL
|
|
for _, slo := range idpDesc.SingleLogoutServices {
|
|
details.SLOURL = slo.Location
|
|
break
|
|
}
|
|
// Extract certificates
|
|
for _, kd := range idpDesc.KeyDescriptors {
|
|
if kd.Use == "signing" || kd.Use == "" {
|
|
for _, x509Cert := range kd.KeyInfo.X509Data.X509Certificates {
|
|
certInfo := extractCertificateInfo(x509Cert.Data)
|
|
if certInfo != nil {
|
|
details.Certificates = append(details.Certificates, *certInfo)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
_ = rawXML // Used for metadata preview endpoint
|
|
|
|
return SSOTestResponse{
|
|
Success: true,
|
|
Message: "SAML metadata validated successfully",
|
|
Details: details,
|
|
}
|
|
}
|
|
|
|
func (r *Router) testOIDCConnection(ctx context.Context, cfg *OIDCTestConfig) SSOTestResponse {
|
|
if cfg == nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "OIDC configuration is required",
|
|
Error: "missing_config",
|
|
}
|
|
}
|
|
|
|
if cfg.IssuerURL == "" {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Issuer URL is required",
|
|
Error: "missing_issuer",
|
|
}
|
|
}
|
|
|
|
if !validateURL(cfg.IssuerURL, []string{"https", "http"}) {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Invalid issuer URL format",
|
|
Error: "invalid_url",
|
|
}
|
|
}
|
|
|
|
// Fetch OIDC discovery document
|
|
discoveryURL := strings.TrimRight(cfg.IssuerURL, "/") + "/.well-known/openid-configuration"
|
|
|
|
httpClient := newTestHTTPClient()
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to create discovery request",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to fetch OIDC discovery document",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "OIDC discovery returned non-200 status",
|
|
Error: resp.Status,
|
|
}
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to read discovery response",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
|
|
var discovery struct {
|
|
Issuer string `json:"issuer"`
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
UserinfoEndpoint string `json:"userinfo_endpoint"`
|
|
JwksURI string `json:"jwks_uri"`
|
|
ScopesSupported []string `json:"scopes_supported"`
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &discovery); err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to parse discovery document",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
|
|
// Validate issuer matches
|
|
if discovery.Issuer != cfg.IssuerURL && discovery.Issuer != strings.TrimRight(cfg.IssuerURL, "/") {
|
|
log.Warn().
|
|
Str("expected", cfg.IssuerURL).
|
|
Str("actual", discovery.Issuer).
|
|
Msg("OIDC issuer mismatch - this may cause token validation issues")
|
|
}
|
|
|
|
details := &SSOTestDetails{
|
|
Type: "oidc",
|
|
EntityID: discovery.Issuer,
|
|
TokenEndpoint: discovery.TokenEndpoint,
|
|
UserinfoEndpoint: discovery.UserinfoEndpoint,
|
|
JwksURI: discovery.JwksURI,
|
|
SupportedScopes: discovery.ScopesSupported,
|
|
}
|
|
|
|
return SSOTestResponse{
|
|
Success: true,
|
|
Message: "OIDC discovery successful",
|
|
Details: details,
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// SSO Metadata Preview
|
|
// ============================================================================
|
|
|
|
// MetadataPreviewRequest represents a request to preview IdP metadata
|
|
type MetadataPreviewRequest struct {
|
|
Type string `json:"type"` // "saml"
|
|
MetadataURL string `json:"metadataUrl,omitempty"`
|
|
MetadataXML string `json:"metadataXml,omitempty"`
|
|
}
|
|
|
|
// MetadataPreviewResponse contains the metadata preview
|
|
type MetadataPreviewResponse struct {
|
|
XML string `json:"xml"`
|
|
Parsed *ParsedMetadataInfo `json:"parsed"`
|
|
}
|
|
|
|
// ParsedMetadataInfo contains parsed metadata information
|
|
type ParsedMetadataInfo struct {
|
|
EntityID string `json:"entityId"`
|
|
SSOURL string `json:"ssoUrl,omitempty"`
|
|
SLOURL string `json:"sloUrl,omitempty"`
|
|
Certificates []CertificateInfo `json:"certificates,omitempty"`
|
|
NameIDFormats []string `json:"nameIdFormats,omitempty"`
|
|
}
|
|
|
|
// handleMetadataPreview fetches and displays IdP metadata
|
|
func (r *Router) handleMetadataPreview(w http.ResponseWriter, req *http.Request) {
|
|
if req.Method != http.MethodPost {
|
|
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed", nil)
|
|
return
|
|
}
|
|
|
|
// Rate limiting
|
|
clientIP := GetClientIP(req)
|
|
if !authLimiter.Allow(clientIP) {
|
|
writeErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests", nil)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(req.Body, maxTestRequestBodySize))
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "read_error", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
var previewReq MetadataPreviewRequest
|
|
if err := json.Unmarshal(body, &previewReq); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_json", "Invalid JSON payload", nil)
|
|
return
|
|
}
|
|
|
|
if previewReq.Type != "saml" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Only SAML metadata preview is supported", nil)
|
|
return
|
|
}
|
|
|
|
if previewReq.MetadataURL == "" && previewReq.MetadataXML == "" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Provide either metadataUrl or metadataXml", nil)
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(req.Context(), testConnectionTimeout)
|
|
defer cancel()
|
|
|
|
var rawXML []byte
|
|
var metadata *saml.EntityDescriptor
|
|
|
|
httpClient := newTestHTTPClient()
|
|
|
|
if previewReq.MetadataURL != "" {
|
|
if !validateURL(previewReq.MetadataURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid metadata URL", nil)
|
|
return
|
|
}
|
|
rawXML, metadata, err = fetchSAMLMetadataFromURL(ctx, httpClient, previewReq.MetadataURL)
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "fetch_error", "Failed to fetch metadata: "+err.Error(), nil)
|
|
return
|
|
}
|
|
} else {
|
|
rawXML = []byte(previewReq.MetadataXML)
|
|
metadata, err = parseSAMLMetadataXML(rawXML)
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "parse_error", "Failed to parse metadata: "+err.Error(), nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Build parsed info
|
|
parsed := &ParsedMetadataInfo{
|
|
EntityID: metadata.EntityID,
|
|
}
|
|
|
|
if len(metadata.IDPSSODescriptors) > 0 {
|
|
idpDesc := metadata.IDPSSODescriptors[0]
|
|
|
|
// Extract SSO URL
|
|
for _, sso := range idpDesc.SingleSignOnServices {
|
|
if sso.Binding == saml.HTTPPostBinding || sso.Binding == saml.HTTPRedirectBinding {
|
|
parsed.SSOURL = sso.Location
|
|
break
|
|
}
|
|
}
|
|
|
|
// Extract SLO URL
|
|
for _, slo := range idpDesc.SingleLogoutServices {
|
|
parsed.SLOURL = slo.Location
|
|
break
|
|
}
|
|
|
|
// Extract NameID formats
|
|
for _, nid := range idpDesc.NameIDFormats {
|
|
parsed.NameIDFormats = append(parsed.NameIDFormats, string(nid))
|
|
}
|
|
|
|
// Extract certificates
|
|
for _, kd := range idpDesc.KeyDescriptors {
|
|
if kd.Use == "signing" || kd.Use == "" {
|
|
for _, x509Cert := range kd.KeyInfo.X509Data.X509Certificates {
|
|
certInfo := extractCertificateInfo(x509Cert.Data)
|
|
if certInfo != nil {
|
|
parsed.Certificates = append(parsed.Certificates, *certInfo)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Format XML for display
|
|
formattedXML := formatXML(rawXML)
|
|
|
|
response := MetadataPreviewResponse{
|
|
XML: formattedXML,
|
|
Parsed: parsed,
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
// ============================================================================
|
|
// Helper Functions
|
|
// ============================================================================
|
|
|
|
func newTestHTTPClient() *http.Client {
|
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
transport.TLSClientConfig = &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
return &http.Client{
|
|
Transport: transport,
|
|
Timeout: testConnectionTimeout,
|
|
}
|
|
}
|
|
|
|
func fetchSAMLMetadataFromURL(ctx context.Context, client *http.Client, metadataURL string) ([]byte, *saml.EntityDescriptor, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, nil, fmt.Errorf("metadata request returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
metadata, err := parseSAMLMetadataXML(body)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return body, metadata, nil
|
|
}
|
|
|
|
func parseSAMLMetadataXML(data []byte) (*saml.EntityDescriptor, error) {
|
|
var metadata saml.EntityDescriptor
|
|
if err := xml.Unmarshal(data, &metadata); err != nil {
|
|
// Try parsing as EntitiesDescriptor
|
|
var entities saml.EntitiesDescriptor
|
|
if err2 := xml.Unmarshal(data, &entities); err2 != nil {
|
|
return nil, fmt.Errorf("failed to parse metadata: %v", err)
|
|
}
|
|
if len(entities.EntityDescriptors) == 0 {
|
|
return nil, fmt.Errorf("no entity descriptors found in metadata")
|
|
}
|
|
metadata = entities.EntityDescriptors[0]
|
|
}
|
|
return &metadata, nil
|
|
}
|
|
|
|
func extractCertificateInfo(certData string) *CertificateInfo {
|
|
// Remove whitespace and decode base64
|
|
certData = strings.ReplaceAll(certData, "\n", "")
|
|
certData = strings.ReplaceAll(certData, "\r", "")
|
|
certData = strings.ReplaceAll(certData, " ", "")
|
|
|
|
// Try to decode as PEM first
|
|
var derBytes []byte
|
|
block, _ := pem.Decode([]byte(certData))
|
|
if block != nil {
|
|
derBytes = block.Bytes
|
|
} else {
|
|
// Assume it's base64 encoded DER
|
|
var err error
|
|
derBytes, err = base64Decode(certData)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(derBytes)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return &CertificateInfo{
|
|
Subject: cert.Subject.String(),
|
|
Issuer: cert.Issuer.String(),
|
|
NotBefore: cert.NotBefore,
|
|
NotAfter: cert.NotAfter,
|
|
IsExpired: time.Now().After(cert.NotAfter),
|
|
}
|
|
}
|
|
|
|
func base64Decode(s string) ([]byte, error) {
|
|
// Try standard base64 first
|
|
decoded, err := base64.StdEncoding.DecodeString(s)
|
|
if err != nil {
|
|
// Try URL-safe base64
|
|
decoded, err = base64.URLEncoding.DecodeString(s)
|
|
}
|
|
return decoded, err
|
|
}
|
|
|
|
func formatXML(data []byte) string {
|
|
// Try to pretty-print the XML
|
|
var buf strings.Builder
|
|
decoder := xml.NewDecoder(strings.NewReader(string(data)))
|
|
encoder := xml.NewEncoder(&buf)
|
|
encoder.Indent("", " ")
|
|
|
|
for {
|
|
token, err := decoder.Token()
|
|
if err != nil {
|
|
break
|
|
}
|
|
if err := encoder.EncodeToken(token); err != nil {
|
|
// Fall back to original
|
|
return string(data)
|
|
}
|
|
}
|
|
encoder.Flush()
|
|
|
|
if buf.Len() > 0 {
|
|
return buf.String()
|
|
}
|
|
return string(data)
|
|
}
|