diff --git a/internal/api/saml_handlers_success_test.go b/internal/api/saml_handlers_success_test.go index 6a6ada6fa..4fb5899f5 100644 --- a/internal/api/saml_handlers_success_test.go +++ b/internal/api/saml_handlers_success_test.go @@ -80,6 +80,54 @@ func TestHandleSAMLMetadata_SynchronizesConfiguredPublicURL(t *testing.T) { } } +func TestHandleSAMLMetadata_RebuildsPreviouslyInitializedRelativeMetadata(t *testing.T) { + provider := testSAMLProvider("okta", true) + router := &Router{ + config: &config.Config{PublicURL: "https://pulse.example.com"}, + samlManager: NewSAMLServiceManager(""), + ssoConfig: &config.SSOConfig{ + Providers: []config.SSOProvider{provider}, + }, + } + + if err := router.samlManager.InitializeProvider(context.Background(), provider.ID, provider.SAML); err != nil { + t.Fatalf("InitializeProvider: %v", err) + } + + service := router.samlManager.GetService(provider.ID) + if service == nil { + t.Fatal("expected initialized SAML service") + } + + staleMetadata, err := service.GetMetadata() + if err != nil { + t.Fatalf("GetMetadata before sync: %v", err) + } + staleBody := string(staleMetadata) + if !strings.Contains(staleBody, `entityID="/saml/okta"`) { + t.Fatalf("expected pre-sync relative entityID, got %q", staleBody) + } + if !strings.Contains(staleBody, `Location="/api/saml/okta/acs"`) { + t.Fatalf("expected pre-sync relative ACS URL, got %q", staleBody) + } + + req := httptest.NewRequest(http.MethodGet, "/api/saml/okta/metadata", nil) + rr := httptest.NewRecorder() + + router.handleSAMLMetadata(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) + } + body := rr.Body.String() + if !strings.Contains(body, `entityID="https://pulse.example.com/saml/okta"`) { + t.Fatalf("expected absolute entityID after sync, got %q", body) + } + if !strings.Contains(body, `Location="https://pulse.example.com/api/saml/okta/acs"`) { + t.Fatalf("expected absolute ACS URL after sync, got %q", body) + } +} + func TestHandleSAMLLogin_SuccessGetAndPost(t *testing.T) { router := newSAMLRouter(t, testSAMLProvider("okta", true))