Fix SAML public URLs and OIDC group role mappings

This commit is contained in:
rcourtman 2026-04-01 23:03:29 +01:00
parent a05b618257
commit cca27b697e
9 changed files with 243 additions and 50 deletions

View file

@ -201,8 +201,9 @@ func NewRouter(cfg *config.Config, monitor *monitoring.Monitor, mtMonitor *monit
auth.SetAdminUser(cfg.AuthUser)
}
// Initialize SAML manager (baseURL will be set dynamically on first use)
r.samlManager = NewSAMLServiceManager("")
// Initialize SAML manager with any configured public URL so startup-loaded providers
// build absolute metadata when a canonical public endpoint is already known.
r.samlManager = NewSAMLServiceManager(cfg.PublicURL)
r.initializeBootstrapToken()
@ -2163,6 +2164,11 @@ func (r *Router) SetMonitor(m *monitoring.Monitor) {
if mgr := m.GetNotificationManager(); mgr != nil {
mgr.SetPublicURL(url)
}
if r.samlManager != nil {
if err := r.samlManager.SetBaseURL(url); err != nil {
log.Warn().Err(err).Msg("Failed to synchronize SAML base URL")
}
}
}
// Inject resource store for polling optimization
if r.resourceHandlers != nil {
@ -4158,6 +4164,11 @@ func (r *Router) capturePublicURLFromRequest(req *http.Request) {
mgr.SetPublicURL(normalizedCandidate)
}
}
if r.samlManager != nil {
if err := r.samlManager.SetBaseURL(normalizedCandidate); err != nil {
log.Warn().Err(err).Msg("Failed to synchronize SAML base URL")
}
}
}
func firstForwardedValue(header string) string {

View file

@ -25,7 +25,7 @@ type SAMLServiceManager struct {
func NewSAMLServiceManager(baseURL string) *SAMLServiceManager {
return &SAMLServiceManager{
services: make(map[string]*SAMLService),
baseURL: baseURL,
baseURL: normalizeSAMLBaseURL(baseURL),
}
}
@ -36,6 +36,27 @@ func (m *SAMLServiceManager) GetService(providerID string) *SAMLService {
return m.services[providerID]
}
// SetBaseURL updates the manager base URL and refreshes all initialized services.
func (m *SAMLServiceManager) SetBaseURL(baseURL string) error {
normalized := normalizeSAMLBaseURL(baseURL)
m.mu.Lock()
m.baseURL = normalized
services := make([]*SAMLService, 0, len(m.services))
for _, service := range m.services {
services = append(services, service)
}
m.mu.Unlock()
for _, service := range services {
if err := service.SetBaseURL(normalized); err != nil {
return fmt.Errorf("refresh provider %s: %w", service.ProviderID(), err)
}
}
return nil
}
// InitializeProvider creates or updates a SAML service for a provider
func (m *SAMLServiceManager) InitializeProvider(ctx context.Context, providerID string, cfg *config.SAMLProviderConfig) error {
service, err := NewSAMLService(ctx, providerID, cfg, m.baseURL)
@ -61,6 +82,19 @@ func (m *SAMLServiceManager) RemoveProvider(providerID string) {
delete(m.services, providerID)
}
func (r *Router) syncSAMLBaseURL(req *http.Request) error {
if r == nil || r.samlManager == nil {
return nil
}
baseURL := normalizeSAMLBaseURL(r.resolvePublicURL(req))
if baseURL == "" {
return nil
}
return r.samlManager.SetBaseURL(baseURL)
}
// handleSAMLLogin initiates a SAML authentication flow
func (r *Router) handleSAMLLogin(w http.ResponseWriter, req *http.Request) {
providerID := extractSAMLProviderID(req.URL.Path, "login")
@ -81,6 +115,12 @@ func (r *Router) handleSAMLLogin(w http.ResponseWriter, req *http.Request) {
return
}
if err := r.syncSAMLBaseURL(req); err != nil {
log.Error().Err(err).Str("provider_id", providerID).Msg("Failed to synchronize SAML provider base URL")
writeErrorResponse(w, http.StatusInternalServerError, "saml_init_failed", "Failed to initialize SAML provider", nil)
return
}
service := r.samlManager.GetService(providerID)
if service == nil {
// Try to initialize the provider
@ -150,6 +190,12 @@ func (r *Router) handleSAMLACS(w http.ResponseWriter, req *http.Request) {
return
}
if err := r.syncSAMLBaseURL(req); err != nil {
log.Error().Err(err).Str("provider_id", providerID).Msg("Failed to synchronize SAML provider base URL")
r.redirectSAMLError(w, req, "", "saml_init_failed")
return
}
service := r.samlManager.GetService(providerID)
if service == nil {
r.redirectSAMLError(w, req, "", "provider_not_initialized")
@ -294,6 +340,12 @@ func (r *Router) handleSAMLMetadata(w http.ResponseWriter, req *http.Request) {
return
}
if err := r.syncSAMLBaseURL(req); err != nil {
log.Error().Err(err).Str("provider_id", providerID).Msg("Failed to synchronize SAML provider base URL")
writeErrorResponse(w, http.StatusInternalServerError, "saml_init_failed", "Failed to initialize SAML provider", nil)
return
}
service := r.samlManager.GetService(providerID)
if service == nil {
// Try to initialize the provider
@ -531,6 +583,12 @@ func (r *Router) InitializeSAMLProviders(ctx context.Context) error {
return nil
}
if publicURL := normalizeSAMLBaseURL(r.config.PublicURL); publicURL != "" {
if err := r.samlManager.SetBaseURL(publicURL); err != nil {
log.Warn().Err(err).Msg("Failed to synchronize SAML base URL before provider initialization")
}
}
for _, provider := range r.ssoConfig.Providers {
if provider.Type == config.SSOProviderTypeSAML && provider.Enabled && provider.SAML != nil {
if err := r.samlManager.InitializeProvider(ctx, provider.ID, provider.SAML); err != nil {

View file

@ -24,7 +24,13 @@ func newTestSAMLService(t *testing.T, providerID string, metadataXML string) *SA
func TestHandleSAMLACS_ProcessResponseError(t *testing.T) {
router := newSAMLRouter(t, testSAMLProvider("okta", true))
router.samlManager.services["okta"] = &SAMLService{}
metadataXML := `<?xml version="1.0"?>
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp">
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp.example.com/sso"/>
</IDPSSODescriptor>
</EntityDescriptor>`
router.samlManager.services["okta"] = newTestSAMLService(t, "okta", metadataXML)
req := httptest.NewRequest(http.MethodPost, "/api/saml/okta/acs", nil)
rec := httptest.NewRecorder()

View file

@ -54,6 +54,32 @@ func TestHandleSAMLMetadata_Success(t *testing.T) {
}
}
func TestHandleSAMLMetadata_SynchronizesConfiguredPublicURL(t *testing.T) {
router := &Router{
config: &config.Config{PublicURL: "https://pulse.example.com"},
samlManager: NewSAMLServiceManager(""),
ssoConfig: &config.SSOConfig{
Providers: []config.SSOProvider{testSAMLProvider("okta", true)},
},
}
req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/metadata", nil)
rr := httptest.NewRecorder()
router.handleSAMLMetadata(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
}
body := rr.Body.String()
if !strings.Contains(body, `entityID="https://pulse.example.com/saml/okta"`) {
t.Fatalf("expected absolute entityID in metadata, got %q", body)
}
if !strings.Contains(body, `Location="https://pulse.example.com/api/saml/okta/acs"`) {
t.Fatalf("expected absolute ACS URL in metadata, got %q", body)
}
}
func TestHandleSAMLLogin_SuccessGetAndPost(t *testing.T) {
router := newSAMLRouter(t, testSAMLProvider("okta", true))

View file

@ -47,6 +47,10 @@ type SAMLAuthResult struct {
Attributes map[string][]string
}
func normalizeSAMLBaseURL(baseURL string) string {
return strings.TrimRight(strings.TrimSpace(baseURL), "/")
}
// NewSAMLService creates a new SAML service for a provider
func NewSAMLService(ctx context.Context, providerID string, cfg *config.SAMLProviderConfig, baseURL string) (*SAMLService, error) {
if cfg == nil {
@ -56,7 +60,7 @@ func NewSAMLService(ctx context.Context, providerID string, cfg *config.SAMLProv
service := &SAMLService{
providerID: providerID,
config: cfg,
baseURL: strings.TrimRight(baseURL, "/"),
baseURL: normalizeSAMLBaseURL(baseURL),
httpClient: newSAMLHTTPClient(),
}
@ -268,6 +272,17 @@ func (s *SAMLService) addIDPCertificate(metadata *saml.EntityDescriptor) error {
}
func (s *SAMLService) initServiceProvider() error {
s.mu.Lock()
defer s.mu.Unlock()
return s.initServiceProviderLocked()
}
func (s *SAMLService) initServiceProviderLocked() error {
if s.idpMetadata == nil {
return errors.New("idp metadata not loaded")
}
// Build SP Entity ID
spEntityID := s.config.SPEntityID
if spEntityID == "" {
@ -336,6 +351,14 @@ func (s *SAMLService) initServiceProvider() error {
return nil
}
func (s *SAMLService) SetBaseURL(baseURL string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.baseURL = normalizeSAMLBaseURL(baseURL)
return s.initServiceProviderLocked()
}
func (s *SAMLService) loadSPCredentials() (*x509.Certificate, *rsa.PrivateKey, error) {
var certData, keyData []byte
var err error
@ -630,6 +653,14 @@ func (s *SAMLService) ProviderID() string {
return s.providerID
}
// GetBaseURL returns the normalized base URL used to build SP endpoints.
func (s *SAMLService) GetBaseURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.baseURL
}
// GetSPEntityID returns the Service Provider Entity ID
func (s *SAMLService) GetSPEntityID() string {
s.mu.RLock()

View file

@ -40,21 +40,22 @@ func (r *Router) handleUpdateOIDCConfig(w http.ResponseWriter, req *http.Request
}
var payload struct {
Enabled bool `json:"enabled"`
IssuerURL string `json:"issuerUrl"`
ClientID string `json:"clientId"`
ClientSecret *string `json:"clientSecret,omitempty"`
RedirectURL string `json:"redirectUrl"`
LogoutURL string `json:"logoutUrl"`
Scopes []string `json:"scopes"`
UsernameClaim string `json:"usernameClaim"`
EmailClaim string `json:"emailClaim"`
GroupsClaim string `json:"groupsClaim"`
AllowedGroups []string `json:"allowedGroups"`
AllowedDomains []string `json:"allowedDomains"`
AllowedEmails []string `json:"allowedEmails"`
ClearClientSecret bool `json:"clearClientSecret"`
CABundle *string `json:"caBundle"`
Enabled bool `json:"enabled"`
IssuerURL string `json:"issuerUrl"`
ClientID string `json:"clientId"`
ClientSecret *string `json:"clientSecret,omitempty"`
RedirectURL string `json:"redirectUrl"`
LogoutURL string `json:"logoutUrl"`
Scopes []string `json:"scopes"`
UsernameClaim string `json:"usernameClaim"`
EmailClaim string `json:"emailClaim"`
GroupsClaim string `json:"groupsClaim"`
AllowedGroups []string `json:"allowedGroups"`
AllowedDomains []string `json:"allowedDomains"`
AllowedEmails []string `json:"allowedEmails"`
GroupRoleMappings map[string]string `json:"groupRoleMappings"`
ClearClientSecret bool `json:"clearClientSecret"`
CABundle *string `json:"caBundle"`
}
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
@ -75,8 +76,11 @@ func (r *Router) handleUpdateOIDCConfig(w http.ResponseWriter, req *http.Request
AllowedGroups: append([]string{}, payload.AllowedGroups...),
AllowedDomains: append([]string{}, payload.AllowedDomains...),
AllowedEmails: append([]string{}, payload.AllowedEmails...),
CABundle: strings.TrimSpace(cfg.CABundle),
EnvOverrides: make(map[string]bool),
GroupRoleMappings: cloneStringMap(
payload.GroupRoleMappings,
),
CABundle: strings.TrimSpace(cfg.CABundle),
EnvOverrides: make(map[string]bool),
}
// Preserve existing secret unless explicitly changed.
@ -115,22 +119,23 @@ func (r *Router) handleUpdateOIDCConfig(w http.ResponseWriter, req *http.Request
}
type oidcResponse struct {
Enabled bool `json:"enabled"`
IssuerURL string `json:"issuerUrl"`
ClientID string `json:"clientId"`
RedirectURL string `json:"redirectUrl"`
LogoutURL string `json:"logoutUrl"`
Scopes []string `json:"scopes"`
UsernameClaim string `json:"usernameClaim"`
EmailClaim string `json:"emailClaim"`
GroupsClaim string `json:"groupsClaim"`
AllowedGroups []string `json:"allowedGroups"`
AllowedDomains []string `json:"allowedDomains"`
AllowedEmails []string `json:"allowedEmails"`
CABundle string `json:"caBundle"`
ClientSecretSet bool `json:"clientSecretSet"`
DefaultRedirect string `json:"defaultRedirect"`
EnvOverrides map[string]bool `json:"envOverrides,omitempty"`
Enabled bool `json:"enabled"`
IssuerURL string `json:"issuerUrl"`
ClientID string `json:"clientId"`
RedirectURL string `json:"redirectUrl"`
LogoutURL string `json:"logoutUrl"`
Scopes []string `json:"scopes"`
UsernameClaim string `json:"usernameClaim"`
EmailClaim string `json:"emailClaim"`
GroupsClaim string `json:"groupsClaim"`
AllowedGroups []string `json:"allowedGroups"`
AllowedDomains []string `json:"allowedDomains"`
AllowedEmails []string `json:"allowedEmails"`
GroupRoleMappings map[string]string `json:"groupRoleMappings"`
CABundle string `json:"caBundle"`
ClientSecretSet bool `json:"clientSecretSet"`
DefaultRedirect string `json:"defaultRedirect"`
EnvOverrides map[string]bool `json:"envOverrides,omitempty"`
}
func makeOIDCResponse(cfg *config.OIDCConfig, publicURL string) oidcResponse {
@ -140,18 +145,21 @@ func makeOIDCResponse(cfg *config.OIDCConfig, publicURL string) oidcResponse {
}
resp := oidcResponse{
Enabled: cfg.Enabled,
IssuerURL: cfg.IssuerURL,
ClientID: cfg.ClientID,
RedirectURL: cfg.RedirectURL,
LogoutURL: cfg.LogoutURL,
Scopes: append([]string{}, cfg.Scopes...),
UsernameClaim: cfg.UsernameClaim,
EmailClaim: cfg.EmailClaim,
GroupsClaim: cfg.GroupsClaim,
AllowedGroups: append([]string{}, cfg.AllowedGroups...),
AllowedDomains: append([]string{}, cfg.AllowedDomains...),
AllowedEmails: append([]string{}, cfg.AllowedEmails...),
Enabled: cfg.Enabled,
IssuerURL: cfg.IssuerURL,
ClientID: cfg.ClientID,
RedirectURL: cfg.RedirectURL,
LogoutURL: cfg.LogoutURL,
Scopes: append([]string{}, cfg.Scopes...),
UsernameClaim: cfg.UsernameClaim,
EmailClaim: cfg.EmailClaim,
GroupsClaim: cfg.GroupsClaim,
AllowedGroups: append([]string{}, cfg.AllowedGroups...),
AllowedDomains: append([]string{}, cfg.AllowedDomains...),
AllowedEmails: append([]string{}, cfg.AllowedEmails...),
GroupRoleMappings: cloneStringMap(
cfg.GroupRoleMappings,
),
CABundle: cfg.CABundle,
ClientSecretSet: cfg.ClientSecret != "",
DefaultRedirect: config.DefaultRedirectURL(publicURL),
@ -166,3 +174,25 @@ func makeOIDCResponse(cfg *config.OIDCConfig, publicURL string) oidcResponse {
return resp
}
func cloneStringMap(input map[string]string) map[string]string {
if len(input) == 0 {
return nil
}
cloned := make(map[string]string, len(input))
for key, value := range input {
trimmedKey := strings.TrimSpace(key)
trimmedValue := strings.TrimSpace(value)
if trimmedKey == "" || trimmedValue == "" {
continue
}
cloned[trimmedKey] = trimmedValue
}
if len(cloned) == 0 {
return nil
}
return cloned
}

View file

@ -23,6 +23,9 @@ func TestSecurityOIDCHandlers_GetConfig(t *testing.T) {
UsernameClaim: "sub",
EmailClaim: "email",
GroupsClaim: "groups",
GroupRoleMappings: map[string]string{
"group-uuid": "viewer",
},
},
}
router := &Router{config: cfg}
@ -49,6 +52,9 @@ func TestSecurityOIDCHandlers_GetConfig(t *testing.T) {
if !resp.ClientSecretSet {
t.Fatalf("expected client secret to be marked as set")
}
if got := resp.GroupRoleMappings["group-uuid"]; got != "viewer" {
t.Fatalf("expected group role mappings to be included, got %q", got)
}
}
func TestSecurityOIDCHandlers_UpdateSaveFailure(t *testing.T) {

View file

@ -180,3 +180,24 @@ func TestMakeOIDCResponse_SlicesCopied(t *testing.T) {
t.Error("response AllowedGroups should be a copy, not a reference")
}
}
func TestMakeOIDCResponse_GroupRoleMappingsCopied(t *testing.T) {
t.Parallel()
cfg := &config.OIDCConfig{
GroupRoleMappings: map[string]string{
"group-a": "viewer",
},
}
resp := makeOIDCResponse(cfg, "https://pulse.example.com")
if got := resp.GroupRoleMappings["group-a"]; got != "viewer" {
t.Fatalf("expected group role mapping to round trip, got %q", got)
}
cfg.GroupRoleMappings["group-a"] = "admin"
if got := resp.GroupRoleMappings["group-a"]; got != "viewer" {
t.Fatalf("expected response group role mappings to be copied, got %q", got)
}
}

View file

@ -34,6 +34,9 @@ func TestSaveOIDCConfig(t *testing.T) {
Enabled: true,
IssuerURL: "https://issuer.com",
ClientID: "client-id",
GroupRoleMappings: map[string]string{
"group-uuid": "viewer",
},
}
err = SaveOIDCConfig(settings)
@ -43,6 +46,7 @@ func TestSaveOIDCConfig(t *testing.T) {
loaded, err := p.LoadOIDCConfig()
require.NoError(t, err)
assert.Equal(t, settings.IssuerURL, loaded.IssuerURL)
assert.Equal(t, "viewer", loaded.GroupRoleMappings["group-uuid"])
}
func TestLoadHostMetadata_Wait(t *testing.T) {