diff --git a/internal/ai/adapters/adapters_additional_test.go b/internal/ai/adapters/adapters_additional_test.go new file mode 100644 index 000000000..f72454f7c --- /dev/null +++ b/internal/ai/adapters/adapters_additional_test.go @@ -0,0 +1,147 @@ +package adapters + +import ( + "path/filepath" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/models" + "github.com/rcourtman/pulse-go-rewrite/internal/monitoring" +) + +type stubIncidentRecorder struct { + windows []*IncidentWindowData + window *IncidentWindowData +} + +func (s *stubIncidentRecorder) GetWindowsForResource(resourceID string, limit int) []*IncidentWindowData { + return s.windows +} + +func (s *stubIncidentRecorder) GetWindow(windowID string) *IncidentWindowData { + return s.window +} + +type stubEventCorrelator struct { + correlations []EventCorrelationData + events []ProxmoxEventData +} + +func (s *stubEventCorrelator) GetCorrelationsForResource(resourceID string) []EventCorrelationData { + return s.correlations +} + +func (s *stubEventCorrelator) GetEventsForResource(resourceID string, limit int) []ProxmoxEventData { + return s.events +} + +func TestForecastDataAdapter_GetMetricHistory(t *testing.T) { + history := monitoring.NewMetricsHistory(10, time.Hour) + now := time.Now() + + history.AddGuestMetric("vm-1", "cpu", 1, now.Add(-10*time.Minute)) + history.AddGuestMetric("vm-1", "cpu", 2, now.Add(-time.Minute)) + + adapter := NewForecastDataAdapter(history) + points, err := adapter.GetMetricHistory("vm-1", "cpu", now.Add(-5*time.Minute), now) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(points) != 1 || points[0].Value != 2 { + t.Fatalf("expected filtered guest points") + } + + nodeHistory := monitoring.NewMetricsHistory(10, time.Hour) + nodeHistory.AddNodeMetric("node-1", "cpu", 3, now.Add(-time.Minute)) + nodeAdapter := NewForecastDataAdapter(nodeHistory) + points, err = nodeAdapter.GetMetricHistory("node-1", "cpu", now.Add(-5*time.Minute), now) + if err != nil || len(points) != 1 || points[0].Value != 3 { + t.Fatalf("expected node points") + } + + storageHistory := monitoring.NewMetricsHistory(10, time.Hour) + storageHistory.AddStorageMetric("store-1", "usage", 55, now.Add(-time.Minute)) + storageAdapter := NewForecastDataAdapter(storageHistory) + points, err = storageAdapter.GetMetricHistory("store-1", "usage", now.Add(-5*time.Minute), now) + if err != nil || len(points) != 1 || points[0].Value != 55 { + t.Fatalf("expected storage points") + } +} + +func TestMetricsAdapter_GetMonitoredResourceIDs(t *testing.T) { + state := models.StateSnapshot{ + VMs: []models.VM{{ID: "vm-1"}}, + Containers: []models.Container{{ID: "ct-1"}}, + Nodes: []models.Node{{ID: "node-1"}}, + } + adapter := NewMetricsAdapter(&mockStateProvider{state: state}) + ids := adapter.GetMonitoredResourceIDs() + if len(ids) != 3 { + t.Fatalf("expected 3 IDs, got %d", len(ids)) + } +} + +func TestIncidentRecorderMCPAdapter(t *testing.T) { + adapter := NewIncidentRecorderMCPAdapter(nil) + if adapter.GetWindowsForResource("res", 1) != nil { + t.Fatalf("expected nil windows for nil recorder") + } + if adapter.GetWindow("id") != nil { + t.Fatalf("expected nil window for nil recorder") + } + + recorder := &stubIncidentRecorder{ + windows: []*IncidentWindowData{{ID: "w1"}}, + window: &IncidentWindowData{ID: "w1"}, + } + adapter = NewIncidentRecorderMCPAdapter(recorder) + if len(adapter.GetWindowsForResource("res", 1)) != 1 { + t.Fatalf("expected windows from recorder") + } + if adapter.GetWindow("w1") == nil { + t.Fatalf("expected window from recorder") + } +} + +func TestEventCorrelatorMCPAdapter(t *testing.T) { + adapter := NewEventCorrelatorMCPAdapter(nil) + if adapter.GetCorrelationsForResource("res", time.Minute) != nil { + t.Fatalf("expected nil correlations for nil correlator") + } + + correlator := &stubEventCorrelator{ + correlations: []EventCorrelationData{{ID: "c1"}}, + } + adapter = NewEventCorrelatorMCPAdapter(correlator) + if len(adapter.GetCorrelationsForResource("res", time.Minute)) != 1 { + t.Fatalf("expected correlations from correlator") + } +} + +func TestKnowledgeStore_SaveLoad(t *testing.T) { + dir := t.TempDir() + store := NewKnowledgeStore(dir) + + if err := store.SaveNote("res-1", "note", "general"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + store.saveToDisk() + + loaded := NewKnowledgeStore(dir) + if err := loaded.loadFromDisk(); err != nil { + t.Fatalf("unexpected load error: %v", err) + } + entries := loaded.GetKnowledge("res-1", "general") + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } +} + +func TestKnowledgeStore_LoadMissingFile(t *testing.T) { + dir := t.TempDir() + store := NewKnowledgeStore(dir) + path := filepath.Join(dir, "knowledge_store.json") + if err := store.loadFromDisk(); err == nil { + t.Fatalf("expected error for missing file %s", path) + } +} diff --git a/internal/ai/alert_provider_additional_test.go b/internal/ai/alert_provider_additional_test.go new file mode 100644 index 000000000..3762be39c --- /dev/null +++ b/internal/ai/alert_provider_additional_test.go @@ -0,0 +1,66 @@ +package ai + +import ( + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +type stubAlertResolver struct { + alerts []AlertInfo + clears []string +} + +func (s *stubAlertResolver) GetActiveAlerts() []AlertInfo { + return s.alerts +} + +func (s *stubAlertResolver) ResolveAlert(alertID string) bool { + s.clears = append(s.clears, alertID) + return true +} + +type stubAlertManagerClear struct { + cleared []string +} + +func (s *stubAlertManagerClear) GetActiveAlerts() []alerts.Alert { + return nil +} + +func (s *stubAlertManagerClear) GetRecentlyResolved() []models.ResolvedAlert { + return nil +} + +func (s *stubAlertManagerClear) ClearAlert(alertID string) bool { + s.cleared = append(s.cleared, alertID) + return true +} + +func TestSetAlertResolverAndResolve(t *testing.T) { + resolver := &stubAlertResolver{} + service := &Service{ + patrolService: &PatrolService{}, + } + + service.SetAlertResolver(resolver) + + if service.patrolService.GetAlertResolver() != resolver { + t.Fatalf("expected resolver to be set on patrol service") + } + + manager := &stubAlertManagerClear{} + adapter := NewAlertManagerAdapter(manager) + if !adapter.ResolveAlert("alert-1") { + t.Fatalf("expected ResolveAlert to return true") + } + if len(manager.cleared) != 1 || manager.cleared[0] != "alert-1" { + t.Fatalf("expected alert-1 to be cleared, got %v", manager.cleared) + } + + adapter = NewAlertManagerAdapter(nil) + if adapter.ResolveAlert("alert-2") { + t.Fatalf("expected ResolveAlert to return false when manager nil") + } +} diff --git a/internal/ai/chat/patrol.go b/internal/ai/chat/patrol.go index ca6ce2cfa..b6d7e5f60 100644 --- a/internal/ai/chat/patrol.go +++ b/internal/ai/chat/patrol.go @@ -407,7 +407,7 @@ func (p *PatrolService) parseFindingBlock(block string) *PatrolFinding { severity = "info" } - validCategories := map[string]bool{"performance": true, "reliability": true, "security": true, "capacity": true, "configuration": true} + validCategories := map[string]bool{"performance": true, "reliability": true, "security": true, "capacity": true, "backup": true, "configuration": true} if !validCategories[category] { category = "configuration" } diff --git a/internal/ai/chat/patrol_test.go b/internal/ai/chat/patrol_test.go index 16d5da7bf..2befe5734 100644 --- a/internal/ai/chat/patrol_test.go +++ b/internal/ai/chat/patrol_test.go @@ -118,6 +118,29 @@ EVIDENCE: None } } +func TestPatrolService_ParseFindings_BackupCategory(t *testing.T) { + service := NewPatrolService(nil) + + response := ` +[FINDING] +KEY: backup-stale +SEVERITY: warning +CATEGORY: backup +RESOURCE: vm-101 +RESOURCE_TYPE: vm +TITLE: Backup stale +DESCRIPTION: No backup in 48 hours +RECOMMENDATION: Check backup jobs +EVIDENCE: Last backup: 2 days ago +[/FINDING] +` + + findings := service.parseFindings(response) + require.Len(t, findings, 1) + assert.Equal(t, "backup", findings[0].Category) + assert.Equal(t, "Backup stale", findings[0].Title) +} + // MockFindingsStore type MockFindingsStore struct { mock.Mock diff --git a/internal/ai/chat/service.go b/internal/ai/chat/service.go index 74e0b0304..36dae81d1 100644 --- a/internal/ai/chat/service.go +++ b/internal/ai/chat/service.go @@ -51,6 +51,7 @@ type ( EventCorrelatorProvider = tools.EventCorrelatorProvider TopologyProvider = tools.TopologyProvider KnowledgeStoreProvider = tools.KnowledgeStoreProvider + MCPDiscoveryProvider = tools.DiscoveryProvider ) // Config holds service configuration @@ -590,6 +591,14 @@ func (s *Service) SetKnowledgeStoreProvider(provider KnowledgeStoreProvider) { } } +func (s *Service) SetDiscoveryProvider(provider MCPDiscoveryProvider) { + s.mu.Lock() + defer s.mu.Unlock() + if s.executor != nil { + s.executor.SetDiscoveryProvider(provider) + } +} + func (s *Service) UpdateControlSettings(cfg *config.AIConfig) { if cfg == nil { return diff --git a/internal/ai/chat/service_additional_test.go b/internal/ai/chat/service_additional_test.go new file mode 100644 index 000000000..10de6be44 --- /dev/null +++ b/internal/ai/chat/service_additional_test.go @@ -0,0 +1,70 @@ +package chat + +import ( + "context" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai/tools" +) + +func TestServiceSettersAndAutonomousMode(t *testing.T) { + executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{}) + loop := &AgenticLoop{} + service := &Service{ + executor: executor, + agenticLoop: loop, + } + + service.SetIncidentRecorderProvider(nil) + service.SetEventCorrelatorProvider(nil) + service.SetTopologyProvider(nil) + service.SetKnowledgeStoreProvider(nil) + + service.SetAutonomousMode(true) + if !service.autonomousMode { + t.Fatalf("expected autonomousMode true") + } + if !loop.autonomousMode { + t.Fatalf("expected agentic loop to be autonomous") + } +} + +func TestServiceExecuteCommand_NoExecutor(t *testing.T) { + service := &Service{} + _, _, err := service.ExecuteCommand(context.Background(), "ls", "") + if err == nil { + t.Fatalf("expected error when executor is unavailable") + } +} + +func TestPatrolServiceSessionLifecycle(t *testing.T) { + store, err := NewSessionStore(t.TempDir()) + if err != nil { + t.Fatalf("NewSessionStore error: %v", err) + } + + service := &Service{ + sessions: store, + started: true, + } + + patrol := NewPatrolService(service) + if err := patrol.CreatePatrolSession(context.Background()); err != nil { + t.Fatalf("CreatePatrolSession error: %v", err) + } + if patrol.GetSessionID() == "" { + t.Fatalf("expected session ID to be set") + } + + patrol.mu.Lock() + patrol.running = true + patrol.mu.Unlock() + if !patrol.IsRunning() { + t.Fatalf("expected patrol to be running") + } + + service.started = false + if err := patrol.CreatePatrolSession(context.Background()); err == nil { + t.Fatalf("expected error when service not running") + } +} diff --git a/internal/ai/circuit/breaker_additional_test.go b/internal/ai/circuit/breaker_additional_test.go new file mode 100644 index 000000000..9cb3ca91b --- /dev/null +++ b/internal/ai/circuit/breaker_additional_test.go @@ -0,0 +1,344 @@ +package circuit + +import ( + "errors" + "testing" + "time" +) + +func TestStateString_Unknown(t *testing.T) { + if got := State(99).String(); got != "unknown" { + t.Fatalf("expected unknown, got %s", got) + } +} + +func TestNewBreaker_DefaultsApplied(t *testing.T) { + b := NewBreaker("defaults", Config{}) + + if b.config.FailureThreshold != 3 { + t.Fatalf("expected default FailureThreshold, got %d", b.config.FailureThreshold) + } + if b.config.SuccessThreshold != 2 { + t.Fatalf("expected default SuccessThreshold, got %d", b.config.SuccessThreshold) + } + if b.config.InitialBackoff != time.Second { + t.Fatalf("expected default InitialBackoff, got %s", b.config.InitialBackoff) + } + if b.config.MaxBackoff != 5*time.Minute { + t.Fatalf("expected default MaxBackoff, got %s", b.config.MaxBackoff) + } + if b.config.BackoffMultiplier != 2.0 { + t.Fatalf("expected default BackoffMultiplier, got %.1f", b.config.BackoffMultiplier) + } + if b.config.HalfOpenTimeout != 30*time.Second { + t.Fatalf("expected default HalfOpenTimeout, got %s", b.config.HalfOpenTimeout) + } +} + +func TestBreaker_CanAllow_DoesNotTransition(t *testing.T) { + b := NewBreaker("test", DefaultConfig()) + b.mu.Lock() + b.state = StateOpen + b.currentBackoff = time.Hour + b.openedAt = time.Now().Add(-2 * time.Hour) + b.mu.Unlock() + + if !b.CanAllow() { + t.Fatalf("expected CanAllow to return true after backoff") + } + if b.State() != StateOpen { + t.Fatalf("expected state to remain open on CanAllow") + } +} + +func TestBreaker_RecordFailure_InvalidDoesNotTrip(t *testing.T) { + cfg := DefaultConfig() + cfg.FailureThreshold = 1 + b := NewBreaker("test", cfg) + + b.RecordFailureWithCategory(errors.New("bad request"), ErrorCategoryInvalid) + if b.State() != StateClosed { + t.Fatalf("expected invalid error not to trip circuit") + } +} + +func TestBreaker_RecordFailure_RateLimitTrips(t *testing.T) { + cfg := DefaultConfig() + cfg.FailureThreshold = 5 + b := NewBreaker("test", cfg) + + b.RecordFailureWithCategory(errors.New("rate limit"), ErrorCategoryRateLimit) + if b.State() != StateOpen { + t.Fatalf("expected rate limit error to trip circuit") + } +} + +func TestBreaker_RecordFailure_HalfOpenBackoffCaps(t *testing.T) { + cfg := DefaultConfig() + cfg.MaxBackoff = 100 * time.Millisecond + cfg.BackoffMultiplier = 2.0 + b := NewBreaker("test", cfg) + + b.mu.Lock() + b.state = StateHalfOpen + b.currentBackoff = 80 * time.Millisecond + b.mu.Unlock() + + b.RecordFailureWithCategory(errors.New("fail"), ErrorCategoryTransient) + if b.State() != StateOpen { + t.Fatalf("expected state to be open after half-open failure") + } + if b.currentBackoff != cfg.MaxBackoff { + t.Fatalf("expected backoff to cap at max, got %s", b.currentBackoff) + } +} + +func TestBreaker_Callbacks(t *testing.T) { + cfg := DefaultConfig() + cfg.FailureThreshold = 1 + b := NewBreaker("test", cfg) + + stateCh := make(chan State, 1) + tripCh := make(chan error, 1) + b.SetOnStateChange(func(_, to State) { + stateCh <- to + }) + b.SetOnTrip(func(err error) { + tripCh <- err + }) + + testErr := errors.New("boom") + b.RecordFailure(testErr) + + select { + case state := <-stateCh: + if state != StateOpen { + t.Fatalf("expected state change to open, got %s", state.String()) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected state change callback to fire") + } + + select { + case err := <-tripCh: + if err == nil || err.Error() != testErr.Error() { + t.Fatalf("expected trip callback with error") + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected trip callback to fire") + } +} + +func TestBreaker_GetStatus_TimeUntilRetry(t *testing.T) { + cfg := DefaultConfig() + cfg.FailureThreshold = 1 + cfg.InitialBackoff = 50 * time.Millisecond + b := NewBreaker("test", cfg) + + b.RecordFailure(errors.New("fail")) + status := b.GetStatus() + + if status.State != "open" { + t.Fatalf("expected status open, got %s", status.State) + } + if status.LastError == "" { + t.Fatalf("expected last error to be set") + } + if status.TimeUntilRetry <= 0 { + t.Fatalf("expected time until retry to be set") + } +} + +func TestBreaker_ExecuteWithCategory_InvalidDoesNotTrip(t *testing.T) { + cfg := DefaultConfig() + cfg.FailureThreshold = 1 + b := NewBreaker("test", cfg) + + err := b.ExecuteWithCategory(func() (error, ErrorCategory) { + return errors.New("invalid"), ErrorCategoryInvalid + }) + if err == nil { + t.Fatalf("expected error") + } + if b.State() != StateClosed { + t.Fatalf("expected state to remain closed") + } +} + +func TestIsCircuitOpen_Additional(t *testing.T) { + if !IsCircuitOpen(ErrCircuitOpen) { + t.Fatalf("expected ErrCircuitOpen to be recognized") + } + if IsCircuitOpen(errors.New("other")) { + t.Fatalf("expected non-circuit error to be false") + } +} + +func TestCategorizeError_Additional(t *testing.T) { + tests := []struct { + err error + expected ErrorCategory + }{ + {errors.New("rate limit exceeded"), ErrorCategoryRateLimit}, + {errors.New("429 too many requests"), ErrorCategoryRateLimit}, + {errors.New("400 bad request"), ErrorCategoryInvalid}, + {errors.New("unauthorized api key"), ErrorCategoryFatal}, + {errors.New("payment required"), ErrorCategoryFatal}, + {errors.New("random failure"), ErrorCategoryTransient}, + {nil, ErrorCategoryTransient}, + } + + for _, tt := range tests { + if got := CategorizeError(tt.err); got != tt.expected { + t.Fatalf("expected %v, got %v for %v", tt.expected, got, tt.err) + } + } +} + +func TestBreaker_IsOpenClosed(t *testing.T) { + b := NewBreaker("test", DefaultConfig()) + if !b.IsClosed() || b.IsOpen() { + t.Fatalf("expected breaker to start closed") + } + + b.mu.Lock() + b.state = StateOpen + b.mu.Unlock() + + if !b.IsOpen() || b.IsClosed() { + t.Fatalf("expected breaker to report open") + } +} + +func TestCircuitOpenErrorMessage(t *testing.T) { + err := circuitOpenError{} + if err.Error() != "circuit breaker is open" { + t.Fatalf("unexpected error message: %s", err.Error()) + } +} + +func TestBreaker_CanAllow_Branches(t *testing.T) { + b := NewBreaker("test", DefaultConfig()) + if !b.CanAllow() { + t.Fatalf("expected CanAllow true for closed") + } + + b.mu.Lock() + b.state = StateOpen + b.openedAt = time.Now() + b.currentBackoff = time.Hour + b.mu.Unlock() + + if b.CanAllow() { + t.Fatalf("expected CanAllow false before backoff elapses") + } + + b.mu.Lock() + b.openedAt = time.Now().Add(-2 * time.Hour) + b.mu.Unlock() + if !b.CanAllow() { + t.Fatalf("expected CanAllow true after backoff") + } + + b.mu.Lock() + b.state = StateHalfOpen + b.mu.Unlock() + if !b.CanAllow() { + t.Fatalf("expected CanAllow true for half-open") + } +} + +func TestBreaker_ExecuteWithCategory_SuccessAndOpen(t *testing.T) { + b := NewBreaker("test", DefaultConfig()) + if err := b.ExecuteWithCategory(func() (error, ErrorCategory) { + return nil, ErrorCategoryTransient + }); err != nil { + t.Fatalf("expected success, got %v", err) + } + + cfg := DefaultConfig() + cfg.FailureThreshold = 1 + cfg.InitialBackoff = time.Hour + b = NewBreaker("test", cfg) + b.RecordFailure(errors.New("fail")) + if err := b.ExecuteWithCategory(func() (error, ErrorCategory) { + return nil, ErrorCategoryTransient + }); err != ErrCircuitOpen { + t.Fatalf("expected ErrCircuitOpen, got %v", err) + } +} + +func TestStateString_All(t *testing.T) { + cases := map[State]string{ + StateClosed: "closed", + StateOpen: "open", + StateHalfOpen: "half-open", + } + for state, expected := range cases { + if state.String() != expected { + t.Fatalf("expected %s for state %d", expected, state) + } + } +} + +func TestBreaker_Allow_TransitionsOpenToHalfOpen(t *testing.T) { + cfg := DefaultConfig() + cfg.InitialBackoff = 10 * time.Millisecond + b := NewBreaker("test", cfg) + + b.mu.Lock() + b.state = StateOpen + b.openedAt = time.Now().Add(-time.Second) + b.currentBackoff = 10 * time.Millisecond + b.mu.Unlock() + + if !b.Allow() { + t.Fatalf("expected Allow to return true after backoff") + } + if b.State() != StateHalfOpen { + t.Fatalf("expected state to transition to half-open") + } +} + +func TestBreaker_TransitionTo_NoOp(t *testing.T) { + b := NewBreaker("test", DefaultConfig()) + b.transitionTo(StateClosed) + if b.State() != StateClosed { + t.Fatalf("expected state to remain closed") + } +} + +func TestBreaker_Allow_BlocksBeforeBackoff(t *testing.T) { + b := NewBreaker("test", DefaultConfig()) + b.mu.Lock() + b.state = StateOpen + b.openedAt = time.Now() + b.currentBackoff = time.Hour + b.mu.Unlock() + + if b.Allow() { + t.Fatalf("expected Allow to return false before backoff elapses") + } + if b.State() != StateOpen { + t.Fatalf("expected state to remain open") + } +} + +func TestBreaker_Allow_HalfOpen(t *testing.T) { + b := NewBreaker("test", DefaultConfig()) + b.mu.Lock() + b.state = StateHalfOpen + b.mu.Unlock() + if !b.Allow() { + t.Fatalf("expected Allow true in half-open") + } +} + +func TestToLower(t *testing.T) { + if toLower("AbC123") != "abc123" { + t.Fatalf("expected toLower to normalize casing") + } + if toLower("lower") != "lower" { + t.Fatalf("expected lowercase input to remain unchanged") + } +} diff --git a/internal/ai/correlation/correlation_additional_test.go b/internal/ai/correlation/correlation_additional_test.go new file mode 100644 index 000000000..6f37d1a67 --- /dev/null +++ b/internal/ai/correlation/correlation_additional_test.go @@ -0,0 +1,184 @@ +package correlation + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestDetector_RecordEventAndCorrelation(t *testing.T) { + cfg := Config{ + MaxEvents: 10, + CorrelationWindow: time.Minute, + MinOccurrences: 1, + RetentionWindow: time.Hour, + } + d := NewDetector(cfg) + + start := time.Now() + d.RecordEvent(Event{ResourceID: "a", EventType: EventAlert, Timestamp: start}) + d.RecordEvent(Event{ResourceID: "b", EventType: EventRestart, Timestamp: start.Add(10 * time.Second)}) + d.RecordEvent(Event{ResourceID: "a", EventType: EventAlert, Timestamp: start.Add(20 * time.Second)}) + d.RecordEvent(Event{ResourceID: "b", EventType: EventRestart, Timestamp: start.Add(30 * time.Second)}) + + corrs := d.GetCorrelations() + if len(corrs) == 0 { + t.Fatalf("expected correlations") + } + if corrs[0].Occurrences < 2 { + t.Fatalf("expected occurrences to increase") + } +} + +func TestDetector_ConfidenceAndFormatting(t *testing.T) { + d := NewDetector(Config{MinOccurrences: 3}) + if d.calculateConfidence(1) != 0.1 { + t.Fatalf("expected low confidence") + } + if d.calculateConfidence(3) < 0.3 { + t.Fatalf("expected baseline confidence for threshold") + } + + c := &Correlation{ + SourceID: "src", + TargetID: "dst", + EventPattern: "alert -> restart", + AvgDelay: 2 * time.Minute, + } + desc := d.formatCorrelationDescription(c) + if !strings.Contains(desc, "src") || !strings.Contains(desc, "dst") { + t.Fatalf("expected description to use IDs") + } + + if formatDuration(30*time.Second) != "seconds" { + t.Fatalf("expected seconds duration") + } + if formatDuration(1*time.Minute) != "1 minute" { + t.Fatalf("expected singular minute") + } + if formatDuration(2*time.Hour) != "2 hours" { + t.Fatalf("expected hour format") + } + if formatConfidence(0.5) != "50%" { + t.Fatalf("expected confidence percent") + } +} + +func TestDetector_DependencyQueries(t *testing.T) { + d := NewDetector(Config{MinOccurrences: 1}) + d.correlations[correlationKey("a", "b", EventAlert, EventRestart)] = &Correlation{ + SourceID: "a", + TargetID: "b", + Occurrences: 1, + EventPattern: "alert -> restart", + Confidence: 0.5, + } + + if len(d.GetDependencies("a")) != 1 { + t.Fatalf("expected dependency") + } + if len(d.GetDependsOn("b")) != 1 { + t.Fatalf("expected depends-on") + } + + predictions := d.PredictCascade("a", EventAlert) + if len(predictions) != 1 { + t.Fatalf("expected cascade prediction") + } +} + +func TestDetector_FormatForContextAndTrim(t *testing.T) { + d := NewDetector(Config{MinOccurrences: 1}) + if d.FormatForContext("") != "" { + t.Fatalf("expected empty context without correlations") + } + + d.correlations["k1"] = &Correlation{ + SourceID: "a", + TargetID: "b", + Occurrences: 2, + Confidence: 0.5, + EventPattern: "alert -> restart", + Description: "desc", + } + out := d.FormatForContext("") + if !strings.Contains(out, "desc") { + t.Fatalf("expected description in context") + } + + d.maxEvents = 2 + d.retentionWindow = time.Minute + d.events = []Event{ + {Timestamp: time.Now().Add(-2 * time.Minute)}, + {Timestamp: time.Now()}, + {Timestamp: time.Now()}, + } + d.trimEvents() + if len(d.events) != 2 { + t.Fatalf("expected trimmed events") + } +} + +func TestDetector_SaveLoadAndLargeFile(t *testing.T) { + dir := t.TempDir() + d := NewDetector(Config{ + DataDir: dir, + MinOccurrences: 1, + RetentionWindow: time.Hour, + }) + d.events = []Event{{ID: "e1", ResourceID: "r1", EventType: EventAlert, Timestamp: time.Now()}} + d.correlations["k1"] = &Correlation{ + SourceID: "r1", + TargetID: "r2", + Occurrences: 1, + LastSeen: time.Now(), + } + + if err := d.saveToDisk(); err != nil { + t.Fatalf("unexpected save error: %v", err) + } + + loaded := NewDetector(Config{DataDir: dir, MinOccurrences: 1}) + if len(loaded.events) == 0 { + t.Fatalf("expected events to load") + } + + path := filepath.Join(dir, "ai_correlations.json") + large := make([]byte, (10<<20)+1) + if err := os.WriteFile(path, large, 0600); err != nil { + t.Fatalf("write failed: %v", err) + } + if err := loaded.loadFromDisk(); err == nil { + t.Fatalf("expected error for large file") + } +} + +func TestDetector_LoadMissingAndInvalid(t *testing.T) { + dir := t.TempDir() + d := NewDetector(Config{DataDir: dir}) + if err := d.loadFromDisk(); err != nil { + t.Fatalf("unexpected error for missing file: %v", err) + } + + path := filepath.Join(dir, "ai_correlations.json") + if err := os.WriteFile(path, []byte("{bad"), 0600); err != nil { + t.Fatalf("write failed: %v", err) + } + if err := d.loadFromDisk(); err == nil { + t.Fatalf("expected error for invalid json") + } +} + +func TestDetector_HelperFunctions(t *testing.T) { + if generateEventID() == "" { + t.Fatalf("expected event id") + } + if intToStr(12) != "12" || intToStr(0) != "0" { + t.Fatalf("expected intToStr") + } + if correlationKey("a", "b", EventAlert, EventRestart) == "" { + t.Fatalf("expected correlation key") + } +} diff --git a/internal/ai/correlation/rootcause_additional_test.go b/internal/ai/correlation/rootcause_additional_test.go new file mode 100644 index 000000000..e8a4a1cec --- /dev/null +++ b/internal/ai/correlation/rootcause_additional_test.go @@ -0,0 +1,166 @@ +package correlation + +import ( + "strings" + "testing" + "time" +) + +type stubTopologyProvider struct { + relationships []ResourceRelationship + types map[string]string + names map[string]string +} + +func (s *stubTopologyProvider) GetRelationships(resourceID string) []ResourceRelationship { + return s.relationships +} + +func (s *stubTopologyProvider) GetResourceType(resourceID string) string { + return s.types[resourceID] +} + +func (s *stubTopologyProvider) GetResourceName(resourceID string) string { + return s.names[resourceID] +} + +type stubEventProvider struct { + events map[string][]RelatedEvent +} + +func (s *stubEventProvider) GetRecentEvents(resourceID string, window time.Duration) []RelatedEvent { + return s.events[resourceID] +} + +func TestRootCauseEngine_DefaultsAndNilProviders(t *testing.T) { + engine := NewRootCauseEngine(RootCauseEngineConfig{}) + if engine.config.CorrelationWindow == 0 || engine.config.MaxChainLength == 0 { + t.Fatalf("expected defaults to be applied") + } + if engine.Analyze(RelatedEvent{}) != nil { + t.Fatalf("expected nil analysis without providers") + } +} + +func TestRootCauseEngine_AnalyzeAndFormat(t *testing.T) { + triggerTime := time.Now() + topology := &stubTopologyProvider{ + relationships: []ResourceRelationship{ + {SourceID: "node-1", TargetID: "vm-1", Relationship: RelationshipRunsOn}, + }, + types: map[string]string{ + "node-1": "node", + "vm-1": "vm", + }, + names: map[string]string{ + "node-1": "Node 1", + "vm-1": "VM 1", + }, + } + events := &stubEventProvider{ + events: map[string][]RelatedEvent{ + "node-1": { + { + ResourceID: "node-1", + ResourceType: "node", + EventType: "alert", + Metric: "cpu", + Value: 95, + Timestamp: triggerTime.Add(-30 * time.Second), + Description: "CPU spike", + }, + }, + }, + } + + engine := NewRootCauseEngine(DefaultRootCauseEngineConfig()) + engine.config.MinConfidence = 0.1 + engine.SetTopologyProvider(topology) + engine.SetEventProvider(events) + + trigger := RelatedEvent{ + ResourceID: "vm-1", + ResourceName: "VM 1", + ResourceType: "vm", + EventType: "alert", + Metric: "cpu", + Timestamp: triggerTime, + Description: "VM alert", + } + + analysis := engine.Analyze(trigger) + if analysis == nil || analysis.RootCause == nil { + t.Fatalf("expected root cause analysis") + } + if analysis.Explanation == "" { + t.Fatalf("expected explanation") + } + if analysis.Confidence <= 0 { + t.Fatalf("expected confidence") + } + + context := engine.FormatForContext("vm-1") + if !strings.Contains(context, "confidence") { + t.Fatalf("expected context output") + } + + patrol := engine.FormatAnalysisForPatrol() + if patrol == "" { + t.Fatalf("expected patrol output") + } +} + +func TestRootCauseEngine_ScoringAndHelpers(t *testing.T) { + engine := NewRootCauseEngine(DefaultRootCauseEngineConfig()) + + trigger := RelatedEvent{ResourceID: "vm-1", ResourceType: "vm", Metric: "cpu", Timestamp: time.Now()} + candidate := RelatedEvent{ResourceID: "node-1", ResourceType: "node", Metric: "cpu", Timestamp: time.Now().Add(-30 * time.Second)} + relationships := []ResourceRelationship{{SourceID: "node-1", TargetID: "vm-1", Relationship: RelationshipRunsOn}} + + score := engine.scoreAsRootCause(&candidate, trigger, relationships) + if score <= 0 { + t.Fatalf("expected score > 0") + } + + root := engine.identifyRootCause(trigger, []RelatedEvent{candidate}, relationships) + if root == nil { + t.Fatalf("expected root cause") + } + + chain := engine.buildCausalChain(root, trigger, []RelatedEvent{candidate}, relationships) + if len(chain) < 2 { + t.Fatalf("expected causal chain") + } + + confidence := engine.calculateConfidence(&RootCauseAnalysis{RootCause: root, RelatedEvents: []RelatedEvent{candidate}, CausalChain: chain, TriggerEvent: trigger}) + if confidence <= 0 { + t.Fatalf("expected confidence") + } + + if formatEventForChain(&RelatedEvent{ResourceID: "r1", Description: "desc"}) == "" { + t.Fatalf("expected formatted chain event") + } + if !isRelatedMetric("cpu", "load") || isRelatedMetric("cpu", "disk") { + t.Fatalf("unexpected metric relation result") + } + if minFloat(1, 2) != 1 { + t.Fatalf("expected min float") + } +} + +func TestRootCauseEngine_AnalysisQueries(t *testing.T) { + engine := NewRootCauseEngine(DefaultRootCauseEngineConfig()) + engine.recentAnalyses = []RootCauseAnalysis{ + { + ID: "a1", + TriggerEvent: RelatedEvent{ResourceID: "vm-1"}, + }, + } + + if len(engine.GetRecentAnalyses(1)) != 1 { + t.Fatalf("expected recent analysis") + } + if len(engine.GetAnalysisForResource("vm-1")) != 1 { + t.Fatalf("expected analysis for resource") + } +} diff --git a/internal/ai/discovery_adapter.go b/internal/ai/discovery_adapter.go new file mode 100644 index 000000000..cd0ae0f44 --- /dev/null +++ b/internal/ai/discovery_adapter.go @@ -0,0 +1,184 @@ +package ai + +import ( + "context" + + "github.com/rcourtman/pulse-go-rewrite/internal/agentexec" + "github.com/rcourtman/pulse-go-rewrite/internal/aidiscovery" + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +// discoveryCommandAdapter adapts agentexec.Server to aidiscovery.CommandExecutor +type discoveryCommandAdapter struct { + server *agentexec.Server +} + +// newDiscoveryCommandAdapter creates a new adapter +func newDiscoveryCommandAdapter(server *agentexec.Server) *discoveryCommandAdapter { + return &discoveryCommandAdapter{server: server} +} + +// ExecuteCommand implements aidiscovery.CommandExecutor +func (a *discoveryCommandAdapter) ExecuteCommand(ctx context.Context, agentID string, cmd aidiscovery.ExecuteCommandPayload) (*aidiscovery.CommandResultPayload, error) { + if a.server == nil { + return &aidiscovery.CommandResultPayload{ + RequestID: cmd.RequestID, + Success: false, + Error: "agent server not available", + }, nil + } + + // Convert to agentexec types + execCmd := agentexec.ExecuteCommandPayload{ + RequestID: cmd.RequestID, + Command: cmd.Command, + TargetType: cmd.TargetType, + TargetID: cmd.TargetID, + Timeout: cmd.Timeout, + } + + result, err := a.server.ExecuteCommand(ctx, agentID, execCmd) + if err != nil { + return &aidiscovery.CommandResultPayload{ + RequestID: cmd.RequestID, + Success: false, + Error: err.Error(), + }, nil + } + + // Convert result back + return &aidiscovery.CommandResultPayload{ + RequestID: result.RequestID, + Success: result.Success, + Stdout: result.Stdout, + Stderr: result.Stderr, + ExitCode: result.ExitCode, + Error: result.Error, + Duration: result.Duration, + }, nil +} + +// GetConnectedAgents implements aidiscovery.CommandExecutor +func (a *discoveryCommandAdapter) GetConnectedAgents() []aidiscovery.ConnectedAgent { + if a.server == nil { + return nil + } + + agents := a.server.GetConnectedAgents() + result := make([]aidiscovery.ConnectedAgent, len(agents)) + for i, agent := range agents { + result[i] = aidiscovery.ConnectedAgent{ + AgentID: agent.AgentID, + Hostname: agent.Hostname, + Version: agent.Version, + Platform: agent.Platform, + Tags: agent.Tags, + ConnectedAt: agent.ConnectedAt, + } + } + return result +} + +// IsAgentConnected implements aidiscovery.CommandExecutor +func (a *discoveryCommandAdapter) IsAgentConnected(agentID string) bool { + if a.server == nil { + return false + } + for _, agent := range a.server.GetConnectedAgents() { + if agent.AgentID == agentID { + return true + } + } + return false +} + +// discoveryStateAdapter adapts StateProvider to aidiscovery.StateProvider +type discoveryStateAdapter struct { + provider StateProvider +} + +// newDiscoveryStateAdapter creates a new state adapter +func newDiscoveryStateAdapter(provider StateProvider) *discoveryStateAdapter { + return &discoveryStateAdapter{provider: provider} +} + +// GetState implements aidiscovery.StateProvider +func (a *discoveryStateAdapter) GetState() aidiscovery.StateSnapshot { + if a.provider == nil { + return aidiscovery.StateSnapshot{} + } + + state := a.provider.GetState() + + // Convert VMs + vms := make([]aidiscovery.VM, len(state.VMs)) + for i, vm := range state.VMs { + vms[i] = aidiscovery.VM{ + VMID: vm.VMID, + Name: vm.Name, + Node: vm.Node, + Status: vm.Status, + Instance: vm.Instance, + } + } + + // Convert Containers + containers := make([]aidiscovery.Container, len(state.Containers)) + for i, c := range state.Containers { + containers[i] = aidiscovery.Container{ + VMID: c.VMID, + Name: c.Name, + Node: c.Node, + Status: c.Status, + Instance: c.Instance, + } + } + + // Convert Docker hosts + dockerHosts := make([]aidiscovery.DockerHost, len(state.DockerHosts)) + for i, dh := range state.DockerHosts { + containers := make([]aidiscovery.DockerContainer, len(dh.Containers)) + for j, dc := range dh.Containers { + ports := make([]aidiscovery.DockerPort, len(dc.Ports)) + for k, p := range dc.Ports { + ports[k] = aidiscovery.DockerPort{ + PublicPort: p.PublicPort, + PrivatePort: p.PrivatePort, + Protocol: p.Protocol, + } + } + mounts := make([]aidiscovery.DockerMount, len(dc.Mounts)) + for k, m := range dc.Mounts { + mounts[k] = aidiscovery.DockerMount{ + Source: m.Source, + Destination: m.Destination, + } + } + containers[j] = aidiscovery.DockerContainer{ + ID: dc.ID, + Name: dc.Name, + Image: dc.Image, + Status: dc.Status, + Ports: ports, + Labels: dc.Labels, + Mounts: mounts, + } + } + dockerHosts[i] = aidiscovery.DockerHost{ + AgentID: dh.AgentID, + Hostname: dh.Hostname, + Containers: containers, + } + } + + return aidiscovery.StateSnapshot{ + VMs: vms, + Containers: containers, + DockerHosts: dockerHosts, + } +} + +// StateProvider interface expected by the adapter (mirrors models.StateSnapshot fields) +type discoveryStateProviderInterface interface { + GetState() models.StateSnapshot +} diff --git a/internal/ai/findings.go b/internal/ai/findings.go index 755f1df11..4d3305167 100644 --- a/internal/ai/findings.go +++ b/internal/ai/findings.go @@ -822,12 +822,11 @@ func (s *FindingsStore) GetAll(startTime *time.Time) []*Finding { // Returns the number of findings removed func (s *FindingsStore) ClearAll() int { s.mu.Lock() - defer s.mu.Unlock() - count := len(s.findings) s.findings = make(map[string]*Finding) s.byResource = make(map[string][]string) s.activeCounts = make(map[FindingSeverity]int) + s.mu.Unlock() s.scheduleSave() return count } @@ -975,9 +974,8 @@ func (s FindingsSummary) IsHealthy() bool { // AddSuppressionRule creates a new user-defined suppression rule func (s *FindingsStore) AddSuppressionRule(resourceID, resourceName string, category FindingCategory, description string) *SuppressionRule { s.mu.Lock() - defer s.mu.Unlock() - rule := s.addSuppressionRuleInternal(resourceID, resourceName, category, description, "manual") + s.mu.Unlock() s.scheduleSave() return rule } @@ -1053,11 +1051,11 @@ func (s *FindingsStore) GetSuppressionRules() []*SuppressionRule { // DeleteSuppressionRule removes a suppression rule func (s *FindingsStore) DeleteSuppressionRule(ruleID string) bool { s.mu.Lock() - defer s.mu.Unlock() // Check if it's an explicit rule if _, exists := s.suppressionRules[ruleID]; exists { delete(s.suppressionRules, ruleID) + s.mu.Unlock() s.scheduleSave() return true } @@ -1075,11 +1073,13 @@ func (s *FindingsStore) DeleteSuppressionRule(ruleID string) bool { if !wasActive && f.IsActive() { s.activeCounts[f.Severity]++ } + s.mu.Unlock() s.scheduleSave() return true } } + s.mu.Unlock() return false } diff --git a/internal/ai/findings_additional_test.go b/internal/ai/findings_additional_test.go new file mode 100644 index 000000000..94e6839f0 --- /dev/null +++ b/internal/ai/findings_additional_test.go @@ -0,0 +1,142 @@ +package ai + +import ( + "testing" + "time" +) + +func TestFinding_ShouldInvestigate(t *testing.T) { + now := time.Now() + old := now.Add(-2 * time.Hour) + + base := &Finding{ + ID: "f1", + Severity: FindingSeverityWarning, + Category: FindingCategoryPerformance, + ResourceID: "r1", + } + + if !base.ShouldInvestigate("approval") { + t.Fatalf("expected base finding to be investigated") + } + + if base.ShouldInvestigate("") { + t.Fatalf("expected autonomy disabled to skip investigation") + } + + base.ResolvedAt = &now + if base.ShouldInvestigate("approval") { + t.Fatalf("expected resolved finding to skip investigation") + } + base.ResolvedAt = nil + + base.Suppressed = true + if base.ShouldInvestigate("approval") { + t.Fatalf("expected suppressed finding to skip investigation") + } + base.Suppressed = false + + base.DismissedReason = "not_an_issue" + if base.ShouldInvestigate("approval") { + t.Fatalf("expected dismissed finding to skip investigation") + } + base.DismissedReason = "" + + future := now.Add(time.Minute) + base.SnoozedUntil = &future + if base.ShouldInvestigate("approval") { + t.Fatalf("expected snoozed finding to skip investigation") + } + base.SnoozedUntil = nil + + base.Severity = FindingSeverityInfo + if base.ShouldInvestigate("approval") { + t.Fatalf("expected info severity to skip investigation") + } + base.Severity = FindingSeverityWarning + + base.InvestigationStatus = string(InvestigationStatusRunning) + if base.ShouldInvestigate("approval") { + t.Fatalf("expected running investigation to skip") + } + base.InvestigationStatus = "" + + base.InvestigationAttempts = 3 + if base.ShouldInvestigate("approval") { + t.Fatalf("expected max attempts to skip") + } + base.InvestigationAttempts = 0 + + base.LastInvestigatedAt = &now + if base.ShouldInvestigate("approval") { + t.Fatalf("expected cooldown to skip") + } + base.LastInvestigatedAt = &old + if !base.ShouldInvestigate("approval") { + t.Fatalf("expected investigation after cooldown") + } +} + +func TestFindingInvestigationHelpers(t *testing.T) { + f := &Finding{} + f.InvestigationStatus = string(InvestigationStatusRunning) + if !f.IsBeingInvestigated() { + t.Fatalf("expected IsBeingInvestigated true") + } + + f.InvestigationAttempts = 2 + old := time.Now().Add(-2 * time.Hour) + f.LastInvestigatedAt = &old + f.InvestigationStatus = "" + if !f.CanRetryInvestigation() { + t.Fatalf("expected retry to be allowed") + } + + f.InvestigationAttempts = 3 + if f.CanRetryInvestigation() { + t.Fatalf("expected retry blocked by max attempts") + } +} + +func TestFinding_Getters(t *testing.T) { + ts := time.Now() + f := &Finding{ + ID: "f1", + Severity: FindingSeverityCritical, + Category: FindingCategoryBackup, + ResourceID: "r1", + ResourceName: "db-1", + ResourceType: "vm", + Title: "Backup missing", + Description: "no backups", + Recommendation: "configure backups", + Evidence: "pbs: none", + InvestigationSessionID: "sess-1", + InvestigationStatus: string(InvestigationStatusFailed), + InvestigationOutcome: string(InvestigationOutcomeCannotFix), + LastInvestigatedAt: &ts, + InvestigationAttempts: 2, + } + + if f.GetID() != "f1" || f.GetSeverity() != string(FindingSeverityCritical) || f.GetCategory() != string(FindingCategoryBackup) { + t.Fatalf("unexpected basic getters") + } + if f.GetResourceID() != "r1" || f.GetResourceName() != "db-1" || f.GetResourceType() != "vm" { + t.Fatalf("unexpected resource getters") + } + if f.GetTitle() != "Backup missing" || f.GetDescription() != "no backups" { + t.Fatalf("unexpected title/description getters") + } + if f.GetRecommendation() != "configure backups" || f.GetEvidence() != "pbs: none" { + t.Fatalf("unexpected recommendation/evidence getters") + } + if f.GetInvestigationSessionID() != "sess-1" || f.GetInvestigationStatus() != string(InvestigationStatusFailed) { + t.Fatalf("unexpected investigation getters") + } + if f.GetInvestigationOutcome() != string(InvestigationOutcomeCannotFix) || f.GetInvestigationAttempts() != 2 { + t.Fatalf("unexpected investigation outcome/attempts") + } + if f.GetLastInvestigatedAt() == nil { + t.Fatalf("expected last investigated timestamp") + } +} diff --git a/internal/ai/findings_test.go b/internal/ai/findings_test.go index 1a02bbaf7..5d67fd738 100644 --- a/internal/ai/findings_test.go +++ b/internal/ai/findings_test.go @@ -367,6 +367,54 @@ func TestFindingsStore_Cleanup(t *testing.T) { } } +func TestFindingsStore_SuppressedPersistsInContextAndCleanup(t *testing.T) { + store := NewFindingsStore() + old := time.Now().Add(-60 * 24 * time.Hour) + + store.findings["suppressed"] = &Finding{ + ID: "suppressed", + Title: "Suppressed Finding", + ResourceName: "host-1", + Suppressed: true, + DismissedReason: "not_an_issue", + LastSeenAt: old, + } + + removed := store.Cleanup(24 * time.Hour) + if removed != 0 { + t.Errorf("Expected 0 findings removed, got %d", removed) + } + if store.Get("suppressed") == nil { + t.Error("suppressed finding should NOT have been removed") + } + + ctx := store.GetDismissedForContext() + if !strings.Contains(ctx, "Suppressed Finding") { + t.Error("expected suppressed finding to remain in context") + } +} + +func TestFindingsStore_Cleanup_RemovesOldDismissed(t *testing.T) { + store := NewFindingsStore() + old := time.Now().Add(-31 * 24 * time.Hour) + + store.findings["dismissed"] = &Finding{ + ID: "dismissed", + Title: "Dismissed Finding", + ResourceName: "host-2", + DismissedReason: "expected_behavior", + LastSeenAt: old, + } + + removed := store.Cleanup(24 * time.Hour) + if removed != 1 { + t.Errorf("Expected 1 finding removed, got %d", removed) + } + if store.Get("dismissed") != nil { + t.Error("dismissed finding should have been removed") + } +} + func TestFindingsStore_GetDismissedForContext(t *testing.T) { store := NewFindingsStore() diff --git a/internal/ai/forecast/service_additional_test.go b/internal/ai/forecast/service_additional_test.go new file mode 100644 index 000000000..b5fc935fc --- /dev/null +++ b/internal/ai/forecast/service_additional_test.go @@ -0,0 +1,504 @@ +package forecast + +import ( + "testing" + "time" +) + +type mockStateProvider struct { + state StateSnapshot +} + +func (m mockStateProvider) GetState() StateSnapshot { + return m.state +} + +func buildLinearData(end time.Time, points int, step time.Duration, startValue, stepValue float64) []MetricDataPoint { + data := make([]MetricDataPoint, points) + start := end.Add(-time.Duration(points-1) * step) + for i := 0; i < points; i++ { + data[i] = MetricDataPoint{ + Timestamp: start.Add(time.Duration(i) * step), + Value: startValue + float64(i)*stepValue, + } + } + return data +} + +func TestNewService_DefaultsApplied(t *testing.T) { + svc := NewService(ForecastConfig{}) + + if svc.config.ShortTermWindow <= 0 { + t.Error("expected ShortTermWindow default to be set") + } + if svc.config.MediumTermWindow <= 0 { + t.Error("expected MediumTermWindow default to be set") + } + if svc.config.LongTermWindow <= 0 { + t.Error("expected LongTermWindow default to be set") + } + if svc.config.DefaultHorizon <= 0 { + t.Error("expected DefaultHorizon default to be set") + } + if svc.config.MaxHorizon <= 0 { + t.Error("expected MaxHorizon default to be set") + } + if svc.config.StableThreshold <= 0 { + t.Error("expected StableThreshold default to be set") + } + if svc.config.VolatileThreshold <= 0 { + t.Error("expected VolatileThreshold default to be set") + } +} + +func TestIsPercentageMetric(t *testing.T) { + cases := map[string]bool{ + "cpu": true, + "CPU": true, + "memory": true, + "mem": true, + "disk": true, + "iops": false, + } + + for metric, expected := range cases { + if got := isPercentageMetric(metric); got != expected { + t.Errorf("metric %q expected %v, got %v", metric, expected, got) + } + } +} + +func TestCalculateTrend_Volatile(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + now := time.Now() + data := []MetricDataPoint{ + {Timestamp: now.Add(-5 * time.Hour), Value: 10}, + {Timestamp: now.Add(-4 * time.Hour), Value: 120}, + {Timestamp: now.Add(-3 * time.Hour), Value: 15}, + {Timestamp: now.Add(-2 * time.Hour), Value: 130}, + {Timestamp: now.Add(-1 * time.Hour), Value: 20}, + {Timestamp: now, Value: 140}, + } + + trend := svc.calculateTrend(data, DefaultForecastConfig()) + if trend.Direction != TrendVolatile { + t.Fatalf("expected volatile trend, got %s", trend.Direction) + } +} + +func TestDetectSeasonality_InsufficientData(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + now := time.Now() + data := buildLinearData(now, 24, time.Hour, 10, 0) + + if seasonality := svc.detectSeasonality(data); seasonality != nil { + t.Fatalf("expected nil seasonality for insufficient data") + } +} + +func TestDetectSeasonality_DailyPeaks(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + base := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC) + var data []MetricDataPoint + for day := 0; day < 3; day++ { + for hour := 0; hour < 24; hour++ { + value := 1.0 + if hour == 14 { + value = 100.0 + } + data = append(data, MetricDataPoint{ + Timestamp: base.Add(time.Duration(day*24+hour) * time.Hour), + Value: value, + }) + } + } + + seasonality := svc.detectSeasonality(data) + if seasonality == nil || !seasonality.HasDaily { + t.Fatalf("expected daily seasonality to be detected") + } + if seasonality.HasWeekly { + t.Fatalf("expected weekly seasonality to be false") + } + foundPeak := false + for _, hour := range seasonality.PeakHours { + if hour == 14 { + foundPeak = true + break + } + } + if !foundPeak { + t.Fatalf("expected peak hour 14 to be detected") + } +} + +func TestGenerateDescription_WithThresholdAndAcceleration(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + threshold := 90.0 + timeToThreshold := 4 * time.Hour + trend := Trend{ + Direction: TrendIncreasing, + RatePerDay: 12.5, + Acceleration: 1.0, + } + + desc := svc.generateDescription("cpu", 70, 80, trend, &timeToThreshold, threshold) + if !containsStr(desc, "increasing") { + t.Fatalf("expected increasing text in description") + } + if !containsStr(desc, "Will reach 90% in 4 hours") { + t.Fatalf("expected time-to-threshold text in description: %s", desc) + } + if !containsStr(desc, "accelerating") { + t.Fatalf("expected accelerating text in description") + } +} + +func TestGenerateDescription_WeeksAndDecelerating(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + timeToThreshold := 14 * 24 * time.Hour + trend := Trend{ + Direction: TrendDecreasing, + RatePerDay: -2.0, + Acceleration: -1.2, + } + + desc := svc.generateDescription("disk", 80, 60, trend, &timeToThreshold, 50) + if !containsStr(desc, "decreasing") { + t.Fatalf("expected decreasing text in description") + } + if !containsStr(desc, "weeks") { + t.Fatalf("expected weeks time-to-threshold text") + } + if !containsStr(desc, "decelerating") { + t.Fatalf("expected decelerating text in description") + } +} + +func TestCalculateConfidence_ClampHigh(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + now := time.Now() + data := buildLinearData(now, 1000, time.Minute, 10, 0) + trend := Trend{Direction: TrendStable} + + confidence := svc.calculateConfidence(data, trend) + if confidence != 0.95 { + t.Fatalf("expected confidence to be clamped at 0.95, got %.2f", confidence) + } +} + +func TestCalculateConfidence_VolatileAcceleration(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + now := time.Now() + data := buildLinearData(now, 10, time.Minute, 10, 5) + trend := Trend{Direction: TrendVolatile, Acceleration: 2.0} + + confidence := svc.calculateConfidence(data, trend) + if confidence >= 0.5 { + t.Fatalf("expected lower confidence for volatile acceleration, got %.2f", confidence) + } +} + +func TestFormatForContext_LowConfidenceNote(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + forecasts := []*Forecast{ + { + ResourceID: "vm-201", + Metric: "cpu", + Trend: Trend{Direction: TrendIncreasing}, + Description: "CPU is increasing", + Confidence: 0.2, + }, + } + + context := svc.FormatForContext(forecasts) + if !containsStr(context, "low confidence") { + t.Fatalf("expected low confidence note in context") + } +} + +func TestFormatKeyForecasts_NoProviders(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + + if result := svc.FormatKeyForecasts(); result != "" { + t.Fatalf("expected empty result when providers are missing") + } +} + +func TestFormatKeyForecasts_NoStateProvider(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + svc.SetDataProvider(&mockDataProvider{data: map[string][]MetricDataPoint{}}) + if result := svc.FormatKeyForecasts(); result != "" { + t.Fatalf("expected empty result when state provider is missing") + } +} + +func TestFormatKeyForecasts_Concerns(t *testing.T) { + cfg := DefaultForecastConfig() + cfg.VolatileThreshold = 100.0 + svc := NewService(cfg) + + now := time.Now() + data := buildLinearData(now, 24, time.Hour, 70, 1.0) + + svc.SetDataProvider(&mockDataProvider{ + data: map[string][]MetricDataPoint{ + "vm-1:cpu": data, + }, + }) + svc.SetStateProvider(mockStateProvider{ + state: StateSnapshot{ + VMs: []VMInfo{{ID: "vm-1", Name: ""}}, + }, + }) + + result := svc.FormatKeyForecasts() + if result == "" { + t.Fatalf("expected non-empty result for concerning trends") + } + if !containsStr(result, "vm-1") { + t.Fatalf("expected vm-1 to be mentioned in concerns") + } + if !containsStr(result, "increasing") { + t.Fatalf("expected increasing trend note in concerns") + } + if !containsStr(result, "critical") { + t.Fatalf("expected critical note in concerns") + } +} + +func TestFormatKeyForecasts_AllResourceTypes(t *testing.T) { + cfg := DefaultForecastConfig() + cfg.VolatileThreshold = 100.0 + svc := NewService(cfg) + + now := time.Now() + data := buildLinearData(now, 24, time.Hour, 70, 1.0) + + svc.SetDataProvider(&mockDataProvider{ + data: map[string][]MetricDataPoint{ + "vm-1:cpu": data, + "ct-1:cpu": data, + "node-1:cpu": data, + "storage-1:disk": data, + }, + }) + svc.SetStateProvider(mockStateProvider{ + state: StateSnapshot{ + VMs: []VMInfo{{ID: "vm-1", Name: "vm"}}, + Containers: []ContainerInfo{{ID: "ct-1", Name: "ct"}}, + Nodes: []NodeInfo{{ID: "node-1", Name: "node"}}, + Storage: []StorageInfo{{ID: "storage-1", Name: ""}}, + }, + }) + + result := svc.FormatKeyForecasts() + if result == "" { + t.Fatalf("expected formatted concerns") + } + if !containsStr(result, "storage-1") { + t.Fatalf("expected storage to be included") + } +} + +func TestForecastAll_ActionableSorted(t *testing.T) { + cfg := DefaultForecastConfig() + cfg.VolatileThreshold = 100.0 + svc := NewService(cfg) + + now := time.Now() + rapid := buildLinearData(now, 80, time.Hour, 10, 1.0) // current near threshold + slow := buildLinearData(now, 80, time.Hour, 20, 0.5) // slower breach + flat := buildLinearData(now, 80, time.Hour, 60, 0.0) // non-increasing + + svc.SetDataProvider(&mockDataProvider{ + data: map[string][]MetricDataPoint{ + "vm-fast:disk": rapid, + "vm-slow:disk": slow, + "vm-flat:disk": flat, + }, + }) + svc.SetStateProvider(mockStateProvider{ + state: StateSnapshot{ + VMs: []VMInfo{ + {ID: "vm-fast", Name: "fast"}, + {ID: "vm-slow", Name: "slow"}, + {ID: "vm-flat", Name: "flat"}, + }, + }, + }) + + resp, err := svc.ForecastAll("disk", 24*time.Hour, 90) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Forecasts) != 2 { + t.Fatalf("expected 2 actionable forecasts, got %d", len(resp.Forecasts)) + } + if resp.Forecasts[0].ResourceID != "vm-fast" { + t.Fatalf("expected vm-fast to be most urgent, got %s", resp.Forecasts[0].ResourceID) + } + if resp.Forecasts[1].ResourceID != "vm-slow" { + t.Fatalf("expected vm-slow to be second, got %s", resp.Forecasts[1].ResourceID) + } +} + +func TestForecastAll_MissingProviders(t *testing.T) { + svc := NewService(DefaultForecastConfig()) + if _, err := svc.ForecastAll("disk", time.Hour, 80); err == nil { + t.Fatalf("expected error when data provider is missing") + } + + svc.SetDataProvider(&mockDataProvider{data: map[string][]MetricDataPoint{}}) + if _, err := svc.ForecastAll("disk", time.Hour, 80); err == nil { + t.Fatalf("expected error when state provider is missing") + } +} + +func TestForecastAll_FiltersNonActionable(t *testing.T) { + cfg := DefaultForecastConfig() + cfg.VolatileThreshold = 100.0 + svc := NewService(cfg) + + now := time.Now() + aboveThreshold := buildLinearData(now, 60, time.Hour, 80, 0.3) + lowConfidence := buildLinearData(now, 5, time.Hour, 10, 1.0) + + svc.SetDataProvider(&mockDataProvider{ + data: map[string][]MetricDataPoint{ + "vm-above:disk": aboveThreshold, + "vm-low:disk": lowConfidence, + }, + }) + svc.SetStateProvider(mockStateProvider{ + state: StateSnapshot{ + VMs: []VMInfo{ + {ID: "vm-above", Name: "above"}, + {ID: "vm-low", Name: "low"}, + }, + }, + }) + + resp, err := svc.ForecastAll("disk", 24*time.Hour, 90) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Forecasts) != 0 { + t.Fatalf("expected no actionable forecasts, got %d", len(resp.Forecasts)) + } +} + +func TestForecastAll_MultipleResourceTypes(t *testing.T) { + cfg := DefaultForecastConfig() + cfg.VolatileThreshold = 100.0 + svc := NewService(cfg) + + now := time.Now() + vmData := buildLinearData(now, 60, time.Hour, 20, 0.8) + ctData := buildLinearData(now, 60, time.Hour, 30, 0.6) + nodeData := buildLinearData(now, 60, time.Hour, 40, 0.4) + + svc.SetDataProvider(&mockDataProvider{ + data: map[string][]MetricDataPoint{ + "vm-1:disk": vmData, + "ct-1:disk": ctData, + "node-1:disk": nodeData, + }, + }) + svc.SetStateProvider(mockStateProvider{ + state: StateSnapshot{ + VMs: []VMInfo{{ID: "vm-1", Name: "vm"}}, + Containers: []ContainerInfo{{ID: "ct-1", Name: "ct"}}, + Nodes: []NodeInfo{{ID: "node-1", Name: "node"}}, + }, + }) + + resp, err := svc.ForecastAll("disk", 24*time.Hour, 90) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Forecasts) != 3 { + t.Fatalf("expected forecasts for vm, container, node, got %d", len(resp.Forecasts)) + } +} + +func TestForecastAll_SkipsErroredResources(t *testing.T) { + cfg := DefaultForecastConfig() + cfg.VolatileThreshold = 100.0 + svc := NewService(cfg) + + now := time.Now() + vmData := buildLinearData(now, 60, time.Hour, 20, 0.8) + + svc.SetDataProvider(&mockDataProvider{ + data: map[string][]MetricDataPoint{ + "vm-1:disk": vmData, + }, + }) + svc.SetStateProvider(mockStateProvider{ + state: StateSnapshot{ + VMs: []VMInfo{{ID: "vm-1", Name: "vm"}}, + Containers: []ContainerInfo{{ID: "ct-1", Name: "ct"}}, + }, + }) + + resp, err := svc.ForecastAll("disk", 24*time.Hour, 90) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Forecasts) != 1 { + t.Fatalf("expected only VM forecast, got %d", len(resp.Forecasts)) + } +} + +func TestForecastToOverviewItem(t *testing.T) { + ttl := 2 * time.Hour + item := forecastToOverviewItem(&Forecast{ + ResourceID: "vm-9", + ResourceName: "db", + Metric: "disk", + CurrentValue: 70, + PredictedValue: 85, + TimeToThreshold: &ttl, + Confidence: 0.6, + Trend: Trend{Direction: TrendIncreasing}, + }, "vm") + + if item.TimeToThreshold == nil || *item.TimeToThreshold == 0 { + t.Fatalf("expected time to threshold to be converted to seconds") + } + if item.ResourceType != "vm" { + t.Fatalf("expected resource type vm, got %s", item.ResourceType) + } +} + +func TestLinearRegression_Degenerate(t *testing.T) { + now := time.Now() + data := []MetricDataPoint{ + {Timestamp: now, Value: 10}, + {Timestamp: now, Value: 20}, + {Timestamp: now, Value: 30}, + } + + slope, intercept := linearRegression(data) + if slope != 0 { + t.Fatalf("expected zero slope for degenerate timestamps, got %.2f", slope) + } + if intercept != 20 { + t.Fatalf("expected intercept to be mean (20), got %.2f", intercept) + } +} + +func TestFilterByWindow_InclusiveBounds(t *testing.T) { + now := time.Now() + data := []MetricDataPoint{ + {Timestamp: now.Add(-2 * time.Hour), Value: 1}, + {Timestamp: now.Add(-1 * time.Hour), Value: 2}, + {Timestamp: now, Value: 3}, + } + + filtered := filterByWindow(data, now.Add(-2*time.Hour), now) + if len(filtered) != 3 { + t.Fatalf("expected inclusive bounds to include all points, got %d", len(filtered)) + } +} diff --git a/internal/ai/investigation/adapters_test.go b/internal/ai/investigation/adapters_test.go new file mode 100644 index 000000000..167205612 --- /dev/null +++ b/internal/ai/investigation/adapters_test.go @@ -0,0 +1,255 @@ +package investigation + +import ( + "context" + "reflect" + "testing" + "time" + "unsafe" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai/approval" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/chat" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/providers" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/tools" + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +type mockAIFinding struct { + id string + severity string + category string + resource string + name string + resType string + title string + desc string + reco string + evidence string + sessionID string + status string + outcome string + lastAt *time.Time + attempts int +} + +func (m *mockAIFinding) GetID() string { return m.id } +func (m *mockAIFinding) GetSeverity() string { return m.severity } +func (m *mockAIFinding) GetCategory() string { return m.category } +func (m *mockAIFinding) GetResourceID() string { return m.resource } +func (m *mockAIFinding) GetResourceName() string { return m.name } +func (m *mockAIFinding) GetResourceType() string { return m.resType } +func (m *mockAIFinding) GetTitle() string { return m.title } +func (m *mockAIFinding) GetDescription() string { return m.desc } +func (m *mockAIFinding) GetRecommendation() string { return m.reco } +func (m *mockAIFinding) GetEvidence() string { return m.evidence } +func (m *mockAIFinding) GetInvestigationSessionID() string { return m.sessionID } +func (m *mockAIFinding) GetInvestigationStatus() string { return m.status } +func (m *mockAIFinding) GetInvestigationOutcome() string { return m.outcome } +func (m *mockAIFinding) GetLastInvestigatedAt() *time.Time { return m.lastAt } +func (m *mockAIFinding) GetInvestigationAttempts() int { return m.attempts } +func (m *mockAIFinding) SetInvestigationSessionID(string) {} +func (m *mockAIFinding) SetInvestigationStatus(string) {} +func (m *mockAIFinding) SetInvestigationOutcome(string) {} +func (m *mockAIFinding) SetLastInvestigatedAt(*time.Time) {} +func (m *mockAIFinding) SetInvestigationAttempts(int) {} + +type mockAIFindingsStore struct { + finding *mockAIFinding + updated bool +} + +func (m *mockAIFindingsStore) Get(id string) AIFinding { + if m.finding != nil && m.finding.id == id { + return m.finding + } + return nil +} + +func (m *mockAIFindingsStore) UpdateInvestigation(id, sessionID, status, outcome string, lastInvestigatedAt *time.Time, attempts int) bool { + m.updated = true + return true +} + +func TestApprovalAdapter(t *testing.T) { + adapter := NewApprovalAdapter(nil) + if err := adapter.Create(&Approval{ID: "a1", RiskLevel: "low"}); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + store, err := approval.NewStore(approval.StoreConfig{ + DataDir: t.TempDir(), + DisablePersistence: true, + }) + if err != nil { + t.Fatalf("unexpected store error: %v", err) + } + + adapter = NewApprovalAdapter(store) + if err := adapter.Create(&Approval{ + ID: "a1", + RiskLevel: "critical", + Command: "echo ok", + FindingID: "f1", + Description: "desc", + }); err != nil { + t.Fatalf("unexpected create error: %v", err) + } + + req, ok := store.GetApproval("a1") + if !ok { + t.Fatalf("expected approval stored") + } + if req.RiskLevel != approval.RiskHigh { + t.Fatalf("expected risk high, got %s", req.RiskLevel) + } +} + +func TestFindingsStoreAdapter(t *testing.T) { + adapter := NewFindingsStoreAdapter(nil) + if adapter.Get("missing") != nil { + t.Fatalf("expected nil for missing store") + } + if adapter.Update(nil) { + t.Fatalf("expected update to fail with nil store") + } + + finding := &mockAIFinding{ + id: "f1", + severity: "critical", + category: "performance", + resource: "res-1", + name: "node-1", + resType: "node", + title: "title", + desc: "desc", + reco: "reco", + evidence: "evidence", + } + store := &mockAIFindingsStore{finding: finding} + adapter = NewFindingsStoreAdapter(store) + + got := adapter.Get("f1") + if got == nil || got.ID != "f1" { + t.Fatalf("expected finding") + } + if !adapter.Update(got) { + t.Fatalf("expected update to succeed") + } + if !store.updated { + t.Fatalf("expected update to be forwarded") + } +} + +type stubStreamingProvider struct{} + +func (s *stubStreamingProvider) Chat(ctx context.Context, req providers.ChatRequest) (*providers.ChatResponse, error) { + return &providers.ChatResponse{}, nil +} + +func (s *stubStreamingProvider) TestConnection(ctx context.Context) error { + return nil +} + +func (s *stubStreamingProvider) Name() string { + return "stub" +} + +func (s *stubStreamingProvider) ListModels(ctx context.Context) ([]providers.ModelInfo, error) { + return nil, nil +} + +func (s *stubStreamingProvider) ChatStream(ctx context.Context, req providers.ChatRequest, callback providers.StreamCallback) error { + callback(providers.StreamEvent{Type: "content", Data: providers.ContentEvent{Text: "ok"}}) + callback(providers.StreamEvent{Type: "done", Data: providers.DoneEvent{}}) + return nil +} + +func (s *stubStreamingProvider) SupportsThinking(model string) bool { + return false +} + +func setServiceField(t *testing.T, svc *chat.Service, fieldName string, value interface{}) { + t.Helper() + val := reflect.ValueOf(svc).Elem().FieldByName(fieldName) + if !val.IsValid() { + t.Fatalf("field %s not found", fieldName) + } + reflect.NewAt(val.Type(), unsafe.Pointer(val.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) +} + +func newTestChatService(t *testing.T) *chat.Service { + t.Helper() + dataDir := t.TempDir() + svc := chat.NewService(chat.Config{ + AIConfig: &config.AIConfig{Enabled: true}, + DataDir: dataDir, + }) + + sessions, err := chat.NewSessionStore(dataDir) + if err != nil { + t.Fatalf("failed to create session store: %v", err) + } + + executor := tools.NewPulseToolExecutor(tools.ExecutorConfig{}) + provider := &stubStreamingProvider{} + agentic := chat.NewAgenticLoop(provider, executor, "system") + + setServiceField(t, svc, "sessions", sessions) + setServiceField(t, svc, "agenticLoop", agentic) + setServiceField(t, svc, "executor", executor) + setServiceField(t, svc, "started", true) + + return svc +} + +func TestChatServiceAdapter_BasicFlow(t *testing.T) { + adapter := NewChatServiceAdapter(nil) + if adapter.IsRunning() { + t.Fatalf("expected nil service to be not running") + } + adapter.SetAutonomousMode(true) + if _, _, err := adapter.ExecuteCommand(context.Background(), "echo ok", ""); err == nil { + t.Fatalf("expected error for nil service") + } + if err := adapter.ExecuteStream(context.Background(), ExecuteRequest{Prompt: "hi"}, func(StreamEvent) {}); err == nil { + t.Fatalf("expected error for nil service") + } + + service := newTestChatService(t) + adapter = NewChatServiceAdapter(service) + + session, err := adapter.CreateSession(context.Background()) + if err != nil || session == nil { + t.Fatalf("expected session creation") + } + + gotContent := false + err = adapter.ExecuteStream(context.Background(), ExecuteRequest{Prompt: "hello", SessionID: session.ID}, func(event StreamEvent) { + if event.Type == "content" { + gotContent = true + } + }) + if err != nil { + t.Fatalf("unexpected stream error: %v", err) + } + if !gotContent { + t.Fatalf("expected content event") + } + + messages, err := adapter.GetMessages(context.Background(), session.ID) + if err != nil { + t.Fatalf("unexpected message error: %v", err) + } + if len(messages) == 0 { + t.Fatalf("expected messages") + } + + if err := adapter.DeleteSession(context.Background(), session.ID); err != nil { + t.Fatalf("unexpected delete error: %v", err) + } + + setServiceField(t, service, "started", false) + if err := adapter.ExecuteStream(context.Background(), ExecuteRequest{Prompt: "hi"}, func(StreamEvent) {}); err == nil { + t.Fatalf("expected error when service not running") + } +} diff --git a/internal/ai/investigation/guardrails.go b/internal/ai/investigation/guardrails.go index 76b3832d2..0d7a93847 100644 --- a/internal/ai/investigation/guardrails.go +++ b/internal/ai/investigation/guardrails.go @@ -44,7 +44,13 @@ func (g *Guardrails) IsDestructiveAction(command string) bool { // RequiresApproval determines if an action requires user approval // based on finding severity, autonomy level, and whether the command is destructive func (g *Guardrails) RequiresApproval(findingSeverity, autonomyLevel, command string, criticalRequireApproval bool) bool { - // Destructive actions ALWAYS require approval + // Autonomous mode - user explicitly opted into full autonomy, no approvals needed + // This is like "auto-accept" mode in Claude Code - user accepts all risk + if autonomyLevel == "autonomous" { + return false + } + + // Destructive actions ALWAYS require approval (except in autonomous mode above) if g.IsDestructiveAction(command) { return true } diff --git a/internal/ai/investigation/guardrails_test.go b/internal/ai/investigation/guardrails_test.go new file mode 100644 index 000000000..b6d36f055 --- /dev/null +++ b/internal/ai/investigation/guardrails_test.go @@ -0,0 +1,109 @@ +package investigation + +import ( + "strings" + "testing" +) + +func TestGuardrails_DestructiveAndCustomPatterns(t *testing.T) { + g := NewGuardrails() + if !g.IsDestructiveAction("rm -rf /tmp") { + t.Fatalf("expected destructive command") + } + + g.AddDestructivePattern("custom destroy") + if !g.IsDestructiveAction("custom destroy now") { + t.Fatalf("expected custom destructive match") + } + if g.IsDestructiveAction("echo safe") { + t.Fatalf("did not expect safe command to be destructive") + } +} + +func TestGuardrails_RequiresApproval(t *testing.T) { + g := NewGuardrails() + if !g.RequiresApproval("warning", "approval", "echo ok", true) { + t.Fatalf("expected approval mode to require approval") + } + if !g.RequiresApproval("critical", "full", "echo ok", true) { + t.Fatalf("expected critical to require approval") + } + if g.RequiresApproval("warning", "full", "echo ok", true) { + t.Fatalf("expected full autonomy warning to skip approval") + } + if !g.RequiresApproval("warning", "controlled", "echo ok", true) { + t.Fatalf("expected default to require approval") + } + + // Autonomous mode bypasses ALL approvals + if g.RequiresApproval("warning", "autonomous", "echo ok", true) { + t.Fatalf("expected autonomous mode to skip approval for safe commands") + } + if g.RequiresApproval("critical", "autonomous", "echo ok", true) { + t.Fatalf("expected autonomous mode to skip approval for critical findings") + } + if g.RequiresApproval("critical", "autonomous", "rm -rf /tmp/test", true) { + t.Fatalf("expected autonomous mode to skip approval even for destructive commands") + } +} + +func TestGuardrails_ClassifyRisk(t *testing.T) { + g := NewGuardrails() + + if g.ClassifyRisk("rm -rf /tmp") != "critical" { + t.Fatalf("expected destructive command to be critical risk") + } + if g.ClassifyRisk("systemctl restart nginx") != "high" { + t.Fatalf("expected restart to be high risk") + } + if g.ClassifyRisk("echo > /etc/hosts") != "medium" { + t.Fatalf("expected config change to be medium risk") + } + if g.ClassifyRisk("cat /etc/hosts") != "low" { + t.Fatalf("expected read-only command to be low risk") + } + if g.ClassifyRisk("unknown-command") != "medium" { + t.Fatalf("expected default to be medium risk") + } +} + +func TestGuardrails_ValidateAndSanitize(t *testing.T) { + g := NewGuardrails() + if valid, reason := g.ValidateCommand(""); valid || reason == "" { + t.Fatalf("expected empty command to be invalid") + } + longCmd := strings.Repeat("a", 4097) + if valid, reason := g.ValidateCommand(longCmd); valid || reason == "" { + t.Fatalf("expected long command to be invalid") + } + if valid, reason := g.ValidateCommand("echo ok; rm -rf /"); valid || reason == "" { + t.Fatalf("expected injection to be invalid") + } + if valid, _ := g.ValidateCommand("echo ok"); !valid { + t.Fatalf("expected command to be valid") + } + + sanitized, changed := g.SanitizeCommand(" echo ok ") + if sanitized != "echo ok" || !changed { + t.Fatalf("expected trim sanitize") + } + sanitized, changed = g.SanitizeCommand(" echo $(rm -rf /) ") + if !strings.Contains(sanitized, "$(") || !changed { + t.Fatalf("expected dangerous command to be flagged as changed") + } +} + +func TestGuardrails_GetDestructivePatterns(t *testing.T) { + g := NewGuardrails() + g.AddDestructivePattern("custom") + patterns := g.GetDestructivePatterns() + found := false + for _, pattern := range patterns { + if pattern == "custom" { + found = true + } + } + if !found { + t.Fatalf("expected custom pattern to be returned") + } +} diff --git a/internal/ai/investigation/orchestrator.go b/internal/ai/investigation/orchestrator.go index 22b871dd3..a93d17ed4 100644 --- a/internal/ai/investigation/orchestrator.go +++ b/internal/ai/investigation/orchestrator.go @@ -102,6 +102,11 @@ type Approval struct { CreatedAt time.Time `json:"created_at"` } +// InfrastructureContextProvider provides discovered infrastructure context for investigations +type InfrastructureContextProvider interface { + GetInfrastructureContext() string +} + // Orchestrator manages the investigation lifecycle type Orchestrator struct { mu sync.RWMutex @@ -114,6 +119,9 @@ type Orchestrator struct { guardrails *Guardrails config InvestigationConfig + // Infrastructure context provider for CLI access information + infraContextProvider InfrastructureContextProvider + // Track running investigations runningCount int runningMu sync.Mutex @@ -151,6 +159,15 @@ func (o *Orchestrator) SetCommandExecutor(executor CommandExecutor) { o.commandExecutor = executor } +// SetInfrastructureContextProvider sets the provider for discovered infrastructure context +// This enables investigations to know where services run (Docker, systemd, native) +// and propose correct CLI commands for remediation +func (o *Orchestrator) SetInfrastructureContextProvider(provider InfrastructureContextProvider) { + o.mu.Lock() + defer o.mu.Unlock() + o.infraContextProvider = provider +} + // GetConfig returns the current configuration func (o *Orchestrator) GetConfig() InvestigationConfig { o.mu.RLock() @@ -240,11 +257,32 @@ func (o *Orchestrator) InvestigateFinding(ctx context.Context, finding *Finding, // buildInvestigationPrompt creates the investigation prompt for a finding func (o *Orchestrator) buildInvestigationPrompt(finding *Finding) string { + // Get infrastructure context if available + var infraContext string + o.mu.RLock() + if o.infraContextProvider != nil { + infraContext = o.infraContextProvider.GetInfrastructureContext() + } + o.mu.RUnlock() + + // Build infrastructure context section + var infraSection string + if infraContext != "" { + infraSection = fmt.Sprintf(` +%s +**IMPORTANT**: When proposing commands, use the CLI access method shown above. +- If a service runs in Docker, use 'docker exec ' instead of direct commands +- Example: For PBS in Docker, use 'docker exec pbs proxmox-backup-manager gc pbs-delly' not 'proxmox-backup-manager gc pbs-delly' +- This ensures commands execute in the correct environment where the service actually runs. + +`, infraContext) + } + return fmt.Sprintf(`You are investigating a finding from Pulse Patrol. Your goal is to: 1. Understand the issue using available tools 2. Determine if it can be automatically fixed 3. If fixable, propose a specific remediation command - +%s ## Finding Details - **Title**: %s - **Severity**: %s @@ -288,6 +326,7 @@ Remember: - Only propose commands you're confident will help - Never propose destructive commands (they'll be blocked anyway) - Focus on the specific resource mentioned in the finding`, + infraSection, finding.Title, finding.Severity, finding.Category, diff --git a/internal/ai/investigation/orchestrator_additional_test.go b/internal/ai/investigation/orchestrator_additional_test.go new file mode 100644 index 000000000..f61706fac --- /dev/null +++ b/internal/ai/investigation/orchestrator_additional_test.go @@ -0,0 +1,323 @@ +package investigation + +import ( + "context" + "encoding/json" + "testing" +) + +type stubChatService struct { + sessionID string + execute func(StreamCallback) error +} + +func (s *stubChatService) CreateSession(ctx context.Context) (*Session, error) { + if s.sessionID == "" { + s.sessionID = "session-1" + } + return &Session{ID: s.sessionID}, nil +} + +func (s *stubChatService) ExecuteStream(ctx context.Context, req ExecuteRequest, callback StreamCallback) error { + if s.execute != nil { + return s.execute(callback) + } + return nil +} + +func (s *stubChatService) GetMessages(ctx context.Context, sessionID string) ([]Message, error) { + return nil, nil +} + +func (s *stubChatService) DeleteSession(ctx context.Context, sessionID string) error { + return nil +} + +func (s *stubChatService) SetAutonomousMode(enabled bool) {} + +type stubCommandExecutor struct { + output string + code int + err error +} + +func (s *stubCommandExecutor) ExecuteCommand(ctx context.Context, command, targetHost string) (string, int, error) { + return s.output, s.code, s.err +} + +type stubApprovalStore struct { + called bool + err error +} + +func (s *stubApprovalStore) Create(appr *Approval) error { + s.called = true + return s.err +} + +type stubFindingsStore struct { + finding *Finding + updated bool +} + +func (s *stubFindingsStore) Get(id string) *Finding { + if s.finding != nil && s.finding.ID == id { + return s.finding + } + return nil +} + +func (s *stubFindingsStore) Update(f *Finding) bool { + s.updated = true + s.finding = f + return true +} + +func TestOrchestrator_ConfigAndLimits(t *testing.T) { + store := NewStore("") + orchestrator := NewOrchestrator(&stubChatService{}, store, nil, nil, DefaultConfig()) + + cfg := orchestrator.GetConfig() + cfg.MaxConcurrent = 1 + orchestrator.SetConfig(cfg) + + if !orchestrator.CanStartInvestigation() { + t.Fatalf("expected investigation to be allowed") + } + + orchestrator.runningCount = 1 + if orchestrator.CanStartInvestigation() { + t.Fatalf("expected max concurrent to block") + } +} + +func TestOrchestrator_ExecuteWithLimits_Success(t *testing.T) { + store := NewStore("") + chatService := &stubChatService{ + execute: func(cb StreamCallback) error { + payload, _ := json.Marshal(map[string]string{"text": "analysis"}) + cb(StreamEvent{Type: "content", Data: payload}) + cb(StreamEvent{Type: "tool_end"}) + return nil + }, + } + orchestrator := NewOrchestrator(chatService, store, nil, nil, DefaultConfig()) + investigation := store.Create("finding-1", "session-1") + + if err := orchestrator.executeWithLimits(context.Background(), investigation, "prompt"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if investigation.Summary == "" { + t.Fatalf("expected summary to be set") + } +} + +func TestOrchestrator_ExecuteWithLimits_StreamError(t *testing.T) { + store := NewStore("") + chatService := &stubChatService{ + execute: func(cb StreamCallback) error { + payload, _ := json.Marshal(map[string]string{"message": "boom"}) + cb(StreamEvent{Type: "error", Data: payload}) + return nil + }, + } + orchestrator := NewOrchestrator(chatService, store, nil, nil, DefaultConfig()) + investigation := store.Create("finding-1", "session-1") + + if err := orchestrator.executeWithLimits(context.Background(), investigation, "prompt"); err == nil { + t.Fatalf("expected stream error") + } +} + +func TestOrchestrator_ProcessResult_ApprovalFlow(t *testing.T) { + store := NewStore("") + approval := &stubApprovalStore{} + findings := &stubFindingsStore{finding: &Finding{ID: "finding-1", Severity: "critical"}} + orchestrator := NewOrchestrator(&stubChatService{}, store, findings, approval, DefaultConfig()) + + investigation := store.Create("finding-1", "session-1") + investigation.Summary = "PROPOSED_FIX: systemctl restart app\nTARGET_HOST: node-1" + + if err := orchestrator.processResult(context.Background(), investigation, findings.finding, "controlled"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + updated := store.Get(investigation.ID) + if updated.Outcome != OutcomeFixQueued { + t.Fatalf("expected fix queued outcome") + } + if !approval.called { + t.Fatalf("expected approval store to be called") + } +} + +func TestOrchestrator_ProcessResult_AutonomySuccess(t *testing.T) { + store := NewStore("") + findings := &stubFindingsStore{finding: &Finding{ID: "finding-1", Severity: "warning"}} + orchestrator := NewOrchestrator(&stubChatService{}, store, findings, nil, DefaultConfig()) + orchestrator.SetCommandExecutor(&stubCommandExecutor{output: "ok", code: 0}) + + investigation := store.Create("finding-1", "session-1") + investigation.Summary = "PROPOSED_FIX: echo ok\nTARGET_HOST: local" + + if err := orchestrator.processResult(context.Background(), investigation, findings.finding, "full"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + updated := store.Get(investigation.ID) + if updated.Outcome != OutcomeFixExecuted { + t.Fatalf("expected fix executed outcome") + } +} + +func TestOrchestrator_ProcessResult_AutonomyFailure(t *testing.T) { + store := NewStore("") + findings := &stubFindingsStore{finding: &Finding{ID: "finding-1", Severity: "warning"}} + orchestrator := NewOrchestrator(&stubChatService{}, store, findings, nil, DefaultConfig()) + orchestrator.SetCommandExecutor(&stubCommandExecutor{output: "fail", code: 1}) + + investigation := store.Create("finding-1", "session-1") + investigation.Summary = "PROPOSED_FIX: echo fail\nTARGET_HOST: local" + + if err := orchestrator.processResult(context.Background(), investigation, findings.finding, "full"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + updated := store.Get(investigation.ID) + if updated.Outcome != OutcomeFixFailed { + t.Fatalf("expected fix failed outcome") + } +} + +func TestOrchestrator_ProcessResult_NoExecutor(t *testing.T) { + store := NewStore("") + findings := &stubFindingsStore{finding: &Finding{ID: "finding-1", Severity: "warning"}} + orchestrator := NewOrchestrator(&stubChatService{}, store, findings, nil, DefaultConfig()) + + investigation := store.Create("finding-1", "session-1") + investigation.Summary = "PROPOSED_FIX: echo ok\nTARGET_HOST: local" + + if err := orchestrator.processResult(context.Background(), investigation, findings.finding, "full"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + updated := store.Get(investigation.ID) + if updated.Outcome != OutcomeFixQueued { + t.Fatalf("expected fix queued outcome") + } +} + +func TestOrchestrator_ProcessResult_NoFix(t *testing.T) { + store := NewStore("") + findings := &stubFindingsStore{finding: &Finding{ID: "finding-1", Severity: "warning"}} + orchestrator := NewOrchestrator(&stubChatService{}, store, findings, nil, DefaultConfig()) + + investigation := store.Create("finding-1", "session-1") + investigation.Summary = "CANNOT_FIX: too complex" + + if err := orchestrator.processResult(context.Background(), investigation, findings.finding, "controlled"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + updated := store.Get(investigation.ID) + if updated.Outcome != OutcomeCannotFix { + t.Fatalf("expected cannot fix outcome") + } +} + +func TestOrchestrator_ParseInvestigationSummary(t *testing.T) { + orchestrator := NewOrchestrator(&stubChatService{}, NewStore(""), nil, nil, DefaultConfig()) + + fix, outcome := orchestrator.parseInvestigationSummary("PROPOSED_FIX: echo ok\nTARGET_HOST: node-1") + if fix == nil || outcome != OutcomeFixQueued { + t.Fatalf("expected fix proposal") + } + + if _, outcome = orchestrator.parseInvestigationSummary("CANNOT_FIX: no"); outcome != OutcomeCannotFix { + t.Fatalf("expected cannot fix outcome") + } + if _, outcome = orchestrator.parseInvestigationSummary("NEEDS_ATTENTION: help"); outcome != OutcomeNeedsAttention { + t.Fatalf("expected needs attention outcome") + } + if _, outcome = orchestrator.parseInvestigationSummary("unknown"); outcome != OutcomeNeedsAttention { + t.Fatalf("expected default needs attention outcome") + } +} + +func TestOrchestrator_ReinvestigateFinding(t *testing.T) { + store := NewStore("") + findings := &stubFindingsStore{finding: &Finding{ID: "finding-1"}} + chatService := &stubChatService{ + execute: func(cb StreamCallback) error { + payload, _ := json.Marshal(map[string]string{"text": "CANNOT_FIX: ok"}) + cb(StreamEvent{Type: "content", Data: payload}) + return nil + }, + } + orchestrator := NewOrchestrator(chatService, store, findings, nil, DefaultConfig()) + + if err := orchestrator.ReinvestigateFinding(context.Background(), "finding-1", "controlled"); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOrchestrator_ReinvestigateFinding_Errors(t *testing.T) { + orchestrator := NewOrchestrator(&stubChatService{}, NewStore(""), nil, nil, DefaultConfig()) + if err := orchestrator.ReinvestigateFinding(context.Background(), "missing", "controlled"); err == nil { + t.Fatalf("expected error for missing store") + } + + findings := &stubFindingsStore{} + orchestrator = NewOrchestrator(&stubChatService{}, NewStore(""), findings, nil, DefaultConfig()) + if err := orchestrator.ReinvestigateFinding(context.Background(), "missing", "controlled"); err == nil { + t.Fatalf("expected error for missing finding") + } +} + +func TestBuildInvestigationPromptAndTrim(t *testing.T) { + orchestrator := NewOrchestrator(&stubChatService{}, NewStore(""), nil, nil, DefaultConfig()) + finding := &Finding{ + ID: "finding-1", + Title: "CPU High", + Severity: "warning", + Category: "performance", + ResourceID: "vm-1", + ResourceName: "web", + ResourceType: "vm", + Description: "desc", + Evidence: "evidence", + Recommendation: "reco", + } + prompt := orchestrator.buildInvestigationPrompt(finding) + if prompt == "" { + t.Fatalf("expected prompt") + } + if trim(" spaced ") != "spaced" { + t.Fatalf("expected trim to remove whitespace") + } +} + +func TestFormatOptional(t *testing.T) { + if formatOptional("Label", "") != "" { + t.Fatalf("expected empty optional formatting") + } + if formatOptional("Label", "value") == "" { + t.Fatalf("expected formatted optional") + } +} + +func TestOrchestrator_Getters(t *testing.T) { + store := NewStore("") + orchestrator := NewOrchestrator(&stubChatService{}, store, nil, nil, DefaultConfig()) + + session := store.Create("finding-1", "session-1") + store.UpdateStatus(session.ID, StatusRunning) + + if orchestrator.GetInvestigation(session.ID) == nil { + t.Fatalf("expected investigation") + } + if orchestrator.GetInvestigationByFinding("finding-1") == nil { + t.Fatalf("expected investigation by finding") + } + if len(orchestrator.GetRunningInvestigations()) != 1 { + t.Fatalf("expected running investigations") + } + if orchestrator.GetRunningCount() != 1 { + t.Fatalf("expected running count") + } +} diff --git a/internal/ai/investigation/store_test.go b/internal/ai/investigation/store_test.go new file mode 100644 index 000000000..37fb194ba --- /dev/null +++ b/internal/ai/investigation/store_test.go @@ -0,0 +1,163 @@ +package investigation + +import ( + "testing" + "time" +) + +func TestStore_CreateGetAndLoad(t *testing.T) { + dir := t.TempDir() + store := NewStore(dir) + + session := store.Create("finding-1", "session-1") + if session == nil || session.ID == "" { + t.Fatalf("expected session to be created") + } + + retrieved := store.Get(session.ID) + if retrieved == nil || retrieved.ID != session.ID { + t.Fatalf("expected to retrieve session") + } + + store.sessions[session.ID].ProposedFix = &Fix{ID: "fix-1"} + copy := store.Get(session.ID) + if copy.ProposedFix == nil || copy.ProposedFix.ID != "fix-1" { + t.Fatalf("expected fix to be copied") + } + if copy.ProposedFix == store.sessions[session.ID].ProposedFix { + t.Fatalf("expected fix to be deep copied") + } + + if err := store.ForceSave(); err != nil { + t.Fatalf("unexpected save error: %v", err) + } + + loaded := NewStore(dir) + if err := loaded.LoadFromDisk(); err != nil { + t.Fatalf("unexpected load error: %v", err) + } + if loaded.Get(session.ID) == nil { + t.Fatalf("expected loaded session") + } +} + +func TestStore_ByFindingAndLatest(t *testing.T) { + store := NewStore("") + + first := store.Create("finding-1", "session-1") + second := store.Create("finding-1", "session-2") + + store.sessions[first.ID].StartedAt = time.Now().Add(-time.Hour) + store.sessions[second.ID].StartedAt = time.Now() + + sessions := store.GetByFinding("finding-1") + if len(sessions) != 2 { + t.Fatalf("expected 2 sessions, got %d", len(sessions)) + } + + latest := store.GetLatestByFinding("finding-1") + if latest == nil || latest.ID != second.ID { + t.Fatalf("expected latest session") + } +} + +func TestStore_RunningCounts(t *testing.T) { + store := NewStore("") + first := store.Create("finding-1", "session-1") + second := store.Create("finding-2", "session-2") + + store.UpdateStatus(first.ID, StatusRunning) + store.UpdateStatus(second.ID, StatusRunning) + + if store.CountRunning() != 2 { + t.Fatalf("expected 2 running sessions") + } + if len(store.GetRunning()) != 2 { + t.Fatalf("expected running sessions") + } +} + +func TestStore_UpdateAndStatus(t *testing.T) { + store := NewStore("") + session := store.Create("finding-1", "session-1") + session.Status = StatusRunning + session.ProposedFix = &Fix{ID: "fix-1"} + + if !store.Update(session) { + t.Fatalf("expected update to succeed") + } + retrieved := store.Get(session.ID) + if retrieved.Status != StatusRunning { + t.Fatalf("expected status update") + } + if retrieved.ProposedFix == nil || retrieved.ProposedFix.ID != "fix-1" { + t.Fatalf("expected fix update") + } + if !store.UpdateStatus(session.ID, StatusCompleted) { + t.Fatalf("expected status update") + } +} + +func TestStore_CompleteFailAndCounts(t *testing.T) { + store := NewStore("") + session := store.Create("finding-1", "session-1") + + if !store.Complete(session.ID, OutcomeFixExecuted, "summary", &Fix{ID: "fix-1"}) { + t.Fatalf("expected complete") + } + updated := store.Get(session.ID) + if updated.Outcome != OutcomeFixExecuted || updated.Status != StatusCompleted { + t.Fatalf("expected completed outcome") + } + + session2 := store.Create("finding-2", "session-2") + if !store.Fail(session2.ID, "error") { + t.Fatalf("expected fail") + } + if store.CountFixed() != 1 { + t.Fatalf("expected fixed count 1") + } +} + +func TestStore_IncrementAndApproval(t *testing.T) { + store := NewStore("") + session := store.Create("finding-1", "session-1") + + if count := store.IncrementTurnCount(session.ID); count != 1 { + t.Fatalf("expected turn count 1") + } + if !store.SetApprovalID(session.ID, "approval-1") { + t.Fatalf("expected approval id set") + } +} + +func TestStore_GetAllAndCleanup(t *testing.T) { + store := NewStore("") + session := store.Create("finding-1", "session-1") + session2 := store.Create("finding-2", "session-2") + + all := store.GetAll() + if len(all) != 2 { + t.Fatalf("expected 2 sessions") + } + + old := time.Now().Add(-2 * time.Hour) + store.sessions[session.ID].CompletedAt = &old + removed := store.Cleanup(time.Hour) + if removed != 1 { + t.Fatalf("expected 1 removed, got %d", removed) + } + if store.Get(session.ID) != nil { + t.Fatalf("expected session removed") + } + if store.Get(session2.ID) == nil { + t.Fatalf("expected remaining session") + } +} + +func TestStore_ForceSave_NoDir(t *testing.T) { + store := NewStore("") + if err := store.ForceSave(); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/ai/investigation/types_test.go b/internal/ai/investigation/types_test.go new file mode 100644 index 000000000..6e00c994e --- /dev/null +++ b/internal/ai/investigation/types_test.go @@ -0,0 +1,40 @@ +package investigation + +import "testing" + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + if cfg.MaxTurns == 0 || cfg.Timeout == 0 || cfg.MaxConcurrent == 0 { + t.Fatalf("expected non-zero defaults") + } + if !cfg.CriticalRequireApproval { + t.Fatalf("expected critical approval default") + } +} + +func TestIsDestructiveAndHelpers(t *testing.T) { + if !IsDestructive("rm -rf /tmp") { + t.Fatalf("expected destructive command") + } + if IsDestructive("echo safe") { + t.Fatalf("expected non-destructive command") + } + if !containsPattern("Systemctl Stop", "systemctl stop") { + t.Fatalf("expected case-insensitive pattern match") + } + if containsPattern("short", "longpattern") { + t.Fatalf("expected pattern mismatch for longer pattern") + } + if indexString("abc", "d") != -1 { + t.Fatalf("expected -1 for missing substring") + } + if indexString("abc", "") != 0 { + t.Fatalf("expected 0 for empty substring") + } + if toLower("AbC") != "abc" { + t.Fatalf("expected lowercase conversion") + } + if !contains("Hello", "he") { + t.Fatalf("expected case-insensitive contains") + } +} diff --git a/internal/ai/knowledge/store.go b/internal/ai/knowledge/store.go index 1ca36ba3c..9bbdc759f 100644 --- a/internal/ai/knowledge/store.go +++ b/internal/ai/knowledge/store.go @@ -14,6 +14,17 @@ import ( "github.com/rs/zerolog/log" ) +// Category constants for note categorization +const ( + CategoryCredential = "credential" + CategoryService = "service" + CategoryPath = "path" + CategoryConfig = "config" + CategoryLearning = "learning" + CategoryHistory = "history" + CategoryInfra = "infrastructure" // Auto-discovered infrastructure facts +) + // Note represents a single piece of learned information type Note struct { ID string `json:"id"` @@ -279,14 +290,15 @@ func (s *Store) FormatForContext(guestID string) string { result = fmt.Sprintf("\n## Previously Learned Information about %s\n", knowledge.GuestName) result += "**If relevant to the current task, use this saved information directly instead of rediscovering it.**\n" - categoryOrder := []string{"credential", "service", "path", "config", "learning", "history"} + categoryOrder := []string{"credential", "service", "path", "config", "learning", "history", "infrastructure"} categoryNames := map[string]string{ - "credential": "Credentials", - "service": "Services", - "path": "Important Paths", - "config": "Configuration", - "learning": "Learnings", - "history": "Session History", + "credential": "Credentials", + "service": "Services", + "path": "Important Paths", + "config": "Configuration", + "learning": "Learnings", + "history": "Session History", + "infrastructure": "Discovered Infrastructure", } for _, cat := range categoryOrder { @@ -437,7 +449,7 @@ func (s *Store) FormatAllForContext() string { var guestSection string guestSection = fmt.Sprintf("\n### %s (%s)", guestName, knowledge.GuestType) - categoryOrder := []string{"credential", "service", "path", "config", "learning"} + categoryOrder := []string{"credential", "service", "path", "config", "learning", "infrastructure"} for _, cat := range categoryOrder { notes, ok := byCategory[cat] if !ok || len(notes) == 0 { @@ -492,3 +504,56 @@ finalize: return result } + +// GetInfrastructureContext returns all discovered infrastructure formatted for AI context. +// This is specifically used by Patrol and investigations to understand where services run +// and how to interact with them (e.g., knowing PBS runs in Docker so commands need docker exec). +func (s *Store) GetInfrastructureContext() string { + guests, err := s.ListGuests() + if err != nil || len(guests) == 0 { + return "" + } + + var sb strings.Builder + sb.WriteString("\n## Discovered Infrastructure\n") + sb.WriteString("The following services have been auto-discovered on your infrastructure.\n") + sb.WriteString("Use this information to propose correct commands (e.g., use 'docker exec' for containerized services).\n\n") + + hasNotes := false + for _, guestID := range guests { + knowledge, err := s.GetKnowledge(guestID) + if err != nil { + continue + } + + // Filter for infrastructure notes only + var infraNotes []Note + for _, note := range knowledge.Notes { + if note.Category == CategoryInfra { + infraNotes = append(infraNotes, note) + } + } + + if len(infraNotes) == 0 { + continue + } + + hasNotes = true + guestName := knowledge.GuestName + if guestName == "" { + guestName = guestID + } + + sb.WriteString(fmt.Sprintf("### %s\n", guestName)) + for _, note := range infraNotes { + sb.WriteString(fmt.Sprintf("- %s: %s\n", note.Title, note.Content)) + } + sb.WriteString("\n") + } + + if !hasNotes { + return "" + } + + return sb.String() +} diff --git a/internal/ai/learning/store_additional_test.go b/internal/ai/learning/store_additional_test.go new file mode 100644 index 000000000..3796c29d5 --- /dev/null +++ b/internal/ai/learning/store_additional_test.go @@ -0,0 +1,373 @@ +package learning + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" +) + +func TestDefaultLearningStoreConfig(t *testing.T) { + cfg := DefaultLearningStoreConfig() + if cfg.MaxRecords != 10000 { + t.Fatalf("expected MaxRecords 10000, got %d", cfg.MaxRecords) + } + if cfg.RetentionDays != 90 { + t.Fatalf("expected RetentionDays 90, got %d", cfg.RetentionDays) + } +} + +func TestRecordFeedback_GeneratesIDTimestampAndSignal(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{}) + + store.RecordFeedback(FeedbackRecord{ + FindingID: "f-1", + ResourceID: "vm-1", + Category: "performance", + Severity: "warning", + Action: ActionThumbsDown, + }) + + store.mu.RLock() + defer store.mu.RUnlock() + if len(store.feedbackRecords) != 1 { + t.Fatalf("expected 1 feedback record, got %d", len(store.feedbackRecords)) + } + + for _, record := range store.feedbackRecords { + if record.ID == "" { + t.Fatalf("expected record ID to be generated") + } + if record.Timestamp.IsZero() { + t.Fatalf("expected timestamp to be set") + } + if !record.Signal.IsFalsePositive { + t.Fatalf("expected thumbs down to mark false positive") + } + } +} + +func TestComputeFeedbackSignal_Default(t *testing.T) { + signal := computeFeedbackSignal(UserAction("unknown")) + if signal.Confidence != 0.5 { + t.Fatalf("expected default confidence 0.5, got %.2f", signal.Confidence) + } +} + +func TestResourcePreferences_NotesTrim(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{}) + + for i := 0; i < 12; i++ { + store.RecordFeedback(FeedbackRecord{ + FindingID: "f" + intToStr(i), + ResourceID: "vm-1", + Category: "performance", + Severity: "warning", + Action: ActionDismissExpected, + UserNote: "note-" + intToStr(i), + }) + } + + pref := store.GetResourcePreference("vm-1") + if pref == nil { + t.Fatalf("expected resource preference to exist") + } + if len(pref.Notes) != 10 { + t.Fatalf("expected 10 notes after trimming, got %d", len(pref.Notes)) + } + if pref.Notes[0] != "note-2" { + t.Fatalf("expected oldest notes to be trimmed, got %s", pref.Notes[0]) + } + if pref.Notes[len(pref.Notes)-1] != "note-11" { + t.Fatalf("expected last note to be retained") + } +} + +func TestShouldSuppress_SeverityThreshold(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{}) + + store.RecordFeedback(FeedbackRecord{ + FindingID: "f1", + ResourceID: "vm-1", + Category: "performance", + Severity: "warning", + Action: ActionDismissNotAnIssue, + }) + + if !store.ShouldSuppress("vm-1", "performance", "info") { + t.Fatalf("expected info severity to be suppressed for thresholded category") + } + if store.ShouldSuppress("vm-1", "performance", "critical") { + t.Fatalf("expected critical severity not to be suppressed") + } +} + +func TestCategoryPreferences_RollingAverage(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{}) + + store.RecordFeedback(FeedbackRecord{ + FindingID: "f1", + ResourceID: "vm-1", + Category: "capacity", + Severity: "warning", + Action: ActionQuickFix, + TimeToAction: 10 * time.Minute, + }) + store.RecordFeedback(FeedbackRecord{ + FindingID: "f2", + ResourceID: "vm-2", + Category: "capacity", + Severity: "warning", + Action: ActionQuickFix, + TimeToAction: 20 * time.Minute, + }) + + pref := store.GetCategoryPreference("capacity") + if pref == nil { + t.Fatalf("expected category preference to exist") + } + expected := 15 * time.Minute + if pref.AverageTimeToAction != expected { + t.Fatalf("expected rolling average %s, got %s", expected, pref.AverageTimeToAction) + } +} + +func TestFormatForContext_Details(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{}) + + for i := 0; i < 6; i++ { + store.RecordFeedback(FeedbackRecord{ + FindingID: "f" + intToStr(i), + ResourceID: "vm-1", + Category: "performance", + Severity: "warning", + Action: ActionDismissNotAnIssue, + UserNote: "note-" + intToStr(i), + }) + } + for i := 0; i < 12; i++ { + action := ActionQuickFix + if i%5 == 0 { + action = ActionDismissNotAnIssue + } + store.RecordFeedback(FeedbackRecord{ + FindingID: "c" + intToStr(i), + ResourceID: "vm-" + intToStr(i), + Category: "capacity", + Severity: "warning", + Action: action, + }) + } + + context := store.FormatForContext() + if context == "" { + t.Fatalf("expected context to be populated") + } + if !containsStr(context, "vm-1") { + t.Fatalf("expected resource preference details in context") + } + if !containsStr(context, "Category value") { + t.Fatalf("expected category section in context") + } +} + +func TestCleanup_TrimMaxRecords(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{MaxRecords: 2}) + + store.RecordFeedback(FeedbackRecord{ + FindingID: "f1", + ResourceID: "vm-1", + Category: "performance", + Severity: "warning", + Action: ActionQuickFix, + }) + store.RecordFeedback(FeedbackRecord{ + FindingID: "f2", + ResourceID: "vm-2", + Category: "performance", + Severity: "warning", + Action: ActionQuickFix, + }) + store.RecordFeedback(FeedbackRecord{ + FindingID: "f3", + ResourceID: "vm-3", + Category: "performance", + Severity: "warning", + Action: ActionQuickFix, + }) + + removed := store.Cleanup() + if removed == 0 { + t.Fatalf("expected Cleanup to trim records when over max") + } + stats := store.GetStatistics() + if stats.TotalFeedbackRecords > 2 { + t.Fatalf("expected records trimmed to max, got %d", stats.TotalFeedbackRecords) + } +} + +func TestSaveAndLoadLearningStore(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "ai_learning.json") + + payload := struct { + FeedbackRecords map[string]*FeedbackRecord `json:"feedback_records"` + ResourcePreferences map[string]*ResourcePreference `json:"resource_preferences"` + CategoryPreferences map[string]*CategoryPreference `json:"category_preferences"` + }{ + FeedbackRecords: map[string]*FeedbackRecord{ + "fb-1": { + ID: "fb-1", + FindingID: "finding-1", + Category: "performance", + Action: ActionQuickFix, + Timestamp: time.Now(), + }, + }, + ResourcePreferences: map[string]*ResourcePreference{ + "vm-1": { + ResourceID: "vm-1", + TotalFindings: 3, + ActionedCount: 2, + DismissedCount: 1, + }, + }, + CategoryPreferences: map[string]*CategoryPreference{ + "performance": { + Category: "performance", + TotalFindings: 5, + ActionedCount: 3, + }, + }, + } + + raw, err := json.Marshal(payload) + if err != nil { + t.Fatalf("failed to marshal payload: %v", err) + } + if err := os.WriteFile(path, raw, 0600); err != nil { + t.Fatalf("failed to write payload: %v", err) + } + + loaded := NewLearningStore(LearningStoreConfig{DataDir: dir}) + stats := loaded.GetStatistics() + if stats.TotalFeedbackRecords != 1 { + t.Fatalf("expected 1 feedback record, got %d", stats.TotalFeedbackRecords) + } + if stats.ResourcePreferences != 1 || stats.CategoryPreferences != 1 { + t.Fatalf("expected resource and category prefs to load") + } +} + +func TestForceSave_PersistsData(t *testing.T) { + dir := t.TempDir() + store := NewLearningStore(LearningStoreConfig{DataDir: dir}) + + store.mu.Lock() + store.feedbackRecords["fb-1"] = &FeedbackRecord{ + ID: "fb-1", + FindingID: "finding-1", + Category: "capacity", + Action: ActionQuickFix, + Timestamp: time.Now(), + } + store.mu.Unlock() + + if err := store.ForceSave(); err != nil { + t.Fatalf("force save failed: %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "ai_learning.json")) + if err != nil { + t.Fatalf("expected saved file to exist: %v", err) + } + if !containsStr(string(data), "fb-1") { + t.Fatalf("expected saved data to contain record id") + } +} + +func TestSaveIfDirty_WritesFile(t *testing.T) { + dir := t.TempDir() + store := NewLearningStore(LearningStoreConfig{DataDir: dir}) + + store.mu.Lock() + store.feedbackRecords["fb-1"] = &FeedbackRecord{ + ID: "fb-1", + FindingID: "finding-1", + Category: "capacity", + Action: ActionQuickFix, + Timestamp: time.Now(), + } + store.dirty = true + store.mu.Unlock() + + store.saveIfDirty() + + if _, err := os.Stat(filepath.Join(dir, "ai_learning.json")); err != nil { + t.Fatalf("expected learning file to exist: %v", err) + } +} + +func TestSaveIfDirty_NotDirty(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{}) + store.saveIfDirty() +} + +func TestSaveToDisk_NoDir(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{}) + if err := store.saveToDisk(); err != nil { + t.Fatalf("expected saveToDisk to no-op without DataDir, got %v", err) + } +} + +func TestSaveIfDirty_SaveError(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "not-a-dir") + if err := os.WriteFile(filePath, []byte("x"), 0600); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + store := NewLearningStore(LearningStoreConfig{DataDir: filePath}) + store.mu.Lock() + store.feedbackRecords["fb-1"] = &FeedbackRecord{ + ID: "fb-1", + FindingID: "finding-1", + Category: "capacity", + Action: ActionQuickFix, + Timestamp: time.Now(), + } + store.dirty = true + store.mu.Unlock() + + store.saveIfDirty() + + store.mu.RLock() + dirty := store.dirty + store.mu.RUnlock() + if !dirty { + t.Fatalf("expected dirty to remain true on save error") + } +} + +func TestComputeFeedbackSignal_AdditionalActions(t *testing.T) { + signal := computeFeedbackSignal(ActionDismissWillFixLater) + if !signal.WasActionable || signal.Confidence <= 0 { + t.Fatalf("expected actionable signal for dismiss will fix later") + } + + signal = computeFeedbackSignal(ActionAcknowledge) + if !signal.WasActionable || signal.Confidence <= 0 { + t.Fatalf("expected actionable signal for acknowledge") + } +} + +func TestGetPreferences_NotFound(t *testing.T) { + store := NewLearningStore(LearningStoreConfig{}) + if store.GetResourcePreference("missing") != nil { + t.Fatalf("expected nil for missing resource preference") + } + if store.GetCategoryPreference("missing") != nil { + t.Fatalf("expected nil for missing category preference") + } +} diff --git a/internal/ai/memory/context_test.go b/internal/ai/memory/context_test.go new file mode 100644 index 000000000..46d756f54 --- /dev/null +++ b/internal/ai/memory/context_test.go @@ -0,0 +1,300 @@ +package memory + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" +) + +func TestNewContextStore_Defaults(t *testing.T) { + store := NewContextStore(ContextStoreConfig{}) + if store.config.MaxMemoriesPerType != 1000 { + t.Fatalf("expected MaxMemoriesPerType default, got %d", store.config.MaxMemoriesPerType) + } + if store.config.MaxResourceNotes != 20 { + t.Fatalf("expected MaxResourceNotes default, got %d", store.config.MaxResourceNotes) + } + if store.config.RetentionDays != 90 { + t.Fatalf("expected RetentionDays default, got %d", store.config.RetentionDays) + } + if store.config.RelevanceDecayDays != 7 { + t.Fatalf("expected RelevanceDecayDays default, got %d", store.config.RelevanceDecayDays) + } +} + +func TestRemember_AddsMemoryAndResourceNote(t *testing.T) { + store := NewContextStore(ContextStoreConfig{MaxResourceNotes: 2}) + + id1 := store.Remember("vm-1", "note-1", "user", MemoryTypeResource) + id2 := store.Remember("vm-1", "note-1", "user", MemoryTypeResource) + if id1 == id2 { + t.Fatalf("expected unique memory IDs for separate remembers") + } + + mem := store.GetResourceMemory("vm-1") + if mem == nil || len(mem.Notes) != 1 { + t.Fatalf("expected duplicate notes to be ignored") + } + + store.AddResourceNote("vm-1", "", "", "note-2") + store.AddResourceNote("vm-1", "", "", "note-3") + mem = store.GetResourceMemory("vm-1") + if len(mem.Notes) != 2 { + t.Fatalf("expected notes trimmed to max, got %d", len(mem.Notes)) + } + if mem.Notes[0] != "note-2" { + t.Fatalf("expected oldest note to be trimmed") + } +} + +func TestAddResourceNote_UpdatesMetadata(t *testing.T) { + store := NewContextStore(ContextStoreConfig{}) + store.AddResourceNote("vm-1", "web-1", "vm", "note") + + mem := store.GetResourceMemory("vm-1") + if mem == nil { + t.Fatalf("expected resource memory to exist") + } + if mem.ResourceName != "web-1" || mem.ResourceType != "vm" { + t.Fatalf("expected metadata to be stored") + } +} + +func TestAddIncidentMemory_CreatesMemory(t *testing.T) { + store := NewContextStore(ContextStoreConfig{}) + + store.AddIncidentMemory(&IncidentMemory{ + ResourceID: "vm-1", + Summary: "high cpu", + RootCause: "runaway process", + Resolution: "restarted service", + }) + + incidents := store.GetRecentIncidents(10) + if len(incidents) != 1 { + t.Fatalf("expected incident to be stored") + } + + store.mu.RLock() + defer store.mu.RUnlock() + if len(store.memories[MemoryTypeIncident]) != 1 { + t.Fatalf("expected incident memory to be created") + } + for _, mem := range store.memories[MemoryTypeIncident] { + if mem.Content == "" || mem.ResourceID != "vm-1" { + t.Fatalf("expected incident memory content to be populated") + } + } +} + +func TestAddPatternMemory_Deduplicates(t *testing.T) { + store := NewContextStore(ContextStoreConfig{}) + + store.AddPatternMemory(&PatternMemory{ + Pattern: "cpu spike at backup", + Description: "backup time spikes CPU", + Occurrences: 2, + }) + store.AddPatternMemory(&PatternMemory{ + Pattern: "cpu spike at backup", + Description: "same pattern", + Occurrences: 1, + }) + + patterns := store.GetPatterns(0) + if len(patterns) != 1 { + t.Fatalf("expected pattern to deduplicate") + } + if patterns[0].Occurrences != 3 { + t.Fatalf("expected occurrences to be incremented, got %d", patterns[0].Occurrences) + } + if patterns[0].Confidence == 0 { + t.Fatalf("expected confidence to be set") + } +} + +func TestRecallAndRecallByType_SortsAndMarksUsed(t *testing.T) { + store := NewContextStore(ContextStoreConfig{}) + + store.mu.Lock() + store.memories[MemoryTypeResource]["m1"] = &Memory{ + ID: "m1", + Type: MemoryTypeResource, + ResourceID: "vm-1", + Relevance: 0.5, + UseCount: 1, + LastUsed: time.Now().Add(-24 * time.Hour), + } + store.memories[MemoryTypeResource]["m2"] = &Memory{ + ID: "m2", + Type: MemoryTypeResource, + ResourceID: "vm-1", + Relevance: 0.9, + UseCount: 1, + LastUsed: time.Now().Add(-24 * time.Hour), + } + store.mu.Unlock() + + memories := store.Recall("vm-1") + if len(memories) != 2 { + t.Fatalf("expected 2 memories, got %d", len(memories)) + } + if memories[0].ID != "m2" { + t.Fatalf("expected higher relevance memory first") + } + + store.mu.RLock() + defer store.mu.RUnlock() + if store.memories[MemoryTypeResource]["m1"].UseCount != 2 { + t.Fatalf("expected use count increment") + } + if store.memories[MemoryTypeResource]["m1"].Relevance <= 0.5 { + t.Fatalf("expected relevance to increase") + } +} + +func TestGetPatterns_FilterAndSort(t *testing.T) { + store := NewContextStore(ContextStoreConfig{}) + + store.mu.Lock() + store.patternMemories["p1"] = &PatternMemory{ID: "p1", Pattern: "a", Confidence: 0.4} + store.patternMemories["p2"] = &PatternMemory{ID: "p2", Pattern: "b", Confidence: 0.9} + store.patternMemories["p3"] = &PatternMemory{ID: "p3", Pattern: "c", Confidence: 0.7} + store.mu.Unlock() + + patterns := store.GetPatterns(0.5) + if len(patterns) != 2 { + t.Fatalf("expected 2 patterns above threshold, got %d", len(patterns)) + } + if patterns[0].Confidence < patterns[1].Confidence { + t.Fatalf("expected patterns sorted by confidence desc") + } +} + +func TestDecayRelevance(t *testing.T) { + store := NewContextStore(ContextStoreConfig{RelevanceDecayDays: 1}) + + store.mu.Lock() + store.memories[MemoryTypeResource]["m1"] = &Memory{ + ID: "m1", + Type: MemoryTypeResource, + ResourceID: "vm-1", + Relevance: 0.2, + LastUsed: time.Now().Add(-30 * 24 * time.Hour), + } + store.mu.Unlock() + + store.DecayRelevance() + store.mu.RLock() + defer store.mu.RUnlock() + if store.memories[MemoryTypeResource]["m1"].Relevance != 0.1 { + t.Fatalf("expected relevance to decay to minimum, got %.2f", store.memories[MemoryTypeResource]["m1"].Relevance) + } +} + +func TestCleanup_RemovesOldAndTrims(t *testing.T) { + store := NewContextStore(ContextStoreConfig{ + MaxMemoriesPerType: 1, + RetentionDays: 1, + }) + + store.mu.Lock() + store.memories[MemoryTypeResource]["old"] = &Memory{ + ID: "old", + Type: MemoryTypeResource, + CreatedAt: time.Now().Add(-48 * time.Hour), + Relevance: 0.9, + } + store.memories[MemoryTypeResource]["low"] = &Memory{ + ID: "low", + Type: MemoryTypeResource, + CreatedAt: time.Now(), + Relevance: 0.05, + } + store.memories[MemoryTypeResource]["new"] = &Memory{ + ID: "new", + Type: MemoryTypeResource, + CreatedAt: time.Now(), + Relevance: 0.8, + } + store.mu.Unlock() + + removed := store.Cleanup() + if removed == 0 { + t.Fatalf("expected cleanup to remove memories") + } + + store.mu.RLock() + defer store.mu.RUnlock() + if len(store.memories[MemoryTypeResource]) != 1 { + t.Fatalf("expected memories trimmed to max") + } +} + +func TestFormatForPatrolAndResource(t *testing.T) { + store := NewContextStore(ContextStoreConfig{}) + store.AddResourceNote("vm-1", "vm-one", "vm", "note-1") + store.AddPatternMemory(&PatternMemory{ + Pattern: "pattern", + Description: "pattern desc", + Occurrences: 5, + }) + store.AddIncidentMemory(&IncidentMemory{ + ResourceID: "vm-1", + Summary: "disk full", + }) + + resource := store.GetResourceMemory("vm-1") + resource.Patterns = []string{"daily spike"} + store.mu.Lock() + store.resourceMemories["vm-1"] = resource + store.mu.Unlock() + + context := store.FormatForPatrol() + if context == "" || !containsStr(context, "Resource Notes") || !containsStr(context, "Recent Incidents") { + t.Fatalf("expected patrol context to include sections") + } + + resourceContext := store.FormatForResource("vm-1") + if resourceContext == "" || !containsStr(resourceContext, "Observed patterns") { + t.Fatalf("expected resource context to include patterns") + } +} + +func TestContextStore_SaveLoad(t *testing.T) { + dir := t.TempDir() + store := NewContextStore(ContextStoreConfig{DataDir: dir}) + + store.Remember("vm-1", "note-1", "user", MemoryTypeResource) + store.AddPatternMemory(&PatternMemory{ + Pattern: "pattern", + Description: "pattern desc", + Occurrences: 3, + }) + if err := store.ForceSave(); err != nil { + t.Fatalf("force save failed: %v", err) + } + + loaded := NewContextStore(ContextStoreConfig{DataDir: dir}) + if loaded == nil { + t.Fatalf("expected store to load from disk") + } + if len(loaded.memories[MemoryTypeResource]) == 0 { + t.Fatalf("expected memories to load") + } + if len(loaded.patternMemories) == 0 { + t.Fatalf("expected patterns to load") + } + + // Validate saved file exists and is JSON. + data, err := os.ReadFile(filepath.Join(dir, "ai_context.json")) + if err != nil { + t.Fatalf("expected context file to exist: %v", err) + } + var decoded map[string]interface{} + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("expected valid json, got %v", err) + } +} diff --git a/internal/ai/patrol.go b/internal/ai/patrol.go index b1db63e35..2a4a854fc 100644 --- a/internal/ai/patrol.go +++ b/internal/ai/patrol.go @@ -17,6 +17,7 @@ import ( "github.com/rcourtman/pulse-go-rewrite/internal/ai/knowledge" "github.com/rcourtman/pulse-go-rewrite/internal/ai/memory" "github.com/rcourtman/pulse-go-rewrite/internal/ai/remediation" + "github.com/rcourtman/pulse-go-rewrite/internal/aidiscovery" "github.com/rcourtman/pulse-go-rewrite/internal/alerts" "github.com/rcourtman/pulse-go-rewrite/internal/models" "github.com/rs/zerolog/log" @@ -341,6 +342,7 @@ type PatrolService struct { config PatrolConfig findings *FindingsStore knowledgeStore *knowledge.Store // For per-resource notes in patrol context + discoveryStore *aidiscovery.Store // For AI-discovered infrastructure context metricsHistory MetricsHistoryProvider // For trend analysis and predictions baselineStore *baseline.Store // For anomaly detection via learned baselines changeDetector *ChangeDetector // For tracking infrastructure changes @@ -895,6 +897,22 @@ func (p *PatrolService) GetKnowledgeStore() *knowledge.Store { return p.knowledgeStore } +// SetDiscoveryStore sets the AI discovery store for infrastructure context +// This enables the patrol service to include discovered service info in prompts +func (p *PatrolService) SetDiscoveryStore(store *aidiscovery.Store) { + p.mu.Lock() + defer p.mu.Unlock() + p.discoveryStore = store + log.Info().Msg("AI Patrol: Discovery store set for infrastructure context") +} + +// GetDiscoveryStore returns the discovery store for external access +func (p *PatrolService) GetDiscoveryStore() *aidiscovery.Store { + p.mu.RLock() + defer p.mu.RUnlock() + return p.discoveryStore +} + // SetMetricsHistoryProvider sets the metrics history provider for enriched context // This enables the patrol service to compute trends and predictions based on historical data func (p *PatrolService) SetMetricsHistoryProvider(provider MetricsHistoryProvider) { @@ -3218,7 +3236,11 @@ func (p *PatrolService) GetFindingsHistory(startTime *time.Time) []*Finding { // The deep parameter is kept for API backwards compatibility but is ignored // Uses context.Background() since this runs async after the HTTP response func (p *PatrolService) ForcePatrol(ctx context.Context, deep bool) { - go p.runPatrol(context.Background()) + runCtx := context.Background() + if ctx != nil { + runCtx = context.WithoutCancel(ctx) + } + go p.runPatrol(runCtx) } // analyzePBSInstance checks a PBS backup server for issues @@ -4471,17 +4493,27 @@ func (p *PatrolService) buildPatrolPrompt(summary string) string { // Get resource notes from knowledge store (per-resource user notes) var knowledgeContext string + var infraContext string var incidentContext string + var discoveryContext string p.mu.RLock() knowledgeStore := p.knowledgeStore incidentStore := p.incidentStore + discoveryStore := p.discoveryStore p.mu.RUnlock() if knowledgeStore != nil { knowledgeContext = knowledgeStore.FormatAllForContext() + infraContext = knowledgeStore.GetInfrastructureContext() } if incidentStore != nil { incidentContext = incidentStore.FormatForPatrol(8) } + // Get AI discovery context (deep-scanned service info) + if discoveryStore != nil { + if discoveries, err := discoveryStore.List(); err == nil && len(discoveries) > 0 { + discoveryContext = aidiscovery.FormatForAIContext(discoveries) + } + } basePrompt := fmt.Sprintf(`Please perform a comprehensive analysis of the following infrastructure and identify any issues, potential problems, or optimization opportunities. @@ -4512,6 +4544,24 @@ If everything looks healthy with stable trends, say so briefly.`, summary) contextAdditions.WriteString("\nIMPORTANT: Consider the user's saved notes above when analyzing. If a user has noted that a resource behaves a certain way (e.g., 'runs hot for transcoding'), do not flag it as an issue.\n") } + // Append infrastructure discovery context (auto-discovered apps and services) + if infraContext != "" { + contextAdditions.WriteString("\n\n") + contextAdditions.WriteString(infraContext) + contextAdditions.WriteString(` +IMPORTANT: When proposing remediation commands, use the CLI access method shown above. +- If a service runs in Docker, use 'docker exec ' instead of direct commands +- Example: For PBS in Docker, use 'docker exec pbs proxmox-backup-manager gc pbs-delly' not 'proxmox-backup-manager gc pbs-delly' +- This ensures commands execute in the correct environment where the service actually runs. +`) + } + + // Append deep AI discovery context (service details, versions, config paths, ports) + if discoveryContext != "" { + contextAdditions.WriteString("\n\n") + contextAdditions.WriteString(discoveryContext) + } + // Append user feedback context (dismissed/snoozed findings) if feedbackContext != "" { contextAdditions.WriteString("\n\n") @@ -4874,6 +4924,8 @@ func (p *PatrolService) parseFindingBlock(block string) *Finding { cat = FindingCategorySecurity case "capacity": cat = FindingCategoryCapacity + case "backup": + cat = FindingCategoryBackup case "configuration": cat = FindingCategoryGeneral // Configuration maps to General default: diff --git a/internal/ai/patrol_test.go b/internal/ai/patrol_test.go index 55b297ce7..83b4cf952 100644 --- a/internal/ai/patrol_test.go +++ b/internal/ai/patrol_test.go @@ -746,6 +746,33 @@ EVIDENCE: Usage: 90% } } +func TestPatrolService_ParseFindingBlock_BackupCategory(t *testing.T) { + ps := NewPatrolService(nil, nil) + + block := ` +SEVERITY: warning +CATEGORY: backup +RESOURCE: vm-101 +RESOURCE_TYPE: vm +TITLE: Backup stale +DESCRIPTION: No backup in 48 hours +RECOMMENDATION: Check backup jobs +EVIDENCE: Last backup: 2 days ago +` + + finding := ps.parseFindingBlock(block) + + if finding == nil { + t.Fatal("Expected non-nil finding") + } + if finding.Category != FindingCategoryBackup { + t.Errorf("Expected category backup, got %v", finding.Category) + } + if finding.Title != "Backup stale" { + t.Errorf("Expected title 'Backup stale', got '%s'", finding.Title) + } +} + func TestPatrolService_ParseFindingBlock_MissingRequiredFields(t *testing.T) { ps := NewPatrolService(nil, nil) @@ -969,7 +996,7 @@ func TestJoinParts(t *testing.T) { {[]string{}, ""}, {[]string{"one"}, "one"}, {[]string{"one", "two"}, "one and two"}, - {[]string{"one", "two", "three"}, "[one two], and three"}, + {[]string{"one", "two", "three"}, "one, two, and three"}, } for _, tt := range tests { diff --git a/internal/ai/proxmox/events_additional_test.go b/internal/ai/proxmox/events_additional_test.go new file mode 100644 index 000000000..a15a05e9d --- /dev/null +++ b/internal/ai/proxmox/events_additional_test.go @@ -0,0 +1,448 @@ +package proxmox + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestNewEventCorrelator_Defaults(t *testing.T) { + c := NewEventCorrelator(EventCorrelatorConfig{}) + if c.config.CorrelationWindow != 15*time.Minute { + t.Fatalf("expected default correlation window") + } + if c.config.MaxEvents != 5000 { + t.Fatalf("expected default max events") + } + if c.config.MaxCorrelations != 1000 { + t.Fatalf("expected default max correlations") + } + if c.config.RetentionDays != 30 { + t.Fatalf("expected default retention days") + } +} + +func TestRecordEvent_EndOperationClosesWindow(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + start := ProxmoxEvent{ + ID: "start-1", + Type: EventMigrationStart, + ResourceID: "vm-1", + Node: "pve1", + TargetNode: "pve2", + Timestamp: time.Now().Add(-5 * time.Minute), + } + c.RecordEvent(start) + + if len(c.GetActiveOperations()) != 1 { + t.Fatalf("expected active operation") + } + + end := ProxmoxEvent{ + Type: EventMigrationEnd, + ResourceID: "vm-1", + Timestamp: time.Now(), + } + c.RecordEvent(end) + + if len(c.GetActiveOperations()) != 0 { + t.Fatalf("expected active operation to close on end event") + } +} + +func TestRecordAnomaly_OutsideWindow(t *testing.T) { + cfg := DefaultEventCorrelatorConfig() + cfg.CorrelationWindow = time.Minute + c := NewEventCorrelator(cfg) + + c.RecordEvent(ProxmoxEvent{ + Type: EventBackupStart, + ResourceID: "vm-1", + Timestamp: time.Now().Add(-10 * time.Minute), + }) + + anomaly := MetricAnomaly{ + ResourceID: "vm-1", + Metric: "cpu", + Timestamp: time.Now(), + } + if correlation := c.RecordAnomaly(anomaly); correlation != nil { + t.Fatalf("expected no correlation outside window") + } +} + +func TestGetRecentEventsWithLimit(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + now := time.Now() + c.RecordEvent(ProxmoxEvent{Type: EventVMStart, ResourceID: "vm-1", Timestamp: now.Add(-2 * time.Minute)}) + c.RecordEvent(ProxmoxEvent{Type: EventVMStop, ResourceID: "vm-1", Timestamp: now.Add(-1 * time.Minute)}) + + events := c.GetRecentEventsWithLimit(10*time.Minute, 1) + if len(events) != 1 { + t.Fatalf("expected 1 event") + } + if events[0].Type != EventVMStop { + t.Fatalf("expected most recent event") + } +} + +func TestGetRecentEventsWithLimit_NoLimitAndCutoff(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + now := time.Now() + c.RecordEvent(ProxmoxEvent{Type: EventVMStop, ResourceID: "vm-1", Timestamp: now.Add(-2 * time.Hour)}) + c.RecordEvent(ProxmoxEvent{Type: EventVMStart, ResourceID: "vm-1", Timestamp: now.Add(-2 * time.Minute)}) + c.RecordEvent(ProxmoxEvent{Type: EventVMStop, ResourceID: "vm-1", Timestamp: now.Add(-1 * time.Minute)}) + + events := c.GetRecentEventsWithLimit(30*time.Minute, 0) + if len(events) != 2 { + t.Fatalf("expected cutoff to exclude old events, got %d", len(events)) + } +} + +func TestGetEventsForResource_IncludesNodeAndStorage(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + c.RecordEvent(ProxmoxEvent{ + Type: EventBackupStart, + Node: "pve1", + Storage: "local", + }) + + if len(c.GetEventsForResource("pve1", 10)) != 1 { + t.Fatalf("expected node match to return event") + } + if len(c.GetEventsForResource("local", 10)) != 1 { + t.Fatalf("expected storage match to return event") + } +} + +func TestGetCorrelationsForResource(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + c.correlations = append(c.correlations, EventCorrelation{ + ID: "c1", + ImpactedResources: []string{"vm-1"}, + }) + if len(c.GetCorrelationsForResource("vm-1")) != 1 { + t.Fatalf("expected correlation for resource") + } +} + +func TestGetActiveOperations_Expires(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + c.activeWindows["w1"] = &OperationWindow{ + EventID: "w1", + EventType: EventBackupStart, + StartTime: time.Now().Add(-2 * time.Hour), + ExpectedEnd: time.Now().Add(-1 * time.Hour), + } + if len(c.GetActiveOperations()) != 0 { + t.Fatalf("expected expired active window to be filtered") + } +} + +func TestFormatForPatrol_IncludesActiveAndCorrelations(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + c.RecordEvent(ProxmoxEvent{ + Type: EventMigrationStart, + ResourceID: "vm-1", + Node: "pve1", + TargetNode: "pve2", + }) + c.RecordAnomaly(MetricAnomaly{ + ResourceID: "vm-1", + Metric: "cpu", + Timestamp: time.Now(), + }) + + context := c.FormatForPatrol(1 * time.Hour) + if !containsStr(context, "Currently Active Operations") { + t.Fatalf("expected active operations section") + } + if !containsStr(context, "Detected Correlations") { + t.Fatalf("expected correlations section") + } +} + +func TestFormatForResource_NoData(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + if c.FormatForResource("vm-1") != "" { + t.Fatalf("expected empty context when no data") + } +} + +func TestFormatForResource_WithCorrelation(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + c.RecordEvent(ProxmoxEvent{ + Type: EventVMStart, + ResourceID: "vm-1", + Status: "running", + }) + c.correlations = append(c.correlations, EventCorrelation{ + ID: "corr-1", + ImpactedResources: []string{"vm-1"}, + Explanation: "VM start spike explained", + CreatedAt: time.Now(), + }) + + context := c.FormatForResource("vm-1") + if !containsStr(context, "Correlated events") { + t.Fatalf("expected correlated events section") + } +} + +func TestTrimEventsAndCorrelations(t *testing.T) { + cfg := DefaultEventCorrelatorConfig() + cfg.MaxEvents = 2 + cfg.MaxCorrelations = 1 + cfg.RetentionDays = 1 + c := NewEventCorrelator(cfg) + + oldTime := time.Now().Add(-48 * time.Hour) + c.events = []ProxmoxEvent{ + {ID: "e1", Timestamp: oldTime}, + {ID: "e2", Timestamp: time.Now()}, + {ID: "e3", Timestamp: time.Now()}, + } + c.correlations = []EventCorrelation{ + {ID: "c1", CreatedAt: oldTime}, + {ID: "c2", CreatedAt: time.Now()}, + } + + c.trimEvents() + c.trimCorrelations() + + if len(c.events) != 2 { + t.Fatalf("expected events trimmed to max") + } + if len(c.correlations) != 1 { + t.Fatalf("expected correlations trimmed to max") + } +} + +func TestSaveLoad(t *testing.T) { + dir := t.TempDir() + c := NewEventCorrelator(EventCorrelatorConfig{}) + c.events = append(c.events, ProxmoxEvent{ + ID: "e1", + Type: EventBackupStart, + ResourceID: "vm-1", + Timestamp: time.Now(), + }) + c.correlations = append(c.correlations, EventCorrelation{ + ID: "c1", + Event: ProxmoxEvent{ID: "e1", Type: EventBackupStart, ResourceID: "vm-1"}, + ImpactedResources: []string{"vm-1"}, + CreatedAt: time.Now(), + }) + c.dataDir = dir + if err := c.saveToDisk(); err != nil { + t.Fatalf("save failed: %v", err) + } + if _, err := os.Stat(filepath.Join(dir, "proxmox_events.json")); err != nil { + t.Fatalf("expected file to exist: %v", err) + } + + loaded := NewEventCorrelator(EventCorrelatorConfig{DataDir: dir}) + if len(loaded.events) == 0 || len(loaded.correlations) == 0 { + t.Fatalf("expected data to load") + } +} + +func TestSaveLoad_NoDir(t *testing.T) { + c := NewEventCorrelator(EventCorrelatorConfig{}) + if err := c.saveToDisk(); err != nil { + t.Fatalf("expected saveToDisk to no-op without DataDir, got %v", err) + } + if err := c.loadFromDisk(); err != nil { + t.Fatalf("expected loadFromDisk to no-op without DataDir, got %v", err) + } +} + +func TestSaveIfDirty(t *testing.T) { + dir := t.TempDir() + c := NewEventCorrelator(EventCorrelatorConfig{DataDir: dir}) + c.events = append(c.events, ProxmoxEvent{ID: "e1", Timestamp: time.Now()}) + c.saveIfDirty() + + if _, err := os.Stat(filepath.Join(dir, "proxmox_events.json")); err != nil { + t.Fatalf("expected saveIfDirty to persist data: %v", err) + } +} + +func TestSaveIfDirty_Error(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "not-a-dir") + if err := os.WriteFile(filePath, []byte("x"), 0600); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + c := NewEventCorrelator(EventCorrelatorConfig{DataDir: filePath}) + c.events = append(c.events, ProxmoxEvent{ID: "e1", Timestamp: time.Now()}) + c.saveIfDirty() +} + +func TestSaveToDisk_Error(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "not-a-dir") + if err := os.WriteFile(filePath, []byte("x"), 0600); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + c := NewEventCorrelator(EventCorrelatorConfig{DataDir: filePath}) + c.events = append(c.events, ProxmoxEvent{ID: "e1", Timestamp: time.Now()}) + if err := c.saveToDisk(); err == nil { + t.Fatalf("expected saveToDisk to fail with invalid data dir") + } +} + +func TestLoadFromDisk_InvalidJSON(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "proxmox_events.json") + if err := os.WriteFile(path, []byte("{not-json"), 0600); err != nil { + t.Fatalf("failed to write invalid json: %v", err) + } + + c := NewEventCorrelator(EventCorrelatorConfig{DataDir: dir}) + if err := c.loadFromDisk(); err == nil { + t.Fatalf("expected loadFromDisk to fail on invalid json") + } +} + +func TestCreateOperationWindow_Types(t *testing.T) { + c := NewEventCorrelator(DefaultEventCorrelatorConfig()) + backup := c.createOperationWindow(ProxmoxEvent{Type: EventBackupStart, ResourceID: "vm-1", Storage: "local"}) + if len(backup.ExpectedMetrics) == 0 || len(backup.AffectedResources) == 0 { + t.Fatalf("expected backup operation window to be populated") + } + + snapshot := c.createOperationWindow(ProxmoxEvent{Type: EventSnapshotDelete, ResourceID: "vm-1", Storage: "local"}) + if len(snapshot.ExpectedMetrics) == 0 || len(snapshot.AffectedResources) == 0 { + t.Fatalf("expected snapshot operation window to be populated") + } + + ha := c.createOperationWindow(ProxmoxEvent{Type: EventHAFailover, ResourceID: "vm-1", Node: "pve1", TargetNode: "pve2"}) + if len(ha.ExpectedMetrics) == 0 || len(ha.AffectedResources) == 0 { + t.Fatalf("expected HA operation window to be populated") + } +} + +func TestHelperFunctions(t *testing.T) { + if !isOngoingOperation(EventBackupStart) || isOngoingOperation(EventBackupEnd) { + t.Fatalf("expected ongoing operation detection") + } + if !isEndOperation(EventMigrationEnd) || isEndOperation(EventMigrationStart) { + t.Fatalf("expected end operation detection") + } + + start := ProxmoxEvent{Type: EventMigrationStart, ResourceID: "vm-1"} + end := ProxmoxEvent{Type: EventMigrationEnd, ResourceID: "vm-1"} + if !matchesEndEvent(start, end) { + t.Fatalf("expected start/end match") + } + backupStart := ProxmoxEvent{Type: EventBackupStart, ResourceID: "vm-1"} + backupEnd := ProxmoxEvent{Type: EventBackupEnd, ResourceID: "vm-1"} + if !matchesEndEvent(backupStart, backupEnd) { + t.Fatalf("expected backup start/end match") + } + if matchesEndEvent(ProxmoxEvent{Type: EventSnapshotCreate, ResourceID: "vm-1"}, ProxmoxEvent{Type: EventSnapshotDelete, ResourceID: "vm-1"}) { + t.Fatalf("expected snapshot start/end to not match") + } + if matchesEndEvent(start, ProxmoxEvent{Type: EventMigrationEnd, ResourceID: "vm-2"}) { + t.Fatalf("expected mismatch for different resource") + } + + if estimateOperationDuration(EventBackupStart) != 2*time.Hour { + t.Fatalf("expected backup duration") + } + if estimateOperationDuration(ProxmoxEventType("custom")) != 15*time.Minute { + t.Fatalf("expected default duration") + } + if len(getExpectedMetrics(EventSnapshotCreate)) == 0 { + t.Fatalf("expected metrics for snapshot") + } + if len(getExpectedMetrics(EventHAFailover)) == 0 { + t.Fatalf("expected metrics for HA") + } + if len(getExpectedMetrics(ProxmoxEventType("custom"))) != 0 { + t.Fatalf("expected no metrics for unknown event type") + } + + if formatEventType(ProxmoxEventType("custom")) != "custom" { + t.Fatalf("expected fallback format") + } + + if !containsString([]string{"a", "b"}, "b") { + t.Fatalf("expected containsString to match") + } + if containsString([]string{"a", "b"}, "c") { + t.Fatalf("expected containsString to return false") + } + if minFloat(1.0, 2.0) != 1.0 { + t.Fatalf("expected minFloat to return min") + } +} + +func TestFormatEventType_All(t *testing.T) { + types := []ProxmoxEventType{ + EventMigrationStart, + EventMigrationEnd, + EventBackupStart, + EventBackupEnd, + EventSnapshotCreate, + EventSnapshotDelete, + EventHAFailover, + EventHAMigration, + EventClusterJoin, + EventClusterLeave, + EventStorageOnline, + EventStorageOffline, + EventNodeReboot, + EventVMCreate, + EventVMDestroy, + EventVMStart, + EventVMStop, + } + + for _, eventType := range types { + if formatEventType(eventType) == "" { + t.Fatalf("expected format for %s", eventType) + } + } +} + +func TestGenerateExplanation_Cases(t *testing.T) { + event := ProxmoxEvent{Type: EventBackupStart, ResourceID: "vm-1"} + anomaly := MetricAnomaly{ResourceID: "vm-1", Metric: "io"} + if generateExplanation(event, anomaly) == "" { + t.Fatalf("expected explanation for backup") + } + + event = ProxmoxEvent{Type: EventSnapshotCreate, ResourceID: "vm-1"} + if generateExplanation(event, anomaly) == "" { + t.Fatalf("expected explanation for snapshot") + } + + event = ProxmoxEvent{Type: EventHAFailover, ResourceID: "vm-1"} + if generateExplanation(event, anomaly) == "" { + t.Fatalf("expected explanation for HA failover") + } + + event = ProxmoxEvent{Type: ProxmoxEventType("custom"), ResourceID: "vm-1"} + if generateExplanation(event, anomaly) == "" { + t.Fatalf("expected default explanation") + } +} + +func TestSortEventsByTimestamp_Additional(t *testing.T) { + older := time.Now().Add(-2 * time.Hour) + newer := time.Now().Add(-1 * time.Hour) + events := []ProxmoxEvent{ + {ID: "old", Timestamp: older}, + {ID: "new", Timestamp: newer}, + } + + SortEventsByTimestamp(events) + if events[0].ID != "new" { + t.Fatalf("expected newest event first") + } +} diff --git a/internal/ai/remediation/engine_additional_test.go b/internal/ai/remediation/engine_additional_test.go new file mode 100644 index 000000000..fc89c35d6 --- /dev/null +++ b/internal/ai/remediation/engine_additional_test.go @@ -0,0 +1,486 @@ +package remediation + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" +) + +func TestNewEngine_Defaults(t *testing.T) { + engine := NewEngine(EngineConfig{}) + if engine.config.MaxExecutions != 100 { + t.Fatalf("expected MaxExecutions default, got %d", engine.config.MaxExecutions) + } + if engine.config.PlanExpiry != 24*time.Hour { + t.Fatalf("expected PlanExpiry default, got %s", engine.config.PlanExpiry) + } + if engine.config.ExecutionTimeout != 5*time.Minute { + t.Fatalf("expected ExecutionTimeout default, got %s", engine.config.ExecutionTimeout) + } +} + +func TestEngine_ValidatePlan_TitleRequired(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + err := engine.CreatePlan(&RemediationPlan{ + Steps: []RemediationStep{{Command: "echo ok"}}, + }) + if err == nil { + t.Fatalf("expected error for missing title") + } +} + +func TestEngine_IsBlockedCommand_CaseInsensitive(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + if !engine.isBlockedCommand("RM -RF /tmp") { + t.Fatalf("expected blocked command detection") + } + if engine.isBlockedCommand("") { + t.Fatalf("expected empty command to be allowed") + } +} + +func TestEngine_AssessRiskAndCategorize(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + + plan := &RemediationPlan{ + Title: "Delete data", + Steps: []RemediationStep{{Command: "delete something"}}, + } + if engine.assessRiskLevel(plan) != RiskHigh { + t.Fatalf("expected high risk") + } + plan.RiskLevel = RiskHigh + if engine.categorize(plan) != CategoryGuided { + t.Fatalf("expected guided for high risk") + } + + plan = &RemediationPlan{ + Title: "Info only", + Steps: []RemediationStep{{Description: "observe"}}, + } + if engine.categorize(plan) != CategoryInformational { + t.Fatalf("expected informational for no commands") + } + + plan = &RemediationPlan{ + Title: "Low risk", + Steps: []RemediationStep{{Command: "echo ok"}}, + RiskLevel: RiskLow, + } + if engine.categorize(plan) != CategoryOneClick { + t.Fatalf("expected one-click for low risk small plan") + } +} + +func TestEngine_ListPlans_SkipsExpiredAndOrders(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + + now := time.Now() + expiredAt := now.Add(-1 * time.Hour) + + planOld := &RemediationPlan{ + Title: "old", + Steps: []RemediationStep{{Command: "echo old"}}, + CreatedAt: now.Add(-2 * time.Hour), + } + planNew := &RemediationPlan{ + Title: "new", + Steps: []RemediationStep{{Command: "echo new"}}, + CreatedAt: now.Add(-1 * time.Hour), + } + planExpired := &RemediationPlan{ + Title: "expired", + Steps: []RemediationStep{{Command: "echo expired"}}, + CreatedAt: now, + ExpiresAt: &expiredAt, + } + + if err := engine.CreatePlan(planOld); err != nil { + t.Fatalf("create plan old failed: %v", err) + } + if err := engine.CreatePlan(planNew); err != nil { + t.Fatalf("create plan new failed: %v", err) + } + if err := engine.CreatePlan(planExpired); err != nil { + t.Fatalf("create plan expired failed: %v", err) + } + + plans := engine.ListPlans(10) + if len(plans) != 2 { + t.Fatalf("expected 2 non-expired plans, got %d", len(plans)) + } + if plans[0].Title != "new" { + t.Fatalf("expected newest plan first") + } + + if len(engine.ListPlans(0)) == 0 { + t.Fatalf("expected ListPlans to return results with default limit") + } +} + +func TestEngine_GetPlanForFinding_SkipsExpired(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + expiredAt := time.Now().Add(-1 * time.Hour) + + plan := &RemediationPlan{ + FindingID: "finding-1", + Title: "expired plan", + Steps: []RemediationStep{{Command: "echo test"}}, + ExpiresAt: &expiredAt, + } + if err := engine.CreatePlan(plan); err != nil { + t.Fatalf("create plan failed: %v", err) + } + if engine.GetPlanForFinding("finding-1") != nil { + t.Fatalf("expected expired plan to be skipped") + } +} + +func TestEngine_GetLatestExecutionForPlan(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + planID := "plan-1" + + older := time.Now().Add(-2 * time.Hour) + newer := time.Now().Add(-1 * time.Hour) + + engine.executions["e1"] = &RemediationExecution{ + ID: "e1", + PlanID: planID, + CompletedAt: &older, + } + engine.executions["e2"] = &RemediationExecution{ + ID: "e2", + PlanID: planID, + ApprovedAt: &newer, + } + + latest := engine.GetLatestExecutionForPlan(planID) + if latest == nil || latest.ID != "e2" { + t.Fatalf("expected latest execution to be e2") + } +} + +func TestEngine_GetLatestExecutionForPlan_None(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + if engine.GetLatestExecutionForPlan("missing") != nil { + t.Fatalf("expected nil when no executions exist") + } +} + +func TestEngine_Execute_Errors(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + + if err := engine.Execute(context.Background(), "missing"); err == nil { + t.Fatalf("expected error for missing execution") + } + + engine.executions["e1"] = &RemediationExecution{ID: "e1", Status: StatusRunning} + if err := engine.Execute(context.Background(), "e1"); err == nil { + t.Fatalf("expected error for non-approved execution") + } + + engine.executions["e2"] = &RemediationExecution{ID: "e2", Status: StatusApproved, PlanID: "missing-plan"} + if err := engine.Execute(context.Background(), "e2"); err == nil { + t.Fatalf("expected error for missing plan") + } + + plan := &RemediationPlan{Title: "p", Steps: []RemediationStep{{Command: "echo"}}} + if err := engine.CreatePlan(plan); err != nil { + t.Fatalf("create plan failed: %v", err) + } + exec := &RemediationExecution{ID: "e3", Status: StatusApproved, PlanID: plan.ID} + engine.executions["e3"] = exec + if err := engine.Execute(context.Background(), "e3"); err == nil { + t.Fatalf("expected error when executor is missing") + } +} + +func TestEngine_Rollback_SuccessAndError(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + executor := newMockExecutor() + executor.results["rollback-1"] = struct { + output string + err error + }{output: "ok", err: nil} + executor.results["rollback-2"] = struct { + output string + err error + }{output: "ok", err: nil} + engine.SetCommandExecutor(executor) + + plan := &RemediationPlan{ + Title: "with rollback", + Steps: []RemediationStep{ + {Order: 1, Command: "cmd-1", Rollback: "rollback-1"}, + {Order: 2, Command: "cmd-2", Rollback: "rollback-2"}, + }, + } + if err := engine.CreatePlan(plan); err != nil { + t.Fatalf("create plan failed: %v", err) + } + exec, _ := engine.ApprovePlan(plan.ID, "admin") + if err := engine.Execute(context.Background(), exec.ID); err != nil { + t.Fatalf("execute failed: %v", err) + } + + if err := engine.Rollback(context.Background(), exec.ID); err != nil { + t.Fatalf("rollback failed: %v", err) + } + updated := engine.GetExecution(exec.ID) + if updated.Status != StatusRolledBack { + t.Fatalf("expected rolled back status") + } +} + +func TestEngine_Rollback_Error(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + executor := newMockExecutor() + executor.results["rollback-bad"] = struct { + output string + err error + }{output: "", err: errors.New("rollback failed")} + engine.SetCommandExecutor(executor) + + plan := &RemediationPlan{ + Title: "rollback error", + Steps: []RemediationStep{ + {Order: 1, Command: "cmd-1", Rollback: "rollback-bad"}, + }, + } + if err := engine.CreatePlan(plan); err != nil { + t.Fatalf("create plan failed: %v", err) + } + exec, _ := engine.ApprovePlan(plan.ID, "admin") + if err := engine.Execute(context.Background(), exec.ID); err != nil { + t.Fatalf("execute failed: %v", err) + } + + if err := engine.Rollback(context.Background(), exec.ID); err == nil { + t.Fatalf("expected rollback error") + } + updated := engine.GetExecution(exec.ID) + if updated.RollbackError == "" { + t.Fatalf("expected rollback error to be recorded") + } +} + +func TestEngine_ListExecutions_Additional(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + old := time.Now().Add(-2 * time.Hour) + newer := time.Now().Add(-1 * time.Hour) + + engine.executions["e1"] = &RemediationExecution{ID: "e1", ApprovedAt: &old} + engine.executions["e2"] = &RemediationExecution{ID: "e2", ApprovedAt: &newer} + + execs := engine.ListExecutions(1) + if len(execs) != 1 { + t.Fatalf("expected limit to apply") + } + if execs[0].ID != "e2" { + t.Fatalf("expected most recent execution first") + } +} + +func TestEngine_AddApprovalRuleAndAutoApprove(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + engine.AddApprovalRule(&ApprovalRule{ + Category: CategoryOneClick, + MaxRiskLevel: RiskMedium, + Enabled: true, + }) + + plan := &RemediationPlan{ + Title: "plan", + Steps: []RemediationStep{{Command: "echo"}}, + Category: CategoryOneClick, + RiskLevel: RiskLow, + } + + if !engine.IsAutoApproved(plan) { + t.Fatalf("expected plan to be auto-approved") + } + + engine.AddApprovalRule(&ApprovalRule{ + Category: CategoryGuided, + MaxRiskLevel: RiskLow, + Enabled: false, + }) + if !engine.IsAutoApproved(plan) { + t.Fatalf("disabled rule should not block auto-approval") + } + + plan.RiskLevel = RiskHigh + if engine.IsAutoApproved(plan) { + t.Fatalf("expected high risk plan to be blocked by max risk") + } +} + +func TestEngine_FormatPlanForContext_Additional(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + if engine.FormatPlanForContext(nil) != "" { + t.Fatalf("expected empty format for nil plan") + } + + plan := &RemediationPlan{ + Title: "Fix issue", + Description: "do the thing", + Category: CategoryGuided, + RiskLevel: RiskMedium, + Prerequisites: []string{ + "backup data", + }, + Warnings: []string{"service restart"}, + Steps: []RemediationStep{ + {Order: 1, Description: "step one", Command: "cmd-1", Rollback: "rb-1"}, + }, + } + + formatted := engine.FormatPlanForContext(plan) + if formatted == "" { + t.Fatalf("expected formatted plan") + } + if !contains(formatted, "Prerequisites") || !contains(formatted, "Warnings") { + t.Fatalf("expected prerequisites and warnings sections") + } + if !contains(formatted, "Rollback") { + t.Fatalf("expected rollback details") + } +} + +func TestEngine_SaveLoad(t *testing.T) { + dir := t.TempDir() + engine := NewEngine(EngineConfig{}) + engine.plans["plan-1"] = &RemediationPlan{ + ID: "plan-1", + Title: "save plan", + Steps: []RemediationStep{{Command: "echo ok"}}, + } + engine.executions["exec-1"] = &RemediationExecution{ + ID: "exec-1", + PlanID: "plan-1", + Status: StatusApproved, + } + engine.approvalRules["rule-1"] = &ApprovalRule{ + ID: "rule-1", + Category: CategoryGuided, + MaxRiskLevel: RiskMedium, + Enabled: true, + } + engine.dataDir = dir + + if err := engine.saveToDisk(); err != nil { + t.Fatalf("save failed: %v", err) + } + + if _, err := os.Stat(filepath.Join(dir, "remediation.json")); err != nil { + t.Fatalf("expected remediation.json to exist: %v", err) + } + + loaded := NewEngine(EngineConfig{DataDir: dir}) + if len(loaded.plans) == 0 || len(loaded.executions) == 0 || len(loaded.approvalRules) == 0 { + t.Fatalf("expected data to load from disk") + } +} + +func TestTruncateString(t *testing.T) { + short := "abc" + if got := truncateString(short, 10); got != short { + t.Fatalf("expected short string to remain unchanged") + } + + long := "this is a long string" + got := truncateString(long, 4) + if got != "this..." { + t.Fatalf("expected truncation with ellipsis, got %q", got) + } +} + +func TestEngine_Rollback_Errors(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + if err := engine.Rollback(context.Background(), "missing"); err == nil { + t.Fatalf("expected error for missing execution") + } + + engine.executions["e1"] = &RemediationExecution{ID: "e1", PlanID: "missing-plan"} + if err := engine.Rollback(context.Background(), "e1"); err == nil { + t.Fatalf("expected error for missing plan") + } + + plan := &RemediationPlan{Title: "p", Steps: []RemediationStep{{Command: "echo"}}} + if err := engine.CreatePlan(plan); err != nil { + t.Fatalf("create plan failed: %v", err) + } + engine.executions["e2"] = &RemediationExecution{ID: "e2", PlanID: plan.ID} + if err := engine.Rollback(context.Background(), "e2"); err == nil { + t.Fatalf("expected error for missing executor") + } +} + +func TestEngine_SaveToDisk_NoDir(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + if err := engine.saveToDisk(); err != nil { + t.Fatalf("expected saveToDisk to no-op without DataDir, got %v", err) + } +} + +func TestEngine_SaveIfDirty(t *testing.T) { + dir := t.TempDir() + engine := NewEngine(EngineConfig{DataDir: dir}) + engine.plans["p1"] = &RemediationPlan{ID: "p1", Title: "t", Steps: []RemediationStep{{Command: "echo"}}} + engine.saveIfDirty() + + if _, err := os.Stat(filepath.Join(dir, "remediation.json")); err != nil { + t.Fatalf("expected saveIfDirty to persist data: %v", err) + } +} + +func TestEngine_SaveIfDirty_Error(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "not-a-dir") + if err := os.WriteFile(filePath, []byte("x"), 0600); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + engine := NewEngine(EngineConfig{DataDir: filePath}) + engine.plans["p1"] = &RemediationPlan{ID: "p1", Title: "t", Steps: []RemediationStep{{Command: "echo"}}} + engine.saveIfDirty() +} + +func TestEngine_GetExecution_NotFound(t *testing.T) { + engine := NewEngine(DefaultEngineConfig()) + if engine.GetExecution("missing") != nil { + t.Fatalf("expected nil for missing execution") + } +} + +func TestRiskValue_Additional(t *testing.T) { + cases := map[RiskLevel]int{ + RiskLow: 1, + RiskMedium: 2, + RiskHigh: 3, + RiskCritical: 4, + RiskLevel("unknown"): 0, + } + for level, expected := range cases { + if got := riskValue(level); got != expected { + t.Fatalf("expected %d for %s, got %d", expected, level, got) + } + } +} + +func TestEngine_SaveToDisk_Error(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "not-a-dir") + if err := os.WriteFile(filePath, []byte("x"), 0600); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + engine := NewEngine(EngineConfig{DataDir: filePath}) + engine.plans["p1"] = &RemediationPlan{ID: "p1", Title: "t", Steps: []RemediationStep{{Command: "echo"}}} + if err := engine.saveToDisk(); err == nil { + t.Fatalf("expected saveToDisk to fail with invalid data dir") + } +} diff --git a/internal/ai/service.go b/internal/ai/service.go index b1de53ebe..5a730f655 100644 --- a/internal/ai/service.go +++ b/internal/ai/service.go @@ -23,7 +23,9 @@ import ( "github.com/rcourtman/pulse-go-rewrite/internal/ai/knowledge" "github.com/rcourtman/pulse-go-rewrite/internal/ai/memory" "github.com/rcourtman/pulse-go-rewrite/internal/ai/providers" + "github.com/rcourtman/pulse-go-rewrite/internal/aidiscovery" "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/infradiscovery" "github.com/rcourtman/pulse-go-rewrite/internal/license" "github.com/rcourtman/pulse-go-rewrite/internal/models" "github.com/rcourtman/pulse-go-rewrite/internal/types" @@ -119,6 +121,15 @@ type Service struct { incidentStore *memory.IncidentStore // Incident timelines for alert memory chatService ChatServiceProvider // Chat service for investigation orchestrator + // Infrastructure discovery service - detects apps running on hosts + infraDiscoveryService *infradiscovery.Service + + // AI-powered deep discovery store - detailed service analysis with commands + aiDiscoveryStore *aidiscovery.Store + + // AI-powered deep discovery service - runs commands and AI analysis + aiDiscoveryService *aidiscovery.Service + // Alert-triggered analysis - token-efficient real-time AI insights alertTriggeredAnalyzer *AlertTriggeredAnalyzer @@ -147,6 +158,7 @@ type modelsCache struct { func NewService(persistence *config.ConfigPersistence, agentServer AgentServer) *Service { // Initialize knowledge store var knowledgeStore *knowledge.Store + var aiDiscoveryStore *aidiscovery.Store costStore := cost.NewStore(cost.DefaultMaxDays) if persistence != nil { var err error @@ -157,14 +169,20 @@ func NewService(persistence *config.ConfigPersistence, agentServer AgentServer) if err := costStore.SetPersistence(NewCostPersistenceAdapter(persistence)); err != nil { log.Warn().Err(err).Msg("Failed to initialize AI usage cost store") } + // Initialize AI discovery store for deep infrastructure discovery + aiDiscoveryStore, err = aidiscovery.NewStore(persistence.DataDir()) + if err != nil { + log.Warn().Err(err).Msg("Failed to initialize AI discovery store") + } } return &Service{ - persistence: persistence, - agentServer: agentServer, - policy: agentexec.DefaultPolicy(), - knowledgeStore: knowledgeStore, - costStore: costStore, + persistence: persistence, + agentServer: agentServer, + policy: agentexec.DefaultPolicy(), + knowledgeStore: knowledgeStore, + aiDiscoveryStore: aiDiscoveryStore, + costStore: costStore, limits: executionLimits{ chatSlots: make(chan struct{}, 4), patrolSlots: make(chan struct{}, 1), @@ -242,6 +260,65 @@ func (s *Service) SetStateProvider(sp StateProvider) { if s.incidentStore != nil { s.patrolService.SetIncidentStore(s.incidentStore) } + // Connect AI discovery store for deep infrastructure context + if s.aiDiscoveryStore != nil { + s.patrolService.SetDiscoveryStore(s.aiDiscoveryStore) + } + } + + // Initialize infrastructure discovery service if not already done + // This uses AI to detect applications running in Docker containers + // and saves discoveries to the knowledge store for Patrol to use when proposing commands + if s.infraDiscoveryService == nil && sp != nil && s.knowledgeStore != nil { + s.infraDiscoveryService = infradiscovery.NewService( + sp, + s.knowledgeStore, + infradiscovery.DefaultConfig(), + ) + // Wire the AI service as the analyzer (implements infradiscovery.AIAnalyzer) + s.infraDiscoveryService.SetAIAnalyzer(s) + s.infraDiscoveryService.Start(context.Background()) + log.Info().Msg("AI-powered infrastructure discovery service started") + } + + // Initialize AI-powered deep discovery service if not already done + // This runs read-only commands on resources and uses AI to understand services + if s.aiDiscoveryService == nil && sp != nil && s.aiDiscoveryStore != nil { + // Create command executor adapter (wraps agentexec.Server) + var cmdExecutor aidiscovery.CommandExecutor + if agentSrv, ok := s.agentServer.(*agentexec.Server); ok { + cmdExecutor = newDiscoveryCommandAdapter(agentSrv) + } + + // Create state adapter + stateAdapter := newDiscoveryStateAdapter(sp) + + // Create deep scanner + scanner := aidiscovery.NewDeepScanner(cmdExecutor) + + // Create the discovery service with config-driven settings + discoveryCfg := aidiscovery.DefaultConfig() + if s.cfg != nil { + discoveryCfg.Interval = s.cfg.GetDiscoveryInterval() + } + + s.aiDiscoveryService = aidiscovery.NewService( + s.aiDiscoveryStore, + scanner, + stateAdapter, + discoveryCfg, + ) + s.aiDiscoveryService.SetAIAnalyzer(s) + + // Start background discovery if enabled and interval is set + if s.cfg != nil && s.cfg.IsDiscoveryEnabled() && s.cfg.GetDiscoveryInterval() > 0 { + s.aiDiscoveryService.Start(context.Background()) + log.Info(). + Dur("interval", s.cfg.GetDiscoveryInterval()). + Msg("AI-powered deep discovery service started with automatic scanning") + } else { + log.Info().Msg("AI-powered deep discovery service initialized (manual mode)") + } } // Initialize alert-triggered analyzer if not already done @@ -438,6 +515,59 @@ func (s *Service) SetIncidentStore(store *memory.IncidentStore) { } } +// SetAIDiscoveryStore sets the AI discovery store for infrastructure context +func (s *Service) SetAIDiscoveryStore(store *aidiscovery.Store) { + s.mu.Lock() + defer s.mu.Unlock() + s.aiDiscoveryStore = store + + if s.patrolService != nil { + s.patrolService.SetDiscoveryStore(store) + } + log.Info().Msg("AI Service: AI discovery store set for infrastructure context") +} + +// GetAIDiscoveryStore returns the AI discovery store +func (s *Service) GetAIDiscoveryStore() *aidiscovery.Store { + s.mu.RLock() + defer s.mu.RUnlock() + return s.aiDiscoveryStore +} + +// GetAIDiscoveryService returns the AI discovery service for triggering scans +func (s *Service) GetAIDiscoveryService() *aidiscovery.Service { + s.mu.RLock() + defer s.mu.RUnlock() + return s.aiDiscoveryService +} + +// updateDiscoverySettings updates the discovery service based on config changes +// Note: caller must NOT hold s.mu lock +func (s *Service) updateDiscoverySettings(cfg *config.AIConfig) { + if s.aiDiscoveryService == nil || cfg == nil { + return + } + + enabled := cfg.IsDiscoveryEnabled() + interval := cfg.GetDiscoveryInterval() + + if enabled && interval > 0 { + // Update interval and ensure service is running + s.aiDiscoveryService.SetInterval(interval) + s.aiDiscoveryService.Start(context.Background()) + log.Info(). + Bool("enabled", enabled). + Dur("interval", interval). + Msg("Discovery service updated: automatic scanning enabled") + } else { + // Stop background scanning (manual mode) + s.aiDiscoveryService.Stop() + log.Info(). + Bool("enabled", enabled). + Msg("Discovery service updated: manual mode (background scanning stopped)") + } +} + // SetPatternDetector sets the pattern detector for failure prediction func (s *Service) SetPatternDetector(detector *PatternDetector) { s.mu.RLock() @@ -998,6 +1128,9 @@ func (s *Service) LoadConfig() error { Bool("autonomous", cfg.IsAutonomous()). Msg("AI service initialized") + // Update discovery service settings based on config + s.updateDiscoverySettings(cfg) + return nil } @@ -2672,6 +2805,67 @@ func (s *Service) DeleteGuestNote(guestID, noteID string) error { return s.knowledgeStore.DeleteNote(guestID, noteID) } +// GetKnowledgeStore returns the knowledge store for external use +// This is used by components like the investigation orchestrator to get +// infrastructure context for proposing correct CLI commands +func (s *Service) GetKnowledgeStore() *knowledge.Store { + s.mu.RLock() + defer s.mu.RUnlock() + return s.knowledgeStore +} + +// AnalyzeForDiscovery implements the infradiscovery.AIAnalyzer interface. +// It sends a prompt to the AI using the discovery model (optimized for cost). +func (s *Service) AnalyzeForDiscovery(ctx context.Context, prompt string) (string, error) { + s.mu.RLock() + provider := s.provider + cfg := s.cfg + costStore := s.costStore + s.mu.RUnlock() + + if provider == nil { + return "", fmt.Errorf("AI provider not configured") + } + + if cfg == nil || !cfg.Enabled { + return "", fmt.Errorf("AI is not enabled") + } + + // Get the discovery model (defaults to cheap/fast model) + model := cfg.GetDiscoveryModel() + + // Build simple message for discovery (no tools needed) + messages := []providers.Message{ + { + Role: "user", + Content: prompt, + }, + } + + // Make the API call + resp, err := provider.Chat(ctx, providers.ChatRequest{ + Messages: messages, + Model: model, + MaxTokens: 4096, // Discovery responses need room for detailed JSON + }) + if err != nil { + return "", fmt.Errorf("discovery analysis failed: %w", err) + } + + // Track cost if cost store is available + if costStore != nil { + costStore.Record(cost.UsageEvent{ + Provider: provider.Name(), + RequestModel: model, + UseCase: "discovery", + InputTokens: resp.InputTokens, + OutputTokens: resp.OutputTokens, + }) + } + + return resp.Content, nil +} + // fetchURL fetches content from a URL with size limits and timeout func (s *Service) fetchURL(ctx context.Context, urlStr string) (string, error) { parsedURL, err := parseAndValidateFetchURL(ctx, urlStr) diff --git a/internal/ai/tools/adapters.go b/internal/ai/tools/adapters.go index ddef13b92..5c07f8e8a 100644 --- a/internal/ai/tools/adapters.go +++ b/internal/ai/tools/adapters.go @@ -735,3 +735,221 @@ func trimContainerName(name string) string { } return name } + +// ========== Discovery Provider Adapter ========== + +// DiscoverySource provides access to AI-powered infrastructure discovery data +type DiscoverySource interface { + GetDiscovery(id string) (DiscoverySourceData, error) + GetDiscoveryByResource(resourceType, hostID, resourceID string) (DiscoverySourceData, error) + ListDiscoveries() ([]DiscoverySourceData, error) + ListDiscoveriesByType(resourceType string) ([]DiscoverySourceData, error) + ListDiscoveriesByHost(hostID string) ([]DiscoverySourceData, error) + FormatForAIContext(discoveries []DiscoverySourceData) string +} + +// DiscoverySourceData represents discovery data from the source +type DiscoverySourceData struct { + ID string + ResourceType string + ResourceID string + HostID string + Hostname string + ServiceType string + ServiceName string + ServiceVersion string + Category string + CLIAccess string + Facts []DiscoverySourceFact + ConfigPaths []string + DataPaths []string + UserNotes string + Confidence float64 + AIReasoning string + DiscoveredAt time.Time + UpdatedAt time.Time +} + +// DiscoverySourceFact represents a fact from the source +type DiscoverySourceFact struct { + Category string + Key string + Value string + Source string +} + +// DiscoveryMCPAdapter adapts aidiscovery.Service to MCP DiscoveryProvider interface +type DiscoveryMCPAdapter struct { + source DiscoverySource +} + +// NewDiscoveryMCPAdapter creates a new adapter for discovery data +func NewDiscoveryMCPAdapter(source DiscoverySource) *DiscoveryMCPAdapter { + if source == nil { + return nil + } + return &DiscoveryMCPAdapter{source: source} +} + +// GetDiscovery implements tools.DiscoveryProvider +func (a *DiscoveryMCPAdapter) GetDiscovery(id string) (*ResourceDiscoveryInfo, error) { + if a.source == nil { + return nil, fmt.Errorf("discovery source not available") + } + + data, err := a.source.GetDiscovery(id) + if err != nil { + return nil, err + } + + return a.convertToInfo(data), nil +} + +// GetDiscoveryByResource implements tools.DiscoveryProvider +func (a *DiscoveryMCPAdapter) GetDiscoveryByResource(resourceType, hostID, resourceID string) (*ResourceDiscoveryInfo, error) { + if a.source == nil { + return nil, fmt.Errorf("discovery source not available") + } + + data, err := a.source.GetDiscoveryByResource(resourceType, hostID, resourceID) + if err != nil { + return nil, err + } + + return a.convertToInfo(data), nil +} + +// ListDiscoveries implements tools.DiscoveryProvider +func (a *DiscoveryMCPAdapter) ListDiscoveries() ([]*ResourceDiscoveryInfo, error) { + if a.source == nil { + return nil, fmt.Errorf("discovery source not available") + } + + dataList, err := a.source.ListDiscoveries() + if err != nil { + return nil, err + } + + return a.convertList(dataList), nil +} + +// ListDiscoveriesByType implements tools.DiscoveryProvider +func (a *DiscoveryMCPAdapter) ListDiscoveriesByType(resourceType string) ([]*ResourceDiscoveryInfo, error) { + if a.source == nil { + return nil, fmt.Errorf("discovery source not available") + } + + dataList, err := a.source.ListDiscoveriesByType(resourceType) + if err != nil { + return nil, err + } + + return a.convertList(dataList), nil +} + +// ListDiscoveriesByHost implements tools.DiscoveryProvider +func (a *DiscoveryMCPAdapter) ListDiscoveriesByHost(hostID string) ([]*ResourceDiscoveryInfo, error) { + if a.source == nil { + return nil, fmt.Errorf("discovery source not available") + } + + dataList, err := a.source.ListDiscoveriesByHost(hostID) + if err != nil { + return nil, err + } + + return a.convertList(dataList), nil +} + +// FormatForAIContext implements tools.DiscoveryProvider +func (a *DiscoveryMCPAdapter) FormatForAIContext(discoveries []*ResourceDiscoveryInfo) string { + if a.source == nil { + return "" + } + + // Convert back to source data format + sourceData := make([]DiscoverySourceData, 0, len(discoveries)) + for _, d := range discoveries { + if d == nil { + continue + } + facts := make([]DiscoverySourceFact, 0, len(d.Facts)) + for _, f := range d.Facts { + facts = append(facts, DiscoverySourceFact{ + Category: f.Category, + Key: f.Key, + Value: f.Value, + Source: f.Source, + }) + } + sourceData = append(sourceData, DiscoverySourceData{ + ID: d.ID, + ResourceType: d.ResourceType, + ResourceID: d.ResourceID, + HostID: d.HostID, + Hostname: d.Hostname, + ServiceType: d.ServiceType, + ServiceName: d.ServiceName, + ServiceVersion: d.ServiceVersion, + Category: d.Category, + CLIAccess: d.CLIAccess, + Facts: facts, + ConfigPaths: d.ConfigPaths, + DataPaths: d.DataPaths, + UserNotes: d.UserNotes, + Confidence: d.Confidence, + AIReasoning: d.AIReasoning, + DiscoveredAt: d.DiscoveredAt, + UpdatedAt: d.UpdatedAt, + }) + } + + return a.source.FormatForAIContext(sourceData) +} + +func (a *DiscoveryMCPAdapter) convertToInfo(data DiscoverySourceData) *ResourceDiscoveryInfo { + if data.ID == "" { + return nil + } + + facts := make([]DiscoveryFact, 0, len(data.Facts)) + for _, f := range data.Facts { + facts = append(facts, DiscoveryFact{ + Category: f.Category, + Key: f.Key, + Value: f.Value, + Source: f.Source, + }) + } + + return &ResourceDiscoveryInfo{ + ID: data.ID, + ResourceType: data.ResourceType, + ResourceID: data.ResourceID, + HostID: data.HostID, + Hostname: data.Hostname, + ServiceType: data.ServiceType, + ServiceName: data.ServiceName, + ServiceVersion: data.ServiceVersion, + Category: data.Category, + CLIAccess: data.CLIAccess, + Facts: facts, + ConfigPaths: data.ConfigPaths, + DataPaths: data.DataPaths, + UserNotes: data.UserNotes, + Confidence: data.Confidence, + AIReasoning: data.AIReasoning, + DiscoveredAt: data.DiscoveredAt, + UpdatedAt: data.UpdatedAt, + } +} + +func (a *DiscoveryMCPAdapter) convertList(dataList []DiscoverySourceData) []*ResourceDiscoveryInfo { + result := make([]*ResourceDiscoveryInfo, 0, len(dataList)) + for _, data := range dataList { + if info := a.convertToInfo(data); info != nil { + result = append(result, info) + } + } + return result +} diff --git a/internal/ai/tools/executor.go b/internal/ai/tools/executor.go index d79227119..6905ea632 100644 --- a/internal/ai/tools/executor.go +++ b/internal/ai/tools/executor.go @@ -93,6 +93,46 @@ type UpdatesProvider interface { IsUpdateActionsEnabled() bool } +// DiscoveryProvider provides AI-powered infrastructure discovery +type DiscoveryProvider interface { + GetDiscovery(id string) (*ResourceDiscoveryInfo, error) + GetDiscoveryByResource(resourceType, hostID, resourceID string) (*ResourceDiscoveryInfo, error) + ListDiscoveries() ([]*ResourceDiscoveryInfo, error) + ListDiscoveriesByType(resourceType string) ([]*ResourceDiscoveryInfo, error) + ListDiscoveriesByHost(hostID string) ([]*ResourceDiscoveryInfo, error) + FormatForAIContext(discoveries []*ResourceDiscoveryInfo) string +} + +// ResourceDiscoveryInfo represents discovered information about a resource +type ResourceDiscoveryInfo struct { + ID string `json:"id"` + ResourceType string `json:"resource_type"` + ResourceID string `json:"resource_id"` + HostID string `json:"host_id"` + Hostname string `json:"hostname"` + ServiceType string `json:"service_type"` + ServiceName string `json:"service_name"` + ServiceVersion string `json:"service_version"` + Category string `json:"category"` + CLIAccess string `json:"cli_access"` + Facts []DiscoveryFact `json:"facts"` + ConfigPaths []string `json:"config_paths"` + DataPaths []string `json:"data_paths"` + UserNotes string `json:"user_notes,omitempty"` + Confidence float64 `json:"confidence"` + AIReasoning string `json:"ai_reasoning,omitempty"` + DiscoveredAt time.Time `json:"discovered_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// DiscoveryFact represents a discovered fact about a resource +type DiscoveryFact struct { + Category string `json:"category"` + Key string `json:"key"` + Value string `json:"value"` + Source string `json:"source,omitempty"` +} + // ControlLevel represents the AI's permission level for infrastructure control type ControlLevel string @@ -136,6 +176,9 @@ type ExecutorConfig struct { TopologyProvider TopologyProvider KnowledgeStoreProvider KnowledgeStoreProvider + // Optional providers - discovery + DiscoveryProvider DiscoveryProvider + // Control settings ControlLevel ControlLevel ProtectedGuests []string // VMIDs that AI cannot control @@ -172,6 +215,9 @@ type PulseToolExecutor struct { topologyProvider TopologyProvider knowledgeStoreProvider KnowledgeStoreProvider + // Discovery provider + discoveryProvider DiscoveryProvider + // Control settings controlLevel ControlLevel protectedGuests []string @@ -207,6 +253,7 @@ func NewPulseToolExecutor(cfg ExecutorConfig) *PulseToolExecutor { eventCorrelatorProvider: cfg.EventCorrelatorProvider, topologyProvider: cfg.TopologyProvider, knowledgeStoreProvider: cfg.KnowledgeStoreProvider, + discoveryProvider: cfg.DiscoveryProvider, controlLevel: cfg.ControlLevel, protectedGuests: cfg.ProtectedGuests, registry: NewToolRegistry(), @@ -317,6 +364,11 @@ func (e *PulseToolExecutor) SetKnowledgeStoreProvider(provider KnowledgeStorePro e.knowledgeStoreProvider = provider } +// SetDiscoveryProvider sets the discovery provider for AI-powered discovery +func (e *PulseToolExecutor) SetDiscoveryProvider(provider DiscoveryProvider) { + e.discoveryProvider = provider +} + // ListTools returns the list of available tools func (e *PulseToolExecutor) ListTools() []Tool { tools := e.registry.ListTools(e.controlLevel) @@ -377,6 +429,8 @@ func (e *PulseToolExecutor) isToolAvailable(name string) bool { return e.topologyProvider != nil case "pulse_remember", "pulse_recall": return e.knowledgeStoreProvider != nil + case "pulse_get_discovery", "pulse_list_discoveries": + return e.discoveryProvider != nil default: return e.stateProvider != nil } @@ -415,6 +469,9 @@ func (e *PulseToolExecutor) registerTools() { // Intelligence tools (incident analysis, knowledge management) e.registerIntelligenceTools() + // Discovery tools (AI-powered infrastructure discovery) + e.registerDiscoveryTools() + // Control tools (conditional on control level) e.registerControlTools() } diff --git a/internal/ai/tools/tools_control.go b/internal/ai/tools/tools_control.go index 49d2eff4f..59e0da3d2 100644 --- a/internal/ai/tools/tools_control.go +++ b/internal/ai/tools/tools_control.go @@ -168,22 +168,20 @@ func (e *PulseToolExecutor) executeRunCommand(ctx context.Context, args map[stri } } - // Skip approval checks if pre-approved - if !preApproved && e.controlLevel == ControlLevelControlled { - // Auto-approve read-only commands when in autonomous mode (investigations) - // This allows AI to gather diagnostic data without user approval - if e.isAutonomous && safety.IsReadOnlyCommand(command) { - log.Debug(). - Str("command", command). - Msg("Auto-approving read-only command for autonomous investigation") - } else { - targetType := "container" - if runOnHost { - targetType = "host" - } - approvalID := createApprovalRecord(command, targetType, e.targetID, targetHost, "Control level requires approval") - return NewTextResult(formatApprovalNeeded(command, "Control level requires approval", approvalID)), nil + // Skip approval checks if pre-approved or in autonomous mode + if !preApproved && !e.isAutonomous && e.controlLevel == ControlLevelControlled { + targetType := "container" + if runOnHost { + targetType = "host" } + approvalID := createApprovalRecord(command, targetType, e.targetID, targetHost, "Control level requires approval") + return NewTextResult(formatApprovalNeeded(command, "Control level requires approval", approvalID)), nil + } + if e.isAutonomous { + log.Debug(). + Str("command", command). + Bool("read_only", safety.IsReadOnlyCommand(command)). + Msg("Auto-approving command for autonomous investigation") } if !preApproved && decision == agentexec.PolicyRequireApproval && !e.isAutonomous { targetType := "container" diff --git a/internal/ai/tools/tools_discovery.go b/internal/ai/tools/tools_discovery.go new file mode 100644 index 000000000..64fdab37f --- /dev/null +++ b/internal/ai/tools/tools_discovery.go @@ -0,0 +1,258 @@ +package tools + +import ( + "context" + "fmt" + "strings" +) + +// registerDiscoveryTools registers AI-powered infrastructure discovery tools +func (e *PulseToolExecutor) registerDiscoveryTools() { + e.registry.Register(RegisteredTool{ + Definition: Tool{ + Name: "pulse_get_discovery", + Description: `Get AI-discovered information about a specific resource (VM, LXC container, Docker container, or host). + +Returns: JSON with service type, version, config paths, data paths, CLI access command, and discovered facts. + +Use when: You need detailed context about a resource before proposing remediation commands, or when investigating what services are running on infrastructure. + +The discovery includes: +- Service type and version (e.g., "Frigate NVR v0.13.2") +- Configuration file locations +- Data/storage paths +- CLI access command (e.g., "pct exec 101 -- ") +- Discovered facts (ports, GPUs, connected services, etc.) +- User-added notes + +This information is critical for proposing correct remediation commands that match the actual service configuration.`, + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]PropertySchema{ + "resource_type": { + Type: "string", + Description: "Type of resource: 'vm', 'lxc', 'docker', or 'host'", + Enum: []string{"vm", "lxc", "docker", "host"}, + }, + "resource_id": { + Type: "string", + Description: "Resource identifier (VMID for VM/LXC, container name for Docker, hostname for host)", + }, + "host_id": { + Type: "string", + Description: "Optional: Host/node ID where the resource runs (required for Docker containers)", + }, + }, + Required: []string{"resource_type", "resource_id"}, + }, + }, + Handler: func(ctx context.Context, exec *PulseToolExecutor, args map[string]interface{}) (CallToolResult, error) { + return exec.executeGetDiscovery(ctx, args) + }, + }) + + e.registry.Register(RegisteredTool{ + Definition: Tool{ + Name: "pulse_list_discoveries", + Description: `List all AI-discovered infrastructure information with optional filtering. + +Returns: JSON array of discoveries with service types, versions, and summaries. + +Use when: You need an overview of what services are running across infrastructure, or want to find specific service types. + +Filters: +- type: Filter by resource type (vm, lxc, docker, host) +- host: Filter by host/node ID +- service_type: Filter by discovered service type (e.g., "frigate", "postgresql")`, + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]PropertySchema{ + "type": { + Type: "string", + Description: "Optional: Filter by resource type", + Enum: []string{"vm", "lxc", "docker", "host"}, + }, + "host": { + Type: "string", + Description: "Optional: Filter by host/node ID", + }, + "service_type": { + Type: "string", + Description: "Optional: Filter by discovered service type", + }, + "limit": { + Type: "integer", + Description: "Maximum number of results (default: 50)", + }, + }, + }, + }, + Handler: func(ctx context.Context, exec *PulseToolExecutor, args map[string]interface{}) (CallToolResult, error) { + return exec.executeListDiscoveries(ctx, args) + }, + }) +} + +func (e *PulseToolExecutor) executeGetDiscovery(_ context.Context, args map[string]interface{}) (CallToolResult, error) { + if e.discoveryProvider == nil { + return NewTextResult("Discovery service not available. Run a discovery scan first."), nil + } + + resourceType, _ := args["resource_type"].(string) + resourceID, _ := args["resource_id"].(string) + hostID, _ := args["host_id"].(string) + + if resourceType == "" { + return NewErrorResult(fmt.Errorf("resource_type is required")), nil + } + if resourceID == "" { + return NewErrorResult(fmt.Errorf("resource_id is required")), nil + } + + discovery, err := e.discoveryProvider.GetDiscoveryByResource(resourceType, hostID, resourceID) + if err != nil { + return NewErrorResult(fmt.Errorf("failed to get discovery: %w", err)), nil + } + + if discovery == nil { + return NewJSONResult(map[string]interface{}{ + "found": false, + "resource_type": resourceType, + "resource_id": resourceID, + "message": "No discovery data found for this resource. Run a discovery scan to gather information.", + }), nil + } + + // Return the discovery information + response := map[string]interface{}{ + "found": true, + "id": discovery.ID, + "resource_type": discovery.ResourceType, + "resource_id": discovery.ResourceID, + "host_id": discovery.HostID, + "hostname": discovery.Hostname, + "service_type": discovery.ServiceType, + "service_name": discovery.ServiceName, + "service_version": discovery.ServiceVersion, + "category": discovery.Category, + "cli_access": discovery.CLIAccess, + "config_paths": discovery.ConfigPaths, + "data_paths": discovery.DataPaths, + "confidence": discovery.Confidence, + "discovered_at": discovery.DiscoveredAt, + "updated_at": discovery.UpdatedAt, + } + + // Add facts if present + if len(discovery.Facts) > 0 { + facts := make([]map[string]string, 0, len(discovery.Facts)) + for _, f := range discovery.Facts { + facts = append(facts, map[string]string{ + "category": f.Category, + "key": f.Key, + "value": f.Value, + }) + } + response["facts"] = facts + } + + // Add user notes if present + if discovery.UserNotes != "" { + response["user_notes"] = discovery.UserNotes + } + + // Add AI reasoning for context + if discovery.AIReasoning != "" { + response["ai_reasoning"] = discovery.AIReasoning + } + + return NewJSONResult(response), nil +} + +func (e *PulseToolExecutor) executeListDiscoveries(_ context.Context, args map[string]interface{}) (CallToolResult, error) { + if e.discoveryProvider == nil { + return NewTextResult("Discovery service not available."), nil + } + + filterType, _ := args["type"].(string) + filterHost, _ := args["host"].(string) + filterServiceType, _ := args["service_type"].(string) + limit := intArg(args, "limit", 50) + + var discoveries []*ResourceDiscoveryInfo + var err error + + // Get discoveries based on filters + if filterType != "" { + discoveries, err = e.discoveryProvider.ListDiscoveriesByType(filterType) + } else if filterHost != "" { + discoveries, err = e.discoveryProvider.ListDiscoveriesByHost(filterHost) + } else { + discoveries, err = e.discoveryProvider.ListDiscoveries() + } + + if err != nil { + return NewErrorResult(fmt.Errorf("failed to list discoveries: %w", err)), nil + } + + // Filter by service type if specified + if filterServiceType != "" { + filtered := make([]*ResourceDiscoveryInfo, 0) + filterLower := strings.ToLower(filterServiceType) + for _, d := range discoveries { + if strings.Contains(strings.ToLower(d.ServiceType), filterLower) || + strings.Contains(strings.ToLower(d.ServiceName), filterLower) { + filtered = append(filtered, d) + } + } + discoveries = filtered + } + + // Apply limit + if len(discoveries) > limit { + discoveries = discoveries[:limit] + } + + // Build response + results := make([]map[string]interface{}, 0, len(discoveries)) + for _, d := range discoveries { + result := map[string]interface{}{ + "id": d.ID, + "resource_type": d.ResourceType, + "resource_id": d.ResourceID, + "host_id": d.HostID, + "hostname": d.Hostname, + "service_type": d.ServiceType, + "service_name": d.ServiceName, + "service_version": d.ServiceVersion, + "category": d.Category, + "cli_access": d.CLIAccess, + "confidence": d.Confidence, + "updated_at": d.UpdatedAt, + } + + // Add key facts count + if len(d.Facts) > 0 { + result["facts_count"] = len(d.Facts) + } + + results = append(results, result) + } + + response := map[string]interface{}{ + "discoveries": results, + "total": len(results), + } + + if filterType != "" { + response["filter_type"] = filterType + } + if filterHost != "" { + response["filter_host"] = filterHost + } + if filterServiceType != "" { + response["filter_service_type"] = filterServiceType + } + + return NewJSONResult(response), nil +} diff --git a/internal/ai/unified/alerts_adapter_test.go b/internal/ai/unified/alerts_adapter_test.go new file mode 100644 index 000000000..b06c79bef --- /dev/null +++ b/internal/ai/unified/alerts_adapter_test.go @@ -0,0 +1,166 @@ +package unified + +import ( + "reflect" + "testing" + "time" + "unsafe" + + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" +) + +func setUnexportedField(t *testing.T, target interface{}, fieldName string, value interface{}) { + t.Helper() + + val := reflect.ValueOf(target) + if val.Kind() != reflect.Ptr { + t.Fatalf("target must be a pointer") + } + field := val.Elem().FieldByName(fieldName) + if !field.IsValid() { + t.Fatalf("field %s not found", fieldName) + } + + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) +} + +func getUnexportedField(t *testing.T, target interface{}, fieldName string) reflect.Value { + t.Helper() + + val := reflect.ValueOf(target) + if val.Kind() != reflect.Ptr { + t.Fatalf("target must be a pointer") + } + field := val.Elem().FieldByName(fieldName) + if !field.IsValid() { + t.Fatalf("field %s not found", fieldName) + } + + return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() +} + +func TestAlertManagerAdapter_NilManager(t *testing.T) { + adapter := NewAlertManagerAdapter(nil) + if adapter.GetActiveAlerts() != nil { + t.Fatalf("expected nil alerts for nil manager") + } + if adapter.GetAlert("missing") != nil { + t.Fatalf("expected nil alert for nil manager") + } + + adapter.SetAlertCallback(nil) + adapter.SetResolvedCallback(nil) +} + +func TestAlertManagerAdapter_WithManagerAndCallbacks(t *testing.T) { + manager := alerts.NewManager() + alert := &alerts.Alert{ + ID: "alert-1", + Type: "cpu", + Level: alerts.AlertLevelCritical, + ResourceID: "vm-100", + ResourceName: "web-1", + Node: "node-1", + Message: "high cpu", + Value: 92.5, + Threshold: 80, + StartTime: time.Now().Add(-time.Minute), + LastSeen: time.Now(), + Metadata: map[string]interface{}{"resourceType": "node"}, + } + + activeAlerts := map[string]*alerts.Alert{ + alert.ID: alert, + } + setUnexportedField(t, manager, "activeAlerts", activeAlerts) + + adapter := NewAlertManagerAdapter(manager) + active := adapter.GetActiveAlerts() + if len(active) != 1 { + t.Fatalf("expected 1 active alert, got %d", len(active)) + } + if active[0].GetAlertID() != alert.ID { + t.Fatalf("expected alert ID %s", alert.ID) + } + if active[0].GetAlertLevel() != string(alert.Level) { + t.Fatalf("expected alert level %s", alert.Level) + } + + found := adapter.GetAlert(alert.ID) + if found == nil || found.GetAlertID() != alert.ID { + t.Fatalf("expected to find alert %s", alert.ID) + } + if adapter.GetAlert("missing") != nil { + t.Fatalf("expected nil for missing alert") + } + + alertCh := make(chan string, 1) + adapter.SetAlertCallback(func(ad AlertAdapter) { + alertCh <- ad.GetAlertID() + }) + onAlert := getUnexportedField(t, manager, "onAlert").Interface().(func(alert *alerts.Alert)) + onAlert(alert) + select { + case got := <-alertCh: + if got != alert.ID { + t.Fatalf("expected alert callback for %s, got %s", alert.ID, got) + } + default: + t.Fatalf("expected alert callback to fire") + } + + resolvedCh := make(chan string, 1) + adapter.SetResolvedCallback(func(alertID string) { + resolvedCh <- alertID + }) + onResolved := getUnexportedField(t, manager, "onResolved").Interface().(func(alertID string)) + onResolved(alert.ID) + select { + case got := <-resolvedCh: + if got != alert.ID { + t.Fatalf("expected resolved callback for %s, got %s", alert.ID, got) + } + default: + t.Fatalf("expected resolved callback to fire") + } +} + +func TestAlertWrapper_NilAlert(t *testing.T) { + wrapper := &alertWrapper{} + if wrapper.GetAlertID() != "" { + t.Fatalf("expected empty ID") + } + if wrapper.GetAlertType() != "" { + t.Fatalf("expected empty type") + } + if wrapper.GetAlertLevel() != "" { + t.Fatalf("expected empty level") + } + if wrapper.GetResourceID() != "" { + t.Fatalf("expected empty resource id") + } + if wrapper.GetResourceName() != "" { + t.Fatalf("expected empty resource name") + } + if wrapper.GetNode() != "" { + t.Fatalf("expected empty node") + } + if wrapper.GetMessage() != "" { + t.Fatalf("expected empty message") + } + if wrapper.GetValue() != 0 { + t.Fatalf("expected zero value") + } + if wrapper.GetThreshold() != 0 { + t.Fatalf("expected zero threshold") + } + if !wrapper.GetStartTime().IsZero() { + t.Fatalf("expected zero start time") + } + if !wrapper.GetLastSeen().IsZero() { + t.Fatalf("expected zero last seen time") + } + if wrapper.GetMetadata() != nil { + t.Fatalf("expected nil metadata") + } +} diff --git a/internal/ai/unified/bridge_test.go b/internal/ai/unified/bridge_test.go new file mode 100644 index 000000000..044309039 --- /dev/null +++ b/internal/ai/unified/bridge_test.go @@ -0,0 +1,208 @@ +package unified + +import ( + "testing" + "time" +) + +type stubAlertProvider struct { + alerts []AlertAdapter + alertCb func(AlertAdapter) + resolvedCb func(string) +} + +func (s *stubAlertProvider) GetActiveAlerts() []AlertAdapter { + return s.alerts +} + +func (s *stubAlertProvider) GetAlert(alertID string) AlertAdapter { + for _, alert := range s.alerts { + if alert.GetAlertID() == alertID { + return alert + } + } + return nil +} + +func (s *stubAlertProvider) SetAlertCallback(cb func(AlertAdapter)) { + s.alertCb = cb +} + +func (s *stubAlertProvider) SetResolvedCallback(cb func(alertID string)) { + s.resolvedCb = cb +} + +func TestAlertBridge_StartStopAndSync(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + bridge := NewAlertBridge(store, DefaultBridgeConfig()) + + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + ResourceName: "web", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + provider := &stubAlertProvider{alerts: []AlertAdapter{alert}} + bridge.SetAlertProvider(provider) + + bridge.Start() + stats := bridge.Stats() + if !stats.Running { + t.Fatalf("expected bridge running") + } + if store.GetByAlert("alert-1") == nil { + t.Fatalf("expected alert to be synced into store") + } + + bridge.Stop() + stats = bridge.Stats() + if stats.Running { + t.Fatalf("expected bridge stopped") + } +} + +func TestAlertBridge_HandleNewAlertAndEnhance(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + bridge := NewAlertBridge(store, BridgeConfig{ + AutoEnhance: true, + EnhanceDelay: 10 * time.Millisecond, + TriggerPatrolOnNew: true, + }) + bridge.running = true + + patrolCh := make(chan string, 1) + bridge.SetPatrolTrigger(func(resourceID, reason string) { + patrolCh <- reason + }) + + enhanceCh := make(chan string, 1) + bridge.SetAIEnhancement(func(findingID string) error { + enhanceCh <- findingID + return nil + }) + + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + ResourceName: "web", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + + bridge.handleNewAlert(alert) + + select { + case reason := <-patrolCh: + if reason != "alert_fired" { + t.Fatalf("expected patrol reason alert_fired, got %s", reason) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected patrol trigger") + } + + select { + case <-enhanceCh: + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected enhancement callback") + } +} + +func TestAlertBridge_HandleAlertResolved(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + bridge := NewAlertBridge(store, BridgeConfig{ + TriggerPatrolOnClear: true, + }) + bridge.running = true + + patrolCh := make(chan string, 1) + bridge.SetPatrolTrigger(func(resourceID, reason string) { + patrolCh <- reason + }) + + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + ResourceName: "web", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + finding, _ := store.AddFromAlert(alert) + bridge.pendingEnhancements[finding.ID] = time.AfterFunc(time.Second, func() {}) + + bridge.handleAlertResolved(alert.AlertID) + + select { + case reason := <-patrolCh: + if reason != "alert_cleared" { + t.Fatalf("expected patrol reason alert_cleared, got %s", reason) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected patrol trigger") + } + + if _, ok := bridge.pendingEnhancements[finding.ID]; ok { + t.Fatalf("expected enhancement to be canceled") + } +} + +func TestAlertBridge_ScheduleEnhancementInactiveFinding(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + bridge := NewAlertBridge(store, DefaultBridgeConfig()) + + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + ResourceName: "web", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + finding, _ := store.AddFromAlert(alert) + store.Resolve(finding.ID) + + enhanceCh := make(chan string, 1) + bridge.scheduleEnhancement(finding.ID, 10*time.Millisecond, func(id string) error { + enhanceCh <- id + return nil + }) + + select { + case <-enhanceCh: + t.Fatalf("did not expect enhancement on inactive finding") + case <-time.After(50 * time.Millisecond): + } +} + +func TestAlertBridge_StatsAndStore(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + bridge := NewAlertBridge(store, DefaultBridgeConfig()) + bridge.running = true + bridge.pendingEnhancements["f1"] = time.AfterFunc(time.Second, func() {}) + + stats := bridge.Stats() + if !stats.Running { + t.Fatalf("expected running stats") + } + if stats.PendingEnhancements != 1 { + t.Fatalf("expected pending enhancements to be 1") + } + if bridge.GetUnifiedStore() != store { + t.Fatalf("expected store to be returned") + } +} diff --git a/internal/ai/unified/integration_test.go b/internal/ai/unified/integration_test.go new file mode 100644 index 000000000..337d14e40 --- /dev/null +++ b/internal/ai/unified/integration_test.go @@ -0,0 +1,250 @@ +package unified + +import ( + "fmt" + "strings" + "testing" + "time" +) + +type stubCorrelationEngine struct { + rootCauseID string + correlated []string + explanation string + returnErr error + calledWithID string +} + +func (s *stubCorrelationEngine) AnalyzeForFinding(findingID string, resourceID string) (string, []string, string, error) { + s.calledWithID = findingID + return s.rootCauseID, s.correlated, s.explanation, s.returnErr +} + +type stubRemediationEngine struct { + planID string + err error + called chan string +} + +func (s *stubRemediationEngine) GeneratePlanForFinding(finding *UnifiedFinding) (string, error) { + if s.called != nil { + s.called <- finding.ID + } + return s.planID, s.err +} + +type stubLearningStore struct { + suppress bool + last string +} + +func (s *stubLearningStore) RecordFindingFeedback(findingID, resourceID, category, action, reason, note string) { + s.last = fmt.Sprintf("%s:%s:%s", findingID, action, reason) +} + +func (s *stubLearningStore) ShouldSuppress(resourceID, category, severity string) bool { + return s.suppress +} + +func TestIntegration_AddAIFinding_Suppressed(t *testing.T) { + integration := NewIntegration(DefaultIntegrationConfig(t.TempDir())) + learning := &stubLearningStore{suppress: true} + integration.SetLearningStore(learning) + + finding := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityWarning, + Category: CategoryCapacity, + ResourceID: "res-1", + Title: "test", + } + + result, isNew := integration.AddAIFinding(finding) + if result != nil || isNew { + t.Fatalf("expected suppressed finding to be dropped") + } +} + +func TestIntegration_AddAIFinding_Remediation(t *testing.T) { + integration := NewIntegration(DefaultIntegrationConfig(t.TempDir())) + remediation := &stubRemediationEngine{ + planID: "plan-1", + called: make(chan string, 1), + } + integration.SetRemediationEngine(remediation) + + finding := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityWarning, + Category: CategoryCapacity, + ResourceID: "res-1", + Title: "test", + } + integration.AddAIFinding(finding) + + select { + case <-remediation.called: + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected remediation to be invoked") + } + + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + stored := integration.store.Get(finding.ID) + if stored != nil && stored.RemediationID == "plan-1" { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("expected remediation ID to be linked") +} + +func TestIntegration_EnhanceFindingWithCorrelation(t *testing.T) { + integration := NewIntegration(DefaultIntegrationConfig(t.TempDir())) + engine := &stubCorrelationEngine{ + rootCauseID: "root-1", + correlated: []string{"c1", "c2", "c3"}, + explanation: "cause", + } + integration.SetCorrelationEngine(engine) + + finding := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityWarning, + Category: CategoryCapacity, + ResourceID: "res-1", + Title: "test", + } + integration.store.AddFromAI(finding) + + if err := integration.enhanceFindingWithCorrelation("ai-1"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + updated := integration.store.Get("ai-1") + if updated == nil || !updated.EnhancedByAI { + t.Fatalf("expected finding to be enhanced") + } + if updated.AIConfidence != 1.0 { + t.Fatalf("expected confidence 1.0, got %f", updated.AIConfidence) + } + if updated.RootCauseID != "root-1" { + t.Fatalf("expected root cause ID") + } +} + +func TestIntegration_EnhanceFindingWithCorrelation_Errors(t *testing.T) { + integration := NewIntegration(DefaultIntegrationConfig(t.TempDir())) + engine := &stubCorrelationEngine{returnErr: errTest} + integration.SetCorrelationEngine(engine) + if err := integration.enhanceFindingWithCorrelation("missing"); err == nil { + t.Fatalf("expected error for missing finding") + } + finding := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityWarning, + Category: CategoryCapacity, + ResourceID: "res-1", + Title: "test", + } + integration.store.AddFromAI(finding) + if err := integration.enhanceFindingWithCorrelation("ai-1"); err == nil { + t.Fatalf("expected correlation error") + } +} + +func TestIntegration_DismissAndSnoozeFeedback(t *testing.T) { + integration := NewIntegration(DefaultIntegrationConfig(t.TempDir())) + learning := &stubLearningStore{} + integration.SetLearningStore(learning) + + finding := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityWarning, + Category: CategoryCapacity, + ResourceID: "res-1", + Title: "test", + } + integration.store.AddFromAI(finding) + + if !integration.DismissFinding("ai-1", "expected", "note") { + t.Fatalf("expected dismiss to succeed") + } + if !strings.Contains(learning.last, "dismiss") { + t.Fatalf("expected dismiss feedback to be recorded") + } + + if !integration.SnoozeFinding("ai-1", time.Minute) { + t.Fatalf("expected snooze to succeed") + } + if !strings.Contains(learning.last, "snooze") { + t.Fatalf("expected snooze feedback to be recorded") + } +} + +func TestIntegration_SummaryAndSnapshots(t *testing.T) { + integration := NewIntegration(DefaultIntegrationConfig(t.TempDir())) + + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "critical", + ResourceID: "vm-1", + ResourceName: "web", + Value: 95, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + finding, _ := integration.store.AddFromAlert(alert) + integration.store.EnhanceWithAI(finding.ID, "context", 0.8, "", nil) + + summary := integration.GetActiveIssuesSummary() + if !strings.Contains(summary, "active issues") { + t.Fatalf("expected active issues summary") + } + + before := integration.TakeSnapshot() + integration.store.Resolve(finding.ID) + after := integration.TakeSnapshot() + + diff := CompareSnapshots(before, after) + if diff == nil || !diff.HasChanges() { + t.Fatalf("expected snapshot changes") + } + if diff.Summary() == "" { + t.Fatalf("expected summary") + } +} + +func TestIntegration_GetContextForPatrol(t *testing.T) { + integration := NewIntegration(DefaultIntegrationConfig(t.TempDir())) + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + ResourceName: "web", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + integration.store.AddFromAlert(alert) + + context := integration.GetContextForPatrol() + if context == "" { + t.Fatalf("expected patrol context") + } +} + +func TestMinInt(t *testing.T) { + if minInt(3, 1) != 1 { + t.Fatalf("expected minInt result") + } +} diff --git a/internal/ai/unified/persistence_test.go b/internal/ai/unified/persistence_test.go new file mode 100644 index 000000000..25de45ba7 --- /dev/null +++ b/internal/ai/unified/persistence_test.go @@ -0,0 +1,100 @@ +package unified + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestFilePersistence_SaveLoad(t *testing.T) { + dir := t.TempDir() + p := NewFilePersistence(dir) + + findings := map[string]*UnifiedFinding{ + "f1": { + ID: "f1", + Source: SourceThreshold, + ResourceID: "res-1", + DetectedAt: time.Now(), + }, + } + + if err := p.SaveFindings(findings); err != nil { + t.Fatalf("unexpected save error: %v", err) + } + + loaded, err := p.LoadFindings() + if err != nil { + t.Fatalf("unexpected load error: %v", err) + } + if len(loaded) != 1 { + t.Fatalf("expected 1 finding, got %d", len(loaded)) + } +} + +func TestFilePersistence_LoadMissingAndInvalid(t *testing.T) { + dir := t.TempDir() + p := NewFilePersistence(dir) + + loaded, err := p.LoadFindings() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(loaded) != 0 { + t.Fatalf("expected empty map for missing file") + } + + path := filepath.Join(dir, "unified_findings.json") + if err := os.WriteFile(path, []byte("{bad json"), 0644); err != nil { + t.Fatalf("write failed: %v", err) + } + + loaded, err = p.LoadFindings() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(loaded) != 0 { + t.Fatalf("expected empty map for invalid json") + } +} + +func TestVersionedPersistence_SaveLoad(t *testing.T) { + dir := t.TempDir() + p := NewVersionedPersistence(dir) + + findings := map[string]*UnifiedFinding{ + "f1": {ID: "f1", Source: SourceAIPatrol, ResourceID: "res-1"}, + } + + if err := p.SaveFindings(findings); err != nil { + t.Fatalf("unexpected save error: %v", err) + } + + loaded, err := p.LoadFindings() + if err != nil { + t.Fatalf("unexpected load error: %v", err) + } + if len(loaded) != 1 { + t.Fatalf("expected 1 finding, got %d", len(loaded)) + } +} + +func TestVersionedPersistence_LoadLegacy(t *testing.T) { + dir := t.TempDir() + p := NewVersionedPersistence(dir) + + path := filepath.Join(dir, "unified_findings.json") + legacy := `[{"id":"f1","source":"ai-patrol","resource_id":"res-1"}]` + if err := os.WriteFile(path, []byte(legacy), 0644); err != nil { + t.Fatalf("write failed: %v", err) + } + + loaded, err := p.LoadFindings() + if err != nil { + t.Fatalf("unexpected load error: %v", err) + } + if len(loaded) != 1 { + t.Fatalf("expected 1 finding, got %d", len(loaded)) + } +} diff --git a/internal/ai/unified/setup_test.go b/internal/ai/unified/setup_test.go new file mode 100644 index 000000000..a3e9f2248 --- /dev/null +++ b/internal/ai/unified/setup_test.go @@ -0,0 +1,42 @@ +package unified + +import ( + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" +) + +func TestSetup_Defaults(t *testing.T) { + result, err := Setup(SetupConfig{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == nil || result.Integration == nil || result.Store == nil || result.Bridge == nil { + t.Fatalf("expected setup components") + } + if result.Adapter != nil { + t.Fatalf("expected nil adapter when no alert manager provided") + } +} + +func TestQuickSetup(t *testing.T) { + manager := alerts.NewManager() + result, err := QuickSetup(manager, t.TempDir()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Adapter == nil { + t.Fatalf("expected adapter with alert manager") + } +} + +func TestSetupWithPatrol(t *testing.T) { + manager := alerts.NewManager() + result, err := SetupWithPatrol(manager, t.TempDir(), func(resourceID, reason string) {}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Bridge == nil { + t.Fatalf("expected bridge") + } +} diff --git a/internal/ai/unified/store_additional_test.go b/internal/ai/unified/store_additional_test.go new file mode 100644 index 000000000..3d6fc3b2f --- /dev/null +++ b/internal/ai/unified/store_additional_test.go @@ -0,0 +1,380 @@ +package unified + +import ( + "strings" + "testing" + "time" +) + +type stubUnifiedPersistence struct { + loadFindings map[string]*UnifiedFinding + loadErr error + saved map[string]*UnifiedFinding + saveCalls int + saveErr error +} + +func (s *stubUnifiedPersistence) SaveFindings(findings map[string]*UnifiedFinding) error { + s.saveCalls++ + s.saved = make(map[string]*UnifiedFinding, len(findings)) + for id, f := range findings { + copy := *f + s.saved[id] = © + } + return s.saveErr +} + +func (s *stubUnifiedPersistence) LoadFindings() (map[string]*UnifiedFinding, error) { + return s.loadFindings, s.loadErr +} + +func TestUnifiedStore_SetPersistence_LoadsFindings(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + loaded := map[string]*UnifiedFinding{ + "f1": { + ID: "f1", + Source: SourceThreshold, + ResourceID: "r1", + AlertID: "a1", + }, + "f2": { + ID: "f2", + Source: SourceAIPatrol, + ResourceID: "r1", + }, + } + persistence := &stubUnifiedPersistence{loadFindings: loaded} + if err := store.SetPersistence(persistence); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if store.Get("f1") == nil { + t.Fatalf("expected finding f1 to load") + } + if store.GetByAlert("a1") == nil { + t.Fatalf("expected alert index to load") + } + byResource := store.GetByResource("r1") + if len(byResource) != 2 { + t.Fatalf("expected 2 findings for resource, got %d", len(byResource)) + } +} + +func TestUnifiedStore_SetPersistence_Error(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + persistence := &stubUnifiedPersistence{loadErr: errTest} + if err := store.SetPersistence(persistence); err == nil { + t.Fatalf("expected error") + } +} + +func TestUnifiedStore_ConvertAlert_MetadataAndRecommendation(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "disk", + AlertLevel: "critical", + ResourceID: "vm-1", + ResourceName: "db", + Message: "disk full", + Value: 96, + Threshold: 90, + StartTime: time.Now().Add(-time.Minute), + LastSeen: time.Now(), + Metadata: map[string]interface{}{"resourceType": "node"}, + } + + finding := store.ConvertAlert(alert) + if finding.Severity != SeverityCritical { + t.Fatalf("expected critical severity") + } + if finding.Category != CategoryCapacity { + t.Fatalf("expected capacity category") + } + if finding.ResourceType != "node" { + t.Fatalf("expected resource type from metadata") + } + if !strings.Contains(finding.Recommendation, "URGENT") { + t.Fatalf("expected urgent recommendation, got %q", finding.Recommendation) + } +} + +func TestUnifiedStore_AddFromAlert_ReopensResolved(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + ResourceName: "web", + Value: 90, + Threshold: 80, + StartTime: time.Now().Add(-time.Minute), + LastSeen: time.Now(), + } + + finding, _ := store.AddFromAlert(alert) + now := time.Now().Add(-time.Second) + store.findings[finding.ID].ResolvedAt = &now + + alert.AlertLevel = "critical" + updated, isNew := store.AddFromAlert(alert) + if isNew { + t.Fatalf("expected update for existing alert") + } + if updated.ResolvedAt != nil { + t.Fatalf("expected resolvedAt cleared") + } + if updated.Severity != SeverityCritical { + t.Fatalf("expected severity to upgrade") + } +} + +func TestUnifiedStore_AddFromAI_Existing(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + initial := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityWatch, + ResourceID: "res-1", + Title: "initial", + } + store.AddFromAI(initial) + + update := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityCritical, + ResourceID: "res-1", + Description: "updated", + } + updated, isNew := store.AddFromAI(update) + if isNew { + t.Fatalf("expected existing update") + } + if updated.Severity != SeverityCritical { + t.Fatalf("expected severity update") + } + if updated.Description != "updated" { + t.Fatalf("expected description update") + } +} + +func TestUnifiedStore_BasicMutations(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + ResourceName: "web", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + finding, _ := store.AddFromAlert(alert) + + if store.ResolveByAlert("missing") { + t.Fatalf("expected resolve to fail for missing alert") + } + if store.EnhanceWithAI("missing", "ctx", 0.5, "", nil) { + t.Fatalf("expected enhance to fail for missing finding") + } + if store.LinkRemediation("missing", "plan") { + t.Fatalf("expected link remediation to fail for missing finding") + } + + if !store.Acknowledge(finding.ID) { + t.Fatalf("expected acknowledge to succeed") + } + if !store.Snooze(finding.ID, time.Minute) { + t.Fatalf("expected snooze to succeed") + } + if !store.Resolve(finding.ID) { + t.Fatalf("expected resolve to succeed") + } + resolved := store.Get(finding.ID) + if resolved.ResolvedAt == nil { + t.Fatalf("expected resolved at set") + } + if resolved.SnoozedUntil != nil { + t.Fatalf("expected snooze cleared on resolve") + } +} + +func TestUnifiedStore_AIFilteringAndDismiss(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + ResourceName: "web", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + store.AddFromAlert(alert) + + ai := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityWarning, + Category: CategoryPerformance, + ResourceID: "vm-2", + ResourceName: "db", + Title: "ai", + } + store.AddFromAI(ai) + + if len(store.GetAIFindings()) != 1 { + t.Fatalf("expected 1 AI finding") + } + if len(store.GetUnenhancedThresholdFindings()) != 1 { + t.Fatalf("expected 1 unenhanced threshold finding") + } + + for _, f := range store.GetThresholdFindings() { + if !store.Dismiss(f.ID, "not_an_issue", "expected") { + t.Fatalf("expected dismiss to succeed") + } + dismissed := store.Get(f.ID) + if !dismissed.Suppressed { + t.Fatalf("expected suppression for not_an_issue") + } + } +} + +func TestUnifiedStore_FormatForContext(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "critical", + ResourceID: "vm-1", + ResourceName: "web", + Value: 95, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + finding, _ := store.AddFromAlert(alert) + store.EnhanceWithAI(finding.ID, "context", 0.9, "root-1", nil) + + ai := &UnifiedFinding{ + ID: "ai-1", + Source: SourceAIPatrol, + Severity: SeverityWarning, + Category: CategoryPerformance, + ResourceID: "vm-2", + ResourceName: "db", + Title: "ai finding", + RootCauseID: "root-1", + } + store.AddFromAI(ai) + + out := store.FormatForContext() + if !strings.Contains(out, "Threshold Alerts") { + t.Fatalf("expected threshold section") + } + if !strings.Contains(out, "AI context") { + t.Fatalf("expected AI context") + } + if !strings.Contains(out, "Root cause linked") { + t.Fatalf("expected root cause linkage") + } +} + +func TestUnifiedStore_SummaryIncludesEnhanced(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + finding, _ := store.AddFromAlert(alert) + store.EnhanceWithAI(finding.ID, "context", 0.8, "", nil) + + summary := store.GetSummary() + if summary.EnhancedByAI != 1 { + t.Fatalf("expected enhanced count 1, got %d", summary.EnhancedByAI) + } +} + +func TestUnifiedStore_ForceSave(t *testing.T) { + store := NewUnifiedStore(DefaultAlertToFindingConfig()) + persistence := &stubUnifiedPersistence{} + store.persistence = persistence + + alert := &SimpleAlertAdapter{ + AlertID: "alert-1", + AlertType: "cpu", + AlertLevel: "warning", + ResourceID: "vm-1", + Value: 90, + Threshold: 80, + StartTime: time.Now(), + LastSeen: time.Now(), + } + store.AddFromAlert(alert) + + if err := store.ForceSave(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if persistence.saveCalls == 0 { + t.Fatalf("expected save to be called") + } +} + +func TestUnifiedHelpers(t *testing.T) { + if severityOrder("unknown") != 0 { + t.Fatalf("expected default severity order") + } + if determineResourceType("nodeOffline", nil) != "node" { + t.Fatalf("expected node resource type") + } + if determineResourceType("usage", nil) != "storage" { + t.Fatalf("expected storage resource type") + } + if determineResourceType("backup", nil) != "backup" { + t.Fatalf("expected backup resource type") + } + if determineResourceType("snapshot", nil) != "snapshot" { + t.Fatalf("expected snapshot resource type") + } + if determineResourceType("imageUpdateAvail", nil) != "docker" { + t.Fatalf("expected docker resource type") + } + if determineResourceType("other", map[string]interface{}{"resourceType": "custom"}) != "custom" { + t.Fatalf("expected custom resource type") + } + + title := generateTitle("offline", "node-1", 0, 0) + if !strings.Contains(title, "offline") { + t.Fatalf("expected offline title") + } + if !strings.Contains(generateRecommendation("temperature", 0, 0), "cooling") { + t.Fatalf("expected temperature recommendation") + } + if !strings.Contains(generateRecommendation("unknown", 0, 0), "Investigate") { + t.Fatalf("expected default recommendation") + } + if formatSourceName("custom") != "custom" { + t.Fatalf("expected fallback source name") + } +} + +var errTest = &testError{} + +type testError struct{} + +func (e *testError) Error() string { + return "test error" +} diff --git a/internal/aidiscovery/commands.go b/internal/aidiscovery/commands.go new file mode 100644 index 000000000..24da8ce90 --- /dev/null +++ b/internal/aidiscovery/commands.go @@ -0,0 +1,442 @@ +package aidiscovery + +import ( + "fmt" + "strings" +) + +// DiscoveryCommand represents a command to run during discovery. +type DiscoveryCommand struct { + Name string // Human-readable name + Command string // The command template + Description string // What this discovers + Categories []string // What categories of info this provides + Timeout int // Timeout in seconds (0 = default) + Optional bool // If true, don't fail if command fails +} + +// CommandSet represents a set of commands for a resource type. +type CommandSet struct { + ResourceType ResourceType + Commands []DiscoveryCommand +} + +// GetCommandsForResource returns the commands to run for a given resource type. +func GetCommandsForResource(resourceType ResourceType) []DiscoveryCommand { + switch resourceType { + case ResourceTypeLXC: + return getLXCCommands() + case ResourceTypeVM: + return getVMCommands() + case ResourceTypeDocker: + return getDockerCommands() + case ResourceTypeDockerVM, ResourceTypeDockerLXC: + return getNestedDockerCommands() + case ResourceTypeK8s: + return getK8sCommands() + case ResourceTypeHost: + return getHostCommands() + default: + return []DiscoveryCommand{} + } +} + +// getLXCCommands returns commands for discovering LXC containers. +func getLXCCommands() []DiscoveryCommand { + return []DiscoveryCommand{ + { + Name: "os_release", + Command: "cat /etc/os-release", + Description: "Operating system identification", + Categories: []string{"version", "config"}, + Optional: true, + }, + { + Name: "hostname", + Command: "hostname", + Description: "Container hostname", + Categories: []string{"config"}, + Optional: true, + }, + { + Name: "running_services", + Command: "systemctl list-units --type=service --state=running --no-pager 2>/dev/null | head -30 || service --status-all 2>/dev/null | grep '+' | head -30", + Description: "Running services and daemons", + Categories: []string{"service"}, + Optional: true, + }, + { + Name: "listening_ports", + Command: "ss -tlnp 2>/dev/null | head -25 || netstat -tlnp 2>/dev/null | head -25", + Description: "Network ports listening", + Categories: []string{"port", "network"}, + Optional: true, + }, + { + Name: "top_processes", + Command: "ps aux --sort=-rss 2>/dev/null | head -15 || ps aux | head -15", + Description: "Top processes by memory", + Categories: []string{"service"}, + Optional: true, + }, + { + Name: "disk_usage", + Command: "df -h 2>/dev/null | head -15", + Description: "Disk usage and mount points", + Categories: []string{"storage"}, + Optional: true, + }, + { + Name: "docker_check", + Command: "docker ps --format '{{.Names}}: {{.Image}} ({{.Status}})' 2>/dev/null | head -20 || echo 'no_docker'", + Description: "Docker containers if running", + Categories: []string{"service", "container"}, + Optional: true, + }, + { + Name: "installed_packages", + Command: "dpkg -l 2>/dev/null | grep -E '^ii' | awk '{print $2}' | head -50 || rpm -qa 2>/dev/null | head -50 || apk list --installed 2>/dev/null | head -50", + Description: "Installed packages", + Categories: []string{"version", "service"}, + Optional: true, + }, + { + Name: "config_files", + Command: "find /etc -name '*.conf' -o -name '*.yml' -o -name '*.yaml' -o -name '*.json' 2>/dev/null | head -30", + Description: "Configuration files", + Categories: []string{"config"}, + Optional: true, + }, + { + Name: "cron_jobs", + Command: "crontab -l 2>/dev/null | grep -v '^#' | head -10 || ls -la /etc/cron.d/ 2>/dev/null | head -10", + Description: "Scheduled jobs", + Categories: []string{"service"}, + Optional: true, + }, + { + Name: "hardware_info", + Command: "lspci 2>/dev/null | head -20 || echo 'no_lspci'", + Description: "Hardware devices (e.g., Coral TPU)", + Categories: []string{"hardware"}, + Optional: true, + }, + { + Name: "gpu_devices", + Command: "ls -la /dev/dri/ 2>/dev/null; ls -la /dev/apex* 2>/dev/null; nvidia-smi -L 2>/dev/null || echo 'no_gpu'", + Description: "GPU and TPU devices", + Categories: []string{"hardware"}, + Optional: true, + }, + } +} + +// getVMCommands returns commands for discovering VMs (via QEMU guest agent). +func getVMCommands() []DiscoveryCommand { + return []DiscoveryCommand{ + { + Name: "os_release", + Command: "cat /etc/os-release", + Description: "Operating system identification", + Categories: []string{"version", "config"}, + Optional: true, + }, + { + Name: "hostname", + Command: "hostname", + Description: "VM hostname", + Categories: []string{"config"}, + Optional: true, + }, + { + Name: "running_services", + Command: "systemctl list-units --type=service --state=running --no-pager 2>/dev/null | head -30", + Description: "Running services and daemons", + Categories: []string{"service"}, + Optional: true, + }, + { + Name: "listening_ports", + Command: "ss -tlnp 2>/dev/null | head -25 || netstat -tlnp 2>/dev/null | head -25", + Description: "Network ports listening", + Categories: []string{"port", "network"}, + Optional: true, + }, + { + Name: "top_processes", + Command: "ps aux --sort=-rss 2>/dev/null | head -15", + Description: "Top processes by memory", + Categories: []string{"service"}, + Optional: true, + }, + { + Name: "disk_usage", + Command: "df -h 2>/dev/null | head -15", + Description: "Disk usage and mount points", + Categories: []string{"storage"}, + Optional: true, + }, + { + Name: "docker_check", + Command: "docker ps --format '{{.Names}}: {{.Image}} ({{.Status}})' 2>/dev/null | head -20 || echo 'no_docker'", + Description: "Docker containers if running", + Categories: []string{"service", "container"}, + Optional: true, + }, + { + Name: "hardware_info", + Command: "lspci 2>/dev/null | head -20", + Description: "PCI hardware devices", + Categories: []string{"hardware"}, + Optional: true, + }, + { + Name: "gpu_devices", + Command: "ls -la /dev/dri/ 2>/dev/null; nvidia-smi -L 2>/dev/null || echo 'no_gpu'", + Description: "GPU devices", + Categories: []string{"hardware"}, + Optional: true, + }, + } +} + +// getDockerCommands returns commands for discovering Docker containers. +// These are run inside the container via docker exec. +func getDockerCommands() []DiscoveryCommand { + return []DiscoveryCommand{ + { + Name: "os_release", + Command: "cat /etc/os-release 2>/dev/null || cat /etc/alpine-release 2>/dev/null || echo 'unknown'", + Description: "Container OS", + Categories: []string{"version"}, + Optional: true, + }, + { + Name: "processes", + Command: "ps aux 2>/dev/null || echo 'no_ps'", + Description: "Running processes", + Categories: []string{"service"}, + Optional: true, + }, + { + Name: "listening_ports", + Command: "ss -tlnp 2>/dev/null || netstat -tlnp 2>/dev/null || echo 'no_ss'", + Description: "Listening ports inside container", + Categories: []string{"port"}, + Optional: true, + }, + { + Name: "env_vars", + Command: "env 2>/dev/null | grep -vE '(PASSWORD|SECRET|KEY|TOKEN|CREDENTIAL)' | head -30", + Description: "Environment variables (filtered)", + Categories: []string{"config"}, + Optional: true, + }, + { + Name: "config_files", + Command: "find /config /data /app /etc -maxdepth 2 -name '*.conf' -o -name '*.yml' -o -name '*.yaml' -o -name '*.json' 2>/dev/null | head -20", + Description: "Configuration files", + Categories: []string{"config"}, + Optional: true, + }, + } +} + +// getNestedDockerCommands returns commands for Docker inside VMs or LXCs. +func getNestedDockerCommands() []DiscoveryCommand { + return []DiscoveryCommand{ + { + Name: "docker_containers", + Command: "docker ps -a --format '{{.Names}}|{{.Image}}|{{.Status}}|{{.Ports}}'", + Description: "All Docker containers", + Categories: []string{"container", "service"}, + Optional: false, + }, + { + Name: "docker_images", + Command: "docker images --format '{{.Repository}}:{{.Tag}}' | head -20", + Description: "Docker images", + Categories: []string{"version"}, + Optional: true, + }, + { + Name: "docker_compose", + Command: "find /opt /home /root -name 'docker-compose*.yml' -o -name 'compose*.yml' 2>/dev/null | head -10", + Description: "Docker compose files", + Categories: []string{"config"}, + Optional: true, + }, + } +} + +// getK8sCommands returns commands for discovering Kubernetes pods. +func getK8sCommands() []DiscoveryCommand { + return []DiscoveryCommand{ + { + Name: "processes", + Command: "ps aux 2>/dev/null || echo 'no_ps'", + Description: "Running processes in pod", + Categories: []string{"service"}, + Optional: true, + }, + { + Name: "listening_ports", + Command: "ss -tlnp 2>/dev/null || netstat -tlnp 2>/dev/null || echo 'no_ss'", + Description: "Listening ports", + Categories: []string{"port"}, + Optional: true, + }, + { + Name: "env_vars", + Command: "env 2>/dev/null | grep -vE '(PASSWORD|SECRET|KEY|TOKEN|CREDENTIAL)' | head -30", + Description: "Environment variables (filtered)", + Categories: []string{"config"}, + Optional: true, + }, + } +} + +// getHostCommands returns commands for discovering host systems. +func getHostCommands() []DiscoveryCommand { + return []DiscoveryCommand{ + { + Name: "os_release", + Command: "cat /etc/os-release", + Description: "Operating system", + Categories: []string{"version", "config"}, + Optional: true, + }, + { + Name: "hostname", + Command: "hostname -f 2>/dev/null || hostname", + Description: "Full hostname", + Categories: []string{"config"}, + Optional: true, + }, + { + Name: "running_services", + Command: "systemctl list-units --type=service --state=running --no-pager 2>/dev/null | head -40", + Description: "Running services", + Categories: []string{"service"}, + Optional: true, + }, + { + Name: "listening_ports", + Command: "ss -tlnp 2>/dev/null | head -30", + Description: "Listening network ports", + Categories: []string{"port", "network"}, + Optional: true, + }, + { + Name: "docker_containers", + Command: "docker ps --format '{{.Names}}: {{.Image}} ({{.Status}})' 2>/dev/null | head -30 || echo 'no_docker'", + Description: "Docker containers on host", + Categories: []string{"container", "service"}, + Optional: true, + }, + { + Name: "proxmox_version", + Command: "pveversion 2>/dev/null || echo 'not_proxmox'", + Description: "Proxmox version if applicable", + Categories: []string{"version"}, + Optional: true, + }, + { + Name: "zfs_pools", + Command: "zpool list 2>/dev/null | head -10 || echo 'no_zfs'", + Description: "ZFS pools", + Categories: []string{"storage"}, + Optional: true, + }, + { + Name: "disk_usage", + Command: "df -h | head -20", + Description: "Disk usage", + Categories: []string{"storage"}, + Optional: true, + }, + { + Name: "hardware_info", + Command: "lscpu | head -20", + Description: "CPU information", + Categories: []string{"hardware"}, + Optional: true, + }, + { + Name: "memory_info", + Command: "free -h", + Description: "Memory information", + Categories: []string{"hardware"}, + Optional: true, + }, + } +} + +// BuildLXCCommand wraps a command for execution in an LXC container. +func BuildLXCCommand(vmid string, cmd string) string { + return fmt.Sprintf("pct exec %s -- sh -c %q", vmid, cmd) +} + +// BuildVMCommand wraps a command for execution in a VM via QEMU guest agent. +// Note: This requires the guest agent to be running. +func BuildVMCommand(vmid string, cmd string) string { + // For VMs, we use qm guest exec which requires the guest agent + return fmt.Sprintf("qm guest exec %s -- sh -c %q", vmid, cmd) +} + +// BuildDockerCommand wraps a command for execution in a Docker container. +func BuildDockerCommand(containerName string, cmd string) string { + return fmt.Sprintf("docker exec %s sh -c %q", containerName, cmd) +} + +// BuildNestedDockerCommand builds a command to run inside Docker on a VM/LXC. +func BuildNestedDockerCommand(vmid string, isLXC bool, containerName string, cmd string) string { + dockerCmd := BuildDockerCommand(containerName, cmd) + if isLXC { + return BuildLXCCommand(vmid, dockerCmd) + } + return BuildVMCommand(vmid, dockerCmd) +} + +// BuildK8sCommand builds a command to run in a Kubernetes pod. +func BuildK8sCommand(namespace, podName, containerName, cmd string) string { + if containerName != "" { + return fmt.Sprintf("kubectl exec -n %s %s -c %s -- sh -c %q", namespace, podName, containerName, cmd) + } + return fmt.Sprintf("kubectl exec -n %s %s -- sh -c %q", namespace, podName, cmd) +} + +// GetCLIAccessTemplate returns a CLI access template for a resource type. +func GetCLIAccessTemplate(resourceType ResourceType) string { + switch resourceType { + case ResourceTypeLXC: + return "pct exec {vmid} -- {command}" + case ResourceTypeVM: + return "qm guest exec {vmid} -- {command}" + case ResourceTypeDocker: + return "docker exec {container} {command}" + case ResourceTypeDockerLXC: + return "pct exec {vmid} -- docker exec {container} {command}" + case ResourceTypeDockerVM: + return "qm guest exec {vmid} -- docker exec {container} {command}" + case ResourceTypeK8s: + return "kubectl exec -n {namespace} {pod} -- {command}" + case ResourceTypeHost: + return "{command}" + default: + return "{command}" + } +} + +// FormatCLIAccess formats a CLI access string with actual values. +func FormatCLIAccess(resourceType ResourceType, vmid, containerName, namespace, podName string) string { + template := GetCLIAccessTemplate(resourceType) + result := template + + result = strings.ReplaceAll(result, "{vmid}", vmid) + result = strings.ReplaceAll(result, "{container}", containerName) + result = strings.ReplaceAll(result, "{namespace}", namespace) + result = strings.ReplaceAll(result, "{pod}", podName) + + return result +} diff --git a/internal/aidiscovery/commands_test.go b/internal/aidiscovery/commands_test.go new file mode 100644 index 000000000..eeac3a030 --- /dev/null +++ b/internal/aidiscovery/commands_test.go @@ -0,0 +1,78 @@ +package aidiscovery + +import ( + "strings" + "testing" +) + +func TestCommandsAndTemplates(t *testing.T) { + resourceTypes := []ResourceType{ + ResourceTypeLXC, + ResourceTypeVM, + ResourceTypeDocker, + ResourceTypeDockerVM, + ResourceTypeDockerLXC, + ResourceTypeK8s, + ResourceTypeHost, + } + + for _, rt := range resourceTypes { + cmds := GetCommandsForResource(rt) + if len(cmds) == 0 { + t.Fatalf("expected commands for %s", rt) + } + } + + if len(GetCommandsForResource(ResourceType("unknown"))) != 0 { + t.Fatalf("expected no commands for unknown resource type") + } + + if !strings.Contains(BuildLXCCommand("101", "echo hi"), "pct exec 101") { + t.Fatalf("unexpected LXC command") + } + if !strings.Contains(BuildVMCommand("101", "echo hi"), "qm guest exec 101") { + t.Fatalf("unexpected VM command") + } + if !strings.Contains(BuildDockerCommand("web", "echo hi"), "docker exec web") { + t.Fatalf("unexpected docker command") + } + + nestedLXC := BuildNestedDockerCommand("201", true, "web", "echo hi") + if !strings.Contains(nestedLXC, "pct exec 201") || !strings.Contains(nestedLXC, "docker exec web") { + t.Fatalf("unexpected nested LXC command: %s", nestedLXC) + } + + nestedVM := BuildNestedDockerCommand("301", false, "web", "echo hi") + if !strings.Contains(nestedVM, "qm guest exec 301") || !strings.Contains(nestedVM, "docker exec web") { + t.Fatalf("unexpected nested VM command: %s", nestedVM) + } + + withContainer := BuildK8sCommand("default", "pod", "app", "echo hi") + if !strings.Contains(withContainer, "-c app") || !strings.Contains(withContainer, "kubectl exec") { + t.Fatalf("unexpected k8s command: %s", withContainer) + } + + withoutContainer := BuildK8sCommand("default", "pod", "", "echo hi") + if strings.Contains(withoutContainer, "-c app") { + t.Fatalf("unexpected container selector: %s", withoutContainer) + } + + template := GetCLIAccessTemplate(ResourceTypeK8s) + if !strings.Contains(template, "{namespace}") || !strings.Contains(template, "{pod}") { + t.Fatalf("unexpected template: %s", template) + } + + for _, rt := range resourceTypes { + if tmpl := GetCLIAccessTemplate(rt); tmpl == "" { + t.Fatalf("expected template for %s", rt) + } + } + if tmpl := GetCLIAccessTemplate(ResourceType("unknown")); tmpl != "{command}" { + t.Fatalf("unexpected default template: %s", tmpl) + } + + formatted := FormatCLIAccess(ResourceTypeK8s, "101", "container", "default", "pod") + if !strings.Contains(formatted, "default") || !strings.Contains(formatted, "pod") { + t.Fatalf("unexpected formatted access: %s", formatted) + } +} diff --git a/internal/aidiscovery/deep_scanner.go b/internal/aidiscovery/deep_scanner.go new file mode 100644 index 000000000..291ceea07 --- /dev/null +++ b/internal/aidiscovery/deep_scanner.go @@ -0,0 +1,378 @@ +package aidiscovery + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +// CommandExecutor executes commands on infrastructure. +type CommandExecutor interface { + ExecuteCommand(ctx context.Context, agentID string, cmd ExecuteCommandPayload) (*CommandResultPayload, error) + GetConnectedAgents() []ConnectedAgent + IsAgentConnected(agentID string) bool +} + +// ExecuteCommandPayload mirrors agentexec.ExecuteCommandPayload +type ExecuteCommandPayload struct { + RequestID string `json:"request_id"` + Command string `json:"command"` + TargetType string `json:"target_type"` // "host", "container", "vm" + TargetID string `json:"target_id,omitempty"` // VMID for container/VM + Timeout int `json:"timeout,omitempty"` +} + +// CommandResultPayload mirrors agentexec.CommandResultPayload +type CommandResultPayload struct { + RequestID string `json:"request_id"` + Success bool `json:"success"` + Stdout string `json:"stdout,omitempty"` + Stderr string `json:"stderr,omitempty"` + ExitCode int `json:"exit_code"` + Error string `json:"error,omitempty"` + Duration int64 `json:"duration_ms"` +} + +// ConnectedAgent mirrors agentexec.ConnectedAgent +type ConnectedAgent struct { + AgentID string + Hostname string + Version string + Platform string + Tags []string + ConnectedAt time.Time +} + +// DeepScanner runs discovery commands on resources. +type DeepScanner struct { + executor CommandExecutor + mu sync.RWMutex + progress map[string]*DiscoveryProgress // resourceID -> progress + maxParallel int + timeout time.Duration +} + +// NewDeepScanner creates a new deep scanner. +func NewDeepScanner(executor CommandExecutor) *DeepScanner { + return &DeepScanner{ + executor: executor, + progress: make(map[string]*DiscoveryProgress), + maxParallel: 3, // Run up to 3 commands in parallel per resource + timeout: 30 * time.Second, + } +} + +// ScanResult contains the results of a deep scan. +type ScanResult struct { + ResourceType ResourceType + ResourceID string + HostID string + Hostname string + CommandOutputs map[string]string + Errors map[string]string + StartedAt time.Time + CompletedAt time.Time +} + +// Scan runs discovery commands on a resource and returns the outputs. +func (s *DeepScanner) Scan(ctx context.Context, req DiscoveryRequest) (*ScanResult, error) { + resourceID := MakeResourceID(req.ResourceType, req.HostID, req.ResourceID) + + // Initialize progress + s.mu.Lock() + s.progress[resourceID] = &DiscoveryProgress{ + ResourceID: resourceID, + Status: DiscoveryStatusRunning, + CurrentStep: "initializing", + StartedAt: time.Now(), + } + s.mu.Unlock() + + defer func() { + s.mu.Lock() + delete(s.progress, resourceID) + s.mu.Unlock() + }() + + result := &ScanResult{ + ResourceType: req.ResourceType, + ResourceID: req.ResourceID, + HostID: req.HostID, + Hostname: req.Hostname, + CommandOutputs: make(map[string]string), + Errors: make(map[string]string), + StartedAt: time.Now(), + } + + // Check if we have an agent for this host + if s.executor == nil { + return nil, fmt.Errorf("no command executor available") + } + + // Find the agent for this host + agentID := s.findAgentForHost(req.HostID, req.Hostname) + if agentID == "" { + return nil, fmt.Errorf("no connected agent for host %s (%s)", req.HostID, req.Hostname) + } + + // Get commands for this resource type + commands := GetCommandsForResource(req.ResourceType) + if len(commands) == 0 { + return nil, fmt.Errorf("no commands defined for resource type %s", req.ResourceType) + } + + // Update progress + s.mu.Lock() + if prog, ok := s.progress[resourceID]; ok { + prog.TotalSteps = len(commands) + prog.CurrentStep = "running commands" + } + s.mu.Unlock() + + // Run commands with limited parallelism + semaphore := make(chan struct{}, s.maxParallel) + var wg sync.WaitGroup + var mu sync.Mutex + + for _, cmd := range commands { + wg.Add(1) + go func(cmd DiscoveryCommand) { + defer wg.Done() + + select { + case semaphore <- struct{}{}: + defer func() { <-semaphore }() + case <-ctx.Done(): + return + } + + // Build the actual command to run + actualCmd := s.buildCommand(req.ResourceType, req.ResourceID, cmd.Command) + + // Execute the command + cmdCtx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + + cmdResult, err := s.executor.ExecuteCommand(cmdCtx, agentID, ExecuteCommandPayload{ + RequestID: uuid.New().String(), + Command: actualCmd, + TargetType: s.getTargetType(req.ResourceType), + TargetID: req.ResourceID, + Timeout: cmd.Timeout, + }) + + mu.Lock() + defer mu.Unlock() + + if err != nil { + if !cmd.Optional { + result.Errors[cmd.Name] = err.Error() + } + log.Debug(). + Err(err). + Str("command", cmd.Name). + Str("resource", resourceID). + Msg("Command failed during discovery") + return + } + + if cmdResult != nil { + output := cmdResult.Stdout + if cmdResult.Stderr != "" && output != "" { + output += "\n--- stderr ---\n" + cmdResult.Stderr + } else if cmdResult.Stderr != "" { + output = cmdResult.Stderr + } + + if output != "" { + result.CommandOutputs[cmd.Name] = output + } + + if !cmdResult.Success && cmdResult.Error != "" && !cmd.Optional { + result.Errors[cmd.Name] = cmdResult.Error + } + } + + // Update progress + s.mu.Lock() + if prog, ok := s.progress[resourceID]; ok { + prog.CompletedSteps++ + } + s.mu.Unlock() + }(cmd) + } + + wg.Wait() + result.CompletedAt = time.Now() + + log.Info(). + Str("resource", resourceID). + Int("outputs", len(result.CommandOutputs)). + Int("errors", len(result.Errors)). + Dur("duration", result.CompletedAt.Sub(result.StartedAt)). + Msg("Deep scan completed") + + return result, nil +} + +// buildCommand wraps the command appropriately for the resource type. +// NOTE: For LXC/VM, the agent handles wrapping via pct exec / qm guest exec +// based on TargetType, so we don't wrap here. We only wrap for Docker containers +// since Docker isn't a recognized TargetType in the agent. +func (s *DeepScanner) buildCommand(resourceType ResourceType, resourceID string, cmd string) string { + switch resourceType { + case ResourceTypeLXC: + // Agent wraps with pct exec based on TargetType="container" + return cmd + case ResourceTypeVM: + // Agent wraps with qm guest exec based on TargetType="vm" + return cmd + case ResourceTypeDocker: + // Docker needs wrapping here since agent doesn't handle it + return BuildDockerCommand(resourceID, cmd) + case ResourceTypeHost: + // Commands run directly on host + return cmd + case ResourceTypeDockerLXC: + // Docker inside LXC - agent wraps with pct exec, we just add docker exec + // resourceID format: "vmid:container_name" + parts := splitResourceID(resourceID) + if len(parts) >= 2 { + return BuildDockerCommand(parts[1], cmd) + } + return cmd + case ResourceTypeDockerVM: + // Docker inside VM - agent wraps with qm guest exec, we just add docker exec + parts := splitResourceID(resourceID) + if len(parts) >= 2 { + return BuildDockerCommand(parts[1], cmd) + } + return cmd + default: + return cmd + } +} + +// getTargetType returns the target type for the agent execution payload. +func (s *DeepScanner) getTargetType(resourceType ResourceType) string { + switch resourceType { + case ResourceTypeLXC: + return "container" + case ResourceTypeVM: + return "vm" + case ResourceTypeDocker: + return "host" // Docker commands run on host via docker exec + case ResourceTypeHost: + return "host" + default: + return "host" + } +} + +// findAgentForHost finds the agent ID for a given host. +func (s *DeepScanner) findAgentForHost(hostID, hostname string) string { + agents := s.executor.GetConnectedAgents() + + // First try exact match on agent ID + for _, agent := range agents { + if agent.AgentID == hostID { + return agent.AgentID + } + } + + // Then try hostname match + for _, agent := range agents { + if agent.Hostname == hostname || agent.Hostname == hostID { + return agent.AgentID + } + } + + // If only one agent connected, use it + if len(agents) == 1 { + return agents[0].AgentID + } + + return "" +} + +// GetProgress returns the current progress of a scan. +func (s *DeepScanner) GetProgress(resourceID string) *DiscoveryProgress { + s.mu.RLock() + defer s.mu.RUnlock() + if prog, ok := s.progress[resourceID]; ok { + return prog + } + return nil +} + +// IsScanning returns whether a resource is currently being scanned. +func (s *DeepScanner) IsScanning(resourceID string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + _, ok := s.progress[resourceID] + return ok +} + +// splitResourceID splits a compound resource ID (e.g., "101:container_name"). +func splitResourceID(id string) []string { + var parts []string + start := 0 + for i, c := range id { + if c == ':' { + parts = append(parts, id[start:i]) + start = i + 1 + } + } + if start < len(id) { + parts = append(parts, id[start:]) + } + return parts +} + +// ScanDocker runs discovery on Docker containers via the host. +func (s *DeepScanner) ScanDocker(ctx context.Context, hostID, hostname, containerName string) (*ScanResult, error) { + req := DiscoveryRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: containerName, + HostID: hostID, + Hostname: hostname, + } + return s.Scan(ctx, req) +} + +// ScanLXC runs discovery on an LXC container. +func (s *DeepScanner) ScanLXC(ctx context.Context, hostID, hostname, vmid string) (*ScanResult, error) { + req := DiscoveryRequest{ + ResourceType: ResourceTypeLXC, + ResourceID: vmid, + HostID: hostID, + Hostname: hostname, + } + return s.Scan(ctx, req) +} + +// ScanVM runs discovery on a VM via QEMU guest agent. +func (s *DeepScanner) ScanVM(ctx context.Context, hostID, hostname, vmid string) (*ScanResult, error) { + req := DiscoveryRequest{ + ResourceType: ResourceTypeVM, + ResourceID: vmid, + HostID: hostID, + Hostname: hostname, + } + return s.Scan(ctx, req) +} + +// ScanHost runs discovery on a host system. +func (s *DeepScanner) ScanHost(ctx context.Context, hostID, hostname string) (*ScanResult, error) { + req := DiscoveryRequest{ + ResourceType: ResourceTypeHost, + ResourceID: hostID, + HostID: hostID, + Hostname: hostname, + } + return s.Scan(ctx, req) +} diff --git a/internal/aidiscovery/deep_scanner_test.go b/internal/aidiscovery/deep_scanner_test.go new file mode 100644 index 000000000..fc5ad780b --- /dev/null +++ b/internal/aidiscovery/deep_scanner_test.go @@ -0,0 +1,325 @@ +package aidiscovery + +import ( + "context" + "strings" + "sync" + "testing" + "time" +) + +type stubExecutor struct { + mu sync.Mutex + commands []string + agents []ConnectedAgent +} + +func (s *stubExecutor) ExecuteCommand(ctx context.Context, agentID string, cmd ExecuteCommandPayload) (*CommandResultPayload, error) { + s.mu.Lock() + s.commands = append(s.commands, cmd.Command) + s.mu.Unlock() + + if err := ctx.Err(); err != nil { + return nil, err + } + + if strings.Contains(cmd.Command, "docker ps -a") { + return &CommandResultPayload{ + RequestID: cmd.RequestID, + Success: false, + Error: "boom", + }, nil + } + + return &CommandResultPayload{ + RequestID: cmd.RequestID, + Success: true, + Stdout: cmd.Command, + Duration: 5, + }, nil +} + +func (s *stubExecutor) GetConnectedAgents() []ConnectedAgent { + return s.agents +} + +func (s *stubExecutor) IsAgentConnected(agentID string) bool { + for _, agent := range s.agents { + if agent.AgentID == agentID { + return true + } + } + return false +} + +type outputExecutor struct{} + +func (outputExecutor) ExecuteCommand(ctx context.Context, agentID string, cmd ExecuteCommandPayload) (*CommandResultPayload, error) { + switch { + case strings.Contains(cmd.Command, "docker ps -a"): + return &CommandResultPayload{Success: true, Stdout: "out", Stderr: "err"}, nil + case strings.Contains(cmd.Command, "docker images"): + return &CommandResultPayload{Success: true, Stderr: "err-only"}, nil + default: + return &CommandResultPayload{Success: true}, nil + } +} + +func (outputExecutor) GetConnectedAgents() []ConnectedAgent { + return []ConnectedAgent{{AgentID: "host1", Hostname: "host1"}} +} + +func (outputExecutor) IsAgentConnected(string) bool { return true } + +type errorExecutor struct{} + +func (errorExecutor) ExecuteCommand(ctx context.Context, agentID string, cmd ExecuteCommandPayload) (*CommandResultPayload, error) { + return nil, context.DeadlineExceeded +} + +func (errorExecutor) GetConnectedAgents() []ConnectedAgent { + return []ConnectedAgent{{AgentID: "host1", Hostname: "host1"}} +} + +func (errorExecutor) IsAgentConnected(string) bool { return true } + +func TestDeepScanner_Scan_NestedDockerCommands(t *testing.T) { + exec := &stubExecutor{ + agents: []ConnectedAgent{ + {AgentID: "host1", Hostname: "host1", ConnectedAt: time.Now()}, + }, + } + scanner := NewDeepScanner(exec) + + result, err := scanner.Scan(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeDockerVM, + ResourceID: "101:web", + HostID: "host1", + Hostname: "host1", + }) + if err != nil { + t.Fatalf("Scan error: %v", err) + } + if len(result.CommandOutputs) == 0 { + t.Fatalf("expected command outputs") + } + if _, ok := result.Errors["docker_containers"]; !ok { + t.Fatalf("expected docker_containers error, got %#v", result.Errors) + } + + exec.mu.Lock() + defer exec.mu.Unlock() + foundWrapped := false + for _, cmd := range exec.commands { + if strings.Contains(cmd, "qm guest exec 101") && strings.Contains(cmd, "docker exec web") { + foundWrapped = true + break + } + } + if !foundWrapped { + t.Fatalf("expected nested docker command, got %#v", exec.commands) + } +} + +func TestDeepScanner_FindAgentAndTargetType(t *testing.T) { + exec := &stubExecutor{ + agents: []ConnectedAgent{ + {AgentID: "a1", Hostname: "node1"}, + {AgentID: "a2", Hostname: "node2"}, + }, + } + scanner := NewDeepScanner(exec) + + if got := scanner.findAgentForHost("a2", ""); got != "a2" { + t.Fatalf("expected direct agent match, got %s", got) + } + if got := scanner.findAgentForHost("node1", "node1"); got != "a1" { + t.Fatalf("expected hostname match, got %s", got) + } + + exec.agents = []ConnectedAgent{{AgentID: "solo", Hostname: "only"}} + if got := scanner.findAgentForHost("missing", "missing"); got != "solo" { + t.Fatalf("expected single agent fallback, got %s", got) + } + exec.agents = nil + if got := scanner.findAgentForHost("missing", "missing"); got != "" { + t.Fatalf("expected no agent, got %s", got) + } + + if scanner.getTargetType(ResourceTypeLXC) != "container" { + t.Fatalf("unexpected target type for lxc") + } + if scanner.getTargetType(ResourceTypeVM) != "vm" { + t.Fatalf("unexpected target type for vm") + } + if scanner.getTargetType(ResourceTypeDocker) != "host" { + t.Fatalf("unexpected target type for docker") + } + if scanner.getTargetType(ResourceTypeHost) != "host" { + t.Fatalf("unexpected target type for host") + } +} + +func TestSplitResourceID(t *testing.T) { + parts := splitResourceID("101:web:extra") + if len(parts) != 3 || parts[0] != "101" || parts[1] != "web" || parts[2] != "extra" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestDeepScanner_BuildCommandAndProgress(t *testing.T) { + scanner := NewDeepScanner(&stubExecutor{}) + + if cmd := scanner.buildCommand(ResourceTypeLXC, "101", "echo hi"); !strings.Contains(cmd, "pct exec 101") { + t.Fatalf("unexpected lxc command: %s", cmd) + } + if cmd := scanner.buildCommand(ResourceTypeVM, "101", "echo hi"); cmd != "echo hi" { + t.Fatalf("unexpected vm command: %s", cmd) + } + if cmd := scanner.buildCommand(ResourceTypeDocker, "web", "echo hi"); !strings.Contains(cmd, "docker exec web") { + t.Fatalf("unexpected docker command: %s", cmd) + } + if cmd := scanner.buildCommand(ResourceTypeHost, "host", "echo hi"); cmd != "echo hi" { + t.Fatalf("unexpected host command: %s", cmd) + } + + dockerLXC := scanner.buildCommand(ResourceTypeDockerLXC, "201:web", "echo hi") + if !strings.Contains(dockerLXC, "pct exec 201") || !strings.Contains(dockerLXC, "docker exec web") { + t.Fatalf("unexpected docker lxc command: %s", dockerLXC) + } + if cmd := scanner.buildCommand(ResourceTypeDockerLXC, "bad", "echo hi"); cmd != "echo hi" { + t.Fatalf("expected fallback lxc command, got %s", cmd) + } + dockerVM := scanner.buildCommand(ResourceTypeDockerVM, "301:web", "echo hi") + if !strings.Contains(dockerVM, "qm guest exec 301") || !strings.Contains(dockerVM, "docker exec web") { + t.Fatalf("unexpected docker vm command: %s", dockerVM) + } + if cmd := scanner.buildCommand(ResourceTypeDockerVM, "bad", "echo hi"); cmd != "echo hi" { + t.Fatalf("expected fallback command, got %s", cmd) + } + if cmd := scanner.buildCommand(ResourceType("unknown"), "id", "echo hi"); cmd != "echo hi" { + t.Fatalf("expected default command, got %s", cmd) + } + + scanner.progress["id"] = &DiscoveryProgress{ResourceID: "id"} + if scanner.GetProgress("id") == nil { + t.Fatalf("expected progress") + } + if !scanner.IsScanning("id") { + t.Fatalf("expected IsScanning true") + } + if scanner.GetProgress("missing") != nil { + t.Fatalf("expected nil progress") + } + if scanner.IsScanning("missing") { + t.Fatalf("expected IsScanning false") + } + + noExec := NewDeepScanner(nil) + if _, err := noExec.ScanHost(context.Background(), "host1", "host1"); err == nil { + t.Fatalf("expected error without executor") + } +} + +func TestDeepScanner_ScanWrappers(t *testing.T) { + exec := &stubExecutor{ + agents: []ConnectedAgent{{AgentID: "host1", Hostname: "host1"}}, + } + scanner := NewDeepScanner(exec) + scanner.maxParallel = 1 + + if _, err := scanner.ScanDocker(context.Background(), "host1", "host1", "web"); err != nil { + t.Fatalf("ScanDocker error: %v", err) + } + if _, err := scanner.ScanLXC(context.Background(), "host1", "host1", "101"); err != nil { + t.Fatalf("ScanLXC error: %v", err) + } + if _, err := scanner.ScanVM(context.Background(), "host1", "host1", "102"); err != nil { + t.Fatalf("ScanVM error: %v", err) + } +} + +func TestDeepScanner_ScanErrors(t *testing.T) { + exec := &stubExecutor{ + agents: []ConnectedAgent{{AgentID: "host1", Hostname: "host1"}}, + } + scanner := NewDeepScanner(exec) + if _, err := scanner.Scan(context.Background(), DiscoveryRequest{ + ResourceType: ResourceType("unknown"), + ResourceID: "id", + HostID: "host1", + Hostname: "host1", + }); err == nil { + t.Fatalf("expected error for unknown resource type") + } + + exec.agents = nil + if _, err := scanner.Scan(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: "web", + HostID: "host1", + Hostname: "host1", + }); err == nil { + t.Fatalf("expected error for missing agent") + } +} + +func TestDeepScanner_OutputHandling(t *testing.T) { + exec := outputExecutor{} + scanner := NewDeepScanner(exec) + scanner.maxParallel = 1 + + result, err := scanner.Scan(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeDockerVM, + ResourceID: "101:web", + HostID: "host1", + Hostname: "host1", + }) + if err != nil { + t.Fatalf("Scan error: %v", err) + } + if out := result.CommandOutputs["docker_containers"]; !strings.Contains(out, "--- stderr ---") { + t.Fatalf("expected combined stderr output, got %s", out) + } + if out := result.CommandOutputs["docker_images"]; out != "err-only" { + t.Fatalf("expected stderr-only output, got %s", out) + } +} + +func TestDeepScanner_CommandErrorHandling(t *testing.T) { + scanner := NewDeepScanner(errorExecutor{}) + scanner.maxParallel = 1 + + result, err := scanner.Scan(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeDockerVM, + ResourceID: "101:web", + HostID: "host1", + Hostname: "host1", + }) + if err != nil { + t.Fatalf("Scan error: %v", err) + } + if _, ok := result.Errors["docker_containers"]; !ok { + t.Fatalf("expected error for non-optional command") + } +} + +func TestDeepScanner_ScanCanceledContext(t *testing.T) { + exec := &stubExecutor{ + agents: []ConnectedAgent{{AgentID: "host1", Hostname: "host1"}}, + } + scanner := NewDeepScanner(exec) + scanner.maxParallel = 0 + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if _, err := scanner.Scan(ctx, DiscoveryRequest{ + ResourceType: ResourceTypeDockerVM, + ResourceID: "101:web", + HostID: "host1", + Hostname: "host1", + }); err != nil { + t.Fatalf("Scan error: %v", err) + } +} diff --git a/internal/aidiscovery/formatters.go b/internal/aidiscovery/formatters.go new file mode 100644 index 000000000..889af6a5b --- /dev/null +++ b/internal/aidiscovery/formatters.go @@ -0,0 +1,337 @@ +package aidiscovery + +import ( + "fmt" + "strings" + "time" +) + +// FormatForAIContext formats discoveries for inclusion in AI prompts. +// This provides context about resources for Patrol, Investigation, and Chat. +func FormatForAIContext(discoveries []*ResourceDiscovery) string { + if len(discoveries) == 0 { + return "" + } + + var sb strings.Builder + sb.WriteString("## Infrastructure Discovery\n\n") + sb.WriteString("The following has been discovered about the affected resources:\n\n") + + for _, d := range discoveries { + sb.WriteString(formatSingleDiscovery(d)) + sb.WriteString("\n") + } + + sb.WriteString("\n**IMPORTANT:** Use the CLI access methods shown above. For example:\n") + sb.WriteString("- For LXC containers, use `pct exec -- `\n") + sb.WriteString("- For VMs with guest agent, use `qm guest exec -- `\n") + sb.WriteString("- For Docker containers, use `docker exec `\n") + + return sb.String() +} + +// FormatSingleForAIContext formats a single discovery for AI context. +func FormatSingleForAIContext(d *ResourceDiscovery) string { + if d == nil { + return "" + } + return formatSingleDiscovery(d) +} + +// formatSingleDiscovery formats a single discovery entry. +func formatSingleDiscovery(d *ResourceDiscovery) string { + var sb strings.Builder + + // Header with service info + sb.WriteString(fmt.Sprintf("### %s (%s)\n", d.ServiceName, d.ID)) + sb.WriteString(fmt.Sprintf("- **Type:** %s\n", d.ResourceType)) + sb.WriteString(fmt.Sprintf("- **Host:** %s\n", d.Hostname)) + + if d.ServiceVersion != "" { + sb.WriteString(fmt.Sprintf("- **Version:** %s\n", d.ServiceVersion)) + } + + if d.Category != "" && d.Category != CategoryUnknown { + sb.WriteString(fmt.Sprintf("- **Category:** %s\n", d.Category)) + } + + // CLI access (most important for remediation) + if d.CLIAccess != "" { + sb.WriteString(fmt.Sprintf("- **CLI Access:** `%s`\n", d.CLIAccess)) + } + + // Config and data paths + if len(d.ConfigPaths) > 0 { + sb.WriteString(fmt.Sprintf("- **Config Paths:** %s\n", strings.Join(d.ConfigPaths, ", "))) + } + if len(d.DataPaths) > 0 { + sb.WriteString(fmt.Sprintf("- **Data Paths:** %s\n", strings.Join(d.DataPaths, ", "))) + } + + // Ports + if len(d.Ports) > 0 { + var ports []string + for _, p := range d.Ports { + ports = append(ports, fmt.Sprintf("%d/%s", p.Port, p.Protocol)) + } + sb.WriteString(fmt.Sprintf("- **Ports:** %s\n", strings.Join(ports, ", "))) + } + + // Important facts + importantFacts := filterImportantFacts(d.Facts) + if len(importantFacts) > 0 { + sb.WriteString("- **Key Facts:**\n") + for _, f := range importantFacts { + sb.WriteString(fmt.Sprintf(" - %s: %s\n", f.Key, f.Value)) + } + } + + // User notes (critical for context) + if d.UserNotes != "" { + sb.WriteString(fmt.Sprintf("- **User Notes:** %s\n", d.UserNotes)) + } + + return sb.String() +} + +// filterImportantFacts returns the most relevant facts for AI context. +func filterImportantFacts(facts []DiscoveryFact) []DiscoveryFact { + var important []DiscoveryFact + + // Priority categories + priorityCategories := map[FactCategory]bool{ + FactCategoryHardware: true, // GPU, TPU + FactCategoryDependency: true, // MQTT, database connections + FactCategorySecurity: true, // Auth info + FactCategoryVersion: true, // Version info + } + + for _, f := range facts { + if priorityCategories[f.Category] && f.Confidence >= 0.7 { + important = append(important, f) + } + } + + // Limit to top 5 facts + if len(important) > 5 { + important = important[:5] + } + + return important +} + +// FormatDiscoverySummary formats a summary of all discoveries. +func FormatDiscoverySummary(discoveries []*ResourceDiscovery) string { + if len(discoveries) == 0 { + return "No infrastructure discovery data available." + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Infrastructure Discovery Summary (%d resources):\n\n", len(discoveries))) + + // Group by resource type + byType := make(map[ResourceType][]*ResourceDiscovery) + for _, d := range discoveries { + byType[d.ResourceType] = append(byType[d.ResourceType], d) + } + + for rt, ds := range byType { + sb.WriteString(fmt.Sprintf("**%s** (%d):\n", rt, len(ds))) + for _, d := range ds { + confidence := "" + if d.Confidence >= 0.9 { + confidence = " [high confidence]" + } else if d.Confidence >= 0.7 { + confidence = " [medium confidence]" + } + sb.WriteString(fmt.Sprintf(" - %s: %s%s\n", d.ResourceID, d.ServiceName, confidence)) + } + sb.WriteString("\n") + } + + return sb.String() +} + +// FormatForRemediation formats discovery specifically for remediation context. +func FormatForRemediation(d *ResourceDiscovery) string { + if d == nil { + return "" + } + + var sb strings.Builder + sb.WriteString("## Resource Context for Remediation\n\n") + + sb.WriteString(fmt.Sprintf("**Resource:** %s (%s)\n", d.ServiceName, d.ID)) + sb.WriteString(fmt.Sprintf("**Type:** %s on %s\n\n", d.ResourceType, d.Hostname)) + + // CLI access is most critical + if d.CLIAccess != "" { + sb.WriteString("### How to Execute Commands\n") + sb.WriteString(fmt.Sprintf("```\n%s\n```\n\n", d.CLIAccess)) + } + + // Service-specific info + if d.ServiceType != "" { + sb.WriteString(fmt.Sprintf("**Service:** %s", d.ServiceType)) + if d.ServiceVersion != "" { + sb.WriteString(fmt.Sprintf(" v%s", d.ServiceVersion)) + } + sb.WriteString("\n\n") + } + + // Config paths for potential fixes + if len(d.ConfigPaths) > 0 { + sb.WriteString("### Configuration Files\n") + for _, p := range d.ConfigPaths { + sb.WriteString(fmt.Sprintf("- `%s`\n", p)) + } + sb.WriteString("\n") + } + + // User notes may contain important context + if d.UserNotes != "" { + sb.WriteString("### User Notes\n") + sb.WriteString(d.UserNotes) + sb.WriteString("\n\n") + } + + // Hardware info for special considerations + for _, f := range d.Facts { + if f.Category == FactCategoryHardware { + sb.WriteString(fmt.Sprintf("**Hardware:** %s = %s\n", f.Key, f.Value)) + } + } + + return sb.String() +} + +// FormatDiscoveryAge returns a human-readable age string. +func FormatDiscoveryAge(d *ResourceDiscovery) string { + if d == nil || d.UpdatedAt.IsZero() { + return "unknown" + } + + age := time.Since(d.UpdatedAt) + switch { + case age < time.Minute: + return "just now" + case age < time.Hour: + mins := int(age.Minutes()) + if mins == 1 { + return "1 minute ago" + } + return fmt.Sprintf("%d minutes ago", mins) + case age < 24*time.Hour: + hours := int(age.Hours()) + if hours == 1 { + return "1 hour ago" + } + return fmt.Sprintf("%d hours ago", hours) + default: + days := int(age.Hours() / 24) + if days == 1 { + return "1 day ago" + } + return fmt.Sprintf("%d days ago", days) + } +} + +// GetCLIExample returns an example CLI command for the resource. +func GetCLIExample(d *ResourceDiscovery, exampleCmd string) string { + if d == nil || d.CLIAccess == "" { + return "" + } + + // Replace the placeholder with the example command + cli := d.CLIAccess + cli = strings.ReplaceAll(cli, "...", exampleCmd) + cli = strings.ReplaceAll(cli, "{command}", exampleCmd) + + return cli +} + +// FormatFactsTable formats facts as a simple table. +func FormatFactsTable(facts []DiscoveryFact) string { + if len(facts) == 0 { + return "" + } + + var sb strings.Builder + sb.WriteString("| Category | Key | Value |\n") + sb.WriteString("|----------|-----|-------|\n") + + for _, f := range facts { + value := f.Value + if len(value) > 50 { + value = value[:47] + "..." + } + sb.WriteString(fmt.Sprintf("| %s | %s | %s |\n", f.Category, f.Key, value)) + } + + return sb.String() +} + +// BuildResourceContextForPatrol builds context for Patrol findings. +func BuildResourceContextForPatrol(store *Store, resourceIDs []string) string { + if store == nil || len(resourceIDs) == 0 { + return "" + } + + discoveries, err := store.GetMultiple(resourceIDs) + if err != nil || len(discoveries) == 0 { + return "" + } + + return FormatForAIContext(discoveries) +} + +// ToJSON converts a discovery to a JSON-friendly map. +func ToJSON(d *ResourceDiscovery) map[string]any { + if d == nil { + return nil + } + + facts := make([]map[string]any, 0, len(d.Facts)) + for _, f := range d.Facts { + facts = append(facts, map[string]any{ + "category": f.Category, + "key": f.Key, + "value": f.Value, + "source": f.Source, + "confidence": f.Confidence, + }) + } + + ports := make([]map[string]any, 0, len(d.Ports)) + for _, p := range d.Ports { + ports = append(ports, map[string]any{ + "port": p.Port, + "protocol": p.Protocol, + "process": p.Process, + "address": p.Address, + }) + } + + return map[string]any{ + "id": d.ID, + "resource_type": d.ResourceType, + "resource_id": d.ResourceID, + "host_id": d.HostID, + "hostname": d.Hostname, + "service_type": d.ServiceType, + "service_name": d.ServiceName, + "service_version": d.ServiceVersion, + "category": d.Category, + "cli_access": d.CLIAccess, + "facts": facts, + "config_paths": d.ConfigPaths, + "data_paths": d.DataPaths, + "ports": ports, + "user_notes": d.UserNotes, + "confidence": d.Confidence, + "ai_reasoning": d.AIReasoning, + "discovered_at": d.DiscoveredAt, + "updated_at": d.UpdatedAt, + "scan_duration": d.ScanDuration, + } +} diff --git a/internal/aidiscovery/formatters_test.go b/internal/aidiscovery/formatters_test.go new file mode 100644 index 000000000..c6841e785 --- /dev/null +++ b/internal/aidiscovery/formatters_test.go @@ -0,0 +1,195 @@ +package aidiscovery + +import ( + "strings" + "testing" + "time" +) + +func TestFormattersAndTables(t *testing.T) { + if FormatForAIContext(nil) != "" { + t.Fatalf("expected empty context for nil discoveries") + } + + discovery := &ResourceDiscovery{ + ID: MakeResourceID(ResourceTypeDocker, "host1", "app"), + ResourceType: ResourceTypeDocker, + ResourceID: "app", + HostID: "host1", + Hostname: "host1", + ServiceType: "app", + ServiceName: "App Service", + ServiceVersion: "1.0", + Category: CategoryWebServer, + CLIAccess: "docker exec app ...", + ConfigPaths: []string{"/etc/app/config.yml"}, + DataPaths: []string{"/var/lib/app"}, + Ports: []PortInfo{{Port: 80, Protocol: "tcp"}}, + UserNotes: "keepalive enabled", + Facts: []DiscoveryFact{ + {Category: FactCategoryHardware, Key: "gpu", Value: "nvidia", Confidence: 0.9}, + {Category: FactCategoryService, Key: "worker", Value: "enabled", Confidence: 0.9}, + }, + } + + ctx := FormatForAIContext([]*ResourceDiscovery{discovery}) + if !strings.Contains(ctx, "Infrastructure Discovery") || !strings.Contains(ctx, "App Service") { + t.Fatalf("unexpected context: %s", ctx) + } + if !strings.Contains(ctx, "docker exec") || !strings.Contains(ctx, "User Notes") { + t.Fatalf("missing expected fields in context") + } + + if FormatSingleForAIContext(nil) != "" { + t.Fatalf("expected empty string for nil discovery") + } + if !strings.Contains(FormatSingleForAIContext(discovery), "App Service") { + t.Fatalf("expected single discovery output") + } + + remediation := FormatForRemediation(discovery) + if !strings.Contains(remediation, "How to Execute Commands") || !strings.Contains(remediation, "Hardware") { + t.Fatalf("unexpected remediation output: %s", remediation) + } + if FormatForRemediation(nil) != "" { + t.Fatalf("expected empty remediation output for nil") + } + + example := GetCLIExample(discovery, "ls /") + if !strings.Contains(example, "ls /") { + t.Fatalf("unexpected cli example: %s", example) + } + if GetCLIExample(&ResourceDiscovery{}, "ls /") != "" { + t.Fatalf("expected empty example when cli access missing") + } + + table := FormatFactsTable([]DiscoveryFact{ + {Category: FactCategoryVersion, Key: "app", Value: strings.Repeat("x", 60)}, + }) + if !strings.Contains(table, "...") { + t.Fatalf("expected truncated table value: %s", table) + } + if FormatFactsTable(nil) != "" { + t.Fatalf("expected empty facts table for nil") + } + + jsonMap := ToJSON(discovery) + if jsonMap["service_name"] != "App Service" || jsonMap["resource_id"] != "app" { + t.Fatalf("unexpected json map: %#v", jsonMap) + } + if ToJSON(nil) != nil { + t.Fatalf("expected nil json map for nil discovery") + } +} + +func TestFormatDiscoverySummaryAndAge(t *testing.T) { + now := time.Now() + if FormatDiscoverySummary(nil) == "" { + t.Fatalf("expected summary text for empty list") + } + if FormatDiscoveryAge(nil) != "unknown" { + t.Fatalf("expected unknown age for nil") + } + if FormatDiscoveryAge(&ResourceDiscovery{}) != "unknown" { + t.Fatalf("expected unknown age for zero timestamp") + } + discoveries := []*ResourceDiscovery{ + { + ID: MakeResourceID(ResourceTypeVM, "node1", "101"), + ResourceType: ResourceTypeVM, + ResourceID: "101", + HostID: "node1", + ServiceName: "VM One", + Confidence: 0.95, + UpdatedAt: now.Add(-2 * time.Hour), + }, + { + ID: MakeResourceID(ResourceTypeDocker, "host1", "app"), + ResourceType: ResourceTypeDocker, + ResourceID: "app", + HostID: "host1", + ServiceName: "App", + Confidence: 0.75, + UpdatedAt: now.Add(-2 * 24 * time.Hour), + }, + } + + summary := FormatDiscoverySummary(discoveries) + if !strings.Contains(summary, "[high confidence]") || !strings.Contains(summary, "[medium confidence]") { + t.Fatalf("unexpected summary: %s", summary) + } + + tests := []struct { + name string + updated time.Time + expected string + }{ + {name: "just-now", updated: now.Add(-30 * time.Second), expected: "just now"}, + {name: "one-minute", updated: now.Add(-1 * time.Minute), expected: "1 minute ago"}, + {name: "minutes", updated: now.Add(-10 * time.Minute), expected: "10 minutes ago"}, + {name: "one-hour", updated: now.Add(-1 * time.Hour), expected: "1 hour ago"}, + {name: "hours", updated: now.Add(-2 * time.Hour), expected: "2 hours ago"}, + {name: "one-day", updated: now.Add(-24 * time.Hour), expected: "1 day ago"}, + {name: "days", updated: now.Add(-3 * 24 * time.Hour), expected: "3 days ago"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatDiscoveryAge(&ResourceDiscovery{UpdatedAt: tt.updated}) + if got != tt.expected { + t.Fatalf("expected %s, got %s", tt.expected, got) + } + }) + } +} + +func TestBuildResourceContextForPatrol(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + discovery := &ResourceDiscovery{ + ID: MakeResourceID(ResourceTypeDocker, "host1", "app"), + ResourceType: ResourceTypeDocker, + ResourceID: "app", + HostID: "host1", + ServiceName: "App Service", + } + if err := store.Save(discovery); err != nil { + t.Fatalf("Save error: %v", err) + } + + ctx := BuildResourceContextForPatrol(store, []string{discovery.ID}) + if !strings.Contains(ctx, "App Service") { + t.Fatalf("unexpected patrol context: %s", ctx) + } + + if BuildResourceContextForPatrol(nil, []string{discovery.ID}) != "" { + t.Fatalf("expected empty context for nil store") + } + if BuildResourceContextForPatrol(store, nil) != "" { + t.Fatalf("expected empty context for empty ids") + } + if BuildResourceContextForPatrol(store, []string{"missing"}) != "" { + t.Fatalf("expected empty context for missing discoveries") + } +} + +func TestFilterImportantFactsLimit(t *testing.T) { + var facts []DiscoveryFact + for i := 0; i < 7; i++ { + facts = append(facts, DiscoveryFact{ + Category: FactCategoryVersion, + Key: "k", + Value: "v", + Confidence: 0.9, + }) + } + + important := filterImportantFacts(facts) + if len(important) != 5 { + t.Fatalf("expected 5 facts, got %d", len(important)) + } +} diff --git a/internal/aidiscovery/service.go b/internal/aidiscovery/service.go new file mode 100644 index 000000000..a6def0c4a --- /dev/null +++ b/internal/aidiscovery/service.go @@ -0,0 +1,778 @@ +// Package aidiscovery provides AI-powered infrastructure discovery capabilities. +// It discovers services, versions, configurations, and CLI access methods +// for VMs, LXCs, Docker containers, Kubernetes pods, and hosts. +package aidiscovery + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +// StateProvider provides access to the current infrastructure state. +type StateProvider interface { + GetState() StateSnapshot +} + +// StateSnapshot represents the infrastructure state. This mirrors models.StateSnapshot +// to avoid circular dependencies. +type StateSnapshot struct { + VMs []VM + Containers []Container + DockerHosts []DockerHost +} + +// VM represents a virtual machine. +type VM struct { + VMID int + Name string + Node string + Status string + Instance string +} + +// Container represents an LXC container. +type Container struct { + VMID int + Name string + Node string + Status string + Instance string +} + +// DockerHost represents a Docker host. +type DockerHost struct { + AgentID string + Hostname string + Containers []DockerContainer +} + +// DockerContainer represents a Docker container. +type DockerContainer struct { + ID string + Name string + Image string + Status string + Ports []DockerPort + Labels map[string]string + Mounts []DockerMount +} + +// DockerPort represents a port mapping. +type DockerPort struct { + PublicPort int + PrivatePort int + Protocol string +} + +// DockerMount represents a mount point. +type DockerMount struct { + Source string + Destination string +} + +// AIAnalyzer provides AI analysis capabilities for discovery. +type AIAnalyzer interface { + AnalyzeForDiscovery(ctx context.Context, prompt string) (string, error) +} + +// Service manages AI-powered infrastructure discovery. +type Service struct { + store *Store + scanner *DeepScanner + stateProvider StateProvider + aiAnalyzer AIAnalyzer + + mu sync.RWMutex + running bool + stopCh chan struct{} + interval time.Duration + initialDelay time.Duration + lastRun time.Time + + // Cache for AI analysis results (by image name) + analysisCache map[string]*AIAnalysisResponse + cacheMu sync.RWMutex + cacheExpiry time.Duration + lastCacheUpdate time.Time +} + +// Config holds discovery service configuration. +type Config struct { + DataDir string + Interval time.Duration // How often to run background discovery + CacheExpiry time.Duration // How long to cache AI analysis results +} + +// DefaultConfig returns the default discovery configuration. +func DefaultConfig() Config { + return Config{ + Interval: 10 * time.Minute, + CacheExpiry: 1 * time.Hour, + } +} + +// NewService creates a new discovery service. +func NewService(store *Store, scanner *DeepScanner, stateProvider StateProvider, cfg Config) *Service { + if cfg.Interval == 0 { + cfg.Interval = 10 * time.Minute + } + if cfg.CacheExpiry == 0 { + cfg.CacheExpiry = 1 * time.Hour + } + + return &Service{ + store: store, + scanner: scanner, + stateProvider: stateProvider, + interval: cfg.Interval, + initialDelay: 30 * time.Second, + cacheExpiry: cfg.CacheExpiry, + stopCh: make(chan struct{}), + analysisCache: make(map[string]*AIAnalysisResponse), + } +} + +// SetAIAnalyzer sets the AI analyzer for discovery. +func (s *Service) SetAIAnalyzer(analyzer AIAnalyzer) { + s.mu.Lock() + defer s.mu.Unlock() + s.aiAnalyzer = analyzer +} + +// Start begins the background discovery service. +func (s *Service) Start(ctx context.Context) { + s.mu.Lock() + if s.running { + s.mu.Unlock() + return + } + s.running = true + s.stopCh = make(chan struct{}) + s.mu.Unlock() + + log.Info(). + Dur("interval", s.interval). + Msg("Starting AI-powered infrastructure discovery service") + + go s.discoveryLoop(ctx) +} + +// Stop stops the background discovery service. +func (s *Service) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.running { + close(s.stopCh) + s.running = false + } +} + +// SetInterval updates the scan interval. Takes effect on next Start(). +func (s *Service) SetInterval(interval time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + s.interval = interval +} + +// IsRunning returns whether the background discovery loop is active. +func (s *Service) IsRunning() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.running +} + +// discoveryLoop runs periodic discovery. +func (s *Service) discoveryLoop(ctx context.Context) { + delay := s.initialDelay + if delay <= 0 { + delay = 30 * time.Second + } + + // Run initial discovery after a short delay + select { + case <-time.After(delay): + case <-s.stopCh: + return + case <-ctx.Done(): + return + } + + s.runBackgroundDiscovery(ctx) + + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.runBackgroundDiscovery(ctx) + case <-s.stopCh: + log.Info().Msg("Stopping AI discovery service") + return + case <-ctx.Done(): + log.Info().Msg("AI discovery context cancelled") + return + } + } +} + +// runBackgroundDiscovery runs discovery on all resources in the background. +func (s *Service) runBackgroundDiscovery(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + log.Error().Interface("panic", r).Stack().Msg("Recovered from panic in background AI discovery") + } + }() + + s.mu.Lock() + s.lastRun = time.Now() + s.mu.Unlock() + + // For background discovery, we only do shallow analysis based on metadata + // Deep scanning is triggered on-demand via DiscoverResource + if s.stateProvider == nil { + return + } + + state := s.stateProvider.GetState() + s.discoverDockerContainers(ctx, state.DockerHosts) +} + +// discoverDockerContainers runs discovery on Docker containers using metadata. +func (s *Service) discoverDockerContainers(ctx context.Context, hosts []DockerHost) { + s.mu.RLock() + analyzer := s.aiAnalyzer + s.mu.RUnlock() + + if analyzer == nil { + log.Debug().Msg("AI analyzer not set, skipping Docker discovery") + return + } + + for _, host := range hosts { + for _, container := range host.Containers { + select { + case <-ctx.Done(): + return + default: + } + + // Build resource ID + id := MakeResourceID(ResourceTypeDocker, host.AgentID, container.Name) + + // Check if we already have a recent discovery + if !s.store.NeedsRefresh(id, s.cacheExpiry) { + continue + } + + // Analyze using metadata (shallow discovery) + discovery := s.analyzeDockerContainer(ctx, analyzer, container, host) + if discovery != nil { + if err := s.store.Save(discovery); err != nil { + log.Warn().Err(err).Str("id", id).Msg("Failed to save discovery") + } + } + } + } +} + +// analyzeDockerContainer analyzes a Docker container using AI. +func (s *Service) analyzeDockerContainer(ctx context.Context, analyzer AIAnalyzer, c DockerContainer, host DockerHost) *ResourceDiscovery { + // Check cache first + s.cacheMu.RLock() + cached, found := s.analysisCache[c.Image] + cacheValid := time.Since(s.lastCacheUpdate) < s.cacheExpiry + s.cacheMu.RUnlock() + + var result *AIAnalysisResponse + + if found && cacheValid { + result = cached + } else { + // Build prompt for AI analysis + prompt := s.buildMetadataAnalysisPrompt(c, host) + + response, err := analyzer.AnalyzeForDiscovery(ctx, prompt) + if err != nil { + log.Warn().Err(err).Str("container", c.Name).Msg("AI analysis failed") + return nil + } + + result = s.parseAIResponse(response) + if result == nil { + log.Warn().Str("container", c.Name).Msg("Failed to parse AI response") + return nil + } + + // Cache the result + s.cacheMu.Lock() + s.analysisCache[c.Image] = result + s.lastCacheUpdate = time.Now() + s.cacheMu.Unlock() + } + + // Skip unknown/low-confidence results + if result.ServiceType == "unknown" || result.Confidence < 0.5 { + return nil + } + + // Build CLI access string + cliAccess := result.CLIAccess + if cliAccess != "" { + cliAccess = strings.ReplaceAll(cliAccess, "{container}", c.Name) + } + + // Extract ports + var ports []PortInfo + for _, p := range c.Ports { + ports = append(ports, PortInfo{ + Port: p.PrivatePort, + Protocol: p.Protocol, + Address: fmt.Sprintf(":%d", p.PublicPort), + }) + } + + return &ResourceDiscovery{ + ID: MakeResourceID(ResourceTypeDocker, host.AgentID, c.Name), + ResourceType: ResourceTypeDocker, + ResourceID: c.Name, + HostID: host.AgentID, + Hostname: host.Hostname, + ServiceType: result.ServiceType, + ServiceName: result.ServiceName, + ServiceVersion: result.ServiceVersion, + Category: result.Category, + CLIAccess: cliAccess, + Facts: result.Facts, + ConfigPaths: result.ConfigPaths, + DataPaths: result.DataPaths, + Ports: ports, + Confidence: result.Confidence, + AIReasoning: result.Reasoning, + DiscoveredAt: time.Now(), + UpdatedAt: time.Now(), + } +} + +// DiscoverResource performs deep discovery on a specific resource. +func (s *Service) DiscoverResource(ctx context.Context, req DiscoveryRequest) (*ResourceDiscovery, error) { + resourceID := MakeResourceID(req.ResourceType, req.HostID, req.ResourceID) + + // Check if we have a recent discovery and force isn't set + if !req.Force { + existing, err := s.store.Get(resourceID) + if err == nil && existing != nil { + age := time.Since(existing.UpdatedAt) + if age < 5*time.Minute { + log.Debug().Str("id", resourceID).Dur("age", age).Msg("Using recent discovery") + return existing, nil + } + } + } + + s.mu.RLock() + analyzer := s.aiAnalyzer + s.mu.RUnlock() + + if analyzer == nil { + return nil, fmt.Errorf("AI analyzer not configured") + } + + // Run deep scan if scanner is available + var scanResult *ScanResult + if s.scanner != nil { + var err error + scanResult, err = s.scanner.Scan(ctx, req) + if err != nil { + log.Warn().Err(err).Str("id", resourceID).Msg("Deep scan failed, using metadata only") + } + } + + // Build analysis request + analysisReq := AIAnalysisRequest{ + ResourceType: req.ResourceType, + ResourceID: req.ResourceID, + HostID: req.HostID, + Hostname: req.Hostname, + } + + if scanResult != nil { + analysisReq.CommandOutputs = scanResult.CommandOutputs + } + + // Add metadata if available + if s.stateProvider != nil { + analysisReq.Metadata = s.getResourceMetadata(req) + } + + // Build prompt and analyze + prompt := s.buildDeepAnalysisPrompt(analysisReq) + response, err := analyzer.AnalyzeForDiscovery(ctx, prompt) + if err != nil { + return nil, fmt.Errorf("AI analysis failed: %w", err) + } + + result := s.parseAIResponse(response) + if result == nil { + // Truncate response for error message + truncated := response + if len(truncated) > 500 { + truncated = truncated[:500] + "..." + } + return nil, fmt.Errorf("failed to parse AI response: %s", truncated) + } + + // Build discovery result + discovery := &ResourceDiscovery{ + ID: resourceID, + ResourceType: req.ResourceType, + ResourceID: req.ResourceID, + HostID: req.HostID, + Hostname: req.Hostname, + ServiceType: result.ServiceType, + ServiceName: result.ServiceName, + ServiceVersion: result.ServiceVersion, + Category: result.Category, + CLIAccess: s.formatCLIAccess(req.ResourceType, req.ResourceID, result.CLIAccess), + Facts: result.Facts, + ConfigPaths: result.ConfigPaths, + DataPaths: result.DataPaths, + Ports: result.Ports, + Confidence: result.Confidence, + AIReasoning: result.Reasoning, + DiscoveredAt: time.Now(), + UpdatedAt: time.Now(), + } + + if scanResult != nil { + discovery.RawCommandOutput = scanResult.CommandOutputs + discovery.ScanDuration = scanResult.CompletedAt.Sub(scanResult.StartedAt).Milliseconds() + } + + // Preserve user notes from existing discovery + existing, _ := s.store.Get(resourceID) + if existing != nil { + discovery.UserNotes = existing.UserNotes + discovery.UserSecrets = existing.UserSecrets + if discovery.DiscoveredAt.IsZero() || existing.DiscoveredAt.Before(discovery.DiscoveredAt) { + discovery.DiscoveredAt = existing.DiscoveredAt + } + } + + // Save discovery + if err := s.store.Save(discovery); err != nil { + return nil, fmt.Errorf("failed to save discovery: %w", err) + } + + return discovery, nil +} + +// getResourceMetadata retrieves metadata for a resource from the state. +func (s *Service) getResourceMetadata(req DiscoveryRequest) map[string]any { + if s.stateProvider == nil { + return nil + } + + state := s.stateProvider.GetState() + metadata := make(map[string]any) + + switch req.ResourceType { + case ResourceTypeLXC: + for _, c := range state.Containers { + if fmt.Sprintf("%d", c.VMID) == req.ResourceID && c.Node == req.HostID { + metadata["name"] = c.Name + metadata["status"] = c.Status + metadata["vmid"] = c.VMID + break + } + } + case ResourceTypeVM: + for _, vm := range state.VMs { + if fmt.Sprintf("%d", vm.VMID) == req.ResourceID && vm.Node == req.HostID { + metadata["name"] = vm.Name + metadata["status"] = vm.Status + metadata["vmid"] = vm.VMID + break + } + } + case ResourceTypeDocker: + for _, host := range state.DockerHosts { + if host.AgentID == req.HostID || host.Hostname == req.HostID { + for _, c := range host.Containers { + if c.Name == req.ResourceID { + metadata["image"] = c.Image + metadata["status"] = c.Status + metadata["labels"] = c.Labels + break + } + } + break + } + } + } + + return metadata +} + +// formatCLIAccess formats the CLI access string with actual values. +func (s *Service) formatCLIAccess(resourceType ResourceType, resourceID, cliTemplate string) string { + if cliTemplate == "" { + // Use default template + cliTemplate = GetCLIAccessTemplate(resourceType) + } + + result := cliTemplate + result = strings.ReplaceAll(result, "{vmid}", resourceID) + result = strings.ReplaceAll(result, "{container}", resourceID) + result = strings.ReplaceAll(result, "{command}", "...") + + return result +} + +// buildMetadataAnalysisPrompt builds a prompt for shallow metadata-based analysis. +func (s *Service) buildMetadataAnalysisPrompt(c DockerContainer, host DockerHost) string { + info := map[string]any{ + "name": c.Name, + "image": c.Image, + "status": c.Status, + "host": host.Hostname, + } + + if len(c.Ports) > 0 { + var ports []map[string]any + for _, p := range c.Ports { + ports = append(ports, map[string]any{ + "public": p.PublicPort, + "private": p.PrivatePort, + "protocol": p.Protocol, + }) + } + info["ports"] = ports + } + + if len(c.Labels) > 0 { + info["labels"] = c.Labels + } + + if len(c.Mounts) > 0 { + var mounts []string + for _, m := range c.Mounts { + mounts = append(mounts, m.Destination) + } + info["mounts"] = mounts + } + + infoJSON, _ := json.MarshalIndent(info, "", " ") + + return fmt.Sprintf(`Analyze this Docker container and identify what service it's running. + +Container Information: +%s + +Based on the image name, ports, labels, and mounts, determine: +1. What service/application is this? +2. What category does it belong to? +3. How should CLI commands be executed? + +Respond in this exact JSON format: +{ + "service_type": "lowercase_type", + "service_name": "Human Readable Name", + "service_version": "version if detectable from image tag", + "category": "database|web_server|cache|monitoring|backup|nvr|storage|container|network|security|media|home_automation|unknown", + "cli_access": "docker exec {container} ", + "facts": [], + "config_paths": [], + "data_paths": [], + "ports": [], + "confidence": 0.0-1.0, + "reasoning": "Brief explanation" +} + +Respond with ONLY valid JSON.`, string(infoJSON)) +} + +// buildDeepAnalysisPrompt builds a prompt for deep analysis with command outputs. +func (s *Service) buildDeepAnalysisPrompt(req AIAnalysisRequest) string { + var sections []string + + sections = append(sections, fmt.Sprintf(`Resource Type: %s +Resource ID: %s +Host: %s (%s)`, req.ResourceType, req.ResourceID, req.Hostname, req.HostID)) + + if len(req.Metadata) > 0 { + metaJSON, _ := json.MarshalIndent(req.Metadata, "", " ") + sections = append(sections, fmt.Sprintf("Metadata:\n%s", string(metaJSON))) + } + + if len(req.CommandOutputs) > 0 { + sections = append(sections, "Command Outputs:") + for name, output := range req.CommandOutputs { + // Truncate long outputs + if len(output) > 2000 { + output = output[:2000] + "\n... (truncated)" + } + sections = append(sections, fmt.Sprintf("--- %s ---\n%s", name, output)) + } + } + + return fmt.Sprintf(`Analyze this infrastructure resource and provide detailed discovery information. + +%s + +Based on all available information, determine: +1. What service/application is running? +2. What version is it? +3. What are the important configuration paths? +4. What data paths should be backed up? +5. What ports are in use? +6. Any special hardware (GPU, TPU, etc.)? +7. Any dependencies (databases, message queues, etc.)? + +Respond in this exact JSON format: +{ + "service_type": "lowercase_type (e.g., frigate, postgres, pbs)", + "service_name": "Human Readable Name", + "service_version": "version number if found", + "category": "database|web_server|cache|monitoring|backup|nvr|storage|container|virtualizer|network|security|media|home_automation|unknown", + "cli_access": "command to access this service's CLI", + "facts": [ + {"category": "version|config|service|port|hardware|network|storage|dependency|security", "key": "fact_name", "value": "fact_value", "source": "command_name", "confidence": 0.9} + ], + "config_paths": ["/path/to/config.yml"], + "data_paths": ["/path/to/data"], + "ports": [{"port": 8080, "protocol": "tcp", "process": "nginx", "address": "0.0.0.0"}], + "confidence": 0.0-1.0, + "reasoning": "Explanation of identification" +} + +Important: +- Extract version numbers from package lists, process output, or config files +- Identify config and data paths from mount points and file listings +- Note any special hardware like Coral TPU, NVIDIA GPU +- For LXC/VM, the CLI access should use pct exec or qm guest exec +- For Docker, use docker exec + +Respond with ONLY valid JSON.`, strings.Join(sections, "\n\n")) +} + +// parseAIResponse parses the AI's JSON response. +func (s *Service) parseAIResponse(response string) *AIAnalysisResponse { + log.Debug().Str("raw_response", response).Msg("AI discovery raw response") + response = strings.TrimSpace(response) + + // Handle markdown code blocks + if strings.HasPrefix(response, "```") { + lines := strings.Split(response, "\n") + var jsonLines []string + inBlock := false + for _, line := range lines { + if strings.HasPrefix(line, "```") { + inBlock = !inBlock + continue + } + if inBlock { + jsonLines = append(jsonLines, line) + } + } + response = strings.Join(jsonLines, "\n") + } + + // Find JSON object + start := strings.Index(response, "{") + end := strings.LastIndex(response, "}") + if start >= 0 && end > start { + response = response[start : end+1] + } + + var result AIAnalysisResponse + if err := json.Unmarshal([]byte(response), &result); err != nil { + log.Debug().Err(err).Str("response", response).Msg("Failed to parse AI response") + return nil + } + + // Set discovered_at for facts + now := time.Now() + for i := range result.Facts { + result.Facts[i].DiscoveredAt = now + } + + return &result +} + +// GetDiscovery retrieves a discovery by ID. +func (s *Service) GetDiscovery(id string) (*ResourceDiscovery, error) { + return s.store.Get(id) +} + +// GetDiscoveryByResource retrieves a discovery by resource type and ID. +func (s *Service) GetDiscoveryByResource(resourceType ResourceType, hostID, resourceID string) (*ResourceDiscovery, error) { + return s.store.GetByResource(resourceType, hostID, resourceID) +} + +// ListDiscoveries returns all discoveries. +func (s *Service) ListDiscoveries() ([]*ResourceDiscovery, error) { + return s.store.List() +} + +// ListDiscoveriesByType returns discoveries for a specific resource type. +func (s *Service) ListDiscoveriesByType(resourceType ResourceType) ([]*ResourceDiscovery, error) { + return s.store.ListByType(resourceType) +} + +// ListDiscoveriesByHost returns discoveries for a specific host. +func (s *Service) ListDiscoveriesByHost(hostID string) ([]*ResourceDiscovery, error) { + return s.store.ListByHost(hostID) +} + +// UpdateNotes updates user notes for a discovery. +func (s *Service) UpdateNotes(id string, notes string, secrets map[string]string) error { + return s.store.UpdateNotes(id, notes, secrets) +} + +// DeleteDiscovery deletes a discovery. +func (s *Service) DeleteDiscovery(id string) error { + return s.store.Delete(id) +} + +// GetProgress returns the progress of an ongoing discovery. +func (s *Service) GetProgress(resourceID string) *DiscoveryProgress { + if s.scanner == nil { + return nil + } + return s.scanner.GetProgress(resourceID) +} + +// GetStatus returns the service status. +func (s *Service) GetStatus() map[string]any { + s.mu.RLock() + defer s.mu.RUnlock() + + s.cacheMu.RLock() + cacheSize := len(s.analysisCache) + s.cacheMu.RUnlock() + + return map[string]any{ + "running": s.running, + "last_run": s.lastRun, + "interval": s.interval.String(), + "cache_size": cacheSize, + "ai_analyzer_set": s.aiAnalyzer != nil, + "scanner_set": s.scanner != nil, + "store_set": s.store != nil, + } +} + +// ClearCache clears the AI analysis cache. +func (s *Service) ClearCache() { + s.cacheMu.Lock() + defer s.cacheMu.Unlock() + s.analysisCache = make(map[string]*AIAnalysisResponse) + s.lastCacheUpdate = time.Time{} +} diff --git a/internal/aidiscovery/service_test.go b/internal/aidiscovery/service_test.go new file mode 100644 index 000000000..6c49c70ba --- /dev/null +++ b/internal/aidiscovery/service_test.go @@ -0,0 +1,658 @@ +package aidiscovery + +import ( + "context" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +type stubAnalyzer struct { + mu sync.Mutex + calls int + response string +} + +func (s *stubAnalyzer) AnalyzeForDiscovery(ctx context.Context, prompt string) (string, error) { + s.mu.Lock() + s.calls++ + s.mu.Unlock() + return s.response, nil +} + +type errorAnalyzer struct{} + +func (errorAnalyzer) AnalyzeForDiscovery(ctx context.Context, prompt string) (string, error) { + return "", context.Canceled +} + +type stubStateProvider struct { + state StateSnapshot +} + +func (s stubStateProvider) GetState() StateSnapshot { + return s.state +} + +type panicStateProvider struct{} + +func (panicStateProvider) GetState() StateSnapshot { + panic("boom") +} + +func TestService_parseAIResponse_Markdown(t *testing.T) { + service := &Service{} + response := "```json\n{\n \"service_type\": \"nginx\",\n \"service_name\": \"Nginx\",\n \"service_version\": \"1.2\",\n \"category\": \"web_server\",\n \"cli_access\": \"docker exec {container} bash\",\n \"facts\": [{\"category\": \"version\", \"key\": \"nginx\", \"value\": \"1.2\", \"source\": \"cmd\", \"confidence\": 0.9}],\n \"config_paths\": [\"/etc/nginx/nginx.conf\"],\n \"data_paths\": [\"/var/www\"],\n \"ports\": [{\"port\": 80, \"protocol\": \"tcp\", \"process\": \"nginx\", \"address\": \"0.0.0.0\"}],\n \"confidence\": 0.9,\n \"reasoning\": \"image name\"\n}\n```" + + parsed := service.parseAIResponse(response) + if parsed == nil { + t.Fatalf("expected parsed response") + } + if parsed.ServiceType != "nginx" || parsed.ServiceName != "Nginx" { + t.Fatalf("unexpected parsed result: %#v", parsed) + } + if len(parsed.Facts) != 1 || parsed.Facts[0].DiscoveredAt.IsZero() { + t.Fatalf("expected fact timestamp set: %#v", parsed.Facts) + } + + if service.parseAIResponse("not json") != nil { + t.Fatalf("expected nil for invalid json") + } +} + +func TestService_analyzeDockerContainer_CacheAndPorts(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + service := NewService(store, nil, nil, Config{CacheExpiry: time.Hour}) + + analyzer := &stubAnalyzer{ + response: `{"service_type":"nginx","service_name":"Nginx","service_version":"1.2","category":"web_server","cli_access":"docker exec {container} nginx -v","facts":[],"config_paths":[],"data_paths":[],"ports":[],"confidence":0.9,"reasoning":"image"}`, + } + + container := DockerContainer{ + Name: "web", + Image: "nginx:latest", + Status: "running", + Ports: []DockerPort{ + {PublicPort: 8080, PrivatePort: 80, Protocol: "tcp"}, + }, + } + host := DockerHost{ + AgentID: "host1", + Hostname: "host1", + } + + first := service.analyzeDockerContainer(context.Background(), analyzer, container, host) + if first == nil { + t.Fatalf("expected discovery") + } + if !strings.Contains(first.CLIAccess, "web") { + t.Fatalf("expected cli access to include container name, got %s", first.CLIAccess) + } + if len(first.Ports) != 1 || first.Ports[0].Port != 80 || first.Ports[0].Address != ":8080" { + t.Fatalf("unexpected ports: %#v", first.Ports) + } + + second := service.analyzeDockerContainer(context.Background(), analyzer, container, host) + if second == nil { + t.Fatalf("expected cached discovery") + } + + analyzer.mu.Lock() + calls := analyzer.calls + analyzer.mu.Unlock() + if calls != 1 { + t.Fatalf("expected analyzer called once, got %d", calls) + } + + lowAnalyzer := &stubAnalyzer{ + response: `{"service_type":"unknown","service_name":"","service_version":"","category":"unknown","cli_access":"","facts":[],"config_paths":[],"data_paths":[],"ports":[],"confidence":0.4,"reasoning":""}`, + } + lowContainer := DockerContainer{Name: "mystery", Image: "unknown:latest"} + if got := service.analyzeDockerContainer(context.Background(), lowAnalyzer, lowContainer, host); got != nil { + t.Fatalf("expected low confidence discovery to be skipped") + } +} + +func TestService_DiscoverResource_RecentAndNoAnalyzer(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + service := NewService(store, nil, nil, DefaultConfig()) + + req := DiscoveryRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: "nginx", + HostID: "host1", + Hostname: "host1", + } + discovery := &ResourceDiscovery{ + ID: MakeResourceID(req.ResourceType, req.HostID, req.ResourceID), + ResourceType: req.ResourceType, + ResourceID: req.ResourceID, + HostID: req.HostID, + Hostname: req.Hostname, + ServiceName: "Existing", + } + if err := store.Save(discovery); err != nil { + t.Fatalf("Save error: %v", err) + } + + found, err := service.DiscoverResource(context.Background(), req) + if err != nil { + t.Fatalf("DiscoverResource error: %v", err) + } + if found == nil || found.ServiceName != "Existing" { + t.Fatalf("unexpected discovery: %#v", found) + } + + _, err = service.DiscoverResource(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeVM, + ResourceID: "101", + HostID: "node1", + Hostname: "node1", + Force: true, + }) + if err == nil || !strings.Contains(err.Error(), "AI analyzer") { + t.Fatalf("expected analyzer error, got %v", err) + } + + service.SetAIAnalyzer(errorAnalyzer{}) + _, err = service.DiscoverResource(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeVM, + ResourceID: "102", + HostID: "node1", + Hostname: "node1", + Force: true, + }) + if err == nil || !strings.Contains(err.Error(), "AI analysis failed") { + t.Fatalf("expected analysis error, got %v", err) + } + + service.SetAIAnalyzer(&stubAnalyzer{response: "not json"}) + _, err = service.DiscoverResource(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeVM, + ResourceID: "103", + HostID: "node1", + Hostname: "node1", + Force: true, + }) + if err == nil || !strings.Contains(err.Error(), "failed to parse") { + t.Fatalf("expected parse error, got %v", err) + } +} + +func TestService_getResourceMetadata(t *testing.T) { + state := StateSnapshot{ + VMs: []VM{ + {VMID: 101, Name: "vm1", Node: "node1", Status: "running"}, + }, + Containers: []Container{ + {VMID: 201, Name: "lxc1", Node: "node2", Status: "stopped"}, + }, + DockerHosts: []DockerHost{ + { + AgentID: "agent1", + Hostname: "dock1", + Containers: []DockerContainer{ + {Name: "redis", Image: "redis:latest", Status: "running", Labels: map[string]string{"tier": "cache"}}, + }, + }, + }, + } + + service := NewService(nil, nil, stubStateProvider{state: state}, DefaultConfig()) + + vmMeta := service.getResourceMetadata(DiscoveryRequest{ + ResourceType: ResourceTypeVM, + ResourceID: "101", + HostID: "node1", + }) + if vmMeta["name"] != "vm1" || vmMeta["vmid"] != 101 { + t.Fatalf("unexpected vm metadata: %#v", vmMeta) + } + + lxcMeta := service.getResourceMetadata(DiscoveryRequest{ + ResourceType: ResourceTypeLXC, + ResourceID: "201", + HostID: "node2", + }) + if lxcMeta["name"] != "lxc1" || lxcMeta["status"] != "stopped" { + t.Fatalf("unexpected lxc metadata: %#v", lxcMeta) + } + + dockerMeta := service.getResourceMetadata(DiscoveryRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: "redis", + HostID: "agent1", + }) + if dockerMeta["image"] != "redis:latest" || dockerMeta["status"] != "running" { + t.Fatalf("unexpected docker metadata: %#v", dockerMeta) + } + + dockerByHost := service.getResourceMetadata(DiscoveryRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: "redis", + HostID: "dock1", + }) + if dockerByHost["image"] != "redis:latest" { + t.Fatalf("unexpected docker hostname metadata: %#v", dockerByHost) + } +} + +func TestService_formatCLIAccessAndStatus(t *testing.T) { + service := NewService(nil, nil, nil, DefaultConfig()) + formatted := service.formatCLIAccess(ResourceTypeDocker, "redis", "") + if !strings.Contains(formatted, "redis") || !strings.Contains(formatted, "...") { + t.Fatalf("unexpected cli access: %s", formatted) + } + + service.analysisCache = map[string]*AIAnalysisResponse{"nginx:latest": {ServiceType: "nginx"}} + service.running = true + status := service.GetStatus() + if status["running"] != true || status["cache_size"] != 1 { + t.Fatalf("unexpected status: %#v", status) + } + + service.ClearCache() + if len(service.analysisCache) != 0 { + t.Fatalf("expected cache cleared") + } +} + +func TestService_DefaultsAndSetAnalyzer(t *testing.T) { + service := NewService(nil, nil, nil, Config{}) + if service.interval == 0 || service.cacheExpiry == 0 { + t.Fatalf("expected defaults for interval and cache expiry") + } + + analyzer := &stubAnalyzer{response: `{}`} + service.SetAIAnalyzer(analyzer) + if service.aiAnalyzer == nil { + t.Fatalf("expected analyzer set") + } + if service.GetProgress("missing") != nil { + t.Fatalf("expected nil progress without scanner") + } + if service.getResourceMetadata(DiscoveryRequest{}) != nil { + t.Fatalf("expected nil metadata without state provider") + } +} + +func TestService_RunBackgroundDiscoveryAndWrappers(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + state := StateSnapshot{ + DockerHosts: []DockerHost{ + { + AgentID: "host1", + Hostname: "host1", + Containers: []DockerContainer{ + {Name: "web", Image: "nginx:latest", Status: "running"}, + }, + }, + }, + } + service := NewService(store, nil, stubStateProvider{state: state}, DefaultConfig()) + service.SetAIAnalyzer(&stubAnalyzer{ + response: `{"service_type":"nginx","service_name":"Nginx","service_version":"1.2","category":"web_server","cli_access":"docker exec {container} nginx -v","facts":[],"config_paths":[],"data_paths":[],"ports":[],"confidence":0.9,"reasoning":"image"}`, + }) + + service.runBackgroundDiscovery(context.Background()) + id := MakeResourceID(ResourceTypeDocker, "host1", "web") + + if got, err := service.GetDiscovery(id); err != nil || got == nil { + t.Fatalf("GetDiscovery error: %v", err) + } + if got, err := service.GetDiscoveryByResource(ResourceTypeDocker, "host1", "web"); err != nil || got == nil { + t.Fatalf("GetDiscoveryByResource error: %v", err) + } + + if list, err := service.ListDiscoveries(); err != nil || len(list) != 1 { + t.Fatalf("ListDiscoveries unexpected: %v len=%d", err, len(list)) + } + if list, err := service.ListDiscoveriesByType(ResourceTypeDocker); err != nil || len(list) != 1 { + t.Fatalf("ListDiscoveriesByType unexpected: %v len=%d", err, len(list)) + } + if list, err := service.ListDiscoveriesByHost("host1"); err != nil || len(list) != 1 { + t.Fatalf("ListDiscoveriesByHost unexpected: %v len=%d", err, len(list)) + } + + if err := service.UpdateNotes(id, "note", map[string]string{"k": "v"}); err != nil { + t.Fatalf("UpdateNotes error: %v", err) + } + updated, err := service.GetDiscovery(id) + if err != nil || updated.UserNotes != "note" { + t.Fatalf("expected updated notes: %#v err=%v", updated, err) + } + + scanner := NewDeepScanner(&stubExecutor{}) + scanner.progress[id] = &DiscoveryProgress{ResourceID: id} + service.scanner = scanner + if service.GetProgress(id) == nil { + t.Fatalf("expected progress") + } + + if err := service.DeleteDiscovery(id); err != nil { + t.Fatalf("DeleteDiscovery error: %v", err) + } + + service.stateProvider = nil + service.runBackgroundDiscovery(context.Background()) +} + +func TestService_PromptsAndDiscoveryLoop(t *testing.T) { + service := NewService(nil, nil, nil, DefaultConfig()) + + container := DockerContainer{ + Name: "web", + Image: "nginx:latest", + Status: "running", + Ports: []DockerPort{ + {PublicPort: 8080, PrivatePort: 80, Protocol: "tcp"}, + }, + Labels: map[string]string{"app": "nginx"}, + Mounts: []DockerMount{{Destination: "/etc/nginx"}}, + } + host := DockerHost{Hostname: "host1"} + prompt := service.buildMetadataAnalysisPrompt(container, host) + if !strings.Contains(prompt, "\"ports\"") || !strings.Contains(prompt, "\"labels\"") || !strings.Contains(prompt, "\"mounts\"") { + t.Fatalf("unexpected metadata prompt: %s", prompt) + } + + longOutput := strings.Repeat("a", 2100) + deepPrompt := service.buildDeepAnalysisPrompt(AIAnalysisRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: "web", + HostID: "host1", + Hostname: "host1", + Metadata: map[string]any{"image": "nginx"}, + CommandOutputs: map[string]string{ + "ps": longOutput, + }, + }) + if !strings.Contains(deepPrompt, "(truncated)") || !strings.Contains(deepPrompt, "Metadata:") { + t.Fatalf("unexpected deep prompt") + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + service.initialDelay = time.Millisecond + service.Start(ctx) + service.Start(ctx) + service.Stop() + + service.stopCh = make(chan struct{}) + close(service.stopCh) + service.discoveryLoop(context.Background()) + + service.initialDelay = 0 + service.stopCh = make(chan struct{}) + close(service.stopCh) + service.discoveryLoop(context.Background()) +} + +func TestService_DiscoveryLoop_StopAndCancel(t *testing.T) { + state := StateSnapshot{ + DockerHosts: []DockerHost{ + { + AgentID: "host1", + Hostname: "host1", + Containers: []DockerContainer{ + {Name: "web", Image: "nginx:latest", Status: "running"}, + }, + }, + }, + } + + runLoop := func(stopWithCancel bool) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + service := NewService(store, nil, stubStateProvider{state: state}, DefaultConfig()) + analyzer := &stubAnalyzer{ + response: `{"service_type":"nginx","service_name":"Nginx","service_version":"1.2","category":"web_server","cli_access":"docker exec {container} nginx -v","facts":[],"config_paths":[],"data_paths":[],"ports":[],"confidence":0.9,"reasoning":"image"}`, + } + service.SetAIAnalyzer(analyzer) + service.initialDelay = time.Millisecond + service.interval = time.Millisecond + service.cacheExpiry = time.Nanosecond + + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + service.discoveryLoop(ctx) + close(done) + }() + + time.Sleep(5 * time.Millisecond) + if stopWithCancel { + cancel() + } else { + close(service.stopCh) + } + + select { + case <-done: + case <-time.After(50 * time.Millisecond): + t.Fatalf("discoveryLoop did not stop") + } + + analyzer.mu.Lock() + calls := analyzer.calls + analyzer.mu.Unlock() + if calls < 2 { + t.Fatalf("expected multiple discoveries, got %d", calls) + } + } + + runLoop(false) + runLoop(true) +} + +func TestService_DiscoverDockerContainersSkips(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + service := NewService(store, nil, nil, DefaultConfig()) + service.discoverDockerContainers(context.Background(), []DockerHost{{AgentID: "host1"}}) + + service.SetAIAnalyzer(&stubAnalyzer{ + response: `{"service_type":"nginx","service_name":"Nginx","service_version":"1.2","category":"web_server","cli_access":"docker exec {container} nginx -v","facts":[],"config_paths":[],"data_paths":[],"ports":[],"confidence":0.9,"reasoning":"image"}`, + }) + + id := MakeResourceID(ResourceTypeDocker, "host1", "web") + if err := store.Save(&ResourceDiscovery{ID: id, ResourceType: ResourceTypeDocker}); err != nil { + t.Fatalf("Save error: %v", err) + } + service.cacheExpiry = time.Hour + service.discoverDockerContainers(context.Background(), []DockerHost{ + {AgentID: "host1", Containers: []DockerContainer{{Name: "web", Image: "nginx:latest"}}}, + }) + + badAnalyzer := &stubAnalyzer{response: "not json"} + if got := service.analyzeDockerContainer(context.Background(), badAnalyzer, DockerContainer{Name: "bad", Image: "bad"}, DockerHost{AgentID: "host1"}); got != nil { + t.Fatalf("expected nil for bad analysis") + } + + canceled, cancel := context.WithCancel(context.Background()) + cancel() + analyzer := &stubAnalyzer{response: `{"service_type":"nginx","service_name":"Nginx","service_version":"1.2","category":"web_server","cli_access":"docker exec {container} nginx -v","facts":[],"config_paths":[],"data_paths":[],"ports":[],"confidence":0.9,"reasoning":"image"}`} + service.SetAIAnalyzer(analyzer) + service.discoverDockerContainers(canceled, []DockerHost{ + {AgentID: "host1", Containers: []DockerContainer{{Name: "web2", Image: "nginx:latest"}}}, + }) + analyzer.mu.Lock() + calls := analyzer.calls + analyzer.mu.Unlock() + if calls != 0 { + t.Fatalf("expected analyzer not called on canceled context") + } + + errAnalyzer := errorAnalyzer{} + if got := service.analyzeDockerContainer(context.Background(), errAnalyzer, DockerContainer{Name: "err", Image: "err"}, DockerHost{AgentID: "host1"}); got != nil { + t.Fatalf("expected nil when analyzer returns error") + } + + storePath := filepath.Join(t.TempDir(), "file") + if err := os.WriteFile(storePath, []byte("x"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + service.store.dataDir = storePath + service.discoverDockerContainers(context.Background(), []DockerHost{ + {AgentID: "host1", Containers: []DockerContainer{{Name: "web3", Image: "nginx:latest"}}}, + }) +} + +func TestService_RunBackgroundDiscoveryRecover(t *testing.T) { + service := NewService(nil, nil, panicStateProvider{}, DefaultConfig()) + service.runBackgroundDiscovery(context.Background()) +} + +func TestService_DiscoverResource_SaveError(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + badPath := filepath.Join(t.TempDir(), "file") + if err := os.WriteFile(badPath, []byte("x"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + store.dataDir = badPath + + service := NewService(store, nil, nil, DefaultConfig()) + service.SetAIAnalyzer(&stubAnalyzer{ + response: `{"service_type":"nginx","service_name":"Nginx","service_version":"1.2","category":"web_server","cli_access":"docker exec {container} nginx -v","facts":[],"config_paths":[],"data_paths":[],"ports":[],"confidence":0.9,"reasoning":"image"}`, + }) + + _, err = service.DiscoverResource(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: "web", + HostID: "host1", + Hostname: "host1", + Force: true, + }) + if err == nil || !strings.Contains(err.Error(), "failed to save discovery") { + t.Fatalf("expected save error, got %v", err) + } +} + +func TestService_DiscoverResource_ScanError(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + scanner := NewDeepScanner(nil) + service := NewService(store, scanner, nil, DefaultConfig()) + service.SetAIAnalyzer(&stubAnalyzer{ + response: `{"service_type":"nginx","service_name":"Nginx","service_version":"1.2","category":"web_server","cli_access":"docker exec {container} nginx -v","facts":[],"config_paths":[],"data_paths":[],"ports":[],"confidence":0.9,"reasoning":"image"}`, + }) + + _, err = service.DiscoverResource(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: "web", + HostID: "host1", + Hostname: "host1", + Force: true, + }) + if err != nil { + t.Fatalf("expected scan error to be tolerated, got %v", err) + } +} + +func TestService_DiscoveryLoop_ContextDoneAtStart(t *testing.T) { + service := NewService(nil, nil, nil, DefaultConfig()) + service.initialDelay = time.Hour + service.stopCh = make(chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + service.discoveryLoop(ctx) +} + +func TestService_DiscoverResource_WithScanResult(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + exec := &stubExecutor{ + agents: []ConnectedAgent{{AgentID: "host1", Hostname: "host1"}}, + } + scanner := NewDeepScanner(exec) + scanner.maxParallel = 1 + + state := StateSnapshot{ + DockerHosts: []DockerHost{ + { + AgentID: "host1", + Hostname: "host1", + Containers: []DockerContainer{ + {Name: "web", Image: "nginx:latest", Status: "running"}, + }, + }, + }, + } + + service := NewService(store, scanner, stubStateProvider{state: state}, DefaultConfig()) + service.SetAIAnalyzer(&stubAnalyzer{ + response: `{"service_type":"nginx","service_name":"Nginx","service_version":"1.2","category":"web_server","cli_access":"docker exec {container} nginx -v","facts":[],"config_paths":[],"data_paths":[],"ports":[{"port":80,"protocol":"tcp","process":"nginx","address":"0.0.0.0"}],"confidence":0.9,"reasoning":"image"}`, + }) + + existing := &ResourceDiscovery{ + ID: MakeResourceID(ResourceTypeDocker, "host1", "web"), + ResourceType: ResourceTypeDocker, + ResourceID: "web", + HostID: "host1", + Hostname: "host1", + UserNotes: "keep", + UserSecrets: map[string]string{"token": "secret"}, + DiscoveredAt: time.Now().Add(-2 * time.Hour), + } + if err := store.Save(existing); err != nil { + t.Fatalf("Save error: %v", err) + } + + found, err := service.DiscoverResource(context.Background(), DiscoveryRequest{ + ResourceType: ResourceTypeDocker, + ResourceID: "web", + HostID: "host1", + Hostname: "host1", + Force: true, + }) + if err != nil { + t.Fatalf("DiscoverResource error: %v", err) + } + if found.UserNotes != "keep" || found.UserSecrets["token"] != "secret" { + t.Fatalf("expected user fields preserved: %#v", found) + } + if len(found.RawCommandOutput) == 0 { + t.Fatalf("expected raw command output") + } + if found.DiscoveredAt.After(existing.DiscoveredAt) { + t.Fatalf("expected older discovered_at preserved") + } +} diff --git a/internal/aidiscovery/store.go b/internal/aidiscovery/store.go new file mode 100644 index 000000000..a38411d5a --- /dev/null +++ b/internal/aidiscovery/store.go @@ -0,0 +1,347 @@ +package aidiscovery + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/crypto" + "github.com/rs/zerolog/log" +) + +// CryptoManager interface for encryption/decryption. +type CryptoManager interface { + Encrypt(plaintext []byte) ([]byte, error) + Decrypt(ciphertext []byte) ([]byte, error) +} + +// Store provides encrypted per-resource storage for discovery data. +type Store struct { + mu sync.RWMutex + dataDir string + crypto CryptoManager + cache map[string]*ResourceDiscovery // In-memory cache + cacheTime map[string]time.Time // Cache timestamps + cacheTTL time.Duration +} + +// For testing - allows injecting a mock crypto manager +var newCryptoManagerAt = crypto.NewCryptoManagerAt + +// For testing - allows injecting a mock marshaler. +var marshalDiscovery = json.Marshal + +// NewStore creates a new discovery store with automatic encryption. +func NewStore(dataDir string) (*Store, error) { + discoveryDir := filepath.Join(dataDir, "discovery") + if err := os.MkdirAll(discoveryDir, 0700); err != nil { + return nil, fmt.Errorf("failed to create discovery directory: %w", err) + } + + // Initialize crypto manager for encryption (uses same key as other Pulse secrets) + cryptoMgr, err := newCryptoManagerAt(dataDir) + if err != nil { + log.Warn().Err(err).Msg("Failed to initialize crypto for discovery store, data will be unencrypted") + } + + return &Store{ + dataDir: discoveryDir, + crypto: cryptoMgr, + cache: make(map[string]*ResourceDiscovery), + cacheTime: make(map[string]time.Time), + cacheTTL: 5 * time.Minute, + }, nil +} + +// getFilePath returns the file path for a resource ID. +func (s *Store) getFilePath(id string) string { + // Sanitize ID for filename: replace : with _ + safeID := strings.ReplaceAll(id, ":", "_") + safeID = strings.ReplaceAll(safeID, "/", "_") + return filepath.Join(s.dataDir, safeID+".enc") +} + +// Save persists a discovery to encrypted storage. +func (s *Store) Save(d *ResourceDiscovery) error { + s.mu.Lock() + defer s.mu.Unlock() + + if d.ID == "" { + return fmt.Errorf("discovery ID is required") + } + + // Update timestamp + d.UpdatedAt = time.Now() + if d.DiscoveredAt.IsZero() { + d.DiscoveredAt = d.UpdatedAt + } + + data, err := marshalDiscovery(d) + if err != nil { + return fmt.Errorf("failed to marshal discovery: %w", err) + } + + // Encrypt if crypto is available + if s.crypto != nil { + encrypted, err := s.crypto.Encrypt(data) + if err != nil { + return fmt.Errorf("failed to encrypt discovery: %w", err) + } + data = encrypted + } + + // Write atomically using tmp file + rename + filePath := s.getFilePath(d.ID) + tmpPath := filePath + ".tmp" + + if err := os.WriteFile(tmpPath, data, 0600); err != nil { + return fmt.Errorf("failed to write discovery file: %w", err) + } + + if err := os.Rename(tmpPath, filePath); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("failed to finalize discovery file: %w", err) + } + + // Update cache + s.cache[d.ID] = d + s.cacheTime[d.ID] = time.Now() + + log.Debug().Str("id", d.ID).Str("service", d.ServiceType).Msg("Discovery saved") + return nil +} + +// Get retrieves a discovery from storage. +func (s *Store) Get(id string) (*ResourceDiscovery, error) { + s.mu.RLock() + // Check cache first + if cached, ok := s.cache[id]; ok { + if cacheTime, hasTime := s.cacheTime[id]; hasTime { + if time.Since(cacheTime) < s.cacheTTL { + s.mu.RUnlock() + return cached, nil + } + } + } + s.mu.RUnlock() + + s.mu.Lock() + defer s.mu.Unlock() + + filePath := s.getFilePath(id) + data, err := os.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil // Not found is not an error + } + return nil, fmt.Errorf("failed to read discovery file: %w", err) + } + + // Decrypt if crypto is available + if s.crypto != nil { + decrypted, err := s.crypto.Decrypt(data) + if err != nil { + return nil, fmt.Errorf("failed to decrypt discovery: %w", err) + } + data = decrypted + } + + var discovery ResourceDiscovery + if err := json.Unmarshal(data, &discovery); err != nil { + return nil, fmt.Errorf("failed to unmarshal discovery: %w", err) + } + + // Update cache + s.cache[id] = &discovery + s.cacheTime[id] = time.Now() + + return &discovery, nil +} + +// GetByResource retrieves a discovery by resource type and ID. +func (s *Store) GetByResource(resourceType ResourceType, hostID, resourceID string) (*ResourceDiscovery, error) { + id := MakeResourceID(resourceType, hostID, resourceID) + return s.Get(id) +} + +// Delete removes a discovery from storage. +func (s *Store) Delete(id string) error { + s.mu.Lock() + defer s.mu.Unlock() + + filePath := s.getFilePath(id) + if err := os.Remove(filePath); err != nil { + if os.IsNotExist(err) { + return nil // Already deleted + } + return fmt.Errorf("failed to delete discovery file: %w", err) + } + + // Remove from cache + delete(s.cache, id) + delete(s.cacheTime, id) + + log.Debug().Str("id", id).Msg("Discovery deleted") + return nil +} + +// List returns all discoveries. +func (s *Store) List() ([]*ResourceDiscovery, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entries, err := os.ReadDir(s.dataDir) + if err != nil { + if os.IsNotExist(err) { + return []*ResourceDiscovery{}, nil + } + return nil, fmt.Errorf("failed to list discovery directory: %w", err) + } + + var discoveries []*ResourceDiscovery + for _, entry := range entries { + // Skip tmp files first to avoid reading partial writes. + if strings.HasSuffix(entry.Name(), ".tmp") { + continue + } + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".enc") { + continue + } + + data, err := os.ReadFile(filepath.Join(s.dataDir, entry.Name())) + if err != nil { + log.Warn().Err(err).Str("file", entry.Name()).Msg("Failed to read discovery file") + continue + } + + // Decrypt if crypto is available + if s.crypto != nil { + decrypted, err := s.crypto.Decrypt(data) + if err != nil { + log.Warn().Err(err).Str("file", entry.Name()).Msg("Failed to decrypt discovery") + continue + } + data = decrypted + } + + var discovery ResourceDiscovery + if err := json.Unmarshal(data, &discovery); err != nil { + log.Warn().Err(err).Str("file", entry.Name()).Msg("Failed to unmarshal discovery") + continue + } + + discoveries = append(discoveries, &discovery) + } + + return discoveries, nil +} + +// ListByType returns discoveries for a specific resource type. +func (s *Store) ListByType(resourceType ResourceType) ([]*ResourceDiscovery, error) { + all, err := s.List() + if err != nil { + return nil, err + } + + var filtered []*ResourceDiscovery + for _, d := range all { + if d.ResourceType == resourceType { + filtered = append(filtered, d) + } + } + return filtered, nil +} + +// ListByHost returns discoveries for a specific host. +func (s *Store) ListByHost(hostID string) ([]*ResourceDiscovery, error) { + all, err := s.List() + if err != nil { + return nil, err + } + + var filtered []*ResourceDiscovery + for _, d := range all { + if d.HostID == hostID { + filtered = append(filtered, d) + } + } + return filtered, nil +} + +// UpdateNotes updates just the user notes and secrets for a discovery. +func (s *Store) UpdateNotes(id string, notes string, secrets map[string]string) error { + discovery, err := s.Get(id) + if err != nil { + return err + } + if discovery == nil { + return fmt.Errorf("discovery not found: %s", id) + } + + discovery.UserNotes = notes + if secrets != nil { + discovery.UserSecrets = secrets + } + + return s.Save(discovery) +} + +// GetMultiple retrieves multiple discoveries by ID. +func (s *Store) GetMultiple(ids []string) ([]*ResourceDiscovery, error) { + var discoveries []*ResourceDiscovery + for _, id := range ids { + d, err := s.Get(id) + if err != nil { + log.Warn().Err(err).Str("id", id).Msg("Failed to get discovery") + continue + } + if d != nil { + discoveries = append(discoveries, d) + } + } + return discoveries, nil +} + +// ClearCache clears the in-memory cache. +func (s *Store) ClearCache() { + s.mu.Lock() + defer s.mu.Unlock() + s.cache = make(map[string]*ResourceDiscovery) + s.cacheTime = make(map[string]time.Time) +} + +// Exists checks if a discovery exists for the given ID. +func (s *Store) Exists(id string) bool { + s.mu.RLock() + if _, ok := s.cache[id]; ok { + s.mu.RUnlock() + return true + } + s.mu.RUnlock() + + filePath := s.getFilePath(id) + _, err := os.Stat(filePath) + return err == nil +} + +// GetAge returns how old the discovery is, or -1 if not found. +func (s *Store) GetAge(id string) time.Duration { + d, err := s.Get(id) + if err != nil || d == nil { + return -1 + } + return time.Since(d.UpdatedAt) +} + +// NeedsRefresh checks if a discovery needs to be refreshed. +func (s *Store) NeedsRefresh(id string, maxAge time.Duration) bool { + age := s.GetAge(id) + if age < 0 { + return true // Not found, needs discovery + } + return age > maxAge +} diff --git a/internal/aidiscovery/store_test.go b/internal/aidiscovery/store_test.go new file mode 100644 index 000000000..ec1e32121 --- /dev/null +++ b/internal/aidiscovery/store_test.go @@ -0,0 +1,469 @@ +package aidiscovery + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/crypto" +) + +type fakeCrypto struct{} + +func (fakeCrypto) Encrypt(plaintext []byte) ([]byte, error) { + out := make([]byte, len(plaintext)) + for i := range plaintext { + out[i] = plaintext[len(plaintext)-1-i] + } + return out, nil +} + +func (fakeCrypto) Decrypt(ciphertext []byte) ([]byte, error) { + return fakeCrypto{}.Encrypt(ciphertext) +} + +type errorCrypto struct{} + +func (errorCrypto) Encrypt(plaintext []byte) ([]byte, error) { + return nil, os.ErrInvalid +} + +func (errorCrypto) Decrypt(ciphertext []byte) ([]byte, error) { + return nil, os.ErrInvalid +} + +func TestStore_SaveGetListAndNotes(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + d1 := &ResourceDiscovery{ + ID: MakeResourceID(ResourceTypeDocker, "host1", "nginx"), + ResourceType: ResourceTypeDocker, + ResourceID: "nginx", + HostID: "host1", + ServiceName: "Nginx", + } + if err := store.Save(d1); err != nil { + t.Fatalf("Save error: %v", err) + } + + got, err := store.Get(d1.ID) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if got == nil || got.ServiceName != "Nginx" { + t.Fatalf("unexpected discovery: %#v", got) + } + if !store.Exists(d1.ID) { + t.Fatalf("expected discovery to exist") + } + + if err := store.UpdateNotes(d1.ID, "notes", map[string]string{"token": "abc"}); err != nil { + t.Fatalf("UpdateNotes error: %v", err) + } + updated, err := store.Get(d1.ID) + if err != nil { + t.Fatalf("Get updated error: %v", err) + } + if updated.UserNotes != "notes" || updated.UserSecrets["token"] != "abc" { + t.Fatalf("notes not updated: %#v", updated) + } + + d2 := &ResourceDiscovery{ + ID: MakeResourceID(ResourceTypeVM, "node1", "101"), + ResourceType: ResourceTypeVM, + ResourceID: "101", + HostID: "node1", + ServiceName: "VM", + } + if err := store.Save(d2); err != nil { + t.Fatalf("Save d2 error: %v", err) + } + + list, err := store.List() + if err != nil { + t.Fatalf("List error: %v", err) + } + if len(list) != 2 { + t.Fatalf("expected 2 discoveries, got %d", len(list)) + } + + byType, err := store.ListByType(ResourceTypeVM) + if err != nil { + t.Fatalf("ListByType error: %v", err) + } + if len(byType) != 1 || byType[0].ID != d2.ID { + t.Fatalf("unexpected ListByType: %#v", byType) + } + + byHost, err := store.ListByHost("host1") + if err != nil { + t.Fatalf("ListByHost error: %v", err) + } + if len(byHost) != 1 || byHost[0].ID != d1.ID { + t.Fatalf("unexpected ListByHost: %#v", byHost) + } + + summary := updated.ToSummary() + if summary.ID != d1.ID || !summary.HasUserNotes { + t.Fatalf("unexpected summary: %#v", summary) + } + + if err := store.Delete(d1.ID); err != nil { + t.Fatalf("Delete error: %v", err) + } + if store.Exists(d1.ID) { + t.Fatalf("expected discovery to be deleted") + } +} + +func TestStore_CryptoRoundTripAndPaths(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = fakeCrypto{} + + id := "docker:host1:app/name" + d := &ResourceDiscovery{ + ID: id, + ResourceType: ResourceTypeDocker, + ResourceID: "app/name", + HostID: "host1", + ServiceName: "App", + } + if err := store.Save(d); err != nil { + t.Fatalf("Save error: %v", err) + } + + path := store.getFilePath(id) + base := filepath.Base(path) + if strings.Contains(base, ":") || strings.Contains(base, "/") { + t.Fatalf("expected sanitized base filename, got %s", base) + } + + loaded, err := store.Get(id) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if loaded == nil || loaded.ServiceName != "App" { + t.Fatalf("unexpected discovery: %#v", loaded) + } + + store.ClearCache() + if _, err := store.Get(id); err != nil { + t.Fatalf("Get with decrypt error: %v", err) + } + list, err := store.List() + if err != nil || len(list) != 1 { + t.Fatalf("List with decrypt error: %v len=%d", err, len(list)) + } +} + +func TestStore_NeedsRefreshAndGetMultiple(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + if !store.NeedsRefresh("missing", time.Minute) { + t.Fatalf("expected missing discovery to need refresh") + } + + d := &ResourceDiscovery{ + ID: MakeResourceID(ResourceTypeHost, "host1", "host1"), + ResourceType: ResourceTypeHost, + ResourceID: "host1", + HostID: "host1", + ServiceName: "Host", + } + if err := store.Save(d); err != nil { + t.Fatalf("Save error: %v", err) + } + + path := store.getFilePath(d.ID) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + var saved ResourceDiscovery + if err := json.Unmarshal(data, &saved); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + saved.UpdatedAt = time.Now().Add(-2 * time.Hour) + data, err = json.Marshal(&saved) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + + store.ClearCache() + if !store.NeedsRefresh(d.ID, time.Minute) { + t.Fatalf("expected old discovery to need refresh") + } + + ids := []string{d.ID, "missing"} + multi, err := store.GetMultiple(ids) + if err != nil { + t.Fatalf("GetMultiple error: %v", err) + } + if len(multi) != 1 || multi[0].ID != d.ID { + t.Fatalf("unexpected GetMultiple: %#v", multi) + } +} + +func TestStore_ErrorsAndListSkips(t *testing.T) { + dir := t.TempDir() + store, err := NewStore(dir) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + if err := store.Save(&ResourceDiscovery{}); err == nil { + t.Fatalf("expected error for empty ID") + } + + store.crypto = errorCrypto{} + if err := store.Save(&ResourceDiscovery{ID: "bad"}); err == nil { + t.Fatalf("expected encrypt error") + } + + store.crypto = nil + if _, err := store.Get("missing"); err != nil { + t.Fatalf("unexpected missing error: %v", err) + } + + d := &ResourceDiscovery{ + ID: MakeResourceID(ResourceTypeDocker, "host1", "web"), + ResourceType: ResourceTypeDocker, + ResourceID: "web", + HostID: "host1", + ServiceName: "Web", + UserSecrets: map[string]string{"token": "abc"}, + } + if err := store.Save(d); err != nil { + t.Fatalf("Save error: %v", err) + } + + // Corrupt file to force unmarshal error during List. + badPath := filepath.Join(store.dataDir, "bad.enc") + if err := os.WriteFile(badPath, []byte("{bad"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + if err := os.WriteFile(filepath.Join(store.dataDir, "note.txt"), []byte("skip"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + if err := os.WriteFile(filepath.Join(store.dataDir, "skip.enc.tmp"), []byte("skip"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + if err := os.MkdirAll(filepath.Join(store.dataDir, "dir"), 0700); err != nil { + t.Fatalf("MkdirAll error: %v", err) + } + unreadable := filepath.Join(store.dataDir, "unreadable.enc") + if err := os.WriteFile(unreadable, []byte("nope"), 0000); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + + list, err := store.List() + if err != nil { + t.Fatalf("List error: %v", err) + } + if len(list) != 1 { + t.Fatalf("expected 1 discovery, got %d", len(list)) + } + + store.crypto = errorCrypto{} + list, err = store.List() + if err != nil { + t.Fatalf("List with crypto error: %v", err) + } + if len(list) != 0 { + t.Fatalf("expected crypto errors to skip entries") + } + + store.crypto = errorCrypto{} + store.ClearCache() + if _, err := store.Get(d.ID); err == nil { + t.Fatalf("expected decrypt error") + } + + store.crypto = nil + if got, err := store.GetByResource(ResourceTypeDocker, "host1", "web"); err != nil || got == nil { + t.Fatalf("GetByResource error: %v", err) + } + + if err := store.UpdateNotes(d.ID, "notes-only", nil); err != nil { + t.Fatalf("UpdateNotes error: %v", err) + } + updated, err := store.Get(d.ID) + if err != nil || updated.UserSecrets == nil { + t.Fatalf("expected secrets to be preserved: %#v err=%v", updated, err) + } + + store.crypto = errorCrypto{} + store.ClearCache() + if err := store.UpdateNotes(d.ID, "notes", nil); err == nil { + t.Fatalf("expected update notes error with crypto failure") + } + if got, err := store.GetMultiple([]string{d.ID}); err != nil || len(got) != 0 { + t.Fatalf("expected GetMultiple to skip errors") + } + + if err := store.UpdateNotes("missing", "notes", nil); err == nil { + t.Fatalf("expected error for missing discovery") + } + + if err := store.Delete("missing"); err != nil { + t.Fatalf("unexpected delete error: %v", err) + } +} + +func TestStore_NewStoreError(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "file") + if err := os.WriteFile(file, []byte("x"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + + if _, err := NewStore(file); err == nil { + t.Fatalf("expected error for file data dir") + } +} + +func TestStore_NewStoreCryptoFailure(t *testing.T) { + orig := newCryptoManagerAt + newCryptoManagerAt = func(dataDir string) (*crypto.CryptoManager, error) { + manager, err := crypto.NewCryptoManagerAt(dataDir) + if err != nil { + return nil, err + } + return manager, os.ErrInvalid + } + t.Cleanup(func() { + newCryptoManagerAt = orig + }) + + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + if store.crypto == nil { + t.Fatalf("expected crypto manager despite init warning") + } +} + +func TestStore_SaveMarshalError(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + orig := marshalDiscovery + marshalDiscovery = func(any) ([]byte, error) { + return nil, os.ErrInvalid + } + t.Cleanup(func() { + marshalDiscovery = orig + }) + + if err := store.Save(&ResourceDiscovery{ID: "marshal"}); err == nil { + t.Fatalf("expected marshal error") + } +} + +func TestStore_SaveAndGetErrors(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + id := MakeResourceID(ResourceTypeDocker, "host1", "web") + filePath := store.getFilePath(id) + if err := os.MkdirAll(filePath, 0700); err != nil { + t.Fatalf("MkdirAll error: %v", err) + } + if err := store.Save(&ResourceDiscovery{ID: id}); err == nil { + t.Fatalf("expected rename error") + } + + tmpFile := filepath.Join(t.TempDir(), "file") + if err := os.WriteFile(tmpFile, []byte("x"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + store.dataDir = tmpFile + if err := store.Save(&ResourceDiscovery{ID: "bad"}); err == nil { + t.Fatalf("expected write error") + } + + store.dataDir = t.TempDir() + store.crypto = nil + badPath := store.getFilePath("bad") + if err := os.WriteFile(badPath, []byte("{bad"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + if _, err := store.Get("bad"); err == nil { + t.Fatalf("expected unmarshal error") + } +} + +func TestStore_ListErrors(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + store.dataDir = filepath.Join(t.TempDir(), "missing") + list, err := store.List() + if err != nil || len(list) != 0 { + t.Fatalf("expected empty list for missing dir") + } + + file := filepath.Join(t.TempDir(), "file") + if err := os.WriteFile(file, []byte("x"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + store.dataDir = file + if _, err := store.List(); err == nil { + t.Fatalf("expected list error for file path") + } + if _, err := store.ListByType(ResourceTypeDocker); err == nil { + t.Fatalf("expected list by type error") + } + if _, err := store.ListByHost("host1"); err == nil { + t.Fatalf("expected list by host error") + } +} + +func TestStore_DeleteError(t *testing.T) { + store, err := NewStore(t.TempDir()) + if err != nil { + t.Fatalf("NewStore error: %v", err) + } + store.crypto = nil + + id := MakeResourceID(ResourceTypeDocker, "host1", "dir") + filePath := store.getFilePath(id) + if err := os.MkdirAll(filePath, 0700); err != nil { + t.Fatalf("MkdirAll error: %v", err) + } + nested := filepath.Join(filePath, "nested") + if err := os.WriteFile(nested, []byte("x"), 0600); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + if err := store.Delete(id); err == nil { + t.Fatalf("expected delete error for non-empty dir") + } +} diff --git a/internal/aidiscovery/tools_adapter.go b/internal/aidiscovery/tools_adapter.go new file mode 100644 index 000000000..5658f7fe9 --- /dev/null +++ b/internal/aidiscovery/tools_adapter.go @@ -0,0 +1,155 @@ +package aidiscovery + +import ( + "github.com/rcourtman/pulse-go-rewrite/internal/ai/tools" +) + +// ToolsAdapter wraps Service to implement tools.DiscoverySource +type ToolsAdapter struct { + service *Service +} + +// NewToolsAdapter creates a new adapter for the discovery service +func NewToolsAdapter(service *Service) *ToolsAdapter { + if service == nil { + return nil + } + return &ToolsAdapter{service: service} +} + +// GetDiscovery implements tools.DiscoverySource +func (a *ToolsAdapter) GetDiscovery(id string) (tools.DiscoverySourceData, error) { + discovery, err := a.service.GetDiscovery(id) + if err != nil { + return tools.DiscoverySourceData{}, err + } + if discovery == nil { + return tools.DiscoverySourceData{}, nil + } + return a.convertToSourceData(discovery), nil +} + +// GetDiscoveryByResource implements tools.DiscoverySource +func (a *ToolsAdapter) GetDiscoveryByResource(resourceType, hostID, resourceID string) (tools.DiscoverySourceData, error) { + discovery, err := a.service.GetDiscoveryByResource(ResourceType(resourceType), hostID, resourceID) + if err != nil { + return tools.DiscoverySourceData{}, err + } + if discovery == nil { + return tools.DiscoverySourceData{}, nil + } + return a.convertToSourceData(discovery), nil +} + +// ListDiscoveries implements tools.DiscoverySource +func (a *ToolsAdapter) ListDiscoveries() ([]tools.DiscoverySourceData, error) { + discoveries, err := a.service.ListDiscoveries() + if err != nil { + return nil, err + } + return a.convertList(discoveries), nil +} + +// ListDiscoveriesByType implements tools.DiscoverySource +func (a *ToolsAdapter) ListDiscoveriesByType(resourceType string) ([]tools.DiscoverySourceData, error) { + discoveries, err := a.service.ListDiscoveriesByType(ResourceType(resourceType)) + if err != nil { + return nil, err + } + return a.convertList(discoveries), nil +} + +// ListDiscoveriesByHost implements tools.DiscoverySource +func (a *ToolsAdapter) ListDiscoveriesByHost(hostID string) ([]tools.DiscoverySourceData, error) { + discoveries, err := a.service.ListDiscoveriesByHost(hostID) + if err != nil { + return nil, err + } + return a.convertList(discoveries), nil +} + +// FormatForAIContext implements tools.DiscoverySource +func (a *ToolsAdapter) FormatForAIContext(sourceData []tools.DiscoverySourceData) string { + // Convert back to ResourceDiscovery for formatting + discoveries := make([]*ResourceDiscovery, 0, len(sourceData)) + for _, sd := range sourceData { + discoveries = append(discoveries, a.convertFromSourceData(sd)) + } + return FormatForAIContext(discoveries) +} + +func (a *ToolsAdapter) convertToSourceData(d *ResourceDiscovery) tools.DiscoverySourceData { + facts := make([]tools.DiscoverySourceFact, 0, len(d.Facts)) + for _, f := range d.Facts { + facts = append(facts, tools.DiscoverySourceFact{ + Category: string(f.Category), + Key: f.Key, + Value: f.Value, + Source: f.Source, + }) + } + + return tools.DiscoverySourceData{ + ID: d.ID, + ResourceType: string(d.ResourceType), + ResourceID: d.ResourceID, + HostID: d.HostID, + Hostname: d.Hostname, + ServiceType: d.ServiceType, + ServiceName: d.ServiceName, + ServiceVersion: d.ServiceVersion, + Category: string(d.Category), + CLIAccess: d.CLIAccess, + Facts: facts, + ConfigPaths: d.ConfigPaths, + DataPaths: d.DataPaths, + UserNotes: d.UserNotes, + Confidence: d.Confidence, + AIReasoning: d.AIReasoning, + DiscoveredAt: d.DiscoveredAt, + UpdatedAt: d.UpdatedAt, + } +} + +func (a *ToolsAdapter) convertFromSourceData(sd tools.DiscoverySourceData) *ResourceDiscovery { + facts := make([]DiscoveryFact, 0, len(sd.Facts)) + for _, f := range sd.Facts { + facts = append(facts, DiscoveryFact{ + Category: FactCategory(f.Category), + Key: f.Key, + Value: f.Value, + Source: f.Source, + }) + } + + return &ResourceDiscovery{ + ID: sd.ID, + ResourceType: ResourceType(sd.ResourceType), + ResourceID: sd.ResourceID, + HostID: sd.HostID, + Hostname: sd.Hostname, + ServiceType: sd.ServiceType, + ServiceName: sd.ServiceName, + ServiceVersion: sd.ServiceVersion, + Category: ServiceCategory(sd.Category), + CLIAccess: sd.CLIAccess, + Facts: facts, + ConfigPaths: sd.ConfigPaths, + DataPaths: sd.DataPaths, + UserNotes: sd.UserNotes, + Confidence: sd.Confidence, + AIReasoning: sd.AIReasoning, + DiscoveredAt: sd.DiscoveredAt, + UpdatedAt: sd.UpdatedAt, + } +} + +func (a *ToolsAdapter) convertList(discoveries []*ResourceDiscovery) []tools.DiscoverySourceData { + result := make([]tools.DiscoverySourceData, 0, len(discoveries)) + for _, d := range discoveries { + if d != nil { + result = append(result, a.convertToSourceData(d)) + } + } + return result +} diff --git a/internal/aidiscovery/types.go b/internal/aidiscovery/types.go new file mode 100644 index 000000000..48ca9aaf9 --- /dev/null +++ b/internal/aidiscovery/types.go @@ -0,0 +1,236 @@ +// Package discovery provides AI-powered infrastructure discovery capabilities. +// It discovers services, versions, configurations, and CLI access methods +// for VMs, LXCs, Docker containers, Kubernetes pods, and hosts. +package aidiscovery + +import ( + "fmt" + "time" +) + +// ResourceType identifies the type of infrastructure resource. +type ResourceType string + +const ( + ResourceTypeVM ResourceType = "vm" + ResourceTypeLXC ResourceType = "lxc" + ResourceTypeDocker ResourceType = "docker" + ResourceTypeK8s ResourceType = "k8s" + ResourceTypeHost ResourceType = "host" + ResourceTypeDockerVM ResourceType = "docker_vm" // Docker on a VM + ResourceTypeDockerLXC ResourceType = "docker_lxc" // Docker in an LXC +) + +// FactCategory categorizes discovery facts. +type FactCategory string + +const ( + FactCategoryVersion FactCategory = "version" + FactCategoryConfig FactCategory = "config" + FactCategoryService FactCategory = "service" + FactCategoryPort FactCategory = "port" + FactCategoryHardware FactCategory = "hardware" + FactCategoryNetwork FactCategory = "network" + FactCategoryStorage FactCategory = "storage" + FactCategoryDependency FactCategory = "dependency" + FactCategorySecurity FactCategory = "security" +) + +// ServiceCategory categorizes the type of service discovered. +type ServiceCategory string + +const ( + CategoryDatabase ServiceCategory = "database" + CategoryWebServer ServiceCategory = "web_server" + CategoryCache ServiceCategory = "cache" + CategoryMessageQueue ServiceCategory = "message_queue" + CategoryMonitoring ServiceCategory = "monitoring" + CategoryBackup ServiceCategory = "backup" + CategoryNVR ServiceCategory = "nvr" + CategoryStorage ServiceCategory = "storage" + CategoryContainer ServiceCategory = "container" + CategoryVirtualizer ServiceCategory = "virtualizer" + CategoryNetwork ServiceCategory = "network" + CategorySecurity ServiceCategory = "security" + CategoryMedia ServiceCategory = "media" + CategoryHomeAuto ServiceCategory = "home_automation" + CategoryUnknown ServiceCategory = "unknown" +) + +// ResourceDiscovery is the main data model for discovered resource information. +type ResourceDiscovery struct { + // Identity + ID string `json:"id"` // Unique ID: "lxc:minipc:101" + ResourceType ResourceType `json:"resource_type"` // vm, lxc, docker, k8s, host + ResourceID string `json:"resource_id"` // 101, container-name, etc. + HostID string `json:"host_id"` // Proxmox node name or host agent ID + Hostname string `json:"hostname"` // Human-readable host name + + // AI-discovered info + ServiceType string `json:"service_type"` // frigate, postgres, pbs + ServiceName string `json:"service_name"` // Human-readable name + ServiceVersion string `json:"service_version"` // v0.13.2 + Category ServiceCategory `json:"category"` // nvr, database, backup + CLIAccess string `json:"cli_access"` // pct exec 101 -- ... + + // Deep discovery facts + Facts []DiscoveryFact `json:"facts"` + ConfigPaths []string `json:"config_paths"` + DataPaths []string `json:"data_paths"` + Ports []PortInfo `json:"ports"` + + // User-added (also encrypted) + UserNotes string `json:"user_notes"` + UserSecrets map[string]string `json:"user_secrets"` // tokens, creds + + // Metadata + Confidence float64 `json:"confidence"` // 0-1 confidence score + AIReasoning string `json:"ai_reasoning"` // AI explanation + DiscoveredAt time.Time `json:"discovered_at"` // First discovery + UpdatedAt time.Time `json:"updated_at"` // Last update + ScanDuration int64 `json:"scan_duration"` // Scan duration in ms + + // Raw data for debugging/re-analysis + RawCommandOutput map[string]string `json:"raw_command_output,omitempty"` +} + +// DiscoveryFact represents a single discovered fact about a resource. +type DiscoveryFact struct { + Category FactCategory `json:"category"` // version, config, service, port + Key string `json:"key"` // e.g., "coral_tpu", "mqtt_broker" + Value string `json:"value"` // e.g., "/dev/apex_0", "mosquitto:1883" + Source string `json:"source"` // command that found this + Confidence float64 `json:"confidence"` // 0-1 confidence for this fact + DiscoveredAt time.Time `json:"discovered_at"` +} + +// PortInfo represents information about a listening port. +type PortInfo struct { + Port int `json:"port"` + Protocol string `json:"protocol"` // tcp, udp + Process string `json:"process"` // process name + Address string `json:"address"` // bind address +} + +// MakeResourceID creates a standardized resource ID. +func MakeResourceID(resourceType ResourceType, hostID, resourceID string) string { + return fmt.Sprintf("%s:%s:%s", resourceType, hostID, resourceID) +} + +// ParseResourceID parses a resource ID into its components. +func ParseResourceID(id string) (resourceType ResourceType, hostID, resourceID string, err error) { + var parts [3]string + count := 0 + start := 0 + for i, c := range id { + if c == ':' { + if count < 2 { + parts[count] = id[start:i] + count++ + start = i + 1 + } + } + } + if count == 2 { + parts[2] = id[start:] + return ResourceType(parts[0]), parts[1], parts[2], nil + } + return "", "", "", fmt.Errorf("invalid resource ID format: %s", id) +} + +// DiscoveryRequest represents a request to discover a resource. +type DiscoveryRequest struct { + ResourceType ResourceType `json:"resource_type"` + ResourceID string `json:"resource_id"` + HostID string `json:"host_id"` + Hostname string `json:"hostname"` + Force bool `json:"force"` // Force re-scan even if recent +} + +// DiscoveryStatus represents the status of a discovery scan. +type DiscoveryStatus string + +const ( + DiscoveryStatusPending DiscoveryStatus = "pending" + DiscoveryStatusRunning DiscoveryStatus = "running" + DiscoveryStatusCompleted DiscoveryStatus = "completed" + DiscoveryStatusFailed DiscoveryStatus = "failed" + DiscoveryStatusNotStarted DiscoveryStatus = "not_started" +) + +// DiscoveryProgress represents the progress of an ongoing discovery. +type DiscoveryProgress struct { + ResourceID string `json:"resource_id"` + Status DiscoveryStatus `json:"status"` + CurrentStep string `json:"current_step"` + TotalSteps int `json:"total_steps"` + CompletedSteps int `json:"completed_steps"` + StartedAt time.Time `json:"started_at"` + Error string `json:"error,omitempty"` +} + +// UpdateNotesRequest represents a request to update user notes. +type UpdateNotesRequest struct { + UserNotes string `json:"user_notes"` + UserSecrets map[string]string `json:"user_secrets,omitempty"` +} + +// DiscoverySummary provides a summary of discoveries for listing. +type DiscoverySummary struct { + ID string `json:"id"` + ResourceType ResourceType `json:"resource_type"` + ResourceID string `json:"resource_id"` + HostID string `json:"host_id"` + Hostname string `json:"hostname"` + ServiceType string `json:"service_type"` + ServiceName string `json:"service_name"` + ServiceVersion string `json:"service_version"` + Category ServiceCategory `json:"category"` + Confidence float64 `json:"confidence"` + HasUserNotes bool `json:"has_user_notes"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ToSummary converts a full discovery to a summary. +func (d *ResourceDiscovery) ToSummary() DiscoverySummary { + return DiscoverySummary{ + ID: d.ID, + ResourceType: d.ResourceType, + ResourceID: d.ResourceID, + HostID: d.HostID, + Hostname: d.Hostname, + ServiceType: d.ServiceType, + ServiceName: d.ServiceName, + ServiceVersion: d.ServiceVersion, + Category: d.Category, + Confidence: d.Confidence, + HasUserNotes: d.UserNotes != "", + UpdatedAt: d.UpdatedAt, + } +} + +// AIAnalysisRequest is sent to the AI for analysis. +type AIAnalysisRequest struct { + ResourceType ResourceType `json:"resource_type"` + ResourceID string `json:"resource_id"` + HostID string `json:"host_id"` + Hostname string `json:"hostname"` + CommandOutputs map[string]string `json:"command_outputs"` + ExistingFacts []DiscoveryFact `json:"existing_facts,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` // Image, labels, etc. +} + +// AIAnalysisResponse is returned by the AI. +type AIAnalysisResponse struct { + ServiceType string `json:"service_type"` + ServiceName string `json:"service_name"` + ServiceVersion string `json:"service_version"` + Category ServiceCategory `json:"category"` + CLIAccess string `json:"cli_access"` + Facts []DiscoveryFact `json:"facts"` + ConfigPaths []string `json:"config_paths"` + DataPaths []string `json:"data_paths"` + Ports []PortInfo `json:"ports"` + Confidence float64 `json:"confidence"` + Reasoning string `json:"reasoning"` +} diff --git a/internal/aidiscovery/types_test.go b/internal/aidiscovery/types_test.go new file mode 100644 index 000000000..405ca10a0 --- /dev/null +++ b/internal/aidiscovery/types_test.go @@ -0,0 +1,22 @@ +package aidiscovery + +import "testing" + +func TestResourceIDHelpers(t *testing.T) { + id := MakeResourceID(ResourceTypeDocker, "host1", "app") + if id != "docker:host1:app" { + t.Fatalf("unexpected id: %s", id) + } + + rt, host, res, err := ParseResourceID(id) + if err != nil { + t.Fatalf("ParseResourceID error: %v", err) + } + if rt != ResourceTypeDocker || host != "host1" || res != "app" { + t.Fatalf("unexpected parse result: %s %s %s", rt, host, res) + } + + if _, _, _, err := ParseResourceID("invalid"); err == nil { + t.Fatalf("expected parse error for invalid id") + } +} diff --git a/internal/api/agent_profiles_tools_extra_test.go b/internal/api/agent_profiles_tools_extra_test.go new file mode 100644 index 000000000..95be7a3cc --- /dev/null +++ b/internal/api/agent_profiles_tools_extra_test.go @@ -0,0 +1,99 @@ +package api + +import ( + "context" + "strings" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/license" + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +func TestFormatValidationIssues(t *testing.T) { + result := models.ValidationResult{ + Errors: []models.ValidationError{ + {Key: "bad", Message: "invalid"}, + }, + Warnings: []models.ValidationError{ + {Key: "warn", Message: "unknown"}, + }, + } + msg := formatValidationIssues(result) + if !strings.Contains(msg, "bad") || !strings.Contains(msg, "warning") { + t.Fatalf("expected errors and warnings in message: %s", msg) + } + + if formatValidationIssues(models.ValidationResult{}) != "unknown validation error" { + t.Fatalf("expected fallback message") + } +} + +func TestBuildScopeProfileName(t *testing.T) { + if buildScopeProfileName("", "agent") != "Patrol Scope: agent" { + t.Fatalf("expected default name") + } + if buildScopeProfileName("agent", "agent") != "Patrol Scope: agent" { + t.Fatalf("expected label equal to agent ID to be simplified") + } + if buildScopeProfileName("Alpha", "agent") != "Patrol Scope: Alpha (agent)" { + t.Fatalf("expected label to be included") + } +} + +func TestMCPAgentProfileManager_ValidateSettings(t *testing.T) { + manager := newTestProfileManager(t) + if err := manager.validateSettings(map[string]interface{}{"unknown_key": true}); err == nil { + t.Fatalf("expected validation error for warnings") + } + if err := manager.validateSettings(map[string]interface{}{"enable_host": "nope"}); err == nil { + t.Fatalf("expected validation error for invalid type") + } +} + +func TestMCPAgentProfileManager_RequireLicense(t *testing.T) { + persistence := config.NewConfigPersistence(t.TempDir()) + licenseService := license.NewService() + manager := NewMCPAgentProfileManager(persistence, licenseService) + + _, _, _, err := manager.ApplyAgentScope(context.Background(), "agent-1", "Alpha", map[string]interface{}{"enable_host": true}) + if err == nil { + t.Fatalf("expected license error") + } +} + +func TestMCPAgentProfileManager_SaveVersion(t *testing.T) { + manager := newTestProfileManager(t) + profile := models.AgentProfile{ + ID: "profile-1", + Name: "Default", + Version: 2, + Config: map[string]interface{}{ + "enable_host": true, + }, + } + if err := manager.saveVersion(profile, "note"); err != nil { + t.Fatalf("unexpected saveVersion error: %v", err) + } + versions, err := manager.persistence.LoadAgentProfileVersions() + if err != nil { + t.Fatalf("unexpected load versions error: %v", err) + } + if len(versions) != 1 || versions[0].Version != 2 { + t.Fatalf("expected version to be saved") + } +} + +func TestMCPAgentProfileManager_AssignProfile_NotFound(t *testing.T) { + manager := newTestProfileManager(t) + if _, err := manager.AssignProfile(context.Background(), "agent-1", "missing"); err == nil { + t.Fatalf("expected error for missing profile") + } +} + +func TestMCPAgentProfileManager_GetScope_EmptyAgent(t *testing.T) { + manager := newTestProfileManager(t) + if _, err := manager.GetAgentScope(context.Background(), ""); err == nil { + t.Fatalf("expected error for empty agent ID") + } +} diff --git a/internal/api/ai_handler.go b/internal/api/ai_handler.go index 4a36ac378..7e68d51fa 100644 --- a/internal/api/ai_handler.go +++ b/internal/api/ai_handler.go @@ -58,6 +58,7 @@ type AIService interface { SetIncidentRecorderProvider(provider chat.IncidentRecorderProvider) SetEventCorrelatorProvider(provider chat.EventCorrelatorProvider) SetTopologyProvider(provider chat.TopologyProvider) + SetDiscoveryProvider(provider chat.MCPDiscoveryProvider) UpdateControlSettings(cfg *config.AIConfig) GetBaseURL() string } diff --git a/internal/api/ai_handler_test.go b/internal/api/ai_handler_test.go index 1ebc05781..2b3671ab7 100644 --- a/internal/api/ai_handler_test.go +++ b/internal/api/ai_handler_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "path/filepath" "strings" "testing" @@ -12,6 +13,7 @@ import ( "github.com/rcourtman/pulse-go-rewrite/internal/ai/chat" "github.com/rcourtman/pulse-go-rewrite/internal/config" "github.com/rcourtman/pulse-go-rewrite/internal/models" + "github.com/rcourtman/pulse-go-rewrite/internal/monitoring" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -820,3 +822,87 @@ func TestHandleStatus_NoService(t *testing.T) { json.NewDecoder(w.Body).Decode(&resp) assert.False(t, resp["running"].(bool)) } + +func TestGetService_MultiTenantInitAndCache(t *testing.T) { + oldNewService := newChatService + defer func() { newChatService = oldNewService }() + + tempDir := t.TempDir() + mtp := config.NewMultiTenantPersistence(tempDir) + h := NewAIHandler(mtp, nil, nil) + + mockSvc := new(MockAIService) + mockSvc.On("Start", mock.Anything).Return(nil).Once() + + var gotCfg chat.Config + newChatService = func(cfg chat.Config) AIService { + gotCfg = cfg + return mockSvc + } + + ctx := context.WithValue(context.Background(), OrgIDContextKey, "acme") + svc := h.GetService(ctx) + assert.Same(t, mockSvc, svc) + + expectedDir := filepath.Join(tempDir, "orgs", "acme") + assert.Equal(t, expectedDir, gotCfg.DataDir) + assert.NotNil(t, gotCfg.AIConfig) + + // Second call should return cached service without re-starting + svc = h.GetService(ctx) + assert.Same(t, mockSvc, svc) + + mockSvc.AssertExpectations(t) +} + +func TestRemoveTenantService(t *testing.T) { + h := NewAIHandler(nil, nil, nil) + mockSvc := new(MockAIService) + mockSvc.On("Stop", mock.Anything).Return(assert.AnError).Once() + h.services["acme"] = mockSvc + + err := h.RemoveTenantService(context.Background(), "acme") + assert.NoError(t, err) + _, exists := h.services["acme"] + assert.False(t, exists) + + mockSvc.AssertExpectations(t) +} + +func TestRemoveTenantService_DefaultNoop(t *testing.T) { + h := NewAIHandler(nil, nil, nil) + mockSvc := new(MockAIService) + h.services["default"] = mockSvc + + err := h.RemoveTenantService(context.Background(), "default") + assert.NoError(t, err) + _, exists := h.services["default"] + assert.True(t, exists) +} + +func TestGetConfig_DefaultFallback(t *testing.T) { + cfg := &config.Config{APIToken: "token"} + h := newTestAIHandler(cfg, nil, nil) + ctx := context.WithValue(context.Background(), OrgIDContextKey, "acme") + + result := h.getConfig(ctx) + assert.Same(t, cfg, result) +} + +func TestGetDataDirDefault(t *testing.T) { + h := newTestAIHandler(nil, nil, nil) + assert.Equal(t, "data", h.getDataDir(nil, "")) + assert.Equal(t, "custom", h.getDataDir(nil, "custom")) +} + +func TestSetMultiTenantPointers(t *testing.T) { + h := NewAIHandler(nil, nil, nil) + mtp := config.NewMultiTenantPersistence(t.TempDir()) + mtm := &monitoring.MultiTenantMonitor{} + + h.SetMultiTenantPersistence(mtp) + h.SetMultiTenantMonitor(mtm) + + assert.Same(t, mtp, h.mtPersistence) + assert.Same(t, mtm, h.mtMonitor) +} diff --git a/internal/api/ai_handlers.go b/internal/api/ai_handlers.go index 74dc9257e..4f8a9faff 100644 --- a/internal/api/ai_handlers.go +++ b/internal/api/ai_handlers.go @@ -30,6 +30,7 @@ import ( "github.com/rcourtman/pulse-go-rewrite/internal/ai/proxmox" "github.com/rcourtman/pulse-go-rewrite/internal/ai/remediation" "github.com/rcourtman/pulse-go-rewrite/internal/ai/unified" + "github.com/rcourtman/pulse-go-rewrite/internal/aidiscovery" "github.com/rcourtman/pulse-go-rewrite/internal/config" "github.com/rcourtman/pulse-go-rewrite/internal/license" "github.com/rcourtman/pulse-go-rewrite/internal/metrics" @@ -83,6 +84,9 @@ type AISettingsHandler struct { chatHandler *AIHandler // Chat service handler for investigations investigationStores map[string]*investigation.Store // Investigation stores per org investigationMu sync.RWMutex + + // AI Discovery store for deep infrastructure discovery + aiDiscoveryStore *aidiscovery.Store } // NewAISettingsHandler creates a new AI settings handler @@ -189,6 +193,9 @@ func (h *AISettingsHandler) GetAIService(ctx context.Context) *ai.Service { if h.correlationDetector != nil { svc.SetCorrelationDetector(h.correlationDetector) } + if h.aiDiscoveryStore != nil { + svc.SetAIDiscoveryStore(h.aiDiscoveryStore) + } // Set license checker if handler available if h.licenseHandlers != nil { @@ -527,6 +534,26 @@ func (h *AISettingsHandler) GetUnifiedStore() *unified.UnifiedStore { return h.unifiedStore } +// SetAIDiscoveryStore sets the AI discovery store for deep infrastructure discovery +func (h *AISettingsHandler) SetAIDiscoveryStore(store *aidiscovery.Store) { + h.aiDiscoveryStore = store + // Also set on legacy service if it exists + if h.legacyAIService != nil { + h.legacyAIService.SetAIDiscoveryStore(store) + } + // Set on all existing tenant services + h.aiServicesMu.RLock() + defer h.aiServicesMu.RUnlock() + for _, svc := range h.aiServices { + svc.SetAIDiscoveryStore(store) + } +} + +// GetAIDiscoveryStore returns the AI discovery store +func (h *AISettingsHandler) GetAIDiscoveryStore() *aidiscovery.Store { + return h.aiDiscoveryStore +} + // SetAlertBridge sets the alert bridge func (h *AISettingsHandler) SetAlertBridge(bridge *unified.AlertBridge) { h.alertBridge = bridge @@ -739,6 +766,13 @@ func (h *AISettingsHandler) setupInvestigationOrchestrator(orgID string, svc *ai // The chatAdapter implements both ChatService and CommandExecutor interfaces orchestrator.SetCommandExecutor(chatAdapter) + // Set infrastructure context provider for CLI access information + // This enables investigations to know where services run (Docker, systemd, native) + // and propose correct commands (e.g., 'docker exec pbs proxmox-backup-manager ...') + if knowledgeStore := svc.GetKnowledgeStore(); knowledgeStore != nil { + orchestrator.SetInfrastructureContextProvider(knowledgeStore) + } + // Create adapter to bridge investigation.Orchestrator to ai.InvestigationOrchestrator interface adapter := ai.NewInvestigationOrchestratorAdapter(orchestrator) @@ -812,6 +846,9 @@ type AISettingsResponse struct { // Infrastructure control settings ControlLevel string `json:"control_level"` // "read_only", "controlled", "autonomous" ProtectedGuests []string `json:"protected_guests,omitempty"` // VMIDs/names that AI cannot control + // AI Discovery settings + DiscoveryEnabled bool `json:"discovery_enabled"` // true if AI discovery is enabled + DiscoveryIntervalHours int `json:"discovery_interval_hours,omitempty"` // Hours between auto-scans (0 = manual only) } // AISettingsUpdateRequest is the request body for PUT /api/settings/ai @@ -854,6 +891,9 @@ type AISettingsUpdateRequest struct { // Infrastructure control settings ControlLevel *string `json:"control_level,omitempty"` // "read_only", "controlled", "autonomous" ProtectedGuests []string `json:"protected_guests,omitempty"` // VMIDs/names that AI cannot control (nil = don't update, empty = clear) + // AI Discovery settings + DiscoveryEnabled *bool `json:"discovery_enabled,omitempty"` // Enable AI discovery + DiscoveryIntervalHours *int `json:"discovery_interval_hours,omitempty"` // Hours between auto-scans (0 = manual only) } // HandleGetAISettings returns the current AI settings (GET /api/settings/ai) @@ -908,18 +948,20 @@ func (h *AISettingsHandler) HandleGetAISettings(w http.ResponseWriter, r *http.R UseProactiveThresholds: settings.UseProactiveThresholds, AvailableModels: nil, // Now populated via /api/ai/models endpoint // Multi-provider configuration - AnthropicConfigured: settings.HasProvider(config.AIProviderAnthropic), - OpenAIConfigured: settings.HasProvider(config.AIProviderOpenAI), - DeepSeekConfigured: settings.HasProvider(config.AIProviderDeepSeek), - GeminiConfigured: settings.HasProvider(config.AIProviderGemini), - OllamaConfigured: settings.HasProvider(config.AIProviderOllama), - OllamaBaseURL: settings.GetBaseURLForProvider(config.AIProviderOllama), - OpenAIBaseURL: settings.OpenAIBaseURL, - ConfiguredProviders: settings.GetConfiguredProviders(), - CostBudgetUSD30d: settings.CostBudgetUSD30d, - RequestTimeoutSeconds: settings.RequestTimeoutSeconds, - ControlLevel: settings.GetControlLevel(), - ProtectedGuests: settings.GetProtectedGuests(), + AnthropicConfigured: settings.HasProvider(config.AIProviderAnthropic), + OpenAIConfigured: settings.HasProvider(config.AIProviderOpenAI), + DeepSeekConfigured: settings.HasProvider(config.AIProviderDeepSeek), + GeminiConfigured: settings.HasProvider(config.AIProviderGemini), + OllamaConfigured: settings.HasProvider(config.AIProviderOllama), + OllamaBaseURL: settings.GetBaseURLForProvider(config.AIProviderOllama), + OpenAIBaseURL: settings.OpenAIBaseURL, + ConfiguredProviders: settings.GetConfiguredProviders(), + CostBudgetUSD30d: settings.CostBudgetUSD30d, + RequestTimeoutSeconds: settings.RequestTimeoutSeconds, + ControlLevel: settings.GetControlLevel(), + ProtectedGuests: settings.GetProtectedGuests(), + DiscoveryEnabled: settings.IsDiscoveryEnabled(), + DiscoveryIntervalHours: settings.DiscoveryIntervalHours, } if err := utils.WriteJSONResponse(w, response); err != nil { @@ -1269,6 +1311,18 @@ func (h *AISettingsHandler) HandleUpdateAISettings(w http.ResponseWriter, r *htt settings.ProtectedGuests = req.ProtectedGuests } + // Handle discovery settings + if req.DiscoveryEnabled != nil { + settings.DiscoveryEnabled = *req.DiscoveryEnabled + } + if req.DiscoveryIntervalHours != nil { + if *req.DiscoveryIntervalHours < 0 { + http.Error(w, "discovery_interval_hours cannot be negative", http.StatusBadRequest) + return + } + settings.DiscoveryIntervalHours = *req.DiscoveryIntervalHours + } + // Save settings if err := h.getPersistence(r.Context()).SaveAIConfig(*settings); err != nil { log.Error().Err(err).Msg("Failed to save AI settings") @@ -1342,17 +1396,19 @@ func (h *AISettingsHandler) HandleUpdateAISettings(w http.ResponseWriter, r *htt UseProactiveThresholds: settings.UseProactiveThresholds, AvailableModels: nil, // Now populated via /api/ai/models endpoint // Multi-provider configuration - AnthropicConfigured: settings.HasProvider(config.AIProviderAnthropic), - OpenAIConfigured: settings.HasProvider(config.AIProviderOpenAI), - DeepSeekConfigured: settings.HasProvider(config.AIProviderDeepSeek), - GeminiConfigured: settings.HasProvider(config.AIProviderGemini), - OllamaConfigured: settings.HasProvider(config.AIProviderOllama), - OllamaBaseURL: settings.GetBaseURLForProvider(config.AIProviderOllama), - OpenAIBaseURL: settings.OpenAIBaseURL, - ConfiguredProviders: settings.GetConfiguredProviders(), - RequestTimeoutSeconds: settings.RequestTimeoutSeconds, - ControlLevel: settings.GetControlLevel(), - ProtectedGuests: settings.GetProtectedGuests(), + AnthropicConfigured: settings.HasProvider(config.AIProviderAnthropic), + OpenAIConfigured: settings.HasProvider(config.AIProviderOpenAI), + DeepSeekConfigured: settings.HasProvider(config.AIProviderDeepSeek), + GeminiConfigured: settings.HasProvider(config.AIProviderGemini), + OllamaConfigured: settings.HasProvider(config.AIProviderOllama), + OllamaBaseURL: settings.GetBaseURLForProvider(config.AIProviderOllama), + OpenAIBaseURL: settings.OpenAIBaseURL, + ConfiguredProviders: settings.GetConfiguredProviders(), + RequestTimeoutSeconds: settings.RequestTimeoutSeconds, + ControlLevel: settings.GetControlLevel(), + ProtectedGuests: settings.GetProtectedGuests(), + DiscoveryEnabled: settings.DiscoveryEnabled, + DiscoveryIntervalHours: settings.DiscoveryIntervalHours, } if err := utils.WriteJSONResponse(w, response); err != nil { @@ -5210,7 +5266,7 @@ func (h *AISettingsHandler) HandleUpdatePatrolAutonomy(w http.ResponseWriter, r // Validate autonomy level if !config.IsValidPatrolAutonomyLevel(req.AutonomyLevel) { writeErrorResponse(w, http.StatusBadRequest, "invalid_autonomy_level", - fmt.Sprintf("Invalid autonomy level: %s. Must be 'monitor', 'approval', or 'full'", req.AutonomyLevel), nil) + fmt.Sprintf("Invalid autonomy level: %s. Must be 'monitor', 'approval', 'full', or 'autonomous'", req.AutonomyLevel), nil) return } @@ -5249,6 +5305,13 @@ func (h *AISettingsHandler) HandleUpdatePatrolAutonomy(w http.ResponseWriter, r return } + // Reload config to update in-memory state + if err := aiService.LoadConfig(); err != nil { + // Log but don't fail - config was saved successfully + LogAuditEvent("patrol_autonomy_reload_warning", "", "", r.URL.Path, false, + fmt.Sprintf("Config saved but failed to reload: %v", err)) + } + // Log audit event username := getAuthUsername(h.getConfig(r.Context()), r) LogAuditEvent("patrol_autonomy_updated", username, GetClientIP(r), r.URL.Path, true, @@ -5301,6 +5364,90 @@ func (h *AISettingsHandler) HandleGetInvestigation(w http.ResponseWriter, r *htt json.NewEncoder(w).Encode(investigation) } +// HandleReapproveInvestigationFix creates a new approval from an investigation's proposed fix (POST /api/ai/findings/{id}/reapprove) +// This is useful when the original approval has expired but the user still wants to execute the fix. +func (h *AISettingsHandler) HandleReapproveInvestigationFix(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check license + if !h.GetAIService(r.Context()).HasLicenseFeature(license.FeatureAIAutoFix) { + writeErrorResponse(w, http.StatusForbidden, "license_required", "Pulse Patrol Auto-Fix feature requires Pro license", nil) + return + } + + // Extract finding ID from path + findingID := strings.TrimPrefix(r.URL.Path, "/api/ai/findings/") + findingID = strings.TrimSuffix(findingID, "/reapprove") + if findingID == "" { + writeErrorResponse(w, http.StatusBadRequest, "missing_id", "Finding ID is required", nil) + return + } + + aiService := h.GetAIService(r.Context()) + patrol := aiService.GetPatrolService() + if patrol == nil { + writeErrorResponse(w, http.StatusServiceUnavailable, "not_initialized", "Patrol service not initialized", nil) + return + } + + // Get investigation from orchestrator + orchestrator := patrol.GetInvestigationOrchestrator() + if orchestrator == nil { + writeErrorResponse(w, http.StatusServiceUnavailable, "not_initialized", "Investigation orchestrator not initialized", nil) + return + } + + inv := orchestrator.GetInvestigationByFinding(findingID) + if inv == nil { + writeErrorResponse(w, http.StatusNotFound, "not_found", "No investigation found for this finding", nil) + return + } + + // Check if investigation has a proposed fix + if inv.ProposedFix == nil || len(inv.ProposedFix.Commands) == 0 { + writeErrorResponse(w, http.StatusBadRequest, "no_fix", "Investigation has no proposed fix", nil) + return + } + + // Check approval store + store := approval.GetStore() + if store == nil { + writeErrorResponse(w, http.StatusServiceUnavailable, "not_initialized", "Approval store not initialized", nil) + return + } + + // Create new approval request + req := &approval.ApprovalRequest{ + ToolID: "investigation_fix", + Command: inv.ProposedFix.Commands[0], + TargetType: "investigation", + TargetID: findingID, + TargetName: inv.ProposedFix.Description, + Context: fmt.Sprintf("Re-approval of fix from investigation: %s", inv.ProposedFix.Description), + RiskLevel: approval.AssessRiskLevel(inv.ProposedFix.Commands[0], "investigation"), + } + + if err := store.CreateApproval(req); err != nil { + writeErrorResponse(w, http.StatusInternalServerError, "create_failed", "Failed to create approval: "+err.Error(), nil) + return + } + + log.Info(). + Str("finding_id", findingID). + Str("approval_id", req.ID). + Str("command", truncateForLog(req.Command, 100)). + Msg("Re-created approval for investigation fix") + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "approval_id": req.ID, + "message": "Approval created. You can now approve and execute the fix.", + }) +} + // HandleGetInvestigationMessages returns chat messages for an investigation (GET /api/ai/findings/{id}/investigation/messages) func (h *AISettingsHandler) HandleGetInvestigationMessages(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { diff --git a/internal/api/ai_handlers_cost_export_additional_test.go b/internal/api/ai_handlers_cost_export_additional_test.go new file mode 100644 index 000000000..93d8da553 --- /dev/null +++ b/internal/api/ai_handlers_cost_export_additional_test.go @@ -0,0 +1,116 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestHandleExportAICostHistory_JSON(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + + events := []config.AIUsageEventRecord{ + { + Timestamp: time.Now().UTC(), + Provider: "openai", + RequestModel: "gpt-4o-mini", + UseCase: "chat", + InputTokens: 42, + OutputTokens: 17, + TargetType: "vm", + TargetID: "vm-1", + FindingID: "finding-1", + }, + } + if err := persistence.SaveAIUsageHistory(events); err != nil { + t.Fatalf("SaveAIUsageHistory: %v", err) + } + + handler := newTestAISettingsHandler(cfg, persistence, nil) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/cost/export?days=7&format=json", nil) + rec := httptest.NewRecorder() + + handler.HandleExportAICostHistory(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") { + t.Fatalf("expected json content type") + } + + var resp struct { + Days int `json:"days"` + Events []struct { + Provider string `json:"provider"` + PricingKnown bool `json:"pricing_known"` + } `json:"events"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp.Days != 7 { + t.Fatalf("days = %d, want 7", resp.Days) + } + if len(resp.Events) != 1 { + t.Fatalf("events = %d, want 1", len(resp.Events)) + } + if resp.Events[0].Provider != "openai" { + t.Fatalf("provider = %s, want openai", resp.Events[0].Provider) + } + if !resp.Events[0].PricingKnown { + t.Fatalf("expected pricing known") + } +} + +func TestHandleExportAICostHistory_CSV(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + + events := []config.AIUsageEventRecord{ + { + Timestamp: time.Now().UTC(), + Provider: "openai", + RequestModel: "gpt-4o-mini", + UseCase: "chat", + InputTokens: 5, + OutputTokens: 3, + TargetType: "node", + TargetID: "node-1", + }, + } + if err := persistence.SaveAIUsageHistory(events); err != nil { + t.Fatalf("SaveAIUsageHistory: %v", err) + } + + handler := newTestAISettingsHandler(cfg, persistence, nil) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/cost/export?days=1&format=csv", nil) + rec := httptest.NewRecorder() + + handler.HandleExportAICostHistory(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if !strings.Contains(rec.Header().Get("Content-Type"), "text/csv") { + t.Fatalf("expected csv content type") + } + + lines := strings.Split(strings.TrimSpace(rec.Body.String()), "\n") + if len(lines) < 2 { + t.Fatalf("expected header and data rows") + } + if !strings.HasPrefix(lines[0], "timestamp,provider,request_model") { + t.Fatalf("unexpected header: %s", lines[0]) + } +} diff --git a/internal/api/ai_handlers_helpers_test.go b/internal/api/ai_handlers_helpers_test.go new file mode 100644 index 000000000..af1367a21 --- /dev/null +++ b/internal/api/ai_handlers_helpers_test.go @@ -0,0 +1,316 @@ +package api + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/chat" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/circuit" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/forecast" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/learning" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/proxmox" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/remediation" + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +type stubStateProvider struct{} + +func (s *stubStateProvider) GetState() models.StateSnapshot { + return models.StateSnapshot{} +} + +type fakeChatWrapper struct { + *chat.Service +} + +func newTestAISettingsHandlerLite() *AISettingsHandler { + return &AISettingsHandler{ + legacyAIService: ai.NewService(nil, nil), + aiServices: make(map[string]*ai.Service), + } +} + +func TestPreviewTitle(t *testing.T) { + cases := map[ai.FindingCategory]string{ + ai.FindingCategoryPerformance: "Performance issue detected", + ai.FindingCategoryCapacity: "Capacity issue detected", + ai.FindingCategoryReliability: "Reliability issue detected", + ai.FindingCategoryBackup: "Backup issue detected", + ai.FindingCategorySecurity: "Security issue detected", + ai.FindingCategory("other"): "Potential issue detected", + } + + for category, expected := range cases { + if got := previewTitle(category); got != expected { + t.Fatalf("category %s expected %q, got %q", category, expected, got) + } + } +} + +func TestPreviewResourceName(t *testing.T) { + cases := map[string]string{ + "node": "Node", + "vm": "VM", + "container": "Container", + "oci_container": "Container", + "docker_host": "Docker host", + "docker_container": "Docker container", + "storage": "Storage", + "pbs": "PBS server", + "pbs_datastore": "PBS datastore", + "pbs_job": "PBS job", + "host": "Host", + "host_raid": "RAID array", + "host_sensor": "Host sensor", + "unknown": "Resource", + } + + for resourceType, expected := range cases { + if got := previewResourceName(resourceType); got != expected { + t.Fatalf("resource %s expected %q, got %q", resourceType, expected, got) + } + } +} + +func TestRedactFindingsForPreview(t *testing.T) { + now := time.Now() + finding := &ai.Finding{ + Key: "key", + ResourceID: "res-1", + ResourceName: "db-1", + ResourceType: "vm", + Node: "node-1", + Category: ai.FindingCategoryPerformance, + Title: "Original title", + Description: "Description", + Recommendation: "Recommendation", + Evidence: "Evidence", + AlertID: "alert-1", + AcknowledgedAt: &now, + SnoozedUntil: &now, + ResolvedAt: &now, + AutoResolved: true, + DismissedReason: "reason", + UserNote: "note", + TimesRaised: 3, + Suppressed: true, + Source: "original", + } + + redacted := redactFindingsForPreview([]*ai.Finding{nil, finding}) + if len(redacted) != 1 { + t.Fatalf("expected 1 redacted finding, got %d", len(redacted)) + } + got := redacted[0] + if got.Key != "" || got.ResourceID != "" || got.Node != "" { + t.Fatalf("expected identifiers to be cleared") + } + if got.ResourceName != "VM" { + t.Fatalf("expected resource name to be preview value, got %q", got.ResourceName) + } + if got.Title != "Performance issue detected" { + t.Fatalf("expected preview title, got %q", got.Title) + } + if got.Description != "Upgrade to view full analysis." { + t.Fatalf("expected preview description") + } + if got.Recommendation != "" || got.Evidence != "" || got.AlertID != "" { + t.Fatalf("expected sensitive fields to be cleared") + } + if got.AcknowledgedAt != nil || got.SnoozedUntil != nil || got.ResolvedAt != nil { + t.Fatalf("expected timestamps to be cleared") + } + if got.AutoResolved || got.DismissedReason != "" || got.UserNote != "" { + t.Fatalf("expected status fields to be cleared") + } + if got.TimesRaised != 0 || got.Suppressed { + t.Fatalf("expected counters to be cleared") + } + if got.Source != "preview" { + t.Fatalf("expected source to be preview") + } + if finding.Title != "Original title" { + t.Fatalf("expected original finding to remain unchanged") + } +} + +func TestRedactPatrolRunHistory(t *testing.T) { + runs := []ai.PatrolRunRecord{ + { + ID: "run-1", + AIAnalysis: "analysis", + InputTokens: 100, + OutputTokens: 200, + FindingIDs: []string{"a", "b"}, + }, + } + + redacted := redactPatrolRunHistory(runs) + if redacted[0].AIAnalysis != "" || redacted[0].InputTokens != 0 || redacted[0].OutputTokens != 0 { + t.Fatalf("expected AI analysis fields to be cleared") + } + if redacted[0].FindingIDs != nil { + t.Fatalf("expected finding IDs to be cleared") + } +} + +func TestIsMCPToolCall(t *testing.T) { + handler := &AISettingsHandler{} + if !handler.isMCPToolCall("pulse_control_guest(guest_id='102')") { + t.Fatalf("expected MCP tool call to be detected") + } + if !handler.isMCPToolCall("default_api:pulse_get_resource(id='1')") { + t.Fatalf("expected MCP tool call with default_api prefix") + } + if handler.isMCPToolCall("echo hello") { + t.Fatalf("expected non-tool command to be false") + } +} + +func TestCleanTargetHost(t *testing.T) { + handler := &AISettingsHandler{} + if got := handler.cleanTargetHost("delly (The container's host is 'delly')"); got != "delly" { + t.Fatalf("expected cleaned host, got %q", got) + } + if got := handler.cleanTargetHost("delly extra"); got != "delly" { + t.Fatalf("expected first token, got %q", got) + } + if got := handler.cleanTargetHost(" delly "); got != "delly" { + t.Fatalf("expected trimmed host, got %q", got) + } + if got := handler.cleanTargetHost(""); got != "" { + t.Fatalf("expected empty host") + } +} + +func TestSplitToolArgs(t *testing.T) { + handler := &AISettingsHandler{} + args := "action='start', guest_id=\"102\", note='hello, world', path=\"/tmp/a,b\", escaped=\"\\\"quote\\\"\"" + parts := handler.splitToolArgs(args) + expected := []string{ + "action='start'", + "guest_id=\"102\"", + "note='hello, world'", + "path=\"/tmp/a,b\"", + "escaped=\"\\\"quote\\\"\"", + } + if len(parts) != len(expected) { + t.Fatalf("expected %d parts, got %d", len(expected), len(parts)) + } + for i := range expected { + if strings.TrimSpace(parts[i]) != expected[i] { + t.Fatalf("expected part %q, got %q", expected[i], parts[i]) + } + } +} + +func TestParseMCPToolCall(t *testing.T) { + handler := &AISettingsHandler{} + tool, args, err := handler.parseMCPToolCall("default_api:pulse_control_guest(guest_id=\"102\", action='start')") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tool != "pulse_control_guest" { + t.Fatalf("expected tool name pulse_control_guest, got %q", tool) + } + if args["guest_id"] != "102" || args["action"] != "start" { + t.Fatalf("unexpected args: %#v", args) + } + + tool, args, err = handler.parseMCPToolCall("pulse_run_command()") + if err != nil { + t.Fatalf("unexpected error for empty args: %v", err) + } + if tool != "pulse_run_command" || len(args) != 0 { + t.Fatalf("expected empty args, got %#v", args) + } + + if _, _, err = handler.parseMCPToolCall("pulse_control_guest"); err == nil { + t.Fatalf("expected error for missing parenthesis") + } + if _, _, err = handler.parseMCPToolCall("pulse_control_guest("); err == nil { + t.Fatalf("expected error for missing closing parenthesis") + } +} + +func TestExecuteMCPToolFix_Errors(t *testing.T) { + handler := &AISettingsHandler{} + if _, _, err := handler.executeMCPToolFix(context.Background(), "pulse_control_guest()", ""); err == nil { + t.Fatalf("expected error when chat handler is missing") + } + + handler.chatHandler = &AIHandler{} + if _, _, err := handler.executeMCPToolFix(context.Background(), "pulse_control_guest()", ""); err == nil { + t.Fatalf("expected error when chat service is missing") + } + + handler.chatHandler.legacyService = &fakeChatWrapper{} + if _, _, err := handler.executeMCPToolFix(context.Background(), "pulse_control_guest()", ""); err == nil { + t.Fatalf("expected error for chat service type mismatch") + } +} + +func TestAISettingsHandler_Setters(t *testing.T) { + handler := newTestAISettingsHandlerLite() + stateProvider := &stubStateProvider{} + handler.SetStateProvider(stateProvider) + if handler.GetStateProvider() != stateProvider { + t.Fatalf("expected state provider to be set") + } + + breaker := &circuit.Breaker{} + handler.SetCircuitBreaker(breaker) + if handler.GetCircuitBreaker() != breaker { + t.Fatalf("expected circuit breaker to be set") + } + + learningStore := &learning.LearningStore{} + handler.SetLearningStore(learningStore) + if handler.GetLearningStore() != learningStore { + t.Fatalf("expected learning store to be set") + } + + forecastSvc := &forecast.Service{} + handler.SetForecastService(forecastSvc) + if handler.GetForecastService() != forecastSvc { + t.Fatalf("expected forecast service to be set") + } + + correlator := &proxmox.EventCorrelator{} + handler.SetProxmoxCorrelator(correlator) + if handler.GetProxmoxCorrelator() != correlator { + t.Fatalf("expected correlator to be set") + } + + engine := &remediation.Engine{} + handler.SetRemediationEngine(engine) + if handler.GetRemediationEngine() != engine { + t.Fatalf("expected remediation engine to be set") + } +} + +func TestAISettingsHandler_RemoveTenantService(t *testing.T) { + handler := newTestAISettingsHandlerLite() + handler.aiServices["org-1"] = ai.NewService(nil, nil) + handler.aiServices["default"] = ai.NewService(nil, nil) + + handler.RemoveTenantService("org-1") + if _, ok := handler.aiServices["org-1"]; ok { + t.Fatalf("expected org-1 to be removed") + } + + handler.RemoveTenantService("default") + if _, ok := handler.aiServices["default"]; !ok { + t.Fatalf("expected default to remain") + } +} + +func TestAISettingsHandler_IsAIEnabled(t *testing.T) { + handler := newTestAISettingsHandlerLite() + if handler.IsAIEnabled(context.Background()) { + t.Fatalf("expected AI to be disabled by default") + } +} diff --git a/internal/api/ai_handlers_intelligence_additional_test.go b/internal/api/ai_handlers_intelligence_additional_test.go new file mode 100644 index 000000000..f5d109e1a --- /dev/null +++ b/internal/api/ai_handlers_intelligence_additional_test.go @@ -0,0 +1,49 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai" + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestHandleGetIntelligence_PatrolUnavailable(t *testing.T) { + persistence := config.NewConfigPersistence(t.TempDir()) + handler := &AISettingsHandler{ + legacyAIService: ai.NewService(persistence, nil), + } + + req := httptest.NewRequest(http.MethodGet, "/api/ai/intelligence", nil) + rr := httptest.NewRecorder() + + handler.HandleGetIntelligence(rr, req) + if rr.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", rr.Code) + } + + var payload map[string]interface{} + if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if payload["error"] != "Pulse Patrol service not available" { + t.Fatalf("error = %v, want Pulse Patrol service not available", payload["error"]) + } +} + +func TestHandlePatrolStream_PatrolUnavailable(t *testing.T) { + persistence := config.NewConfigPersistence(t.TempDir()) + handler := &AISettingsHandler{ + legacyAIService: ai.NewService(persistence, nil), + } + + req := httptest.NewRequest(http.MethodGet, "/api/ai/patrol/stream", nil) + rr := httptest.NewRecorder() + + handler.HandlePatrolStream(rr, req) + if rr.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", rr.Code) + } +} diff --git a/internal/api/ai_handlers_investigation_additional_test.go b/internal/api/ai_handlers_investigation_additional_test.go new file mode 100644 index 000000000..29cba8754 --- /dev/null +++ b/internal/api/ai_handlers_investigation_additional_test.go @@ -0,0 +1,640 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/rcourtman/pulse-go-rewrite/internal/agentexec" + "github.com/rcourtman/pulse-go-rewrite/internal/ai" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/approval" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/chat" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/investigation" + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +type stubInvestigationOrchestrator struct { + session *ai.InvestigationSession + reinvestigateCh chan reinvestigateCall + lastAutonomy string + lastReinvestigate string +} + +type reinvestigateCall struct { + findingID string + autonomy string +} + +func (s *stubInvestigationOrchestrator) InvestigateFinding(ctx context.Context, finding *ai.InvestigationFinding, autonomyLevel string) error { + return nil +} + +func (s *stubInvestigationOrchestrator) GetInvestigationByFinding(findingID string) *ai.InvestigationSession { + if s.session == nil || s.session.FindingID != findingID { + return nil + } + return s.session +} + +func (s *stubInvestigationOrchestrator) GetRunningCount() int { + return 0 +} + +func (s *stubInvestigationOrchestrator) GetFixedCount() int { + return 0 +} + +func (s *stubInvestigationOrchestrator) CanStartInvestigation() bool { + return true +} + +func (s *stubInvestigationOrchestrator) ReinvestigateFinding(ctx context.Context, findingID, autonomyLevel string) error { + s.lastReinvestigate = findingID + s.lastAutonomy = autonomyLevel + if s.reinvestigateCh != nil { + s.reinvestigateCh <- reinvestigateCall{findingID: findingID, autonomy: autonomyLevel} + } + return nil +} + +type stubChatService struct { + messages []ai.ChatMessage +} + +func (s *stubChatService) CreateSession(ctx context.Context) (*ai.ChatSession, error) { + return &ai.ChatSession{ID: "session-1"}, nil +} + +func (s *stubChatService) ExecuteStream(ctx context.Context, req ai.ChatExecuteRequest, callback ai.ChatStreamCallback) error { + return nil +} + +func (s *stubChatService) GetMessages(ctx context.Context, sessionID string) ([]ai.ChatMessage, error) { + return s.messages, nil +} + +func (s *stubChatService) DeleteSession(ctx context.Context, sessionID string) error { + return nil +} + +func TestFindingsStoreWrapper_GetAndUpdate(t *testing.T) { + store := ai.NewFindingsStore() + store.Add(&ai.Finding{ + ID: "finding-1", + Severity: ai.FindingSeverityWarning, + Category: ai.FindingCategoryPerformance, + ResourceID: "res-1", + ResourceName: "res-1", + ResourceType: "host", + Title: "title", + Description: "desc", + }) + + wrapper := &findingsStoreWrapper{store: store} + found := wrapper.Get("finding-1") + if found == nil || found.GetID() != "finding-1" { + t.Fatalf("expected finding to be returned") + } + if wrapper.Get("missing") != nil { + t.Fatalf("expected missing finding to return nil") + } + + updated := wrapper.UpdateInvestigation("finding-1", "session-1", "running", "outcome", nil, 2) + if !updated { + t.Fatalf("expected UpdateInvestigation to return true") + } + got := store.Get("finding-1") + if got.InvestigationOutcome != "outcome" || got.InvestigationStatus != "running" || got.InvestigationAttempts != 2 { + t.Fatalf("unexpected investigation update: %+v", got) + } + + nilWrapper := &findingsStoreWrapper{store: nil} + if nilWrapper.Get("finding-1") != nil { + t.Fatalf("expected nil store to return nil") + } + if nilWrapper.UpdateInvestigation("finding-1", "session-1", "running", "outcome", nil, 1) { + t.Fatalf("expected nil store update to return false") + } +} + +func TestHandleClearAllFindings(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + svc := handler.GetAIService(context.Background()) + svc.SetStateProvider(&MockStateProvider{}) + patrol := svc.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service to be initialized") + } + + findings := patrol.GetFindings() + findings.Add(&ai.Finding{ + ID: "finding-1", + Severity: ai.FindingSeverityWarning, + Category: ai.FindingCategoryPerformance, + ResourceID: "res-1", + ResourceName: "res-1", + ResourceType: "host", + Title: "title", + Description: "desc", + }) + findings.Add(&ai.Finding{ + ID: "finding-2", + Severity: ai.FindingSeverityCritical, + Category: ai.FindingCategorySecurity, + ResourceID: "res-2", + ResourceName: "res-2", + ResourceType: "host", + Title: "title", + Description: "desc", + }) + + req := httptest.NewRequest(http.MethodDelete, "/api/ai/patrol/findings?confirm=true", nil) + rec := httptest.NewRecorder() + handler.HandleClearAllFindings(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp["cleared"].(float64) != 2 { + t.Fatalf("expected 2 findings cleared, got %v", resp["cleared"]) + } + if got := patrol.GetFindings().GetAll(nil); len(got) != 0 { + t.Fatalf("expected findings store to be empty") + } +} + +func TestHandleListApprovals(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + prevStore := approval.GetStore() + t.Cleanup(func() { approval.SetStore(prevStore) }) + + store, err := approval.NewStore(approval.StoreConfig{ + DataDir: tmp, + DisablePersistence: true, + }) + if err != nil { + t.Fatalf("failed to create approval store: %v", err) + } + approval.SetStore(store) + + if err := store.CreateApproval(&approval.ApprovalRequest{Command: "echo ok"}); err != nil { + t.Fatalf("failed to create approval: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/ai/approvals", nil) + rec := httptest.NewRecorder() + handler.HandleListApprovals(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp struct { + Approvals []approval.ApprovalRequest `json:"approvals"` + Stats map[string]int `json:"stats"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if len(resp.Approvals) != 1 { + t.Fatalf("expected 1 approval, got %d", len(resp.Approvals)) + } + if resp.Stats["pending"] != 1 { + t.Fatalf("expected pending approvals to be 1, got %d", resp.Stats["pending"]) + } +} + +func TestHandlePatrolAutonomyGetAndUpdate(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + + aiCfg := config.NewDefaultAIConfig() + aiCfg.PatrolAutonomyLevel = config.PatrolAutonomyApproval + aiCfg.PatrolInvestigationBudget = 8 + aiCfg.PatrolInvestigationTimeoutSec = 120 + aiCfg.PatrolCriticalRequireApproval = true + if err := persistence.SaveAIConfig(*aiCfg); err != nil { + t.Fatalf("SaveAIConfig: %v", err) + } + + handler := newTestAISettingsHandler(cfg, persistence, nil) + + getReq := httptest.NewRequest(http.MethodGet, "/api/ai/patrol/autonomy", nil) + getRec := httptest.NewRecorder() + handler.HandleGetPatrolAutonomy(getRec, getReq) + + if getRec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", getRec.Code, getRec.Body.String()) + } + + var getResp PatrolAutonomySettings + if err := json.Unmarshal(getRec.Body.Bytes(), &getResp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if getResp.AutonomyLevel != config.PatrolAutonomyApproval || getResp.InvestigationBudget != 8 || getResp.InvestigationTimeoutSec != 120 { + t.Fatalf("unexpected autonomy settings: %+v", getResp) + } + + update := PatrolAutonomySettings{ + AutonomyLevel: config.PatrolAutonomyFull, + InvestigationBudget: 3, + InvestigationTimeoutSec: 10, + CriticalRequireApproval: false, + } + body, _ := json.Marshal(update) + updateReq := httptest.NewRequest(http.MethodPut, "/api/ai/patrol/autonomy", strings.NewReader(string(body))) + updateRec := httptest.NewRecorder() + handler.HandleUpdatePatrolAutonomy(updateRec, updateReq) + + if updateRec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", updateRec.Code, updateRec.Body.String()) + } + + var updateResp map[string]interface{} + if err := json.Unmarshal(updateRec.Body.Bytes(), &updateResp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + settings := updateResp["settings"].(map[string]interface{}) + if settings["autonomy_level"] != config.PatrolAutonomyFull { + t.Fatalf("unexpected autonomy level %v", settings["autonomy_level"]) + } + if settings["investigation_budget"].(float64) != 5 { + t.Fatalf("expected clamped budget to 5, got %v", settings["investigation_budget"]) + } + if settings["investigation_timeout_sec"].(float64) != 60 { + t.Fatalf("expected clamped timeout to 60, got %v", settings["investigation_timeout_sec"]) + } + + loaded, err := persistence.LoadAIConfig() + if err != nil { + t.Fatalf("LoadAIConfig: %v", err) + } + if loaded.PatrolAutonomyLevel != config.PatrolAutonomyFull || loaded.PatrolInvestigationBudget != 5 || loaded.PatrolInvestigationTimeoutSec != 60 { + t.Fatalf("unexpected persisted settings: %+v", loaded) + } +} + +func TestHandleGetInvestigation(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + svc := handler.GetAIService(context.Background()) + svc.SetStateProvider(&MockStateProvider{}) + patrol := svc.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service") + } + + session := &ai.InvestigationSession{ + ID: "inv-1", + FindingID: "finding-1", + SessionID: "session-1", + Status: "completed", + StartedAt: time.Now(), + } + orchestrator := &stubInvestigationOrchestrator{session: session} + patrol.SetInvestigationOrchestrator(orchestrator) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/findings/finding-1/investigation", nil) + rec := httptest.NewRecorder() + handler.HandleGetInvestigation(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp ai.InvestigationSession + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.ID != "inv-1" || resp.FindingID != "finding-1" { + t.Fatalf("unexpected investigation response: %+v", resp) + } +} + +func TestHandleReapproveInvestigationFix(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + prevStore := approval.GetStore() + t.Cleanup(func() { approval.SetStore(prevStore) }) + + store, err := approval.NewStore(approval.StoreConfig{ + DataDir: tmp, + DisablePersistence: true, + }) + if err != nil { + t.Fatalf("failed to create approval store: %v", err) + } + approval.SetStore(store) + + svc := handler.GetAIService(context.Background()) + svc.SetStateProvider(&MockStateProvider{}) + patrol := svc.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service") + } + + session := &ai.InvestigationSession{ + ID: "inv-1", + FindingID: "finding-1", + SessionID: "session-1", + Status: "completed", + StartedAt: time.Now(), + ProposedFix: &ai.InvestigationFix{ + ID: "fix-1", + Description: "Restart service", + Commands: []string{"systemctl restart foo"}, + }, + } + orchestrator := &stubInvestigationOrchestrator{session: session} + patrol.SetInvestigationOrchestrator(orchestrator) + + req := httptest.NewRequest(http.MethodPost, "/api/ai/findings/finding-1/reapprove", nil) + rec := httptest.NewRecorder() + handler.HandleReapproveInvestigationFix(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + approvalID := resp["approval_id"] + if approvalID == "" { + t.Fatalf("expected approval_id in response") + } + if _, ok := store.GetApproval(approvalID); !ok { + t.Fatalf("expected approval %s to exist", approvalID) + } +} + +func TestHandleGetInvestigationMessages(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + svc := handler.GetAIService(context.Background()) + svc.SetStateProvider(&MockStateProvider{}) + svc.SetChatService(&stubChatService{ + messages: []ai.ChatMessage{ + {ID: "msg-1", Role: "assistant", Content: "hello"}, + }, + }) + patrol := svc.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service") + } + + session := &ai.InvestigationSession{ + ID: "inv-1", + FindingID: "finding-1", + SessionID: "session-1", + Status: "completed", + StartedAt: time.Now(), + } + orchestrator := &stubInvestigationOrchestrator{session: session} + patrol.SetInvestigationOrchestrator(orchestrator) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/findings/finding-1/investigation/messages", nil) + rec := httptest.NewRecorder() + handler.HandleGetInvestigationMessages(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp["session_id"] != "session-1" { + t.Fatalf("unexpected session_id %v", resp["session_id"]) + } + msgs := resp["messages"].([]interface{}) + if len(msgs) != 1 { + t.Fatalf("expected 1 message, got %d", len(msgs)) + } +} + +func TestHandleReinvestigateFinding(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + svc := handler.GetAIService(context.Background()) + aiCfg := config.NewDefaultAIConfig() + aiCfg.PatrolAutonomyLevel = config.PatrolAutonomyApproval + if err := persistence.SaveAIConfig(*aiCfg); err != nil { + t.Fatalf("SaveAIConfig: %v", err) + } + if err := svc.LoadConfig(); err != nil { + t.Fatalf("LoadConfig: %v", err) + } + svc.SetStateProvider(&MockStateProvider{}) + + patrol := svc.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service") + } + + callCh := make(chan reinvestigateCall, 1) + orchestrator := &stubInvestigationOrchestrator{reinvestigateCh: callCh} + patrol.SetInvestigationOrchestrator(orchestrator) + + req := httptest.NewRequest(http.MethodPost, "/api/ai/findings/finding-1/reinvestigate", nil) + rec := httptest.NewRecorder() + handler.HandleReinvestigateFinding(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + select { + case call := <-callCh: + if call.findingID != "finding-1" || call.autonomy != config.PatrolAutonomyApproval { + t.Fatalf("unexpected reinvestigation call: %+v", call) + } + case <-time.After(2 * time.Second): + t.Fatalf("expected reinvestigation to be triggered") + } +} + +func TestExecuteInvestigationFix_MCPTool(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + svc := handler.GetAIService(context.Background()) + svc.SetStateProvider(&MockStateProvider{}) + patrol := svc.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service") + } + + findingID := "finding-1" + findings := patrol.GetFindings() + findings.Add(&ai.Finding{ + ID: findingID, + Severity: ai.FindingSeverityWarning, + Category: ai.FindingCategoryPerformance, + ResourceID: "res-1", + ResourceName: "res-1", + ResourceType: "host", + Title: "title", + Description: "desc", + }) + + store := investigation.NewStore("") + session := store.Create(findingID, "session-1") + session.ProposedFix = &investigation.Fix{ + ID: "fix-1", + Description: "Get capabilities", + Commands: []string{"pulse_get_capabilities()"}, + } + if !store.Update(session) { + t.Fatalf("failed to update investigation session") + } + handler.investigationStores = map[string]*investigation.Store{"default": store} + + chatSvc := chat.NewService(chat.Config{AIConfig: config.NewDefaultAIConfig()}) + handler.chatHandler = &AIHandler{legacyService: chatSvc} + + req := httptest.NewRequest(http.MethodPost, "/api/ai/approvals/exec", nil) + rec := httptest.NewRecorder() + handler.executeInvestigationFix(rec, req, &approval.ApprovalRequest{ + ID: "approval-1", + ToolID: "investigation_fix", + Command: "pulse_get_capabilities()", + TargetID: findingID, + }) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp["success"] != true { + t.Fatalf("expected success response, got %v", resp["success"]) + } + + updatedFinding := findings.Get(findingID) + if updatedFinding == nil || updatedFinding.InvestigationOutcome != string(investigation.OutcomeFixExecuted) { + t.Fatalf("unexpected finding outcome: %+v", updatedFinding) + } + + updatedSession := store.Get(session.ID) + if updatedSession == nil || updatedSession.Outcome != investigation.OutcomeFixExecuted { + t.Fatalf("unexpected investigation outcome: %+v", updatedSession) + } +} + +func wsURLForHTTP(url string) string { + if strings.HasPrefix(url, "https://") { + return "wss://" + strings.TrimPrefix(url, "https://") + } + return "ws://" + strings.TrimPrefix(url, "http://") +} + +func registerAgent(t *testing.T, url, agentID, hostname string) *websocket.Conn { + t.Helper() + + conn, _, err := websocket.DefaultDialer.Dial(wsURLForHTTP(url), nil) + if err != nil { + t.Fatalf("failed to dial websocket: %v", err) + } + + msg := agentexec.Message{ + Type: agentexec.MsgTypeAgentRegister, + Timestamp: time.Now(), + Payload: agentexec.AgentRegisterPayload{ + AgentID: agentID, + Hostname: hostname, + Version: "1.0.0", + Platform: "linux", + Token: "ok", + }, + } + if err := conn.WriteJSON(msg); err != nil { + conn.Close() + t.Fatalf("failed to write registration message: %v", err) + } + + _, raw, err := conn.ReadMessage() + if err != nil { + conn.Close() + t.Fatalf("failed to read registration response: %v", err) + } + + var resp agentexec.Message + if err := json.Unmarshal(raw, &resp); err != nil { + conn.Close() + t.Fatalf("failed to decode registration response: %v", err) + } + payloadBytes, _ := json.Marshal(resp.Payload) + var reg agentexec.RegisteredPayload + if err := json.Unmarshal(payloadBytes, ®); err != nil { + conn.Close() + t.Fatalf("failed to decode registration payload: %v", err) + } + if !reg.Success { + conn.Close() + t.Fatalf("registration failed: %s", reg.Message) + } + + return conn +} + +func TestFindAgentForTarget(t *testing.T) { + server := agentexec.NewServer(func(string) bool { return true }) + ts := httptest.NewServer(http.HandlerFunc(server.HandleWebSocket)) + defer ts.Close() + + conn1 := registerAgent(t, ts.URL, "agent-1", "host-a") + defer conn1.Close() + conn2 := registerAgent(t, ts.URL, "agent-2", "host-b") + defer conn2.Close() + + handler := &AISettingsHandler{agentServer: server} + + if got := handler.findAgentForTarget("host-a"); got != "agent-1" { + t.Fatalf("expected agent-1, got %q", got) + } + if got := handler.findAgentForTarget("agent-2"); got != "agent-2" { + t.Fatalf("expected agent-2, got %q", got) + } + if got := handler.findAgentForTarget(""); got != "" { + t.Fatalf("expected empty agent when multiple connected, got %q", got) + } +} diff --git a/internal/api/ai_handlers_oauth_test.go b/internal/api/ai_handlers_oauth_test.go new file mode 100644 index 000000000..3e7069a60 --- /dev/null +++ b/internal/api/ai_handlers_oauth_test.go @@ -0,0 +1,161 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai/providers" +) + +func resetOAuthSessions() { + oauthSessionsMu.Lock() + oauthSessions = make(map[string]*providers.OAuthSession) + oauthSessionsMu.Unlock() +} + +func TestHandleOAuthStart(t *testing.T) { + resetOAuthSessions() + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodGet, "/api/ai/oauth/start", nil) + rr := httptest.NewRecorder() + + handler.HandleOAuthStart(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + + var resp map[string]string + if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp["auth_url"] == "" || resp["state"] == "" { + t.Fatalf("expected auth_url and state in response") + } + if !strings.Contains(resp["auth_url"], "claude.ai/oauth/authorize") { + t.Fatalf("expected auth_url to contain authorize endpoint") + } + + oauthSessionsMu.Lock() + delete(oauthSessions, resp["state"]) + oauthSessionsMu.Unlock() +} + +func TestHandleOAuthStart_MethodNotAllowed(t *testing.T) { + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodPut, "/api/ai/oauth/start", nil) + rr := httptest.NewRecorder() + + handler.HandleOAuthStart(rr, req) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected 405, got %d", rr.Code) + } +} + +func TestHandleOAuthExchange_MethodNotAllowed(t *testing.T) { + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodGet, "/api/ai/oauth/exchange", nil) + rr := httptest.NewRecorder() + + handler.HandleOAuthExchange(rr, req) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected 405, got %d", rr.Code) + } +} + +func TestHandleOAuthExchange_InvalidBody(t *testing.T) { + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodPost, "/api/ai/oauth/exchange", strings.NewReader("{")) + rr := httptest.NewRecorder() + + handler.HandleOAuthExchange(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rr.Code) + } +} + +func TestHandleOAuthExchange_MissingFields(t *testing.T) { + handler := &AISettingsHandler{} + body := []byte(`{"code":"","state":""}`) + req := httptest.NewRequest(http.MethodPost, "/api/ai/oauth/exchange", bytes.NewReader(body)) + rr := httptest.NewRecorder() + + handler.HandleOAuthExchange(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rr.Code) + } +} + +func TestHandleOAuthExchange_UnknownState(t *testing.T) { + resetOAuthSessions() + handler := &AISettingsHandler{} + body := []byte(`{"code":"code123","state":"missing"}`) + req := httptest.NewRequest(http.MethodPost, "/api/ai/oauth/exchange", bytes.NewReader(body)) + rr := httptest.NewRecorder() + + handler.HandleOAuthExchange(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rr.Code) + } +} + +func TestHandleOAuthCallback_ErrorParam(t *testing.T) { + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodGet, "/api/ai/oauth/callback?error=access_denied&error_description=no", nil) + rr := httptest.NewRecorder() + + handler.HandleOAuthCallback(rr, req) + if rr.Code != http.StatusTemporaryRedirect { + t.Fatalf("expected 307, got %d", rr.Code) + } + location := rr.Header().Get("Location") + if !strings.Contains(location, "ai_oauth_error=access_denied") { + t.Fatalf("expected redirect to include error, got %q", location) + } +} + +func TestHandleOAuthCallback_MissingParams(t *testing.T) { + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodGet, "/api/ai/oauth/callback?code=abc", nil) + rr := httptest.NewRecorder() + + handler.HandleOAuthCallback(rr, req) + if rr.Code != http.StatusTemporaryRedirect { + t.Fatalf("expected 307, got %d", rr.Code) + } + location := rr.Header().Get("Location") + if !strings.Contains(location, "ai_oauth_error=missing_params") { + t.Fatalf("expected missing_params redirect, got %q", location) + } +} + +func TestHandleOAuthCallback_InvalidState(t *testing.T) { + resetOAuthSessions() + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodGet, "/api/ai/oauth/callback?code=abc&state=missing", nil) + rr := httptest.NewRecorder() + + handler.HandleOAuthCallback(rr, req) + if rr.Code != http.StatusTemporaryRedirect { + t.Fatalf("expected 307, got %d", rr.Code) + } + location := rr.Header().Get("Location") + if !strings.Contains(location, "ai_oauth_error=invalid_state") { + t.Fatalf("expected invalid_state redirect, got %q", location) + } +} + +func TestHandleOAuthDisconnect_MethodNotAllowed(t *testing.T) { + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodGet, "/api/ai/oauth/disconnect", nil) + rr := httptest.NewRecorder() + + handler.HandleOAuthDisconnect(rr, req) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected 405, got %d", rr.Code) + } +} diff --git a/internal/api/ai_handlers_patrol_actions_additional_test.go b/internal/api/ai_handlers_patrol_actions_additional_test.go new file mode 100644 index 000000000..db3a48f44 --- /dev/null +++ b/internal/api/ai_handlers_patrol_actions_additional_test.go @@ -0,0 +1,347 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/learning" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/unified" + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func setupAIHandlerWithPatrol(t *testing.T) (*AISettingsHandler, *ai.PatrolService, *unified.UnifiedStore, *learning.LearningStore) { + t.Helper() + + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + handler.legacyAIService.SetStateProvider(&stubStateProvider{}) + + patrol := handler.legacyAIService.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service to be initialized") + } + + unifiedStore := unified.NewUnifiedStore(unified.DefaultAlertToFindingConfig()) + handler.SetUnifiedStore(unifiedStore) + + learningStore := learning.NewLearningStore(learning.LearningStoreConfig{}) + handler.SetLearningStore(learningStore) + + return handler, patrol, unifiedStore, learningStore +} + +func addPatrolFinding(t *testing.T, patrol *ai.PatrolService, id string, detectedAt time.Time) *ai.Finding { + t.Helper() + + finding := &ai.Finding{ + ID: id, + Key: "key-" + id, + Severity: ai.FindingSeverityWarning, + Category: ai.FindingCategoryPerformance, + Title: "CPU spike", + Description: "CPU high", + ResourceID: "vm-1", + ResourceName: "vm-1", + ResourceType: "vm", + DetectedAt: detectedAt, + LastSeenAt: detectedAt, + } + patrol.GetFindings().Add(finding) + return finding +} + +func addUnifiedFinding(store *unified.UnifiedStore, id string, detectedAt time.Time) *unified.UnifiedFinding { + finding := &unified.UnifiedFinding{ + ID: id, + Source: unified.SourceAIPatrol, + Severity: unified.SeverityWarning, + Category: unified.CategoryPerformance, + ResourceID: "vm-1", + ResourceName: "vm-1", + ResourceType: "vm", + Title: "CPU spike", + Description: "CPU high", + DetectedAt: detectedAt, + } + store.AddFromAI(finding) + return finding +} + +func TestHandleAcknowledgeFinding_PatrolAndUnified(t *testing.T) { + handler, patrol, unifiedStore, learningStore := setupAIHandlerWithPatrol(t) + + detectedAt := time.Now().Add(-2 * time.Hour) + addPatrolFinding(t, patrol, "finding-ack", detectedAt) + addUnifiedFinding(unifiedStore, "finding-ack", detectedAt) + + req := httptest.NewRequest(http.MethodPost, "/api/ai/patrol/acknowledge", strings.NewReader(`{"finding_id":"finding-ack"}`)) + rec := httptest.NewRecorder() + + handler.HandleAcknowledgeFinding(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + patrolFinding := patrol.GetFindings().Get("finding-ack") + if patrolFinding == nil || patrolFinding.AcknowledgedAt == nil { + t.Fatalf("expected patrol finding to be acknowledged") + } + + unifiedFinding := unifiedStore.Get("finding-ack") + if unifiedFinding == nil || unifiedFinding.AcknowledgedAt == nil { + t.Fatalf("expected unified finding to be acknowledged") + } + + stats := learningStore.GetStatistics() + if stats.TotalFeedbackRecords != 1 { + t.Fatalf("feedback records = %d, want 1", stats.TotalFeedbackRecords) + } +} + +func TestHandleAcknowledgeFinding_UnifiedOnly(t *testing.T) { + handler, patrol, unifiedStore, learningStore := setupAIHandlerWithPatrol(t) + + detectedAt := time.Now().Add(-1 * time.Hour) + addUnifiedFinding(unifiedStore, "finding-unified-only", detectedAt) + + req := httptest.NewRequest(http.MethodPost, "/api/ai/patrol/acknowledge", strings.NewReader(`{"finding_id":"finding-unified-only"}`)) + rec := httptest.NewRecorder() + + handler.HandleAcknowledgeFinding(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + if patrol.GetFindings().Get("finding-unified-only") != nil { + t.Fatalf("expected patrol finding to be absent") + } + + unifiedFinding := unifiedStore.Get("finding-unified-only") + if unifiedFinding == nil || unifiedFinding.AcknowledgedAt == nil { + t.Fatalf("expected unified finding to be acknowledged") + } + + stats := learningStore.GetStatistics() + if stats.TotalFeedbackRecords != 1 { + t.Fatalf("feedback records = %d, want 1", stats.TotalFeedbackRecords) + } +} + +func TestHandleSnoozeFinding_CapsDuration(t *testing.T) { + handler, patrol, unifiedStore, learningStore := setupAIHandlerWithPatrol(t) + + detectedAt := time.Now().Add(-30 * time.Minute) + addPatrolFinding(t, patrol, "finding-snooze", detectedAt) + addUnifiedFinding(unifiedStore, "finding-snooze", detectedAt) + + body := `{"finding_id":"finding-snooze","duration_hours":200}` + req := httptest.NewRequest(http.MethodPost, "/api/ai/patrol/snooze", strings.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleSnoozeFinding(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + message, _ := resp["message"].(string) + if !strings.Contains(message, "168") { + t.Fatalf("expected capped duration in message, got %q", message) + } + + patrolFinding := patrol.GetFindings().Get("finding-snooze") + if patrolFinding == nil || patrolFinding.SnoozedUntil == nil { + t.Fatalf("expected snoozed patrol finding") + } + if patrolFinding.SnoozedUntil.Before(time.Now().Add(167 * time.Hour)) { + t.Fatalf("snooze duration not applied") + } + + unifiedFinding := unifiedStore.Get("finding-snooze") + if unifiedFinding == nil || unifiedFinding.SnoozedUntil == nil { + t.Fatalf("expected snoozed unified finding") + } + + stats := learningStore.GetStatistics() + if stats.TotalFeedbackRecords != 1 { + t.Fatalf("feedback records = %d, want 1", stats.TotalFeedbackRecords) + } +} + +func TestHandleResolveFinding_SetsResolved(t *testing.T) { + handler, patrol, unifiedStore, learningStore := setupAIHandlerWithPatrol(t) + + detectedAt := time.Now().Add(-2 * time.Hour) + addPatrolFinding(t, patrol, "finding-resolve", detectedAt) + addUnifiedFinding(unifiedStore, "finding-resolve", detectedAt) + + req := httptest.NewRequest(http.MethodPost, "/api/ai/patrol/resolve", strings.NewReader(`{"finding_id":"finding-resolve"}`)) + rec := httptest.NewRecorder() + + handler.HandleResolveFinding(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + patrolFinding := patrol.GetFindings().Get("finding-resolve") + if patrolFinding == nil || patrolFinding.ResolvedAt == nil { + t.Fatalf("expected patrol finding to be resolved") + } + + unifiedFinding := unifiedStore.Get("finding-resolve") + if unifiedFinding == nil || unifiedFinding.ResolvedAt == nil { + t.Fatalf("expected unified finding to be resolved") + } + + stats := learningStore.GetStatistics() + if stats.TotalFeedbackRecords != 1 { + t.Fatalf("feedback records = %d, want 1", stats.TotalFeedbackRecords) + } +} + +func TestHandleDismissFinding_ValidReason(t *testing.T) { + handler, patrol, unifiedStore, learningStore := setupAIHandlerWithPatrol(t) + + detectedAt := time.Now().Add(-3 * time.Hour) + addPatrolFinding(t, patrol, "finding-dismiss", detectedAt) + addUnifiedFinding(unifiedStore, "finding-dismiss", detectedAt) + + payload := map[string]string{ + "finding_id": "finding-dismiss", + "reason": "expected_behavior", + "note": "known load test", + } + body, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/api/ai/patrol/dismiss", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleDismissFinding(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + patrolFinding := patrol.GetFindings().Get("finding-dismiss") + if patrolFinding == nil || patrolFinding.DismissedReason != "expected_behavior" { + t.Fatalf("expected patrol finding to be dismissed") + } + if patrolFinding.UserNote != "known load test" { + t.Fatalf("expected patrol note to be recorded") + } + + unifiedFinding := unifiedStore.Get("finding-dismiss") + if unifiedFinding == nil || unifiedFinding.DismissedReason != "expected_behavior" { + t.Fatalf("expected unified finding to be dismissed") + } + if unifiedFinding.UserNote != "known load test" { + t.Fatalf("expected unified note to be recorded") + } + + stats := learningStore.GetStatistics() + if stats.TotalFeedbackRecords != 1 { + t.Fatalf("feedback records = %d, want 1", stats.TotalFeedbackRecords) + } +} + +func TestHandleSuppressFinding_SetsSuppressed(t *testing.T) { + handler, patrol, unifiedStore, learningStore := setupAIHandlerWithPatrol(t) + + detectedAt := time.Now().Add(-45 * time.Minute) + addPatrolFinding(t, patrol, "finding-suppress", detectedAt) + addUnifiedFinding(unifiedStore, "finding-suppress", detectedAt) + + req := httptest.NewRequest(http.MethodPost, "/api/ai/patrol/suppress", strings.NewReader(`{"finding_id":"finding-suppress"}`)) + rec := httptest.NewRecorder() + + handler.HandleSuppressFinding(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + patrolFinding := patrol.GetFindings().Get("finding-suppress") + if patrolFinding == nil || !patrolFinding.Suppressed { + t.Fatalf("expected patrol finding to be suppressed") + } + if patrolFinding.DismissedReason != "suppressed" { + t.Fatalf("expected patrol dismissal reason to be set") + } + + unifiedFinding := unifiedStore.Get("finding-suppress") + if unifiedFinding == nil || !unifiedFinding.Suppressed { + t.Fatalf("expected unified finding to be suppressed") + } + if unifiedFinding.DismissedReason != "not_an_issue" { + t.Fatalf("expected unified dismissal reason to be not_an_issue") + } + + stats := learningStore.GetStatistics() + if stats.TotalFeedbackRecords != 1 { + t.Fatalf("feedback records = %d, want 1", stats.TotalFeedbackRecords) + } +} + +func TestHandleGetFindingsHistory_StartTimeFilter(t *testing.T) { + handler, patrol, _, _ := setupAIHandlerWithPatrol(t) + + oldTime := time.Now().Add(-3 * time.Hour) + recentTime := time.Now().Add(-30 * time.Minute) + addPatrolFinding(t, patrol, "finding-old", oldTime) + addPatrolFinding(t, patrol, "finding-recent", recentTime) + + startTime := time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339) + req := httptest.NewRequest(http.MethodGet, "/api/ai/patrol/history?start_time="+startTime, nil) + rec := httptest.NewRecorder() + + handler.HandleGetFindingsHistory(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var findings []ai.Finding + if err := json.Unmarshal(rec.Body.Bytes(), &findings); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(findings) != 1 { + t.Fatalf("findings = %d, want 1", len(findings)) + } + if findings[0].ID != "finding-recent" { + t.Fatalf("expected recent finding, got %s", findings[0].ID) + } +} + +func TestHandleForcePatrol_ConfigDisabled(t *testing.T) { + handler, patrol, _, _ := setupAIHandlerWithPatrol(t) + + cfg := ai.DefaultPatrolConfig() + cfg.Enabled = false + patrol.SetConfig(cfg) + + req := httptest.NewRequest(http.MethodPost, "/api/ai/patrol/run?deep=true", nil) + rec := httptest.NewRecorder() + + handler.HandleForcePatrol(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if !strings.Contains(rec.Body.String(), "Triggered patrol run") { + t.Fatalf("expected success message") + } +} diff --git a/internal/api/ai_handlers_setters_additional_test.go b/internal/api/ai_handlers_setters_additional_test.go new file mode 100644 index 000000000..58225e14b --- /dev/null +++ b/internal/api/ai_handlers_setters_additional_test.go @@ -0,0 +1,59 @@ +package api + +import ( + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/unified" + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/metrics" + "github.com/rcourtman/pulse-go-rewrite/internal/monitoring" +) + +func TestAISettingsHandler_SettersAndGetters(t *testing.T) { + handler := &AISettingsHandler{} + + mtp := config.NewMultiTenantPersistence(t.TempDir()) + handler.SetMultiTenantPersistence(mtp) + if handler.mtPersistence != mtp { + t.Fatalf("mtPersistence not set") + } + + mtm := &monitoring.MultiTenantMonitor{} + handler.SetMultiTenantMonitor(mtm) + if handler.mtMonitor != mtm { + t.Fatalf("mtMonitor not set") + } + + store := unified.NewUnifiedStore(unified.DefaultAlertToFindingConfig()) + handler.SetUnifiedStore(store) + if handler.GetUnifiedStore() != store { + t.Fatalf("GetUnifiedStore returned unexpected store") + } + + bridge := unified.NewAlertBridge(store, unified.DefaultBridgeConfig()) + handler.SetAlertBridge(bridge) + if handler.GetAlertBridge() != bridge { + t.Fatalf("GetAlertBridge returned unexpected bridge") + } + + triggerManager := ai.NewTriggerManager(ai.DefaultTriggerManagerConfig()) + handler.SetTriggerManager(triggerManager) + if handler.GetTriggerManager() != triggerManager { + t.Fatalf("GetTriggerManager returned unexpected manager") + } + + coordinator := ai.NewIncidentCoordinator(ai.IncidentCoordinatorConfig{}) + handler.SetIncidentCoordinator(coordinator) + if handler.GetIncidentCoordinator() != coordinator { + t.Fatalf("GetIncidentCoordinator returned unexpected coordinator") + } + + recorder := &metrics.IncidentRecorder{} + handler.SetIncidentRecorder(recorder) + if handler.GetIncidentRecorder() != recorder { + t.Fatalf("GetIncidentRecorder returned unexpected recorder") + } + + handler.WireOrchestratorAfterChatStart() +} diff --git a/internal/api/ai_handlers_stream_test.go b/internal/api/ai_handlers_stream_test.go new file mode 100644 index 000000000..1b87312cb --- /dev/null +++ b/internal/api/ai_handlers_stream_test.go @@ -0,0 +1,206 @@ +package api + +import ( + "bytes" + "encoding/json" + "os" + "strings" + "testing" + + "net/http" + "net/http/httptest" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/knowledge" + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +type stubLicenseChecker struct { + allow bool +} + +func (s stubLicenseChecker) HasFeature(feature string) bool { + return s.allow +} + +func (s stubLicenseChecker) GetLicenseStateString() (string, bool) { + if s.allow { + return "active", true + } + return "expired", false +} + +func withEnv(t *testing.T, key, value string, fn func()) { + t.Helper() + old := os.Getenv(key) + if err := os.Setenv(key, value); err != nil { + t.Fatalf("setenv failed: %v", err) + } + defer func() { + _ = os.Setenv(key, old) + }() + fn() +} + +func newTestAISettingsHandlerWithService(t *testing.T) *AISettingsHandler { + t.Helper() + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + + handler := NewAISettingsHandler(nil, nil, nil) + handler.legacyConfig = cfg + handler.legacyPersistence = persistence + handler.legacyAIService = ai.NewService(persistence, nil) + return handler +} + +func TestHandleExecuteStream_LicenseRequired(t *testing.T) { + withEnv(t, "PULSE_MOCK_MODE", "true", func() { + handler := newTestAISettingsHandlerWithService(t) + handler.legacyAIService.SetLicenseChecker(stubLicenseChecker{allow: false}) + + body := `{"prompt":"hi","use_case":"autofix"}` + req := httptest.NewRequest(http.MethodPost, "/api/ai/execute/stream", strings.NewReader(body)) + rec := httptest.NewRecorder() + handler.HandleExecuteStream(rec, req) + + if rec.Code != http.StatusPaymentRequired { + t.Fatalf("expected payment required, got %d", rec.Code) + } + if !strings.Contains(rec.Body.String(), "license_required") { + t.Fatalf("expected license error body") + } + }) +} + +func TestHandleExecuteStream_PromptRequired(t *testing.T) { + withEnv(t, "PULSE_MOCK_MODE", "true", func() { + handler := newTestAISettingsHandlerWithService(t) + + body := `{"prompt":""}` + req := httptest.NewRequest(http.MethodPost, "/api/ai/execute/stream", strings.NewReader(body)) + rec := httptest.NewRecorder() + handler.HandleExecuteStream(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected bad request, got %d", rec.Code) + } + }) +} + +func TestHandleExecuteStream_Success(t *testing.T) { + withEnv(t, "PULSE_MOCK_MODE", "true", func() { + handler := newTestAISettingsHandlerWithService(t) + + body := `{"prompt":"hello"}` + req := httptest.NewRequest(http.MethodPost, "/api/ai/execute/stream", strings.NewReader(body)) + rec := httptest.NewRecorder() + handler.HandleExecuteStream(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected OK, got %d", rec.Code) + } + if !strings.Contains(rec.Header().Get("Content-Type"), "text/event-stream") { + t.Fatalf("expected SSE content type") + } + }) +} + +func TestHandleExportGuestKnowledge(t *testing.T) { + handler := newTestAISettingsHandlerWithService(t) + if err := handler.legacyAIService.SaveGuestNote("guest-1", "VM 1", "vm", "ops", "Note", "Content"); err != nil { + t.Fatalf("save note error: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/api/ai/knowledge/export?guest_id=guest-1", nil) + rec := httptest.NewRecorder() + handler.HandleExportGuestKnowledge(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected OK, got %d", rec.Code) + } + if !strings.Contains(rec.Header().Get("Content-Disposition"), "guest-1") { + t.Fatalf("expected content disposition with guest id") + } + + var exported knowledge.GuestKnowledge + if err := json.NewDecoder(rec.Body).Decode(&exported); err != nil { + t.Fatalf("decode error: %v", err) + } + if len(exported.Notes) == 0 { + t.Fatalf("expected exported notes") + } + + missingReq := httptest.NewRequest(http.MethodGet, "/api/ai/knowledge/export", nil) + missingRec := httptest.NewRecorder() + handler.HandleExportGuestKnowledge(missingRec, missingReq) + if missingRec.Code != http.StatusBadRequest { + t.Fatalf("expected bad request for missing guest id") + } +} + +func TestHandleImportGuestKnowledge(t *testing.T) { + handler := newTestAISettingsHandlerWithService(t) + if err := handler.legacyAIService.SaveGuestNote("guest-1", "VM 1", "vm", "ops", "Old", "Old content"); err != nil { + t.Fatalf("save note error: %v", err) + } + + invalidReq := httptest.NewRequest(http.MethodPost, "/api/ai/knowledge/import", strings.NewReader("{bad")) + invalidRec := httptest.NewRecorder() + handler.HandleImportGuestKnowledge(invalidRec, invalidReq) + if invalidRec.Code != http.StatusBadRequest { + t.Fatalf("expected bad request for invalid json") + } + + methodReq := httptest.NewRequest(http.MethodGet, "/api/ai/knowledge/import", nil) + methodRec := httptest.NewRecorder() + handler.HandleImportGuestKnowledge(methodRec, methodReq) + if methodRec.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected method not allowed") + } + + emptyReq := httptest.NewRequest(http.MethodPost, "/api/ai/knowledge/import", strings.NewReader(`{"guest_id":""}`)) + emptyRec := httptest.NewRecorder() + handler.HandleImportGuestKnowledge(emptyRec, emptyReq) + if emptyRec.Code != http.StatusBadRequest { + t.Fatalf("expected bad request for missing guest id") + } + + noNotesReq := httptest.NewRequest(http.MethodPost, "/api/ai/knowledge/import", strings.NewReader(`{"guest_id":"guest-1","notes":[]}`)) + noNotesRec := httptest.NewRecorder() + handler.HandleImportGuestKnowledge(noNotesRec, noNotesReq) + if noNotesRec.Code != http.StatusBadRequest { + t.Fatalf("expected bad request for empty notes") + } + + payload := map[string]interface{}{ + "guest_id": "guest-1", + "guest_name": "VM 1", + "guest_type": "vm", + "merge": false, + "notes": []map[string]string{ + {"category": "ops", "title": "New", "content": "New content"}, + {"category": "", "title": "Skip", "content": "Bad"}, + }, + } + body, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/api/ai/knowledge/import", bytes.NewReader(body)) + rec := httptest.NewRecorder() + handler.HandleImportGuestKnowledge(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected OK, got %d", rec.Code) + } + if !strings.Contains(rec.Body.String(), "\"imported\":1") { + t.Fatalf("expected import count") + } + + knowledge, err := handler.legacyAIService.GetGuestKnowledge("guest-1") + if err != nil { + t.Fatalf("get knowledge error: %v", err) + } + if len(knowledge.Notes) != 1 { + t.Fatalf("expected only imported notes, got %d", len(knowledge.Notes)) + } +} diff --git a/internal/api/ai_intelligence_handlers_additional_test.go b/internal/api/ai_intelligence_handlers_additional_test.go new file mode 100644 index 000000000..71d2d5092 --- /dev/null +++ b/internal/api/ai_intelligence_handlers_additional_test.go @@ -0,0 +1,196 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai/forecast" +) + +func decodeJSON(t *testing.T, rr *httptest.ResponseRecorder) map[string]interface{} { + t.Helper() + var payload map[string]interface{} + if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil { + t.Fatalf("failed to decode JSON: %v", err) + } + return payload +} + +func TestAIIntelligenceHandlers_NoServices(t *testing.T) { + handler := &AISettingsHandler{} + + t.Run("anomalies-disabled", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/intelligence/anomalies", nil) + rr := httptest.NewRecorder() + + handler.HandleGetAnomalies(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["message"] != "Pulse Patrol is not enabled" { + t.Fatalf("message = %v, want Pulse Patrol is not enabled", payload["message"]) + } + }) + + t.Run("learning-disabled", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/intelligence/learning", nil) + rr := httptest.NewRecorder() + + handler.HandleGetLearningStatus(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["status"] != "ai_disabled" { + t.Fatalf("status = %v, want ai_disabled", payload["status"]) + } + }) + + t.Run("learning-preferences-missing", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/learning/preferences", nil) + rr := httptest.NewRecorder() + + handler.HandleGetLearningPreferences(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["message"] != "Learning store not available" { + t.Fatalf("message = %v, want Learning store not available", payload["message"]) + } + }) + + t.Run("unified-findings-missing", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/unified/findings", nil) + rr := httptest.NewRecorder() + + handler.HandleGetUnifiedFindings(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["message"] != "Unified store not available" { + t.Fatalf("message = %v, want Unified store not available", payload["message"]) + } + }) + + t.Run("proxmox-events-missing", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/proxmox/events", nil) + rr := httptest.NewRecorder() + + handler.HandleGetProxmoxEvents(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["message"] != "Proxmox event correlator not available" { + t.Fatalf("message = %v, want Proxmox event correlator not available", payload["message"]) + } + }) + + t.Run("proxmox-correlations-missing", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/proxmox/correlations", nil) + rr := httptest.NewRecorder() + + handler.HandleGetProxmoxCorrelations(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["message"] != "Proxmox event correlator not available" { + t.Fatalf("message = %v, want Proxmox event correlator not available", payload["message"]) + } + }) + + t.Run("remediation-plans-missing", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/remediation/plans", nil) + rr := httptest.NewRecorder() + + handler.HandleGetRemediationPlans(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["message"] != "Remediation engine not available" { + t.Fatalf("message = %v, want Remediation engine not available", payload["message"]) + } + }) + + t.Run("remediation-plan-missing-engine", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/remediation/plans/1?plan_id=1", nil) + rr := httptest.NewRecorder() + + handler.HandleGetRemediationPlan(rr, req) + if rr.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", rr.Code) + } + }) + + t.Run("approve-remediation-plan-missing-engine", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/api/ai/remediation/plans/1/approve", nil) + rr := httptest.NewRecorder() + + handler.HandleApproveRemediationPlan(rr, req) + if rr.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503", rr.Code) + } + }) +} + +func TestForecastHandlers(t *testing.T) { + handler := &AISettingsHandler{} + + t.Run("forecast-service-missing", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/ai/forecast", nil) + rr := httptest.NewRecorder() + + handler.HandleGetForecast(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["message"] != "Forecast service not available" { + t.Fatalf("message = %v, want Forecast service not available", payload["message"]) + } + }) + + t.Run("forecast-missing-params", func(t *testing.T) { + handler.SetForecastService(forecast.NewService(forecast.DefaultForecastConfig())) + req := httptest.NewRequest(http.MethodGet, "/api/ai/forecast", nil) + rr := httptest.NewRecorder() + + handler.HandleGetForecast(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rr.Code) + } + }) + + t.Run("forecast-overview-error", func(t *testing.T) { + handler.SetForecastService(forecast.NewService(forecast.DefaultForecastConfig())) + req := httptest.NewRequest(http.MethodGet, "/api/ai/forecasts/overview", nil) + rr := httptest.NewRecorder() + + handler.HandleGetForecastOverview(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rr.Code) + } + payload := decodeJSON(t, rr) + if payload["error"] == nil { + t.Fatalf("expected error in response") + } + }) +} + +func TestAIIntelligenceHandlers_MethodNotAllowed(t *testing.T) { + handler := &AISettingsHandler{} + req := httptest.NewRequest(http.MethodPost, "/api/ai/intelligence/anomalies", nil) + rr := httptest.NewRecorder() + + handler.HandleGetAnomalies(rr, req) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("status = %d, want 405", rr.Code) + } +} diff --git a/internal/api/ai_intelligence_handlers_data_additional_test.go b/internal/api/ai_intelligence_handlers_data_additional_test.go new file mode 100644 index 000000000..684769dc1 --- /dev/null +++ b/internal/api/ai_intelligence_handlers_data_additional_test.go @@ -0,0 +1,188 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai" + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func setupAIHandlerWithIntelligence(t *testing.T) (*AISettingsHandler, *ai.PatrolService) { + t.Helper() + + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + handler.legacyAIService.SetStateProvider(&stubStateProvider{}) + + patrol := handler.legacyAIService.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service to be initialized") + } + + return handler, patrol +} + +func seedPatternDetector(now time.Time) *ai.PatternDetector { + detector := ai.NewPatternDetector(ai.PatternDetectorConfig{ + MinOccurrences: 2, + PatternWindow: 24 * time.Hour, + PredictionLimit: 48 * time.Hour, + }) + + detector.RecordEvent(ai.HistoricalEvent{ + ResourceID: "vm-1", + EventType: ai.EventHighCPU, + Timestamp: now.Add(-1 * time.Hour), + Duration: 10 * time.Minute, + }) + detector.RecordEvent(ai.HistoricalEvent{ + ResourceID: "vm-1", + EventType: ai.EventHighCPU, + Timestamp: now, + Duration: 5 * time.Minute, + }) + + return detector +} + +func seedCorrelationDetector(now time.Time) *ai.CorrelationDetector { + detector := ai.NewCorrelationDetector(ai.CorrelationConfig{ + MinOccurrences: 1, + CorrelationWindow: 10 * time.Minute, + RetentionWindow: 24 * time.Hour, + MaxEvents: 100, + }) + + detector.RecordEvent(ai.CorrelationEvent{ + ResourceID: "node-1", + ResourceName: "node-1", + ResourceType: "node", + EventType: ai.CorrelationEventHighCPU, + Timestamp: now.Add(-2 * time.Minute), + }) + detector.RecordEvent(ai.CorrelationEvent{ + ResourceID: "vm-1", + ResourceName: "vm-1", + ResourceType: "vm", + EventType: ai.CorrelationEventRestart, + Timestamp: now.Add(-1 * time.Minute), + }) + + return detector +} + +func TestHandleGetPatterns_LockedWithData(t *testing.T) { + t.Setenv("PULSE_MOCK_MODE", "true") + handler, _ := setupAIHandlerWithIntelligence(t) + + handler.legacyAIService.SetPatternDetector(seedPatternDetector(time.Now())) + handler.legacyAIService.SetLicenseChecker(stubLicenseChecker{allow: false}) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/intelligence/patterns", nil) + rec := httptest.NewRecorder() + + handler.HandleGetPatterns(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if rec.Header().Get("X-License-Required") != "true" { + t.Fatalf("expected license header to be set") + } + + var resp struct { + Patterns []map[string]interface{} `json:"patterns"` + Count int `json:"count"` + LicenseRequired bool `json:"license_required"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp.Count != 1 { + t.Fatalf("count = %d, want 1", resp.Count) + } + if len(resp.Patterns) != 0 { + t.Fatalf("expected patterns to be redacted") + } + if !resp.LicenseRequired { + t.Fatalf("expected license_required=true") + } +} + +func TestHandleGetPredictions_WithData(t *testing.T) { + t.Setenv("PULSE_MOCK_MODE", "true") + handler, _ := setupAIHandlerWithIntelligence(t) + + handler.legacyAIService.SetPatternDetector(seedPatternDetector(time.Now())) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/intelligence/predictions?resource_id=vm-1", nil) + rec := httptest.NewRecorder() + + handler.HandleGetPredictions(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var resp struct { + Predictions []struct { + ResourceID string `json:"resource_id"` + IsOverdue bool `json:"is_overdue"` + } `json:"predictions"` + Count int `json:"count"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp.Count != 1 || len(resp.Predictions) != 1 { + t.Fatalf("predictions count = %d, want 1", resp.Count) + } + if resp.Predictions[0].ResourceID != "vm-1" { + t.Fatalf("resource_id = %s, want vm-1", resp.Predictions[0].ResourceID) + } + if resp.Predictions[0].IsOverdue { + t.Fatalf("expected prediction to not be overdue") + } +} + +func TestHandleGetCorrelations_WithData(t *testing.T) { + t.Setenv("PULSE_MOCK_MODE", "true") + handler, _ := setupAIHandlerWithIntelligence(t) + + handler.legacyAIService.SetCorrelationDetector(seedCorrelationDetector(time.Now())) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/intelligence/correlations?resource_id=vm-1", nil) + rec := httptest.NewRecorder() + + handler.HandleGetCorrelations(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var resp struct { + Correlations []struct { + TargetID string `json:"target_id"` + EventPattern string `json:"event_pattern"` + } `json:"correlations"` + Count int `json:"count"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp.Count != 1 || len(resp.Correlations) != 1 { + t.Fatalf("correlations count = %d, want 1", resp.Count) + } + if resp.Correlations[0].TargetID != "vm-1" { + t.Fatalf("target_id = %s, want vm-1", resp.Correlations[0].TargetID) + } + if resp.Correlations[0].EventPattern == "" { + t.Fatalf("expected event_pattern to be set") + } +} diff --git a/internal/api/ai_intelligence_handlers_remediation_additional_test.go b/internal/api/ai_intelligence_handlers_remediation_additional_test.go new file mode 100644 index 000000000..8f2795190 --- /dev/null +++ b/internal/api/ai_intelligence_handlers_remediation_additional_test.go @@ -0,0 +1,311 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/circuit" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/memory" + "github.com/rcourtman/pulse-go-rewrite/internal/ai/remediation" + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +type stubRemediationExecutor struct { + mu sync.Mutex + calls []remediationCall +} + +type remediationCall struct { + target string + command string +} + +func (s *stubRemediationExecutor) Execute(ctx context.Context, target, command string) (string, error) { + s.mu.Lock() + s.calls = append(s.calls, remediationCall{target: target, command: command}) + s.mu.Unlock() + return "ok", nil +} + +func (s *stubRemediationExecutor) Calls() []remediationCall { + s.mu.Lock() + defer s.mu.Unlock() + calls := make([]remediationCall, len(s.calls)) + copy(calls, s.calls) + return calls +} + +func TestHandleExecuteRemediationPlan(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + executor := &stubRemediationExecutor{} + engine := remediation.NewEngine(remediation.EngineConfig{DataDir: ""}) + engine.SetCommandExecutor(executor) + handler.SetRemediationEngine(engine) + + plan := &remediation.RemediationPlan{ + ID: "plan-1", + FindingID: "finding-1", + ResourceID: "res-1", + Title: "Restart service", + Description: "Restart to recover", + Steps: []remediation.RemediationStep{ + { + Order: 0, + Description: "Restart", + Command: "echo ok", + Target: "host-1", + }, + }, + } + if err := engine.CreatePlan(plan); err != nil { + t.Fatalf("CreatePlan: %v", err) + } + + body, _ := json.Marshal(map[string]string{"plan_id": plan.ID}) + req := httptest.NewRequest(http.MethodPost, "/api/ai/remediation/plans/plan-1/execute", bytes.NewReader(body)) + rec := httptest.NewRecorder() + handler.HandleExecuteRemediationPlan(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var execution remediation.RemediationExecution + if err := json.Unmarshal(rec.Body.Bytes(), &execution); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if execution.PlanID != plan.ID || execution.Status != remediation.StatusCompleted { + t.Fatalf("unexpected execution: %+v", execution) + } + if len(executor.Calls()) != 1 { + t.Fatalf("expected 1 command call, got %d", len(executor.Calls())) + } +} + +func TestHandleRollbackRemediationPlan(t *testing.T) { + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + executor := &stubRemediationExecutor{} + engine := remediation.NewEngine(remediation.EngineConfig{DataDir: ""}) + engine.SetCommandExecutor(executor) + handler.SetRemediationEngine(engine) + + plan := &remediation.RemediationPlan{ + ID: "plan-rollback", + FindingID: "finding-1", + ResourceID: "res-1", + Title: "Restart service", + Description: "Restart to recover", + Steps: []remediation.RemediationStep{ + { + Order: 0, + Description: "Restart", + Command: "echo ok", + Target: "host-1", + Rollback: "echo rollback", + }, + }, + } + if err := engine.CreatePlan(plan); err != nil { + t.Fatalf("CreatePlan: %v", err) + } + + exec, err := engine.ApprovePlan(plan.ID, "tester") + if err != nil { + t.Fatalf("ApprovePlan: %v", err) + } + if err := engine.Execute(context.Background(), exec.ID); err != nil { + t.Fatalf("Execute: %v", err) + } + + body, _ := json.Marshal(map[string]string{"execution_id": exec.ID}) + req := httptest.NewRequest(http.MethodPost, "/api/ai/remediation/plans/plan-rollback/rollback", bytes.NewReader(body)) + rec := httptest.NewRecorder() + handler.HandleRollbackRemediationPlan(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + calls := executor.Calls() + if len(calls) < 2 { + t.Fatalf("expected rollback to execute command, got %d calls", len(calls)) + } + if calls[len(calls)-1].command != "echo rollback" { + t.Fatalf("expected rollback command, got %q", calls[len(calls)-1].command) + } +} + +func TestHandleGetCircuitBreakerStatus(t *testing.T) { + handler := &AISettingsHandler{} + + req := httptest.NewRequest(http.MethodGet, "/api/ai/circuit/status", nil) + rec := httptest.NewRecorder() + handler.HandleGetCircuitBreakerStatus(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d", rec.Code) + } + + breaker := circuit.NewBreaker("patrol", circuit.Config{ + FailureThreshold: 1, + SuccessThreshold: 1, + InitialBackoff: time.Minute, + MaxBackoff: time.Minute, + BackoffMultiplier: 1, + HalfOpenTimeout: time.Minute, + }) + breaker.RecordFailure(context.Canceled) + + handler.SetCircuitBreaker(breaker) + rec = httptest.NewRecorder() + handler.HandleGetCircuitBreakerStatus(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d", rec.Code) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp["state"] != "open" { + t.Fatalf("expected state open, got %v", resp["state"]) + } + if resp["can_patrol"].(bool) { + t.Fatalf("expected can_patrol to be false when breaker is open") + } +} + +func setupIncidentHandler(t *testing.T) (*AISettingsHandler, *memory.IncidentStore) { + t.Helper() + tmp := t.TempDir() + cfg := &config.Config{DataPath: tmp} + persistence := config.NewConfigPersistence(tmp) + handler := newTestAISettingsHandler(cfg, persistence, nil) + + svc := handler.GetAIService(context.Background()) + svc.SetStateProvider(&MockStateProvider{}) + patrol := svc.GetPatrolService() + if patrol == nil { + t.Fatalf("expected patrol service") + } + + store := memory.NewIncidentStore(memory.IncidentStoreConfig{DataDir: ""}) + patrol.SetIncidentStore(store) + + coordinator := ai.NewIncidentCoordinator(ai.IncidentCoordinatorConfig{EnableRecorder: false}) + coordinator.SetIncidentStore(store) + coordinator.Start() + handler.SetIncidentCoordinator(coordinator) + + alert := &alerts.Alert{ + ID: "alert-1", + Type: "cpu", + Level: alerts.AlertLevelWarning, + ResourceID: "res-1", + ResourceName: "node-1", + StartTime: time.Now(), + LastSeen: time.Now(), + } + coordinator.OnAlertFired(alert) + + return handler, store +} + +func TestHandleGetRecentIncidents(t *testing.T) { + handler, _ := setupIncidentHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/incidents?resource_id=res-1&limit=5", nil) + rec := httptest.NewRecorder() + handler.HandleGetRecentIncidents(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + incidents := resp["incidents"].([]interface{}) + if len(incidents) != 1 { + t.Fatalf("expected 1 incident, got %d", len(incidents)) + } + if resp["active_count"].(float64) < 1 { + t.Fatalf("expected active_count >= 1") + } +} + +func TestHandleGetRecentIncidentsSummary(t *testing.T) { + handler, _ := setupIncidentHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/api/ai/incidents?limit=5", nil) + rec := httptest.NewRecorder() + handler.HandleGetRecentIncidents(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp["incident_summary"] == "" { + t.Fatalf("expected incident_summary to be populated") + } +} + +func TestHandleGetIncidentData(t *testing.T) { + handler, store := setupIncidentHandler(t) + + alert := &alerts.Alert{ + ID: "alert-2", + Type: "disk", + Level: alerts.AlertLevelCritical, + ResourceID: "node/pve", + ResourceName: "node/pve", + StartTime: time.Now(), + LastSeen: time.Now(), + } + store.RecordAlertFired(alert) + + escaped := url.PathEscape("node/pve") + req := httptest.NewRequest(http.MethodGet, "/api/ai/incidents/"+escaped+"?limit=5", nil) + rec := httptest.NewRecorder() + handler.HandleGetIncidentData(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp["resource_id"] != "node/pve" { + t.Fatalf("unexpected resource_id %v", resp["resource_id"]) + } + incidents := resp["incidents"].([]interface{}) + if len(incidents) == 0 { + t.Fatalf("expected incidents to be returned") + } + if resp["formatted_context"] == "" { + t.Fatalf("expected formatted_context to be populated") + } +} diff --git a/internal/api/ai_intelligence_helpers_additional_test.go b/internal/api/ai_intelligence_helpers_additional_test.go new file mode 100644 index 000000000..02dbd4add --- /dev/null +++ b/internal/api/ai_intelligence_helpers_additional_test.go @@ -0,0 +1,40 @@ +package api + +import ( + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai" +) + +func TestRemediationStatsFromRecords(t *testing.T) { + records := []ai.RemediationRecord{ + {Outcome: ai.OutcomeResolved, Automatic: true}, + {Outcome: ai.OutcomeResolved, Automatic: false}, + {Outcome: ai.OutcomePartial, Automatic: true}, + {Outcome: ai.OutcomeFailed, Automatic: false}, + {Outcome: "unknown", Automatic: true}, + } + + stats := remediationStatsFromRecords(records) + if stats["total"] != 5 { + t.Fatalf("total = %d, want 5", stats["total"]) + } + if stats["resolved"] != 2 { + t.Fatalf("resolved = %d, want 2", stats["resolved"]) + } + if stats["partial"] != 1 { + t.Fatalf("partial = %d, want 1", stats["partial"]) + } + if stats["failed"] != 1 { + t.Fatalf("failed = %d, want 1", stats["failed"]) + } + if stats["unknown"] != 1 { + t.Fatalf("unknown = %d, want 1", stats["unknown"]) + } + if stats["automatic"] != 3 { + t.Fatalf("automatic = %d, want 3", stats["automatic"]) + } + if stats["manual"] != 2 { + t.Fatalf("manual = %d, want 2", stats["manual"]) + } +} diff --git a/internal/api/aidiscovery_handlers.go b/internal/api/aidiscovery_handlers.go new file mode 100644 index 000000000..6718cddee --- /dev/null +++ b/internal/api/aidiscovery_handlers.go @@ -0,0 +1,349 @@ +package api + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/rcourtman/pulse-go-rewrite/internal/aidiscovery" + "github.com/rs/zerolog/log" +) + +// AIDiscoveryHandlers handles AI-powered infrastructure discovery endpoints. +type AIDiscoveryHandlers struct { + service *aidiscovery.Service +} + +// NewAIDiscoveryHandlers creates new AI discovery handlers. +func NewAIDiscoveryHandlers(service *aidiscovery.Service) *AIDiscoveryHandlers { + return &AIDiscoveryHandlers{ + service: service, + } +} + +// SetService sets the discovery service (used for late initialization after routes are registered). +func (h *AIDiscoveryHandlers) SetService(service *aidiscovery.Service) { + h.service = service +} + +// writeDiscoveryJSON writes a JSON response. +func writeDiscoveryJSON(w http.ResponseWriter, data any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(data) +} + +// writeDiscoveryError writes a JSON error response. +func writeDiscoveryError(w http.ResponseWriter, statusCode int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(map[string]any{ + "error": true, + "message": message, + }) +} + +// HandleListDiscoveries handles GET /api/aidiscovery +func (h *AIDiscoveryHandlers) HandleListDiscoveries(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + discoveries, err := h.service.ListDiscoveries() + if err != nil { + log.Error().Err(err).Msg("Failed to list discoveries") + writeDiscoveryError(w, http.StatusInternalServerError, "Failed to list discoveries") + return + } + + // Convert to summaries for list view + summaries := make([]aidiscovery.DiscoverySummary, 0, len(discoveries)) + for _, d := range discoveries { + summaries = append(summaries, d.ToSummary()) + } + + writeDiscoveryJSON(w, map[string]any{ + "discoveries": summaries, + "total": len(summaries), + }) +} + +// HandleGetDiscovery handles GET /api/aidiscovery/{type}/{host}/{id} +func (h *AIDiscoveryHandlers) HandleGetDiscovery(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + // Parse path: /api/aidiscovery/{type}/{host}/{id} + path := strings.TrimPrefix(r.URL.Path, "/api/aidiscovery/") + parts := strings.SplitN(path, "/", 3) + if len(parts) < 3 { + writeDiscoveryError(w, http.StatusBadRequest, "Invalid path: expected /api/aidiscovery/{type}/{host}/{id}") + return + } + + resourceType := aidiscovery.ResourceType(parts[0]) + hostID := parts[1] + resourceID := parts[2] + + discovery, err := h.service.GetDiscoveryByResource(resourceType, hostID, resourceID) + if err != nil { + log.Error().Err(err).Str("type", string(resourceType)).Str("host", hostID).Str("id", resourceID).Msg("Failed to get discovery") + writeDiscoveryError(w, http.StatusInternalServerError, "Failed to get discovery") + return + } + + if discovery == nil { + writeDiscoveryError(w, http.StatusNotFound, "Discovery not found") + return + } + + writeDiscoveryJSON(w, discovery) +} + +// HandleTriggerDiscovery handles POST /api/aidiscovery/{type}/{host}/{id} +func (h *AIDiscoveryHandlers) HandleTriggerDiscovery(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + // Parse path + path := strings.TrimPrefix(r.URL.Path, "/api/aidiscovery/") + parts := strings.SplitN(path, "/", 3) + if len(parts) < 3 { + writeDiscoveryError(w, http.StatusBadRequest, "Invalid path: expected /api/aidiscovery/{type}/{host}/{id}") + return + } + + resourceType := aidiscovery.ResourceType(parts[0]) + hostID := parts[1] + resourceID := parts[2] + + // Parse optional request body for force flag and hostname + var reqBody struct { + Force bool `json:"force"` + Hostname string `json:"hostname"` + } + if r.Body != nil { + _ = json.NewDecoder(r.Body).Decode(&reqBody) + } + + // Build discovery request + req := aidiscovery.DiscoveryRequest{ + ResourceType: resourceType, + ResourceID: resourceID, + HostID: hostID, + Hostname: reqBody.Hostname, + Force: reqBody.Force, + } + + // If hostname not provided, try to use hostID + if req.Hostname == "" { + req.Hostname = hostID + } + + discovery, err := h.service.DiscoverResource(r.Context(), req) + if err != nil { + log.Error().Err(err). + Str("type", string(resourceType)). + Str("host", hostID). + Str("id", resourceID). + Msg("Failed to trigger discovery") + writeDiscoveryError(w, http.StatusInternalServerError, "Discovery failed: "+err.Error()) + return + } + + writeDiscoveryJSON(w, discovery) +} + +// HandleUpdateNotes handles PUT /api/aidiscovery/{type}/{host}/{id}/notes +func (h *AIDiscoveryHandlers) HandleUpdateNotes(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + // Parse path + path := strings.TrimPrefix(r.URL.Path, "/api/aidiscovery/") + path = strings.TrimSuffix(path, "/notes") + parts := strings.SplitN(path, "/", 3) + if len(parts) < 3 { + writeDiscoveryError(w, http.StatusBadRequest, "Invalid path") + return + } + + resourceType := aidiscovery.ResourceType(parts[0]) + hostID := parts[1] + resourceID := parts[2] + + // Build the full ID + id := aidiscovery.MakeResourceID(resourceType, hostID, resourceID) + + // Parse request body + var req aidiscovery.UpdateNotesRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeDiscoveryError(w, http.StatusBadRequest, "Invalid request body") + return + } + + if err := h.service.UpdateNotes(id, req.UserNotes, req.UserSecrets); err != nil { + log.Error().Err(err).Str("id", id).Msg("Failed to update notes") + writeDiscoveryError(w, http.StatusInternalServerError, "Failed to update notes: "+err.Error()) + return + } + + // Return updated discovery + discovery, err := h.service.GetDiscovery(id) + if err != nil { + writeDiscoveryError(w, http.StatusInternalServerError, "Notes updated but failed to fetch result") + return + } + + writeDiscoveryJSON(w, discovery) +} + +// HandleDeleteDiscovery handles DELETE /api/aidiscovery/{type}/{host}/{id} +func (h *AIDiscoveryHandlers) HandleDeleteDiscovery(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + // Parse path + path := strings.TrimPrefix(r.URL.Path, "/api/aidiscovery/") + parts := strings.SplitN(path, "/", 3) + if len(parts) < 3 { + writeDiscoveryError(w, http.StatusBadRequest, "Invalid path") + return + } + + resourceType := aidiscovery.ResourceType(parts[0]) + hostID := parts[1] + resourceID := parts[2] + + id := aidiscovery.MakeResourceID(resourceType, hostID, resourceID) + + if err := h.service.DeleteDiscovery(id); err != nil { + log.Error().Err(err).Str("id", id).Msg("Failed to delete discovery") + writeDiscoveryError(w, http.StatusInternalServerError, "Failed to delete discovery") + return + } + + writeDiscoveryJSON(w, map[string]any{"success": true, "id": id}) +} + +// HandleGetProgress handles GET /api/aidiscovery/{type}/{host}/{id}/progress +func (h *AIDiscoveryHandlers) HandleGetProgress(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + // Parse path + path := strings.TrimPrefix(r.URL.Path, "/api/aidiscovery/") + path = strings.TrimSuffix(path, "/progress") + parts := strings.SplitN(path, "/", 3) + if len(parts) < 3 { + writeDiscoveryError(w, http.StatusBadRequest, "Invalid path") + return + } + + resourceType := aidiscovery.ResourceType(parts[0]) + hostID := parts[1] + resourceID := parts[2] + + id := aidiscovery.MakeResourceID(resourceType, hostID, resourceID) + + progress := h.service.GetProgress(id) + if progress == nil { + // Not currently scanning - check if we have a discovery + discovery, err := h.service.GetDiscovery(id) + if err == nil && discovery != nil { + writeDiscoveryJSON(w, map[string]any{ + "status": "completed", + "resource_id": id, + "updated_at": discovery.UpdatedAt, + }) + return + } + + writeDiscoveryJSON(w, map[string]any{ + "status": "not_started", + "resource_id": id, + }) + return + } + + writeDiscoveryJSON(w, progress) +} + +// HandleGetStatus handles GET /api/aidiscovery/status +func (h *AIDiscoveryHandlers) HandleGetStatus(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + writeDiscoveryJSON(w, h.service.GetStatus()) +} + +// HandleListByType handles GET /api/aidiscovery/type/{type} +func (h *AIDiscoveryHandlers) HandleListByType(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + // Parse path + path := strings.TrimPrefix(r.URL.Path, "/api/aidiscovery/type/") + resourceType := aidiscovery.ResourceType(path) + + discoveries, err := h.service.ListDiscoveriesByType(resourceType) + if err != nil { + log.Error().Err(err).Str("type", string(resourceType)).Msg("Failed to list discoveries by type") + writeDiscoveryError(w, http.StatusInternalServerError, "Failed to list discoveries") + return + } + + summaries := make([]aidiscovery.DiscoverySummary, 0, len(discoveries)) + for _, d := range discoveries { + summaries = append(summaries, d.ToSummary()) + } + + writeDiscoveryJSON(w, map[string]any{ + "discoveries": summaries, + "total": len(summaries), + "type": resourceType, + }) +} + +// HandleListByHost handles GET /api/aidiscovery/host/{host} +func (h *AIDiscoveryHandlers) HandleListByHost(w http.ResponseWriter, r *http.Request) { + if h.service == nil { + writeDiscoveryError(w, http.StatusServiceUnavailable, "AI discovery service not configured") + return + } + + // Parse path + hostID := strings.TrimPrefix(r.URL.Path, "/api/aidiscovery/host/") + + discoveries, err := h.service.ListDiscoveriesByHost(hostID) + if err != nil { + log.Error().Err(err).Str("host", hostID).Msg("Failed to list discoveries by host") + writeDiscoveryError(w, http.StatusInternalServerError, "Failed to list discoveries") + return + } + + summaries := make([]aidiscovery.DiscoverySummary, 0, len(discoveries)) + for _, d := range discoveries { + summaries = append(summaries, d.ToSummary()) + } + + writeDiscoveryJSON(w, map[string]any{ + "discoveries": summaries, + "total": len(summaries), + "host": hostID, + }) +} diff --git a/internal/api/auth_oidc_refresh_additional_test.go b/internal/api/auth_oidc_refresh_additional_test.go new file mode 100644 index 000000000..ad5b4e730 --- /dev/null +++ b/internal/api/auth_oidc_refresh_additional_test.go @@ -0,0 +1,232 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func newOIDCTestServer(t *testing.T, tokenStatus int, tokenBody map[string]interface{}) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + baseURL := scheme + "://" + r.Host + + switch r.URL.Path { + case "/.well-known/openid-configuration": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": baseURL, + "authorization_endpoint": baseURL + "/auth", + "token_endpoint": baseURL + "/token", + "jwks_uri": baseURL + "/jwks", + }) + case "/jwks": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{"keys": []interface{}{}}) + case "/token": + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tokenStatus) + if tokenBody != nil { + _ = json.NewEncoder(w).Encode(tokenBody) + } + default: + http.NotFound(w, r) + } + })) +} + +func TestRefreshOIDCSessionTokens_Success(t *testing.T) { + InitSessionStore(t.TempDir()) + store := GetSessionStore() + + tokenResp := map[string]interface{}{ + "access_token": "new-access", + "refresh_token": "new-refresh", + "expires_in": 3600, + "token_type": "Bearer", + } + server := newOIDCTestServer(t, http.StatusOK, tokenResp) + defer server.Close() + + cfg := &config.Config{ + OIDC: &config.OIDCConfig{ + Enabled: true, + IssuerURL: server.URL, + ClientID: "client", + ClientSecret: "secret", + RedirectURL: "http://localhost/callback", + Scopes: []string{"openid"}, + }, + } + + sessionToken := "oidc-session-success" + store.CreateOIDCSession(sessionToken, time.Hour, "agent", "127.0.0.1", "user", &OIDCTokenInfo{ + RefreshToken: "old-refresh", + AccessTokenExp: time.Now().Add(-time.Minute), + Issuer: server.URL, + ClientID: "client", + }) + + session := store.GetSession(sessionToken) + if session == nil { + t.Fatalf("expected session to exist") + } + + refreshOIDCSessionTokens(cfg, sessionToken, session) + + updated := store.GetSession(sessionToken) + if updated == nil { + t.Fatalf("expected session to remain after refresh") + } + if updated.OIDCRefreshToken != "new-refresh" { + t.Fatalf("expected refresh token to update, got %q", updated.OIDCRefreshToken) + } + if time.Until(updated.OIDCAccessTokenExp) <= 0 { + t.Fatalf("expected access token expiry to be in the future") + } +} + +func TestRefreshOIDCSessionTokens_IssuerMismatchInvalidates(t *testing.T) { + InitSessionStore(t.TempDir()) + store := GetSessionStore() + + cfg := &config.Config{ + OIDC: &config.OIDCConfig{ + Enabled: true, + IssuerURL: "https://issuer.example", + }, + } + + sessionToken := "oidc-session-mismatch" + store.CreateOIDCSession(sessionToken, time.Hour, "agent", "127.0.0.1", "user", &OIDCTokenInfo{ + RefreshToken: "refresh", + AccessTokenExp: time.Now().Add(-time.Minute), + Issuer: "https://different-issuer", + ClientID: "client", + }) + + session := store.GetSession(sessionToken) + if session == nil { + t.Fatalf("expected session to exist") + } + + refreshOIDCSessionTokens(cfg, sessionToken, session) + + if store.GetSession(sessionToken) != nil { + t.Fatalf("expected session to be invalidated on issuer mismatch") + } +} + +func TestRefreshOIDCSessionTokens_RefreshFailureInvalidates(t *testing.T) { + InitSessionStore(t.TempDir()) + store := GetSessionStore() + + tokenResp := map[string]interface{}{ + "error": "invalid_grant", + "error_description": "refresh token expired", + } + server := newOIDCTestServer(t, http.StatusBadRequest, tokenResp) + defer server.Close() + + cfg := &config.Config{ + OIDC: &config.OIDCConfig{ + Enabled: true, + IssuerURL: server.URL, + ClientID: "client", + ClientSecret: "secret", + RedirectURL: "http://localhost/callback", + Scopes: []string{"openid"}, + }, + } + + sessionToken := "oidc-session-failure" + store.CreateOIDCSession(sessionToken, time.Hour, "agent", "127.0.0.1", "user", &OIDCTokenInfo{ + RefreshToken: "refresh", + AccessTokenExp: time.Now().Add(-time.Minute), + Issuer: server.URL, + ClientID: "client", + }) + + session := store.GetSession(sessionToken) + if session == nil { + t.Fatalf("expected session to exist") + } + + refreshOIDCSessionTokens(cfg, sessionToken, session) + + if store.GetSession(sessionToken) != nil { + t.Fatalf("expected session to be invalidated after refresh failure") + } +} + +func TestRefreshOIDCSessionTokens_OIDCDisabledDoesNotInvalidate(t *testing.T) { + InitSessionStore(t.TempDir()) + store := GetSessionStore() + + cfg := &config.Config{ + OIDC: &config.OIDCConfig{ + Enabled: false, + }, + } + + sessionToken := "oidc-session-disabled" + store.CreateOIDCSession(sessionToken, time.Hour, "agent", "127.0.0.1", "user", &OIDCTokenInfo{ + RefreshToken: "refresh", + AccessTokenExp: time.Now().Add(-time.Minute), + Issuer: "https://issuer.example", + ClientID: "client", + }) + + session := store.GetSession(sessionToken) + if session == nil { + t.Fatalf("expected session to exist") + } + + refreshOIDCSessionTokens(cfg, sessionToken, session) + + if store.GetSession(sessionToken) == nil { + t.Fatalf("expected session to remain when OIDC is disabled") + } +} + +func TestNewOIDCTestServer_IssuerFieldUsesURL(t *testing.T) { + server := newOIDCTestServer(t, http.StatusOK, nil) + defer server.Close() + + resp, err := http.Get(server.URL + "/.well-known/openid-configuration") + if err != nil { + t.Fatalf("failed to fetch discovery doc: %v", err) + } + defer resp.Body.Close() + + var body map[string]string + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("failed to decode discovery doc: %v", err) + } + + issuer := body["issuer"] + if issuer == "" { + t.Fatalf("expected issuer in discovery doc") + } + if !strings.Contains(issuer, "http://") { + t.Fatalf("expected issuer to include scheme, got %q", issuer) + } + if _, err := url.Parse(issuer); err != nil { + t.Fatalf("expected issuer to parse as URL: %v", err) + } +} diff --git a/internal/api/authorization_additional_test.go b/internal/api/authorization_additional_test.go new file mode 100644 index 000000000..e2693cfa9 --- /dev/null +++ b/internal/api/authorization_additional_test.go @@ -0,0 +1,17 @@ +package api + +import "testing" + +func TestMultiTenantOrganizationLoader_NoPersistence(t *testing.T) { + loader := NewMultiTenantOrganizationLoader(nil) + if _, err := loader.GetOrganization("org"); err == nil { + t.Fatalf("expected error when persistence is nil") + } +} + +func TestDefaultAuthorizationChecker_CanAccessOrg_Default(t *testing.T) { + checker := NewAuthorizationChecker(nil) + if !checker.CanAccessOrg("user", nil, "default") { + t.Fatalf("expected default org access") + } +} diff --git a/internal/api/config_handlers_cluster_additional_test.go b/internal/api/config_handlers_cluster_additional_test.go new file mode 100644 index 000000000..157e1a260 --- /dev/null +++ b/internal/api/config_handlers_cluster_additional_test.go @@ -0,0 +1,458 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/mock" + "github.com/rcourtman/pulse-go-rewrite/pkg/proxmox" +) + +func TestMaybeRefreshClusterInfo_UpdatesMetadata(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + originalDetect := detectPVECluster + t.Cleanup(func() { detectPVECluster = originalDetect }) + + called := false + detectPVECluster = func(clientConfig proxmox.ClientConfig, nodeName string, existing []config.ClusterEndpoint) (bool, string, []config.ClusterEndpoint) { + called = true + return true, "unknown cluster", []config.ClusterEndpoint{ + {NodeName: "node-1", Host: "https://node-1.local:8006"}, + } + } + + instance := config.PVEInstance{ + Name: "pve-1", + Host: "https://pve-1.local:8006", + TokenValue: "token", + } + + handler.maybeRefreshClusterInfo(context.Background(), &instance) + + if !called { + t.Fatalf("expected detectPVECluster to be called") + } + if !instance.IsCluster { + t.Fatalf("expected instance to be marked as cluster") + } + if instance.ClusterName != "pve-1" { + t.Fatalf("expected cluster name to default to instance name, got %q", instance.ClusterName) + } + if len(instance.ClusterEndpoints) != 1 { + t.Fatalf("expected cluster endpoints to be updated") + } +} + +func TestMaybeRefreshClusterInfo_SkipsWithoutCredentials(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + originalDetect := detectPVECluster + t.Cleanup(func() { detectPVECluster = originalDetect }) + + called := false + detectPVECluster = func(clientConfig proxmox.ClientConfig, nodeName string, existing []config.ClusterEndpoint) (bool, string, []config.ClusterEndpoint) { + called = true + return true, "cluster", nil + } + + instance := config.PVEInstance{ + Name: "pve-1", + Host: "https://pve-1.local:8006", + } + + handler.maybeRefreshClusterInfo(context.Background(), &instance) + + if called { + t.Fatalf("expected detectPVECluster to be skipped without credentials") + } +} + +func TestMaybeRefreshClusterInfo_SkipsWithinCooldown(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + originalDetect := detectPVECluster + t.Cleanup(func() { detectPVECluster = originalDetect }) + + called := false + detectPVECluster = func(clientConfig proxmox.ClientConfig, nodeName string, existing []config.ClusterEndpoint) (bool, string, []config.ClusterEndpoint) { + called = true + return true, "cluster", nil + } + + instance := config.PVEInstance{ + Name: "pve-1", + Host: "https://pve-1.local:8006", + TokenValue: "token", + } + handler.lastClusterDetection[instance.Name] = time.Now() + + handler.maybeRefreshClusterInfo(context.Background(), &instance) + + if called { + t.Fatalf("expected detectPVECluster to be skipped during cooldown") + } +} + +func TestIsContainerSSHRestricted(t *testing.T) { + t.Setenv("PULSE_DOCKER", "true") + t.Setenv("PULSE_DEV_ALLOW_CONTAINER_SSH", "") + + if !isContainerSSHRestricted() { + t.Fatalf("expected SSH to be restricted in container") + } + + t.Setenv("PULSE_DEV_ALLOW_CONTAINER_SSH", "true") + if isContainerSSHRestricted() { + t.Fatalf("expected SSH restriction to be disabled when override is true") + } +} + +func TestResolveHostnameToIP(t *testing.T) { + if got := resolveHostnameToIP("https://127.0.0.1:8006"); got != "127.0.0.1" { + t.Fatalf("expected IP passthrough, got %q", got) + } + + got := resolveHostnameToIP("https://localhost:8006") + if got == "" || (got != "127.0.0.1" && got != "::1") { + t.Fatalf("expected localhost to resolve to loopback, got %q", got) + } + + if got := resolveHostnameToIP("not-a-url"); got != "" { + t.Fatalf("expected invalid host to return empty string, got %q", got) + } +} + +func TestGetAllNodesForAPI(t *testing.T) { + monitorDisks := true + tempEnabled := true + cfg := &config.Config{ + DataPath: t.TempDir(), + PVEInstances: []config.PVEInstance{ + { + Name: "pve-1", + Host: "https://pve-1.local:8006", + GuestURL: "https://guest.local", + User: "root@pam", + Password: "secret", + TokenName: "token", + TokenValue: "token-value", + Fingerprint: "fp", + VerifySSL: true, + MonitorVMs: true, + MonitorContainers: true, + MonitorStorage: false, + MonitorBackups: true, + MonitorPhysicalDisks: &monitorDisks, + PhysicalDiskPollingMinutes: 15, + TemperatureMonitoringEnabled: &tempEnabled, + IsCluster: true, + ClusterName: "cluster-1", + ClusterEndpoints: []config.ClusterEndpoint{ + {NodeName: "pve-1", Host: "https://pve-1.local:8006"}, + }, + Source: "agent", + }, + }, + PBSInstances: []config.PBSInstance{ + { + Name: "pbs-1", + Host: "https://pbs-1.local:8007", + GuestURL: "https://pbs-guest.local", + User: "backup@pam", + TokenName: "token", + TokenValue: "pbs-token", + VerifySSL: false, + MonitorDatastores: true, + ExcludeDatastores: []string{"ds1"}, + Source: "script", + }, + }, + PMGInstances: []config.PMGInstance{ + { + Name: "pmg-1", + Host: "https://pmg-1.local:8008", + User: "admin@pam", + }, + }, + } + + handler := newTestConfigHandlers(t, cfg) + nodes := handler.GetAllNodesForAPI(context.Background()) + + if len(nodes) != 3 { + t.Fatalf("expected 3 nodes, got %d", len(nodes)) + } + + var pveNode, pbsNode, pmgNode *NodeResponse + for i := range nodes { + node := nodes[i] + switch node.Type { + case "pve": + pveNode = &node + case "pbs": + pbsNode = &node + case "pmg": + pmgNode = &node + } + } + + if pveNode == nil || pbsNode == nil || pmgNode == nil { + t.Fatalf("expected pve, pbs, and pmg nodes to be present") + } + + if !pveNode.HasPassword || !pveNode.HasToken || pveNode.ClusterName != "cluster-1" { + t.Fatalf("unexpected PVE node fields: %+v", pveNode) + } + if pveNode.MonitorPhysicalDisks == nil || !*pveNode.MonitorPhysicalDisks { + t.Fatalf("expected PVE MonitorPhysicalDisks to be true") + } + if pveNode.Status != "disconnected" { + t.Fatalf("expected PVE status to be disconnected, got %q", pveNode.Status) + } + + if !pbsNode.HasToken || len(pbsNode.ExcludeDatastores) != 1 { + t.Fatalf("unexpected PBS node fields: %+v", pbsNode) + } + + if !pmgNode.MonitorMailStats || pmgNode.MonitorQueues || pmgNode.MonitorQuarantine { + t.Fatalf("unexpected PMG monitoring flags: %+v", pmgNode) + } +} + +func TestHandleRefreshClusterNodes_Success(t *testing.T) { + cfg := &config.Config{ + DataPath: t.TempDir(), + PVEInstances: []config.PVEInstance{ + { + Name: "pve-1", + Host: "https://pve-1.local:8006", + TokenValue: "token", + }, + }, + } + handler := newTestConfigHandlers(t, cfg) + + originalDetect := detectPVECluster + t.Cleanup(func() { detectPVECluster = originalDetect }) + detectPVECluster = func(clientConfig proxmox.ClientConfig, nodeName string, existing []config.ClusterEndpoint) (bool, string, []config.ClusterEndpoint) { + return true, "cluster-1", []config.ClusterEndpoint{ + {NodeName: "node-1", Host: "https://node-1.local:8006"}, + {NodeName: "node-2", Host: "https://node-2.local:8006"}, + } + } + + req := httptest.NewRequest(http.MethodPost, "/api/config/nodes/pve-0/refresh-cluster", nil) + rec := httptest.NewRecorder() + handler.HandleRefreshClusterNodes(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp["clusterName"] != "cluster-1" { + t.Fatalf("expected clusterName to be cluster-1, got %v", resp["clusterName"]) + } + if cfg.PVEInstances[0].ClusterName != "cluster-1" || !cfg.PVEInstances[0].IsCluster { + t.Fatalf("expected instance to be updated as cluster") + } + if len(cfg.PVEInstances[0].ClusterEndpoints) != 2 { + t.Fatalf("expected cluster endpoints to be updated") + } +} + +func TestHandleRefreshClusterNodes_InvalidNodeType(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + req := httptest.NewRequest(http.MethodPost, "/api/config/nodes/pbs-0/refresh-cluster", nil) + rec := httptest.NewRecorder() + handler.HandleRefreshClusterNodes(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", rec.Code) + } +} + +func TestHandleTestNodeConfig_InvalidType(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + body := []byte(`{"type":"invalid"}`) + req := httptest.NewRequest(http.MethodPost, "/api/config/nodes/test-config", bytes.NewReader(body)) + rec := httptest.NewRecorder() + handler.HandleTestNodeConfig(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", rec.Code) + } +} + +func TestHandleTestNode_InvalidPath(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + req := httptest.NewRequest(http.MethodGet, "/api/config/nodes/pve-0", nil) + rec := httptest.NewRecorder() + handler.HandleTestNode(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", rec.Code) + } +} + +func TestGetNodeStatus_RecentlyAutoRegistered(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + handler.markAutoRegistered("pve", "node-1") + + if status := handler.getNodeStatus(context.Background(), "pve", "node-1"); status != "connected" { + t.Fatalf("expected connected for recently auto-registered node, got %q", status) + } + if status := handler.getNodeStatus(context.Background(), "pve", "node-2"); status != "disconnected" { + t.Fatalf("expected disconnected for unknown node, got %q", status) + } +} + +func TestHandleGetSystemSettings_ConfigOverrides(t *testing.T) { + cfg := &config.Config{ + DataPath: t.TempDir(), + PVEPollingInterval: 30 * time.Second, + PBSPollingInterval: 90 * time.Second, + BackupPollingInterval: 12 * time.Second, + BackendPort: 8081, + FrontendPort: 3000, + AllowedOrigins: "https://example.com", + ConnectionTimeout: 15 * time.Second, + UpdateChannel: "stable", + AutoUpdateEnabled: true, + AutoUpdateCheckInterval: 6 * time.Hour, + AutoUpdateTime: "03:30", + LogLevel: "debug", + DiscoveryEnabled: true, + DiscoverySubnet: "10.0.0.0/24", + Discovery: config.DefaultDiscoveryConfig(), + EnableBackupPolling: false, + PublicURL: "https://public.example", + TemperatureMonitoringEnabled: true, + } + handler := newTestConfigHandlers(t, cfg) + + settings := config.DefaultSystemSettings() + settings.Theme = "light" + settings.FullWidthMode = true + if err := handler.getPersistence(context.Background()).SaveSystemSettings(*settings); err != nil { + t.Fatalf("SaveSystemSettings: %v", err) + } + + t.Setenv("PULSE_AUTH_HIDE_LOCAL_LOGIN", "1") + + req := httptest.NewRequest(http.MethodGet, "/api/config/system", nil) + rec := httptest.NewRecorder() + handler.HandleGetSystemSettings(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d", rec.Code) + } + + var resp struct { + config.SystemSettings + EnvOverrides map[string]bool `json:"envOverrides"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.PVEPollingInterval != 30 { + t.Fatalf("expected PVEPollingInterval to be 30, got %d", resp.PVEPollingInterval) + } + if resp.Theme != "light" || !resp.FullWidthMode { + t.Fatalf("expected persisted theme settings to remain, got %+v", resp.SystemSettings) + } + if resp.BackupPollingEnabled == nil || *resp.BackupPollingEnabled { + t.Fatalf("expected backup polling enabled to be false") + } + if !resp.EnvOverrides["hideLocalLogin"] { + t.Fatalf("expected hideLocalLogin env override to be true") + } +} + +func TestHandleGetMockMode(t *testing.T) { + prevConfig := mock.GetConfig() + prevEnabled := mock.IsMockEnabled() + t.Cleanup(func() { + mock.SetMockConfig(prevConfig) + mock.SetEnabled(prevEnabled) + }) + + mock.SetEnabled(false) + mock.SetMockConfig(mock.MockConfig{ + NodeCount: 3, + RandomMetrics: false, + }) + + handler := newTestConfigHandlers(t, &config.Config{DataPath: t.TempDir()}) + req := httptest.NewRequest(http.MethodGet, "/api/config/mock", nil) + rec := httptest.NewRecorder() + handler.HandleGetMockMode(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d", rec.Code) + } + + var resp struct { + Enabled bool `json:"enabled"` + Config mock.MockConfig `json:"config"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Enabled { + t.Fatalf("expected mock mode disabled") + } + if resp.Config.NodeCount != 3 || resp.Config.RandomMetrics { + t.Fatalf("unexpected mock config: %+v", resp.Config) + } +} + +func TestHandleAgentInstallCommand(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + handler := newTestConfigHandlers(t, cfg) + + body := []byte(`{"type":"pve"}`) + req := httptest.NewRequest(http.MethodPost, "/api/config/agent-install", bytes.NewReader(body)) + req.Host = "example.com:8080" + rec := httptest.NewRecorder() + handler.HandleAgentInstallCommand(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status OK, got %d: %s", rec.Code, rec.Body.String()) + } + + var resp AgentInstallCommandResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Token == "" || resp.Command == "" { + t.Fatalf("expected token and command in response") + } + if !bytes.Contains([]byte(resp.Command), []byte(resp.Token)) { + t.Fatalf("expected command to include token") + } + if len(cfg.APITokens) != 1 { + t.Fatalf("expected API token to be persisted") + } +} diff --git a/internal/api/config_handlers_helpers_additional_test.go b/internal/api/config_handlers_helpers_additional_test.go new file mode 100644 index 000000000..4e130ff59 --- /dev/null +++ b/internal/api/config_handlers_helpers_additional_test.go @@ -0,0 +1,67 @@ +package api + +import ( + "net" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/pkg/proxmox" +) + +func TestFindExistingIPOverride(t *testing.T) { + endpoints := []config.ClusterEndpoint{ + {NodeName: "node1", IPOverride: "10.0.0.10"}, + {NodeName: "node2", IPOverride: "10.0.0.11"}, + } + + if got := findExistingIPOverride("node2", endpoints); got != "10.0.0.11" { + t.Fatalf("findExistingIPOverride = %q, want 10.0.0.11", got) + } + if got := findExistingIPOverride("missing", endpoints); got != "" { + t.Fatalf("findExistingIPOverride = %q, want empty", got) + } +} + +func TestExtractIPFromHost(t *testing.T) { + ip := extractIPFromHost("https://10.1.1.5:8006") + if ip == nil || !ip.Equal(net.ParseIP("10.1.1.5")) { + t.Fatalf("extractIPFromHost returned %v, want 10.1.1.5", ip) + } + + ip = extractIPFromHost("10.2.3.4") + if ip == nil || !ip.Equal(net.ParseIP("10.2.3.4")) { + t.Fatalf("extractIPFromHost returned %v, want 10.2.3.4", ip) + } +} + +func TestIPsOnSameNetwork(t *testing.T) { + if !ipsOnSameNetwork(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.50")) { + t.Fatalf("expected 10.0.0.1 and 10.0.0.50 to match") + } + if ipsOnSameNetwork(net.ParseIP("10.0.0.1"), net.ParseIP("10.1.0.1")) { + t.Fatalf("expected 10.0.0.1 and 10.1.0.1 to differ") + } + + ipv6a := net.ParseIP("2001:db8::1") + ipv6b := net.ParseIP("2001:db8::2") + if !ipsOnSameNetwork(ipv6a, ipv6b) { + t.Fatalf("expected IPv6 addresses to match") + } +} + +func TestFindPreferredIP(t *testing.T) { + interfaces := []proxmox.NodeNetworkInterface{ + {Active: 0, Address: "10.0.0.10"}, + {Active: 1, Address: "10.0.0.11"}, + {Active: 1, Address: "10.0.1.10"}, + } + + ref := net.ParseIP("10.0.0.50") + if got := findPreferredIP(interfaces, ref); got != "10.0.0.11" { + t.Fatalf("findPreferredIP = %q, want 10.0.0.11", got) + } + + if got := findPreferredIP(nil, ref); got != "" { + t.Fatalf("findPreferredIP = %q, want empty", got) + } +} diff --git a/internal/api/config_handlers_secure_auto_register_test.go b/internal/api/config_handlers_secure_auto_register_test.go new file mode 100644 index 000000000..20b45219c --- /dev/null +++ b/internal/api/config_handlers_secure_auto_register_test.go @@ -0,0 +1,70 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestHandleSecureAutoRegister_PVE(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("PULSE_DATA_DIR", tempDir) + + cfg := &config.Config{ + DataPath: tempDir, + ConfigPath: tempDir, + } + handler := newTestConfigHandlers(t, cfg) + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + reqBody := AutoRegisterRequest{ + Type: "pve", + Host: server.URL, + ServerName: "test-node", + RequestToken: true, + Username: "root@pam", + Password: "secret", + } + + req := httptest.NewRequest(http.MethodPost, "/api/auto-register/secure", nil) + rec := httptest.NewRecorder() + + handler.handleSecureAutoRegister(rec, req, &reqBody, "127.0.0.1") + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp["status"] != "success" { + t.Fatalf("status = %v, want success", resp["status"]) + } + if resp["tokenId"] == "" || resp["tokenValue"] == "" { + t.Fatalf("expected token details in response") + } + if resp["action"] != "create_token" { + t.Fatalf("action = %v, want create_token", resp["action"]) + } + + if len(handler.legacyConfig.PVEInstances) != 1 { + t.Fatalf("expected 1 PVE instance, got %d", len(handler.legacyConfig.PVEInstances)) + } + instance := handler.legacyConfig.PVEInstances[0] + if !strings.Contains(instance.Host, "https://") { + t.Fatalf("expected normalized host, got %q", instance.Host) + } + if instance.TokenName == "" || instance.TokenValue == "" { + t.Fatalf("expected stored token values") + } +} diff --git a/internal/api/config_profiles_additional_test.go b/internal/api/config_profiles_additional_test.go new file mode 100644 index 000000000..0cc6065d0 --- /dev/null +++ b/internal/api/config_profiles_additional_test.go @@ -0,0 +1,199 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +func newConfigProfileHandler(t *testing.T) (*ConfigProfileHandler, *config.ConfigPersistence) { + t.Helper() + tempDir := t.TempDir() + mtp := config.NewMultiTenantPersistence(tempDir) + persistence, err := mtp.GetPersistence("default") + if err != nil { + t.Fatalf("GetPersistence: %v", err) + } + handler := NewConfigProfileHandler(mtp) + return handler, persistence +} + +func createProfile(t *testing.T, handler *ConfigProfileHandler, name string, cfg models.AgentConfigMap) models.AgentProfile { + t.Helper() + profile := models.AgentProfile{ + Name: name, + Config: cfg, + } + body, _ := json.Marshal(profile) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req = req.WithContext(context.WithValue(req.Context(), "username", "tester")) + rec := httptest.NewRecorder() + + handler.CreateProfile(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("CreateProfile status = %d, body=%s", rec.Code, rec.Body.String()) + } + + var created models.AgentProfile + if err := json.NewDecoder(rec.Body).Decode(&created); err != nil { + t.Fatalf("decode create response: %v", err) + } + return created +} + +func TestConfigProfileHandler_GetProfile(t *testing.T) { + handler, _ := newConfigProfileHandler(t) + created := createProfile(t, handler, "Profile One", models.AgentConfigMap{"interval": "10s"}) + + req := httptest.NewRequest(http.MethodGet, "/"+created.ID, nil) + rec := httptest.NewRecorder() + + handler.GetProfile(rec, req, created.ID) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var got models.AgentProfile + if err := json.NewDecoder(rec.Body).Decode(&got); err != nil { + t.Fatalf("decode response: %v", err) + } + if got.ID != created.ID { + t.Fatalf("profile ID = %s, want %s", got.ID, created.ID) + } +} + +func TestConfigProfileHandler_GetChangeLog_Filtered(t *testing.T) { + handler, persistence := newConfigProfileHandler(t) + created := createProfile(t, handler, "Profile Log", models.AgentConfigMap{"log_level": "debug"}) + + change := models.ProfileChangeLog{ + ID: "log-1", + ProfileID: created.ID, + ProfileName: created.Name, + Action: "create", + NewVersion: 1, + User: "tester", + } + if err := persistence.AppendProfileChangeLog(change); err != nil { + t.Fatalf("AppendProfileChangeLog: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/changelog?profile_id="+created.ID, nil) + rec := httptest.NewRecorder() + + handler.GetChangeLog(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var logs []models.ProfileChangeLog + if err := json.NewDecoder(rec.Body).Decode(&logs); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(logs) == 0 { + t.Fatalf("expected change log entries") + } + if logs[0].ProfileID != created.ID { + t.Fatalf("profile_id = %s, want %s", logs[0].ProfileID, created.ID) + } +} + +func TestConfigProfileHandler_DeploymentStatusLifecycle(t *testing.T) { + handler, _ := newConfigProfileHandler(t) + created := createProfile(t, handler, "Profile Deploy", models.AgentConfigMap{"feature": true}) + + update := models.ProfileDeploymentStatus{ + AgentID: "agent-1", + ProfileID: created.ID, + AssignedVersion: created.Version, + DeployedVersion: created.Version, + DeploymentStatus: "deployed", + } + body, _ := json.Marshal(update) + req := httptest.NewRequest(http.MethodPost, "/deployments", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.UpdateDeploymentStatus(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/deployments?agent_id=agent-1", nil) + rec = httptest.NewRecorder() + handler.GetDeploymentStatus(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var statuses []models.ProfileDeploymentStatus + if err := json.NewDecoder(rec.Body).Decode(&statuses); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(statuses) != 1 { + t.Fatalf("statuses = %d, want 1", len(statuses)) + } + if statuses[0].AgentID != "agent-1" { + t.Fatalf("agent_id = %s, want agent-1", statuses[0].AgentID) + } +} + +func TestConfigProfileHandler_VersionsAndRollback(t *testing.T) { + handler, _ := newConfigProfileHandler(t) + created := createProfile(t, handler, "Profile Versioned", models.AgentConfigMap{"log_level": "debug"}) + + update := models.AgentProfile{ + Name: "Profile Versioned", + Config: models.AgentConfigMap{"log_level": "info"}, + } + updateBody, _ := json.Marshal(update) + updateReq := httptest.NewRequest(http.MethodPut, "/"+created.ID, bytes.NewReader(updateBody)) + updateReq = updateReq.WithContext(context.WithValue(updateReq.Context(), "username", "tester")) + updateRec := httptest.NewRecorder() + handler.UpdateProfile(updateRec, updateReq, created.ID) + if updateRec.Code != http.StatusOK { + t.Fatalf("UpdateProfile status = %d, body=%s", updateRec.Code, updateRec.Body.String()) + } + + req := httptest.NewRequest(http.MethodGet, "/"+created.ID+"/versions", nil) + rec := httptest.NewRecorder() + handler.GetProfileVersions(rec, req, created.ID) + if rec.Code != http.StatusOK { + t.Fatalf("GetProfileVersions status = %d", rec.Code) + } + var versions []models.AgentProfileVersion + if err := json.NewDecoder(rec.Body).Decode(&versions); err != nil { + t.Fatalf("decode versions: %v", err) + } + if len(versions) < 2 { + t.Fatalf("expected multiple versions, got %d", len(versions)) + } + + rollbackReq := httptest.NewRequest(http.MethodPost, "/"+created.ID+"/rollback/1", nil) + rollbackReq = rollbackReq.WithContext(context.WithValue(rollbackReq.Context(), "username", "tester")) + rollbackRec := httptest.NewRecorder() + handler.RollbackProfile(rollbackRec, rollbackReq, created.ID, "1") + if rollbackRec.Code != http.StatusOK { + t.Fatalf("RollbackProfile status = %d, body=%s", rollbackRec.Code, rollbackRec.Body.String()) + } + + var rolled models.AgentProfile + if err := json.NewDecoder(rollbackRec.Body).Decode(&rolled); err != nil { + t.Fatalf("decode rollback response: %v", err) + } + if rolled.Version != created.Version+2 { + t.Fatalf("version = %d, want %d", rolled.Version, created.Version+2) + } + if rolled.Config["log_level"] != "debug" { + t.Fatalf("config log_level = %v, want debug", rolled.Config["log_level"]) + } +} diff --git a/internal/api/diagnostics_additional_test.go b/internal/api/diagnostics_additional_test.go new file mode 100644 index 000000000..7e7ae0ff9 --- /dev/null +++ b/internal/api/diagnostics_additional_test.go @@ -0,0 +1,447 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/monitoring" + agentsdocker "github.com/rcourtman/pulse-go-rewrite/pkg/agents/docker" + "github.com/rcourtman/pulse-go-rewrite/pkg/proxmox" +) + +type stubAIPersistence struct { + cfg *config.AIConfig + dataDir string + err error +} + +func (s stubAIPersistence) LoadAIConfig() (*config.AIConfig, error) { + if s.err != nil { + return nil, s.err + } + return s.cfg, nil +} + +func (s stubAIPersistence) DataDir() string { + return s.dataDir +} + +type proxmoxTestResponse struct { + status int + body string +} + +func newProxmoxTestServer(t *testing.T, responses map[string]proxmoxTestResponse) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp, ok := responses[r.URL.Path] + if !ok { + http.NotFound(w, r) + return + } + status := resp.status + if status == 0 { + status = http.StatusOK + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if resp.body != "" { + _, _ = w.Write([]byte(resp.body)) + } + })) +} + +func newProxmoxClient(t *testing.T, serverURL string) *proxmox.Client { + t.Helper() + + client, err := proxmox.NewClient(proxmox.ClientConfig{ + Host: serverURL, + TokenName: "user@pam!token", + TokenValue: "secret", + VerifySSL: false, + }) + if err != nil { + t.Fatalf("proxmox.NewClient: %v", err) + } + return client +} + +func newMonitorForDiagnostics(t *testing.T, cfg *config.Config) *monitoring.Monitor { + t.Helper() + + if cfg == nil { + cfg = &config.Config{DataPath: t.TempDir()} + } + if cfg.DataPath == "" { + cfg.DataPath = t.TempDir() + } + + monitor, err := monitoring.New(cfg) + if err != nil { + t.Fatalf("monitoring.New: %v", err) + } + t.Cleanup(func() { monitor.Stop() }) + return monitor +} + +func TestHandleDiagnostics_CacheHit(t *testing.T) { + cached := DiagnosticsInfo{Version: "cached"} + cachedAt := time.Now() + + diagnosticsCacheMu.Lock() + prevCache := diagnosticsCache + prevTimestamp := diagnosticsCacheTimestamp + diagnosticsCache = cached + diagnosticsCacheTimestamp = cachedAt + diagnosticsCacheMu.Unlock() + + t.Cleanup(func() { + diagnosticsCacheMu.Lock() + diagnosticsCache = prevCache + diagnosticsCacheTimestamp = prevTimestamp + diagnosticsCacheMu.Unlock() + }) + + router := &Router{config: &config.Config{}} + req := httptest.NewRequest(http.MethodGet, "/api/diagnostics", nil) + rec := httptest.NewRecorder() + + router.handleDiagnostics(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var payload DiagnosticsInfo + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode diagnostics: %v", err) + } + if payload.Version != "cached" { + t.Fatalf("version = %q, want cached", payload.Version) + } + if rec.Header().Get("X-Diagnostics-Cached-At") == "" { + t.Fatalf("expected cached-at header to be set") + } +} + +func TestComputeDiagnostics_Basic(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + monitor := newMonitorForDiagnostics(t, cfg) + + router := &Router{config: cfg, monitor: monitor} + diag := router.computeDiagnostics(context.Background()) + + if diag.System.OS == "" { + t.Fatalf("expected system OS to be populated") + } + if diag.MetricsStore == nil { + t.Fatalf("expected metrics store diagnostics") + } + if diag.APITokens == nil { + t.Fatalf("expected api token diagnostics") + } + if diag.Discovery == nil { + t.Fatalf("expected discovery diagnostics") + } + if diag.AIChat == nil { + t.Fatalf("expected ai chat diagnostics") + } +} + +func TestBuildAPITokenDiagnostic_WithDockerUsage(t *testing.T) { + now := time.Now() + lastUsed := now.Add(-time.Hour) + cfg := &config.Config{ + DataPath: t.TempDir(), + EnvOverrides: map[string]bool{ + "API_TOKEN": true, + }, + APITokens: []config.APITokenRecord{ + { + ID: "token-1", + Name: "Environment token", + Prefix: "pre", + Suffix: "suf", + CreatedAt: now, + }, + { + ID: "token-2", + Name: "Legacy token", + CreatedAt: now, + LastUsedAt: &lastUsed, + }, + }, + } + + monitor := newMonitorForDiagnostics(t, cfg) + + report := agentsdocker.Report{ + Agent: agentsdocker.AgentInfo{ + ID: "agent-1", + Version: "1.0.0", + }, + Host: agentsdocker.HostInfo{ + Hostname: "docker-host-1", + MachineID: "machine-1", + }, + Timestamp: time.Now().UTC(), + } + if _, err := monitor.ApplyDockerReport(report, &config.APITokenRecord{ID: "token-1"}); err != nil { + t.Fatalf("ApplyDockerReport: %v", err) + } + + legacyReport := report + legacyReport.Agent.ID = "agent-2" + legacyReport.Host.Hostname = "docker-legacy" + legacyReport.Host.MachineID = "machine-2" + if _, err := monitor.ApplyDockerReport(legacyReport, nil); err != nil { + t.Fatalf("ApplyDockerReport legacy: %v", err) + } + + diag := buildAPITokenDiagnostic(cfg, monitor) + if diag == nil || !diag.Enabled { + t.Fatalf("expected diagnostics enabled") + } + if diag.TokenCount != 2 { + t.Fatalf("token count = %d, want 2", diag.TokenCount) + } + if !diag.HasEnvTokens || !diag.HasLegacyToken { + t.Fatalf("expected env and legacy tokens to be detected") + } + if diag.LegacyDockerHostCount != 1 { + t.Fatalf("legacy docker host count = %d, want 1", diag.LegacyDockerHostCount) + } + if diag.UnusedTokenCount != 1 { + t.Fatalf("unused token count = %d, want 1", diag.UnusedTokenCount) + } + if len(diag.Usage) != 1 || diag.Usage[0].TokenID != "token-1" { + t.Fatalf("expected token usage for token-1") + } +} + +func TestBuildDockerAgentDiagnostic(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + monitor := newMonitorForDiagnostics(t, cfg) + + now := time.Now().UTC() + report := agentsdocker.Report{ + Agent: agentsdocker.AgentInfo{ + ID: "agent-1", + Version: "0.9.0", + }, + Host: agentsdocker.HostInfo{ + Hostname: "docker-old", + MachineID: "machine-old", + }, + Timestamp: now.Add(-5 * time.Minute), + } + if _, err := monitor.ApplyDockerReport(report, &config.APITokenRecord{ID: "token-1"}); err != nil { + t.Fatalf("ApplyDockerReport: %v", err) + } + + legacyReport := report + legacyReport.Agent.ID = "agent-2" + legacyReport.Agent.Version = "" + legacyReport.Host.Hostname = "docker-legacy" + legacyReport.Host.MachineID = "machine-legacy" + legacyReport.Timestamp = now.Add(-20 * time.Minute) + if _, err := monitor.ApplyDockerReport(legacyReport, nil); err != nil { + t.Fatalf("ApplyDockerReport legacy: %v", err) + } + + diag := buildDockerAgentDiagnostic(monitor, "1.0.0") + if diag == nil { + t.Fatalf("expected diagnostics") + } + if diag.HostsTotal != 2 { + t.Fatalf("hosts total = %d, want 2", diag.HostsTotal) + } + if diag.HostsOutdatedVersion != 1 { + t.Fatalf("outdated version count = %d, want 1", diag.HostsOutdatedVersion) + } + if diag.HostsWithoutVersion != 1 { + t.Fatalf("missing version count = %d, want 1", diag.HostsWithoutVersion) + } + if diag.HostsWithoutTokenBinding != 1 { + t.Fatalf("hosts without token binding = %d, want 1", diag.HostsWithoutTokenBinding) + } + if diag.HostsNeedingAttention == 0 { + t.Fatalf("expected hosts needing attention") + } +} + +func TestBuildAlertsDiagnostic_LegacySettings(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir()} + monitor := newMonitorForDiagnostics(t, cfg) + + manager := monitor.GetAlertManager() + alertCfg := manager.GetConfig() + legacy := 90.0 + alertCfg.GuestDefaults.CPULegacy = &legacy + alertCfg.TimeThreshold = 5 + alertCfg.Schedule.GroupingWindow = 10 + alertCfg.Schedule.Grouping.Window = 0 + alertCfg.Schedule.Cooldown = 0 + manager.UpdateConfig(alertCfg) + + diag := buildAlertsDiagnostic(monitor) + if diag == nil { + t.Fatalf("expected diagnostics") + } + if !diag.LegacyThresholdsDetected { + t.Fatalf("expected legacy thresholds to be detected") + } + if !diag.MissingCooldown || !diag.MissingGroupingWindow { + t.Fatalf("expected missing schedule settings") + } + if len(diag.Notes) == 0 { + t.Fatalf("expected notes to be populated") + } +} + +func TestBuildDiscoveryDiagnostic_ConfigOnly(t *testing.T) { + cfg := &config.Config{ + DiscoveryEnabled: true, + DiscoverySubnet: "", + Discovery: config.DiscoveryConfig{ + EnvironmentOverride: " 10.0.0.0/24 ", + }, + } + + diag := buildDiscoveryDiagnostic(cfg, nil) + if diag == nil { + t.Fatalf("expected discovery diagnostics") + } + if diag.ConfiguredSubnet != "auto" { + t.Fatalf("configured subnet = %q, want auto", diag.ConfiguredSubnet) + } + if diag.EnvironmentOverride != "10.0.0.0/24" { + t.Fatalf("environment override = %q, want trimmed", diag.EnvironmentOverride) + } + if diag.SubnetAllowlist == nil { + t.Fatalf("expected allowlist to be initialized") + } +} + +func TestBuildAIChatDiagnostic_WithService(t *testing.T) { + aiCfg := &config.AIConfig{ + Enabled: true, + Provider: config.AIProviderOllama, + ChatModel: "ollama:llama3", + } + handler := &AIHandler{ + legacyPersistence: stubAIPersistence{cfg: aiCfg, dataDir: t.TempDir()}, + } + + mockSvc := new(MockAIService) + mockSvc.On("IsRunning").Return(true) + mockSvc.On("GetBaseURL").Return("http://localhost:1234") + handler.legacyService = mockSvc + + diag := buildAIChatDiagnostic(&config.Config{}, handler) + if diag == nil || !diag.Enabled { + t.Fatalf("expected ai chat to be enabled") + } + if !diag.Running || !diag.Healthy { + t.Fatalf("expected ai chat to be running and healthy") + } + if diag.Port != 1234 { + t.Fatalf("port = %d, want 1234", diag.Port) + } + if diag.Model != "ollama:llama3" { + t.Fatalf("model = %q, want ollama:llama3", diag.Model) + } +} + +func TestCheckVMDiskMonitoring_Success(t *testing.T) { + responses := map[string]proxmoxTestResponse{ + "/api2/json/nodes": {body: `{"data":[{"node":"pve1","status":"online"}]}`}, + "/api2/json/nodes/pve1/qemu": {body: `{"data":[{"vmid":100,"name":"vm-100","node":"pve1","status":"running","template":0}]}`}, + "/api2/json/nodes/pve1/qemu/100/status/current": {body: `{"data":{"agent":1}}`}, + "/api2/json/nodes/pve1/qemu/100/agent/get-fsinfo": {body: `{"data":{"result":[{"name":"root","type":"ext4","mountpoint":"/","total-bytes":100,"used-bytes":50}]}}`}, + } + server := newProxmoxTestServer(t, responses) + defer server.Close() + + client := newProxmoxClient(t, server.URL) + router := &Router{} + result := router.checkVMDiskMonitoring(context.Background(), client, "") + + if result.VMsFound != 1 || result.VMsWithAgent != 1 || result.VMsWithDiskData != 1 { + t.Fatalf("unexpected VM stats: %+v", result) + } + if !strings.Contains(result.TestResult, "SUCCESS") { + t.Fatalf("expected success test result, got %q", result.TestResult) + } +} + +func TestCheckVMDiskMonitoring_NoNodes(t *testing.T) { + responses := map[string]proxmoxTestResponse{ + "/api2/json/nodes": {body: `{"data":[]}`}, + } + server := newProxmoxTestServer(t, responses) + defer server.Close() + + client := newProxmoxClient(t, server.URL) + router := &Router{} + result := router.checkVMDiskMonitoring(context.Background(), client, "") + if result.TestResult != "No nodes found" { + t.Fatalf("test result = %q, want No nodes found", result.TestResult) + } +} + +func TestCheckPhysicalDisks_Found(t *testing.T) { + responses := map[string]proxmoxTestResponse{ + "/api2/json/nodes": {body: `{"data":[{"node":"pve1","status":"online"}]}`}, + "/api2/json/nodes/pve1/disks/list": {body: `{"data":[{"devpath":"/dev/sda","model":"Test","serial":"ABC","type":"sata","health":"PASSED"}]}`}, + } + server := newProxmoxTestServer(t, responses) + defer server.Close() + + client := newProxmoxClient(t, server.URL) + router := &Router{} + result := router.checkPhysicalDisks(context.Background(), client, "") + + if result.NodesWithDisks != 1 || result.TotalDisks != 1 { + t.Fatalf("unexpected disk totals: %+v", result) + } + if !strings.Contains(result.TestResult, "Found 1 disks") { + t.Fatalf("unexpected test result: %q", result.TestResult) + } +} + +func TestCheckPhysicalDisks_PermissionDenied(t *testing.T) { + responses := map[string]proxmoxTestResponse{ + "/api2/json/nodes": {body: `{"data":[{"node":"pve1","status":"online"}]}`}, + "/api2/json/nodes/pve1/disks/list": {status: http.StatusForbidden, body: `{"errors":"forbidden"}`}, + } + server := newProxmoxTestServer(t, responses) + defer server.Close() + + client := newProxmoxClient(t, server.URL) + router := &Router{} + result := router.checkPhysicalDisks(context.Background(), client, "") + + if len(result.NodeResults) != 1 { + t.Fatalf("expected one node result") + } + if result.NodeResults[0].APIResponse != "Permission denied" { + t.Fatalf("api response = %q, want Permission denied", result.NodeResults[0].APIResponse) + } + foundNote := false + for _, note := range result.Recommendations { + if strings.Contains(note, "permissions") { + foundNote = true + break + } + } + if !foundNote { + t.Fatalf("expected permissions recommendation") + } +} diff --git a/internal/api/docker_agents_additional_test.go b/internal/api/docker_agents_additional_test.go new file mode 100644 index 000000000..d10d2abb1 --- /dev/null +++ b/internal/api/docker_agents_additional_test.go @@ -0,0 +1,256 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/monitoring" + "github.com/rcourtman/pulse-go-rewrite/internal/websocket" + agentsdocker "github.com/rcourtman/pulse-go-rewrite/pkg/agents/docker" +) + +func newDockerAgentHandlers(t *testing.T, cfg *config.Config) (*DockerAgentHandlers, *monitoring.Monitor) { + t.Helper() + + if cfg == nil { + cfg = &config.Config{DataPath: t.TempDir()} + } + if cfg.DataPath == "" { + cfg.DataPath = t.TempDir() + } + + monitor, err := monitoring.New(cfg) + if err != nil { + t.Fatalf("monitoring.New: %v", err) + } + t.Cleanup(func() { monitor.Stop() }) + + hub := websocket.NewHub(nil) + handler := NewDockerAgentHandlers(nil, monitor, hub, cfg) + return handler, monitor +} + +func seedDockerHost(t *testing.T, monitor *monitoring.Monitor) string { + t.Helper() + + report := agentsdocker.Report{ + Agent: agentsdocker.AgentInfo{ + ID: "agent-1", + Version: "1.0.0", + IntervalSeconds: 30, + }, + Host: agentsdocker.HostInfo{ + Hostname: "docker-host", + Name: "Docker Host", + MachineID: "machine-1", + DockerVersion: "26.0.0", + TotalCPU: 4, + TotalMemoryBytes: 8 << 30, + UptimeSeconds: 120, + }, + Timestamp: time.Now().UTC(), + } + + host, err := monitor.ApplyDockerReport(report, nil) + if err != nil { + t.Fatalf("ApplyDockerReport: %v", err) + } + if host.ID == "" { + t.Fatalf("expected host ID to be set") + } + return host.ID +} + +func TestDockerAgentHandlers_HandleReport(t *testing.T) { + handler, _ := newDockerAgentHandlers(t, nil) + + report := agentsdocker.Report{ + Agent: agentsdocker.AgentInfo{ + ID: "agent-2", + Version: "1.0.0", + IntervalSeconds: 30, + }, + Host: agentsdocker.HostInfo{ + Hostname: "docker-host-2", + Name: "Docker Host 2", + MachineID: "machine-2", + DockerVersion: "26.0.0", + TotalCPU: 4, + TotalMemoryBytes: 4 << 30, + UptimeSeconds: 60, + }, + Timestamp: time.Now().UTC(), + } + body, _ := json.Marshal(report) + + req := httptest.NewRequest(http.MethodPost, "/api/agents/docker/report", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleReport(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp["success"] != true { + t.Fatalf("success = %v, want true", resp["success"]) + } + if resp["hostId"] == "" { + t.Fatalf("expected hostId in response") + } +} + +func TestDockerAgentHandlers_HandleCommandAck(t *testing.T) { + handler, monitor := newDockerAgentHandlers(t, nil) + hostID := seedDockerHost(t, monitor) + + cmdStatus, err := monitor.QueueDockerHostStop(hostID) + if err != nil { + t.Fatalf("QueueDockerHostStop: %v", err) + } + + reqBody := map[string]string{ + "hostId": hostID, + "status": "completed", + } + body, _ := json.Marshal(reqBody) + path := "/api/agents/docker/commands/" + cmdStatus.ID + "/ack" + req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleCommandAck(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestDockerAgentHandlers_HandleDockerHostActions(t *testing.T) { + handler, monitor := newDockerAgentHandlers(t, nil) + hostID := seedDockerHost(t, monitor) + + req := httptest.NewRequest(http.MethodPost, "/api/agents/docker/hosts/"+hostID+"/allow-reenroll", nil) + rec := httptest.NewRecorder() + + handler.HandleDockerHostActions(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestDockerAgentHandlers_HandleDeleteHost(t *testing.T) { + handler, monitor := newDockerAgentHandlers(t, nil) + hostID := seedDockerHost(t, monitor) + + req := httptest.NewRequest(http.MethodDelete, "/api/agents/docker/hosts/"+hostID+"?force=true", nil) + rec := httptest.NewRecorder() + + handler.HandleDeleteHost(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestDockerAgentHandlers_HandleUnhideHost(t *testing.T) { + handler, monitor := newDockerAgentHandlers(t, nil) + hostID := seedDockerHost(t, monitor) + + req := httptest.NewRequest(http.MethodPut, "/api/agents/docker/hosts/"+hostID+"/unhide", nil) + rec := httptest.NewRecorder() + + handler.HandleUnhideHost(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestDockerAgentHandlers_HandleMarkPendingUninstall(t *testing.T) { + handler, monitor := newDockerAgentHandlers(t, nil) + hostID := seedDockerHost(t, monitor) + + req := httptest.NewRequest(http.MethodPut, "/api/agents/docker/hosts/"+hostID+"/pending-uninstall", nil) + rec := httptest.NewRecorder() + + handler.HandleMarkPendingUninstall(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestDockerAgentHandlers_HandleSetCustomDisplayName(t *testing.T) { + handler, monitor := newDockerAgentHandlers(t, nil) + hostID := seedDockerHost(t, monitor) + + body := []byte(`{"displayName":"My Docker Host"}`) + req := httptest.NewRequest(http.MethodPut, "/api/agents/docker/hosts/"+hostID+"/display-name", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleSetCustomDisplayName(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestDockerAgentHandlers_HandleContainerUpdate(t *testing.T) { + handler, monitor := newDockerAgentHandlers(t, &config.Config{DataPath: t.TempDir()}) + hostID := seedDockerHost(t, monitor) + + reqBody := map[string]string{ + "hostId": hostID, + "containerId": "container-1", + "containerName": "nginx", + } + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/agents/docker/containers/container-1/update", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleContainerUpdate(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestDockerAgentHandlers_HandleContainerUpdate_Disabled(t *testing.T) { + cfg := &config.Config{DataPath: t.TempDir(), DisableDockerUpdateActions: true} + handler, monitor := newDockerAgentHandlers(t, cfg) + hostID := seedDockerHost(t, monitor) + + reqBody := map[string]string{ + "hostId": hostID, + "containerId": "container-2", + "containerName": "redis", + } + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/agents/docker/containers/container-2/update", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleContainerUpdate(rec, req) + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want 403", rec.Code) + } +} + +func TestDockerAgentHandlers_HandleCheckUpdates(t *testing.T) { + handler, monitor := newDockerAgentHandlers(t, nil) + hostID := seedDockerHost(t, monitor) + + req := httptest.NewRequest(http.MethodPost, "/api/agents/docker/hosts/"+hostID+"/check-updates", nil) + rec := httptest.NewRecorder() + + handler.HandleCheckUpdates(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } + + if !strings.Contains(rec.Body.String(), "Check for updates") { + t.Fatalf("expected check updates message") + } +} diff --git a/internal/api/docker_metadata_additional_test.go b/internal/api/docker_metadata_additional_test.go new file mode 100644 index 000000000..a63e24a59 --- /dev/null +++ b/internal/api/docker_metadata_additional_test.go @@ -0,0 +1,174 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func newDockerMetadataHandler(t *testing.T) *DockerMetadataHandler { + t.Helper() + tempDir := t.TempDir() + mtp := config.NewMultiTenantPersistence(tempDir) + if _, err := mtp.GetPersistence("default"); err != nil { + t.Fatalf("GetPersistence: %v", err) + } + return NewDockerMetadataHandler(mtp) +} + +func TestDockerMetadataHandlers_ContainerMetadata(t *testing.T) { + handler := newDockerMetadataHandler(t) + + t.Run("get-all-empty", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/docker/metadata", nil) + rec := httptest.NewRecorder() + + handler.HandleGetMetadata(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + var payload map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(payload) != 0 { + t.Fatalf("expected empty map, got %v", payload) + } + }) + + t.Run("update-invalid-url", func(t *testing.T) { + meta := config.DockerMetadata{CustomURL: "ftp://example.com"} + body, _ := json.Marshal(meta) + req := httptest.NewRequest(http.MethodPut, "/api/docker/metadata/host1:container:abc", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleUpdateMetadata(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + }) + + t.Run("update-get-delete", func(t *testing.T) { + meta := config.DockerMetadata{ + CustomURL: "https://example.com", + Description: "test container", + Tags: []string{"app"}, + } + body, _ := json.Marshal(meta) + req := httptest.NewRequest(http.MethodPut, "/api/docker/metadata/host1:container:abc", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleUpdateMetadata(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + getReq := httptest.NewRequest(http.MethodGet, "/api/docker/metadata/host1:container:abc", nil) + getRec := httptest.NewRecorder() + handler.HandleGetMetadata(getRec, getReq) + if getRec.Code != http.StatusOK { + t.Fatalf("get status = %d, want 200", getRec.Code) + } + var got config.DockerMetadata + if err := json.Unmarshal(getRec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode get response: %v", err) + } + if got.CustomURL != "https://example.com" { + t.Fatalf("custom_url = %q, want https://example.com", got.CustomURL) + } + + delReq := httptest.NewRequest(http.MethodDelete, "/api/docker/metadata/host1:container:abc", nil) + delRec := httptest.NewRecorder() + handler.HandleDeleteMetadata(delRec, delReq) + if delRec.Code != http.StatusNoContent { + t.Fatalf("delete status = %d, want 204", delRec.Code) + } + + getReq = httptest.NewRequest(http.MethodGet, "/api/docker/metadata/host1:container:abc", nil) + getRec = httptest.NewRecorder() + handler.HandleGetMetadata(getRec, getReq) + if getRec.Code != http.StatusOK { + t.Fatalf("get status = %d, want 200", getRec.Code) + } + var empty config.DockerMetadata + if err := json.Unmarshal(getRec.Body.Bytes(), &empty); err != nil { + t.Fatalf("decode get response: %v", err) + } + if empty.ID != "host1:container:abc" { + t.Fatalf("expected empty metadata with ID, got %q", empty.ID) + } + }) +} + +func TestDockerMetadataHandlers_HostMetadata(t *testing.T) { + handler := newDockerMetadataHandler(t) + + t.Run("update-and-get-host", func(t *testing.T) { + meta := config.DockerHostMetadata{ + CustomDisplayName: "Host A", + CustomURL: "https://portainer.local", + Notes: []string{"note1"}, + } + body, _ := json.Marshal(meta) + req := httptest.NewRequest(http.MethodPut, "/api/docker/hosts/metadata/host-1", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleUpdateHostMetadata(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + getReq := httptest.NewRequest(http.MethodGet, "/api/docker/hosts/metadata/host-1", nil) + getRec := httptest.NewRecorder() + handler.HandleGetHostMetadata(getRec, getReq) + if getRec.Code != http.StatusOK { + t.Fatalf("get status = %d, want 200", getRec.Code) + } + var got config.DockerHostMetadata + if err := json.Unmarshal(getRec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode get response: %v", err) + } + if got.CustomDisplayName != "Host A" { + t.Fatalf("custom_display_name = %q, want Host A", got.CustomDisplayName) + } + }) + + t.Run("merge-host-metadata", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/api/docker/hosts/metadata/host-1", bytes.NewReader([]byte(`{}`))) + rec := httptest.NewRecorder() + + handler.HandleUpdateHostMetadata(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + + getReq := httptest.NewRequest(http.MethodGet, "/api/docker/hosts/metadata/host-1", nil) + getRec := httptest.NewRecorder() + handler.HandleGetHostMetadata(getRec, getReq) + if getRec.Code != http.StatusOK { + t.Fatalf("get status = %d, want 200", getRec.Code) + } + var got config.DockerHostMetadata + if err := json.Unmarshal(getRec.Body.Bytes(), &got); err != nil { + t.Fatalf("decode get response: %v", err) + } + if got.CustomDisplayName != "Host A" { + t.Fatalf("expected merged display name, got %q", got.CustomDisplayName) + } + }) + + t.Run("delete-host-metadata", func(t *testing.T) { + req := httptest.NewRequest(http.MethodDelete, "/api/docker/hosts/metadata/host-1", nil) + rec := httptest.NewRecorder() + + handler.HandleDeleteHostMetadata(rec, req) + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want 204", rec.Code) + } + }) +} diff --git a/internal/api/guest_metadata_additional_test.go b/internal/api/guest_metadata_additional_test.go new file mode 100644 index 000000000..ba4ce059d --- /dev/null +++ b/internal/api/guest_metadata_additional_test.go @@ -0,0 +1,16 @@ +package api + +import ( + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestGuestMetadataHandler_Reload(t *testing.T) { + mtp := config.NewMultiTenantPersistence(t.TempDir()) + handler := NewGuestMetadataHandler(mtp) + + if err := handler.Reload(); err != nil { + t.Fatalf("Reload error: %v", err) + } +} diff --git a/internal/api/host_agents_additional_test.go b/internal/api/host_agents_additional_test.go new file mode 100644 index 000000000..61737f768 --- /dev/null +++ b/internal/api/host_agents_additional_test.go @@ -0,0 +1,191 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" + "unsafe" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/models" + "github.com/rcourtman/pulse-go-rewrite/internal/monitoring" + "github.com/rcourtman/pulse-go-rewrite/internal/websocket" + agentshost "github.com/rcourtman/pulse-go-rewrite/pkg/agents/host" +) + +func newHostAgentHandlers(t *testing.T, cfg *config.Config) (*HostAgentHandlers, *monitoring.Monitor) { + t.Helper() + + if cfg == nil { + cfg = &config.Config{DataPath: t.TempDir()} + } + if cfg.DataPath == "" { + cfg.DataPath = t.TempDir() + } + + monitor, err := monitoring.New(cfg) + if err != nil { + t.Fatalf("monitoring.New: %v", err) + } + t.Cleanup(func() { monitor.Stop() }) + + hub := websocket.NewHub(nil) + handler := NewHostAgentHandlers(nil, monitor, hub) + return handler, monitor +} + +func monitorState(t *testing.T, monitor *monitoring.Monitor) *models.State { + t.Helper() + + v := reflect.ValueOf(monitor).Elem().FieldByName("state") + ptr := unsafe.Pointer(v.UnsafeAddr()) + return reflect.NewAt(v.Type(), ptr).Elem().Interface().(*models.State) +} + +func seedHostAgent(t *testing.T, monitor *monitoring.Monitor) string { + t.Helper() + + report := agentshost.Report{ + Agent: agentshost.AgentInfo{ + ID: "agent-1", + Version: "1.0.0", + }, + Host: agentshost.HostInfo{ + ID: "machine-1", + Hostname: "host-1.local", + Platform: "linux", + }, + Timestamp: time.Now().UTC(), + } + + host, err := monitor.ApplyHostReport(report, nil) + if err != nil { + t.Fatalf("ApplyHostReport: %v", err) + } + if host.ID == "" { + t.Fatalf("expected host ID to be set") + } + return host.ID +} + +func TestHostAgentHandlers_HandleReport(t *testing.T) { + handler, _ := newHostAgentHandlers(t, nil) + + report := agentshost.Report{ + Agent: agentshost.AgentInfo{ + ID: "agent-2", + Version: "1.0.0", + }, + Host: agentshost.HostInfo{ + ID: "machine-2", + Hostname: "host-2.local", + Platform: "linux", + }, + Timestamp: time.Now().UTC(), + } + body, _ := json.Marshal(report) + + req := httptest.NewRequest(http.MethodPost, "/api/agents/host/report", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleReport(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestHostAgentHandlers_HandleDeleteHost(t *testing.T) { + handler, monitor := newHostAgentHandlers(t, nil) + hostID := seedHostAgent(t, monitor) + + req := httptest.NewRequest(http.MethodDelete, "/api/agents/host/"+hostID, nil) + rec := httptest.NewRecorder() + + handler.HandleDeleteHost(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestHostAgentHandlers_HandleConfigPatch(t *testing.T) { + handler, monitor := newHostAgentHandlers(t, nil) + hostID := seedHostAgent(t, monitor) + + body := []byte(`{"commandsEnabled":true}`) + req := httptest.NewRequest(http.MethodPatch, "/api/agents/host/"+hostID+"/config", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleConfig(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestHostAgentHandlers_EnsureHostTokenMatch(t *testing.T) { + handler, monitor := newHostAgentHandlers(t, nil) + state := monitorState(t, monitor) + state.UpsertHost(models.Host{ + ID: "host-3", + Hostname: "host-3.local", + TokenID: "token-1", + }) + + req := httptest.NewRequest(http.MethodGet, "/api/agents/host/host-3/config", nil) + attachAPITokenRecord(req, &config.APITokenRecord{ + ID: "token-2", + Scopes: []string{config.ScopeHostConfigRead}, + }) + rec := httptest.NewRecorder() + + ok := handler.ensureHostTokenMatch(rec, req, "host-3") + if ok { + t.Fatalf("expected token mismatch") + } + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want 403", rec.Code) + } +} + +func TestHostAgentHandlers_HandleUninstall(t *testing.T) { + handler, monitor := newHostAgentHandlers(t, nil) + hostID := seedHostAgent(t, monitor) + + body := []byte(`{"hostId":"` + hostID + `"}`) + req := httptest.NewRequest(http.MethodPost, "/api/agents/host/uninstall", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleUninstall(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestHostAgentHandlers_HandleLinkUnlink(t *testing.T) { + handler, monitor := newHostAgentHandlers(t, nil) + hostID := seedHostAgent(t, monitor) + + state := monitorState(t, monitor) + state.UpdateNodes([]models.Node{{ID: "node-1", Name: "node-1"}}) + + linkBody := []byte(`{"hostId":"` + hostID + `","nodeId":"node-1"}`) + linkReq := httptest.NewRequest(http.MethodPost, "/api/agents/host/link", bytes.NewReader(linkBody)) + linkRec := httptest.NewRecorder() + + handler.HandleLink(linkRec, linkReq) + if linkRec.Code != http.StatusOK { + t.Fatalf("link status = %d, want 200: %s", linkRec.Code, linkRec.Body.String()) + } + + unlinkBody := []byte(`{"hostId":"` + hostID + `"}`) + unlinkReq := httptest.NewRequest(http.MethodPost, "/api/agents/host/unlink", bytes.NewReader(unlinkBody)) + unlinkRec := httptest.NewRecorder() + + handler.HandleUnlink(unlinkRec, unlinkReq) + if unlinkRec.Code != http.StatusOK { + t.Fatalf("unlink status = %d, want 200: %s", unlinkRec.Code, unlinkRec.Body.String()) + } +} diff --git a/internal/api/kubernetes_agents_additional_test.go b/internal/api/kubernetes_agents_additional_test.go new file mode 100644 index 000000000..79ef4acb7 --- /dev/null +++ b/internal/api/kubernetes_agents_additional_test.go @@ -0,0 +1,194 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/monitoring" + "github.com/rcourtman/pulse-go-rewrite/internal/websocket" + agentsk8s "github.com/rcourtman/pulse-go-rewrite/pkg/agents/kubernetes" +) + +func newKubernetesAgentHandlers(t *testing.T, cfg *config.Config) (*KubernetesAgentHandlers, *monitoring.Monitor) { + t.Helper() + + if cfg == nil { + cfg = &config.Config{DataPath: t.TempDir()} + } + if cfg.DataPath == "" { + cfg.DataPath = t.TempDir() + } + + monitor, err := monitoring.New(cfg) + if err != nil { + t.Fatalf("monitoring.New: %v", err) + } + t.Cleanup(func() { monitor.Stop() }) + + hub := websocket.NewHub(nil) + handler := NewKubernetesAgentHandlers(nil, monitor, hub) + return handler, monitor +} + +func seedKubernetesCluster(t *testing.T, monitor *monitoring.Monitor) string { + t.Helper() + + report := agentsk8s.Report{ + Agent: agentsk8s.AgentInfo{ + ID: "agent-1", + Version: "1.0.0", + IntervalSeconds: 30, + }, + Cluster: agentsk8s.ClusterInfo{ + ID: "cluster-1", + Name: "cluster-1", + Version: "1.28.0", + }, + Nodes: []agentsk8s.Node{ + {Name: "node-1", Ready: true}, + }, + Pods: []agentsk8s.Pod{ + {Name: "pod-1", Namespace: "default", Phase: "Running"}, + }, + Timestamp: time.Now().UTC(), + } + + cluster, err := monitor.ApplyKubernetesReport(report, nil) + if err != nil { + t.Fatalf("ApplyKubernetesReport: %v", err) + } + if cluster.ID == "" { + t.Fatalf("expected cluster ID to be set") + } + return cluster.ID +} + +func TestKubernetesAgentHandlers_SetMonitorGetMonitor(t *testing.T) { + handler, monitor := newKubernetesAgentHandlers(t, nil) + + other, err := monitoring.New(&config.Config{DataPath: t.TempDir()}) + if err != nil { + t.Fatalf("monitoring.New: %v", err) + } + t.Cleanup(func() { other.Stop() }) + + handler.SetMonitor(other) + if got := handler.getMonitor(context.Background()); got != other { + t.Fatalf("getMonitor = %v, want %v", got, other) + } + + handler.SetMonitor(monitor) + if got := handler.getMonitor(context.Background()); got != monitor { + t.Fatalf("getMonitor = %v, want %v", got, monitor) + } +} + +func TestKubernetesAgentHandlers_HandleReport(t *testing.T) { + handler, _ := newKubernetesAgentHandlers(t, nil) + + report := agentsk8s.Report{ + Agent: agentsk8s.AgentInfo{ + ID: "agent-2", + Version: "1.0.0", + IntervalSeconds: 30, + }, + Cluster: agentsk8s.ClusterInfo{ + ID: "cluster-2", + Name: "cluster-2", + Version: "1.28.0", + }, + Timestamp: time.Now().UTC(), + } + body, _ := json.Marshal(report) + + req := httptest.NewRequest(http.MethodPost, "/api/agents/kubernetes/report", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleReport(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestKubernetesAgentHandlers_HandleClusterActions(t *testing.T) { + handler, _ := newKubernetesAgentHandlers(t, nil) + + req := httptest.NewRequest(http.MethodPost, "/api/agents/kubernetes/clusters/cluster-1/allow-reenroll", nil) + rec := httptest.NewRecorder() + + handler.HandleClusterActions(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestKubernetesAgentHandlers_HandleDeleteCluster(t *testing.T) { + handler, monitor := newKubernetesAgentHandlers(t, nil) + clusterID := seedKubernetesCluster(t, monitor) + + req := httptest.NewRequest(http.MethodDelete, "/api/agents/kubernetes/clusters/"+clusterID, nil) + rec := httptest.NewRecorder() + + handler.HandleDeleteCluster(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestKubernetesAgentHandlers_HandleAllowReenroll(t *testing.T) { + handler, _ := newKubernetesAgentHandlers(t, nil) + + req := httptest.NewRequest(http.MethodPost, "/api/agents/kubernetes/clusters/cluster-1/allow-reenroll", nil) + rec := httptest.NewRecorder() + + handler.HandleAllowReenroll(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestKubernetesAgentHandlers_HandleUnhideCluster(t *testing.T) { + handler, monitor := newKubernetesAgentHandlers(t, nil) + clusterID := seedKubernetesCluster(t, monitor) + + req := httptest.NewRequest(http.MethodPut, "/api/agents/kubernetes/clusters/"+clusterID+"/unhide", nil) + rec := httptest.NewRecorder() + + handler.HandleUnhideCluster(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestKubernetesAgentHandlers_HandleMarkPendingUninstall(t *testing.T) { + handler, monitor := newKubernetesAgentHandlers(t, nil) + clusterID := seedKubernetesCluster(t, monitor) + + req := httptest.NewRequest(http.MethodPut, "/api/agents/kubernetes/clusters/"+clusterID+"/pending-uninstall", nil) + rec := httptest.NewRecorder() + + handler.HandleMarkPendingUninstall(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestKubernetesAgentHandlers_HandleSetCustomDisplayName(t *testing.T) { + handler, monitor := newKubernetesAgentHandlers(t, nil) + clusterID := seedKubernetesCluster(t, monitor) + + body := []byte(`{"displayName":"Custom Cluster"}`) + req := httptest.NewRequest(http.MethodPut, "/api/agents/kubernetes/clusters/"+clusterID+"/display-name", bytes.NewReader(body)) + rec := httptest.NewRecorder() + + handler.HandleSetCustomDisplayName(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/api/multi_tenant_setters_additional_test.go b/internal/api/multi_tenant_setters_additional_test.go new file mode 100644 index 000000000..319a6b6d4 --- /dev/null +++ b/internal/api/multi_tenant_setters_additional_test.go @@ -0,0 +1,32 @@ +package api + +import "testing" + +func TestConfigHandlersSetMultiTenantMonitor(t *testing.T) { + handler := &ConfigHandlers{} + handler.SetMultiTenantMonitor(nil) + if handler.mtMonitor != nil { + t.Fatalf("mtMonitor should be nil after SetMultiTenantMonitor(nil)") + } +} + +func TestRouterSetMultiTenantMonitor(t *testing.T) { + router := &Router{ + alertHandlers: &AlertHandlers{}, + notificationHandlers: &NotificationHandlers{}, + dockerAgentHandlers: &DockerAgentHandlers{}, + hostAgentHandlers: &HostAgentHandlers{}, + kubernetesAgentHandlers: &KubernetesAgentHandlers{}, + systemSettingsHandler: &SystemSettingsHandler{}, + resourceHandlers: NewResourceHandlers(), + } + + router.SetMultiTenantMonitor(nil) + + if router.mtMonitor != nil { + t.Fatalf("mtMonitor should be nil after SetMultiTenantMonitor(nil)") + } + if router.resourceHandlers.tenantStateProvider == nil { + t.Fatalf("tenantStateProvider should be set on resource handlers") + } +} diff --git a/internal/api/notification_queue_additional_test.go b/internal/api/notification_queue_additional_test.go new file mode 100644 index 000000000..2a5ec578a --- /dev/null +++ b/internal/api/notification_queue_additional_test.go @@ -0,0 +1,117 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/monitoring" + "github.com/rcourtman/pulse-go-rewrite/internal/notifications" +) + +func newNotificationQueueHandlers(t *testing.T) (*NotificationQueueHandlers, *notifications.NotificationQueue) { + t.Helper() + + t.Setenv("PULSE_DATA_DIR", t.TempDir()) + cfg := &config.Config{DataPath: t.TempDir()} + + monitor, err := monitoring.New(cfg) + if err != nil { + t.Fatalf("monitoring.New: %v", err) + } + t.Cleanup(func() { monitor.Stop() }) + + queue := monitor.GetNotificationManager().GetQueue() + if queue == nil { + t.Fatalf("expected notification queue to be initialized") + } + + handler := NewNotificationQueueHandlers(monitor) + return handler, queue +} + +func enqueueDLQNotification(t *testing.T, queue *notifications.NotificationQueue, id string) { + t.Helper() + + notification := ¬ifications.QueuedNotification{ + ID: id, + Type: "webhook", + Status: notifications.QueueStatusDLQ, + Alerts: []*alerts.Alert{{ID: "alert-1", Type: "test"}}, + Config: json.RawMessage(`{}`), + } + if err := queue.Enqueue(notification); err != nil { + t.Fatalf("queue.Enqueue: %v", err) + } +} + +func TestNotificationQueueHandlers_GetDLQAndStats(t *testing.T) { + handler, queue := newNotificationQueueHandlers(t) + enqueueDLQNotification(t, queue, "notif-1") + + req := httptest.NewRequest(http.MethodGet, "/api/notifications/dlq?limit=10", nil) + rec := httptest.NewRecorder() + handler.GetDLQ(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("GetDLQ status = %d, want 200", rec.Code) + } + + var dlq []notifications.QueuedNotification + if err := json.Unmarshal(rec.Body.Bytes(), &dlq); err != nil { + t.Fatalf("decode DLQ: %v", err) + } + if len(dlq) != 1 || dlq[0].ID != "notif-1" { + t.Fatalf("DLQ = %+v, want notif-1", dlq) + } + + req = httptest.NewRequest(http.MethodGet, "/api/notifications/queue/stats", nil) + rec = httptest.NewRecorder() + handler.GetQueueStats(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("GetQueueStats status = %d, want 200", rec.Code) + } +} + +func TestNotificationQueueHandlers_RetryAndDelete(t *testing.T) { + handler, queue := newNotificationQueueHandlers(t) + enqueueDLQNotification(t, queue, "notif-2") + + retryBody := []byte(`{"id":"notif-2"}`) + req := httptest.NewRequest(http.MethodPost, "/api/notifications/dlq/retry", bytes.NewReader(retryBody)) + rec := httptest.NewRecorder() + handler.RetryDLQItem(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("RetryDLQItem status = %d, want 200: %s", rec.Code, rec.Body.String()) + } + + deleteBody := []byte(`{"id":"notif-2"}`) + req = httptest.NewRequest(http.MethodPost, "/api/notifications/dlq/delete", bytes.NewReader(deleteBody)) + rec = httptest.NewRecorder() + handler.DeleteDLQItem(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("DeleteDLQItem status = %d, want 200: %s", rec.Code, rec.Body.String()) + } +} + +func TestNotificationQueueHandlers_HandleNotificationQueue(t *testing.T) { + handler, queue := newNotificationQueueHandlers(t) + enqueueDLQNotification(t, queue, "notif-3") + + req := httptest.NewRequest(http.MethodGet, "/api/notifications/dlq", nil) + rec := httptest.NewRecorder() + handler.HandleNotificationQueue(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("HandleNotificationQueue DLQ status = %d, want 200", rec.Code) + } + + req = httptest.NewRequest(http.MethodGet, "/api/notifications/unknown", nil) + rec = httptest.NewRecorder() + handler.HandleNotificationQueue(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("HandleNotificationQueue status = %d, want 404", rec.Code) + } +} diff --git a/internal/api/oidc_handlers_additional_test.go b/internal/api/oidc_handlers_additional_test.go new file mode 100644 index 000000000..a98392238 --- /dev/null +++ b/internal/api/oidc_handlers_additional_test.go @@ -0,0 +1,96 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestHandleOIDCLogin_DisabledGetRedirect(t *testing.T) { + router := &Router{config: &config.Config{OIDC: &config.OIDCConfig{Enabled: false}}} + + req := httptest.NewRequest(http.MethodGet, "/api/oidc/login", nil) + rec := httptest.NewRecorder() + + router.handleOIDCLogin(rec, req) + if rec.Code != http.StatusFound { + t.Fatalf("status = %d, want 302", rec.Code) + } + location := rec.Header().Get("Location") + if !strings.Contains(location, "oidc=error") || !strings.Contains(location, "oidc_error=oidc_disabled") { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestHandleOIDCLogin_DisabledPost(t *testing.T) { + router := &Router{config: &config.Config{OIDC: &config.OIDCConfig{Enabled: false}}} + + req := httptest.NewRequest(http.MethodPost, "/api/oidc/login", nil) + rec := httptest.NewRecorder() + + router.handleOIDCLogin(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rec.Code) + } + + var payload struct { + Code string `json:"code"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload.Code != "oidc_disabled" { + t.Fatalf("code = %q, want oidc_disabled", payload.Code) + } +} + +func TestHandleOIDCCallback_Disabled(t *testing.T) { + router := &Router{config: &config.Config{OIDC: &config.OIDCConfig{Enabled: false}}} + + req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback", nil) + rec := httptest.NewRecorder() + + router.handleOIDCCallback(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want 404", rec.Code) + } +} + +func TestGetOIDCService_Disabled(t *testing.T) { + router := &Router{config: &config.Config{OIDC: &config.OIDCConfig{Enabled: false}}} + + if _, err := router.getOIDCService(context.Background(), "https://example.com/callback"); err == nil { + t.Fatalf("expected error when oidc disabled") + } +} + +func TestRedirectOIDCError(t *testing.T) { + router := &Router{config: &config.Config{}} + + req := httptest.NewRequest(http.MethodGet, "/api/oidc/callback", nil) + rec := httptest.NewRecorder() + + router.redirectOIDCError(rec, req, "/login?foo=bar", "bad") + location := rec.Header().Get("Location") + if !strings.Contains(location, "oidc=error") || !strings.Contains(location, "oidc_error=bad") { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestEnsureOIDCConfig_Defaults(t *testing.T) { + cfg := &config.Config{PublicURL: "https://pulse.example.com"} + router := &Router{config: cfg} + + oidcCfg := router.ensureOIDCConfig() + if oidcCfg == nil { + t.Fatalf("expected oidc config to be initialized") + } + if oidcCfg.RedirectURL != "https://pulse.example.com"+config.DefaultOIDCCallbackPath { + t.Fatalf("redirect url = %q, want default", oidcCfg.RedirectURL) + } +} diff --git a/internal/api/router.go b/internal/api/router.go index 69e0c40c6..c41e46953 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -37,6 +37,7 @@ import ( "github.com/rcourtman/pulse-go-rewrite/internal/ai/remediation" "github.com/rcourtman/pulse-go-rewrite/internal/ai/tools" "github.com/rcourtman/pulse-go-rewrite/internal/ai/unified" + "github.com/rcourtman/pulse-go-rewrite/internal/aidiscovery" "github.com/rcourtman/pulse-go-rewrite/internal/alerts" "github.com/rcourtman/pulse-go-rewrite/internal/config" "github.com/rcourtman/pulse-go-rewrite/internal/license" @@ -67,6 +68,7 @@ type Router struct { systemSettingsHandler *SystemSettingsHandler aiSettingsHandler *AISettingsHandler aiHandler *AIHandler // AI chat handler + aiDiscoveryHandlers *AIDiscoveryHandlers resourceHandlers *ResourceHandlers reportingHandlers *ReportingHandlers configProfileHandler *ConfigProfileHandler @@ -1267,6 +1269,11 @@ func (r *Router) setupRoutes() { // AI chat handler r.aiHandler = NewAIHandler(r.multiTenant, r.mtMonitor, r.agentExecServer) + + // AI-powered infrastructure discovery handlers + // Note: The actual service is wired up later via SetAIDiscoveryService + r.aiDiscoveryHandlers = NewAIDiscoveryHandlers(nil) + // Wire license checker for Pro feature gating (AI Patrol, Alert Analysis, Auto-Fix) r.aiSettingsHandler.SetLicenseHandlers(r.licenseHandlers) // Wire model change callback to restart AI chat service when model is changed @@ -1412,6 +1419,8 @@ func (r *Router) setupRoutes() { r.aiSettingsHandler.HandleGetInvestigation(w, req) case strings.HasSuffix(path, "/reinvestigate"): r.aiSettingsHandler.HandleReinvestigateFinding(w, req) + case strings.HasSuffix(path, "/reapprove"): + r.aiSettingsHandler.HandleReapproveInvestigationFix(w, req) default: http.Error(w, "Not found", http.StatusNotFound) } @@ -1493,6 +1502,47 @@ func (r *Router) setupRoutes() { // AI question endpoints r.mux.HandleFunc("/api/ai/question/", RequireAuth(r.config, r.routeQuestions)) + // AI-powered infrastructure discovery endpoints + r.mux.HandleFunc("/api/aidiscovery", RequireAuth(r.config, RequireScope(config.ScopeMonitoringRead, r.aiDiscoveryHandlers.HandleListDiscoveries))) + r.mux.HandleFunc("/api/aidiscovery/status", RequireAuth(r.config, RequireScope(config.ScopeMonitoringRead, r.aiDiscoveryHandlers.HandleGetStatus))) + r.mux.HandleFunc("/api/aidiscovery/type/", RequireAuth(r.config, RequireScope(config.ScopeMonitoringRead, r.aiDiscoveryHandlers.HandleListByType))) + r.mux.HandleFunc("/api/aidiscovery/host/", RequireAuth(r.config, RequireScope(config.ScopeMonitoringRead, r.aiDiscoveryHandlers.HandleListByHost))) + r.mux.HandleFunc("/api/aidiscovery/", RequireAuth(r.config, func(w http.ResponseWriter, req *http.Request) { + path := req.URL.Path + switch req.Method { + case http.MethodGet: + if !ensureScope(w, req, config.ScopeMonitoringRead) { + return + } + if strings.HasSuffix(path, "/progress") { + r.aiDiscoveryHandlers.HandleGetProgress(w, req) + } else { + r.aiDiscoveryHandlers.HandleGetDiscovery(w, req) + } + case http.MethodPost: + if !ensureScope(w, req, config.ScopeMonitoringWrite) { + return + } + r.aiDiscoveryHandlers.HandleTriggerDiscovery(w, req) + case http.MethodPut: + if !ensureScope(w, req, config.ScopeMonitoringWrite) { + return + } + if strings.HasSuffix(path, "/notes") { + r.aiDiscoveryHandlers.HandleUpdateNotes(w, req) + } else { + http.Error(w, "Not found", http.StatusNotFound) + } + case http.MethodDelete: + if !ensureScope(w, req, config.ScopeMonitoringWrite) { + return + } + r.aiDiscoveryHandlers.HandleDeleteDiscovery(w, req) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + })) + // Agent WebSocket for AI command execution r.mux.HandleFunc("/api/agent/ws", r.handleAgentWebSocket) @@ -1873,6 +1923,13 @@ func (r *Router) SetConfig(cfg *config.Config) { } } +// SetAIDiscoveryService sets the AI discovery service for the router. +func (r *Router) SetAIDiscoveryService(svc *aidiscovery.Service) { + if r.aiDiscoveryHandlers != nil { + r.aiDiscoveryHandlers.SetService(svc) + } +} + // StartPatrol starts the AI patrol service for background infrastructure monitoring func (r *Router) StartPatrol(ctx context.Context) { if r.aiSettingsHandler != nil { @@ -2085,6 +2142,16 @@ func (r *Router) StartPatrol(ctx context.Context) { // Finally start the actual patrol loop r.aiSettingsHandler.StartPatrol(ctx) + + // Wire up AI discovery service to the handlers + // This enables the /api/aidiscovery endpoints to trigger discovery scans + aiService := r.aiSettingsHandler.GetAIService(ctx) + if aiService != nil { + if discoveryService := aiService.GetAIDiscoveryService(); discoveryService != nil { + r.SetAIDiscoveryService(discoveryService) + log.Info().Msg("AI Discovery: Service wired to API handlers") + } + } } } @@ -2556,6 +2623,19 @@ func (r *Router) wireAIChatProviders() { } } + // Wire discovery provider for AI-powered infrastructure discovery (pulse_get_discovery, pulse_list_discoveries) + if r.aiSettingsHandler != nil { + if aiSvc := r.aiSettingsHandler.GetAIService(context.Background()); aiSvc != nil { + if discoverySvc := aiSvc.GetAIDiscoveryService(); discoverySvc != nil { + adapter := aidiscovery.NewToolsAdapter(discoverySvc) + if adapter != nil { + service.SetDiscoveryProvider(tools.NewDiscoveryMCPAdapter(adapter)) + log.Debug().Msg("AI chat: Discovery provider wired") + } + } + } + } + log.Info().Msg("AI chat MCP tool providers wired") } diff --git a/internal/api/router_integration_test.go b/internal/api/router_integration_test.go index 0920e3ba2..ad1ed41ba 100644 --- a/internal/api/router_integration_test.go +++ b/internal/api/router_integration_test.go @@ -90,6 +90,13 @@ func newIntegrationServerWithConfig(t *testing.T, customize func(*config.Config) srv := httptest.NewServer(router.Handler()) t.Cleanup(func() { srv.Close() + if monitor != nil { + monitor.StopDiscoveryService() + monitor.Stop() + } + if hub != nil { + hub.Stop() + } mock.SetEnabled(false) }) diff --git a/internal/config/ai.go b/internal/config/ai.go index 4979ace6f..6cc3f6d71 100644 --- a/internal/config/ai.go +++ b/internal/config/ai.go @@ -16,14 +16,15 @@ const ( // This is stored in ai.enc (encrypted) in the config directory type AIConfig struct { Enabled bool `json:"enabled"` - Provider string `json:"provider"` // DEPRECATED: legacy single provider field, kept for migration - APIKey string `json:"api_key"` // DEPRECATED: legacy single API key, kept for migration - Model string `json:"model"` // Currently selected default model (format: "provider:model-name") - ChatModel string `json:"chat_model,omitempty"` // Model for interactive chat (defaults to Model) - PatrolModel string `json:"patrol_model,omitempty"` // Model for background patrol (defaults to Model, can be cheaper) - BaseURL string `json:"base_url"` // DEPRECATED: legacy base URL, kept for migration - AutonomousMode bool `json:"autonomous_mode"` // when true, AI executes commands without approval - CustomContext string `json:"custom_context"` // user-provided context about their infrastructure + Provider string `json:"provider"` // DEPRECATED: legacy single provider field, kept for migration + APIKey string `json:"api_key"` // DEPRECATED: legacy single API key, kept for migration + Model string `json:"model"` // Currently selected default model (format: "provider:model-name") + ChatModel string `json:"chat_model,omitempty"` // Model for interactive chat (defaults to Model) + PatrolModel string `json:"patrol_model,omitempty"` // Model for background patrol (defaults to Model, can be cheaper) + DiscoveryModel string `json:"discovery_model,omitempty"` // Model for infrastructure discovery (defaults to cheapest available, e.g., haiku) + BaseURL string `json:"base_url"` // DEPRECATED: legacy base URL, kept for migration + AutonomousMode bool `json:"autonomous_mode"` // when true, AI executes commands without approval + CustomContext string `json:"custom_context"` // user-provided context about their infrastructure // Multi-provider credentials - each provider can be configured independently AnthropicAPIKey string `json:"anthropic_api_key,omitempty"` // Anthropic API key @@ -72,6 +73,10 @@ type AIConfig struct { PatrolInvestigationBudget int `json:"patrol_investigation_budget,omitempty"` // Max turns per investigation (default: 15) PatrolInvestigationTimeoutSec int `json:"patrol_investigation_timeout_sec,omitempty"` // Max seconds per investigation (default: 300) PatrolCriticalRequireApproval bool `json:"patrol_critical_require_approval"` // Critical findings always require approval (default: true) + + // AI Discovery settings - controls automatic infrastructure discovery + DiscoveryEnabled bool `json:"discovery_enabled"` // Enable AI-powered infrastructure discovery + DiscoveryIntervalHours int `json:"discovery_interval_hours,omitempty"` // Hours between automatic re-scans (0 = manual only, default: 0) } // AIProvider constants @@ -101,6 +106,9 @@ const ( PatrolAutonomyApproval = "approval" // PatrolAutonomyFull - Spawn Chat sessions to investigate, execute non-critical fixes automatically PatrolAutonomyFull = "full" + // PatrolAutonomyAutonomous - Full autonomy, execute ALL fixes including destructive ones without approval + // User accepts full risk - similar to "auto-accept" mode in Claude Code + PatrolAutonomyAutonomous = "autonomous" ) // Default patrol investigation settings @@ -393,6 +401,16 @@ func (c *AIConfig) GetPatrolModel() string { return c.GetModel() } +// GetDiscoveryModel returns the model for infrastructure discovery +// Falls back to the main model since discovery needs to use the same provider +func (c *AIConfig) GetDiscoveryModel() string { + if c.DiscoveryModel != "" { + return c.DiscoveryModel + } + // Fall back to the main model to ensure we use the same provider + return c.GetModel() +} + // GetAutoFixModel returns the model for automatic remediation actions // Falls back to PatrolModel, then to the main Model if AutoFixModel is not set // Auto-fix may warrant a more capable model since it takes actions @@ -559,7 +577,7 @@ func (c *AIConfig) GetPatrolAutonomyLevel() string { return PatrolAutonomyMonitor } switch c.PatrolAutonomyLevel { - case PatrolAutonomyMonitor, PatrolAutonomyApproval, PatrolAutonomyFull: + case PatrolAutonomyMonitor, PatrolAutonomyApproval, PatrolAutonomyFull, PatrolAutonomyAutonomous: return c.PatrolAutonomyLevel default: return PatrolAutonomyMonitor @@ -609,7 +627,7 @@ func (c *AIConfig) ShouldCriticalRequireApproval() bool { // IsValidPatrolAutonomyLevel checks if a patrol autonomy level string is valid func IsValidPatrolAutonomyLevel(level string) bool { switch level { - case PatrolAutonomyMonitor, PatrolAutonomyApproval, PatrolAutonomyFull: + case PatrolAutonomyMonitor, PatrolAutonomyApproval, PatrolAutonomyFull, PatrolAutonomyAutonomous: return true default: return false @@ -621,3 +639,17 @@ func (c *AIConfig) IsPatrolAutonomyEnabled() bool { level := c.GetPatrolAutonomyLevel() return level != PatrolAutonomyMonitor } + +// IsDiscoveryEnabled returns whether AI-powered infrastructure discovery is enabled +func (c *AIConfig) IsDiscoveryEnabled() bool { + return c.DiscoveryEnabled +} + +// GetDiscoveryInterval returns the interval between automatic discovery scans +// Returns 0 if discovery is manual-only +func (c *AIConfig) GetDiscoveryInterval() time.Duration { + if c.DiscoveryIntervalHours <= 0 { + return 0 // Manual only + } + return time.Duration(c.DiscoveryIntervalHours) * time.Hour +} diff --git a/internal/config/ai_additional_test.go b/internal/config/ai_additional_test.go new file mode 100644 index 000000000..80d24d985 --- /dev/null +++ b/internal/config/ai_additional_test.go @@ -0,0 +1,134 @@ +package config + +import "testing" + +func TestAIConfig_DiscoveryAndControl(t *testing.T) { + cfg := &AIConfig{} + // Default discovery model should fall back to GetModel() + if got := cfg.GetDiscoveryModel(); got != cfg.GetModel() { + t.Fatalf("default discovery model should fall back to GetModel(), got %q", got) + } + + cfg.DiscoveryModel = "custom:discovery" + if got := cfg.GetDiscoveryModel(); got != "custom:discovery" { + t.Fatalf("custom discovery model = %q", got) + } + + cfg = &AIConfig{} + if got := cfg.GetControlLevel(); got != ControlLevelReadOnly { + t.Fatalf("default control level = %q", got) + } + + cfg.AutonomousMode = true + if got := cfg.GetControlLevel(); got != ControlLevelAutonomous { + t.Fatalf("legacy autonomous mode = %q", got) + } + + cfg.ControlLevel = "suggest" + if got := cfg.GetControlLevel(); got != ControlLevelControlled { + t.Fatalf("suggest control level = %q", got) + } + + cfg.ControlLevel = "invalid" + if got := cfg.GetControlLevel(); got != ControlLevelReadOnly { + t.Fatalf("invalid control level = %q", got) + } + + cfg.ControlLevel = ControlLevelControlled + if !cfg.IsControlEnabled() { + t.Fatalf("control should be enabled for controlled level") + } + if cfg.IsAutonomous() { + t.Fatalf("autonomous should be false for controlled level") + } +} + +func TestAIConfig_PatrolSettings(t *testing.T) { + cfg := &AIConfig{} + if got := cfg.GetPatrolAutonomyLevel(); got != PatrolAutonomyMonitor { + t.Fatalf("default patrol autonomy = %q", got) + } + if cfg.IsPatrolAutonomyEnabled() { + t.Fatalf("patrol autonomy should be disabled by default") + } + + cfg.PatrolAutonomyLevel = PatrolAutonomyFull + if got := cfg.GetPatrolAutonomyLevel(); got != PatrolAutonomyFull { + t.Fatalf("patrol autonomy = %q", got) + } + if !cfg.IsPatrolAutonomyEnabled() { + t.Fatalf("patrol autonomy should be enabled for full mode") + } + + cfg.PatrolAutonomyLevel = "invalid" + if got := cfg.GetPatrolAutonomyLevel(); got != PatrolAutonomyMonitor { + t.Fatalf("invalid autonomy should fallback to monitor, got %q", got) + } + + cfg.PatrolInvestigationBudget = 2 + if got := cfg.GetPatrolInvestigationBudget(); got != 5 { + t.Fatalf("budget should clamp to 5, got %d", got) + } + + cfg.PatrolInvestigationBudget = 40 + if got := cfg.GetPatrolInvestigationBudget(); got != 30 { + t.Fatalf("budget should clamp to 30, got %d", got) + } + + cfg.PatrolInvestigationBudget = 10 + if got := cfg.GetPatrolInvestigationBudget(); got != 10 { + t.Fatalf("budget should be 10, got %d", got) + } + + cfg.PatrolInvestigationTimeoutSec = 30 + if got := cfg.GetPatrolInvestigationTimeout(); got.Seconds() != 60 { + t.Fatalf("timeout should clamp to 60s, got %s", got) + } + + cfg.PatrolInvestigationTimeoutSec = 1900 + if got := cfg.GetPatrolInvestigationTimeout(); got.Seconds() != 1800 { + t.Fatalf("timeout should clamp to 1800s, got %s", got) + } + + cfg.PatrolInvestigationTimeoutSec = 120 + if got := cfg.GetPatrolInvestigationTimeout(); got.Seconds() != 120 { + t.Fatalf("timeout should be 120s, got %s", got) + } + + cfg.PatrolAutonomyLevel = "" + cfg.PatrolCriticalRequireApproval = false + if !cfg.ShouldCriticalRequireApproval() { + t.Fatalf("critical approval should default to true when level unset") + } + + cfg.PatrolAutonomyLevel = PatrolAutonomyMonitor + if cfg.ShouldCriticalRequireApproval() { + t.Fatalf("critical approval should be false when explicitly disabled") + } +} + +func TestAIConfig_ProtectedGuestsAndValidation(t *testing.T) { + cfg := &AIConfig{} + if guests := cfg.GetProtectedGuests(); len(guests) != 0 { + t.Fatalf("expected empty protected guests, got %v", guests) + } + + cfg.ProtectedGuests = []string{"vm-100", "vm-200"} + guests := cfg.GetProtectedGuests() + if len(guests) != 2 || guests[0] != "vm-100" { + t.Fatalf("unexpected protected guests: %v", guests) + } + + if IsValidControlLevel("bad") { + t.Fatalf("expected invalid control level to be false") + } + if !IsValidControlLevel(ControlLevelAutonomous) { + t.Fatalf("expected autonomous to be valid") + } + if IsValidPatrolAutonomyLevel("bad") { + t.Fatalf("expected invalid patrol autonomy to be false") + } + if !IsValidPatrolAutonomyLevel(PatrolAutonomyApproval) { + t.Fatalf("expected patrol approval to be valid") + } +} diff --git a/internal/config/api_tokens_additional_test.go b/internal/config/api_tokens_additional_test.go new file mode 100644 index 000000000..4fc915f94 --- /dev/null +++ b/internal/config/api_tokens_additional_test.go @@ -0,0 +1,76 @@ +package config + +import ( + "testing" + "time" +) + +func TestAPITokenRecord_IsLegacyToken(t *testing.T) { + record := &APITokenRecord{} + if !record.IsLegacyToken() { + t.Fatalf("expected legacy token when no org bindings") + } + + record.OrgID = "org-1" + if record.IsLegacyToken() { + t.Fatalf("expected non-legacy token when OrgID is set") + } + + record.OrgID = "" + record.OrgIDs = []string{"org-2"} + if record.IsLegacyToken() { + t.Fatalf("expected non-legacy token when OrgIDs is set") + } +} + +func TestNewAPITokenRecord(t *testing.T) { + if _, err := NewAPITokenRecord("", "name", nil); err == nil { + t.Fatalf("expected error for empty raw token") + } + + record, err := NewAPITokenRecord("token-abcdef123456", "name", nil) + if err != nil { + t.Fatalf("NewAPITokenRecord error: %v", err) + } + if record.Name != "name" { + t.Fatalf("Name = %q", record.Name) + } + if record.Hash == "" { + t.Fatalf("expected hash to be set") + } + if record.Prefix != "token-" { + t.Fatalf("Prefix = %q", record.Prefix) + } + if record.Suffix != "3456" { + t.Fatalf("Suffix = %q", record.Suffix) + } + if len(record.Scopes) != 1 || record.Scopes[0] != ScopeWildcard { + t.Fatalf("Scopes = %v", record.Scopes) + } +} + +func TestNewHashedAPITokenRecord(t *testing.T) { + if _, err := NewHashedAPITokenRecord("", "name", time.Time{}, nil); err == nil { + t.Fatalf("expected error for empty hashed token") + } + + record, err := NewHashedAPITokenRecord("hashed-token-1234", "name", time.Time{}, []string{ScopeSettingsRead}) + if err != nil { + t.Fatalf("NewHashedAPITokenRecord error: %v", err) + } + if record.Hash != "hashed-token-1234" { + t.Fatalf("Hash = %q", record.Hash) + } + if record.Prefix != "hashed" { + t.Fatalf("Prefix = %q", record.Prefix) + } + if record.Suffix != "1234" { + t.Fatalf("Suffix = %q", record.Suffix) + } + if len(record.Scopes) != 1 || record.Scopes[0] != ScopeSettingsRead { + t.Fatalf("Scopes = %v", record.Scopes) + } + if record.CreatedAt.IsZero() { + t.Fatalf("expected CreatedAt to be set") + } +} diff --git a/internal/infradiscovery/service.go b/internal/infradiscovery/service.go new file mode 100644 index 000000000..1caec4637 --- /dev/null +++ b/internal/infradiscovery/service.go @@ -0,0 +1,605 @@ +// Package infradiscovery provides AI-powered infrastructure discovery for detecting +// applications and services running on monitored hosts. It uses LLM analysis to +// identify services from Docker containers, enabling AI systems like Patrol to +// understand where services run and propose correct remediation commands. +package infradiscovery + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai/knowledge" + "github.com/rcourtman/pulse-go-rewrite/internal/models" + "github.com/rs/zerolog/log" +) + +// StateProvider provides access to the current infrastructure state. +type StateProvider interface { + GetState() models.StateSnapshot +} + +// AIAnalyzer provides AI analysis capabilities for discovery. +// This interface allows the discovery service to use LLM analysis +// without creating circular dependencies with the AI package. +type AIAnalyzer interface { + // AnalyzeForDiscovery sends a prompt to the AI and returns the response. + // The model parameter specifies which model to use (e.g., "anthropic:claude-3-5-haiku-latest") + AnalyzeForDiscovery(ctx context.Context, prompt string) (string, error) +} + +// DiscoveredApp represents a detected application or service. +type DiscoveredApp struct { + ID string `json:"id"` // Unique ID: "docker:hostname:container" + Type string `json:"type"` // Application type: pbs, postgres, nginx, custom, etc. + Name string `json:"name"` // Human-readable name: "Proxmox Backup Server" + Category string `json:"category"` // Category: backup, database, web, monitoring, unknown + RunsIn string `json:"runs_in"` // Runtime: docker, systemd, native + HostID string `json:"host_id"` // Host identifier (agent ID or hostname) + Hostname string `json:"hostname"` // Human-readable hostname + ContainerID string `json:"container_id"` // Docker container ID (if applicable) + ContainerName string `json:"container_name"` // Docker container name (if applicable) + ServiceUnit string `json:"service_unit"` // Systemd unit name (if applicable) + Ports []int `json:"ports"` // Exposed ports + CLIAccess string `json:"cli_access"` // How to access CLI: "docker exec pbs proxmox-backup-manager" + Confidence float64 `json:"confidence"` // Detection confidence 0-1 + DetectedAt time.Time `json:"detected_at"` // When this app was detected + AIReasoning string `json:"ai_reasoning"` // AI's reasoning for the identification +} + +// ContainerInfo holds information about a container for AI analysis. +type ContainerInfo struct { + Name string `json:"name"` + Image string `json:"image"` + Ports []PortInfo `json:"ports,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + EnvVarNames []string `json:"env_var_names,omitempty"` // Just names, not values (security) + Mounts []string `json:"mounts,omitempty"` + Networks []string `json:"networks,omitempty"` + Status string `json:"status,omitempty"` + Command string `json:"command,omitempty"` +} + +// PortInfo holds port mapping information. +type PortInfo struct { + HostPort int `json:"host_port,omitempty"` + ContainerPort int `json:"container_port"` + Protocol string `json:"protocol,omitempty"` +} + +// AIDiscoveryResult represents the AI's analysis of a container. +type AIDiscoveryResult struct { + ServiceType string `json:"service_type"` // e.g., "postgres", "pbs", "nginx", "unknown" + ServiceName string `json:"service_name"` // Human-readable name + Category string `json:"category"` // backup, database, web, monitoring, etc. + CLICommand string `json:"cli_command"` // How to run CLI commands in this container + Confidence float64 `json:"confidence"` // 0-1 confidence score + Reasoning string `json:"reasoning"` // Why the AI made this determination +} + +// Service manages AI-powered infrastructure discovery. +type Service struct { + stateProvider StateProvider + knowledgeStore *knowledge.Store + aiAnalyzer AIAnalyzer + mu sync.RWMutex + lastRun time.Time + interval time.Duration + stopCh chan struct{} + running bool + discoveries []DiscoveredApp + + // Cache to avoid re-analyzing the same containers + // Key: image name, Value: analysis result + analysisCache map[string]*AIDiscoveryResult + cacheMu sync.RWMutex + cacheExpiry time.Duration + lastCacheUpdate time.Time +} + +// Config holds discovery service configuration. +type Config struct { + Interval time.Duration // How often to run discovery (default: 5 minutes) + CacheExpiry time.Duration // How long to cache analysis results (default: 1 hour) +} + +// DefaultConfig returns the default discovery configuration. +func DefaultConfig() Config { + return Config{ + Interval: 5 * time.Minute, + CacheExpiry: 1 * time.Hour, + } +} + +// NewService creates a new AI-powered infrastructure discovery service. +func NewService(stateProvider StateProvider, knowledgeStore *knowledge.Store, cfg Config) *Service { + if cfg.Interval == 0 { + cfg.Interval = 5 * time.Minute + } + if cfg.CacheExpiry == 0 { + cfg.CacheExpiry = 1 * time.Hour + } + + return &Service{ + stateProvider: stateProvider, + knowledgeStore: knowledgeStore, + interval: cfg.Interval, + cacheExpiry: cfg.CacheExpiry, + stopCh: make(chan struct{}), + discoveries: make([]DiscoveredApp, 0), + analysisCache: make(map[string]*AIDiscoveryResult), + } +} + +// SetAIAnalyzer sets the AI analyzer for discovery. +// This must be called before Start() for AI-powered discovery to work. +func (s *Service) SetAIAnalyzer(analyzer AIAnalyzer) { + s.mu.Lock() + defer s.mu.Unlock() + s.aiAnalyzer = analyzer +} + +// Start begins the background discovery service. +func (s *Service) Start(ctx context.Context) { + s.mu.Lock() + if s.running { + s.mu.Unlock() + return + } + s.running = true + s.mu.Unlock() + + log.Info(). + Dur("interval", s.interval). + Msg("Starting AI-powered infrastructure discovery service") + + // Run immediately on startup + go func() { + defer func() { + if r := recover(); r != nil { + log.Error(). + Interface("panic", r). + Stack(). + Msg("Recovered from panic in initial infrastructure discovery") + } + }() + s.RunDiscovery(ctx) + }() + + // Start periodic discovery loop + go func() { + defer func() { + if r := recover(); r != nil { + log.Error(). + Interface("panic", r). + Stack(). + Msg("Recovered from panic in infrastructure discovery loop") + } + }() + s.discoveryLoop(ctx) + }() +} + +// Stop stops the background discovery service. +func (s *Service) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.running { + close(s.stopCh) + s.running = false + } +} + +// discoveryLoop runs periodic discovery. +func (s *Service) discoveryLoop(ctx context.Context) { + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.RunDiscovery(ctx) + case <-s.stopCh: + log.Info().Msg("Stopping infrastructure discovery service") + return + case <-ctx.Done(): + log.Info().Msg("Infrastructure discovery context cancelled") + return + } + } +} + +// RunDiscovery performs a discovery scan using AI analysis. +func (s *Service) RunDiscovery(ctx context.Context) []DiscoveredApp { + start := time.Now() + state := s.stateProvider.GetState() + + s.mu.RLock() + analyzer := s.aiAnalyzer + s.mu.RUnlock() + + if analyzer == nil { + log.Debug().Msg("AI analyzer not set, skipping discovery") + return nil + } + + var apps []DiscoveredApp + + // Collect all containers from all Docker hosts + var allContainers []struct { + Container models.DockerContainer + Host models.DockerHost + } + + for _, dockerHost := range state.DockerHosts { + for _, container := range dockerHost.Containers { + allContainers = append(allContainers, struct { + Container models.DockerContainer + Host models.DockerHost + }{container, dockerHost}) + } + } + + if len(allContainers) == 0 { + log.Debug().Msg("No Docker containers found for discovery") + s.mu.Lock() + s.lastRun = time.Now() + s.mu.Unlock() + return apps + } + + // Analyze containers (check cache first, batch uncached ones) + for _, item := range allContainers { + app := s.analyzeContainer(ctx, analyzer, item.Container, item.Host) + if app != nil { + apps = append(apps, *app) + } + } + + // Save discoveries to knowledge store + s.saveDiscoveries(apps) + + // Update cache + s.mu.Lock() + s.discoveries = apps + s.lastRun = time.Now() + s.mu.Unlock() + + log.Info(). + Int("containers_scanned", len(allContainers)). + Int("apps_discovered", len(apps)). + Dur("duration", time.Since(start)). + Msg("AI infrastructure discovery completed") + + return apps +} + +// analyzeContainer uses AI to analyze a single container. +func (s *Service) analyzeContainer(ctx context.Context, analyzer AIAnalyzer, c models.DockerContainer, host models.DockerHost) *DiscoveredApp { + // Check cache first + s.cacheMu.RLock() + cached, found := s.analysisCache[c.Image] + cacheValid := time.Since(s.lastCacheUpdate) < s.cacheExpiry + s.cacheMu.RUnlock() + + var result *AIDiscoveryResult + + if found && cacheValid { + result = cached + log.Debug(). + Str("container", c.Name). + Str("image", c.Image). + Msg("Using cached analysis result") + } else { + // Build container info for AI analysis + info := s.buildContainerInfo(c) + + // Create analysis prompt + prompt := s.buildAnalysisPrompt(info) + + // Call AI + response, err := analyzer.AnalyzeForDiscovery(ctx, prompt) + if err != nil { + log.Warn(). + Err(err). + Str("container", c.Name). + Str("image", c.Image). + Msg("AI analysis failed for container") + return nil + } + + // Parse response + result = s.parseAIResponse(response) + if result == nil { + log.Warn(). + Str("container", c.Name). + Str("response", response). + Msg("Failed to parse AI response") + return nil + } + + // Cache the result + s.cacheMu.Lock() + s.analysisCache[c.Image] = result + s.lastCacheUpdate = time.Now() + s.cacheMu.Unlock() + + log.Debug(). + Str("container", c.Name). + Str("image", c.Image). + Str("service_type", result.ServiceType). + Float64("confidence", result.Confidence). + Msg("AI analyzed container") + } + + // Skip unknown/low-confidence results + if result.ServiceType == "unknown" || result.Confidence < 0.5 { + return nil + } + + // Build CLI access string + cliAccess := result.CLICommand + if cliAccess != "" { + // Replace placeholder with actual container name + cliAccess = strings.ReplaceAll(cliAccess, "{container}", c.Name) + cliAccess = strings.ReplaceAll(cliAccess, "${container}", c.Name) + } + + // Extract ports + var ports []int + for _, p := range c.Ports { + if p.PublicPort > 0 { + ports = append(ports, int(p.PublicPort)) + } else if p.PrivatePort > 0 { + ports = append(ports, int(p.PrivatePort)) + } + } + + return &DiscoveredApp{ + ID: fmt.Sprintf("docker:%s:%s", host.Hostname, c.Name), + Type: result.ServiceType, + Name: result.ServiceName, + Category: result.Category, + RunsIn: "docker", + HostID: host.AgentID, + Hostname: host.Hostname, + ContainerID: c.ID, + ContainerName: c.Name, + Ports: ports, + CLIAccess: cliAccess, + Confidence: result.Confidence, + DetectedAt: time.Now(), + AIReasoning: result.Reasoning, + } +} + +// buildContainerInfo extracts relevant information from a container for AI analysis. +func (s *Service) buildContainerInfo(c models.DockerContainer) ContainerInfo { + info := ContainerInfo{ + Name: c.Name, + Image: c.Image, + Status: c.Status, + } + + // Extract ports + for _, p := range c.Ports { + info.Ports = append(info.Ports, PortInfo{ + HostPort: int(p.PublicPort), + ContainerPort: int(p.PrivatePort), + Protocol: p.Protocol, + }) + } + + // Extract labels + if len(c.Labels) > 0 { + info.Labels = c.Labels + } + + // Extract mount destinations + for _, m := range c.Mounts { + info.Mounts = append(info.Mounts, m.Destination) + } + + // Extract network names + for _, n := range c.Networks { + info.Networks = append(info.Networks, n.Name) + } + + return info +} + +// buildAnalysisPrompt creates the prompt for AI container analysis. +func (s *Service) buildAnalysisPrompt(info ContainerInfo) string { + // Convert info to JSON for the prompt + infoJSON, _ := json.MarshalIndent(info, "", " ") + + return fmt.Sprintf(`Analyze this Docker container and identify what service or application it's running. + +Container Information: +%s + +Based on the image name, ports, labels, environment variables, mounts, and other signals, determine: +1. What service/application is this? (e.g., postgres, redis, nginx, proxmox-backup-server, grafana, etc.) +2. What category does it belong to? (database, cache, web, backup, monitoring, message_queue, storage, etc.) +3. How should CLI commands be executed for this service? + +Respond in this exact JSON format: +{ + "service_type": "the_service_type", + "service_name": "Human Readable Name", + "category": "category", + "cli_command": "docker exec {container} ", + "confidence": 0.95, + "reasoning": "Brief explanation of why you identified it this way" +} + +Important guidelines: +- service_type should be lowercase, no spaces (e.g., "postgres", "redis", "pbs", "nginx") +- For CLI command, use {container} as a placeholder for the container name +- If the service has a CLI tool, include it (e.g., "docker exec {container} psql -U postgres" for PostgreSQL) +- If no CLI is applicable, use empty string for cli_command +- Set confidence between 0 and 1 (1 = certain, 0.5 = guess) +- If you cannot identify the service, use service_type "unknown" with low confidence + +Common services to look for: +- Databases: PostgreSQL, MySQL, MariaDB, MongoDB, Redis, Elasticsearch +- Backup: Proxmox Backup Server (PBS), Restic, Borg +- Web: Nginx, Apache, Traefik, Caddy, HAProxy +- Monitoring: Prometheus, Grafana, Loki, Alertmanager +- Message queues: RabbitMQ, Kafka +- Storage: MinIO, Nextcloud +- Home automation: Home Assistant +- Media: Plex, Jellyfin +- CI/CD: Jenkins, Drone, GitLab Runner + +Respond with ONLY the JSON, no other text.`, string(infoJSON)) +} + +// parseAIResponse parses the AI's JSON response. +func (s *Service) parseAIResponse(response string) *AIDiscoveryResult { + // Try to extract JSON from the response + response = strings.TrimSpace(response) + + // Handle markdown code blocks + if strings.HasPrefix(response, "```") { + lines := strings.Split(response, "\n") + var jsonLines []string + inBlock := false + for _, line := range lines { + if strings.HasPrefix(line, "```") { + inBlock = !inBlock + continue + } + if inBlock { + jsonLines = append(jsonLines, line) + } + } + response = strings.Join(jsonLines, "\n") + } + + // Find JSON object in response + start := strings.Index(response, "{") + end := strings.LastIndex(response, "}") + if start >= 0 && end > start { + response = response[start : end+1] + } + + var result AIDiscoveryResult + if err := json.Unmarshal([]byte(response), &result); err != nil { + log.Debug(). + Err(err). + Str("response", response). + Msg("Failed to parse AI response as JSON") + return nil + } + + return &result +} + +// saveDiscoveries persists discovered applications to the knowledge store. +func (s *Service) saveDiscoveries(apps []DiscoveredApp) { + if s.knowledgeStore == nil { + return + } + + for _, app := range apps { + // Create a descriptive note for each discovered application + title := fmt.Sprintf("%s (%s)", app.Name, app.RunsIn) + + var content string + if app.CLIAccess != "" { + content = fmt.Sprintf( + "Detected %s running in %s on %s. CLI access: %s", + app.Name, + app.RunsIn, + app.Hostname, + app.CLIAccess, + ) + } else { + content = fmt.Sprintf( + "Detected %s running in %s on %s. No CLI access available.", + app.Name, + app.RunsIn, + app.Hostname, + ) + } + + // Save to knowledge store under the host's ID + err := s.knowledgeStore.SaveNote( + app.HostID, + app.Hostname, + "host", + knowledge.CategoryInfra, + title, + content, + ) + if err != nil { + log.Warn(). + Err(err). + Str("app_id", app.ID). + Str("host", app.Hostname). + Msg("Failed to save infrastructure discovery to knowledge store") + } + } +} + +// GetDiscoveries returns the cached list of discovered applications. +func (s *Service) GetDiscoveries() []DiscoveredApp { + s.mu.RLock() + defer s.mu.RUnlock() + result := make([]DiscoveredApp, len(s.discoveries)) + copy(result, s.discoveries) + return result +} + +// GetLastRun returns the time of the last discovery run. +func (s *Service) GetLastRun() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.lastRun +} + +// ForceRefresh triggers an immediate discovery scan. +func (s *Service) ForceRefresh(ctx context.Context) { + go func() { + defer func() { + if r := recover(); r != nil { + log.Error(). + Interface("panic", r). + Stack(). + Msg("Recovered from panic in ForceRefresh infrastructure discovery") + } + }() + s.RunDiscovery(ctx) + }() +} + +// ClearCache clears the analysis cache, forcing re-analysis of all containers. +func (s *Service) ClearCache() { + s.cacheMu.Lock() + defer s.cacheMu.Unlock() + s.analysisCache = make(map[string]*AIDiscoveryResult) + s.lastCacheUpdate = time.Time{} +} + +// GetStatus returns the current service status. +func (s *Service) GetStatus() map[string]interface{} { + s.mu.RLock() + defer s.mu.RUnlock() + + s.cacheMu.RLock() + cacheSize := len(s.analysisCache) + s.cacheMu.RUnlock() + + return map[string]interface{}{ + "running": s.running, + "last_run": s.lastRun, + "interval": s.interval.String(), + "discovered_apps": len(s.discoveries), + "cache_size": cacheSize, + "ai_analyzer_set": s.aiAnalyzer != nil, + } +} diff --git a/internal/infradiscovery/service_additional_test.go b/internal/infradiscovery/service_additional_test.go new file mode 100644 index 000000000..5cc449b9c --- /dev/null +++ b/internal/infradiscovery/service_additional_test.go @@ -0,0 +1,153 @@ +package infradiscovery + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/ai/knowledge" + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +func waitFor(t *testing.T, timeout time.Duration, check func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if check() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("condition not met before timeout") +} + +func TestStartStopUpdatesStatus(t *testing.T) { + provider := &mockStateProvider{state: models.StateSnapshot{}} + service := NewService(provider, nil, Config{ + Interval: 10 * time.Millisecond, + CacheExpiry: time.Millisecond, + }) + service.SetAIAnalyzer(&mockAIAnalyzer{}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + service.Start(ctx) + status := service.GetStatus() + if running, ok := status["running"].(bool); !ok || !running { + t.Fatalf("expected running status true, got %v", status["running"]) + } + + waitFor(t, 500*time.Millisecond, func() bool { + return !service.GetLastRun().IsZero() + }) + + service.Stop() + status = service.GetStatus() + if running, ok := status["running"].(bool); !ok || running { + t.Fatalf("expected running status false, got %v", status["running"]) + } +} + +func TestForceRefreshUpdatesLastRun(t *testing.T) { + provider := &mockStateProvider{state: models.StateSnapshot{}} + service := NewService(provider, nil, DefaultConfig()) + service.SetAIAnalyzer(&mockAIAnalyzer{}) + + service.ForceRefresh(context.Background()) + + waitFor(t, 500*time.Millisecond, func() bool { + return !service.GetLastRun().IsZero() + }) +} + +func TestSaveDiscoveriesWritesKnowledge(t *testing.T) { + store, err := knowledge.NewStore(t.TempDir()) + if err != nil { + t.Fatalf("create knowledge store: %v", err) + } + service := NewService(&mockStateProvider{}, store, DefaultConfig()) + + apps := []DiscoveredApp{ + { + ID: "docker:host1:pg", + Name: "PostgreSQL", + RunsIn: "docker", + HostID: "host-1", + Hostname: "host1", + CLIAccess: "docker exec pg psql", + }, + { + ID: "docker:host1:redis", + Name: "Redis", + RunsIn: "docker", + HostID: "host-1", + Hostname: "host1", + CLIAccess: "", + }, + } + + service.saveDiscoveries(apps) + + knowledgeData, err := store.GetKnowledge("host-1") + if err != nil { + t.Fatalf("load knowledge: %v", err) + } + if knowledgeData == nil || len(knowledgeData.Notes) != 2 { + t.Fatalf("expected 2 notes, got %+v", knowledgeData) + } + + var sawCLI, sawNoCLI bool + for _, note := range knowledgeData.Notes { + if strings.Contains(note.Content, "CLI access:") { + sawCLI = true + } + if strings.Contains(note.Content, "No CLI access available.") { + sawNoCLI = true + } + } + if !sawCLI || !sawNoCLI { + t.Fatalf("expected notes with and without CLI access, got %+v", knowledgeData.Notes) + } +} + +func TestGetDiscoveriesReturnsCopy(t *testing.T) { + provider := &mockStateProvider{ + state: models.StateSnapshot{ + DockerHosts: []models.DockerHost{ + { + AgentID: "agent-1", + Hostname: "host1", + Containers: []models.DockerContainer{ + {ID: "1", Name: "web", Image: "nginx:latest"}, + }, + }, + }, + }, + } + + analyzer := &mockAIAnalyzer{ + responses: map[string]string{ + "nginx:latest": `{"service_type": "nginx", "service_name": "Nginx", "category": "web", "cli_command": "", "confidence": 0.9, "reasoning": "Web server"}`, + }, + } + + service := NewService(provider, nil, DefaultConfig()) + service.SetAIAnalyzer(analyzer) + service.RunDiscovery(context.Background()) + + first := service.GetDiscoveries() + if len(first) != 1 { + t.Fatalf("expected 1 discovery, got %d", len(first)) + } + first[0].Name = "changed" + + second := service.GetDiscoveries() + if len(second) != 1 { + t.Fatalf("expected 1 discovery, got %d", len(second)) + } + if second[0].Name == "changed" { + t.Fatalf("expected discoveries to be immutable copy, got %v", second[0].Name) + } +} diff --git a/internal/infradiscovery/service_test.go b/internal/infradiscovery/service_test.go new file mode 100644 index 000000000..e56ed398e --- /dev/null +++ b/internal/infradiscovery/service_test.go @@ -0,0 +1,379 @@ +package infradiscovery + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +// mockStateProvider implements StateProvider for testing +type mockStateProvider struct { + state models.StateSnapshot +} + +func (m *mockStateProvider) GetState() models.StateSnapshot { + return m.state +} + +// mockAIAnalyzer implements AIAnalyzer for testing +type mockAIAnalyzer struct { + responses map[string]string // image -> response + callCount int +} + +func (m *mockAIAnalyzer) AnalyzeForDiscovery(ctx context.Context, prompt string) (string, error) { + m.callCount++ + // Return a mock response based on what's in the prompt + // In real tests, we'd parse the prompt to determine which container + for image, response := range m.responses { + if containsString(prompt, image) { + return response, nil + } + } + // Default unknown response + return `{"service_type": "unknown", "service_name": "Unknown", "category": "unknown", "cli_command": "", "confidence": 0.3, "reasoning": "Could not identify"}`, nil +} + +func containsString(s, substr string) bool { + return len(substr) > 0 && len(s) >= len(substr) && (s == substr || strings.Contains(s, substr)) +} + +func TestNewService(t *testing.T) { + provider := &mockStateProvider{} + service := NewService(provider, nil, DefaultConfig()) + + if service == nil { + t.Fatal("NewService returned nil") + } + + if service.interval != 5*time.Minute { + t.Errorf("interval = %v, want 5m", service.interval) + } + + if service.cacheExpiry != 1*time.Hour { + t.Errorf("cacheExpiry = %v, want 1h", service.cacheExpiry) + } +} + +func TestParseAIResponse(t *testing.T) { + service := &Service{} + + tests := []struct { + name string + response string + want *AIDiscoveryResult + }{ + { + name: "valid JSON", + response: `{ + "service_type": "postgres", + "service_name": "PostgreSQL", + "category": "database", + "cli_command": "docker exec {container} psql -U postgres", + "confidence": 0.95, + "reasoning": "Image name contains postgres" + }`, + want: &AIDiscoveryResult{ + ServiceType: "postgres", + ServiceName: "PostgreSQL", + Category: "database", + CLICommand: "docker exec {container} psql -U postgres", + Confidence: 0.95, + Reasoning: "Image name contains postgres", + }, + }, + { + name: "JSON in markdown code block", + response: "```json\n{\"service_type\": \"redis\", \"service_name\": \"Redis\", \"category\": \"cache\", \"cli_command\": \"docker exec {container} redis-cli\", \"confidence\": 0.9, \"reasoning\": \"Redis image\"}\n```", + want: &AIDiscoveryResult{ + ServiceType: "redis", + ServiceName: "Redis", + Category: "cache", + CLICommand: "docker exec {container} redis-cli", + Confidence: 0.9, + Reasoning: "Redis image", + }, + }, + { + name: "invalid JSON", + response: "not json at all", + want: nil, + }, + { + name: "JSON with extra text", + response: `Here's my analysis: + {"service_type": "nginx", "service_name": "Nginx", "category": "web", "cli_command": "", "confidence": 0.85, "reasoning": "Web server"} + That's my answer.`, + want: &AIDiscoveryResult{ + ServiceType: "nginx", + ServiceName: "Nginx", + Category: "web", + CLICommand: "", + Confidence: 0.85, + Reasoning: "Web server", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := service.parseAIResponse(tt.response) + if tt.want == nil { + if got != nil { + t.Errorf("parseAIResponse() = %v, want nil", got) + } + return + } + if got == nil { + t.Fatal("parseAIResponse() = nil, want non-nil") + } + if got.ServiceType != tt.want.ServiceType { + t.Errorf("ServiceType = %q, want %q", got.ServiceType, tt.want.ServiceType) + } + if got.ServiceName != tt.want.ServiceName { + t.Errorf("ServiceName = %q, want %q", got.ServiceName, tt.want.ServiceName) + } + if got.Category != tt.want.Category { + t.Errorf("Category = %q, want %q", got.Category, tt.want.Category) + } + if got.CLICommand != tt.want.CLICommand { + t.Errorf("CLICommand = %q, want %q", got.CLICommand, tt.want.CLICommand) + } + }) + } +} + +func TestBuildContainerInfo(t *testing.T) { + service := &Service{} + + container := models.DockerContainer{ + ID: "abc123", + Name: "mydb", + Image: "postgres:14", + Status: "running", + Ports: []models.DockerContainerPort{ + {PublicPort: 5432, PrivatePort: 5432, Protocol: "tcp"}, + }, + Labels: map[string]string{ + "app": "database", + }, + Mounts: []models.DockerContainerMount{ + {Destination: "/var/lib/postgresql/data"}, + }, + Networks: []models.DockerContainerNetworkLink{ + {Name: "backend"}, + }, + } + + info := service.buildContainerInfo(container) + + if info.Name != "mydb" { + t.Errorf("Name = %q, want 'mydb'", info.Name) + } + if info.Image != "postgres:14" { + t.Errorf("Image = %q, want 'postgres:14'", info.Image) + } + if len(info.Ports) != 1 { + t.Errorf("Ports length = %d, want 1", len(info.Ports)) + } + if info.Ports[0].ContainerPort != 5432 { + t.Errorf("ContainerPort = %d, want 5432", info.Ports[0].ContainerPort) + } + if info.Labels["app"] != "database" { + t.Errorf("Labels[app] = %q, want 'database'", info.Labels["app"]) + } + if len(info.Mounts) != 1 || info.Mounts[0] != "/var/lib/postgresql/data" { + t.Errorf("Mounts = %v, want [/var/lib/postgresql/data]", info.Mounts) + } +} + +func TestRunDiscovery_NoAnalyzer(t *testing.T) { + provider := &mockStateProvider{ + state: models.StateSnapshot{ + DockerHosts: []models.DockerHost{ + { + Hostname: "host1", + Containers: []models.DockerContainer{ + {ID: "1", Name: "test", Image: "test:latest"}, + }, + }, + }, + }, + } + + service := NewService(provider, nil, DefaultConfig()) + // Don't set analyzer + + apps := service.RunDiscovery(context.Background()) + if apps != nil { + t.Errorf("RunDiscovery() without analyzer should return nil, got %v", apps) + } +} + +func TestRunDiscovery_WithAnalyzer(t *testing.T) { + provider := &mockStateProvider{ + state: models.StateSnapshot{ + DockerHosts: []models.DockerHost{ + { + AgentID: "agent-1", + Hostname: "docker-host", + Containers: []models.DockerContainer{ + {ID: "1", Name: "mydb", Image: "postgres:14"}, + {ID: "2", Name: "cache", Image: "redis:7"}, + }, + }, + }, + }, + } + + analyzer := &mockAIAnalyzer{ + responses: map[string]string{ + "postgres:14": `{"service_type": "postgres", "service_name": "PostgreSQL", "category": "database", "cli_command": "docker exec {container} psql -U postgres", "confidence": 0.95, "reasoning": "PostgreSQL database"}`, + "redis:7": `{"service_type": "redis", "service_name": "Redis", "category": "cache", "cli_command": "docker exec {container} redis-cli", "confidence": 0.9, "reasoning": "Redis cache"}`, + }, + } + + service := NewService(provider, nil, DefaultConfig()) + service.SetAIAnalyzer(analyzer) + + apps := service.RunDiscovery(context.Background()) + + if len(apps) != 2 { + t.Fatalf("RunDiscovery() returned %d apps, want 2", len(apps)) + } + + // Check PostgreSQL was detected + foundPostgres := false + foundRedis := false + for _, app := range apps { + if app.Type == "postgres" { + foundPostgres = true + if app.ContainerName != "mydb" { + t.Errorf("Postgres ContainerName = %q, want 'mydb'", app.ContainerName) + } + if app.CLIAccess != "docker exec mydb psql -U postgres" { + t.Errorf("Postgres CLIAccess = %q, want 'docker exec mydb psql -U postgres'", app.CLIAccess) + } + } + if app.Type == "redis" { + foundRedis = true + if app.ContainerName != "cache" { + t.Errorf("Redis ContainerName = %q, want 'cache'", app.ContainerName) + } + } + } + + if !foundPostgres { + t.Error("PostgreSQL not detected") + } + if !foundRedis { + t.Error("Redis not detected") + } +} + +func TestCaching(t *testing.T) { + provider := &mockStateProvider{ + state: models.StateSnapshot{ + DockerHosts: []models.DockerHost{ + { + AgentID: "agent-1", + Hostname: "host1", + Containers: []models.DockerContainer{ + {ID: "1", Name: "db1", Image: "postgres:14"}, + {ID: "2", Name: "db2", Image: "postgres:14"}, // Same image + }, + }, + }, + }, + } + + analyzer := &mockAIAnalyzer{ + responses: map[string]string{ + "postgres:14": `{"service_type": "postgres", "service_name": "PostgreSQL", "category": "database", "cli_command": "docker exec {container} psql", "confidence": 0.95, "reasoning": "PostgreSQL"}`, + }, + } + + service := NewService(provider, nil, DefaultConfig()) + service.SetAIAnalyzer(analyzer) + + // First run + service.RunDiscovery(context.Background()) + + // Should have called AI once (cached for second container with same image) + if analyzer.callCount != 1 { + t.Errorf("First run: analyzer called %d times, want 1 (caching)", analyzer.callCount) + } + + // Second run should use cache + analyzer.callCount = 0 + service.RunDiscovery(context.Background()) + + if analyzer.callCount != 0 { + t.Errorf("Second run: analyzer called %d times, want 0 (should use cache)", analyzer.callCount) + } + + // Clear cache and run again + service.ClearCache() + service.RunDiscovery(context.Background()) + + if analyzer.callCount != 1 { + t.Errorf("After cache clear: analyzer called %d times, want 1", analyzer.callCount) + } +} + +func TestGetStatus(t *testing.T) { + provider := &mockStateProvider{} + service := NewService(provider, nil, DefaultConfig()) + + status := service.GetStatus() + + if status["running"] != false { + t.Errorf("status['running'] = %v, want false", status["running"]) + } + if status["ai_analyzer_set"] != false { + t.Errorf("status['ai_analyzer_set'] = %v, want false", status["ai_analyzer_set"]) + } + + // Set analyzer + service.SetAIAnalyzer(&mockAIAnalyzer{}) + status = service.GetStatus() + + if status["ai_analyzer_set"] != true { + t.Errorf("status['ai_analyzer_set'] = %v, want true after setting analyzer", status["ai_analyzer_set"]) + } +} + +func TestLowConfidenceFiltering(t *testing.T) { + provider := &mockStateProvider{ + state: models.StateSnapshot{ + DockerHosts: []models.DockerHost{ + { + AgentID: "agent-1", + Hostname: "host1", + Containers: []models.DockerContainer{ + {ID: "1", Name: "mystery", Image: "custom/unknown:latest"}, + }, + }, + }, + }, + } + + analyzer := &mockAIAnalyzer{ + responses: map[string]string{ + "custom/unknown:latest": `{"service_type": "unknown", "service_name": "Unknown Service", "category": "unknown", "cli_command": "", "confidence": 0.3, "reasoning": "Cannot identify"}`, + }, + } + + service := NewService(provider, nil, DefaultConfig()) + service.SetAIAnalyzer(analyzer) + + apps := service.RunDiscovery(context.Background()) + + // Low confidence results should be filtered out + if len(apps) != 0 { + t.Errorf("RunDiscovery() returned %d apps, want 0 (low confidence should be filtered)", len(apps)) + } +} diff --git a/internal/kubernetesagent/agent_additional_test.go b/internal/kubernetesagent/agent_additional_test.go new file mode 100644 index 000000000..356ff855b --- /dev/null +++ b/internal/kubernetesagent/agent_additional_test.go @@ -0,0 +1,114 @@ +package kubernetesagent + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/buffer" + agentsk8s "github.com/rcourtman/pulse-go-rewrite/pkg/agents/kubernetes" + "github.com/rs/zerolog" + "k8s.io/apimachinery/pkg/version" + fakediscovery "k8s.io/client-go/discovery/fake" + "k8s.io/client-go/kubernetes/fake" +) + +func TestDiscoverClusterMetadata(t *testing.T) { + clientset := fake.NewSimpleClientset() + discovery := clientset.Discovery().(*fakediscovery.FakeDiscovery) + discovery.FakedServerVersion = &version.Info{GitVersion: "v1.2.3"} + + agent := &Agent{ + kubeClient: clientset, + } + + if err := agent.discoverClusterMetadata(context.Background()); err != nil { + t.Fatalf("discoverClusterMetadata: %v", err) + } + if agent.clusterVersion != "v1.2.3" { + t.Fatalf("clusterVersion = %q, want v1.2.3", agent.clusterVersion) + } +} + +func TestRun_StopsOnContextCancel(t *testing.T) { + requested := make(chan struct{}, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/agents/kubernetes/report" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Fatalf("unexpected method: %s", r.Method) + } + select { + case requested <- struct{}{}: + default: + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + logger := zerolog.New(io.Discard) + agent := &Agent{ + cfg: Config{APIToken: "token"}, + logger: logger, + httpClient: server.Client(), + pulseURL: server.URL, + agentID: "agent-1", + agentVersion: "v1", + interval: 10 * time.Millisecond, + kubeClient: fake.NewSimpleClientset(), + reportBuffer: buffer.New[agentsk8s.Report](5), + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + done <- agent.Run(ctx) + }() + + select { + case <-requested: + case <-time.After(200 * time.Millisecond): + t.Fatal("expected report request") + } + + cancel() + + select { + case err := <-done: + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("Run did not return after cancel") + } +} + +func TestSendReport_ErrorWithBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("nope")) + })) + defer server.Close() + + logger := zerolog.New(io.Discard) + agent := &Agent{ + cfg: Config{APIToken: "token"}, + logger: logger, + httpClient: server.Client(), + pulseURL: server.URL, + agentVersion: "v1", + } + + err := agent.sendReport(context.Background(), agentsk8s.Report{Timestamp: time.Now().UTC()}) + if err == nil { + t.Fatal("expected error from sendReport") + } + if !strings.Contains(err.Error(), "nope") { + t.Fatalf("error = %v, want body in message", err) + } +} diff --git a/internal/kubernetesagent/agent_new_test.go b/internal/kubernetesagent/agent_new_test.go new file mode 100644 index 000000000..7a41c3581 --- /dev/null +++ b/internal/kubernetesagent/agent_new_test.go @@ -0,0 +1,106 @@ +package kubernetesagent + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/rs/zerolog" +) + +func writeTestKubeconfig(t *testing.T, path, serverURL, contextName string) { + t.Helper() + kubeconfig := fmt.Sprintf(` +apiVersion: v1 +kind: Config +clusters: +- name: c1 + cluster: + server: %s +contexts: +- name: %s + context: + cluster: c1 + user: u1 +current-context: %s +users: +- name: u1 + user: + token: test +`, serverURL, contextName, contextName) + + if err := os.WriteFile(path, []byte(kubeconfig), 0o600); err != nil { + t.Fatalf("write kubeconfig: %v", err) + } +} + +func TestBuildRESTConfig_DefaultKubeconfig(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/version" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"gitVersion":"v1.25.0"}`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmp := t.TempDir() + kubeconfigPath := filepath.Join(tmp, "config") + writeTestKubeconfig(t, kubeconfigPath, server.URL, "ctx-default") + + t.Setenv("KUBECONFIG", kubeconfigPath) + + restCfg, ctxName, err := buildRESTConfig("", "") + if err != nil { + t.Fatalf("buildRESTConfig: %v", err) + } + if restCfg.Host != server.URL { + t.Fatalf("restCfg.Host = %q, want %q", restCfg.Host, server.URL) + } + if ctxName != "ctx-default" { + t.Fatalf("contextName = %q, want ctx-default", ctxName) + } +} + +func TestNew_WithKubeconfig(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/version" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"gitVersion":"v1.26.1"}`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmp := t.TempDir() + kubeconfigPath := filepath.Join(tmp, "config") + writeTestKubeconfig(t, kubeconfigPath, server.URL, "ctx-test") + + agent, err := New(Config{ + PulseURL: "http://pulse.local", + APIToken: "token", + KubeconfigPath: kubeconfigPath, + KubeContext: "ctx-test", + LogLevel: zerolog.Disabled, + }) + if err != nil { + t.Fatalf("New returned error: %v", err) + } + if agent.clusterServer != server.URL { + t.Fatalf("clusterServer = %q, want %q", agent.clusterServer, server.URL) + } + if agent.clusterContext != "ctx-test" { + t.Fatalf("clusterContext = %q, want ctx-test", agent.clusterContext) + } + if agent.clusterVersion != "v1.26.1" { + t.Fatalf("clusterVersion = %q, want v1.26.1", agent.clusterVersion) + } + if agent.agentID == "" || agent.clusterID == "" { + t.Fatalf("expected non-empty agent and cluster IDs") + } +} diff --git a/internal/metrics/incident_recorder_additional_test.go b/internal/metrics/incident_recorder_additional_test.go new file mode 100644 index 000000000..36b90c17f --- /dev/null +++ b/internal/metrics/incident_recorder_additional_test.go @@ -0,0 +1,127 @@ +package metrics + +import ( + "sync/atomic" + "testing" + "time" +) + +type countingProvider struct { + ids []string + metrics map[string]map[string]float64 + calls int32 +} + +func (c *countingProvider) GetCurrentMetrics(resourceID string) (map[string]float64, error) { + atomic.AddInt32(&c.calls, 1) + metrics, ok := c.metrics[resourceID] + if !ok { + return nil, errNoMetrics(resourceID) + } + copied := make(map[string]float64, len(metrics)) + for k, v := range metrics { + copied[k] = v + } + return copied, nil +} + +func (c *countingProvider) GetMonitoredResourceIDs() []string { + return append([]string{}, c.ids...) +} + +func waitForCalls(t *testing.T, provider *countingProvider, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if atomic.LoadInt32(&provider.calls) > 0 { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatal("expected provider to be called") +} + +func TestDefaultIncidentRecorderConfig(t *testing.T) { + cfg := DefaultIncidentRecorderConfig() + if cfg.SampleInterval == 0 || cfg.PreIncidentWindow == 0 || cfg.PostIncidentWindow == 0 { + t.Fatalf("default config should be non-zero, got %+v", cfg) + } + if cfg.MaxDataPointsPerWindow == 0 || cfg.MaxWindows == 0 || cfg.RetentionDuration == 0 { + t.Fatalf("default config should be non-zero, got %+v", cfg) + } +} + +func TestIncidentRecorderStartStop(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{ + SampleInterval: 5 * time.Millisecond, + PreIncidentWindow: 10 * time.Millisecond, + PostIncidentWindow: 10 * time.Millisecond, + MaxDataPointsPerWindow: 5, + }) + + provider := &countingProvider{ + ids: []string{"res-1"}, + metrics: map[string]map[string]float64{ + "res-1": {"cpu": 1}, + }, + } + recorder.SetMetricsProvider(provider) + + recorder.Start() + waitForCalls(t, provider, 200*time.Millisecond) + recorder.Stop() + + if recorder.running { + t.Fatalf("expected recorder to be stopped") + } +} + +func TestGetWindowsForResource(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{}) + + active := &IncidentWindow{ID: "active-1", ResourceID: "res-1"} + recorder.activeWindows["active-1"] = active + + got := recorder.GetWindowsForResource("res-1", 0) + if len(got) != 1 || got[0].ID != "active-1" { + t.Fatalf("expected active window, got %+v", got) + } + + recorder.activeWindows = map[string]*IncidentWindow{} + recorder.completedWindows = []*IncidentWindow{ + {ID: "old", ResourceID: "res-1"}, + {ID: "new", ResourceID: "res-1"}, + } + + limited := recorder.GetWindowsForResource("res-1", 1) + if len(limited) != 1 || limited[0].ID != "new" { + t.Fatalf("expected most recent completed window, got %+v", limited) + } +} + +func TestGetRecentWindows(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{}) + recorder.activeWindows["active"] = &IncidentWindow{ID: "active", ResourceID: "res-1"} + recorder.completedWindows = []*IncidentWindow{ + {ID: "old", ResourceID: "res-2"}, + {ID: "new", ResourceID: "res-3"}, + } + + recent := recorder.GetRecentWindows(2) + if len(recent) != 2 { + t.Fatalf("expected 2 windows, got %d", len(recent)) + } + + var sawActive, sawNew bool + for _, window := range recent { + if window.ID == "active" { + sawActive = true + } + if window.ID == "new" { + sawNew = true + } + } + if !sawActive || !sawNew { + t.Fatalf("expected active and newest completed windows, got %+v", recent) + } +} diff --git a/internal/metrics/incident_recorder_test.go b/internal/metrics/incident_recorder_test.go new file mode 100644 index 000000000..1235e0806 --- /dev/null +++ b/internal/metrics/incident_recorder_test.go @@ -0,0 +1,296 @@ +package metrics + +import ( + "strings" + "testing" + "time" +) + +type stubMetricsProvider struct { + metricsByID map[string]map[string]float64 + ids []string +} + +func (s *stubMetricsProvider) GetCurrentMetrics(resourceID string) (map[string]float64, error) { + metrics, ok := s.metricsByID[resourceID] + if !ok { + return nil, errNoMetrics(resourceID) + } + copied := make(map[string]float64, len(metrics)) + for k, v := range metrics { + copied[k] = v + } + return copied, nil +} + +func (s *stubMetricsProvider) GetMonitoredResourceIDs() []string { + return append([]string{}, s.ids...) +} + +type errNoMetrics string + +func (e errNoMetrics) Error() string { + return "no metrics for " + string(e) +} + +func TestNewIncidentRecorderDefaults(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{}) + + if recorder.config.SampleInterval != 5*time.Second { + t.Fatalf("expected default sample interval, got %s", recorder.config.SampleInterval) + } + if recorder.config.PreIncidentWindow != 5*time.Minute { + t.Fatalf("expected default pre-incident window, got %s", recorder.config.PreIncidentWindow) + } + if recorder.config.PostIncidentWindow != 10*time.Minute { + t.Fatalf("expected default post-incident window, got %s", recorder.config.PostIncidentWindow) + } + if recorder.config.MaxDataPointsPerWindow != 500 { + t.Fatalf("expected default max data points, got %d", recorder.config.MaxDataPointsPerWindow) + } + if recorder.config.MaxWindows != 100 { + t.Fatalf("expected default max windows, got %d", recorder.config.MaxWindows) + } + if recorder.config.RetentionDuration != 24*time.Hour { + t.Fatalf("expected default retention, got %s", recorder.config.RetentionDuration) + } +} + +func TestStartRecordingExtendsWindow(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{ + PreIncidentWindow: time.Minute, + PostIncidentWindow: time.Minute, + }) + + firstID := recorder.StartRecording("res-1", "db", "host", "alert", "alert-1") + firstWindow := recorder.activeWindows[firstID] + if firstWindow == nil { + t.Fatalf("expected window for %s", firstID) + } + firstEnd := *firstWindow.EndTime + + secondID := recorder.StartRecording("res-1", "db", "host", "alert", "alert-2") + if secondID != firstID { + t.Fatalf("expected same window ID, got %s and %s", firstID, secondID) + } + secondWindow := recorder.activeWindows[secondID] + if secondWindow.EndTime.Before(firstEnd) { + t.Fatalf("expected end time to extend or remain, got %s before %s", secondWindow.EndTime, firstEnd) + } +} + +func TestRecordSampleBuffersAndCleansUp(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{ + PreIncidentWindow: time.Minute, + PostIncidentWindow: time.Minute, + MaxDataPointsPerWindow: 10, + }) + + provider := &stubMetricsProvider{ + metricsByID: map[string]map[string]float64{ + "res-1": {"cpu": 1}, + "res-2": {"cpu": 2}, + }, + ids: []string{"res-1", "res-2"}, + } + recorder.SetMetricsProvider(provider) + + recorder.preIncidentBuffer["gone"] = []IncidentDataPoint{ + {Timestamp: time.Now().Add(-time.Minute), Metrics: map[string]float64{"cpu": 0.5}}, + } + + windowID := recorder.StartRecording("res-1", "db", "host", "alert", "alert-1") + recorder.recordSample() + + window := recorder.activeWindows[windowID] + if window == nil { + t.Fatalf("expected active window %s", windowID) + } + if len(window.DataPoints) != 1 { + t.Fatalf("expected 1 data point, got %d", len(window.DataPoints)) + } + + if len(recorder.preIncidentBuffer["res-1"]) == 0 { + t.Fatalf("expected pre-incident buffer for res-1") + } + if len(recorder.preIncidentBuffer["res-2"]) == 0 { + t.Fatalf("expected pre-incident buffer for res-2") + } + if _, ok := recorder.preIncidentBuffer["gone"]; ok { + t.Fatalf("expected cleanup of unmonitored resource buffer") + } +} + +func TestStopRecordingCompletesWindow(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{ + PreIncidentWindow: time.Minute, + PostIncidentWindow: time.Minute, + }) + provider := &stubMetricsProvider{ + metricsByID: map[string]map[string]float64{ + "res-1": {"cpu": 1}, + }, + ids: []string{"res-1"}, + } + recorder.SetMetricsProvider(provider) + + windowID := recorder.StartRecording("res-1", "db", "host", "alert", "alert-1") + recorder.recordSample() + recorder.StopRecording(windowID) + + if _, ok := recorder.activeWindows[windowID]; ok { + t.Fatalf("expected window %s to be removed from active windows", windowID) + } + if len(recorder.completedWindows) != 1 { + t.Fatalf("expected 1 completed window, got %d", len(recorder.completedWindows)) + } + if recorder.completedWindows[0].Status != IncidentWindowStatusComplete { + t.Fatalf("expected completed status, got %s", recorder.completedWindows[0].Status) + } + if recorder.completedWindows[0].Summary == nil { + t.Fatalf("expected summary to be computed") + } +} + +func TestComputeSummary(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{}) + start := time.Now().Add(-time.Second) + end := start.Add(time.Second) + window := &IncidentWindow{ + DataPoints: []IncidentDataPoint{ + {Timestamp: start, Metrics: map[string]float64{"cpu": 1, "mem": 4}}, + {Timestamp: end, Metrics: map[string]float64{"cpu": 3, "mem": 2}}, + }, + } + + summary := recorder.computeSummary(window) + if summary == nil { + t.Fatalf("expected summary") + } + if summary.DataPoints != 2 { + t.Fatalf("expected 2 data points, got %d", summary.DataPoints) + } + if summary.Peaks["cpu"] != 3 || summary.Lows["cpu"] != 1 { + t.Fatalf("unexpected cpu stats: peaks=%v lows=%v", summary.Peaks["cpu"], summary.Lows["cpu"]) + } + if summary.Peaks["mem"] != 4 || summary.Lows["mem"] != 2 { + t.Fatalf("unexpected mem stats: peaks=%v lows=%v", summary.Peaks["mem"], summary.Lows["mem"]) + } + if summary.Averages["cpu"] != 2 { + t.Fatalf("unexpected cpu average: %v", summary.Averages["cpu"]) + } + if summary.Averages["mem"] != 3 { + t.Fatalf("unexpected mem average: %v", summary.Averages["mem"]) + } + if summary.Changes["cpu"] != 2 || summary.Changes["mem"] != -2 { + t.Fatalf("unexpected changes: cpu=%v mem=%v", summary.Changes["cpu"], summary.Changes["mem"]) + } + if summary.Duration != time.Second { + t.Fatalf("unexpected duration: %s", summary.Duration) + } +} + +func TestFormatForContextIncludesSummaryAndData(t *testing.T) { + recorder := NewIncidentRecorder(IncidentRecorderConfig{}) + now := time.Now() + end := now.Add(2 * time.Second) + window := &IncidentWindow{ + ID: "window-1", + ResourceID: "res-1", + Status: IncidentWindowStatusComplete, + StartTime: now, + EndTime: &end, + DataPoints: []IncidentDataPoint{ + {Timestamp: now, Metrics: map[string]float64{"cpu": 1.25}}, + }, + Summary: &IncidentSummary{ + Duration: 2 * time.Second, + DataPoints: 1, + Peaks: map[string]float64{"cpu": 1.25}, + Changes: map[string]float64{"cpu": 0.5}, + }, + } + recorder.completedWindows = []*IncidentWindow{window} + + formatted := recorder.FormatForContext("", "window-1") + if formatted == "" { + t.Fatalf("expected formatted context") + } + required := []string{ + "Incident Recording Data", + "Summary", + "Duration: 2s", + "Data points: 1", + "Peak values", + "cpu: 1.25", + "Changes during incident", + "cpu: +0.50", + "Recent Data Points", + } + for _, snippet := range required { + if !strings.Contains(formatted, snippet) { + t.Fatalf("expected formatted output to include %q", snippet) + } + } +} + +func TestCopyWindowDeepCopy(t *testing.T) { + now := time.Now() + end := now.Add(time.Second) + window := &IncidentWindow{ + ID: "window-1", + EndTime: &end, + DataPoints: []IncidentDataPoint{ + {Timestamp: now, Metrics: map[string]float64{"cpu": 1}}, + }, + Summary: &IncidentSummary{ + Peaks: map[string]float64{"cpu": 1}, + }, + } + + clone := copyWindow(window) + if clone == nil || clone == window { + t.Fatalf("expected deep copy") + } + if clone.Summary == window.Summary { + t.Fatalf("expected summary to be copied") + } + + window.DataPoints[0].Metrics["cpu"] = 9 + *window.EndTime = end.Add(5 * time.Second) + + if clone.DataPoints[0].Metrics["cpu"] != 1 { + t.Fatalf("expected data points to be copied") + } + if clone.EndTime.Equal(*window.EndTime) { + t.Fatalf("expected end time to be copied") + } +} + +func TestSaveAndLoad(t *testing.T) { + dir := t.TempDir() + recorder := NewIncidentRecorder(IncidentRecorderConfig{DataDir: dir}) + + end := time.Now() + recorder.completedWindows = []*IncidentWindow{ + { + ID: "window-1", + EndTime: &end, + Status: IncidentWindowStatusComplete, + DataPoints: []IncidentDataPoint{{Timestamp: end, Metrics: map[string]float64{"cpu": 1}}}, + }, + } + + if err := recorder.saveToDisk(); err != nil { + t.Fatalf("save failed: %v", err) + } + + loaded := NewIncidentRecorder(IncidentRecorderConfig{DataDir: dir}) + window := loaded.GetWindow("window-1") + if window == nil { + t.Fatalf("expected window to load from disk") + } + if window.Status != IncidentWindowStatusComplete { + t.Fatalf("expected status to persist, got %s", window.Status) + } +} diff --git a/internal/models/converters_additional_test.go b/internal/models/converters_additional_test.go new file mode 100644 index 000000000..b0951afe6 --- /dev/null +++ b/internal/models/converters_additional_test.go @@ -0,0 +1,198 @@ +package models + +import ( + "encoding/json" + "testing" + "time" +) + +func TestRemovedDockerHostToFrontend_DisplayName(t *testing.T) { + removedAt := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) + host := RemovedDockerHost{ + ID: "host-1", + Hostname: "docker-1", + DisplayName: "Docker One", + RemovedAt: removedAt, + } + + frontend := host.ToFrontend() + if frontend.ID != host.ID { + t.Fatalf("ID = %q, want %q", frontend.ID, host.ID) + } + if frontend.Hostname != host.Hostname { + t.Fatalf("Hostname = %q, want %q", frontend.Hostname, host.Hostname) + } + if frontend.DisplayName != host.DisplayName { + t.Fatalf("DisplayName = %q, want %q", frontend.DisplayName, host.DisplayName) + } + if frontend.RemovedAt != removedAt.Unix()*1000 { + t.Fatalf("RemovedAt = %d, want %d", frontend.RemovedAt, removedAt.Unix()*1000) + } +} + +func TestKubernetesClusterToFrontend(t *testing.T) { + now := time.Date(2024, 2, 3, 4, 5, 6, 0, time.UTC) + lastUsed := time.Date(2024, 2, 4, 5, 6, 7, 0, time.UTC) + + cluster := KubernetesCluster{ + ID: "cluster-1", + AgentID: "agent-1", + Name: "prod", + DisplayName: "", + Server: "https://k8s", + Context: "ctx", + Version: "v1.25.0", + Status: "healthy", + LastSeen: now, + IntervalSeconds: 30, + AgentVersion: "v1", + TokenID: "token-1", + TokenName: "token-name", + TokenHint: "hint", + TokenLastUsedAt: &lastUsed, + Hidden: true, + PendingUninstall: true, + Nodes: []KubernetesNode{ + { + UID: "node-1", + Name: "node-a", + Ready: true, + Roles: []string{"master"}, + }, + }, + Pods: []KubernetesPod{ + { + UID: "pod-1", + Name: "pod-a", + Namespace: "ns", + Labels: map[string]string{"app": "svc"}, + Containers: []KubernetesPodContainer{ + {Name: "c1", Ready: true}, + }, + }, + }, + Deployments: []KubernetesDeployment{ + { + UID: "dep-1", + Name: "deploy-a", + Namespace: "ns", + DesiredReplicas: 2, + Labels: map[string]string{"tier": "web"}, + }, + }, + } + + frontend := cluster.ToFrontend() + if frontend.DisplayName != cluster.Name { + t.Fatalf("DisplayName = %q, want %q", frontend.DisplayName, cluster.Name) + } + if frontend.TokenLastUsedAt == nil || *frontend.TokenLastUsedAt != lastUsed.Unix()*1000 { + t.Fatalf("TokenLastUsedAt = %v, want %d", frontend.TokenLastUsedAt, lastUsed.Unix()*1000) + } + if len(frontend.Nodes) != 1 || frontend.Nodes[0].Name != "node-a" { + t.Fatalf("Nodes = %#v, want 1 node", frontend.Nodes) + } + if len(frontend.Pods) != 1 || frontend.Pods[0].Name != "pod-a" { + t.Fatalf("Pods = %#v, want 1 pod", frontend.Pods) + } + if len(frontend.Deployments) != 1 || frontend.Deployments[0].Name != "deploy-a" { + t.Fatalf("Deployments = %#v, want 1 deployment", frontend.Deployments) + } +} + +func TestKubernetesClusterToFrontend_DisplayNameFallback(t *testing.T) { + cluster := KubernetesCluster{ + ID: "cluster-2", + Name: "", + } + + frontend := cluster.ToFrontend() + if frontend.DisplayName != cluster.ID { + t.Fatalf("DisplayName = %q, want %q", frontend.DisplayName, cluster.ID) + } +} + +func TestTimeToUnixMillis(t *testing.T) { + if got := timeToUnixMillis(time.Time{}); got != 0 { + t.Fatalf("timeToUnixMillis(zero) = %d, want 0", got) + } + + now := time.Date(2024, 3, 4, 5, 6, 7, 0, time.UTC) + if got := timeToUnixMillis(now); got != now.Unix()*1000 { + t.Fatalf("timeToUnixMillis = %d, want %d", got, now.Unix()*1000) + } +} + +func TestRemovedKubernetesClusterToFrontend(t *testing.T) { + removedAt := time.Date(2024, 4, 5, 6, 7, 8, 0, time.UTC) + cluster := RemovedKubernetesCluster{ + ID: "cluster-3", + Name: "old", + DisplayName: "Old Cluster", + RemovedAt: removedAt, + } + + frontend := cluster.ToFrontend() + if frontend.ID != cluster.ID { + t.Fatalf("ID = %q, want %q", frontend.ID, cluster.ID) + } + if frontend.Name != cluster.Name { + t.Fatalf("Name = %q, want %q", frontend.Name, cluster.Name) + } + if frontend.DisplayName != cluster.DisplayName { + t.Fatalf("DisplayName = %q, want %q", frontend.DisplayName, cluster.DisplayName) + } + if frontend.RemovedAt != removedAt.Unix()*1000 { + t.Fatalf("RemovedAt = %d, want %d", frontend.RemovedAt, removedAt.Unix()*1000) + } +} + +func TestConvertResourceToFrontend(t *testing.T) { + total := int64(100) + used := int64(60) + free := int64(40) + identity := &ResourceIdentityInput{ + Hostname: "host-a", + MachineID: "machine-1", + IPs: []string{"10.0.0.1"}, + } + + input := ResourceConvertInput{ + ID: "res-1", + Type: "node", + Name: "node-a", + DisplayName: "Node A", + Status: "healthy", + CPU: &ResourceMetricInput{Current: 12.5, Total: &total, Used: &used, Free: &free}, + Memory: &ResourceMetricInput{Current: 50.1}, + Disk: &ResourceMetricInput{Current: 70.2}, + NetworkRX: 1000, + NetworkTX: 2000, + HasNetwork: true, + Tags: []string{"prod"}, + Labels: map[string]string{"env": "prod"}, + LastSeenUnix: 12345, + Alerts: []ResourceAlertInput{ + {ID: "a1", Type: "cpu", Level: "warn", Message: "high", Value: 90, Threshold: 80, StartTimeUnix: 111}, + }, + Identity: identity, + PlatformData: json.RawMessage(`{"key":"value"}`), + } + + frontend := ConvertResourceToFrontend(input) + if frontend.ID != input.ID || frontend.Name != input.Name { + t.Fatalf("frontend = %#v, want ID %q Name %q", frontend, input.ID, input.Name) + } + if frontend.CPU == nil || frontend.CPU.Total == nil || *frontend.CPU.Total != total { + t.Fatalf("CPU = %#v, want total %d", frontend.CPU, total) + } + if frontend.Network == nil || frontend.Network.RXBytes != 1000 || frontend.Network.TXBytes != 2000 { + t.Fatalf("Network = %#v, want RX/TX set", frontend.Network) + } + if len(frontend.Alerts) != 1 || frontend.Alerts[0].ID != "a1" { + t.Fatalf("Alerts = %#v, want 1 alert", frontend.Alerts) + } + if frontend.Identity == nil || frontend.Identity.Hostname != identity.Hostname { + t.Fatalf("Identity = %#v, want hostname %q", frontend.Identity, identity.Hostname) + } +} diff --git a/internal/models/organization_additional_test.go b/internal/models/organization_additional_test.go new file mode 100644 index 000000000..efc58e5b1 --- /dev/null +++ b/internal/models/organization_additional_test.go @@ -0,0 +1,33 @@ +package models + +import "testing" + +func TestOrganizationAccessors(t *testing.T) { + org := &Organization{ + ID: "org-1", + OwnerUserID: "owner", + Members: []OrganizationMember{ + {UserID: "admin", Role: OrgRoleAdmin}, + {UserID: "member", Role: OrgRoleMember}, + }, + } + + if !org.HasMember("admin") || org.HasMember("missing") { + t.Fatalf("HasMember results unexpected") + } + if role := org.GetMemberRole("admin"); role != OrgRoleAdmin { + t.Fatalf("GetMemberRole = %q, want admin", role) + } + if role := org.GetMemberRole("missing"); role != "" { + t.Fatalf("GetMemberRole for missing = %q, want empty", role) + } + if !org.IsOwner("owner") || org.IsOwner("admin") { + t.Fatalf("IsOwner results unexpected") + } + if !org.CanUserAccess("owner") || !org.CanUserAccess("member") || org.CanUserAccess("missing") { + t.Fatalf("CanUserAccess results unexpected") + } + if !org.CanUserManage("owner") || !org.CanUserManage("admin") || org.CanUserManage("member") { + t.Fatalf("CanUserManage results unexpected") + } +} diff --git a/internal/models/profile_validation_additional_test.go b/internal/models/profile_validation_additional_test.go new file mode 100644 index 000000000..e2c93cab6 --- /dev/null +++ b/internal/models/profile_validation_additional_test.go @@ -0,0 +1,10 @@ +package models + +import "testing" + +func TestValidationErrorError(t *testing.T) { + err := ValidationError{Key: "cpu_threshold", Message: "invalid"} + if got := err.Error(); got != "cpu_threshold: invalid" { + t.Fatalf("Error() = %q, want cpu_threshold: invalid", got) + } +} diff --git a/internal/models/state_additional_test.go b/internal/models/state_additional_test.go new file mode 100644 index 000000000..abbd5b076 --- /dev/null +++ b/internal/models/state_additional_test.go @@ -0,0 +1,263 @@ +package models + +import ( + "strings" + "testing" + "time" +) + +func TestStateClearAllDockerHosts(t *testing.T) { + state := &State{ + DockerHosts: []DockerHost{ + {ID: "h1"}, + {ID: "h2"}, + }, + } + + count := state.ClearAllDockerHosts() + if count != 2 { + t.Fatalf("count = %d, want 2", count) + } + if len(state.DockerHosts) != 0 { + t.Fatalf("DockerHosts = %#v, want empty", state.DockerHosts) + } +} + +func TestStateKubernetesClusterLifecycle(t *testing.T) { + state := &State{} + + initial := KubernetesCluster{ + ID: "c1", + Name: "alpha", + CustomDisplayName: "keep", + Hidden: true, + PendingUninstall: true, + Status: "init", + } + state.UpsertKubernetesCluster(initial) + + update := KubernetesCluster{ + ID: "c1", + Name: "alpha", + CustomDisplayName: "", + Hidden: false, + PendingUninstall: false, + Status: "ready", + } + state.UpsertKubernetesCluster(update) + + clusters := state.GetKubernetesClusters() + if len(clusters) != 1 { + t.Fatalf("clusters = %#v, want 1", clusters) + } + if clusters[0].CustomDisplayName != "keep" { + t.Fatalf("CustomDisplayName = %q, want keep", clusters[0].CustomDisplayName) + } + if !clusters[0].Hidden || !clusters[0].PendingUninstall { + t.Fatalf("expected Hidden and PendingUninstall preserved") + } + + if ok := state.SetKubernetesClusterStatus("c1", "ok"); !ok { + t.Fatalf("SetKubernetesClusterStatus returned false") + } + if _, ok := state.SetKubernetesClusterHidden("c1", false); !ok { + t.Fatalf("SetKubernetesClusterHidden returned false") + } + if _, ok := state.SetKubernetesClusterPendingUninstall("c1", false); !ok { + t.Fatalf("SetKubernetesClusterPendingUninstall returned false") + } + if _, ok := state.SetKubernetesClusterCustomDisplayName("c1", "custom"); !ok { + t.Fatalf("SetKubernetesClusterCustomDisplayName returned false") + } + + removed, ok := state.RemoveKubernetesCluster("c1") + if !ok || removed.ID != "c1" { + t.Fatalf("RemoveKubernetesCluster = (%v, %v), want c1", removed, ok) + } + if _, ok := state.RemoveKubernetesCluster("missing"); ok { + t.Fatalf("expected RemoveKubernetesCluster to fail for missing") + } +} + +func TestStateRemovedKubernetesClusters(t *testing.T) { + state := &State{} + t1 := time.Date(2024, 1, 1, 1, 0, 0, 0, time.UTC) + t2 := time.Date(2024, 1, 2, 1, 0, 0, 0, time.UTC) + + state.AddRemovedKubernetesCluster(RemovedKubernetesCluster{ID: "c1", RemovedAt: t1}) + state.AddRemovedKubernetesCluster(RemovedKubernetesCluster{ID: "c2", RemovedAt: t2}) + state.AddRemovedKubernetesCluster(RemovedKubernetesCluster{ID: "c1", DisplayName: "updated", RemovedAt: t1}) + + entries := state.GetRemovedKubernetesClusters() + if len(entries) != 2 { + t.Fatalf("entries = %#v, want 2", entries) + } + if entries[0].ID != "c2" { + t.Fatalf("entries[0].ID = %q, want c2", entries[0].ID) + } + + state.RemoveRemovedKubernetesCluster("c1") + entries = state.GetRemovedKubernetesClusters() + if len(entries) != 1 || entries[0].ID != "c2" { + t.Fatalf("entries = %#v, want c2 only", entries) + } +} + +func TestStateClearAllHosts(t *testing.T) { + state := &State{ + Hosts: []Host{{ID: "h1"}, {ID: "h2"}}, + } + + count := state.ClearAllHosts() + if count != 2 { + t.Fatalf("count = %d, want 2", count) + } + if len(state.Hosts) != 0 { + t.Fatalf("Hosts = %#v, want empty", state.Hosts) + } +} + +func TestStateLinkNodeToHostAgent(t *testing.T) { + state := &State{ + Nodes: []Node{{ID: "n1"}}, + } + + if ok := state.LinkNodeToHostAgent("n1", "h1"); !ok { + t.Fatalf("LinkNodeToHostAgent returned false") + } + if state.Nodes[0].LinkedHostAgentID != "h1" { + t.Fatalf("LinkedHostAgentID = %q, want h1", state.Nodes[0].LinkedHostAgentID) + } + if ok := state.LinkNodeToHostAgent("missing", "h1"); ok { + t.Fatalf("expected false for missing node") + } +} + +func TestStateUnlinkNodesFromHostAgent(t *testing.T) { + state := &State{ + Nodes: []Node{ + {ID: "n1", LinkedHostAgentID: "h1"}, + {ID: "n2", LinkedHostAgentID: "h1"}, + {ID: "n3", LinkedHostAgentID: "h2"}, + }, + } + + count := state.UnlinkNodesFromHostAgent("h1") + if count != 2 { + t.Fatalf("count = %d, want 2", count) + } + for _, node := range state.Nodes[:2] { + if node.LinkedHostAgentID != "" { + t.Fatalf("expected LinkedHostAgentID cleared, got %q", node.LinkedHostAgentID) + } + } +} + +func TestStateLinkHostAgentToNode(t *testing.T) { + state := &State{ + Hosts: []Host{ + {ID: "h1", LinkedNodeID: "n1"}, + {ID: "h2", LinkedVMID: "vm1", LinkedContainerID: "ct1"}, + }, + Nodes: []Node{ + {ID: "n1", LinkedHostAgentID: "h1"}, + {ID: "n2"}, + }, + } + + if err := state.LinkHostAgentToNode("h2", "n2"); err != nil { + t.Fatalf("LinkHostAgentToNode error: %v", err) + } + if state.Hosts[1].LinkedNodeID != "n2" { + t.Fatalf("LinkedNodeID = %q, want n2", state.Hosts[1].LinkedNodeID) + } + if state.Nodes[1].LinkedHostAgentID != "h2" { + t.Fatalf("LinkedHostAgentID = %q, want h2", state.Nodes[1].LinkedHostAgentID) + } + if state.Hosts[1].LinkedVMID != "" || state.Hosts[1].LinkedContainerID != "" { + t.Fatalf("expected VM/container links cleared") + } + + if err := state.LinkHostAgentToNode("missing", "n2"); err == nil || !strings.Contains(err.Error(), "host agent not found") { + t.Fatalf("expected host not found error, got %v", err) + } + if err := state.LinkHostAgentToNode("h2", "missing"); err == nil || !strings.Contains(err.Error(), "node not found") { + t.Fatalf("expected node not found error, got %v", err) + } +} + +func TestStateUnlinkHostAgent(t *testing.T) { + state := &State{ + Hosts: []Host{{ID: "h1", LinkedNodeID: "n1", LinkedVMID: "vm", LinkedContainerID: "ct"}}, + Nodes: []Node{{ID: "n1", LinkedHostAgentID: "h1"}}, + } + + if ok := state.UnlinkHostAgent("h1"); !ok { + t.Fatalf("UnlinkHostAgent returned false") + } + if state.Hosts[0].LinkedNodeID != "" || state.Hosts[0].LinkedVMID != "" || state.Hosts[0].LinkedContainerID != "" { + t.Fatalf("expected host links cleared") + } + if state.Nodes[0].LinkedHostAgentID != "" { + t.Fatalf("expected node link cleared") + } + if ok := state.UnlinkHostAgent("missing"); ok { + t.Fatalf("expected false for missing host") + } +} + +func TestStateUpsertCephCluster(t *testing.T) { + state := &State{} + state.UpsertCephCluster(CephCluster{ID: "c1", Name: "b"}) + state.UpsertCephCluster(CephCluster{ID: "c2", Name: "a"}) + state.UpsertCephCluster(CephCluster{ID: "c1", Name: "c"}) + + if len(state.CephClusters) != 2 { + t.Fatalf("clusters = %#v, want 2", state.CephClusters) + } + if state.CephClusters[0].Name != "a" || state.CephClusters[1].Name != "c" { + t.Fatalf("clusters order = %#v, want a then c", state.CephClusters) + } +} + +func TestStateSetHostCommandsEnabled(t *testing.T) { + state := &State{ + Hosts: []Host{{ID: "h1", CommandsEnabled: false}}, + } + + if ok := state.SetHostCommandsEnabled("h1", true); !ok { + t.Fatalf("SetHostCommandsEnabled returned false") + } + if !state.Hosts[0].CommandsEnabled { + t.Fatalf("CommandsEnabled not updated") + } + if ok := state.SetHostCommandsEnabled("missing", true); ok { + t.Fatalf("expected false for missing host") + } +} + +func TestStateContainers(t *testing.T) { + now := time.Now() + state := &State{ + Containers: []Container{{ID: "ct1"}}, + } + + containers := state.GetContainers() + if len(containers) != 1 || containers[0].ID != "ct1" { + t.Fatalf("containers = %#v, want ct1", containers) + } + containers[0].ID = "changed" + if state.Containers[0].ID != "ct1" { + t.Fatalf("state containers should not be modified by copy") + } + + if ok := state.UpdateContainerDockerStatus("ct1", true, now); !ok { + t.Fatalf("UpdateContainerDockerStatus returned false") + } + if !state.Containers[0].HasDocker { + t.Fatalf("HasDocker not updated") + } + if ok := state.UpdateContainerDockerStatus("missing", true, now); ok { + t.Fatalf("expected false for missing container") + } +} diff --git a/internal/monitoring/monitor_additional_test.go b/internal/monitoring/monitor_additional_test.go new file mode 100644 index 000000000..65cce9912 --- /dev/null +++ b/internal/monitoring/monitor_additional_test.go @@ -0,0 +1,175 @@ +package monitoring + +import ( + "context" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" + "github.com/rcourtman/pulse-go-rewrite/internal/models" +) + +type fakeDockerChecker struct{} + +func (f *fakeDockerChecker) CheckDockerInContainer(ctx context.Context, node string, vmid int) (bool, error) { + return false, nil +} + +func TestMonitorGetConfig(t *testing.T) { + cfg := &config.Config{DataPath: "/tmp/pulse-test"} + monitor := &Monitor{config: cfg} + + if got := monitor.GetConfig(); got != cfg { + t.Fatalf("GetConfig = %v, want %v", got, cfg) + } +} + +func TestMonitorSetGetDockerChecker(t *testing.T) { + monitor := &Monitor{} + checker := &fakeDockerChecker{} + + monitor.SetDockerChecker(checker) + if got := monitor.GetDockerChecker(); got != checker { + t.Fatalf("GetDockerChecker = %v, want %v", got, checker) + } + + monitor.SetDockerChecker(nil) + if got := monitor.GetDockerChecker(); got != nil { + t.Fatalf("GetDockerChecker = %v, want nil", got) + } +} + +func TestMonitorGetDockerHosts(t *testing.T) { + monitor := &Monitor{state: models.NewState()} + monitor.state.UpsertDockerHost(models.DockerHost{ID: "host-1", Hostname: "host-1"}) + + hosts := monitor.GetDockerHosts() + if len(hosts) != 1 { + t.Fatalf("GetDockerHosts length = %d, want 1", len(hosts)) + } + if hosts[0].ID != "host-1" { + t.Fatalf("GetDockerHosts[0].ID = %q, want %q", hosts[0].ID, "host-1") + } +} + +func TestMonitorGetDockerHostsNilReceiver(t *testing.T) { + var monitor *Monitor + if got := monitor.GetDockerHosts(); got != nil { + t.Fatalf("GetDockerHosts = %v, want nil", got) + } +} + +func TestMonitorLinkHostAgent(t *testing.T) { + monitor := &Monitor{state: models.NewState()} + + if err := monitor.LinkHostAgent("", "node-1"); err == nil { + t.Fatalf("expected error on empty host ID") + } + if err := monitor.LinkHostAgent("host-1", ""); err == nil { + t.Fatalf("expected error on empty node ID") + } + + monitor.state.UpsertHost(models.Host{ID: "host-1", Hostname: "host-1"}) + monitor.state.UpdateNodes([]models.Node{{ID: "node-1", Name: "node-1"}}) + + if err := monitor.LinkHostAgent("host-1", "node-1"); err != nil { + t.Fatalf("LinkHostAgent error: %v", err) + } + + hosts := monitor.state.GetHosts() + if len(hosts) != 1 || hosts[0].LinkedNodeID != "node-1" { + t.Fatalf("LinkedNodeID = %q, want %q", hosts[0].LinkedNodeID, "node-1") + } + if len(monitor.state.Nodes) != 1 || monitor.state.Nodes[0].LinkedHostAgentID != "host-1" { + t.Fatalf("LinkedHostAgentID = %q, want %q", monitor.state.Nodes[0].LinkedHostAgentID, "host-1") + } +} + +func TestMonitorInvalidateAgentProfileCache(t *testing.T) { + monitor := &Monitor{ + agentProfileCache: &agentProfileCacheEntry{ + profiles: []models.AgentProfile{{ID: "profile-1"}}, + loadedAt: time.Now(), + }, + } + + monitor.InvalidateAgentProfileCache() + if monitor.agentProfileCache != nil { + t.Fatalf("expected cache to be cleared") + } +} + +func TestMonitorMarkDockerHostPendingUninstall(t *testing.T) { + monitor := &Monitor{state: models.NewState()} + + if _, err := monitor.MarkDockerHostPendingUninstall(""); err == nil { + t.Fatalf("expected error on empty host ID") + } + if _, err := monitor.MarkDockerHostPendingUninstall("missing"); err == nil { + t.Fatalf("expected error on missing host") + } + + monitor.state.UpsertDockerHost(models.DockerHost{ID: "host-1", Hostname: "host-1"}) + host, err := monitor.MarkDockerHostPendingUninstall("host-1") + if err != nil { + t.Fatalf("MarkDockerHostPendingUninstall error: %v", err) + } + if !host.PendingUninstall { + t.Fatalf("expected PendingUninstall to be true") + } + + hosts := monitor.state.GetDockerHosts() + if len(hosts) != 1 || !hosts[0].PendingUninstall { + t.Fatalf("state PendingUninstall = %v, want true", hosts[0].PendingUninstall) + } +} + +func TestEnsureClusterEndpointURL(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", ""}, + {"https://node.example:8006", "https://node.example:8006"}, + {"node.example", "https://node.example:8006"}, + {"node.example:9006", "https://node.example:9006"}, + {" node.example ", "https://node.example:8006"}, + } + + for _, tt := range tests { + if got := ensureClusterEndpointURL(tt.input); got != tt.expected { + t.Fatalf("ensureClusterEndpointURL(%q) = %q, want %q", tt.input, got, tt.expected) + } + } +} + +func TestClusterEndpointEffectiveURL(t *testing.T) { + endpoint := config.ClusterEndpoint{ + Host: "node.local", + IP: "10.0.0.1", + } + + if got := clusterEndpointEffectiveURL(endpoint, true, false); got != "https://node.local:8006" { + t.Fatalf("verifySSL host preference = %q, want %q", got, "https://node.local:8006") + } + + endpoint.Host = "" + if got := clusterEndpointEffectiveURL(endpoint, true, false); got != "https://10.0.0.1:8006" { + t.Fatalf("verifySSL fallback to IP = %q, want %q", got, "https://10.0.0.1:8006") + } + + endpoint.Host = "node.local" + if got := clusterEndpointEffectiveURL(endpoint, false, false); got != "https://10.0.0.1:8006" { + t.Fatalf("non-SSL IP preference = %q, want %q", got, "https://10.0.0.1:8006") + } + + endpoint.IPOverride = "192.168.1.10" + if got := clusterEndpointEffectiveURL(endpoint, false, false); got != "https://192.168.1.10:8006" { + t.Fatalf("override IP preference = %q, want %q", got, "https://192.168.1.10:8006") + } + + endpoint = config.ClusterEndpoint{} + if got := clusterEndpointEffectiveURL(endpoint, true, false); got != "" { + t.Fatalf("empty endpoint = %q, want empty", got) + } +} diff --git a/internal/monitoring/multi_tenant_monitor_additional_test.go b/internal/monitoring/multi_tenant_monitor_additional_test.go new file mode 100644 index 000000000..b5ffd6978 --- /dev/null +++ b/internal/monitoring/multi_tenant_monitor_additional_test.go @@ -0,0 +1,20 @@ +package monitoring + +import "testing" + +func TestMultiTenantMonitorRemoveTenant(t *testing.T) { + monitor := &Monitor{} + mtm := &MultiTenantMonitor{ + monitors: map[string]*Monitor{ + "org-1": monitor, + }, + } + + mtm.RemoveTenant("org-1") + if _, ok := mtm.monitors["org-1"]; ok { + t.Fatalf("expected org-1 to be removed") + } + + // Ensure removal of missing orgs is a no-op. + mtm.RemoveTenant("missing") +} diff --git a/internal/notifications/notifications_additional_test.go b/internal/notifications/notifications_additional_test.go new file mode 100644 index 000000000..7b3680964 --- /dev/null +++ b/internal/notifications/notifications_additional_test.go @@ -0,0 +1,193 @@ +package notifications + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/alerts" +) + +func TestSendResolvedAppriseCLI(t *testing.T) { + manager := NewNotificationManager("") + defer manager.Stop() + + var called bool + manager.appriseExec = func(ctx context.Context, path string, args []string) ([]byte, error) { + called = true + if path != "apprise" { + t.Fatalf("expected CLI path apprise, got %q", path) + } + if len(args) == 0 || args[len(args)-1] != "target-1" { + t.Fatalf("expected target to be passed, got %v", args) + } + if !containsArg(args, "-t") || !containsArg(args, "-b") { + t.Fatalf("expected title/body args, got %v", args) + } + return nil, nil + } + + config := AppriseConfig{ + Enabled: true, + Mode: AppriseModeCLI, + Targets: []string{"target-1"}, + CLIPath: "apprise", + TimeoutSeconds: 1, + } + + alertList := []*alerts.Alert{ + { + ID: "a1", + Type: "cpu", + Level: alerts.AlertLevelWarning, + ResourceID: "r1", + ResourceName: "db-1", + Message: "cpu high", + Value: 91, + Threshold: 80, + StartTime: time.Now().Add(-time.Minute), + }, + } + + if err := manager.sendResolvedApprise(config, alertList, time.Now()); err != nil { + t.Fatalf("sendResolvedApprise error: %v", err) + } + if !called { + t.Fatalf("expected apprise exec to be called") + } +} + +func TestSendGroupedWebhookGeneric(t *testing.T) { + var gotMethod string + var gotBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + body, _ := io.ReadAll(r.Body) + gotBody = body + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + manager := NewNotificationManager("") + defer manager.Stop() + manager.webhookClient = server.Client() + if err := manager.UpdateAllowedPrivateCIDRs("127.0.0.1/32"); err != nil { + t.Fatalf("allowlist: %v", err) + } + + alertsList := []*alerts.Alert{ + { + ID: "a1", + Type: "cpu", + Level: alerts.AlertLevelCritical, + ResourceID: "r1", + ResourceName: "db-1", + Message: "cpu critical", + Value: 99, + Threshold: 90, + StartTime: time.Now().Add(-2 * time.Minute), + }, + { + ID: "a2", + Type: "mem", + Level: alerts.AlertLevelWarning, + ResourceID: "r2", + ResourceName: "cache-1", + Message: "memory high", + Value: 85, + Threshold: 80, + StartTime: time.Now().Add(-time.Minute), + }, + } + + webhook := WebhookConfig{ + Name: "generic", + URL: server.URL + "/hook", + Enabled: true, + } + + if err := manager.sendGroupedWebhook(webhook, alertsList); err != nil { + t.Fatalf("sendGroupedWebhook error: %v", err) + } + if gotMethod != http.MethodPost { + t.Fatalf("expected POST, got %s", gotMethod) + } + + var payload map[string]any + if err := json.Unmarshal(gotBody, &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if grouped, ok := payload["grouped"].(bool); !ok || !grouped { + t.Fatalf("expected grouped payload, got %v", payload["grouped"]) + } + if count, ok := payload["count"].(float64); !ok || int(count) != len(alertsList) { + t.Fatalf("expected count %d, got %v", len(alertsList), payload["count"]) + } +} + +func TestSendResolvedWebhookHTTP(t *testing.T) { + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + gotBody = body + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + manager := NewNotificationManager("") + defer manager.Stop() + manager.webhookClient = server.Client() + if err := manager.UpdateAllowedPrivateCIDRs("127.0.0.1/32"); err != nil { + t.Fatalf("allowlist: %v", err) + } + + alertList := []*alerts.Alert{ + { + ID: "a1", + Type: "disk", + Level: alerts.AlertLevelWarning, + ResourceID: "r1", + ResourceName: "storage-1", + Message: "disk high", + Value: 92, + Threshold: 90, + StartTime: time.Now().Add(-time.Minute), + }, + } + + webhook := WebhookConfig{ + Name: "resolved", + URL: server.URL + "/resolved", + Enabled: true, + } + + if err := manager.sendResolvedWebhook(webhook, alertList, time.Now()); err != nil { + t.Fatalf("sendResolvedWebhook error: %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(gotBody, &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if payload["event"] != "resolved" { + t.Fatalf("expected event resolved, got %v", payload["event"]) + } + if payload["alertId"] != "a1" { + t.Fatalf("expected alertId a1, got %v", payload["alertId"]) + } +} + +func containsArg(args []string, value string) bool { + for _, arg := range args { + if strings.TrimSpace(arg) == value { + return true + } + } + return false +} diff --git a/internal/remoteconfig/client_additional_test.go b/internal/remoteconfig/client_additional_test.go new file mode 100644 index 000000000..93ca4455e --- /dev/null +++ b/internal/remoteconfig/client_additional_test.go @@ -0,0 +1,269 @@ +package remoteconfig + +import ( + "context" + "crypto/ed25519" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestClientFetchWithSignature(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", base64.StdEncoding.EncodeToString(pub)) + + issuedAt := time.Now().UTC() + expiresAt := issuedAt.Add(5 * time.Minute) + commands := true + settings := map[string]interface{}{"interval": "1m"} + + payload := SignedConfigPayload{ + HostID: "agent-1", + IssuedAt: issuedAt, + ExpiresAt: expiresAt, + CommandsEnabled: &commands, + Settings: settings, + } + signature, err := SignConfigPayload(payload, priv) + if err != nil { + t.Fatalf("SignConfigPayload: %v", err) + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/agents/host/agent-1/config" { + w.WriteHeader(http.StatusNotFound) + return + } + resp := Response{ + Success: true, + HostID: "agent-1", + } + resp.Config.CommandsEnabled = &commands + resp.Config.Settings = settings + resp.Config.IssuedAt = issuedAt + resp.Config.ExpiresAt = expiresAt + resp.Config.Signature = signature + _ = json.NewEncoder(w).Encode(resp) + })) + defer ts.Close() + + client := New(Config{ + PulseURL: ts.URL, + APIToken: "token-123", + AgentID: "agent-1", + }) + + gotSettings, gotCommands, err := client.Fetch(context.Background()) + if err != nil { + t.Fatalf("Fetch error: %v", err) + } + if gotCommands == nil || *gotCommands != true { + t.Fatalf("expected commands enabled, got %v", gotCommands) + } + if gotSettings["interval"] != "1m" { + t.Fatalf("unexpected settings: %#v", gotSettings) + } +} + +func TestClientFetchSignatureFailures(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", base64.StdEncoding.EncodeToString(pub)) + + settings := map[string]interface{}{"interval": "1m"} + makeResp := func(sig string, issued, expires time.Time) Response { + resp := Response{ + Success: true, + HostID: "agent-1", + } + resp.Config.Settings = settings + resp.Config.Signature = sig + resp.Config.IssuedAt = issued + resp.Config.ExpiresAt = expires + return resp + } + + issuedAt := time.Now().UTC() + expiresAt := issuedAt.Add(5 * time.Minute) + payload := SignedConfigPayload{ + HostID: "agent-1", + IssuedAt: issuedAt, + ExpiresAt: expiresAt, + Settings: settings, + } + signature, err := SignConfigPayload(payload, priv) + if err != nil { + t.Fatalf("SignConfigPayload: %v", err) + } + + tests := []struct { + name string + resp Response + wantText string + }{ + {name: "missing timestamps", resp: makeResp(signature, time.Time{}, time.Time{}), wantText: "missing timestamp"}, + {name: "expired", resp: makeResp(signature, issuedAt.Add(-10*time.Minute), issuedAt.Add(-5*time.Minute)), wantText: "expired"}, + {name: "future", resp: makeResp(signature, issuedAt.Add(10*time.Minute), issuedAt.Add(20*time.Minute)), wantText: "future"}, + {name: "invalid signature", resp: makeResp("nope", issuedAt, expiresAt), wantText: "verification failed"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(tt.resp) + })) + defer ts.Close() + + client := New(Config{PulseURL: ts.URL, APIToken: "t", AgentID: "agent-1"}) + _, _, err := client.Fetch(context.Background()) + if err == nil || !strings.Contains(err.Error(), tt.wantText) { + t.Fatalf("expected error containing %q, got %v", tt.wantText, err) + } + }) + } +} + +func TestClientFetchHostLookupAndErrors(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/agents/host/lookup": + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"success":true,"host":{"id":"host-9"}}`)) + case "/api/agents/host/host-9/config": + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"success":true,"hostId":"host-9","config":{"settings":{"mode":"ok"}}}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + + client := New(Config{ + PulseURL: ts.URL, + APIToken: "t", + AgentID: "agent-1", + Hostname: "known", + }) + settings, _, err := client.Fetch(context.Background()) + if err != nil { + t.Fatalf("Fetch error: %v", err) + } + if settings["mode"] != "ok" { + t.Fatalf("unexpected settings: %#v", settings) + } + + redirect := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/other", http.StatusFound) + })) + defer redirect.Close() + + client = New(Config{PulseURL: redirect.URL, APIToken: "t", AgentID: "agent-1"}) + if _, _, err := client.Fetch(context.Background()); err == nil { + t.Fatalf("expected redirect error") + } + + badJSON := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("{bad")) + })) + defer badJSON.Close() + + client = New(Config{PulseURL: badJSON.URL, APIToken: "t", AgentID: "agent-1"}) + if _, _, err := client.Fetch(context.Background()); err == nil { + t.Fatalf("expected decode error") + } +} + +func TestClientNewDefaultsAndHostLookupNotFound(t *testing.T) { + client := New(Config{InsecureSkipVerify: true}) + if client.cfg.PulseURL != "http://localhost:7655" { + t.Fatalf("unexpected default PulseURL: %s", client.cfg.PulseURL) + } + + transport, ok := client.httpClient.Transport.(*http.Transport) + if !ok || transport.TLSClientConfig == nil || !transport.TLSClientConfig.InsecureSkipVerify { + t.Fatalf("expected insecure TLS config") + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + client = New(Config{ + PulseURL: ts.URL, + APIToken: "t", + Hostname: "missing", + }) + if got, err := client.resolveHostID(context.Background()); err != nil || got != "" { + t.Fatalf("expected empty host ID, got %q err=%v", got, err) + } +} + +func TestClientFetchResolveHostIDError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/lookup") { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + client := New(Config{ + PulseURL: ts.URL, + APIToken: "t", + AgentID: "agent-1", + Hostname: "known", + }) + if _, _, err := client.Fetch(context.Background()); err == nil { + t.Fatalf("expected resolve host error") + } +} + +type errorRoundTripper struct{} + +func (errorRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, context.Canceled +} + +func TestClientFetchInvalidURL(t *testing.T) { + client := New(Config{ + PulseURL: "http://bad url", + APIToken: "t", + AgentID: "agent-1", + }) + if _, _, err := client.Fetch(context.Background()); err == nil || !strings.Contains(err.Error(), "create request") { + t.Fatalf("expected create request error, got %v", err) + } +} + +func TestClientResolveHostIDRequestErrors(t *testing.T) { + client := New(Config{ + PulseURL: "http://bad url", + APIToken: "t", + Hostname: "host", + }) + if _, err := client.resolveHostID(context.Background()); err == nil || !strings.Contains(err.Error(), "create host lookup request") { + t.Fatalf("expected request error, got %v", err) + } + + client = New(Config{ + PulseURL: "http://example.com", + APIToken: "t", + Hostname: "host", + }) + client.httpClient = &http.Client{Transport: errorRoundTripper{}} + if _, err := client.resolveHostID(context.Background()); err == nil || !strings.Contains(err.Error(), "host lookup request") { + t.Fatalf("expected transport error, got %v", err) + } +} diff --git a/internal/remoteconfig/signature_additional_test.go b/internal/remoteconfig/signature_additional_test.go new file mode 100644 index 000000000..216eb367c --- /dev/null +++ b/internal/remoteconfig/signature_additional_test.go @@ -0,0 +1,223 @@ +package remoteconfig + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "strings" + "testing" + "time" +) + +func TestDecodeEd25519PrivateKeyInvalidLength(t *testing.T) { + raw := []byte("short") + encoded := base64.StdEncoding.EncodeToString(raw) + if _, err := DecodeEd25519PrivateKey(encoded); err == nil { + t.Fatalf("expected invalid length error") + } +} + +func TestSignConfigPayloadMissingKey(t *testing.T) { + if _, err := SignConfigPayload(SignedConfigPayload{}, nil); err == nil { + t.Fatalf("expected missing key error") + } +} + +func TestSignConfigPayloadInvalidSettings(t *testing.T) { + _, priv, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + payload := SignedConfigPayload{ + HostID: "host-1", + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Minute), + Settings: map[string]interface{}{"bad": func() {}}, + } + if _, err := SignConfigPayload(payload, priv); err == nil { + t.Fatalf("expected settings error") + } +} + +func TestVerifyConfigPayloadSignatureInvalidBase64(t *testing.T) { + payload := SignedConfigPayload{ + HostID: "host", + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Minute), + } + if err := VerifyConfigPayloadSignature(payload, "not-base64"); err == nil { + t.Fatalf("expected base64 error") + } +} + +func TestVerifyConfigPayloadSignatureMissing(t *testing.T) { + payload := SignedConfigPayload{ + HostID: "host", + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Minute), + } + if err := VerifyConfigPayloadSignature(payload, ""); err == nil { + t.Fatalf("expected missing signature error") + } +} + +func TestTrustedConfigPublicKeysErrors(t *testing.T) { + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", "") + if keys, err := trustedConfigPublicKeys(); err != nil || len(keys) == 0 { + t.Fatalf("expected default keys, got %d err=%v", len(keys), err) + } + + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", ",") + if _, err := trustedConfigPublicKeys(); err == nil { + t.Fatalf("expected no trusted keys error") + } + + block := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("nope")}) + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", string(block)) + if _, err := trustedConfigPublicKeys(); err == nil { + t.Fatalf("expected no trusted keys error for wrong PEM type") + } + + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + rsaPub, err := x509.MarshalPKIXPublicKey(&rsaKey.PublicKey) + if err != nil { + t.Fatalf("MarshalPKIXPublicKey: %v", err) + } + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", base64.StdEncoding.EncodeToString(rsaPub)) + if _, err := trustedConfigPublicKeys(); err == nil || !strings.Contains(err.Error(), "Ed25519") { + t.Fatalf("expected ed25519 error, got %v", err) + } + + garbage := base64.StdEncoding.EncodeToString([]byte("garbage")) + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", garbage) + if _, err := trustedConfigPublicKeys(); err == nil { + t.Fatalf("expected parse error for garbage") + } + + pub, _, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + raw, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + t.Fatalf("MarshalPKIXPublicKey: %v", err) + } + block = pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: raw}) + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", string(block)) + if keys, err := trustedConfigPublicKeys(); err != nil || len(keys) != 1 { + t.Fatalf("expected 1 key, got %d err=%v", len(keys), err) + } + + badBlock := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: []byte("bad")}) + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", string(badBlock)) + if _, err := trustedConfigPublicKeys(); err == nil { + t.Fatalf("expected parse error for invalid pem") + } + + rsaBlock := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: rsaPub}) + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", string(rsaBlock)) + if _, err := trustedConfigPublicKeys(); err == nil || !strings.Contains(err.Error(), "Ed25519") { + t.Fatalf("expected pem ed25519 error, got %v", err) + } + + multi := append(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("nope")}), block...) + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", string(multi)) + if keys, err := trustedConfigPublicKeys(); err != nil || len(keys) != 1 { + t.Fatalf("expected 1 key from multi pem, got %d err=%v", len(keys), err) + } +} + +func TestVerifyConfigPayloadSignatureFailure(t *testing.T) { + pub, _, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", base64.StdEncoding.EncodeToString(pub)) + + payload := SignedConfigPayload{ + HostID: "host-1", + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Minute), + Settings: map[string]interface{}{"key": "value"}, + } + if err := VerifyConfigPayloadSignature(payload, base64.StdEncoding.EncodeToString([]byte("bad"))); err == nil { + t.Fatalf("expected verification failure") + } +} + +func TestCanonicalConfigPayloadEmptySettings(t *testing.T) { + payload := SignedConfigPayload{ + HostID: "host-1", + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Minute), + } + data, err := canonicalConfigPayload(payload) + if err != nil { + t.Fatalf("canonicalConfigPayload error: %v", err) + } + if !strings.Contains(string(data), `"hostId":"host-1"`) { + t.Fatalf("unexpected payload: %s", string(data)) + } +} + +func TestMarshalSortedMapEmptyAndInvalid(t *testing.T) { + if data, err := marshalSortedMap(map[string]interface{}{}); err != nil || data != nil { + t.Fatalf("expected nil for empty map") + } + + if _, err := marshalSortedMap(map[string]interface{}{"bad": func() {}}); err == nil { + t.Fatalf("expected marshal error") + } +} + +func TestVerifyConfigPayloadSignatureCanonicalError(t *testing.T) { + payload := SignedConfigPayload{ + HostID: "host", + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Minute), + Settings: map[string]interface{}{"bad": func() {}}, + } + if err := VerifyConfigPayloadSignature(payload, base64.StdEncoding.EncodeToString([]byte("sig"))); err == nil { + t.Fatalf("expected canonical error") + } +} + +func TestMarshalCanonicalValueSliceError(t *testing.T) { + if _, err := marshalCanonicalValue([]interface{}{func() {}}); err == nil { + t.Fatalf("expected slice marshal error") + } +} + +func TestVerifyConfigPayloadSignatureTrustedKeysError(t *testing.T) { + payload := SignedConfigPayload{ + HostID: "host", + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Minute), + } + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", "not-base64") + if err := VerifyConfigPayloadSignature(payload, base64.StdEncoding.EncodeToString([]byte("sig"))); err == nil { + t.Fatalf("expected trusted key error") + } +} + +func TestTrustedConfigPublicKeysPKIXEd25519(t *testing.T) { + pub, _, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + pkix, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + t.Fatalf("MarshalPKIXPublicKey: %v", err) + } + t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", base64.StdEncoding.EncodeToString(pkix)) + if keys, err := trustedConfigPublicKeys(); err != nil || len(keys) != 1 { + t.Fatalf("expected 1 pkix key, got %d err=%v", len(keys), err) + } +} diff --git a/internal/sensors/parser_additional_test.go b/internal/sensors/parser_additional_test.go new file mode 100644 index 000000000..a351ac633 --- /dev/null +++ b/internal/sensors/parser_additional_test.go @@ -0,0 +1,102 @@ +package sensors + +import ( + "math" + "testing" +) + +func TestNormalizeSensorKey(t *testing.T) { + tests := []struct { + chipName string + sensorName string + expected string + }{ + {"nct6687-isa-0a20", "CPU Fan", "nct6687_cpu_fan"}, + {"amdgpu-pci-0400", "edge-temp", "amdgpu_edge_temp"}, + {"nvme-pci-0100", "Composite", "nvme_composite"}, + } + + for _, tc := range tests { + t.Run(tc.expected, func(t *testing.T) { + got := normalizeSensorKey(tc.chipName, tc.sensorName) + if got != tc.expected { + t.Fatalf("normalizeSensorKey(%q, %q) = %q, want %q", tc.chipName, tc.sensorName, got, tc.expected) + } + }) + } +} + +func TestExtractNumericValue(t *testing.T) { + tests := []struct { + name string + value interface{} + expected float64 + }{ + {"float64", 12.5, 12.5}, + {"int", 42, 42}, + {"int64", int64(99), 99}, + {"string", "nope", 0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := extractNumericValue(tc.value) + if got != tc.expected { + t.Fatalf("extractNumericValue(%v) = %v, want %v", tc.value, got, tc.expected) + } + }) + } +} + +func TestExtractNVMeCompositeTemp(t *testing.T) { + t.Run("found", func(t *testing.T) { + chipMap := map[string]interface{}{ + "Composite": map[string]interface{}{ + "temp1_input": 42.0, + }, + "Other": map[string]interface{}{ + "temp1_input": 30.0, + }, + } + + temp, ok := extractNVMeCompositeTemp(chipMap) + if !ok { + t.Fatalf("expected composite temp to be found") + } + if temp != 42.0 { + t.Fatalf("temp = %v, want 42.0", temp) + } + }) + + t.Run("missing", func(t *testing.T) { + chipMap := map[string]interface{}{ + "Temperature 2": map[string]interface{}{ + "temp1_input": 30.0, + }, + } + + temp, ok := extractNVMeCompositeTemp(chipMap) + if ok { + t.Fatalf("expected composite temp to be missing, got %v", temp) + } + if temp != 0 { + t.Fatalf("temp = %v, want 0", temp) + } + }) + + t.Run("invalid", func(t *testing.T) { + chipMap := map[string]interface{}{ + "Composite": map[string]interface{}{ + "temp1_input": -1.0, + }, + } + + temp, ok := extractNVMeCompositeTemp(chipMap) + if ok { + t.Fatalf("expected composite temp to be rejected, got %v", temp) + } + if !math.IsNaN(temp) && temp != 0 { + t.Fatalf("temp = %v, want 0", temp) + } + }) +} diff --git a/internal/sensors/power.go b/internal/sensors/power.go index 876b2c5b3..826af456d 100644 --- a/internal/sensors/power.go +++ b/internal/sensors/power.go @@ -34,7 +34,7 @@ type PowerData struct { // raplBasePath is the base path for Intel RAPL (Running Average Power Limit) readings. // RAPL provides energy counters that we sample to calculate power. -const raplBasePath = "/sys/class/powercap/intel-rapl" +var raplBasePath = "/sys/class/powercap/intel-rapl" // sampleInterval is the time between energy counter readings. // Shorter intervals are less accurate; longer intervals add latency. @@ -204,7 +204,7 @@ func readStringFile(path string) (string, error) { } // hwmonBasePath is the base path for hwmon devices (used by AMD energy driver). -const hwmonBasePath = "/sys/class/hwmon" +var hwmonBasePath = "/sys/class/hwmon" // collectAMDEnergy reads power data from AMD energy driver via hwmon. // The amd_energy module exposes energy counters similar to Intel RAPL. diff --git a/internal/sensors/power_additional_test.go b/internal/sensors/power_additional_test.go new file mode 100644 index 000000000..9941cb3d8 --- /dev/null +++ b/internal/sensors/power_additional_test.go @@ -0,0 +1,69 @@ +package sensors + +import ( + "os" + "path/filepath" + "testing" +) + +func TestReadAMDEnergy(t *testing.T) { + tmpDir := t.TempDir() + + energy1 := filepath.Join(tmpDir, "energy1_input") + if err := os.WriteFile(energy1, []byte("1000"), 0644); err != nil { + t.Fatalf("write energy1_input: %v", err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "energy1_label"), []byte("package"), 0644); err != nil { + t.Fatalf("write energy1_label: %v", err) + } + + energy2 := filepath.Join(tmpDir, "energy2_input") + if err := os.WriteFile(energy2, []byte("2000"), 0644); err != nil { + t.Fatalf("write energy2_input: %v", err) + } + + energy3 := filepath.Join(tmpDir, "energy3_input") + if err := os.WriteFile(energy3, []byte("not-a-number"), 0644); err != nil { + t.Fatalf("write energy3_input: %v", err) + } + + result, err := readAMDEnergy(tmpDir) + if err != nil { + t.Fatalf("readAMDEnergy returned error: %v", err) + } + + if len(result) != 2 { + t.Fatalf("result = %#v, want 2 readings", result) + } + if result["package"] != 1000 { + t.Fatalf("package = %d, want 1000", result["package"]) + } + if result["energy2_input"] != 2000 { + t.Fatalf("energy2_input = %d, want 2000", result["energy2_input"]) + } + if _, ok := result["energy3_input"]; ok { + t.Fatalf("energy3_input should be skipped due to parse error") + } +} + +func TestReadAMDEnergy_NoFiles(t *testing.T) { + tmpDir := t.TempDir() + + _, err := readAMDEnergy(tmpDir) + if err == nil { + t.Fatalf("expected error for missing energy files") + } +} + +func TestReadAMDEnergy_NoReadings(t *testing.T) { + tmpDir := t.TempDir() + energy1 := filepath.Join(tmpDir, "energy1_input") + if err := os.WriteFile(energy1, []byte("invalid"), 0644); err != nil { + t.Fatalf("write energy1_input: %v", err) + } + + _, err := readAMDEnergy(tmpDir) + if err == nil { + t.Fatalf("expected error for unreadable energy values") + } +} diff --git a/internal/sensors/power_sysfs_test.go b/internal/sensors/power_sysfs_test.go new file mode 100644 index 000000000..182ee7dd4 --- /dev/null +++ b/internal/sensors/power_sysfs_test.go @@ -0,0 +1,200 @@ +package sensors + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" +) + +func writeEnergyFile(path string, value string) error { + return os.WriteFile(path, []byte(value), 0644) +} + +func TestCollectRALP_MockSysfs(t *testing.T) { + tmpDir := t.TempDir() + original := raplBasePath + raplBasePath = tmpDir + t.Cleanup(func() { + raplBasePath = original + }) + + pkg0 := filepath.Join(tmpDir, "intel-rapl:0") + if err := os.MkdirAll(pkg0, 0755); err != nil { + t.Fatalf("mkdir pkg0: %v", err) + } + if err := writeEnergyFile(filepath.Join(pkg0, "energy_uj"), "1000000"); err != nil { + t.Fatalf("write package energy: %v", err) + } + if err := os.WriteFile(filepath.Join(pkg0, "name"), []byte("package-0"), 0644); err != nil { + t.Fatalf("write name: %v", err) + } + + core := filepath.Join(pkg0, "intel-rapl:0:0") + if err := os.MkdirAll(core, 0755); err != nil { + t.Fatalf("mkdir core: %v", err) + } + if err := writeEnergyFile(filepath.Join(core, "energy_uj"), "200000"); err != nil { + t.Fatalf("write core energy: %v", err) + } + if err := os.WriteFile(filepath.Join(core, "name"), []byte("core"), 0644); err != nil { + t.Fatalf("write core name: %v", err) + } + + dram := filepath.Join(pkg0, "intel-rapl:0:1") + if err := os.MkdirAll(dram, 0755); err != nil { + t.Fatalf("mkdir dram: %v", err) + } + if err := writeEnergyFile(filepath.Join(dram, "energy_uj"), "300000"); err != nil { + t.Fatalf("write dram energy: %v", err) + } + if err := os.WriteFile(filepath.Join(dram, "name"), []byte("dram"), 0644); err != nil { + t.Fatalf("write dram name: %v", err) + } + + errCh := make(chan error, 1) + go func() { + time.Sleep(20 * time.Millisecond) + if err := writeEnergyFile(filepath.Join(pkg0, "energy_uj"), "2000000"); err != nil { + errCh <- err + return + } + if err := writeEnergyFile(filepath.Join(core, "energy_uj"), "400000"); err != nil { + errCh <- err + return + } + if err := writeEnergyFile(filepath.Join(dram, "energy_uj"), "600000"); err != nil { + errCh <- err + return + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + data, err := collectRALP(ctx) + if err != nil { + t.Fatalf("collectRALP error: %v", err) + } + select { + case err := <-errCh: + t.Fatalf("update energy files: %v", err) + default: + } + if !data.Available { + t.Fatalf("expected data.Available true") + } + if data.Source != "rapl" { + t.Fatalf("expected source rapl, got %q", data.Source) + } + if data.PackageWatts <= 0 || data.CoreWatts <= 0 || data.DRAMWatts <= 0 { + t.Fatalf("expected non-zero watts, got %+v", data) + } +} + +func TestCollectAMDEnergy_MockSysfs(t *testing.T) { + tmpDir := t.TempDir() + original := hwmonBasePath + hwmonBasePath = tmpDir + t.Cleanup(func() { + hwmonBasePath = original + }) + + hwmon := filepath.Join(tmpDir, "hwmon0") + if err := os.MkdirAll(hwmon, 0755); err != nil { + t.Fatalf("mkdir hwmon: %v", err) + } + if err := os.WriteFile(filepath.Join(hwmon, "name"), []byte("amd_energy"), 0644); err != nil { + t.Fatalf("write hwmon name: %v", err) + } + if err := writeEnergyFile(filepath.Join(hwmon, "energy1_input"), "1000000"); err != nil { + t.Fatalf("write energy input: %v", err) + } + if err := os.WriteFile(filepath.Join(hwmon, "energy1_label"), []byte("socket"), 0644); err != nil { + t.Fatalf("write energy label: %v", err) + } + + errCh := make(chan error, 1) + go func() { + time.Sleep(20 * time.Millisecond) + if err := writeEnergyFile(filepath.Join(hwmon, "energy1_input"), "2000000"); err != nil { + errCh <- err + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + data, err := collectAMDEnergy(ctx) + if err != nil { + t.Fatalf("collectAMDEnergy error: %v", err) + } + select { + case err := <-errCh: + t.Fatalf("update energy input: %v", err) + default: + } + if !data.Available { + t.Fatalf("expected data.Available true") + } + if data.Source != "amd_energy" { + t.Fatalf("expected source amd_energy, got %q", data.Source) + } + if data.PackageWatts <= 0 { + t.Fatalf("expected package watts > 0, got %+v", data) + } +} + +func TestCollectPower_FallbackToAMD(t *testing.T) { + originalRAPL := raplBasePath + originalHwmon := hwmonBasePath + tmpDir := t.TempDir() + raplBasePath = filepath.Join(tmpDir, "missing-rapl") + hwmonBasePath = filepath.Join(tmpDir, "hwmon") + t.Cleanup(func() { + raplBasePath = originalRAPL + hwmonBasePath = originalHwmon + }) + + if err := os.MkdirAll(hwmonBasePath, 0755); err != nil { + t.Fatalf("mkdir hwmon base: %v", err) + } + hwmon := filepath.Join(hwmonBasePath, "hwmon0") + if err := os.MkdirAll(hwmon, 0755); err != nil { + t.Fatalf("mkdir hwmon: %v", err) + } + if err := os.WriteFile(filepath.Join(hwmon, "name"), []byte("amd_energy"), 0644); err != nil { + t.Fatalf("write hwmon name: %v", err) + } + if err := writeEnergyFile(filepath.Join(hwmon, "energy1_input"), "1000000"); err != nil { + t.Fatalf("write energy input: %v", err) + } + if err := os.WriteFile(filepath.Join(hwmon, "energy1_label"), []byte("package"), 0644); err != nil { + t.Fatalf("write energy label: %v", err) + } + + errCh := make(chan error, 1) + go func() { + time.Sleep(20 * time.Millisecond) + if err := writeEnergyFile(filepath.Join(hwmon, "energy1_input"), "2000000"); err != nil { + errCh <- err + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + data, err := CollectPower(ctx) + if err != nil { + t.Fatalf("CollectPower error: %v", err) + } + select { + case err := <-errCh: + t.Fatalf("update energy input: %v", err) + default: + } + if data.Source != "amd_energy" { + t.Fatalf("expected amd_energy fallback, got %q", data.Source) + } +} diff --git a/internal/updates/adapter_installsh_extra_test.go b/internal/updates/adapter_installsh_extra_test.go new file mode 100644 index 000000000..2dcb38584 --- /dev/null +++ b/internal/updates/adapter_installsh_extra_test.go @@ -0,0 +1,259 @@ +package updates + +import ( + "context" + "crypto/sha256" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +func writeCurlStub(t *testing.T, dir string) { + t.Helper() + + cpPath, err := exec.LookPath("cp") + if err != nil { + t.Fatalf("find cp: %v", err) + } + + script := fmt.Sprintf(`#!/bin/sh +set -e +out="" +url="" +while [ "$#" -gt 0 ]; do + case "$1" in + -o) out="$2"; shift 2;; + *) url="$1"; shift;; + esac +done +if [ -z "$out" ]; then + exit 1 +fi +if echo "$url" | grep -q "\.sha256$"; then + exec %s "$PULSE_TEST_SHA" "$out" +fi +exec %s "$PULSE_TEST_FILE" "$out" +`, cpPath, cpPath) + + writeStub(t, dir, "curl", script) +} + +func TestInstallShAdapterDetectServiceName(t *testing.T) { + stubDir := t.TempDir() + writeStub(t, stubDir, "systemctl", `#!/bin/sh +if [ "$1" = "is-active" ] && [ "$2" = "pulse-backend" ]; then + echo "active" + exit 0 +fi +exit 1 +`) + t.Setenv("PATH", stubDir+string(os.PathListSeparator)+os.Getenv("PATH")) + + adapter := &InstallShAdapter{} + name, err := adapter.detectServiceName() + if err != nil { + t.Fatalf("detectServiceName error: %v", err) + } + if name != "pulse-backend" { + t.Fatalf("expected pulse-backend, got %s", name) + } + + writeStub(t, stubDir, "systemctl", `#!/bin/sh +echo "inactive" +exit 0 +`) + name, err = adapter.detectServiceName() + if err != nil { + t.Fatalf("detectServiceName error: %v", err) + } + if name != "pulse" { + t.Fatalf("expected default pulse, got %s", name) + } +} + +func TestInstallShAdapterDownloadInstallScript(t *testing.T) { + content := []byte("#!/bin/sh\necho ok\n") + sum := sha256.Sum256(content) + + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "install.sh") + if err := os.WriteFile(scriptPath, content, 0600); err != nil { + t.Fatalf("write install.sh: %v", err) + } + checksumPath := filepath.Join(tmpDir, "install.sh.sha256") + if err := os.WriteFile(checksumPath, []byte(fmt.Sprintf("%x install.sh\n", sum)), 0600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + stubDir := t.TempDir() + writeCurlStub(t, stubDir) + t.Setenv("PATH", stubDir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("PULSE_TEST_FILE", scriptPath) + t.Setenv("PULSE_TEST_SHA", checksumPath) + + adapter := NewInstallShAdapter(nil) + got, err := adapter.downloadInstallScript(context.Background()) + if err != nil { + t.Fatalf("downloadInstallScript error: %v", err) + } + if got != string(content) { + t.Fatalf("unexpected script content: %q", got) + } + + if err := os.WriteFile(checksumPath, []byte("deadbeef"), 0600); err != nil { + t.Fatalf("write checksum: %v", err) + } + if _, err := adapter.downloadInstallScript(context.Background()); err == nil { + t.Fatalf("expected checksum error") + } +} + +func TestInstallShAdapterDownloadBinary(t *testing.T) { + extractDir := t.TempDir() + tarball := filepath.Join(extractDir, "pulse-v1.2.3-linux-amd64.tar.gz") + writeTarGz(t, tarball, map[string]string{ + "bin/pulse": "binary", + }) + + data, err := os.ReadFile(tarball) + if err != nil { + t.Fatalf("read tarball: %v", err) + } + sum := sha256.Sum256(data) + checksumPath := tarball + ".sha256" + if err := os.WriteFile(checksumPath, []byte(fmt.Sprintf("%x %s\n", sum, filepath.Base(tarball))), 0600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + stubDir := t.TempDir() + writeCurlStub(t, stubDir) + t.Setenv("PATH", stubDir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("PULSE_TEST_FILE", tarball) + t.Setenv("PULSE_TEST_SHA", checksumPath) + + adapter := &InstallShAdapter{} + binaryPath, err := adapter.downloadBinary(context.Background(), "1.2.3") + if err != nil { + t.Fatalf("downloadBinary error: %v", err) + } + got, err := os.ReadFile(binaryPath) + if err != nil { + t.Fatalf("read binary: %v", err) + } + if string(got) != "binary" { + t.Fatalf("unexpected binary content: %q", string(got)) + } + + if err := os.WriteFile(checksumPath, []byte("deadbeef"), 0600); err != nil { + t.Fatalf("write checksum: %v", err) + } + if _, err := adapter.downloadBinary(context.Background(), "1.2.3"); err == nil { + t.Fatalf("expected checksum mismatch error") + } +} + +func TestInstallShAdapterReadLastLines(t *testing.T) { + adapter := &InstallShAdapter{} + if adapter.readLastLines(filepath.Join(t.TempDir(), "missing"), 2) != "" { + t.Fatalf("expected empty for missing file") + } + + path := filepath.Join(t.TempDir(), "log.txt") + if err := os.WriteFile(path, []byte("a\nb\nc\n"), 0600); err != nil { + t.Fatalf("write log: %v", err) + } + got := adapter.readLastLines(path, 2) + if got != "b\nc" { + t.Fatalf("unexpected lines: %q", got) + } + if adapter.readLastLines(path, 0) != "" { + t.Fatalf("expected empty for zero lines") + } +} + +func TestInstallShAdapterRestoreConfigAndInstallBinary(t *testing.T) { + cpPath, err := exec.LookPath("cp") + if err != nil { + t.Fatalf("find cp: %v", err) + } + stubDir := t.TempDir() + writeStub(t, stubDir, "cp", fmt.Sprintf("#!/bin/sh\nexec %s \"$@\"\n", cpPath)) + writeStub(t, stubDir, "chown", "#!/bin/sh\nexit 0\n") + t.Setenv("PATH", stubDir+string(os.PathListSeparator)+os.Getenv("PATH")) + + adapter := &InstallShAdapter{} + + srcDir := filepath.Join(t.TempDir(), "backup") + destDir := filepath.Join(t.TempDir(), "config") + if err := os.MkdirAll(srcDir, 0755); err != nil { + t.Fatalf("mkdir src: %v", err) + } + if err := os.WriteFile(filepath.Join(srcDir, "app.conf"), []byte("ok"), 0600); err != nil { + t.Fatalf("write src file: %v", err) + } + if err := os.MkdirAll(destDir, 0755); err != nil { + t.Fatalf("mkdir dest: %v", err) + } + + if err := adapter.restoreConfig(context.Background(), srcDir, destDir); err != nil { + t.Fatalf("restoreConfig error: %v", err) + } + if _, err := os.Stat(filepath.Join(destDir, "app.conf")); err != nil { + t.Fatalf("expected config restored: %v", err) + } + + srcBinary := filepath.Join(t.TempDir(), "pulse") + if err := os.WriteFile(srcBinary, []byte("bin"), 0755); err != nil { + t.Fatalf("write source binary: %v", err) + } + targetBinary := filepath.Join(t.TempDir(), "pulse") + if err := os.WriteFile(targetBinary, []byte("old"), 0755); err != nil { + t.Fatalf("write target binary: %v", err) + } + + if err := adapter.installBinary(context.Background(), srcBinary, targetBinary); err != nil { + t.Fatalf("installBinary error: %v", err) + } + if _, err := os.Stat(targetBinary + ".pre-rollback"); err != nil { + t.Fatalf("expected backup binary: %v", err) + } +} + +func TestInstallShAdapterWaitForHealth(t *testing.T) { + stubDir := t.TempDir() + writeStub(t, stubDir, "curl", "#!/bin/sh\nexit 0\n") + t.Setenv("PATH", stubDir+string(os.PathListSeparator)+os.Getenv("PATH")) + + adapter := &InstallShAdapter{} + if err := adapter.waitForHealth(context.Background(), time.Second); err != nil { + t.Fatalf("waitForHealth success error: %v", err) + } + + writeStub(t, stubDir, "curl", "#!/bin/sh\nexit 1\n") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := adapter.waitForHealth(ctx, time.Second); err == nil || !strings.Contains(err.Error(), "canceled") { + t.Fatalf("expected context error, got %v", err) + } + + if err := adapter.waitForHealth(context.Background(), 0); err == nil { + t.Fatalf("expected timeout error") + } +} + +func TestInstallShAdapterExecuteRollbackDownloadError(t *testing.T) { + stubDir := t.TempDir() + writeStub(t, stubDir, "systemctl", "#!/bin/sh\nexit 0\n") + writeStub(t, stubDir, "curl", "#!/bin/sh\nexit 1\n") + t.Setenv("PATH", stubDir+string(os.PathListSeparator)+os.Getenv("PATH")) + + adapter := &InstallShAdapter{} + entry := &UpdateHistoryEntry{BackupPath: t.TempDir()} + if err := adapter.executeRollback(context.Background(), entry, "1.2.3"); err == nil { + t.Fatalf("expected download error") + } +} diff --git a/internal/updates/manager_additional_test.go b/internal/updates/manager_additional_test.go new file mode 100644 index 000000000..2397cf8ca --- /dev/null +++ b/internal/updates/manager_additional_test.go @@ -0,0 +1,41 @@ +package updates + +import ( + "errors" + "strings" + "testing" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func TestGetCachedUpdateInfo_WithChannel(t *testing.T) { + info := &UpdateInfo{LatestVersion: "9.9.9"} + manager := &Manager{ + config: &config.Config{UpdateChannel: "stable"}, + checkCache: map[string]*UpdateInfo{"stable": info}, + } + + got := manager.GetCachedUpdateInfo() + if got != info { + t.Fatalf("expected cached info, got %+v", got) + } +} + +func TestSanitizeErrorTruncation(t *testing.T) { + if got := sanitizeError(nil); got != "" { + t.Fatalf("expected empty string for nil error, got %q", got) + } + + if got := sanitizeError(errors.New("short")); got != "short" { + t.Fatalf("expected short error, got %q", got) + } + + long := strings.Repeat("x", 600) + got := sanitizeError(errors.New(long)) + if len(got) <= 500 { + t.Fatalf("expected truncated error > 500 chars, got len=%d", len(got)) + } + if !strings.HasSuffix(got, "...") { + t.Fatalf("expected truncated error suffix ..., got %q", got[len(got)-3:]) + } +} diff --git a/internal/updates/manager_check_updates_test.go b/internal/updates/manager_check_updates_test.go new file mode 100644 index 000000000..8b1ee3b48 --- /dev/null +++ b/internal/updates/manager_check_updates_test.go @@ -0,0 +1,177 @@ +package updates + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/rcourtman/pulse-go-rewrite/internal/config" +) + +func newReleaseServer(t *testing.T, releases []ReleaseInfo, hitCount *int32) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/rcourtman/Pulse/releases" { + w.WriteHeader(http.StatusNotFound) + return + } + if hitCount != nil { + atomic.AddInt32(hitCount, 1) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(releases) + })) +} + +func TestCheckForUpdatesWithChannel_SourceBuild(t *testing.T) { + markerPath := "BUILD_FROM_SOURCE" + if err := os.WriteFile(markerPath, []byte("1"), 0644); err != nil { + t.Fatalf("write %s: %v", markerPath, err) + } + t.Cleanup(func() { + _ = os.Remove(markerPath) + }) + + manager := NewManager(&config.Config{UpdateChannel: "stable"}) + + info, err := manager.CheckForUpdatesWithChannel(context.Background(), "") + if err != nil { + t.Fatalf("CheckForUpdatesWithChannel returned error: %v", err) + } + if info.Available { + t.Fatalf("expected no updates for source build, got available") + } + if info.LatestVersion != info.CurrentVersion { + t.Fatalf("LatestVersion = %q, want %q", info.LatestVersion, info.CurrentVersion) + } +} + +func TestCheckForUpdatesWithChannel_AvailableUsesCache(t *testing.T) { + var hits int32 + releaseTime := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) + releases := []ReleaseInfo{ + { + TagName: "v99.0.0", + Name: "v99.0.0", + Body: "Release notes", + Prerelease: false, + PublishedAt: releaseTime, + Assets: []struct { + Name string `json:"name"` + BrowserDownloadURL string `json:"browser_download_url"` + }{ + { + Name: "pulse-v99.0.0-linux-amd64.tar.gz", + BrowserDownloadURL: "https://example.com/pulse-v99.0.0-linux-amd64.tar.gz", + }, + }, + }, + } + + server := newReleaseServer(t, releases, &hits) + defer server.Close() + + t.Setenv("PULSE_UPDATE_SERVER", server.URL) + + manager := NewManager(&config.Config{UpdateChannel: "stable"}) + + info, err := manager.CheckForUpdatesWithChannel(context.Background(), "") + if err != nil { + t.Fatalf("CheckForUpdatesWithChannel returned error: %v", err) + } + if !info.Available { + t.Fatalf("expected update to be available") + } + if info.LatestVersion != "99.0.0" { + t.Fatalf("LatestVersion = %q, want 99.0.0", info.LatestVersion) + } + if info.DownloadURL == "" { + t.Fatalf("DownloadURL not set") + } + + info2, err := manager.CheckForUpdatesWithChannel(context.Background(), "") + if err != nil { + t.Fatalf("CheckForUpdatesWithChannel second call error: %v", err) + } + if info2.LatestVersion != info.LatestVersion { + t.Fatalf("cached LatestVersion = %q, want %q", info2.LatestVersion, info.LatestVersion) + } + if got := atomic.LoadInt32(&hits); got != 1 { + t.Fatalf("expected 1 request, got %d", got) + } +} + +func TestCheckForUpdatesWithChannel_NoReleases(t *testing.T) { + var hits int32 + releases := []ReleaseInfo{ + { + TagName: "v99.0.0-rc.1", + Name: "v99.0.0-rc.1", + Prerelease: true, + PublishedAt: time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), + }, + } + + server := newReleaseServer(t, releases, &hits) + defer server.Close() + + t.Setenv("PULSE_UPDATE_SERVER", server.URL) + + manager := NewManager(&config.Config{UpdateChannel: "stable"}) + + info, err := manager.CheckForUpdatesWithChannel(context.Background(), "") + if err != nil { + t.Fatalf("CheckForUpdatesWithChannel returned error: %v", err) + } + if info.Available { + t.Fatalf("expected no updates for stable channel with only prereleases") + } + if got := atomic.LoadInt32(&hits); got != 1 { + t.Fatalf("expected 1 request, got %d", got) + } +} + +func TestCheckForUpdates_Wrapper(t *testing.T) { + var hits int32 + releases := []ReleaseInfo{ + { + TagName: "v99.1.0", + Name: "v99.1.0", + Body: "Release notes", + Prerelease: false, + PublishedAt: time.Date(2024, 2, 3, 4, 5, 6, 0, time.UTC), + Assets: []struct { + Name string `json:"name"` + BrowserDownloadURL string `json:"browser_download_url"` + }{ + { + Name: "pulse-v99.1.0-linux-amd64.tar.gz", + BrowserDownloadURL: "https://example.com/pulse-v99.1.0-linux-amd64.tar.gz", + }, + }, + }, + } + + server := newReleaseServer(t, releases, &hits) + defer server.Close() + + t.Setenv("PULSE_UPDATE_SERVER", server.URL) + + manager := NewManager(&config.Config{UpdateChannel: "stable"}) + + info, err := manager.CheckForUpdates(context.Background()) + if err != nil { + t.Fatalf("CheckForUpdates returned error: %v", err) + } + if !info.Available { + t.Fatalf("expected update to be available") + } + if info.LatestVersion != "99.1.0" { + t.Fatalf("LatestVersion = %q, want 99.1.0", info.LatestVersion) + } +} diff --git a/internal/updates/manager_sse_helpers_test.go b/internal/updates/manager_sse_helpers_test.go new file mode 100644 index 000000000..39042a0ce --- /dev/null +++ b/internal/updates/manager_sse_helpers_test.go @@ -0,0 +1,42 @@ +package updates + +import ( + "net/http/httptest" + "testing" + "time" +) + +type flushRecorder struct { + *httptest.ResponseRecorder +} + +func (f *flushRecorder) Flush() {} + +func TestManagerSSEHelpers(t *testing.T) { + m := &Manager{ + sseBroadcast: NewSSEBroadcaster(), + } + + if m.GetSSEBroadcaster() == nil { + t.Fatal("expected non-nil SSE broadcaster") + } + + rec := &flushRecorder{ResponseRecorder: httptest.NewRecorder()} + client := m.AddSSEClient(rec, "client-1") + if client == nil { + t.Fatal("expected AddSSEClient to return client") + } + + status, ts := m.GetSSECachedStatus() + if status.Status == "" { + t.Fatalf("expected cached status, got empty") + } + if ts.IsZero() { + t.Fatalf("expected cached status time to be set") + } + + m.RemoveSSEClient("client-1") + + // Give background goroutines a moment to handle send and removal. + time.Sleep(10 * time.Millisecond) +} diff --git a/internal/updates/mock_updater_additional_test.go b/internal/updates/mock_updater_additional_test.go new file mode 100644 index 000000000..79b708eef --- /dev/null +++ b/internal/updates/mock_updater_additional_test.go @@ -0,0 +1,13 @@ +package updates + +import ( + "context" + "testing" +) + +func TestMockUpdaterRollback(t *testing.T) { + updater := NewMockUpdater() + if err := updater.Rollback(context.Background(), "event-1"); err != nil { + t.Fatalf("Rollback returned error: %v", err) + } +} diff --git a/internal/updates/version_additional_test.go b/internal/updates/version_additional_test.go new file mode 100644 index 000000000..646cb51c4 --- /dev/null +++ b/internal/updates/version_additional_test.go @@ -0,0 +1,85 @@ +package updates + +import ( + "os" + "path/filepath" + "testing" +) + +func TestGetCurrentVersion_UsesVersionFile(t *testing.T) { + oldwd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + defer func() { + _ = os.Chdir(oldwd) + }() + + tmpDir := t.TempDir() + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("chdir: %v", err) + } + + t.Setenv("PATH", "") + t.Setenv("PULSE_MOCK_MODE", "") + t.Setenv("PULSE_ALLOW_DOCKER_UPDATES", "") + + versionPath := filepath.Join(tmpDir, "VERSION") + if err := os.WriteFile(versionPath, []byte("1.2.3"), 0644); err != nil { + t.Fatalf("write VERSION: %v", err) + } + + info, err := GetCurrentVersion() + if err != nil { + t.Fatalf("GetCurrentVersion error: %v", err) + } + if info.Version != "1.2.3" { + t.Fatalf("Version = %q, want 1.2.3", info.Version) + } + if info.Build != "release" { + t.Fatalf("Build = %q, want release", info.Build) + } + if info.IsDevelopment { + t.Fatalf("IsDevelopment = true, want false") + } + if info.Channel != "stable" { + t.Fatalf("Channel = %q, want stable", info.Channel) + } +} + +func TestGetDeploymentType_Mock(t *testing.T) { + t.Setenv("PULSE_MOCK_MODE", "true") + if got := GetDeploymentType(); got != "mock" { + t.Fatalf("GetDeploymentType = %q, want mock", got) + } +} + +func TestGetDeploymentType_Manual(t *testing.T) { + oldwd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + defer func() { + _ = os.Chdir(oldwd) + }() + + oldArgs := os.Args + defer func() { + os.Args = oldArgs + }() + + tmpDir := t.TempDir() + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("chdir: %v", err) + } + + t.Setenv("PULSE_MOCK_MODE", "") + t.Setenv("PULSE_ALLOW_DOCKER_UPDATES", "") + t.Setenv("PATH", "") + + os.Args = []string{"pulse"} + + if got := GetDeploymentType(); got != "manual" { + t.Fatalf("GetDeploymentType = %q, want manual", got) + } +} diff --git a/pkg/audit/tenant_logger_manager_test.go b/pkg/audit/tenant_logger_manager_test.go new file mode 100644 index 000000000..3f42a4b26 --- /dev/null +++ b/pkg/audit/tenant_logger_manager_test.go @@ -0,0 +1,130 @@ +package audit + +import ( + "path/filepath" + "testing" +) + +type stubLogger struct { + events []Event + queryCalls int + countCalls int + closed bool + urls []string +} + +func (s *stubLogger) Log(event Event) error { + s.events = append(s.events, event) + return nil +} + +func (s *stubLogger) Query(filter QueryFilter) ([]Event, error) { + s.queryCalls++ + return []Event{{EventType: "test"}}, nil +} + +func (s *stubLogger) Count(filter QueryFilter) (int, error) { + s.countCalls++ + return 7, nil +} + +func (s *stubLogger) GetWebhookURLs() []string { + return s.urls +} + +func (s *stubLogger) UpdateWebhookURLs(urls []string) error { + s.urls = urls + return nil +} + +func (s *stubLogger) Close() error { + s.closed = true + return nil +} + +type stubLoggerFactory struct { + created []string + logger Logger + err error +} + +func (f *stubLoggerFactory) CreateLogger(dbPath string) (Logger, error) { + f.created = append(f.created, dbPath) + if f.err != nil { + return nil, f.err + } + return f.logger, nil +} + +func TestTenantLoggerManager_GetLogger_Creates(t *testing.T) { + logger := &stubLogger{} + factory := &stubLoggerFactory{logger: logger} + manager := NewTenantLoggerManager("data", factory) + + got := manager.GetLogger("org-1") + if got != logger { + t.Fatalf("expected factory logger") + } + + expectedPath := filepath.Join("data", "orgs", "org-1", "audit.db") + if len(factory.created) != 1 || factory.created[0] != expectedPath { + t.Fatalf("expected db path %q, got %v", expectedPath, factory.created) + } +} + +func TestTenantLoggerManager_GetLogger_Default(t *testing.T) { + manager := NewTenantLoggerManager("data", &stubLoggerFactory{logger: &stubLogger{}}) + logger := manager.GetLogger("default") + if logger == nil { + t.Fatalf("expected default logger") + } +} + +func TestTenantLoggerManager_LogQueryCount(t *testing.T) { + logger := &stubLogger{} + manager := NewTenantLoggerManager("data", &stubLoggerFactory{logger: logger}) + + if err := manager.Log("org-1", "login", "user", "ip", "/path", true, "details"); err != nil { + t.Fatalf("unexpected log error: %v", err) + } + if len(logger.events) != 1 { + t.Fatalf("expected 1 logged event") + } + + if _, err := manager.Query("org-1", QueryFilter{}); err != nil || logger.queryCalls != 1 { + t.Fatalf("expected query to be called") + } + if _, err := manager.Count("org-1", QueryFilter{}); err != nil || logger.countCalls != 1 { + t.Fatalf("expected count to be called") + } +} + +func TestTenantLoggerManager_CloseAndRemove(t *testing.T) { + logger := &stubLogger{} + manager := NewTenantLoggerManager("data", &stubLoggerFactory{logger: logger}) + manager.GetLogger("org-1") + + manager.RemoveTenantLogger("org-1") + if !logger.closed { + t.Fatalf("expected logger to be closed on removal") + } + if len(manager.GetAllLoggers()) != 0 { + t.Fatalf("expected logger map to be empty after removal") + } + + manager.GetLogger("org-1") + manager.Close() + if len(manager.GetAllLoggers()) != 0 { + t.Fatalf("expected logger map to be cleared on close") + } +} + +func TestConsoleLogger_WebhookMethods(t *testing.T) { + logger := NewConsoleLogger() + if len(logger.GetWebhookURLs()) != 0 { + t.Fatalf("expected empty webhook URLs") + } + if err := logger.UpdateWebhookURLs([]string{"http://example.com"}); err != nil { + t.Fatalf("expected UpdateWebhookURLs to succeed") + } +} diff --git a/pkg/audit/webhook_validation_test.go b/pkg/audit/webhook_validation_test.go new file mode 100644 index 000000000..d3bac6e50 --- /dev/null +++ b/pkg/audit/webhook_validation_test.go @@ -0,0 +1,98 @@ +package audit + +import ( + "context" + "net" + "testing" + "time" +) + +func TestValidateWebhookURL(t *testing.T) { + origResolver := resolveWebhookIPs + defer func() { resolveWebhookIPs = origResolver }() + + resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: net.ParseIP("8.8.8.8")}}, nil + } + + if err := validateWebhookURL(context.Background(), ""); err == nil { + t.Fatalf("expected error for empty URL") + } + if err := validateWebhookURL(context.Background(), "not a url"); err == nil { + t.Fatalf("expected error for invalid URL") + } + if err := validateWebhookURL(context.Background(), "ftp://example.com"); err == nil { + t.Fatalf("expected error for invalid scheme") + } + if err := validateWebhookURL(context.Background(), "http://"); err == nil { + t.Fatalf("expected error for missing host") + } + if err := validateWebhookURL(context.Background(), "http://localhost"); err == nil { + t.Fatalf("expected error for localhost") + } + if err := validateWebhookURL(context.Background(), "http://127.0.0.1"); err == nil { + t.Fatalf("expected error for loopback") + } + if err := validateWebhookURL(context.Background(), "http://192.168.1.5"); err == nil { + t.Fatalf("expected error for private IP") + } + if err := validateWebhookURL(context.Background(), "http://metadata.google.internal"); err == nil { + t.Fatalf("expected error for blocked hostname") + } + + if err := validateWebhookURL(context.Background(), "https://example.com"); err != nil { + t.Fatalf("expected valid URL, got %v", err) + } + + resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) { + return nil, context.DeadlineExceeded + } + if err := validateWebhookURL(context.Background(), "https://example.com"); err == nil { + t.Fatalf("expected resolution error") + } + + resolveWebhookIPs = func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: net.ParseIP("10.0.0.2")}}, nil + } + if err := validateWebhookURL(context.Background(), "https://example.com"); err == nil { + t.Fatalf("expected private IP resolution error") + } +} + +func TestIsPrivateOrReservedIP(t *testing.T) { + cases := map[string]bool{ + "127.0.0.1": true, + "10.0.0.1": true, + "169.254.1.1": true, + "0.0.0.0": true, + "8.8.8.8": false, + } + for ipStr, expected := range cases { + if got := isPrivateOrReservedIP(net.ParseIP(ipStr)); got != expected { + t.Fatalf("ip %s expected %v, got %v", ipStr, expected, got) + } + } +} + +func TestWebhookDelivery_QueueAndURLs(t *testing.T) { + delivery := NewWebhookDelivery([]string{"http://example.com"}) + if delivery.QueueLength() != 0 { + t.Fatalf("expected empty queue") + } + + delivery.Enqueue(Event{ID: "e1", EventType: "login", Timestamp: time.Now()}) + if delivery.QueueLength() != 1 { + t.Fatalf("expected queued event") + } + + delivery.UpdateURLs([]string{"http://new.example.com"}) + urls := delivery.GetURLs() + if len(urls) != 1 || urls[0] != "http://new.example.com" { + t.Fatalf("expected updated URLs") + } + + urls[0] = "mutated" + if delivery.GetURLs()[0] != "http://new.example.com" { + t.Fatalf("expected URLs to be copied defensively") + } +} diff --git a/pkg/discovery/envdetect/envdetect_additional_test.go b/pkg/discovery/envdetect/envdetect_additional_test.go new file mode 100644 index 000000000..e5c50ff35 --- /dev/null +++ b/pkg/discovery/envdetect/envdetect_additional_test.go @@ -0,0 +1,456 @@ +package envdetect + +import ( + "errors" + "net" + "os" + "strings" + "testing" +) + +func TestDetectContainer_SystemdDetectVirtDocker(t *testing.T) { + probe := fakeEnvironmentProbe{ + lookPathPresent: map[string]bool{"systemd-detect-virt": true}, + commandOutput: map[string][]byte{ + "systemd-detect-virt\x00--container": []byte("docker\n"), + }, + } + + isContainer, containerType := detectContainer(probe) + if !isContainer { + t.Fatalf("expected container, got non-container") + } + if containerType != "docker" { + t.Fatalf("containerType = %q, want docker", containerType) + } +} + +func TestDetectContainer_SystemdDetectVirtUnknown(t *testing.T) { + probe := fakeEnvironmentProbe{ + lookPathPresent: map[string]bool{"systemd-detect-virt": true}, + commandOutput: map[string][]byte{ + "systemd-detect-virt\x00--container": []byte("rkt\n"), + }, + } + + isContainer, containerType := detectContainer(probe) + if !isContainer { + t.Fatalf("expected container, got non-container") + } + if containerType != "rkt" { + t.Fatalf("containerType = %q, want rkt", containerType) + } +} + +func TestDetectContainer_MarkersAndCgroup(t *testing.T) { + t.Run("marker-containerenv", func(t *testing.T) { + probe := fakeEnvironmentProbe{ + statPresent: map[string]bool{"/run/.containerenv": true}, + } + + isContainer, containerType := detectContainer(probe) + if !isContainer || containerType != "docker" { + t.Fatalf("marker detection = (%v, %q), want (true, docker)", isContainer, containerType) + } + }) + + t.Run("cgroup-docker", func(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{ + "/proc/1/cgroup": []byte("12:cpu:/docker/abc\n"), + }, + } + + isContainer, containerType := detectContainer(probe) + if !isContainer || containerType != "docker" { + t.Fatalf("cgroup detection = (%v, %q), want (true, docker)", isContainer, containerType) + } + }) + + t.Run("environ-lxc", func(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileErr: map[string]error{ + "/proc/1/cgroup": errors.New("nope"), + }, + fileData: map[string][]byte{ + "/proc/1/environ": []byte("container=lxc\x00"), + }, + } + + isContainer, containerType := detectContainer(probe) + if !isContainer || containerType != "lxc" { + t.Fatalf("environ detection = (%v, %q), want (true, lxc)", isContainer, containerType) + } + }) +} + +func TestDetectContainer_NoMarkers(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileErr: map[string]error{ + "/proc/1/cgroup": os.ErrNotExist, + "/proc/1/environ": os.ErrNotExist, + }, + } + + isContainer, containerType := detectContainer(probe) + if isContainer { + t.Fatalf("expected non-container, got container type %q", containerType) + } + if containerType != "" { + t.Fatalf("containerType = %q, want empty", containerType) + } +} + +func TestDetectNativeEnvironment_FallbackOnError(t *testing.T) { + profile := &EnvironmentProfile{ + Policy: DefaultScanPolicy(), + Metadata: map[string]string{}, + } + probe := fakeEnvironmentProbe{interfacesErr: errors.New("boom")} + + result, err := detectNativeEnvironment(profile, probe) + if err != nil { + t.Fatalf("detectNativeEnvironment returned error: %v", err) + } + if len(result.Phases) == 0 || result.Phases[0].Name != "fallback_common_subnets" { + t.Fatalf("expected fallback subnets, got %#v", result.Phases) + } + + found := false + for _, warn := range result.Warnings { + if strings.Contains(warn, "Failed to enumerate interfaces") { + found = true + break + } + } + if !found { + t.Fatalf("expected warning about interface enumeration, got %#v", result.Warnings) + } +} + +func TestDetectNativeEnvironment_NoSubnetsFallback(t *testing.T) { + profile := &EnvironmentProfile{ + Policy: DefaultScanPolicy(), + Metadata: map[string]string{}, + } + probe := fakeEnvironmentProbe{ + interfaces: []ifaceInfo{ + { + Name: "lo", + Flags: net.FlagUp | net.FlagLoopback, + Addrs: []net.Addr{ + &net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)}, + }, + }, + { + Name: "eth0", + Flags: net.FlagUp, + Addrs: []net.Addr{ + &net.IPNet{IP: net.IPv4(169, 254, 1, 2), Mask: net.CIDRMask(16, 32)}, + }, + }, + }, + } + + result, err := detectNativeEnvironment(profile, probe) + if err != nil { + t.Fatalf("detectNativeEnvironment returned error: %v", err) + } + if len(result.Phases) == 0 || result.Phases[0].Name != "fallback_common_subnets" { + t.Fatalf("expected fallback subnets, got %#v", result.Phases) + } + + found := false + for _, warn := range result.Warnings { + if strings.Contains(warn, "No active IPv4 interfaces found") { + found = true + break + } + } + if !found { + t.Fatalf("expected warning about missing interfaces, got %#v", result.Warnings) + } +} + +func TestGetAllLocalSubnets_DedupAndSkip(t *testing.T) { + probe := fakeEnvironmentProbe{ + interfaces: []ifaceInfo{ + { + Name: "lo", + Flags: net.FlagUp | net.FlagLoopback, + Addrs: []net.Addr{&net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)}}, + }, + { + Name: "down0", + Flags: 0, + Addrs: []net.Addr{&net.IPNet{IP: net.IPv4(10, 0, 0, 2), Mask: net.CIDRMask(24, 32)}}, + }, + { + Name: "err0", + Flags: net.FlagUp, + AddrsErr: errors.New("boom"), + }, + { + Name: "v6", + Flags: net.FlagUp, + Addrs: []net.Addr{&net.IPNet{IP: net.ParseIP("fe80::1"), Mask: net.CIDRMask(64, 128)}}, + }, + { + Name: "eth0", + Flags: net.FlagUp, + Addrs: []net.Addr{&net.IPNet{IP: net.IPv4(192, 168, 50, 10), Mask: net.CIDRMask(24, 32)}}, + }, + { + Name: "eth1", + Flags: net.FlagUp, + Addrs: []net.Addr{&net.IPNet{IP: net.IPv4(192, 168, 50, 20), Mask: net.CIDRMask(24, 32)}}, + }, + { + Name: "linklocal", + Flags: net.FlagUp, + Addrs: []net.Addr{&net.IPNet{IP: net.IPv4(169, 254, 10, 2), Mask: net.CIDRMask(16, 32)}}, + }, + }, + } + + subnets, err := getAllLocalSubnets(probe) + if err != nil { + t.Fatalf("getAllLocalSubnets returned error: %v", err) + } + if len(subnets) != 1 { + t.Fatalf("subnets = %#v, want 1 unique subnet", subnets) + } + if got := subnets[0].String(); got != mustIPNet(t, "192.168.50.0/24").String() { + t.Fatalf("subnet = %q, want 192.168.50.0/24", got) + } +} + +func TestIsDockerHostMode_InterfaceError(t *testing.T) { + probe := fakeEnvironmentProbe{interfacesErr: errors.New("boom")} + + hostMode, warnings := isDockerHostMode(probe) + if hostMode { + t.Fatalf("expected hostMode=false on error") + } + if len(warnings) == 0 || !strings.Contains(warnings[0], "Unable to enumerate interfaces") { + t.Fatalf("warnings = %#v, want interface enumeration warning", warnings) + } +} + +func TestDetectDockerEnvironment_BridgeFallback(t *testing.T) { + profile := &EnvironmentProfile{ + Policy: DefaultScanPolicy(), + Metadata: map[string]string{}, + } + profile.Policy.ScanGateways = false + + probe := fakeEnvironmentProbe{ + fileErr: map[string]error{ + "/proc/net/route": errors.New("nope"), + }, + interfaces: []ifaceInfo{ + { + Name: "lo", + Flags: net.FlagUp | net.FlagLoopback, + Addrs: []net.Addr{&net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)}}, + }, + }, + } + + result, err := detectDockerEnvironment(profile, probe) + if err != nil { + t.Fatalf("detectDockerEnvironment returned error: %v", err) + } + if result.Type != DockerBridge { + t.Fatalf("Type = %v, want %v", result.Type, DockerBridge) + } + if len(result.Phases) == 0 || result.Phases[0].Name != "fallback_common_subnets" { + t.Fatalf("expected fallback subnets, got %#v", result.Phases) + } +} + +func TestDetectHostNetworkFromContainer_NonStandardGateway(t *testing.T) { + route := strings.Join([]string{ + "Iface\tDestination\tGateway", + "eth0\t00000000\t2A00000A", + }, "\n") + + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{"/proc/net/route": []byte(route)}, + } + + subnets, confidence, warnings := detectHostNetworkFromContainer(probe) + if confidence != 0.4 { + t.Fatalf("confidence = %v, want 0.4", confidence) + } + if len(subnets) != 1 || subnets[0].String() != "10.0.0.0/24" { + t.Fatalf("subnets = %#v, want 10.0.0.0/24", subnets) + } + found := false + for _, warn := range warnings { + if strings.Contains(warn, "does not end with .1 or .254") { + found = true + break + } + } + if !found { + t.Fatalf("warnings = %#v, want non-standard gateway warning", warnings) + } +} + +func TestDetectHostNetworkFromContainer_GatewayErrorFallback(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileErr: map[string]error{"/proc/net/route": errors.New("boom")}, + } + + subnets, confidence, warnings := detectHostNetworkFromContainer(probe) + if confidence != 0.3 { + t.Fatalf("confidence = %v, want 0.3", confidence) + } + if len(subnets) == 0 { + t.Fatalf("expected fallback subnets, got none") + } + if len(warnings) == 0 || !strings.Contains(warnings[0], "Could not determine default gateway") { + t.Fatalf("warnings = %#v, want gateway warning", warnings) + } +} + +func TestGetDefaultGateway_InvalidHex(t *testing.T) { + route := strings.Join([]string{ + "Iface\tDestination\tGateway", + "eth0\t00000000\tZZZZZZZZ", + }, "\n") + + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{"/proc/net/route": []byte(route)}, + } + + _, err := getDefaultGateway(probe) + if err == nil { + t.Fatalf("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to parse default gateway") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestCountKernelRoutes_ReadFileError(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileErr: map[string]error{"/proc/net/route": errors.New("boom")}, + } + + count, warn := countKernelRoutes(probe) + if count != 0 { + t.Fatalf("count = %d, want 0", count) + } + if warn == "" || !strings.Contains(warn, "Unable to read /proc/net/route") { + t.Fatalf("warn = %q, want read error warning", warn) + } +} + +func TestIsLXCPrivileged_ErrorPaths(t *testing.T) { + t.Run("permission", func(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileErr: map[string]error{"/proc/self/uid_map": os.ErrPermission}, + } + privileged, warn := isLXCPrivileged(probe) + if privileged { + t.Fatalf("expected unprivileged on permission error") + } + if !strings.Contains(warn, "permission denied") { + t.Fatalf("warn = %q, want permission warning", warn) + } + }) + + t.Run("bad-format", func(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{"/proc/self/uid_map": []byte("0 0\n")}, + } + privileged, warn := isLXCPrivileged(probe) + if privileged { + t.Fatalf("expected unprivileged on format error") + } + if !strings.Contains(warn, "Unexpected format") { + t.Fatalf("warn = %q, want format warning", warn) + } + }) + + t.Run("bad-hostuid", func(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{"/proc/self/uid_map": []byte("0 x 1\n")}, + } + privileged, warn := isLXCPrivileged(probe) + if privileged { + t.Fatalf("expected unprivileged on parse error") + } + if !strings.Contains(warn, "Failed to parse uid_map") { + t.Fatalf("warn = %q, want parse warning", warn) + } + }) + + t.Run("bad-length", func(t *testing.T) { + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{"/proc/self/uid_map": []byte("0 0 nope\n")}, + } + privileged, warn := isLXCPrivileged(probe) + if privileged { + t.Fatalf("expected unprivileged on length parse error") + } + if !strings.Contains(warn, "Failed to parse uid_map length") { + t.Fatalf("warn = %q, want length warning", warn) + } + }) +} + +func TestDetectEnvironment_UnsupportedContainerType(t *testing.T) { + probe := fakeEnvironmentProbe{ + lookPathPresent: map[string]bool{"systemd-detect-virt": true}, + commandOutput: map[string][]byte{ + "systemd-detect-virt\x00--container": []byte("rkt\n"), + }, + } + + profile, err := detectEnvironment(probe) + if err != nil { + t.Fatalf("detectEnvironment returned error: %v", err) + } + if profile.Type != Unknown { + t.Fatalf("Type = %v, want %v", profile.Type, Unknown) + } + if profile.Metadata["container_type"] != "rkt" { + t.Fatalf("container_type = %q, want rkt", profile.Metadata["container_type"]) + } + if len(profile.Phases) == 0 || profile.Phases[0].Name != "fallback_common_subnets" { + t.Fatalf("expected fallback subnets, got %#v", profile.Phases) + } +} + +func TestDetectLXCEnvironment_UnprivilegedFallback(t *testing.T) { + profile := &EnvironmentProfile{ + Policy: DefaultScanPolicy(), + Metadata: map[string]string{}, + } + profile.Policy.ScanGateways = false + + probe := fakeEnvironmentProbe{ + fileData: map[string][]byte{"/proc/self/uid_map": []byte("0 100000 65536\n")}, + interfaces: []ifaceInfo{ + { + Name: "lo", + Flags: net.FlagUp | net.FlagLoopback, + Addrs: []net.Addr{&net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)}}, + }, + }, + } + + result, err := detectLXCEnvironment(profile, probe) + if err != nil { + t.Fatalf("detectLXCEnvironment returned error: %v", err) + } + if result.Type != LXCUnprivileged { + t.Fatalf("Type = %v, want %v", result.Type, LXCUnprivileged) + } + if len(result.Phases) == 0 || result.Phases[0].Name != "fallback_common_subnets" { + t.Fatalf("expected fallback subnets, got %#v", result.Phases) + } +}