Pulse/internal/api/hosted_entitlement_refresh.go

468 lines
14 KiB
Go

package api
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"math/rand/v2"
"net/http"
"strings"
"sync"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"github.com/rs/zerolog/log"
)
const (
hostedEntitlementRefreshDefaultInterval = 2 * time.Hour
hostedEntitlementRefreshMinInterval = 15 * time.Minute
hostedEntitlementRefreshMaxInterval = 6 * time.Hour
hostedEntitlementRefreshImmediateWindow = 30 * time.Minute
hostedEntitlementRefreshBackoffMin = 30 * time.Second
hostedEntitlementRefreshBackoffMax = 30 * time.Minute
hostedEntitlementRefreshJitter = 0.2
)
type hostedEntitlementRefreshLoop struct {
mu sync.Mutex
cancel context.CancelFunc
wg sync.WaitGroup
running bool
}
type hostedEntitlementRefreshError struct {
statusCode int
message string
permanent bool
}
func (e *hostedEntitlementRefreshError) Error() string {
if e == nil {
return ""
}
if strings.TrimSpace(e.message) != "" {
return e.message
}
return fmt.Sprintf("hosted entitlement refresh failed with status %d", e.statusCode)
}
func (h *LicenseHandlers) ensureHostedEntitlementRefreshForOrg(orgID string, service *licenseService) {
if h == nil || h.mtPersistence == nil {
return
}
requestedOrgID := normalizeHostedEntitlementOrgID(orgID)
effectiveOrgID, state, err := h.loadHostedEntitlementRefreshTarget(requestedOrgID)
if err != nil {
log.Warn().Err(err).Str("org_id", requestedOrgID).Msg("Failed to load billing state for hosted entitlement refresh")
return
}
if service != nil && service.Current() != nil {
h.stopHostedEntitlementRefreshLoop(effectiveOrgID)
if effectiveOrgID != requestedOrgID {
h.stopHostedEntitlementRefreshLoop(requestedOrgID)
}
return
}
if !hasHostedEntitlementRefreshToken(state) {
h.stopHostedEntitlementRefreshLoop(effectiveOrgID)
if effectiveOrgID != requestedOrgID {
h.stopHostedEntitlementRefreshLoop(requestedOrgID)
}
return
}
loop := h.hostedEntitlementRefreshLoop(effectiveOrgID)
if !loop.isRunning() && hostedEntitlementNeedsImmediateRefresh(state) {
if _, permanent, err := h.refreshHostedEntitlementLeaseOnce(effectiveOrgID, service); err != nil {
if permanent {
log.Warn().Err(err).Str("org_id", effectiveOrgID).Msg("Permanent hosted entitlement refresh failure during initialization")
h.stopHostedEntitlementRefreshLoop(effectiveOrgID)
return
}
log.Warn().Err(err).Str("org_id", effectiveOrgID).Msg("Hosted entitlement refresh initialization failed")
}
}
h.startHostedEntitlementRefreshLoop(effectiveOrgID)
if effectiveOrgID != requestedOrgID {
h.stopHostedEntitlementRefreshLoop(requestedOrgID)
}
}
func hasHostedEntitlementRefreshToken(state *billingState) bool {
return state != nil && strings.TrimSpace(state.EntitlementRefreshToken) != ""
}
func (h *LicenseHandlers) loadHostedEntitlementRefreshTarget(orgID string) (string, *billingState, error) {
requestedOrgID := normalizeHostedEntitlementOrgID(orgID)
if h == nil || h.mtPersistence == nil {
return requestedOrgID, nil, nil
}
state, effectiveOrgID, err := config.LoadEffectiveEntitlementBillingState(h.mtPersistence.BaseDataDir(), requestedOrgID)
if err != nil {
return "", nil, err
}
var normalized *billingState
if state != nil {
normalized = normalizeBillingStateFromLicensing(state)
}
return normalizeHostedEntitlementOrgID(effectiveOrgID), normalized, nil
}
func hostedEntitlementNeedsImmediateRefresh(state *billingState) bool {
if state == nil || strings.TrimSpace(state.EntitlementRefreshToken) == "" {
return false
}
leaseClaims, err := hostedEntitlementLeaseClaimsFromState(state)
if err != nil || leaseClaims == nil || leaseClaims.ExpiresAt == nil {
return true
}
return time.Until(leaseClaims.ExpiresAt.Time) <= hostedEntitlementRefreshImmediateWindow
}
func hostedEntitlementLeaseClaimsFromState(state *billingState) (*entitlementLeaseClaimsModel, error) {
if state == nil || strings.TrimSpace(state.EntitlementJWT) == "" {
return nil, fmt.Errorf("entitlement lease token is required")
}
publicKey, err := trialActivationPublicKeyFromLicensing()
if err != nil {
return nil, err
}
return parseEntitlementLeaseTokenFromLicensing(state.EntitlementJWT, publicKey, "")
}
func (h *LicenseHandlers) hostedEntitlementInstanceHost(state *billingState) string {
if host := entitlementExpectedInstanceHost(h.cfg); host != "" {
return host
}
claims, err := hostedEntitlementLeaseClaimsFromState(state)
if err != nil || claims == nil {
return ""
}
return normalizeHostForTrial(claims.InstanceHost)
}
func (h *LicenseHandlers) hostedEntitlementRefreshLoop(orgID string) *hostedEntitlementRefreshLoop {
if h == nil {
return nil
}
if loop, ok := h.hostedLeaseRefresh.Load(orgID); ok {
if typed, ok := loop.(*hostedEntitlementRefreshLoop); ok {
return typed
}
}
loop := &hostedEntitlementRefreshLoop{}
actual, _ := h.hostedLeaseRefresh.LoadOrStore(orgID, loop)
if typed, ok := actual.(*hostedEntitlementRefreshLoop); ok {
return typed
}
return loop
}
func (l *hostedEntitlementRefreshLoop) isRunning() bool {
if l == nil {
return false
}
l.mu.Lock()
defer l.mu.Unlock()
return l.running
}
func (h *LicenseHandlers) startHostedEntitlementRefreshLoop(orgID string) {
loop := h.hostedEntitlementRefreshLoop(orgID)
if loop == nil {
return
}
loop.mu.Lock()
defer loop.mu.Unlock()
if loop.running {
return
}
ctx, cancel := context.WithCancel(context.Background())
loop.cancel = cancel
loop.running = true
loop.wg.Add(1)
go func() {
defer func() {
loop.mu.Lock()
loop.running = false
loop.mu.Unlock()
loop.wg.Done()
}()
h.runHostedEntitlementRefreshLoop(ctx, orgID)
}()
}
func (h *LicenseHandlers) stopHostedEntitlementRefreshLoop(orgID string) {
if h == nil {
return
}
value, ok := h.hostedLeaseRefresh.Load(orgID)
if !ok {
return
}
loop, ok := value.(*hostedEntitlementRefreshLoop)
if !ok || loop == nil {
h.hostedLeaseRefresh.Delete(orgID)
return
}
loop.mu.Lock()
if !loop.running {
loop.mu.Unlock()
h.hostedLeaseRefresh.Delete(orgID)
return
}
loop.cancel()
loop.running = false
loop.mu.Unlock()
loop.wg.Wait()
h.hostedLeaseRefresh.Delete(orgID)
}
func (h *LicenseHandlers) runHostedEntitlementRefreshLoop(ctx context.Context, orgID string) {
consecutiveFailures := 0
for {
interval, ok := h.nextHostedEntitlementRefreshInterval(orgID, consecutiveFailures)
if !ok {
return
}
select {
case <-ctx.Done():
return
case <-time.After(interval):
}
refreshed, permanent, err := h.refreshHostedEntitlementLeaseOnce(orgID, nil)
if err != nil {
if permanent {
log.Warn().Err(err).Str("org_id", orgID).Msg("Stopping hosted entitlement refresh loop after permanent failure")
return
}
consecutiveFailures++
log.Warn().
Err(err).
Str("org_id", orgID).
Int("consecutive_failures", consecutiveFailures).
Dur("next_retry", hostedEntitlementRefreshBackoff(consecutiveFailures)).
Msg("Hosted entitlement refresh failed")
continue
}
if !refreshed {
return
}
consecutiveFailures = 0
}
}
func (h *LicenseHandlers) nextHostedEntitlementRefreshInterval(orgID string, consecutiveFailures int) (time.Duration, bool) {
if consecutiveFailures > 0 {
return hostedEntitlementRefreshBackoff(consecutiveFailures), true
}
if h == nil || h.mtPersistence == nil {
return 0, false
}
billingStore := config.NewFileBillingStore(h.mtPersistence.BaseDataDir())
state, err := billingStore.GetBillingState(orgID)
if err != nil || !hasHostedEntitlementRefreshToken(state) {
return 0, false
}
claims, err := hostedEntitlementLeaseClaimsFromState(state)
if err != nil || claims == nil || claims.ExpiresAt == nil {
return time.Minute, true
}
remaining := time.Until(claims.ExpiresAt.Time)
if remaining <= hostedEntitlementRefreshImmediateWindow {
return time.Minute, true
}
interval := remaining / 2
if interval < hostedEntitlementRefreshMinInterval {
interval = hostedEntitlementRefreshMinInterval
}
if interval > hostedEntitlementRefreshMaxInterval {
interval = hostedEntitlementRefreshMaxInterval
}
return withHostedEntitlementRefreshJitter(interval), true
}
func hostedEntitlementRefreshBackoff(consecutiveFailures int) time.Duration {
if consecutiveFailures <= 0 {
return hostedEntitlementRefreshDefaultInterval
}
backoff := hostedEntitlementRefreshBackoffMin * (1 << min(consecutiveFailures-1, 10))
if backoff > hostedEntitlementRefreshBackoffMax {
backoff = hostedEntitlementRefreshBackoffMax
}
return backoff
}
func withHostedEntitlementRefreshJitter(interval time.Duration) time.Duration {
jitterRange := float64(interval) * hostedEntitlementRefreshJitter
offset := (rand.Float64()*2 - 1) * jitterRange
return interval + time.Duration(offset)
}
func (h *LicenseHandlers) refreshHostedEntitlementLeaseOnce(orgID string, service *licenseService) (bool, bool, error) {
if h == nil || h.mtPersistence == nil {
return false, true, nil
}
if service == nil {
if value, ok := h.services.Load(orgID); ok {
if typed, ok := value.(*licenseService); ok {
service = typed
}
}
}
if service != nil && service.Current() != nil {
return false, true, nil
}
billingStore := config.NewFileBillingStore(h.mtPersistence.BaseDataDir())
effectiveOrgID, state, err := h.loadHostedEntitlementRefreshTarget(orgID)
if err != nil {
return false, false, fmt.Errorf("load hosted entitlement refresh target: %w", err)
}
if !hasHostedEntitlementRefreshToken(state) {
return false, true, nil
}
instanceHost := h.hostedEntitlementInstanceHost(state)
if instanceHost == "" {
return false, true, fmt.Errorf("hosted entitlement instance host is unavailable")
}
response, err := h.requestHostedEntitlementLeaseRefresh(effectiveOrgID, instanceHost, state.EntitlementRefreshToken)
if err != nil {
var refreshErr *hostedEntitlementRefreshError
if errors.As(err, &refreshErr) && refreshErr != nil && refreshErr.permanent {
if clearErr := h.clearHostedEntitlementState(effectiveOrgID, billingStore); clearErr != nil {
log.Warn().Err(clearErr).Str("org_id", effectiveOrgID).Msg("Failed to clear hosted entitlement state after permanent refresh failure")
}
if service != nil && service.Current() == nil {
service.SetEvaluator(nil)
_ = h.ensureEvaluatorForOrg(effectiveOrgID, service)
}
return false, true, err
}
return false, false, err
}
publicKey, err := trialActivationPublicKeyFromLicensing()
if err != nil {
return false, false, fmt.Errorf("load entitlement verification key: %w", err)
}
leaseClaims, err := verifyEntitlementLeaseTokenFromLicensing(response.EntitlementJWT, publicKey, instanceHost, time.Now().UTC())
if err != nil {
return false, false, fmt.Errorf("verify refreshed entitlement lease: %w", err)
}
if normalizeHostedEntitlementOrgID(leaseClaims.OrgID) != normalizeHostedEntitlementOrgID(effectiveOrgID) {
return false, false, fmt.Errorf("refreshed entitlement lease org mismatch")
}
updated := normalizeBillingStateFromLicensing(state)
updated.EntitlementJWT = strings.TrimSpace(response.EntitlementJWT)
updated.Capabilities = []string{}
updated.Limits = map[string]int64{}
updated.MetersEnabled = []string{}
updated.PlanVersion = ""
updated.SubscriptionState = ""
updated.TrialStartedAt = leaseClaims.TrialStartedAt
updated.TrialEndsAt = nil
updated.TrialExtendedAt = nil
updated.GrantQuickstartCredits()
if err := billingStore.SaveBillingState(effectiveOrgID, updated); err != nil {
return false, false, fmt.Errorf("save refreshed entitlement lease: %w", err)
}
if service != nil && service.Current() == nil {
service.SetEvaluator(newLicenseEvaluatorForBillingStoreFromLicensing(billingStore, effectiveOrgID, 0, instanceHost))
}
return true, false, nil
}
func (h *LicenseHandlers) clearHostedEntitlementState(orgID string, billingStore *config.FileBillingStore) error {
if billingStore == nil {
return nil
}
existing, err := billingStore.GetBillingState(orgID)
if err != nil {
return err
}
if existing == nil {
return nil
}
existing.Capabilities = []string{}
existing.Limits = map[string]int64{}
existing.MetersEnabled = []string{}
existing.EntitlementJWT = ""
existing.EntitlementRefreshToken = ""
existing.PlanVersion = string(subscriptionStateExpiredValue)
existing.SubscriptionState = subscriptionStateExpiredValue
existing.TrialEndsAt = nil
existing.TrialExtendedAt = nil
return billingStore.SaveBillingState(orgID, existing)
}
func (h *LicenseHandlers) requestHostedEntitlementLeaseRefresh(orgID, instanceHost, refreshToken string) (*hostedTrialLeaseRefreshResponse, error) {
if h == nil {
return nil, &hostedEntitlementRefreshError{permanent: true, message: "license handlers unavailable"}
}
refreshURL := hostedEntitlementRefreshURLFromConfig(h.cfg)
if refreshURL == "" {
return nil, &hostedEntitlementRefreshError{permanent: true, message: "hosted entitlement refresh URL unavailable"}
}
payload, err := json.Marshal(hostedTrialLeaseRefreshRequest{
OrgID: normalizeHostedEntitlementOrgID(orgID),
InstanceHost: strings.TrimSpace(instanceHost),
EntitlementRefreshToken: strings.TrimSpace(refreshToken),
})
if err != nil {
return nil, fmt.Errorf("marshal hosted entitlement refresh request: %w", err)
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, refreshURL, bytes.NewReader(payload))
if err != nil {
return nil, fmt.Errorf("build hosted entitlement refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := (&http.Client{Timeout: 5 * time.Second}).Do(req)
if err != nil {
return nil, fmt.Errorf("post hosted entitlement refresh request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
permanent := resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusGone || resp.StatusCode == http.StatusBadRequest || resp.StatusCode == http.StatusNotFound
return nil, &hostedEntitlementRefreshError{
statusCode: resp.StatusCode,
message: fmt.Sprintf("hosted entitlement refresh returned status %d", resp.StatusCode),
permanent: permanent,
}
}
var response hostedTrialLeaseRefreshResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, fmt.Errorf("decode hosted entitlement refresh response: %w", err)
}
if strings.TrimSpace(response.EntitlementJWT) == "" {
return nil, fmt.Errorf("hosted entitlement refresh response missing entitlement_jwt")
}
return &response, nil
}
func normalizeHostedEntitlementOrgID(raw string) string {
orgID := strings.TrimSpace(raw)
if orgID == "" {
return "default"
}
return orgID
}