mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-05-07 08:57:12 +00:00
1973 lines
64 KiB
Go
1973 lines
64 KiB
Go
package registry
|
|
|
|
import (
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
pkglicensing "github.com/rcourtman/pulse-go-rewrite/pkg/licensing"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
// TenantRegistry provides CRUD operations for tenant records backed by SQLite.
|
|
type TenantRegistry struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
const stripeEventProcessingLeaseSeconds int64 = 120
|
|
|
|
func canonicalizeRegistryPlanVersion(planVersion string) string {
|
|
return pkglicensing.CanonicalizePlanVersion(strings.TrimSpace(planVersion))
|
|
}
|
|
|
|
// NewTenantRegistry opens (or creates) the tenant registry database in dir.
|
|
func NewTenantRegistry(dir string) (*TenantRegistry, error) {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return nil, fmt.Errorf("create registry dir: %w", err)
|
|
}
|
|
|
|
dbPath := filepath.Join(dir, "tenants.db")
|
|
dsn := dbPath + "?" + url.Values{
|
|
"_pragma": []string{
|
|
"busy_timeout(30000)",
|
|
"foreign_keys(ON)",
|
|
"journal_mode(WAL)",
|
|
"synchronous(NORMAL)",
|
|
},
|
|
}.Encode()
|
|
|
|
db, err := sql.Open("sqlite", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open tenant registry db: %w", err)
|
|
}
|
|
db.SetMaxOpenConns(1)
|
|
db.SetMaxIdleConns(1)
|
|
db.SetConnMaxLifetime(0)
|
|
|
|
r := &TenantRegistry{db: db}
|
|
if err := r.initSchema(); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
func (r *TenantRegistry) initSchema() error {
|
|
schema := `
|
|
CREATE TABLE IF NOT EXISTS tenants (
|
|
id TEXT PRIMARY KEY,
|
|
account_id TEXT NOT NULL DEFAULT '',
|
|
email TEXT NOT NULL DEFAULT '',
|
|
display_name TEXT NOT NULL DEFAULT '',
|
|
state TEXT NOT NULL DEFAULT 'provisioning',
|
|
stripe_customer_id TEXT NOT NULL DEFAULT '',
|
|
stripe_subscription_id TEXT NOT NULL DEFAULT '',
|
|
stripe_price_id TEXT NOT NULL DEFAULT '',
|
|
plan_version TEXT NOT NULL DEFAULT '',
|
|
container_id TEXT NOT NULL DEFAULT '',
|
|
current_image_digest TEXT NOT NULL DEFAULT '',
|
|
desired_image_digest TEXT NOT NULL DEFAULT '',
|
|
created_at INTEGER NOT NULL,
|
|
updated_at INTEGER NOT NULL,
|
|
last_health_check INTEGER,
|
|
health_check_ok INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_tenants_state ON tenants(state);
|
|
CREATE INDEX IF NOT EXISTS idx_tenants_stripe_customer_id ON tenants(stripe_customer_id);
|
|
CREATE INDEX IF NOT EXISTS idx_tenants_created_at ON tenants(created_at DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_tenants_state_created_at ON tenants(state, created_at DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_tenants_account_id_created_at ON tenants(account_id, created_at DESC);
|
|
|
|
CREATE TABLE IF NOT EXISTS accounts (
|
|
id TEXT PRIMARY KEY,
|
|
kind TEXT NOT NULL DEFAULT 'individual',
|
|
display_name TEXT NOT NULL DEFAULT '',
|
|
created_at INTEGER NOT NULL,
|
|
updated_at INTEGER NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_accounts_created_at ON accounts(created_at DESC);
|
|
|
|
CREATE TABLE IF NOT EXISTS stripe_accounts (
|
|
account_id TEXT PRIMARY KEY,
|
|
stripe_customer_id TEXT NOT NULL UNIQUE,
|
|
stripe_subscription_id TEXT,
|
|
stripe_sub_item_workspaces_id TEXT,
|
|
plan_version TEXT NOT NULL DEFAULT '',
|
|
subscription_state TEXT NOT NULL DEFAULT 'trial',
|
|
grace_started_at INTEGER,
|
|
trial_ends_at INTEGER,
|
|
current_period_end INTEGER,
|
|
updated_at INTEGER NOT NULL,
|
|
FOREIGN KEY (account_id) REFERENCES accounts(id)
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_stripe_accounts_customer ON stripe_accounts(stripe_customer_id);
|
|
|
|
CREATE TABLE IF NOT EXISTS stripe_events (
|
|
stripe_event_id TEXT PRIMARY KEY,
|
|
event_type TEXT NOT NULL,
|
|
received_at INTEGER NOT NULL,
|
|
processing_started_at INTEGER,
|
|
processed_at INTEGER,
|
|
processing_error TEXT
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
email TEXT NOT NULL UNIQUE,
|
|
created_at INTEGER NOT NULL,
|
|
last_login_at INTEGER,
|
|
session_version INTEGER NOT NULL DEFAULT 1
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS account_memberships (
|
|
account_id TEXT NOT NULL,
|
|
user_id TEXT NOT NULL,
|
|
role TEXT NOT NULL DEFAULT 'tech',
|
|
created_at INTEGER NOT NULL,
|
|
PRIMARY KEY (account_id, user_id),
|
|
FOREIGN KEY (account_id) REFERENCES accounts(id),
|
|
FOREIGN KEY (user_id) REFERENCES users(id)
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_memberships_user_id ON account_memberships(user_id);
|
|
CREATE INDEX IF NOT EXISTS idx_memberships_user_id_created_at ON account_memberships(user_id, created_at DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_memberships_account_id_created_at ON account_memberships(account_id, created_at DESC);
|
|
|
|
CREATE TABLE IF NOT EXISTS hosted_entitlements (
|
|
id TEXT PRIMARY KEY,
|
|
kind TEXT NOT NULL DEFAULT 'paid',
|
|
tenant_id TEXT,
|
|
trial_request_id TEXT,
|
|
org_id TEXT NOT NULL DEFAULT '',
|
|
email TEXT NOT NULL DEFAULT '',
|
|
return_url TEXT NOT NULL DEFAULT '',
|
|
instance_token TEXT NOT NULL DEFAULT '',
|
|
instance_host TEXT NOT NULL DEFAULT '',
|
|
trial_started_at INTEGER,
|
|
refresh_token TEXT NOT NULL UNIQUE,
|
|
activation_token TEXT NOT NULL DEFAULT '',
|
|
issued_at INTEGER NOT NULL,
|
|
activation_issued_at INTEGER,
|
|
last_refreshed_at INTEGER,
|
|
redeemed_at INTEGER,
|
|
revoked_at INTEGER,
|
|
FOREIGN KEY (tenant_id) REFERENCES tenants(id) ON DELETE CASCADE
|
|
);
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_tenant_id ON hosted_entitlements(tenant_id) WHERE tenant_id <> '';
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_trial_request_id ON hosted_entitlements(trial_request_id) WHERE trial_request_id <> '';
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_refresh_token ON hosted_entitlements(refresh_token);
|
|
CREATE INDEX IF NOT EXISTS idx_hosted_entitlements_kind ON hosted_entitlements(kind);
|
|
`
|
|
if _, err := r.db.Exec(schema); err != nil {
|
|
return fmt.Errorf("init tenant registry schema: %w", err)
|
|
}
|
|
|
|
// Migration: add account_id to tenants if not present.
|
|
// (SQLite makes it awkward to add FK constraints via ALTER TABLE, and FK
|
|
// enforcement is off by default; this keeps the change backwards-compatible.)
|
|
hasAccountID, err := r.tenantsHasColumn("account_id")
|
|
if err != nil {
|
|
return fmt.Errorf("check tenants schema for account_id: %w", err)
|
|
}
|
|
if !hasAccountID {
|
|
if _, err := r.db.Exec(`ALTER TABLE tenants ADD COLUMN account_id TEXT NOT NULL DEFAULT ''`); err != nil {
|
|
return fmt.Errorf("migrate tenants: add account_id: %w", err)
|
|
}
|
|
}
|
|
hasEntitlementRefreshToken, err := r.tenantsHasColumn("entitlement_refresh_token")
|
|
if err != nil {
|
|
return fmt.Errorf("check tenants schema for entitlement_refresh_token: %w", err)
|
|
}
|
|
if _, err := r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_tenants_account_id ON tenants(account_id)`); err != nil {
|
|
return fmt.Errorf("init tenant registry schema: create idx_tenants_account_id: %w", err)
|
|
}
|
|
hasHostedEntitlementID, err := r.tableHasColumn("hosted_entitlements", "id")
|
|
if err != nil {
|
|
return fmt.Errorf("check hosted_entitlements schema for id: %w", err)
|
|
}
|
|
if !hasHostedEntitlementID {
|
|
if err := r.migrateLegacyHostedEntitlementsTable(); err != nil {
|
|
return fmt.Errorf("migrate hosted_entitlements table: %w", err)
|
|
}
|
|
}
|
|
if _, err := r.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_tenant_id ON hosted_entitlements(tenant_id) WHERE tenant_id <> ''`); err != nil {
|
|
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_tenant_id: %w", err)
|
|
}
|
|
if _, err := r.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_trial_request_id ON hosted_entitlements(trial_request_id) WHERE trial_request_id <> ''`); err != nil {
|
|
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_trial_request_id: %w", err)
|
|
}
|
|
if _, err := r.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_refresh_token ON hosted_entitlements(refresh_token)`); err != nil {
|
|
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_refresh_token: %w", err)
|
|
}
|
|
if _, err := r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_hosted_entitlements_kind ON hosted_entitlements(kind)`); err != nil {
|
|
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_kind: %w", err)
|
|
}
|
|
hasHostedActivationToken, err := r.tableHasColumn("hosted_entitlements", "activation_token")
|
|
if err != nil {
|
|
return fmt.Errorf("check hosted_entitlements schema for activation_token: %w", err)
|
|
}
|
|
if !hasHostedActivationToken {
|
|
if _, err := r.db.Exec(`ALTER TABLE hosted_entitlements ADD COLUMN activation_token TEXT NOT NULL DEFAULT ''`); err != nil {
|
|
return fmt.Errorf("migrate hosted_entitlements: add activation_token: %w", err)
|
|
}
|
|
}
|
|
hasHostedActivationIssuedAt, err := r.tableHasColumn("hosted_entitlements", "activation_issued_at")
|
|
if err != nil {
|
|
return fmt.Errorf("check hosted_entitlements schema for activation_issued_at: %w", err)
|
|
}
|
|
if !hasHostedActivationIssuedAt {
|
|
if _, err := r.db.Exec(`ALTER TABLE hosted_entitlements ADD COLUMN activation_issued_at INTEGER`); err != nil {
|
|
return fmt.Errorf("migrate hosted_entitlements: add activation_issued_at: %w", err)
|
|
}
|
|
}
|
|
if _, err := r.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_activation_token ON hosted_entitlements(activation_token) WHERE activation_token <> ''`); err != nil {
|
|
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_activation_token: %w", err)
|
|
}
|
|
if hasEntitlementRefreshToken {
|
|
if _, err := r.db.Exec(`
|
|
INSERT OR IGNORE INTO hosted_entitlements (
|
|
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
)
|
|
SELECT 'paid:' || id, 'paid', id, NULL, '', '', '', '', '', NULL, entitlement_refresh_token, '', updated_at, NULL, NULL, NULL, NULL
|
|
FROM tenants
|
|
WHERE entitlement_refresh_token <> ''
|
|
`); err != nil {
|
|
return fmt.Errorf("migrate hosted entitlements from tenants: %w", err)
|
|
}
|
|
if err := r.removeLegacyTenantEntitlementRefreshTokenColumn(); err != nil {
|
|
return fmt.Errorf("migrate tenants: drop legacy entitlement_refresh_token: %w", err)
|
|
}
|
|
}
|
|
|
|
// Migration: add session_version to users if missing.
|
|
hasSessionVersion, err := r.tableHasColumn("users", "session_version")
|
|
if err != nil {
|
|
return fmt.Errorf("check users schema for session_version: %w", err)
|
|
}
|
|
if !hasSessionVersion {
|
|
if _, err := r.db.Exec(`ALTER TABLE users ADD COLUMN session_version INTEGER NOT NULL DEFAULT 1`); err != nil {
|
|
return fmt.Errorf("migrate users: add session_version: %w", err)
|
|
}
|
|
}
|
|
if _, err := r.db.Exec(`UPDATE users SET session_version = 1 WHERE session_version IS NULL OR session_version < 1`); err != nil {
|
|
return fmt.Errorf("migrate users: backfill session_version: %w", err)
|
|
}
|
|
|
|
// Migration: add processing_started_at to stripe_events if missing.
|
|
hasStripeProcessingStarted, err := r.tableHasColumn("stripe_events", "processing_started_at")
|
|
if err != nil {
|
|
return fmt.Errorf("check stripe_events schema for processing_started_at: %w", err)
|
|
}
|
|
if !hasStripeProcessingStarted {
|
|
if _, err := r.db.Exec(`ALTER TABLE stripe_events ADD COLUMN processing_started_at INTEGER`); err != nil {
|
|
return fmt.Errorf("migrate stripe_events: add processing_started_at: %w", err)
|
|
}
|
|
}
|
|
|
|
// Migration: add grace_started_at to stripe_accounts if missing.
|
|
hasGraceStartedAt, err := r.tableHasColumn("stripe_accounts", "grace_started_at")
|
|
if err != nil {
|
|
return fmt.Errorf("check stripe_accounts schema for grace_started_at: %w", err)
|
|
}
|
|
if !hasGraceStartedAt {
|
|
if _, err := r.db.Exec(`ALTER TABLE stripe_accounts ADD COLUMN grace_started_at INTEGER`); err != nil {
|
|
return fmt.Errorf("migrate stripe_accounts: add grace_started_at: %w", err)
|
|
}
|
|
}
|
|
// Backfill legacy past_due/grace rows to preserve existing grace windows.
|
|
if _, err := r.db.Exec(`
|
|
UPDATE stripe_accounts
|
|
SET grace_started_at = updated_at
|
|
WHERE grace_started_at IS NULL
|
|
AND subscription_state IN ('past_due', 'grace')
|
|
`); err != nil {
|
|
return fmt.Errorf("migrate stripe_accounts: backfill grace_started_at: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *TenantRegistry) tenantsHasColumn(name string) (bool, error) {
|
|
return r.tableHasColumn("tenants", name)
|
|
}
|
|
|
|
func (r *TenantRegistry) tableHasColumn(tableName, name string) (bool, error) {
|
|
rows, err := r.db.Query(`PRAGMA table_info(` + tableName + `)`)
|
|
if err != nil {
|
|
return false, fmt.Errorf("pragma table_info(%s): %w", tableName, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var (
|
|
cid int
|
|
colName string
|
|
colType string
|
|
notNull int
|
|
dflt sql.NullString
|
|
pk int
|
|
)
|
|
if err := rows.Scan(&cid, &colName, &colType, ¬Null, &dflt, &pk); err != nil {
|
|
return false, fmt.Errorf("scan table_info(%s): %w", tableName, err)
|
|
}
|
|
if colName == name {
|
|
return true, nil
|
|
}
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return false, fmt.Errorf("iterate table_info(%s): %w", tableName, err)
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (r *TenantRegistry) removeLegacyTenantEntitlementRefreshTokenColumn() error {
|
|
if _, err := r.db.Exec(`PRAGMA foreign_keys = OFF`); err != nil {
|
|
return fmt.Errorf("disable foreign keys for tenants migration: %w", err)
|
|
}
|
|
defer func() {
|
|
_, _ = r.db.Exec(`PRAGMA foreign_keys = ON`)
|
|
}()
|
|
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("begin tenants migration: %w", err)
|
|
}
|
|
defer func() {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
if _, err := tx.Exec(`
|
|
CREATE TABLE tenants_new (
|
|
id TEXT PRIMARY KEY,
|
|
account_id TEXT NOT NULL DEFAULT '',
|
|
email TEXT NOT NULL DEFAULT '',
|
|
display_name TEXT NOT NULL DEFAULT '',
|
|
state TEXT NOT NULL DEFAULT 'provisioning',
|
|
stripe_customer_id TEXT NOT NULL DEFAULT '',
|
|
stripe_subscription_id TEXT NOT NULL DEFAULT '',
|
|
stripe_price_id TEXT NOT NULL DEFAULT '',
|
|
plan_version TEXT NOT NULL DEFAULT '',
|
|
container_id TEXT NOT NULL DEFAULT '',
|
|
current_image_digest TEXT NOT NULL DEFAULT '',
|
|
desired_image_digest TEXT NOT NULL DEFAULT '',
|
|
created_at INTEGER NOT NULL,
|
|
updated_at INTEGER NOT NULL,
|
|
last_health_check INTEGER,
|
|
health_check_ok INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
INSERT INTO tenants_new (
|
|
id, account_id, email, display_name, state,
|
|
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
|
plan_version, container_id, current_image_digest, desired_image_digest,
|
|
created_at, updated_at, last_health_check, health_check_ok
|
|
)
|
|
SELECT
|
|
id, account_id, email, display_name, state,
|
|
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
|
plan_version, container_id, current_image_digest, desired_image_digest,
|
|
created_at, updated_at, last_health_check, health_check_ok
|
|
FROM tenants;
|
|
DROP TABLE tenants;
|
|
ALTER TABLE tenants_new RENAME TO tenants;
|
|
CREATE INDEX idx_tenants_state ON tenants(state);
|
|
CREATE INDEX idx_tenants_stripe_customer_id ON tenants(stripe_customer_id);
|
|
CREATE INDEX idx_tenants_created_at ON tenants(created_at DESC);
|
|
CREATE INDEX idx_tenants_state_created_at ON tenants(state, created_at DESC);
|
|
CREATE INDEX idx_tenants_account_id ON tenants(account_id);
|
|
CREATE INDEX idx_tenants_account_id_created_at ON tenants(account_id, created_at DESC);
|
|
`); err != nil {
|
|
return fmt.Errorf("rebuild tenants table without legacy entitlement column: %w", err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit tenants migration: %w", err)
|
|
}
|
|
tx = nil
|
|
return nil
|
|
}
|
|
|
|
func (r *TenantRegistry) migrateLegacyHostedEntitlementsTable() error {
|
|
if _, err := r.db.Exec(`PRAGMA foreign_keys = OFF`); err != nil {
|
|
return fmt.Errorf("disable foreign keys for hosted entitlements migration: %w", err)
|
|
}
|
|
defer func() {
|
|
_, _ = r.db.Exec(`PRAGMA foreign_keys = ON`)
|
|
}()
|
|
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("begin hosted entitlements migration: %w", err)
|
|
}
|
|
defer func() {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
if _, err := tx.Exec(`
|
|
CREATE TABLE hosted_entitlements_new (
|
|
id TEXT PRIMARY KEY,
|
|
kind TEXT NOT NULL DEFAULT 'paid',
|
|
tenant_id TEXT,
|
|
trial_request_id TEXT,
|
|
org_id TEXT NOT NULL DEFAULT '',
|
|
email TEXT NOT NULL DEFAULT '',
|
|
return_url TEXT NOT NULL DEFAULT '',
|
|
instance_token TEXT NOT NULL DEFAULT '',
|
|
instance_host TEXT NOT NULL DEFAULT '',
|
|
trial_started_at INTEGER,
|
|
refresh_token TEXT NOT NULL UNIQUE,
|
|
activation_token TEXT NOT NULL DEFAULT '',
|
|
issued_at INTEGER NOT NULL,
|
|
activation_issued_at INTEGER,
|
|
last_refreshed_at INTEGER,
|
|
redeemed_at INTEGER,
|
|
revoked_at INTEGER,
|
|
FOREIGN KEY (tenant_id) REFERENCES tenants(id) ON DELETE CASCADE
|
|
);
|
|
INSERT INTO hosted_entitlements_new (
|
|
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
)
|
|
SELECT
|
|
'paid:' || tenant_id,
|
|
'paid',
|
|
tenant_id,
|
|
NULL,
|
|
'',
|
|
'',
|
|
'',
|
|
'',
|
|
'',
|
|
NULL,
|
|
refresh_token,
|
|
'',
|
|
issued_at,
|
|
NULL,
|
|
last_refreshed_at,
|
|
NULL,
|
|
revoked_at
|
|
FROM hosted_entitlements;
|
|
DROP TABLE hosted_entitlements;
|
|
ALTER TABLE hosted_entitlements_new RENAME TO hosted_entitlements;
|
|
CREATE UNIQUE INDEX idx_hosted_entitlements_tenant_id ON hosted_entitlements(tenant_id) WHERE tenant_id <> '';
|
|
CREATE UNIQUE INDEX idx_hosted_entitlements_trial_request_id ON hosted_entitlements(trial_request_id) WHERE trial_request_id <> '';
|
|
CREATE UNIQUE INDEX idx_hosted_entitlements_refresh_token ON hosted_entitlements(refresh_token);
|
|
CREATE UNIQUE INDEX idx_hosted_entitlements_activation_token ON hosted_entitlements(activation_token) WHERE activation_token <> '';
|
|
CREATE INDEX idx_hosted_entitlements_kind ON hosted_entitlements(kind);
|
|
`); err != nil {
|
|
return fmt.Errorf("rebuild hosted_entitlements table: %w", err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit hosted entitlements migration: %w", err)
|
|
}
|
|
tx = nil
|
|
return nil
|
|
}
|
|
|
|
// Ping checks database connectivity (used for readiness probes).
|
|
func (r *TenantRegistry) Ping() error {
|
|
return r.db.Ping()
|
|
}
|
|
|
|
// Close closes the underlying database connection.
|
|
func (r *TenantRegistry) Close() error {
|
|
if r == nil || r.db == nil {
|
|
return nil
|
|
}
|
|
return r.db.Close()
|
|
}
|
|
|
|
// Create inserts a new tenant record.
|
|
func (r *TenantRegistry) Create(t *Tenant) error {
|
|
if t == nil {
|
|
return fmt.Errorf("tenant is nil")
|
|
}
|
|
now := time.Now().UTC()
|
|
if t.CreatedAt.IsZero() {
|
|
t.CreatedAt = now
|
|
}
|
|
t.PlanVersion = canonicalizeRegistryPlanVersion(t.PlanVersion)
|
|
t.UpdatedAt = now
|
|
|
|
_, err := r.db.Exec(`
|
|
INSERT INTO tenants (
|
|
id, account_id, email, display_name, state,
|
|
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
|
plan_version, container_id, current_image_digest, desired_image_digest,
|
|
created_at, updated_at, last_health_check, health_check_ok
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
t.ID, t.AccountID, t.Email, t.DisplayName, string(t.State),
|
|
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
|
|
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
|
|
t.CreatedAt.Unix(), t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("create tenant: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get retrieves a tenant by ID.
|
|
func (r *TenantRegistry) Get(tenantID string) (*Tenant, error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
id, account_id, email, display_name, state,
|
|
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
|
plan_version, container_id, current_image_digest, desired_image_digest,
|
|
created_at, updated_at, last_health_check, health_check_ok
|
|
FROM tenants WHERE id = ?`, tenantID)
|
|
return scanTenant(row)
|
|
}
|
|
|
|
// GetByStripeCustomerID retrieves a tenant by Stripe customer ID.
|
|
func (r *TenantRegistry) GetByStripeCustomerID(customerID string) (*Tenant, error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
id, account_id, email, display_name, state,
|
|
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
|
plan_version, container_id, current_image_digest, desired_image_digest,
|
|
created_at, updated_at, last_health_check, health_check_ok
|
|
FROM tenants WHERE stripe_customer_id = ?`, customerID)
|
|
return scanTenant(row)
|
|
}
|
|
|
|
// Update modifies an existing tenant record.
|
|
func (r *TenantRegistry) Update(t *Tenant) error {
|
|
if t == nil {
|
|
return fmt.Errorf("tenant is nil")
|
|
}
|
|
t.PlanVersion = canonicalizeRegistryPlanVersion(t.PlanVersion)
|
|
t.UpdatedAt = time.Now().UTC()
|
|
|
|
res, err := r.db.Exec(`
|
|
UPDATE tenants SET
|
|
account_id = ?, email = ?, display_name = ?, state = ?,
|
|
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_price_id = ?,
|
|
plan_version = ?, container_id = ?, current_image_digest = ?, desired_image_digest = ?,
|
|
updated_at = ?, last_health_check = ?, health_check_ok = ?
|
|
WHERE id = ?`,
|
|
t.AccountID, t.Email, t.DisplayName, string(t.State),
|
|
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
|
|
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
|
|
t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
|
|
t.ID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update tenant: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("tenant %q not found", t.ID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Delete removes a tenant record by ID.
|
|
func (r *TenantRegistry) Delete(id string) error {
|
|
if _, err := r.db.Exec(`DELETE FROM hosted_entitlements WHERE tenant_id = ?`, id); err != nil {
|
|
return fmt.Errorf("delete hosted entitlement: %w", err)
|
|
}
|
|
res, err := r.db.Exec(`DELETE FROM tenants WHERE id = ?`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("delete tenant: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("tenant %q not found", id)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List returns all tenants.
|
|
func (r *TenantRegistry) List() ([]*Tenant, error) {
|
|
rows, err := r.db.Query(`SELECT
|
|
id, account_id, email, display_name, state,
|
|
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
|
plan_version, container_id, current_image_digest, desired_image_digest,
|
|
created_at, updated_at, last_health_check, health_check_ok
|
|
FROM tenants ORDER BY created_at DESC`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list tenants: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanTenants(rows)
|
|
}
|
|
|
|
// ListByState returns all tenants matching the given state.
|
|
func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) {
|
|
rows, err := r.db.Query(`SELECT
|
|
id, account_id, email, display_name, state,
|
|
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
|
plan_version, container_id, current_image_digest, desired_image_digest,
|
|
created_at, updated_at, last_health_check, health_check_ok
|
|
FROM tenants WHERE state = ? ORDER BY created_at DESC`, string(state))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list tenants by state: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanTenants(rows)
|
|
}
|
|
|
|
// ListByAccountID returns all tenants belonging to the given account ID.
|
|
func (r *TenantRegistry) ListByAccountID(accountID string) ([]*Tenant, error) {
|
|
rows, err := r.db.Query(`SELECT
|
|
id, account_id, email, display_name, state,
|
|
stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
|
plan_version, container_id, current_image_digest, desired_image_digest,
|
|
created_at, updated_at, last_health_check, health_check_ok
|
|
FROM tenants WHERE account_id = ? ORDER BY created_at DESC`, accountID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list tenants by account id: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanTenants(rows)
|
|
}
|
|
|
|
// CountActiveByAccountID returns the number of non-deleted tenants belonging to
|
|
// the given account. States counted: provisioning, active, suspended, failed.
|
|
// States excluded: deleting, deleted, canceled.
|
|
func (r *TenantRegistry) CountActiveByAccountID(accountID string) (int, error) {
|
|
var count int
|
|
err := r.db.QueryRow(`
|
|
SELECT COUNT(*) FROM tenants
|
|
WHERE account_id = ?
|
|
AND state NOT IN ('deleting', 'deleted', 'canceled')`,
|
|
accountID,
|
|
).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("count active tenants for account %q: %w", accountID, err)
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// GetTenantForAccount retrieves a tenant by ID and verifies it belongs to the
|
|
// given account. Returns (nil, nil) if the tenant does not exist or belongs to
|
|
// a different account.
|
|
func (r *TenantRegistry) GetTenantForAccount(accountID, tenantID string) (*Tenant, error) {
|
|
t, err := r.Get(tenantID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if t == nil {
|
|
return nil, nil
|
|
}
|
|
if strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID {
|
|
return nil, nil
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
// CountByState returns a map of state -> count.
|
|
func (r *TenantRegistry) CountByState() (map[TenantState]int, error) {
|
|
rows, err := r.db.Query(`SELECT state, COUNT(*) FROM tenants GROUP BY state`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("count tenants by state: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
counts := make(map[TenantState]int)
|
|
for rows.Next() {
|
|
var state string
|
|
var count int
|
|
if err := rows.Scan(&state, &count); err != nil {
|
|
return nil, fmt.Errorf("scan count: %w", err)
|
|
}
|
|
counts[TenantState(state)] = count
|
|
}
|
|
return counts, rows.Err()
|
|
}
|
|
|
|
// HealthSummary returns the number of healthy and unhealthy active tenants.
|
|
func (r *TenantRegistry) HealthSummary() (healthy, unhealthy int, err error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
COALESCE(SUM(CASE WHEN health_check_ok = 1 THEN 1 ELSE 0 END), 0),
|
|
COALESCE(SUM(CASE WHEN health_check_ok = 0 THEN 1 ELSE 0 END), 0)
|
|
FROM tenants WHERE state = ?`, string(TenantStateActive))
|
|
if err := row.Scan(&healthy, &unhealthy); err != nil {
|
|
return 0, 0, fmt.Errorf("health summary: %w", err)
|
|
}
|
|
return healthy, unhealthy, nil
|
|
}
|
|
|
|
// scanner is an interface satisfied by both *sql.Row and *sql.Rows.
|
|
type scanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanTenant(s scanner) (*Tenant, error) {
|
|
var t Tenant
|
|
var state string
|
|
var createdAt, updatedAt int64
|
|
var lastHealthCheck sql.NullInt64
|
|
var healthOK int
|
|
|
|
err := s.Scan(
|
|
&t.ID, &t.AccountID, &t.Email, &t.DisplayName, &state,
|
|
&t.StripeCustomerID, &t.StripeSubscriptionID, &t.StripePriceID,
|
|
&t.PlanVersion, &t.ContainerID, &t.CurrentImageDigest, &t.DesiredImageDigest,
|
|
&createdAt, &updatedAt, &lastHealthCheck, &healthOK,
|
|
)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("scan tenant: %w", err)
|
|
}
|
|
|
|
t.State = TenantState(state)
|
|
t.PlanVersion = canonicalizeRegistryPlanVersion(t.PlanVersion)
|
|
t.CreatedAt = time.Unix(createdAt, 0).UTC()
|
|
t.UpdatedAt = time.Unix(updatedAt, 0).UTC()
|
|
if lastHealthCheck.Valid {
|
|
ts := time.Unix(lastHealthCheck.Int64, 0).UTC()
|
|
t.LastHealthCheck = &ts
|
|
}
|
|
t.HealthCheckOK = healthOK != 0
|
|
return &t, nil
|
|
}
|
|
|
|
func paidHostedEntitlementID(tenantID string) string {
|
|
return "paid:" + strings.TrimSpace(tenantID)
|
|
}
|
|
|
|
func trialHostedEntitlementID(requestID string) string {
|
|
return "trial:" + strings.TrimSpace(requestID)
|
|
}
|
|
|
|
// StoreOrIssueHostedEntitlement stores a new hosted entitlement refresh token for a paid tenant,
|
|
// or returns the existing active token if one has already been issued.
|
|
func (r *TenantRegistry) StoreOrIssueHostedEntitlement(tenantID, token string, issuedAt time.Time) (string, bool, error) {
|
|
tenantID = strings.TrimSpace(tenantID)
|
|
token = strings.TrimSpace(token)
|
|
issuedAt = issuedAt.UTC()
|
|
if tenantID == "" {
|
|
return "", false, fmt.Errorf("missing tenant id")
|
|
}
|
|
if token == "" {
|
|
return "", false, fmt.Errorf("missing entitlement refresh token")
|
|
}
|
|
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("begin entitlement refresh token tx: %w", err)
|
|
}
|
|
defer func() {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
if err := tx.QueryRow(`SELECT id FROM tenants WHERE id = ?`, tenantID).Scan(new(string)); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return "", false, fmt.Errorf("tenant %q not found", tenantID)
|
|
}
|
|
return "", false, fmt.Errorf("load tenant for hosted entitlement: %w", err)
|
|
}
|
|
|
|
rec, err := loadHostedEntitlement(tx.QueryRow(`
|
|
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
FROM hosted_entitlements WHERE tenant_id = ?`,
|
|
tenantID,
|
|
))
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("load hosted entitlement: %w", err)
|
|
}
|
|
if rec != nil && strings.TrimSpace(rec.RefreshToken) != "" && rec.RevokedAt == nil {
|
|
if err := tx.Commit(); err != nil {
|
|
return "", false, fmt.Errorf("commit hosted entitlement tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return rec.RefreshToken, false, nil
|
|
}
|
|
|
|
if rec == nil {
|
|
if _, err := tx.Exec(`
|
|
INSERT INTO hosted_entitlements (
|
|
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
) VALUES (?, ?, ?, NULL, '', '', '', '', '', NULL, ?, '', ?, NULL, NULL, NULL, NULL)`,
|
|
paidHostedEntitlementID(tenantID),
|
|
string(HostedEntitlementKindPaid),
|
|
tenantID,
|
|
token,
|
|
issuedAt.Unix(),
|
|
); err != nil {
|
|
return "", false, fmt.Errorf("insert hosted entitlement: %w", err)
|
|
}
|
|
} else {
|
|
if _, err := tx.Exec(`
|
|
UPDATE hosted_entitlements
|
|
SET id = ?, kind = ?, tenant_id = ?, trial_request_id = NULL, org_id = '', email = '',
|
|
return_url = '', instance_token = '', instance_host = '', trial_started_at = NULL,
|
|
refresh_token = ?, activation_token = '', issued_at = ?, activation_issued_at = NULL,
|
|
last_refreshed_at = NULL, redeemed_at = NULL, revoked_at = NULL
|
|
WHERE tenant_id = ?`,
|
|
paidHostedEntitlementID(tenantID),
|
|
string(HostedEntitlementKindPaid),
|
|
tenantID,
|
|
token,
|
|
issuedAt.Unix(),
|
|
tenantID,
|
|
); err != nil {
|
|
return "", false, fmt.Errorf("rotate hosted entitlement: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return "", false, fmt.Errorf("commit hosted entitlement tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return token, true, nil
|
|
}
|
|
|
|
func (r *TenantRegistry) StoreOrIssueTrialHostedEntitlement(input TrialHostedEntitlementInput) (string, bool, error) {
|
|
requestID := strings.TrimSpace(input.RequestID)
|
|
token := strings.TrimSpace(input.RefreshToken)
|
|
if requestID == "" {
|
|
return "", false, fmt.Errorf("missing trial request id")
|
|
}
|
|
if token == "" {
|
|
return "", false, fmt.Errorf("missing entitlement refresh token")
|
|
}
|
|
if strings.TrimSpace(input.OrgID) == "" || strings.TrimSpace(input.Email) == "" || strings.TrimSpace(input.ReturnURL) == "" || strings.TrimSpace(input.InstanceHost) == "" {
|
|
return "", false, fmt.Errorf("trial entitlement input is incomplete")
|
|
}
|
|
|
|
issuedAt := input.IssuedAt.UTC()
|
|
redeemedAt := input.RedeemedAt.UTC()
|
|
trialStartedAt := input.TrialStartedAt.UTC()
|
|
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("begin trial hosted entitlement tx: %w", err)
|
|
}
|
|
defer func() {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
rec, err := loadHostedEntitlement(tx.QueryRow(`
|
|
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
FROM hosted_entitlements WHERE trial_request_id = ?`,
|
|
requestID,
|
|
))
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("load trial hosted entitlement: %w", err)
|
|
}
|
|
if rec != nil && strings.TrimSpace(rec.RefreshToken) != "" && rec.RevokedAt == nil {
|
|
if _, err := tx.Exec(`
|
|
UPDATE hosted_entitlements
|
|
SET org_id = ?, email = ?, return_url = ?, instance_token = ?, instance_host = ?, trial_started_at = ?,
|
|
redeemed_at = COALESCE(redeemed_at, ?)
|
|
WHERE id = ?`,
|
|
strings.TrimSpace(input.OrgID),
|
|
strings.TrimSpace(input.Email),
|
|
strings.TrimSpace(input.ReturnURL),
|
|
strings.TrimSpace(input.InstanceToken),
|
|
strings.TrimSpace(input.InstanceHost),
|
|
trialStartedAt.Unix(),
|
|
redeemedAt.Unix(),
|
|
rec.ID,
|
|
); err != nil {
|
|
return "", false, fmt.Errorf("update trial hosted entitlement metadata: %w", err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return "", false, fmt.Errorf("commit trial hosted entitlement tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return rec.RefreshToken, false, nil
|
|
}
|
|
|
|
if rec == nil {
|
|
if _, err := tx.Exec(`
|
|
INSERT INTO hosted_entitlements (
|
|
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
) VALUES (?, ?, NULL, ?, ?, ?, ?, ?, ?, ?, ?, '', ?, NULL, NULL, ?, NULL)`,
|
|
trialHostedEntitlementID(requestID),
|
|
string(HostedEntitlementKindTrial),
|
|
requestID,
|
|
strings.TrimSpace(input.OrgID),
|
|
strings.TrimSpace(input.Email),
|
|
strings.TrimSpace(input.ReturnURL),
|
|
strings.TrimSpace(input.InstanceToken),
|
|
strings.TrimSpace(input.InstanceHost),
|
|
trialStartedAt.Unix(),
|
|
token,
|
|
issuedAt.Unix(),
|
|
redeemedAt.Unix(),
|
|
); err != nil {
|
|
return "", false, fmt.Errorf("insert trial hosted entitlement: %w", err)
|
|
}
|
|
} else {
|
|
if _, err := tx.Exec(`
|
|
UPDATE hosted_entitlements
|
|
SET id = ?, kind = ?, tenant_id = NULL, trial_request_id = ?, org_id = ?, email = ?, return_url = ?,
|
|
instance_token = ?, instance_host = ?, trial_started_at = ?, refresh_token = ?, issued_at = ?,
|
|
last_refreshed_at = NULL, redeemed_at = ?, revoked_at = NULL
|
|
WHERE trial_request_id = ?`,
|
|
trialHostedEntitlementID(requestID),
|
|
string(HostedEntitlementKindTrial),
|
|
requestID,
|
|
strings.TrimSpace(input.OrgID),
|
|
strings.TrimSpace(input.Email),
|
|
strings.TrimSpace(input.ReturnURL),
|
|
strings.TrimSpace(input.InstanceToken),
|
|
strings.TrimSpace(input.InstanceHost),
|
|
trialStartedAt.Unix(),
|
|
token,
|
|
issuedAt.Unix(),
|
|
redeemedAt.Unix(),
|
|
requestID,
|
|
); err != nil {
|
|
return "", false, fmt.Errorf("rotate trial hosted entitlement: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return "", false, fmt.Errorf("commit trial hosted entitlement tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return token, true, nil
|
|
}
|
|
|
|
func (r *TenantRegistry) StoreOrRotateTrialActivation(input TrialHostedActivationInput, ttl time.Duration) (string, bool, error) {
|
|
requestID := strings.TrimSpace(input.RequestID)
|
|
activationToken := strings.TrimSpace(input.ActivationToken)
|
|
refreshToken := strings.TrimSpace(input.RefreshToken)
|
|
if requestID == "" {
|
|
return "", false, fmt.Errorf("missing trial request id")
|
|
}
|
|
if activationToken == "" {
|
|
return "", false, fmt.Errorf("missing activation token")
|
|
}
|
|
if refreshToken == "" {
|
|
return "", false, fmt.Errorf("missing entitlement refresh token")
|
|
}
|
|
if strings.TrimSpace(input.OrgID) == "" || strings.TrimSpace(input.Email) == "" || strings.TrimSpace(input.ReturnURL) == "" || strings.TrimSpace(input.InstanceHost) == "" {
|
|
return "", false, fmt.Errorf("trial activation input is incomplete")
|
|
}
|
|
if ttl <= 0 {
|
|
return "", false, fmt.Errorf("activation ttl is required")
|
|
}
|
|
|
|
issuedAt := input.IssuedAt.UTC()
|
|
if issuedAt.IsZero() {
|
|
return "", false, fmt.Errorf("activation issued_at is required")
|
|
}
|
|
trialStartedAt := input.TrialStartedAt.UTC()
|
|
if trialStartedAt.IsZero() {
|
|
trialStartedAt = issuedAt
|
|
}
|
|
rotateBefore := issuedAt.Add(-ttl)
|
|
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("begin trial activation tx: %w", err)
|
|
}
|
|
defer func() {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
rec, err := loadHostedEntitlement(tx.QueryRow(`
|
|
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
FROM hosted_entitlements WHERE trial_request_id = ?`,
|
|
requestID,
|
|
))
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("load trial activation entitlement: %w", err)
|
|
}
|
|
|
|
if rec != nil && strings.TrimSpace(rec.ActivationToken) != "" && rec.ActivationIssuedAt != nil && rec.ActivationIssuedAt.After(rotateBefore) && rec.RevokedAt == nil {
|
|
if _, err := tx.Exec(`
|
|
UPDATE hosted_entitlements
|
|
SET org_id = ?, email = ?, return_url = ?, instance_token = ?, instance_host = ?, trial_started_at = ?
|
|
WHERE id = ?`,
|
|
strings.TrimSpace(input.OrgID),
|
|
strings.TrimSpace(input.Email),
|
|
strings.TrimSpace(input.ReturnURL),
|
|
strings.TrimSpace(input.InstanceToken),
|
|
strings.TrimSpace(input.InstanceHost),
|
|
trialStartedAt.Unix(),
|
|
rec.ID,
|
|
); err != nil {
|
|
return "", false, fmt.Errorf("update trial activation metadata: %w", err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return "", false, fmt.Errorf("commit trial activation tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return rec.ActivationToken, false, nil
|
|
}
|
|
|
|
if rec == nil {
|
|
if _, err := tx.Exec(`
|
|
INSERT INTO hosted_entitlements (
|
|
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
) VALUES (?, ?, NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL, NULL)`,
|
|
trialHostedEntitlementID(requestID),
|
|
string(HostedEntitlementKindTrial),
|
|
requestID,
|
|
strings.TrimSpace(input.OrgID),
|
|
strings.TrimSpace(input.Email),
|
|
strings.TrimSpace(input.ReturnURL),
|
|
strings.TrimSpace(input.InstanceToken),
|
|
strings.TrimSpace(input.InstanceHost),
|
|
trialStartedAt.Unix(),
|
|
refreshToken,
|
|
activationToken,
|
|
issuedAt.Unix(),
|
|
issuedAt.Unix(),
|
|
); err != nil {
|
|
return "", false, fmt.Errorf("insert trial activation entitlement: %w", err)
|
|
}
|
|
} else {
|
|
effectiveRefreshToken := strings.TrimSpace(rec.RefreshToken)
|
|
if effectiveRefreshToken == "" {
|
|
effectiveRefreshToken = refreshToken
|
|
}
|
|
if _, err := tx.Exec(`
|
|
UPDATE hosted_entitlements
|
|
SET id = ?, kind = ?, tenant_id = NULL, trial_request_id = ?, org_id = ?, email = ?, return_url = ?,
|
|
instance_token = ?, instance_host = ?, trial_started_at = ?, refresh_token = ?, activation_token = ?,
|
|
issued_at = CASE WHEN issued_at > 0 THEN issued_at ELSE ? END, activation_issued_at = ?,
|
|
last_refreshed_at = NULL, revoked_at = NULL
|
|
WHERE trial_request_id = ?`,
|
|
trialHostedEntitlementID(requestID),
|
|
string(HostedEntitlementKindTrial),
|
|
requestID,
|
|
strings.TrimSpace(input.OrgID),
|
|
strings.TrimSpace(input.Email),
|
|
strings.TrimSpace(input.ReturnURL),
|
|
strings.TrimSpace(input.InstanceToken),
|
|
strings.TrimSpace(input.InstanceHost),
|
|
trialStartedAt.Unix(),
|
|
effectiveRefreshToken,
|
|
activationToken,
|
|
issuedAt.Unix(),
|
|
issuedAt.Unix(),
|
|
requestID,
|
|
); err != nil {
|
|
return "", false, fmt.Errorf("rotate trial activation entitlement: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return "", false, fmt.Errorf("commit trial activation tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return activationToken, true, nil
|
|
}
|
|
|
|
func scanTenants(rows *sql.Rows) ([]*Tenant, error) {
|
|
var tenants []*Tenant
|
|
for rows.Next() {
|
|
t, err := scanTenant(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tenants = append(tenants, t)
|
|
}
|
|
return tenants, rows.Err()
|
|
}
|
|
|
|
// GetHostedEntitlementByRefreshToken retrieves the hosted entitlement record for a refresh token.
|
|
func (r *TenantRegistry) GetHostedEntitlementByRefreshToken(token string) (*HostedEntitlement, error) {
|
|
row := r.db.QueryRow(`
|
|
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
FROM hosted_entitlements
|
|
WHERE refresh_token = ?`,
|
|
strings.TrimSpace(token),
|
|
)
|
|
return loadHostedEntitlement(row)
|
|
}
|
|
|
|
func (r *TenantRegistry) GetHostedEntitlementByActivationToken(token string) (*HostedEntitlement, error) {
|
|
row := r.db.QueryRow(`
|
|
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
FROM hosted_entitlements
|
|
WHERE activation_token = ?`,
|
|
strings.TrimSpace(token),
|
|
)
|
|
return loadHostedEntitlement(row)
|
|
}
|
|
|
|
func (r *TenantRegistry) GetHostedEntitlementByTrialRequestID(requestID string) (*HostedEntitlement, error) {
|
|
row := r.db.QueryRow(`
|
|
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
|
|
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
|
|
FROM hosted_entitlements
|
|
WHERE trial_request_id = ?`,
|
|
strings.TrimSpace(requestID),
|
|
)
|
|
return loadHostedEntitlement(row)
|
|
}
|
|
|
|
// MarkHostedEntitlementRefreshed records the last successful hosted entitlement refresh time.
|
|
func (r *TenantRegistry) MarkHostedEntitlementRefreshed(id string, refreshedAt time.Time) error {
|
|
id = strings.TrimSpace(id)
|
|
if id == "" {
|
|
return fmt.Errorf("missing hosted entitlement id")
|
|
}
|
|
res, err := r.db.Exec(`
|
|
UPDATE hosted_entitlements
|
|
SET last_refreshed_at = ?
|
|
WHERE id = ? AND revoked_at IS NULL`,
|
|
refreshedAt.UTC().Unix(),
|
|
id,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("mark hosted entitlement refreshed: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("hosted entitlement %q not found", id)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RevokeHostedEntitlement revokes a tenant's hosted entitlement refresh authority.
|
|
func (r *TenantRegistry) RevokeHostedEntitlement(tenantID string, revokedAt time.Time) error {
|
|
tenantID = strings.TrimSpace(tenantID)
|
|
if tenantID == "" {
|
|
return fmt.Errorf("missing tenant id")
|
|
}
|
|
_, err := r.db.Exec(`
|
|
UPDATE hosted_entitlements
|
|
SET revoked_at = COALESCE(revoked_at, ?)
|
|
WHERE tenant_id = ? AND kind = ?`,
|
|
revokedAt.UTC().Unix(),
|
|
tenantID,
|
|
string(HostedEntitlementKindPaid),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("revoke hosted entitlement: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loadHostedEntitlement(s scanner) (*HostedEntitlement, error) {
|
|
var rec HostedEntitlement
|
|
var kind string
|
|
var tenantID sql.NullString
|
|
var trialRequestID sql.NullString
|
|
var issuedAt int64
|
|
var trialStartedAt sql.NullInt64
|
|
var activationIssuedAt sql.NullInt64
|
|
var lastRefreshedAt sql.NullInt64
|
|
var redeemedAt sql.NullInt64
|
|
var revokedAt sql.NullInt64
|
|
if err := s.Scan(
|
|
&rec.ID,
|
|
&kind,
|
|
&tenantID,
|
|
&trialRequestID,
|
|
&rec.OrgID,
|
|
&rec.Email,
|
|
&rec.ReturnURL,
|
|
&rec.InstanceToken,
|
|
&rec.InstanceHost,
|
|
&trialStartedAt,
|
|
&rec.RefreshToken,
|
|
&rec.ActivationToken,
|
|
&issuedAt,
|
|
&activationIssuedAt,
|
|
&lastRefreshedAt,
|
|
&redeemedAt,
|
|
&revokedAt,
|
|
); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("scan hosted entitlement: %w", err)
|
|
}
|
|
if strings.TrimSpace(kind) == "" {
|
|
kind = string(HostedEntitlementKindPaid)
|
|
}
|
|
rec.Kind = HostedEntitlementKind(kind)
|
|
if tenantID.Valid {
|
|
rec.TenantID = tenantID.String
|
|
}
|
|
if trialRequestID.Valid {
|
|
rec.TrialRequestID = trialRequestID.String
|
|
}
|
|
if trialStartedAt.Valid {
|
|
ts := time.Unix(trialStartedAt.Int64, 0).UTC()
|
|
rec.TrialStartedAt = &ts
|
|
}
|
|
rec.IssuedAt = time.Unix(issuedAt, 0).UTC()
|
|
if activationIssuedAt.Valid {
|
|
ts := time.Unix(activationIssuedAt.Int64, 0).UTC()
|
|
rec.ActivationIssuedAt = &ts
|
|
}
|
|
if lastRefreshedAt.Valid {
|
|
ts := time.Unix(lastRefreshedAt.Int64, 0).UTC()
|
|
rec.LastRefreshedAt = &ts
|
|
}
|
|
if redeemedAt.Valid {
|
|
ts := time.Unix(redeemedAt.Int64, 0).UTC()
|
|
rec.RedeemedAt = &ts
|
|
}
|
|
if revokedAt.Valid {
|
|
ts := time.Unix(revokedAt.Int64, 0).UTC()
|
|
rec.RevokedAt = &ts
|
|
}
|
|
return &rec, nil
|
|
}
|
|
|
|
// CreateAccount inserts a new account record.
|
|
func (r *TenantRegistry) CreateAccount(a *Account) error {
|
|
if a == nil {
|
|
return fmt.Errorf("account is nil")
|
|
}
|
|
now := time.Now().UTC()
|
|
if a.CreatedAt.IsZero() {
|
|
a.CreatedAt = now
|
|
}
|
|
a.UpdatedAt = now
|
|
|
|
kind := string(a.Kind)
|
|
if kind == "" {
|
|
kind = string(AccountKindIndividual)
|
|
}
|
|
|
|
_, err := r.db.Exec(`
|
|
INSERT INTO accounts (
|
|
id, kind, display_name, created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?)`,
|
|
a.ID, kind, a.DisplayName, a.CreatedAt.Unix(), a.UpdatedAt.Unix(),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("create account: %w", err)
|
|
}
|
|
a.Kind = AccountKind(kind)
|
|
return nil
|
|
}
|
|
|
|
// GetAccount retrieves an account by ID.
|
|
func (r *TenantRegistry) GetAccount(accountID string) (*Account, error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
id, kind, display_name, created_at, updated_at
|
|
FROM accounts WHERE id = ?`, accountID)
|
|
return scanAccount(row)
|
|
}
|
|
|
|
// UpdateAccount modifies an existing account record.
|
|
func (r *TenantRegistry) UpdateAccount(a *Account) error {
|
|
if a == nil {
|
|
return fmt.Errorf("account is nil")
|
|
}
|
|
a.UpdatedAt = time.Now().UTC()
|
|
|
|
kind := string(a.Kind)
|
|
if kind == "" {
|
|
kind = string(AccountKindIndividual)
|
|
}
|
|
|
|
res, err := r.db.Exec(`
|
|
UPDATE accounts SET
|
|
kind = ?, display_name = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
kind, a.DisplayName, a.UpdatedAt.Unix(),
|
|
a.ID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update account: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("account %q not found", a.ID)
|
|
}
|
|
a.Kind = AccountKind(kind)
|
|
return nil
|
|
}
|
|
|
|
// ListAccounts returns all accounts.
|
|
func (r *TenantRegistry) ListAccounts() ([]*Account, error) {
|
|
rows, err := r.db.Query(`SELECT
|
|
id, kind, display_name, created_at, updated_at
|
|
FROM accounts ORDER BY created_at DESC`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list accounts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanAccounts(rows)
|
|
}
|
|
|
|
// CreateUser inserts a new user record.
|
|
func (r *TenantRegistry) CreateUser(u *User) error {
|
|
if u == nil {
|
|
return fmt.Errorf("user is nil")
|
|
}
|
|
now := time.Now().UTC()
|
|
if u.CreatedAt.IsZero() {
|
|
u.CreatedAt = now
|
|
}
|
|
if u.SessionVersion < 1 {
|
|
u.SessionVersion = 1
|
|
}
|
|
|
|
_, err := r.db.Exec(`
|
|
INSERT INTO users (
|
|
id, email, created_at, last_login_at, session_version
|
|
) VALUES (?, ?, ?, ?, ?)`,
|
|
u.ID, u.Email, u.CreatedAt.Unix(), nullableTimeUnix(u.LastLoginAt), u.SessionVersion,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("create user: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetUser retrieves a user by ID.
|
|
func (r *TenantRegistry) GetUser(userID string) (*User, error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
id, email, created_at, last_login_at, session_version
|
|
FROM users WHERE id = ?`, userID)
|
|
return scanUser(row)
|
|
}
|
|
|
|
// GetUserByEmail retrieves a user by email.
|
|
func (r *TenantRegistry) GetUserByEmail(email string) (*User, error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
id, email, created_at, last_login_at, session_version
|
|
FROM users WHERE email = ?`, email)
|
|
return scanUser(row)
|
|
}
|
|
|
|
// UpdateUserLastLogin sets last_login_at for the given user ID to the current time.
|
|
func (r *TenantRegistry) UpdateUserLastLogin(userID string) error {
|
|
now := time.Now().UTC()
|
|
res, err := r.db.Exec(`UPDATE users SET last_login_at = ? WHERE id = ?`, now.Unix(), userID)
|
|
if err != nil {
|
|
return fmt.Errorf("update user last login: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("user %q not found", userID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetUserSessionVersion returns the current session version for the user.
|
|
func (r *TenantRegistry) GetUserSessionVersion(userID string) (int64, error) {
|
|
userID = strings.TrimSpace(userID)
|
|
if userID == "" {
|
|
return 0, fmt.Errorf("missing user id")
|
|
}
|
|
|
|
var version int64
|
|
row := r.db.QueryRow(`SELECT session_version FROM users WHERE id = ?`, userID)
|
|
if err := row.Scan(&version); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return 0, fmt.Errorf("user %q not found", userID)
|
|
}
|
|
return 0, fmt.Errorf("get user session version: %w", err)
|
|
}
|
|
if version < 1 {
|
|
version = 1
|
|
}
|
|
return version, nil
|
|
}
|
|
|
|
// RevokeUserSessions increments the user's session version, invalidating all
|
|
// previously issued sessions, and returns the new version.
|
|
func (r *TenantRegistry) RevokeUserSessions(userID string) (int64, error) {
|
|
userID = strings.TrimSpace(userID)
|
|
if userID == "" {
|
|
return 0, fmt.Errorf("missing user id")
|
|
}
|
|
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("begin revoke sessions tx: %w", err)
|
|
}
|
|
defer func() {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
res, err := tx.Exec(`UPDATE users SET session_version = session_version + 1 WHERE id = ?`, userID)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("increment session version: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return 0, fmt.Errorf("user %q not found", userID)
|
|
}
|
|
|
|
var newVersion int64
|
|
if err := tx.QueryRow(`SELECT session_version FROM users WHERE id = ?`, userID).Scan(&newVersion); err != nil {
|
|
return 0, fmt.Errorf("read incremented session version: %w", err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return 0, fmt.Errorf("commit revoke sessions tx: %w", err)
|
|
}
|
|
tx = nil
|
|
|
|
if newVersion < 1 {
|
|
newVersion = 1
|
|
}
|
|
return newVersion, nil
|
|
}
|
|
|
|
// CreateMembership inserts a new membership record.
|
|
func (r *TenantRegistry) CreateMembership(m *AccountMembership) error {
|
|
if m == nil {
|
|
return fmt.Errorf("membership is nil")
|
|
}
|
|
now := time.Now().UTC()
|
|
if m.CreatedAt.IsZero() {
|
|
m.CreatedAt = now
|
|
}
|
|
role := string(m.Role)
|
|
if role == "" {
|
|
role = string(MemberRoleTech)
|
|
}
|
|
|
|
_, err := r.db.Exec(`
|
|
INSERT INTO account_memberships (
|
|
account_id, user_id, role, created_at
|
|
) VALUES (?, ?, ?, ?)`,
|
|
m.AccountID, m.UserID, role, m.CreatedAt.Unix(),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("create membership: %w", err)
|
|
}
|
|
m.Role = MemberRole(role)
|
|
return nil
|
|
}
|
|
|
|
// GetMembership retrieves a membership record by account ID and user ID.
|
|
func (r *TenantRegistry) GetMembership(accountID, userID string) (*AccountMembership, error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
account_id, user_id, role, created_at
|
|
FROM account_memberships
|
|
WHERE account_id = ? AND user_id = ?`, accountID, userID)
|
|
return scanMembership(row)
|
|
}
|
|
|
|
// ListMembersByAccount returns all membership records for a given account ID.
|
|
func (r *TenantRegistry) ListMembersByAccount(accountID string) ([]*AccountMembership, error) {
|
|
rows, err := r.db.Query(`SELECT
|
|
account_id, user_id, role, created_at
|
|
FROM account_memberships
|
|
WHERE account_id = ?
|
|
ORDER BY created_at DESC`, accountID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list members by account: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanMemberships(rows)
|
|
}
|
|
|
|
// ListAccountsByUser returns account IDs for all accounts the given user belongs to.
|
|
func (r *TenantRegistry) ListAccountsByUser(userID string) ([]string, error) {
|
|
rows, err := r.db.Query(`SELECT account_id FROM account_memberships WHERE user_id = ? ORDER BY created_at DESC`, userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list accounts by user: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var accountIDs []string
|
|
for rows.Next() {
|
|
var accountID string
|
|
if err := rows.Scan(&accountID); err != nil {
|
|
return nil, fmt.Errorf("scan account id: %w", err)
|
|
}
|
|
accountIDs = append(accountIDs, accountID)
|
|
}
|
|
return accountIDs, rows.Err()
|
|
}
|
|
|
|
// UpdateMembershipRole updates a membership role.
|
|
func (r *TenantRegistry) UpdateMembershipRole(accountID, userID string, role MemberRole) error {
|
|
res, err := r.db.Exec(`UPDATE account_memberships SET role = ? WHERE account_id = ? AND user_id = ?`, string(role), accountID, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("update membership role: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("membership (%q, %q) not found", accountID, userID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteMembership deletes a membership record.
|
|
func (r *TenantRegistry) DeleteMembership(accountID, userID string) error {
|
|
res, err := r.db.Exec(`DELETE FROM account_memberships WHERE account_id = ? AND user_id = ?`, accountID, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("delete membership: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("membership (%q, %q) not found", accountID, userID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreateStripeAccount inserts a new StripeAccount mapping row.
|
|
func (r *TenantRegistry) CreateStripeAccount(sa *StripeAccount) error {
|
|
if sa == nil {
|
|
return fmt.Errorf("stripe account is nil")
|
|
}
|
|
sa.AccountID = strings.TrimSpace(sa.AccountID)
|
|
sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID)
|
|
sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID)
|
|
sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID)
|
|
sa.PlanVersion = canonicalizeRegistryPlanVersion(sa.PlanVersion)
|
|
sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState)
|
|
|
|
if sa.AccountID == "" {
|
|
return fmt.Errorf("missing account id")
|
|
}
|
|
if sa.StripeCustomerID == "" {
|
|
return fmt.Errorf("missing stripe customer id")
|
|
}
|
|
if sa.SubscriptionState == "" {
|
|
sa.SubscriptionState = "trial"
|
|
}
|
|
if sa.UpdatedAt == 0 {
|
|
sa.UpdatedAt = time.Now().UTC().Unix()
|
|
}
|
|
|
|
_, err := r.db.Exec(`
|
|
INSERT INTO stripe_accounts (
|
|
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
|
|
plan_version, subscription_state, grace_started_at, trial_ends_at, current_period_end, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
sa.AccountID,
|
|
sa.StripeCustomerID,
|
|
nullableString(sa.StripeSubscriptionID),
|
|
nullableString(sa.StripeSubItemWorkspacesID),
|
|
sa.PlanVersion,
|
|
sa.SubscriptionState,
|
|
nullableInt64Ptr(sa.GraceStartedAt),
|
|
nullableInt64Ptr(sa.TrialEndsAt),
|
|
nullableInt64Ptr(sa.CurrentPeriodEnd),
|
|
sa.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("create stripe account: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetStripeAccount retrieves the StripeAccount row by account ID.
|
|
func (r *TenantRegistry) GetStripeAccount(accountID string) (*StripeAccount, error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
|
|
plan_version, subscription_state, grace_started_at, trial_ends_at, current_period_end, updated_at
|
|
FROM stripe_accounts WHERE account_id = ?`, strings.TrimSpace(accountID))
|
|
return scanStripeAccount(row)
|
|
}
|
|
|
|
// GetStripeAccountByCustomerID retrieves the StripeAccount row by Stripe customer ID.
|
|
func (r *TenantRegistry) GetStripeAccountByCustomerID(customerID string) (*StripeAccount, error) {
|
|
row := r.db.QueryRow(`SELECT
|
|
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
|
|
plan_version, subscription_state, grace_started_at, trial_ends_at, current_period_end, updated_at
|
|
FROM stripe_accounts WHERE stripe_customer_id = ?`, strings.TrimSpace(customerID))
|
|
return scanStripeAccount(row)
|
|
}
|
|
|
|
// ListStripeAccounts returns all Stripe account mappings.
|
|
func (r *TenantRegistry) ListStripeAccounts() ([]*StripeAccount, error) {
|
|
rows, err := r.db.Query(`SELECT
|
|
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
|
|
plan_version, subscription_state, grace_started_at, trial_ends_at, current_period_end, updated_at
|
|
FROM stripe_accounts
|
|
ORDER BY updated_at DESC`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list stripe accounts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []*StripeAccount
|
|
for rows.Next() {
|
|
sa, scanErr := scanStripeAccount(rows)
|
|
if scanErr != nil {
|
|
return nil, scanErr
|
|
}
|
|
if sa != nil {
|
|
out = append(out, sa)
|
|
}
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate stripe accounts: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// UpdateStripeAccount modifies an existing StripeAccount row.
|
|
func (r *TenantRegistry) UpdateStripeAccount(sa *StripeAccount) error {
|
|
if sa == nil {
|
|
return fmt.Errorf("stripe account is nil")
|
|
}
|
|
sa.AccountID = strings.TrimSpace(sa.AccountID)
|
|
sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID)
|
|
sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID)
|
|
sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID)
|
|
sa.PlanVersion = canonicalizeRegistryPlanVersion(sa.PlanVersion)
|
|
sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState)
|
|
|
|
if sa.AccountID == "" {
|
|
return fmt.Errorf("missing account id")
|
|
}
|
|
if sa.StripeCustomerID == "" {
|
|
return fmt.Errorf("missing stripe customer id")
|
|
}
|
|
if sa.SubscriptionState == "" {
|
|
sa.SubscriptionState = "trial"
|
|
}
|
|
|
|
sa.UpdatedAt = time.Now().UTC().Unix()
|
|
|
|
res, err := r.db.Exec(`
|
|
UPDATE stripe_accounts SET
|
|
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_sub_item_workspaces_id = ?,
|
|
plan_version = ?, subscription_state = ?, grace_started_at = ?, trial_ends_at = ?, current_period_end = ?, updated_at = ?
|
|
WHERE account_id = ?`,
|
|
sa.StripeCustomerID,
|
|
nullableString(sa.StripeSubscriptionID),
|
|
nullableString(sa.StripeSubItemWorkspacesID),
|
|
sa.PlanVersion,
|
|
sa.SubscriptionState,
|
|
nullableInt64Ptr(sa.GraceStartedAt),
|
|
nullableInt64Ptr(sa.TrialEndsAt),
|
|
nullableInt64Ptr(sa.CurrentPeriodEnd),
|
|
sa.UpdatedAt,
|
|
sa.AccountID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update stripe account: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("stripe account %q not found", sa.AccountID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RecordStripeEvent inserts a webhook event record and returns true if the
|
|
// event was already recorded (duplicate Stripe delivery).
|
|
func (r *TenantRegistry) RecordStripeEvent(eventID, eventType string) (alreadyProcessed bool, err error) {
|
|
eventID = strings.TrimSpace(eventID)
|
|
eventType = strings.TrimSpace(eventType)
|
|
if eventID == "" {
|
|
return false, fmt.Errorf("missing stripe event id")
|
|
}
|
|
if eventType == "" {
|
|
return false, fmt.Errorf("missing stripe event type")
|
|
}
|
|
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return false, fmt.Errorf("begin record stripe event tx: %w", err)
|
|
}
|
|
defer func() {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
now := time.Now().UTC().Unix()
|
|
|
|
// INSERT OR IGNORE avoids driver-specific error parsing for duplicates.
|
|
res, err := tx.Exec(`
|
|
INSERT OR IGNORE INTO stripe_events (
|
|
stripe_event_id, event_type, received_at, processing_started_at, processed_at, processing_error
|
|
) VALUES (?, ?, ?, ?, NULL, NULL)`,
|
|
eventID, eventType, now, now,
|
|
)
|
|
if err != nil {
|
|
return false, fmt.Errorf("record stripe event: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return false, fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 1 {
|
|
if err := tx.Commit(); err != nil {
|
|
return false, fmt.Errorf("commit record stripe event tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return false, nil
|
|
}
|
|
|
|
var processedAt sql.NullInt64
|
|
var processingStartedAt sql.NullInt64
|
|
var processingError sql.NullString
|
|
if err := tx.QueryRow(`
|
|
SELECT processed_at, processing_started_at, processing_error
|
|
FROM stripe_events WHERE stripe_event_id = ?`,
|
|
eventID,
|
|
).Scan(&processedAt, &processingStartedAt, &processingError); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return false, fmt.Errorf("stripe event %q not found after insert-or-ignore", eventID)
|
|
}
|
|
return false, fmt.Errorf("query stripe event status: %w", err)
|
|
}
|
|
|
|
// Exact duplicates that were already processed successfully are skipped.
|
|
if processedAt.Valid && strings.TrimSpace(processingError.String) == "" {
|
|
if err := tx.Commit(); err != nil {
|
|
return false, fmt.Errorf("commit record stripe event tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return true, nil
|
|
}
|
|
|
|
// If another request is currently processing this event and that processing
|
|
// window is still fresh, skip duplicate execution.
|
|
if processingStartedAt.Valid && !processedAt.Valid && strings.TrimSpace(processingError.String) == "" {
|
|
if now-processingStartedAt.Int64 <= stripeEventProcessingLeaseSeconds {
|
|
if err := tx.Commit(); err != nil {
|
|
return false, fmt.Errorf("commit record stripe event tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
// Reclaim failed/stale event deliveries for retry processing.
|
|
_, err = tx.Exec(`
|
|
UPDATE stripe_events SET
|
|
event_type = ?,
|
|
received_at = ?,
|
|
processing_started_at = ?,
|
|
processed_at = NULL,
|
|
processing_error = NULL
|
|
WHERE stripe_event_id = ?`,
|
|
eventType, now, now, eventID,
|
|
)
|
|
if err != nil {
|
|
return false, fmt.Errorf("reclaim stripe event for retry: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return false, fmt.Errorf("commit record stripe event tx: %w", err)
|
|
}
|
|
tx = nil
|
|
return false, nil
|
|
}
|
|
|
|
// MarkStripeEventProcessed marks a previously recorded event as processed.
|
|
// processingError is stored (nullable) for troubleshooting.
|
|
func (r *TenantRegistry) MarkStripeEventProcessed(eventID string, processingError string) error {
|
|
eventID = strings.TrimSpace(eventID)
|
|
if eventID == "" {
|
|
return fmt.Errorf("missing stripe event id")
|
|
}
|
|
processingError = strings.TrimSpace(processingError)
|
|
|
|
res, err := r.db.Exec(`
|
|
UPDATE stripe_events SET
|
|
processing_started_at = NULL,
|
|
processed_at = ?, processing_error = ?
|
|
WHERE stripe_event_id = ?`,
|
|
time.Now().UTC().Unix(),
|
|
nullableString(processingError),
|
|
eventID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("mark stripe event processed: %w", err)
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("get rows affected: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("stripe event %q not found", eventID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func scanAccount(s scanner) (*Account, error) {
|
|
var a Account
|
|
var kind string
|
|
var createdAt, updatedAt int64
|
|
if err := s.Scan(&a.ID, &kind, &a.DisplayName, &createdAt, &updatedAt); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("scan account: %w", err)
|
|
}
|
|
a.Kind = AccountKind(kind)
|
|
a.CreatedAt = time.Unix(createdAt, 0).UTC()
|
|
a.UpdatedAt = time.Unix(updatedAt, 0).UTC()
|
|
return &a, nil
|
|
}
|
|
|
|
func scanAccounts(rows *sql.Rows) ([]*Account, error) {
|
|
var accounts []*Account
|
|
for rows.Next() {
|
|
a, err := scanAccount(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
accounts = append(accounts, a)
|
|
}
|
|
return accounts, rows.Err()
|
|
}
|
|
|
|
func scanUser(s scanner) (*User, error) {
|
|
var u User
|
|
var createdAt int64
|
|
var lastLogin sql.NullInt64
|
|
if err := s.Scan(&u.ID, &u.Email, &createdAt, &lastLogin, &u.SessionVersion); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("scan user: %w", err)
|
|
}
|
|
u.CreatedAt = time.Unix(createdAt, 0).UTC()
|
|
if lastLogin.Valid {
|
|
ts := time.Unix(lastLogin.Int64, 0).UTC()
|
|
u.LastLoginAt = &ts
|
|
}
|
|
if u.SessionVersion < 1 {
|
|
u.SessionVersion = 1
|
|
}
|
|
return &u, nil
|
|
}
|
|
|
|
func scanMembership(s scanner) (*AccountMembership, error) {
|
|
var m AccountMembership
|
|
var role string
|
|
var createdAt int64
|
|
if err := s.Scan(&m.AccountID, &m.UserID, &role, &createdAt); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("scan membership: %w", err)
|
|
}
|
|
m.Role = MemberRole(role)
|
|
m.CreatedAt = time.Unix(createdAt, 0).UTC()
|
|
return &m, nil
|
|
}
|
|
|
|
func scanMemberships(rows *sql.Rows) ([]*AccountMembership, error) {
|
|
var memberships []*AccountMembership
|
|
for rows.Next() {
|
|
m, err := scanMembership(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
memberships = append(memberships, m)
|
|
}
|
|
return memberships, rows.Err()
|
|
}
|
|
|
|
func nullableTimeUnix(t *time.Time) driver.Value {
|
|
if t == nil {
|
|
return nil
|
|
}
|
|
return t.Unix()
|
|
}
|
|
|
|
func nullableInt64Ptr(v *int64) driver.Value {
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
return *v
|
|
}
|
|
|
|
func nullableString(s string) driver.Value {
|
|
if strings.TrimSpace(s) == "" {
|
|
return nil
|
|
}
|
|
return strings.TrimSpace(s)
|
|
}
|
|
|
|
func boolToInt(b bool) int {
|
|
if b {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func scanStripeAccount(s scanner) (*StripeAccount, error) {
|
|
var sa StripeAccount
|
|
var subID, subItemID sql.NullString
|
|
var graceStartedAt, trialEnds, periodEnd sql.NullInt64
|
|
if err := s.Scan(
|
|
&sa.AccountID,
|
|
&sa.StripeCustomerID,
|
|
&subID,
|
|
&subItemID,
|
|
&sa.PlanVersion,
|
|
&sa.SubscriptionState,
|
|
&graceStartedAt,
|
|
&trialEnds,
|
|
&periodEnd,
|
|
&sa.UpdatedAt,
|
|
); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("scan stripe account: %w", err)
|
|
}
|
|
if subID.Valid {
|
|
sa.StripeSubscriptionID = subID.String
|
|
}
|
|
if subItemID.Valid {
|
|
sa.StripeSubItemWorkspacesID = subItemID.String
|
|
}
|
|
if graceStartedAt.Valid {
|
|
v := graceStartedAt.Int64
|
|
sa.GraceStartedAt = &v
|
|
}
|
|
if trialEnds.Valid {
|
|
v := trialEnds.Int64
|
|
sa.TrialEndsAt = &v
|
|
}
|
|
if periodEnd.Valid {
|
|
v := periodEnd.Int64
|
|
sa.CurrentPeriodEnd = &v
|
|
}
|
|
sa.PlanVersion = canonicalizeRegistryPlanVersion(sa.PlanVersion)
|
|
return &sa, nil
|
|
}
|