Pulse/internal/api/cloud_handoff_handlers.go

499 lines
12 KiB
Go

package api
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"github.com/rcourtman/pulse-go-rewrite/internal/models"
_ "modernc.org/sqlite"
)
const (
cloudHandoffIssuer = "pulse-cloud-control-plane"
handoffPrivateDirPerm = 0o700
handoffPrivateFilePerm = 0o600
)
type cloudHandoffClaims struct {
AccountID string `json:"account_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
type jtiReplayStore struct {
once sync.Once
db *sql.DB
mu sync.Mutex
configDir string
initErr error
}
const deleteExpiredHandoffJTIQuery = `
DELETE FROM handoff_jti INDEXED BY idx_handoff_jti_expires_at
WHERE expires_at <= ?`
func (s *jtiReplayStore) init() {
s.once.Do(func() {
dir := filepath.Clean(s.configDir)
if strings.TrimSpace(dir) == "" {
s.initErr = fmt.Errorf("configDir is required")
return
}
secretsDir := filepath.Join(dir, "secrets")
if err := os.MkdirAll(secretsDir, handoffPrivateDirPerm); err != nil {
s.initErr = fmt.Errorf("create handoff secrets dir: %w", err)
return
}
if err := os.Chmod(secretsDir, handoffPrivateDirPerm); err != nil {
s.initErr = fmt.Errorf("chmod handoff secrets dir: %w", err)
return
}
dbPath := filepath.Join(secretsDir, "handoff_jti.db")
dsn := dbPath + "?" + url.Values{
"_pragma": []string{
"busy_timeout(30000)",
"journal_mode(WAL)",
"synchronous(NORMAL)",
},
}.Encode()
db, err := sql.Open("sqlite", dsn)
if err != nil {
s.initErr = fmt.Errorf("open handoff jti db: %w", err)
return
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0)
schema := `
CREATE TABLE IF NOT EXISTS handoff_jti (
jti TEXT PRIMARY KEY,
expires_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_handoff_jti_expires_at ON handoff_jti(expires_at);
`
if _, err := db.Exec(schema); err != nil {
_ = db.Close()
s.initErr = fmt.Errorf("init handoff jti schema: %w", err)
return
}
for _, path := range []string{dbPath, dbPath + "-wal", dbPath + "-shm"} {
if err := hardenPrivateFile(path, handoffPrivateFilePerm); err != nil {
_ = db.Close()
s.initErr = fmt.Errorf("harden handoff jti file permissions: %w", err)
return
}
}
s.db = db
})
}
func hardenPrivateFile(path string, mode os.FileMode) error {
info, err := os.Stat(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
if info.Mode().Perm() == mode {
return nil
}
return os.Chmod(path, mode)
}
func (s *jtiReplayStore) checkAndStore(jti string, expiresAt time.Time) (stored bool, err error) {
s.init()
if s.initErr != nil {
return false, s.initErr
}
if s.db == nil {
return false, fmt.Errorf("handoff jti store not initialized")
}
jti = strings.TrimSpace(jti)
if jti == "" {
return false, fmt.Errorf("jti is required")
}
expiresAt = expiresAt.UTC()
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now().UTC().Unix()
if _, err := s.db.Exec(deleteExpiredHandoffJTIQuery, now); err != nil {
return false, fmt.Errorf("cleanup handoff jti: %w", err)
}
_, err = s.db.Exec(`INSERT INTO handoff_jti (jti, expires_at) VALUES (?, ?)`, jti, expiresAt.Unix())
if err != nil {
if isSQLiteUniqueViolation(err) {
return false, nil
}
return false, fmt.Errorf("store handoff jti: %w", err)
}
return true, nil
}
func (s *jtiReplayStore) delete(jti string) error {
s.init()
if s.initErr != nil {
return s.initErr
}
if s.db == nil {
return fmt.Errorf("handoff jti store not initialized")
}
jti = strings.TrimSpace(jti)
if jti == "" {
return fmt.Errorf("jti is required")
}
s.mu.Lock()
defer s.mu.Unlock()
if _, err := s.db.Exec(`DELETE FROM handoff_jti WHERE jti = ?`, jti); err != nil {
return fmt.Errorf("delete handoff jti: %w", err)
}
return nil
}
func isSQLiteUniqueViolation(err error) bool {
if err == nil {
return false
}
s := err.Error()
return strings.Contains(s, "UNIQUE constraint failed") || strings.Contains(s, "constraint failed")
}
func tenantIDFromRequest(r *http.Request) string {
if v := strings.TrimSpace(os.Getenv("PULSE_TENANT_ID")); v != "" {
if isValidOrganizationID(v) {
return v
}
return ""
}
if hostedModeEnabledFromEnv() {
if v := tenantIDFromPublicURL(strings.TrimSpace(os.Getenv("PULSE_PUBLIC_URL"))); v != "" {
return v
}
}
if hostedRuntimeTenantID := tenantIDFromHostedProxyRequest(r); hostedRuntimeTenantID != "" {
return hostedRuntimeTenantID
}
if r == nil {
return ""
}
peerIP := extractRemoteIP(r.RemoteAddr)
trustedProxy := isTrustedProxyIP(peerIP)
rawHost := ""
if trustedProxy {
rawHost = firstForwardedValue(r.Header.Get("X-Forwarded-Host"))
}
if rawHost == "" {
// Only trust direct Host for loopback requests (local development/tests).
if ip := net.ParseIP(peerIP); ip != nil && ip.IsLoopback() {
rawHost = strings.TrimSpace(r.Host)
}
}
if rawHost == "" {
return ""
}
_, host := sanitizeForwardedHost(rawHost)
if host == "" {
return ""
}
tenantID := host
// Host is expected to be "<tenant-id>.<baseDomain>".
if i := strings.IndexByte(host, '.'); i > 0 {
tenantID = host[:i]
}
if !isValidOrganizationID(tenantID) {
return ""
}
return tenantID
}
func tenantIDFromHostedProxyRequest(r *http.Request) string {
if r == nil || !hostedModeEnabledFromEnv() {
return ""
}
if v := tenantIDFromPublicURL(strings.TrimSpace(r.Host)); v != "" {
return v
}
return ""
}
func tenantIDFromPublicURL(publicURL string) string {
if publicURL == "" {
return ""
}
if !strings.Contains(publicURL, "://") {
publicURL = "https://" + publicURL
}
parsed, err := url.Parse(publicURL)
if err != nil {
return ""
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return ""
}
// Hosted tenant URLs follow "<tenant-id>.<base-domain>" and must have a
// distinct tenant label ahead of the shared cloud domain.
if strings.Count(host, ".") < 3 {
return ""
}
tenantID := host
if i := strings.IndexByte(host, '.'); i > 0 {
tenantID = host[:i]
}
if !isValidOrganizationID(tenantID) {
return ""
}
return tenantID
}
func normalizeHandoffEmail(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
// HandleHandoffExchange verifies a control-plane-minted handoff JWT, records its
// jti to prevent replay, then creates a tenant session and redirects to the app.
//
// Route (wiring happens elsewhere): POST /api/cloud/handoff/exchange
//
// If the caller requests JSON (`Accept: application/json` or `?format=json`),
// this returns a success payload instead of redirecting.
func HandleHandoffExchange(configDir string) http.HandlerFunc {
configDir = filepath.Clean(configDir)
InitPersistentAuthStores(configDir)
keyPath := filepath.Join(configDir, "secrets", "handoff.key")
replay := &jtiReplayStore{configDir: configDir}
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
key, err := os.ReadFile(keyPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
http.Error(w, "not found", http.StatusNotFound)
return
}
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
tenantID := tenantIDFromRequest(r)
if tenantID == "" {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
tokenStr := strings.TrimSpace(r.FormValue("token"))
if tokenStr == "" {
http.Error(w, "missing token", http.StatusBadRequest)
return
}
var claims cloudHandoffClaims
parsed, err := jwt.ParseWithClaims(
tokenStr,
&claims,
func(t *jwt.Token) (any, error) { return key, nil },
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithIssuer(cloudHandoffIssuer),
jwt.WithAudience(tenantID),
)
if err != nil || parsed == nil || !parsed.Valid {
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
if claims.ExpiresAt == nil {
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
claims.Email = normalizeHandoffEmail(claims.Email)
if strings.TrimSpace(claims.ID) == "" || strings.TrimSpace(claims.Subject) == "" || claims.Email == "" {
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
ok, err := replay.checkAndStore(claims.ID, claims.ExpiresAt.Time)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if !ok {
http.Error(w, "replayed token", http.StatusUnauthorized)
return
}
if err := ensureHandoffOrganizationMembership(configDir, tenantID, claims.Email, claims.Role); err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
// Invalidate any pre-existing session to prevent session fixation attacks.
InvalidateOldSessionFromRequest(r)
sessionToken := generateSessionToken()
if sessionToken == "" {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
userAgent := r.Header.Get("User-Agent")
clientIP := GetClientIP(r)
sessionDuration := 24 * time.Hour
GetSessionStore().CreateSession(sessionToken, sessionDuration, userAgent, clientIP, claims.Email)
TrackUserSession(claims.Email, sessionToken)
csrfToken := generateCSRFToken(sessionToken)
isSecure, sameSitePolicy := getCookieSettings(r)
cookieMaxAge := int(sessionDuration.Seconds())
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName(isSecure),
Value: sessionToken,
Path: "/",
HttpOnly: true,
Secure: isSecure,
SameSite: sameSitePolicy,
MaxAge: cookieMaxAge,
})
http.SetCookie(w, &http.Cookie{
Name: CookieNameCSRF,
Value: csrfToken,
Path: "/",
Secure: isSecure,
SameSite: sameSitePolicy,
MaxAge: cookieMaxAge,
})
http.SetCookie(w, &http.Cookie{
Name: CookieNameOrgID,
Value: tenantID,
Path: "/",
Secure: isSecure,
SameSite: sameSitePolicy,
MaxAge: cookieMaxAge,
})
if strings.Contains(r.Header.Get("Accept"), "application/json") || r.URL.Query().Get("format") == "json" {
resp := map[string]any{
"ok": true,
"tenant_id": tenantID,
"account_id": claims.AccountID,
"user_id": claims.Subject,
"email": claims.Email,
"role": claims.Role,
"jti": claims.ID,
"exp": claims.ExpiresAt.Time.UTC().Format(time.RFC3339),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
return
}
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
}
}
func ensureHandoffOrganizationMembership(configDir, tenantID, email, role string) error {
mtp := config.NewMultiTenantPersistence(configDir)
org, err := mtp.LoadOrganization(tenantID)
if err != nil {
return fmt.Errorf("load tenant organization %s: %w", tenantID, err)
}
now := time.Now().UTC()
if org == nil {
org = &models.Organization{}
}
if strings.TrimSpace(org.ID) == "" {
org.ID = tenantID
}
if strings.TrimSpace(org.DisplayName) == "" {
org.DisplayName = tenantID
}
if models.NormalizeOrgStatus(org.Status) == "" {
org.Status = models.OrgStatusActive
} else {
org.Status = models.NormalizeOrgStatus(org.Status)
}
if org.CreatedAt.IsZero() {
org.CreatedAt = now
}
desiredRole := models.OrganizationRoleFromAccountRole(role)
memberFound := false
for i := range org.Members {
if !strings.EqualFold(strings.TrimSpace(org.Members[i].UserID), email) {
continue
}
memberFound = true
org.Members[i].UserID = email
if !models.OrganizationRoleAtLeast(org.Members[i].Role, desiredRole) {
org.Members[i].Role = desiredRole
}
if org.Members[i].AddedAt.IsZero() {
org.Members[i].AddedAt = now
}
if strings.TrimSpace(org.Members[i].AddedBy) == "" {
addedBy := strings.TrimSpace(org.OwnerUserID)
if addedBy == "" {
addedBy = email
}
org.Members[i].AddedBy = addedBy
}
break
}
if !memberFound {
addedBy := strings.TrimSpace(org.OwnerUserID)
if addedBy == "" {
addedBy = email
}
org.Members = append(org.Members, models.OrganizationMember{
UserID: email,
Role: desiredRole,
AddedAt: now,
AddedBy: addedBy,
})
}
if desiredRole == models.OrgRoleOwner && strings.TrimSpace(org.OwnerUserID) == "" {
org.OwnerUserID = email
}
return mtp.SaveOrganization(org)
}