mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-28 19:41:17 +00:00
438 lines
13 KiB
Go
438 lines
13 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
func TestOIDCServiceMatches(t *testing.T) {
|
|
svc := &OIDCService{
|
|
snapshot: oidcSnapshot{
|
|
issuer: "https://issuer.example.com",
|
|
clientID: "client-id",
|
|
clientSecret: "client-secret",
|
|
redirectURL: "https://pulse.example.com/callback",
|
|
scopes: []string{"openid", "email"},
|
|
caBundle: "",
|
|
caBundleHash: "",
|
|
},
|
|
}
|
|
|
|
cfg := &config.OIDCConfig{
|
|
Enabled: true,
|
|
IssuerURL: "https://issuer.example.com",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
RedirectURL: "https://pulse.example.com/callback",
|
|
Scopes: []string{"openid", "email"},
|
|
CABundle: "",
|
|
}
|
|
|
|
if !svc.Matches(cfg) {
|
|
t.Fatalf("expected config to match")
|
|
}
|
|
|
|
cfg.ClientID = "other"
|
|
if svc.Matches(cfg) {
|
|
t.Fatalf("expected config mismatch")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceStateEntryAndConsume(t *testing.T) {
|
|
svc := &OIDCService{stateStore: newOIDCStateStore()}
|
|
|
|
state, entry, err := svc.newStateEntry("/return")
|
|
if err != nil {
|
|
t.Fatalf("newStateEntry error: %v", err)
|
|
}
|
|
if state == "" || entry == nil {
|
|
t.Fatalf("expected state and entry")
|
|
}
|
|
if entry.ReturnTo != "/return" {
|
|
t.Fatalf("expected returnTo /return, got %q", entry.ReturnTo)
|
|
}
|
|
|
|
consumed, ok := svc.consumeState(state)
|
|
if !ok || consumed == nil {
|
|
t.Fatalf("expected to consume state")
|
|
}
|
|
if _, ok := svc.consumeState(state); ok {
|
|
t.Fatalf("expected state to be removed")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceAuthCodeURLIncludesPKCE(t *testing.T) {
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{Endpoint: oauth2.Endpoint{AuthURL: "https://issuer.example.com/auth"}, ClientID: "client"},
|
|
}
|
|
entry := &oidcStateEntry{Nonce: "nonce", CodeChallenge: "challenge"}
|
|
url := svc.authCodeURL("state", entry)
|
|
if url == "" {
|
|
t.Fatalf("expected auth url")
|
|
}
|
|
if !strings.Contains(url, "code_challenge=challenge") {
|
|
t.Fatalf("expected code_challenge in url, got %q", url)
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceExchangeCode(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if r.Form.Get("code_verifier") == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprint(w, `{"access_token":"access","token_type":"Bearer","refresh_token":"refresh","expires_in":3600}`)
|
|
}))
|
|
defer server.Close()
|
|
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{
|
|
ClientID: "client",
|
|
ClientSecret: "secret",
|
|
Endpoint: oauth2.Endpoint{TokenURL: server.URL},
|
|
RedirectURL: "https://pulse.example.com/callback",
|
|
},
|
|
httpClient: server.Client(),
|
|
}
|
|
|
|
entry := &oidcStateEntry{CodeVerifier: "verifier"}
|
|
token, err := svc.exchangeCode(context.Background(), "code", entry)
|
|
if err != nil {
|
|
t.Fatalf("exchangeCode error: %v", err)
|
|
}
|
|
if token.AccessToken != "access" {
|
|
t.Fatalf("expected access token, got %q", token.AccessToken)
|
|
}
|
|
}
|
|
|
|
func TestHashCABundle(t *testing.T) {
|
|
if hash, err := hashCABundle(""); err != nil || hash != "" {
|
|
t.Fatalf("expected empty hash, got %q err=%v", hash, err)
|
|
}
|
|
|
|
file := t.TempDir() + "/ca.pem"
|
|
data := []byte("test-ca")
|
|
if err := os.WriteFile(file, data, 0o600); err != nil {
|
|
t.Fatalf("write file: %v", err)
|
|
}
|
|
|
|
hash, err := hashCABundle(file)
|
|
if err != nil {
|
|
t.Fatalf("hashCABundle error: %v", err)
|
|
}
|
|
sum := sha256.Sum256(data)
|
|
if hash != fmt.Sprintf("%x", sum[:]) {
|
|
t.Fatalf("unexpected hash %q", hash)
|
|
}
|
|
}
|
|
|
|
func TestOIDCStateStoreCleanupAndConsume(t *testing.T) {
|
|
store := &oidcStateStore{entries: make(map[string]*oidcStateEntry), stopCleanup: make(chan struct{})}
|
|
store.Put("expired", &oidcStateEntry{ExpiresAt: time.Now().Add(-time.Minute)})
|
|
store.Put("active", &oidcStateEntry{ExpiresAt: time.Now().Add(time.Minute)})
|
|
|
|
store.cleanup()
|
|
|
|
if _, ok := store.entries["expired"]; ok {
|
|
t.Fatalf("expected expired entry to be cleaned")
|
|
}
|
|
|
|
entry, ok := store.Consume("active")
|
|
if !ok || entry == nil {
|
|
t.Fatalf("expected to consume active entry")
|
|
}
|
|
if _, ok := store.Consume("active"); ok {
|
|
t.Fatalf("expected entry to be removed after consume")
|
|
}
|
|
}
|
|
|
|
func TestOIDCStateStoreStop(t *testing.T) {
|
|
store := &oidcStateStore{entries: make(map[string]*oidcStateEntry), stopCleanup: make(chan struct{})}
|
|
store.Stop()
|
|
select {
|
|
case <-store.stopCleanup:
|
|
default:
|
|
t.Fatalf("expected stop channel to be closed")
|
|
}
|
|
}
|
|
|
|
func TestGenerateRandomURLString(t *testing.T) {
|
|
val, err := generateRandomURLString(16)
|
|
if err != nil {
|
|
t.Fatalf("generateRandomURLString error: %v", err)
|
|
}
|
|
if val == "" {
|
|
t.Fatalf("expected non-empty value")
|
|
}
|
|
}
|
|
|
|
func TestGeneratePKCEPair(t *testing.T) {
|
|
verifier, challenge, err := generatePKCEPair()
|
|
if err != nil {
|
|
t.Fatalf("generatePKCEPair error: %v", err)
|
|
}
|
|
if verifier == "" || challenge == "" {
|
|
t.Fatalf("expected verifier and challenge")
|
|
}
|
|
|
|
hash := sha256.Sum256([]byte(verifier))
|
|
expected := base64.RawURLEncoding.EncodeToString(hash[:])
|
|
if challenge != expected {
|
|
t.Fatalf("unexpected challenge %q", challenge)
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceConsumeStateExpired(t *testing.T) {
|
|
svc := &OIDCService{stateStore: newOIDCStateStore()}
|
|
svc.stateStore.Put("expired", &oidcStateEntry{ExpiresAt: time.Now().Add(-time.Minute)})
|
|
|
|
entry, ok := svc.consumeState("expired")
|
|
if ok || entry != nil {
|
|
t.Fatalf("expected expired entry to be rejected")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceContextWithHTTPClient(t *testing.T) {
|
|
client := &http.Client{}
|
|
svc := &OIDCService{httpClient: client}
|
|
|
|
ctx := svc.contextWithHTTPClient(context.Background())
|
|
if ctx == nil {
|
|
t.Fatalf("expected context")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceRefreshToken_NoToken(t *testing.T) {
|
|
svc := &OIDCService{}
|
|
if _, err := svc.RefreshToken(context.Background(), ""); err == nil {
|
|
t.Fatalf("expected error for empty refresh token")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceAuthCodeURL_NoPKCE(t *testing.T) {
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{Endpoint: oauth2.Endpoint{AuthURL: "https://issuer.example.com/auth"}, ClientID: "client"},
|
|
}
|
|
entry := &oidcStateEntry{Nonce: "nonce"}
|
|
url := svc.authCodeURL("state", entry)
|
|
if url == "" {
|
|
t.Fatalf("expected auth url")
|
|
}
|
|
if strings.Contains(url, "code_challenge=") {
|
|
t.Fatalf("did not expect code_challenge in url")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceExchangeCode_Error(t *testing.T) {
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: "http://127.0.0.1:0"}, ClientID: "client", ClientSecret: "secret"},
|
|
}
|
|
entry := &oidcStateEntry{CodeVerifier: "verifier"}
|
|
_, err := svc.exchangeCode(context.Background(), "code", entry)
|
|
if err == nil {
|
|
t.Fatalf("expected exchangeCode error")
|
|
}
|
|
}
|
|
|
|
func TestOIDCStateStorePutConsumeExpired(t *testing.T) {
|
|
store := &oidcStateStore{entries: make(map[string]*oidcStateEntry), stopCleanup: make(chan struct{})}
|
|
store.Put("expired", &oidcStateEntry{ExpiresAt: time.Now().Add(-time.Second)})
|
|
|
|
if entry, ok := store.Consume("expired"); ok || entry != nil {
|
|
t.Fatalf("expected expired entry to be rejected")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceMatchesCABundleHash(t *testing.T) {
|
|
file := t.TempDir() + "/ca.pem"
|
|
data := []byte("test-ca")
|
|
if err := os.WriteFile(file, data, 0o600); err != nil {
|
|
t.Fatalf("write file: %v", err)
|
|
}
|
|
|
|
hash, err := hashCABundle(file)
|
|
if err != nil {
|
|
t.Fatalf("hashCABundle error: %v", err)
|
|
}
|
|
|
|
svc := &OIDCService{snapshot: oidcSnapshot{issuer: "iss", clientID: "id", clientSecret: "secret", redirectURL: "cb", scopes: []string{"openid"}, caBundle: file, caBundleHash: hash}}
|
|
cfg := &config.OIDCConfig{Enabled: true, IssuerURL: "iss", ClientID: "id", ClientSecret: "secret", RedirectURL: "cb", Scopes: []string{"openid"}, CABundle: file}
|
|
|
|
if !svc.Matches(cfg) {
|
|
t.Fatalf("expected CABundle hash to match")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceAuthCodeURLMatchesNonce(t *testing.T) {
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{Endpoint: oauth2.Endpoint{AuthURL: "https://issuer.example.com/auth"}, ClientID: "client"},
|
|
}
|
|
entry := &oidcStateEntry{Nonce: "nonce"}
|
|
url := svc.authCodeURL("state", entry)
|
|
if !strings.Contains(url, "nonce=nonce") {
|
|
t.Fatalf("expected nonce in url, got %q", url)
|
|
}
|
|
}
|
|
|
|
func TestOIDCStateStoreConsumeUnknown(t *testing.T) {
|
|
store := &oidcStateStore{entries: make(map[string]*oidcStateEntry), stopCleanup: make(chan struct{})}
|
|
if entry, ok := store.Consume("missing"); ok || entry != nil {
|
|
t.Fatalf("expected missing entry to return false")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceNewStateEntryFields(t *testing.T) {
|
|
svc := &OIDCService{stateStore: newOIDCStateStore()}
|
|
_, entry, err := svc.newStateEntry("/return")
|
|
if err != nil {
|
|
t.Fatalf("newStateEntry error: %v", err)
|
|
}
|
|
if entry.Nonce == "" || entry.CodeVerifier == "" || entry.CodeChallenge == "" {
|
|
t.Fatalf("expected nonce and pkce fields")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceConsumeStateMissing(t *testing.T) {
|
|
svc := &OIDCService{stateStore: newOIDCStateStore()}
|
|
entry, ok := svc.consumeState("missing")
|
|
if ok || entry != nil {
|
|
t.Fatalf("expected missing state to return false")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceAuthCodeURLWithPKCEAndNonce(t *testing.T) {
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{Endpoint: oauth2.Endpoint{AuthURL: "https://issuer.example.com/auth"}, ClientID: "client"},
|
|
}
|
|
entry := &oidcStateEntry{Nonce: "nonce", CodeChallenge: "challenge"}
|
|
url := svc.authCodeURL("state", entry)
|
|
if !strings.Contains(url, "code_challenge=challenge") || !strings.Contains(url, "nonce=nonce") {
|
|
t.Fatalf("unexpected auth url: %q", url)
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceContextWithHTTPClientNil(t *testing.T) {
|
|
svc := &OIDCService{}
|
|
ctx := svc.contextWithHTTPClient(context.Background())
|
|
if ctx == nil {
|
|
t.Fatalf("expected context")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceRefreshTokenKeepsOldRefresh(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprint(w, `{"access_token":"access","token_type":"Bearer","expires_in":3600}`)
|
|
}))
|
|
defer server.Close()
|
|
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: server.URL}, ClientID: "client", ClientSecret: "secret"},
|
|
httpClient: server.Client(),
|
|
}
|
|
|
|
result, err := svc.RefreshToken(context.Background(), "old-refresh")
|
|
if err != nil {
|
|
t.Fatalf("RefreshToken error: %v", err)
|
|
}
|
|
if result.RefreshToken != "old-refresh" {
|
|
t.Fatalf("expected old refresh token to be preserved")
|
|
}
|
|
if result.AccessToken == "" {
|
|
t.Fatalf("expected access token")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceRefreshTokenReplacesRefresh(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprint(w, `{"access_token":"access","refresh_token":"new-refresh","token_type":"Bearer","expires_in":3600}`)
|
|
}))
|
|
defer server.Close()
|
|
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: server.URL}, ClientID: "client", ClientSecret: "secret"},
|
|
httpClient: server.Client(),
|
|
}
|
|
|
|
result, err := svc.RefreshToken(context.Background(), "old-refresh")
|
|
if err != nil {
|
|
t.Fatalf("RefreshToken error: %v", err)
|
|
}
|
|
if result.RefreshToken != "new-refresh" {
|
|
t.Fatalf("expected refresh token to be replaced")
|
|
}
|
|
if result.AccessToken == "" {
|
|
t.Fatalf("expected access token")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceExchangeCodeMissingVerifier(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
body, _ := io.ReadAll(r.Body)
|
|
if strings.Contains(string(body), "code_verifier") {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprint(w, `{"access_token":"access","token_type":"Bearer","expires_in":3600}`)
|
|
}))
|
|
defer server.Close()
|
|
|
|
svc := &OIDCService{
|
|
oauth2Cfg: &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: server.URL}, ClientID: "client", ClientSecret: "secret"},
|
|
httpClient: server.Client(),
|
|
}
|
|
|
|
entry := &oidcStateEntry{}
|
|
_, err := svc.exchangeCode(context.Background(), "code", entry)
|
|
if err != nil {
|
|
t.Fatalf("exchangeCode error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceMatchesScopeMismatch(t *testing.T) {
|
|
svc := &OIDCService{snapshot: oidcSnapshot{issuer: "iss", clientID: "id", clientSecret: "secret", redirectURL: "cb", scopes: []string{"openid"}}}
|
|
cfg := &config.OIDCConfig{Enabled: true, IssuerURL: "iss", ClientID: "id", ClientSecret: "secret", RedirectURL: "cb", Scopes: []string{"openid", "email"}}
|
|
|
|
if svc.Matches(cfg) {
|
|
t.Fatalf("expected scope mismatch")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceMatchesNil(t *testing.T) {
|
|
var svc *OIDCService
|
|
if svc.Matches(&config.OIDCConfig{}) {
|
|
t.Fatalf("expected Matches to be false for nil service")
|
|
}
|
|
}
|
|
|
|
func TestOIDCServiceMatchesNilConfig(t *testing.T) {
|
|
svc := &OIDCService{}
|
|
if svc.Matches(nil) {
|
|
t.Fatalf("expected Matches to be false for nil config")
|
|
}
|
|
}
|
|
|
|
func TestGenerateRandomURLString_ErrorSize(t *testing.T) {
|
|
if _, err := generateRandomURLString(0); err != nil {
|
|
t.Fatalf("expected no error for size 0, got %v", err)
|
|
}
|
|
}
|