mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-06 16:16:26 +00:00
499 lines
12 KiB
Go
499 lines
12 KiB
Go
package api
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
const (
|
|
cloudHandoffIssuer = "pulse-cloud-control-plane"
|
|
handoffPrivateDirPerm = 0o700
|
|
handoffPrivateFilePerm = 0o600
|
|
)
|
|
|
|
type cloudHandoffClaims struct {
|
|
AccountID string `json:"account_id"`
|
|
Email string `json:"email"`
|
|
Role string `json:"role"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type jtiReplayStore struct {
|
|
once sync.Once
|
|
db *sql.DB
|
|
mu sync.Mutex
|
|
|
|
configDir string
|
|
initErr error
|
|
}
|
|
|
|
const deleteExpiredHandoffJTIQuery = `
|
|
DELETE FROM handoff_jti INDEXED BY idx_handoff_jti_expires_at
|
|
WHERE expires_at <= ?`
|
|
|
|
func (s *jtiReplayStore) init() {
|
|
s.once.Do(func() {
|
|
dir := filepath.Clean(s.configDir)
|
|
if strings.TrimSpace(dir) == "" {
|
|
s.initErr = fmt.Errorf("configDir is required")
|
|
return
|
|
}
|
|
secretsDir := filepath.Join(dir, "secrets")
|
|
if err := os.MkdirAll(secretsDir, handoffPrivateDirPerm); err != nil {
|
|
s.initErr = fmt.Errorf("create handoff secrets dir: %w", err)
|
|
return
|
|
}
|
|
if err := os.Chmod(secretsDir, handoffPrivateDirPerm); err != nil {
|
|
s.initErr = fmt.Errorf("chmod handoff secrets dir: %w", err)
|
|
return
|
|
}
|
|
|
|
dbPath := filepath.Join(secretsDir, "handoff_jti.db")
|
|
dsn := dbPath + "?" + url.Values{
|
|
"_pragma": []string{
|
|
"busy_timeout(30000)",
|
|
"journal_mode(WAL)",
|
|
"synchronous(NORMAL)",
|
|
},
|
|
}.Encode()
|
|
|
|
db, err := sql.Open("sqlite", dsn)
|
|
if err != nil {
|
|
s.initErr = fmt.Errorf("open handoff jti db: %w", err)
|
|
return
|
|
}
|
|
db.SetMaxOpenConns(1)
|
|
db.SetMaxIdleConns(1)
|
|
db.SetConnMaxLifetime(0)
|
|
|
|
schema := `
|
|
CREATE TABLE IF NOT EXISTS handoff_jti (
|
|
jti TEXT PRIMARY KEY,
|
|
expires_at INTEGER NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_handoff_jti_expires_at ON handoff_jti(expires_at);
|
|
`
|
|
if _, err := db.Exec(schema); err != nil {
|
|
_ = db.Close()
|
|
s.initErr = fmt.Errorf("init handoff jti schema: %w", err)
|
|
return
|
|
}
|
|
for _, path := range []string{dbPath, dbPath + "-wal", dbPath + "-shm"} {
|
|
if err := hardenPrivateFile(path, handoffPrivateFilePerm); err != nil {
|
|
_ = db.Close()
|
|
s.initErr = fmt.Errorf("harden handoff jti file permissions: %w", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
s.db = db
|
|
})
|
|
}
|
|
|
|
func hardenPrivateFile(path string, mode os.FileMode) error {
|
|
info, err := os.Stat(path)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
if info.Mode().Perm() == mode {
|
|
return nil
|
|
}
|
|
return os.Chmod(path, mode)
|
|
}
|
|
|
|
func (s *jtiReplayStore) checkAndStore(jti string, expiresAt time.Time) (stored bool, err error) {
|
|
s.init()
|
|
if s.initErr != nil {
|
|
return false, s.initErr
|
|
}
|
|
if s.db == nil {
|
|
return false, fmt.Errorf("handoff jti store not initialized")
|
|
}
|
|
jti = strings.TrimSpace(jti)
|
|
if jti == "" {
|
|
return false, fmt.Errorf("jti is required")
|
|
}
|
|
expiresAt = expiresAt.UTC()
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now().UTC().Unix()
|
|
if _, err := s.db.Exec(deleteExpiredHandoffJTIQuery, now); err != nil {
|
|
return false, fmt.Errorf("cleanup handoff jti: %w", err)
|
|
}
|
|
|
|
_, err = s.db.Exec(`INSERT INTO handoff_jti (jti, expires_at) VALUES (?, ?)`, jti, expiresAt.Unix())
|
|
if err != nil {
|
|
if isSQLiteUniqueViolation(err) {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("store handoff jti: %w", err)
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func (s *jtiReplayStore) delete(jti string) error {
|
|
s.init()
|
|
if s.initErr != nil {
|
|
return s.initErr
|
|
}
|
|
if s.db == nil {
|
|
return fmt.Errorf("handoff jti store not initialized")
|
|
}
|
|
jti = strings.TrimSpace(jti)
|
|
if jti == "" {
|
|
return fmt.Errorf("jti is required")
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if _, err := s.db.Exec(`DELETE FROM handoff_jti WHERE jti = ?`, jti); err != nil {
|
|
return fmt.Errorf("delete handoff jti: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isSQLiteUniqueViolation(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
s := err.Error()
|
|
return strings.Contains(s, "UNIQUE constraint failed") || strings.Contains(s, "constraint failed")
|
|
}
|
|
|
|
func tenantIDFromRequest(r *http.Request) string {
|
|
if v := strings.TrimSpace(os.Getenv("PULSE_TENANT_ID")); v != "" {
|
|
if isValidOrganizationID(v) {
|
|
return v
|
|
}
|
|
return ""
|
|
}
|
|
if hostedModeEnabledFromEnv() {
|
|
if v := tenantIDFromPublicURL(strings.TrimSpace(os.Getenv("PULSE_PUBLIC_URL"))); v != "" {
|
|
return v
|
|
}
|
|
}
|
|
if hostedRuntimeTenantID := tenantIDFromHostedProxyRequest(r); hostedRuntimeTenantID != "" {
|
|
return hostedRuntimeTenantID
|
|
}
|
|
if r == nil {
|
|
return ""
|
|
}
|
|
|
|
peerIP := extractRemoteIP(r.RemoteAddr)
|
|
trustedProxy := isTrustedProxyIP(peerIP)
|
|
|
|
rawHost := ""
|
|
if trustedProxy {
|
|
rawHost = firstForwardedValue(r.Header.Get("X-Forwarded-Host"))
|
|
}
|
|
if rawHost == "" {
|
|
// Only trust direct Host for loopback requests (local development/tests).
|
|
if ip := net.ParseIP(peerIP); ip != nil && ip.IsLoopback() {
|
|
rawHost = strings.TrimSpace(r.Host)
|
|
}
|
|
}
|
|
if rawHost == "" {
|
|
return ""
|
|
}
|
|
|
|
_, host := sanitizeForwardedHost(rawHost)
|
|
if host == "" {
|
|
return ""
|
|
}
|
|
|
|
tenantID := host
|
|
// Host is expected to be "<tenant-id>.<baseDomain>".
|
|
if i := strings.IndexByte(host, '.'); i > 0 {
|
|
tenantID = host[:i]
|
|
}
|
|
if !isValidOrganizationID(tenantID) {
|
|
return ""
|
|
}
|
|
return tenantID
|
|
}
|
|
|
|
func tenantIDFromHostedProxyRequest(r *http.Request) string {
|
|
if r == nil || !hostedModeEnabledFromEnv() {
|
|
return ""
|
|
}
|
|
if v := tenantIDFromPublicURL(strings.TrimSpace(r.Host)); v != "" {
|
|
return v
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func tenantIDFromPublicURL(publicURL string) string {
|
|
if publicURL == "" {
|
|
return ""
|
|
}
|
|
if !strings.Contains(publicURL, "://") {
|
|
publicURL = "https://" + publicURL
|
|
}
|
|
parsed, err := url.Parse(publicURL)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
host := strings.TrimSpace(parsed.Hostname())
|
|
if host == "" {
|
|
return ""
|
|
}
|
|
// Hosted tenant URLs follow "<tenant-id>.<base-domain>" and must have a
|
|
// distinct tenant label ahead of the shared cloud domain.
|
|
if strings.Count(host, ".") < 3 {
|
|
return ""
|
|
}
|
|
tenantID := host
|
|
if i := strings.IndexByte(host, '.'); i > 0 {
|
|
tenantID = host[:i]
|
|
}
|
|
if !isValidOrganizationID(tenantID) {
|
|
return ""
|
|
}
|
|
return tenantID
|
|
}
|
|
|
|
func normalizeHandoffEmail(email string) string {
|
|
return strings.ToLower(strings.TrimSpace(email))
|
|
}
|
|
|
|
// HandleHandoffExchange verifies a control-plane-minted handoff JWT, records its
|
|
// jti to prevent replay, then creates a tenant session and redirects to the app.
|
|
//
|
|
// Route (wiring happens elsewhere): POST /api/cloud/handoff/exchange
|
|
//
|
|
// If the caller requests JSON (`Accept: application/json` or `?format=json`),
|
|
// this returns a success payload instead of redirecting.
|
|
func HandleHandoffExchange(configDir string) http.HandlerFunc {
|
|
configDir = filepath.Clean(configDir)
|
|
InitPersistentAuthStores(configDir)
|
|
keyPath := filepath.Join(configDir, "secrets", "handoff.key")
|
|
replay := &jtiReplayStore{configDir: configDir}
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
key, err := os.ReadFile(keyPath)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
tenantID := tenantIDFromRequest(r)
|
|
if tenantID == "" {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := r.ParseForm(); err != nil {
|
|
http.Error(w, "bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
tokenStr := strings.TrimSpace(r.FormValue("token"))
|
|
if tokenStr == "" {
|
|
http.Error(w, "missing token", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var claims cloudHandoffClaims
|
|
parsed, err := jwt.ParseWithClaims(
|
|
tokenStr,
|
|
&claims,
|
|
func(t *jwt.Token) (any, error) { return key, nil },
|
|
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
|
jwt.WithIssuer(cloudHandoffIssuer),
|
|
jwt.WithAudience(tenantID),
|
|
)
|
|
if err != nil || parsed == nil || !parsed.Valid {
|
|
http.Error(w, "invalid token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if claims.ExpiresAt == nil {
|
|
http.Error(w, "invalid token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
claims.Email = normalizeHandoffEmail(claims.Email)
|
|
if strings.TrimSpace(claims.ID) == "" || strings.TrimSpace(claims.Subject) == "" || claims.Email == "" {
|
|
http.Error(w, "invalid token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
ok, err := replay.checkAndStore(claims.ID, claims.ExpiresAt.Time)
|
|
if err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if !ok {
|
|
http.Error(w, "replayed token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if err := ensureHandoffOrganizationMembership(configDir, tenantID, claims.Email, claims.Role); err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Invalidate any pre-existing session to prevent session fixation attacks.
|
|
InvalidateOldSessionFromRequest(r)
|
|
|
|
sessionToken := generateSessionToken()
|
|
if sessionToken == "" {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
userAgent := r.Header.Get("User-Agent")
|
|
clientIP := GetClientIP(r)
|
|
sessionDuration := 24 * time.Hour
|
|
GetSessionStore().CreateSession(sessionToken, sessionDuration, userAgent, clientIP, claims.Email)
|
|
TrackUserSession(claims.Email, sessionToken)
|
|
|
|
csrfToken := generateCSRFToken(sessionToken)
|
|
isSecure, sameSitePolicy := getCookieSettings(r)
|
|
cookieMaxAge := int(sessionDuration.Seconds())
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: sessionCookieName(isSecure),
|
|
Value: sessionToken,
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: isSecure,
|
|
SameSite: sameSitePolicy,
|
|
MaxAge: cookieMaxAge,
|
|
})
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: CookieNameCSRF,
|
|
Value: csrfToken,
|
|
Path: "/",
|
|
Secure: isSecure,
|
|
SameSite: sameSitePolicy,
|
|
MaxAge: cookieMaxAge,
|
|
})
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: CookieNameOrgID,
|
|
Value: tenantID,
|
|
Path: "/",
|
|
Secure: isSecure,
|
|
SameSite: sameSitePolicy,
|
|
MaxAge: cookieMaxAge,
|
|
})
|
|
|
|
if strings.Contains(r.Header.Get("Accept"), "application/json") || r.URL.Query().Get("format") == "json" {
|
|
resp := map[string]any{
|
|
"ok": true,
|
|
"tenant_id": tenantID,
|
|
"account_id": claims.AccountID,
|
|
"user_id": claims.Subject,
|
|
"email": claims.Email,
|
|
"role": claims.Role,
|
|
"jti": claims.ID,
|
|
"exp": claims.ExpiresAt.Time.UTC().Format(time.RFC3339),
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
}
|
|
}
|
|
|
|
func ensureHandoffOrganizationMembership(configDir, tenantID, email, role string) error {
|
|
mtp := config.NewMultiTenantPersistence(configDir)
|
|
org, err := mtp.LoadOrganization(tenantID)
|
|
if err != nil {
|
|
return fmt.Errorf("load tenant organization %s: %w", tenantID, err)
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
if org == nil {
|
|
org = &models.Organization{}
|
|
}
|
|
if strings.TrimSpace(org.ID) == "" {
|
|
org.ID = tenantID
|
|
}
|
|
if strings.TrimSpace(org.DisplayName) == "" {
|
|
org.DisplayName = tenantID
|
|
}
|
|
if models.NormalizeOrgStatus(org.Status) == "" {
|
|
org.Status = models.OrgStatusActive
|
|
} else {
|
|
org.Status = models.NormalizeOrgStatus(org.Status)
|
|
}
|
|
if org.CreatedAt.IsZero() {
|
|
org.CreatedAt = now
|
|
}
|
|
|
|
desiredRole := models.OrganizationRoleFromAccountRole(role)
|
|
memberFound := false
|
|
for i := range org.Members {
|
|
if !strings.EqualFold(strings.TrimSpace(org.Members[i].UserID), email) {
|
|
continue
|
|
}
|
|
memberFound = true
|
|
org.Members[i].UserID = email
|
|
if !models.OrganizationRoleAtLeast(org.Members[i].Role, desiredRole) {
|
|
org.Members[i].Role = desiredRole
|
|
}
|
|
if org.Members[i].AddedAt.IsZero() {
|
|
org.Members[i].AddedAt = now
|
|
}
|
|
if strings.TrimSpace(org.Members[i].AddedBy) == "" {
|
|
addedBy := strings.TrimSpace(org.OwnerUserID)
|
|
if addedBy == "" {
|
|
addedBy = email
|
|
}
|
|
org.Members[i].AddedBy = addedBy
|
|
}
|
|
break
|
|
}
|
|
|
|
if !memberFound {
|
|
addedBy := strings.TrimSpace(org.OwnerUserID)
|
|
if addedBy == "" {
|
|
addedBy = email
|
|
}
|
|
org.Members = append(org.Members, models.OrganizationMember{
|
|
UserID: email,
|
|
Role: desiredRole,
|
|
AddedAt: now,
|
|
AddedBy: addedBy,
|
|
})
|
|
}
|
|
|
|
if desiredRole == models.OrgRoleOwner && strings.TrimSpace(org.OwnerUserID) == "" {
|
|
org.OwnerUserID = email
|
|
}
|
|
|
|
return mtp.SaveOrganization(org)
|
|
}
|