Pulse/internal/api/entitlement_handlers_test.go

656 lines
22 KiB
Go

package api
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/rcourtman/pulse-go-rewrite/internal/config"
"github.com/rcourtman/pulse-go-rewrite/internal/license"
"github.com/rcourtman/pulse-go-rewrite/internal/license/entitlements"
pkglicensing "github.com/rcourtman/pulse-go-rewrite/pkg/licensing"
)
func containsCapability(values []string, key string) bool {
for _, value := range values {
if value == key {
return true
}
}
return false
}
func TestBuildEntitlementPayload_ActiveLicense(t *testing.T) {
status := &license.LicenseStatus{
Valid: true,
Tier: license.TierPro,
Features: append([]string(nil), license.TierFeatures[license.TierPro]...),
MaxMonitoredSystems: 50,
}
payload := buildEntitlementPayload(status, "")
if payload.SubscriptionState != string(license.SubStateActive) {
t.Fatalf("expected subscription_state %q, got %q", license.SubStateActive, payload.SubscriptionState)
}
if !reflect.DeepEqual(payload.Capabilities, status.Features) {
t.Fatalf("expected capabilities to match status features")
}
var agentLimit *LimitStatus
for i := range payload.Limits {
if payload.Limits[i].Key == "max_monitored_systems" {
agentLimit = &payload.Limits[i]
break
}
}
if agentLimit == nil {
t.Fatalf("expected max_monitored_systems limit in payload")
}
if agentLimit.Limit != 50 {
t.Fatalf("expected max_monitored_systems limit 50, got %d", agentLimit.Limit)
}
if len(payload.UpgradeReasons) != 0 {
t.Fatalf("expected no upgrade reasons for pro tier, got %d", len(payload.UpgradeReasons))
}
}
func TestBuildEntitlementPayload_FreeTier(t *testing.T) {
status := &license.LicenseStatus{
Valid: true,
Tier: license.TierFree,
Features: append([]string(nil), license.TierFeatures[license.TierFree]...),
}
payload := buildEntitlementPayload(status, "")
// Upgrade reasons should cover every Pro feature not in Free.
proMinusFree := countProMinusFreeFeatures()
if len(payload.UpgradeReasons) != proMinusFree {
t.Fatalf("expected %d upgrade reasons for free tier, got %d", proMinusFree, len(payload.UpgradeReasons))
}
for _, reason := range payload.UpgradeReasons {
if reason.ActionURL == "" {
t.Fatalf("expected action_url for reason %q", reason.Key)
}
}
}
func TestBuildEntitlementPayloadWithUsage_CurrentValues(t *testing.T) {
status := &license.LicenseStatus{
Valid: true,
Tier: license.TierPro,
Features: append([]string(nil), license.TierFeatures[license.TierPro]...),
MaxMonitoredSystems: 50,
MaxGuests: 100,
}
payload := buildEntitlementPayloadWithUsage(status, "", entitlementUsageSnapshot{
MonitoredSystems: 12,
MonitoredSystemsAvailable: true,
Guests: 44,
LegacyConnections: pkglicensing.LegacyConnectionCounts{
ProxmoxNodes: 2,
DockerHosts: 1,
KubernetesClusters: 3,
},
}, nil)
var agentLimit *LimitStatus
var guestLimit *LimitStatus
for i := range payload.Limits {
if payload.Limits[i].Key == "max_monitored_systems" {
agentLimit = &payload.Limits[i]
}
if payload.Limits[i].Key == "max_guests" {
guestLimit = &payload.Limits[i]
}
}
if agentLimit == nil {
t.Fatalf("expected max_monitored_systems limit")
}
if guestLimit == nil {
t.Fatalf("expected max_guests limit")
}
if agentLimit.Current != 12 {
t.Fatalf("expected agent current 12, got %d", agentLimit.Current)
}
if agentLimit.CurrentAvailable == nil || !*agentLimit.CurrentAvailable {
t.Fatalf("expected agent current availability true, got %+v", agentLimit.CurrentAvailable)
}
if guestLimit.Current != 44 {
t.Fatalf("expected guest current 44, got %d", guestLimit.Current)
}
if payload.LegacyConnections.ProxmoxNodes != 2 {
t.Fatalf("expected proxmox_nodes 2, got %d", payload.LegacyConnections.ProxmoxNodes)
}
if payload.HasMigrationGap {
t.Fatal("expected has_migration_gap=false under monitored-system counting")
}
}
func TestBuildEntitlementPayloadWithUsage_MonitoredSystemUsageUnavailable(t *testing.T) {
status := &license.LicenseStatus{
Valid: true,
Tier: license.TierPro,
Features: append([]string(nil), license.TierFeatures[license.TierPro]...),
MaxMonitoredSystems: 50,
}
payload := buildEntitlementPayloadWithUsage(status, "", entitlementUsageSnapshot{}, nil)
if len(payload.Limits) != 1 {
t.Fatalf("expected one monitored-system limit, got %d", len(payload.Limits))
}
if payload.Limits[0].Current != 0 {
t.Fatalf("expected unresolved current to remain 0, got %d", payload.Limits[0].Current)
}
if payload.Limits[0].CurrentAvailable == nil || *payload.Limits[0].CurrentAvailable {
t.Fatalf("expected unresolved current availability false, got %+v", payload.Limits[0].CurrentAvailable)
}
}
func TestBuildEntitlementPayload_Expired(t *testing.T) {
status := &license.LicenseStatus{
Valid: false,
InGracePeriod: false,
}
payload := buildEntitlementPayload(status, "")
if payload.SubscriptionState != string(license.SubStateExpired) {
t.Fatalf("expected subscription_state %q, got %q", license.SubStateExpired, payload.SubscriptionState)
}
}
func TestBuildEntitlementPayload_GracePeriod(t *testing.T) {
status := &license.LicenseStatus{
Valid: true,
InGracePeriod: true,
}
payload := buildEntitlementPayload(status, "")
if payload.SubscriptionState != string(license.SubStateGrace) {
t.Fatalf("expected subscription_state %q, got %q", license.SubStateGrace, payload.SubscriptionState)
}
}
func TestBuildEntitlementPayload_NilCapabilities(t *testing.T) {
status := &license.LicenseStatus{
Features: nil,
}
payload := buildEntitlementPayload(status, "")
if payload.Capabilities == nil {
t.Fatalf("expected capabilities to be an empty slice, got nil")
}
if len(payload.Capabilities) != 0 {
t.Fatalf("expected capabilities length 0, got %d", len(payload.Capabilities))
}
}
func TestBuildCommercialPosturePayloadWithUsage_CurrentValues(t *testing.T) {
status := &license.LicenseStatus{
Valid: true,
Tier: license.TierFree,
Features: append([]string(nil), license.TierFeatures[license.TierFree]...),
MaxMonitoredSystems: 5,
}
payload := buildCommercialPosturePayloadWithUsage(status, "", entitlementUsageSnapshot{
MonitoredSystems: 7,
MonitoredSystemsAvailable: true,
LegacyConnections: pkglicensing.LegacyConnectionCounts{
ProxmoxNodes: 2,
DockerHosts: 1,
},
}, nil)
if payload.Tier != string(license.TierFree) {
t.Fatalf("expected tier=%q, got %q", license.TierFree, payload.Tier)
}
if payload.SubscriptionState != string(license.SubStateActive) {
t.Fatalf("expected subscription_state=%q, got %q", license.SubStateActive, payload.SubscriptionState)
}
if len(payload.UpgradeReasons) == 0 {
t.Fatal("expected upgrade reasons for free-tier commercial posture")
}
if payload.LegacyConnections.ProxmoxNodes != 2 || payload.LegacyConnections.DockerHosts != 1 {
t.Fatalf("expected legacy counts to be preserved, got %+v", payload.LegacyConnections)
}
if payload.HasMigrationGap {
t.Fatal("expected has_migration_gap=false under canonical monitored-system counting")
}
}
func TestHandleCommercialPosture_ActiveLicense(t *testing.T) {
t.Setenv("PULSE_LICENSE_DEV_MODE", "true")
handler := createTestHandler(t)
licenseKey, err := pkglicensing.GenerateLicenseForTesting(
"owner@example.com",
pkglicensing.TierPro,
24*time.Hour,
)
if err != nil {
t.Fatalf("GenerateLicenseForTesting: %v", err)
}
if _, err := handler.Service(context.Background()).Activate(licenseKey); err != nil {
t.Fatalf("Activate() error = %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/license/commercial-posture", nil)
rec := httptest.NewRecorder()
handler.HandleCommercialPosture(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
body := rec.Body.String()
for _, forbidden := range []string{
`"capabilities"`,
`"limits"`,
`"licensed_email"`,
`"plan_version"`,
`"max_history_days"`,
} {
if strings.Contains(body, forbidden) {
t.Fatalf("commercial posture leaked field %s in %s", forbidden, body)
}
}
var payload CommercialPosturePayload
if err := json.NewDecoder(rec.Body).Decode(&payload); err != nil {
t.Fatalf("failed to decode commercial posture payload: %v", err)
}
if payload.SubscriptionState != string(license.SubStateActive) {
t.Fatalf("expected subscription_state %q, got %q", license.SubStateActive, payload.SubscriptionState)
}
if payload.Tier != string(license.TierPro) {
t.Fatalf("expected tier %q, got %q", license.TierPro, payload.Tier)
}
}
func TestLimitState(t *testing.T) {
tests := []struct {
name string
current int64
limit int64
want string
}{
{name: "ok_below_threshold", current: 50, limit: 100, want: "ok"},
{name: "warning_at_90_percent", current: 90, limit: 100, want: "warning"},
{name: "enforced_at_limit", current: 100, limit: 100, want: "enforced"},
{name: "enforced_above_limit", current: 110, limit: 100, want: "enforced"},
{name: "ok_unlimited", current: 50, limit: 0, want: "ok"},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
got := limitState(tc.current, tc.limit)
if got != tc.want {
t.Fatalf("limitState(%d, %d) = %q, want %q", tc.current, tc.limit, got, tc.want)
}
})
}
}
func TestBuildEntitlementPayload_TrialState(t *testing.T) {
expiresAt := time.Now().Add(36 * time.Hour).UTC().Format(time.RFC3339)
status := &license.LicenseStatus{
Valid: true,
Tier: license.TierPro,
Features: append([]string(nil), license.TierFeatures[license.TierPro]...),
ExpiresAt: &expiresAt,
}
payload := buildEntitlementPayload(status, string(license.SubStateTrial))
if payload.SubscriptionState != string(license.SubStateTrial) {
t.Fatalf("expected subscription_state %q, got %q", license.SubStateTrial, payload.SubscriptionState)
}
if payload.TrialExpiresAt == nil {
t.Fatalf("expected trial_expires_at to be populated for trial state")
}
if payload.TrialDaysRemaining == nil {
t.Fatalf("expected trial_days_remaining to be populated for trial state")
}
if *payload.TrialDaysRemaining != 2 {
t.Fatalf("expected trial_days_remaining 2, got %d", *payload.TrialDaysRemaining)
}
}
func TestBuildEntitlementPayload_PreservesPlanVersionForSelfHostedJWT(t *testing.T) {
tests := []struct {
name string
planVersion string
}{
{name: "lifetime grandfathered", planVersion: "v5_lifetime_grandfathered"},
{name: "monthly grandfathered", planVersion: "v5_pro_monthly_grandfathered"},
{name: "annual grandfathered", planVersion: "v5_pro_annual_grandfathered"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
status := &license.LicenseStatus{
Valid: true,
Tier: license.TierPro,
PlanVersion: tc.planVersion,
Features: append([]string(nil), license.TierFeatures[license.TierPro]...),
}
payload := buildEntitlementPayload(status, string(license.SubStateActive))
if payload.PlanVersion != tc.planVersion {
t.Fatalf("plan_version=%q, want %q", payload.PlanVersion, tc.planVersion)
}
})
}
}
func TestEntitlementHandler_UsesEvaluatorWhenNoLicense(t *testing.T) {
baseDir := t.TempDir()
mtp := config.NewMultiTenantPersistence(baseDir)
orgID := "test-hosted-entitlements"
if _, err := mtp.GetPersistence(orgID); err != nil {
t.Fatalf("GetPersistence(%s) failed: %v", orgID, err)
}
store := config.NewFileBillingStore(baseDir)
if err := store.SaveBillingState(orgID, &entitlements.BillingState{
Capabilities: []string{
license.FeatureAIPatrol,
license.FeatureAIAutoFix,
},
Limits: map[string]int64{
"max_monitored_systems": 5,
},
PlanVersion: "pro",
SubscriptionState: entitlements.SubStateActive,
}); err != nil {
t.Fatalf("SaveBillingState(%s) failed: %v", orgID, err)
}
h := NewLicenseHandlers(mtp, true)
req := httptest.NewRequest(http.MethodGet, "/api/license/entitlements", nil)
req = req.WithContext(context.WithValue(req.Context(), OrgIDContextKey, orgID))
rec := httptest.NewRecorder()
h.HandleEntitlements(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d, want %d", rec.Code, http.StatusOK)
}
var payload EntitlementPayload
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("unmarshal payload failed: %v", err)
}
if payload.SubscriptionState != string(license.SubStateActive) {
t.Fatalf("subscription_state=%q, want %q", payload.SubscriptionState, license.SubStateActive)
}
if payload.PlanVersion != "pro" {
t.Fatalf("plan_version=%q, want %q", payload.PlanVersion, "pro")
}
contains := func(values []string, key string) bool {
for _, v := range values {
if v == key {
return true
}
}
return false
}
if !contains(payload.Capabilities, license.FeatureAIAutoFix) {
t.Fatalf("expected capabilities to include %q, got %v", license.FeatureAIAutoFix, payload.Capabilities)
}
if !contains(payload.Capabilities, license.FeatureAIPatrol) {
t.Fatalf("expected capabilities to include %q, got %v", license.FeatureAIPatrol, payload.Capabilities)
}
var maxMonitoredSystems *LimitStatus
for i := range payload.Limits {
if payload.Limits[i].Key == "max_monitored_systems" {
maxMonitoredSystems = &payload.Limits[i]
break
}
}
if maxMonitoredSystems == nil {
t.Fatalf("expected max_monitored_systems limit in payload, got %v", payload.Limits)
}
if maxMonitoredSystems.Limit != 5 {
t.Fatalf("max_monitored_systems.limit=%d, want %d", maxMonitoredSystems.Limit, 5)
}
// Parity: every advertised capability must be enforced by HasFeature.
ctx := context.WithValue(context.Background(), OrgIDContextKey, orgID)
svc, _, err := h.getTenantComponents(ctx)
if err != nil {
t.Fatalf("getTenantComponents failed: %v", err)
}
for _, cap := range payload.Capabilities {
if !svc.HasFeature(cap) {
t.Fatalf("parity mismatch: HasFeature(%q)=false but capability present in payload", cap)
}
}
}
func TestEntitlementHandler_TrialEligibility_FreshOrgAllowed(t *testing.T) {
baseDir := t.TempDir()
mtp := config.NewMultiTenantPersistence(baseDir)
h := NewLicenseHandlers(mtp, false)
ctx := context.WithValue(context.Background(), OrgIDContextKey, "default")
req := httptest.NewRequest(http.MethodGet, "/api/license/entitlements", nil).WithContext(ctx)
rec := httptest.NewRecorder()
h.HandleEntitlements(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload EntitlementPayload
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("unmarshal payload failed: %v", err)
}
if !payload.TrialEligible {
t.Fatalf("trial_eligible=%v, want true", payload.TrialEligible)
}
if payload.TrialEligibilityReason != "" {
t.Fatalf("trial_eligibility_reason=%q, want empty", payload.TrialEligibilityReason)
}
}
func TestEntitlementHandler_DevModeMirrorsFeatureGateCapabilities(t *testing.T) {
t.Setenv("PULSE_DEV", "true")
t.Setenv("PULSE_MULTI_TENANT_ENABLED", "")
baseDir := t.TempDir()
mtp := config.NewMultiTenantPersistence(baseDir)
h := NewLicenseHandlers(mtp, false)
ctx := context.WithValue(context.Background(), OrgIDContextKey, "default")
req := httptest.NewRequest(http.MethodGet, "/api/license/entitlements", nil).WithContext(ctx)
rec := httptest.NewRecorder()
h.HandleEntitlements(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload EntitlementPayload
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("unmarshal payload failed: %v", err)
}
if !containsCapability(payload.Capabilities, license.FeatureAdvancedReporting) {
t.Fatalf("expected dev entitlements to include %q, got %v", license.FeatureAdvancedReporting, payload.Capabilities)
}
if containsCapability(payload.Capabilities, license.FeatureMultiTenant) {
t.Fatalf("expected dev entitlements to omit %q while runtime flag is disabled, got %v", license.FeatureMultiTenant, payload.Capabilities)
}
for _, feature := range []string{
license.FeatureMultiUser,
license.FeatureWhiteLabel,
license.FeatureUnlimited,
} {
if containsCapability(payload.Capabilities, feature) {
t.Fatalf("expected dev entitlements to omit non-runtime capability %q, got %v", feature, payload.Capabilities)
}
}
if len(payload.UpgradeReasons) != 0 {
t.Fatalf("expected no upgrade reasons in dev mode, got %v", payload.UpgradeReasons)
}
svc, _, err := h.getTenantComponents(ctx)
if err != nil {
t.Fatalf("getTenantComponents failed: %v", err)
}
for _, cap := range payload.Capabilities {
if !svc.HasFeature(cap) {
t.Fatalf("parity mismatch: HasFeature(%q)=false but capability present in payload", cap)
}
}
}
func TestEntitlementHandler_DevModeIncludesMultiTenantWhenRuntimeEnabled(t *testing.T) {
t.Setenv("PULSE_DEV", "true")
t.Setenv("PULSE_MULTI_TENANT_ENABLED", "true")
baseDir := t.TempDir()
mtp := config.NewMultiTenantPersistence(baseDir)
h := NewLicenseHandlers(mtp, false)
ctx := context.WithValue(context.Background(), OrgIDContextKey, "default")
req := httptest.NewRequest(http.MethodGet, "/api/license/entitlements", nil).WithContext(ctx)
rec := httptest.NewRecorder()
h.HandleEntitlements(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload EntitlementPayload
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("unmarshal payload failed: %v", err)
}
if !containsCapability(payload.Capabilities, license.FeatureMultiTenant) {
t.Fatalf("expected dev entitlements to include %q when runtime flag is enabled, got %v", license.FeatureMultiTenant, payload.Capabilities)
}
}
func TestEntitlementHandler_TrialEligibility_AlreadyUsedDenied(t *testing.T) {
baseDir := t.TempDir()
mtp := config.NewMultiTenantPersistence(baseDir)
orgID := "default"
now := time.Now()
startedAt := now.Add(-15 * 24 * time.Hour).Unix()
endsAt := now.Add(-24 * time.Hour).Unix()
store := config.NewFileBillingStore(baseDir)
if err := store.SaveBillingState(orgID, &entitlements.BillingState{
Capabilities: []string{},
Limits: map[string]int64{},
MetersEnabled: []string{},
PlanVersion: "trial",
SubscriptionState: entitlements.SubStateExpired,
TrialStartedAt: &startedAt,
TrialEndsAt: &endsAt,
}); err != nil {
t.Fatalf("SaveBillingState: %v", err)
}
h := NewLicenseHandlers(mtp, false)
ctx := context.WithValue(context.Background(), OrgIDContextKey, orgID)
req := httptest.NewRequest(http.MethodGet, "/api/license/entitlements", nil).WithContext(ctx)
rec := httptest.NewRecorder()
h.HandleEntitlements(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload EntitlementPayload
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("unmarshal payload failed: %v", err)
}
if payload.TrialEligible {
t.Fatalf("trial_eligible=%v, want false", payload.TrialEligible)
}
if payload.TrialEligibilityReason != "already_used" {
t.Fatalf("trial_eligibility_reason=%q, want %q", payload.TrialEligibilityReason, "already_used")
}
}
func TestEntitlementHandler_CommercialMigrationBlocksTrialEligibility(t *testing.T) {
baseDir := t.TempDir()
mtp := config.NewMultiTenantPersistence(baseDir)
orgID := "default"
store := config.NewFileBillingStore(baseDir)
if err := store.SaveBillingState(orgID, &entitlements.BillingState{
Capabilities: []string{},
Limits: map[string]int64{},
MetersEnabled: []string{},
PlanVersion: string(entitlements.SubStateExpired),
SubscriptionState: entitlements.SubStateExpired,
CommercialMigration: &pkglicensing.CommercialMigrationStatus{
Source: pkglicensing.CommercialMigrationSourceV5License,
State: pkglicensing.CommercialMigrationStatePending,
Reason: pkglicensing.CommercialMigrationReasonExchangeUnavailable,
RecommendedAction: pkglicensing.CommercialMigrationActionRetryActivation,
},
}); err != nil {
t.Fatalf("SaveBillingState: %v", err)
}
h := NewLicenseHandlers(mtp, false)
ctx := context.WithValue(context.Background(), OrgIDContextKey, orgID)
req := httptest.NewRequest(http.MethodGet, "/api/license/entitlements", nil).WithContext(ctx)
rec := httptest.NewRecorder()
h.HandleEntitlements(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload EntitlementPayload
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("unmarshal payload failed: %v", err)
}
if payload.CommercialMigration == nil {
t.Fatal("expected commercial_migration payload")
}
if payload.CommercialMigration.State != pkglicensing.CommercialMigrationStatePending {
t.Fatalf("commercial_migration.state=%q, want %q", payload.CommercialMigration.State, pkglicensing.CommercialMigrationStatePending)
}
if payload.TrialEligible {
t.Fatalf("trial_eligible=%v, want false", payload.TrialEligible)
}
if payload.TrialEligibilityReason != "commercial_migration_pending" {
t.Fatalf("trial_eligibility_reason=%q, want %q", payload.TrialEligibilityReason, "commercial_migration_pending")
}
}
// countProMinusFreeFeatures returns the number of Pro features not included in Free.
func countProMinusFreeFeatures() int {
freeSet := make(map[string]struct{}, len(license.TierFeatures[license.TierFree]))
for _, f := range license.TierFeatures[license.TierFree] {
freeSet[f] = struct{}{}
}
count := 0
for _, f := range license.TierFeatures[license.TierPro] {
if _, ok := freeSet[f]; !ok {
count++
}
}
return count
}