mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 08:57:12 +00:00
583 lines
18 KiB
Go
583 lines
18 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
|
)
|
|
|
|
func signHandoffToken(t *testing.T, key []byte, claims cloudHandoffClaims) string {
|
|
t.Helper()
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signed, err := token.SignedString(key)
|
|
if err != nil {
|
|
t.Fatalf("SignedString() error = %v", err)
|
|
}
|
|
return signed
|
|
}
|
|
|
|
func makeExchangeRequest(t *testing.T, handler http.HandlerFunc, host, token string) *httptest.ResponseRecorder {
|
|
t.Helper()
|
|
form := url.Values{}
|
|
if token != "" {
|
|
form.Set("token", token)
|
|
}
|
|
req := httptest.NewRequest(http.MethodPost, "/api/cloud/handoff/exchange?format=json", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Host = host
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
return rec
|
|
}
|
|
|
|
func TestIsSQLiteUniqueViolation(t *testing.T) {
|
|
if isSQLiteUniqueViolation(nil) {
|
|
t.Fatal("expected nil error to return false")
|
|
}
|
|
if !isSQLiteUniqueViolation(errors.New("UNIQUE constraint failed: handoff_jti.jti")) {
|
|
t.Fatal("expected UNIQUE constraint failure to return true")
|
|
}
|
|
if !isSQLiteUniqueViolation(errors.New("constraint failed")) {
|
|
t.Fatal("expected generic constraint failure to return true")
|
|
}
|
|
if isSQLiteUniqueViolation(errors.New("some other error")) {
|
|
t.Fatal("expected unrelated error to return false")
|
|
}
|
|
}
|
|
|
|
func TestTenantIDFromRequest(t *testing.T) {
|
|
t.Run("uses env var when present", func(t *testing.T) {
|
|
t.Setenv("PULSE_HOSTED_MODE", "")
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
|
|
resetTrustedProxyConfig()
|
|
t.Setenv("PULSE_TENANT_ID", "env-tenant")
|
|
t.Setenv("PULSE_PUBLIC_URL", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Host = "host-tenant.example.com"
|
|
req.RemoteAddr = "198.51.100.10:4567"
|
|
|
|
if got := tenantIDFromRequest(req); got != "env-tenant" {
|
|
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "env-tenant")
|
|
}
|
|
})
|
|
|
|
t.Run("extracts subdomain from loopback host", func(t *testing.T) {
|
|
t.Setenv("PULSE_HOSTED_MODE", "")
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
|
|
resetTrustedProxyConfig()
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
t.Setenv("PULSE_PUBLIC_URL", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Host = "tenant.example.com:8443"
|
|
req.RemoteAddr = "127.0.0.1:8080"
|
|
|
|
if got := tenantIDFromRequest(req); got != "tenant" {
|
|
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "tenant")
|
|
}
|
|
})
|
|
|
|
t.Run("returns full loopback host when no dot exists", func(t *testing.T) {
|
|
t.Setenv("PULSE_HOSTED_MODE", "")
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
|
|
resetTrustedProxyConfig()
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
t.Setenv("PULSE_PUBLIC_URL", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Host = "localhost"
|
|
req.RemoteAddr = "127.0.0.1:8080"
|
|
|
|
if got := tenantIDFromRequest(req); got != "localhost" {
|
|
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "localhost")
|
|
}
|
|
})
|
|
|
|
t.Run("extracts tenant from trusted proxy forwarded host", func(t *testing.T) {
|
|
t.Setenv("PULSE_HOSTED_MODE", "")
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "203.0.113.0/24")
|
|
resetTrustedProxyConfig()
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
t.Setenv("PULSE_PUBLIC_URL", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "203.0.113.10:9443"
|
|
req.Host = "ignored.example.com"
|
|
req.Header.Set("X-Forwarded-Host", "proxy-tenant.example.com")
|
|
|
|
if got := tenantIDFromRequest(req); got != "proxy-tenant" {
|
|
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "proxy-tenant")
|
|
}
|
|
})
|
|
|
|
t.Run("falls back to hosted public url when tenant env is missing", func(t *testing.T) {
|
|
t.Setenv("PULSE_HOSTED_MODE", "true")
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
|
|
resetTrustedProxyConfig()
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
t.Setenv("PULSE_PUBLIC_URL", "https://tenant-from-public.cloud.pulserelay.pro")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "198.51.100.20:1234"
|
|
req.Host = "untrusted.example.com"
|
|
|
|
if got := tenantIDFromRequest(req); got != "tenant-from-public" {
|
|
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "tenant-from-public")
|
|
}
|
|
})
|
|
|
|
t.Run("uses hosted request host when tenant env and public url are missing", func(t *testing.T) {
|
|
t.Setenv("PULSE_HOSTED_MODE", "true")
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
|
|
resetTrustedProxyConfig()
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
t.Setenv("PULSE_PUBLIC_URL", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "198.51.100.20:1234"
|
|
req.Host = "tenant-from-host.cloud.pulserelay.pro"
|
|
|
|
if got := tenantIDFromRequest(req); got != "tenant-from-host" {
|
|
t.Fatalf("tenantIDFromRequest() = %q, want %q", got, "tenant-from-host")
|
|
}
|
|
})
|
|
|
|
t.Run("ignores untrusted remote host header", func(t *testing.T) {
|
|
t.Setenv("PULSE_HOSTED_MODE", "")
|
|
t.Setenv("PULSE_TRUSTED_PROXY_CIDRS", "")
|
|
resetTrustedProxyConfig()
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
t.Setenv("PULSE_PUBLIC_URL", "")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "198.51.100.20:1234"
|
|
req.Host = "tenant.example.com"
|
|
|
|
if got := tenantIDFromRequest(req); got != "" {
|
|
t.Fatalf("tenantIDFromRequest() = %q, want empty", got)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTenantIDFromPublicURL(t *testing.T) {
|
|
t.Run("extracts tenant from hosted public url", func(t *testing.T) {
|
|
if got := tenantIDFromPublicURL("https://tenant-a.cloud.pulserelay.pro"); got != "tenant-a" {
|
|
t.Fatalf("tenantIDFromPublicURL() = %q, want %q", got, "tenant-a")
|
|
}
|
|
})
|
|
|
|
t.Run("rejects invalid public url", func(t *testing.T) {
|
|
if got := tenantIDFromPublicURL("://not-a-url"); got != "" {
|
|
t.Fatalf("tenantIDFromPublicURL() = %q, want empty", got)
|
|
}
|
|
if got := tenantIDFromPublicURL("https://cloud.pulserelay.pro"); got != "" {
|
|
t.Fatalf("tenantIDFromPublicURL() = %q, want empty", got)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJTIReplayStoreCheckAndStore(t *testing.T) {
|
|
store := &jtiReplayStore{configDir: t.TempDir()}
|
|
expires := time.Now().Add(time.Hour)
|
|
|
|
stored, err := store.checkAndStore("abc123", expires)
|
|
if err != nil {
|
|
t.Fatalf("first checkAndStore() error = %v", err)
|
|
}
|
|
if !stored {
|
|
t.Fatal("first checkAndStore() = false, want true")
|
|
}
|
|
|
|
stored, err = store.checkAndStore("abc123", expires)
|
|
if err != nil {
|
|
t.Fatalf("second checkAndStore() error = %v", err)
|
|
}
|
|
if stored {
|
|
t.Fatal("second checkAndStore() = true, want false")
|
|
}
|
|
|
|
_, err = store.checkAndStore(" ", expires)
|
|
if err == nil {
|
|
t.Fatal("expected empty jti to return error")
|
|
}
|
|
}
|
|
|
|
func TestJTIReplayStoreDelete(t *testing.T) {
|
|
store := &jtiReplayStore{configDir: t.TempDir()}
|
|
expires := time.Now().Add(time.Hour)
|
|
|
|
stored, err := store.checkAndStore("abc123", expires)
|
|
if err != nil {
|
|
t.Fatalf("checkAndStore() error = %v", err)
|
|
}
|
|
if !stored {
|
|
t.Fatal("checkAndStore() = false, want true")
|
|
}
|
|
|
|
if err := store.delete("abc123"); err != nil {
|
|
t.Fatalf("delete() error = %v", err)
|
|
}
|
|
|
|
stored, err = store.checkAndStore("abc123", expires)
|
|
if err != nil {
|
|
t.Fatalf("checkAndStore() after delete error = %v", err)
|
|
}
|
|
if !stored {
|
|
t.Fatal("checkAndStore() after delete = false, want true")
|
|
}
|
|
}
|
|
|
|
func TestJTIReplayStoreSecuresPermissionModes(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("POSIX permission bits are not enforced on Windows")
|
|
}
|
|
|
|
configDir := t.TempDir()
|
|
store := &jtiReplayStore{configDir: configDir}
|
|
t.Cleanup(func() {
|
|
if store.db != nil {
|
|
_ = store.db.Close()
|
|
}
|
|
})
|
|
|
|
stored, err := store.checkAndStore("jti-perms", time.Now().Add(time.Hour))
|
|
if err != nil {
|
|
t.Fatalf("checkAndStore() error = %v", err)
|
|
}
|
|
if !stored {
|
|
t.Fatal("expected first JTI insert to store")
|
|
}
|
|
|
|
secretsDir := filepath.Join(configDir, "secrets")
|
|
dirInfo, err := os.Stat(secretsDir)
|
|
if err != nil {
|
|
t.Fatalf("stat secrets dir: %v", err)
|
|
}
|
|
if got := dirInfo.Mode().Perm(); got != handoffPrivateDirPerm {
|
|
t.Fatalf("secrets dir mode = %#o, want %#o", got, handoffPrivateDirPerm)
|
|
}
|
|
|
|
dbPath := filepath.Join(secretsDir, "handoff_jti.db")
|
|
dbInfo, err := os.Stat(dbPath)
|
|
if err != nil {
|
|
t.Fatalf("stat handoff db: %v", err)
|
|
}
|
|
if got := dbInfo.Mode().Perm(); got != handoffPrivateFilePerm {
|
|
t.Fatalf("handoff db mode = %#o, want %#o", got, handoffPrivateFilePerm)
|
|
}
|
|
}
|
|
|
|
func TestHandleHandoffExchange(t *testing.T) {
|
|
key := []byte("test-handoff-key")
|
|
configDir := t.TempDir()
|
|
resetSessionStoreForTests()
|
|
t.Cleanup(resetSessionStoreForTests)
|
|
resetCSRFStoreForTests()
|
|
t.Cleanup(resetCSRFStoreForTests)
|
|
InitSessionStore(configDir)
|
|
InitCSRFStore(configDir)
|
|
secretsDir := filepath.Join(configDir, "secrets")
|
|
if err := os.MkdirAll(secretsDir, 0o755); err != nil {
|
|
t.Fatalf("MkdirAll() error = %v", err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(secretsDir, "handoff.key"), key, 0o600); err != nil {
|
|
t.Fatalf("WriteFile() error = %v", err)
|
|
}
|
|
|
|
handler := HandleHandoffExchange(configDir)
|
|
tenantID := "tenant-a"
|
|
host := tenantID + ".example.com"
|
|
|
|
t.Run("missing token returns bad request", func(t *testing.T) {
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
rec := makeExchangeRequest(t, handler, host, "")
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
|
}
|
|
})
|
|
|
|
t.Run("invalid token returns unauthorized", func(t *testing.T) {
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
rec := makeExchangeRequest(t, handler, host, "not-a-jwt")
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
|
}
|
|
})
|
|
|
|
t.Run("missing tenant context returns internal error", func(t *testing.T) {
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
rec := makeExchangeRequest(t, handler, "", "anything")
|
|
if rec.Code != http.StatusInternalServerError {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
t.Run("missing exp claim returns unauthorized", func(t *testing.T) {
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
token := signHandoffToken(t, key, cloudHandoffClaims{
|
|
AccountID: "acct-1",
|
|
Email: "user@example.com",
|
|
Role: "admin",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "jti-no-exp",
|
|
Subject: "user-1",
|
|
Issuer: cloudHandoffIssuer,
|
|
Audience: jwt.ClaimStrings{tenantID},
|
|
ExpiresAt: nil,
|
|
},
|
|
})
|
|
|
|
rec := makeExchangeRequest(t, handler, host, token)
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
|
}
|
|
})
|
|
|
|
t.Run("successful exchange and replay rejection", func(t *testing.T) {
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
token := signHandoffToken(t, key, cloudHandoffClaims{
|
|
AccountID: "acct-123",
|
|
Email: "user@example.com",
|
|
Role: "owner",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "jti-success",
|
|
Subject: "user-123",
|
|
Issuer: cloudHandoffIssuer,
|
|
Audience: jwt.ClaimStrings{tenantID},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
})
|
|
|
|
first := makeExchangeRequest(t, handler, host, token)
|
|
if first.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", first.Code, http.StatusOK)
|
|
}
|
|
if got := first.Header().Get("Content-Type"); got != "application/json" {
|
|
t.Fatalf("Content-Type = %q, want %q", got, "application/json")
|
|
}
|
|
|
|
var payload map[string]any
|
|
if err := json.Unmarshal(first.Body.Bytes(), &payload); err != nil {
|
|
t.Fatalf("response json decode error = %v", err)
|
|
}
|
|
if ok, _ := payload["ok"].(bool); !ok {
|
|
t.Fatalf("payload ok = %v, want true", payload["ok"])
|
|
}
|
|
if got, _ := payload["tenant_id"].(string); got != tenantID {
|
|
t.Fatalf("tenant_id = %q, want %q", got, tenantID)
|
|
}
|
|
if got, _ := payload["account_id"].(string); got != "acct-123" {
|
|
t.Fatalf("account_id = %q, want %q", got, "acct-123")
|
|
}
|
|
|
|
second := makeExchangeRequest(t, handler, host, token)
|
|
if second.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status = %d, want %d", second.Code, http.StatusUnauthorized)
|
|
}
|
|
})
|
|
|
|
t.Run("normalizes mixed-case email before session creation", func(t *testing.T) {
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
token := signHandoffToken(t, key, cloudHandoffClaims{
|
|
AccountID: "acct-123",
|
|
Email: "Operator.Owner+Mixed@PulseRelay.Pro",
|
|
Role: "owner",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "jti-mixed-email",
|
|
Subject: "user-123",
|
|
Issuer: cloudHandoffIssuer,
|
|
Audience: jwt.ClaimStrings{tenantID},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
})
|
|
|
|
rec := makeExchangeRequest(t, handler, host, token)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
|
}
|
|
|
|
var payload map[string]any
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
|
t.Fatalf("response json decode error = %v", err)
|
|
}
|
|
if got, _ := payload["email"].(string); got != "operator.owner+mixed@pulserelay.pro" {
|
|
t.Fatalf("email = %q, want %q", got, "operator.owner+mixed@pulserelay.pro")
|
|
}
|
|
|
|
var sessionCookie *http.Cookie
|
|
for _, cookie := range rec.Result().Cookies() {
|
|
if strings.HasPrefix(cookie.Name, "pulse_session") {
|
|
sessionCookie = cookie
|
|
break
|
|
}
|
|
}
|
|
if sessionCookie == nil {
|
|
t.Fatal("expected pulse_session cookie to be set")
|
|
}
|
|
|
|
session := GetSessionStore().GetSession(sessionCookie.Value)
|
|
if session == nil {
|
|
t.Fatal("expected session to exist")
|
|
}
|
|
if session.Username != "operator.owner+mixed@pulserelay.pro" {
|
|
t.Fatalf("session username = %q, want %q", session.Username, "operator.owner+mixed@pulserelay.pro")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHandleHandoffExchangeBrowserFlowSetsSessionCookies(t *testing.T) {
|
|
key := []byte("test-handoff-key")
|
|
configDir := t.TempDir()
|
|
resetSessionStoreForTests()
|
|
t.Cleanup(resetSessionStoreForTests)
|
|
resetCSRFStoreForTests()
|
|
t.Cleanup(resetCSRFStoreForTests)
|
|
InitSessionStore(configDir)
|
|
InitCSRFStore(configDir)
|
|
secretsDir := filepath.Join(configDir, "secrets")
|
|
if err := os.MkdirAll(secretsDir, 0o755); err != nil {
|
|
t.Fatalf("MkdirAll() error = %v", err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(secretsDir, "handoff.key"), key, 0o600); err != nil {
|
|
t.Fatalf("WriteFile() error = %v", err)
|
|
}
|
|
|
|
handler := HandleHandoffExchange(configDir)
|
|
tenantID := "tenant-browser"
|
|
token := signHandoffToken(t, key, cloudHandoffClaims{
|
|
AccountID: "acct-browser",
|
|
Email: "browser@example.com",
|
|
Role: "owner",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "jti-browser",
|
|
Subject: "user-browser",
|
|
Issuer: cloudHandoffIssuer,
|
|
Audience: jwt.ClaimStrings{tenantID},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
})
|
|
|
|
form := url.Values{}
|
|
form.Set("token", token)
|
|
req := httptest.NewRequest(http.MethodPost, "/api/cloud/handoff/exchange", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Host = tenantID + ".example.com"
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusTemporaryRedirect {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusTemporaryRedirect)
|
|
}
|
|
if loc := rec.Header().Get("Location"); loc != "/" {
|
|
t.Fatalf("redirect location = %q, want %q", loc, "/")
|
|
}
|
|
|
|
cookies := rec.Result().Cookies()
|
|
var haveSession, haveCSRF, haveOrg bool
|
|
for _, c := range cookies {
|
|
switch c.Name {
|
|
case "pulse_session":
|
|
haveSession = c.Value != ""
|
|
case "pulse_csrf":
|
|
haveCSRF = c.Value != ""
|
|
case "pulse_org_id":
|
|
haveOrg = c.Value == tenantID
|
|
}
|
|
}
|
|
|
|
if !haveSession {
|
|
t.Fatal("expected pulse_session cookie")
|
|
}
|
|
if !haveCSRF {
|
|
t.Fatal("expected pulse_csrf cookie")
|
|
}
|
|
if !haveOrg {
|
|
t.Fatal("expected pulse_org_id cookie for tenant")
|
|
}
|
|
}
|
|
|
|
func TestHandleHandoffExchangeEnsuresTenantOrganizationMembership(t *testing.T) {
|
|
key := []byte("test-handoff-key")
|
|
configDir := t.TempDir()
|
|
resetSessionStoreForTests()
|
|
t.Cleanup(resetSessionStoreForTests)
|
|
resetCSRFStoreForTests()
|
|
t.Cleanup(resetCSRFStoreForTests)
|
|
InitSessionStore(configDir)
|
|
InitCSRFStore(configDir)
|
|
|
|
secretsDir := filepath.Join(configDir, "secrets")
|
|
if err := os.MkdirAll(secretsDir, 0o755); err != nil {
|
|
t.Fatalf("MkdirAll() error = %v", err)
|
|
}
|
|
if err := os.WriteFile(filepath.Join(secretsDir, "handoff.key"), key, 0o600); err != nil {
|
|
t.Fatalf("WriteFile() error = %v", err)
|
|
}
|
|
|
|
tenantID := "tenant-membership"
|
|
mtp := config.NewMultiTenantPersistence(configDir)
|
|
if err := mtp.SaveOrganization(&models.Organization{
|
|
ID: tenantID,
|
|
DisplayName: "Membership Test",
|
|
Status: models.OrgStatusActive,
|
|
CreatedAt: time.Now().UTC(),
|
|
OwnerUserID: "legacy-owner@example.com",
|
|
Members: []models.OrganizationMember{
|
|
{UserID: "legacy-owner@example.com", Role: models.OrgRoleOwner, AddedAt: time.Now().UTC()},
|
|
},
|
|
}); err != nil {
|
|
t.Fatalf("SaveOrganization() error = %v", err)
|
|
}
|
|
|
|
handler := HandleHandoffExchange(configDir)
|
|
token := signHandoffToken(t, key, cloudHandoffClaims{
|
|
AccountID: "acct-membership",
|
|
Email: "courtmanr@gmail.com",
|
|
Role: "owner",
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: "jti-membership",
|
|
Subject: "user-membership",
|
|
Issuer: cloudHandoffIssuer,
|
|
Audience: jwt.ClaimStrings{tenantID},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
})
|
|
|
|
rec := makeExchangeRequest(t, handler, tenantID+".example.com", token)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
|
|
}
|
|
|
|
org, err := mtp.LoadOrganization(tenantID)
|
|
if err != nil {
|
|
t.Fatalf("LoadOrganization() error = %v", err)
|
|
}
|
|
if org.OwnerUserID != "legacy-owner@example.com" {
|
|
t.Fatalf("OwnerUserID = %q, want %q", org.OwnerUserID, "legacy-owner@example.com")
|
|
}
|
|
if got := org.GetMemberRole("courtmanr@gmail.com"); got != models.OrgRoleOwner {
|
|
t.Fatalf("role for courtmanr@gmail.com = %q, want %q", got, models.OrgRoleOwner)
|
|
}
|
|
if !org.CanUserAccess("courtmanr@gmail.com") {
|
|
t.Fatal("expected handed-off user to have tenant organization access")
|
|
}
|
|
}
|
|
|
|
func TestHandleHandoffExchangeKeyMissing(t *testing.T) {
|
|
handler := HandleHandoffExchange(t.TempDir())
|
|
t.Setenv("PULSE_TENANT_ID", "")
|
|
|
|
rec := makeExchangeRequest(t, handler, "tenant.example.com", "anything")
|
|
if rec.Code != http.StatusNotFound {
|
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
|
|
}
|
|
}
|