pulse/license upgrade safety hardening

This commit is contained in:
rcourtman 2026-03-07 15:13:09 +00:00
parent a6f6f66078
commit d6e8bffaeb
4 changed files with 330 additions and 15 deletions

View file

@ -3,6 +3,7 @@ package api
import (
"context"
"encoding/json"
"errors"
"net/http"
"sync"
"time"
@ -20,6 +21,7 @@ type LicenseHandlers struct {
mtPersistence *config.MultiTenantPersistence
legacyPersistence *config.ConfigPersistence
services sync.Map // map[string]*license.Service
loadIssues sync.Map // map[string]string
configDir string // Base config dir, though we use mtPersistence for tenants
auditOnce sync.Once
}
@ -62,11 +64,15 @@ func (h *LicenseHandlers) getTenantComponents(ctx context.Context) (*license.Ser
// Try to load existing license
if persistence != nil {
persisted, err := persistence.LoadWithMetadata()
if err == nil && persisted.LicenseKey != "" {
lic, err := service.Activate(persisted.LicenseKey)
if err != nil {
h.setLoadIssue(orgID, err)
} else if persisted.LicenseKey != "" {
lic, err := service.ActivatePersisted(persisted.LicenseKey)
if err != nil {
h.setLoadIssue(orgID, err)
log.Warn().Str("org_id", orgID).Err(err).Msg("Failed to load saved license")
} else {
h.clearLoadIssue(orgID)
if persisted.GracePeriodEnd != nil && lic != nil {
gracePeriodEnd := time.Unix(*persisted.GracePeriodEnd, 0)
lic.GracePeriodEnd = &gracePeriodEnd
@ -78,6 +84,8 @@ func (h *LicenseHandlers) getTenantComponents(ctx context.Context) (*license.Ser
// Since audit logger is global, we do this once.
h.initAuditLoggerIfLicensed(service, persistence)
}
} else {
h.clearLoadIssue(orgID)
}
}
@ -146,6 +154,37 @@ func (h *LicenseHandlers) Service(ctx context.Context) *license.Service {
return svc
}
func (h *LicenseHandlers) setLoadIssue(orgID string, err error) {
if err == nil {
h.clearLoadIssue(orgID)
return
}
h.loadIssues.Store(orgID, err.Error())
}
func (h *LicenseHandlers) clearLoadIssue(orgID string) {
h.loadIssues.Delete(orgID)
}
func (h *LicenseHandlers) getLoadIssue(orgID string) string {
if v, ok := h.loadIssues.Load(orgID); ok {
if issue, ok := v.(string); ok {
return issue
}
}
return ""
}
func (h *LicenseHandlers) effectiveState(ctx context.Context, service *license.Service) (license.LicenseState, string) {
orgID := GetOrgID(ctx)
loadIssue := h.getLoadIssue(orgID)
if loadIssue != "" && service.Current() == nil {
return license.LicenseStateCorrupt, loadIssue
}
state, _ := service.GetLicenseState()
return state, ""
}
// HandleLicenseStatus handles GET /api/license/status
// Returns the current license status.
func (h *LicenseHandlers) HandleLicenseStatus(w http.ResponseWriter, r *http.Request) {
@ -162,6 +201,9 @@ func (h *LicenseHandlers) HandleLicenseStatus(w http.ResponseWriter, r *http.Req
}
status := service.Status()
state, loadIssue := h.effectiveState(r.Context(), service)
status.State = string(state)
status.LoadError = loadIssue
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
@ -189,7 +231,7 @@ func (h *LicenseHandlers) HandleLicenseFeatures(w http.ResponseWriter, r *http.R
return
}
state, _ := service.GetLicenseState()
state, _ := h.effectiveState(r.Context(), service)
response := LicenseFeaturesResponse{
LicenseStatus: string(state),
Features: map[string]bool{
@ -265,6 +307,7 @@ func (h *LicenseHandlers) HandleActivateLicense(w http.ResponseWriter, r *http.R
return
}
orgID := GetOrgID(r.Context())
lic, err := service.Activate(req.LicenseKey)
if err != nil {
log.Warn().Err(err).Msg("Failed to activate license")
@ -279,16 +322,37 @@ func (h *LicenseHandlers) HandleActivateLicense(w http.ResponseWriter, r *http.R
}
// Persist the license with grace period if applicable
if persistence != nil {
var gracePeriodEnd *int64
if lic.GracePeriodEnd != nil {
ts := lic.GracePeriodEnd.Unix()
gracePeriodEnd = &ts
}
if err := persistence.SaveWithGracePeriod(req.LicenseKey, gracePeriodEnd); err != nil {
log.Warn().Err(err).Msg("Failed to persist license, it won't survive restarts")
}
if persistence == nil {
service.Clear()
persistErr := errors.New("license persistence unavailable")
h.setLoadIssue(orgID, persistErr)
log.Error().Err(persistErr).Msg("Failed to persist license activation")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(ActivateLicenseResponse{
Success: false,
Message: "License could not be persisted",
})
return
}
var gracePeriodEnd *int64
if lic.GracePeriodEnd != nil {
ts := lic.GracePeriodEnd.Unix()
gracePeriodEnd = &ts
}
if err := persistence.SaveWithGracePeriod(req.LicenseKey, gracePeriodEnd); err != nil {
service.Clear()
h.setLoadIssue(orgID, err)
log.Error().Err(err).Msg("Failed to persist license activation")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(ActivateLicenseResponse{
Success: false,
Message: "License could not be persisted",
})
return
}
h.clearLoadIssue(orgID)
log.Info().
Str("email", lic.Claims.Email).
@ -324,6 +388,7 @@ func (h *LicenseHandlers) HandleClearLicense(w http.ResponseWriter, r *http.Requ
}
service.Clear()
h.clearLoadIssue(GetOrgID(r.Context()))
// Clear from persistence
if persistence != nil {

View file

@ -3,9 +3,12 @@ package api
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
@ -14,6 +17,11 @@ import (
)
func createTestHandler(t *testing.T) *LicenseHandlers {
handler, _ := createTestHandlerWithDir(t)
return handler
}
func createTestHandlerWithDir(t *testing.T) (*LicenseHandlers, string) {
tempDir := t.TempDir()
mtp := config.NewMultiTenantPersistence(tempDir)
// Ensure default persistence exists
@ -21,7 +29,19 @@ func createTestHandler(t *testing.T) *LicenseHandlers {
if err != nil {
t.Fatalf("Failed to initialize default persistence: %v", err)
}
return NewLicenseHandlers(mtp)
return NewLicenseHandlers(mtp), tempDir
}
func makeLicenseKeyForClaims(t *testing.T, claims license.Claims) string {
t.Helper()
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"EdDSA","typ":"JWT"}`))
payloadBytes, err := json.Marshal(claims)
if err != nil {
t.Fatalf("failed to marshal test claims: %v", err)
}
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
return header + "." + payload + ".fake-sig"
}
func TestLicenseHandlers_FallbackToLegacyPersistence(t *testing.T) {
@ -153,6 +173,39 @@ func TestHandleLicenseFeatures_WithActiveLicense(t *testing.T) {
}
}
func TestHandleLicenseFeatures_CorruptPersistedLicense(t *testing.T) {
handler, tempDir := createTestHandlerWithDir(t)
licensePath := filepath.Join(tempDir, license.LicenseFileName)
if err := os.WriteFile(licensePath, []byte("%%%not-base64%%%"), 0600); err != nil {
t.Fatalf("failed to write corrupt persisted license: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/license/features", nil)
rec := httptest.NewRecorder()
handler.HandleLicenseFeatures(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
}
var resp licenseFeaturesResponse
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.LicenseStatus != string(license.LicenseStateCorrupt) {
t.Fatalf("expected license_status %q, got %q", license.LicenseStateCorrupt, resp.LicenseStatus)
}
if resp.Features[license.FeatureAIPatrol] != true {
t.Fatalf("expected free-tier feature %q to remain enabled", license.FeatureAIPatrol)
}
if resp.Features[license.FeatureAIAutoFix] {
t.Fatalf("expected Pro-only feature %q to be disabled", license.FeatureAIAutoFix)
}
}
// ========================================
// HandleLicenseStatus tests
// ========================================
@ -192,6 +245,9 @@ func TestHandleLicenseStatus_NoLicense(t *testing.T) {
if resp.Valid {
t.Fatalf("expected Valid=false for no license")
}
if resp.State != string(license.LicenseStateNone) {
t.Fatalf("expected state %q, got %q", license.LicenseStateNone, resp.State)
}
if resp.Tier != license.TierFree {
t.Fatalf("expected tier %q, got %q", license.TierFree, resp.Tier)
}
@ -228,6 +284,9 @@ func TestHandleLicenseStatus_WithActiveLicense(t *testing.T) {
if !resp.Valid {
t.Fatalf("expected Valid=true for active license")
}
if resp.State != string(license.LicenseStateActive) {
t.Fatalf("expected state %q, got %q", license.LicenseStateActive, resp.State)
}
if resp.Email != "test@example.com" {
t.Fatalf("expected email %q, got %q", "test@example.com", resp.Email)
}
@ -236,6 +295,96 @@ func TestHandleLicenseStatus_WithActiveLicense(t *testing.T) {
}
}
func TestHandleLicenseStatus_CorruptPersistedLicense(t *testing.T) {
handler, tempDir := createTestHandlerWithDir(t)
licensePath := filepath.Join(tempDir, license.LicenseFileName)
if err := os.WriteFile(licensePath, []byte("%%%not-base64%%%"), 0600); err != nil {
t.Fatalf("failed to write corrupt persisted license: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/license/status", nil)
rec := httptest.NewRecorder()
handler.HandleLicenseStatus(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
}
var resp license.LicenseStatus
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Valid {
t.Fatalf("expected Valid=false for corrupt persisted license")
}
if resp.State != string(license.LicenseStateCorrupt) {
t.Fatalf("expected state %q, got %q", license.LicenseStateCorrupt, resp.State)
}
if resp.LoadError == "" {
t.Fatalf("expected load_error to be set for corrupt persisted license")
}
if resp.Tier != license.TierFree {
t.Fatalf("expected tier %q, got %q", license.TierFree, resp.Tier)
}
}
func TestHandleLicenseStatus_ExpiredPersistedLicense(t *testing.T) {
t.Setenv("PULSE_LICENSE_DEV_MODE", "true")
handler := createTestHandler(t)
persistence, err := handler.getPersistenceForOrg("default")
if err != nil {
t.Fatalf("failed to get persistence: %v", err)
}
expiredKey := makeLicenseKeyForClaims(t, license.Claims{
LicenseID: "test-expired-persisted",
Email: "expired@example.com",
Tier: license.TierPro,
IssuedAt: time.Now().Add(-40 * 24 * time.Hour).Unix(),
ExpiresAt: time.Now().Add(-10 * 24 * time.Hour).Unix(),
})
if err := persistence.SaveWithGracePeriod(expiredKey, nil); err != nil {
t.Fatalf("failed to persist expired license: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/license/status", nil)
rec := httptest.NewRecorder()
handler.HandleLicenseStatus(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
}
var resp license.LicenseStatus
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Valid {
t.Fatalf("expected Valid=false for expired persisted license")
}
if resp.State != string(license.LicenseStateExpired) {
t.Fatalf("expected state %q, got %q", license.LicenseStateExpired, resp.State)
}
if resp.Email != "expired@example.com" {
t.Fatalf("expected email %q, got %q", "expired@example.com", resp.Email)
}
if resp.Tier != license.TierPro {
t.Fatalf("expected tier %q, got %q", license.TierPro, resp.Tier)
}
if resp.ExpiresAt == nil {
t.Fatalf("expected expires_at to be reported")
}
if resp.LoadError != "" {
t.Fatalf("expected load_error to be empty, got %q", resp.LoadError)
}
}
// ========================================
// HandleActivateLicense tests
// ========================================
@ -359,6 +508,40 @@ func TestHandleActivateLicense_ValidKey(t *testing.T) {
}
}
func TestHandleActivateLicense_PersistenceUnavailableClearsRuntimeState(t *testing.T) {
t.Setenv("PULSE_LICENSE_DEV_MODE", "true")
handler := NewLicenseHandlers(nil)
licenseKey, err := license.GenerateLicenseForTesting("pro@example.com", license.TierPro, 24*time.Hour)
if err != nil {
t.Fatalf("failed to generate test license: %v", err)
}
body, _ := json.Marshal(map[string]string{"license_key": licenseKey})
req := httptest.NewRequest(http.MethodPost, "/api/license/activate", bytes.NewReader(body))
rec := httptest.NewRecorder()
handler.HandleActivateLicense(rec, req)
if rec.Code != http.StatusInternalServerError {
t.Fatalf("expected status %d, got %d: %s", http.StatusInternalServerError, rec.Code, rec.Body.String())
}
var resp ActivateLicenseResponse
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Success {
t.Fatalf("expected Success=false when persistence fails")
}
if resp.Message != "License could not be persisted" {
t.Fatalf("expected message %q, got %q", "License could not be persisted", resp.Message)
}
if handler.Service(context.Background()).Current() != nil {
t.Fatalf("expected runtime license to be cleared after persistence failure")
}
}
// ========================================
// HandleClearLicense tests
// ========================================

View file

@ -270,6 +270,7 @@ type LicenseState string
const (
LicenseStateNone LicenseState = "none"
LicenseStateActive LicenseState = "active"
LicenseStateCorrupt LicenseState = "corrupt"
LicenseStateExpired LicenseState = "expired"
LicenseStateGracePeriod LicenseState = "grace_period"
)
@ -299,6 +300,27 @@ func (s *Service) GetLicenseState() (LicenseState, *License) {
return LicenseStateActive, s.license
}
// ActivatePersisted restores a previously persisted license key.
// Unlike Activate, it accepts licenses that are expired past grace so the
// service can still report an explicit expired state after restart.
func (s *Service) ActivatePersisted(licenseKey string) (*License, error) {
license, err := LoadPersistedLicense(licenseKey)
if err != nil {
return nil, err
}
s.mu.Lock()
s.license = license
cb := s.onLicenseChange
s.mu.Unlock()
if cb != nil {
cb(license)
}
return license, nil
}
// GetLicenseStateString returns the current license state as string and whether features are available
// This implements the LicenseChecker interface for the AI service
func (s *Service) GetLicenseStateString() (string, bool) {
@ -371,6 +393,7 @@ func (s *Service) Status() *LicenseStatus {
// LicenseStatus is the JSON response for license status API.
type LicenseStatus struct {
Valid bool `json:"valid"`
State string `json:"state,omitempty"`
Tier Tier `json:"tier"`
Email string `json:"email,omitempty"`
ExpiresAt *string `json:"expires_at,omitempty"`
@ -381,10 +404,22 @@ type LicenseStatus struct {
MaxGuests int `json:"max_guests,omitempty"`
InGracePeriod bool `json:"in_grace_period,omitempty"`
GracePeriodEnd *string `json:"grace_period_end,omitempty"`
LoadError string `json:"load_error,omitempty"`
}
// ValidateLicense validates a license key and returns the license if valid.
func ValidateLicense(licenseKey string) (*License, error) {
return validateLicense(licenseKey, false)
}
// LoadPersistedLicense validates a persisted license key but does not reject
// licenses that are expired past grace. This lets startup/reporting preserve an
// explicit expired state instead of collapsing to "none" after restart.
func LoadPersistedLicense(licenseKey string) (*License, error) {
return validateLicense(licenseKey, true)
}
func validateLicense(licenseKey string, allowExpired bool) (*License, error) {
// Trim whitespace
licenseKey = strings.TrimSpace(licenseKey)
if licenseKey == "" {
@ -460,12 +495,12 @@ func ValidateLicense(licenseKey string) (*License, error) {
// Grace period: 7 days after expiration
gracePeriodDuration := 7 * 24 * time.Hour
gracePeriodEnd := expirationTime.Add(gracePeriodDuration)
license.GracePeriodEnd = &gracePeriodEnd
if time.Now().Before(gracePeriodEnd) {
// Within grace period - allow activation but mark as in grace period
license.GracePeriodEnd = &gracePeriodEnd
// License is still valid during grace period
} else {
} else if !allowExpired {
// Past grace period - reject
return nil, fmt.Errorf("%w: expired on %s (grace period ended %s)",
ErrExpiredLicense,

View file

@ -311,6 +311,38 @@ func TestValidateLicense_ExpiredPastGrace(t *testing.T) {
}
}
func TestLoadPersistedLicense_ExpiredPastGrace(t *testing.T) {
os.Setenv("PULSE_LICENSE_DEV_MODE", "true")
defer os.Unsetenv("PULSE_LICENSE_DEV_MODE")
claims := Claims{
LicenseID: "test-expired-persisted",
Email: "persisted@pulse.test",
Tier: TierPro,
IssuedAt: time.Now().Add(-40 * 24 * time.Hour).Unix(),
ExpiresAt: time.Now().Add(-10 * 24 * time.Hour).Unix(),
}
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"EdDSA","typ":"JWT"}`))
payloadBytes, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
key := header + "." + payload + ".fake-sig"
lic, err := LoadPersistedLicense(key)
if err != nil {
t.Fatalf("expected persisted expired license to load, got error: %v", err)
}
if lic.Claims.Email != claims.Email {
t.Fatalf("expected email %q, got %q", claims.Email, lic.Claims.Email)
}
if lic.GracePeriodEnd == nil {
t.Fatal("expected GracePeriodEnd to be populated for expired persisted license")
}
if time.Now().Before(*lic.GracePeriodEnd) {
t.Fatal("expected GracePeriodEnd to be in the past for license past grace")
}
}
func TestLicenseStatus(t *testing.T) {
service := NewService()