mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-29 12:00:13 +00:00
Router & Middleware: - Add auth context middleware for user/token extraction - Add tenant middleware with authorization checking - Refactor middleware chain ordering for proper isolation - Add router helpers for common patterns Authentication & SSO: - Enhance auth with tenant-aware context - Update OIDC, SAML, and SSO handlers for multi-tenant - Add RBAC handler improvements - Add security enhancements New Test Coverage: - API foundation tests - Auth and authorization tests - Router state and general tests - SSO handler CRUD tests - WebSocket isolation tests - Resource handler tests
1102 lines
33 KiB
Go
1102 lines
33 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/crewjam/saml"
|
|
"github.com/google/uuid"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// Security constants for SSO
|
|
const (
|
|
maxProviderIDLength = 64
|
|
maxProviderNameLength = 128
|
|
maxURLLength = 2048
|
|
maxRequestBodySize = 1 << 20 // 1MB
|
|
)
|
|
|
|
// providerIDRegex validates provider IDs (alphanumeric, hyphens, underscores)
|
|
var providerIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}[a-zA-Z0-9]?$`)
|
|
|
|
// validateProviderID checks if a provider ID is safe and valid
|
|
func validateProviderID(id string) bool {
|
|
if id == "" || len(id) > maxProviderIDLength {
|
|
return false
|
|
}
|
|
return providerIDRegex.MatchString(id)
|
|
}
|
|
|
|
// sanitizeProviderName sanitizes a provider name
|
|
func sanitizeProviderName(name string) string {
|
|
name = strings.TrimSpace(name)
|
|
if len(name) > maxProviderNameLength {
|
|
name = name[:maxProviderNameLength]
|
|
}
|
|
// Remove control characters
|
|
name = strings.Map(func(r rune) rune {
|
|
if r < 32 || r == 127 {
|
|
return -1
|
|
}
|
|
return r
|
|
}, name)
|
|
return name
|
|
}
|
|
|
|
// validateURL checks if a URL is valid and uses an allowed scheme
|
|
func validateURL(urlStr string, allowedSchemes []string) bool {
|
|
if urlStr == "" || len(urlStr) > maxURLLength {
|
|
return false
|
|
}
|
|
parsed, err := url.ParseRequestURI(urlStr)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
for _, scheme := range allowedSchemes {
|
|
if strings.EqualFold(parsed.Scheme, scheme) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// SSOProviderResponse represents an SSO provider for API responses
|
|
type SSOProviderResponse struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Type string `json:"type"`
|
|
Enabled bool `json:"enabled"`
|
|
DisplayName string `json:"displayName,omitempty"`
|
|
IconURL string `json:"iconUrl,omitempty"`
|
|
Priority int `json:"priority"`
|
|
|
|
// OIDC-specific (only present for OIDC providers)
|
|
OIDCIssuerURL string `json:"oidcIssuerUrl,omitempty"`
|
|
OIDCClientID string `json:"oidcClientId,omitempty"`
|
|
OIDCClientSecretSet bool `json:"oidcClientSecretSet,omitempty"`
|
|
|
|
// SAML-specific (only present for SAML providers)
|
|
SAMLIDPEntityID string `json:"samlIdpEntityId,omitempty"`
|
|
SAMLSPEntityID string `json:"samlSpEntityId,omitempty"`
|
|
SAMLMetadataURL string `json:"samlMetadataUrl,omitempty"`
|
|
SAMLACSUrl string `json:"samlAcsUrl,omitempty"`
|
|
|
|
// Common restrictions
|
|
AllowedGroups []string `json:"allowedGroups,omitempty"`
|
|
AllowedDomains []string `json:"allowedDomains,omitempty"`
|
|
AllowedEmails []string `json:"allowedEmails,omitempty"`
|
|
}
|
|
|
|
// SSOProvidersListResponse represents the list of SSO providers
|
|
type SSOProvidersListResponse struct {
|
|
Providers []SSOProviderResponse `json:"providers"`
|
|
DefaultProviderID string `json:"defaultProviderId,omitempty"`
|
|
AllowMultipleProviders bool `json:"allowMultipleProviders"`
|
|
}
|
|
|
|
// handleSSOProviders handles listing and creating SSO providers
|
|
func (r *Router) handleSSOProviders(w http.ResponseWriter, req *http.Request) {
|
|
switch req.Method {
|
|
case http.MethodGet:
|
|
r.handleListSSOProviders(w, req)
|
|
case http.MethodPost:
|
|
r.handleCreateSSOProvider(w, req)
|
|
default:
|
|
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed", nil)
|
|
}
|
|
}
|
|
|
|
// handleSSOProvider handles getting, updating, and deleting a specific SSO provider
|
|
func (r *Router) handleSSOProvider(w http.ResponseWriter, req *http.Request) {
|
|
// Extract provider ID from path: /api/security/sso/providers/{id}
|
|
providerID := strings.TrimPrefix(req.URL.Path, "/api/security/sso/providers/")
|
|
providerID = strings.TrimSuffix(providerID, "/")
|
|
|
|
if providerID == "" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "missing_id", "Provider ID is required", nil)
|
|
return
|
|
}
|
|
|
|
// Security: Validate provider ID format to prevent injection attacks
|
|
if !validateProviderID(providerID) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_id", "Invalid provider ID format", nil)
|
|
return
|
|
}
|
|
|
|
switch req.Method {
|
|
case http.MethodGet:
|
|
r.handleGetSSOProvider(w, req, providerID)
|
|
case http.MethodPut:
|
|
r.handleUpdateSSOProvider(w, req, providerID)
|
|
case http.MethodDelete:
|
|
r.handleDeleteSSOProvider(w, req, providerID)
|
|
default:
|
|
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed", nil)
|
|
}
|
|
}
|
|
|
|
func (r *Router) handleListSSOProviders(w http.ResponseWriter, req *http.Request) {
|
|
if r.ssoConfig == nil {
|
|
r.ssoConfig = config.NewSSOConfig()
|
|
}
|
|
|
|
response := SSOProvidersListResponse{
|
|
Providers: make([]SSOProviderResponse, 0),
|
|
DefaultProviderID: r.ssoConfig.DefaultProviderID,
|
|
AllowMultipleProviders: r.ssoConfig.AllowMultipleProviders,
|
|
}
|
|
|
|
for _, p := range r.ssoConfig.Providers {
|
|
response.Providers = append(response.Providers, providerToResponse(&p, r.config.PublicURL))
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
func (r *Router) handleGetSSOProvider(w http.ResponseWriter, req *http.Request, providerID string) {
|
|
if r.ssoConfig == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
provider := r.ssoConfig.GetProvider(providerID)
|
|
if provider == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(providerToResponse(provider, r.config.PublicURL))
|
|
}
|
|
|
|
func (r *Router) handleCreateSSOProvider(w http.ResponseWriter, req *http.Request) {
|
|
body, err := io.ReadAll(io.LimitReader(req.Body, maxRequestBodySize))
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "read_error", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
var provider config.SSOProvider
|
|
if err := json.Unmarshal(body, &provider); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_json", "Invalid JSON payload", nil)
|
|
return
|
|
}
|
|
|
|
// Generate ID if not provided
|
|
if provider.ID == "" {
|
|
provider.ID = uuid.NewString()
|
|
}
|
|
|
|
// Security: Validate provider ID format
|
|
if !validateProviderID(provider.ID) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid provider ID format", nil)
|
|
return
|
|
}
|
|
|
|
// Sanitize provider name
|
|
provider.Name = sanitizeProviderName(provider.Name)
|
|
provider.DisplayName = sanitizeProviderName(provider.DisplayName)
|
|
|
|
// Validate provider
|
|
if provider.Name == "" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Provider name is required", nil)
|
|
return
|
|
}
|
|
|
|
if provider.Type != config.SSOProviderTypeOIDC && provider.Type != config.SSOProviderTypeSAML {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Provider type must be 'oidc' or 'saml'", nil)
|
|
return
|
|
}
|
|
|
|
// Security: Validate OIDC configuration
|
|
if provider.Type == config.SSOProviderTypeOIDC {
|
|
if provider.OIDC == nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "OIDC configuration is required", nil)
|
|
return
|
|
}
|
|
if provider.OIDC.IssuerURL != "" && !validateURL(provider.OIDC.IssuerURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid OIDC issuer URL", nil)
|
|
return
|
|
}
|
|
if provider.OIDC.RedirectURL != "" && !validateURL(provider.OIDC.RedirectURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid OIDC redirect URL", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Security: Validate SAML configuration
|
|
if provider.Type == config.SSOProviderTypeSAML {
|
|
if provider.SAML == nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "SAML configuration is required", nil)
|
|
return
|
|
}
|
|
if provider.SAML.IDPMetadataURL != "" && !validateURL(provider.SAML.IDPMetadataURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid SAML metadata URL", nil)
|
|
return
|
|
}
|
|
if provider.SAML.IDPSSOURL != "" && !validateURL(provider.SAML.IDPSSOURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid SAML SSO URL", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Security: Validate icon URL if provided
|
|
if provider.IconURL != "" && !validateURL(provider.IconURL, []string{"https", "http", "data"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid icon URL", nil)
|
|
return
|
|
}
|
|
|
|
if r.ssoConfig == nil {
|
|
r.ssoConfig = config.NewSSOConfig()
|
|
}
|
|
|
|
// Check for duplicate ID
|
|
if r.ssoConfig.GetProvider(provider.ID) != nil {
|
|
writeErrorResponse(w, http.StatusConflict, "duplicate_id", "Provider with this ID already exists", nil)
|
|
return
|
|
}
|
|
|
|
// Add provider
|
|
if err := r.ssoConfig.AddProvider(provider); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "add_error", err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
// Persist configuration
|
|
if err := r.saveSSOConfig(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to persist SSO configuration")
|
|
// Remove the provider we just added since persistence failed
|
|
r.ssoConfig.RemoveProvider(provider.ID)
|
|
writeErrorResponse(w, http.StatusInternalServerError, "save_error", "Failed to save configuration", nil)
|
|
return
|
|
}
|
|
|
|
// Initialize SAML provider if applicable
|
|
if provider.Type == config.SSOProviderTypeSAML && provider.Enabled && provider.SAML != nil {
|
|
if err := r.samlManager.InitializeProvider(req.Context(), provider.ID, provider.SAML); err != nil {
|
|
log.Warn().Err(err).Str("provider_id", provider.ID).Msg("Failed to initialize SAML provider (will retry on first use)")
|
|
}
|
|
}
|
|
|
|
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_provider_created", "", GetClientIP(req), req.URL.Path, true, "Created provider: "+provider.Name)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusCreated)
|
|
json.NewEncoder(w).Encode(providerToResponse(&provider, r.config.PublicURL))
|
|
}
|
|
|
|
func (r *Router) handleUpdateSSOProvider(w http.ResponseWriter, req *http.Request, providerID string) {
|
|
if r.ssoConfig == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
existing := r.ssoConfig.GetProvider(providerID)
|
|
if existing == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(req.Body, maxRequestBodySize))
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "read_error", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
var updated config.SSOProvider
|
|
if err := json.Unmarshal(body, &updated); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_json", "Invalid JSON payload", nil)
|
|
return
|
|
}
|
|
|
|
// Ensure ID matches
|
|
updated.ID = providerID
|
|
|
|
// Sanitize inputs
|
|
updated.Name = sanitizeProviderName(updated.Name)
|
|
updated.DisplayName = sanitizeProviderName(updated.DisplayName)
|
|
|
|
if updated.Name == "" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Provider name is required", nil)
|
|
return
|
|
}
|
|
|
|
// Security: Validate URLs for OIDC
|
|
if updated.Type == config.SSOProviderTypeOIDC && updated.OIDC != nil {
|
|
if updated.OIDC.IssuerURL != "" && !validateURL(updated.OIDC.IssuerURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid OIDC issuer URL", nil)
|
|
return
|
|
}
|
|
if updated.OIDC.RedirectURL != "" && !validateURL(updated.OIDC.RedirectURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid OIDC redirect URL", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Security: Validate URLs for SAML
|
|
if updated.Type == config.SSOProviderTypeSAML && updated.SAML != nil {
|
|
if updated.SAML.IDPMetadataURL != "" && !validateURL(updated.SAML.IDPMetadataURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid SAML metadata URL", nil)
|
|
return
|
|
}
|
|
if updated.SAML.IDPSSOURL != "" && !validateURL(updated.SAML.IDPSSOURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid SAML SSO URL", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Security: Validate icon URL if provided
|
|
if updated.IconURL != "" && !validateURL(updated.IconURL, []string{"https", "http", "data"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid icon URL", nil)
|
|
return
|
|
}
|
|
|
|
// Preserve secrets if not provided in update
|
|
if updated.Type == config.SSOProviderTypeOIDC && updated.OIDC != nil && existing.OIDC != nil {
|
|
if updated.OIDC.ClientSecret == "" && existing.OIDC.ClientSecretSet {
|
|
updated.OIDC.ClientSecret = existing.OIDC.ClientSecret
|
|
updated.OIDC.ClientSecretSet = true
|
|
}
|
|
}
|
|
|
|
// Update provider
|
|
if err := r.ssoConfig.UpdateProvider(updated); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "update_error", err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
// Persist configuration
|
|
if err := r.saveSSOConfig(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to persist SSO configuration")
|
|
// Revert to existing provider
|
|
r.ssoConfig.UpdateProvider(*existing)
|
|
writeErrorResponse(w, http.StatusInternalServerError, "save_error", "Failed to save configuration", nil)
|
|
return
|
|
}
|
|
|
|
// Re-initialize SAML provider if applicable
|
|
if updated.Type == config.SSOProviderTypeSAML && updated.SAML != nil {
|
|
if updated.Enabled {
|
|
if err := r.samlManager.InitializeProvider(req.Context(), updated.ID, updated.SAML); err != nil {
|
|
log.Warn().Err(err).Str("provider_id", updated.ID).Msg("Failed to re-initialize SAML provider")
|
|
}
|
|
} else {
|
|
r.samlManager.RemoveProvider(updated.ID)
|
|
}
|
|
}
|
|
|
|
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_provider_updated", "", GetClientIP(req), req.URL.Path, true, "Updated provider: "+updated.Name)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(providerToResponse(&updated, r.config.PublicURL))
|
|
}
|
|
|
|
func (r *Router) handleDeleteSSOProvider(w http.ResponseWriter, req *http.Request, providerID string) {
|
|
if r.ssoConfig == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
existing := r.ssoConfig.GetProvider(providerID)
|
|
if existing == nil {
|
|
writeErrorResponse(w, http.StatusNotFound, "not_found", "Provider not found", nil)
|
|
return
|
|
}
|
|
|
|
// Remove provider
|
|
if err := r.ssoConfig.RemoveProvider(providerID); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "remove_error", err.Error(), nil)
|
|
return
|
|
}
|
|
|
|
// Persist configuration
|
|
if err := r.saveSSOConfig(); err != nil {
|
|
log.Error().Err(err).Msg("Failed to persist SSO configuration")
|
|
// Re-add the provider since persistence failed
|
|
r.ssoConfig.AddProvider(*existing)
|
|
writeErrorResponse(w, http.StatusInternalServerError, "save_error", "Failed to save configuration", nil)
|
|
return
|
|
}
|
|
|
|
// Remove SAML service if applicable
|
|
if existing.Type == config.SSOProviderTypeSAML {
|
|
r.samlManager.RemoveProvider(providerID)
|
|
}
|
|
|
|
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_provider_deleted", "", GetClientIP(req), req.URL.Path, true, "Deleted provider: "+existing.Name)
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (r *Router) saveSSOConfig() error {
|
|
if r.persistence == nil {
|
|
return nil
|
|
}
|
|
return r.persistence.SaveSSOConfig(r.ssoConfig)
|
|
}
|
|
|
|
func providerToResponse(p *config.SSOProvider, publicURL string) SSOProviderResponse {
|
|
resp := SSOProviderResponse{
|
|
ID: p.ID,
|
|
Name: p.Name,
|
|
Type: string(p.Type),
|
|
Enabled: p.Enabled,
|
|
DisplayName: p.DisplayName,
|
|
IconURL: p.IconURL,
|
|
Priority: p.Priority,
|
|
AllowedGroups: p.AllowedGroups,
|
|
AllowedDomains: p.AllowedDomains,
|
|
AllowedEmails: p.AllowedEmails,
|
|
}
|
|
|
|
if resp.DisplayName == "" {
|
|
resp.DisplayName = p.Name
|
|
}
|
|
|
|
baseURL := publicURL
|
|
if baseURL == "" {
|
|
baseURL = "http://localhost:7655"
|
|
}
|
|
|
|
if p.Type == config.SSOProviderTypeOIDC && p.OIDC != nil {
|
|
resp.OIDCIssuerURL = p.OIDC.IssuerURL
|
|
resp.OIDCClientID = p.OIDC.ClientID
|
|
resp.OIDCClientSecretSet = p.OIDC.ClientSecretSet || p.OIDC.ClientSecret != ""
|
|
}
|
|
|
|
if p.Type == config.SSOProviderTypeSAML && p.SAML != nil {
|
|
resp.SAMLIDPEntityID = p.SAML.IDPEntityID
|
|
if resp.SAMLIDPEntityID == "" {
|
|
resp.SAMLIDPEntityID = p.SAML.IDPIssuer
|
|
}
|
|
resp.SAMLSPEntityID = p.SAML.SPEntityID
|
|
if resp.SAMLSPEntityID == "" {
|
|
resp.SAMLSPEntityID = baseURL + "/saml/" + p.ID
|
|
}
|
|
resp.SAMLMetadataURL = baseURL + "/api/saml/" + p.ID + "/metadata"
|
|
resp.SAMLACSUrl = baseURL + "/api/saml/" + p.ID + "/acs"
|
|
}
|
|
|
|
return resp
|
|
}
|
|
|
|
// ============================================================================
|
|
// SSO Provider Connection Testing
|
|
// ============================================================================
|
|
|
|
const (
|
|
maxTestRequestBodySize = 32 * 1024 // 32KB
|
|
testConnectionTimeout = 30 * time.Second
|
|
)
|
|
|
|
// SSOTestRequest represents a request to test SSO provider configuration
|
|
type SSOTestRequest struct {
|
|
Type string `json:"type"` // "saml" or "oidc"
|
|
SAML *SAMLTestConfig `json:"saml,omitempty"`
|
|
OIDC *OIDCTestConfig `json:"oidc,omitempty"`
|
|
}
|
|
|
|
// SAMLTestConfig contains SAML configuration to test
|
|
type SAMLTestConfig struct {
|
|
IDPMetadataURL string `json:"idpMetadataUrl,omitempty"`
|
|
IDPMetadataXML string `json:"idpMetadataXml,omitempty"`
|
|
IDPSSOURL string `json:"idpSsoUrl,omitempty"`
|
|
IDPCertificate string `json:"idpCertificate,omitempty"`
|
|
}
|
|
|
|
// OIDCTestConfig contains OIDC configuration to test
|
|
type OIDCTestConfig struct {
|
|
IssuerURL string `json:"issuerUrl"`
|
|
ClientID string `json:"clientId,omitempty"`
|
|
}
|
|
|
|
// SSOTestResponse represents the result of a connection test
|
|
type SSOTestResponse struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message"`
|
|
Error string `json:"error,omitempty"`
|
|
Details *SSOTestDetails `json:"details,omitempty"`
|
|
}
|
|
|
|
// SSOTestDetails contains detailed information about the tested provider
|
|
type SSOTestDetails struct {
|
|
Type string `json:"type"`
|
|
EntityID string `json:"entityId,omitempty"`
|
|
SSOURL string `json:"ssoUrl,omitempty"`
|
|
SLOURL string `json:"sloUrl,omitempty"`
|
|
Certificates []CertificateInfo `json:"certificates,omitempty"`
|
|
// OIDC-specific
|
|
TokenEndpoint string `json:"tokenEndpoint,omitempty"`
|
|
UserinfoEndpoint string `json:"userinfoEndpoint,omitempty"`
|
|
JwksURI string `json:"jwksUri,omitempty"`
|
|
SupportedScopes []string `json:"supportedScopes,omitempty"`
|
|
}
|
|
|
|
// CertificateInfo contains certificate details
|
|
type CertificateInfo struct {
|
|
Subject string `json:"subject"`
|
|
Issuer string `json:"issuer"`
|
|
NotBefore time.Time `json:"notBefore"`
|
|
NotAfter time.Time `json:"notAfter"`
|
|
IsExpired bool `json:"isExpired"`
|
|
}
|
|
|
|
// handleTestSSOProvider tests an SSO provider configuration
|
|
func (r *Router) handleTestSSOProvider(w http.ResponseWriter, req *http.Request) {
|
|
if req.Method != http.MethodPost {
|
|
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed", nil)
|
|
return
|
|
}
|
|
|
|
// Rate limiting
|
|
clientIP := GetClientIP(req)
|
|
if !authLimiter.Allow(clientIP) {
|
|
writeErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests", nil)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(req.Body, maxTestRequestBodySize))
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "read_error", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
var testReq SSOTestRequest
|
|
if err := json.Unmarshal(body, &testReq); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_json", "Invalid JSON payload", nil)
|
|
return
|
|
}
|
|
|
|
// Validate request
|
|
if testReq.Type != "saml" && testReq.Type != "oidc" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Type must be 'saml' or 'oidc'", nil)
|
|
return
|
|
}
|
|
|
|
var response SSOTestResponse
|
|
|
|
ctx, cancel := context.WithTimeout(req.Context(), testConnectionTimeout)
|
|
defer cancel()
|
|
|
|
switch testReq.Type {
|
|
case "saml":
|
|
response = r.testSAMLConnection(ctx, testReq.SAML)
|
|
case "oidc":
|
|
response = r.testOIDCConnection(ctx, testReq.OIDC)
|
|
}
|
|
|
|
LogAuditEventForTenant(GetOrgID(req.Context()), "sso_provider_test", "", clientIP, req.URL.Path, response.Success,
|
|
"Tested "+testReq.Type+" provider connection")
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if response.Success {
|
|
w.WriteHeader(http.StatusOK)
|
|
} else {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
func (r *Router) testSAMLConnection(ctx context.Context, cfg *SAMLTestConfig) SSOTestResponse {
|
|
if cfg == nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "SAML configuration is required",
|
|
Error: "missing_config",
|
|
}
|
|
}
|
|
|
|
// Need at least one source of metadata
|
|
if cfg.IDPMetadataURL == "" && cfg.IDPMetadataXML == "" && cfg.IDPSSOURL == "" {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Provide IdP Metadata URL, XML, or SSO URL",
|
|
Error: "missing_metadata",
|
|
}
|
|
}
|
|
|
|
// Validate URLs
|
|
if cfg.IDPMetadataURL != "" && !validateURL(cfg.IDPMetadataURL, []string{"https", "http"}) {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Invalid metadata URL format",
|
|
Error: "invalid_url",
|
|
}
|
|
}
|
|
|
|
var metadata *saml.EntityDescriptor
|
|
var rawXML []byte
|
|
var err error
|
|
|
|
httpClient := newTestHTTPClient()
|
|
|
|
if cfg.IDPMetadataURL != "" {
|
|
rawXML, metadata, err = fetchSAMLMetadataFromURL(ctx, httpClient, cfg.IDPMetadataURL)
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to fetch metadata from URL",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
} else if cfg.IDPMetadataXML != "" {
|
|
rawXML = []byte(cfg.IDPMetadataXML)
|
|
metadata, err = parseSAMLMetadataXML(rawXML)
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to parse metadata XML",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
} else {
|
|
// Manual configuration - just validate the SSO URL
|
|
if !validateURL(cfg.IDPSSOURL, []string{"https", "http"}) {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Invalid SSO URL format",
|
|
Error: "invalid_url",
|
|
}
|
|
}
|
|
return SSOTestResponse{
|
|
Success: true,
|
|
Message: "SSO URL is valid (manual configuration)",
|
|
Details: &SSOTestDetails{
|
|
Type: "saml",
|
|
SSOURL: cfg.IDPSSOURL,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Extract details from metadata
|
|
details := &SSOTestDetails{
|
|
Type: "saml",
|
|
EntityID: metadata.EntityID,
|
|
}
|
|
|
|
// Extract SSO URL
|
|
if len(metadata.IDPSSODescriptors) > 0 {
|
|
idpDesc := metadata.IDPSSODescriptors[0]
|
|
for _, sso := range idpDesc.SingleSignOnServices {
|
|
if sso.Binding == saml.HTTPPostBinding || sso.Binding == saml.HTTPRedirectBinding {
|
|
details.SSOURL = sso.Location
|
|
break
|
|
}
|
|
}
|
|
// Extract SLO URL
|
|
for _, slo := range idpDesc.SingleLogoutServices {
|
|
details.SLOURL = slo.Location
|
|
break
|
|
}
|
|
// Extract certificates
|
|
for _, kd := range idpDesc.KeyDescriptors {
|
|
if kd.Use == "signing" || kd.Use == "" {
|
|
for _, x509Cert := range kd.KeyInfo.X509Data.X509Certificates {
|
|
certInfo := extractCertificateInfo(x509Cert.Data)
|
|
if certInfo != nil {
|
|
details.Certificates = append(details.Certificates, *certInfo)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
_ = rawXML // Used for metadata preview endpoint
|
|
|
|
return SSOTestResponse{
|
|
Success: true,
|
|
Message: "SAML metadata validated successfully",
|
|
Details: details,
|
|
}
|
|
}
|
|
|
|
func (r *Router) testOIDCConnection(ctx context.Context, cfg *OIDCTestConfig) SSOTestResponse {
|
|
if cfg == nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "OIDC configuration is required",
|
|
Error: "missing_config",
|
|
}
|
|
}
|
|
|
|
if cfg.IssuerURL == "" {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Issuer URL is required",
|
|
Error: "missing_issuer",
|
|
}
|
|
}
|
|
|
|
if !validateURL(cfg.IssuerURL, []string{"https", "http"}) {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Invalid issuer URL format",
|
|
Error: "invalid_url",
|
|
}
|
|
}
|
|
|
|
// Fetch OIDC discovery document
|
|
discoveryURL := strings.TrimRight(cfg.IssuerURL, "/") + "/.well-known/openid-configuration"
|
|
|
|
httpClient := newTestHTTPClient()
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to create discovery request",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to fetch OIDC discovery document",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "OIDC discovery returned non-200 status",
|
|
Error: resp.Status,
|
|
}
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to read discovery response",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
|
|
var discovery struct {
|
|
Issuer string `json:"issuer"`
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
UserinfoEndpoint string `json:"userinfo_endpoint"`
|
|
JwksURI string `json:"jwks_uri"`
|
|
ScopesSupported []string `json:"scopes_supported"`
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &discovery); err != nil {
|
|
return SSOTestResponse{
|
|
Success: false,
|
|
Message: "Failed to parse discovery document",
|
|
Error: err.Error(),
|
|
}
|
|
}
|
|
|
|
// Validate issuer matches
|
|
if discovery.Issuer != cfg.IssuerURL && discovery.Issuer != strings.TrimRight(cfg.IssuerURL, "/") {
|
|
log.Warn().
|
|
Str("expected", cfg.IssuerURL).
|
|
Str("actual", discovery.Issuer).
|
|
Msg("OIDC issuer mismatch - this may cause token validation issues")
|
|
}
|
|
|
|
details := &SSOTestDetails{
|
|
Type: "oidc",
|
|
EntityID: discovery.Issuer,
|
|
TokenEndpoint: discovery.TokenEndpoint,
|
|
UserinfoEndpoint: discovery.UserinfoEndpoint,
|
|
JwksURI: discovery.JwksURI,
|
|
SupportedScopes: discovery.ScopesSupported,
|
|
}
|
|
|
|
return SSOTestResponse{
|
|
Success: true,
|
|
Message: "OIDC discovery successful",
|
|
Details: details,
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// SSO Metadata Preview
|
|
// ============================================================================
|
|
|
|
// MetadataPreviewRequest represents a request to preview IdP metadata
|
|
type MetadataPreviewRequest struct {
|
|
Type string `json:"type"` // "saml"
|
|
MetadataURL string `json:"metadataUrl,omitempty"`
|
|
MetadataXML string `json:"metadataXml,omitempty"`
|
|
}
|
|
|
|
// MetadataPreviewResponse contains the metadata preview
|
|
type MetadataPreviewResponse struct {
|
|
XML string `json:"xml"`
|
|
Parsed *ParsedMetadataInfo `json:"parsed"`
|
|
}
|
|
|
|
// ParsedMetadataInfo contains parsed metadata information
|
|
type ParsedMetadataInfo struct {
|
|
EntityID string `json:"entityId"`
|
|
SSOURL string `json:"ssoUrl,omitempty"`
|
|
SLOURL string `json:"sloUrl,omitempty"`
|
|
Certificates []CertificateInfo `json:"certificates,omitempty"`
|
|
NameIDFormats []string `json:"nameIdFormats,omitempty"`
|
|
}
|
|
|
|
// handleMetadataPreview fetches and displays IdP metadata
|
|
func (r *Router) handleMetadataPreview(w http.ResponseWriter, req *http.Request) {
|
|
if req.Method != http.MethodPost {
|
|
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed", nil)
|
|
return
|
|
}
|
|
|
|
// Rate limiting
|
|
clientIP := GetClientIP(req)
|
|
if !authLimiter.Allow(clientIP) {
|
|
writeErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests", nil)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(req.Body, maxTestRequestBodySize))
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "read_error", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
var previewReq MetadataPreviewRequest
|
|
if err := json.Unmarshal(body, &previewReq); err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_json", "Invalid JSON payload", nil)
|
|
return
|
|
}
|
|
|
|
if previewReq.Type != "saml" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Only SAML metadata preview is supported", nil)
|
|
return
|
|
}
|
|
|
|
if previewReq.MetadataURL == "" && previewReq.MetadataXML == "" {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Provide either metadataUrl or metadataXml", nil)
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(req.Context(), testConnectionTimeout)
|
|
defer cancel()
|
|
|
|
var rawXML []byte
|
|
var metadata *saml.EntityDescriptor
|
|
|
|
httpClient := newTestHTTPClient()
|
|
|
|
if previewReq.MetadataURL != "" {
|
|
if !validateURL(previewReq.MetadataURL, []string{"https", "http"}) {
|
|
writeErrorResponse(w, http.StatusBadRequest, "validation_error", "Invalid metadata URL", nil)
|
|
return
|
|
}
|
|
rawXML, metadata, err = fetchSAMLMetadataFromURL(ctx, httpClient, previewReq.MetadataURL)
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "fetch_error", "Failed to fetch metadata: "+err.Error(), nil)
|
|
return
|
|
}
|
|
} else {
|
|
rawXML = []byte(previewReq.MetadataXML)
|
|
metadata, err = parseSAMLMetadataXML(rawXML)
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "parse_error", "Failed to parse metadata: "+err.Error(), nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Build parsed info
|
|
parsed := &ParsedMetadataInfo{
|
|
EntityID: metadata.EntityID,
|
|
}
|
|
|
|
if len(metadata.IDPSSODescriptors) > 0 {
|
|
idpDesc := metadata.IDPSSODescriptors[0]
|
|
|
|
// Extract SSO URL
|
|
for _, sso := range idpDesc.SingleSignOnServices {
|
|
if sso.Binding == saml.HTTPPostBinding || sso.Binding == saml.HTTPRedirectBinding {
|
|
parsed.SSOURL = sso.Location
|
|
break
|
|
}
|
|
}
|
|
|
|
// Extract SLO URL
|
|
for _, slo := range idpDesc.SingleLogoutServices {
|
|
parsed.SLOURL = slo.Location
|
|
break
|
|
}
|
|
|
|
// Extract NameID formats
|
|
for _, nid := range idpDesc.NameIDFormats {
|
|
parsed.NameIDFormats = append(parsed.NameIDFormats, string(nid))
|
|
}
|
|
|
|
// Extract certificates
|
|
for _, kd := range idpDesc.KeyDescriptors {
|
|
if kd.Use == "signing" || kd.Use == "" {
|
|
for _, x509Cert := range kd.KeyInfo.X509Data.X509Certificates {
|
|
certInfo := extractCertificateInfo(x509Cert.Data)
|
|
if certInfo != nil {
|
|
parsed.Certificates = append(parsed.Certificates, *certInfo)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Format XML for display
|
|
formattedXML := formatXML(rawXML)
|
|
|
|
response := MetadataPreviewResponse{
|
|
XML: formattedXML,
|
|
Parsed: parsed,
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
// ============================================================================
|
|
// Helper Functions
|
|
// ============================================================================
|
|
|
|
func newTestHTTPClient() *http.Client {
|
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
transport.TLSClientConfig = &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
return &http.Client{
|
|
Transport: transport,
|
|
Timeout: testConnectionTimeout,
|
|
}
|
|
}
|
|
|
|
func fetchSAMLMetadataFromURL(ctx context.Context, client *http.Client, metadataURL string) ([]byte, *saml.EntityDescriptor, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, nil, fmt.Errorf("metadata request returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
metadata, err := parseSAMLMetadataXML(body)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return body, metadata, nil
|
|
}
|
|
|
|
func parseSAMLMetadataXML(data []byte) (*saml.EntityDescriptor, error) {
|
|
var metadata saml.EntityDescriptor
|
|
if err := xml.Unmarshal(data, &metadata); err != nil {
|
|
// Try parsing as EntitiesDescriptor
|
|
var entities saml.EntitiesDescriptor
|
|
if err2 := xml.Unmarshal(data, &entities); err2 != nil {
|
|
return nil, fmt.Errorf("failed to parse metadata: %v", err)
|
|
}
|
|
if len(entities.EntityDescriptors) == 0 {
|
|
return nil, fmt.Errorf("no entity descriptors found in metadata")
|
|
}
|
|
metadata = entities.EntityDescriptors[0]
|
|
}
|
|
return &metadata, nil
|
|
}
|
|
|
|
func extractCertificateInfo(certData string) *CertificateInfo {
|
|
// Remove whitespace and decode base64
|
|
certData = strings.ReplaceAll(certData, "\n", "")
|
|
certData = strings.ReplaceAll(certData, "\r", "")
|
|
certData = strings.ReplaceAll(certData, " ", "")
|
|
|
|
// Try to decode as PEM first
|
|
var derBytes []byte
|
|
block, _ := pem.Decode([]byte(certData))
|
|
if block != nil {
|
|
derBytes = block.Bytes
|
|
} else {
|
|
// Assume it's base64 encoded DER
|
|
var err error
|
|
derBytes, err = base64Decode(certData)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(derBytes)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return &CertificateInfo{
|
|
Subject: cert.Subject.String(),
|
|
Issuer: cert.Issuer.String(),
|
|
NotBefore: cert.NotBefore,
|
|
NotAfter: cert.NotAfter,
|
|
IsExpired: time.Now().After(cert.NotAfter),
|
|
}
|
|
}
|
|
|
|
func base64Decode(s string) ([]byte, error) {
|
|
// Try standard base64 first
|
|
decoded, err := base64.StdEncoding.DecodeString(s)
|
|
if err != nil {
|
|
// Try URL-safe base64
|
|
decoded, err = base64.URLEncoding.DecodeString(s)
|
|
}
|
|
return decoded, err
|
|
}
|
|
|
|
func formatXML(data []byte) string {
|
|
// Try to pretty-print the XML
|
|
var buf strings.Builder
|
|
decoder := xml.NewDecoder(strings.NewReader(string(data)))
|
|
encoder := xml.NewEncoder(&buf)
|
|
encoder.Indent("", " ")
|
|
|
|
for {
|
|
token, err := decoder.Token()
|
|
if err != nil {
|
|
break
|
|
}
|
|
if err := encoder.EncodeToken(token); err != nil {
|
|
// Fall back to original
|
|
return string(data)
|
|
}
|
|
}
|
|
encoder.Flush()
|
|
|
|
if buf.Len() > 0 {
|
|
return buf.String()
|
|
}
|
|
return string(data)
|
|
}
|