package api
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
)
func generateTestCert(t *testing.T) (certPEM, keyPEM []byte, key *rsa.PrivateKey) {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate key: %v", err)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
IsCA: true,
}
der, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("create cert: %v", err)
}
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
return certPEM, keyPEM, priv
}
func TestParseIDPMetadataXML(t *testing.T) {
xml := `
`
metadata, err := parseIDPMetadataXML([]byte(xml))
if err != nil {
t.Fatalf("parse metadata: %v", err)
}
if metadata.EntityID != "idp-1" {
t.Fatalf("unexpected entity id: %s", metadata.EntityID)
}
wrapped := `
`
metadata, err = parseIDPMetadataXML([]byte(wrapped))
if err != nil {
t.Fatalf("parse wrapped metadata: %v", err)
}
if metadata.EntityID != "idp-2" {
t.Fatalf("unexpected entity id: %s", metadata.EntityID)
}
if _, err := parseIDPMetadataXML([]byte("
`,
IDPCertificate: string(certPEM),
}
service, err := NewSAMLService(context.Background(), "idp", cfg, "http://localhost:8080")
if err != nil {
t.Fatalf("new service: %v", err)
}
url, err := service.MakeAuthRequest("")
if err != nil || !strings.Contains(url, "SAMLRequest") {
t.Fatalf("unexpected auth url: %v %s", err, url)
}
if _, err := service.GetMetadata(); err != nil {
t.Fatalf("metadata error: %v", err)
}
logoutURL, err := service.MakeLogoutRequest("user", "sess")
if err != nil || !strings.Contains(logoutURL, "SAMLRequest") {
t.Fatalf("unexpected logout url: %v %s", err, logoutURL)
}
service = &SAMLService{config: &config.SAMLProviderConfig{}}
if _, err := service.MakeAuthRequest(""); err == nil {
t.Fatal("expected error when sp missing")
}
if _, err := service.GetMetadata(); err == nil {
t.Fatal("expected error when sp missing")
}
if _, err := service.MakeLogoutRequest("user", "sess"); err == nil {
t.Fatal("expected error when sp missing")
}
if err := service.RefreshMetadata(context.Background()); err == nil {
t.Fatal("expected refresh error without url")
}
}
func TestFetchMetadataFromURL(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`
`))
}))
defer server.Close()
cfg := &config.SAMLProviderConfig{IDPMetadataURL: server.URL}
service := &SAMLService{config: cfg, httpClient: newSAMLHTTPClient()}
metadata, err := service.fetchIDPMetadataFromURL(context.Background(), server.URL)
if err != nil {
t.Fatalf("fetch metadata: %v", err)
}
if metadata.EntityID != "idp-url" {
t.Fatalf("unexpected entity id: %s", metadata.EntityID)
}
}