Pulse/internal/cloudcp/registry/registry.go
2026-03-18 16:06:30 +00:00

1973 lines
64 KiB
Go

package registry
import (
"database/sql"
"database/sql/driver"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"
pkglicensing "github.com/rcourtman/pulse-go-rewrite/pkg/licensing"
_ "modernc.org/sqlite"
)
// TenantRegistry provides CRUD operations for tenant records backed by SQLite.
type TenantRegistry struct {
db *sql.DB
}
const stripeEventProcessingLeaseSeconds int64 = 120
func canonicalizeRegistryPlanVersion(planVersion string) string {
return pkglicensing.CanonicalizePlanVersion(strings.TrimSpace(planVersion))
}
// NewTenantRegistry opens (or creates) the tenant registry database in dir.
func NewTenantRegistry(dir string) (*TenantRegistry, error) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create registry dir: %w", err)
}
dbPath := filepath.Join(dir, "tenants.db")
dsn := dbPath + "?" + url.Values{
"_pragma": []string{
"busy_timeout(30000)",
"foreign_keys(ON)",
"journal_mode(WAL)",
"synchronous(NORMAL)",
},
}.Encode()
db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, fmt.Errorf("open tenant registry db: %w", err)
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0)
r := &TenantRegistry{db: db}
if err := r.initSchema(); err != nil {
_ = db.Close()
return nil, err
}
return r, nil
}
func (r *TenantRegistry) initSchema() error {
schema := `
CREATE TABLE IF NOT EXISTS tenants (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL DEFAULT '',
email TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT 'provisioning',
stripe_customer_id TEXT NOT NULL DEFAULT '',
stripe_subscription_id TEXT NOT NULL DEFAULT '',
stripe_price_id TEXT NOT NULL DEFAULT '',
plan_version TEXT NOT NULL DEFAULT '',
container_id TEXT NOT NULL DEFAULT '',
current_image_digest TEXT NOT NULL DEFAULT '',
desired_image_digest TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
last_health_check INTEGER,
health_check_ok INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_tenants_state ON tenants(state);
CREATE INDEX IF NOT EXISTS idx_tenants_stripe_customer_id ON tenants(stripe_customer_id);
CREATE INDEX IF NOT EXISTS idx_tenants_created_at ON tenants(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_tenants_state_created_at ON tenants(state, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_tenants_account_id_created_at ON tenants(account_id, created_at DESC);
CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
kind TEXT NOT NULL DEFAULT 'individual',
display_name TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_accounts_created_at ON accounts(created_at DESC);
CREATE TABLE IF NOT EXISTS stripe_accounts (
account_id TEXT PRIMARY KEY,
stripe_customer_id TEXT NOT NULL UNIQUE,
stripe_subscription_id TEXT,
stripe_sub_item_workspaces_id TEXT,
plan_version TEXT NOT NULL DEFAULT '',
subscription_state TEXT NOT NULL DEFAULT 'trial',
grace_started_at INTEGER,
trial_ends_at INTEGER,
current_period_end INTEGER,
updated_at INTEGER NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id)
);
CREATE INDEX IF NOT EXISTS idx_stripe_accounts_customer ON stripe_accounts(stripe_customer_id);
CREATE TABLE IF NOT EXISTS stripe_events (
stripe_event_id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
received_at INTEGER NOT NULL,
processing_started_at INTEGER,
processed_at INTEGER,
processing_error TEXT
);
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT NOT NULL UNIQUE,
created_at INTEGER NOT NULL,
last_login_at INTEGER,
session_version INTEGER NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS account_memberships (
account_id TEXT NOT NULL,
user_id TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'tech',
created_at INTEGER NOT NULL,
PRIMARY KEY (account_id, user_id),
FOREIGN KEY (account_id) REFERENCES accounts(id),
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS idx_memberships_user_id ON account_memberships(user_id);
CREATE INDEX IF NOT EXISTS idx_memberships_user_id_created_at ON account_memberships(user_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_memberships_account_id_created_at ON account_memberships(account_id, created_at DESC);
CREATE TABLE IF NOT EXISTS hosted_entitlements (
id TEXT PRIMARY KEY,
kind TEXT NOT NULL DEFAULT 'paid',
tenant_id TEXT,
trial_request_id TEXT,
org_id TEXT NOT NULL DEFAULT '',
email TEXT NOT NULL DEFAULT '',
return_url TEXT NOT NULL DEFAULT '',
instance_token TEXT NOT NULL DEFAULT '',
instance_host TEXT NOT NULL DEFAULT '',
trial_started_at INTEGER,
refresh_token TEXT NOT NULL UNIQUE,
activation_token TEXT NOT NULL DEFAULT '',
issued_at INTEGER NOT NULL,
activation_issued_at INTEGER,
last_refreshed_at INTEGER,
redeemed_at INTEGER,
revoked_at INTEGER,
FOREIGN KEY (tenant_id) REFERENCES tenants(id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_tenant_id ON hosted_entitlements(tenant_id) WHERE tenant_id <> '';
CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_trial_request_id ON hosted_entitlements(trial_request_id) WHERE trial_request_id <> '';
CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_refresh_token ON hosted_entitlements(refresh_token);
CREATE INDEX IF NOT EXISTS idx_hosted_entitlements_kind ON hosted_entitlements(kind);
`
if _, err := r.db.Exec(schema); err != nil {
return fmt.Errorf("init tenant registry schema: %w", err)
}
// Migration: add account_id to tenants if not present.
// (SQLite makes it awkward to add FK constraints via ALTER TABLE, and FK
// enforcement is off by default; this keeps the change backwards-compatible.)
hasAccountID, err := r.tenantsHasColumn("account_id")
if err != nil {
return fmt.Errorf("check tenants schema for account_id: %w", err)
}
if !hasAccountID {
if _, err := r.db.Exec(`ALTER TABLE tenants ADD COLUMN account_id TEXT NOT NULL DEFAULT ''`); err != nil {
return fmt.Errorf("migrate tenants: add account_id: %w", err)
}
}
hasEntitlementRefreshToken, err := r.tenantsHasColumn("entitlement_refresh_token")
if err != nil {
return fmt.Errorf("check tenants schema for entitlement_refresh_token: %w", err)
}
if _, err := r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_tenants_account_id ON tenants(account_id)`); err != nil {
return fmt.Errorf("init tenant registry schema: create idx_tenants_account_id: %w", err)
}
hasHostedEntitlementID, err := r.tableHasColumn("hosted_entitlements", "id")
if err != nil {
return fmt.Errorf("check hosted_entitlements schema for id: %w", err)
}
if !hasHostedEntitlementID {
if err := r.migrateLegacyHostedEntitlementsTable(); err != nil {
return fmt.Errorf("migrate hosted_entitlements table: %w", err)
}
}
if _, err := r.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_tenant_id ON hosted_entitlements(tenant_id) WHERE tenant_id <> ''`); err != nil {
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_tenant_id: %w", err)
}
if _, err := r.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_trial_request_id ON hosted_entitlements(trial_request_id) WHERE trial_request_id <> ''`); err != nil {
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_trial_request_id: %w", err)
}
if _, err := r.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_refresh_token ON hosted_entitlements(refresh_token)`); err != nil {
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_refresh_token: %w", err)
}
if _, err := r.db.Exec(`CREATE INDEX IF NOT EXISTS idx_hosted_entitlements_kind ON hosted_entitlements(kind)`); err != nil {
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_kind: %w", err)
}
hasHostedActivationToken, err := r.tableHasColumn("hosted_entitlements", "activation_token")
if err != nil {
return fmt.Errorf("check hosted_entitlements schema for activation_token: %w", err)
}
if !hasHostedActivationToken {
if _, err := r.db.Exec(`ALTER TABLE hosted_entitlements ADD COLUMN activation_token TEXT NOT NULL DEFAULT ''`); err != nil {
return fmt.Errorf("migrate hosted_entitlements: add activation_token: %w", err)
}
}
hasHostedActivationIssuedAt, err := r.tableHasColumn("hosted_entitlements", "activation_issued_at")
if err != nil {
return fmt.Errorf("check hosted_entitlements schema for activation_issued_at: %w", err)
}
if !hasHostedActivationIssuedAt {
if _, err := r.db.Exec(`ALTER TABLE hosted_entitlements ADD COLUMN activation_issued_at INTEGER`); err != nil {
return fmt.Errorf("migrate hosted_entitlements: add activation_issued_at: %w", err)
}
}
if _, err := r.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_hosted_entitlements_activation_token ON hosted_entitlements(activation_token) WHERE activation_token <> ''`); err != nil {
return fmt.Errorf("init tenant registry schema: create idx_hosted_entitlements_activation_token: %w", err)
}
if hasEntitlementRefreshToken {
if _, err := r.db.Exec(`
INSERT OR IGNORE INTO hosted_entitlements (
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
)
SELECT 'paid:' || id, 'paid', id, NULL, '', '', '', '', '', NULL, entitlement_refresh_token, '', updated_at, NULL, NULL, NULL, NULL
FROM tenants
WHERE entitlement_refresh_token <> ''
`); err != nil {
return fmt.Errorf("migrate hosted entitlements from tenants: %w", err)
}
if err := r.removeLegacyTenantEntitlementRefreshTokenColumn(); err != nil {
return fmt.Errorf("migrate tenants: drop legacy entitlement_refresh_token: %w", err)
}
}
// Migration: add session_version to users if missing.
hasSessionVersion, err := r.tableHasColumn("users", "session_version")
if err != nil {
return fmt.Errorf("check users schema for session_version: %w", err)
}
if !hasSessionVersion {
if _, err := r.db.Exec(`ALTER TABLE users ADD COLUMN session_version INTEGER NOT NULL DEFAULT 1`); err != nil {
return fmt.Errorf("migrate users: add session_version: %w", err)
}
}
if _, err := r.db.Exec(`UPDATE users SET session_version = 1 WHERE session_version IS NULL OR session_version < 1`); err != nil {
return fmt.Errorf("migrate users: backfill session_version: %w", err)
}
// Migration: add processing_started_at to stripe_events if missing.
hasStripeProcessingStarted, err := r.tableHasColumn("stripe_events", "processing_started_at")
if err != nil {
return fmt.Errorf("check stripe_events schema for processing_started_at: %w", err)
}
if !hasStripeProcessingStarted {
if _, err := r.db.Exec(`ALTER TABLE stripe_events ADD COLUMN processing_started_at INTEGER`); err != nil {
return fmt.Errorf("migrate stripe_events: add processing_started_at: %w", err)
}
}
// Migration: add grace_started_at to stripe_accounts if missing.
hasGraceStartedAt, err := r.tableHasColumn("stripe_accounts", "grace_started_at")
if err != nil {
return fmt.Errorf("check stripe_accounts schema for grace_started_at: %w", err)
}
if !hasGraceStartedAt {
if _, err := r.db.Exec(`ALTER TABLE stripe_accounts ADD COLUMN grace_started_at INTEGER`); err != nil {
return fmt.Errorf("migrate stripe_accounts: add grace_started_at: %w", err)
}
}
// Backfill legacy past_due/grace rows to preserve existing grace windows.
if _, err := r.db.Exec(`
UPDATE stripe_accounts
SET grace_started_at = updated_at
WHERE grace_started_at IS NULL
AND subscription_state IN ('past_due', 'grace')
`); err != nil {
return fmt.Errorf("migrate stripe_accounts: backfill grace_started_at: %w", err)
}
return nil
}
func (r *TenantRegistry) tenantsHasColumn(name string) (bool, error) {
return r.tableHasColumn("tenants", name)
}
func (r *TenantRegistry) tableHasColumn(tableName, name string) (bool, error) {
rows, err := r.db.Query(`PRAGMA table_info(` + tableName + `)`)
if err != nil {
return false, fmt.Errorf("pragma table_info(%s): %w", tableName, err)
}
defer rows.Close()
for rows.Next() {
var (
cid int
colName string
colType string
notNull int
dflt sql.NullString
pk int
)
if err := rows.Scan(&cid, &colName, &colType, &notNull, &dflt, &pk); err != nil {
return false, fmt.Errorf("scan table_info(%s): %w", tableName, err)
}
if colName == name {
return true, nil
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("iterate table_info(%s): %w", tableName, err)
}
return false, nil
}
func (r *TenantRegistry) removeLegacyTenantEntitlementRefreshTokenColumn() error {
if _, err := r.db.Exec(`PRAGMA foreign_keys = OFF`); err != nil {
return fmt.Errorf("disable foreign keys for tenants migration: %w", err)
}
defer func() {
_, _ = r.db.Exec(`PRAGMA foreign_keys = ON`)
}()
tx, err := r.db.Begin()
if err != nil {
return fmt.Errorf("begin tenants migration: %w", err)
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
if _, err := tx.Exec(`
CREATE TABLE tenants_new (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL DEFAULT '',
email TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT 'provisioning',
stripe_customer_id TEXT NOT NULL DEFAULT '',
stripe_subscription_id TEXT NOT NULL DEFAULT '',
stripe_price_id TEXT NOT NULL DEFAULT '',
plan_version TEXT NOT NULL DEFAULT '',
container_id TEXT NOT NULL DEFAULT '',
current_image_digest TEXT NOT NULL DEFAULT '',
desired_image_digest TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
last_health_check INTEGER,
health_check_ok INTEGER NOT NULL DEFAULT 0
);
INSERT INTO tenants_new (
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
)
SELECT
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
FROM tenants;
DROP TABLE tenants;
ALTER TABLE tenants_new RENAME TO tenants;
CREATE INDEX idx_tenants_state ON tenants(state);
CREATE INDEX idx_tenants_stripe_customer_id ON tenants(stripe_customer_id);
CREATE INDEX idx_tenants_created_at ON tenants(created_at DESC);
CREATE INDEX idx_tenants_state_created_at ON tenants(state, created_at DESC);
CREATE INDEX idx_tenants_account_id ON tenants(account_id);
CREATE INDEX idx_tenants_account_id_created_at ON tenants(account_id, created_at DESC);
`); err != nil {
return fmt.Errorf("rebuild tenants table without legacy entitlement column: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit tenants migration: %w", err)
}
tx = nil
return nil
}
func (r *TenantRegistry) migrateLegacyHostedEntitlementsTable() error {
if _, err := r.db.Exec(`PRAGMA foreign_keys = OFF`); err != nil {
return fmt.Errorf("disable foreign keys for hosted entitlements migration: %w", err)
}
defer func() {
_, _ = r.db.Exec(`PRAGMA foreign_keys = ON`)
}()
tx, err := r.db.Begin()
if err != nil {
return fmt.Errorf("begin hosted entitlements migration: %w", err)
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
if _, err := tx.Exec(`
CREATE TABLE hosted_entitlements_new (
id TEXT PRIMARY KEY,
kind TEXT NOT NULL DEFAULT 'paid',
tenant_id TEXT,
trial_request_id TEXT,
org_id TEXT NOT NULL DEFAULT '',
email TEXT NOT NULL DEFAULT '',
return_url TEXT NOT NULL DEFAULT '',
instance_token TEXT NOT NULL DEFAULT '',
instance_host TEXT NOT NULL DEFAULT '',
trial_started_at INTEGER,
refresh_token TEXT NOT NULL UNIQUE,
activation_token TEXT NOT NULL DEFAULT '',
issued_at INTEGER NOT NULL,
activation_issued_at INTEGER,
last_refreshed_at INTEGER,
redeemed_at INTEGER,
revoked_at INTEGER,
FOREIGN KEY (tenant_id) REFERENCES tenants(id) ON DELETE CASCADE
);
INSERT INTO hosted_entitlements_new (
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
)
SELECT
'paid:' || tenant_id,
'paid',
tenant_id,
NULL,
'',
'',
'',
'',
'',
NULL,
refresh_token,
'',
issued_at,
NULL,
last_refreshed_at,
NULL,
revoked_at
FROM hosted_entitlements;
DROP TABLE hosted_entitlements;
ALTER TABLE hosted_entitlements_new RENAME TO hosted_entitlements;
CREATE UNIQUE INDEX idx_hosted_entitlements_tenant_id ON hosted_entitlements(tenant_id) WHERE tenant_id <> '';
CREATE UNIQUE INDEX idx_hosted_entitlements_trial_request_id ON hosted_entitlements(trial_request_id) WHERE trial_request_id <> '';
CREATE UNIQUE INDEX idx_hosted_entitlements_refresh_token ON hosted_entitlements(refresh_token);
CREATE UNIQUE INDEX idx_hosted_entitlements_activation_token ON hosted_entitlements(activation_token) WHERE activation_token <> '';
CREATE INDEX idx_hosted_entitlements_kind ON hosted_entitlements(kind);
`); err != nil {
return fmt.Errorf("rebuild hosted_entitlements table: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit hosted entitlements migration: %w", err)
}
tx = nil
return nil
}
// Ping checks database connectivity (used for readiness probes).
func (r *TenantRegistry) Ping() error {
return r.db.Ping()
}
// Close closes the underlying database connection.
func (r *TenantRegistry) Close() error {
if r == nil || r.db == nil {
return nil
}
return r.db.Close()
}
// Create inserts a new tenant record.
func (r *TenantRegistry) Create(t *Tenant) error {
if t == nil {
return fmt.Errorf("tenant is nil")
}
now := time.Now().UTC()
if t.CreatedAt.IsZero() {
t.CreatedAt = now
}
t.PlanVersion = canonicalizeRegistryPlanVersion(t.PlanVersion)
t.UpdatedAt = now
_, err := r.db.Exec(`
INSERT INTO tenants (
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
t.ID, t.AccountID, t.Email, t.DisplayName, string(t.State),
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
t.CreatedAt.Unix(), t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
)
if err != nil {
return fmt.Errorf("create tenant: %w", err)
}
return nil
}
// Get retrieves a tenant by ID.
func (r *TenantRegistry) Get(tenantID string) (*Tenant, error) {
row := r.db.QueryRow(`SELECT
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
FROM tenants WHERE id = ?`, tenantID)
return scanTenant(row)
}
// GetByStripeCustomerID retrieves a tenant by Stripe customer ID.
func (r *TenantRegistry) GetByStripeCustomerID(customerID string) (*Tenant, error) {
row := r.db.QueryRow(`SELECT
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
FROM tenants WHERE stripe_customer_id = ?`, customerID)
return scanTenant(row)
}
// Update modifies an existing tenant record.
func (r *TenantRegistry) Update(t *Tenant) error {
if t == nil {
return fmt.Errorf("tenant is nil")
}
t.PlanVersion = canonicalizeRegistryPlanVersion(t.PlanVersion)
t.UpdatedAt = time.Now().UTC()
res, err := r.db.Exec(`
UPDATE tenants SET
account_id = ?, email = ?, display_name = ?, state = ?,
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_price_id = ?,
plan_version = ?, container_id = ?, current_image_digest = ?, desired_image_digest = ?,
updated_at = ?, last_health_check = ?, health_check_ok = ?
WHERE id = ?`,
t.AccountID, t.Email, t.DisplayName, string(t.State),
t.StripeCustomerID, t.StripeSubscriptionID, t.StripePriceID,
t.PlanVersion, t.ContainerID, t.CurrentImageDigest, t.DesiredImageDigest,
t.UpdatedAt.Unix(), nullableTimeUnix(t.LastHealthCheck), boolToInt(t.HealthCheckOK),
t.ID,
)
if err != nil {
return fmt.Errorf("update tenant: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("tenant %q not found", t.ID)
}
return nil
}
// Delete removes a tenant record by ID.
func (r *TenantRegistry) Delete(id string) error {
if _, err := r.db.Exec(`DELETE FROM hosted_entitlements WHERE tenant_id = ?`, id); err != nil {
return fmt.Errorf("delete hosted entitlement: %w", err)
}
res, err := r.db.Exec(`DELETE FROM tenants WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("delete tenant: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("tenant %q not found", id)
}
return nil
}
// List returns all tenants.
func (r *TenantRegistry) List() ([]*Tenant, error) {
rows, err := r.db.Query(`SELECT
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
FROM tenants ORDER BY created_at DESC`)
if err != nil {
return nil, fmt.Errorf("list tenants: %w", err)
}
defer rows.Close()
return scanTenants(rows)
}
// ListByState returns all tenants matching the given state.
func (r *TenantRegistry) ListByState(state TenantState) ([]*Tenant, error) {
rows, err := r.db.Query(`SELECT
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
FROM tenants WHERE state = ? ORDER BY created_at DESC`, string(state))
if err != nil {
return nil, fmt.Errorf("list tenants by state: %w", err)
}
defer rows.Close()
return scanTenants(rows)
}
// ListByAccountID returns all tenants belonging to the given account ID.
func (r *TenantRegistry) ListByAccountID(accountID string) ([]*Tenant, error) {
rows, err := r.db.Query(`SELECT
id, account_id, email, display_name, state,
stripe_customer_id, stripe_subscription_id, stripe_price_id,
plan_version, container_id, current_image_digest, desired_image_digest,
created_at, updated_at, last_health_check, health_check_ok
FROM tenants WHERE account_id = ? ORDER BY created_at DESC`, accountID)
if err != nil {
return nil, fmt.Errorf("list tenants by account id: %w", err)
}
defer rows.Close()
return scanTenants(rows)
}
// CountActiveByAccountID returns the number of non-deleted tenants belonging to
// the given account. States counted: provisioning, active, suspended, failed.
// States excluded: deleting, deleted, canceled.
func (r *TenantRegistry) CountActiveByAccountID(accountID string) (int, error) {
var count int
err := r.db.QueryRow(`
SELECT COUNT(*) FROM tenants
WHERE account_id = ?
AND state NOT IN ('deleting', 'deleted', 'canceled')`,
accountID,
).Scan(&count)
if err != nil {
return 0, fmt.Errorf("count active tenants for account %q: %w", accountID, err)
}
return count, nil
}
// GetTenantForAccount retrieves a tenant by ID and verifies it belongs to the
// given account. Returns (nil, nil) if the tenant does not exist or belongs to
// a different account.
func (r *TenantRegistry) GetTenantForAccount(accountID, tenantID string) (*Tenant, error) {
t, err := r.Get(tenantID)
if err != nil {
return nil, err
}
if t == nil {
return nil, nil
}
if strings.TrimSpace(t.AccountID) == "" || t.AccountID != accountID {
return nil, nil
}
return t, nil
}
// CountByState returns a map of state -> count.
func (r *TenantRegistry) CountByState() (map[TenantState]int, error) {
rows, err := r.db.Query(`SELECT state, COUNT(*) FROM tenants GROUP BY state`)
if err != nil {
return nil, fmt.Errorf("count tenants by state: %w", err)
}
defer rows.Close()
counts := make(map[TenantState]int)
for rows.Next() {
var state string
var count int
if err := rows.Scan(&state, &count); err != nil {
return nil, fmt.Errorf("scan count: %w", err)
}
counts[TenantState(state)] = count
}
return counts, rows.Err()
}
// HealthSummary returns the number of healthy and unhealthy active tenants.
func (r *TenantRegistry) HealthSummary() (healthy, unhealthy int, err error) {
row := r.db.QueryRow(`SELECT
COALESCE(SUM(CASE WHEN health_check_ok = 1 THEN 1 ELSE 0 END), 0),
COALESCE(SUM(CASE WHEN health_check_ok = 0 THEN 1 ELSE 0 END), 0)
FROM tenants WHERE state = ?`, string(TenantStateActive))
if err := row.Scan(&healthy, &unhealthy); err != nil {
return 0, 0, fmt.Errorf("health summary: %w", err)
}
return healthy, unhealthy, nil
}
// scanner is an interface satisfied by both *sql.Row and *sql.Rows.
type scanner interface {
Scan(dest ...any) error
}
func scanTenant(s scanner) (*Tenant, error) {
var t Tenant
var state string
var createdAt, updatedAt int64
var lastHealthCheck sql.NullInt64
var healthOK int
err := s.Scan(
&t.ID, &t.AccountID, &t.Email, &t.DisplayName, &state,
&t.StripeCustomerID, &t.StripeSubscriptionID, &t.StripePriceID,
&t.PlanVersion, &t.ContainerID, &t.CurrentImageDigest, &t.DesiredImageDigest,
&createdAt, &updatedAt, &lastHealthCheck, &healthOK,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan tenant: %w", err)
}
t.State = TenantState(state)
t.PlanVersion = canonicalizeRegistryPlanVersion(t.PlanVersion)
t.CreatedAt = time.Unix(createdAt, 0).UTC()
t.UpdatedAt = time.Unix(updatedAt, 0).UTC()
if lastHealthCheck.Valid {
ts := time.Unix(lastHealthCheck.Int64, 0).UTC()
t.LastHealthCheck = &ts
}
t.HealthCheckOK = healthOK != 0
return &t, nil
}
func paidHostedEntitlementID(tenantID string) string {
return "paid:" + strings.TrimSpace(tenantID)
}
func trialHostedEntitlementID(requestID string) string {
return "trial:" + strings.TrimSpace(requestID)
}
// StoreOrIssueHostedEntitlement stores a new hosted entitlement refresh token for a paid tenant,
// or returns the existing active token if one has already been issued.
func (r *TenantRegistry) StoreOrIssueHostedEntitlement(tenantID, token string, issuedAt time.Time) (string, bool, error) {
tenantID = strings.TrimSpace(tenantID)
token = strings.TrimSpace(token)
issuedAt = issuedAt.UTC()
if tenantID == "" {
return "", false, fmt.Errorf("missing tenant id")
}
if token == "" {
return "", false, fmt.Errorf("missing entitlement refresh token")
}
tx, err := r.db.Begin()
if err != nil {
return "", false, fmt.Errorf("begin entitlement refresh token tx: %w", err)
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
if err := tx.QueryRow(`SELECT id FROM tenants WHERE id = ?`, tenantID).Scan(new(string)); err != nil {
if err == sql.ErrNoRows {
return "", false, fmt.Errorf("tenant %q not found", tenantID)
}
return "", false, fmt.Errorf("load tenant for hosted entitlement: %w", err)
}
rec, err := loadHostedEntitlement(tx.QueryRow(`
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
FROM hosted_entitlements WHERE tenant_id = ?`,
tenantID,
))
if err != nil {
return "", false, fmt.Errorf("load hosted entitlement: %w", err)
}
if rec != nil && strings.TrimSpace(rec.RefreshToken) != "" && rec.RevokedAt == nil {
if err := tx.Commit(); err != nil {
return "", false, fmt.Errorf("commit hosted entitlement tx: %w", err)
}
tx = nil
return rec.RefreshToken, false, nil
}
if rec == nil {
if _, err := tx.Exec(`
INSERT INTO hosted_entitlements (
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
) VALUES (?, ?, ?, NULL, '', '', '', '', '', NULL, ?, '', ?, NULL, NULL, NULL, NULL)`,
paidHostedEntitlementID(tenantID),
string(HostedEntitlementKindPaid),
tenantID,
token,
issuedAt.Unix(),
); err != nil {
return "", false, fmt.Errorf("insert hosted entitlement: %w", err)
}
} else {
if _, err := tx.Exec(`
UPDATE hosted_entitlements
SET id = ?, kind = ?, tenant_id = ?, trial_request_id = NULL, org_id = '', email = '',
return_url = '', instance_token = '', instance_host = '', trial_started_at = NULL,
refresh_token = ?, activation_token = '', issued_at = ?, activation_issued_at = NULL,
last_refreshed_at = NULL, redeemed_at = NULL, revoked_at = NULL
WHERE tenant_id = ?`,
paidHostedEntitlementID(tenantID),
string(HostedEntitlementKindPaid),
tenantID,
token,
issuedAt.Unix(),
tenantID,
); err != nil {
return "", false, fmt.Errorf("rotate hosted entitlement: %w", err)
}
}
if err := tx.Commit(); err != nil {
return "", false, fmt.Errorf("commit hosted entitlement tx: %w", err)
}
tx = nil
return token, true, nil
}
func (r *TenantRegistry) StoreOrIssueTrialHostedEntitlement(input TrialHostedEntitlementInput) (string, bool, error) {
requestID := strings.TrimSpace(input.RequestID)
token := strings.TrimSpace(input.RefreshToken)
if requestID == "" {
return "", false, fmt.Errorf("missing trial request id")
}
if token == "" {
return "", false, fmt.Errorf("missing entitlement refresh token")
}
if strings.TrimSpace(input.OrgID) == "" || strings.TrimSpace(input.Email) == "" || strings.TrimSpace(input.ReturnURL) == "" || strings.TrimSpace(input.InstanceHost) == "" {
return "", false, fmt.Errorf("trial entitlement input is incomplete")
}
issuedAt := input.IssuedAt.UTC()
redeemedAt := input.RedeemedAt.UTC()
trialStartedAt := input.TrialStartedAt.UTC()
tx, err := r.db.Begin()
if err != nil {
return "", false, fmt.Errorf("begin trial hosted entitlement tx: %w", err)
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
rec, err := loadHostedEntitlement(tx.QueryRow(`
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
FROM hosted_entitlements WHERE trial_request_id = ?`,
requestID,
))
if err != nil {
return "", false, fmt.Errorf("load trial hosted entitlement: %w", err)
}
if rec != nil && strings.TrimSpace(rec.RefreshToken) != "" && rec.RevokedAt == nil {
if _, err := tx.Exec(`
UPDATE hosted_entitlements
SET org_id = ?, email = ?, return_url = ?, instance_token = ?, instance_host = ?, trial_started_at = ?,
redeemed_at = COALESCE(redeemed_at, ?)
WHERE id = ?`,
strings.TrimSpace(input.OrgID),
strings.TrimSpace(input.Email),
strings.TrimSpace(input.ReturnURL),
strings.TrimSpace(input.InstanceToken),
strings.TrimSpace(input.InstanceHost),
trialStartedAt.Unix(),
redeemedAt.Unix(),
rec.ID,
); err != nil {
return "", false, fmt.Errorf("update trial hosted entitlement metadata: %w", err)
}
if err := tx.Commit(); err != nil {
return "", false, fmt.Errorf("commit trial hosted entitlement tx: %w", err)
}
tx = nil
return rec.RefreshToken, false, nil
}
if rec == nil {
if _, err := tx.Exec(`
INSERT INTO hosted_entitlements (
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
) VALUES (?, ?, NULL, ?, ?, ?, ?, ?, ?, ?, ?, '', ?, NULL, NULL, ?, NULL)`,
trialHostedEntitlementID(requestID),
string(HostedEntitlementKindTrial),
requestID,
strings.TrimSpace(input.OrgID),
strings.TrimSpace(input.Email),
strings.TrimSpace(input.ReturnURL),
strings.TrimSpace(input.InstanceToken),
strings.TrimSpace(input.InstanceHost),
trialStartedAt.Unix(),
token,
issuedAt.Unix(),
redeemedAt.Unix(),
); err != nil {
return "", false, fmt.Errorf("insert trial hosted entitlement: %w", err)
}
} else {
if _, err := tx.Exec(`
UPDATE hosted_entitlements
SET id = ?, kind = ?, tenant_id = NULL, trial_request_id = ?, org_id = ?, email = ?, return_url = ?,
instance_token = ?, instance_host = ?, trial_started_at = ?, refresh_token = ?, issued_at = ?,
last_refreshed_at = NULL, redeemed_at = ?, revoked_at = NULL
WHERE trial_request_id = ?`,
trialHostedEntitlementID(requestID),
string(HostedEntitlementKindTrial),
requestID,
strings.TrimSpace(input.OrgID),
strings.TrimSpace(input.Email),
strings.TrimSpace(input.ReturnURL),
strings.TrimSpace(input.InstanceToken),
strings.TrimSpace(input.InstanceHost),
trialStartedAt.Unix(),
token,
issuedAt.Unix(),
redeemedAt.Unix(),
requestID,
); err != nil {
return "", false, fmt.Errorf("rotate trial hosted entitlement: %w", err)
}
}
if err := tx.Commit(); err != nil {
return "", false, fmt.Errorf("commit trial hosted entitlement tx: %w", err)
}
tx = nil
return token, true, nil
}
func (r *TenantRegistry) StoreOrRotateTrialActivation(input TrialHostedActivationInput, ttl time.Duration) (string, bool, error) {
requestID := strings.TrimSpace(input.RequestID)
activationToken := strings.TrimSpace(input.ActivationToken)
refreshToken := strings.TrimSpace(input.RefreshToken)
if requestID == "" {
return "", false, fmt.Errorf("missing trial request id")
}
if activationToken == "" {
return "", false, fmt.Errorf("missing activation token")
}
if refreshToken == "" {
return "", false, fmt.Errorf("missing entitlement refresh token")
}
if strings.TrimSpace(input.OrgID) == "" || strings.TrimSpace(input.Email) == "" || strings.TrimSpace(input.ReturnURL) == "" || strings.TrimSpace(input.InstanceHost) == "" {
return "", false, fmt.Errorf("trial activation input is incomplete")
}
if ttl <= 0 {
return "", false, fmt.Errorf("activation ttl is required")
}
issuedAt := input.IssuedAt.UTC()
if issuedAt.IsZero() {
return "", false, fmt.Errorf("activation issued_at is required")
}
trialStartedAt := input.TrialStartedAt.UTC()
if trialStartedAt.IsZero() {
trialStartedAt = issuedAt
}
rotateBefore := issuedAt.Add(-ttl)
tx, err := r.db.Begin()
if err != nil {
return "", false, fmt.Errorf("begin trial activation tx: %w", err)
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
rec, err := loadHostedEntitlement(tx.QueryRow(`
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
FROM hosted_entitlements WHERE trial_request_id = ?`,
requestID,
))
if err != nil {
return "", false, fmt.Errorf("load trial activation entitlement: %w", err)
}
if rec != nil && strings.TrimSpace(rec.ActivationToken) != "" && rec.ActivationIssuedAt != nil && rec.ActivationIssuedAt.After(rotateBefore) && rec.RevokedAt == nil {
if _, err := tx.Exec(`
UPDATE hosted_entitlements
SET org_id = ?, email = ?, return_url = ?, instance_token = ?, instance_host = ?, trial_started_at = ?
WHERE id = ?`,
strings.TrimSpace(input.OrgID),
strings.TrimSpace(input.Email),
strings.TrimSpace(input.ReturnURL),
strings.TrimSpace(input.InstanceToken),
strings.TrimSpace(input.InstanceHost),
trialStartedAt.Unix(),
rec.ID,
); err != nil {
return "", false, fmt.Errorf("update trial activation metadata: %w", err)
}
if err := tx.Commit(); err != nil {
return "", false, fmt.Errorf("commit trial activation tx: %w", err)
}
tx = nil
return rec.ActivationToken, false, nil
}
if rec == nil {
if _, err := tx.Exec(`
INSERT INTO hosted_entitlements (
id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
) VALUES (?, ?, NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL, NULL)`,
trialHostedEntitlementID(requestID),
string(HostedEntitlementKindTrial),
requestID,
strings.TrimSpace(input.OrgID),
strings.TrimSpace(input.Email),
strings.TrimSpace(input.ReturnURL),
strings.TrimSpace(input.InstanceToken),
strings.TrimSpace(input.InstanceHost),
trialStartedAt.Unix(),
refreshToken,
activationToken,
issuedAt.Unix(),
issuedAt.Unix(),
); err != nil {
return "", false, fmt.Errorf("insert trial activation entitlement: %w", err)
}
} else {
effectiveRefreshToken := strings.TrimSpace(rec.RefreshToken)
if effectiveRefreshToken == "" {
effectiveRefreshToken = refreshToken
}
if _, err := tx.Exec(`
UPDATE hosted_entitlements
SET id = ?, kind = ?, tenant_id = NULL, trial_request_id = ?, org_id = ?, email = ?, return_url = ?,
instance_token = ?, instance_host = ?, trial_started_at = ?, refresh_token = ?, activation_token = ?,
issued_at = CASE WHEN issued_at > 0 THEN issued_at ELSE ? END, activation_issued_at = ?,
last_refreshed_at = NULL, revoked_at = NULL
WHERE trial_request_id = ?`,
trialHostedEntitlementID(requestID),
string(HostedEntitlementKindTrial),
requestID,
strings.TrimSpace(input.OrgID),
strings.TrimSpace(input.Email),
strings.TrimSpace(input.ReturnURL),
strings.TrimSpace(input.InstanceToken),
strings.TrimSpace(input.InstanceHost),
trialStartedAt.Unix(),
effectiveRefreshToken,
activationToken,
issuedAt.Unix(),
issuedAt.Unix(),
requestID,
); err != nil {
return "", false, fmt.Errorf("rotate trial activation entitlement: %w", err)
}
}
if err := tx.Commit(); err != nil {
return "", false, fmt.Errorf("commit trial activation tx: %w", err)
}
tx = nil
return activationToken, true, nil
}
func scanTenants(rows *sql.Rows) ([]*Tenant, error) {
var tenants []*Tenant
for rows.Next() {
t, err := scanTenant(rows)
if err != nil {
return nil, err
}
tenants = append(tenants, t)
}
return tenants, rows.Err()
}
// GetHostedEntitlementByRefreshToken retrieves the hosted entitlement record for a refresh token.
func (r *TenantRegistry) GetHostedEntitlementByRefreshToken(token string) (*HostedEntitlement, error) {
row := r.db.QueryRow(`
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
FROM hosted_entitlements
WHERE refresh_token = ?`,
strings.TrimSpace(token),
)
return loadHostedEntitlement(row)
}
func (r *TenantRegistry) GetHostedEntitlementByActivationToken(token string) (*HostedEntitlement, error) {
row := r.db.QueryRow(`
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
FROM hosted_entitlements
WHERE activation_token = ?`,
strings.TrimSpace(token),
)
return loadHostedEntitlement(row)
}
func (r *TenantRegistry) GetHostedEntitlementByTrialRequestID(requestID string) (*HostedEntitlement, error) {
row := r.db.QueryRow(`
SELECT id, kind, tenant_id, trial_request_id, org_id, email, return_url, instance_token, instance_host,
trial_started_at, refresh_token, activation_token, issued_at, activation_issued_at, last_refreshed_at, redeemed_at, revoked_at
FROM hosted_entitlements
WHERE trial_request_id = ?`,
strings.TrimSpace(requestID),
)
return loadHostedEntitlement(row)
}
// MarkHostedEntitlementRefreshed records the last successful hosted entitlement refresh time.
func (r *TenantRegistry) MarkHostedEntitlementRefreshed(id string, refreshedAt time.Time) error {
id = strings.TrimSpace(id)
if id == "" {
return fmt.Errorf("missing hosted entitlement id")
}
res, err := r.db.Exec(`
UPDATE hosted_entitlements
SET last_refreshed_at = ?
WHERE id = ? AND revoked_at IS NULL`,
refreshedAt.UTC().Unix(),
id,
)
if err != nil {
return fmt.Errorf("mark hosted entitlement refreshed: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("hosted entitlement %q not found", id)
}
return nil
}
// RevokeHostedEntitlement revokes a tenant's hosted entitlement refresh authority.
func (r *TenantRegistry) RevokeHostedEntitlement(tenantID string, revokedAt time.Time) error {
tenantID = strings.TrimSpace(tenantID)
if tenantID == "" {
return fmt.Errorf("missing tenant id")
}
_, err := r.db.Exec(`
UPDATE hosted_entitlements
SET revoked_at = COALESCE(revoked_at, ?)
WHERE tenant_id = ? AND kind = ?`,
revokedAt.UTC().Unix(),
tenantID,
string(HostedEntitlementKindPaid),
)
if err != nil {
return fmt.Errorf("revoke hosted entitlement: %w", err)
}
return nil
}
func loadHostedEntitlement(s scanner) (*HostedEntitlement, error) {
var rec HostedEntitlement
var kind string
var tenantID sql.NullString
var trialRequestID sql.NullString
var issuedAt int64
var trialStartedAt sql.NullInt64
var activationIssuedAt sql.NullInt64
var lastRefreshedAt sql.NullInt64
var redeemedAt sql.NullInt64
var revokedAt sql.NullInt64
if err := s.Scan(
&rec.ID,
&kind,
&tenantID,
&trialRequestID,
&rec.OrgID,
&rec.Email,
&rec.ReturnURL,
&rec.InstanceToken,
&rec.InstanceHost,
&trialStartedAt,
&rec.RefreshToken,
&rec.ActivationToken,
&issuedAt,
&activationIssuedAt,
&lastRefreshedAt,
&redeemedAt,
&revokedAt,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan hosted entitlement: %w", err)
}
if strings.TrimSpace(kind) == "" {
kind = string(HostedEntitlementKindPaid)
}
rec.Kind = HostedEntitlementKind(kind)
if tenantID.Valid {
rec.TenantID = tenantID.String
}
if trialRequestID.Valid {
rec.TrialRequestID = trialRequestID.String
}
if trialStartedAt.Valid {
ts := time.Unix(trialStartedAt.Int64, 0).UTC()
rec.TrialStartedAt = &ts
}
rec.IssuedAt = time.Unix(issuedAt, 0).UTC()
if activationIssuedAt.Valid {
ts := time.Unix(activationIssuedAt.Int64, 0).UTC()
rec.ActivationIssuedAt = &ts
}
if lastRefreshedAt.Valid {
ts := time.Unix(lastRefreshedAt.Int64, 0).UTC()
rec.LastRefreshedAt = &ts
}
if redeemedAt.Valid {
ts := time.Unix(redeemedAt.Int64, 0).UTC()
rec.RedeemedAt = &ts
}
if revokedAt.Valid {
ts := time.Unix(revokedAt.Int64, 0).UTC()
rec.RevokedAt = &ts
}
return &rec, nil
}
// CreateAccount inserts a new account record.
func (r *TenantRegistry) CreateAccount(a *Account) error {
if a == nil {
return fmt.Errorf("account is nil")
}
now := time.Now().UTC()
if a.CreatedAt.IsZero() {
a.CreatedAt = now
}
a.UpdatedAt = now
kind := string(a.Kind)
if kind == "" {
kind = string(AccountKindIndividual)
}
_, err := r.db.Exec(`
INSERT INTO accounts (
id, kind, display_name, created_at, updated_at
) VALUES (?, ?, ?, ?, ?)`,
a.ID, kind, a.DisplayName, a.CreatedAt.Unix(), a.UpdatedAt.Unix(),
)
if err != nil {
return fmt.Errorf("create account: %w", err)
}
a.Kind = AccountKind(kind)
return nil
}
// GetAccount retrieves an account by ID.
func (r *TenantRegistry) GetAccount(accountID string) (*Account, error) {
row := r.db.QueryRow(`SELECT
id, kind, display_name, created_at, updated_at
FROM accounts WHERE id = ?`, accountID)
return scanAccount(row)
}
// UpdateAccount modifies an existing account record.
func (r *TenantRegistry) UpdateAccount(a *Account) error {
if a == nil {
return fmt.Errorf("account is nil")
}
a.UpdatedAt = time.Now().UTC()
kind := string(a.Kind)
if kind == "" {
kind = string(AccountKindIndividual)
}
res, err := r.db.Exec(`
UPDATE accounts SET
kind = ?, display_name = ?, updated_at = ?
WHERE id = ?`,
kind, a.DisplayName, a.UpdatedAt.Unix(),
a.ID,
)
if err != nil {
return fmt.Errorf("update account: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("account %q not found", a.ID)
}
a.Kind = AccountKind(kind)
return nil
}
// ListAccounts returns all accounts.
func (r *TenantRegistry) ListAccounts() ([]*Account, error) {
rows, err := r.db.Query(`SELECT
id, kind, display_name, created_at, updated_at
FROM accounts ORDER BY created_at DESC`)
if err != nil {
return nil, fmt.Errorf("list accounts: %w", err)
}
defer rows.Close()
return scanAccounts(rows)
}
// CreateUser inserts a new user record.
func (r *TenantRegistry) CreateUser(u *User) error {
if u == nil {
return fmt.Errorf("user is nil")
}
now := time.Now().UTC()
if u.CreatedAt.IsZero() {
u.CreatedAt = now
}
if u.SessionVersion < 1 {
u.SessionVersion = 1
}
_, err := r.db.Exec(`
INSERT INTO users (
id, email, created_at, last_login_at, session_version
) VALUES (?, ?, ?, ?, ?)`,
u.ID, u.Email, u.CreatedAt.Unix(), nullableTimeUnix(u.LastLoginAt), u.SessionVersion,
)
if err != nil {
return fmt.Errorf("create user: %w", err)
}
return nil
}
// GetUser retrieves a user by ID.
func (r *TenantRegistry) GetUser(userID string) (*User, error) {
row := r.db.QueryRow(`SELECT
id, email, created_at, last_login_at, session_version
FROM users WHERE id = ?`, userID)
return scanUser(row)
}
// GetUserByEmail retrieves a user by email.
func (r *TenantRegistry) GetUserByEmail(email string) (*User, error) {
row := r.db.QueryRow(`SELECT
id, email, created_at, last_login_at, session_version
FROM users WHERE email = ?`, email)
return scanUser(row)
}
// UpdateUserLastLogin sets last_login_at for the given user ID to the current time.
func (r *TenantRegistry) UpdateUserLastLogin(userID string) error {
now := time.Now().UTC()
res, err := r.db.Exec(`UPDATE users SET last_login_at = ? WHERE id = ?`, now.Unix(), userID)
if err != nil {
return fmt.Errorf("update user last login: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("user %q not found", userID)
}
return nil
}
// GetUserSessionVersion returns the current session version for the user.
func (r *TenantRegistry) GetUserSessionVersion(userID string) (int64, error) {
userID = strings.TrimSpace(userID)
if userID == "" {
return 0, fmt.Errorf("missing user id")
}
var version int64
row := r.db.QueryRow(`SELECT session_version FROM users WHERE id = ?`, userID)
if err := row.Scan(&version); err != nil {
if err == sql.ErrNoRows {
return 0, fmt.Errorf("user %q not found", userID)
}
return 0, fmt.Errorf("get user session version: %w", err)
}
if version < 1 {
version = 1
}
return version, nil
}
// RevokeUserSessions increments the user's session version, invalidating all
// previously issued sessions, and returns the new version.
func (r *TenantRegistry) RevokeUserSessions(userID string) (int64, error) {
userID = strings.TrimSpace(userID)
if userID == "" {
return 0, fmt.Errorf("missing user id")
}
tx, err := r.db.Begin()
if err != nil {
return 0, fmt.Errorf("begin revoke sessions tx: %w", err)
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
res, err := tx.Exec(`UPDATE users SET session_version = session_version + 1 WHERE id = ?`, userID)
if err != nil {
return 0, fmt.Errorf("increment session version: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return 0, fmt.Errorf("user %q not found", userID)
}
var newVersion int64
if err := tx.QueryRow(`SELECT session_version FROM users WHERE id = ?`, userID).Scan(&newVersion); err != nil {
return 0, fmt.Errorf("read incremented session version: %w", err)
}
if err := tx.Commit(); err != nil {
return 0, fmt.Errorf("commit revoke sessions tx: %w", err)
}
tx = nil
if newVersion < 1 {
newVersion = 1
}
return newVersion, nil
}
// CreateMembership inserts a new membership record.
func (r *TenantRegistry) CreateMembership(m *AccountMembership) error {
if m == nil {
return fmt.Errorf("membership is nil")
}
now := time.Now().UTC()
if m.CreatedAt.IsZero() {
m.CreatedAt = now
}
role := string(m.Role)
if role == "" {
role = string(MemberRoleTech)
}
_, err := r.db.Exec(`
INSERT INTO account_memberships (
account_id, user_id, role, created_at
) VALUES (?, ?, ?, ?)`,
m.AccountID, m.UserID, role, m.CreatedAt.Unix(),
)
if err != nil {
return fmt.Errorf("create membership: %w", err)
}
m.Role = MemberRole(role)
return nil
}
// GetMembership retrieves a membership record by account ID and user ID.
func (r *TenantRegistry) GetMembership(accountID, userID string) (*AccountMembership, error) {
row := r.db.QueryRow(`SELECT
account_id, user_id, role, created_at
FROM account_memberships
WHERE account_id = ? AND user_id = ?`, accountID, userID)
return scanMembership(row)
}
// ListMembersByAccount returns all membership records for a given account ID.
func (r *TenantRegistry) ListMembersByAccount(accountID string) ([]*AccountMembership, error) {
rows, err := r.db.Query(`SELECT
account_id, user_id, role, created_at
FROM account_memberships
WHERE account_id = ?
ORDER BY created_at DESC`, accountID)
if err != nil {
return nil, fmt.Errorf("list members by account: %w", err)
}
defer rows.Close()
return scanMemberships(rows)
}
// ListAccountsByUser returns account IDs for all accounts the given user belongs to.
func (r *TenantRegistry) ListAccountsByUser(userID string) ([]string, error) {
rows, err := r.db.Query(`SELECT account_id FROM account_memberships WHERE user_id = ? ORDER BY created_at DESC`, userID)
if err != nil {
return nil, fmt.Errorf("list accounts by user: %w", err)
}
defer rows.Close()
var accountIDs []string
for rows.Next() {
var accountID string
if err := rows.Scan(&accountID); err != nil {
return nil, fmt.Errorf("scan account id: %w", err)
}
accountIDs = append(accountIDs, accountID)
}
return accountIDs, rows.Err()
}
// UpdateMembershipRole updates a membership role.
func (r *TenantRegistry) UpdateMembershipRole(accountID, userID string, role MemberRole) error {
res, err := r.db.Exec(`UPDATE account_memberships SET role = ? WHERE account_id = ? AND user_id = ?`, string(role), accountID, userID)
if err != nil {
return fmt.Errorf("update membership role: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("membership (%q, %q) not found", accountID, userID)
}
return nil
}
// DeleteMembership deletes a membership record.
func (r *TenantRegistry) DeleteMembership(accountID, userID string) error {
res, err := r.db.Exec(`DELETE FROM account_memberships WHERE account_id = ? AND user_id = ?`, accountID, userID)
if err != nil {
return fmt.Errorf("delete membership: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("membership (%q, %q) not found", accountID, userID)
}
return nil
}
// CreateStripeAccount inserts a new StripeAccount mapping row.
func (r *TenantRegistry) CreateStripeAccount(sa *StripeAccount) error {
if sa == nil {
return fmt.Errorf("stripe account is nil")
}
sa.AccountID = strings.TrimSpace(sa.AccountID)
sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID)
sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID)
sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID)
sa.PlanVersion = canonicalizeRegistryPlanVersion(sa.PlanVersion)
sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState)
if sa.AccountID == "" {
return fmt.Errorf("missing account id")
}
if sa.StripeCustomerID == "" {
return fmt.Errorf("missing stripe customer id")
}
if sa.SubscriptionState == "" {
sa.SubscriptionState = "trial"
}
if sa.UpdatedAt == 0 {
sa.UpdatedAt = time.Now().UTC().Unix()
}
_, err := r.db.Exec(`
INSERT INTO stripe_accounts (
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
plan_version, subscription_state, grace_started_at, trial_ends_at, current_period_end, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
sa.AccountID,
sa.StripeCustomerID,
nullableString(sa.StripeSubscriptionID),
nullableString(sa.StripeSubItemWorkspacesID),
sa.PlanVersion,
sa.SubscriptionState,
nullableInt64Ptr(sa.GraceStartedAt),
nullableInt64Ptr(sa.TrialEndsAt),
nullableInt64Ptr(sa.CurrentPeriodEnd),
sa.UpdatedAt,
)
if err != nil {
return fmt.Errorf("create stripe account: %w", err)
}
return nil
}
// GetStripeAccount retrieves the StripeAccount row by account ID.
func (r *TenantRegistry) GetStripeAccount(accountID string) (*StripeAccount, error) {
row := r.db.QueryRow(`SELECT
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
plan_version, subscription_state, grace_started_at, trial_ends_at, current_period_end, updated_at
FROM stripe_accounts WHERE account_id = ?`, strings.TrimSpace(accountID))
return scanStripeAccount(row)
}
// GetStripeAccountByCustomerID retrieves the StripeAccount row by Stripe customer ID.
func (r *TenantRegistry) GetStripeAccountByCustomerID(customerID string) (*StripeAccount, error) {
row := r.db.QueryRow(`SELECT
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
plan_version, subscription_state, grace_started_at, trial_ends_at, current_period_end, updated_at
FROM stripe_accounts WHERE stripe_customer_id = ?`, strings.TrimSpace(customerID))
return scanStripeAccount(row)
}
// ListStripeAccounts returns all Stripe account mappings.
func (r *TenantRegistry) ListStripeAccounts() ([]*StripeAccount, error) {
rows, err := r.db.Query(`SELECT
account_id, stripe_customer_id, stripe_subscription_id, stripe_sub_item_workspaces_id,
plan_version, subscription_state, grace_started_at, trial_ends_at, current_period_end, updated_at
FROM stripe_accounts
ORDER BY updated_at DESC`)
if err != nil {
return nil, fmt.Errorf("list stripe accounts: %w", err)
}
defer rows.Close()
var out []*StripeAccount
for rows.Next() {
sa, scanErr := scanStripeAccount(rows)
if scanErr != nil {
return nil, scanErr
}
if sa != nil {
out = append(out, sa)
}
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate stripe accounts: %w", err)
}
return out, nil
}
// UpdateStripeAccount modifies an existing StripeAccount row.
func (r *TenantRegistry) UpdateStripeAccount(sa *StripeAccount) error {
if sa == nil {
return fmt.Errorf("stripe account is nil")
}
sa.AccountID = strings.TrimSpace(sa.AccountID)
sa.StripeCustomerID = strings.TrimSpace(sa.StripeCustomerID)
sa.StripeSubscriptionID = strings.TrimSpace(sa.StripeSubscriptionID)
sa.StripeSubItemWorkspacesID = strings.TrimSpace(sa.StripeSubItemWorkspacesID)
sa.PlanVersion = canonicalizeRegistryPlanVersion(sa.PlanVersion)
sa.SubscriptionState = strings.TrimSpace(sa.SubscriptionState)
if sa.AccountID == "" {
return fmt.Errorf("missing account id")
}
if sa.StripeCustomerID == "" {
return fmt.Errorf("missing stripe customer id")
}
if sa.SubscriptionState == "" {
sa.SubscriptionState = "trial"
}
sa.UpdatedAt = time.Now().UTC().Unix()
res, err := r.db.Exec(`
UPDATE stripe_accounts SET
stripe_customer_id = ?, stripe_subscription_id = ?, stripe_sub_item_workspaces_id = ?,
plan_version = ?, subscription_state = ?, grace_started_at = ?, trial_ends_at = ?, current_period_end = ?, updated_at = ?
WHERE account_id = ?`,
sa.StripeCustomerID,
nullableString(sa.StripeSubscriptionID),
nullableString(sa.StripeSubItemWorkspacesID),
sa.PlanVersion,
sa.SubscriptionState,
nullableInt64Ptr(sa.GraceStartedAt),
nullableInt64Ptr(sa.TrialEndsAt),
nullableInt64Ptr(sa.CurrentPeriodEnd),
sa.UpdatedAt,
sa.AccountID,
)
if err != nil {
return fmt.Errorf("update stripe account: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("stripe account %q not found", sa.AccountID)
}
return nil
}
// RecordStripeEvent inserts a webhook event record and returns true if the
// event was already recorded (duplicate Stripe delivery).
func (r *TenantRegistry) RecordStripeEvent(eventID, eventType string) (alreadyProcessed bool, err error) {
eventID = strings.TrimSpace(eventID)
eventType = strings.TrimSpace(eventType)
if eventID == "" {
return false, fmt.Errorf("missing stripe event id")
}
if eventType == "" {
return false, fmt.Errorf("missing stripe event type")
}
tx, err := r.db.Begin()
if err != nil {
return false, fmt.Errorf("begin record stripe event tx: %w", err)
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
now := time.Now().UTC().Unix()
// INSERT OR IGNORE avoids driver-specific error parsing for duplicates.
res, err := tx.Exec(`
INSERT OR IGNORE INTO stripe_events (
stripe_event_id, event_type, received_at, processing_started_at, processed_at, processing_error
) VALUES (?, ?, ?, ?, NULL, NULL)`,
eventID, eventType, now, now,
)
if err != nil {
return false, fmt.Errorf("record stripe event: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return false, fmt.Errorf("get rows affected: %w", err)
}
if affected == 1 {
if err := tx.Commit(); err != nil {
return false, fmt.Errorf("commit record stripe event tx: %w", err)
}
tx = nil
return false, nil
}
var processedAt sql.NullInt64
var processingStartedAt sql.NullInt64
var processingError sql.NullString
if err := tx.QueryRow(`
SELECT processed_at, processing_started_at, processing_error
FROM stripe_events WHERE stripe_event_id = ?`,
eventID,
).Scan(&processedAt, &processingStartedAt, &processingError); err != nil {
if err == sql.ErrNoRows {
return false, fmt.Errorf("stripe event %q not found after insert-or-ignore", eventID)
}
return false, fmt.Errorf("query stripe event status: %w", err)
}
// Exact duplicates that were already processed successfully are skipped.
if processedAt.Valid && strings.TrimSpace(processingError.String) == "" {
if err := tx.Commit(); err != nil {
return false, fmt.Errorf("commit record stripe event tx: %w", err)
}
tx = nil
return true, nil
}
// If another request is currently processing this event and that processing
// window is still fresh, skip duplicate execution.
if processingStartedAt.Valid && !processedAt.Valid && strings.TrimSpace(processingError.String) == "" {
if now-processingStartedAt.Int64 <= stripeEventProcessingLeaseSeconds {
if err := tx.Commit(); err != nil {
return false, fmt.Errorf("commit record stripe event tx: %w", err)
}
tx = nil
return true, nil
}
}
// Reclaim failed/stale event deliveries for retry processing.
_, err = tx.Exec(`
UPDATE stripe_events SET
event_type = ?,
received_at = ?,
processing_started_at = ?,
processed_at = NULL,
processing_error = NULL
WHERE stripe_event_id = ?`,
eventType, now, now, eventID,
)
if err != nil {
return false, fmt.Errorf("reclaim stripe event for retry: %w", err)
}
if err := tx.Commit(); err != nil {
return false, fmt.Errorf("commit record stripe event tx: %w", err)
}
tx = nil
return false, nil
}
// MarkStripeEventProcessed marks a previously recorded event as processed.
// processingError is stored (nullable) for troubleshooting.
func (r *TenantRegistry) MarkStripeEventProcessed(eventID string, processingError string) error {
eventID = strings.TrimSpace(eventID)
if eventID == "" {
return fmt.Errorf("missing stripe event id")
}
processingError = strings.TrimSpace(processingError)
res, err := r.db.Exec(`
UPDATE stripe_events SET
processing_started_at = NULL,
processed_at = ?, processing_error = ?
WHERE stripe_event_id = ?`,
time.Now().UTC().Unix(),
nullableString(processingError),
eventID,
)
if err != nil {
return fmt.Errorf("mark stripe event processed: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("get rows affected: %w", err)
}
if affected == 0 {
return fmt.Errorf("stripe event %q not found", eventID)
}
return nil
}
func scanAccount(s scanner) (*Account, error) {
var a Account
var kind string
var createdAt, updatedAt int64
if err := s.Scan(&a.ID, &kind, &a.DisplayName, &createdAt, &updatedAt); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan account: %w", err)
}
a.Kind = AccountKind(kind)
a.CreatedAt = time.Unix(createdAt, 0).UTC()
a.UpdatedAt = time.Unix(updatedAt, 0).UTC()
return &a, nil
}
func scanAccounts(rows *sql.Rows) ([]*Account, error) {
var accounts []*Account
for rows.Next() {
a, err := scanAccount(rows)
if err != nil {
return nil, err
}
accounts = append(accounts, a)
}
return accounts, rows.Err()
}
func scanUser(s scanner) (*User, error) {
var u User
var createdAt int64
var lastLogin sql.NullInt64
if err := s.Scan(&u.ID, &u.Email, &createdAt, &lastLogin, &u.SessionVersion); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan user: %w", err)
}
u.CreatedAt = time.Unix(createdAt, 0).UTC()
if lastLogin.Valid {
ts := time.Unix(lastLogin.Int64, 0).UTC()
u.LastLoginAt = &ts
}
if u.SessionVersion < 1 {
u.SessionVersion = 1
}
return &u, nil
}
func scanMembership(s scanner) (*AccountMembership, error) {
var m AccountMembership
var role string
var createdAt int64
if err := s.Scan(&m.AccountID, &m.UserID, &role, &createdAt); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan membership: %w", err)
}
m.Role = MemberRole(role)
m.CreatedAt = time.Unix(createdAt, 0).UTC()
return &m, nil
}
func scanMemberships(rows *sql.Rows) ([]*AccountMembership, error) {
var memberships []*AccountMembership
for rows.Next() {
m, err := scanMembership(rows)
if err != nil {
return nil, err
}
memberships = append(memberships, m)
}
return memberships, rows.Err()
}
func nullableTimeUnix(t *time.Time) driver.Value {
if t == nil {
return nil
}
return t.Unix()
}
func nullableInt64Ptr(v *int64) driver.Value {
if v == nil {
return nil
}
return *v
}
func nullableString(s string) driver.Value {
if strings.TrimSpace(s) == "" {
return nil
}
return strings.TrimSpace(s)
}
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
func scanStripeAccount(s scanner) (*StripeAccount, error) {
var sa StripeAccount
var subID, subItemID sql.NullString
var graceStartedAt, trialEnds, periodEnd sql.NullInt64
if err := s.Scan(
&sa.AccountID,
&sa.StripeCustomerID,
&subID,
&subItemID,
&sa.PlanVersion,
&sa.SubscriptionState,
&graceStartedAt,
&trialEnds,
&periodEnd,
&sa.UpdatedAt,
); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("scan stripe account: %w", err)
}
if subID.Valid {
sa.StripeSubscriptionID = subID.String
}
if subItemID.Valid {
sa.StripeSubItemWorkspacesID = subItemID.String
}
if graceStartedAt.Valid {
v := graceStartedAt.Int64
sa.GraceStartedAt = &v
}
if trialEnds.Valid {
v := trialEnds.Int64
sa.TrialEndsAt = &v
}
if periodEnd.Valid {
v := periodEnd.Int64
sa.CurrentPeriodEnd = &v
}
sa.PlanVersion = canonicalizeRegistryPlanVersion(sa.PlanVersion)
return &sa, nil
}