Pulse/internal/api/oidc_handlers_more_test.go
rcourtman 43d7fffeef Test: add coverage for auth and security handlers
Add additional tests for OIDC, SAML, and tenant middleware to improve coverage of security-critical paths.
2026-02-02 22:02:11 +00:00

318 lines
9.9 KiB
Go

package api
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"golang.org/x/oauth2"
)
func newTestOIDCConfig() *config.OIDCConfig {
return &config.OIDCConfig{
Enabled: true,
IssuerURL: "https://issuer.example.com",
ClientID: "client-id",
ClientSecret: "client-secret",
RedirectURL: "https://app.example.com/api/oidc/callback",
Scopes: []string{"openid", "email"},
UsernameClaim: "preferred_username",
EmailClaim: "email",
GroupsClaim: "groups",
}
}
func newTestOIDCService(cfg *config.OIDCConfig, authURL, tokenURL string) *OIDCService {
return &OIDCService{
snapshot: oidcSnapshot{
issuer: cfg.IssuerURL,
clientID: cfg.ClientID,
clientSecret: cfg.ClientSecret,
redirectURL: cfg.RedirectURL,
scopes: append([]string{}, cfg.Scopes...),
caBundle: cfg.CABundle,
},
oauth2Cfg: &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
RedirectURL: cfg.RedirectURL,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenURL,
},
Scopes: append([]string{}, cfg.Scopes...),
},
stateStore: newOIDCStateStore(),
}
}
func newOIDCRouterWithService(t *testing.T, authURL, tokenURL string) (*Router, *OIDCService) {
t.Helper()
cfg := newTestOIDCConfig()
svc := newTestOIDCService(cfg, authURL, tokenURL)
router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
t.Cleanup(func() {
if svc.stateStore != nil {
svc.stateStore.Stop()
}
})
return router, svc
}
func TestHandleOIDCLogin_MethodNotAllowed(t *testing.T) {
router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
req := httptest.NewRequest(http.MethodPut, "/api/oidc/login", nil)
rec := httptest.NewRecorder()
router.handleOIDCLogin(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
}
}
func TestHandleOIDCLogin_InvalidJSON(t *testing.T) {
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", strings.NewReader("{"))
rec := httptest.NewRecorder()
router.handleOIDCLogin(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
var payload map[string]interface{}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload["code"] != "invalid_request" {
t.Fatalf("code = %v, want invalid_request", payload["code"])
}
}
func TestHandleOIDCLogin_GetSuccess(t *testing.T) {
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodGet, "/api/oidc/login?returnTo=/dashboard", nil)
rec := httptest.NewRecorder()
router.handleOIDCLogin(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
u, err := url.Parse(location)
if err != nil {
t.Fatalf("parse location: %v", err)
}
if u.Host != "auth.example.com" {
t.Fatalf("unexpected auth host: %q", u.Host)
}
state := u.Query().Get("state")
if state == "" {
t.Fatalf("expected state param in redirect")
}
entry, ok := svc.consumeState(state)
if !ok {
t.Fatalf("expected state entry to be stored")
}
if entry.ReturnTo != "/dashboard" {
t.Fatalf("returnTo = %q, want /dashboard", entry.ReturnTo)
}
}
func TestHandleOIDCLogin_PostSuccess(t *testing.T) {
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
body := strings.NewReader(`{"returnTo":"/home"}`)
req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", body)
rec := httptest.NewRecorder()
router.handleOIDCLogin(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var payload struct {
AuthorizationURL string `json:"authorizationUrl"`
}
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if payload.AuthorizationURL == "" {
t.Fatalf("expected authorizationUrl in response")
}
u, err := url.Parse(payload.AuthorizationURL)
if err != nil {
t.Fatalf("parse authorizationUrl: %v", err)
}
state := u.Query().Get("state")
if state == "" {
t.Fatalf("expected state param in authorizationUrl")
}
entry, ok := svc.consumeState(state)
if !ok {
t.Fatalf("expected state entry to be stored")
}
if entry.ReturnTo != "/home" {
t.Fatalf("returnTo = %q, want /home", entry.ReturnTo)
}
}
func TestGetOIDCService_ReturnsCachedService(t *testing.T) {
cfg := newTestOIDCConfig()
svc := newTestOIDCService(cfg, "https://auth.example.com/authorize", "https://token.example.com")
router := &Router{config: &config.Config{OIDC: cfg}, oidcService: svc}
defer svc.stateStore.Stop()
got, err := router.getOIDCService(context.Background(), cfg.RedirectURL)
if err != nil {
t.Fatalf("getOIDCService error: %v", err)
}
if got != svc {
t.Fatalf("expected cached service to be returned")
}
}
func TestHandleOIDCCallback_MethodNotAllowed(t *testing.T) {
router := &Router{config: &config.Config{OIDC: newTestOIDCConfig()}}
req := httptest.NewRequest(http.MethodPost, "/api/oidc/callback", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
}
}
func TestHandleOIDCCallback_ErrorParam(t *testing.T) {
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?error=access_denied", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=access_denied") {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestHandleOIDCCallback_MissingState(t *testing.T) {
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?code=abc", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=missing_state") {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestHandleOIDCCallback_InvalidState(t *testing.T) {
router, _ := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state=invalid", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=invalid_state") {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestHandleOIDCCallback_MissingCode(t *testing.T) {
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", "")
state, _, err := svc.newStateEntry("/dashboard")
if err != nil {
t.Fatalf("newStateEntry error: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state), nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=missing_code") {
t.Fatalf("unexpected redirect location: %q", location)
}
if !strings.HasPrefix(location, "/dashboard") {
t.Fatalf("expected redirect back to /dashboard, got %q", location)
}
}
func TestHandleOIDCCallback_ExchangeFailed(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer tokenServer.Close()
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
svc.httpClient = tokenServer.Client()
state, _, err := svc.newStateEntry("/dashboard")
if err != nil {
t.Fatalf("newStateEntry error: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=exchange_failed") {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestHandleOIDCCallback_MissingIDToken(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":"access","token_type":"Bearer","expires_in":3600}`))
}))
defer tokenServer.Close()
router, svc := newOIDCRouterWithService(t, "https://auth.example.com/authorize", tokenServer.URL)
svc.httpClient = tokenServer.Client()
state, _, err := svc.newStateEntry("/dashboard")
if err != nil {
t.Fatalf("newStateEntry error: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback?state="+url.QueryEscape(state)+"&code=abc", nil)
rec := httptest.NewRecorder()
router.handleOIDCCallback(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusFound)
}
location := rec.Header().Get("Location")
if !strings.Contains(location, "oidc_error=missing_id_token") {
t.Fatalf("unexpected redirect location: %q", location)
}
}