package api import ( "context" "encoding/json" "net/http" "strings" "github.com/rcourtman/pulse-go-rewrite/internal/config" "github.com/rcourtman/pulse-go-rewrite/internal/models" "github.com/rcourtman/pulse-go-rewrite/pkg/auth" "github.com/rs/zerolog/log" ) type OrganizationContextKey string const ( OrgIDContextKey OrganizationContextKey = "org_id" OrgContextKey OrganizationContextKey = "org_object" APITokenContextKey OrganizationContextKey = "api_token_record" ) // TenantMiddleware extracts the organization ID from the request and // sets up the context for multi-tenant isolation. type TenantMiddleware struct { persistence *config.MultiTenantPersistence authChecker AuthorizationChecker } // TenantMiddlewareConfig holds configuration for the tenant middleware. type TenantMiddlewareConfig struct { Persistence *config.MultiTenantPersistence AuthChecker AuthorizationChecker } func NewTenantMiddleware(p *config.MultiTenantPersistence) *TenantMiddleware { return &TenantMiddleware{persistence: p} } // NewTenantMiddlewareWithConfig creates a new TenantMiddleware with full configuration. func NewTenantMiddlewareWithConfig(cfg TenantMiddlewareConfig) *TenantMiddleware { return &TenantMiddleware{ persistence: cfg.Persistence, authChecker: cfg.AuthChecker, } } // SetAuthChecker sets the authorization checker for the middleware. func (m *TenantMiddleware) SetAuthChecker(checker AuthorizationChecker) { m.authChecker = checker } func (m *TenantMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { orgID := requestedOrgID(r) // 2. Validate Organization Exists (only for non-default orgs) // This must check existence WITHOUT creating directories to prevent DoS. // It also must run BEFORE feature checks to ensure invalid org IDs return 400 (Bad Request) // rather than 501/402 (feature disabled/unlicensed). if orgID != "default" && m.persistence != nil { if !m.persistence.OrgExists(orgID) { writeJSONError(w, http.StatusBadRequest, "invalid_org", "Invalid Organization ID") return } } // 3. Feature flag and License Check for multi-tenant access // Non-default orgs require: // 1. Feature flag enabled (PULSE_MULTI_TENANT_ENABLED=true) - returns 501 if disabled // 2. Enterprise license - returns 402 if unlicensed if orgID != "default" { // Check feature flag first - 501 Not Implemented if disabled if !IsMultiTenantEnabled() { writeMultiTenantDisabledError(w) return } // Feature is enabled, check license - 402 Payment Required if unlicensed checkCtx := context.WithValue(r.Context(), OrgIDContextKey, orgID) if !hasMultiTenantFeatureForContext(checkCtx) { writeMultiTenantRequiredError(w) return } } // 4. Authorization Check // Check if the authenticated user/token is allowed to access this organization // Note: This runs AFTER AuthContextMiddleware, so auth context is available if m.authChecker != nil && orgID != "default" { // Get API token from context (set by AuthContextMiddleware) var token *config.APITokenRecord if tokenVal := auth.GetAPIToken(r.Context()); tokenVal != nil { if t, ok := tokenVal.(*config.APITokenRecord); ok { token = t } } // Get user ID from context (set by AuthContextMiddleware) userID := auth.GetUser(r.Context()) // Only perform authorization check if we have auth context // If no auth context, the route's RequireAuth will handle authentication errors if token != nil || userID != "" { // Perform authorization check using the interface method result := m.authChecker.CheckAccess(token, userID, orgID) if !result.Allowed { log.Warn(). Str("org_id", orgID). Str("user_id", userID). Str("reason", result.Reason). Msg("Unauthorized access attempt to organization") writeJSONError(w, http.StatusForbidden, "access_denied", result.Reason) return } // Log warning for legacy tokens accessing non-default orgs if result.IsLegacyToken { log.Warn(). Str("org_id", orgID). Msg("Legacy token with wildcard access used - consider binding to specific org") } } } // 5. Inject into Context ctx := context.WithValue(r.Context(), OrgIDContextKey, orgID) // Also store a mock organization object for now org := &models.Organization{ID: orgID, DisplayName: orgID} ctx = context.WithValue(ctx, OrgContextKey, org) next.ServeHTTP(w, r.WithContext(ctx)) }) } // writeJSONError writes a JSON error response. func writeJSONError(w http.ResponseWriter, status int, code, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{ "error": code, "message": message, }) } // Helper to get OrgID from context func GetOrgID(ctx context.Context) string { if id, ok := ctx.Value(OrgIDContextKey).(string); ok { return id } return "default" } // Helper to get Organization from context func GetOrganization(ctx context.Context) *models.Organization { if org, ok := ctx.Value(OrgContextKey).(*models.Organization); ok { return org } return &models.Organization{ID: "default", DisplayName: "Default Organization"} } func requestedOrgID(r *http.Request) string { orgID := "" if r != nil { orgID = strings.TrimSpace(r.Header.Get("X-Pulse-Org-ID")) if orgID == "" { if cookie, err := r.Cookie("pulse_org_id"); err == nil { orgID = strings.TrimSpace(cookie.Value) } } } if orgID != "" && orgID != "default" && isV5SingleTenantMode() { log.Debug(). Str("path", r.URL.Path). Str("requested_org", orgID). Msg("Ignoring non-default org for single-tenant v5 runtime") return "default" } if orgID == "" { return "default" } return orgID }