mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 08:57:12 +00:00
271 lines
7.2 KiB
Go
271 lines
7.2 KiB
Go
package entitlements
|
|
|
|
import (
|
|
"errors"
|
|
"reflect"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type mockBillingStore struct {
|
|
mu sync.Mutex
|
|
state *BillingState
|
|
err error
|
|
calls int
|
|
}
|
|
|
|
func (m *mockBillingStore) GetBillingState(_ string) (*BillingState, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
m.calls++
|
|
if m.err != nil {
|
|
return nil, m.err
|
|
}
|
|
if m.state == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
state := cloneBillingStateForTest(*m.state)
|
|
return &state, nil
|
|
}
|
|
|
|
func (m *mockBillingStore) setState(state *BillingState) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.state = state
|
|
}
|
|
|
|
func (m *mockBillingStore) setError(err error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.err = err
|
|
}
|
|
|
|
func (m *mockBillingStore) callCount() int {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.calls
|
|
}
|
|
|
|
func TestDatabaseSourceHappyPath(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
Capabilities: []string{"rbac", "relay"},
|
|
Limits: map[string]int64{"max_monitored_systems": 50},
|
|
MetersEnabled: []string{"active_agents"},
|
|
PlanVersion: "pro-v2",
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.Capabilities(); !reflect.DeepEqual(got, []string{"rbac", "relay"}) {
|
|
t.Fatalf("expected capabilities %v, got %v", []string{"rbac", "relay"}, got)
|
|
}
|
|
if got := source.Limits(); !reflect.DeepEqual(got, map[string]int64{"max_monitored_systems": 50}) {
|
|
t.Fatalf("expected limits %v, got %v", map[string]int64{"max_monitored_systems": 50}, got)
|
|
}
|
|
if got := source.MetersEnabled(); !reflect.DeepEqual(got, []string{"active_agents"}) {
|
|
t.Fatalf("expected meters %v, got %v", []string{"active_agents"}, got)
|
|
}
|
|
if got := source.PlanVersion(); got != "pro-v2" {
|
|
t.Fatalf("expected plan_version %q, got %q", "pro-v2", got)
|
|
}
|
|
if got := source.SubscriptionState(); got != SubStateActive {
|
|
t.Fatalf("expected subscription_state %q, got %q", SubStateActive, got)
|
|
}
|
|
|
|
if store.callCount() != 1 {
|
|
t.Fatalf("expected store to be called once, got %d", store.callCount())
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceCacheHit(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
PlanVersion: "pro-v1",
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
_ = source.PlanVersion()
|
|
_ = source.PlanVersion()
|
|
|
|
if store.callCount() != 1 {
|
|
t.Fatalf("expected cache hit to avoid second store call, got %d calls", store.callCount())
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceCacheMissRefresh(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
PlanVersion: "pro-v1",
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", 5*time.Millisecond)
|
|
|
|
if got := source.PlanVersion(); got != "pro-v1" {
|
|
t.Fatalf("expected initial plan_version %q, got %q", "pro-v1", got)
|
|
}
|
|
|
|
store.setState(&BillingState{
|
|
PlanVersion: "pro-v2",
|
|
SubscriptionState: SubStateActive,
|
|
})
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
if got := source.PlanVersion(); got != "pro-v2" {
|
|
t.Fatalf("expected refreshed plan_version %q, got %q", "pro-v2", got)
|
|
}
|
|
|
|
if store.callCount() != 2 {
|
|
t.Fatalf("expected store refresh call, got %d calls", store.callCount())
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceFailOpenWithStaleCache(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
Capabilities: []string{"rbac"},
|
|
PlanVersion: "pro-v1",
|
|
SubscriptionState: SubStateActive,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", 5*time.Millisecond)
|
|
|
|
if got := source.Capabilities(); !reflect.DeepEqual(got, []string{"rbac"}) {
|
|
t.Fatalf("expected initial capabilities %v, got %v", []string{"rbac"}, got)
|
|
}
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
store.setError(errors.New("store unavailable"))
|
|
store.setState(&BillingState{
|
|
Capabilities: []string{"new_capability"},
|
|
PlanVersion: "pro-v2",
|
|
SubscriptionState: SubStateActive,
|
|
})
|
|
|
|
if got := source.Capabilities(); !reflect.DeepEqual(got, []string{"rbac"}) {
|
|
t.Fatalf("expected stale cached capabilities on failure, got %v", got)
|
|
}
|
|
|
|
if store.callCount() != 2 {
|
|
t.Fatalf("expected refresh attempt with failure, got %d calls", store.callCount())
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceFailOpenWithNoCache(t *testing.T) {
|
|
store := &mockBillingStore{
|
|
err: errors.New("store unavailable"),
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.Capabilities(); got != nil {
|
|
t.Fatalf("expected default nil capabilities, got %v", got)
|
|
}
|
|
if got := source.Limits(); got != nil {
|
|
t.Fatalf("expected default nil limits, got %v", got)
|
|
}
|
|
if got := source.MetersEnabled(); got != nil {
|
|
t.Fatalf("expected default nil meters_enabled, got %v", got)
|
|
}
|
|
if got := source.PlanVersion(); got != "trial" {
|
|
t.Fatalf("expected default plan_version %q, got %q", "trial", got)
|
|
}
|
|
if got := source.SubscriptionState(); got != SubStateTrial {
|
|
t.Fatalf("expected default subscription_state %q, got %q", SubStateTrial, got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceTrialExpiryMarksExpiredAndStripsCapabilities(t *testing.T) {
|
|
expiredAt := time.Now().Add(-1 * time.Hour).Unix()
|
|
store := &mockBillingStore{
|
|
state: &BillingState{
|
|
Capabilities: []string{"ai_autofix", "relay"},
|
|
Limits: map[string]int64{"max_monitored_systems": 50},
|
|
MetersEnabled: []string{"active_agents"},
|
|
PlanVersion: "trial",
|
|
SubscriptionState: SubStateTrial,
|
|
TrialStartedAt: ptrInt64(time.Now().Add(-15 * 24 * time.Hour).Unix()),
|
|
TrialEndsAt: &expiredAt,
|
|
},
|
|
}
|
|
|
|
source := NewDatabaseSource(store, "org-1", time.Hour)
|
|
|
|
if got := source.SubscriptionState(); got != SubStateExpired {
|
|
t.Fatalf("expected subscription_state %q, got %q", SubStateExpired, got)
|
|
}
|
|
if got := source.Capabilities(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected capabilities to be stripped on expiry, got %v", got)
|
|
}
|
|
if got := source.Limits(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected limits to be stripped on expiry, got %v", got)
|
|
}
|
|
if got := source.MetersEnabled(); got != nil && len(got) != 0 {
|
|
t.Fatalf("expected meters_enabled to be stripped on expiry, got %v", got)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseSourceImplementsEntitlementSource(t *testing.T) {
|
|
t.Helper()
|
|
var _ EntitlementSource = (*DatabaseSource)(nil)
|
|
}
|
|
|
|
func ptrInt64(v int64) *int64 {
|
|
return &v
|
|
}
|
|
|
|
func cloneBillingStateForTest(state BillingState) BillingState {
|
|
// Start with a full value copy so new fields are never silently dropped.
|
|
cp := state
|
|
|
|
// Deep-clone reference types to break aliasing.
|
|
cp.Capabilities = cloneStringSliceForTest(state.Capabilities)
|
|
cp.Limits = cloneInt64MapForTest(state.Limits)
|
|
cp.MetersEnabled = cloneStringSliceForTest(state.MetersEnabled)
|
|
cp.TrialStartedAt = cloneInt64PtrForTest(state.TrialStartedAt)
|
|
cp.TrialEndsAt = cloneInt64PtrForTest(state.TrialEndsAt)
|
|
cp.TrialExtendedAt = cloneInt64PtrForTest(state.TrialExtendedAt)
|
|
|
|
return cp
|
|
}
|
|
|
|
func cloneStringSliceForTest(values []string) []string {
|
|
if values == nil {
|
|
return nil
|
|
}
|
|
out := make([]string, len(values))
|
|
copy(out, values)
|
|
return out
|
|
}
|
|
|
|
func cloneInt64MapForTest(values map[string]int64) map[string]int64 {
|
|
if values == nil {
|
|
return nil
|
|
}
|
|
out := make(map[string]int64, len(values))
|
|
for k, v := range values {
|
|
out[k] = v
|
|
}
|
|
return out
|
|
}
|
|
|
|
func cloneInt64PtrForTest(v *int64) *int64 {
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
c := *v
|
|
return &c
|
|
}
|