Pulse/internal/api/cloud_handoff_handlers_test.go

583 lines
18 KiB
Go

package api
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"github.com/rcourtman/pulse-go-rewrite/internal/models"
)
func signHandoffToken(t *testing.T, key []byte, claims cloudHandoffClaims) string {
t.Helper()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(key)
if err != nil {
t.Fatalf("SignedString() error = %v", err)
}
return signed
}
func makeExchangeRequest(t *testing.T, handler http.HandlerFunc, host, token string) *httptest.ResponseRecorder {
t.Helper()
form := url.Values{}
if token != "" {
form.Set("token", token)
}
req := httptest.NewRequest(http.MethodPost, "/api/cloud/handoff/exchange?format=json", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Host = host
req.RemoteAddr = "127.0.0.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
return rec
}
func TestIsSQLiteUniqueViolation(t *testing.T) {
if isSQLiteUniqueViolation(nil) {
t.Fatal("expected nil error to return false")
}
if !isSQLiteUniqueViolation(errors.New("UNIQUE constraint failed: handoff_jti.jti")) {
t.Fatal("expected UNIQUE constraint failure to return true")
}
if !isSQLiteUniqueViolation(errors.New("constraint failed")) {
t.Fatal("expected generic constraint failure to return true")
}
if isSQLiteUniqueViolation(errors.New("some other error")) {
t.Fatal("expected unrelated error to return false")
}
}
func TestTenantIDFromRequest(t *testing.T) {
t.Run("uses env var when present", func(t *testing.T) {
t.Setenv("PULSE_HOSTED_MODE", "")
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
resetTrustedProxyConfig()
t.Setenv("PULSE_TENANT_ID", "env-tenant")
t.Setenv("PULSE_PUBLIC_URL", "")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = "host-tenant.example.com"
req.RemoteAddr = "198.51.100.10:4567"
if got := tenantIDFromRequest(req); got != "env-tenant" {
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "env-tenant")
}
})
t.Run("extracts subdomain from loopback host", func(t *testing.T) {
t.Setenv("PULSE_HOSTED_MODE", "")
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
resetTrustedProxyConfig()
t.Setenv("PULSE_TENANT_ID", "")
t.Setenv("PULSE_PUBLIC_URL", "")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = "tenant.example.com:8443"
req.RemoteAddr = "127.0.0.1:8080"
if got := tenantIDFromRequest(req); got != "tenant" {
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "tenant")
}
})
t.Run("returns full loopback host when no dot exists", func(t *testing.T) {
t.Setenv("PULSE_HOSTED_MODE", "")
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
resetTrustedProxyConfig()
t.Setenv("PULSE_TENANT_ID", "")
t.Setenv("PULSE_PUBLIC_URL", "")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = "localhost"
req.RemoteAddr = "127.0.0.1:8080"
if got := tenantIDFromRequest(req); got != "localhost" {
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "localhost")
}
})
t.Run("extracts tenant from trusted proxy forwarded host", func(t *testing.T) {
t.Setenv("PULSE_HOSTED_MODE", "")
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "203.0.113.0/24")
resetTrustedProxyConfig()
t.Setenv("PULSE_TENANT_ID", "")
t.Setenv("PULSE_PUBLIC_URL", "")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.10:9443"
req.Host = "ignored.example.com"
req.Header.Set("X-Forwarded-Host", "proxy-tenant.example.com")
if got := tenantIDFromRequest(req); got != "proxy-tenant" {
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "proxy-tenant")
}
})
t.Run("falls back to hosted public url when tenant env is missing", func(t *testing.T) {
t.Setenv("PULSE_HOSTED_MODE", "true")
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
resetTrustedProxyConfig()
t.Setenv("PULSE_TENANT_ID", "")
t.Setenv("PULSE_PUBLIC_URL", "https://tenant-from-public.cloud.pulserelay.pro")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "198.51.100.20:1234"
req.Host = "untrusted.example.com"
if got := tenantIDFromRequest(req); got != "tenant-from-public" {
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "tenant-from-public")
}
})
t.Run("uses hosted request host when tenant env and public url are missing", func(t *testing.T) {
t.Setenv("PULSE_HOSTED_MODE", "true")
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
resetTrustedProxyConfig()
t.Setenv("PULSE_TENANT_ID", "")
t.Setenv("PULSE_PUBLIC_URL", "")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "198.51.100.20:1234"
req.Host = "tenant-from-host.cloud.pulserelay.pro"
if got := tenantIDFromRequest(req); got != "tenant-from-host" {
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "tenant-from-host")
}
})
t.Run("ignores untrusted remote host header", func(t *testing.T) {
t.Setenv("PULSE_HOSTED_MODE", "")
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
resetTrustedProxyConfig()
t.Setenv("PULSE_TENANT_ID", "")
t.Setenv("PULSE_PUBLIC_URL", "")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "198.51.100.20:1234"
req.Host = "tenant.example.com"
if got := tenantIDFromRequest(req); got != "" {
t.Fatalf("tenantIDFromRequest() = %q, want empty", got)
}
})
}
func TestTenantIDFromPublicURL(t *testing.T) {
t.Run("extracts tenant from hosted public url", func(t *testing.T) {
if got := tenantIDFromPublicURL("https://tenant-a.cloud.pulserelay.pro"); got != "tenant-a" {
t.Fatalf("tenantIDFromPublicURL() = %q, want %q", got, "tenant-a")
}
})
t.Run("rejects invalid public url", func(t *testing.T) {
if got := tenantIDFromPublicURL("://not-a-url"); got != "" {
t.Fatalf("tenantIDFromPublicURL() = %q, want empty", got)
}
if got := tenantIDFromPublicURL("https://cloud.pulserelay.pro"); got != "" {
t.Fatalf("tenantIDFromPublicURL() = %q, want empty", got)
}
})
}
func TestJTIReplayStoreCheckAndStore(t *testing.T) {
store := &jtiReplayStore{configDir: t.TempDir()}
expires := time.Now().Add(time.Hour)
stored, err := store.checkAndStore("abc123", expires)
if err != nil {
t.Fatalf("first checkAndStore() error = %v", err)
}
if !stored {
t.Fatal("first checkAndStore() = false, want true")
}
stored, err = store.checkAndStore("abc123", expires)
if err != nil {
t.Fatalf("second checkAndStore() error = %v", err)
}
if stored {
t.Fatal("second checkAndStore() = true, want false")
}
_, err = store.checkAndStore(" ", expires)
if err == nil {
t.Fatal("expected empty jti to return error")
}
}
func TestJTIReplayStoreDelete(t *testing.T) {
store := &jtiReplayStore{configDir: t.TempDir()}
expires := time.Now().Add(time.Hour)
stored, err := store.checkAndStore("abc123", expires)
if err != nil {
t.Fatalf("checkAndStore() error = %v", err)
}
if !stored {
t.Fatal("checkAndStore() = false, want true")
}
if err := store.delete("abc123"); err != nil {
t.Fatalf("delete() error = %v", err)
}
stored, err = store.checkAndStore("abc123", expires)
if err != nil {
t.Fatalf("checkAndStore() after delete error = %v", err)
}
if !stored {
t.Fatal("checkAndStore() after delete = false, want true")
}
}
func TestJTIReplayStoreSecuresPermissionModes(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("POSIX permission bits are not enforced on Windows")
}
configDir := t.TempDir()
store := &jtiReplayStore{configDir: configDir}
t.Cleanup(func() {
if store.db != nil {
_ = store.db.Close()
}
})
stored, err := store.checkAndStore("jti-perms", time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("checkAndStore() error = %v", err)
}
if !stored {
t.Fatal("expected first JTI insert to store")
}
secretsDir := filepath.Join(configDir, "secrets")
dirInfo, err := os.Stat(secretsDir)
if err != nil {
t.Fatalf("stat secrets dir: %v", err)
}
if got := dirInfo.Mode().Perm(); got != handoffPrivateDirPerm {
t.Fatalf("secrets dir mode = %#o, want %#o", got, handoffPrivateDirPerm)
}
dbPath := filepath.Join(secretsDir, "handoff_jti.db")
dbInfo, err := os.Stat(dbPath)
if err != nil {
t.Fatalf("stat handoff db: %v", err)
}
if got := dbInfo.Mode().Perm(); got != handoffPrivateFilePerm {
t.Fatalf("handoff db mode = %#o, want %#o", got, handoffPrivateFilePerm)
}
}
func TestHandleHandoffExchange(t *testing.T) {
key := []byte("test-handoff-key")
configDir := t.TempDir()
resetSessionStoreForTests()
t.Cleanup(resetSessionStoreForTests)
resetCSRFStoreForTests()
t.Cleanup(resetCSRFStoreForTests)
InitSessionStore(configDir)
InitCSRFStore(configDir)
secretsDir := filepath.Join(configDir, "secrets")
if err := os.MkdirAll(secretsDir, 0o755); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
if err := os.WriteFile(filepath.Join(secretsDir, "handoff.key"), key, 0o600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
handler := HandleHandoffExchange(configDir)
tenantID := "tenant-a"
host := tenantID + ".example.com"
t.Run("missing token returns bad request", func(t *testing.T) {
t.Setenv("PULSE_TENANT_ID", "")
rec := makeExchangeRequest(t, handler, host, "")
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
})
t.Run("invalid token returns unauthorized", func(t *testing.T) {
t.Setenv("PULSE_TENANT_ID", "")
rec := makeExchangeRequest(t, handler, host, "not-a-jwt")
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
})
t.Run("missing tenant context returns internal error", func(t *testing.T) {
t.Setenv("PULSE_TENANT_ID", "")
rec := makeExchangeRequest(t, handler, "", "anything")
if rec.Code != http.StatusInternalServerError {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
}
})
t.Run("missing exp claim returns unauthorized", func(t *testing.T) {
t.Setenv("PULSE_TENANT_ID", "")
token := signHandoffToken(t, key, cloudHandoffClaims{
AccountID: "acct-1",
Email: "user@example.com",
Role: "admin",
RegisteredClaims: jwt.RegisteredClaims{
ID: "jti-no-exp",
Subject: "user-1",
Issuer: cloudHandoffIssuer,
Audience: jwt.ClaimStrings{tenantID},
ExpiresAt: nil,
},
})
rec := makeExchangeRequest(t, handler, host, token)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
})
t.Run("successful exchange and replay rejection", func(t *testing.T) {
t.Setenv("PULSE_TENANT_ID", "")
token := signHandoffToken(t, key, cloudHandoffClaims{
AccountID: "acct-123",
Email: "user@example.com",
Role: "owner",
RegisteredClaims: jwt.RegisteredClaims{
ID: "jti-success",
Subject: "user-123",
Issuer: cloudHandoffIssuer,
Audience: jwt.ClaimStrings{tenantID},
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
})
first := makeExchangeRequest(t, handler, host, token)
if first.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", first.Code, http.StatusOK)
}
if got := first.Header().Get("Content-Type"); got != "application/json" {
t.Fatalf("Content-Type = %q, want %q", got, "application/json")
}
var payload map[string]any
if err := json.Unmarshal(first.Body.Bytes(), &payload); err != nil {
t.Fatalf("response json decode error = %v", err)
}
if ok, _ := payload["ok"].(bool); !ok {
t.Fatalf("payload ok = %v, want true", payload["ok"])
}
if got, _ := payload["tenant_id"].(string); got != tenantID {
t.Fatalf("tenant_id = %q, want %q", got, tenantID)
}
if got, _ := payload["account_id"].(string); got != "acct-123" {
t.Fatalf("account_id = %q, want %q", got, "acct-123")
}
second := makeExchangeRequest(t, handler, host, token)
if second.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", second.Code, http.StatusUnauthorized)
}
})
t.Run("normalizes mixed-case email before session creation", func(t *testing.T) {
t.Setenv("PULSE_TENANT_ID", "")
token := signHandoffToken(t, key, cloudHandoffClaims{
AccountID: "acct-123",
Email: "Operator.Owner+Mixed@PulseRelay.Pro",
Role: "owner",
RegisteredClaims: jwt.RegisteredClaims{
ID: "jti-mixed-email",
Subject: "user-123",
Issuer: cloudHandoffIssuer,
Audience: jwt.ClaimStrings{tenantID},
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
})
rec := makeExchangeRequest(t, handler, host, token)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("response json decode error = %v", err)
}
if got, _ := payload["email"].(string); got != "operator.owner+mixed@pulserelay.pro" {
t.Fatalf("email = %q, want %q", got, "operator.owner+mixed@pulserelay.pro")
}
var sessionCookie *http.Cookie
for _, cookie := range rec.Result().Cookies() {
if strings.HasPrefix(cookie.Name, "pulse_session") {
sessionCookie = cookie
break
}
}
if sessionCookie == nil {
t.Fatal("expected pulse_session cookie to be set")
}
session := GetSessionStore().GetSession(sessionCookie.Value)
if session == nil {
t.Fatal("expected session to exist")
}
if session.Username != "operator.owner+mixed@pulserelay.pro" {
t.Fatalf("session username = %q, want %q", session.Username, "operator.owner+mixed@pulserelay.pro")
}
})
}
func TestHandleHandoffExchangeBrowserFlowSetsSessionCookies(t *testing.T) {
key := []byte("test-handoff-key")
configDir := t.TempDir()
resetSessionStoreForTests()
t.Cleanup(resetSessionStoreForTests)
resetCSRFStoreForTests()
t.Cleanup(resetCSRFStoreForTests)
InitSessionStore(configDir)
InitCSRFStore(configDir)
secretsDir := filepath.Join(configDir, "secrets")
if err := os.MkdirAll(secretsDir, 0o755); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
if err := os.WriteFile(filepath.Join(secretsDir, "handoff.key"), key, 0o600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
handler := HandleHandoffExchange(configDir)
tenantID := "tenant-browser"
token := signHandoffToken(t, key, cloudHandoffClaims{
AccountID: "acct-browser",
Email: "browser@example.com",
Role: "owner",
RegisteredClaims: jwt.RegisteredClaims{
ID: "jti-browser",
Subject: "user-browser",
Issuer: cloudHandoffIssuer,
Audience: jwt.ClaimStrings{tenantID},
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
})
form := url.Values{}
form.Set("token", token)
req := httptest.NewRequest(http.MethodPost, "/api/cloud/handoff/exchange", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Host = tenantID + ".example.com"
req.RemoteAddr = "127.0.0.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusTemporaryRedirect {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusTemporaryRedirect)
}
if loc := rec.Header().Get("Location"); loc != "/" {
t.Fatalf("redirect location = %q, want %q", loc, "/")
}
cookies := rec.Result().Cookies()
var haveSession, haveCSRF, haveOrg bool
for _, c := range cookies {
switch c.Name {
case "pulse_session":
haveSession = c.Value != ""
case "pulse_csrf":
haveCSRF = c.Value != ""
case "pulse_org_id":
haveOrg = c.Value == tenantID
}
}
if !haveSession {
t.Fatal("expected pulse_session cookie")
}
if !haveCSRF {
t.Fatal("expected pulse_csrf cookie")
}
if !haveOrg {
t.Fatal("expected pulse_org_id cookie for tenant")
}
}
func TestHandleHandoffExchangeEnsuresTenantOrganizationMembership(t *testing.T) {
key := []byte("test-handoff-key")
configDir := t.TempDir()
resetSessionStoreForTests()
t.Cleanup(resetSessionStoreForTests)
resetCSRFStoreForTests()
t.Cleanup(resetCSRFStoreForTests)
InitSessionStore(configDir)
InitCSRFStore(configDir)
secretsDir := filepath.Join(configDir, "secrets")
if err := os.MkdirAll(secretsDir, 0o755); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
if err := os.WriteFile(filepath.Join(secretsDir, "handoff.key"), key, 0o600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
tenantID := "tenant-membership"
mtp := config.NewMultiTenantPersistence(configDir)
if err := mtp.SaveOrganization(&models.Organization{
ID: tenantID,
DisplayName: "Membership Test",
Status: models.OrgStatusActive,
CreatedAt: time.Now().UTC(),
OwnerUserID: "legacy-owner@example.com",
Members: []models.OrganizationMember{
{UserID: "legacy-owner@example.com", Role: models.OrgRoleOwner, AddedAt: time.Now().UTC()},
},
}); err != nil {
t.Fatalf("SaveOrganization() error = %v", err)
}
handler := HandleHandoffExchange(configDir)
token := signHandoffToken(t, key, cloudHandoffClaims{
AccountID: "acct-membership",
Email: "courtmanr@gmail.com",
Role: "owner",
RegisteredClaims: jwt.RegisteredClaims{
ID: "jti-membership",
Subject: "user-membership",
Issuer: cloudHandoffIssuer,
Audience: jwt.ClaimStrings{tenantID},
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
})
rec := makeExchangeRequest(t, handler, tenantID+".example.com", token)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}
org, err := mtp.LoadOrganization(tenantID)
if err != nil {
t.Fatalf("LoadOrganization() error = %v", err)
}
if org.OwnerUserID != "legacy-owner@example.com" {
t.Fatalf("OwnerUserID = %q, want %q", org.OwnerUserID, "legacy-owner@example.com")
}
if got := org.GetMemberRole("courtmanr@gmail.com"); got != models.OrgRoleOwner {
t.Fatalf("role for courtmanr@gmail.com = %q, want %q", got, models.OrgRoleOwner)
}
if !org.CanUserAccess("courtmanr@gmail.com") {
t.Fatal("expected handed-off user to have tenant organization access")
}
}
func TestHandleHandoffExchangeKeyMissing(t *testing.T) {
handler := HandleHandoffExchange(t.TempDir())
t.Setenv("PULSE_TENANT_ID", "")
rec := makeExchangeRequest(t, handler, "tenant.example.com", "anything")
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}