mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 03:20:11 +00:00
Add additional tests for OIDC, SAML, and tenant middleware to improve coverage of security-critical paths.
318 lines
9.9 KiB
Go
318 lines
9.9 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
func newTestOIDCConfig() *config.OIDCConfig {
|
|
return &config.OIDCConfig{
|
|
Enabled: true,
|
|
IssuerURL: "https://issuer.example.com",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
RedirectURL: "https://app.example.com/api/oidc/callback",
|
|
Scopes: []string{"openid", "email"},
|
|
UsernameClaim: "preferred_username",
|
|
EmailClaim: "email",
|
|
GroupsClaim: "groups",
|
|
}
|
|
}
|
|
|
|
func newTestOIDCService(cfg *config.OIDCConfig, authURL, tokenURL string) *OIDCService {
|
|
return &OIDCService{
|
|
snapshot: oidcSnapshot{
|
|
issuer: cfg.IssuerURL,
|
|
clientID: cfg.ClientID,
|
|
clientSecret: cfg.ClientSecret,
|
|
redirectURL: cfg.RedirectURL,
|
|
scopes: append([]string{}, cfg.Scopes...),
|
|
caBundle: cfg.CABundle,
|
|
},
|
|
oauth2Cfg: &oauth2.Config{
|
|
ClientID: cfg.ClientID,
|
|
ClientSecret: cfg.ClientSecret,
|
|
RedirectURL: cfg.RedirectURL,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: authURL,
|
|
TokenURL: tokenURL,
|
|
},
|
|
Scopes: append([]string{}, cfg.Scopes...),
|
|
},
|
|
stateStore: newOIDCStateStore(),
|
|
}
|
|
}
|
|
|
|
func newOIDCRouterWithService(t *testing.T, authURL, tokenURL string) (*Router, *OIDCService) {
|
|
t.Helper()
|
|
cfg := newTestOIDCConfig()
|
|
svc := newTestOIDCService(cfg, authURL, tokenURL)
|
|
router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
|
|
t.Cleanup(func() {
|
|
if svc.stateStore != nil {
|
|
svc.stateStore.Stop()
|
|
}
|
|
})
|
|
return router, svc
|
|
}
|
|
|
|
func TestHandleOIDCLogin_MethodNotAllowed(t *testing.T) {
|
|
router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
|
|
req := httptest.NewRequest(http.MethodPut, "/api/oidc/login", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCLogin(rec, req)
|
|
|
|
if rec.Code != http.StatusMethodNotAllowed {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCLogin_InvalidJSON(t *testing.T) {
|
|
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
|
req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", strings.NewReader("{"))
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCLogin(rec, req)
|
|
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
|
}
|
|
var payload map[string]interface{}
|
|
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("decode response: %v", err)
|
|
}
|
|
if payload["code"] != "invalid_request" {
|
|
t.Fatalf("code = %v, want invalid_request", payload["code"])
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCLogin_GetSuccess(t *testing.T) {
|
|
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/api/oidc/login?returnTo=/dashboard", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCLogin(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
|
}
|
|
location := rec.Header().Get("Location")
|
|
u, err := url.Parse(location)
|
|
if err != nil {
|
|
t.Fatalf("parse location: %v", err)
|
|
}
|
|
if u.Host != "auth.example.com" {
|
|
t.Fatalf("unexpected auth host: %q", u.Host)
|
|
}
|
|
state := u.Query().Get("state")
|
|
if state == "" {
|
|
t.Fatalf("expected state param in redirect")
|
|
}
|
|
entry, ok := svc.consumeState(state)
|
|
if !ok {
|
|
t.Fatalf("expected state entry to be stored")
|
|
}
|
|
if entry.ReturnTo != "/dashboard" {
|
|
t.Fatalf("returnTo = %q, want /dashboard", entry.ReturnTo)
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCLogin_PostSuccess(t *testing.T) {
|
|
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
|
body := strings.NewReader(`{"returnTo":"/home"}`)
|
|
req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", body)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCLogin(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
var payload struct {
|
|
AuthorizationURL string `json:"authorizationUrl"`
|
|
}
|
|
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("decode response: %v", err)
|
|
}
|
|
if payload.AuthorizationURL == "" {
|
|
t.Fatalf("expected authorizationUrl in response")
|
|
}
|
|
u, err := url.Parse(payload.AuthorizationURL)
|
|
if err != nil {
|
|
t.Fatalf("parse authorizationUrl: %v", err)
|
|
}
|
|
state := u.Query().Get("state")
|
|
if state == "" {
|
|
t.Fatalf("expected state param in authorizationUrl")
|
|
}
|
|
entry, ok := svc.consumeState(state)
|
|
if !ok {
|
|
t.Fatalf("expected state entry to be stored")
|
|
}
|
|
if entry.ReturnTo != "/home" {
|
|
t.Fatalf("returnTo = %q, want /home", entry.ReturnTo)
|
|
}
|
|
}
|
|
|
|
func TestGetOIDCService_ReturnsCachedService(t *testing.T) {
|
|
cfg := newTestOIDCConfig()
|
|
svc := newTestOIDCService(cfg, "https://auth.example.com/authorize", "https://token.example.com")
|
|
router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
|
|
defer svc.stateStore.Stop()
|
|
|
|
got, err := router.getOIDCService(context.Background(), cfg.RedirectURL)
|
|
if err != nil {
|
|
t.Fatalf("getOIDCService error: %v", err)
|
|
}
|
|
if got != svc {
|
|
t.Fatalf("expected cached service to be returned")
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCCallback_MethodNotAllowed(t *testing.T) {
|
|
router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
|
|
req := httptest.NewRequest(http.MethodPost, "/api/oidc/callback", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCCallback(rec, req)
|
|
|
|
if rec.Code != http.StatusMethodNotAllowed {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCCallback_ErrorParam(t *testing.T) {
|
|
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?error=access_denied", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCCallback(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
|
}
|
|
location := rec.Header().Get("Location")
|
|
if !strings.Contains(location, "oidc_error=access_denied") {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCCallback_MissingState(t *testing.T) {
|
|
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?code=abc", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCCallback(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
|
}
|
|
location := rec.Header().Get("Location")
|
|
if !strings.Contains(location, "oidc_error=missing_state") {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCCallback_InvalidState(t *testing.T) {
|
|
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state=invalid", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCCallback(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
|
}
|
|
location := rec.Header().Get("Location")
|
|
if !strings.Contains(location, "oidc_error=invalid_state") {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCCallback_MissingCode(t *testing.T) {
|
|
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
|
|
state, _, err := svc.newStateEntry("/dashboard")
|
|
if err != nil {
|
|
t.Fatalf("newStateEntry error: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state), nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCCallback(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
|
}
|
|
location := rec.Header().Get("Location")
|
|
if !strings.Contains(location, "oidc_error=missing_code") {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
if !strings.HasPrefix(location, "/dashboard") {
|
|
t.Fatalf("expected redirect back to /dashboard, got %q", location)
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCCallback_ExchangeFailed(t *testing.T) {
|
|
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
defer tokenServer.Close()
|
|
|
|
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
|
|
svc.httpClient = tokenServer.Client()
|
|
state, _, err := svc.newStateEntry("/dashboard")
|
|
if err != nil {
|
|
t.Fatalf("newStateEntry error: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCCallback(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
|
}
|
|
location := rec.Header().Get("Location")
|
|
if !strings.Contains(location, "oidc_error=exchange_failed") {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
}
|
|
|
|
func TestHandleOIDCCallback_MissingIDToken(t *testing.T) {
|
|
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(`{"access_token":"access","token_type":"Bearer","expires_in":3600}`))
|
|
}))
|
|
defer tokenServer.Close()
|
|
|
|
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
|
|
svc.httpClient = tokenServer.Client()
|
|
state, _, err := svc.newStateEntry("/dashboard")
|
|
if err != nil {
|
|
t.Fatalf("newStateEntry error: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
router.handleOIDCCallback(rec, req)
|
|
|
|
if rec.Code != http.StatusFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
|
|
}
|
|
location := rec.Header().Get("Location")
|
|
if !strings.Contains(location, "oidc_error=missing_id_token") {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
}
|