From 499ab812e32fd8df2b31d8416890927338df0499 Mon Sep 17 00:00:00 2001 From: rcourtman Date: Thu, 5 Mar 2026 23:40:09 +0000 Subject: [PATCH] Fix post-release regressions and lock v5 to single-tenant runtime --- internal/ai/providers/openai.go | 49 +- internal/ai/providers/openai_test.go | 62 +++ internal/api/ai_handlers.go | 105 +++- internal/api/auth.go | 34 +- internal/api/config_handlers.go | 227 ++++++--- ...config_handlers_cluster_additional_test.go | 30 ++ ...config_handlers_helpers_additional_test.go | 31 +- .../api/config_handlers_setup_url_test.go | 52 ++ internal/api/docker_metadata.go | 18 +- internal/api/guest_metadata.go | 23 +- internal/api/host_metadata.go | 18 +- internal/api/license_handlers.go | 27 +- internal/api/license_handlers_test.go | 14 + internal/api/metadata_handlers_test.go | 22 + internal/api/middleware_license.go | 21 + internal/api/middleware_tenant.go | 46 +- .../api/middleware_tenant_additional_test.go | 25 + internal/api/router.go | 17 +- .../router_single_tenant_persistence_test.go | 26 + internal/api/security_regression_test.go | 17 + internal/api/tenant_agent_auth_test.go | 20 + internal/ceph/collector.go | 214 +++++++- internal/ceph/collector_test.go | 97 ++++ internal/hostagent/commands.go | 4 +- internal/hostagent/proxmox_setup.go | 5 +- internal/hostagent/proxmox_setup_test.go | 10 + internal/models/models.go | 14 + internal/monitoring/diagnostic_snapshots.go | 2 + internal/monitoring/host_agent_temps.go | 48 +- internal/monitoring/host_agent_temps_test.go | 76 +++ internal/monitoring/monitor.go | 458 +++++++++--------- .../monitoring/monitor_additional_test.go | 46 +- .../monitoring/monitor_extra_coverage_test.go | 119 ++++- internal/monitoring/monitor_polling.go | 207 ++++---- internal/monitoring/multi_tenant_monitor.go | 69 +++ .../multi_tenant_monitor_additional_test.go | 58 ++- internal/monitoring/reload_test.go | 37 ++ internal/smartctl/collector.go | 152 ++++-- internal/smartctl/collector_coverage_test.go | 51 ++ internal/websocket/hub.go | 16 + internal/websocket/hub_multitenant_test.go | 33 ++ pkg/server/server.go | 50 +- scripts/install.sh | 57 ++- 43 files changed, 2130 insertions(+), 577 deletions(-) create mode 100644 internal/api/router_single_tenant_persistence_test.go create mode 100644 internal/api/tenant_agent_auth_test.go diff --git a/internal/ai/providers/openai.go b/internal/ai/providers/openai.go index 80b760da4..62eca76e9 100644 --- a/internal/ai/providers/openai.go +++ b/internal/ai/providers/openai.go @@ -490,6 +490,24 @@ func (c *OpenAIClient) modelsEndpoint() string { return u.Scheme + "://" + u.Host + path } +func (c *OpenAIClient) usesCustomModelsListing() bool { + if c.isDeepSeek() { + return true + } + + u, err := url.Parse(c.baseURL) + if err != nil { + return false + } + + host := strings.ToLower(strings.TrimSpace(u.Hostname())) + if host == "" { + return false + } + + return host != "api.openai.com" +} + // SupportsThinking returns true if the model supports extended thinking func (c *OpenAIClient) SupportsThinking(model string) bool { // DeepSeek reasoner models support extended thinking @@ -863,23 +881,26 @@ func (c *OpenAIClient) ListModels(ctx context.Context) ([]ModelInfo, error) { models := make([]ModelInfo, 0, len(result.Data)) cache := GetNotableCache() + allowAllModels := c.usesCustomModelsListing() for _, m := range result.Data { - // Filter to only chat-capable models - if strings.Contains(m.ID, "gpt") || strings.Contains(m.ID, "o1") || + chatCapable := strings.Contains(m.ID, "gpt") || strings.Contains(m.ID, "o1") || strings.Contains(m.ID, "o3") || strings.Contains(m.ID, "o4") || - strings.Contains(m.ID, "deepseek") { - // Use correct provider for notable detection - provider := "openai" - if strings.Contains(m.ID, "deepseek") { - provider = "deepseek" - } - models = append(models, ModelInfo{ - ID: m.ID, - Name: m.ID, // OpenAI uses ID as name - CreatedAt: m.Created, - Notable: cache.IsNotable(provider, m.ID, m.Created), - }) + strings.Contains(m.ID, "deepseek") + if !allowAllModels && !chatCapable { + continue } + + // Use correct provider for notable detection + provider := "openai" + if strings.Contains(m.ID, "deepseek") { + provider = "deepseek" + } + models = append(models, ModelInfo{ + ID: m.ID, + Name: m.ID, // OpenAI uses ID as name + CreatedAt: m.Created, + Notable: cache.IsNotable(provider, m.ID, m.Created), + }) } return models, nil diff --git a/internal/ai/providers/openai_test.go b/internal/ai/providers/openai_test.go index 0f7581cbd..2f3da6413 100644 --- a/internal/ai/providers/openai_test.go +++ b/internal/ai/providers/openai_test.go @@ -13,6 +13,12 @@ import ( "github.com/stretchr/testify/require" ) +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return fn(r) +} + func TestOpenAIClient_ChatStream_Success(t *testing.T) { // Mock OpenAI SSE stream server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -209,11 +215,67 @@ func TestOpenAIClient_ListModels(t *testing.T) { models, err := client.ListModels(context.Background()) require.NoError(t, err) + assert.Len(t, models, 3) + assert.Equal(t, "gpt-4", models[0].ID) + assert.Equal(t, "gpt-3.5-turbo", models[1].ID) + assert.Equal(t, "claude-3", models[2].ID) +} + +func TestOpenAIClient_ListModels_OfficialEndpointStillFiltersNonChatModels(t *testing.T) { + client := NewOpenAIClient("sk-test", "gpt-4", "", 0) + client.client = &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "https", r.URL.Scheme) + assert.Equal(t, "api.openai.com", r.URL.Host) + assert.Equal(t, "/v1/models", r.URL.Path) + + rec := httptest.NewRecorder() + rec.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(rec).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"id": "gpt-4", "object": "model", "created": 1234567890, "owned_by": "openai"}, + {"id": "gpt-3.5-turbo", "object": "model", "created": 1234567890, "owned_by": "openai"}, + {"id": "claude-3", "object": "model", "created": 1234567890, "owned_by": "anthropic"}, + }, + }) + return rec.Result(), nil + }), + } + + models, err := client.ListModels(context.Background()) + require.NoError(t, err) + assert.Len(t, models, 2) assert.Equal(t, "gpt-4", models[0].ID) assert.Equal(t, "gpt-3.5-turbo", models[1].ID) } +func TestOpenAIClient_ListModels_CustomEndpointIncludesNonOpenAIModelNames(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"id": "llama3-8b", "object": "model", "created": 1234567890, "owned_by": "localai"}, + {"id": "qwen3.5-27b", "object": "model", "created": 1234567891, "owned_by": "localai"}, + {"id": "gemma-3-4b", "object": "model", "created": 1234567892, "owned_by": "localai"}, + }, + }) + })) + defer server.Close() + + client := NewOpenAIClient("sk-test", "llama3-8b", server.URL+"/custom-openai", 0) + + models, err := client.ListModels(context.Background()) + require.NoError(t, err) + + assert.Len(t, models, 3) + assert.Equal(t, "llama3-8b", models[0].ID) + assert.Equal(t, "qwen3.5-27b", models[1].ID) + assert.Equal(t, "gemma-3-4b", models[2].ID) +} + func TestOpenAIClient_Chat_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v1/chat/completions", r.URL.Path) diff --git a/internal/api/ai_handlers.go b/internal/api/ai_handlers.go index ab503cfba..e0b0972ca 100644 --- a/internal/api/ai_handlers.go +++ b/internal/api/ai_handlers.go @@ -124,11 +124,23 @@ func NewAISettingsHandler(mtp *config.MultiTenantPersistence, mtm *monitoring.Mu } } +func (h *AISettingsHandler) ensureLegacyAIService() *ai.Service { + if h.legacyAIService != nil || h.legacyPersistence == nil { + return h.legacyAIService + } + + h.legacyAIService = ai.NewService(h.legacyPersistence, h.agentServer) + if err := h.legacyAIService.LoadConfig(); err != nil { + log.Warn().Err(err).Msg("Failed to load AI config on startup") + } + return h.legacyAIService +} + // GetAIService returns the underlying AI service func (h *AISettingsHandler) GetAIService(ctx context.Context) *ai.Service { orgID := GetOrgID(ctx) if orgID == "default" || orgID == "" { - return h.legacyAIService + return h.ensureLegacyAIService() } h.aiServicesMu.RLock() @@ -266,6 +278,17 @@ func (h *AISettingsHandler) SetConfig(cfg *config.Config) { h.legacyConfig = cfg } +// SetLegacyRuntime wires the single-tenant runtime config and persistence explicitly. +func (h *AISettingsHandler) SetLegacyRuntime(cfg *config.Config, persistence *config.ConfigPersistence) { + if cfg != nil { + h.legacyConfig = cfg + } + if persistence != nil { + h.legacyPersistence = persistence + } + h.ensureLegacyAIService() +} + // setSSECORSHeaders validates the request origin against the configured AllowedOrigins // and sets CORS headers only for allowed origins. This prevents arbitrary origin reflection. func (h *AISettingsHandler) setSSECORSHeaders(w http.ResponseWriter, r *http.Request) { @@ -305,7 +328,9 @@ func (h *AISettingsHandler) setSSECORSHeaders(w http.ResponseWriter, r *http.Req // SetStateProvider sets the state provider for infrastructure context func (h *AISettingsHandler) SetStateProvider(sp ai.StateProvider) { h.stateProvider = sp - h.legacyAIService.SetStateProvider(sp) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetStateProvider(sp) + } h.aiServicesMu.Lock() for _, svc := range h.aiServices { @@ -335,7 +360,9 @@ func (h *AISettingsHandler) GetStateProvider() ai.StateProvider { // SetResourceProvider sets the resource provider for unified infrastructure context (Phase 2) func (h *AISettingsHandler) SetResourceProvider(rp ai.ResourceProvider) { h.resourceProvider = rp - h.legacyAIService.SetResourceProvider(rp) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetResourceProvider(rp) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -347,7 +374,9 @@ func (h *AISettingsHandler) SetResourceProvider(rp ai.ResourceProvider) { // SetMetadataProvider sets the metadata provider for AI URL discovery func (h *AISettingsHandler) SetMetadataProvider(mp ai.MetadataProvider) { h.metadataProvider = mp - h.legacyAIService.SetMetadataProvider(mp) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetMetadataProvider(mp) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -369,7 +398,9 @@ func (h *AISettingsHandler) IsAIEnabled(ctx context.Context) bool { // SetPatrolThresholdProvider sets the threshold provider for the patrol service func (h *AISettingsHandler) SetPatrolThresholdProvider(provider ai.ThresholdProvider) { h.patrolThresholdProvider = provider - h.legacyAIService.SetPatrolThresholdProvider(provider) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetPatrolThresholdProvider(provider) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -381,9 +412,11 @@ func (h *AISettingsHandler) SetPatrolThresholdProvider(provider ai.ThresholdProv // SetPatrolFindingsPersistence enables findings persistence for the patrol service func (h *AISettingsHandler) SetPatrolFindingsPersistence(persistence ai.FindingsPersistence) error { var firstErr error - if patrol := h.legacyAIService.GetPatrolService(); patrol != nil { - if err := patrol.SetFindingsPersistence(persistence); err != nil { - firstErr = err + if svc := h.ensureLegacyAIService(); svc != nil { + if patrol := svc.GetPatrolService(); patrol != nil { + if err := patrol.SetFindingsPersistence(persistence); err != nil { + firstErr = err + } } } // Also apply to active services @@ -405,9 +438,11 @@ func (h *AISettingsHandler) SetPatrolFindingsPersistence(persistence ai.Findings // SetPatrolRunHistoryPersistence enables patrol run history persistence for the patrol service func (h *AISettingsHandler) SetPatrolRunHistoryPersistence(persistence ai.PatrolHistoryPersistence) error { var firstErr error - if patrol := h.legacyAIService.GetPatrolService(); patrol != nil { - if err := patrol.SetRunHistoryPersistence(persistence); err != nil { - firstErr = err + if svc := h.ensureLegacyAIService(); svc != nil { + if patrol := svc.GetPatrolService(); patrol != nil { + if err := patrol.SetRunHistoryPersistence(persistence); err != nil { + firstErr = err + } } } // Also apply to active services @@ -429,7 +464,9 @@ func (h *AISettingsHandler) SetPatrolRunHistoryPersistence(persistence ai.Patrol // SetMetricsHistoryProvider sets the metrics history provider for enriched AI context func (h *AISettingsHandler) SetMetricsHistoryProvider(provider ai.MetricsHistoryProvider) { h.metricsHistoryProvider = provider - h.legacyAIService.SetMetricsHistoryProvider(provider) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetMetricsHistoryProvider(provider) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -441,7 +478,9 @@ func (h *AISettingsHandler) SetMetricsHistoryProvider(provider ai.MetricsHistory // SetBaselineStore sets the baseline store for anomaly detection func (h *AISettingsHandler) SetBaselineStore(store *ai.BaselineStore) { h.baselineStore = store - h.legacyAIService.SetBaselineStore(store) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetBaselineStore(store) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -453,7 +492,9 @@ func (h *AISettingsHandler) SetBaselineStore(store *ai.BaselineStore) { // SetChangeDetector sets the change detector for operational memory func (h *AISettingsHandler) SetChangeDetector(detector *ai.ChangeDetector) { h.changeDetector = detector - h.legacyAIService.SetChangeDetector(detector) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetChangeDetector(detector) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -465,7 +506,9 @@ func (h *AISettingsHandler) SetChangeDetector(detector *ai.ChangeDetector) { // SetRemediationLog sets the remediation log for tracking fix attempts func (h *AISettingsHandler) SetRemediationLog(remLog *ai.RemediationLog) { h.remediationLog = remLog - h.legacyAIService.SetRemediationLog(remLog) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetRemediationLog(remLog) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -477,7 +520,9 @@ func (h *AISettingsHandler) SetRemediationLog(remLog *ai.RemediationLog) { // SetIncidentStore sets the incident store for alert timelines. func (h *AISettingsHandler) SetIncidentStore(store *memory.IncidentStore) { h.incidentStore = store - h.legacyAIService.SetIncidentStore(store) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetIncidentStore(store) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -489,7 +534,9 @@ func (h *AISettingsHandler) SetIncidentStore(store *memory.IncidentStore) { // SetPatternDetector sets the pattern detector for failure prediction func (h *AISettingsHandler) SetPatternDetector(detector *ai.PatternDetector) { h.patternDetector = detector - h.legacyAIService.SetPatternDetector(detector) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetPatternDetector(detector) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -501,7 +548,9 @@ func (h *AISettingsHandler) SetPatternDetector(detector *ai.PatternDetector) { // SetCorrelationDetector sets the correlation detector for multi-resource correlation func (h *AISettingsHandler) SetCorrelationDetector(detector *ai.CorrelationDetector) { h.correlationDetector = detector - h.legacyAIService.SetCorrelationDetector(detector) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetCorrelationDetector(detector) + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() @@ -574,8 +623,8 @@ func (h *AISettingsHandler) GetUnifiedStore() *unified.UnifiedStore { func (h *AISettingsHandler) SetDiscoveryStore(store *servicediscovery.Store) { h.discoveryStore = store // Also set on legacy service if it exists - if h.legacyAIService != nil { - h.legacyAIService.SetDiscoveryStore(store) + if svc := h.ensureLegacyAIService(); svc != nil { + svc.SetDiscoveryStore(store) } // Set on all existing tenant services h.aiServicesMu.RLock() @@ -632,7 +681,9 @@ func (h *AISettingsHandler) GetIncidentRecorder() *metrics.IncidentRecorder { // StopPatrol stops the background AI patrol service func (h *AISettingsHandler) StopPatrol() { - h.legacyAIService.StopPatrol() + if svc := h.ensureLegacyAIService(); svc != nil { + svc.StopPatrol() + } h.aiServicesMu.Lock() defer h.aiServicesMu.Unlock() for _, svc := range h.aiServices { @@ -642,7 +693,10 @@ func (h *AISettingsHandler) StopPatrol() { // GetAlertTriggeredAnalyzer returns the alert-triggered analyzer for wiring into alert callbacks func (h *AISettingsHandler) GetAlertTriggeredAnalyzer(ctx context.Context) *ai.AlertTriggeredAnalyzer { - return h.GetAIService(ctx).GetAlertTriggeredAnalyzer() + if svc := h.GetAIService(ctx); svc != nil { + return svc.GetAlertTriggeredAnalyzer() + } + return nil } // SetLicenseHandlers sets the license handlers for Pro feature gating @@ -651,8 +705,13 @@ func (h *AISettingsHandler) SetLicenseHandlers(handlers *LicenseHandlers) { // Update legacy service? // legacy service needs a legacy/default license checker? // We can try to get it using background context (default tenant) + if handlers == nil { + return + } if svc, _, err := handlers.getTenantComponents(context.Background()); err == nil { - h.legacyAIService.SetLicenseChecker(svc) + if legacy := h.ensureLegacyAIService(); legacy != nil { + legacy.SetLicenseChecker(svc) + } } } diff --git a/internal/api/auth.go b/internal/api/auth.go index 07f45e94b..bdf2d658e 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -202,6 +202,12 @@ func CheckAuth(cfg *config.Config, w http.ResponseWriter, r *http.Request) bool return true } + // Requests that already carry validated auth context from the outer + // middleware should not be forced back through global-config auth checks. + if applyAuthContextHeaders(w, r) { + return true + } + config.Mu.RLock() defer config.Mu.RUnlock() @@ -575,6 +581,26 @@ func CheckAuth(cfg *config.Config, w http.ResponseWriter, r *http.Request) bool return false } +func applyAuthContextHeaders(w http.ResponseWriter, r *http.Request) bool { + if r == nil { + return false + } + + username := internalauth.GetUser(r.Context()) + if username == "" { + return false + } + + if w != nil { + w.Header().Set("X-Authenticated-User", username) + if internalauth.GetAPIToken(r.Context()) != nil { + w.Header().Set("X-Auth-Method", "api_token") + } + } + + return true +} + // RequireAuth middleware checks for authentication func RequireAuth(cfg *config.Config, handler http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -877,13 +903,7 @@ func extractAndStoreAuthContext(cfg *config.Config, mtm *monitoring.MultiTenantM targetConfig := cfg if mtm != nil { - // Check for Tenant ID in header or cookie - orgID := "default" - if id := r.Header.Get("X-Pulse-Org-ID"); id != "" { - orgID = id - } else if cookie, err := r.Cookie("pulse_org_id"); err == nil && cookie.Value != "" { - orgID = cookie.Value - } + orgID := requestedOrgID(r) // If targeting a specific tenant, try to load that tenant's config if orgID != "default" { diff --git a/internal/api/config_handlers.go b/internal/api/config_handlers.go index 329ed1b54..9c1b02c3c 100644 --- a/internal/api/config_handlers.go +++ b/internal/api/config_handlers.go @@ -244,6 +244,14 @@ func (h *ConfigHandlers) SetConfig(cfg *config.Config) { h.legacyConfig = cfg } +// SetPersistence updates the legacy persistence used for single-tenant runtime paths. +func (h *ConfigHandlers) SetPersistence(p *config.ConfigPersistence) { + if p == nil { + return + } + h.legacyPersistence = p +} + // getContextState helper to retrieve tenant-specific state func (h *ConfigHandlers) getContextState(ctx context.Context) (*config.Config, *config.ConfigPersistence, *monitoring.Monitor) { orgID := "default" @@ -835,6 +843,55 @@ func ipsOnSameNetwork(ip1, ip2 net.IP) bool { return false } +func interfaceNetwork(iface proxmox.NodeNetworkInterface) (*net.IPNet, net.IP) { + if strings.TrimSpace(iface.CIDR) != "" { + if ip, network, err := net.ParseCIDR(strings.TrimSpace(iface.CIDR)); err == nil { + return network, ip + } + } + + address := net.ParseIP(strings.TrimSpace(iface.Address)) + netmask := net.ParseIP(strings.TrimSpace(iface.Netmask)) + if ipv4 := address.To4(); ipv4 != nil { + if maskIPv4 := netmask.To4(); maskIPv4 != nil { + mask := net.IPMask(maskIPv4) + return &net.IPNet{IP: ipv4.Mask(mask), Mask: mask}, ipv4 + } + } + + address6 := net.ParseIP(strings.TrimSpace(iface.Address6)) + if strings.TrimSpace(iface.CIDR) != "" && address6 != nil { + if _, network, err := net.ParseCIDR(strings.TrimSpace(iface.CIDR)); err == nil { + return network, address6 + } + } + + return nil, nil +} + +func likelySameSubnet(ip1, ip2 net.IP) bool { + if ip1 == nil || ip2 == nil { + return false + } + + if ip1v4 := ip1.To4(); ip1v4 != nil { + ip2v4 := ip2.To4() + if ip2v4 == nil { + return false + } + mask := net.CIDRMask(24, 32) + return ip1v4.Mask(mask).Equal(ip2v4.Mask(mask)) + } + + ip1v6 := ip1.To16() + ip2v6 := ip2.To16() + if ip1v6 == nil || ip2v6 == nil { + return false + } + mask := net.CIDRMask(64, 128) + return ip1v6.Mask(mask).Equal(ip2v6.Mask(mask)) +} + // findPreferredIP looks through a list of node network interfaces and returns // an IP that appears to be on the same network as the reference IP. // Returns empty string if no match found. @@ -843,21 +900,38 @@ func findPreferredIP(interfaces []proxmox.NodeNetworkInterface, referenceIP net. return "" } + bestIP := "" + bestPrefix := -1 for _, iface := range interfaces { // Skip inactive interfaces if iface.Active != 1 { continue } - // Check IPv4 address + network, ifaceIP := interfaceNetwork(iface) + if network != nil && ifaceIP != nil && network.Contains(referenceIP) { + ones, _ := network.Mask.Size() + candidate := strings.TrimSpace(iface.Address) + if candidate == "" { + candidate = ifaceIP.String() + } + if candidate != "" && ones > bestPrefix { + bestIP = candidate + bestPrefix = ones + } + continue + } + + // Fallback for older payloads that don't include CIDR/netmask details. if iface.Address != "" { - ip := net.ParseIP(iface.Address) - if ip != nil && ipsOnSameNetwork(ip, referenceIP) { - return iface.Address + ip := net.ParseIP(strings.TrimSpace(iface.Address)) + if ip != nil && likelySameSubnet(ip, referenceIP) { + return strings.TrimSpace(iface.Address) } } } - return "" + + return bestIP } var detectPVECluster = defaultDetectPVECluster @@ -984,32 +1058,27 @@ func defaultDetectPVECluster(clientConfig proxmox.ClientConfig, nodeName string, // Try to find a better IP on the same network as initial connection (management network) // Only do this if no manual override is set - if endpoint.IPOverride == "" && connectionIP != nil && clusterNode.IP != "" { - // Check if cluster-reported IP is already on the same network as our connection - clusterIP := net.ParseIP(clusterNode.IP) - if clusterIP != nil && !ipsOnSameNetwork(clusterIP, connectionIP) { - // Cluster IP is on a different network, try to find one on the same network - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - nodeInterfaces, err := tempClient.GetNodeNetworkInterfaces(ctx, clusterNode.Name) - cancel() + if endpoint.IPOverride == "" && connectionIP != nil && clusterNode.Name != "" { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + nodeInterfaces, err := tempClient.GetNodeNetworkInterfaces(ctx, clusterNode.Name) + cancel() - if err == nil { - preferredIP := findPreferredIP(nodeInterfaces, connectionIP) - if preferredIP != "" && preferredIP != clusterNode.IP { - log.Info(). - Str("node", clusterNode.Name). - Str("cluster_ip", clusterNode.IP). - Str("preferred_ip", preferredIP). - Str("connection_ip", connectionIP.String()). - Msg("Found preferred management IP for cluster node") - endpoint.IPOverride = preferredIP - } - } else { - log.Debug(). - Err(err). + if err == nil { + preferredIP := findPreferredIP(nodeInterfaces, connectionIP) + if preferredIP != "" && preferredIP != clusterNode.IP { + log.Info(). Str("node", clusterNode.Name). - Msg("Could not query node network interfaces for network preference") + Str("cluster_ip", clusterNode.IP). + Str("preferred_ip", preferredIP). + Str("connection_ip", connectionIP.String()). + Msg("Found preferred management IP for cluster node") + endpoint.IPOverride = preferredIP } + } else { + log.Debug(). + Err(err). + Str("node", clusterNode.Name). + Msg("Could not query node network interfaces for network preference") } } @@ -1049,31 +1118,27 @@ func defaultDetectPVECluster(clientConfig proxmox.ClientConfig, nodeName string, // Apply subnet preference even in fallback path (refs #929) // Node validation may have failed because cluster-reported IPs are on internal // network, but we can still query node interfaces via the initial connection - if connectionIP != nil && clusterNode.IP != "" && clusterNode.Name != "" { - clusterIP := net.ParseIP(clusterNode.IP) - if clusterIP != nil && !ipsOnSameNetwork(clusterIP, connectionIP) { - // Cluster IP is on a different network, try to find one on the same network - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - nodeInterfaces, err := tempClient.GetNodeNetworkInterfaces(ctx, clusterNode.Name) - cancel() + if connectionIP != nil && clusterNode.Name != "" { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + nodeInterfaces, err := tempClient.GetNodeNetworkInterfaces(ctx, clusterNode.Name) + cancel() - if err == nil { - preferredIP := findPreferredIP(nodeInterfaces, connectionIP) - if preferredIP != "" && preferredIP != clusterNode.IP { - log.Info(). - Str("node", clusterNode.Name). - Str("cluster_ip", clusterNode.IP). - Str("preferred_ip", preferredIP). - Str("connection_ip", connectionIP.String()). - Msg("Found preferred management IP for unvalidated cluster node") - endpoint.IPOverride = preferredIP - } - } else { - log.Debug(). - Err(err). + if err == nil { + preferredIP := findPreferredIP(nodeInterfaces, connectionIP) + if preferredIP != "" && preferredIP != clusterNode.IP { + log.Info(). Str("node", clusterNode.Name). - Msg("Could not query node network interfaces in fallback path") + Str("cluster_ip", clusterNode.IP). + Str("preferred_ip", preferredIP). + Str("connection_ip", connectionIP.String()). + Msg("Found preferred management IP for unvalidated cluster node") + endpoint.IPOverride = preferredIP } + } else { + log.Debug(). + Err(err). + Str("node", clusterNode.Name). + Msg("Could not query node network interfaces in fallback path") } } @@ -4922,7 +4987,7 @@ func (h *ConfigHandlers) HandleSetupScriptURL(w http.ResponseWriter, r *http.Req Used: false, NodeType: req.Type, Host: req.Host, - OrgID: GetOrgID(r.Context()), + OrgID: "default", } h.codeMutex.Unlock() @@ -5143,11 +5208,43 @@ func (h *ConfigHandlers) HandleAutoRegister(w http.ResponseWriter, r *http.Reque Bool("hasConfigToken", h.getConfig(r.Context()).HasAPITokens()). Msg("Checking authentication for auto-register") + validateAPIToken := func(rawToken string) (*config.APITokenRecord, bool) { + token := strings.TrimSpace(rawToken) + if token == "" { + return nil, false + } + + // Mirror the main auth path: avoid taking the write lock for obviously invalid tokens, + // then update usage metadata only after a positive read-only check. + config.Mu.RLock() + valid := h.getConfig(r.Context()).IsValidAPIToken(token) + config.Mu.RUnlock() + if !valid { + return nil, false + } + + config.Mu.Lock() + record, ok := h.getConfig(r.Context()).ValidateAPIToken(token) + config.Mu.Unlock() + return record, ok + } + + requestAPIToken := func() string { + if token := strings.TrimSpace(r.Header.Get("X-API-Token")); token != "" { + return token + } + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + return strings.TrimSpace(authHeader[7:]) + } + return "" + } + // First check for setup code/auth token in the request if authCode != "" { matchedAPIToken := false if h.getConfig(r.Context()).HasAPITokens() { - if record, ok := h.getConfig(r.Context()).ValidateAPIToken(authCode); ok { + if record, ok := validateAPIToken(authCode); ok { // Accept settings:write (admin tokens) or host-agent:report (agent tokens) if record.HasScope(config.ScopeSettingsWrite) || record.HasScope(config.ScopeHostReport) { authenticated = true @@ -5185,11 +5282,6 @@ func (h *ConfigHandlers) HandleAutoRegister(w http.ResponseWriter, r *http.Reque if setupCode.NodeType == req.Type { setupCode.Used = true // Mark as used immediately - // Inject OrgID from setup code into context for subsequent processing - if setupCode.OrgID != "" { - ctx := context.WithValue(r.Context(), OrgIDContextKey, setupCode.OrgID) - r = r.WithContext(ctx) - } // Allow a short grace period for follow-up actions without keeping tokens alive too long graceExpiry := time.Now().Add(1 * time.Minute) if setupCode.ExpiresAt.Before(graceExpiry) { @@ -5221,8 +5313,8 @@ func (h *ConfigHandlers) HandleAutoRegister(w http.ResponseWriter, r *http.Reque // If not authenticated via setup code, check API token if configured if !authenticated && h.getConfig(r.Context()).HasAPITokens() { - apiToken := r.Header.Get("X-API-Token") - if record, ok := h.getConfig(r.Context()).ValidateAPIToken(apiToken); ok { + apiToken := requestAPIToken() + if record, ok := validateAPIToken(apiToken); ok { // Accept settings:write (admin tokens) or host-agent:report (agent tokens) if record.HasScope(config.ScopeSettingsWrite) || record.HasScope(config.ScopeHostReport) { authenticated = true @@ -6210,6 +6302,8 @@ func (h *ConfigHandlers) HandleAgentInstallCommand(w http.ResponseWriter, r *htt return } + defaultCtx := context.WithValue(r.Context(), OrgIDContextKey, "default") + // Generate a new API token with host report and host manage scopes rawToken, err := internalauth.GenerateAPIToken() if err != nil { @@ -6232,16 +6326,15 @@ func (h *ConfigHandlers) HandleAgentInstallCommand(w http.ResponseWriter, r *htt http.Error(w, "Failed to generate token", http.StatusInternalServerError) return } - // Persist the token config.Mu.Lock() - h.getConfig(r.Context()).APITokens = append(h.getConfig(r.Context()).APITokens, *record) - h.getConfig(r.Context()).SortAPITokens() + h.getConfig(defaultCtx).APITokens = append(h.getConfig(defaultCtx).APITokens, *record) + h.getConfig(defaultCtx).SortAPITokens() - if h.getPersistence(r.Context()) != nil { - if err := h.getPersistence(r.Context()).SaveAPITokens(h.getConfig(r.Context()).APITokens); err != nil { + if h.getPersistence(defaultCtx) != nil { + if err := h.getPersistence(defaultCtx).SaveAPITokens(h.getConfig(defaultCtx).APITokens); err != nil { // Rollback the in-memory addition - h.getConfig(r.Context()).APITokens = h.getConfig(r.Context()).APITokens[:len(h.getConfig(r.Context()).APITokens)-1] + h.getConfig(defaultCtx).APITokens = h.getConfig(defaultCtx).APITokens[:len(h.getConfig(defaultCtx).APITokens)-1] config.Mu.Unlock() log.Error().Err(err).Msg("Failed to persist API tokens after creation") http.Error(w, "Failed to save token to disk: "+err.Error(), http.StatusInternalServerError) @@ -6253,9 +6346,9 @@ func (h *ConfigHandlers) HandleAgentInstallCommand(w http.ResponseWriter, r *htt // Derive Pulse URL from the request host := r.Host if parsedHost, parsedPort, err := net.SplitHostPort(host); err == nil { - if (parsedHost == "127.0.0.1" || parsedHost == "localhost") && parsedPort == strconv.Itoa(h.getConfig(r.Context()).FrontendPort) { + if (parsedHost == "127.0.0.1" || parsedHost == "localhost") && parsedPort == strconv.Itoa(h.getConfig(defaultCtx).FrontendPort) { // Prefer a user-configured public URL when we're running on loopback - if publicURL := strings.TrimSpace(h.getConfig(r.Context()).PublicURL); publicURL != "" { + if publicURL := strings.TrimSpace(h.getConfig(defaultCtx).PublicURL); publicURL != "" { if parsedURL, err := url.Parse(publicURL); err == nil && parsedURL.Host != "" { host = parsedURL.Host } diff --git a/internal/api/config_handlers_cluster_additional_test.go b/internal/api/config_handlers_cluster_additional_test.go index 39b54efd2..73799f936 100644 --- a/internal/api/config_handlers_cluster_additional_test.go +++ b/internal/api/config_handlers_cluster_additional_test.go @@ -455,3 +455,33 @@ func TestHandleAgentInstallCommand(t *testing.T) { t.Fatalf("expected API token to be persisted") } } + +func TestHandleAgentInstallCommandIgnoresOrgID(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + body := []byte(`{"type":"pve"}`) + req := httptest.NewRequest(http.MethodPost, "/api/config/agent-install", bytes.NewReader(body)) + req.Host = "example.com:8080" + req = req.WithContext(context.WithValue(req.Context(), OrgIDContextKey, "acme")) + rec := httptest.NewRecorder() + handler.HandleAgentInstallCommand(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp AgentInstallCommandResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if bytes.Contains([]byte(resp.Command), []byte("--org-id")) { + t.Fatalf("expected command to remain single-tenant, got %q", resp.Command) + } + if len(cfg.APITokens) != 1 { + t.Fatalf("expected API token to be persisted") + } + if cfg.APITokens[0].OrgID != "" { + t.Fatalf("expected generated token to stay unbound, got %q", cfg.APITokens[0].OrgID) + } +} diff --git a/internal/api/config_handlers_helpers_additional_test.go b/internal/api/config_handlers_helpers_additional_test.go index 4e130ff59..4f0529dd6 100644 --- a/internal/api/config_handlers_helpers_additional_test.go +++ b/internal/api/config_handlers_helpers_additional_test.go @@ -51,9 +51,9 @@ func TestIPsOnSameNetwork(t *testing.T) { func TestFindPreferredIP(t *testing.T) { interfaces := []proxmox.NodeNetworkInterface{ - {Active: 0, Address: "10.0.0.10"}, - {Active: 1, Address: "10.0.0.11"}, - {Active: 1, Address: "10.0.1.10"}, + {Active: 0, Address: "10.0.0.10", CIDR: "10.0.0.10/24"}, + {Active: 1, Address: "10.0.0.11", CIDR: "10.0.0.11/24"}, + {Active: 1, Address: "10.0.1.10", CIDR: "10.0.1.10/24"}, } ref := net.ParseIP("10.0.0.50") @@ -65,3 +65,28 @@ func TestFindPreferredIP(t *testing.T) { t.Fatalf("findPreferredIP = %q, want empty", got) } } + +func TestFindPreferredIP_UsesMostSpecificMatchingSubnet(t *testing.T) { + interfaces := []proxmox.NodeNetworkInterface{ + {Active: 1, Address: "10.15.5.20", CIDR: "10.15.5.20/24"}, + {Active: 1, Address: "10.15.2.20", CIDR: "10.15.2.20/24"}, + {Active: 1, Address: "10.15.0.20", CIDR: "10.15.0.20/16"}, + } + + ref := net.ParseIP("10.15.2.99") + if got := findPreferredIP(interfaces, ref); got != "10.15.2.20" { + t.Fatalf("findPreferredIP = %q, want 10.15.2.20", got) + } +} + +func TestFindPreferredIP_FallbackDoesNotTreatDifferentThirdOctetAsSameSubnet(t *testing.T) { + interfaces := []proxmox.NodeNetworkInterface{ + {Active: 1, Address: "10.15.5.20"}, + {Active: 1, Address: "10.15.2.20"}, + } + + ref := net.ParseIP("10.15.2.99") + if got := findPreferredIP(interfaces, ref); got != "10.15.2.20" { + t.Fatalf("findPreferredIP fallback = %q, want 10.15.2.20", got) + } +} diff --git a/internal/api/config_handlers_setup_url_test.go b/internal/api/config_handlers_setup_url_test.go index e75895c29..455b3f704 100644 --- a/internal/api/config_handlers_setup_url_test.go +++ b/internal/api/config_handlers_setup_url_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -10,6 +11,7 @@ import ( "testing" "github.com/rcourtman/pulse-go-rewrite/internal/config" + internalauth "github.com/rcourtman/pulse-go-rewrite/pkg/auth" ) func TestHandleSetupScriptURL(t *testing.T) { @@ -130,3 +132,53 @@ func TestHandleSetupScriptURL_MethodNotAllowed(t *testing.T) { t.Errorf("expected method not allowed, got %v", w.Code) } } + +func TestHandleSetupScriptURLIgnoresOrgContext(t *testing.T) { + tempDir, err := os.MkdirTemp("", "pulse-setup-url-test-org") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + cfg := &config.Config{ + DataPath: tempDir, + FrontendPort: 8080, + PublicURL: "https://pulse.example.com", + } + handler := newTestConfigHandlers(t, cfg) + + body := bytes.NewBufferString(`{"type":"pve","host":"delly"}`) + req := httptest.NewRequest(http.MethodPost, "/api/setup-script-url", body) + req = req.WithContext(context.WithValue(req.Context(), OrgIDContextKey, "acme")) + req.Host = "127.0.0.1:8080" + w := httptest.NewRecorder() + + handler.HandleSetupScriptURL(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + token, ok := resp["setupToken"].(string) + if !ok || token == "" { + t.Fatalf("expected setup token in response, got %#v", resp["setupToken"]) + } + + tokenHash := internalauth.HashAPIToken(token) + + handler.codeMutex.RLock() + setupCode := handler.setupCodes[tokenHash] + handler.codeMutex.RUnlock() + + if setupCode == nil { + t.Fatalf("expected setup code to be stored for token hash %q", tokenHash) + } + if setupCode.OrgID != "default" { + t.Fatalf("expected setup code org to be forced to default, got %q", setupCode.OrgID) + } +} diff --git a/internal/api/docker_metadata.go b/internal/api/docker_metadata.go index 12eacd446..83a7fc306 100644 --- a/internal/api/docker_metadata.go +++ b/internal/api/docker_metadata.go @@ -13,7 +13,8 @@ import ( // DockerMetadataHandler handles Docker resource metadata operations type DockerMetadataHandler struct { - mtPersistence *config.MultiTenantPersistence + mtPersistence *config.MultiTenantPersistence + legacyPersistence *config.ConfigPersistence } // NewDockerMetadataHandler creates a new Docker metadata handler @@ -23,6 +24,10 @@ func NewDockerMetadataHandler(mtPersistence *config.MultiTenantPersistence) *Doc } } +func (h *DockerMetadataHandler) SetLegacyPersistence(persistence *config.ConfigPersistence) { + h.legacyPersistence = persistence +} + func (h *DockerMetadataHandler) getStore(ctx context.Context) *config.DockerMetadataStore { orgID := "default" if ctx != nil { @@ -30,8 +35,15 @@ func (h *DockerMetadataHandler) getStore(ctx context.Context) *config.DockerMeta orgID = id } } - p, _ := h.mtPersistence.GetPersistence(orgID) - return p.GetDockerMetadataStore() + if h.mtPersistence != nil { + if p, err := h.mtPersistence.GetPersistence(orgID); err == nil && p != nil { + return p.GetDockerMetadataStore() + } + } + if h.legacyPersistence != nil { + return h.legacyPersistence.GetDockerMetadataStore() + } + return nil } // Store returns the underlying metadata store for default tenant diff --git a/internal/api/guest_metadata.go b/internal/api/guest_metadata.go index f955fa7ec..17605d4ef 100644 --- a/internal/api/guest_metadata.go +++ b/internal/api/guest_metadata.go @@ -13,7 +13,8 @@ import ( // GuestMetadataHandler handles guest metadata operations type GuestMetadataHandler struct { - mtPersistence *config.MultiTenantPersistence + mtPersistence *config.MultiTenantPersistence + legacyPersistence *config.ConfigPersistence } // NewGuestMetadataHandler creates a new guest metadata handler @@ -23,6 +24,10 @@ func NewGuestMetadataHandler(mtPersistence *config.MultiTenantPersistence) *Gues } } +func (h *GuestMetadataHandler) SetLegacyPersistence(persistence *config.ConfigPersistence) { + h.legacyPersistence = persistence +} + func (h *GuestMetadataHandler) getStore(ctx context.Context) *config.GuestMetadataStore { // Default to "default" org if none specified (though middleware should always set it) orgID := "default" @@ -31,8 +36,15 @@ func (h *GuestMetadataHandler) getStore(ctx context.Context) *config.GuestMetada orgID = id } } - p, _ := h.mtPersistence.GetPersistence(orgID) - return p.GetGuestMetadataStore() + if h.mtPersistence != nil { + if p, err := h.mtPersistence.GetPersistence(orgID); err == nil && p != nil { + return p.GetGuestMetadataStore() + } + } + if h.legacyPersistence != nil { + return h.legacyPersistence.GetGuestMetadataStore() + } + return nil } // Reload reloads the guest metadata from disk @@ -43,7 +55,10 @@ func (h *GuestMetadataHandler) Reload() error { // But stores load on init. Reload() method on store might be needed if modified on disk externally. // For now, this is a no-op or TODO for multi-tenant deep reload. // Actually, we can get "default" store and reload it for legacy compat. - return h.getStore(context.Background()).Load() + if store := h.getStore(context.Background()); store != nil { + return store.Load() + } + return nil } // Store returns the underlying metadata store for the default tenant (Legacy support) diff --git a/internal/api/host_metadata.go b/internal/api/host_metadata.go index ce2f7a851..ff7cdf012 100644 --- a/internal/api/host_metadata.go +++ b/internal/api/host_metadata.go @@ -13,7 +13,8 @@ import ( // HostMetadataHandler handles host metadata operations type HostMetadataHandler struct { - mtPersistence *config.MultiTenantPersistence + mtPersistence *config.MultiTenantPersistence + legacyPersistence *config.ConfigPersistence } // NewHostMetadataHandler creates a new host metadata handler @@ -23,6 +24,10 @@ func NewHostMetadataHandler(mtPersistence *config.MultiTenantPersistence) *HostM } } +func (h *HostMetadataHandler) SetLegacyPersistence(persistence *config.ConfigPersistence) { + h.legacyPersistence = persistence +} + func (h *HostMetadataHandler) getStore(ctx context.Context) *config.HostMetadataStore { orgID := "default" if ctx != nil { @@ -30,8 +35,15 @@ func (h *HostMetadataHandler) getStore(ctx context.Context) *config.HostMetadata orgID = id } } - p, _ := h.mtPersistence.GetPersistence(orgID) - return p.GetHostMetadataStore() + if h.mtPersistence != nil { + if p, err := h.mtPersistence.GetPersistence(orgID); err == nil && p != nil { + return p.GetHostMetadataStore() + } + } + if h.legacyPersistence != nil { + return h.legacyPersistence.GetHostMetadataStore() + } + return nil } // Store returns the underlying metadata store for default tenant diff --git a/internal/api/license_handlers.go b/internal/api/license_handlers.go index d0e42114b..95580be64 100644 --- a/internal/api/license_handlers.go +++ b/internal/api/license_handlers.go @@ -17,10 +17,11 @@ import ( // LicenseHandlers handles license management API endpoints. // LicenseHandlers handles license management API endpoints. type LicenseHandlers struct { - mtPersistence *config.MultiTenantPersistence - services sync.Map // map[string]*license.Service - configDir string // Base config dir, though we use mtPersistence for tenants - auditOnce sync.Once + mtPersistence *config.MultiTenantPersistence + legacyPersistence *config.ConfigPersistence + services sync.Map // map[string]*license.Service + configDir string // Base config dir, though we use mtPersistence for tenants + auditOnce sync.Once } // NewLicenseHandlers creates a new license handlers instance. @@ -30,6 +31,10 @@ func NewLicenseHandlers(mtp *config.MultiTenantPersistence) *LicenseHandlers { } } +func (h *LicenseHandlers) SetLegacyPersistence(persistence *config.ConfigPersistence) { + h.legacyPersistence = persistence +} + // getTenantComponents resolves the license service and persistence for the current tenant. // It initializes them if they haven't been loaded yet. func (h *LicenseHandlers) getTenantComponents(ctx context.Context) (*license.Service, *license.Persistence, error) { @@ -81,11 +86,17 @@ func (h *LicenseHandlers) getTenantComponents(ctx context.Context) (*license.Ser } func (h *LicenseHandlers) getPersistenceForOrg(orgID string) (*license.Persistence, error) { - configPersistence, err := h.mtPersistence.GetPersistence(orgID) - if err != nil { - return nil, err + if h.mtPersistence != nil { + configPersistence, err := h.mtPersistence.GetPersistence(orgID) + if err != nil { + return nil, err + } + return license.NewPersistence(configPersistence.GetConfigDir()) } - return license.NewPersistence(configPersistence.GetConfigDir()) + if h.legacyPersistence == nil { + return nil, nil + } + return license.NewPersistence(h.legacyPersistence.GetConfigDir()) } // initAuditLoggerIfLicensed initializes the SQLite audit logger if the license diff --git a/internal/api/license_handlers_test.go b/internal/api/license_handlers_test.go index 6b4ae0265..ad3a74e8e 100644 --- a/internal/api/license_handlers_test.go +++ b/internal/api/license_handlers_test.go @@ -24,6 +24,20 @@ func createTestHandler(t *testing.T) *LicenseHandlers { return NewLicenseHandlers(mtp) } +func TestLicenseHandlers_FallbackToLegacyPersistence(t *testing.T) { + persistence := config.NewConfigPersistence(t.TempDir()) + handler := NewLicenseHandlers(nil) + handler.SetLegacyPersistence(persistence) + + svc, p, err := handler.getTenantComponents(context.Background()) + if err != nil { + t.Fatalf("expected legacy persistence fallback, got error: %v", err) + } + if svc == nil || p == nil { + t.Fatalf("expected service and persistence from legacy fallback") + } +} + type licenseFeaturesResponse struct { LicenseStatus string `json:"license_status"` Features map[string]bool `json:"features"` diff --git a/internal/api/metadata_handlers_test.go b/internal/api/metadata_handlers_test.go index 4b7efff8a..7b82f8489 100644 --- a/internal/api/metadata_handlers_test.go +++ b/internal/api/metadata_handlers_test.go @@ -119,3 +119,25 @@ func TestHostMetadataHandler(t *testing.T) { t.Fatalf("unexpected status: %d", resp.Code) } } + +func TestMetadataHandlers_FallbackToLegacyPersistence(t *testing.T) { + persistence := config.NewConfigPersistence(t.TempDir()) + + guestHandler := NewGuestMetadataHandler(nil) + guestHandler.SetLegacyPersistence(persistence) + if guestHandler.Store() == nil { + t.Fatal("expected guest metadata store from legacy persistence") + } + + dockerHandler := NewDockerMetadataHandler(nil) + dockerHandler.SetLegacyPersistence(persistence) + if dockerHandler.Store() == nil { + t.Fatal("expected docker metadata store from legacy persistence") + } + + hostHandler := NewHostMetadataHandler(nil) + hostHandler.SetLegacyPersistence(persistence) + if hostHandler.Store() == nil { + t.Fatal("expected host metadata store from legacy persistence") + } +} diff --git a/internal/api/middleware_license.go b/internal/api/middleware_license.go index 3cd973516..83f3e175f 100644 --- a/internal/api/middleware_license.go +++ b/internal/api/middleware_license.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "os" + "path/filepath" "strings" "sync" @@ -18,11 +19,31 @@ import ( // AND properly licensed for non-default organizations to work. var multiTenantEnabled = strings.EqualFold(os.Getenv("PULSE_MULTI_TENANT_ENABLED"), "true") +// v5 should behave as single-tenant in real runtime even if dormant multi-tenant +// code remains in the branch. Tests can disable this to exercise legacy paths. +var v5SingleTenantMode = !runningUnderGoTest() + // IsMultiTenantEnabled returns whether multi-tenant functionality is enabled. func IsMultiTenantEnabled() bool { return multiTenantEnabled } +func isV5SingleTenantMode() bool { + return v5SingleTenantMode +} + +func IsV5SingleTenantRuntime() bool { + return isV5SingleTenantMode() +} + +func setV5SingleTenantModeForTests(enabled bool) { + v5SingleTenantMode = enabled +} + +func runningUnderGoTest() bool { + return strings.HasSuffix(filepath.Base(os.Args[0]), ".test") +} + // DefaultMultiTenantChecker implements websocket.MultiTenantChecker for use with the WebSocket hub. type DefaultMultiTenantChecker struct{} diff --git a/internal/api/middleware_tenant.go b/internal/api/middleware_tenant.go index 28db3fb12..e42b304f8 100644 --- a/internal/api/middleware_tenant.go +++ b/internal/api/middleware_tenant.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net/http" + "strings" "github.com/rcourtman/pulse-go-rewrite/internal/config" "github.com/rcourtman/pulse-go-rewrite/internal/models" @@ -51,24 +52,7 @@ func (m *TenantMiddleware) SetAuthChecker(checker AuthorizationChecker) { func (m *TenantMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 1. Extract Org ID - // Priority: - // 1. Header: X-Pulse-Org-ID (for API clients/agents) - // 2. Cookie: pulse_org_id (for browser session) - // 3. Fallback: "default" (for backward compatibility) - - orgID := r.Header.Get("X-Pulse-Org-ID") - if orgID == "" { - // Check cookie - if cookie, err := r.Cookie("pulse_org_id"); err == nil { - orgID = cookie.Value - } - } - - // Fallback to default - if orgID == "" { - orgID = "default" - } + orgID := requestedOrgID(r) // 2. Validate Organization Exists (only for non-default orgs) // This must check existence WITHOUT creating directories to prevent DoS. @@ -174,3 +158,29 @@ func GetOrganization(ctx context.Context) *models.Organization { } return &models.Organization{ID: "default", DisplayName: "Default Organization"} } + +func requestedOrgID(r *http.Request) string { + orgID := "" + if r != nil { + orgID = strings.TrimSpace(r.Header.Get("X-Pulse-Org-ID")) + if orgID == "" { + if cookie, err := r.Cookie("pulse_org_id"); err == nil { + orgID = strings.TrimSpace(cookie.Value) + } + } + } + + if orgID != "" && orgID != "default" && isV5SingleTenantMode() { + log.Debug(). + Str("path", r.URL.Path). + Str("requested_org", orgID). + Msg("Ignoring non-default org for single-tenant v5 runtime") + return "default" + } + + if orgID == "" { + return "default" + } + + return orgID +} diff --git a/internal/api/middleware_tenant_additional_test.go b/internal/api/middleware_tenant_additional_test.go index b4de3a153..c998583f0 100644 --- a/internal/api/middleware_tenant_additional_test.go +++ b/internal/api/middleware_tenant_additional_test.go @@ -137,3 +137,28 @@ func TestTenantMiddleware_MultiTenantLicenseRequired(t *testing.T) { t.Fatalf("expected 402, got %d", rec.Code) } } + +func TestTenantMiddleware_V5SingleTenantModeIgnoresHeaderAndCookie(t *testing.T) { + prevSingleTenant := isV5SingleTenantMode() + setV5SingleTenantModeForTests(true) + t.Cleanup(func() { setV5SingleTenantModeForTests(prevSingleTenant) }) + + mw := NewTenantMiddleware(nil) + handler := mw.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := GetOrgID(r.Context()); got != "default" { + t.Fatalf("expected default org in v5 single-tenant mode, got %q", got) + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/api/dashboard", nil) + req.Header.Set("X-Pulse-Org-ID", "header-org") + req.AddCookie(&http.Cookie{Name: "pulse_org_id", Value: "cookie-org"}) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } +} diff --git a/internal/api/router.go b/internal/api/router.go index 0aea7589d..e3bf28141 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -139,6 +139,10 @@ func NewRouter(cfg *config.Config, monitor *monitoring.Monitor, mtMonitor *monit InitSessionStore(cfg.DataPath) InitCSRFStore(cfg.DataPath) + if isV5SingleTenantMode() { + mtMonitor = nil + } + updateHistory, err := updates.NewUpdateHistory(cfg.DataPath) if err != nil { log.Error().Err(err).Msg("Failed to initialize update history") @@ -152,6 +156,11 @@ func NewRouter(cfg *config.Config, monitor *monitoring.Monitor, mtMonitor *monit updateManager := updates.NewManager(cfg) updateManager.SetHistory(updateHistory) + var mtPersistence *config.MultiTenantPersistence + if !isV5SingleTenantMode() { + mtPersistence = config.NewMultiTenantPersistence(cfg.DataPath) + } + r := &Router{ mux: http.NewServeMux(), config: cfg, @@ -164,7 +173,7 @@ func NewRouter(cfg *config.Config, monitor *monitoring.Monitor, mtMonitor *monit exportLimiter: NewRateLimiter(5, 1*time.Minute), // 5 attempts per minute downloadLimiter: NewRateLimiter(60, 1*time.Minute), // downloads/installers per minute per IP persistence: config.NewConfigPersistence(cfg.DataPath), - multiTenant: config.NewMultiTenantPersistence(cfg.DataPath), + multiTenant: mtPersistence, authorizer: auth.GetAuthorizer(), serverVersion: strings.TrimSpace(serverVersion), projectRoot: projectRoot, @@ -253,12 +262,16 @@ func (r *Router) setupRoutes() { r.notificationHandlers = NewNotificationHandlers(r.mtMonitor, NewNotificationMonitorWrapper(r.monitor)) r.notificationQueueHandlers = NewNotificationQueueHandlers(r.monitor) guestMetadataHandler := NewGuestMetadataHandler(r.multiTenant) + guestMetadataHandler.SetLegacyPersistence(r.persistence) dockerMetadataHandler := NewDockerMetadataHandler(r.multiTenant) + dockerMetadataHandler.SetLegacyPersistence(r.persistence) hostMetadataHandler := NewHostMetadataHandler(r.multiTenant) + hostMetadataHandler.SetLegacyPersistence(r.persistence) r.configHandlers = NewConfigHandlers(r.multiTenant, r.mtMonitor, r.reloadFunc, r.wsHub, guestMetadataHandler, r.reloadSystemSettings) if r.monitor != nil { r.configHandlers.SetMonitor(r.monitor) } + r.configHandlers.SetPersistence(r.persistence) updateHandlers := NewUpdateHandlers(r.updateManager, r.updateHistory) r.dockerAgentHandlers = NewDockerAgentHandlers(r.mtMonitor, r.monitor, r.wsHub, r.config) r.kubernetesAgentHandlers = NewKubernetesAgentHandlers(r.mtMonitor, r.monitor, r.wsHub) @@ -266,6 +279,7 @@ func (r *Router) setupRoutes() { r.resourceHandlers = NewResourceHandlers() r.configProfileHandler = NewConfigProfileHandler(r.multiTenant) r.licenseHandlers = NewLicenseHandlers(r.multiTenant) + r.licenseHandlers.SetLegacyPersistence(r.persistence) // Wire license service provider so middleware can access per-tenant license services SetLicenseServiceProvider(r.licenseHandlers) r.reportingHandlers = NewReportingHandlers(r.mtMonitor) @@ -1439,6 +1453,7 @@ func (r *Router) setupRoutes() { // AI settings endpoints r.aiSettingsHandler = NewAISettingsHandler(r.multiTenant, r.mtMonitor, r.agentExecServer) + r.aiSettingsHandler.SetLegacyRuntime(r.config, r.persistence) // Inject state provider so AI has access to full infrastructure context (VMs, containers, IPs) if r.monitor != nil { r.aiSettingsHandler.SetStateProvider(r.monitor) diff --git a/internal/api/router_single_tenant_persistence_test.go b/internal/api/router_single_tenant_persistence_test.go new file mode 100644 index 000000000..23f072cd1 --- /dev/null +++ b/internal/api/router_single_tenant_persistence_test.go @@ -0,0 +1,26 @@ +package api + +import ( + "context" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestNewRouter_V5SingleTenantConfigHandlersUseLegacyPersistence(t *testing.T) { + prevSingleTenant := setV5SingleTenantModeForTests(true) + t.Cleanup(func() { setV5SingleTenantModeForTests(prevSingleTenant) }) + + cfg := &config.Config{DataPath: t.TempDir()} + router := NewRouter(cfg, nil, nil, nil, nil, "1.0.0") + + if router.configHandlers == nil { + t.Fatal("expected config handlers to be initialized") + } + if router.configHandlers.legacyPersistence == nil { + t.Fatal("expected legacy persistence to be wired in single-tenant mode") + } + if router.configHandlers.getPersistence(context.Background()) == nil { + t.Fatal("expected getPersistence to return legacy persistence in single-tenant mode") + } +} diff --git a/internal/api/security_regression_test.go b/internal/api/security_regression_test.go index 3c5ad3af6..e2cab9238 100644 --- a/internal/api/security_regression_test.go +++ b/internal/api/security_regression_test.go @@ -1862,6 +1862,23 @@ func TestAutoRegisterAcceptsAgentToken(t *testing.T) { } } +func TestAutoRegisterAcceptsBearerAgentToken(t *testing.T) { + rawToken := "agent-register-token-bearer-123.12345678" + record := newTokenRecord(t, rawToken, []string{config.ScopeHostReport}, nil) + cfg := newTestConfigWithTokens(t, record) + router := NewRouter(cfg, nil, nil, nil, nil, "1.0.0") + router.configHandlers.SetConfig(cfg) + + body := `{"type":"pve","host":"https://192.168.1.1:8006","tokenId":"test@pam!pulse","tokenValue":"secret"}` + req := httptest.NewRequest(http.MethodPost, "/api/auto-register", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer "+rawToken) + rec := httptest.NewRecorder() + router.Handler().ServeHTTP(rec, req) + if rec.Code == http.StatusUnauthorized { + t.Fatalf("expected bearer agent token with host-agent:report to be accepted, got 401") + } +} + func TestConfigExportRequiresProxyAdmin(t *testing.T) { cfg := newTestConfigWithTokens(t) cfg.ProxyAuthSecret = "proxy-secret" diff --git a/internal/api/tenant_agent_auth_test.go b/internal/api/tenant_agent_auth_test.go new file mode 100644 index 000000000..173042cc3 --- /dev/null +++ b/internal/api/tenant_agent_auth_test.go @@ -0,0 +1,20 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestedOrgID_V5SingleTenantModeIgnoresHostAgentHeader(t *testing.T) { + prevSingleTenant := isV5SingleTenantMode() + setV5SingleTenantModeForTests(true) + t.Cleanup(func() { setV5SingleTenantModeForTests(prevSingleTenant) }) + + req := httptest.NewRequest(http.MethodGet, "/api/agents/host/lookup?hostname=missing-host", nil) + req.Header.Set("X-Pulse-Org-ID", "acme") + + if got := requestedOrgID(req); got != "default" { + t.Fatalf("expected host agent path to collapse to default org in single-tenant mode, got %q", got) + } +} diff --git a/internal/ceph/collector.go b/internal/ceph/collector.go index 13df19103..040ef2d01 100644 --- a/internal/ceph/collector.go +++ b/internal/ceph/collector.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "os/exec" + "sort" "strings" "time" ) @@ -195,21 +196,35 @@ func parseStatus(data []byte) (*ClusterStatus, error) { } `json:"checks"` } `json:"health"` MonMap struct { - Epoch int `json:"epoch"` - Mons []struct { + Epoch int `json:"epoch"` + NumMons int `json:"num_mons"` + QuorumNames []string `json:"quorum_names"` + Mons []struct { Name string `json:"name"` Rank int `json:"rank"` Addr string `json:"addr"` } `json:"mons"` } `json:"monmap"` MgrMap struct { - Available bool `json:"available"` - NumActive int `json:"num_active_name,omitempty"` - ActiveName string `json:"active_name"` - Standbys []struct { + Available bool `json:"available"` + ActiveName string `json:"active_name"` + NumStandbys int `json:"num_standbys"` + NumStandby int `json:"num_standby"` + Standbys []struct { Name string `json:"name"` } `json:"standbys"` } `json:"mgrmap"` + ServiceMap struct { + Services map[string]struct { + Daemons map[string]struct { + Status string `json:"status"` + Hostname string `json:"hostname"` + Metadata struct { + Hostname string `json:"hostname"` + } `json:"metadata"` + } `json:"daemons"` + } `json:"services"` + } `json:"servicemap"` OSDMap struct { Epoch int `json:"epoch"` NumOSD int `json:"num_osds"` @@ -235,6 +250,63 @@ func parseStatus(data []byte) (*ClusterStatus, error) { return nil, fmt.Errorf("unmarshal: %w", err) } + monitors := make([]Monitor, 0, len(raw.MonMap.Mons)) + for _, mon := range raw.MonMap.Mons { + monitors = append(monitors, Monitor{ + Name: mon.Name, + Rank: mon.Rank, + Addr: mon.Addr, + }) + } + if len(monitors) == 0 && len(raw.MonMap.QuorumNames) > 0 { + for i, name := range raw.MonMap.QuorumNames { + monitors = append(monitors, Monitor{ + Name: name, + Rank: i, + Status: "up", + }) + } + } + + monCount := len(monitors) + if raw.MonMap.NumMons > monCount { + monCount = raw.MonMap.NumMons + } + + standbyCount := len(raw.MgrMap.Standbys) + if raw.MgrMap.NumStandbys > standbyCount { + standbyCount = raw.MgrMap.NumStandbys + } + if raw.MgrMap.NumStandby > standbyCount { + standbyCount = raw.MgrMap.NumStandby + } + + managerCount := standbyCount + if raw.MgrMap.ActiveName != "" || raw.MgrMap.Available { + managerCount++ + } + + services := buildServiceSummary(raw.ServiceMap.Services) + if monCount == 0 { + if monService, ok := services["mon"]; ok && monService.Total > 0 { + monCount = monService.Total + if len(monitors) == 0 { + for i, daemon := range monService.Daemons { + monitors = append(monitors, Monitor{ + Name: daemon, + Rank: i, + Status: "unknown", + }) + } + } + } + } else if monService, ok := services["mon"]; ok && monService.Total > monCount { + monCount = monService.Total + } + if mgrService, ok := services["mgr"]; ok && mgrService.Total > managerCount { + managerCount = mgrService.Total + } + status := &ClusterStatus{ FSID: raw.FSID, Health: HealthStatus{ @@ -242,14 +314,15 @@ func parseStatus(data []byte) (*ClusterStatus, error) { Checks: make(map[string]Check), }, MonMap: MonitorMap{ - Epoch: raw.MonMap.Epoch, - NumMons: len(raw.MonMap.Mons), + Epoch: raw.MonMap.Epoch, + NumMons: monCount, + Monitors: monitors, }, MgrMap: ManagerMap{ Available: raw.MgrMap.Available, - NumMgrs: 1 + len(raw.MgrMap.Standbys), + NumMgrs: managerCount, ActiveMgr: raw.MgrMap.ActiveName, - Standbys: len(raw.MgrMap.Standbys), + Standbys: standbyCount, }, OSDMap: OSDMap{ Epoch: raw.OSDMap.Epoch, @@ -279,15 +352,6 @@ func parseStatus(data []byte) (*ClusterStatus, error) { status.PGMap.UsagePercent = float64(raw.PGMap.BytesUsed) / float64(raw.PGMap.BytesTotal) * 100 } - // Parse monitors - for _, mon := range raw.MonMap.Mons { - status.MonMap.Monitors = append(status.MonMap.Monitors, Monitor{ - Name: mon.Name, - Rank: mon.Rank, - Addr: mon.Addr, - }) - } - // Parse health checks for name, check := range raw.Health.Checks { details := make([]string, 0, len(check.Detail)) @@ -301,16 +365,118 @@ func parseStatus(data []byte) (*ClusterStatus, error) { } } - // Build service summary - status.Services = []ServiceInfo{ - {Type: "mon", Running: len(raw.MonMap.Mons), Total: len(raw.MonMap.Mons)}, - {Type: "mgr", Running: boolToInt(raw.MgrMap.Available), Total: status.MgrMap.NumMgrs}, - {Type: "osd", Running: raw.OSDMap.NumUp, Total: raw.OSDMap.NumOSD}, + if len(services) > 0 { + keys := make([]string, 0, len(services)) + for serviceType := range services { + keys = append(keys, serviceType) + } + sort.Strings(keys) + status.Services = make([]ServiceInfo, 0, len(keys)+1) + for _, serviceType := range keys { + status.Services = append(status.Services, services[serviceType]) + } + } else { + status.Services = []ServiceInfo{ + {Type: "mon", Running: status.MonMap.NumMons, Total: status.MonMap.NumMons, Daemons: monitorNames(status.MonMap.Monitors)}, + {Type: "mgr", Running: boolToInt(raw.MgrMap.Available), Total: status.MgrMap.NumMgrs, Daemons: managerNames(raw.MgrMap.ActiveName, raw.MgrMap.Standbys)}, + } + } + if !serviceInfoExists(status.Services, "osd") { + status.Services = append(status.Services, ServiceInfo{Type: "osd", Running: raw.OSDMap.NumUp, Total: raw.OSDMap.NumOSD}) } return status, nil } +func buildServiceSummary(services map[string]struct { + Daemons map[string]struct { + Status string `json:"status"` + Hostname string `json:"hostname"` + Metadata struct { + Hostname string `json:"hostname"` + } `json:"metadata"` + } `json:"daemons"` +}) map[string]ServiceInfo { + if len(services) == 0 { + return nil + } + + result := make(map[string]ServiceInfo, len(services)) + for serviceType, definition := range services { + summary := ServiceInfo{Type: serviceType} + if len(definition.Daemons) > 0 { + daemonNames := make([]string, 0, len(definition.Daemons)) + for daemonName, daemon := range definition.Daemons { + summary.Total++ + if isServiceRunning(daemon.Status) { + summary.Running++ + } + label := daemonName + host := strings.TrimSpace(daemon.Hostname) + if host == "" { + host = strings.TrimSpace(daemon.Metadata.Hostname) + } + if host != "" { + label = fmt.Sprintf("%s@%s", daemonName, host) + } + daemonNames = append(daemonNames, label) + } + sort.Strings(daemonNames) + summary.Daemons = daemonNames + } + result[serviceType] = summary + } + + return result +} + +func isServiceRunning(status string) bool { + switch strings.ToLower(strings.TrimSpace(status)) { + case "running", "active", "up": + return true + default: + return false + } +} + +func monitorNames(monitors []Monitor) []string { + if len(monitors) == 0 { + return nil + } + + names := make([]string, 0, len(monitors)) + for _, mon := range monitors { + if strings.TrimSpace(mon.Name) != "" { + names = append(names, mon.Name) + } + } + return names +} + +func managerNames(activeName string, standbys []struct { + Name string `json:"name"` +}) []string { + names := make([]string, 0, 1+len(standbys)) + if strings.TrimSpace(activeName) != "" { + names = append(names, activeName) + } + for _, standby := range standbys { + if strings.TrimSpace(standby.Name) != "" { + names = append(names, standby.Name) + } + } + return names +} + +func serviceInfoExists(services []ServiceInfo, serviceType string) bool { + for _, service := range services { + if service.Type == serviceType { + return true + } + } + return false +} + // parseDF parses the output of `ceph df --format json`. func parseDF(data []byte) ([]Pool, float64, error) { var raw struct { diff --git a/internal/ceph/collector_test.go b/internal/ceph/collector_test.go index 5afb8d0b2..1d505a974 100644 --- a/internal/ceph/collector_test.go +++ b/internal/ceph/collector_test.go @@ -93,6 +93,103 @@ func TestParseStatus(t *testing.T) { } } +func TestParseStatus_CountOnlyFallbacks(t *testing.T) { + data := []byte(`{ + "fsid":"fsid-counts", + "health":{"status":"HEALTH_OK","checks":{}}, + "monmap":{"epoch":9,"num_mons":3,"quorum_names":["mon-a","mon-b","mon-c"]}, + "mgrmap":{"available":true,"active_name":"mgr-a","num_standbys":1}, + "osdmap":{"epoch":4,"num_osds":6,"num_up_osds":6,"num_in_osds":6}, + "pgmap":{"num_pgs":128,"bytes_total":1000,"bytes_used":100,"bytes_avail":900} + }`) + + status, err := parseStatus(data) + if err != nil { + t.Fatalf("parseStatus returned error: %v", err) + } + + if status.MonMap.NumMons != 3 { + t.Fatalf("NumMons = %d, want 3", status.MonMap.NumMons) + } + if len(status.MonMap.Monitors) != 3 { + t.Fatalf("len(Monitors) = %d, want 3", len(status.MonMap.Monitors)) + } + if status.MgrMap.NumMgrs != 2 { + t.Fatalf("NumMgrs = %d, want 2", status.MgrMap.NumMgrs) + } + if status.MgrMap.Standbys != 1 { + t.Fatalf("Standbys = %d, want 1", status.MgrMap.Standbys) + } + + var monSvc, mgrSvc *ServiceInfo + for i := range status.Services { + switch status.Services[i].Type { + case "mon": + monSvc = &status.Services[i] + case "mgr": + mgrSvc = &status.Services[i] + } + } + if monSvc == nil || monSvc.Total != 3 { + t.Fatalf("mon service = %+v, want total 3", monSvc) + } + if mgrSvc == nil || mgrSvc.Total != 2 { + t.Fatalf("mgr service = %+v, want total 2", mgrSvc) + } +} + +func TestParseStatus_ServiceMapFallbacks(t *testing.T) { + data := []byte(`{ + "fsid":"fsid-servicemap", + "health":{"status":"HEALTH_OK","checks":{}}, + "monmap":{"epoch":1}, + "mgrmap":{"available":true}, + "servicemap":{ + "services":{ + "mon":{"daemons":{ + "a":{"status":"running","hostname":"node1"}, + "b":{"status":"running","hostname":"node2"}, + "c":{"status":"stopped","hostname":"node3"} + }}, + "mgr":{"daemons":{ + "mgr-a":{"status":"active","hostname":"node1"}, + "mgr-b":{"status":"standby","hostname":"node2"} + }} + } + }, + "osdmap":{"epoch":3,"num_osds":3,"num_up_osds":3,"num_in_osds":3}, + "pgmap":{"num_pgs":64,"bytes_total":1000,"bytes_used":100,"bytes_avail":900} + }`) + + status, err := parseStatus(data) + if err != nil { + t.Fatalf("parseStatus returned error: %v", err) + } + + if status.MonMap.NumMons != 3 { + t.Fatalf("NumMons = %d, want 3", status.MonMap.NumMons) + } + if status.MgrMap.NumMgrs != 2 { + t.Fatalf("NumMgrs = %d, want 2", status.MgrMap.NumMgrs) + } + + var monSvc, mgrSvc *ServiceInfo + for i := range status.Services { + switch status.Services[i].Type { + case "mon": + monSvc = &status.Services[i] + case "mgr": + mgrSvc = &status.Services[i] + } + } + if monSvc == nil || monSvc.Total != 3 || monSvc.Running != 2 { + t.Fatalf("mon service = %+v, want total 3 running 2", monSvc) + } + if mgrSvc == nil || mgrSvc.Total != 2 || mgrSvc.Running != 1 { + t.Fatalf("mgr service = %+v, want total 2 running 1", mgrSvc) + } +} + func TestIsAvailable(t *testing.T) { t.Run("available", func(t *testing.T) { withCommandRunner(t, func(ctx context.Context, name string, args ...string) ([]byte, []byte, error) { diff --git a/internal/hostagent/commands.go b/internal/hostagent/commands.go index a76b9366d..734e5c674 100644 --- a/internal/hostagent/commands.go +++ b/internal/hostagent/commands.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net/http" "net/url" "os" "os/exec" @@ -175,7 +176,8 @@ func (c *CommandClient) connectAndHandle(ctx context.Context) error { } // Connect - conn, _, err := dialer.DialContext(ctx, wsURL, nil) + headers := make(http.Header) + conn, _, err := dialer.DialContext(ctx, wsURL, headers) if err != nil { return fmt.Errorf("dial websocket: %w", err) } diff --git a/internal/hostagent/proxmox_setup.go b/internal/hostagent/proxmox_setup.go index 33c78d2b0..c1eca28e0 100644 --- a/internal/hostagent/proxmox_setup.go +++ b/internal/hostagent/proxmox_setup.go @@ -743,7 +743,10 @@ func (p *ProxmoxSetup) doRegisterRequest(ctx context.Context, body []byte) error return fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-API-Token", p.apiToken) + if token := strings.TrimSpace(p.apiToken); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("X-API-Token", token) + } resp, err := p.httpClient.Do(req) if err != nil { diff --git a/internal/hostagent/proxmox_setup_test.go b/internal/hostagent/proxmox_setup_test.go index 50208d598..3576dda02 100644 --- a/internal/hostagent/proxmox_setup_test.go +++ b/internal/hostagent/proxmox_setup_test.go @@ -333,7 +333,11 @@ func TestProxmoxSetup_RunForType(t *testing.T) { func TestRegisterWithPulseRetry(t *testing.T) { // Server returns 503 twice, then 200 var attempt int32 + var gotAuth string + var gotAPIToken string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotAPIToken = r.Header.Get("X-API-Token") n := atomic.AddInt32(&attempt, 1) if n <= 2 { w.WriteHeader(http.StatusServiceUnavailable) @@ -359,6 +363,12 @@ func TestRegisterWithPulseRetry(t *testing.T) { if atomic.LoadInt32(&attempt) != 3 { t.Errorf("expected 3 attempts, got %d", atomic.LoadInt32(&attempt)) } + if gotAuth != "Bearer test-token" { + t.Fatalf("Authorization = %q, want %q", gotAuth, "Bearer test-token") + } + if gotAPIToken != "test-token" { + t.Fatalf("X-API-Token = %q, want %q", gotAPIToken, "test-token") + } } func TestRegisterWithPulseNoRetryOn4xx(t *testing.T) { diff --git a/internal/models/models.go b/internal/models/models.go index d28b7a512..f43777434 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -2302,6 +2302,20 @@ func (s *State) GetHosts() []Host { return hosts } +// GetNodeLinkedHostAgentID returns the linked host agent ID for the given node. +func (s *State) GetNodeLinkedHostAgentID(nodeID string) string { + s.mu.RLock() + defer s.mu.RUnlock() + + for _, node := range s.Nodes { + if node.ID == nodeID { + return node.LinkedHostAgentID + } + } + + return "" +} + // RemoveHost removes a host by ID and returns the removed entry. func (s *State) RemoveHost(hostID string) (Host, bool) { s.mu.Lock() diff --git a/internal/monitoring/diagnostic_snapshots.go b/internal/monitoring/diagnostic_snapshots.go index 1ed9a9279..98dd71ec4 100644 --- a/internal/monitoring/diagnostic_snapshots.go +++ b/internal/monitoring/diagnostic_snapshots.go @@ -50,6 +50,8 @@ type VMMemoryRaw struct { StatusMem uint64 `json:"statusMem,omitempty"` StatusFreeMem uint64 `json:"statusFreemem,omitempty"` StatusMaxMem uint64 `json:"statusMaxmem,omitempty"` + RRDAvailable uint64 `json:"rrdAvailable,omitempty"` + RRDUsed uint64 `json:"rrdUsed,omitempty"` Balloon uint64 `json:"balloon,omitempty"` BalloonMin uint64 `json:"balloonMin,omitempty"` MemInfoUsed uint64 `json:"meminfoUsed,omitempty"` diff --git a/internal/monitoring/host_agent_temps.go b/internal/monitoring/host_agent_temps.go index 23d24d8c6..defa637e6 100644 --- a/internal/monitoring/host_agent_temps.go +++ b/internal/monitoring/host_agent_temps.go @@ -36,10 +36,28 @@ func (m *Monitor) getHostAgentTemperatureByID(nodeID, nodeName string) *models.T var matchedHost *models.Host - // First, try to find a host agent that is explicitly linked to this node - // via LinkedNodeID. This is the most reliable method and handles duplicate - // hostnames correctly. + // First, try to resolve through the node's canonical linked host agent ID. + // This is the link preserved by node refreshes and is more reliable than + // host-side hostname matching when node IDs change or FQDN/short names differ. if nodeID != "" { + if linkedHostID := m.state.GetNodeLinkedHostAgentID(nodeID); linkedHostID != "" { + for i := range hosts { + if hosts[i].ID == linkedHostID { + matchedHost = &hosts[i] + log.Debug(). + Str("nodeID", nodeID). + Str("hostAgentID", hosts[i].ID). + Str("hostname", hosts[i].Hostname). + Msg("Matched host agent to node via LinkedHostAgentID") + break + } + } + } + } + + // Fallback: try to find a host agent that is explicitly linked to this node + // via LinkedNodeID. This maintains compatibility with older host-side links. + if matchedHost == nil && nodeID != "" { for i := range hosts { if hosts[i].LinkedNodeID == nodeID { matchedHost = &hosts[i] @@ -54,11 +72,12 @@ func (m *Monitor) getHostAgentTemperatureByID(nodeID, nodeName string) *models.T } // Fallback: match by hostname if no linked host was found - // This maintains backwards compatibility for setups where linking hasn't occurred yet + // This maintains backwards compatibility for setups where linking hasn't occurred yet. + // Compare both FQDN and short hostname forms so the behavior matches node auto-linking. if matchedHost == nil { - nodeLower := strings.ToLower(strings.TrimSpace(nodeName)) + nodeLower := normalizeHostAgentNodeName(nodeName) for i := range hosts { - hostnameLower := strings.ToLower(strings.TrimSpace(hosts[i].Hostname)) + hostnameLower := normalizeHostAgentNodeName(hosts[i].Hostname) if hostnameLower == nodeLower { matchedHost = &hosts[i] break @@ -70,8 +89,9 @@ func (m *Monitor) getHostAgentTemperatureByID(nodeID, nodeName string) *models.T return nil } - // Check if the host agent has temperature data - if len(matchedHost.Sensors.TemperatureCelsius) == 0 { + // Allow SMART-only sensor payloads from platforms like FreeBSD where + // CPU sensors may be unavailable but disk temperature data is still valid. + if len(matchedHost.Sensors.TemperatureCelsius) == 0 && len(matchedHost.Sensors.SMART) == 0 { return nil } @@ -79,6 +99,14 @@ func (m *Monitor) getHostAgentTemperatureByID(nodeID, nodeName string) *models.T return convertHostSensorsToTemperature(matchedHost.Sensors, matchedHost.LastSeen) } +func normalizeHostAgentNodeName(name string) string { + normalized := strings.ToLower(strings.TrimSpace(name)) + if idx := strings.Index(normalized, "."); idx > 0 { + return normalized[:idx] + } + return normalized +} + // convertHostSensorsToTemperature converts HostSensorSummary to the Temperature model. // The host agent reports temperatures in a flat map with keys like: // - "cpu_package" -> CPU package temperature @@ -86,10 +114,6 @@ func (m *Monitor) getHostAgentTemperatureByID(nodeID, nodeName string) *models.T // - "nvme0", "nvme1", etc. -> NVMe temperatures // - "gpu_edge", "gpu_junction", etc. -> GPU temperatures func convertHostSensorsToTemperature(sensors models.HostSensorSummary, lastSeen time.Time) *models.Temperature { - if len(sensors.TemperatureCelsius) == 0 { - return nil - } - temp := &models.Temperature{ Available: true, LastUpdate: lastSeen, diff --git a/internal/monitoring/host_agent_temps_test.go b/internal/monitoring/host_agent_temps_test.go index cf75acf25..f77fb73cd 100644 --- a/internal/monitoring/host_agent_temps_test.go +++ b/internal/monitoring/host_agent_temps_test.go @@ -76,6 +76,29 @@ func TestConvertHostSensorsToTemperature_NVMe(t *testing.T) { } } +func TestConvertHostSensorsToTemperature_SMARTOnly(t *testing.T) { + sensors := models.HostSensorSummary{ + SMART: []models.HostDiskSMART{ + {Device: "ada0", Temperature: 34, Serial: "ABC123"}, + }, + } + + result := convertHostSensorsToTemperature(sensors, time.Now()) + + if result == nil { + t.Fatal("expected non-nil result") + } + if !result.HasSMART { + t.Error("expected HasSMART to be true") + } + if len(result.SMART) != 1 { + t.Fatalf("expected 1 SMART disk, got %d", len(result.SMART)) + } + if result.SMART[0].Device != "/dev/ada0" { + t.Fatalf("expected /dev/ada0, got %q", result.SMART[0].Device) + } +} + func TestConvertHostSensorsToTemperature_GPU(t *testing.T) { sensors := models.HostSensorSummary{ TemperatureCelsius: map[string]float64{ @@ -327,6 +350,26 @@ func TestGetHostAgentTemperature(t *testing.T) { assert.Equal(t, 60.0, result.CPUPackage) }) + t.Run("match by node linked host agent id before hostname fallback", func(t *testing.T) { + m.state.UpdateNodes([]models.Node{{ + ID: "node-456", + Name: "pve01", + LinkedHostAgentID: "host-linked", + }}) + host := models.Host{ + ID: "host-linked", + Hostname: "pve01.example.com", + Sensors: models.HostSensorSummary{ + TemperatureCelsius: map[string]float64{"cpu_package": 61.0}, + }, + } + m.state.UpsertHost(host) + + result := m.getHostAgentTemperatureByID("node-456", "pve01") + assert.NotNil(t, result) + assert.Equal(t, 61.0, result.CPUPackage) + }) + t.Run("match by hostname fallback", func(t *testing.T) { host := models.Host{ ID: "host2", @@ -342,6 +385,21 @@ func TestGetHostAgentTemperature(t *testing.T) { assert.Equal(t, 65.0, result.CPUPackage) }) + t.Run("match by hostname fallback with fqdn host", func(t *testing.T) { + host := models.Host{ + ID: "host-fqdn", + Hostname: "node-fqdn.example.com", + Sensors: models.HostSensorSummary{ + TemperatureCelsius: map[string]float64{"cpu_package": 66.0}, + }, + } + m.state.UpsertHost(host) + + result := m.getHostAgentTemperature("node-fqdn") + assert.NotNil(t, result) + assert.Equal(t, 66.0, result.CPUPackage) + }) + t.Run("no matching host", func(t *testing.T) { result := m.getHostAgentTemperature("node-missing") assert.Nil(t, result) @@ -356,6 +414,24 @@ func TestGetHostAgentTemperature(t *testing.T) { result := m.getHostAgentTemperature("node3") assert.Nil(t, result) }) + + t.Run("matching host with SMART-only data", func(t *testing.T) { + host := models.Host{ + ID: "host4", + Hostname: "node4", + Sensors: models.HostSensorSummary{ + SMART: []models.HostDiskSMART{ + {Device: "ada0", Temperature: 36}, + }, + }, + } + m.state.UpsertHost(host) + result := m.getHostAgentTemperature("node4") + assert.NotNil(t, result) + assert.True(t, result.HasSMART) + assert.Len(t, result.SMART, 1) + assert.Equal(t, "/dev/ada0", result.SMART[0].Device) + }) } func TestConvertHostSensorsToTemperature_ExtraBranches(t *testing.T) { diff --git a/internal/monitoring/monitor.go b/internal/monitoring/monitor.go index 8ea89505f..8e718f6b9 100644 --- a/internal/monitoring/monitor.go +++ b/internal/monitoring/monitor.go @@ -586,11 +586,12 @@ func ensureClusterEndpointURL(raw string) string { return "https://" + net.JoinHostPort(value, "8006") } -func clusterEndpointEffectiveURL(endpoint config.ClusterEndpoint, verifySSL bool, hasFingerprint bool) string { +func clusterEndpointEffectiveURL(endpoint config.ClusterEndpoint, verifySSL bool, baseFingerprint string) string { // When TLS hostname verification is required (VerifySSL=true and no fingerprint), // prefer hostname over IP to ensure certificate CN/SAN validation works correctly. // When TLS is not verified (VerifySSL=false) or a fingerprint is provided (which // bypasses hostname checks), prefer IP to reduce DNS lookups (refs #620). + hasFingerprint := strings.TrimSpace(endpoint.Fingerprint) != "" || strings.TrimSpace(baseFingerprint) != "" requiresHostnameForTLS := verifySSL && !hasFingerprint // Use EffectiveIP() which prefers user-specified IPOverride over auto-discovered IP @@ -616,6 +617,72 @@ func clusterEndpointEffectiveURL(endpoint config.ClusterEndpoint, verifySSL bool return "" } +func buildClusterClientEndpoints(pve config.PVEInstance) ([]string, map[string]string) { + endpoints := make([]string, 0, len(pve.ClusterEndpoints)+1) + endpointFingerprints := make(map[string]string) + hasValidEndpoints := false + + for _, ep := range pve.ClusterEndpoints { + effectiveURL := clusterEndpointEffectiveURL(ep, pve.VerifySSL, pve.Fingerprint) + if effectiveURL == "" { + log.Warn(). + Str("node", ep.NodeName). + Msg("Skipping cluster endpoint with no host/IP") + continue + } + + if parsed, err := url.Parse(effectiveURL); err == nil { + hostname := parsed.Hostname() + if hostname != "" && (strings.Contains(hostname, ".") || net.ParseIP(hostname) != nil) { + hasValidEndpoints = true + } + } else { + hostname := normalizeEndpointHost(effectiveURL) + if hostname != "" && (strings.Contains(hostname, ".") || net.ParseIP(hostname) != nil) { + hasValidEndpoints = true + } + } + + endpoints = append(endpoints, effectiveURL) + if ep.Fingerprint != "" { + endpointFingerprints[effectiveURL] = ep.Fingerprint + } + } + + if !hasValidEndpoints || len(endpoints) == 0 { + fallback := ensureClusterEndpointURL(pve.Host) + if fallback != "" { + log.Info(). + Str("instance", pve.Name). + Str("mainHost", pve.Host). + Msg("Cluster endpoints are not resolvable, using main host for all cluster operations") + return []string{fallback}, endpointFingerprints + } + return nil, endpointFingerprints + } + + mainHostURL := ensureClusterEndpointURL(pve.Host) + if mainHostURL != "" { + mainHostAlreadyIncluded := false + for _, ep := range endpoints { + if ep == mainHostURL { + mainHostAlreadyIncluded = true + break + } + } + if !mainHostAlreadyIncluded { + log.Info(). + Str("instance", pve.Name). + Str("mainHost", mainHostURL). + Int("clusterEndpoints", len(endpoints)). + Msg("Adding main host as fallback for remote cluster access") + endpoints = append(endpoints, mainHostURL) + } + } + + return endpoints, endpointFingerprints +} + // PollExecutor defines the contract for executing polling tasks. type PollExecutor interface { Execute(ctx context.Context, task PollTask) @@ -1269,12 +1336,12 @@ func (m *Monitor) getNodeRRDMetrics(ctx context.Context, client PVEClientInterfa return entry, nil } -// getVMRRDMetrics fetches Proxmox RRD memavailable for a single VM with a +// getVMRRDMetrics fetches Proxmox RRD memory metrics for a single VM with a // short-lived cache to avoid a live API call on every poll for VMs that // consistently lack guest-agent memory data (e.g. Windows VMs). -func (m *Monitor) getVMRRDMetrics(ctx context.Context, client PVEClientInterface, instance, node string, vmid int) (uint64, error) { +func (m *Monitor) getVMRRDMetrics(ctx context.Context, client PVEClientInterface, instance, node string, vmid int) (rrdMemCacheEntry, error) { if client == nil || node == "" || vmid <= 0 { - return 0, fmt.Errorf("invalid arguments for VM RRD lookup") + return rrdMemCacheEntry{}, fmt.Errorf("invalid arguments for VM RRD lookup") } cacheKey := fmt.Sprintf("%s/%s/%d", instance, node, vmid) @@ -1283,39 +1350,75 @@ func (m *Monitor) getVMRRDMetrics(ctx context.Context, client PVEClientInterface m.rrdCacheMu.RLock() if entry, ok := m.vmRRDMemCache[cacheKey]; ok && now.Sub(entry.fetchedAt) < nodeRRDCacheTTL { m.rrdCacheMu.RUnlock() - return entry.available, nil + if entry.negative { + return rrdMemCacheEntry{}, fmt.Errorf("vm RRD mem read previously failed (negative cache)") + } + return entry, nil } m.rrdCacheMu.RUnlock() requestCtx, cancel := context.WithTimeout(ctx, nodeRRDRequestTimeout) defer cancel() - points, err := client.GetVMRRDData(requestCtx, node, vmid, "hour", "AVERAGE", []string{"memavailable"}) + points, err := client.GetVMRRDData(requestCtx, node, vmid, "hour", "AVERAGE", []string{"memavailable", "memused", "maxmem"}) if err != nil { - return 0, err + m.rrdCacheMu.Lock() + m.vmRRDMemCache[cacheKey] = rrdMemCacheEntry{negative: true, fetchedAt: now} + m.rrdCacheMu.Unlock() + return rrdMemCacheEntry{}, err } if len(points) == 0 { - return 0, fmt.Errorf("no RRD points for VM %s/%d", node, vmid) + m.rrdCacheMu.Lock() + m.vmRRDMemCache[cacheKey] = rrdMemCacheEntry{negative: true, fetchedAt: now} + m.rrdCacheMu.Unlock() + return rrdMemCacheEntry{}, fmt.Errorf("no RRD points for VM %s/%d", node, vmid) } + var memUsed uint64 + var memTotal uint64 var memAvailable uint64 for i := len(points) - 1; i >= 0; i-- { p := points[i] + if memTotal == 0 && p.MaxMem != nil && !math.IsNaN(*p.MaxMem) && *p.MaxMem > 0 { + memTotal = uint64(math.Round(*p.MaxMem)) + } if p.MemAvailable != nil && !math.IsNaN(*p.MemAvailable) && *p.MemAvailable > 0 { memAvailable = uint64(math.Round(*p.MemAvailable)) + } + if memUsed == 0 && p.MemUsed != nil && !math.IsNaN(*p.MemUsed) && *p.MemUsed > 0 { + memUsed = uint64(math.Round(*p.MemUsed)) + } + if memAvailable > 0 && (memUsed > 0 || memTotal > 0) { break } } - if memAvailable == 0 { - return 0, fmt.Errorf("rrd memavailable not present for VM %s/%d", node, vmid) + if memTotal > 0 { + if memAvailable > memTotal { + memAvailable = memTotal + } + if memUsed > memTotal { + memUsed = memTotal + } } - entry := rrdMemCacheEntry{available: memAvailable, fetchedAt: now} + if memAvailable == 0 && memUsed == 0 { + m.rrdCacheMu.Lock() + m.vmRRDMemCache[cacheKey] = rrdMemCacheEntry{negative: true, fetchedAt: now} + m.rrdCacheMu.Unlock() + return rrdMemCacheEntry{}, fmt.Errorf("rrd VM memory metrics not present for VM %s/%d", node, vmid) + } + + entry := rrdMemCacheEntry{ + available: memAvailable, + used: memUsed, + total: memTotal, + fetchedAt: now, + } m.rrdCacheMu.Lock() m.vmRRDMemCache[cacheKey] = entry m.rrdCacheMu.Unlock() - return memAvailable, nil + return entry, nil } // getVMAgentMemAvailable reads MemAvailable via the QEMU guest agent file-read @@ -4180,75 +4283,7 @@ func New(cfg *config.Config) (*Monitor, error) { // Check if this is a cluster if pve.IsCluster && len(pve.ClusterEndpoints) > 0 { - // For clusters, check if endpoints have IPs/resolvable hosts - // If not, use the main host for all connections (Proxmox will route cluster API calls) - hasValidEndpoints := false - endpoints := make([]string, 0, len(pve.ClusterEndpoints)) - endpointFingerprints := make(map[string]string) - - for _, ep := range pve.ClusterEndpoints { - hasFingerprint := pve.Fingerprint != "" - effectiveURL := clusterEndpointEffectiveURL(ep, pve.VerifySSL, hasFingerprint) - if effectiveURL == "" { - log.Warn(). - Str("node", ep.NodeName). - Msg("Skipping cluster endpoint with no host/IP") - continue - } - - if parsed, err := url.Parse(effectiveURL); err == nil { - hostname := parsed.Hostname() - if hostname != "" && (strings.Contains(hostname, ".") || net.ParseIP(hostname) != nil) { - hasValidEndpoints = true - } - } else { - hostname := normalizeEndpointHost(effectiveURL) - if hostname != "" && (strings.Contains(hostname, ".") || net.ParseIP(hostname) != nil) { - hasValidEndpoints = true - } - } - - endpoints = append(endpoints, effectiveURL) - // Store per-endpoint fingerprint for TOFU (Trust On First Use) - if ep.Fingerprint != "" { - endpointFingerprints[effectiveURL] = ep.Fingerprint - } - } - - // If endpoints are just node names (not FQDNs or IPs), use main host only - // This is common when cluster nodes are discovered but not directly reachable - if !hasValidEndpoints || len(endpoints) == 0 { - log.Info(). - Str("instance", pve.Name). - Str("mainHost", pve.Host). - Msg("Cluster endpoints are not resolvable, using main host for all cluster operations") - fallback := ensureClusterEndpointURL(pve.Host) - if fallback == "" { - fallback = ensureClusterEndpointURL(pve.Host) - } - endpoints = []string{fallback} - } else { - // Always include the main host URL as a fallback endpoint. - // This handles remote cluster scenarios where Proxmox reports internal IPs - // that aren't reachable from Pulse's network. The user-provided URL is - // reachable, so include it as a fallback for cluster API routing. - mainHostURL := ensureClusterEndpointURL(pve.Host) - mainHostAlreadyIncluded := false - for _, ep := range endpoints { - if ep == mainHostURL { - mainHostAlreadyIncluded = true - break - } - } - if !mainHostAlreadyIncluded && mainHostURL != "" { - log.Info(). - Str("instance", pve.Name). - Str("mainHost", mainHostURL). - Int("clusterEndpoints", len(endpoints)). - Msg("Adding main host as fallback for remote cluster access") - endpoints = append(endpoints, mainHostURL) - } - } + endpoints, endpointFingerprints := buildClusterClientEndpoints(pve) log.Info(). Str("cluster", pve.ClusterName). @@ -4969,39 +5004,7 @@ func (m *Monitor) retryFailedConnections(ctx context.Context) { // Try to recreate PVE clients for _, pve := range missingPVE { if pve.IsCluster && len(pve.ClusterEndpoints) > 0 { - // Create cluster client - hasValidEndpoints := false - endpoints := make([]string, 0, len(pve.ClusterEndpoints)) - endpointFingerprints := make(map[string]string) - - for _, ep := range pve.ClusterEndpoints { - // Use EffectiveIP() which prefers IPOverride over auto-discovered IP - host := ep.EffectiveIP() - if host == "" { - host = ep.Host - } - if host == "" { - continue - } - if strings.Contains(host, ".") || net.ParseIP(host) != nil { - hasValidEndpoints = true - } - if !strings.HasPrefix(host, "http") { - host = fmt.Sprintf("https://%s:8006", host) - } - endpoints = append(endpoints, host) - // Store per-endpoint fingerprint for TOFU - if ep.Fingerprint != "" { - endpointFingerprints[host] = ep.Fingerprint - } - } - - if !hasValidEndpoints || len(endpoints) == 0 { - endpoints = []string{pve.Host} - if !strings.HasPrefix(endpoints[0], "http") { - endpoints[0] = fmt.Sprintf("https://%s:8006", endpoints[0]) - } - } + endpoints, endpointFingerprints := buildClusterClientEndpoints(pve) clientConfig := config.CreateProxmoxConfig(&pve) clientConfig.Timeout = m.config.ConnectionTimeout @@ -6826,7 +6829,6 @@ func (m *Monitor) pollPVEInstance(ctx context.Context, instanceName string, clie } // Update the online status for each cluster endpoint - hasFingerprint := instanceCfg.Fingerprint != "" for i := range instanceCfg.ClusterEndpoints { if online, exists := onlineNodes[instanceCfg.ClusterEndpoints[i].NodeName]; exists { instanceCfg.ClusterEndpoints[i].Online = online @@ -6838,7 +6840,7 @@ func (m *Monitor) pollPVEInstance(ctx context.Context, instanceName string, clie // Update Pulse connectivity status if pulseHealth != nil { // Try to find the endpoint in the health map by matching the effective URL - endpointURL := clusterEndpointEffectiveURL(instanceCfg.ClusterEndpoints[i], instanceCfg.VerifySSL, hasFingerprint) + endpointURL := clusterEndpointEffectiveURL(instanceCfg.ClusterEndpoints[i], instanceCfg.VerifySSL, instanceCfg.Fingerprint) if health, exists := pulseHealth[endpointURL]; exists { reachable := health.Healthy instanceCfg.ClusterEndpoints[i].PulseReachable = &reachable @@ -7044,6 +7046,10 @@ func (m *Monitor) pollVMsAndContainersEfficient(ctx context.Context, instanceNam ListingMaxMem: res.MaxMem, } var detailedStatus *proxmox.VMStatus + memAvailable := uint64(0) + memInfoTotalMinusUsed := uint64(0) + rrdUsed := uint64(0) + agentEnabled := false // Try to get actual disk usage from guest agent if VM is running diskUsed := res.Disk @@ -7078,8 +7084,7 @@ func (m *Monitor) pollVMsAndContainersEfficient(ctx context.Context, instanceNam guestRaw.Balloon = detailedStatus.Balloon guestRaw.BalloonMin = detailedStatus.BalloonMin guestRaw.Agent = detailedStatus.Agent.Value - memAvailable := uint64(0) - memInfoTotalMinusUsed := uint64(0) + agentEnabled = detailedStatus.Agent.Value > 0 if detailedStatus.MemInfo != nil { guestRaw.MemInfoUsed = detailedStatus.MemInfo.Used guestRaw.MemInfoFree = detailedStatus.MemInfo.Free @@ -7114,105 +7119,6 @@ func (m *Monitor) pollVMsAndContainersEfficient(ctx context.Context, instanceNam memTotal = detailedStatus.MaxMem } - // Fallback for Linux VMs when the guest agent doesn't provide MemInfo.Available: - // try Proxmox RRD's memavailable (cache-aware) before falling back to status.Mem - // which can include reclaimable page cache (inflating usage). Refs: #1270 - if memAvailable == 0 { - if rrdAvailable, rrdErr := m.getVMRRDMetrics(ctx, client, instanceName, res.Node, res.VMID); rrdErr == nil && rrdAvailable > 0 { - memAvailable = rrdAvailable - memorySource = "rrd-memavailable" - guestRaw.MemInfoAvailable = memAvailable - log.Debug(). - Str("vm", res.Name). - Str("node", res.Node). - Int("vmid", res.VMID). - Uint64("total", memTotal). - Uint64("available", memAvailable). - Msg("QEMU memory: using RRD memavailable fallback (excludes reclaimable cache)") - } else if rrdErr != nil { - log.Debug(). - Err(rrdErr). - Str("instance", instanceName). - Str("vm", res.Name). - Int("vmid", res.VMID). - Msg("RRD memory data unavailable for VM, using status/cluster resources values") - } - } - - // Fallback: use linked Pulse host agent's memory data. - // gopsutil's Used = Total - Available (excludes page cache), - // so we can derive accurate available memory. Refs: #1270 - if memAvailable == 0 { - if agentHost, ok := vmIDToHostAgent[guestID]; ok { - agentAvailable := agentHost.Memory.Total - agentHost.Memory.Used - if agentAvailable > 0 { - memAvailable = uint64(agentAvailable) - memorySource = "host-agent" - guestRaw.HostAgentTotal = uint64(agentHost.Memory.Total) - guestRaw.HostAgentUsed = uint64(agentHost.Memory.Used) - log.Debug(). - Str("vm", res.Name). - Str("node", res.Node). - Int("vmid", res.VMID). - Uint64("total", memTotal). - Uint64("available", memAvailable). - Int64("agentTotal", agentHost.Memory.Total). - Int64("agentUsed", agentHost.Memory.Used). - Msg("QEMU memory: using linked Pulse host agent memory (excludes page cache)") - } - } - } - - // Last-resort fallback before status-mem: read /proc/meminfo via the - // QEMU guest agent's file-read endpoint. This works for Linux VMs with - // the guest agent running even when the balloon driver doesn't populate - // the meminfo fields. Results are cached with negative backoff. Refs: #1270 - if memAvailable == 0 && detailedStatus.Agent.Value > 0 { - if agentAvail, agentErr := m.getVMAgentMemAvailable(ctx, client, instanceName, res.Node, res.VMID); agentErr == nil && agentAvail > 0 { - memAvailable = agentAvail - memorySource = "guest-agent-meminfo" - guestRaw.MemInfoAvailable = memAvailable - log.Debug(). - Str("vm", res.Name). - Str("node", res.Node). - Int("vmid", res.VMID). - Uint64("total", memTotal). - Uint64("available", memAvailable). - Msg("QEMU memory: using guest agent /proc/meminfo fallback (excludes reclaimable cache)") - } - } - - // Last-chance MemInfo fallback: use Total-Used only after RRD/agent - // attempts, so those more reliable cache-aware sources get priority. - if memAvailable == 0 && memInfoTotalMinusUsed > 0 { - memAvailable = memInfoTotalMinusUsed - memorySource = "meminfo-total-minus-used" - } - - switch { - case memAvailable > 0: - if memAvailable > memTotal { - memAvailable = memTotal - } - memUsed = memTotal - memAvailable - case detailedStatus.Mem > 0: - // Prefer Mem over FreeMem: Proxmox calculates Mem as - // (total_mem - free_mem) using the balloon's guest-visible - // total, which is correct even when ballooning is active. - // FreeMem is relative to the balloon allocation (not MaxMem), - // so subtracting it from MaxMem produces wildly inflated - // usage when the balloon has reduced the VM's memory. - // Refs: #1185 - memUsed = detailedStatus.Mem - memorySource = "status-mem" - case detailedStatus.FreeMem > 0 && memTotal >= detailedStatus.FreeMem: - memUsed = memTotal - detailedStatus.FreeMem - memorySource = "status-freemem" - } - if memUsed > memTotal { - memUsed = memTotal - } - // Gather guest metadata from the agent when available guestIPs, guestIfaces, guestOSName, guestOSVersion, guestAgentVersion := m.fetchGuestAgentMetadata(ctx, client, instanceName, res.Node, res.Name, res.VMID, detailedStatus) if len(guestIPs) > 0 { @@ -7542,9 +7448,125 @@ func (m *Monitor) pollVMsAndContainersEfficient(ctx context.Context, instanceNam } } - if res.Status != "running" { + // Fallback for Linux VMs when the guest agent doesn't provide MemInfo.Available: + // try Proxmox RRD memory first, even if detailed status was unavailable. + if res.Status == "running" && memAvailable == 0 { + if rrdEntry, rrdErr := m.getVMRRDMetrics(ctx, client, instanceName, res.Node, res.VMID); rrdErr == nil { + if rrdEntry.total > 0 { + memTotal = rrdEntry.total + } + if rrdEntry.available > 0 { + memAvailable = rrdEntry.available + memorySource = "rrd-memavailable" + guestRaw.RRDAvailable = rrdEntry.available + guestRaw.MemInfoAvailable = rrdEntry.available + log.Debug(). + Str("vm", res.Name). + Str("node", res.Node). + Int("vmid", res.VMID). + Uint64("total", memTotal). + Uint64("available", memAvailable). + Msg("QEMU memory: using RRD memavailable fallback (excludes reclaimable cache)") + } else if rrdEntry.used > 0 { + rrdUsed = rrdEntry.used + memorySource = "rrd-memused" + guestRaw.RRDUsed = rrdEntry.used + log.Debug(). + Str("vm", res.Name). + Str("node", res.Node). + Int("vmid", res.VMID). + Uint64("total", memTotal). + Uint64("used", rrdUsed). + Msg("QEMU memory: using RRD memused fallback") + } + } else { + log.Debug(). + Err(rrdErr). + Str("instance", instanceName). + Str("vm", res.Name). + Int("vmid", res.VMID). + Msg("RRD memory data unavailable for VM, using status/cluster resources values") + } + } + + // Fallback: use linked Pulse host agent's memory data. + // gopsutil's Used = Total - Available (excludes page cache), + // so we can derive accurate available memory. Refs: #1270 + if res.Status == "running" && memAvailable == 0 { + if agentHost, ok := vmIDToHostAgent[guestID]; ok { + agentAvailable := agentHost.Memory.Total - agentHost.Memory.Used + if agentAvailable > 0 { + memAvailable = uint64(agentAvailable) + memorySource = "host-agent" + guestRaw.HostAgentTotal = uint64(agentHost.Memory.Total) + guestRaw.HostAgentUsed = uint64(agentHost.Memory.Used) + log.Debug(). + Str("vm", res.Name). + Str("node", res.Node). + Int("vmid", res.VMID). + Uint64("total", memTotal). + Uint64("available", memAvailable). + Int64("agentTotal", agentHost.Memory.Total). + Int64("agentUsed", agentHost.Memory.Used). + Msg("QEMU memory: using linked Pulse host agent memory (excludes page cache)") + } + } + } + + // Last-resort fallback before status-mem: read /proc/meminfo via the + // QEMU guest agent's file-read endpoint. This works for Linux VMs with + // the guest agent running even when the balloon driver doesn't populate + // the meminfo fields. Results are cached with negative backoff. Refs: #1270 + if res.Status == "running" && memAvailable == 0 && agentEnabled { + if agentAvail, agentErr := m.getVMAgentMemAvailable(ctx, client, instanceName, res.Node, res.VMID); agentErr == nil && agentAvail > 0 { + memAvailable = agentAvail + memorySource = "guest-agent-meminfo" + guestRaw.MemInfoAvailable = memAvailable + log.Debug(). + Str("vm", res.Name). + Str("node", res.Node). + Int("vmid", res.VMID). + Uint64("total", memTotal). + Uint64("available", memAvailable). + Msg("QEMU memory: using guest agent /proc/meminfo fallback (excludes reclaimable cache)") + } + } + + // Last-chance MemInfo fallback: use Total-Used only after RRD/agent + // attempts, so those more reliable cache-aware sources get priority. + if res.Status == "running" && memAvailable == 0 && memInfoTotalMinusUsed > 0 { + memAvailable = memInfoTotalMinusUsed + memorySource = "meminfo-total-minus-used" + } + + switch { + case res.Status != "running": memorySource = "powered-off" memUsed = 0 + case memAvailable > 0: + if memAvailable > memTotal { + memAvailable = memTotal + } + memUsed = memTotal - memAvailable + case rrdUsed > 0: + memUsed = rrdUsed + memorySource = "rrd-memused" + case detailedStatus != nil && detailedStatus.Mem > 0: + // Prefer Mem over FreeMem: Proxmox calculates Mem as + // (total_mem - free_mem) using the balloon's guest-visible + // total, which is correct even when ballooning is active. + // FreeMem is relative to the balloon allocation (not MaxMem), + // so subtracting it from MaxMem produces wildly inflated + // usage when the balloon has reduced the VM's memory. + // Refs: #1185 + memUsed = detailedStatus.Mem + memorySource = "status-mem" + case detailedStatus != nil && detailedStatus.FreeMem > 0 && memTotal >= detailedStatus.FreeMem: + memUsed = memTotal - detailedStatus.FreeMem + memorySource = "status-freemem" + } + if memUsed > memTotal { + memUsed = memTotal } memFree := uint64(0) diff --git a/internal/monitoring/monitor_additional_test.go b/internal/monitoring/monitor_additional_test.go index 65cce9912..967f2bf8f 100644 --- a/internal/monitoring/monitor_additional_test.go +++ b/internal/monitoring/monitor_additional_test.go @@ -149,27 +149,63 @@ func TestClusterEndpointEffectiveURL(t *testing.T) { IP: "10.0.0.1", } - if got := clusterEndpointEffectiveURL(endpoint, true, false); got != "https://node.local:8006" { + if got := clusterEndpointEffectiveURL(endpoint, true, ""); got != "https://node.local:8006" { t.Fatalf("verifySSL host preference = %q, want %q", got, "https://node.local:8006") } endpoint.Host = "" - if got := clusterEndpointEffectiveURL(endpoint, true, false); got != "https://10.0.0.1:8006" { + if got := clusterEndpointEffectiveURL(endpoint, true, ""); got != "https://10.0.0.1:8006" { t.Fatalf("verifySSL fallback to IP = %q, want %q", got, "https://10.0.0.1:8006") } endpoint.Host = "node.local" - if got := clusterEndpointEffectiveURL(endpoint, false, false); got != "https://10.0.0.1:8006" { + if got := clusterEndpointEffectiveURL(endpoint, false, ""); got != "https://10.0.0.1:8006" { t.Fatalf("non-SSL IP preference = %q, want %q", got, "https://10.0.0.1:8006") } endpoint.IPOverride = "192.168.1.10" - if got := clusterEndpointEffectiveURL(endpoint, false, false); got != "https://192.168.1.10:8006" { + if got := clusterEndpointEffectiveURL(endpoint, false, ""); got != "https://192.168.1.10:8006" { t.Fatalf("override IP preference = %q, want %q", got, "https://192.168.1.10:8006") } + endpoint.Fingerprint = "ep-fingerprint" + if got := clusterEndpointEffectiveURL(endpoint, true, ""); got != "https://192.168.1.10:8006" { + t.Fatalf("per-endpoint fingerprint should allow IP override, got %q", got) + } + endpoint = config.ClusterEndpoint{} - if got := clusterEndpointEffectiveURL(endpoint, true, false); got != "" { + if got := clusterEndpointEffectiveURL(endpoint, true, ""); got != "" { t.Fatalf("empty endpoint = %q, want empty", got) } } + +func TestBuildClusterClientEndpoints_PrefersOverrideWhenEndpointFingerprintPresent(t *testing.T) { + pve := config.PVEInstance{ + Name: "cluster-a", + Host: "https://cluster-a.local:8006", + VerifySSL: true, + IsCluster: true, + ClusterName: "cluster-a", + ClusterEndpoints: []config.ClusterEndpoint{ + { + NodeName: "node1", + Host: "https://node1.local:8006", + IP: "10.15.5.11", + IPOverride: "10.15.2.11", + Fingerprint: "node1-fp", + }, + }, + } + + endpoints, fingerprints := buildClusterClientEndpoints(pve) + + if len(endpoints) != 2 { + t.Fatalf("expected endpoint plus main host fallback, got %d", len(endpoints)) + } + if endpoints[0] != "https://10.15.2.11:8006" { + t.Fatalf("expected endpoint override URL first, got %q", endpoints[0]) + } + if fingerprints["https://10.15.2.11:8006"] != "node1-fp" { + t.Fatalf("expected fingerprint to follow effective endpoint URL, got %q", fingerprints["https://10.15.2.11:8006"]) + } +} diff --git a/internal/monitoring/monitor_extra_coverage_test.go b/internal/monitoring/monitor_extra_coverage_test.go index 63f2bca46..1cdf43d95 100644 --- a/internal/monitoring/monitor_extra_coverage_test.go +++ b/internal/monitoring/monitor_extra_coverage_test.go @@ -241,10 +241,13 @@ func TestMonitor_LinkNodeToHostAgent_Extra(t *testing.T) { type mockPVEClientExtra struct { mockPVEClient - resources []proxmox.ClusterResource - vmStatus *proxmox.VMStatus - fsInfo []proxmox.VMFileSystem - netIfaces []proxmox.VMNetworkInterface + resources []proxmox.ClusterResource + vms []proxmox.VM + vmStatus *proxmox.VMStatus + vmStatusErr error + fsInfo []proxmox.VMFileSystem + netIfaces []proxmox.VMNetworkInterface + vmRRDPoints []proxmox.GuestRRDPoint } func (m *mockPVEClientExtra) GetClusterResources(ctx context.Context, resourceType string) ([]proxmox.ClusterResource, error) { @@ -252,6 +255,9 @@ func (m *mockPVEClientExtra) GetClusterResources(ctx context.Context, resourceTy } func (m *mockPVEClientExtra) GetVMStatus(ctx context.Context, node string, vmid int) (*proxmox.VMStatus, error) { + if m.vmStatusErr != nil { + return nil, m.vmStatusErr + } return m.vmStatus, nil } @@ -304,7 +310,11 @@ func (m *mockPVEClientExtra) GetLXCRRDData(ctx context.Context, node string, vmi } func (m *mockPVEClientExtra) GetVMRRDData(ctx context.Context, node string, vmid int, timeframe string, cf string, ds []string) ([]proxmox.GuestRRDPoint, error) { - return nil, nil + return m.vmRRDPoints, nil +} + +func (m *mockPVEClientExtra) GetVMs(ctx context.Context, node string) ([]proxmox.VM, error) { + return m.vms, nil } func (m *mockPVEClientExtra) GetNodeStatus(ctx context.Context, node string) (*proxmox.NodeStatus, error) { @@ -390,6 +400,105 @@ func TestMonitor_PollVMsAndContainersEfficient_Extra(t *testing.T) { } } +func TestMonitor_PollVMsAndContainersEfficient_UsesVMRRDMemUsedWhenStatusUnavailable(t *testing.T) { + const total = uint64(8 << 30) + const inflatedUsed = uint64(6 << 30) + const rrdUsed = uint64(3 << 30) + + m := &Monitor{ + state: models.NewState(), + guestAgentFSInfoTimeout: time.Second, + guestAgentRetries: 1, + guestAgentNetworkTimeout: time.Second, + guestAgentOSInfoTimeout: time.Second, + guestAgentVersionTimeout: time.Second, + guestMetadataCache: make(map[string]guestMetadataCacheEntry), + guestMetadataLimiter: make(map[string]time.Time), + rateTracker: NewRateTracker(), + metricsHistory: NewMetricsHistory(100, time.Hour), + alertManager: alerts.NewManager(), + stalenessTracker: NewStalenessTracker(nil), + nodeRRDMemCache: make(map[string]rrdMemCacheEntry), + vmRRDMemCache: make(map[string]rrdMemCacheEntry), + vmAgentMemCache: make(map[string]agentMemCacheEntry), + } + defer m.alertManager.Stop() + + client := &mockPVEClientExtra{ + resources: []proxmox.ClusterResource{ + {Type: "qemu", VMID: 100, Name: "vm100", Node: "node1", Status: "running", MaxMem: total, Mem: inflatedUsed}, + }, + vmStatusErr: fmt.Errorf("API error 403: status unavailable"), + vmRRDPoints: []proxmox.GuestRRDPoint{ + {MaxMem: floatPtr(float64(total)), MemUsed: floatPtr(float64(rrdUsed))}, + }, + } + + success := m.pollVMsAndContainersEfficient(context.Background(), "pve1", "", false, client, map[string]string{"node1": "online"}) + if !success { + t.Fatal("pollVMsAndContainersEfficient failed") + } + + state := m.GetState() + if len(state.VMs) != 1 { + t.Fatalf("expected 1 VM, got %d", len(state.VMs)) + } + if state.VMs[0].Memory.Used != int64(rrdUsed) { + t.Fatalf("memory used mismatch: got %d want %d", state.VMs[0].Memory.Used, rrdUsed) + } + if state.VMs[0].MemorySource != "rrd-memused" { + t.Fatalf("memory source mismatch: got %q want rrd-memused", state.VMs[0].MemorySource) + } +} + +func TestMonitor_PollVMsWithNodes_UsesVMRRDMemUsedWhenStatusUnavailable(t *testing.T) { + const total = uint64(8 << 30) + const inflatedUsed = uint64(6 << 30) + const rrdUsed = uint64(3 << 30) + + m := &Monitor{ + state: models.NewState(), + guestAgentFSInfoTimeout: time.Second, + guestAgentRetries: 1, + guestAgentNetworkTimeout: time.Second, + guestAgentOSInfoTimeout: time.Second, + guestAgentVersionTimeout: time.Second, + guestMetadataCache: make(map[string]guestMetadataCacheEntry), + guestMetadataLimiter: make(map[string]time.Time), + rateTracker: NewRateTracker(), + metricsHistory: NewMetricsHistory(100, time.Hour), + alertManager: alerts.NewManager(), + stalenessTracker: NewStalenessTracker(nil), + nodeRRDMemCache: make(map[string]rrdMemCacheEntry), + vmRRDMemCache: make(map[string]rrdMemCacheEntry), + vmAgentMemCache: make(map[string]agentMemCacheEntry), + } + defer m.alertManager.Stop() + + client := &mockPVEClientExtra{ + vms: []proxmox.VM{ + {VMID: 100, Name: "vm100", Node: "node1", Status: "running", MaxMem: total, Mem: inflatedUsed}, + }, + vmStatusErr: fmt.Errorf("API error 403: status unavailable"), + vmRRDPoints: []proxmox.GuestRRDPoint{ + {MaxMem: floatPtr(float64(total)), MemUsed: floatPtr(float64(rrdUsed))}, + }, + } + + m.pollVMsWithNodes(context.Background(), "pve1", "", false, client, []proxmox.Node{{Node: "node1", Status: "online"}}, map[string]string{"node1": "online"}) + + state := m.GetState() + if len(state.VMs) != 1 { + t.Fatalf("expected 1 VM, got %d", len(state.VMs)) + } + if state.VMs[0].Memory.Used != int64(rrdUsed) { + t.Fatalf("memory used mismatch: got %d want %d", state.VMs[0].Memory.Used, rrdUsed) + } + if state.VMs[0].MemorySource != "rrd-memused" { + t.Fatalf("memory source mismatch: got %q want rrd-memused", state.VMs[0].MemorySource) + } +} + func TestMonitor_MiscSetters_Extra(t *testing.T) { m := &Monitor{ state: models.NewState(), diff --git a/internal/monitoring/monitor_polling.go b/internal/monitoring/monitor_polling.go index f8d470d9d..42a2aeebc 100644 --- a/internal/monitoring/monitor_polling.go +++ b/internal/monitoring/monitor_polling.go @@ -285,9 +285,13 @@ func (m *Monitor) pollVMsWithNodes(ctx context.Context, instanceName string, clu networkOutBytes := int64(vm.NetOut) // Get memory info for running VMs (and agent status for disk) - memUsed := uint64(0) + memUsed := vm.Mem memTotal := vm.MaxMem var vmStatus *proxmox.VMStatus + memAvailable := uint64(0) + memInfoTotalMinusUsed := uint64(0) + rrdUsed := uint64(0) + agentEnabled := vm.Agent > 0 var ipAddresses []string var networkInterfaces []models.GuestNetworkInterface var osName, osVersion, guestAgentVersion string @@ -303,8 +307,7 @@ func (m *Monitor) pollVMsWithNodes(ctx context.Context, instanceName string, clu guestRaw.Balloon = status.Balloon guestRaw.BalloonMin = status.BalloonMin guestRaw.Agent = status.Agent.Value - memAvailable := uint64(0) - memInfoTotalMinusUsed := uint64(0) + agentEnabled = status.Agent.Value > 0 if status.MemInfo != nil { guestRaw.MemInfoUsed = status.MemInfo.Used guestRaw.MemInfoFree = status.MemInfo.Free @@ -328,97 +331,6 @@ func (m *Monitor) pollVMsWithNodes(ctx context.Context, instanceName string, clu // confusion (showing 1GB/1GB at 100% when VM is configured for 4GB) // and makes the frontend's balloon marker logic ineffective. // Refs: #1070 - - // Fallback: try RRD memavailable (cached). Refs: #1270 - if memAvailable == 0 { - if rrdAvailable, rrdErr := m.getVMRRDMetrics(ctx, client, instanceName, n.Node, vm.VMID); rrdErr == nil && rrdAvailable > 0 { - memAvailable = rrdAvailable - memorySource = "rrd-memavailable" - guestRaw.MemInfoAvailable = memAvailable - log.Debug(). - Str("vm", vm.Name). - Str("node", n.Node). - Int("vmid", vm.VMID). - Uint64("total", memTotal). - Uint64("available", memAvailable). - Msg("QEMU memory: using RRD memavailable fallback (excludes reclaimable cache)") - } - } - - // Fallback: use linked Pulse host agent's memory data. - // gopsutil's Used = Total - Available (excludes page cache), - // so we can derive accurate available memory. Refs: #1270 - if memAvailable == 0 { - if agentHost, ok := vmIDToHostAgent[guestID]; ok { - agentAvailable := agentHost.Memory.Total - agentHost.Memory.Used - if agentAvailable > 0 { - memAvailable = uint64(agentAvailable) - memorySource = "host-agent" - guestRaw.HostAgentTotal = uint64(agentHost.Memory.Total) - guestRaw.HostAgentUsed = uint64(agentHost.Memory.Used) - log.Debug(). - Str("vm", vm.Name). - Str("node", n.Node). - Int("vmid", vm.VMID). - Uint64("total", memTotal). - Uint64("available", memAvailable). - Int64("agentTotal", agentHost.Memory.Total). - Int64("agentUsed", agentHost.Memory.Used). - Msg("QEMU memory: using linked Pulse host agent memory (excludes page cache)") - } - } - } - - // Last-resort fallback before status-mem: read /proc/meminfo via the - // QEMU guest agent file-read endpoint. Refs: #1270 - if memAvailable == 0 && status.Agent.Value > 0 { - if agentAvail, agentErr := m.getVMAgentMemAvailable(ctx, client, instanceName, n.Node, vm.VMID); agentErr == nil && agentAvail > 0 { - memAvailable = agentAvail - memorySource = "guest-agent-meminfo" - guestRaw.MemInfoAvailable = memAvailable - log.Debug(). - Str("vm", vm.Name). - Str("node", n.Node). - Int("vmid", vm.VMID). - Uint64("total", memTotal). - Uint64("available", memAvailable). - Msg("QEMU memory: using guest agent /proc/meminfo fallback (excludes reclaimable cache)") - } - } - - // Last-chance MemInfo fallback: use Total-Used only after RRD/agent - // attempts, so those more reliable cache-aware sources get priority. - if memAvailable == 0 && memInfoTotalMinusUsed > 0 { - memAvailable = memInfoTotalMinusUsed - memorySource = "meminfo-total-minus-used" - } - - switch { - case memAvailable > 0: - if memAvailable > memTotal { - memAvailable = memTotal - } - memUsed = memTotal - memAvailable - case vmStatus.Mem > 0: - // Prefer Mem over FreeMem: Proxmox calculates Mem as - // (total_mem - free_mem) using the balloon's guest-visible - // total, which is correct even when ballooning is active. - // FreeMem is relative to the balloon allocation (not MaxMem), - // so subtracting it from MaxMem produces wildly inflated - // usage when the balloon has reduced the VM's memory. - // Refs: #1185 - memUsed = vmStatus.Mem - memorySource = "status-mem" - case vmStatus.FreeMem > 0 && memTotal >= vmStatus.FreeMem: - memUsed = memTotal - vmStatus.FreeMem - memorySource = "status-freemem" - default: - memUsed = 0 - memorySource = "status-unavailable" - } - if memUsed > memTotal { - memUsed = memTotal - } // Use actual disk I/O values from detailed status diskReadBytes = int64(vmStatus.DiskRead) diskWriteBytes = int64(vmStatus.DiskWrite) @@ -428,11 +340,105 @@ func (m *Monitor) pollVMsWithNodes(ctx context.Context, instanceName string, clu cancel() } - if vm.Status != "running" { + if vm.Status == "running" && memAvailable == 0 { + if rrdEntry, rrdErr := m.getVMRRDMetrics(ctx, client, instanceName, n.Node, vm.VMID); rrdErr == nil { + if rrdEntry.total > 0 { + memTotal = rrdEntry.total + } + if rrdEntry.available > 0 { + memAvailable = rrdEntry.available + memorySource = "rrd-memavailable" + guestRaw.RRDAvailable = rrdEntry.available + guestRaw.MemInfoAvailable = rrdEntry.available + log.Debug(). + Str("vm", vm.Name). + Str("node", n.Node). + Int("vmid", vm.VMID). + Uint64("total", memTotal). + Uint64("available", memAvailable). + Msg("QEMU memory: using RRD memavailable fallback (excludes reclaimable cache)") + } else if rrdEntry.used > 0 { + rrdUsed = rrdEntry.used + memorySource = "rrd-memused" + guestRaw.RRDUsed = rrdEntry.used + log.Debug(). + Str("vm", vm.Name). + Str("node", n.Node). + Int("vmid", vm.VMID). + Uint64("total", memTotal). + Uint64("used", rrdUsed). + Msg("QEMU memory: using RRD memused fallback") + } + } + } + + if vm.Status == "running" && memAvailable == 0 { + if agentHost, ok := vmIDToHostAgent[guestID]; ok { + agentAvailable := agentHost.Memory.Total - agentHost.Memory.Used + if agentAvailable > 0 { + memAvailable = uint64(agentAvailable) + memorySource = "host-agent" + guestRaw.HostAgentTotal = uint64(agentHost.Memory.Total) + guestRaw.HostAgentUsed = uint64(agentHost.Memory.Used) + log.Debug(). + Str("vm", vm.Name). + Str("node", n.Node). + Int("vmid", vm.VMID). + Uint64("total", memTotal). + Uint64("available", memAvailable). + Int64("agentTotal", agentHost.Memory.Total). + Int64("agentUsed", agentHost.Memory.Used). + Msg("QEMU memory: using linked Pulse host agent memory (excludes page cache)") + } + } + } + + if vm.Status == "running" && memAvailable == 0 && agentEnabled { + if agentAvail, agentErr := m.getVMAgentMemAvailable(ctx, client, instanceName, n.Node, vm.VMID); agentErr == nil && agentAvail > 0 { + memAvailable = agentAvail + memorySource = "guest-agent-meminfo" + guestRaw.MemInfoAvailable = memAvailable + log.Debug(). + Str("vm", vm.Name). + Str("node", n.Node). + Int("vmid", vm.VMID). + Uint64("total", memTotal). + Uint64("available", memAvailable). + Msg("QEMU memory: using guest agent /proc/meminfo fallback (excludes reclaimable cache)") + } + } + + if vm.Status == "running" && memAvailable == 0 && memInfoTotalMinusUsed > 0 { + memAvailable = memInfoTotalMinusUsed + memorySource = "meminfo-total-minus-used" + } + + switch { + case vm.Status != "running": + memUsed = 0 memorySource = "powered-off" - } else if vmStatus == nil { + case memAvailable > 0: + if memAvailable > memTotal { + memAvailable = memTotal + } + memUsed = memTotal - memAvailable + case rrdUsed > 0: + memUsed = rrdUsed + memorySource = "rrd-memused" + case vmStatus != nil && vmStatus.Mem > 0: + memUsed = vmStatus.Mem + memorySource = "status-mem" + case vmStatus != nil && vmStatus.FreeMem > 0 && memTotal >= vmStatus.FreeMem: + memUsed = memTotal - vmStatus.FreeMem + memorySource = "status-freemem" + case vmStatus == nil: + memorySource = "listing-mem" + default: memorySource = "status-unavailable" } + if memUsed > memTotal { + memUsed = memTotal + } if vm.Status == "running" && vmStatus != nil { guestIPs, guestIfaces, guestOSName, guestOSVersion, agentVersion := m.fetchGuestAgentMetadata(ctx, client, instanceName, n.Node, vm.Name, vm.VMID, vmStatus) @@ -1732,12 +1738,11 @@ func (m *Monitor) fetchNodeStorageFallback(ctx context.Context, instanceCfg *con } var target string - hasFingerprint := strings.TrimSpace(instanceCfg.Fingerprint) != "" for _, ep := range instanceCfg.ClusterEndpoints { if !strings.EqualFold(ep.NodeName, nodeName) { continue } - target = clusterEndpointEffectiveURL(ep, instanceCfg.VerifySSL, hasFingerprint) + target = clusterEndpointEffectiveURL(ep, instanceCfg.VerifySSL, instanceCfg.Fingerprint) if target != "" { break } @@ -1784,10 +1789,9 @@ func (m *Monitor) pollPVENode( connectionHost := instanceCfg.Host guestURL := instanceCfg.GuestURL if instanceCfg.IsCluster && len(instanceCfg.ClusterEndpoints) > 0 { - hasFingerprint := instanceCfg.Fingerprint != "" for _, ep := range instanceCfg.ClusterEndpoints { if strings.EqualFold(ep.NodeName, node.Node) { - if effective := clusterEndpointEffectiveURL(ep, instanceCfg.VerifySSL, hasFingerprint); effective != "" { + if effective := clusterEndpointEffectiveURL(ep, instanceCfg.VerifySSL, instanceCfg.Fingerprint); effective != "" { connectionHost = effective } if ep.GuestURL != "" { @@ -2274,10 +2278,9 @@ func (m *Monitor) pollPVENode( if modelNode.IsClusterMember && instanceCfg.IsCluster { // Try to find specific endpoint configuration for this node if len(instanceCfg.ClusterEndpoints) > 0 { - hasFingerprint := instanceCfg.Fingerprint != "" for _, ep := range instanceCfg.ClusterEndpoints { if strings.EqualFold(ep.NodeName, node.Node) { - if effective := clusterEndpointEffectiveURL(ep, instanceCfg.VerifySSL, hasFingerprint); effective != "" { + if effective := clusterEndpointEffectiveURL(ep, instanceCfg.VerifySSL, instanceCfg.Fingerprint); effective != "" { sshHost = effective foundNodeEndpoint = true } diff --git a/internal/monitoring/multi_tenant_monitor.go b/internal/monitoring/multi_tenant_monitor.go index f426c7267..ca071a3fd 100644 --- a/internal/monitoring/multi_tenant_monitor.go +++ b/internal/monitoring/multi_tenant_monitor.go @@ -37,6 +37,10 @@ func NewMultiTenantMonitor(baseCfg *config.Config, persistence *config.MultiTena // GetMonitor returns the monitor instance for a specific organization. // It lazily initializes the monitor if it doesn't exist. func (mtm *MultiTenantMonitor) GetMonitor(orgID string) (*Monitor, error) { + if orgID == "" { + orgID = "default" + } + mtm.mu.RLock() monitor, exists := mtm.monitors[orgID] mtm.mu.RUnlock() @@ -56,6 +60,24 @@ func (mtm *MultiTenantMonitor) GetMonitor(orgID string) (*Monitor, error) { // Initialize new monitor for this tenant log.Info().Str("org_id", orgID).Msg("Initializing tenant monitor") + // Single-tenant runtime path: no tenant persistence is available, so only + // the default monitor can be created from the base config directly. + if mtm.persistence == nil { + if orgID != "default" { + return nil, fmt.Errorf("tenant monitor unavailable in single-tenant mode: %s", orgID) + } + + var err error + monitor, err = New(mtm.baseConfig.DeepCopy()) + if err != nil { + return nil, fmt.Errorf("failed to create default monitor: %w", err) + } + monitor.SetOrgID("default") + go monitor.Start(mtm.globalCtx, mtm.wsHub) + mtm.monitors[orgID] = monitor + return monitor, nil + } + // 1. Load Tenant Config // Deep copy the base config to ensure tenant isolation. // Each tenant gets its own independent config that won't share @@ -92,6 +114,22 @@ func (mtm *MultiTenantMonitor) GetMonitor(orgID string) (*Monitor, error) { Msg("Loaded tenant nodes config") } + // Load tenant-scoped API tokens in addition to any global tokens inherited + // from the base config so org-specific agents continue working after + // the tenant monitor is recreated. + tenantTokens, err := tenantPersistence.LoadAPITokens() + if err != nil { + log.Warn().Err(err).Str("org_id", orgID).Msg("Failed to load tenant API tokens") + } else if len(tenantTokens) > 0 { + tenantConfig.APITokens = mergeAPITokens(tenantConfig.APITokens, tenantTokens) + tenantConfig.SortAPITokens() + log.Info(). + Str("org_id", orgID). + Int("tenant_tokens", len(tenantTokens)). + Int("total_tokens", len(tenantConfig.APITokens)). + Msg("Loaded tenant API tokens") + } + // 2. Create Monitor // Usage of internal New constructor monitor, err = New(tenantConfig) @@ -148,3 +186,34 @@ func (mtm *MultiTenantMonitor) OrgExists(orgID string) bool { } return mtm.persistence.OrgExists(orgID) } + +func mergeAPITokens(baseTokens, tenantTokens []config.APITokenRecord) []config.APITokenRecord { + if len(baseTokens) == 0 && len(tenantTokens) == 0 { + return nil + } + + merged := make([]config.APITokenRecord, 0, len(baseTokens)+len(tenantTokens)) + seen := make(map[string]struct{}, len(baseTokens)+len(tenantTokens)) + + appendUnique := func(tokens []config.APITokenRecord) { + for _, token := range tokens { + key := token.Hash + if key == "" { + key = token.ID + } + if key == "" { + continue + } + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + merged = append(merged, token.Clone()) + } + } + + appendUnique(baseTokens) + appendUnique(tenantTokens) + + return merged +} diff --git a/internal/monitoring/multi_tenant_monitor_additional_test.go b/internal/monitoring/multi_tenant_monitor_additional_test.go index b5ffd6978..81060234e 100644 --- a/internal/monitoring/multi_tenant_monitor_additional_test.go +++ b/internal/monitoring/multi_tenant_monitor_additional_test.go @@ -1,6 +1,10 @@ package monitoring -import "testing" +import ( + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) func TestMultiTenantMonitorRemoveTenant(t *testing.T) { monitor := &Monitor{} @@ -18,3 +22,55 @@ func TestMultiTenantMonitorRemoveTenant(t *testing.T) { // Ensure removal of missing orgs is a no-op. mtm.RemoveTenant("missing") } + +func TestMultiTenantMonitorLoadsTenantAPITokens(t *testing.T) { + baseDir := t.TempDir() + baseCfg := &config.Config{ + DataPath: baseDir, + ConfigPath: baseDir, + } + + globalRecord, err := config.NewAPITokenRecord("global-token-123.12345678", "global", []string{config.ScopeMonitoringRead}) + if err != nil { + t.Fatalf("new global token: %v", err) + } + baseCfg.APITokens = []config.APITokenRecord{*globalRecord} + baseCfg.SortAPITokens() + + mtp := config.NewMultiTenantPersistence(baseDir) + tenantPersistence, err := mtp.GetPersistence("org-1") + if err != nil { + t.Fatalf("get tenant persistence: %v", err) + } + + tenantRecord, err := config.NewAPITokenRecord("tenant-token-123.12345678", "tenant", []string{config.ScopeHostReport}) + if err != nil { + t.Fatalf("new tenant token: %v", err) + } + tenantRecord.OrgID = "org-1" + if err := tenantPersistence.SaveAPITokens([]config.APITokenRecord{*tenantRecord}); err != nil { + t.Fatalf("save tenant tokens: %v", err) + } + + mtm := NewMultiTenantMonitor(baseCfg, mtp, nil) + t.Cleanup(mtm.Stop) + + monitor, err := mtm.GetMonitor("org-1") + if err != nil { + t.Fatalf("get tenant monitor: %v", err) + } + + cfg := monitor.GetConfig() + if cfg == nil { + t.Fatalf("expected tenant config") + } + if len(cfg.APITokens) != 2 { + t.Fatalf("expected 2 merged api tokens, got %d", len(cfg.APITokens)) + } + if !cfg.IsValidAPIToken("global-token-123.12345678") { + t.Fatalf("expected global token to remain valid") + } + if !cfg.IsValidAPIToken("tenant-token-123.12345678") { + t.Fatalf("expected tenant token to be loaded") + } +} diff --git a/internal/monitoring/reload_test.go b/internal/monitoring/reload_test.go index 65b5c4ead..f5f1cd159 100644 --- a/internal/monitoring/reload_test.go +++ b/internal/monitoring/reload_test.go @@ -80,3 +80,40 @@ func TestReloadableMonitor_Lifecycle_Coverage(t *testing.T) { // Test Stop rm.Stop() } + +func TestReloadableMonitor_SingleTenantLifecycle(t *testing.T) { + cfg := &config.Config{ + DataPath: t.TempDir(), + } + + rm, err := NewReloadableMonitor(cfg, nil, nil) + require.NoError(t, err) + require.NotNil(t, rm) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rm.Start(ctx) + + defaultMonitor := rm.GetMonitor() + require.NotNil(t, defaultMonitor) + + defaultState := rm.GetState("default") + require.NotNil(t, defaultState) + + nonDefaultState := rm.GetState("acme") + assert.Nil(t, nonDefaultState) + + errChan := make(chan error, 1) + go func() { + errChan <- rm.Reload() + }() + + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(3 * time.Second): + t.Fatal("single-tenant reload timed out") + } + + rm.Stop() +} diff --git a/internal/smartctl/collector.go b/internal/smartctl/collector.go index 1cea8f4ff..0e76c4b13 100644 --- a/internal/smartctl/collector.go +++ b/internal/smartctl/collector.go @@ -250,37 +250,118 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error) return nil, err } - // Run smartctl with standby check to avoid waking sleeping drives - // -n standby: don't check if drive is in standby (return exit code 2) - // -i: device info - // -A: attributes (for temperature) - // --json=o: output original smartctl JSON format - output, err := runCommandOutput(cmdCtx, smartctlPath, "-n", "standby", "-i", "-A", "-H", "--json=o", device) + attempts := smartctlProbeAttempts(device) + var firstParsed *DiskSMART + var lastErr error - // smartctl returns non-zero exit codes for various conditions - // Exit code 2 means drive is in standby - that's okay - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode := exitErr.ExitCode() - // Check for standby (bit 1 set in exit status) - if exitCode&2 != 0 { - return &DiskSMART{ - Device: filepath.Base(device), - Standby: true, - LastUpdated: timeNow(), - }, nil + for i, args := range attempts { + output, err := runCommandOutput(cmdCtx, smartctlPath, args...) + + // smartctl returns non-zero exit codes for various conditions. + // Exit code 2 means drive is in standby - that's okay. + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode := exitErr.ExitCode() + if exitCode&2 != 0 { + return &DiskSMART{ + Device: filepath.Base(device), + Standby: true, + LastUpdated: timeNow(), + }, nil + } + if len(output) == 0 { + lastErr = err + continue + } + } else { + lastErr = err + continue } - // Other exit codes might still have valid JSON output - // Continue parsing if we got output - if len(output) == 0 { - return nil, err - } - } else { - return nil, err + } + + result, parseErr := parseSMARTOutput(output, device) + if parseErr != nil { + lastErr = parseErr + continue + } + if firstParsed == nil { + firstParsed = result + } + if !shouldRetryFreeBSDSMART(device, result, i, len(attempts)) { + log.Debug(). + Str("device", result.Device). + Str("model", result.Model). + Int("temperature", result.Temperature). + Str("health", result.Health). + Msg("Collected SMART data") + return result, nil } } - // Parse JSON output + if firstParsed != nil { + log.Debug(). + Str("device", firstParsed.Device). + Str("model", firstParsed.Model). + Int("temperature", firstParsed.Temperature). + Str("health", firstParsed.Health). + Msg("Collected SMART data") + return firstParsed, nil + } + if lastErr != nil { + return nil, lastErr + } + return nil, nil +} + +func smartctlProbeAttempts(device string) [][]string { + attempts := [][]string{ + smartctlArgs(device, ""), + } + + for _, deviceType := range freeBSDSmartctlDeviceTypes(filepath.Base(device)) { + attempts = append(attempts, smartctlArgs(device, deviceType)) + } + + return attempts +} + +func smartctlArgs(device, deviceType string) []string { + args := []string{} + if deviceType != "" { + args = append(args, "-d", deviceType) + } + args = append(args, "-n", "standby", "-i", "-A", "-H", "--json=o", device) + return args +} + +func freeBSDSmartctlDeviceTypes(device string) []string { + if runtimeGOOS != "freebsd" { + return nil + } + + switch { + case strings.HasPrefix(device, "ada"), strings.HasPrefix(device, "ad"): + return []string{"sat"} + case strings.HasPrefix(device, "da"): + return []string{"sat,auto", "scsi"} + case strings.HasPrefix(device, "nvd"), strings.HasPrefix(device, "nvme"): + return []string{"nvme"} + default: + return nil + } +} + +func shouldRetryFreeBSDSMART(device string, result *DiskSMART, attemptIndex, attemptCount int) bool { + if runtimeGOOS != "freebsd" || attemptIndex >= attemptCount-1 || result == nil { + return false + } + if result.Temperature > 0 { + return false + } + return len(freeBSDSmartctlDeviceTypes(filepath.Base(device))) > 0 +} + +func parseSMARTOutput(output []byte, device string) (*DiskSMART, error) { var smartData smartctlJSON if err := json.Unmarshal(output, &smartData); err != nil { return nil, err @@ -294,27 +375,19 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error) LastUpdated: timeNow(), } - // Build WWN string if available if smartData.WWN.NAA != 0 { result.WWN = formatWWN(smartData.WWN.NAA, smartData.WWN.OUI, smartData.WWN.ID) } - // Get temperature (different location for NVMe vs SATA). - // Try top-level fields first, then fall back to ATA attributes 194/190. if smartData.Temperature.Current > 0 { result.Temperature = smartData.Temperature.Current } else if smartData.NVMeSmartHealthInformationLog.Temperature > 0 { result.Temperature = smartData.NVMeSmartHealthInformationLog.Temperature } else { - // Fallback: extract from ATA SMART attributes 194 (Temperature_Celsius) - // or 190 (Airflow_Temperature_Cel). Some drives don't populate the - // top-level temperature field. for _, attr := range smartData.ATASmartAttributes.Table { if attr.ID == 194 || attr.ID == 190 { - // Temperature is typically in the raw value's lower byte. - // raw.string format: "34" or "34 (Min/Max 22/45)" etc. temp := int(parseRawValue(attr.Raw.String, attr.Raw.Value)) - if temp > 0 && temp < 150 { // sanity: valid operating range + if temp > 0 && temp < 150 { result.Temperature = temp break } @@ -322,23 +395,14 @@ func collectDeviceSMART(ctx context.Context, device string) (*DiskSMART, error) } } - // Get health status if smartData.SmartStatus.Passed { result.Health = "PASSED" } else { result.Health = "FAILED" } - // Parse SMART attributes result.Attributes = parseSMARTAttributes(&smartData, result.Type) - log.Debug(). - Str("device", result.Device). - Str("model", result.Model). - Int("temperature", result.Temperature). - Str("health", result.Health). - Msg("Collected SMART data") - return result, nil } diff --git a/internal/smartctl/collector_coverage_test.go b/internal/smartctl/collector_coverage_test.go index bad30851f..81fad054e 100644 --- a/internal/smartctl/collector_coverage_test.go +++ b/internal/smartctl/collector_coverage_test.go @@ -316,6 +316,57 @@ func TestCollectDeviceSMARTNVMeTempFallback(t *testing.T) { } } +func TestCollectDeviceSMARTFreeBSDAdaFallback(t *testing.T) { + origRun := runCommandOutput + origLook := execLookPath + origNow := timeNow + origGOOS := runtimeGOOS + t.Cleanup(func() { + runCommandOutput = origRun + execLookPath = origLook + timeNow = origNow + runtimeGOOS = origGOOS + }) + + fixed := time.Date(2024, 4, 6, 7, 8, 9, 0, time.UTC) + timeNow = func() time.Time { return fixed } + execLookPath = func(string) (string, error) { return "smartctl", nil } + runtimeGOOS = "freebsd" + + var attempts [][]string + runCommandOutput = func(ctx context.Context, name string, args ...string) ([]byte, error) { + attempts = append(attempts, append([]string(nil), args...)) + + payload := smartctlJSON{} + payload.Device.Protocol = "ATA" + payload.SmartStatus.Passed = true + + if len(args) >= 2 && args[0] == "-d" && args[1] == "sat" { + payload.Temperature.Current = 37 + } + + out, _ := json.Marshal(payload) + return out, nil + } + + result, err := collectDeviceSMART(context.Background(), "/dev/ada0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || result.Temperature != 37 || !result.LastUpdated.Equal(fixed) { + t.Fatalf("unexpected result: %#v", result) + } + if len(attempts) != 2 { + t.Fatalf("expected 2 attempts, got %d", len(attempts)) + } + if len(attempts[0]) >= 2 && attempts[0][0] == "-d" { + t.Fatalf("expected first attempt without explicit device type, got %v", attempts[0]) + } + if len(attempts[1]) < 2 || attempts[1][0] != "-d" || attempts[1][1] != "sat" { + t.Fatalf("expected sat fallback on second attempt, got %v", attempts[1]) + } +} + func TestCollectDeviceSMARTWWN(t *testing.T) { origRun := runCommandOutput origLook := execLookPath diff --git a/internal/websocket/hub.go b/internal/websocket/hub.go index a9f57f1bd..4b32f3635 100644 --- a/internal/websocket/hub.go +++ b/internal/websocket/hub.go @@ -366,6 +366,7 @@ type Hub struct { allowedOrigins []string // Allowed origins for CORS orgAuthChecker OrgAuthChecker // Org authorization checker multiTenantChecker MultiTenantChecker // Multi-tenant feature flag and license checker + singleTenantMode bool // Ignore tenant selection and force default org // Broadcast coalescing fields coalesceWindow time.Duration coalescePending *Message @@ -411,6 +412,13 @@ func (h *Hub) SetMultiTenantChecker(checker MultiTenantChecker) { h.multiTenantChecker = checker } +// SetSingleTenantMode forces all connections to use the default org. +func (h *Hub) SetSingleTenantMode(enabled bool) { + h.mu.Lock() + defer h.mu.Unlock() + h.singleTenantMode = enabled +} + // getStateForClient returns the state for a specific client based on their tenant func (h *Hub) getStateForClient(client *Client) interface{} { h.mu.RLock() @@ -633,8 +641,16 @@ func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { h.mu.RLock() mtChecker := h.multiTenantChecker authChecker := h.orgAuthChecker + singleTenantMode := h.singleTenantMode h.mu.RUnlock() + if singleTenantMode && orgID != "" && orgID != "default" { + log.Debug(). + Str("requested_org", orgID). + Msg("Ignoring non-default org for single-tenant WebSocket runtime") + orgID = "default" + } + if orgID != "default" { // Check if multi-tenant is enabled and licensed if mtChecker != nil { diff --git a/internal/websocket/hub_multitenant_test.go b/internal/websocket/hub_multitenant_test.go index 4bb9df636..d1e53b223 100644 --- a/internal/websocket/hub_multitenant_test.go +++ b/internal/websocket/hub_multitenant_test.go @@ -9,9 +9,13 @@ import ( type fakeMultiTenantChecker struct { result MultiTenantCheckResult + called *bool } func (f fakeMultiTenantChecker) CheckMultiTenant(ctx context.Context, orgID string) MultiTenantCheckResult { + if f.called != nil { + *f.called = true + } return f.result } @@ -58,3 +62,32 @@ func TestHandleWebSocket_MultiTenantUnlicensed(t *testing.T) { t.Fatalf("expected status %d, got %d", http.StatusPaymentRequired, rec.Code) } } + +func TestHandleWebSocket_SingleTenantModeIgnoresOrgID(t *testing.T) { + hub := NewHub(nil) + hub.SetSingleTenantMode(true) + + called := false + hub.SetMultiTenantChecker(fakeMultiTenantChecker{ + called: &called, + result: MultiTenantCheckResult{ + Allowed: false, + FeatureEnabled: false, + Licensed: false, + Reason: "disabled", + }, + }) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + req.Header.Set("X-Pulse-Org-ID", "tenant-a") + rec := httptest.NewRecorder() + + hub.HandleWebSocket(rec, req) + + if called { + t.Fatal("expected single-tenant mode to bypass multi-tenant checker") + } + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected websocket upgrade failure after collapsing org to default, got %d", rec.Code) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 6fecec5c1..6a44b02c1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -96,6 +96,7 @@ func Run(ctx context.Context, version string) error { if dataDir == "" { dataDir = "/etc/pulse" } + runtimeSingleTenant := api.IsV5SingleTenantRuntime() rbacManager, err := auth.NewFileManager(dataDir) if err != nil { log.Warn().Err(err).Msg("Failed to initialize RBAC manager, role management will be unavailable") @@ -106,7 +107,7 @@ func Run(ctx context.Context, version string) error { // Run multi-tenant data migration only when the feature is explicitly enabled. // This prevents any on-disk layout changes for default (single-tenant) users. - if api.IsMultiTenantEnabled() { + if api.IsMultiTenantEnabled() && !runtimeSingleTenant { if err := config.RunMigrationIfNeeded(dataDir); err != nil { log.Error().Err(err).Msg("Multi-tenant data migration failed") // Continue anyway - migration failure shouldn't block startup @@ -147,9 +148,13 @@ func Run(ctx context.Context, version string) error { wsHub.SetAllowedOrigins([]string{}) } go wsHub.Run() + wsHub.SetSingleTenantMode(runtimeSingleTenant) // Initialize reloadable monitoring system - mtPersistence := config.NewMultiTenantPersistence(cfg.DataPath) + var mtPersistence *config.MultiTenantPersistence + if !runtimeSingleTenant { + mtPersistence = config.NewMultiTenantPersistence(cfg.DataPath) + } reloadableMonitor, err := monitoring.NewReloadableMonitor(cfg, mtPersistence, wsHub) if err != nil { return fmt.Errorf("failed to initialize monitoring system: %w", err) @@ -197,6 +202,11 @@ func Run(ctx context.Context, version string) error { // Set tenant-aware state getter for multi-tenant support wsHub.SetStateGetterForTenant(func(orgID string) interface{} { + if runtimeSingleTenant { + state := reloadableMonitor.GetMonitor().GetState() + return state.ToFrontend() + } + mtMonitor := reloadableMonitor.GetMultiTenantMonitor() if mtMonitor == nil { // Fall back to default monitor @@ -216,12 +226,14 @@ func Run(ctx context.Context, version string) error { // Set org authorization checker for WebSocket connections // This ensures clients can only subscribe to orgs they have access to - orgLoader := api.NewMultiTenantOrganizationLoader(mtPersistence) - wsHub.SetOrgAuthChecker(api.NewAuthorizationChecker(orgLoader)) + if !runtimeSingleTenant { + orgLoader := api.NewMultiTenantOrganizationLoader(mtPersistence) + wsHub.SetOrgAuthChecker(api.NewAuthorizationChecker(orgLoader)) - // Set multi-tenant checker for WebSocket connections - // This ensures the feature flag and license are checked before allowing non-default org connections - wsHub.SetMultiTenantChecker(api.NewMultiTenantChecker()) + // Set multi-tenant checker for WebSocket connections + // This ensures the feature flag and license are checked before allowing non-default org connections + wsHub.SetMultiTenantChecker(api.NewMultiTenantChecker()) + } // Wire up Prometheus metrics for alert lifecycle alerts.SetMetricHooks( @@ -243,19 +255,31 @@ func Run(ctx context.Context, version string) error { } if router != nil { router.SetMonitor(reloadableMonitor.GetMonitor()) - router.SetMultiTenantMonitor(reloadableMonitor.GetMultiTenantMonitor()) + if runtimeSingleTenant { + router.SetMultiTenantMonitor(nil) + } else { + router.SetMultiTenantMonitor(reloadableMonitor.GetMultiTenantMonitor()) + } if cfg := reloadableMonitor.GetConfig(); cfg != nil { router.SetConfig(cfg) } } return nil } - router = api.NewRouter(cfg, reloadableMonitor.GetMonitor(), reloadableMonitor.GetMultiTenantMonitor(), wsHub, reloadFunc, version) + routerMTMonitor := reloadableMonitor.GetMultiTenantMonitor() + if runtimeSingleTenant { + routerMTMonitor = nil + } + router = api.NewRouter(cfg, reloadableMonitor.GetMonitor(), routerMTMonitor, wsHub, reloadFunc, version) // Inject resource store into monitor for WebSocket broadcasts router.SetMonitor(reloadableMonitor.GetMonitor()) // Wire multi-tenant monitor to resource handlers for tenant-aware state - router.SetMultiTenantMonitor(reloadableMonitor.GetMultiTenantMonitor()) + if runtimeSingleTenant { + router.SetMultiTenantMonitor(nil) + } else { + router.SetMultiTenantMonitor(reloadableMonitor.GetMultiTenantMonitor()) + } // Create HTTP server with unified configuration srv := &http.Server{ @@ -319,7 +343,11 @@ func Run(ctx context.Context, version string) error { log.Error().Err(err).Msg("Failed to reload monitor after mock.env change") } else if router != nil { router.SetMonitor(reloadableMonitor.GetMonitor()) - router.SetMultiTenantMonitor(reloadableMonitor.GetMultiTenantMonitor()) + if runtimeSingleTenant { + router.SetMultiTenantMonitor(nil) + } else { + router.SetMultiTenantMonitor(reloadableMonitor.GetMultiTenantMonitor()) + } if cfg := reloadableMonitor.GetConfig(); cfg != nil { router.SetConfig(cfg) } diff --git a/scripts/install.sh b/scripts/install.sh index f8683ac60..2421a3c1b 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -58,17 +58,44 @@ LOG_FILE="/var/log/${AGENT_NAME}.log" QNAP=false QNAP_VOL="" +stop_qnap_agents_procfs() { + local signal="${1:-TERM}" + local proc="" + local pid="" + local cmd="" + + [[ -d /proc ]] || return 0 + + for proc in /proc/[0-9]*; do + [[ -e "$proc/cmdline" ]] || continue + pid="${proc##*/}" + [[ "$pid" == "$$" ]] && continue + + cmd="$(tr '\000' ' ' < "$proc/cmdline" 2>/dev/null || true)" + [[ -n "$cmd" ]] || continue + + case "$cmd" in + *start-pulse-agent.sh*|*/pulse-agent*|*"/usr/local/bin/pulse-agent"*) + kill "-${signal}" "$pid" 2>/dev/null || true + ;; + esac + done +} + # Stop all QNAP pulse-agent processes (wrappers first, then binaries). # Kill wrappers first to prevent watchdog respawn, then binaries. stop_qnap_agents() { # 1. Kill wrapper scripts (watchdog loops) pkill -f "start-pulse-agent\.sh" 2>/dev/null || true + stop_qnap_agents_procfs TERM sleep 1 # 2. Kill agent binaries at both possible paths (with or without leading /) pkill -f "(^|/)pulse-agent( |$)" 2>/dev/null || true + stop_qnap_agents_procfs TERM sleep 2 # 3. Verify — force-kill any survivors pkill -9 -f "(^|/)pulse-agent( |$)" 2>/dev/null || true + stop_qnap_agents_procfs KILL } # TrueNAS SCALE configuration (immutable root filesystem) @@ -1298,8 +1325,34 @@ if [[ "$QNAP" == true ]]; then # Pulse Agent startup script for QNAP # Auto-generated by Pulse installer -# Kill any running pulse-agent binary processes -pkill -f "(^|/)pulse-agent( |\$)" 2>/dev/null || true +stop_existing_agents() { + proc="" + pid="" + cmd="" + + pkill -f "start-pulse-agent\.sh" 2>/dev/null || true + pkill -f "(^|/)pulse-agent( |\$)" 2>/dev/null || true + + if [ -d /proc ]; then + for proc in /proc/[0-9]*; do + [ -e "\$proc/cmdline" ] || continue + pid="\${proc##*/}" + [ "\$pid" = "\$\$" ] && continue + + cmd=\$(tr '\000' ' ' < "\$proc/cmdline" 2>/dev/null || true) + [ -n "\$cmd" ] || continue + + case "\$cmd" in + *start-pulse-agent.sh*|*/pulse-agent*|*"/usr/local/bin/pulse-agent"*) + kill "\$pid" 2>/dev/null || true + ;; + esac + done + fi +} + +# Kill any running pulse-agent wrappers/binaries before starting this watchdog. +stop_existing_agents sleep 2 # Copy binary from persistent storage to /usr/local/bin (RAM disk, wiped on reboot).