mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 00:37:36 +00:00
838 lines
26 KiB
Go
838 lines
26 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/stripe/stripe-go/v82"
|
|
"github.com/stripe/stripe-go/v82/webhook"
|
|
)
|
|
|
|
const stripeWebhookBodyLimit = 1024 * 1024 // 1MiB
|
|
|
|
var errStripeWebhookEventInFlight = errors.New("stripe webhook event is in-flight")
|
|
|
|
// StripeWebhookHandlers handles Stripe webhooks for hosted Cloud provisioning.
|
|
//
|
|
// SECURITY: Signature verification (ConstructEvent) is the authentication mechanism for this endpoint.
|
|
type StripeWebhookHandlers struct {
|
|
hostedMode bool
|
|
|
|
persistence *config.MultiTenantPersistence
|
|
rbacProvider HostedRBACProvider
|
|
magicLinks *MagicLinkService
|
|
publicURL func(*http.Request) string
|
|
|
|
billingStore *config.FileBillingStore
|
|
|
|
deduper *stripeWebhookDeduper
|
|
index *stripeCustomerOrgIndex
|
|
|
|
conversionRecorder *conversionRecorder
|
|
conversionHealth *conversionPipelineHealth
|
|
disableMetrics func() bool
|
|
|
|
now func() time.Time
|
|
}
|
|
|
|
// SetConversionRecorder wires the conversion event recorder for backend-emitted
|
|
// conversion events (checkout_completed on checkout.session.completed webhook).
|
|
func (h *StripeWebhookHandlers) SetConversionRecorder(rec *conversionRecorder, health *conversionPipelineHealth, disableAll ...func() bool) {
|
|
if h == nil {
|
|
return
|
|
}
|
|
h.conversionRecorder = rec
|
|
h.conversionHealth = health
|
|
if len(disableAll) > 0 {
|
|
h.disableMetrics = disableAll[0]
|
|
}
|
|
}
|
|
|
|
// emitConversionEvent is a fire-and-forget helper that records a backend-emitted
|
|
// conversion event. Respects the DisableLocalUpgradeMetrics config flag.
|
|
// Errors are logged but never propagated to callers.
|
|
func (h *StripeWebhookHandlers) emitConversionEvent(orgID string, event conversionEvent) {
|
|
if h == nil || h.conversionRecorder == nil {
|
|
return
|
|
}
|
|
if h.disableMetrics != nil && h.disableMetrics() {
|
|
return
|
|
}
|
|
if orgID == "" {
|
|
orgID = "default"
|
|
}
|
|
event.OrgID = orgID
|
|
if event.Timestamp <= 0 {
|
|
event.Timestamp = time.Now().UnixMilli()
|
|
}
|
|
if event.IdempotencyKey == "" {
|
|
event.IdempotencyKey = fmt.Sprintf("backend:%s:%s:%s:%d", orgID, event.Type, event.Surface, event.Timestamp)
|
|
}
|
|
if err := h.conversionRecorder.Record(event); err != nil {
|
|
log.Warn().Err(err).Str("event_type", event.Type).Str("org_id", orgID).Msg("Failed to record backend conversion event")
|
|
} else {
|
|
recordConversionEventMetric(event.Type, event.Surface)
|
|
if h.conversionHealth != nil {
|
|
h.conversionHealth.RecordEvent(event.Type)
|
|
}
|
|
}
|
|
}
|
|
|
|
func NewStripeWebhookHandlers(
|
|
billingStore *config.FileBillingStore,
|
|
persistence *config.MultiTenantPersistence,
|
|
rbacProvider HostedRBACProvider,
|
|
magicLinks *MagicLinkService,
|
|
publicURL func(*http.Request) string,
|
|
hostedMode bool,
|
|
dataPath string,
|
|
) *StripeWebhookHandlers {
|
|
baseDir := resolvePulseDataDir(dataPath)
|
|
return &StripeWebhookHandlers{
|
|
hostedMode: hostedMode,
|
|
persistence: persistence,
|
|
rbacProvider: rbacProvider,
|
|
magicLinks: magicLinks,
|
|
publicURL: publicURL,
|
|
billingStore: billingStore,
|
|
deduper: newStripeWebhookDeduper(filepath.Join(baseDir, "stripe", "webhook-events")),
|
|
index: newStripeCustomerOrgIndex(filepath.Join(baseDir, "stripe", "customers")),
|
|
now: time.Now,
|
|
}
|
|
}
|
|
|
|
func (h *StripeWebhookHandlers) HandleStripeWebhook(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if !h.hostedMode {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
if h.billingStore == nil || h.persistence == nil || h.rbacProvider == nil || h.deduper == nil || h.index == nil {
|
|
writeErrorResponse(w, http.StatusServiceUnavailable, "stripe_unavailable", "Stripe webhook handler is not configured", nil)
|
|
return
|
|
}
|
|
|
|
secret := strings.TrimSpace(os.Getenv("STRIPE_WEBHOOK_SECRET"))
|
|
if secret == "" {
|
|
writeErrorResponse(w, http.StatusServiceUnavailable, "stripe_unavailable", "Stripe webhook secret is not configured", nil)
|
|
return
|
|
}
|
|
|
|
r.Body = http.MaxBytesReader(w, r.Body, stripeWebhookBodyLimit)
|
|
payload, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", nil)
|
|
return
|
|
}
|
|
|
|
sigHeader := r.Header.Get("Stripe-Signature")
|
|
if strings.TrimSpace(sigHeader) == "" {
|
|
// Intentionally vague; missing signature is treated as invalid auth.
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_signature", "Invalid Stripe signature", nil)
|
|
return
|
|
}
|
|
|
|
event, err := webhook.ConstructEventWithOptions(payload, sigHeader, secret, webhook.ConstructEventOptions{
|
|
IgnoreAPIVersionMismatch: true,
|
|
})
|
|
if err != nil {
|
|
writeErrorResponse(w, http.StatusBadRequest, "invalid_signature", "Invalid Stripe signature", nil)
|
|
return
|
|
}
|
|
|
|
already, err := h.deduper.Do(event.ID, func() error {
|
|
return h.handleEvent(r.Context(), &event, r)
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, errStripeWebhookEventInFlight) {
|
|
log.Warn().
|
|
Str("event_id", event.ID).
|
|
Str("type", string(event.Type)).
|
|
Msg("Stripe webhook event is already in-flight; returning non-2xx so Stripe retries")
|
|
writeErrorResponse(w, http.StatusConflict, "stripe_in_flight", "Stripe webhook is being processed; retry later", nil)
|
|
return
|
|
}
|
|
log.Error().Err(err).Str("event_id", event.ID).Str("type", string(event.Type)).Msg("Stripe webhook processing failed")
|
|
writeErrorResponse(w, http.StatusInternalServerError, "stripe_processing_failed", "Failed to process Stripe webhook", nil)
|
|
return
|
|
}
|
|
|
|
if already {
|
|
// Stripe treats any 2xx as success; returning JSON helps local debugging.
|
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
|
"received": true,
|
|
"status": "duplicate",
|
|
})
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
|
"received": true,
|
|
"status": "processed",
|
|
})
|
|
}
|
|
|
|
func (h *StripeWebhookHandlers) handleEvent(ctx context.Context, event *stripe.Event, r *http.Request) error {
|
|
if event == nil {
|
|
return errors.New("stripe event is nil")
|
|
}
|
|
|
|
switch event.Type {
|
|
case "checkout.session.completed":
|
|
var session stripeCheckoutSession
|
|
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
|
|
return fmt.Errorf("decode checkout.session: %w", err)
|
|
}
|
|
return h.handleCheckoutSessionCompleted(ctx, session, r)
|
|
case "customer.subscription.updated":
|
|
var sub stripeSubscription
|
|
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
|
return fmt.Errorf("decode subscription: %w", err)
|
|
}
|
|
return h.handleSubscriptionUpdated(ctx, sub)
|
|
case "customer.subscription.deleted":
|
|
var sub stripeSubscription
|
|
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
|
return fmt.Errorf("decode subscription: %w", err)
|
|
}
|
|
return h.handleSubscriptionDeleted(ctx, sub)
|
|
default:
|
|
log.Info().Str("type", string(event.Type)).Str("event_id", event.ID).Msg("Stripe webhook ignored (unhandled type)")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
type stripeCheckoutSession struct {
|
|
ID string `json:"id"`
|
|
Mode string `json:"mode"`
|
|
Customer string `json:"customer"`
|
|
Subscription string `json:"subscription"`
|
|
CustomerEmail string `json:"customer_email"`
|
|
CustomerDetails stripeCustDetails `json:"customer_details"`
|
|
ClientReference string `json:"client_reference_id"`
|
|
Metadata map[string]string `json:"metadata"`
|
|
SubscriptionData map[string]any `json:"subscription_data"`
|
|
}
|
|
|
|
type stripeCustDetails struct {
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
type stripeSubscription struct {
|
|
ID string `json:"id"`
|
|
Customer string `json:"customer"`
|
|
Status string `json:"status"`
|
|
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
|
CurrentPeriodEnd int64 `json:"current_period_end"`
|
|
EndedAt int64 `json:"ended_at"`
|
|
CancellationReason string `json:"cancellation_reason"`
|
|
Items struct {
|
|
Data []struct {
|
|
Price struct {
|
|
ID string `json:"id"`
|
|
Product string `json:"product"`
|
|
Metadata map[string]string `json:"metadata"`
|
|
} `json:"price"`
|
|
} `json:"data"`
|
|
} `json:"items"`
|
|
Metadata map[string]string `json:"metadata"`
|
|
}
|
|
|
|
func (h *StripeWebhookHandlers) handleCheckoutSessionCompleted(ctx context.Context, session stripeCheckoutSession, r *http.Request) error {
|
|
// Expect subscription-mode sessions for Cloud.
|
|
if strings.TrimSpace(session.Customer) == "" {
|
|
return fmt.Errorf("checkout session missing customer")
|
|
}
|
|
|
|
// SECURITY: customer email is not a safe org identifier.
|
|
// If present, it's used only for best-effort post-checkout UX (magic link) and audit logs.
|
|
email := strings.ToLower(strings.TrimSpace(session.CustomerEmail))
|
|
if email == "" {
|
|
email = strings.ToLower(strings.TrimSpace(session.CustomerDetails.Email))
|
|
}
|
|
|
|
orgName := ""
|
|
if session.Metadata != nil {
|
|
orgName = strings.TrimSpace(session.Metadata["org_name"])
|
|
if orgName == "" {
|
|
orgName = strings.TrimSpace(session.Metadata["org"])
|
|
}
|
|
}
|
|
// Prefer existing mapping by customer ID; otherwise require server-owned linkage (metadata/client_reference_id).
|
|
orgID, ok, err := h.index.LookupOrgID(session.Customer)
|
|
if err != nil {
|
|
return fmt.Errorf("lookup org by customer id: %w", err)
|
|
}
|
|
orgResolvedBy := "customer_index"
|
|
if !ok {
|
|
orgID = ""
|
|
if session.Metadata != nil {
|
|
orgID = strings.TrimSpace(session.Metadata["org_id"])
|
|
}
|
|
if orgID == "" {
|
|
orgID = strings.TrimSpace(session.ClientReference)
|
|
}
|
|
orgResolvedBy = "session_linkage"
|
|
}
|
|
orgID = strings.TrimSpace(orgID)
|
|
if orgID == "" {
|
|
return fmt.Errorf(
|
|
"checkout session %q missing org linkage for customer %q",
|
|
strings.TrimSpace(session.ID),
|
|
strings.TrimSpace(session.Customer),
|
|
)
|
|
}
|
|
if !isValidOrganizationID(orgID) {
|
|
return fmt.Errorf(
|
|
"checkout session %q resolved invalid org %q via %s",
|
|
strings.TrimSpace(session.ID),
|
|
orgID,
|
|
orgResolvedBy,
|
|
)
|
|
}
|
|
|
|
// SECURITY: only provision into an org that already exists. Do not create tenants from webhook payloads.
|
|
org, err := h.persistence.LoadOrganizationStrict(orgID)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return fmt.Errorf(
|
|
"checkout session %q linked to missing org %q via %s",
|
|
strings.TrimSpace(session.ID),
|
|
orgID,
|
|
orgResolvedBy,
|
|
)
|
|
}
|
|
return fmt.Errorf("load org: %w", err)
|
|
}
|
|
if org == nil {
|
|
return fmt.Errorf("load org: empty org")
|
|
}
|
|
|
|
// Persist customer->org mapping once the linkage has been validated.
|
|
if !ok {
|
|
if err := h.index.Save(session.Customer, orgID); err != nil {
|
|
return fmt.Errorf("save customer index: %w", err)
|
|
}
|
|
}
|
|
|
|
planVersion := derivePlanVersion(session.Metadata, "")
|
|
limits, _ := limitsForCloudPlanFromLicensing(planVersion)
|
|
|
|
state := &billingState{
|
|
Capabilities: cloudCapabilitiesFromLicensing(),
|
|
Limits: limits,
|
|
MetersEnabled: []string{},
|
|
PlanVersion: planVersion,
|
|
SubscriptionState: subscriptionStateActiveValue,
|
|
StripeCustomerID: session.Customer,
|
|
StripeSubscriptionID: strings.TrimSpace(session.Subscription),
|
|
}
|
|
|
|
if err := h.billingStore.SaveBillingState(orgID, state); err != nil {
|
|
return fmt.Errorf("save billing state: %w", err)
|
|
}
|
|
|
|
// Best-effort: issue a magic link so the user can sign in quickly after checkout.
|
|
// (In dev/staging this is log-only; production should swap in a real emailer.)
|
|
if h.magicLinks != nil && email != "" {
|
|
// Only send a link to an existing org member/owner. Stripe customer email is user-controlled.
|
|
sendTo := ""
|
|
if strings.EqualFold(org.OwnerUserID, email) {
|
|
sendTo = org.OwnerUserID
|
|
} else {
|
|
for _, m := range org.Members {
|
|
if strings.EqualFold(m.UserID, email) {
|
|
sendTo = m.UserID
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if sendTo != "" && h.magicLinks.AllowRequest(sendTo) {
|
|
token, genErr := h.magicLinks.GenerateToken(sendTo, orgID)
|
|
if genErr == nil {
|
|
baseURL := ""
|
|
if h.publicURL != nil && r != nil {
|
|
baseURL = h.publicURL(r)
|
|
}
|
|
if baseURL != "" {
|
|
if sendErr := h.magicLinks.SendMagicLink(sendTo, orgID, token, baseURL); sendErr != nil {
|
|
log.Warn().Err(sendErr).Str("email", sendTo).Str("org_id", orgID).Msg("Stripe checkout: failed to send magic link")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
h.emitConversionEvent(orgID, conversionEvent{
|
|
Type: conversionEventCheckoutCompleted,
|
|
Surface: "stripe_webhook",
|
|
})
|
|
|
|
log.Info().
|
|
Str("org_id", orgID).
|
|
Str("email", email).
|
|
Str("org_name", orgName).
|
|
Str("customer_id", session.Customer).
|
|
Str("resolved_by", orgResolvedBy).
|
|
Msg("Stripe checkout.session.completed processed")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *StripeWebhookHandlers) handleSubscriptionUpdated(ctx context.Context, sub stripeSubscription) error {
|
|
customerID := strings.TrimSpace(sub.Customer)
|
|
if customerID == "" {
|
|
return fmt.Errorf("subscription missing customer")
|
|
}
|
|
|
|
orgID, ok, err := h.index.LookupOrgID(customerID)
|
|
if err != nil {
|
|
return fmt.Errorf("lookup org by customer id: %w", err)
|
|
}
|
|
if !ok {
|
|
// Backstop for older data: scan org billing files.
|
|
orgID, ok, err = h.scanOrgByStripeCustomerID(customerID)
|
|
if err != nil {
|
|
return fmt.Errorf("scan org by customer id: %w", err)
|
|
}
|
|
if ok {
|
|
if saveErr := h.index.Save(customerID, orgID); saveErr != nil {
|
|
log.Warn().
|
|
Err(saveErr).
|
|
Str("customer_id", customerID).
|
|
Str("org_id", orgID).
|
|
Msg("Stripe subscription.updated: failed to backfill customer org index")
|
|
}
|
|
}
|
|
}
|
|
if !ok {
|
|
log.Warn().Str("customer_id", customerID).Str("subscription_id", sub.ID).Msg("Stripe subscription.updated: org not found for customer")
|
|
return nil
|
|
}
|
|
|
|
before, err := h.billingStore.GetBillingState(orgID)
|
|
if err != nil {
|
|
return fmt.Errorf("load billing state: %w", err)
|
|
}
|
|
state := normalizeBillingState(before)
|
|
|
|
subState := mapStripeSubscriptionStatusToState(sub.Status)
|
|
state.SubscriptionState = subState
|
|
|
|
priceID := firstPriceID(sub)
|
|
state.StripePriceID = priceID
|
|
state.StripeCustomerID = customerID
|
|
state.StripeSubscriptionID = strings.TrimSpace(sub.ID)
|
|
state.PlanVersion = derivePlanVersion(sub.Metadata, priceID)
|
|
|
|
if shouldGrantPaidCapabilities(subState) {
|
|
state.Capabilities = cloudCapabilitiesFromLicensing()
|
|
limits, _ := limitsForCloudPlanFromLicensing(state.PlanVersion)
|
|
state.Limits = limits
|
|
} else {
|
|
state.Capabilities = []string{}
|
|
state.Limits = map[string]int64{}
|
|
}
|
|
|
|
if err := h.billingStore.SaveBillingState(orgID, state); err != nil {
|
|
return fmt.Errorf("save billing state: %w", err)
|
|
}
|
|
|
|
log.Info().
|
|
Str("org_id", orgID).
|
|
Str("customer_id", customerID).
|
|
Str("subscription_id", sub.ID).
|
|
Str("subscription_state", string(subState)).
|
|
Msg("Stripe customer.subscription.updated processed")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *StripeWebhookHandlers) handleSubscriptionDeleted(ctx context.Context, sub stripeSubscription) error {
|
|
customerID := strings.TrimSpace(sub.Customer)
|
|
if customerID == "" {
|
|
return fmt.Errorf("subscription missing customer")
|
|
}
|
|
|
|
orgID, ok, err := h.index.LookupOrgID(customerID)
|
|
if err != nil {
|
|
return fmt.Errorf("lookup org by customer id: %w", err)
|
|
}
|
|
if !ok {
|
|
orgID, ok, err = h.scanOrgByStripeCustomerID(customerID)
|
|
if err != nil {
|
|
return fmt.Errorf("scan org by customer id: %w", err)
|
|
}
|
|
if ok {
|
|
if saveErr := h.index.Save(customerID, orgID); saveErr != nil {
|
|
log.Warn().
|
|
Err(saveErr).
|
|
Str("customer_id", customerID).
|
|
Str("org_id", orgID).
|
|
Msg("Stripe subscription.deleted: failed to backfill customer org index")
|
|
}
|
|
}
|
|
}
|
|
if !ok {
|
|
log.Warn().Str("customer_id", customerID).Str("subscription_id", sub.ID).Msg("Stripe subscription.deleted: org not found for customer")
|
|
return nil
|
|
}
|
|
|
|
before, err := h.billingStore.GetBillingState(orgID)
|
|
if err != nil {
|
|
return fmt.Errorf("load billing state: %w", err)
|
|
}
|
|
state := normalizeBillingState(before)
|
|
|
|
// CRITICAL: revoke paid capabilities immediately on cancellation.
|
|
state.SubscriptionState = subscriptionStateCanceledValue
|
|
state.Capabilities = []string{}
|
|
state.Limits = map[string]int64{}
|
|
state.StripeCustomerID = customerID
|
|
state.StripeSubscriptionID = strings.TrimSpace(sub.ID)
|
|
|
|
if err := h.billingStore.SaveBillingState(orgID, state); err != nil {
|
|
return fmt.Errorf("save billing state: %w", err)
|
|
}
|
|
|
|
log.Info().
|
|
Str("org_id", orgID).
|
|
Str("customer_id", customerID).
|
|
Str("subscription_id", sub.ID).
|
|
Msg("Stripe customer.subscription.deleted processed (capabilities revoked)")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *StripeWebhookHandlers) scanOrgByStripeCustomerID(customerID string) (string, bool, error) {
|
|
orgs, err := h.persistence.ListOrganizations()
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
customerID = strings.TrimSpace(customerID)
|
|
if customerID == "" {
|
|
return "", false, nil
|
|
}
|
|
|
|
for _, org := range orgs {
|
|
if org == nil || strings.TrimSpace(org.ID) == "" {
|
|
continue
|
|
}
|
|
state, loadErr := h.billingStore.GetBillingState(org.ID)
|
|
if loadErr != nil || state == nil {
|
|
continue
|
|
}
|
|
if strings.TrimSpace(state.StripeCustomerID) == customerID {
|
|
return org.ID, true, nil
|
|
}
|
|
}
|
|
|
|
return "", false, nil
|
|
}
|
|
|
|
func firstPriceID(sub stripeSubscription) string {
|
|
for _, item := range sub.Items.Data {
|
|
if strings.TrimSpace(item.Price.ID) != "" {
|
|
return strings.TrimSpace(item.Price.ID)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func mapStripeSubscriptionStatusToState(status string) subscriptionState {
|
|
return mapStripeSubscriptionStatusToStateFromLicensing(status)
|
|
}
|
|
|
|
func shouldGrantPaidCapabilities(state subscriptionState) bool {
|
|
return shouldGrantPaidCapabilitiesFromLicensing(state)
|
|
}
|
|
|
|
func derivePlanVersion(metadata map[string]string, priceID string) string {
|
|
return deriveStripePlanVersionFromLicensing(metadata, priceID)
|
|
}
|
|
|
|
func resolvePulseDataDir(dataPath string) string {
|
|
return config.ResolveRuntimeDataDir(dataPath)
|
|
}
|
|
|
|
// stripeWebhookDeduper provides durable idempotency for Stripe webhook event IDs.
|
|
// Stripe retries webhooks; without a persistent dedupe store, retries can provision duplicate tenants.
|
|
type stripeWebhookDeduper struct {
|
|
dir string
|
|
lockTTL time.Duration
|
|
now func() time.Time
|
|
hashSalt []byte
|
|
}
|
|
|
|
func newStripeWebhookDeduper(dir string) *stripeWebhookDeduper {
|
|
return &stripeWebhookDeduper{
|
|
dir: dir,
|
|
lockTTL: 10 * time.Minute,
|
|
now: time.Now,
|
|
// Salt prevents event IDs from being used directly as filenames if they contain odd characters.
|
|
// (Event IDs are normally safe, but this keeps the filesystem contract tight.)
|
|
hashSalt: []byte("pulse-stripe-webhook-v1"),
|
|
}
|
|
}
|
|
|
|
func (d *stripeWebhookDeduper) Do(eventID string, fn func() error) (already bool, err error) {
|
|
if d == nil {
|
|
return false, errors.New("deduper is nil")
|
|
}
|
|
if strings.TrimSpace(eventID) == "" {
|
|
return false, errors.New("event id is required")
|
|
}
|
|
if fn == nil {
|
|
return false, errors.New("handler is required")
|
|
}
|
|
|
|
donePath := d.donePath(eventID)
|
|
if fileExists(donePath) {
|
|
return true, nil
|
|
}
|
|
|
|
lockPath := d.lockPath(eventID)
|
|
acquired, lockErr := d.acquireLock(lockPath)
|
|
if lockErr != nil {
|
|
return false, fmt.Errorf("acquire dedupe lock: %w", lockErr)
|
|
}
|
|
if !acquired {
|
|
// Another in-flight processor exists. If processing has already completed, treat as a duplicate.
|
|
// If the lock exists but the done file does not, we must return a non-2xx so Stripe retries later.
|
|
// (Otherwise a concurrent Stripe retry could stop retrying while the original attempt still fails.)
|
|
if fileExists(donePath) {
|
|
return true, nil
|
|
}
|
|
return false, errStripeWebhookEventInFlight
|
|
}
|
|
|
|
defer func() {
|
|
if rmErr := os.Remove(lockPath); rmErr != nil && !errors.Is(rmErr, os.ErrNotExist) {
|
|
log.Warn().Err(rmErr).Str("path", lockPath).Msg("Stripe dedupe: failed to remove lock file")
|
|
}
|
|
}()
|
|
|
|
if err := fn(); err != nil {
|
|
return false, fmt.Errorf("process webhook event: %w", err)
|
|
}
|
|
|
|
if err := os.MkdirAll(filepath.Dir(donePath), 0o700); err != nil {
|
|
return false, fmt.Errorf("create dedupe dir: %w", err)
|
|
}
|
|
|
|
meta := map[string]any{
|
|
"handled_at": d.now().UTC().UnixMilli(),
|
|
}
|
|
data, err := json.Marshal(meta)
|
|
if err != nil {
|
|
return false, fmt.Errorf("marshal dedupe metadata: %w", err)
|
|
}
|
|
tmp := donePath + ".tmp"
|
|
if err := os.WriteFile(tmp, data, 0o600); err != nil {
|
|
return false, fmt.Errorf("write dedupe tmp: %w", err)
|
|
}
|
|
if err := os.Rename(tmp, donePath); err != nil {
|
|
if rmErr := os.Remove(tmp); rmErr != nil && !errors.Is(rmErr, os.ErrNotExist) {
|
|
return false, errors.Join(
|
|
fmt.Errorf("commit dedupe: %w", err),
|
|
fmt.Errorf("remove dedupe tmp %s: %w", tmp, rmErr),
|
|
)
|
|
}
|
|
return false, fmt.Errorf("commit dedupe: %w", err)
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func (d *stripeWebhookDeduper) acquireLock(lockPath string) (bool, error) {
|
|
if err := os.MkdirAll(filepath.Dir(lockPath), 0o700); err != nil {
|
|
return false, fmt.Errorf("create lock dir: %w", err)
|
|
}
|
|
|
|
f, err := os.OpenFile(lockPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600)
|
|
if err == nil {
|
|
if closeErr := f.Close(); closeErr != nil {
|
|
if rmErr := os.Remove(lockPath); rmErr != nil && !errors.Is(rmErr, os.ErrNotExist) {
|
|
return false, errors.Join(
|
|
fmt.Errorf("close lock file: %w", closeErr),
|
|
fmt.Errorf("cleanup lock file %s: %w", lockPath, rmErr),
|
|
)
|
|
}
|
|
return false, fmt.Errorf("close lock file: %w", closeErr)
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
if !errors.Is(err, os.ErrExist) {
|
|
return false, fmt.Errorf("create lock: %w", err)
|
|
}
|
|
|
|
// Break stale locks (e.g., process crash) so Stripe retries can succeed.
|
|
info, statErr := os.Stat(lockPath)
|
|
if statErr != nil && !errors.Is(statErr, os.ErrNotExist) {
|
|
return false, fmt.Errorf("stat lock file: %w", statErr)
|
|
}
|
|
if statErr == nil && d.now().Sub(info.ModTime()) > d.lockTTL {
|
|
if rmErr := os.Remove(lockPath); rmErr != nil && !errors.Is(rmErr, os.ErrNotExist) {
|
|
return false, fmt.Errorf("remove stale lock: %w", rmErr)
|
|
}
|
|
f, err := os.OpenFile(lockPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600)
|
|
if err == nil {
|
|
if closeErr := f.Close(); closeErr != nil {
|
|
if rmErr := os.Remove(lockPath); rmErr != nil && !errors.Is(rmErr, os.ErrNotExist) {
|
|
return false, errors.Join(
|
|
fmt.Errorf("close recreated lock file: %w", closeErr),
|
|
fmt.Errorf("cleanup recreated lock file %s: %w", lockPath, rmErr),
|
|
)
|
|
}
|
|
return false, fmt.Errorf("close recreated lock file: %w", closeErr)
|
|
}
|
|
return true, nil
|
|
}
|
|
if errors.Is(err, os.ErrExist) {
|
|
return false, nil
|
|
}
|
|
return false, fmt.Errorf("recreate lock: %w", err)
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func (d *stripeWebhookDeduper) donePath(eventID string) string {
|
|
return filepath.Join(d.dir, d.filenameForID(eventID)+".done")
|
|
}
|
|
|
|
func (d *stripeWebhookDeduper) lockPath(eventID string) string {
|
|
return filepath.Join(d.dir, d.filenameForID(eventID)+".lock")
|
|
}
|
|
|
|
func (d *stripeWebhookDeduper) filenameForID(id string) string {
|
|
// Use a deterministic HMAC so we never trust arbitrary IDs as filesystem paths.
|
|
mac := hmac.New(sha256.New, d.hashSalt)
|
|
_, _ = mac.Write([]byte(id))
|
|
return hex.EncodeToString(mac.Sum(nil))
|
|
}
|
|
|
|
type stripeCustomerOrgIndex struct {
|
|
dir string
|
|
}
|
|
|
|
func newStripeCustomerOrgIndex(dir string) *stripeCustomerOrgIndex {
|
|
return &stripeCustomerOrgIndex{dir: dir}
|
|
}
|
|
|
|
func (i *stripeCustomerOrgIndex) LookupOrgID(customerID string) (string, bool, error) {
|
|
if i == nil {
|
|
return "", false, errors.New("index is nil")
|
|
}
|
|
customerID = strings.TrimSpace(customerID)
|
|
if customerID == "" {
|
|
return "", false, nil
|
|
}
|
|
if !isSafeStripeID(customerID) {
|
|
return "", false, fmt.Errorf("invalid stripe customer id")
|
|
}
|
|
|
|
path := filepath.Join(i.dir, customerID+".json")
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return "", false, nil
|
|
}
|
|
return "", false, fmt.Errorf("read customer org index %s: %w", path, err)
|
|
}
|
|
var rec struct {
|
|
OrgID string `json:"org_id"`
|
|
}
|
|
if err := json.Unmarshal(data, &rec); err != nil {
|
|
return "", false, fmt.Errorf("decode customer org index %s: %w", path, err)
|
|
}
|
|
orgID := strings.TrimSpace(rec.OrgID)
|
|
if orgID == "" {
|
|
return "", false, nil
|
|
}
|
|
return orgID, true, nil
|
|
}
|
|
|
|
func (i *stripeCustomerOrgIndex) Save(customerID, orgID string) error {
|
|
if i == nil {
|
|
return errors.New("index is nil")
|
|
}
|
|
customerID = strings.TrimSpace(customerID)
|
|
orgID = strings.TrimSpace(orgID)
|
|
if customerID == "" || orgID == "" {
|
|
return fmt.Errorf("customerID and orgID are required")
|
|
}
|
|
if !isSafeStripeID(customerID) {
|
|
return fmt.Errorf("invalid stripe customer id")
|
|
}
|
|
if !isValidOrganizationID(orgID) {
|
|
return fmt.Errorf("invalid org id")
|
|
}
|
|
|
|
if err := os.MkdirAll(i.dir, 0o700); err != nil {
|
|
return fmt.Errorf("create customer org index directory: %w", err)
|
|
}
|
|
|
|
path := filepath.Join(i.dir, customerID+".json")
|
|
data, err := json.Marshal(map[string]any{
|
|
"org_id": orgID,
|
|
"updated_at": time.Now().UTC().UnixMilli(),
|
|
"customer_id": customerID,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("marshal customer org index entry: %w", err)
|
|
}
|
|
tmp := path + ".tmp"
|
|
if err := os.WriteFile(tmp, data, 0o600); err != nil {
|
|
return fmt.Errorf("write customer org index temp file: %w", err)
|
|
}
|
|
if err := os.Rename(tmp, path); err != nil {
|
|
if rmErr := os.Remove(tmp); rmErr != nil && !errors.Is(rmErr, os.ErrNotExist) {
|
|
return errors.Join(
|
|
fmt.Errorf("commit customer org index: %w", err),
|
|
fmt.Errorf("remove customer org index temp file: %w", rmErr),
|
|
)
|
|
}
|
|
return fmt.Errorf("commit customer org index: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isSafeStripeID(id string) bool {
|
|
// Stripe IDs are typically like "cus_...", "sub_...", "evt_...".
|
|
// Keep this strict to avoid filesystem surprises.
|
|
if len(id) < 5 || len(id) > 128 {
|
|
return false
|
|
}
|
|
for i := 0; i < len(id); i++ {
|
|
c := id[i]
|
|
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-' {
|
|
continue
|
|
}
|
|
return false
|
|
}
|
|
if filepath.Base(id) != id {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// fileExists is defined in router.go (same package). Keep a single implementation
|
|
// to avoid duplicate symbol errors across the internal/api package.
|