Pulse/internal/config/billing_state.go
2026-03-18 16:06:30 +00:00

436 lines
16 KiB
Go

package config
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/crypto"
pkglicensing "github.com/rcourtman/pulse-go-rewrite/pkg/licensing"
"github.com/rs/zerolog/log"
)
// Ensure FileBillingStore satisfies the hosted entitlement BillingStore interface.
var _ pkglicensing.BillingStore = (*FileBillingStore)(nil)
// FileBillingStore persists billing state in per-org files under the data directory.
type FileBillingStore struct {
baseDataDir string
mu sync.RWMutex
}
// NewFileBillingStore creates a file-backed billing store rooted at baseDataDir.
func NewFileBillingStore(baseDataDir string) *FileBillingStore {
return &FileBillingStore{baseDataDir: baseDataDir}
}
// GetBillingState returns the current billing state for an org.
// Missing billing files are treated as "no state yet" and return (nil, nil).
// If the state has been tampered with (invalid HMAC), it is treated as nonexistent.
func (s *FileBillingStore) GetBillingState(orgID string) (*pkglicensing.BillingState, error) {
billingPath, err := s.billingStatePath(orgID)
if err != nil {
return nil, err
}
// Read file under read lock, then release before potential migration write.
s.mu.RLock()
data, err := os.ReadFile(billingPath)
s.mu.RUnlock()
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
return nil, fmt.Errorf("read billing state for org %q: %w", orgID, err)
}
if len(data) == 0 {
return nil, nil
}
var state pkglicensing.BillingState
if err := json.Unmarshal(data, &state); err != nil {
return nil, fmt.Errorf("decode billing state for org %q: %w", orgID, err)
}
cryptoMgr, cryptoErr := s.billingCryptoManagerForSecrets(state.EntitlementJWT, state.EntitlementRefreshToken)
needsPersist := false
decodeSecret := func(fieldName, value string) string {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return ""
}
if cryptoErr != nil {
log.Warn().
Err(cryptoErr).
Str("org_id", orgID).
Str("field", fieldName).
Msg("Dropping persisted billing secret because billing crypto is unavailable")
needsPersist = true
return ""
}
if decrypted, err := cryptoMgr.DecryptString(trimmed); err == nil {
return strings.TrimSpace(decrypted)
}
// Legacy billing.json files persisted hosted entitlement secrets in plaintext.
// Keep the plaintext value in memory for integrity verification, then rewrite
// the file through the canonical encrypted-at-rest path below.
needsPersist = true
return trimmed
}
state.EntitlementJWT = decodeSecret("entitlement_jwt", state.EntitlementJWT)
state.EntitlementRefreshToken = decodeSecret("entitlement_refresh_token", state.EntitlementRefreshToken)
// Integrity verification: derive HMAC key from .encryption.key.
// If the key is unavailable (new install, key not yet created), skip checks.
hmacKey, keyErr := s.loadHMACKey()
if keyErr == nil {
if state.Integrity == "" {
// Migration: pre-upgrade state without integrity. Compute and persist.
state.Integrity = billingIntegrity(&state, hmacKey)
needsPersist = true
} else if !verifyBillingIntegrity(&state, hmacKey) {
// Try migration chain: v6-pre-quickstart → pre-v6-legacy → tampered.
if verifyBillingIntegrityV6PreQuickstart(&state, hmacKey) {
// Valid v6 signature (before quickstart fields) — re-sign with current format.
state.CommercialMigration = nil
state.Integrity = billingIntegrity(&state, hmacKey)
needsPersist = true
} else if verifyBillingIntegrityLegacy(&state, hmacKey) {
// Valid legacy signature — re-sign with current format and persist.
// Strip fields that didn't exist in the legacy era to prevent
// injection via tampered JSON before the HMAC covered them.
state.OverflowGrantedAt = nil
state.CommercialMigration = nil
state.QuickstartCreditsGranted = false
state.QuickstartCreditsUsed = 0
state.QuickstartCreditsGrantedAt = nil
state.Integrity = billingIntegrity(&state, hmacKey)
needsPersist = true
} else {
// Tampered state — treat as nonexistent (free tier).
return nil, nil
}
}
}
if needsPersist {
if saveErr := s.SaveBillingState(orgID, &state); saveErr != nil {
log.Warn().
Err(saveErr).
Str("org_id", orgID).
Msg("Failed to persist canonical billing-state migration")
}
}
if strings.TrimSpace(state.EntitlementJWT) != "" {
resolved := pkglicensing.ResolveEntitlementLeaseBillingState(state, "", time.Now().UTC())
return pkglicensing.NormalizeBillingState(&resolved), nil
}
return pkglicensing.NormalizeBillingState(&state), nil
}
// SaveBillingState persists billing state for an org to billing.json.
func (s *FileBillingStore) SaveBillingState(orgID string, state *pkglicensing.BillingState) error {
if state == nil {
return errors.New("billing state is required")
}
normalized := pkglicensing.NormalizeBillingState(state)
*state = *normalized
persistedState := *normalized
integrityState := *normalized
cryptoMgr, cryptoErr := s.billingCryptoManagerForSecrets(normalized.EntitlementJWT, normalized.EntitlementRefreshToken)
encodeSecret := func(fieldName, value string) string {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return ""
}
if cryptoErr != nil {
log.Warn().
Err(cryptoErr).
Str("org_id", orgID).
Str("field", fieldName).
Msg("Dropping billing secret because billing crypto is unavailable")
switch fieldName {
case "entitlement_jwt":
integrityState.EntitlementJWT = ""
case "entitlement_refresh_token":
integrityState.EntitlementRefreshToken = ""
}
return ""
}
encrypted, err := cryptoMgr.EncryptString(trimmed)
if err != nil {
log.Warn().
Err(err).
Str("org_id", orgID).
Str("field", fieldName).
Msg("Dropping billing secret because billing encryption failed")
switch fieldName {
case "entitlement_jwt":
integrityState.EntitlementJWT = ""
case "entitlement_refresh_token":
integrityState.EntitlementRefreshToken = ""
}
return ""
}
return encrypted
}
persistedState.EntitlementJWT = encodeSecret("entitlement_jwt", normalized.EntitlementJWT)
persistedState.EntitlementRefreshToken = encodeSecret("entitlement_refresh_token", normalized.EntitlementRefreshToken)
// Compute integrity HMAC if encryption key is available.
if hmacKey, err := s.loadHMACKey(); err == nil {
integrity := billingIntegrity(&integrityState, hmacKey)
state.Integrity = integrity
persistedState.Integrity = integrity
}
billingPath, err := s.billingStatePath(orgID)
if err != nil {
return fmt.Errorf("resolve billing state path for org %q: %w", orgID, err)
}
data, err := json.Marshal(&persistedState)
if err != nil {
return fmt.Errorf("encode billing state for org %q: %w", orgID, err)
}
s.mu.Lock()
defer s.mu.Unlock()
if err := os.MkdirAll(filepath.Dir(billingPath), 0o700); err != nil {
return fmt.Errorf("create billing directory for org %q: %w", orgID, err)
}
tmpPath := billingPath + ".tmp"
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
return fmt.Errorf("write temp billing state for org %q: %w", orgID, err)
}
if err := os.Rename(tmpPath, billingPath); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) {
log.Warn().
Err(removeErr).
Str("tmp_path", tmpPath).
Str("org_id", orgID).
Msg("Failed to remove temporary billing state file after failed rename")
}
return fmt.Errorf("commit billing state for org %q: %w", orgID, err)
}
return nil
}
func (s *FileBillingStore) billingCryptoManager() (*crypto.CryptoManager, error) {
return crypto.NewCryptoManagerAt(s.resolveDataDir())
}
func (s *FileBillingStore) billingCryptoManagerForSecrets(values ...string) (*crypto.CryptoManager, error) {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return s.billingCryptoManager()
}
}
return nil, errors.New("billing crypto not required")
}
func (s *FileBillingStore) billingStatePath(orgID string) (string, error) {
orgID = strings.TrimSpace(orgID)
if !isValidOrgID(orgID) {
return "", fmt.Errorf("invalid organization ID: %s", orgID)
}
// Default org stores config at the root data dir for backward compatibility,
// so billing state for the default org must live alongside other root configs.
if orgID == "default" {
return filepath.Join(s.resolveDataDir(), "billing.json"), nil
}
return filepath.Join(s.resolveDataDir(), "orgs", orgID, "billing.json"), nil
}
func (s *FileBillingStore) resolveDataDir() string {
return ResolveRuntimeDataDir(s.baseDataDir)
}
// loadHMACKey derives a purpose-specific HMAC key from the .encryption.key file.
// Returns an error if the key file is missing or invalid (graceful degradation).
func (s *FileBillingStore) loadHMACKey() ([]byte, error) {
keyPath := filepath.Join(s.resolveDataDir(), ".encryption.key")
raw, err := os.ReadFile(keyPath)
if err != nil {
return nil, err
}
trimmed := strings.TrimSpace(string(raw))
decoded := make([]byte, base64.StdEncoding.DecodedLen(len(trimmed)))
n, err := base64.StdEncoding.Decode(decoded, []byte(trimmed))
if err != nil || n != 32 {
return nil, fmt.Errorf("invalid encryption key")
}
// Domain-separated key: SHA256("pulse-billing-integrity-" || raw_key)
h := sha256.New()
h.Write([]byte("pulse-billing-integrity-"))
h.Write(decoded[:n])
return h.Sum(nil), nil
}
// billingIntegrityPayload contains only the entitlement-critical fields used
// for HMAC computation. Non-critical metadata (e.g. Stripe IDs) is excluded so
// adding informational fields to BillingState won't break existing signatures.
//
// IMPORTANT: When adding a new field to BillingState that gates entitlements or
// affects billing logic, add it here too. Existing on-disk signatures will
// auto-migrate on next read (see GetBillingState migration path).
type billingIntegrityPayload struct {
Capabilities []string `json:"capabilities"`
Limits map[string]int64 `json:"limits"`
EntitlementJWT string `json:"entitlement_jwt"`
EntitlementRefreshToken string `json:"entitlement_refresh_token"`
CommercialMigration *pkglicensing.CommercialMigrationStatus `json:"commercial_migration,omitempty"`
PlanVersion string `json:"plan_version"`
SubscriptionState pkglicensing.SubscriptionState `json:"subscription_state"`
TrialStartedAt *int64 `json:"trial_started_at"`
TrialEndsAt *int64 `json:"trial_ends_at"`
TrialExtendedAt *int64 `json:"trial_extended_at"`
OverflowGrantedAt *int64 `json:"overflow_granted_at"`
// Quickstart credits gate free hosted Patrol runs — must be HMAC-protected.
QuickstartCreditsGranted bool `json:"quickstart_credits_granted"`
QuickstartCreditsUsed int `json:"quickstart_credits_used"`
QuickstartCreditsGrantedAt *int64 `json:"quickstart_credits_granted_at"`
}
// billingIntegrity computes the HMAC-SHA256 over the critical billing fields.
func billingIntegrity(state *pkglicensing.BillingState, key []byte) string {
caps := make([]string, len(state.Capabilities))
copy(caps, state.Capabilities)
sort.Strings(caps)
// Clone and canonicalize limits: nil → empty map for deterministic JSON.
// Snapshot avoids aliasing; callers must not mutate state concurrently.
limits := make(map[string]int64, len(state.Limits))
for k, v := range state.Limits {
limits[k] = v
}
payload := billingIntegrityPayload{
Capabilities: caps,
Limits: limits,
EntitlementJWT: strings.TrimSpace(state.EntitlementJWT),
EntitlementRefreshToken: strings.TrimSpace(state.EntitlementRefreshToken),
CommercialMigration: pkglicensing.CloneCommercialMigrationStatus(state.CommercialMigration),
PlanVersion: state.PlanVersion,
SubscriptionState: state.SubscriptionState,
TrialStartedAt: state.TrialStartedAt,
TrialEndsAt: state.TrialEndsAt,
TrialExtendedAt: state.TrialExtendedAt,
OverflowGrantedAt: state.OverflowGrantedAt,
QuickstartCreditsGranted: state.QuickstartCreditsGranted,
QuickstartCreditsUsed: state.QuickstartCreditsUsed,
QuickstartCreditsGrantedAt: state.QuickstartCreditsGrantedAt,
}
data, _ := json.Marshal(payload) // struct marshal cannot fail
mac := hmac.New(sha256.New, key)
mac.Write(data)
return hex.EncodeToString(mac.Sum(nil))
}
// verifyBillingIntegrity checks whether the stored HMAC matches the computed one.
func verifyBillingIntegrity(state *pkglicensing.BillingState, key []byte) bool {
expected := billingIntegrity(state, key)
return hmac.Equal([]byte(expected), []byte(state.Integrity))
}
// billingIntegrityPayloadLegacy is the pre-v6 HMAC payload format (without Limits).
// Kept only for migration verification — new signatures always use billingIntegrityPayload.
type billingIntegrityPayloadLegacy struct {
Capabilities []string `json:"capabilities"`
PlanVersion string `json:"plan_version"`
SubscriptionState pkglicensing.SubscriptionState `json:"subscription_state"`
TrialStartedAt *int64 `json:"trial_started_at"`
TrialEndsAt *int64 `json:"trial_ends_at"`
TrialExtendedAt *int64 `json:"trial_extended_at"`
}
// billingIntegrityLegacy computes the legacy HMAC (without Limits) for migration checks.
func billingIntegrityLegacy(state *pkglicensing.BillingState, key []byte) string {
caps := make([]string, len(state.Capabilities))
copy(caps, state.Capabilities)
sort.Strings(caps)
payload := billingIntegrityPayloadLegacy{
Capabilities: caps,
PlanVersion: state.PlanVersion,
SubscriptionState: state.SubscriptionState,
TrialStartedAt: state.TrialStartedAt,
TrialEndsAt: state.TrialEndsAt,
TrialExtendedAt: state.TrialExtendedAt,
}
data, _ := json.Marshal(payload)
mac := hmac.New(sha256.New, key)
mac.Write(data)
return hex.EncodeToString(mac.Sum(nil))
}
// verifyBillingIntegrityLegacy checks if a stored HMAC was signed with the legacy format.
func verifyBillingIntegrityLegacy(state *pkglicensing.BillingState, key []byte) bool {
expected := billingIntegrityLegacy(state, key)
return hmac.Equal([]byte(expected), []byte(state.Integrity))
}
// billingIntegrityPayloadV6PreQuickstart is the v6 HMAC format before quickstart credits.
// Has Limits and OverflowGrantedAt but not the quickstart fields.
type billingIntegrityPayloadV6PreQuickstart struct {
Capabilities []string `json:"capabilities"`
Limits map[string]int64 `json:"limits"`
PlanVersion string `json:"plan_version"`
SubscriptionState pkglicensing.SubscriptionState `json:"subscription_state"`
TrialStartedAt *int64 `json:"trial_started_at"`
TrialEndsAt *int64 `json:"trial_ends_at"`
TrialExtendedAt *int64 `json:"trial_extended_at"`
OverflowGrantedAt *int64 `json:"overflow_granted_at"`
}
func billingIntegrityV6PreQuickstart(state *pkglicensing.BillingState, key []byte) string {
caps := make([]string, len(state.Capabilities))
copy(caps, state.Capabilities)
sort.Strings(caps)
limits := make(map[string]int64, len(state.Limits))
for k, v := range state.Limits {
limits[k] = v
}
payload := billingIntegrityPayloadV6PreQuickstart{
Capabilities: caps,
Limits: limits,
PlanVersion: state.PlanVersion,
SubscriptionState: state.SubscriptionState,
TrialStartedAt: state.TrialStartedAt,
TrialEndsAt: state.TrialEndsAt,
TrialExtendedAt: state.TrialExtendedAt,
OverflowGrantedAt: state.OverflowGrantedAt,
}
data, _ := json.Marshal(payload)
mac := hmac.New(sha256.New, key)
mac.Write(data)
return hex.EncodeToString(mac.Sum(nil))
}
func verifyBillingIntegrityV6PreQuickstart(state *pkglicensing.BillingState, key []byte) bool {
expected := billingIntegrityV6PreQuickstart(state, key)
return hmac.Equal([]byte(expected), []byte(state.Integrity))
}