mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 00:37:36 +00:00
475 lines
15 KiB
Go
475 lines
15 KiB
Go
package licensing
|
|
|
|
import (
|
|
"crypto/ed25519"
|
|
"encoding/base64"
|
|
"errors"
|
|
"reflect"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
type mockBillingStore struct {
|
|
mu sync.Mutex
|
|
state *BillingState
|
|
err error
|
|
calls int
|
|
}
|
|
|
|
func (m *mockBillingStore) GetBillingState(_ string) (*BillingState, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
m.calls++
|
|
if m.err != nil {
|
|
return nil, m.err
|
|
}
|
|
if m.state == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
state := cloneBillingState(*m.state)
|
|
return &state, nil
|
|
}
|
|
|
|
func (m *mockBillingStore) setState(state *BillingState) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.state = state
|
|
}
|
|
|
|
func (m *mockBillingStore) setError(err error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.err = err
|
|
}
|
|
|
|
func (m *mockBillingStore) callCount() int {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.calls
|
|
}
|
|
|
|
func TestDatabaseSourceHappyPath(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
Capabilities: []string{"rbac", "relay"},
|
|
Limits: map[string]int64{"max_monitored_systems": 50},
|
|
MetersEnabled: []string{"active_agents"},
|
|
PlanVersion: "pro-v2",
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.Capabilities(); !reflect.DeepEqual(got, []string{"rbac", "relay"}) {
|
|
t.Fatalf("expected capabilities %v, got %v", []string{"rbac", "relay"}, got)
|
|
}
|
|
if got := source.Limits(); !reflect.DeepEqual(got, map[string]int64{"max_monitored_systems": 50}) {
|
|
t.Fatalf("expected limits %v, got %v", map[string]int64{"max_monitored_systems": 50}, got)
|
|
}
|
|
if got := source.MetersEnabled(); !reflect.DeepEqual(got, []string{"active_agents"}) {
|
|
t.Fatalf("expected meters %v, got %v", []string{"active_agents"}, got)
|
|
}
|
|
if got := source.PlanVersion(); got != "pro-v2" {
|
|
t.Fatalf("expected plan_version %q, got %q", "pro-v2", got)
|
|
}
|
|
if got := source.SubscriptionState(); got != SubStateActive {
|
|
t.Fatalf("expected subscription_state %q, got %q", SubStateActive, got)
|
|
}
|
|
|
|
if store.callCount() != 1 {
|
|
t.Fatalf("expected store to be called once, got %d", store.callCount())
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceLimits_MaxNodesMigration(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
Limits: map[string]int64{"max_nodes": 25},
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
got := source.Limits()
|
|
if got["max_monitored_systems"] != 25 {
|
|
t.Fatalf("expected max_monitored_systems=25, got %d", got["max_monitored_systems"])
|
|
}
|
|
if _, hasOld := got["max_nodes"]; hasOld {
|
|
t.Fatal("expected max_nodes to be absent after migration")
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceCanonicalizesCloudPlanVersionAndLimits(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
PlanVersion: "cloud_v1",
|
|
Limits: map[string]int64{"max_monitored_systems": 999},
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.PlanVersion(); got != "cloud_starter" {
|
|
t.Fatalf("expected plan_version %q, got %q", "cloud_starter", got)
|
|
}
|
|
if got := source.Limits()["max_monitored_systems"]; got != 10 {
|
|
t.Fatalf("expected max_monitored_systems=%d, got %d", 10, got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourcePreservesMissingPlanVersion(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
PlanVersion: " ",
|
|
Limits: map[string]int64{"max_monitored_systems": 42},
|
|
SubscriptionState: SubscriptionState(" ACTIVE "),
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.PlanVersion(); got != "" {
|
|
t.Fatalf("expected missing plan_version to stay empty, got %q", got)
|
|
}
|
|
if got := source.SubscriptionState(); got != SubStateActive {
|
|
t.Fatalf("expected subscription_state %q, got %q", SubStateActive, got)
|
|
}
|
|
if got := source.Limits()["max_monitored_systems"]; got != 42 {
|
|
t.Fatalf("expected max_monitored_systems=%d, got %d", 42, got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceCacheHit(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
PlanVersion: "pro-v1",
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
_ = source.PlanVersion()
|
|
_ = source.PlanVersion()
|
|
|
|
if store.callCount() != 1 {
|
|
t.Fatalf("expected cache hit to avoid second store call, got %d calls", store.callCount())
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceCacheMissRefresh(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
PlanVersion: "pro-v1",
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.PlanVersion(); got != "pro-v1" {
|
|
t.Fatalf("expected initial plan_version %q, got %q", "pro-v1", got)
|
|
}
|
|
|
|
store.setState(&BillingState{
|
|
PlanVersion: "pro-v2",
|
|
SubscriptionState: SubStateActive,
|
|
})
|
|
|
|
source.mu.Lock()
|
|
source.cacheTime = time.Now().Add(-2 * time.Hour)
|
|
source.mu.Unlock()
|
|
|
|
if got := source.PlanVersion(); got != "pro-v2" {
|
|
t.Fatalf("expected refreshed plan_version %q, got %q", "pro-v2", got)
|
|
}
|
|
|
|
if store.callCount() != 2 {
|
|
t.Fatalf("expected store refresh call, got %d calls", store.callCount())
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceFailOpenWithStaleCache(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
Capabilities: []string{"rbac"},
|
|
PlanVersion: "pro-v1",
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.Capabilities(); !reflect.DeepEqual(got, []string{"rbac"}) {
|
|
t.Fatalf("expected initial capabilities %v, got %v", []string{"rbac"}, got)
|
|
}
|
|
|
|
source.mu.Lock()
|
|
source.cacheTime = time.Now().Add(-2 * time.Hour)
|
|
source.mu.Unlock()
|
|
|
|
store.setError(errors.New("store unavailable"))
|
|
store.setState(&BillingState{
|
|
Capabilities: []string{"new_capability"},
|
|
PlanVersion: "pro-v2",
|
|
SubscriptionState: SubStateActive,
|
|
})
|
|
|
|
if got := source.Capabilities(); !reflect.DeepEqual(got, []string{"rbac"}) {
|
|
t.Fatalf("expected stale cached capabilities on failure, got %v", got)
|
|
}
|
|
|
|
if store.callCount() != 2 {
|
|
t.Fatalf("expected refresh attempt with failure, got %d calls", store.callCount())
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceFailOpenWithNoCache(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
err: errors.New("store unavailable"),
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.Capabilities(); got != nil {
|
|
t.Fatalf("expected default nil capabilities, got %v", got)
|
|
}
|
|
if got := source.Limits(); got != nil {
|
|
t.Fatalf("expected default nil limits, got %v", got)
|
|
}
|
|
if got := source.MetersEnabled(); got != nil {
|
|
t.Fatalf("expected default nil meters_enabled, got %v", got)
|
|
}
|
|
if got := source.PlanVersion(); got != "trial" {
|
|
t.Fatalf("expected default plan_version %q, got %q", "trial", got)
|
|
}
|
|
if got := source.SubscriptionState(); got != SubStateTrial {
|
|
t.Fatalf("expected default subscription_state %q, got %q", SubStateTrial, got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceTrialExpiryMarksExpiredAndStripsCapabilities(t *testing.T) {
|
|
expiredAt := time.Now().Add(-1 * time.Hour).Unix()
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
Capabilities: []string{"ai_autofix", "relay"},
|
|
Limits: map[string]int64{"max_monitored_systems": 50},
|
|
MetersEnabled: []string{"active_agents"},
|
|
PlanVersion: "trial",
|
|
SubscriptionState: SubStateTrial,
|
|
TrialStartedAt: ptrInt64(time.Now().Add(-15 * 24 * time.Hour).Unix()),
|
|
TrialEndsAt: &expiredAt,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.SubscriptionState(); got != SubStateExpired {
|
|
t.Fatalf("expected subscription_state %q, got %q", SubStateExpired, got)
|
|
}
|
|
if got := source.Capabilities(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected capabilities to be stripped on expiry, got %v", got)
|
|
}
|
|
if got := source.Limits(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected limits to be stripped on expiry, got %v", got)
|
|
}
|
|
if got := source.MetersEnabled(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected meters_enabled to be stripped on expiry, got %v", got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceCanceledCloudPlanFailsClosed(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
Capabilities: []string{"relay"},
|
|
Limits: map[string]int64{"max_monitored_systems": 999},
|
|
MetersEnabled: []string{"api_requests"},
|
|
PlanVersion: "cloud_starter",
|
|
SubscriptionState: SubStateCanceled,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.SubscriptionState(); got != SubStateCanceled {
|
|
t.Fatalf("expected subscription_state %q, got %q", SubStateCanceled, got)
|
|
}
|
|
if got := source.Capabilities(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected capabilities to be stripped on cancellation, got %v", got)
|
|
}
|
|
if got := source.Limits(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected limits to be stripped on cancellation, got %v", got)
|
|
}
|
|
if got := source.MetersEnabled(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected meters_enabled to be stripped on cancellation, got %v", got)
|
|
}
|
|
if got := source.PlanVersion(); got != "cloud_starter" {
|
|
t.Fatalf("expected plan_version %q, got %q", "cloud_starter", got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceLeaseOnlyStateResolvesTrialEntitlement(t *testing.T) {
|
|
pub, priv, err := ed25519.GenerateKey(nil)
|
|
if err != nil {
|
|
t.Fatalf("GenerateKey: %v", err)
|
|
}
|
|
embeddedBefore := EmbeddedPublicKey
|
|
EmbeddedPublicKey = ""
|
|
t.Cleanup(func() { EmbeddedPublicKey = embeddedBefore })
|
|
t.Setenv(TrialActivationPublicKeyEnvVar, base64.StdEncoding.EncodeToString(pub))
|
|
|
|
now := time.Now().UTC()
|
|
trialState := BuildTrialBillingState(now, []string{"ai_autofix"})
|
|
token, err := SignEntitlementLeaseToken(priv, EntitlementLeaseClaims{
|
|
OrgID: "org-1",
|
|
InstanceHost: "pulse.example.com",
|
|
PlanVersion: trialState.PlanVersion,
|
|
SubscriptionState: trialState.SubscriptionState,
|
|
Capabilities: append([]string(nil), trialState.Capabilities...),
|
|
Limits: map[string]int64{"max_monitored_systems": 25},
|
|
TrialStartedAt: trialState.TrialStartedAt,
|
|
TrialEndsAt: trialState.TrialEndsAt,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(*trialState.TrialEndsAt, 0).UTC()),
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("SignEntitlementLeaseToken: %v", err)
|
|
}
|
|
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
EntitlementJWT: token,
|
|
TrialStartedAt: trialState.TrialStartedAt,
|
|
},
|
|
}
|
|
source := NewDatabaseSource(store, "org-1", time.Hour).WithExpectedInstanceHost("pulse.example.com")
|
|
|
|
if got := source.SubscriptionState(); got != SubStateTrial {
|
|
t.Fatalf("expected subscription_state %q, got %q", SubStateTrial, got)
|
|
}
|
|
if got := source.Capabilities(); !reflect.DeepEqual(got, []string{"ai_autofix"}) {
|
|
t.Fatalf("expected capabilities %v, got %v", []string{"ai_autofix"}, got)
|
|
}
|
|
if got := source.Limits(); !reflect.DeepEqual(got, map[string]int64{"max_monitored_systems": 25}) {
|
|
t.Fatalf("expected limits %v, got %v", map[string]int64{"max_monitored_systems": 25}, got)
|
|
}
|
|
if got := source.TrialStartedAt(); got == nil || *got != *trialState.TrialStartedAt {
|
|
t.Fatalf("expected trial_started_at %v, got %v", trialState.TrialStartedAt, got)
|
|
}
|
|
if got := source.TrialEndsAt(); got == nil || *got != *trialState.TrialEndsAt {
|
|
t.Fatalf("expected trial_ends_at %v, got %v", trialState.TrialEndsAt, got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceLeaseOnlyStatePreservesMissingPlanVersion(t *testing.T) {
|
|
pub, priv, err := ed25519.GenerateKey(nil)
|
|
if err != nil {
|
|
t.Fatalf("GenerateKey: %v", err)
|
|
}
|
|
embeddedBefore := EmbeddedPublicKey
|
|
EmbeddedPublicKey = ""
|
|
t.Cleanup(func() { EmbeddedPublicKey = embeddedBefore })
|
|
t.Setenv(TrialActivationPublicKeyEnvVar, base64.StdEncoding.EncodeToString(pub))
|
|
|
|
now := time.Now().UTC()
|
|
token, err := SignEntitlementLeaseToken(priv, EntitlementLeaseClaims{
|
|
OrgID: "org-1",
|
|
InstanceHost: "pulse.example.com",
|
|
PlanVersion: " ",
|
|
SubscriptionState: SubStateActive,
|
|
Limits: map[string]int64{"max_monitored_systems": 42},
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("SignEntitlementLeaseToken: %v", err)
|
|
}
|
|
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
EntitlementJWT: token,
|
|
},
|
|
}
|
|
source := NewDatabaseSource(store, "org-1", time.Hour).WithExpectedInstanceHost("pulse.example.com")
|
|
|
|
if got := source.PlanVersion(); got != "" {
|
|
t.Fatalf("expected empty plan_version, got %q", got)
|
|
}
|
|
if got := source.SubscriptionState(); got != SubStateActive {
|
|
t.Fatalf("expected subscription_state %q, got %q", SubStateActive, got)
|
|
}
|
|
if got := source.Limits(); !reflect.DeepEqual(got, map[string]int64{"max_monitored_systems": 42}) {
|
|
t.Fatalf("expected limits %v, got %v", map[string]int64{"max_monitored_systems": 42}, got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceLeaseHostMismatchFailsClosed(t *testing.T) {
|
|
pub, priv, err := ed25519.GenerateKey(nil)
|
|
if err != nil {
|
|
t.Fatalf("GenerateKey: %v", err)
|
|
}
|
|
embeddedBefore := EmbeddedPublicKey
|
|
EmbeddedPublicKey = ""
|
|
t.Cleanup(func() { EmbeddedPublicKey = embeddedBefore })
|
|
t.Setenv(TrialActivationPublicKeyEnvVar, base64.StdEncoding.EncodeToString(pub))
|
|
|
|
now := time.Now().UTC()
|
|
trialState := BuildTrialBillingState(now, []string{FeatureAIAutoFix})
|
|
token, err := SignEntitlementLeaseToken(priv, EntitlementLeaseClaims{
|
|
OrgID: "org-1",
|
|
InstanceHost: "pulse-a.example.com",
|
|
PlanVersion: trialState.PlanVersion,
|
|
SubscriptionState: trialState.SubscriptionState,
|
|
Capabilities: append([]string(nil), trialState.Capabilities...),
|
|
Limits: map[string]int64{"max_monitored_systems": 25},
|
|
TrialStartedAt: trialState.TrialStartedAt,
|
|
TrialEndsAt: trialState.TrialEndsAt,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(*trialState.TrialEndsAt, 0).UTC()),
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("SignEntitlementLeaseToken: %v", err)
|
|
}
|
|
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
EntitlementJWT: token,
|
|
TrialStartedAt: trialState.TrialStartedAt,
|
|
},
|
|
}
|
|
source := NewDatabaseSource(store, "org-1", time.Hour).WithExpectedInstanceHost("pulse-b.example.com")
|
|
|
|
if got := source.SubscriptionState(); got != SubStateExpired {
|
|
t.Fatalf("expected subscription_state %q on host mismatch, got %q", SubStateExpired, got)
|
|
}
|
|
if got := source.Capabilities(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected capabilities to be stripped on host mismatch, got %v", got)
|
|
}
|
|
if got := source.Limits(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected limits to be stripped on host mismatch, got %v", got)
|
|
}
|
|
if got := source.TrialStartedAt(); got == nil || *got != *trialState.TrialStartedAt {
|
|
t.Fatalf("expected trial_started_at %v to be preserved, got %v", trialState.TrialStartedAt, got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceImplementsEntitlementSource(t *testing.T) {
|
|
t.Helper()
|
|
var _ EntitlementSource = (*DatabaseSource)(nil)
|
|
}
|
|
|
|
func ptrInt64(v int64) *int64 {
|
|
return &v
|
|
}
|