package api import ( "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "math/big" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/rcourtman/pulse-go-rewrite/internal/config" ) func generateTestCert(t *testing.T) (certPEM, keyPEM []byte, key *rsa.PrivateKey) { t.Helper() priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("generate key: %v", err) } template := x509.Certificate{ SerialNumber: big.NewInt(1), NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, IsCA: true, } der, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { t.Fatalf("create cert: %v", err) } certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) return certPEM, keyPEM, priv } func TestParseIDPMetadataXML(t *testing.T) { xml := ` ` metadata, err := parseIDPMetadataXML([]byte(xml)) if err != nil { t.Fatalf("parse metadata: %v", err) } if metadata.EntityID != "idp-1" { t.Fatalf("unexpected entity id: %s", metadata.EntityID) } wrapped := ` ` metadata, err = parseIDPMetadataXML([]byte(wrapped)) if err != nil { t.Fatalf("parse wrapped metadata: %v", err) } if metadata.EntityID != "idp-2" { t.Fatalf("unexpected entity id: %s", metadata.EntityID) } if _, err := parseIDPMetadataXML([]byte(" `, IDPCertificate: string(certPEM), } service, err := NewSAMLService(context.Background(), "idp", cfg, "http://localhost:8080") if err != nil { t.Fatalf("new service: %v", err) } url, err := service.MakeAuthRequest("") if err != nil || !strings.Contains(url, "SAMLRequest") { t.Fatalf("unexpected auth url: %v %s", err, url) } if _, err := service.GetMetadata(); err != nil { t.Fatalf("metadata error: %v", err) } logoutURL, err := service.MakeLogoutRequest("user", "sess") if err != nil || !strings.Contains(logoutURL, "SAMLRequest") { t.Fatalf("unexpected logout url: %v %s", err, logoutURL) } service = &SAMLService{config: &config.SAMLProviderConfig{}} if _, err := service.MakeAuthRequest(""); err == nil { t.Fatal("expected error when sp missing") } if _, err := service.GetMetadata(); err == nil { t.Fatal("expected error when sp missing") } if _, err := service.MakeLogoutRequest("user", "sess"); err == nil { t.Fatal("expected error when sp missing") } if err := service.RefreshMetadata(context.Background()); err == nil { t.Fatal("expected refresh error without url") } } func TestFetchMetadataFromURL(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(` `)) })) defer server.Close() cfg := &config.SAMLProviderConfig{IDPMetadataURL: server.URL} service := &SAMLService{config: cfg, httpClient: newSAMLHTTPClient()} metadata, err := service.fetchIDPMetadataFromURL(context.Background(), server.URL) if err != nil { t.Fatalf("fetch metadata: %v", err) } if metadata.EntityID != "idp-url" { t.Fatalf("unexpected entity id: %s", metadata.EntityID) } }