mirror of
https://github.com/rcourtman/Pulse.git
synced 2026-04-26 10:31:17 +00:00
test: Add comprehensive test coverage across packages
New test files with expanded coverage: API tests: - ai_handler_test.go: AI handler unit tests with mocking - agent_profiles_tools_test.go: Profile management tests - alerts_endpoints_test.go: Alert API endpoint tests - alerts_test.go: Updated for interface changes - audit_handlers_test.go: Audit handler tests - frontend_embed_test.go: Frontend embedding tests - metadata_handlers_test.go, metadata_provider_test.go: Metadata tests - notifications_test.go: Updated for interface changes - profile_suggestions_test.go: Profile suggestion tests - saml_service_test.go: SAML authentication tests - sensor_proxy_gate_test.go: Sensor proxy tests - updates_test.go: Updated for interface changes Agent tests: - dockeragent/signature_test.go: Docker agent signature tests - hostagent/agent_metrics_test.go: Host agent metrics tests - hostagent/commands_test.go: Command execution tests - hostagent/network_helpers_test.go: Network helper tests - hostagent/proxmox_setup_test.go: Updated setup tests - kubernetesagent/*_test.go: Kubernetes agent tests Core package tests: - monitoring/kubernetes_agents_test.go, reload_test.go - remoteconfig/client_test.go, signature_test.go - sensors/collector_test.go - updates/adapter_installsh_*_test.go: Install adapter tests - updates/manager_*_test.go: Update manager tests - websocket/hub_*_test.go: WebSocket hub tests Library tests: - pkg/audit/export_test.go: Audit export tests - pkg/metrics/store_test.go: Metrics store tests - pkg/proxmox/*_test.go: Proxmox client tests - pkg/reporting/reporting_test.go: Reporting tests - pkg/server/*_test.go: Server tests - pkg/tlsutil/extra_test.go: TLS utility tests Total: ~8000 lines of new test code
This commit is contained in:
parent
d06ed2edb3
commit
a6a8efaa65
49 changed files with 8141 additions and 398 deletions
112
internal/api/agent_profiles_tools_test.go
Normal file
112
internal/api/agent_profiles_tools_test.go
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
)
|
||||
|
||||
func newTestProfileManager(t *testing.T) *MCPAgentProfileManager {
|
||||
t.Helper()
|
||||
persistence := config.NewConfigPersistence(t.TempDir())
|
||||
return NewMCPAgentProfileManager(persistence, nil)
|
||||
}
|
||||
|
||||
func TestMCPAgentProfileManagerApplyAndGetScope(t *testing.T) {
|
||||
manager := newTestProfileManager(t)
|
||||
ctx := context.Background()
|
||||
|
||||
settings := map[string]interface{}{
|
||||
"enable_host": true,
|
||||
}
|
||||
|
||||
profileID, profileName, created, err := manager.ApplyAgentScope(ctx, "agent-1", "Alpha", settings)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyAgentScope error: %v", err)
|
||||
}
|
||||
if !created || profileID == "" || profileName == "" {
|
||||
t.Fatalf("unexpected apply result: id=%q name=%q created=%v", profileID, profileName, created)
|
||||
}
|
||||
|
||||
scope, err := manager.GetAgentScope(ctx, "agent-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAgentScope error: %v", err)
|
||||
}
|
||||
if scope == nil || scope.ProfileID != profileID || scope.ProfileVersion != 1 {
|
||||
t.Fatalf("unexpected scope: %+v", scope)
|
||||
}
|
||||
if scope.Settings["enable_host"] != true {
|
||||
t.Fatalf("unexpected settings: %+v", scope.Settings)
|
||||
}
|
||||
|
||||
updatedSettings := map[string]interface{}{
|
||||
"enable_host": false,
|
||||
}
|
||||
_, _, created, err = manager.ApplyAgentScope(ctx, "agent-1", "Alpha", updatedSettings)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyAgentScope update error: %v", err)
|
||||
}
|
||||
if created {
|
||||
t.Fatal("expected update to reuse profile")
|
||||
}
|
||||
|
||||
scope, err = manager.GetAgentScope(ctx, "agent-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAgentScope error: %v", err)
|
||||
}
|
||||
if scope.ProfileVersion != 2 {
|
||||
t.Fatalf("expected profile version 2, got %d", scope.ProfileVersion)
|
||||
}
|
||||
if scope.Settings["enable_host"] != false {
|
||||
t.Fatalf("unexpected updated settings: %+v", scope.Settings)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPAgentProfileManagerAssignProfile(t *testing.T) {
|
||||
manager := newTestProfileManager(t)
|
||||
ctx := context.Background()
|
||||
|
||||
profile := models.AgentProfile{
|
||||
ID: "profile-1",
|
||||
Name: "Default",
|
||||
Description: "default",
|
||||
Config: map[string]interface{}{
|
||||
"enable_host": true,
|
||||
},
|
||||
Version: 1,
|
||||
}
|
||||
|
||||
if err := manager.persistence.SaveAgentProfiles([]models.AgentProfile{profile}); err != nil {
|
||||
t.Fatalf("SaveAgentProfiles error: %v", err)
|
||||
}
|
||||
|
||||
name, err := manager.AssignProfile(ctx, "agent-2", profile.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("AssignProfile error: %v", err)
|
||||
}
|
||||
if name != profile.Name {
|
||||
t.Fatalf("unexpected profile name: %q", name)
|
||||
}
|
||||
|
||||
scope, err := manager.GetAgentScope(ctx, "agent-2")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAgentScope error: %v", err)
|
||||
}
|
||||
if scope == nil || scope.ProfileID != profile.ID || scope.ProfileName != profile.Name {
|
||||
t.Fatalf("unexpected scope: %+v", scope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPAgentProfileManagerGetScopeMissing(t *testing.T) {
|
||||
manager := newTestProfileManager(t)
|
||||
|
||||
scope, err := manager.GetAgentScope(context.Background(), "missing")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAgentScope error: %v", err)
|
||||
}
|
||||
if scope != nil {
|
||||
t.Fatalf("expected nil scope, got %+v", scope)
|
||||
}
|
||||
}
|
||||
799
internal/api/ai_handler_test.go
Normal file
799
internal/api/ai_handler_test.go
Normal file
|
|
@ -0,0 +1,799 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ai/chat"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockAIService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAIService) Start(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAIService) Stop(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAIService) Restart(ctx context.Context, newCfg *config.AIConfig) error {
|
||||
args := m.Called(ctx, newCfg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAIService) IsRunning() bool {
|
||||
args := m.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockAIService) Execute(ctx context.Context, req chat.ExecuteRequest) (map[string]interface{}, error) {
|
||||
args := m.Called(ctx, req)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(map[string]interface{}), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) ExecuteStream(ctx context.Context, req chat.ExecuteRequest, callback chat.StreamCallback) error {
|
||||
args := m.Called(ctx, req, callback)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAIService) ListSessions(ctx context.Context) ([]chat.Session, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.Get(0).([]chat.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) CreateSession(ctx context.Context) (*chat.Session, error) {
|
||||
args := m.Called(ctx)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*chat.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
args := m.Called(ctx, sessionID)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAIService) GetMessages(ctx context.Context, sessionID string) ([]chat.Message, error) {
|
||||
args := m.Called(ctx, sessionID)
|
||||
return args.Get(0).([]chat.Message), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) AbortSession(ctx context.Context, sessionID string) error {
|
||||
args := m.Called(ctx, sessionID)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAIService) SummarizeSession(ctx context.Context, sessionID string) (map[string]interface{}, error) {
|
||||
args := m.Called(ctx, sessionID)
|
||||
return args.Get(0).(map[string]interface{}), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) GetSessionDiff(ctx context.Context, sessionID string) (map[string]interface{}, error) {
|
||||
args := m.Called(ctx, sessionID)
|
||||
return args.Get(0).(map[string]interface{}), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) ForkSession(ctx context.Context, sessionID string) (*chat.Session, error) {
|
||||
args := m.Called(ctx, sessionID)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*chat.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) RevertSession(ctx context.Context, sessionID string) (map[string]interface{}, error) {
|
||||
args := m.Called(ctx, sessionID)
|
||||
return args.Get(0).(map[string]interface{}), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) UnrevertSession(ctx context.Context, sessionID string) (map[string]interface{}, error) {
|
||||
args := m.Called(ctx, sessionID)
|
||||
return args.Get(0).(map[string]interface{}), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAIService) AnswerQuestion(ctx context.Context, questionID string, answers []chat.QuestionAnswer) error {
|
||||
args := m.Called(ctx, questionID, answers)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAIService) SetAlertProvider(provider chat.MCPAlertProvider) { m.Called(provider) }
|
||||
func (m *MockAIService) SetFindingsProvider(provider chat.MCPFindingsProvider) { m.Called(provider) }
|
||||
func (m *MockAIService) SetBaselineProvider(provider chat.MCPBaselineProvider) { m.Called(provider) }
|
||||
func (m *MockAIService) SetPatternProvider(provider chat.MCPPatternProvider) { m.Called(provider) }
|
||||
func (m *MockAIService) SetMetricsHistory(provider chat.MCPMetricsHistoryProvider) {
|
||||
m.Called(provider)
|
||||
}
|
||||
func (m *MockAIService) SetBackupProvider(provider chat.MCPBackupProvider) { m.Called(provider) }
|
||||
func (m *MockAIService) SetStorageProvider(provider chat.MCPStorageProvider) { m.Called(provider) }
|
||||
func (m *MockAIService) SetDiskHealthProvider(provider chat.MCPDiskHealthProvider) {
|
||||
m.Called(provider)
|
||||
}
|
||||
func (m *MockAIService) SetUpdatesProvider(provider chat.MCPUpdatesProvider) { m.Called(provider) }
|
||||
func (m *MockAIService) SetAgentProfileManager(manager chat.AgentProfileManager) {
|
||||
m.Called(manager)
|
||||
}
|
||||
func (m *MockAIService) SetFindingsManager(manager chat.FindingsManager) { m.Called(manager) }
|
||||
func (m *MockAIService) SetMetadataUpdater(updater chat.MetadataUpdater) { m.Called(updater) }
|
||||
|
||||
func (m *MockAIService) UpdateControlSettings(cfg *config.AIConfig) { m.Called(cfg) }
|
||||
func (m *MockAIService) GetBaseURL() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
type MockAIPersistence struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAIPersistence) LoadAIConfig() (*config.AIConfig, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*config.AIConfig), args.Error(1)
|
||||
}
|
||||
|
||||
type MockAIStateProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAIStateProvider) GetState() models.StateSnapshot {
|
||||
args := m.Called()
|
||||
return args.Get(0).(models.StateSnapshot)
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
// Mock newChatService
|
||||
oldNewService := newChatService
|
||||
defer func() { newChatService = oldNewService }()
|
||||
|
||||
mockSvc := new(MockAIService)
|
||||
newChatService = func(cfg chat.Config) AIService {
|
||||
return mockSvc
|
||||
}
|
||||
|
||||
mockPersist := new(MockAIPersistence)
|
||||
h := NewAIHandler(&config.Config{}, mockPersist, nil)
|
||||
|
||||
// AI disabled in config
|
||||
mockPersist.On("LoadAIConfig").Return(&config.AIConfig{Enabled: false}, nil).Once()
|
||||
err := h.Start(context.Background(), nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, h.service)
|
||||
|
||||
// AI enabled
|
||||
aiCfg := &config.AIConfig{Enabled: true, Model: "test"}
|
||||
mockPersist.On("LoadAIConfig").Return(aiCfg, nil).Once()
|
||||
mockSvc.On("Start", mock.Anything).Return(nil).Once()
|
||||
|
||||
err = h.Start(context.Background(), nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, mockSvc, h.service)
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
mockSvc := new(MockAIService)
|
||||
h := NewAIHandler(nil, nil, nil)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("Stop", mock.Anything).Return(nil)
|
||||
err := h.Stop(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Nil service
|
||||
h.service = nil
|
||||
err = h.Stop(context.Background())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStart_Error(t *testing.T) {
|
||||
oldNewService := newChatService
|
||||
defer func() { newChatService = oldNewService }()
|
||||
|
||||
mockSvc := new(MockAIService)
|
||||
newChatService = func(cfg chat.Config) AIService {
|
||||
return mockSvc
|
||||
}
|
||||
|
||||
mockPersist := new(MockAIPersistence)
|
||||
h := NewAIHandler(&config.Config{}, mockPersist, nil)
|
||||
|
||||
aiCfg := &config.AIConfig{Enabled: true, Model: "test"}
|
||||
mockPersist.On("LoadAIConfig").Return(aiCfg, nil)
|
||||
mockSvc.On("Start", mock.Anything).Return(assert.AnError)
|
||||
|
||||
err := h.Start(context.Background(), nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRestart(t *testing.T) {
|
||||
mockPersist := new(MockAIPersistence)
|
||||
mockSvc := new(MockAIService)
|
||||
h := NewAIHandler(nil, mockPersist, nil)
|
||||
h.service = mockSvc
|
||||
|
||||
aiCfg := &config.AIConfig{}
|
||||
mockPersist.On("LoadAIConfig").Return(aiCfg, nil)
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("Restart", mock.Anything, aiCfg).Return(nil)
|
||||
err := h.Restart(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Service nil
|
||||
h.service = nil
|
||||
err = h.Restart(context.Background())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetService(t *testing.T) {
|
||||
mockSvc := new(MockAIService)
|
||||
h := NewAIHandler(nil, nil, nil)
|
||||
h.service = mockSvc
|
||||
assert.Equal(t, mockSvc, h.GetService())
|
||||
}
|
||||
|
||||
func TestGetAIConfig(t *testing.T) {
|
||||
mockPersist := new(MockAIPersistence)
|
||||
h := NewAIHandler(nil, mockPersist, nil)
|
||||
|
||||
aiCfg := &config.AIConfig{Model: "test"}
|
||||
mockPersist.On("LoadAIConfig").Return(aiCfg, nil)
|
||||
|
||||
result := h.GetAIConfig()
|
||||
assert.Equal(t, aiCfg, result)
|
||||
}
|
||||
|
||||
func TestLoadAIConfig_Error(t *testing.T) {
|
||||
mockPersist := new(MockAIPersistence)
|
||||
h := NewAIHandler(nil, mockPersist, nil)
|
||||
|
||||
mockPersist.On("LoadAIConfig").Return((*config.AIConfig)(nil), assert.AnError)
|
||||
|
||||
result := h.loadAIConfig()
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestHandleStatus(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
APIToken: "test-token",
|
||||
}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/status", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleStatus(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var resp map[string]interface{}
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, resp["running"].(bool))
|
||||
assert.Equal(t, "direct", resp["engine"])
|
||||
}
|
||||
|
||||
func TestHandleSessions(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
sessions := []chat.Session{{ID: "s1"}, {ID: "s2"}}
|
||||
mockSvc.On("ListSessions", mock.Anything).Return(sessions, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/sessions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleSessions(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleCreateSession(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
session := &chat.Session{ID: "new-session"}
|
||||
mockSvc.On("CreateSession", mock.Anything).Return(session, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleCreateSession(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleDeleteSession(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("DeleteSession", mock.Anything, "s1").Return(nil)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/ai/sessions/s1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleDeleteSession(w, req, "s1")
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleMessages(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
messages := []chat.Message{{Role: "user", Content: "hello"}}
|
||||
mockSvc.On("GetMessages", mock.Anything, "s1").Return(messages, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/sessions/s1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleMessages(w, req, "s1")
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleChat_NotRunning(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(false)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/chat", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleChat(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleChat_InvalidJSON(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/chat", strings.NewReader("invalid"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleChat(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleChat_Success(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
|
||||
// Mock ExecuteStream to just return nil
|
||||
mockSvc.On("ExecuteStream", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||
callback := args.Get(2).(chat.StreamCallback)
|
||||
data, _ := json.Marshal("hello")
|
||||
callback(chat.StreamEvent{Type: "content", Data: data})
|
||||
})
|
||||
|
||||
body := `{"prompt": "hi"}`
|
||||
req := httptest.NewRequest("POST", "/api/ai/chat", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleChat(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Header().Get("Content-Type"), "text/event-stream")
|
||||
}
|
||||
|
||||
func TestHandleAnswerQuestion(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("AnswerQuestion", mock.Anything, "q1", mock.Anything).Return(nil)
|
||||
|
||||
body := `{"answers": [{"id": "a1", "value": "v1"}]}`
|
||||
req := httptest.NewRequest("POST", "/api/ai/question/q1/answer", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleAnswerQuestion(w, req, "q1")
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleSessions_NotRunning(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(false)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/sessions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleSessions(w, req)
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleSessions_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("ListSessions", mock.Anything).Return(([]chat.Session)(nil), assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/sessions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleSessions(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleCreateSession_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("CreateSession", mock.Anything).Return((*chat.Session)(nil), assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleCreateSession(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleDeleteSession_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("DeleteSession", mock.Anything, "s1").Return(assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/ai/sessions/s1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleDeleteSession(w, req, "s1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleMessages_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("GetMessages", mock.Anything, "s1").Return(([]chat.Message)(nil), assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/sessions/s1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleMessages(w, req, "s1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleAbort_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("AbortSession", mock.Anything, "s1").Return(assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/abort", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAbort(w, req, "s1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleSummarize_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("SummarizeSession", mock.Anything, "s1").Return((map[string]interface{})(nil), assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/summarize", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleSummarize(w, req, "s1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleAnswerQuestion_InvalidJSON(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/question/q1/answer", strings.NewReader("invalid"))
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAnswerQuestion(w, req, "q1")
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleAnswerQuestion_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("AnswerQuestion", mock.Anything, "q1", mock.Anything).Return(assert.AnError)
|
||||
|
||||
body := `{"answers": []}`
|
||||
req := httptest.NewRequest("POST", "/api/ai/question/q1/answer", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAnswerQuestion(w, req, "q1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleChat_Options(t *testing.T) {
|
||||
h := NewAIHandler(nil, nil, nil)
|
||||
req := httptest.NewRequest("OPTIONS", "/api/ai/chat", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleChat(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "http://example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
func TestHandleChat_MethodNotAllowed(t *testing.T) {
|
||||
h := NewAIHandler(nil, nil, nil)
|
||||
req := httptest.NewRequest("GET", "/api/ai/chat", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleChat(w, req)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleChat_Error(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("ExecuteStream", mock.Anything, mock.Anything, mock.Anything).Return(assert.AnError)
|
||||
|
||||
body := `{"prompt": "hi"}`
|
||||
req := httptest.NewRequest("POST", "/api/ai/chat", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleChat(w, req)
|
||||
// ExecuteStream error happens after headers are sent, so w.Code might be 200
|
||||
// but the error is returned.
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleDiff_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("GetSessionDiff", mock.Anything, "s1").Return((map[string]interface{})(nil), assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/sessions/s1/diff", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleDiff(w, req, "s1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleFork_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("ForkSession", mock.Anything, "s1").Return((*chat.Session)(nil), assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/fork", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleFork(w, req, "s1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleRevert_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("RevertSession", mock.Anything, "s1").Return((map[string]interface{})(nil), assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/revert", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleRevert(w, req, "s1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleUnrevert_Error(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("UnrevertSession", mock.Anything, "s1").Return((map[string]interface{})(nil), assert.AnError)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/unrevert", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleUnrevert(w, req, "s1")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleStatus_NotRunning(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(false)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/status", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleStatus(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code) // HandleStatus returns 200 even if not running
|
||||
var resp map[string]interface{}
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
assert.False(t, resp["running"].(bool))
|
||||
}
|
||||
|
||||
func TestMockUnimplemented(t *testing.T) {
|
||||
mockSvc := new(MockAIService)
|
||||
mockSvc.On("SetFindingsManager", mock.Anything).Return()
|
||||
mockSvc.On("SetMetadataUpdater", mock.Anything).Return()
|
||||
mockSvc.On("UpdateControlSettings", mock.Anything).Return()
|
||||
|
||||
h := NewAIHandler(nil, nil, nil)
|
||||
h.service = mockSvc
|
||||
|
||||
h.SetFindingsManager(nil)
|
||||
h.SetMetadataUpdater(nil)
|
||||
h.UpdateControlSettings(nil)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestProviders(t *testing.T) {
|
||||
h := NewAIHandler(nil, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
|
||||
mockSvc.On("SetAlertProvider", mock.Anything).Return()
|
||||
mockSvc.On("SetFindingsProvider", mock.Anything).Return()
|
||||
mockSvc.On("SetBaselineProvider", mock.Anything).Return()
|
||||
mockSvc.On("SetPatternProvider", mock.Anything).Return()
|
||||
mockSvc.On("SetMetricsHistory", mock.Anything).Return()
|
||||
mockSvc.On("SetAgentProfileManager", mock.Anything).Return()
|
||||
mockSvc.On("SetStorageProvider", mock.Anything).Return()
|
||||
mockSvc.On("SetBackupProvider", mock.Anything).Return()
|
||||
mockSvc.On("SetDiskHealthProvider", mock.Anything).Return()
|
||||
mockSvc.On("SetUpdatesProvider", mock.Anything).Return()
|
||||
|
||||
h.SetAlertProvider(nil)
|
||||
h.SetFindingsProvider(nil)
|
||||
h.SetBaselineProvider(nil)
|
||||
h.SetPatternProvider(nil)
|
||||
h.SetMetricsHistory(nil)
|
||||
h.SetAgentProfileManager(nil)
|
||||
h.SetStorageProvider(nil)
|
||||
h.SetBackupProvider(nil)
|
||||
h.SetDiskHealthProvider(nil)
|
||||
h.SetUpdatesProvider(nil)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestHandleAbort_Success(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("AbortSession", mock.Anything, "s1").Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/abort", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAbort(w, req, "s1")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleSummarize_Success(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("SummarizeSession", mock.Anything, "s1").Return(map[string]interface{}{"summary": "ok"}, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/summarize", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleSummarize(w, req, "s1")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleDiff_Success(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("GetSessionDiff", mock.Anything, "s1").Return(map[string]interface{}{"diff": "test"}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/sessions/s1/diff", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleDiff(w, req, "s1")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleFork_Success(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("ForkSession", mock.Anything, "s1").Return(&chat.Session{ID: "s2"}, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/fork", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleFork(w, req, "s1")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleRevert_Success(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("RevertSession", mock.Anything, "s1").Return(map[string]interface{}{"reverted": true}, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/revert", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleRevert(w, req, "s1")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleUnrevert_Success(t *testing.T) {
|
||||
h := NewAIHandler(&config.Config{}, nil, nil)
|
||||
mockSvc := new(MockAIService)
|
||||
h.service = mockSvc
|
||||
mockSvc.On("IsRunning").Return(true)
|
||||
mockSvc.On("UnrevertSession", mock.Anything, "s1").Return(map[string]interface{}{"unreverted": true}, nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ai/sessions/s1/unrevert", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleUnrevert(w, req, "s1")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestHandleStatus_NoService(t *testing.T) {
|
||||
// HandleStatus with no service initialized should still return 200 with running=false
|
||||
cfg := &config.Config{}
|
||||
h := NewAIHandler(cfg, nil, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/ai/status", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.HandleStatus(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var resp map[string]interface{}
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
assert.False(t, resp["running"].(bool))
|
||||
}
|
||||
254
internal/api/alerts_endpoints_test.go
Normal file
254
internal/api/alerts_endpoints_test.go
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
package api_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
|
||||
)
|
||||
|
||||
func TestAlertsEndpoints(t *testing.T) {
|
||||
srv := newIntegrationServer(t)
|
||||
|
||||
// 1. Get initial alert config
|
||||
t.Run("GetAlertConfig", func(t *testing.T) {
|
||||
res, err := http.Get(srv.server.URL + "/api/alerts/config")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
var config alerts.AlertConfig
|
||||
if err := json.NewDecoder(res.Body).Decode(&config); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// 2. Update alert config
|
||||
t.Run("UpdateAlertConfig", func(t *testing.T) {
|
||||
newConfig := alerts.AlertConfig{
|
||||
Schedule: alerts.ScheduleConfig{
|
||||
Cooldown: 300,
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(newConfig)
|
||||
// HandleAlerts expects PUT for config updates
|
||||
req, err := http.NewRequest(http.MethodPut, srv.server.URL+"/api/alerts/config", bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
// Verify update persistence
|
||||
resVerify, err := http.Get(srv.server.URL + "/api/alerts/config")
|
||||
if err != nil {
|
||||
t.Fatalf("verify request failed: %v", err)
|
||||
}
|
||||
defer resVerify.Body.Close()
|
||||
|
||||
var updatedConfig alerts.AlertConfig
|
||||
if err := json.NewDecoder(resVerify.Body).Decode(&updatedConfig); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
if updatedConfig.Schedule.Cooldown != 300 {
|
||||
t.Errorf("expected cooldown 300, got %d", updatedConfig.Schedule.Cooldown)
|
||||
}
|
||||
})
|
||||
|
||||
// 3. Activate alerts
|
||||
t.Run("ActivateAlerts", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, srv.server.URL+"/api/alerts/activate", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
// Activate again (should be idempotent)
|
||||
res2, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("retry request failed: %v", err)
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res2.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
// 4. Get active alerts
|
||||
t.Run("GetActiveAlerts", func(t *testing.T) {
|
||||
res, err := http.Get(srv.server.URL + "/api/alerts/active")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
// 5. Get alert history
|
||||
t.Run("GetAlertHistory", func(t *testing.T) {
|
||||
res, err := http.Get(srv.server.URL + "/api/alerts/history")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
// Test filters
|
||||
resFilter, err := http.Get(srv.server.URL + "/api/alerts/history?limit=10&severity=critical")
|
||||
if err != nil {
|
||||
t.Fatalf("filter request failed: %v", err)
|
||||
}
|
||||
defer resFilter.Body.Close()
|
||||
if resFilter.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", resFilter.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
// 6. Clear alert history
|
||||
t.Run("ClearAlertHistory", func(t *testing.T) {
|
||||
// HandleAlerts expects DELETE on /history
|
||||
req, err := http.NewRequest(http.MethodDelete, srv.server.URL+"/api/alerts/history", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
// 7. Acknowledge Alert (Single)
|
||||
t.Run("AcknowledgeAlert", func(t *testing.T) {
|
||||
body := map[string]string{"id": "test-alert-id"}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, srv.server.URL+"/api/alerts/acknowledge", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// Should be 404 because alert doesn't exist, but that proves the handler code ran
|
||||
if res.StatusCode != http.StatusNotFound && res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want 404 or 200", res.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
// 8. Bulk Acknowledge
|
||||
t.Run("BulkAcknowledge", func(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"alertIds": []string{"alert-1", "alert-2"},
|
||||
}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, srv.server.URL+"/api/alerts/bulk/acknowledge", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
// 9. Bulk Clear
|
||||
t.Run("BulkClear", func(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"alertIds": []string{"alert-1", "alert-2"},
|
||||
}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, srv.server.URL+"/api/alerts/bulk/clear", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
// 10. Incident Timeline
|
||||
t.Run("GetIncidentTimeline", func(t *testing.T) {
|
||||
// Test timeline list by resource
|
||||
res, err := http.Get(srv.server.URL + "/api/alerts/incidents?resource_id=test-node")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
// Test specific alert timeline
|
||||
res2, err := http.Get(srv.server.URL + "/api/alerts/incidents?alert_id=test-alert")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
|
||||
// 200 OK (empty/null) or 404 depending on impl. Implementation returns null/empty usually if not found but status 200.
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", res2.StatusCode, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -1,6 +1,227 @@
|
|||
package api
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ai/memory"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/alerts"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/notifications"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
testifymock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type MockAlertManager struct {
|
||||
testifymock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) GetConfig() alerts.AlertConfig {
|
||||
args := m.Called()
|
||||
return args.Get(0).(alerts.AlertConfig)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) UpdateConfig(cfg alerts.AlertConfig) {
|
||||
m.Called(cfg)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) GetActiveAlerts() []alerts.Alert {
|
||||
args := m.Called()
|
||||
return args.Get(0).([]alerts.Alert)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) NotifyExistingAlert(id string) {
|
||||
m.Called(id)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) ClearAlertHistory() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) UnacknowledgeAlert(id string) error {
|
||||
args := m.Called(id)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) AcknowledgeAlert(id, user string) error {
|
||||
args := m.Called(id, user)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) ClearAlert(id string) bool {
|
||||
args := m.Called(id)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) GetAlertHistory(limit int) []alerts.Alert {
|
||||
args := m.Called(limit)
|
||||
return args.Get(0).([]alerts.Alert)
|
||||
}
|
||||
|
||||
func (m *MockAlertManager) GetAlertHistorySince(since time.Time, limit int) []alerts.Alert {
|
||||
args := m.Called(since, limit)
|
||||
return args.Get(0).([]alerts.Alert)
|
||||
}
|
||||
|
||||
type MockAlertMonitor struct {
|
||||
testifymock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAlertMonitor) GetAlertManager() AlertManager {
|
||||
args := m.Called()
|
||||
return args.Get(0).(AlertManager)
|
||||
}
|
||||
|
||||
func (m *MockAlertMonitor) GetConfigPersistence() ConfigPersistence {
|
||||
args := m.Called()
|
||||
return args.Get(0).(ConfigPersistence)
|
||||
}
|
||||
|
||||
func (m *MockAlertMonitor) GetIncidentStore() *memory.IncidentStore {
|
||||
args := m.Called()
|
||||
if store := args.Get(0); store != nil {
|
||||
return store.(*memory.IncidentStore)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAlertMonitor) GetNotificationManager() *notifications.NotificationManager {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*notifications.NotificationManager)
|
||||
}
|
||||
|
||||
func (m *MockAlertMonitor) SyncAlertState() {
|
||||
m.Called()
|
||||
}
|
||||
|
||||
func (m *MockAlertMonitor) GetState() models.StateSnapshot {
|
||||
args := m.Called()
|
||||
return args.Get(0).(models.StateSnapshot)
|
||||
}
|
||||
|
||||
type MockConfigPersistence struct {
|
||||
testifymock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConfigPersistence) SaveAlertConfig(cfg alerts.AlertConfig) error {
|
||||
args := m.Called(cfg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// Tests
|
||||
func TestGetAlertConfig(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
cfg := alerts.AlertConfig{Enabled: true}
|
||||
mockManager.On("GetConfig").Return(cfg)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/alerts/config", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.GetAlertConfig(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var resp alerts.AlertConfig
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
assert.True(t, resp.Enabled)
|
||||
}
|
||||
|
||||
func TestUpdateAlertConfig(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockPersist := new(MockConfigPersistence)
|
||||
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("GetConfigPersistence").Return(mockPersist)
|
||||
mockMonitor.On("GetNotificationManager").Return(¬ifications.NotificationManager{})
|
||||
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
cfg := alerts.AlertConfig{Enabled: true}
|
||||
mockManager.On("UpdateConfig", testifymock.Anything).Return()
|
||||
mockManager.On("GetConfig").Return(cfg)
|
||||
mockPersist.On("SaveAlertConfig", testifymock.Anything).Return(nil)
|
||||
|
||||
body, _ := json.Marshal(cfg)
|
||||
req := httptest.NewRequest("POST", "/api/alerts/config", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.UpdateAlertConfig(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestGetActiveAlerts(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("GetActiveAlerts").Return([]alerts.Alert{{ID: "a1"}})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/alerts/active", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.GetActiveAlerts(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var resp []alerts.Alert
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
assert.Len(t, resp, 1)
|
||||
assert.Equal(t, "a1", resp[0].ID)
|
||||
}
|
||||
|
||||
func TestAcknowledgeAlert(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("AcknowledgeAlert", "a1", testifymock.Anything).Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a1/acknowledge", nil)
|
||||
req.SetPathValue("id", "a1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.AcknowledgeAlert(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestClearAlert(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("ClearAlert", "a1").Return(true)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a1/clear", nil)
|
||||
req.SetPathValue("id", "a1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ClearAlert(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestValidateAlertID(t *testing.T) {
|
||||
testCases := []struct {
|
||||
|
|
@ -18,20 +239,545 @@ func TestValidateAlertID(t *testing.T) {
|
|||
{name: "path traversal middle", id: "pve/../secret", valid: false},
|
||||
}
|
||||
|
||||
// Populate the oversized id string once to avoid zero bytes being mistaken for a valid character set.
|
||||
for i := range testCases {
|
||||
if testCases[i].name == "too long" {
|
||||
value := make([]byte, 501)
|
||||
for j := range value {
|
||||
value[j] = 'a'
|
||||
}
|
||||
testCases[i].id = string(value)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if got := validateAlertID(tc.id); got != tc.valid {
|
||||
t.Errorf("validateAlertID(%s) = %v, want %v", tc.name, got, tc.valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestAlertHandlers_SetMonitor(t *testing.T) {
|
||||
mockMonitor1 := new(MockAlertMonitor)
|
||||
mockMonitor2 := new(MockAlertMonitor)
|
||||
h := NewAlertHandlers(mockMonitor1, nil)
|
||||
assert.Equal(t, mockMonitor1, h.monitor)
|
||||
h.SetMonitor(mockMonitor2)
|
||||
assert.Equal(t, mockMonitor2, h.monitor)
|
||||
}
|
||||
|
||||
func TestGetAlertHistory(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("GetAlertHistory", testifymock.Anything).Return([]alerts.Alert{{ID: "h1"}})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/alerts/history?limit=10", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAlertHistory(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var resp []alerts.Alert
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
assert.Len(t, resp, 1)
|
||||
}
|
||||
|
||||
func TestUnacknowledgeAlert(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("UnacknowledgeAlert", "a1").Return(nil)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a1/unacknowledge", nil)
|
||||
req.SetPathValue("id", "a1")
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlert(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestClearAlertHistory(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("ClearAlertHistory").Return(nil).Once()
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/alerts/history/clear", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlertHistory(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestAcknowledgeAlertURL_Success(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("AcknowledgeAlert", "a/b", "admin").Return(nil).Once()
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a%2Fb/acknowledge", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlert(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestUnacknowledgeAlertURL_Success(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("UnacknowledgeAlert", "a/b").Return(nil).Once()
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a%2Fb/unacknowledge", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlert(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestClearAlertURL_Success(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("ClearAlert", "a/b").Return(true).Once()
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a%2Fb/clear", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlert(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestSaveAlertIncidentNote(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockStore := memory.NewIncidentStore(memory.IncidentStoreConfig{})
|
||||
mockMonitor.On("GetIncidentStore").Return(mockStore)
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
// Create an incident first so RecordNote has something to attach to
|
||||
alert := &alerts.Alert{ID: "a1", Type: "test"}
|
||||
mockStore.RecordAlertFired(alert)
|
||||
|
||||
body := `{"alert_id": "a1", "note": "test note", "user": "admin"}`
|
||||
req := httptest.NewRequest("POST", "/api/alerts/note", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.SaveAlertIncidentNote(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestBulkAcknowledgeAlerts(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("AcknowledgeAlert", "a1", "admin").Return(nil)
|
||||
mockManager.On("AcknowledgeAlert", "a2", "admin").Return(fmt.Errorf("error"))
|
||||
|
||||
body := `{"alertIds": ["a1", "a2"], "user": "admin"}`
|
||||
req := httptest.NewRequest("POST", "/api/alerts/bulk/acknowledge", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.BulkAcknowledgeAlerts(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var resp struct {
|
||||
Results []map[string]interface{} `json:"results"`
|
||||
}
|
||||
json.NewDecoder(w.Body).Decode(&resp)
|
||||
assert.Len(t, resp.Results, 2)
|
||||
}
|
||||
|
||||
func TestHandleAlerts(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("GetConfigPersistence").Return(new(MockConfigPersistence))
|
||||
mockMonitor.On("GetNotificationManager").Return(¬ifications.NotificationManager{})
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
type route struct {
|
||||
method string
|
||||
path string
|
||||
setup func()
|
||||
}
|
||||
|
||||
routes := []route{
|
||||
{"GET", "/api/alerts/active", func() { mockManager.On("GetActiveAlerts").Return([]alerts.Alert{}).Once() }},
|
||||
{"GET", "/api/alerts/history", func() {
|
||||
mockManager.On("GetAlertHistory", mock.MatchedBy(func(int) bool { return true })).Return([]alerts.Alert{}).Once()
|
||||
}},
|
||||
{"GET", "/api/alerts/incidents?alert_id=a1", func() {
|
||||
mockMonitor.On("GetIncidentStore").Return(memory.NewIncidentStore(memory.IncidentStoreConfig{})).Once()
|
||||
}},
|
||||
{"POST", "/api/alerts/incidents/note", func() {
|
||||
store := memory.NewIncidentStore(memory.IncidentStoreConfig{})
|
||||
store.RecordAlertFired(&alerts.Alert{ID: "a1", Type: "test"})
|
||||
mockMonitor.On("GetIncidentStore").Return(store).Once()
|
||||
}},
|
||||
{"DELETE", "/api/alerts/history", func() { mockManager.On("ClearAlertHistory").Return(nil).Once() }},
|
||||
{"POST", "/api/alerts/bulk/acknowledge", func() {
|
||||
mockManager.On("AcknowledgeAlert", mock.Anything, mock.Anything).Return(nil)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
}},
|
||||
{"POST", "/api/alerts/bulk/clear", func() {
|
||||
mockManager.On("ClearAlert", mock.Anything).Return(true)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
}},
|
||||
{"POST", "/api/alerts/acknowledge", func() {
|
||||
mockManager.On("AcknowledgeAlert", "a1", "admin").Return(nil).Once()
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
}},
|
||||
{"POST", "/api/alerts/unacknowledge", func() {
|
||||
mockManager.On("UnacknowledgeAlert", "a1").Return(nil).Once()
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
}},
|
||||
{"POST", "/api/alerts/clear", func() {
|
||||
mockManager.On("ClearAlert", "a1").Return(true).Once()
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
}},
|
||||
{"POST", "/api/alerts/a1/acknowledge", func() {
|
||||
mockManager.On("AcknowledgeAlert", "a1", "admin").Return(nil).Once()
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
}},
|
||||
{"POST", "/api/alerts/a1/unacknowledge", func() {
|
||||
mockManager.On("UnacknowledgeAlert", "a1").Return(nil).Once()
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
}},
|
||||
{"POST", "/api/alerts/a1/clear", func() {
|
||||
mockManager.On("ClearAlert", "a1").Return(true).Once()
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
}},
|
||||
}
|
||||
|
||||
for _, rt := range routes {
|
||||
t.Run(rt.method+"_"+rt.path, func(t *testing.T) {
|
||||
rt.setup()
|
||||
var body []byte
|
||||
if rt.method == "POST" || rt.method == "PUT" || rt.method == "DELETE" {
|
||||
if strings.Contains(rt.path, "bulk") {
|
||||
body = []byte(`{"alertIds": ["a1"]}`)
|
||||
} else if strings.Contains(rt.path, "note") {
|
||||
body = []byte(`{"alert_id": "a1", "note": "test"}`)
|
||||
} else {
|
||||
body = []byte(`{"id": "a1", "user": "admin"}`)
|
||||
}
|
||||
}
|
||||
req := httptest.NewRequest(rt.method, rt.path, bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAlerts(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// Test NotFound
|
||||
req := httptest.NewRequest("GET", "/api/alerts/unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAlerts(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestBulkClearAlerts(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("ClearAlert", "a1").Return(true)
|
||||
mockManager.On("ClearAlert", "a2").Return(false)
|
||||
|
||||
body := `{"alertIds": ["a1", "a2"]}`
|
||||
req := httptest.NewRequest("POST", "/api/alerts/bulk/clear", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.BulkClearAlerts(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestAcknowledgeAlertByBody_Success(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("AcknowledgeAlert", "a1", "admin").Return(nil)
|
||||
|
||||
body := `{"id": "a1", "user": "admin"}`
|
||||
req := httptest.NewRequest("POST", "/api/alerts/acknowledge", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlertByBody(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestUnacknowledgeAlertByBody_Success(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("UnacknowledgeAlert", "a1").Return(nil)
|
||||
|
||||
body := `{"id": "a1"}`
|
||||
req := httptest.NewRequest("POST", "/api/alerts/unacknowledge", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlertByBody(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestClearAlertByBody_Success(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
mockMonitor.On("SyncAlertState").Return()
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
mockManager.On("ClearAlert", "a1").Return(true)
|
||||
|
||||
body := `{"id": "a1"}`
|
||||
req := httptest.NewRequest("POST", "/api/alerts/clear", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlertByBody(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestAlertHandlers_ErrorCases(t *testing.T) {
|
||||
mockMonitor := new(MockAlertMonitor)
|
||||
mockManager := new(MockAlertManager)
|
||||
mockMonitor.On("GetAlertManager").Return(mockManager)
|
||||
h := NewAlertHandlers(mockMonitor, nil)
|
||||
|
||||
t.Run("AcknowledgeAlertByBody_InvalidJSON", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/acknowledge", strings.NewReader(`{invalid`))
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("AcknowledgeAlertByBody_MissingID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/acknowledge", strings.NewReader(`{"id": ""}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("AcknowledgeAlertByBody_InvalidID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/acknowledge", strings.NewReader(`{"id": "bad\x01"}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("AcknowledgeAlertByBody_ManagerError", func(t *testing.T) {
|
||||
mockManager.On("AcknowledgeAlert", "a1", "admin").Return(fmt.Errorf("error")).Once()
|
||||
req := httptest.NewRequest("POST", "/api/alerts/acknowledge", strings.NewReader(`{"id": "a1", "user": "admin"}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlertByBody(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UnacknowledgeAlertByBody_MissingID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/unacknowledge", strings.NewReader(`{"id": ""}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UnacknowledgeAlertByBody_InvalidID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/unacknowledge", strings.NewReader(`{"id": "bad\x01"}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("ClearAlertByBody_MissingID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/clear", strings.NewReader(`{"id": ""}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("ClearAlertByBody_InvalidID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/clear", strings.NewReader(`{"id": "bad\x01"}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("ClearAlertByBody_NotFound", func(t *testing.T) {
|
||||
mockManager.On("ClearAlert", "unknown").Return(false).Once()
|
||||
req := httptest.NewRequest("POST", "/api/alerts/clear", strings.NewReader(`{"id": "unknown"}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlertByBody(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
})
|
||||
|
||||
t.Run("BulkAcknowledgeAlerts_InvalidJSON", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/bulk/acknowledge", strings.NewReader(`{invalid`))
|
||||
w := httptest.NewRecorder()
|
||||
h.BulkAcknowledgeAlerts(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("BulkAcknowledgeAlerts_NoIDs", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/bulk/acknowledge", strings.NewReader(`{"alertIds": []}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.BulkAcknowledgeAlerts(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UnacknowledgeAlertByBody_InvalidJSON", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/unacknowledge", strings.NewReader(`{invalid`))
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("ClearAlertByBody_InvalidJSON", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/clear", strings.NewReader(`{invalid`))
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlertByBody(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("SaveAlertIncidentNote_NoStore", func(t *testing.T) {
|
||||
mockMonitor2 := new(MockAlertMonitor)
|
||||
mockMonitor2.On("GetIncidentStore").Return(nil)
|
||||
h2 := NewAlertHandlers(mockMonitor2, nil)
|
||||
req := httptest.NewRequest("POST", "/api/alerts/note", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
h2.SaveAlertIncidentNote(w, req)
|
||||
assert.Equal(t, 503, w.Code)
|
||||
})
|
||||
|
||||
t.Run("SaveAlertIncidentNote_InvalidBody", func(t *testing.T) {
|
||||
mockStore := memory.NewIncidentStore(memory.IncidentStoreConfig{})
|
||||
mockMonitor.On("GetIncidentStore").Return(mockStore)
|
||||
req := httptest.NewRequest("POST", "/api/alerts/note", strings.NewReader(`{invalid`))
|
||||
w := httptest.NewRecorder()
|
||||
h.SaveAlertIncidentNote(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("SaveAlertIncidentNote_MissingIDs", func(t *testing.T) {
|
||||
mockStore := memory.NewIncidentStore(memory.IncidentStoreConfig{})
|
||||
mockMonitor.On("GetIncidentStore").Return(mockStore)
|
||||
req := httptest.NewRequest("POST", "/api/alerts/note", strings.NewReader(`{"note": "test"}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.SaveAlertIncidentNote(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("SaveAlertIncidentNote_InvalidAlertID", func(t *testing.T) {
|
||||
mockStore := memory.NewIncidentStore(memory.IncidentStoreConfig{})
|
||||
mockMonitor.On("GetIncidentStore").Return(mockStore)
|
||||
req := httptest.NewRequest("POST", "/api/alerts/note", strings.NewReader(`{"alert_id": "bad\x01", "note": "test"}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.SaveAlertIncidentNote(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("SaveAlertIncidentNote_MissingNote", func(t *testing.T) {
|
||||
mockStore := memory.NewIncidentStore(memory.IncidentStoreConfig{})
|
||||
mockMonitor.On("GetIncidentStore").Return(mockStore)
|
||||
req := httptest.NewRequest("POST", "/api/alerts/note", strings.NewReader(`{"alert_id": "a1", "note": ""}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.SaveAlertIncidentNote(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("SaveAlertIncidentNote_NotFound", func(t *testing.T) {
|
||||
mockStore := memory.NewIncidentStore(memory.IncidentStoreConfig{})
|
||||
mockMonitor.On("GetIncidentStore").Return(mockStore)
|
||||
// alert_id non-existent in store
|
||||
req := httptest.NewRequest("POST", "/api/alerts/note", strings.NewReader(`{"alert_id": "none", "note": "test"}`))
|
||||
w := httptest.NewRecorder()
|
||||
h.SaveAlertIncidentNote(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("ClearAlertHistory_Error", func(t *testing.T) {
|
||||
mockManager.On("ClearAlertHistory").Return(errors.New("failed")).Once()
|
||||
req := httptest.NewRequest("POST", "/api/alerts/history/clear", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlertHistory(w, req)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
})
|
||||
|
||||
t.Run("AcknowledgeAlert_InvalidURL", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a/notack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlert(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("AcknowledgeAlert_InvalidID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/bad%01/acknowledge", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlert(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("AcknowledgeAlert_Error", func(t *testing.T) {
|
||||
mockManager.On("AcknowledgeAlert", "a1", "admin").Return(errors.New("not found")).Once()
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a1/acknowledge", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.AcknowledgeAlert(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UnacknowledgeAlert_InvalidURL", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a/notunack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlert(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UnacknowledgeAlert_InvalidID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/bad%01/unacknowledge", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlert(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UnacknowledgeAlert_Error", func(t *testing.T) {
|
||||
mockManager.On("UnacknowledgeAlert", "a1").Return(errors.New("not found")).Once()
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a1/unacknowledge", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.UnacknowledgeAlert(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
})
|
||||
|
||||
t.Run("ClearAlert_InvalidURL", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/a/notclear", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlert(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("ClearAlert_InvalidID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/alerts/bad%01/clear", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlert(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
})
|
||||
|
||||
t.Run("ClearAlert_NotFound", func(t *testing.T) {
|
||||
mockManager.On("ClearAlert", "none").Return(false).Once()
|
||||
req := httptest.NewRequest("POST", "/api/alerts/none/clear", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ClearAlert(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,11 +2,16 @@ package api
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/pkg/audit"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type verifyResponse struct {
|
||||
|
|
@ -19,6 +24,8 @@ type testAuditLogger struct {
|
|||
events []audit.Event
|
||||
verifyResult bool
|
||||
queryErr error
|
||||
updateErr error
|
||||
countErr error
|
||||
}
|
||||
|
||||
func (l *testAuditLogger) Log(event audit.Event) error {
|
||||
|
|
@ -42,6 +49,9 @@ func (l *testAuditLogger) Query(filter audit.QueryFilter) ([]audit.Event, error)
|
|||
}
|
||||
|
||||
func (l *testAuditLogger) Count(filter audit.QueryFilter) (int, error) {
|
||||
if l.countErr != nil {
|
||||
return 0, l.countErr
|
||||
}
|
||||
events, err := l.Query(filter)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
|
@ -62,7 +72,7 @@ func (l *testAuditLogger) GetWebhookURLs() []string {
|
|||
}
|
||||
|
||||
func (l *testAuditLogger) UpdateWebhookURLs(urls []string) error {
|
||||
return nil
|
||||
return l.updateErr
|
||||
}
|
||||
|
||||
type testAuditLoggerNoVerify struct {
|
||||
|
|
@ -226,4 +236,400 @@ func TestHandleVerifyAuditEvent_Failed(t *testing.T) {
|
|||
if resp.Verified {
|
||||
t.Fatalf("expected verified to be false")
|
||||
}
|
||||
|
||||
t.Run("Event not found", func(t *testing.T) {
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
events: []audit.Event{},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit/missing/verify", nil)
|
||||
req.SetPathValue("id", "missing")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleVerifyAuditEvent(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
var resp APIError
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
assert.Equal(t, "Audit event not found", resp.ErrorMessage)
|
||||
})
|
||||
|
||||
t.Run("Not persistent logger", func(t *testing.T) {
|
||||
setAuditLogger(t, audit.NewConsoleLogger())
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit/abc/verify", nil)
|
||||
req.SetPathValue("id", "abc")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleVerifyAuditEvent(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code) // Console logger returns 200 with available: false
|
||||
})
|
||||
|
||||
t.Run("Missing ID", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit//verify", nil)
|
||||
// Don't set path value
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleVerifyAuditEvent(rec, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
})
|
||||
|
||||
t.Run("Query error", func(t *testing.T) {
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
queryErr: fmt.Errorf("query error"),
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit/abc/verify", nil)
|
||||
req.SetPathValue("id", "abc")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleVerifyAuditEvent(rec, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleListAuditEvents(t *testing.T) {
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
events: []audit.Event{{ID: "1", EventType: "login"}, {ID: "2", EventType: "logout"}},
|
||||
})
|
||||
handler := NewAuditHandlers()
|
||||
|
||||
// Test success
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit?event=login", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleListAuditEvents(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
if resp["total"].(float64) != 2 {
|
||||
t.Errorf("expected total 2, got %v", resp["total"])
|
||||
}
|
||||
|
||||
// Test parse error for startTime/endTime
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/audit?startTime=invalid&endTime=invalid", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.HandleListAuditEvents(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code) // It just ignores invalid times
|
||||
|
||||
// Test method not allowed
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/audit", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.HandleListAuditEvents(rec, req)
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetWebhooks(t *testing.T) {
|
||||
handler := NewAuditHandlers()
|
||||
|
||||
// Test success
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit/webhooks", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleGetWebhooks(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
if _, ok := resp["urls"]; !ok {
|
||||
t.Error("expected urls field in response")
|
||||
}
|
||||
|
||||
// Test method not allowed
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/audit/webhooks", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.HandleGetWebhooks(rec, req)
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUpdateWebhooks(t *testing.T) {
|
||||
handler := NewAuditHandlers()
|
||||
|
||||
// Test success
|
||||
body := `{"urls": ["https://example.com/webhook"]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/audit/webhooks", strings.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleUpdateWebhooks(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected status %d, got %d: %s", http.StatusNoContent, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
// Test invalid URL (loopback)
|
||||
body = `{"urls": ["http://127.0.0.1/webhook"]}`
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/audit/webhooks", strings.NewReader(body))
|
||||
rec = httptest.NewRecorder()
|
||||
handler.HandleUpdateWebhooks(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d for loopback URL, got %d", http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// Test invalid JSON
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/audit/webhooks", strings.NewReader("invalid"))
|
||||
rec = httptest.NewRecorder()
|
||||
handler.HandleUpdateWebhooks(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// Test method not allowed
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/audit/webhooks", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.HandleUpdateWebhooks(rec, req)
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
}
|
||||
|
||||
// Test update error
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
updateErr: fmt.Errorf("update failed"),
|
||||
})
|
||||
body = `{"urls": ["https://example.com/webhook"]}`
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/audit/webhooks", strings.NewReader(body))
|
||||
rec = httptest.NewRecorder()
|
||||
handler.HandleUpdateWebhooks(rec, req)
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleExportAuditEvents_NotPersistent(t *testing.T) {
|
||||
setAuditLogger(t, audit.NewConsoleLogger())
|
||||
handler := NewAuditHandlers()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit/export", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleExportAuditEvents(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusNotImplemented, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleAuditSummary_NotPersistent(t *testing.T) {
|
||||
setAuditLogger(t, audit.NewConsoleLogger())
|
||||
handler := NewAuditHandlers()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit/summary", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleAuditSummary(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusNotImplemented, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrivateOrReservedIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
ip string
|
||||
reserved bool
|
||||
}{
|
||||
{"192.168.1.1", true},
|
||||
{"10.0.0.1", true},
|
||||
{"172.16.0.1", true},
|
||||
{"127.0.0.1", true},
|
||||
{"8.8.8.8", false},
|
||||
{"169.254.1.1", true},
|
||||
{"224.0.0.1", true},
|
||||
{"0.0.0.0", true},
|
||||
{"0.255.255.255", true},
|
||||
{"::1", true},
|
||||
{"fe80::1", true},
|
||||
{"ff02::1", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if got := isPrivateOrReservedIP(ip); got != tc.reserved {
|
||||
t.Errorf("isPrivateOrReservedIP(%s) = %v, want %v", tc.ip, got, tc.reserved)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWebhookURL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
url string
|
||||
wantErr bool
|
||||
}{
|
||||
{"https://example.com", false},
|
||||
{"http://test.com/hook", false},
|
||||
{"", true},
|
||||
{" ", true},
|
||||
{"://", true},
|
||||
{"ftp://example.com", true},
|
||||
{"https://", true},
|
||||
{"https://localhost", true},
|
||||
{"http://127.0.0.1", true},
|
||||
{"https://192.168.1.100", true},
|
||||
{"https://metadata.google", true},
|
||||
{"https://internal.site", true},
|
||||
{"http://example.local", true},
|
||||
{"https://example.com/path\x7f", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
err := validateWebhookURL(tc.url)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("validateWebhookURL(%s) error = %v, wantErr %v", tc.url, err, tc.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListAuditEvents_Filters(t *testing.T) {
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
events: []audit.Event{{ID: "1", EventType: "login", Success: true}},
|
||||
})
|
||||
handler := NewAuditHandlers()
|
||||
|
||||
// Test with various filters
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit?limit=10&offset=0&startTime=2023-01-01T00:00:00Z&endTime=2024-01-01T00:00:00Z&success=true", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleListAuditEvents(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListAuditEvents_QueryError(t *testing.T) {
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
queryErr: fmt.Errorf("db error"),
|
||||
})
|
||||
handler := NewAuditHandlers()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/audit", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.HandleListAuditEvents(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
// Test Count error
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
countErr: fmt.Errorf("count error"),
|
||||
})
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/audit", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.HandleListAuditEvents(rec, req)
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected count error status %d, got %d", http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
}
|
||||
func TestHandleExportAuditEvents(t *testing.T) {
|
||||
oldLogger := audit.GetLogger()
|
||||
defer audit.SetLogger(oldLogger)
|
||||
|
||||
logger := &testAuditLogger{
|
||||
events: []audit.Event{
|
||||
{ID: "1", EventType: "test", Success: true, Timestamp: time.Now()},
|
||||
},
|
||||
}
|
||||
audit.SetLogger(logger)
|
||||
|
||||
h := NewAuditHandlers()
|
||||
|
||||
t.Run("JSON format", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/audit/export?format=json", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleExportAuditEvents(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type"))
|
||||
assert.Contains(t, w.Header().Get("Content-Disposition"), "attachment; filename=audit-log-")
|
||||
assert.Equal(t, "1", w.Header().Get("X-Event-Count"))
|
||||
})
|
||||
|
||||
t.Run("CSV format", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/audit/export?format=csv", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleExportAuditEvents(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "text/csv; charset=utf-8", w.Header().Get("Content-Type"))
|
||||
assert.Contains(t, w.Header().Get("Content-Disposition"), "attachment; filename=audit-log-")
|
||||
})
|
||||
|
||||
t.Run("With filters", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/audit/export?event=test&user=admin&startTime=2026-01-01T00:00:00Z&endTime=2026-12-31T23:59:59Z&success=true&verify=true", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleExportAuditEvents(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("Method not allowed", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/audit/export", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleExportAuditEvents(w, req)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
})
|
||||
|
||||
t.Run("Export error", func(t *testing.T) {
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
queryErr: fmt.Errorf("query error"),
|
||||
})
|
||||
req := httptest.NewRequest("GET", "/api/audit/export", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleExportAuditEvents(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleAuditSummary(t *testing.T) {
|
||||
oldLogger := audit.GetLogger()
|
||||
defer audit.SetLogger(oldLogger)
|
||||
|
||||
logger := &testAuditLogger{
|
||||
events: []audit.Event{
|
||||
{ID: "1", EventType: "login", Success: true, Timestamp: time.Now()},
|
||||
{ID: "2", EventType: "login", Success: false, Timestamp: time.Now()},
|
||||
},
|
||||
}
|
||||
audit.SetLogger(logger)
|
||||
|
||||
h := NewAuditHandlers()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/audit/summary?verify=true", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAuditSummary(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
||||
|
||||
var summary audit.ExportSummary
|
||||
err := json.NewDecoder(w.Body).Decode(&summary)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, summary.TotalEvents)
|
||||
assert.Equal(t, 1, summary.SuccessCount)
|
||||
assert.Equal(t, 1, summary.FailureCount)
|
||||
})
|
||||
|
||||
t.Run("With filters", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/audit/summary?event=login&user=admin&startTime=2026-01-01T00:00:00Z&endTime=2026-12-31T23:59:59Z", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAuditSummary(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("Method not allowed", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/audit/summary", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAuditSummary(w, req)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
|
||||
})
|
||||
|
||||
t.Run("Summary error", func(t *testing.T) {
|
||||
setAuditLogger(t, &testAuditLogger{
|
||||
queryErr: fmt.Errorf("query error"),
|
||||
})
|
||||
req := httptest.NewRequest("GET", "/api/audit/summary", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleAuditSummary(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
97
internal/api/frontend_embed_test.go
Normal file
97
internal/api/frontend_embed_test.go
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func resetDevProxy() {
|
||||
devProxyOnce = sync.Once{}
|
||||
devProxy = nil
|
||||
devProxyErr = nil
|
||||
}
|
||||
|
||||
func writeFile(t *testing.T, dir, name, content string) {
|
||||
t.Helper()
|
||||
path := filepath.Join(dir, name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFrontendFSOverride(t *testing.T) {
|
||||
resetDevProxy()
|
||||
dir := t.TempDir()
|
||||
writeFile(t, dir, "index.html", "<html>ok</html>")
|
||||
t.Setenv("PULSE_FRONTEND_DIR", dir)
|
||||
|
||||
fsys, err := getFrontendFS()
|
||||
if err != nil {
|
||||
t.Fatalf("getFrontendFS error: %v", err)
|
||||
}
|
||||
|
||||
f, err := fsys.Open("index.html")
|
||||
if err != nil {
|
||||
t.Fatalf("open index: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
t.Fatalf("read index: %v", err)
|
||||
}
|
||||
if string(data) != "<html>ok</html>" {
|
||||
t.Fatalf("unexpected content: %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeFrontendHandler(t *testing.T) {
|
||||
resetDevProxy()
|
||||
dir := t.TempDir()
|
||||
writeFile(t, dir, "index.html", "<html>index</html>")
|
||||
writeFile(t, dir, "app-123.js", "console.log('ok');")
|
||||
t.Setenv("PULSE_FRONTEND_DIR", dir)
|
||||
t.Setenv("FRONTEND_DEV_SERVER", "")
|
||||
|
||||
handler := serveFrontendHandler()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handler(rec, req)
|
||||
if rec.Code != http.StatusOK || !strings.Contains(rec.Body.String(), "index") {
|
||||
t.Fatalf("unexpected root response: %d %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if rec.Header().Get("Cache-Control") == "" {
|
||||
t.Fatal("expected cache headers for index")
|
||||
}
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/app-123.js", nil)
|
||||
handler(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected asset response: %d", rec.Code)
|
||||
}
|
||||
if !strings.Contains(rec.Header().Get("Cache-Control"), "immutable") {
|
||||
t.Fatalf("expected immutable cache header, got %s", rec.Header().Get("Cache-Control"))
|
||||
}
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/missing", nil)
|
||||
handler(rec, req)
|
||||
if rec.Code != http.StatusOK || !strings.Contains(rec.Body.String(), "index") {
|
||||
t.Fatalf("expected SPA fallback, got %d %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
handler(rec, req)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404 for api path, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
119
internal/api/metadata_handlers_test.go
Normal file
119
internal/api/metadata_handlers_test.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func TestGuestMetadataHandler(t *testing.T) {
|
||||
handler := NewGuestMetadataHandler(t.TempDir())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/guests/metadata", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.HandleGetMetadata(resp, req)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
var all map[string]config.GuestMetadata
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &all); err != nil {
|
||||
t.Fatalf("decode all guests: %v", err)
|
||||
}
|
||||
if len(all) != 0 {
|
||||
t.Fatalf("expected empty metadata, got %v", all)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/guests/metadata/", strings.NewReader(`{}`))
|
||||
resp = httptest.NewRecorder()
|
||||
handler.HandleUpdateMetadata(resp, req)
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected bad request, got %d", resp.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/guests/metadata/100", strings.NewReader(`{"customUrl":"ftp://example.com"}`))
|
||||
resp = httptest.NewRecorder()
|
||||
handler.HandleUpdateMetadata(resp, req)
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected bad request, got %d", resp.Code)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), "http:// or https://") {
|
||||
t.Fatalf("unexpected error: %s", resp.Body.String())
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/guests/metadata/100", strings.NewReader(`{"customUrl":"https://example.com","description":"desc"}`))
|
||||
resp = httptest.NewRecorder()
|
||||
handler.HandleUpdateMetadata(resp, req)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
var meta config.GuestMetadata
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &meta); err != nil {
|
||||
t.Fatalf("decode guest metadata: %v", err)
|
||||
}
|
||||
if meta.CustomURL != "https://example.com" {
|
||||
t.Fatalf("unexpected metadata: %+v", meta)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/guests/metadata/100", nil)
|
||||
resp = httptest.NewRecorder()
|
||||
handler.HandleGetMetadata(resp, req)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &meta); err != nil {
|
||||
t.Fatalf("decode guest metadata: %v", err)
|
||||
}
|
||||
if meta.CustomURL != "https://example.com" {
|
||||
t.Fatalf("unexpected metadata: %+v", meta)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/guests/metadata/100", nil)
|
||||
resp = httptest.NewRecorder()
|
||||
handler.HandleDeleteMetadata(resp, req)
|
||||
if resp.Code != http.StatusNoContent {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostMetadataHandler(t *testing.T) {
|
||||
handler := NewHostMetadataHandler(t.TempDir())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/hosts/metadata", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.HandleGetMetadata(resp, req)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
var all map[string]config.HostMetadata
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &all); err != nil {
|
||||
t.Fatalf("decode all hosts: %v", err)
|
||||
}
|
||||
if len(all) != 0 {
|
||||
t.Fatalf("expected empty metadata, got %v", all)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/hosts/metadata/host1", strings.NewReader(`{"customUrl":"http://host.local"}`))
|
||||
resp = httptest.NewRecorder()
|
||||
handler.HandleUpdateMetadata(resp, req)
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
var meta config.HostMetadata
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &meta); err != nil {
|
||||
t.Fatalf("decode host metadata: %v", err)
|
||||
}
|
||||
if meta.CustomURL != "http://host.local" {
|
||||
t.Fatalf("unexpected metadata: %+v", meta)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/hosts/metadata/host1", nil)
|
||||
resp = httptest.NewRecorder()
|
||||
handler.HandleDeleteMetadata(resp, req)
|
||||
if resp.Code != http.StatusNoContent {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
}
|
||||
51
internal/api/metadata_provider_test.go
Normal file
51
internal/api/metadata_provider_test.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func TestMetadataProvider_SetURLs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
guestStore := config.NewGuestMetadataStore(dir, nil)
|
||||
dockerStore := config.NewDockerMetadataStore(dir, nil)
|
||||
hostStore := config.NewHostMetadataStore(dir, nil)
|
||||
|
||||
provider := NewMetadataProvider(guestStore, dockerStore, hostStore)
|
||||
|
||||
if err := provider.SetGuestURL("guest1", "https://guest"); err != nil {
|
||||
t.Fatalf("SetGuestURL error: %v", err)
|
||||
}
|
||||
if got := guestStore.Get("guest1"); got == nil || got.CustomURL != "https://guest" {
|
||||
t.Fatalf("unexpected guest metadata: %+v", got)
|
||||
}
|
||||
|
||||
if err := provider.SetDockerURL("ctr1", "https://docker"); err != nil {
|
||||
t.Fatalf("SetDockerURL error: %v", err)
|
||||
}
|
||||
if got := dockerStore.Get("ctr1"); got == nil || got.CustomURL != "https://docker" {
|
||||
t.Fatalf("unexpected docker metadata: %+v", got)
|
||||
}
|
||||
|
||||
if err := provider.SetHostURL("host1", "https://host"); err != nil {
|
||||
t.Fatalf("SetHostURL error: %v", err)
|
||||
}
|
||||
if got := hostStore.Get("host1"); got == nil || got.CustomURL != "https://host" {
|
||||
t.Fatalf("unexpected host metadata: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataProvider_SetURLMissingStore(t *testing.T) {
|
||||
provider := NewMetadataProvider(nil, nil, nil)
|
||||
|
||||
if err := provider.SetGuestURL("guest1", "https://guest"); err == nil {
|
||||
t.Fatal("expected guest store error")
|
||||
}
|
||||
if err := provider.SetDockerURL("ctr1", "https://docker"); err == nil {
|
||||
t.Fatal("expected docker store error")
|
||||
}
|
||||
if err := provider.SetHostURL("host1", "https://host"); err == nil {
|
||||
t.Fatal("expected host store error")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,16 @@
|
|||
package api
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/notifications"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestRedactSecretsFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
|
@ -139,3 +149,478 @@ func TestRedactSecretsFromURL(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MockNotificationMonitor struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockNotificationMonitor) GetNotificationManager() NotificationManager {
|
||||
args := m.Called()
|
||||
return args.Get(0).(NotificationManager)
|
||||
}
|
||||
|
||||
func (m *MockNotificationMonitor) GetConfigPersistence() NotificationConfigPersistence {
|
||||
args := m.Called()
|
||||
return args.Get(0).(NotificationConfigPersistence)
|
||||
}
|
||||
|
||||
func (m *MockNotificationMonitor) GetState() models.StateSnapshot {
|
||||
args := m.Called()
|
||||
return args.Get(0).(models.StateSnapshot)
|
||||
}
|
||||
|
||||
type MockNotificationManager struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) GetEmailConfig() notifications.EmailConfig {
|
||||
args := m.Called()
|
||||
return args.Get(0).(notifications.EmailConfig)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) SetEmailConfig(cfg notifications.EmailConfig) {
|
||||
m.Called(cfg)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) GetAppriseConfig() notifications.AppriseConfig {
|
||||
args := m.Called()
|
||||
return args.Get(0).(notifications.AppriseConfig)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) SetAppriseConfig(cfg notifications.AppriseConfig) {
|
||||
m.Called(cfg)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) GetWebhooks() []notifications.WebhookConfig {
|
||||
args := m.Called()
|
||||
return args.Get(0).([]notifications.WebhookConfig)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) ValidateWebhookURL(url string) error {
|
||||
args := m.Called(url)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) AddWebhook(w notifications.WebhookConfig) {
|
||||
m.Called(w)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) UpdateWebhook(id string, w notifications.WebhookConfig) error {
|
||||
args := m.Called(id, w)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) DeleteWebhook(id string) error {
|
||||
args := m.Called(id)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) SendTestWebhook(w notifications.WebhookConfig) error {
|
||||
args := m.Called(w)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) SendTestNotificationWithConfig(method string, cfg *notifications.EmailConfig, nodeInfo *notifications.TestNodeInfo) error {
|
||||
args := m.Called(method, cfg, nodeInfo)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) SendTestAppriseWithConfig(cfg notifications.AppriseConfig) error {
|
||||
args := m.Called(cfg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) SendTestNotification(method string) error {
|
||||
args := m.Called(method)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) GetWebhookHistory() []notifications.WebhookDelivery {
|
||||
args := m.Called()
|
||||
return args.Get(0).([]notifications.WebhookDelivery)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) TestEnhancedWebhook(w notifications.EnhancedWebhookConfig) (int, string, error) {
|
||||
args := m.Called(w)
|
||||
return args.Int(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockNotificationManager) GetQueueStats() (map[string]int, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(map[string]int), args.Error(1)
|
||||
}
|
||||
|
||||
type MockNotificationConfigPersistence struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockNotificationConfigPersistence) SaveEmailConfig(cfg notifications.EmailConfig) error {
|
||||
args := m.Called(cfg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationConfigPersistence) SaveAppriseConfig(cfg notifications.AppriseConfig) error {
|
||||
args := m.Called(cfg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationConfigPersistence) SaveWebhooks(w []notifications.WebhookConfig) error {
|
||||
args := m.Called(w)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockNotificationConfigPersistence) IsEncryptionEnabled() bool {
|
||||
args := m.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func TestNotificationHandlers(t *testing.T) {
|
||||
mockMonitor := new(MockNotificationMonitor)
|
||||
mockManager := new(MockNotificationManager)
|
||||
mockPersistence := new(MockNotificationConfigPersistence)
|
||||
|
||||
mockMonitor.On("GetNotificationManager").Return(mockManager)
|
||||
mockMonitor.On("GetConfigPersistence").Return(mockPersistence)
|
||||
|
||||
h := NewNotificationHandlers(mockMonitor)
|
||||
|
||||
t.Run("GetEmailConfig", func(t *testing.T) {
|
||||
cfg := notifications.EmailConfig{
|
||||
Enabled: true,
|
||||
SMTPHost: "smtp.example.com",
|
||||
Password: "password123",
|
||||
}
|
||||
mockManager.On("GetEmailConfig").Return(cfg).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/notifications/email", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetEmailConfig(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var resp notifications.EmailConfig
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.Equal(t, "smtp.example.com", resp.SMTPHost)
|
||||
assert.Empty(t, resp.Password) // Should be redacted
|
||||
})
|
||||
|
||||
t.Run("UpdateEmailConfig", func(t *testing.T) {
|
||||
cfg := notifications.EmailConfig{
|
||||
Enabled: true,
|
||||
SMTPHost: "smtp.example.com",
|
||||
Password: "newpassword",
|
||||
}
|
||||
mockManager.On("SetEmailConfig", mock.Anything).Return().Once()
|
||||
mockPersistence.On("SaveEmailConfig", mock.Anything).Return(nil).Once()
|
||||
|
||||
body, _ := json.Marshal(cfg)
|
||||
req := httptest.NewRequest("PUT", "/api/notifications/email", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateEmailConfig(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
mockManager.AssertExpectations(t)
|
||||
mockPersistence.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("GetWebhooks", func(t *testing.T) {
|
||||
webhooks := []notifications.WebhookConfig{
|
||||
{
|
||||
ID: "wh1",
|
||||
Name: "Test Webhook",
|
||||
URL: "https://example.com",
|
||||
Headers: map[string]string{"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
mockManager.On("GetWebhooks").Return(webhooks).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/notifications/webhooks", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetWebhooks(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var resp []map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.Equal(t, 1, len(resp))
|
||||
assert.Equal(t, "wh1", resp[0]["id"])
|
||||
headers := resp[0]["headers"].(map[string]interface{})
|
||||
assert.Equal(t, "***REDACTED***", headers["Authorization"])
|
||||
})
|
||||
|
||||
t.Run("CreateWebhook", func(t *testing.T) {
|
||||
webhook := notifications.WebhookConfig{
|
||||
Name: "New Webhook",
|
||||
URL: "https://example.com/new",
|
||||
}
|
||||
mockManager.On("ValidateWebhookURL", "https://example.com/new").Return(nil).Once()
|
||||
mockManager.On("AddWebhook", mock.Anything).Return().Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{}).Once()
|
||||
mockPersistence.On("SaveWebhooks", mock.Anything).Return(nil).Once()
|
||||
|
||||
body, _ := json.Marshal(webhook)
|
||||
req := httptest.NewRequest("POST", "/api/notifications/webhooks", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.CreateWebhook(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("GetNotificationHealth", func(t *testing.T) {
|
||||
stats := map[string]int{
|
||||
"pending": 1,
|
||||
"sending": 2,
|
||||
"sent": 10,
|
||||
"failed": 0,
|
||||
"dlq": 0,
|
||||
}
|
||||
mockManager.On("GetQueueStats").Return(stats, nil).Once()
|
||||
mockManager.On("GetEmailConfig").Return(notifications.EmailConfig{}).Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{}).Once()
|
||||
mockPersistence.On("IsEncryptionEnabled").Return(true).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/notifications/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetNotificationHealth(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
queue := resp["queue"].(map[string]interface{})
|
||||
assert.Equal(t, float64(1), queue["pending"])
|
||||
assert.Equal(t, true, queue["healthy"])
|
||||
})
|
||||
|
||||
t.Run("GetAppriseConfig", func(t *testing.T) {
|
||||
cfg := notifications.AppriseConfig{Enabled: true}
|
||||
mockManager.On("GetAppriseConfig").Return(cfg).Once()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/notifications/apprise", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetAppriseConfig(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UpdateAppriseConfig", func(t *testing.T) {
|
||||
cfg := notifications.AppriseConfig{Enabled: true}
|
||||
mockManager.On("SetAppriseConfig", mock.Anything).Return().Once()
|
||||
mockPersistence.On("SaveAppriseConfig", mock.Anything).Return(nil).Once()
|
||||
mockManager.On("GetAppriseConfig").Return(cfg).Once()
|
||||
|
||||
body, _ := json.Marshal(cfg)
|
||||
req := httptest.NewRequest("PUT", "/api/notifications/apprise", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateAppriseConfig(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UpdateWebhook", func(t *testing.T) {
|
||||
webhook := notifications.WebhookConfig{ID: "wh1", Name: "Updated"}
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{{ID: "wh1"}}).Once()
|
||||
mockManager.On("ValidateWebhookURL", mock.Anything).Return(nil).Once()
|
||||
mockManager.On("UpdateWebhook", "wh1", mock.Anything).Return(nil).Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{webhook}).Once()
|
||||
mockPersistence.On("SaveWebhooks", mock.Anything).Return(nil).Once()
|
||||
|
||||
body, _ := json.Marshal(webhook)
|
||||
req := httptest.NewRequest("PUT", "/api/notifications/webhooks/wh1", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateWebhook(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("DeleteWebhook", func(t *testing.T) {
|
||||
mockManager.On("DeleteWebhook", "wh1").Return(nil).Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{}).Once()
|
||||
mockPersistence.On("SaveWebhooks", mock.Anything).Return(nil).Once()
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/notifications/webhooks/wh1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.DeleteWebhook(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("GetWebhookTemplates", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/notifications/webhooks/templates", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetWebhookTemplates(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("GetWebhookHistory", func(t *testing.T) {
|
||||
mockManager.On("GetWebhookHistory").Return([]notifications.WebhookDelivery{}).Once()
|
||||
req := httptest.NewRequest("GET", "/api/notifications/webhooks/history", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetWebhookHistory(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("GetEmailProviders", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/notifications/email/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetEmailProviders(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("HandleNotifications_Router", func(t *testing.T) {
|
||||
routes := []struct {
|
||||
method string
|
||||
path string
|
||||
setup func()
|
||||
}{
|
||||
{"GET", "/api/notifications/email", func() { mockManager.On("GetEmailConfig").Return(notifications.EmailConfig{}).Once() }},
|
||||
{"PUT", "/api/notifications/email", func() {
|
||||
mockManager.On("GetEmailConfig").Return(notifications.EmailConfig{}).Once()
|
||||
mockManager.On("SetEmailConfig", mock.Anything).Return().Once()
|
||||
mockPersistence.On("SaveEmailConfig", mock.Anything).Return(nil).Once()
|
||||
}},
|
||||
{"GET", "/api/notifications/apprise", func() { mockManager.On("GetAppriseConfig").Return(notifications.AppriseConfig{}).Once() }},
|
||||
{"PUT", "/api/notifications/apprise", func() {
|
||||
mockManager.On("SetAppriseConfig", mock.Anything).Return().Once()
|
||||
mockPersistence.On("SaveAppriseConfig", mock.Anything).Return(nil).Once()
|
||||
mockManager.On("GetAppriseConfig").Return(notifications.AppriseConfig{}).Once()
|
||||
}},
|
||||
{"GET", "/api/notifications/webhooks", func() { mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{}).Once() }},
|
||||
{"POST", "/api/notifications/webhooks", func() {
|
||||
mockManager.On("ValidateWebhookURL", mock.Anything).Return(nil).Once()
|
||||
mockManager.On("AddWebhook", mock.Anything).Return().Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{}).Once()
|
||||
mockPersistence.On("SaveWebhooks", mock.Anything).Return(nil).Once()
|
||||
}},
|
||||
{"POST", "/api/notifications/webhooks/test", func() {
|
||||
mockManager.On("TestEnhancedWebhook", mock.Anything).Return(200, "OK", nil).Once()
|
||||
}},
|
||||
{"PUT", "/api/notifications/webhooks/wh1", func() {
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{{ID: "wh1"}}).Once()
|
||||
mockManager.On("ValidateWebhookURL", mock.Anything).Return(nil).Once()
|
||||
mockManager.On("UpdateWebhook", "wh1", mock.Anything).Return(nil).Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{{ID: "wh1"}}).Once()
|
||||
mockPersistence.On("SaveWebhooks", mock.Anything).Return(nil).Once()
|
||||
}},
|
||||
{"DELETE", "/api/notifications/webhooks/wh1", func() {
|
||||
mockManager.On("DeleteWebhook", "wh1").Return(nil).Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{}).Once()
|
||||
mockPersistence.On("SaveWebhooks", mock.Anything).Return(nil).Once()
|
||||
}},
|
||||
{"GET", "/api/notifications/webhook-templates", func() {}},
|
||||
{"GET", "/api/notifications/webhook-history", func() { mockManager.On("GetWebhookHistory").Return([]notifications.WebhookDelivery{}).Once() }},
|
||||
{"GET", "/api/notifications/email-providers", func() {}},
|
||||
{"GET", "/api/notifications/health", func() {
|
||||
mockManager.On("GetQueueStats").Return(map[string]int{}, nil).Once()
|
||||
mockManager.On("GetEmailConfig").Return(notifications.EmailConfig{}).Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{}).Once()
|
||||
mockPersistence.On("IsEncryptionEnabled").Return(true).Once()
|
||||
}},
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
t.Run(route.method+"_"+route.path, func(t *testing.T) {
|
||||
route.setup()
|
||||
var body []byte
|
||||
if route.method == "POST" || route.method == "PUT" {
|
||||
body = []byte("{}")
|
||||
}
|
||||
req := httptest.NewRequest(route.method, route.path, bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleNotifications(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// Test 404
|
||||
req := httptest.NewRequest("GET", "/api/notifications/unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleNotifications(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
})
|
||||
|
||||
t.Run("TestNotification", func(t *testing.T) {
|
||||
mockMonitor.On("GetState").Return(models.StateSnapshot{}).Once()
|
||||
mockManager.On("SendTestNotification", "email").Return(nil).Once()
|
||||
body, _ := json.Marshal(map[string]string{"method": "email"})
|
||||
req := httptest.NewRequest("POST", "/api/notifications/test", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.TestNotification(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("TestNotification_Webhook", func(t *testing.T) {
|
||||
mockMonitor.On("GetState").Return(models.StateSnapshot{}).Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{{ID: "wh1"}}).Once()
|
||||
mockManager.On("SendTestWebhook", mock.Anything).Return(nil).Once()
|
||||
body, _ := json.Marshal(map[string]string{"method": "webhook", "webhookId": "wh1"})
|
||||
req := httptest.NewRequest("POST", "/api/notifications/test", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.TestNotification(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("TestWebhook", func(t *testing.T) {
|
||||
mockManager.On("TestEnhancedWebhook", mock.Anything).Return(200, "OK", nil).Once()
|
||||
body, _ := json.Marshal(map[string]string{"url": "https://example.com/test", "service": "ntfy"})
|
||||
req := httptest.NewRequest("POST", "/api/notifications/webhooks/test", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.TestWebhook(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("TestNotification_EmailWithConfig", func(t *testing.T) {
|
||||
mockMonitor.On("GetState").Return(models.StateSnapshot{}).Once()
|
||||
mockManager.On("SendTestNotificationWithConfig", "email", mock.Anything, mock.Anything).Return(nil).Once()
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"method": "email",
|
||||
"config": notifications.EmailConfig{Enabled: true, SMTPHost: "smtp.example.com", Password: "test"},
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/api/notifications/test", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.TestNotification(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("TestNotification_AppriseWithConfig", func(t *testing.T) {
|
||||
mockMonitor.On("GetState").Return(models.StateSnapshot{}).Once()
|
||||
mockManager.On("SendTestAppriseWithConfig", mock.Anything).Return(nil).Once()
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"method": "apprise",
|
||||
"config": notifications.AppriseConfig{Enabled: true, APIKey: "test"},
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/api/notifications/test", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.TestNotification(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
|
||||
t.Run("UpdateWebhook_PreserveRedacted", func(t *testing.T) {
|
||||
existing := notifications.WebhookConfig{
|
||||
ID: "wh1",
|
||||
Headers: map[string]string{"Auth": "secret"},
|
||||
CustomFields: map[string]string{"Key": "value"},
|
||||
}
|
||||
updated := notifications.WebhookConfig{
|
||||
ID: "wh1",
|
||||
URL: "https://example.com/new",
|
||||
Headers: map[string]string{"Auth": "***REDACTED***"},
|
||||
CustomFields: map[string]string{"Key": "***REDACTED***"},
|
||||
}
|
||||
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{existing}).Once()
|
||||
mockManager.On("ValidateWebhookURL", "https://example.com/new").Return(nil).Once()
|
||||
mockManager.On("UpdateWebhook", "wh1", mock.MatchedBy(func(w notifications.WebhookConfig) bool {
|
||||
return w.Headers["Auth"] == "secret" && w.CustomFields["Key"] == "value"
|
||||
})).Return(nil).Once()
|
||||
mockManager.On("GetWebhooks").Return([]notifications.WebhookConfig{updated}).Once()
|
||||
mockPersistence.On("SaveWebhooks", mock.Anything).Return(nil).Once()
|
||||
|
||||
body, _ := json.Marshal(updated)
|
||||
req := httptest.NewRequest("PUT", "/api/notifications/webhooks/wh1", bytes.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
h.UpdateWebhook(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
58
internal/api/profile_suggestions_test.go
Normal file
58
internal/api/profile_suggestions_test.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseAISuggestion(t *testing.T) {
|
||||
payload := "Here is your suggestion:\n```json\n" +
|
||||
`{"name":"Media Server","config":{"enable_docker":true},"rationale":["Uses Docker"]}` +
|
||||
"\n```\nThanks!"
|
||||
|
||||
suggestion, err := parseAISuggestion(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("parseAISuggestion error: %v", err)
|
||||
}
|
||||
if suggestion.Name != "Media Server" {
|
||||
t.Fatalf("unexpected name: %s", suggestion.Name)
|
||||
}
|
||||
if suggestion.Description == "" {
|
||||
t.Fatal("expected default description")
|
||||
}
|
||||
if suggestion.Config["enable_docker"] != true {
|
||||
t.Fatalf("unexpected config: %+v", suggestion.Config)
|
||||
}
|
||||
if len(suggestion.Rationale) != 1 {
|
||||
t.Fatalf("unexpected rationale: %+v", suggestion.Rationale)
|
||||
}
|
||||
|
||||
payload = `{"name":"Test","description":"Has braces { in text"}`
|
||||
suggestion, err = parseAISuggestion(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("parseAISuggestion error: %v", err)
|
||||
}
|
||||
if suggestion.Name != "Test" || suggestion.Description != "Has braces { in text" {
|
||||
t.Fatalf("unexpected suggestion: %+v", suggestion)
|
||||
}
|
||||
|
||||
if _, err := parseAISuggestion("no json here"); err == nil {
|
||||
t.Fatal("expected error without JSON")
|
||||
}
|
||||
if _, err := parseAISuggestion("{\"name\":"); err == nil {
|
||||
t.Fatal("expected error for incomplete JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigSchemaDoc(t *testing.T) {
|
||||
doc := buildConfigSchemaDoc()
|
||||
if doc == "" {
|
||||
t.Fatal("expected schema doc")
|
||||
}
|
||||
if !strings.Contains(doc, "- interval (duration string") {
|
||||
t.Fatalf("expected interval key in doc:\n%s", doc)
|
||||
}
|
||||
if !strings.Contains(doc, "enable_docker") {
|
||||
t.Fatalf("expected enable_docker key in doc:\n%s", doc)
|
||||
}
|
||||
}
|
||||
215
internal/api/saml_service_test.go
Normal file
215
internal/api/saml_service_test.go
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func generateTestCert(t *testing.T) (certPEM, keyPEM []byte, key *rsa.PrivateKey) {
|
||||
t.Helper()
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatalf("create cert: %v", err)
|
||||
}
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||
return certPEM, keyPEM, priv
|
||||
}
|
||||
|
||||
func TestParseIDPMetadataXML(t *testing.T) {
|
||||
xml := `<?xml version="1.0"?>
|
||||
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp-1">
|
||||
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp/sso"/>
|
||||
</IDPSSODescriptor>
|
||||
</EntityDescriptor>`
|
||||
|
||||
metadata, err := parseIDPMetadataXML([]byte(xml))
|
||||
if err != nil {
|
||||
t.Fatalf("parse metadata: %v", err)
|
||||
}
|
||||
if metadata.EntityID != "idp-1" {
|
||||
t.Fatalf("unexpected entity id: %s", metadata.EntityID)
|
||||
}
|
||||
|
||||
wrapped := `<?xml version="1.0"?>
|
||||
<EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata">
|
||||
<EntityDescriptor entityID="idp-2">
|
||||
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"></IDPSSODescriptor>
|
||||
</EntityDescriptor>
|
||||
</EntitiesDescriptor>`
|
||||
metadata, err = parseIDPMetadataXML([]byte(wrapped))
|
||||
if err != nil {
|
||||
t.Fatalf("parse wrapped metadata: %v", err)
|
||||
}
|
||||
if metadata.EntityID != "idp-2" {
|
||||
t.Fatalf("unexpected entity id: %s", metadata.EntityID)
|
||||
}
|
||||
|
||||
if _, err := parseIDPMetadataXML([]byte("<bad")); err == nil {
|
||||
t.Fatal("expected error for invalid xml")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildManualMetadataAndCertificate(t *testing.T) {
|
||||
cfg := &config.SAMLProviderConfig{}
|
||||
service := &SAMLService{config: cfg}
|
||||
if _, err := service.buildManualMetadata(); err == nil {
|
||||
t.Fatal("expected error for missing SSO URL")
|
||||
}
|
||||
|
||||
cfg.IDPSSOURL = "http://idp/sso"
|
||||
cfg.IDPSLOUrl = "http://idp/slo"
|
||||
cfg.IDPIssuer = "issuer"
|
||||
certPEM, _, _ := generateTestCert(t)
|
||||
cfg.IDPCertificate = string(certPEM)
|
||||
|
||||
metadata, err := service.buildManualMetadata()
|
||||
if err != nil {
|
||||
t.Fatalf("build metadata: %v", err)
|
||||
}
|
||||
if metadata.EntityID != "issuer" {
|
||||
t.Fatalf("unexpected entity id: %s", metadata.EntityID)
|
||||
}
|
||||
if len(metadata.IDPSSODescriptors) == 0 || len(metadata.IDPSSODescriptors[0].SingleLogoutServices) == 0 {
|
||||
t.Fatal("expected SLO service in metadata")
|
||||
}
|
||||
if len(metadata.IDPSSODescriptors[0].KeyDescriptors) == 0 {
|
||||
t.Fatal("expected key descriptor with certificate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSPCredentials(t *testing.T) {
|
||||
cfg := &config.SAMLProviderConfig{}
|
||||
service := &SAMLService{config: cfg}
|
||||
if _, _, err := service.loadSPCredentials(); err == nil {
|
||||
t.Fatal("expected error for missing cert/key")
|
||||
}
|
||||
|
||||
certPEM, keyPEM, _ := generateTestCert(t)
|
||||
cfg.SPCertificate = string(certPEM)
|
||||
if _, _, err := service.loadSPCredentials(); err == nil {
|
||||
t.Fatal("expected error for missing key")
|
||||
}
|
||||
cfg.SPCertificate = "bad"
|
||||
cfg.SPPrivateKey = "bad"
|
||||
if _, _, err := service.loadSPCredentials(); err == nil {
|
||||
t.Fatal("expected error for invalid pem")
|
||||
}
|
||||
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate ec key: %v", err)
|
||||
}
|
||||
pkcs8, err := x509.MarshalPKCS8PrivateKey(ecKey)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal pkcs8: %v", err)
|
||||
}
|
||||
cfg.SPCertificate = string(certPEM)
|
||||
cfg.SPPrivateKey = string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: pkcs8}))
|
||||
if _, _, err := service.loadSPCredentials(); err == nil {
|
||||
t.Fatal("expected error for non-rsa key")
|
||||
}
|
||||
|
||||
cfg.SPPrivateKey = string(keyPEM)
|
||||
cert, key, err := service.loadSPCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("load credentials: %v", err)
|
||||
}
|
||||
if cert == nil || key == nil {
|
||||
t.Fatal("expected cert and key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSAMLServiceBasicFlows(t *testing.T) {
|
||||
certPEM, _, _ := generateTestCert(t)
|
||||
cfg := &config.SAMLProviderConfig{
|
||||
IDPMetadataXML: `<?xml version="1.0"?>
|
||||
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp">
|
||||
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp/sso"/>
|
||||
<SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://idp/slo"/>
|
||||
</IDPSSODescriptor>
|
||||
</EntityDescriptor>`,
|
||||
IDPCertificate: string(certPEM),
|
||||
}
|
||||
|
||||
service, err := NewSAMLService(context.Background(), "idp", cfg, "http://localhost:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("new service: %v", err)
|
||||
}
|
||||
|
||||
url, err := service.MakeAuthRequest("")
|
||||
if err != nil || !strings.Contains(url, "SAMLRequest") {
|
||||
t.Fatalf("unexpected auth url: %v %s", err, url)
|
||||
}
|
||||
|
||||
if _, err := service.GetMetadata(); err != nil {
|
||||
t.Fatalf("metadata error: %v", err)
|
||||
}
|
||||
|
||||
logoutURL, err := service.MakeLogoutRequest("user", "sess")
|
||||
if err != nil || !strings.Contains(logoutURL, "SAMLRequest") {
|
||||
t.Fatalf("unexpected logout url: %v %s", err, logoutURL)
|
||||
}
|
||||
|
||||
service = &SAMLService{config: &config.SAMLProviderConfig{}}
|
||||
if _, err := service.MakeAuthRequest(""); err == nil {
|
||||
t.Fatal("expected error when sp missing")
|
||||
}
|
||||
if _, err := service.GetMetadata(); err == nil {
|
||||
t.Fatal("expected error when sp missing")
|
||||
}
|
||||
if _, err := service.MakeLogoutRequest("user", "sess"); err == nil {
|
||||
t.Fatal("expected error when sp missing")
|
||||
}
|
||||
if err := service.RefreshMetadata(context.Background()); err == nil {
|
||||
t.Fatal("expected refresh error without url")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchMetadataFromURL(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<?xml version="1.0"?>
|
||||
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="idp-url">
|
||||
<IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"></IDPSSODescriptor>
|
||||
</EntityDescriptor>`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.SAMLProviderConfig{IDPMetadataURL: server.URL}
|
||||
service := &SAMLService{config: cfg, httpClient: newSAMLHTTPClient()}
|
||||
metadata, err := service.fetchIDPMetadataFromURL(context.Background(), server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("fetch metadata: %v", err)
|
||||
}
|
||||
if metadata.EntityID != "idp-url" {
|
||||
t.Fatalf("unexpected entity id: %s", metadata.EntityID)
|
||||
}
|
||||
}
|
||||
48
internal/api/sensor_proxy_gate_test.go
Normal file
48
internal/api/sensor_proxy_gate_test.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func TestSensorProxyGate(t *testing.T) {
|
||||
r := &Router{config: &config.Config{EnableSensorProxy: true}}
|
||||
if !r.isSensorProxyEnabled() {
|
||||
t.Fatal("expected sensor proxy enabled")
|
||||
}
|
||||
|
||||
allowed := r.requireSensorProxyEnabled(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sensor", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
allowed(resp, req)
|
||||
if resp.Code != http.StatusNoContent {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
|
||||
r.config.EnableSensorProxy = false
|
||||
denied := r.requireSensorProxyEnabled(func(http.ResponseWriter, *http.Request) {
|
||||
t.Fatal("handler should not be called")
|
||||
})
|
||||
resp = httptest.NewRecorder()
|
||||
denied(resp, req)
|
||||
if resp.Code != http.StatusGone {
|
||||
t.Fatalf("unexpected status: %d", resp.Code)
|
||||
}
|
||||
if warning := resp.Header().Get("Warning"); warning == "" {
|
||||
t.Fatal("expected Warning header")
|
||||
}
|
||||
|
||||
var apiErr APIError
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &apiErr); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if apiErr.Code != "sensor_proxy_disabled" {
|
||||
t.Fatalf("unexpected error code: %s", apiErr.Code)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,463 +1,390 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/updates"
|
||||
)
|
||||
|
||||
func TestUpdateHandlers_HandleCheckUpdates_MethodNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
// MockUpdateManager implements UpdateManager interface for testing
|
||||
type MockUpdateManager struct {
|
||||
CheckForUpdatesFunc func(ctx context.Context, channel string) (*updates.UpdateInfo, error)
|
||||
ApplyUpdateFunc func(ctx context.Context, req updates.ApplyUpdateRequest) error
|
||||
GetStatusFunc func() updates.UpdateStatus
|
||||
GetSSECachedStatusFunc func() (updates.UpdateStatus, time.Time)
|
||||
AddSSEClientFunc func(w http.ResponseWriter, clientID string) *updates.SSEClient
|
||||
RemoveSSEClientFunc func(clientID string)
|
||||
}
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
func (m *MockUpdateManager) CheckForUpdatesWithChannel(ctx context.Context, channel string) (*updates.UpdateInfo, error) {
|
||||
if m.CheckForUpdatesFunc != nil {
|
||||
return m.CheckForUpdatesFunc(ctx, channel)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/updates/check", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
func (m *MockUpdateManager) ApplyUpdate(ctx context.Context, req updates.ApplyUpdateRequest) error {
|
||||
if m.ApplyUpdateFunc != nil {
|
||||
return m.ApplyUpdateFunc(ctx, req)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
handlers.HandleCheckUpdates(rec, req)
|
||||
func (m *MockUpdateManager) GetStatus() updates.UpdateStatus {
|
||||
if m.GetStatusFunc != nil {
|
||||
return m.GetStatusFunc()
|
||||
}
|
||||
return updates.UpdateStatus{}
|
||||
}
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
func (m *MockUpdateManager) GetSSECachedStatus() (updates.UpdateStatus, time.Time) {
|
||||
if m.GetSSECachedStatusFunc != nil {
|
||||
return m.GetSSECachedStatusFunc()
|
||||
}
|
||||
return updates.UpdateStatus{}, time.Time{}
|
||||
}
|
||||
|
||||
func (m *MockUpdateManager) AddSSEClient(w http.ResponseWriter, clientID string) *updates.SSEClient {
|
||||
if m.AddSSEClientFunc != nil {
|
||||
return m.AddSSEClientFunc(w, clientID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUpdateManager) RemoveSSEClient(clientID string) {
|
||||
if m.RemoveSSEClientFunc != nil {
|
||||
m.RemoveSSEClientFunc(clientID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleApplyUpdate_MethodNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestHandleCheckUpdates_Success(t *testing.T) {
|
||||
mockManager := &MockUpdateManager{
|
||||
CheckForUpdatesFunc: func(ctx context.Context, channel string) (*updates.UpdateInfo, error) {
|
||||
return &updates.UpdateInfo{
|
||||
Available: true,
|
||||
LatestVersion: "v1.2.3",
|
||||
CurrentVersion: "v1.0.0",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
h := NewUpdateHandlers(mockManager, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/updates/check", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/updates/apply", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.HandleCheckUpdates(w, r)
|
||||
|
||||
handlers.HandleApplyUpdate(rec, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
var info updates.UpdateInfo
|
||||
json.NewDecoder(w.Body).Decode(&info)
|
||||
if !info.Available || info.LatestVersion != "v1.2.3" {
|
||||
t.Errorf("Unexpected response: %+v", info)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleApplyUpdate_InvalidJSONBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestHandleCheckUpdates_Error(t *testing.T) {
|
||||
mockManager := &MockUpdateManager{
|
||||
CheckForUpdatesFunc: func(ctx context.Context, channel string) (*updates.UpdateInfo, error) {
|
||||
return nil, errors.New("github down")
|
||||
},
|
||||
}
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
h := NewUpdateHandlers(mockManager, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/updates/check", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/updates/apply", strings.NewReader("invalid json"))
|
||||
rec := httptest.NewRecorder()
|
||||
h.HandleCheckUpdates(w, r)
|
||||
|
||||
handlers.HandleApplyUpdate(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status 500, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleApplyUpdate_MissingDownloadURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestHandleApplyUpdate_Success(t *testing.T) {
|
||||
mockManager := &MockUpdateManager{
|
||||
ApplyUpdateFunc: func(ctx context.Context, req updates.ApplyUpdateRequest) error {
|
||||
if req.DownloadURL != "http://example.com/update.tar.gz" {
|
||||
return errors.New("wrong url")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
h := NewUpdateHandlers(mockManager, nil)
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"downloadUrl": "http://example.com/update.tar.gz"}`
|
||||
r := httptest.NewRequest("POST", "/updates/apply", strings.NewReader(body))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/updates/apply", strings.NewReader(`{}`))
|
||||
rec := httptest.NewRecorder()
|
||||
h.HandleApplyUpdate(w, r)
|
||||
|
||||
handlers.HandleApplyUpdate(rec, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
|
||||
// Note: ApplyUpdate runs in background, so we just check it was accepted
|
||||
}
|
||||
|
||||
func TestHandleUpdateStatus_Fresh(t *testing.T) {
|
||||
mockManager := &MockUpdateManager{
|
||||
GetStatusFunc: func() updates.UpdateStatus {
|
||||
return updates.UpdateStatus{Status: "idle"}
|
||||
},
|
||||
}
|
||||
|
||||
h := NewUpdateHandlers(mockManager, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/updates/status", nil)
|
||||
|
||||
h.HandleUpdateStatus(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("X-Cache") != "MISS" {
|
||||
t.Error("Expected X-Cache: MISS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleUpdateStatus_MethodNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestHandleUpdateStatus_Cached(t *testing.T) {
|
||||
mockManager := &MockUpdateManager{
|
||||
GetStatusFunc: func() updates.UpdateStatus {
|
||||
return updates.UpdateStatus{Status: "fresh"}
|
||||
},
|
||||
GetSSECachedStatusFunc: func() (updates.UpdateStatus, time.Time) {
|
||||
return updates.UpdateStatus{Status: "cached"}, time.Now()
|
||||
},
|
||||
}
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
h := NewUpdateHandlers(mockManager, nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/updates/status", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
// First request - MISS
|
||||
r1 := httptest.NewRequest("GET", "/updates/status", nil)
|
||||
r1.RemoteAddr = "1.2.3.4:1234"
|
||||
w1 := httptest.NewRecorder()
|
||||
h.HandleUpdateStatus(w1, r1)
|
||||
|
||||
handlers.HandleUpdateStatus(rec, req)
|
||||
if w1.Header().Get("X-Cache") != "MISS" {
|
||||
t.Error("Expected first request to be MISS")
|
||||
}
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
// Second request immediately after - HIT
|
||||
r2 := httptest.NewRequest("GET", "/updates/status", nil)
|
||||
r2.RemoteAddr = "1.2.3.4:5678" // Same IP
|
||||
w2 := httptest.NewRecorder()
|
||||
h.HandleUpdateStatus(w2, r2)
|
||||
|
||||
if w2.Header().Get("X-Cache") != "HIT" {
|
||||
t.Error("Expected second request to be HIT")
|
||||
}
|
||||
|
||||
var status updates.UpdateStatus
|
||||
json.NewDecoder(w2.Body).Decode(&status)
|
||||
if status.Status != "cached" {
|
||||
t.Errorf("Expected cached status, got %s", status.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleUpdateStream_MethodNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
func TestHandleUpdateStream(t *testing.T) {
|
||||
mockManager := &MockUpdateManager{
|
||||
AddSSEClientFunc: func(w http.ResponseWriter, clientID string) *updates.SSEClient {
|
||||
return &updates.SSEClient{
|
||||
ID: clientID,
|
||||
Done: make(chan bool),
|
||||
Flusher: w.(http.Flusher),
|
||||
}
|
||||
},
|
||||
RemoveSSEClientFunc: func(clientID string) {},
|
||||
}
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
h := NewUpdateHandlers(mockManager, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/updates/stream", nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/updates/stream", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
// Create context that we can cancel to simulate client disconnect
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handlers.HandleUpdateStream(rec, req)
|
||||
// This blocks until context cancel, so run in goroutine
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
h.HandleUpdateStream(w, r)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
// Give it a moment to establish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Cancel/Disconnect
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("HandleUpdateStream didn't return after context cancel")
|
||||
}
|
||||
|
||||
if w.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Error("Expected text/event-stream content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleGetUpdatePlan_MethodNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/updates/plan", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleGetUpdatePlan(rec, req)
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleListUpdateHistory_MethodNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/updates/history", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleListUpdateHistory(rec, req)
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleListUpdateHistory_NoHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handlers := &UpdateHandlers{
|
||||
history: nil,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/updates/history", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleListUpdateHistory(rec, req)
|
||||
|
||||
if rec.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status %d, got %d", http.StatusServiceUnavailable, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleListUpdateHistory_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
func TestHandleListUpdateHistory(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
history, _ := updates.NewUpdateHistory(tmp)
|
||||
|
||||
handlers := &UpdateHandlers{
|
||||
history: history,
|
||||
// Pre-populate history
|
||||
history.CreateEntry(context.Background(), updates.UpdateHistoryEntry{
|
||||
EventID: "test-entry",
|
||||
Status: updates.StatusSuccess,
|
||||
VersionTo: "v1.2.3",
|
||||
})
|
||||
|
||||
h := NewUpdateHandlers(&MockUpdateManager{}, history)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/updates/history", nil)
|
||||
|
||||
h.HandleListUpdateHistory(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/updates/history", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleListUpdateHistory(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Errorf("expected content-type application/json, got %q", ct)
|
||||
var entries []updates.UpdateHistoryEntry
|
||||
json.NewDecoder(w.Body).Decode(&entries)
|
||||
if len(entries) != 1 {
|
||||
t.Errorf("Expected 1 entry, got %d", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleGetUpdateHistoryEntry_MethodNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handlers := &UpdateHandlers{}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/updates/history/123", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleGetUpdateHistoryEntry(rec, req)
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleGetUpdateHistoryEntry_NoHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handlers := &UpdateHandlers{
|
||||
history: nil,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/updates/history/123", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleGetUpdateHistoryEntry(rec, req)
|
||||
|
||||
if rec.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status %d, got %d", http.StatusServiceUnavailable, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleGetUpdateHistoryEntry_MissingID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
func TestHandleGetUpdateHistoryEntry(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
history, _ := updates.NewUpdateHistory(tmp)
|
||||
|
||||
handlers := &UpdateHandlers{
|
||||
history: history,
|
||||
// Pre-populate history
|
||||
history.CreateEntry(context.Background(), updates.UpdateHistoryEntry{
|
||||
EventID: "test-entry-1",
|
||||
Status: updates.StatusSuccess,
|
||||
VersionTo: "v1.2.3",
|
||||
})
|
||||
|
||||
h := NewUpdateHandlers(&MockUpdateManager{}, history)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/updates/history/entry?id=test-entry-1", nil)
|
||||
|
||||
h.HandleGetUpdateHistoryEntry(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/updates/history/entry", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleGetUpdateHistoryEntry(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateHandlers_HandleGetUpdateHistoryEntry_NotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmp := t.TempDir()
|
||||
history, _ := updates.NewUpdateHistory(tmp)
|
||||
|
||||
handlers := &UpdateHandlers{
|
||||
history: history,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/updates/history/entry?id=nonexistent", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlers.HandleGetUpdateHistoryEntry(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status %d, got %d", http.StatusNotFound, rec.Code)
|
||||
var entry updates.UpdateHistoryEntry
|
||||
json.NewDecoder(w.Body).Decode(&entry)
|
||||
if entry.EventID != "test-entry-1" {
|
||||
t.Errorf("Expected EventID test-entry-1, got %s", entry.EventID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
// Re-include the IP tests as they were useful
|
||||
tests := []struct {
|
||||
name string
|
||||
xff string // X-Forwarded-For header
|
||||
xri string // X-Real-IP header
|
||||
remoteAddr string // Request.RemoteAddr
|
||||
expectedIP string
|
||||
}{
|
||||
// X-Forwarded-For takes priority
|
||||
{
|
||||
name: "XFF with valid IPv4",
|
||||
xff: "192.168.1.100",
|
||||
xri: "",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "XFF with valid IPv6",
|
||||
xff: "2001:db8::1",
|
||||
xri: "",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "2001:db8::1",
|
||||
},
|
||||
{
|
||||
name: "XFF with IPv6 loopback",
|
||||
xff: "::1",
|
||||
xri: "",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "::1",
|
||||
},
|
||||
|
||||
// X-Real-IP fallback when XFF not valid
|
||||
{
|
||||
name: "XRI with valid IPv4",
|
||||
xff: "",
|
||||
xri: "172.16.0.50",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "172.16.0.50",
|
||||
},
|
||||
{
|
||||
name: "XRI with valid IPv6",
|
||||
xff: "",
|
||||
xri: "fe80::1",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "fe80::1",
|
||||
},
|
||||
{
|
||||
name: "XRI preferred when XFF invalid",
|
||||
xff: "invalid-ip",
|
||||
xri: "192.168.1.1",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
|
||||
// RemoteAddr fallback
|
||||
{
|
||||
name: "RemoteAddr with port",
|
||||
xff: "",
|
||||
xri: "",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "10.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr IPv6 with port",
|
||||
xff: "",
|
||||
xri: "",
|
||||
remoteAddr: "[::1]:12345",
|
||||
expectedIP: "::1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr without port",
|
||||
xff: "",
|
||||
xri: "",
|
||||
remoteAddr: "10.0.0.1",
|
||||
expectedIP: "10.0.0.1",
|
||||
},
|
||||
|
||||
// Invalid headers fall through
|
||||
{
|
||||
name: "XFF invalid falls to XRI",
|
||||
xff: "not-an-ip",
|
||||
xri: "192.168.1.1",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "Both headers invalid falls to RemoteAddr",
|
||||
xff: "not-an-ip",
|
||||
xri: "also-not-an-ip",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "10.0.0.1",
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "Empty XFF ignored",
|
||||
xff: "",
|
||||
xri: "192.168.1.1",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "All empty uses RemoteAddr",
|
||||
xff: "",
|
||||
xri: "",
|
||||
remoteAddr: "127.0.0.1:8080",
|
||||
expectedIP: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "Loopback IPv4",
|
||||
xff: "127.0.0.1",
|
||||
xri: "",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "127.0.0.1",
|
||||
},
|
||||
|
||||
// Note: The current implementation has a bug with multiple IPs in XFF
|
||||
// It tries to parse the entire string as a single IP, which fails
|
||||
// This test documents current behavior, not ideal behavior
|
||||
{
|
||||
name: "XFF with multiple IPs - current behavior",
|
||||
xff: "192.168.1.100, 10.0.0.1",
|
||||
xri: "172.16.0.1",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "172.16.0.1", // Falls through because "192.168.1.100, 10.0.0.1" is not a valid single IP
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &http.Request{
|
||||
Header: make(http.Header),
|
||||
RemoteAddr: tt.remoteAddr,
|
||||
}
|
||||
|
||||
if tt.xff != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xff)
|
||||
}
|
||||
if tt.xri != "" {
|
||||
req.Header.Set("X-Real-IP", tt.xri)
|
||||
}
|
||||
|
||||
result := getClientIP(req)
|
||||
if result != tt.expectedIP {
|
||||
t.Errorf("getClientIP() = %q, want %q", result, tt.expectedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP_NilHeaders(t *testing.T) {
|
||||
// Test with a request that has nil headers (edge case)
|
||||
req := &http.Request{
|
||||
Header: nil,
|
||||
RemoteAddr: "10.0.0.1:12345",
|
||||
}
|
||||
|
||||
// This will panic if headers aren't handled correctly
|
||||
// The function should gracefully handle nil headers
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("getClientIP panicked with nil headers: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Note: With nil headers, Header.Get() will panic
|
||||
// This test documents that the function expects non-nil headers
|
||||
// If this test panics, it's documenting current behavior
|
||||
_ = getClientIP(req)
|
||||
}
|
||||
|
||||
func TestGetClientIP_HeaderCaseSensitivity(t *testing.T) {
|
||||
// HTTP headers are case-insensitive per RFC 7230
|
||||
// http.Header.Get handles this automatically
|
||||
tests := []struct {
|
||||
name string
|
||||
headerKey string
|
||||
headerVal string
|
||||
remoteAddr string
|
||||
expectedIP string
|
||||
headers map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "lowercase x-forwarded-for",
|
||||
headerKey: "x-forwarded-for",
|
||||
headerVal: "192.168.1.100",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "lowercase x-real-ip",
|
||||
headerKey: "x-real-ip",
|
||||
headerVal: "192.168.1.100",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "mixed case X-Forwarded-FOR",
|
||||
headerKey: "X-Forwarded-FOR",
|
||||
headerVal: "192.168.1.100",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
expectedIP: "192.168.1.100",
|
||||
},
|
||||
{"RemoteAddr", "1.2.3.4:1234", nil, "1.2.3.4"},
|
||||
{"XFF", "1.1.1.1:1234", map[string]string{"X-Forwarded-For": "2.2.2.2"}, "2.2.2.2"},
|
||||
{"X-Real-IP", "1.1.1.1:1234", map[string]string{"X-Real-IP": "3.3.3.3"}, "3.3.3.3"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &http.Request{
|
||||
Header: make(http.Header),
|
||||
RemoteAddr: tt.remoteAddr,
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.RemoteAddr = tt.remoteAddr
|
||||
for k, v := range tt.headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
req.Header.Set(tt.headerKey, tt.headerVal)
|
||||
|
||||
result := getClientIP(req)
|
||||
if result != tt.expectedIP {
|
||||
t.Errorf("getClientIP() = %q, want %q", result, tt.expectedIP)
|
||||
// getClientIP is strict internal but exposed via tests in same package
|
||||
ip := getClientIP(r)
|
||||
if ip != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoCleanupRateLimits(t *testing.T) {
|
||||
h := NewUpdateHandlers(nil, nil)
|
||||
now := time.Now()
|
||||
h.statusRateLimits["old"] = now.Add(-15 * time.Minute)
|
||||
h.statusRateLimits["new"] = now.Add(-5 * time.Minute)
|
||||
|
||||
h.doCleanupRateLimits(now)
|
||||
|
||||
if _, ok := h.statusRateLimits["old"]; ok {
|
||||
t.Error("Old entry not cleaned up")
|
||||
}
|
||||
if _, ok := h.statusRateLimits["new"]; !ok {
|
||||
t.Error("New entry cleaned up prematurely")
|
||||
}
|
||||
}
|
||||
|
||||
type mockUpdater struct {
|
||||
updates.Updater
|
||||
prepareFunc func(ctx context.Context, req updates.UpdateRequest) (*updates.UpdatePlan, error)
|
||||
}
|
||||
|
||||
func (m *mockUpdater) PrepareUpdate(ctx context.Context, req updates.UpdateRequest) (*updates.UpdatePlan, error) {
|
||||
return m.prepareFunc(ctx, req)
|
||||
}
|
||||
|
||||
func TestHandleGetUpdatePlan(t *testing.T) {
|
||||
// Set mock mode so GetCurrentVersion returns "mock"
|
||||
t.Setenv("PULSE_MOCK_MODE", "true")
|
||||
|
||||
mu := &mockUpdater{
|
||||
prepareFunc: func(ctx context.Context, req updates.UpdateRequest) (*updates.UpdatePlan, error) {
|
||||
return &updates.UpdatePlan{
|
||||
Instructions: []string{"test"},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
h := NewUpdateHandlers(nil, nil)
|
||||
h.registry.Register("mock", mu)
|
||||
|
||||
// Test missing version
|
||||
r := httptest.NewRequest("GET", "/api/updates/plan", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.HandleGetUpdatePlan(w, r)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test success
|
||||
r = httptest.NewRequest("GET", "/api/updates/plan?version=v1.2.3", nil)
|
||||
w = httptest.NewRecorder()
|
||||
h.HandleGetUpdatePlan(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var plan updates.UpdatePlan
|
||||
json.NewDecoder(w.Body).Decode(&plan)
|
||||
if len(plan.Instructions) != 1 {
|
||||
t.Errorf("Expected 1 instruction, got %d", len(plan.Instructions))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
119
internal/dockeragent/signature_test.go
Normal file
119
internal/dockeragent/signature_test.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package dockeragent
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVerifySignature(t *testing.T) {
|
||||
originalKeys := trustedPublicKeysPEM
|
||||
defer func() {
|
||||
trustedPublicKeysPEM = originalKeys
|
||||
}()
|
||||
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
pubBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal public key: %v", err)
|
||||
}
|
||||
trustedPublicKeysPEM = []string{string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}))}
|
||||
|
||||
data := []byte("payload")
|
||||
sig := ed25519.Sign(privateKey, data)
|
||||
signature := base64.StdEncoding.EncodeToString(sig)
|
||||
|
||||
if err := verifySignature(data, signature); err != nil {
|
||||
t.Fatalf("expected signature to verify: %v", err)
|
||||
}
|
||||
|
||||
if err := verifySignature(data, ""); err == nil {
|
||||
t.Fatal("expected missing signature error")
|
||||
}
|
||||
if err := verifySignature(data, "!!!"); err == nil {
|
||||
t.Fatal("expected invalid base64 error")
|
||||
}
|
||||
|
||||
// Invalid signature
|
||||
invalidSig := base64.StdEncoding.EncodeToString([]byte("bad"))
|
||||
if err := verifySignature(data, invalidSig); err == nil {
|
||||
t.Fatal("expected invalid signature error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifySignatureInvalidKeys(t *testing.T) {
|
||||
originalKeys := trustedPublicKeysPEM
|
||||
defer func() {
|
||||
trustedPublicKeysPEM = originalKeys
|
||||
}()
|
||||
|
||||
trustedPublicKeysPEM = []string{"not-pem"}
|
||||
if err := verifySignature([]byte("data"), base64.StdEncoding.EncodeToString([]byte("sig"))); err == nil {
|
||||
t.Fatal("expected error for invalid pem")
|
||||
}
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate rsa: %v", err)
|
||||
}
|
||||
pubBytes, err := x509.MarshalPKIXPublicKey(&rsaKey.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal rsa: %v", err)
|
||||
}
|
||||
trustedPublicKeysPEM = []string{string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}))}
|
||||
|
||||
if err := verifySignature([]byte("data"), base64.StdEncoding.EncodeToString([]byte("sig"))); err == nil {
|
||||
t.Fatal("expected error for non-ed25519 key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyFileSignature(t *testing.T) {
|
||||
originalKeys := trustedPublicKeysPEM
|
||||
defer func() {
|
||||
trustedPublicKeysPEM = originalKeys
|
||||
}()
|
||||
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
pubBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal public key: %v", err)
|
||||
}
|
||||
trustedPublicKeysPEM = []string{string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}))}
|
||||
|
||||
file := filepathJoin(t)
|
||||
data := []byte("file")
|
||||
if err := os.WriteFile(file, data, 0600); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
|
||||
sig := ed25519.Sign(privateKey, data)
|
||||
signature := base64.StdEncoding.EncodeToString(sig)
|
||||
if err := verifyFileSignature(file, signature); err != nil {
|
||||
t.Fatalf("expected file signature to verify: %v", err)
|
||||
}
|
||||
|
||||
if err := verifyFileSignature("missing", signature); err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
|
||||
if err := verifyFileSignature(file, "!!!"); err == nil {
|
||||
t.Fatal("expected base64 error")
|
||||
}
|
||||
}
|
||||
|
||||
func filepathJoin(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmp := t.TempDir()
|
||||
return tmp + "/payload"
|
||||
}
|
||||
284
internal/hostagent/agent_metrics_test.go
Normal file
284
internal/hostagent/agent_metrics_test.go
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/ceph"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/hostmetrics"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/sensors"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/smartctl"
|
||||
agentshost "github.com/rcourtman/pulse-go-rewrite/pkg/agents/host"
|
||||
"github.com/shirou/gopsutil/v4/host"
|
||||
)
|
||||
|
||||
func TestBuildReport(t *testing.T) {
|
||||
// Backup original functions
|
||||
origHostInfo := hostInfoWithContext
|
||||
origUptime := hostUptimeWithContext
|
||||
origHostMetrics := hostmetricsCollect
|
||||
origSensors := sensorsCollectPower
|
||||
origMdadm := mdadmCollectArrays
|
||||
origSmart := smartctlCollectLocal
|
||||
origNow := nowUTC
|
||||
|
||||
// Restore functions after test
|
||||
t.Cleanup(func() {
|
||||
hostInfoWithContext = origHostInfo
|
||||
hostUptimeWithContext = origUptime
|
||||
hostmetricsCollect = origHostMetrics
|
||||
sensorsCollectPower = origSensors
|
||||
mdadmCollectArrays = origMdadm
|
||||
smartctlCollectLocal = origSmart
|
||||
nowUTC = origNow
|
||||
})
|
||||
|
||||
// Setup mocks
|
||||
fixedTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
nowUTC = func() time.Time { return fixedTime }
|
||||
|
||||
hostInfoWithContext = func(ctx context.Context) (*host.InfoStat, error) {
|
||||
return &host.InfoStat{
|
||||
Hostname: "test-host",
|
||||
Uptime: 1000,
|
||||
BootTime: 1000000,
|
||||
Procs: 100,
|
||||
OS: "linux",
|
||||
Platform: "debian",
|
||||
PlatformFamily: "debian",
|
||||
PlatformVersion: "11",
|
||||
KernelVersion: "5.10.0",
|
||||
VirtualizationSystem: "kvm",
|
||||
VirtualizationRole: "guest",
|
||||
HostID: "host-id-123",
|
||||
}, nil
|
||||
}
|
||||
|
||||
hostUptimeWithContext = func(ctx context.Context) (uint64, error) {
|
||||
return 3600, nil
|
||||
}
|
||||
|
||||
hostmetricsCollect = func(ctx context.Context, diskExclude []string) (hostmetrics.Snapshot, error) {
|
||||
return hostmetrics.Snapshot{
|
||||
CPUUsagePercent: 50.0,
|
||||
Memory: agentshost.MemoryMetric{
|
||||
TotalBytes: 1000,
|
||||
UsedBytes: 500,
|
||||
Usage: 50.0,
|
||||
},
|
||||
Disks: []agentshost.Disk{
|
||||
{
|
||||
Device: "/dev/sda1",
|
||||
Mountpoint: "/",
|
||||
UsedBytes: 200,
|
||||
TotalBytes: 1000,
|
||||
Usage: 20.0,
|
||||
},
|
||||
},
|
||||
Network: []agentshost.NetworkInterface{
|
||||
{
|
||||
Name: "eth0",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Mock optional collectors with correct types
|
||||
sensorsCollectPower = func(context.Context) (*sensors.PowerData, error) {
|
||||
return &sensors.PowerData{}, nil
|
||||
}
|
||||
mdadmCollectArrays = func(context.Context) ([]agentshost.RAIDArray, error) { return nil, nil }
|
||||
smartctlCollectLocal = func(context.Context, []string) ([]smartctl.DiskSMART, error) { return nil, nil }
|
||||
|
||||
// Create Agent
|
||||
cfg := Config{
|
||||
AgentID: "agent-123",
|
||||
APIToken: "test-token", // Required
|
||||
LogLevel: -1, // Disabled
|
||||
}
|
||||
agent, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test case 1: Successful collection
|
||||
t.Run("Successful collection", func(t *testing.T) {
|
||||
report, err := agent.buildReport(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("buildReport failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify Agent Info
|
||||
if report.Agent.ID != "agent-123" {
|
||||
t.Errorf("Agent.ID = %q, want %q", report.Agent.ID, "agent-123")
|
||||
}
|
||||
if report.Agent.Version != "dev" {
|
||||
// Version is usually set by linker or defaults to 0.0.0/dev.
|
||||
// hostagent.New calls buildVersion() which might rely on version package.
|
||||
// Let's just check it is present.
|
||||
t.Logf("Agent Version: %s", report.Agent.Version)
|
||||
}
|
||||
|
||||
// Verify Host Info
|
||||
if report.Host.Hostname != "test-host" {
|
||||
t.Errorf("Host.Hostname = %q, want %q", report.Host.Hostname, "test-host")
|
||||
}
|
||||
if report.Host.UptimeSeconds != 3600 {
|
||||
t.Errorf("Host.UptimeSeconds = %d, want %d", report.Host.UptimeSeconds, 3600)
|
||||
}
|
||||
// agent.go lines 166-169:
|
||||
// osName := strings.TrimSpace(info.Platform) ... fallback to PlatformFamily
|
||||
// Our mock returns Platform: "debian"
|
||||
if report.Host.OSName != "debian" {
|
||||
t.Errorf("Host.OSName = %q, want %q", report.Host.OSName, "debian")
|
||||
}
|
||||
|
||||
// Verify Metrics
|
||||
if report.Metrics.CPUUsagePercent != 50.0 {
|
||||
t.Errorf("CPU Usage = %f, want 50.0", report.Metrics.CPUUsagePercent)
|
||||
}
|
||||
|
||||
// Verify Timestamp
|
||||
if !report.Timestamp.Equal(fixedTime) {
|
||||
t.Errorf("Timestamp = %v, want %v", report.Timestamp, fixedTime)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 2: Host info failure (should partially fail or return error depending on implementation)
|
||||
// Looking at agent.go:374, if hostInfo fails, it logs error and continues with cached/empty?
|
||||
// Actually buildReport calls agent.hostInfo (cached) if available or calls hostInfoWithContext again?
|
||||
// It seems New() populates initial hostInfo.
|
||||
// Let's test runtime failure of hostUptime.
|
||||
t.Run("Uptime failure", func(t *testing.T) {
|
||||
hostUptimeWithContext = func(ctx context.Context) (uint64, error) {
|
||||
return 0, errors.New("uptime failed")
|
||||
}
|
||||
|
||||
report, err := agent.buildReport(context.Background())
|
||||
if err != nil {
|
||||
// It might be acceptable to fail or just have 0 uptime
|
||||
// Implementation uses: uptime, err := hostUptimeWithContext(ctx)
|
||||
// if err != nil ...
|
||||
t.Logf("buildReport returned error on uptime fail: %v", err)
|
||||
} else {
|
||||
if report.Host.UptimeSeconds != 0 {
|
||||
// If logic falls back to something else?
|
||||
// If hostInfo is present, it might use that.
|
||||
t.Logf("Host Uptime reported as %d", report.Host.UptimeSeconds)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 3: RAID Array collection
|
||||
t.Run("RAID collection", func(t *testing.T) {
|
||||
mdadmCollectArrays = func(ctx context.Context) ([]agentshost.RAIDArray, error) {
|
||||
return []agentshost.RAIDArray{
|
||||
{Name: "md0", State: "clean"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
report, err := agent.buildReport(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("buildReport failed: %v", err)
|
||||
}
|
||||
|
||||
if len(report.RAID) != 1 {
|
||||
t.Errorf("Expected 1 RAID array, got %d", len(report.RAID))
|
||||
} else if report.RAID[0].Name != "md0" {
|
||||
t.Errorf("Expected RAID name md0, got %s", report.RAID[0].Name)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 4: Ceph collection
|
||||
t.Run("Ceph collection", func(t *testing.T) {
|
||||
origCeph := cephCollect
|
||||
defer func() { cephCollect = origCeph }()
|
||||
|
||||
cephCollect = func(ctx context.Context) (*ceph.ClusterStatus, error) {
|
||||
return &ceph.ClusterStatus{
|
||||
FSID: "ceph-fsid-123",
|
||||
Health: ceph.HealthStatus{
|
||||
Status: "HEALTH_OK",
|
||||
},
|
||||
MonMap: ceph.MonitorMap{
|
||||
Epoch: 100,
|
||||
NumMons: 3,
|
||||
Monitors: []ceph.Monitor{
|
||||
{Name: "a"},
|
||||
},
|
||||
},
|
||||
MgrMap: ceph.ManagerMap{
|
||||
Available: true,
|
||||
ActiveMgr: "a",
|
||||
},
|
||||
OSDMap: ceph.OSDMap{
|
||||
NumOSDs: 10,
|
||||
NumUp: 10,
|
||||
},
|
||||
PGMap: ceph.PGMap{
|
||||
UsagePercent: 55.5,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
report, err := agent.buildReport(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("buildReport failed: %v", err)
|
||||
}
|
||||
|
||||
// On non-Linux this will be nil, so check logic conditionally or skip
|
||||
// Since user is Linux, we expect it to be populated.
|
||||
if report.Ceph == nil {
|
||||
// If we are erroneously detecting non-linux in test env (e.g. valid MacOS dev machine)
|
||||
// But user says "USER's OS version is linux".
|
||||
// We can check if runtime.GOOS == "linux"
|
||||
t.Log("Ceph report is nil (likely not running on Linux or DisableCeph=true)")
|
||||
} else {
|
||||
if report.Ceph.FSID != "ceph-fsid-123" {
|
||||
t.Errorf("Expected Ceph FSID ceph-fsid-123, got %s", report.Ceph.FSID)
|
||||
}
|
||||
if report.Ceph.Health.Status != "HEALTH_OK" {
|
||||
t.Errorf("Expected Ceph status HEALTH_OK, got %s", report.Ceph.Health.Status)
|
||||
}
|
||||
if len(report.Ceph.MonMap.Monitors) != 1 {
|
||||
t.Errorf("Expected 1 monitor, got %d", len(report.Ceph.MonMap.Monitors))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 5: SMART collection
|
||||
t.Run("SMART collection", func(t *testing.T) {
|
||||
origSmart := smartctlCollectLocal
|
||||
defer func() { smartctlCollectLocal = origSmart }()
|
||||
|
||||
smartctlCollectLocal = func(_ context.Context, _ []string) ([]smartctl.DiskSMART, error) {
|
||||
return []smartctl.DiskSMART{
|
||||
{
|
||||
Device: "/dev/sda",
|
||||
Model: "TestDisk",
|
||||
Health: "PASSED",
|
||||
Temperature: 35,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
report, err := agent.buildReport(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("buildReport failed: %v", err)
|
||||
}
|
||||
|
||||
// SMART data is attached to Sensors in the report
|
||||
if len(report.Sensors.SMART) != 1 {
|
||||
t.Errorf("Expected 1 SMART disk, got %d", len(report.Sensors.SMART))
|
||||
} else {
|
||||
if report.Sensors.SMART[0].Device != "/dev/sda" {
|
||||
t.Errorf("Expected device /dev/sda, got %s", report.Sensors.SMART[0].Device)
|
||||
}
|
||||
if report.Sensors.SMART[0].Health != "PASSED" {
|
||||
t.Errorf("Expected Health=PASSED, got %s", report.Sensors.SMART[0].Health)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
130
internal/hostagent/commands_test.go
Normal file
130
internal/hostagent/commands_test.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func TestCommandClient_Run(t *testing.T) {
|
||||
// Setup mock WebSocket server
|
||||
upgrader := websocket.Upgrader{}
|
||||
|
||||
// Channels to verify interaction
|
||||
registerReceived := make(chan bool)
|
||||
commandSent := make(chan bool)
|
||||
resultReceived := make(chan bool)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 1. Expect Registration
|
||||
var msg wsMessage
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
t.Errorf("read registration failed: %v", err)
|
||||
return
|
||||
}
|
||||
if msg.Type != msgTypeAgentRegister {
|
||||
t.Errorf("expected register msg, got %s", msg.Type)
|
||||
return
|
||||
}
|
||||
registerReceived <- true
|
||||
|
||||
// 2. Send Registration Success
|
||||
respPayload, _ := json.Marshal(registeredPayload{Success: true})
|
||||
conn.WriteJSON(wsMessage{
|
||||
Type: msgTypeRegistered,
|
||||
Timestamp: time.Now(),
|
||||
Payload: respPayload,
|
||||
})
|
||||
|
||||
// 3. Send Execute Command
|
||||
cmdPayload, _ := json.Marshal(executeCommandPayload{
|
||||
RequestID: "cmd-1",
|
||||
Command: "echo hello",
|
||||
Timeout: 5,
|
||||
})
|
||||
conn.WriteJSON(wsMessage{
|
||||
Type: msgTypeExecuteCmd,
|
||||
Timestamp: time.Now(),
|
||||
Payload: cmdPayload,
|
||||
})
|
||||
commandSent <- true
|
||||
|
||||
// 4. Expect Result
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
t.Errorf("read result failed: %v", err)
|
||||
return
|
||||
}
|
||||
if msg.Type != msgTypeCommandResult {
|
||||
t.Errorf("expected result msg, got %s", msg.Type)
|
||||
return
|
||||
}
|
||||
resultReceived <- true
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Mock execCommandContext
|
||||
origExec := execCommandContext
|
||||
defer func() { execCommandContext = origExec }()
|
||||
|
||||
execCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd {
|
||||
// Just run real echo for simplicity, or mock using test helper
|
||||
// echo is safe and standard
|
||||
return exec.CommandContext(ctx, name, args...)
|
||||
}
|
||||
|
||||
logger := zerolog.Nop()
|
||||
cfg := Config{
|
||||
PulseURL: server.URL,
|
||||
APIToken: "test-token",
|
||||
Logger: &logger,
|
||||
}
|
||||
client := NewCommandClient(cfg, "agent-1", "host-1", "linux", "v1")
|
||||
|
||||
// Run client in background
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
client.Run(ctx)
|
||||
}()
|
||||
|
||||
// Verify sequence
|
||||
select {
|
||||
case <-registerReceived:
|
||||
t.Log("Registration received")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for registration")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-commandSent:
|
||||
t.Log("Command sent")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for command send")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-resultReceived:
|
||||
t.Log("Result received")
|
||||
case <-time.After(5 * time.Second): // Allow execution time
|
||||
t.Fatal("Timeout waiting for result")
|
||||
}
|
||||
|
||||
// Terminate
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
16
internal/hostagent/network_helpers_test.go
Normal file
16
internal/hostagent/network_helpers_test.go
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
package hostagent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsLikelyVirtualInterfaceName(t *testing.T) {
|
||||
virtual := []string{"", "lo", "docker0", "veth123", "br-abc", "cni0", "flannel.1", "virbr0", "ztabc"}
|
||||
for _, name := range virtual {
|
||||
if !isLikelyVirtualInterfaceName(name) {
|
||||
t.Fatalf("expected %q to be virtual", name)
|
||||
}
|
||||
}
|
||||
|
||||
if isLikelyVirtualInterfaceName("eth0") {
|
||||
t.Fatal("expected eth0 to be non-virtual")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
|
@ -248,3 +255,599 @@ func TestGetHostURL(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenValue(t *testing.T) {
|
||||
setup := &ProxmoxSetup{
|
||||
logger: zerolog.Nop(),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "parses standard table output",
|
||||
output: `
|
||||
┌──────────────┬──────────────────────────────────────┐
|
||||
│ key │ value │
|
||||
╞══════════════╪══════════════════════════════════════╡
|
||||
│ full-tokenid │ pulse-monitor@pam!pulse-monitor │
|
||||
├──────────────┼──────────────────────────────────────┤
|
||||
│ info │ {"privsep":1} │
|
||||
├──────────────┼──────────────────────────────────────┤
|
||||
│ value │ 7c5709fb-6aee-4c32-8b9f-5c2656912345 │
|
||||
└──────────────┴──────────────────────────────────────┘
|
||||
`,
|
||||
expected: "7c5709fb-6aee-4c32-8b9f-5c2656912345",
|
||||
},
|
||||
{
|
||||
name: "parses output with extra whitespace",
|
||||
output: `
|
||||
│ value │ 7c5709fb-6aee-4c32-8b9f-5c2656912345 │
|
||||
`,
|
||||
expected: "7c5709fb-6aee-4c32-8b9f-5c2656912345",
|
||||
},
|
||||
{
|
||||
name: "returns empty on missing value",
|
||||
output: `│ other │ something │`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
output: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := setup.parseTokenValue(tt.output)
|
||||
if result != tt.expected {
|
||||
t.Errorf("parseTokenValue() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePBSTokenValue(t *testing.T) {
|
||||
setup := &ProxmoxSetup{
|
||||
logger: zerolog.Nop(),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "parses standard JSON output",
|
||||
output: `{"tokenid":"pulse-monitor@pbs!pulse-monitor","value":"pbs-api-token-value-12345"}`,
|
||||
expected: "pbs-api-token-value-12345",
|
||||
},
|
||||
{
|
||||
name: "parses JSON with extra fields",
|
||||
output: `{"other":"stuff","value":"my-secret-token","more":"stuff"}`,
|
||||
expected: "my-secret-token",
|
||||
},
|
||||
{
|
||||
name: "returns empty on invalid JSON",
|
||||
output: `not-json`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "returns empty when value missing",
|
||||
output: `{"tokenid":"foo"}`,
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := setup.parsePBSTokenValue(tt.output)
|
||||
if result != tt.expected {
|
||||
t.Errorf("parsePBSTokenValue() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupPVEToken(t *testing.T) {
|
||||
// Backup and restore
|
||||
origRunCommandOutput := runCommandOutput
|
||||
origRunCommand := runCommand
|
||||
defer func() {
|
||||
runCommandOutput = origRunCommandOutput
|
||||
runCommand = origRunCommand
|
||||
}()
|
||||
|
||||
mockOutput := `
|
||||
┌──────────────┬──────────────────────────────────────┐
|
||||
│ key │ value │
|
||||
╞══════════════╪══════════════════════════════════════╡
|
||||
│ full-tokenid │ pulse-monitor@pam!pulse-monitor │
|
||||
├──────────────┼──────────────────────────────────────┤
|
||||
│ value │ 7c5709fb-6aee-4c32-8b9f-5c2656912345 │
|
||||
└──────────────┴──────────────────────────────────────┘
|
||||
`
|
||||
var capturedCmd string
|
||||
var capturedArgs []string
|
||||
|
||||
// Mock runCommand to do nothing (success)
|
||||
runCommand = func(ctx context.Context, name string, args ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) (string, error) {
|
||||
capturedCmd = name
|
||||
capturedArgs = args
|
||||
return mockOutput, nil
|
||||
}
|
||||
|
||||
setup := NewProxmoxSetup(zerolog.Nop(), nil, "", "", "pve", "", "", false)
|
||||
id, value, err := setup.setupPVEToken(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("setupPVEToken failed: %v", err)
|
||||
}
|
||||
|
||||
if id != "pulse-monitor@pam!test-token" {
|
||||
t.Errorf("expected token ID pulse-monitor@pam!test-token, got %s", id)
|
||||
}
|
||||
if value != "7c5709fb-6aee-4c32-8b9f-5c2656912345" {
|
||||
t.Errorf("expected token value 7c5709fb-6aee-4c32-8b9f-5c2656912345, got %s", value)
|
||||
}
|
||||
|
||||
if capturedCmd != "pveum" {
|
||||
t.Errorf("expected command pveum, got %s", capturedCmd)
|
||||
}
|
||||
// Verify critical args
|
||||
foundAdd := false
|
||||
foundTokenName := false
|
||||
for _, arg := range capturedArgs {
|
||||
if arg == "add" {
|
||||
foundAdd = true
|
||||
}
|
||||
if arg == "test-token" {
|
||||
foundTokenName = true
|
||||
}
|
||||
}
|
||||
if !foundAdd || !foundTokenName {
|
||||
t.Errorf("missing critical args in %v", capturedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupPBSToken(t *testing.T) {
|
||||
// Backup and restore
|
||||
origRunCommandOutput := runCommandOutput
|
||||
origRunCommand := runCommand
|
||||
defer func() {
|
||||
runCommandOutput = origRunCommandOutput
|
||||
runCommand = origRunCommand
|
||||
}()
|
||||
|
||||
mockOutput := `{"tokenid":"pulse-monitor@pbs!test-token","value":"pbs-api-token-value-12345"}`
|
||||
|
||||
// Mock runCommand to do nothing
|
||||
runCommand = func(ctx context.Context, name string, args ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) (string, error) {
|
||||
return mockOutput, nil
|
||||
}
|
||||
|
||||
setup := NewProxmoxSetup(zerolog.Nop(), nil, "", "", "pbs", "", "", false)
|
||||
id, value, err := setup.setupPBSToken(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("setupPBSToken failed: %v", err)
|
||||
}
|
||||
|
||||
if id != "pulse-monitor@pbs!test-token" {
|
||||
t.Errorf("expected token ID pulse-monitor@pbs!test-token, got %s", id)
|
||||
}
|
||||
if value != "pbs-api-token-value-12345" {
|
||||
t.Errorf("expected token value pbs-api-token-value-12345, got %s", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectProxmoxTypes(t *testing.T) {
|
||||
origLookPath := lookPath
|
||||
defer func() { lookPath = origLookPath }()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mockPaths map[string]bool // map[exe]exists
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "detects pve only",
|
||||
mockPaths: map[string]bool{
|
||||
"pvesh": true,
|
||||
"proxmox-backup-manager": false,
|
||||
},
|
||||
expected: []string{"pve"},
|
||||
},
|
||||
{
|
||||
name: "detects pbs only",
|
||||
mockPaths: map[string]bool{
|
||||
"pvesh": false,
|
||||
"proxmox-backup-manager": true,
|
||||
},
|
||||
expected: []string{"pbs"},
|
||||
},
|
||||
{
|
||||
name: "detects both",
|
||||
mockPaths: map[string]bool{
|
||||
"pvesh": true,
|
||||
"proxmox-backup-manager": true,
|
||||
},
|
||||
expected: []string{"pve", "pbs"},
|
||||
},
|
||||
{
|
||||
name: "detects none",
|
||||
mockPaths: map[string]bool{
|
||||
"pvesh": false,
|
||||
"proxmox-backup-manager": false,
|
||||
},
|
||||
expected: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
lookPath = func(file string) (string, error) {
|
||||
if exists := tt.mockPaths[file]; exists {
|
||||
return "/usr/bin/" + file, nil
|
||||
}
|
||||
return "", &exec.Error{Name: file, Err: exec.ErrNotFound}
|
||||
}
|
||||
|
||||
setup := &ProxmoxSetup{}
|
||||
result := setup.detectProxmoxTypes()
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
} else {
|
||||
for i, v := range result {
|
||||
if v != tt.expected[i] {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunForType(t *testing.T) {
|
||||
// Setup temporary state dir
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Backup and restore state file paths
|
||||
origStateFileDir := stateFileDir
|
||||
origStateFilePath := stateFilePath
|
||||
origStateFilePVE := stateFilePVE
|
||||
origStateFilePBS := stateFilePBS
|
||||
|
||||
stateFileDir = tmpDir
|
||||
stateFilePath = filepath.Join(tmpDir, "proxmox-registered")
|
||||
stateFilePVE = filepath.Join(tmpDir, "proxmox-pve-registered")
|
||||
stateFilePBS = filepath.Join(tmpDir, "proxmox-pbs-registered")
|
||||
|
||||
defer func() {
|
||||
stateFileDir = origStateFileDir
|
||||
stateFilePath = origStateFilePath
|
||||
stateFilePVE = origStateFilePVE
|
||||
stateFilePBS = origStateFilePBS
|
||||
}()
|
||||
|
||||
// Backup and restore runCommand functions
|
||||
origRunCommand := runCommand
|
||||
origRunCommandOutput := runCommandOutput
|
||||
defer func() {
|
||||
runCommand = origRunCommand
|
||||
runCommandOutput = origRunCommandOutput
|
||||
}()
|
||||
|
||||
// Mock Proxmox commands
|
||||
mockTokenOutput := `
|
||||
┌──────────────┬──────────────────────────────────────┐
|
||||
│ key │ value │
|
||||
╞══════════════╪══════════════════════════════════════╡
|
||||
│ full-tokenid │ pulse-monitor@pam!test-token │
|
||||
├──────────────┼──────────────────────────────────────┤
|
||||
│ value │ 7c5709fb-6aee-4c32-8b9f-5c2656912345 │
|
||||
└──────────────┴──────────────────────────────────────┘
|
||||
`
|
||||
runCommand = func(ctx context.Context, name string, args ...string) error {
|
||||
return nil
|
||||
}
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) (string, error) {
|
||||
return mockTokenOutput, nil
|
||||
}
|
||||
|
||||
// Mock HTTP Client to capture registration
|
||||
var capturedReq *http.Request
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedReq = r
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"success":true}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
setup := NewProxmoxSetup(zerolog.Nop(), server.Client(), server.URL, "api-token", "pve", "test-host", "", false)
|
||||
|
||||
// Test case 1: Not registered yet
|
||||
result, err := setup.runForType(context.Background(), "pve")
|
||||
if err != nil {
|
||||
t.Fatalf("runForType failed: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected result, got nil")
|
||||
}
|
||||
if !result.Registered {
|
||||
t.Error("expected Registered=true")
|
||||
}
|
||||
if !strings.HasPrefix(result.TokenID, "pulse-monitor@pam!pulse-") {
|
||||
t.Errorf("expected token ID starting with pulse-monitor@pam!pulse-, got %s", result.TokenID)
|
||||
}
|
||||
|
||||
// Verify HTTP registration call
|
||||
if capturedReq == nil {
|
||||
t.Error("expected HTTP registration request")
|
||||
} else {
|
||||
if capturedReq.URL.Path != "/api/auto-register" {
|
||||
t.Errorf("expected path /api/auto-register, got %s", capturedReq.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify state file created
|
||||
if _, err := os.Stat(stateFilePVE); os.IsNotExist(err) {
|
||||
t.Error("expected state file to be created")
|
||||
}
|
||||
|
||||
// Test case 2: Already registered (should skip)
|
||||
capturedReq = nil // Reset capture
|
||||
result, err = setup.runForType(context.Background(), "pve")
|
||||
if err != nil {
|
||||
t.Fatalf("runForType (2nd call) failed: %v", err)
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Error("expected nil result (skipped), got something")
|
||||
}
|
||||
if capturedReq != nil {
|
||||
t.Error("did not expect HTTP call on 2nd run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAll(t *testing.T) {
|
||||
// Setup temporary state dir
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Backup and restore state variables
|
||||
origStateFileDir := stateFileDir
|
||||
origStateFilePath := stateFilePath
|
||||
origStateFilePVE := stateFilePVE
|
||||
origStateFilePBS := stateFilePBS
|
||||
|
||||
stateFileDir = tmpDir
|
||||
stateFilePath = filepath.Join(tmpDir, "proxmox-registered")
|
||||
stateFilePVE = filepath.Join(tmpDir, "proxmox-pve-registered")
|
||||
stateFilePBS = filepath.Join(tmpDir, "proxmox-pbs-registered")
|
||||
|
||||
defer func() {
|
||||
stateFileDir = origStateFileDir
|
||||
stateFilePath = origStateFilePath
|
||||
stateFilePVE = origStateFilePVE
|
||||
stateFilePBS = origStateFilePBS
|
||||
}()
|
||||
|
||||
// Backup and restore lookPath
|
||||
origLookPath := lookPath
|
||||
defer func() { lookPath = origLookPath }()
|
||||
|
||||
// Backup runCommand
|
||||
origRunCommand := runCommand
|
||||
origRunCommandOutput := runCommandOutput
|
||||
defer func() {
|
||||
runCommand = origRunCommand
|
||||
runCommandOutput = origRunCommandOutput
|
||||
}()
|
||||
|
||||
// Mock LookPath to find BOTH PVE and PBS
|
||||
lookPath = func(file string) (string, error) {
|
||||
if file == "pvesh" || file == "proxmox-backup-manager" {
|
||||
return "/usr/bin/" + file, nil
|
||||
}
|
||||
return "", &exec.Error{Name: file, Err: exec.ErrNotFound}
|
||||
}
|
||||
|
||||
// Mock Command execution
|
||||
runCommand = func(ctx context.Context, name string, args ...string) error {
|
||||
return nil
|
||||
}
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) (string, error) {
|
||||
// Return valid tokens for both types
|
||||
if name == "pveum" {
|
||||
return `
|
||||
┌──────────────┬──────────────────────────────────────┐
|
||||
│ key │ value │
|
||||
╞══════════════╪══════════════════════════════════════╡
|
||||
│ full-tokenid │ pulse-monitor@pam!pulse-pve-token │
|
||||
├──────────────┼──────────────────────────────────────┤
|
||||
│ value │ 7c5709fb-6aee-4c32-8b9f-5c2656912345 │
|
||||
└──────────────┴──────────────────────────────────────┘
|
||||
`, nil
|
||||
}
|
||||
if name == "proxmox-backup-manager" {
|
||||
return `{"tokenid":"pulse-monitor@pbs!pulse-pbs-token","value":"pbs-value"}`, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Mock HTTP Server
|
||||
callCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"success":true}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
setup := NewProxmoxSetup(zerolog.Nop(), server.Client(), server.URL, "api-token", "", "test-host", "", false)
|
||||
|
||||
// RunAll
|
||||
results, err := setup.RunAll(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("RunAll failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have 2 results
|
||||
if len(results) != 2 {
|
||||
t.Errorf("expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
// Should have made 2 HTTP calls
|
||||
if callCount != 2 {
|
||||
t.Errorf("expected 2 HTTP calls, got %d", callCount)
|
||||
}
|
||||
|
||||
// Check state files
|
||||
if _, err := os.Stat(stateFilePVE); os.IsNotExist(err) {
|
||||
t.Error("expected PVE state file")
|
||||
}
|
||||
if _, err := os.Stat(stateFilePBS); os.IsNotExist(err) {
|
||||
t.Error("expected PBS state file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_Legacy(t *testing.T) {
|
||||
// Setup temporary state dir
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Backup
|
||||
origStateFilePath := stateFilePath
|
||||
stateFilePath = filepath.Join(tmpDir, "proxmox-registered")
|
||||
defer func() { stateFilePath = origStateFilePath }()
|
||||
|
||||
origLookPath := lookPath
|
||||
defer func() { lookPath = origLookPath }()
|
||||
|
||||
origRunCommand := runCommand
|
||||
origRunCommandOutput := runCommandOutput
|
||||
defer func() {
|
||||
runCommand = origRunCommand
|
||||
runCommandOutput = origRunCommandOutput
|
||||
}()
|
||||
|
||||
// Mock LookPath - find PVE
|
||||
lookPath = func(file string) (string, error) {
|
||||
if file == "pvesh" {
|
||||
return "/usr/bin/pvesh", nil
|
||||
}
|
||||
return "", &exec.Error{Name: file, Err: exec.ErrNotFound}
|
||||
}
|
||||
|
||||
// Mock Token
|
||||
runCommand = func(ctx context.Context, name string, args ...string) error { return nil }
|
||||
runCommandOutput = func(ctx context.Context, name string, args ...string) (string, error) {
|
||||
if name == "pveum" {
|
||||
return `
|
||||
┌──────────────┬──────────────────────────────────────┐
|
||||
│ key │ value │
|
||||
╞══════════════╪══════════════════════════════════════╡
|
||||
│ full-tokenid │ pulse-monitor@pam!pulse-pve-token │
|
||||
├──────────────┼──────────────────────────────────────┤
|
||||
│ value │ 7c5709fb-6aee-4c32-8b9f-5c2656912345 │
|
||||
└──────────────┴──────────────────────────────────────┘
|
||||
`, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Mock HTTP
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"success":true}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Test Run()
|
||||
setup := NewProxmoxSetup(zerolog.Nop(), server.Client(), server.URL, "api-token", "", "test-host", "", false) // empty ptype -> auto-detect
|
||||
result, err := setup.Run(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Run failed: %v", err)
|
||||
}
|
||||
if result == nil || !result.Registered {
|
||||
t.Error("expected successful registration")
|
||||
}
|
||||
|
||||
// Verify legacy state file
|
||||
if _, err := os.Stat(stateFilePath); os.IsNotExist(err) {
|
||||
t.Error("expected legacy state file")
|
||||
}
|
||||
|
||||
// Run again - checks isAlreadyRegistered (idempotency)
|
||||
result2, err := setup.Run(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Run 2 failed: %v", err)
|
||||
}
|
||||
if result2 != nil {
|
||||
t.Error("expected nil result (skipped) on 2nd run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTypeRegistered_Legacy(t *testing.T) {
|
||||
// Setup temporary state dir
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Backup
|
||||
origStateFilePath := stateFilePath
|
||||
origStateFilePVE := stateFilePVE // Ensure new files don't exist
|
||||
stateFilePath = filepath.Join(tmpDir, "proxmox-registered")
|
||||
stateFilePVE = filepath.Join(tmpDir, "proxmox-pve-registered")
|
||||
|
||||
defer func() {
|
||||
stateFilePath = origStateFilePath
|
||||
stateFilePVE = origStateFilePVE
|
||||
}()
|
||||
|
||||
origLookPath := lookPath
|
||||
defer func() { lookPath = origLookPath }()
|
||||
|
||||
// Create legacy state file
|
||||
if err := os.WriteFile(stateFilePath, []byte("legacy"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setup := &ProxmoxSetup{}
|
||||
|
||||
// Scenario 1: PVE installed. Requesting PVE check. Should be true.
|
||||
lookPath = func(file string) (string, error) {
|
||||
if file == "pvesh" {
|
||||
return "/bin/pvesh", nil
|
||||
}
|
||||
return "", &exec.Error{Name: file, Err: exec.ErrNotFound}
|
||||
}
|
||||
if !setup.isTypeRegistered("pve") {
|
||||
t.Error("Legacy: PVE installed + legacy file should => registered")
|
||||
}
|
||||
|
||||
// Scenario 2: PVE installed. Requesting PBS check. Should be false (PVE assumed primary).
|
||||
if setup.isTypeRegistered("pbs") {
|
||||
t.Error("Legacy: PVE installed + legacy file should => PBS NOT registered")
|
||||
}
|
||||
|
||||
// Scenario 3: Only PBS installed. Requesting PBS check. Should be true.
|
||||
lookPath = func(file string) (string, error) {
|
||||
if file == "proxmox-backup-manager" {
|
||||
return "/bin/proxmox-backup-manager", nil
|
||||
}
|
||||
return "", &exec.Error{Name: file, Err: exec.ErrNotFound}
|
||||
}
|
||||
if !setup.isTypeRegistered("pbs") {
|
||||
t.Error("Legacy: Only PBS installed + legacy file should => PBS registered")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
125
internal/kubernetesagent/agent_helpers_test.go
Normal file
125
internal/kubernetesagent/agent_helpers_test.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
package kubernetesagent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
)
|
||||
|
||||
func TestIsProblemPod(t *testing.T) {
|
||||
if !isProblemPod(corev1.Pod{Status: corev1.PodStatus{Phase: corev1.PodPending}}) {
|
||||
t.Fatal("expected pending pod to be a problem")
|
||||
}
|
||||
if !isProblemPod(corev1.Pod{Status: corev1.PodStatus{Phase: corev1.PodFailed}}) {
|
||||
t.Fatal("expected failed pod to be a problem")
|
||||
}
|
||||
if !isProblemPod(corev1.Pod{Status: corev1.PodStatus{Phase: corev1.PodUnknown}}) {
|
||||
t.Fatal("expected unknown pod to be a problem")
|
||||
}
|
||||
|
||||
okPod := corev1.Pod{
|
||||
Status: corev1.PodStatus{
|
||||
Phase: corev1.PodRunning,
|
||||
ContainerStatuses: []corev1.ContainerStatus{
|
||||
{
|
||||
Ready: true,
|
||||
State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if isProblemPod(okPod) {
|
||||
t.Fatal("expected healthy running pod to be non-problem")
|
||||
}
|
||||
|
||||
waitingPod := corev1.Pod{
|
||||
Status: corev1.PodStatus{
|
||||
Phase: corev1.PodRunning,
|
||||
ContainerStatuses: []corev1.ContainerStatus{
|
||||
{
|
||||
Ready: false,
|
||||
State: corev1.ContainerState{
|
||||
Waiting: &corev1.ContainerStateWaiting{Reason: "ImagePullBackOff"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if !isProblemPod(waitingPod) {
|
||||
t.Fatal("expected waiting container to be a problem")
|
||||
}
|
||||
|
||||
initFailedPod := corev1.Pod{
|
||||
Status: corev1.PodStatus{
|
||||
Phase: corev1.PodRunning,
|
||||
InitContainerStatuses: []corev1.ContainerStatus{
|
||||
{
|
||||
Ready: false,
|
||||
State: corev1.ContainerState{
|
||||
Terminated: &corev1.ContainerStateTerminated{ExitCode: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if !isProblemPod(initFailedPod) {
|
||||
t.Fatal("expected failed init container to be a problem")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeContainerState(t *testing.T) {
|
||||
status, reason, message := summarizeContainerState(corev1.ContainerStatus{
|
||||
State: corev1.ContainerState{Running: &corev1.ContainerStateRunning{}},
|
||||
})
|
||||
if status != "running" || reason != "" || message != "" {
|
||||
t.Fatalf("unexpected running summary: %s %s %s", status, reason, message)
|
||||
}
|
||||
|
||||
status, reason, message = summarizeContainerState(corev1.ContainerStatus{
|
||||
State: corev1.ContainerState{
|
||||
Waiting: &corev1.ContainerStateWaiting{Reason: "CrashLoopBackOff", Message: " waiting "},
|
||||
},
|
||||
})
|
||||
if status != "waiting" || reason != "CrashLoopBackOff" || message != "waiting" {
|
||||
t.Fatalf("unexpected waiting summary: %s %s %s", status, reason, message)
|
||||
}
|
||||
|
||||
status, reason, message = summarizeContainerState(corev1.ContainerStatus{
|
||||
State: corev1.ContainerState{
|
||||
Terminated: &corev1.ContainerStateTerminated{Reason: "Error", Message: " boom "},
|
||||
},
|
||||
})
|
||||
if status != "terminated" || reason != "Error" || message != "boom" {
|
||||
t.Fatalf("unexpected terminated summary: %s %s %s", status, reason, message)
|
||||
}
|
||||
|
||||
status, reason, message = summarizeContainerState(corev1.ContainerStatus{})
|
||||
if status != "unknown" || reason != "" || message != "" {
|
||||
t.Fatalf("unexpected unknown summary: %s %s %s", status, reason, message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOwnerRef(t *testing.T) {
|
||||
controller := true
|
||||
refs := []metav1.OwnerReference{
|
||||
{Kind: "ReplicaSet", Name: "rs-1"},
|
||||
{Kind: "Deployment", Name: "deploy-1", Controller: &controller},
|
||||
}
|
||||
kind, name := ownerRef(refs)
|
||||
if kind != "Deployment" || name != "deploy-1" {
|
||||
t.Fatalf("expected controller ref, got %s %s", kind, name)
|
||||
}
|
||||
|
||||
kind, name = ownerRef([]metav1.OwnerReference{
|
||||
{Kind: "Job", Name: "job-1"},
|
||||
})
|
||||
if kind != "Job" || name != "job-1" {
|
||||
t.Fatalf("expected first ref, got %s %s", kind, name)
|
||||
}
|
||||
|
||||
kind, name = ownerRef(nil)
|
||||
if kind != "" || name != "" {
|
||||
t.Fatalf("expected empty ref, got %s %s", kind, name)
|
||||
}
|
||||
}
|
||||
111
internal/kubernetesagent/agent_more_test.go
Normal file
111
internal/kubernetesagent/agent_more_test.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package kubernetesagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/buffer"
|
||||
agentsk8s "github.com/rcourtman/pulse-go-rewrite/pkg/agents/kubernetes"
|
||||
"github.com/rs/zerolog"
|
||||
"k8s.io/client-go/kubernetes/fake"
|
||||
)
|
||||
|
||||
func TestComputeClusterID(t *testing.T) {
|
||||
id1 := computeClusterID("https://k8s", "ctx", "name")
|
||||
id2 := computeClusterID("https://k8s", "ctx", "name")
|
||||
if id1 == "" || id1 != id2 {
|
||||
t.Fatalf("expected stable cluster ID, got %s and %s", id1, id2)
|
||||
}
|
||||
|
||||
id3 := computeClusterID("https://k8s", "ctx", "other")
|
||||
if id3 == id1 {
|
||||
t.Fatalf("expected different IDs for different inputs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushReportsStopsOnError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
logger := zerolog.New(io.Discard)
|
||||
reportBuffer := buffer.New[agentsk8s.Report](10)
|
||||
reportBuffer.Push(agentsk8s.Report{Timestamp: time.Now().UTC()})
|
||||
|
||||
agent := &Agent{
|
||||
cfg: Config{APIToken: "token"},
|
||||
logger: logger,
|
||||
httpClient: server.Client(),
|
||||
pulseURL: server.URL,
|
||||
agentVersion: "v1",
|
||||
reportBuffer: reportBuffer,
|
||||
}
|
||||
|
||||
agent.flushReports(context.Background())
|
||||
if _, ok := reportBuffer.Peek(); !ok {
|
||||
t.Fatal("expected report to remain buffered after failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushReportsSuccess(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
logger := zerolog.New(io.Discard)
|
||||
reportBuffer := buffer.New[agentsk8s.Report](10)
|
||||
reportBuffer.Push(agentsk8s.Report{Timestamp: time.Now().UTC()})
|
||||
reportBuffer.Push(agentsk8s.Report{Timestamp: time.Now().UTC()})
|
||||
|
||||
agent := &Agent{
|
||||
cfg: Config{APIToken: "token"},
|
||||
logger: logger,
|
||||
httpClient: server.Client(),
|
||||
pulseURL: server.URL,
|
||||
agentVersion: "v1",
|
||||
reportBuffer: reportBuffer,
|
||||
}
|
||||
|
||||
agent.flushReports(context.Background())
|
||||
if _, ok := reportBuffer.Peek(); ok {
|
||||
t.Fatal("expected report buffer to be empty after flush")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOnceBuffersOnSendError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
logger := zerolog.New(io.Discard)
|
||||
reportBuffer := buffer.New[agentsk8s.Report](10)
|
||||
agent := &Agent{
|
||||
cfg: Config{APIToken: "token"},
|
||||
logger: logger,
|
||||
httpClient: server.Client(),
|
||||
pulseURL: server.URL,
|
||||
agentID: "agent1",
|
||||
agentVersion: "v1",
|
||||
interval: time.Second,
|
||||
clusterID: "cluster",
|
||||
clusterName: "cluster",
|
||||
clusterServer: "https://k8s",
|
||||
clusterContext: "ctx",
|
||||
kubeClient: fake.NewSimpleClientset(),
|
||||
reportBuffer: reportBuffer,
|
||||
includeNamespaces: nil,
|
||||
excludeNamespaces: nil,
|
||||
}
|
||||
|
||||
agent.runOnce(context.Background())
|
||||
if _, ok := reportBuffer.Peek(); !ok {
|
||||
t.Fatal("expected report to be buffered on send failure")
|
||||
}
|
||||
}
|
||||
180
internal/monitoring/kubernetes_agents_test.go
Normal file
180
internal/monitoring/kubernetes_agents_test.go
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
package monitoring
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/models"
|
||||
agentsk8s "github.com/rcourtman/pulse-go-rewrite/pkg/agents/kubernetes"
|
||||
)
|
||||
|
||||
func newKubernetesTestMonitor() *Monitor {
|
||||
return &Monitor{
|
||||
state: models.NewState(),
|
||||
config: &config.Config{},
|
||||
removedKubernetesClusters: make(map[string]time.Time),
|
||||
kubernetesTokenBindings: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKubernetesClusterIdentifier(t *testing.T) {
|
||||
report := agentsk8s.Report{
|
||||
Cluster: agentsk8s.ClusterInfo{ID: "cluster-1"},
|
||||
Agent: agentsk8s.AgentInfo{ID: "agent-1"},
|
||||
}
|
||||
if got := normalizeKubernetesClusterIdentifier(report); got != "cluster-1" {
|
||||
t.Fatalf("unexpected identifier: %s", got)
|
||||
}
|
||||
|
||||
report.Cluster.ID = ""
|
||||
if got := normalizeKubernetesClusterIdentifier(report); got != "agent-1" {
|
||||
t.Fatalf("unexpected identifier: %s", got)
|
||||
}
|
||||
|
||||
report.Agent.ID = ""
|
||||
report.Cluster.Server = "https://server"
|
||||
report.Cluster.Context = "ctx"
|
||||
report.Cluster.Name = "name"
|
||||
stableKey := "https://server|ctx|name"
|
||||
sum := sha256.Sum256([]byte(stableKey))
|
||||
expected := hex.EncodeToString(sum[:])
|
||||
if got := normalizeKubernetesClusterIdentifier(report); got != expected {
|
||||
t.Fatalf("unexpected hashed identifier: %s", got)
|
||||
}
|
||||
|
||||
report.Cluster.Server = ""
|
||||
report.Cluster.Context = ""
|
||||
report.Cluster.Name = ""
|
||||
if got := normalizeKubernetesClusterIdentifier(report); got != "" {
|
||||
t.Fatalf("expected empty identifier, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyKubernetesReport(t *testing.T) {
|
||||
monitor := newKubernetesTestMonitor()
|
||||
report := agentsk8s.Report{
|
||||
Agent: agentsk8s.AgentInfo{ID: "agent-1", IntervalSeconds: 10},
|
||||
Cluster: agentsk8s.ClusterInfo{ID: "cluster-1", Name: "cluster"},
|
||||
}
|
||||
|
||||
cluster, err := monitor.ApplyKubernetesReport(report, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cluster.ID != "cluster-1" || cluster.DisplayName != "cluster" {
|
||||
t.Fatalf("unexpected cluster: %+v", cluster)
|
||||
}
|
||||
if !monitor.state.ConnectionHealth[kubernetesConnectionPrefix+"cluster-1"] {
|
||||
t.Fatal("expected connection health to be true")
|
||||
}
|
||||
|
||||
monitor.removedKubernetesClusters["cluster-2"] = time.Now()
|
||||
report.Cluster.ID = "cluster-2"
|
||||
if _, err := monitor.ApplyKubernetesReport(report, nil); err == nil {
|
||||
t.Fatal("expected error for removed cluster")
|
||||
}
|
||||
|
||||
token := &config.APITokenRecord{ID: "token-1", Name: "Token"}
|
||||
monitor.kubernetesTokenBindings["token-1"] = "other-agent"
|
||||
report.Cluster.ID = "cluster-3"
|
||||
report.Agent.ID = "agent-1"
|
||||
if _, err := monitor.ApplyKubernetesReport(report, token); err == nil {
|
||||
t.Fatal("expected error for token bound to different agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveAndReenrollKubernetesCluster(t *testing.T) {
|
||||
monitor := newKubernetesTestMonitor()
|
||||
monitor.kubernetesTokenBindings["token-1"] = "agent-1"
|
||||
monitor.config.APITokens = []config.APITokenRecord{{ID: "token-1"}}
|
||||
monitor.state.UpsertKubernetesCluster(models.KubernetesCluster{
|
||||
ID: "cluster-1",
|
||||
Name: "cluster",
|
||||
DisplayName: "cluster",
|
||||
TokenID: "token-1",
|
||||
TokenName: "Token",
|
||||
})
|
||||
monitor.state.SetConnectionHealth(kubernetesConnectionPrefix+"cluster-1", true)
|
||||
|
||||
_, err := monitor.RemoveKubernetesCluster("cluster-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(monitor.state.GetKubernetesClusters()) != 0 {
|
||||
t.Fatal("expected cluster removed")
|
||||
}
|
||||
if _, exists := monitor.kubernetesTokenBindings["token-1"]; exists {
|
||||
t.Fatal("expected token binding removed")
|
||||
}
|
||||
if _, exists := monitor.state.ConnectionHealth[kubernetesConnectionPrefix+"cluster-1"]; exists {
|
||||
t.Fatal("expected connection health removed")
|
||||
}
|
||||
if len(monitor.state.GetRemovedKubernetesClusters()) != 1 {
|
||||
t.Fatal("expected removed cluster entry")
|
||||
}
|
||||
|
||||
if err := monitor.AllowKubernetesClusterReenroll("cluster-1"); err != nil {
|
||||
t.Fatalf("unexpected reenroll error: %v", err)
|
||||
}
|
||||
if len(monitor.state.GetRemovedKubernetesClusters()) != 0 {
|
||||
t.Fatal("expected removed entry cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKubernetesClusterUpdates(t *testing.T) {
|
||||
monitor := newKubernetesTestMonitor()
|
||||
monitor.state.UpsertKubernetesCluster(models.KubernetesCluster{
|
||||
ID: "cluster-1",
|
||||
Name: "cluster",
|
||||
LastSeen: time.Now().Add(-10 * time.Second),
|
||||
Status: "online",
|
||||
IntervalSeconds: 5,
|
||||
})
|
||||
monitor.state.UpsertKubernetesCluster(models.KubernetesCluster{
|
||||
ID: "cluster-2",
|
||||
Name: "cluster2",
|
||||
LastSeen: time.Now().Add(-10 * time.Hour),
|
||||
Status: "online",
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
monitor.evaluateKubernetesAgents(now)
|
||||
if monitor.state.ConnectionHealth[kubernetesConnectionPrefix+"cluster-1"] != true {
|
||||
t.Fatal("expected cluster-1 healthy")
|
||||
}
|
||||
if monitor.state.ConnectionHealth[kubernetesConnectionPrefix+"cluster-2"] != false {
|
||||
t.Fatal("expected cluster-2 unhealthy")
|
||||
}
|
||||
|
||||
_, err := monitor.UnhideKubernetesCluster("cluster-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected unhide error: %v", err)
|
||||
}
|
||||
if _, err := monitor.MarkKubernetesClusterPendingUninstall("cluster-1"); err != nil {
|
||||
t.Fatalf("unexpected pending uninstall error: %v", err)
|
||||
}
|
||||
if _, err := monitor.SetKubernetesClusterCustomDisplayName("cluster-1", "custom"); err != nil {
|
||||
t.Fatalf("unexpected set display name error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupRemovedKubernetesClusters(t *testing.T) {
|
||||
monitor := newKubernetesTestMonitor()
|
||||
monitor.removedKubernetesClusters["cluster-1"] = time.Now().Add(-2 * removedKubernetesClustersTTL)
|
||||
monitor.state.AddRemovedKubernetesCluster(models.RemovedKubernetesCluster{
|
||||
ID: "cluster-1",
|
||||
Name: "cluster",
|
||||
RemovedAt: time.Now().Add(-2 * removedKubernetesClustersTTL),
|
||||
})
|
||||
|
||||
monitor.cleanupRemovedKubernetesClusters(time.Now())
|
||||
if len(monitor.removedKubernetesClusters) != 0 {
|
||||
t.Fatal("expected removed clusters cleaned up")
|
||||
}
|
||||
if len(monitor.state.GetRemovedKubernetesClusters()) != 0 {
|
||||
t.Fatal("expected state cleanup")
|
||||
}
|
||||
}
|
||||
58
internal/monitoring/reload_test.go
Normal file
58
internal/monitoring/reload_test.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
package monitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/mock"
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/websocket"
|
||||
)
|
||||
|
||||
func TestReloadableMonitorLifecycle(t *testing.T) {
|
||||
t.Setenv("PULSE_DATA_DIR", t.TempDir())
|
||||
mock.SetEnabled(true)
|
||||
defer mock.SetEnabled(false)
|
||||
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
hub := websocket.NewHub(nil)
|
||||
rm, err := NewReloadableMonitor(cfg, hub)
|
||||
if err != nil {
|
||||
t.Fatalf("new reloadable monitor: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
rm.Start(ctx)
|
||||
|
||||
if err := rm.Reload(); err != nil {
|
||||
t.Fatalf("reload error: %v", err)
|
||||
}
|
||||
|
||||
if rm.GetMonitor() == nil {
|
||||
t.Fatal("expected monitor instance")
|
||||
}
|
||||
if rm.GetConfig() == nil {
|
||||
t.Fatal("expected config instance")
|
||||
}
|
||||
|
||||
rm.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
// Allow any goroutines to observe cancel without blocking test.
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadableMonitorGetConfigNil(t *testing.T) {
|
||||
rm := &ReloadableMonitor{}
|
||||
if rm.GetConfig() != nil {
|
||||
t.Fatal("expected nil config")
|
||||
}
|
||||
}
|
||||
|
|
@ -111,3 +111,58 @@ func TestClient_Fetch(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_ResolveHostID(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/agents/host/lookup" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.Query().Get("hostname") {
|
||||
case "known":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"success":true,"host":{"id":"host-123"}}`))
|
||||
case "unknown":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"success":false}`))
|
||||
case "bad":
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case "invalid":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`not-json`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := New(Config{
|
||||
PulseURL: ts.URL,
|
||||
APIToken: "token",
|
||||
})
|
||||
|
||||
if got, err := client.resolveHostID(context.Background()); err != nil || got != "" {
|
||||
t.Fatalf("expected empty hostID for blank hostname, got %q err=%v", got, err)
|
||||
}
|
||||
|
||||
client.cfg.Hostname = "known"
|
||||
if got, err := client.resolveHostID(context.Background()); err != nil || got != "host-123" {
|
||||
t.Fatalf("expected host-123, got %q err=%v", got, err)
|
||||
}
|
||||
|
||||
client.cfg.Hostname = "unknown"
|
||||
if got, err := client.resolveHostID(context.Background()); err != nil || got != "" {
|
||||
t.Fatalf("expected empty hostID, got %q err=%v", got, err)
|
||||
}
|
||||
|
||||
client.cfg.Hostname = "bad"
|
||||
if _, err := client.resolveHostID(context.Background()); err == nil {
|
||||
t.Fatal("expected error for server failure")
|
||||
}
|
||||
|
||||
client.cfg.Hostname = "invalid"
|
||||
if _, err := client.resolveHostID(context.Background()); err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
package remoteconfig
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -34,3 +37,83 @@ func TestVerifyConfigPayloadSignature_WithEnvKey(t *testing.T) {
|
|||
t.Fatalf("VerifyConfigPayloadSignature: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeEd25519PrivateKey(t *testing.T) {
|
||||
if _, err := DecodeEd25519PrivateKey(""); err == nil {
|
||||
t.Fatal("expected error for empty key")
|
||||
}
|
||||
if _, err := DecodeEd25519PrivateKey("not-base64"); err == nil {
|
||||
t.Fatal("expected error for invalid base64")
|
||||
}
|
||||
|
||||
_, priv, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
|
||||
full := base64.StdEncoding.EncodeToString(priv)
|
||||
decoded, err := DecodeEd25519PrivateKey(full)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeEd25519PrivateKey full: %v", err)
|
||||
}
|
||||
if !bytes.Equal(decoded, priv) {
|
||||
t.Fatal("expected decoded private key to match")
|
||||
}
|
||||
|
||||
seed := base64.StdEncoding.EncodeToString(priv.Seed())
|
||||
decoded, err = DecodeEd25519PrivateKey(seed)
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeEd25519PrivateKey seed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(decoded.Seed(), priv.Seed()) {
|
||||
t.Fatal("expected decoded seed to match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustedConfigPublicKeys(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", base64.StdEncoding.EncodeToString(pub))
|
||||
keys, err := trustedConfigPublicKeys()
|
||||
if err != nil || len(keys) != 1 {
|
||||
t.Fatalf("expected 1 key, got %d err=%v", len(keys), err)
|
||||
}
|
||||
|
||||
raw, err := x509.MarshalPKIXPublicKey(pub)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalPKIXPublicKey: %v", err)
|
||||
}
|
||||
block := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: raw})
|
||||
t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", string(block))
|
||||
keys, err = trustedConfigPublicKeys()
|
||||
if err != nil || len(keys) != 1 {
|
||||
t.Fatalf("expected 1 pem key, got %d err=%v", len(keys), err)
|
||||
}
|
||||
|
||||
t.Setenv("PULSE_AGENT_CONFIG_PUBLIC_KEYS", "not-base64")
|
||||
if _, err := trustedConfigPublicKeys(); err == nil {
|
||||
t.Fatal("expected error for invalid public key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalCanonicalValue(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"b": 1,
|
||||
"a": []interface{}{
|
||||
map[string]interface{}{"d": "x", "c": "y"},
|
||||
2,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := marshalCanonicalValue(input)
|
||||
if err != nil {
|
||||
t.Fatalf("marshalCanonicalValue error: %v", err)
|
||||
}
|
||||
expected := `{"a":[{"c":"y","d":"x"},2],"b":1}`
|
||||
if string(data) != expected {
|
||||
t.Fatalf("unexpected canonical JSON: %s", string(data))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
55
internal/sensors/collector_test.go
Normal file
55
internal/sensors/collector_test.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
package sensors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeScript(t *testing.T, dir, name, content string) {
|
||||
t.Helper()
|
||||
path := filepath.Join(dir, name)
|
||||
if err := os.WriteFile(path, []byte(content), 0700); err != nil {
|
||||
t.Fatalf("write script: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalMissingSensors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
t.Setenv("PATH", dir)
|
||||
|
||||
if _, err := CollectLocal(context.Background()); err == nil {
|
||||
t.Fatal("expected error when sensors missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalSensorsOutput(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeScript(t, dir, "sensors", "#!/bin/sh\necho '{\"chip\":{\"temp\":{\"temp1_input\":42}}}'\n")
|
||||
t.Setenv("PATH", dir+":"+os.Getenv("PATH"))
|
||||
|
||||
out, err := CollectLocal(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out != "{\"chip\":{\"temp\":{\"temp1_input\":42}}}" {
|
||||
t.Fatalf("unexpected output: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectLocalFallbackToPiTemp(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeScript(t, dir, "sensors", "#!/bin/sh\necho '{}'\n")
|
||||
writeScript(t, dir, "cat", "#!/bin/sh\necho '42000'\n")
|
||||
t.Setenv("PATH", dir+":"+os.Getenv("PATH"))
|
||||
|
||||
out, err := CollectLocal(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
expected := `{"cpu_thermal-virtual-0":{"temp1":{"temp1_input":42000}}}`
|
||||
if out != expected {
|
||||
t.Fatalf("unexpected fallback output: %s", out)
|
||||
}
|
||||
}
|
||||
175
internal/updates/adapter_installsh_exec_test.go
Normal file
175
internal/updates/adapter_installsh_exec_test.go
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInstallShAdapter_DownloadBinary(t *testing.T) {
|
||||
tarball := filepath.Join(t.TempDir(), "pulse.tar.gz")
|
||||
writeTarGz(t, tarball, map[string]string{
|
||||
"bin/pulse": "binary",
|
||||
})
|
||||
|
||||
data, err := os.ReadFile(tarball)
|
||||
if err != nil {
|
||||
t.Fatalf("read tarball: %v", err)
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
checksum := hex.EncodeToString(sum[:])
|
||||
|
||||
dir := t.TempDir()
|
||||
curl := filepath.Join(dir, "curl")
|
||||
script := strings.Join([]string{
|
||||
"#!/bin/sh",
|
||||
`out=""`,
|
||||
`url=""`,
|
||||
`while [ "$#" -gt 0 ]; do`,
|
||||
` if [ "$1" = "-o" ]; then`,
|
||||
` out="$2"`,
|
||||
` shift 2`,
|
||||
` continue`,
|
||||
` fi`,
|
||||
` url="$1"`,
|
||||
` shift`,
|
||||
`done`,
|
||||
`if echo "$url" | grep -q ".sha256$"; then`,
|
||||
` echo "$PULSE_TEST_CHECKSUM pulse.tar.gz" > "$out"`,
|
||||
`else`,
|
||||
` cat "$PULSE_TEST_TARBALL" > "$out"`,
|
||||
`fi`,
|
||||
``,
|
||||
}, "\n")
|
||||
if err := os.WriteFile(curl, []byte(script), 0755); err != nil {
|
||||
t.Fatalf("write curl stub: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
t.Setenv("PULSE_TEST_TARBALL", tarball)
|
||||
t.Setenv("PULSE_TEST_CHECKSUM", checksum)
|
||||
|
||||
adapter := &InstallShAdapter{}
|
||||
binaryPath, err := adapter.downloadBinary(context.Background(), "1.2.3")
|
||||
if err != nil {
|
||||
t.Fatalf("downloadBinary error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(filepath.Dir(filepath.Dir(filepath.Dir(binaryPath))))
|
||||
})
|
||||
|
||||
payload, err := os.ReadFile(binaryPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read binary: %v", err)
|
||||
}
|
||||
if string(payload) != "binary" {
|
||||
t.Fatalf("unexpected binary content: %q", string(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallShAdapter_WaitForHealth(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
curl := filepath.Join(dir, "curl")
|
||||
script := "#!/bin/sh\nexit 0\n"
|
||||
if err := os.WriteFile(curl, []byte(script), 0755); err != nil {
|
||||
t.Fatalf("write curl stub: %v", err)
|
||||
}
|
||||
t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
adapter := &InstallShAdapter{}
|
||||
if err := adapter.waitForHealth(context.Background(), 200*time.Millisecond); err != nil {
|
||||
t.Fatalf("waitForHealth error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallShAdapter_ServiceCommands(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
systemctl := filepath.Join(dir, "systemctl")
|
||||
script := "#!/bin/sh\nexit 0\n"
|
||||
if err := os.WriteFile(systemctl, []byte(script), 0755); err != nil {
|
||||
t.Fatalf("write systemctl stub: %v", err)
|
||||
}
|
||||
t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
adapter := &InstallShAdapter{}
|
||||
if err := adapter.stopService(context.Background(), "pulse"); err != nil {
|
||||
t.Fatalf("stopService error: %v", err)
|
||||
}
|
||||
if err := adapter.startService(context.Background(), "pulse"); err != nil {
|
||||
t.Fatalf("startService error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallShAdapter_RestoreConfig(t *testing.T) {
|
||||
backupDir := filepath.Join(t.TempDir(), "backup")
|
||||
targetDir := filepath.Join(t.TempDir(), "target")
|
||||
if err := os.MkdirAll(backupDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir backup: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir target: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(backupDir, "config.txt"), []byte("ok"), 0600); err != nil {
|
||||
t.Fatalf("write backup file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(targetDir, "old.txt"), []byte("old"), 0600); err != nil {
|
||||
t.Fatalf("write target file: %v", err)
|
||||
}
|
||||
|
||||
adapter := &InstallShAdapter{}
|
||||
if err := adapter.restoreConfig(context.Background(), backupDir, targetDir); err != nil {
|
||||
t.Fatalf("restoreConfig error: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(targetDir, "config.txt")); err != nil {
|
||||
t.Fatalf("expected restored file: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(targetDir, "old.txt")); err == nil {
|
||||
t.Fatal("expected old file to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallShAdapter_InstallBinary(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
source := filepath.Join(dir, "source")
|
||||
if err := os.WriteFile(source, []byte("payload"), 0600); err != nil {
|
||||
t.Fatalf("write source: %v", err)
|
||||
}
|
||||
|
||||
targetDir := filepath.Join(dir, "bin")
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir target: %v", err)
|
||||
}
|
||||
target := filepath.Join(targetDir, "pulse")
|
||||
if err := os.WriteFile(target, []byte("old"), 0600); err != nil {
|
||||
t.Fatalf("write target: %v", err)
|
||||
}
|
||||
|
||||
chownDir := t.TempDir()
|
||||
chown := filepath.Join(chownDir, "chown")
|
||||
if err := os.WriteFile(chown, []byte("#!/bin/sh\nexit 0\n"), 0755); err != nil {
|
||||
t.Fatalf("write chown stub: %v", err)
|
||||
}
|
||||
t.Setenv("PATH", chownDir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
adapter := &InstallShAdapter{}
|
||||
if err := adapter.installBinary(context.Background(), source, target); err != nil {
|
||||
t.Fatalf("installBinary error: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(target + ".pre-rollback"); err != nil {
|
||||
t.Fatalf("expected backup file: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(target)
|
||||
if err != nil {
|
||||
t.Fatalf("read target: %v", err)
|
||||
}
|
||||
if string(content) != "payload" {
|
||||
t.Fatalf("unexpected target content: %q", string(content))
|
||||
}
|
||||
}
|
||||
119
internal/updates/adapter_installsh_execute_test.go
Normal file
119
internal/updates/adapter_installsh_execute_test.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setupInstallShCurlStub(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
checksum := hex.EncodeToString(sum[:])
|
||||
|
||||
dir := t.TempDir()
|
||||
curl := filepath.Join(dir, "curl")
|
||||
script := strings.Join([]string{
|
||||
"#!/bin/sh",
|
||||
`out=""`,
|
||||
`url=""`,
|
||||
`while [ "$#" -gt 0 ]; do`,
|
||||
` if [ "$1" = "-o" ]; then`,
|
||||
` out="$2"`,
|
||||
` shift 2`,
|
||||
` continue`,
|
||||
` fi`,
|
||||
` url="$1"`,
|
||||
` shift`,
|
||||
`done`,
|
||||
`if echo "$url" | grep -q ".sha256$"; then`,
|
||||
` echo "` + checksum + ` install.sh" > "$out"`,
|
||||
`else`,
|
||||
` printf '%s' "` + content + `" > "$out"`,
|
||||
`fi`,
|
||||
``,
|
||||
}, "\n")
|
||||
if err := os.WriteFile(curl, []byte(script), 0755); err != nil {
|
||||
t.Fatalf("write curl stub: %v", err)
|
||||
}
|
||||
|
||||
return dir
|
||||
}
|
||||
|
||||
func setupBashStub(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
bash := filepath.Join(dir, "bash")
|
||||
script := strings.Join([]string{
|
||||
"#!/bin/sh",
|
||||
"cat >/dev/null",
|
||||
"echo \"Backup: /tmp/backup\"",
|
||||
"echo \"Installing\"",
|
||||
"echo \"Success\"",
|
||||
"exit 0",
|
||||
"",
|
||||
}, "\n")
|
||||
if err := os.WriteFile(bash, []byte(script), 0755); err != nil {
|
||||
t.Fatalf("write bash stub: %v", err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
func TestInstallShAdapterExecuteSuccess(t *testing.T) {
|
||||
scriptContent := "echo ok"
|
||||
curlDir := setupInstallShCurlStub(t, scriptContent)
|
||||
bashDir := setupBashStub(t)
|
||||
|
||||
t.Setenv("PATH", strings.Join([]string{curlDir, bashDir, os.Getenv("PATH")}, string(os.PathListSeparator)))
|
||||
|
||||
adapter := &InstallShAdapter{
|
||||
installScriptURL: "http://example/install.sh",
|
||||
logDir: t.TempDir(),
|
||||
}
|
||||
|
||||
var updates []UpdateProgress
|
||||
progress := func(p UpdateProgress) {
|
||||
updates = append(updates, p)
|
||||
}
|
||||
|
||||
err := adapter.Execute(context.Background(), UpdateRequest{Version: "v1.2.3"}, progress)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
t.Fatal("expected progress updates")
|
||||
}
|
||||
foundCompleted := false
|
||||
for _, progress := range updates {
|
||||
if progress.Stage == "completed" && progress.IsComplete {
|
||||
foundCompleted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundCompleted {
|
||||
t.Fatalf("expected completed progress, got %+v", updates)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallShAdapterExecuteInvalidVersion(t *testing.T) {
|
||||
scriptContent := "echo ok"
|
||||
curlDir := setupInstallShCurlStub(t, scriptContent)
|
||||
t.Setenv("PATH", curlDir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
adapter := &InstallShAdapter{
|
||||
installScriptURL: "http://example/install.sh",
|
||||
logDir: t.TempDir(),
|
||||
}
|
||||
|
||||
err := adapter.Execute(context.Background(), UpdateRequest{Version: "bad version"}, func(UpdateProgress) {})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid version")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid version format") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
80
internal/updates/adapter_installsh_helpers_test.go
Normal file
80
internal/updates/adapter_installsh_helpers_test.go
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInstallShAdapter_DetectServiceName(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
systemctl := filepath.Join(dir, "systemctl")
|
||||
script := `#!/bin/sh
|
||||
if [ "$1" = "is-active" ] && [ "$2" = "pulse-backend" ]; then
|
||||
echo "active"
|
||||
exit 0
|
||||
fi
|
||||
echo "inactive"
|
||||
exit 0
|
||||
`
|
||||
if err := os.WriteFile(systemctl, []byte(script), 0755); err != nil {
|
||||
t.Fatalf("write systemctl: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
adapter := &InstallShAdapter{}
|
||||
name, err := adapter.detectServiceName()
|
||||
if err != nil {
|
||||
t.Fatalf("detectServiceName error: %v", err)
|
||||
}
|
||||
if name != "pulse-backend" {
|
||||
t.Fatalf("expected pulse-backend, got %q", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallShAdapter_DownloadInstallScript(t *testing.T) {
|
||||
content := "echo hi"
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
checksum := hex.EncodeToString(sum[:])
|
||||
|
||||
dir := t.TempDir()
|
||||
curl := filepath.Join(dir, "curl")
|
||||
script := strings.Join([]string{
|
||||
"#!/bin/sh",
|
||||
`out=""`,
|
||||
`url=""`,
|
||||
`while [ "$#" -gt 0 ]; do`,
|
||||
` if [ "$1" = "-o" ]; then`,
|
||||
` out="$2"`,
|
||||
` shift 2`,
|
||||
` continue`,
|
||||
` fi`,
|
||||
` url="$1"`,
|
||||
` shift`,
|
||||
`done`,
|
||||
`if echo "$url" | grep -q ".sha256$"; then`,
|
||||
` echo "` + checksum + ` install.sh" > "$out"`,
|
||||
`else`,
|
||||
` printf '%s' "` + content + `" > "$out"`,
|
||||
`fi`,
|
||||
``,
|
||||
}, "\n")
|
||||
if err := os.WriteFile(curl, []byte(script), 0755); err != nil {
|
||||
t.Fatalf("write curl: %v", err)
|
||||
}
|
||||
t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
adapter := &InstallShAdapter{installScriptURL: "http://example/install.sh"}
|
||||
out, err := adapter.downloadInstallScript(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("downloadInstallScript error: %v", err)
|
||||
}
|
||||
if out != content {
|
||||
t.Fatalf("unexpected script content: %q", out)
|
||||
}
|
||||
}
|
||||
78
internal/updates/adapter_installsh_more_test.go
Normal file
78
internal/updates/adapter_installsh_more_test.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInstallShAdapter_PrepareUpdate(t *testing.T) {
|
||||
adapter := NewInstallShAdapter(nil)
|
||||
|
||||
plan, err := adapter.PrepareUpdate(context.Background(), UpdateRequest{Version: "v1.2.3"})
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareUpdate error: %v", err)
|
||||
}
|
||||
if !plan.CanAutoUpdate || !plan.RequiresRoot || !plan.RollbackSupport {
|
||||
t.Fatalf("unexpected plan: %+v", plan)
|
||||
}
|
||||
if len(plan.Instructions) == 0 || len(plan.Prerequisites) == 0 {
|
||||
t.Fatalf("expected instructions and prerequisites: %+v", plan)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallShAdapter_RollbackErrors(t *testing.T) {
|
||||
history, err := NewUpdateHistory(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("NewUpdateHistory error: %v", err)
|
||||
}
|
||||
adapter := NewInstallShAdapter(history)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := adapter.Rollback(ctx, "missing"); err == nil {
|
||||
t.Fatal("expected error for missing history entry")
|
||||
}
|
||||
|
||||
eventNoBackup, err := history.CreateEntry(ctx, UpdateHistoryEntry{
|
||||
Action: ActionUpdate,
|
||||
Status: StatusSuccess,
|
||||
VersionFrom: "v1.0.0",
|
||||
VersionTo: "v1.1.0",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEntry error: %v", err)
|
||||
}
|
||||
if err := adapter.Rollback(ctx, eventNoBackup); err == nil || !strings.Contains(err.Error(), "no backup path") {
|
||||
t.Fatalf("expected backup path error, got %v", err)
|
||||
}
|
||||
|
||||
eventMissingBackup, err := history.CreateEntry(ctx, UpdateHistoryEntry{
|
||||
Action: ActionUpdate,
|
||||
Status: StatusSuccess,
|
||||
VersionFrom: "v1.0.0",
|
||||
VersionTo: "v1.1.0",
|
||||
BackupPath: filepath.Join(t.TempDir(), "missing"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEntry error: %v", err)
|
||||
}
|
||||
if err := adapter.Rollback(ctx, eventMissingBackup); err == nil || !strings.Contains(err.Error(), "backup not found") {
|
||||
t.Fatalf("expected backup not found error, got %v", err)
|
||||
}
|
||||
|
||||
backupDir := t.TempDir()
|
||||
eventNoTarget, err := history.CreateEntry(ctx, UpdateHistoryEntry{
|
||||
Action: ActionUpdate,
|
||||
Status: StatusSuccess,
|
||||
VersionFrom: "",
|
||||
VersionTo: "v1.1.0",
|
||||
BackupPath: backupDir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateEntry error: %v", err)
|
||||
}
|
||||
if err := adapter.Rollback(ctx, eventNoTarget); err == nil || !strings.Contains(err.Error(), "no target version") {
|
||||
t.Fatalf("expected target version error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
|
@ -410,14 +411,14 @@ func TestDockerUpdater(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Execute returns error", func(t *testing.T) {
|
||||
err := updater.Execute(nil, UpdateRequest{}, nil)
|
||||
err := updater.Execute(context.Background(), UpdateRequest{}, nil)
|
||||
if err == nil {
|
||||
t.Error("Execute() should return error for docker deployments")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Rollback returns error", func(t *testing.T) {
|
||||
err := updater.Rollback(nil, "event-123")
|
||||
err := updater.Rollback(context.Background(), "event-123")
|
||||
if err == nil {
|
||||
t.Error("Rollback() should return error for docker deployments")
|
||||
}
|
||||
|
|
@ -440,14 +441,14 @@ func TestAURUpdater(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Execute returns error", func(t *testing.T) {
|
||||
err := updater.Execute(nil, UpdateRequest{}, nil)
|
||||
err := updater.Execute(context.Background(), UpdateRequest{}, nil)
|
||||
if err == nil {
|
||||
t.Error("Execute() should return error for AUR deployments")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Rollback returns error", func(t *testing.T) {
|
||||
err := updater.Rollback(nil, "event-123")
|
||||
err := updater.Rollback(context.Background(), "event-123")
|
||||
if err == nil {
|
||||
t.Error("Rollback() should return error for AUR deployments")
|
||||
}
|
||||
|
|
|
|||
122
internal/updates/manager_applyupdate_test.go
Normal file
122
internal/updates/manager_applyupdate_test.go
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeStub(t *testing.T, dir, name, script string) {
|
||||
t.Helper()
|
||||
|
||||
path := filepath.Join(dir, name)
|
||||
if err := os.WriteFile(path, []byte(script), 0755); err != nil {
|
||||
t.Fatalf("write stub %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func copyFile(src, dest string, mode os.FileMode) error {
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, copyErr := io.Copy(out, in)
|
||||
closeErr := out.Close()
|
||||
if copyErr != nil {
|
||||
return copyErr
|
||||
}
|
||||
if closeErr != nil {
|
||||
return closeErr
|
||||
}
|
||||
if mode != 0 {
|
||||
return os.Chmod(dest, mode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManagerApplyUpdateFilesMissingBinary(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
|
||||
if err := manager.applyUpdateFiles(t.TempDir()); err == nil {
|
||||
t.Fatal("expected error for missing pulse binary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerApplyUpdateFilesCopiesPulseBinary(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
|
||||
cpPath, err := exec.LookPath("cp")
|
||||
if err != nil {
|
||||
t.Fatalf("find cp: %v", err)
|
||||
}
|
||||
mvPath, err := exec.LookPath("mv")
|
||||
if err != nil {
|
||||
t.Fatalf("find mv: %v", err)
|
||||
}
|
||||
|
||||
stubDir := t.TempDir()
|
||||
writeStub(t, stubDir, "cp", fmt.Sprintf("#!/bin/sh\nexec %s \"$@\"\n", cpPath))
|
||||
writeStub(t, stubDir, "mv", fmt.Sprintf("#!/bin/sh\nexec %s \"$@\"\n", mvPath))
|
||||
writeStub(t, stubDir, "chown", "#!/bin/sh\nexit 0\n")
|
||||
t.Setenv("PATH", stubDir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
binaryPath, err := os.Executable()
|
||||
if err != nil {
|
||||
t.Fatalf("executable path: %v", err)
|
||||
}
|
||||
info, err := os.Stat(binaryPath)
|
||||
if err != nil {
|
||||
t.Fatalf("stat executable: %v", err)
|
||||
}
|
||||
backup := filepath.Join(t.TempDir(), "orig-binary")
|
||||
if err := copyFile(binaryPath, backup, info.Mode()); err != nil {
|
||||
t.Fatalf("backup binary: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := copyFile(backup, binaryPath, info.Mode()); err != nil {
|
||||
t.Fatalf("restore binary: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
extractDir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(extractDir, "pulse"), []byte("newbinary"), 0755); err != nil {
|
||||
t.Fatalf("write pulse: %v", err)
|
||||
}
|
||||
if err := manager.applyUpdateFiles(extractDir); err != nil {
|
||||
t.Fatalf("applyUpdateFiles root pulse: %v", err)
|
||||
}
|
||||
data, err := os.ReadFile(binaryPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read replaced binary: %v", err)
|
||||
}
|
||||
if string(data) != "newbinary" {
|
||||
t.Fatalf("unexpected root binary contents: %q", string(data))
|
||||
}
|
||||
|
||||
extractDir = t.TempDir()
|
||||
if err := os.MkdirAll(filepath.Join(extractDir, "bin"), 0755); err != nil {
|
||||
t.Fatalf("mkdir bin: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(extractDir, "bin", "pulse"), []byte("newbinary2"), 0755); err != nil {
|
||||
t.Fatalf("write pulse bin: %v", err)
|
||||
}
|
||||
if err := manager.applyUpdateFiles(extractDir); err != nil {
|
||||
t.Fatalf("applyUpdateFiles bin pulse: %v", err)
|
||||
}
|
||||
data, err = os.ReadFile(binaryPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read replaced binary (bin): %v", err)
|
||||
}
|
||||
if string(data) != "newbinary2" {
|
||||
t.Fatalf("unexpected bin binary contents: %q", string(data))
|
||||
}
|
||||
}
|
||||
140
internal/updates/manager_checksum_test.go
Normal file
140
internal/updates/manager_checksum_test.go
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
func TestManagerVerifyChecksum(t *testing.T) {
|
||||
tarballPath := filepath.Join(t.TempDir(), "pulse.tar.gz")
|
||||
if err := os.WriteFile(tarballPath, []byte("payload"), 0600); err != nil {
|
||||
t.Fatalf("write tarball: %v", err)
|
||||
}
|
||||
sum := sha256.Sum256([]byte("payload"))
|
||||
checksum := hex.EncodeToString(sum[:])
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "SHA256SUMS") {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(checksum + " pulse.tar.gz\n"))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := NewManager(nil)
|
||||
tarballURL := server.URL + "/pulse.tar.gz"
|
||||
|
||||
if err := manager.verifyChecksum(context.Background(), tarballURL, tarballPath); err != nil {
|
||||
t.Fatalf("verifyChecksum error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdaterPrepareUpdateInstructions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
docker := NewDockerUpdater()
|
||||
plan, err := docker.PrepareUpdate(ctx, UpdateRequest{Version: "v1.2.3"})
|
||||
if err != nil {
|
||||
t.Fatalf("docker PrepareUpdate error: %v", err)
|
||||
}
|
||||
if plan.CanAutoUpdate || len(plan.Instructions) == 0 {
|
||||
t.Fatalf("unexpected docker plan: %+v", plan)
|
||||
}
|
||||
|
||||
aur := NewAURUpdater()
|
||||
plan, err = aur.PrepareUpdate(ctx, UpdateRequest{Version: "v1.2.3"})
|
||||
if err != nil {
|
||||
t.Fatalf("aur PrepareUpdate error: %v", err)
|
||||
}
|
||||
if plan.CanAutoUpdate || len(plan.Instructions) == 0 {
|
||||
t.Fatalf("unexpected aur plan: %+v", plan)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdaterExecuteRollbackErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := NewDockerUpdater().Execute(ctx, UpdateRequest{}, nil); err == nil {
|
||||
t.Fatal("expected docker Execute error")
|
||||
}
|
||||
if err := NewDockerUpdater().Rollback(ctx, "event"); err == nil {
|
||||
t.Fatal("expected docker Rollback error")
|
||||
}
|
||||
|
||||
if err := NewAURUpdater().Execute(ctx, UpdateRequest{}, nil); err == nil {
|
||||
t.Fatal("expected aur Execute error")
|
||||
}
|
||||
if err := NewAURUpdater().Rollback(ctx, "event"); err == nil {
|
||||
t.Fatal("expected aur Rollback error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCloseAndBackup(t *testing.T) {
|
||||
manager := NewManager(&config.Config{})
|
||||
manager.Close()
|
||||
|
||||
select {
|
||||
case _, ok := <-manager.GetProgressChannel():
|
||||
if ok {
|
||||
t.Fatal("expected progress channel to be closed")
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected closed progress channel")
|
||||
}
|
||||
|
||||
t.Setenv("PULSE_DATA_DIR", t.TempDir())
|
||||
|
||||
dataDir := filepath.Join("/opt/pulse", "data")
|
||||
configDir := filepath.Join("/opt/pulse", "config")
|
||||
|
||||
if _, err := os.Stat(dataDir); err == nil {
|
||||
t.Skip("data dir already exists; skip backup test to avoid interference")
|
||||
}
|
||||
if _, err := os.Stat(configDir); err == nil {
|
||||
t.Skip("config dir already exists; skip backup test to avoid interference")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir data: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir config: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(dataDir)
|
||||
_ = os.RemoveAll(configDir)
|
||||
})
|
||||
|
||||
if err := os.WriteFile(filepath.Join(dataDir, "data.txt"), []byte("ok"), 0600); err != nil {
|
||||
t.Fatalf("write data file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(configDir, "config.txt"), []byte("ok"), 0600); err != nil {
|
||||
t.Fatalf("write config file: %v", err)
|
||||
}
|
||||
|
||||
backupDir, err := manager.createBackup()
|
||||
if err != nil {
|
||||
t.Fatalf("createBackup error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(backupDir)
|
||||
})
|
||||
|
||||
if _, err := os.Stat(filepath.Join(backupDir, "data", "data.txt")); err != nil {
|
||||
t.Fatalf("expected data backup: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(backupDir, "config", "config.txt")); err != nil {
|
||||
t.Fatalf("expected config backup: %v", err)
|
||||
}
|
||||
}
|
||||
139
internal/updates/manager_fileops_test.go
Normal file
139
internal/updates/manager_fileops_test.go
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeTarGz(t *testing.T, path string, files map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
var buf bytes.Buffer
|
||||
gzw := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gzw)
|
||||
|
||||
for name, content := range files {
|
||||
hdr := &tar.Header{
|
||||
Name: name,
|
||||
Mode: 0644,
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
t.Fatalf("write header: %v", err)
|
||||
}
|
||||
if _, err := io.WriteString(tw, content); err != nil {
|
||||
t.Fatalf("write content: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tw.Close(); err != nil {
|
||||
t.Fatalf("close tar: %v", err)
|
||||
}
|
||||
if err := gzw.Close(); err != nil {
|
||||
t.Fatalf("close gzip: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, buf.Bytes(), 0600); err != nil {
|
||||
t.Fatalf("write tarball: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExtractTarball(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
src := filepath.Join(t.TempDir(), "update.tar.gz")
|
||||
dest := filepath.Join(t.TempDir(), "extract")
|
||||
|
||||
writeTarGz(t, src, map[string]string{
|
||||
"bin/pulse": "binary",
|
||||
})
|
||||
|
||||
if err := manager.extractTarball(src, dest); err != nil {
|
||||
t.Fatalf("extractTarball error: %v", err)
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dest, "bin", "pulse"))
|
||||
if err != nil {
|
||||
t.Fatalf("read extracted file: %v", err)
|
||||
}
|
||||
if string(data) != "binary" {
|
||||
t.Fatalf("unexpected file content: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExtractTarballRejectsUnsafePaths(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
src := filepath.Join(t.TempDir(), "bad.tar.gz")
|
||||
dest := filepath.Join(t.TempDir(), "extract")
|
||||
|
||||
writeTarGz(t, src, map[string]string{
|
||||
"../evil": "nope",
|
||||
})
|
||||
|
||||
if err := manager.extractTarball(src, dest); err == nil {
|
||||
t.Fatal("expected error for unsafe path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCopyFileSafe(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
dir := t.TempDir()
|
||||
|
||||
src := filepath.Join(dir, "src.txt")
|
||||
dest := filepath.Join(dir, "dest.txt")
|
||||
if err := os.WriteFile(src, []byte("payload"), 0600); err != nil {
|
||||
t.Fatalf("write src: %v", err)
|
||||
}
|
||||
|
||||
if err := manager.copyFileSafe(src, dest); err != nil {
|
||||
t.Fatalf("copyFileSafe error: %v", err)
|
||||
}
|
||||
data, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("read dest: %v", err)
|
||||
}
|
||||
if string(data) != "payload" {
|
||||
t.Fatalf("unexpected dest content: %q", string(data))
|
||||
}
|
||||
|
||||
link := filepath.Join(dir, "link.txt")
|
||||
if err := os.Symlink(src, link); err != nil {
|
||||
t.Fatalf("symlink: %v", err)
|
||||
}
|
||||
skipDest := filepath.Join(dir, "skip.txt")
|
||||
if err := manager.copyFileSafe(link, skipDest); err != nil {
|
||||
t.Fatalf("copyFileSafe symlink error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(skipDest); err == nil {
|
||||
t.Fatal("expected symlink copy to be skipped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCopyDirSafe(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
srcDir := filepath.Join(t.TempDir(), "src")
|
||||
destDir := filepath.Join(t.TempDir(), "dest")
|
||||
|
||||
if err := os.MkdirAll(srcDir, 0755); err != nil {
|
||||
t.Fatalf("mkdir src: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(srcDir, "ok.txt"), []byte("ok"), 0600); err != nil {
|
||||
t.Fatalf("write ok: %v", err)
|
||||
}
|
||||
if err := os.Symlink(filepath.Join(srcDir, "ok.txt"), filepath.Join(srcDir, "link.txt")); err != nil {
|
||||
t.Fatalf("symlink: %v", err)
|
||||
}
|
||||
|
||||
if err := manager.copyDirSafe(srcDir, destDir); err != nil {
|
||||
t.Fatalf("copyDirSafe error: %v", err)
|
||||
}
|
||||
if _, err := os.ReadFile(filepath.Join(destDir, "ok.txt")); err != nil {
|
||||
t.Fatalf("expected ok.txt copied: %v", err)
|
||||
}
|
||||
if _, err := os.Lstat(filepath.Join(destDir, "link.txt")); err == nil {
|
||||
t.Fatal("expected symlink to be skipped")
|
||||
}
|
||||
}
|
||||
173
internal/updates/manager_more_test.go
Normal file
173
internal/updates/manager_more_test.go
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
)
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestResolveChannel(t *testing.T) {
|
||||
manager := NewManager(&config.Config{UpdateChannel: "stable"})
|
||||
|
||||
if got := manager.resolveChannel("rc", nil); got != "rc" {
|
||||
t.Fatalf("expected requested channel, got %s", got)
|
||||
}
|
||||
if got := manager.resolveChannel("", nil); got != "stable" {
|
||||
t.Fatalf("expected config channel, got %s", got)
|
||||
}
|
||||
if got := manager.resolveChannel("", &VersionInfo{Channel: "rc"}); got != "stable" {
|
||||
t.Fatalf("expected config to win, got %s", got)
|
||||
}
|
||||
|
||||
manager.config.UpdateChannel = ""
|
||||
if got := manager.resolveChannel("", &VersionInfo{Channel: "rc"}); got != "rc" {
|
||||
t.Fatalf("expected version channel, got %s", got)
|
||||
}
|
||||
if got := manager.resolveChannel("", nil); got != "stable" {
|
||||
t.Fatalf("expected default channel, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCachedUpdateInfo(t *testing.T) {
|
||||
manager := NewManager(&config.Config{UpdateChannel: "stable"})
|
||||
expected := &UpdateInfo{Available: true, LatestVersion: "v1.2.3"}
|
||||
manager.statusMu.Lock()
|
||||
manager.checkCache["stable"] = expected
|
||||
manager.cacheTime["stable"] = time.Now()
|
||||
manager.statusMu.Unlock()
|
||||
|
||||
if got := manager.GetCachedUpdateInfo(); got != expected {
|
||||
t.Fatalf("expected cached info, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerUpdateStatus(t *testing.T) {
|
||||
manager := NewManager(&config.Config{})
|
||||
|
||||
manager.updateStatus("checking", 12, "progress", errors.New("boom"))
|
||||
status := manager.GetStatus()
|
||||
if status.Status != "checking" || status.Progress != 12 || status.Message != "progress" {
|
||||
t.Fatalf("unexpected status: %+v", status)
|
||||
}
|
||||
if status.Error == "" || !strings.Contains(status.Error, "boom") {
|
||||
t.Fatalf("unexpected status error: %s", status.Error)
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-manager.GetProgressChannel():
|
||||
if got.Status != "checking" || got.Progress != 12 {
|
||||
t.Fatalf("unexpected progress: %+v", got)
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected progress update on channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfiguredStageDelay(t *testing.T) {
|
||||
stageDelayOnce = sync.Once{}
|
||||
stageDelayValue = 0
|
||||
t.Setenv("PULSE_UPDATE_STAGE_DELAY_MS", "15")
|
||||
if got := configuredStageDelay(); got != 15*time.Millisecond {
|
||||
t.Fatalf("expected 15ms, got %v", got)
|
||||
}
|
||||
if got := statusDelayForStage("downloading"); got != 15*time.Millisecond {
|
||||
t.Fatalf("expected 15ms delay for downloading, got %v", got)
|
||||
}
|
||||
if got := statusDelayForStage("idle"); got != 0 {
|
||||
t.Fatalf("expected 0 delay for idle, got %v", got)
|
||||
}
|
||||
|
||||
stageDelayOnce = sync.Once{}
|
||||
stageDelayValue = 0
|
||||
t.Setenv("PULSE_UPDATE_STAGE_DELAY_MS", "bad")
|
||||
if got := configuredStageDelay(); got != 0 {
|
||||
t.Fatalf("expected 0 delay for invalid value, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLatestReleaseFromFeedMocked(t *testing.T) {
|
||||
feed := `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<feed xmlns="http://www.w3.org/2005/Atom">
|
||||
<entry><title>Pulse v5.0.0-rc.1</title></entry>
|
||||
<entry><title>Pulse v4.36.2</title></entry>
|
||||
</feed>`
|
||||
|
||||
origTransport := http.DefaultTransport
|
||||
http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
body := io.NopCloser(strings.NewReader(feed))
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Status: "200 OK",
|
||||
Body: body,
|
||||
Header: http.Header{"Content-Type": []string{"application/atom+xml"}},
|
||||
Request: req,
|
||||
}, nil
|
||||
})
|
||||
t.Cleanup(func() { http.DefaultTransport = origTransport })
|
||||
|
||||
manager := NewManager(&config.Config{})
|
||||
|
||||
release, err := manager.getLatestReleaseFromFeed(context.Background(), "stable")
|
||||
if err != nil {
|
||||
t.Fatalf("stable feed error: %v", err)
|
||||
}
|
||||
if release.TagName != "v4.36.2" {
|
||||
t.Fatalf("unexpected stable tag: %s", release.TagName)
|
||||
}
|
||||
|
||||
release, err = manager.getLatestReleaseFromFeed(context.Background(), "rc")
|
||||
if err != nil {
|
||||
t.Fatalf("rc feed error: %v", err)
|
||||
}
|
||||
if release.TagName != "v5.0.0-rc.1" {
|
||||
t.Fatalf("unexpected rc tag: %s", release.TagName)
|
||||
}
|
||||
|
||||
feed = `<?xml version="1.0" encoding="UTF-8"?><feed></feed>`
|
||||
if _, err := manager.getLatestReleaseFromFeed(context.Background(), "stable"); err == nil {
|
||||
t.Fatal("expected error for empty feed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerDownloadFile(t *testing.T) {
|
||||
content := "payload"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(content))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager := NewManager(&config.Config{})
|
||||
dest := filepath.Join(t.TempDir(), "file.bin")
|
||||
|
||||
n, err := manager.downloadFile(context.Background(), server.URL, dest)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadFile error: %v", err)
|
||||
}
|
||||
if n != int64(len(content)) {
|
||||
t.Fatalf("expected %d bytes, got %d", len(content), n)
|
||||
}
|
||||
data, err := os.ReadFile(dest)
|
||||
if err != nil {
|
||||
t.Fatalf("read file error: %v", err)
|
||||
}
|
||||
if string(data) != content {
|
||||
t.Fatalf("unexpected file content: %s", string(data))
|
||||
}
|
||||
}
|
||||
48
internal/updates/mock_updater_test.go
Normal file
48
internal/updates/mock_updater_test.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMockUpdaterBasics(t *testing.T) {
|
||||
updater := NewMockUpdater()
|
||||
if updater == nil {
|
||||
t.Fatal("expected updater")
|
||||
}
|
||||
if !updater.SupportsApply() {
|
||||
t.Fatal("expected SupportsApply true")
|
||||
}
|
||||
if updater.GetDeploymentType() != "mock" {
|
||||
t.Fatalf("unexpected deployment type: %s", updater.GetDeploymentType())
|
||||
}
|
||||
|
||||
plan, err := updater.PrepareUpdate(context.Background(), UpdateRequest{Version: "1.2.3"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if plan == nil || !plan.CanAutoUpdate || len(plan.Instructions) == 0 {
|
||||
t.Fatal("unexpected plan result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockUpdaterExecute(t *testing.T) {
|
||||
updater := NewMockUpdater()
|
||||
var stages []UpdateProgress
|
||||
err := updater.Execute(context.Background(), UpdateRequest{}, func(stage UpdateProgress) {
|
||||
stages = append(stages, stage)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(stages) == 0 || !stages[len(stages)-1].IsComplete {
|
||||
t.Fatalf("unexpected stages: %+v", stages)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
err = updater.Execute(ctx, UpdateRequest{}, func(UpdateProgress) {})
|
||||
if err == nil {
|
||||
t.Fatal("expected context error")
|
||||
}
|
||||
}
|
||||
149
internal/websocket/hub_more2_test.go
Normal file
149
internal/websocket/hub_more2_test.go
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestBroadcastStateEnqueuesRawData(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
state := struct {
|
||||
DockerHosts []string
|
||||
}{DockerHosts: []string{"a", "b"}}
|
||||
|
||||
hub.BroadcastState(state)
|
||||
|
||||
select {
|
||||
case msg := <-hub.broadcastSeq:
|
||||
if msg.Type != "rawData" {
|
||||
t.Fatalf("unexpected message type: %s", msg.Type)
|
||||
}
|
||||
if !reflect.DeepEqual(msg.Data, state) {
|
||||
t.Fatalf("unexpected state payload: %+v", msg.Data)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected broadcastSeq message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastAlertResolvedAndCustom(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
|
||||
hub.BroadcastAlertResolved("alert-1")
|
||||
select {
|
||||
case data := <-hub.broadcast:
|
||||
var msg Message
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
t.Fatalf("unmarshal message: %v", err)
|
||||
}
|
||||
if msg.Type != "alertResolved" {
|
||||
t.Fatalf("unexpected type: %s", msg.Type)
|
||||
}
|
||||
payload := msg.Data.(map[string]interface{})
|
||||
if payload["alertId"] != "alert-1" {
|
||||
t.Fatalf("unexpected alertId: %v", payload["alertId"])
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected alertResolved broadcast")
|
||||
}
|
||||
|
||||
hub.Broadcast(map[string]string{"status": "ok"})
|
||||
select {
|
||||
case data := <-hub.broadcast:
|
||||
var msg Message
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
t.Fatalf("unmarshal message: %v", err)
|
||||
}
|
||||
if msg.Type != "custom" {
|
||||
t.Fatalf("unexpected type: %s", msg.Type)
|
||||
}
|
||||
if msg.Timestamp == "" {
|
||||
t.Fatal("expected timestamp on custom broadcast")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected custom broadcast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendPingEnqueuesMessage(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
hub.sendPing()
|
||||
|
||||
select {
|
||||
case data := <-hub.broadcast:
|
||||
var msg Message
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
t.Fatalf("unmarshal message: %v", err)
|
||||
}
|
||||
if msg.Type != "ping" {
|
||||
t.Fatalf("unexpected type: %s", msg.Type)
|
||||
}
|
||||
payload := msg.Data.(map[string]interface{})
|
||||
if _, ok := payload["timestamp"]; !ok {
|
||||
t.Fatal("expected ping timestamp")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected ping broadcast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopClosesChannel(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
hub.Stop()
|
||||
|
||||
select {
|
||||
case _, ok := <-hub.stopChan:
|
||||
if ok {
|
||||
t.Fatal("expected stopChan to be closed")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected stopChan closure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketPingPong(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
go hub.Run()
|
||||
t.Cleanup(hub.Stop)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(hub.HandleWebSocket))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.WriteJSON(Message{Type: "ping"}); err != nil {
|
||||
t.Fatalf("write ping: %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(1 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
|
||||
t.Fatalf("set read deadline: %v", err)
|
||||
}
|
||||
_, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var msg Message
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
t.Fatalf("unmarshal message: %v", err)
|
||||
}
|
||||
if msg.Type == "pong" {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatal("expected pong response")
|
||||
}
|
||||
127
internal/websocket/hub_more_test.go
Normal file
127
internal/websocket/hub_more_test.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDispatchToClientsDropsFullClient(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
client := &Client{
|
||||
hub: hub,
|
||||
send: make(chan []byte, 1),
|
||||
id: "client-1",
|
||||
}
|
||||
|
||||
hub.mu.Lock()
|
||||
hub.clients[client] = true
|
||||
hub.mu.Unlock()
|
||||
|
||||
client.send <- []byte("filled")
|
||||
|
||||
hub.dispatchToClients([]byte("payload"), "drop")
|
||||
|
||||
if hub.GetClientCount() != 0 {
|
||||
t.Fatalf("expected client to be dropped")
|
||||
}
|
||||
|
||||
<-client.send
|
||||
if _, ok := <-client.send; ok {
|
||||
t.Fatal("expected send channel to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunBroadcastSequencerImmediate(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
client := &Client{
|
||||
hub: hub,
|
||||
send: make(chan []byte, 1),
|
||||
id: "client-1",
|
||||
}
|
||||
|
||||
hub.mu.Lock()
|
||||
hub.clients[client] = true
|
||||
hub.mu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
hub.runBroadcastSequencer()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
hub.broadcastSeq <- Message{
|
||||
Type: "alert",
|
||||
Data: map[string]string{"id": "a1"},
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-client.send:
|
||||
var msg Message
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
t.Fatalf("unmarshal message: %v", err)
|
||||
}
|
||||
if msg.Type != "alert" {
|
||||
t.Fatalf("unexpected message type: %s", msg.Type)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected broadcast message")
|
||||
}
|
||||
|
||||
close(hub.stopChan)
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("broadcast sequencer did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunBroadcastSequencerCoalescesRawData(t *testing.T) {
|
||||
hub := NewHub(nil)
|
||||
hub.coalesceWindow = 5 * time.Millisecond
|
||||
client := &Client{
|
||||
hub: hub,
|
||||
send: make(chan []byte, 1),
|
||||
id: "client-1",
|
||||
}
|
||||
|
||||
hub.mu.Lock()
|
||||
hub.clients[client] = true
|
||||
hub.mu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
hub.runBroadcastSequencer()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
hub.broadcastSeq <- Message{Type: "rawData", Data: map[string]string{"value": "first"}}
|
||||
hub.broadcastSeq <- Message{Type: "rawData", Data: map[string]string{"value": "second"}}
|
||||
|
||||
select {
|
||||
case data := <-client.send:
|
||||
var msg Message
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
t.Fatalf("unmarshal message: %v", err)
|
||||
}
|
||||
payload := msg.Data.(map[string]interface{})
|
||||
if payload["value"] != "second" {
|
||||
t.Fatalf("expected coalesced value 'second', got %v", payload["value"])
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected coalesced message")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-client.send:
|
||||
t.Fatal("expected only one coalesced message")
|
||||
default:
|
||||
}
|
||||
|
||||
close(hub.stopChan)
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("broadcast sequencer did not exit")
|
||||
}
|
||||
}
|
||||
94
pkg/audit/export_test.go
Normal file
94
pkg/audit/export_test.go
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExporterExportAndSummary(t *testing.T) {
|
||||
logger, err := NewSQLiteLogger(SQLiteLoggerConfig{DataDir: t.TempDir()})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create logger: %v", err)
|
||||
}
|
||||
defer logger.Close()
|
||||
|
||||
events := []Event{
|
||||
{
|
||||
ID: "e1",
|
||||
Timestamp: time.Now().Add(-time.Minute),
|
||||
EventType: "login",
|
||||
User: "alice",
|
||||
IP: "127.0.0.1",
|
||||
Success: true,
|
||||
Details: "ok",
|
||||
},
|
||||
{
|
||||
ID: "e2",
|
||||
Timestamp: time.Now(),
|
||||
EventType: "config_change",
|
||||
User: "",
|
||||
IP: "127.0.0.2",
|
||||
Success: false,
|
||||
Details: "failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
if err := logger.Log(event); err != nil {
|
||||
t.Fatalf("log event: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
exporter := NewExporter(logger)
|
||||
result, err := exporter.Export(QueryFilter{}, ExportFormatCSV, true)
|
||||
if err != nil {
|
||||
t.Fatalf("export csv: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(result.Filename, "audit-log-") || !strings.HasSuffix(result.Filename, ".csv") {
|
||||
t.Fatalf("unexpected filename: %s", result.Filename)
|
||||
}
|
||||
|
||||
reader := csv.NewReader(strings.NewReader(string(result.Data)))
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
t.Fatalf("read csv: %v", err)
|
||||
}
|
||||
if len(records) < 3 || records[0][0] != "ID" {
|
||||
t.Fatalf("unexpected csv records: %+v", records)
|
||||
}
|
||||
|
||||
jsonResult, err := exporter.Export(QueryFilter{}, ExportFormatJSON, false)
|
||||
if err != nil {
|
||||
t.Fatalf("export json: %v", err)
|
||||
}
|
||||
var parsed struct {
|
||||
EventCount int `json:"event_count"`
|
||||
Events []ExportEvent `json:"events"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonResult.Data, &parsed); err != nil {
|
||||
t.Fatalf("decode json export: %v", err)
|
||||
}
|
||||
if parsed.EventCount != 2 || len(parsed.Events) != 2 {
|
||||
t.Fatalf("unexpected json export: %+v", parsed)
|
||||
}
|
||||
|
||||
if _, err := exporter.Export(QueryFilter{}, "xml", false); err == nil {
|
||||
t.Fatal("expected error for unsupported format")
|
||||
}
|
||||
|
||||
// Tamper with signature to test verification in summary/export
|
||||
if _, err := logger.db.Exec(`UPDATE audit_events SET signature = 'bad' WHERE id = ?`, "e2"); err != nil {
|
||||
t.Fatalf("tamper signature: %v", err)
|
||||
}
|
||||
|
||||
summary, err := exporter.GenerateSummary(QueryFilter{}, true)
|
||||
if err != nil {
|
||||
t.Fatalf("summary error: %v", err)
|
||||
}
|
||||
if summary.TotalEvents != 2 || summary.InvalidSigCount == 0 {
|
||||
t.Fatalf("unexpected summary: %+v", summary)
|
||||
}
|
||||
}
|
||||
|
|
@ -99,3 +99,133 @@ func TestStoreSelectTierAndStats(t *testing.T) {
|
|||
t.Fatalf("expected stats DB info to be populated: %+v", stats)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRollupTier(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := DefaultConfig(dir)
|
||||
cfg.DBPath = filepath.Join(dir, "metrics-rollup.db")
|
||||
cfg.FlushInterval = time.Hour
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore returned error: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
base := time.Now().Add(-2 * time.Minute).Truncate(time.Minute)
|
||||
ts := base.Unix()
|
||||
|
||||
_, err = store.db.Exec(
|
||||
`INSERT INTO metrics (resource_type, resource_id, metric_type, value, timestamp, tier) VALUES
|
||||
('vm','vm-101','cpu',1.0,?, 'raw'),
|
||||
('vm','vm-101','cpu',3.0,?, 'raw')`,
|
||||
ts, base.Add(10*time.Second).Unix(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("insert metrics returned error: %v", err)
|
||||
}
|
||||
|
||||
store.rollupTier(TierRaw, TierMinute, time.Minute, 0)
|
||||
|
||||
var countRaw int
|
||||
if err := store.db.QueryRow(`SELECT COUNT(*) FROM metrics WHERE tier = 'raw'`).Scan(&countRaw); err != nil {
|
||||
t.Fatalf("query raw count: %v", err)
|
||||
}
|
||||
if countRaw != 0 {
|
||||
t.Fatalf("expected raw metrics to be rolled up, got %d", countRaw)
|
||||
}
|
||||
|
||||
var value, minValue, maxValue float64
|
||||
var bucketTs int64
|
||||
if err := store.db.QueryRow(
|
||||
`SELECT value, min_value, max_value, timestamp FROM metrics WHERE tier = 'minute'`,
|
||||
).Scan(&value, &minValue, &maxValue, &bucketTs); err != nil {
|
||||
t.Fatalf("query minute tier: %v", err)
|
||||
}
|
||||
|
||||
expectedBucket := (ts / 60) * 60
|
||||
if bucketTs != expectedBucket {
|
||||
t.Fatalf("expected bucket %d, got %d", expectedBucket, bucketTs)
|
||||
}
|
||||
if value != 2.0 || minValue != 1.0 || maxValue != 3.0 {
|
||||
t.Fatalf("unexpected rollup values: value=%v min=%v max=%v", value, minValue, maxValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRetentionPrunesOldData(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := DefaultConfig(dir)
|
||||
cfg.DBPath = filepath.Join(dir, "metrics-retention.db")
|
||||
cfg.RetentionRaw = time.Minute
|
||||
cfg.RetentionMinute = time.Minute
|
||||
cfg.RetentionHourly = time.Minute
|
||||
cfg.RetentionDaily = time.Minute
|
||||
cfg.FlushInterval = time.Hour
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore returned error: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
oldTs := time.Now().Add(-2 * time.Hour).Unix()
|
||||
newTs := time.Now().Unix()
|
||||
|
||||
_, err = store.db.Exec(
|
||||
`INSERT INTO metrics (resource_type, resource_id, metric_type, value, timestamp, tier) VALUES
|
||||
('vm','vm-101','cpu',1.0,?, 'raw'),
|
||||
('vm','vm-101','cpu',2.0,?, 'minute'),
|
||||
('vm','vm-101','cpu',3.0,?, 'hourly'),
|
||||
('vm','vm-101','cpu',4.0,?, 'daily'),
|
||||
('vm','vm-101','cpu',5.0,?, 'raw')`,
|
||||
oldTs, oldTs, oldTs, oldTs, newTs,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("insert metrics returned error: %v", err)
|
||||
}
|
||||
|
||||
store.runRetention()
|
||||
|
||||
var rawCount int
|
||||
if err := store.db.QueryRow(`SELECT COUNT(*) FROM metrics WHERE tier = 'raw'`).Scan(&rawCount); err != nil {
|
||||
t.Fatalf("query raw count: %v", err)
|
||||
}
|
||||
if rawCount != 1 {
|
||||
t.Fatalf("expected 1 raw metric after retention, got %d", rawCount)
|
||||
}
|
||||
var total int
|
||||
if err := store.db.QueryRow(`SELECT COUNT(*) FROM metrics`).Scan(&total); err != nil {
|
||||
t.Fatalf("query total count: %v", err)
|
||||
}
|
||||
if total != 1 {
|
||||
t.Fatalf("expected only newest metric to remain, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreWriteFlushesBuffer(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := DefaultConfig(dir)
|
||||
cfg.DBPath = filepath.Join(dir, "metrics-buffer.db")
|
||||
cfg.WriteBufferSize = 1
|
||||
cfg.FlushInterval = time.Hour
|
||||
|
||||
store, err := NewStore(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore returned error: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
ts := time.Now().Add(-time.Second)
|
||||
store.Write("vm", "vm-101", "cpu", 1.5, ts)
|
||||
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
points, err := store.Query("vm", "vm-101", "cpu", ts.Add(-time.Second), ts.Add(time.Second))
|
||||
if err == nil && len(points) == 1 {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatal("expected buffered metric to flush to database")
|
||||
}
|
||||
|
|
|
|||
100
pkg/proxmox/ceph_test.go
Normal file
100
pkg/proxmox/ceph_test.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package proxmox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetCephStatus(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/cluster/ceph/status" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"data": map[string]interface{}{
|
||||
"fsid": "fsid-1",
|
||||
"health": map[string]interface{}{
|
||||
"status": "HEALTH_OK",
|
||||
"summary": []map[string]interface{}{
|
||||
{"severity": "info", "summary": "ok"},
|
||||
},
|
||||
"checks": map[string]interface{}{},
|
||||
},
|
||||
"servicemap": map[string]interface{}{
|
||||
"services": map[string]interface{}{},
|
||||
},
|
||||
"osdmap": map[string]interface{}{
|
||||
"num_osds": 1,
|
||||
"num_up_osds": 1,
|
||||
"num_in_osds": 1,
|
||||
},
|
||||
"pgmap": map[string]interface{}{
|
||||
"num_pgs": 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &Client{baseURL: server.URL, httpClient: server.Client()}
|
||||
status, err := client.GetCephStatus(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if status.FSID != "fsid-1" || status.Health.Status != "HEALTH_OK" {
|
||||
t.Fatalf("unexpected status: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCephDF(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/cluster/ceph/df" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"data": map[string]interface{}{
|
||||
"data": map[string]interface{}{
|
||||
"stats": map[string]interface{}{
|
||||
"total_bytes": 100,
|
||||
"total_used_bytes": 40,
|
||||
"total_avail_bytes": 60,
|
||||
"total_used_raw_bytes": 45,
|
||||
"percent_used": 40.0,
|
||||
},
|
||||
"pools": []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"name": "pool1",
|
||||
"stats": map[string]interface{}{
|
||||
"bytes_used": 10,
|
||||
"kb_used": 20,
|
||||
"max_avail": 30,
|
||||
"objects": 40,
|
||||
"percent_used": 10.0,
|
||||
"dirty": 0,
|
||||
"stored_raw": 50,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &Client{baseURL: server.URL, httpClient: server.Client()}
|
||||
df, err := client.GetCephDF(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if df.Data.Stats.TotalBytes != 100 || len(df.Data.Pools) != 1 {
|
||||
t.Fatalf("unexpected df: %+v", df)
|
||||
}
|
||||
}
|
||||
102
pkg/proxmox/client_api_more2_test.go
Normal file
102
pkg/proxmox/client_api_more2_test.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package proxmox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientNodeStatusAndRRD(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/nodes/node1/status":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": NodeStatus{CPU: 0.5, KernelVersion: "6.1"},
|
||||
})
|
||||
case "/api2/json/nodes/node1/rrddata":
|
||||
if !strings.Contains(r.URL.RawQuery, "timeframe=hour") || !strings.Contains(r.URL.RawQuery, "cf=AVERAGE") {
|
||||
http.Error(w, "bad query", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []NodeRRDPoint{{Time: 123}},
|
||||
})
|
||||
case "/api2/json/nodes/node1/lxc/101/rrddata":
|
||||
if !strings.Contains(r.URL.RawQuery, "ds=memused") {
|
||||
http.Error(w, "bad query", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []GuestRRDPoint{{Time: 456}},
|
||||
})
|
||||
case "/api2/json/nodes/node1/disks/list":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []Disk{{DevPath: "/dev/sda", Model: "Disk"}},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
status, err := client.GetNodeStatus(ctx, "node1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetNodeStatus error: %v", err)
|
||||
}
|
||||
if status.KernelVersion != "6.1" {
|
||||
t.Fatalf("unexpected node status: %+v", status)
|
||||
}
|
||||
|
||||
rrd, err := client.GetNodeRRDData(ctx, "node1", "", "", []string{"cpu"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetNodeRRDData error: %v", err)
|
||||
}
|
||||
if len(rrd) != 1 || rrd[0].Time != 123 {
|
||||
t.Fatalf("unexpected node rrd: %+v", rrd)
|
||||
}
|
||||
|
||||
guestRRD, err := client.GetLXCRRDData(ctx, "node1", 101, "", "", []string{"memused"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetLXCRRDData error: %v", err)
|
||||
}
|
||||
if len(guestRRD) != 1 || guestRRD[0].Time != 456 {
|
||||
t.Fatalf("unexpected guest rrd: %+v", guestRRD)
|
||||
}
|
||||
|
||||
disks, err := client.GetDisks(ctx, "node1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetDisks error: %v", err)
|
||||
}
|
||||
if len(disks) != 1 || disks[0].DevPath != "/dev/sda" {
|
||||
t.Fatalf("unexpected disks: %+v", disks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientNodeNetworkInterfaces(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/nodes/node1/network":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []NodeNetworkInterface{{Iface: "eth0", Active: 1}},
|
||||
})
|
||||
case "/api2/json/nodes/bad/network":
|
||||
http.Error(w, "boom", http.StatusInternalServerError)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
ifaces, err := client.GetNodeNetworkInterfaces(ctx, "node1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetNodeNetworkInterfaces error: %v", err)
|
||||
}
|
||||
if len(ifaces) != 1 || ifaces[0].Iface != "eth0" {
|
||||
t.Fatalf("unexpected interfaces: %+v", ifaces)
|
||||
}
|
||||
|
||||
if _, err := client.GetNodeNetworkInterfaces(ctx, "bad"); err == nil {
|
||||
t.Fatal("expected error for non-200 response")
|
||||
}
|
||||
}
|
||||
95
pkg/proxmox/client_api_more3_test.go
Normal file
95
pkg/proxmox/client_api_more3_test.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package proxmox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientVMFSInfoParsing(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/nodes/node1/qemu/100/agent/get-fsinfo":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": map[string]interface{}{
|
||||
"result": []map[string]interface{}{
|
||||
{
|
||||
"name": "root",
|
||||
"type": "ext4",
|
||||
"mountpoint": "/",
|
||||
"total-bytes": 100,
|
||||
"used-bytes": 10,
|
||||
"disk": []map[string]interface{}{
|
||||
{"dev": "/dev/sda"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "windows",
|
||||
"type": "ntfs",
|
||||
"mountpoint": "C:\\Windows",
|
||||
"total-bytes": 200,
|
||||
"used-bytes": 20,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
filesystems, err := client.GetVMFSInfo(ctx, "node1", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVMFSInfo error: %v", err)
|
||||
}
|
||||
if len(filesystems) != 2 {
|
||||
t.Fatalf("expected 2 filesystems, got %d", len(filesystems))
|
||||
}
|
||||
if filesystems[0].Disk != "/dev/sda" {
|
||||
t.Fatalf("expected disk from metadata, got %q", filesystems[0].Disk)
|
||||
}
|
||||
if filesystems[1].Disk != "C:" {
|
||||
t.Fatalf("expected windows drive disk, got %q", filesystems[1].Disk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientVMFSInfoObjectResult(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/nodes/node1/qemu/100/agent/get-fsinfo":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": map[string]interface{}{
|
||||
"result": map[string]interface{}{"error": "no info"},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
filesystems, err := client.GetVMFSInfo(ctx, "node1", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVMFSInfo error: %v", err)
|
||||
}
|
||||
if len(filesystems) != 0 {
|
||||
t.Fatalf("expected empty filesystem list, got %d", len(filesystems))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientContainerInterfacesError(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/nodes/node1/lxc/101/interfaces":
|
||||
http.Error(w, "boom", http.StatusBadRequest)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
if _, err := client.GetContainerInterfaces(ctx, "node1", 101); err == nil {
|
||||
t.Fatal("expected error for non-200 interface response")
|
||||
}
|
||||
}
|
||||
274
pkg/proxmox/client_api_more_test.go
Normal file
274
pkg/proxmox/client_api_more_test.go
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
package proxmox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestClient(t *testing.T, handler http.HandlerFunc) *Client {
|
||||
t.Helper()
|
||||
server := httptest.NewServer(handler)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
cfg := ClientConfig{
|
||||
Host: server.URL,
|
||||
TokenName: "user@pve!token",
|
||||
TokenValue: "secret",
|
||||
VerifySSL: false,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
client, err := NewClient(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient failed: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, w http.ResponseWriter, payload interface{}) {
|
||||
t.Helper()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(payload); err != nil {
|
||||
t.Fatalf("encode json: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientStorageAndTasks(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/storage":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []Storage{{Storage: "local", Type: "dir"}},
|
||||
})
|
||||
case "/api2/json/nodes":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []Node{{Node: "node1", Status: "online"}, {Node: "node2", Status: "offline"}},
|
||||
})
|
||||
case "/api2/json/nodes/node1/tasks":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []Task{
|
||||
{UPID: "1", Type: "vzdump"},
|
||||
{UPID: "2", Type: "other"},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
storage, err := client.GetAllStorage(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAllStorage error: %v", err)
|
||||
}
|
||||
if len(storage) != 1 || storage[0].Storage != "local" {
|
||||
t.Fatalf("unexpected storage: %+v", storage)
|
||||
}
|
||||
|
||||
tasks, err := client.GetBackupTasks(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetBackupTasks error: %v", err)
|
||||
}
|
||||
if len(tasks) != 1 || tasks[0].Type != "vzdump" {
|
||||
t.Fatalf("unexpected tasks: %+v", tasks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSnapshotsAndContent(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/nodes/node1/storage/local/content":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []StorageContent{
|
||||
{Volid: "backup1", Content: "backup"},
|
||||
{Volid: "iso1", Content: "iso"},
|
||||
{Volid: "tmpl1", Content: "vztmpl"},
|
||||
},
|
||||
})
|
||||
case "/api2/json/nodes/node1/qemu/100/snapshot":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []Snapshot{{Name: "current"}, {Name: "snap1"}},
|
||||
})
|
||||
case "/api2/json/nodes/node1/lxc/101/snapshot":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []Snapshot{{Name: "current"}, {Name: "snap2"}},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
content, err := client.GetStorageContent(ctx, "node1", "local")
|
||||
if err != nil {
|
||||
t.Fatalf("GetStorageContent error: %v", err)
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 backup-related items, got %d", len(content))
|
||||
}
|
||||
|
||||
snaps, err := client.GetVMSnapshots(ctx, "node1", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVMSnapshots error: %v", err)
|
||||
}
|
||||
if len(snaps) != 1 || snaps[0].Name != "snap1" || snaps[0].VMID != 100 {
|
||||
t.Fatalf("unexpected VM snapshots: %+v", snaps)
|
||||
}
|
||||
|
||||
ctSnaps, err := client.GetContainerSnapshots(ctx, "node1", 101)
|
||||
if err != nil {
|
||||
t.Fatalf("GetContainerSnapshots error: %v", err)
|
||||
}
|
||||
if len(ctSnaps) != 1 || ctSnaps[0].Name != "snap2" || ctSnaps[0].VMID != 101 {
|
||||
t.Fatalf("unexpected container snapshots: %+v", ctSnaps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientClusterAndAgentInfo(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/cluster/status":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []ClusterStatus{
|
||||
{Type: "cluster", Name: "prod"},
|
||||
{Type: "node", Name: "node1"},
|
||||
},
|
||||
})
|
||||
case "/api2/json/nodes/node1/qemu/100/config":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": map[string]interface{}{"name": "vm1"},
|
||||
})
|
||||
case "/api2/json/nodes/node1/qemu/100/agent/get-osinfo":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": map[string]interface{}{"id": "linux"},
|
||||
})
|
||||
case "/api2/json/nodes/node1/qemu/100/agent/info":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"version": "1.2.3",
|
||||
},
|
||||
},
|
||||
})
|
||||
case "/api2/json/nodes/node1/qemu/100/agent/network-get-interfaces":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": map[string]interface{}{
|
||||
"result": []VMNetworkInterface{{Name: "eth0"}},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
status, err := client.GetClusterStatus(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClusterStatus error: %v", err)
|
||||
}
|
||||
if len(status) != 2 {
|
||||
t.Fatalf("unexpected cluster status: %+v", status)
|
||||
}
|
||||
member, err := client.IsClusterMember(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("IsClusterMember error: %v", err)
|
||||
}
|
||||
if !member {
|
||||
t.Fatal("expected cluster membership")
|
||||
}
|
||||
|
||||
config, err := client.GetVMConfig(ctx, "node1", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVMConfig error: %v", err)
|
||||
}
|
||||
if config["name"] != "vm1" {
|
||||
t.Fatalf("unexpected vm config: %+v", config)
|
||||
}
|
||||
|
||||
osInfo, err := client.GetVMAgentInfo(ctx, "node1", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVMAgentInfo error: %v", err)
|
||||
}
|
||||
if osInfo["id"] != "linux" {
|
||||
t.Fatalf("unexpected agent info: %+v", osInfo)
|
||||
}
|
||||
|
||||
version, err := client.GetVMAgentVersion(ctx, "node1", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVMAgentVersion error: %v", err)
|
||||
}
|
||||
if version != "1.2.3" {
|
||||
t.Fatalf("unexpected agent version: %q", version)
|
||||
}
|
||||
|
||||
ifaces, err := client.GetVMNetworkInterfaces(ctx, "node1", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVMNetworkInterfaces error: %v", err)
|
||||
}
|
||||
if len(ifaces) != 1 || ifaces[0].Name != "eth0" {
|
||||
t.Fatalf("unexpected interfaces: %+v", ifaces)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientStatusAndResources(t *testing.T) {
|
||||
client := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api2/json/nodes/node1/qemu/100/status/current":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": map[string]interface{}{"status": "running"},
|
||||
})
|
||||
case "/api2/json/nodes/node1/lxc/101/status/current":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": map[string]interface{}{"status": "running", "vmid": 101},
|
||||
})
|
||||
case "/api2/json/cluster/resources":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []ClusterResource{{ID: "qemu/100", Type: "vm", VMID: 100}},
|
||||
})
|
||||
case "/api2/json/nodes/node1/lxc/101/interfaces":
|
||||
writeJSON(t, w, map[string]interface{}{
|
||||
"data": []ContainerInterface{{Name: "eth0", HWAddr: "aa:bb"}},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
vmStatus, err := client.GetVMStatus(ctx, "node1", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("GetVMStatus error: %v", err)
|
||||
}
|
||||
if vmStatus.Status != "running" {
|
||||
t.Fatalf("unexpected VM status: %+v", vmStatus)
|
||||
}
|
||||
|
||||
ctStatus, err := client.GetContainerStatus(ctx, "node1", 101)
|
||||
if err != nil {
|
||||
t.Fatalf("GetContainerStatus error: %v", err)
|
||||
}
|
||||
if ctStatus.Status != "running" || ctStatus.VMID != 101 {
|
||||
t.Fatalf("unexpected container status: %+v", ctStatus)
|
||||
}
|
||||
|
||||
resources, err := client.GetClusterResources(ctx, "vm")
|
||||
if err != nil {
|
||||
t.Fatalf("GetClusterResources error: %v", err)
|
||||
}
|
||||
if len(resources) != 1 || resources[0].VMID != 100 {
|
||||
t.Fatalf("unexpected resources: %+v", resources)
|
||||
}
|
||||
|
||||
ifaces, err := client.GetContainerInterfaces(ctx, "node1", 101)
|
||||
if err != nil {
|
||||
t.Fatalf("GetContainerInterfaces error: %v", err)
|
||||
}
|
||||
if len(ifaces) != 1 || ifaces[0].Name != "eth0" {
|
||||
t.Fatalf("unexpected container interfaces: %+v", ifaces)
|
||||
}
|
||||
}
|
||||
112
pkg/proxmox/cluster_client_more_test.go
Normal file
112
pkg/proxmox/cluster_client_more_test.go
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
package proxmox
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClusterClientEndpointFingerprint(t *testing.T) {
|
||||
cc := &ClusterClient{
|
||||
config: ClientConfig{Fingerprint: "base"},
|
||||
endpointFingerprints: map[string]string{
|
||||
"node1": "node-fp",
|
||||
},
|
||||
}
|
||||
|
||||
if got := cc.getEndpointFingerprint("node1"); got != "node-fp" {
|
||||
t.Fatalf("expected node fingerprint, got %s", got)
|
||||
}
|
||||
if got := cc.getEndpointFingerprint("node2"); got != "base" {
|
||||
t.Fatalf("expected base fingerprint, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClusterClientMarkAndClearError(t *testing.T) {
|
||||
cc := &ClusterClient{
|
||||
name: "cluster",
|
||||
nodeHealth: map[string]bool{"node1": true},
|
||||
lastError: make(map[string]string),
|
||||
lastHealthCheck: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
cc.markUnhealthyWithError("node1", "connection refused")
|
||||
if cc.nodeHealth["node1"] {
|
||||
t.Fatal("expected node to be unhealthy")
|
||||
}
|
||||
if errMsg := cc.lastError["node1"]; errMsg == "" || !strings.Contains(errMsg, "Connection refused") {
|
||||
t.Fatalf("unexpected error message: %q", errMsg)
|
||||
}
|
||||
|
||||
cc.clearEndpointError("node1")
|
||||
if !cc.nodeHealth["node1"] {
|
||||
t.Fatal("expected node to be healthy after clear")
|
||||
}
|
||||
if _, ok := cc.lastError["node1"]; ok {
|
||||
t.Fatal("expected lastError to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClusterClientApplyRateLimitCooldown(t *testing.T) {
|
||||
cc := &ClusterClient{rateLimitUntil: make(map[string]time.Time)}
|
||||
cc.applyRateLimitCooldown("node1", 100*time.Millisecond)
|
||||
if when, ok := cc.rateLimitUntil["node1"]; !ok || time.Until(when) <= 0 {
|
||||
t.Fatalf("expected cooldown set, got %v", when)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWithFailoverSkipsUnhealthyMarking(t *testing.T) {
|
||||
cc := &ClusterClient{
|
||||
name: "cluster",
|
||||
endpoints: []string{"node1"},
|
||||
clients: map[string]*Client{"node1": {}},
|
||||
nodeHealth: map[string]bool{"node1": true},
|
||||
lastError: make(map[string]string),
|
||||
lastHealthCheck: map[string]time.Time{"node1": time.Now()},
|
||||
rateLimitUntil: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
err := cc.executeWithFailover(context.Background(), func(*Client) error {
|
||||
return fmt.Errorf("No QEMU guest agent")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !cc.nodeHealth["node1"] {
|
||||
t.Fatal("expected node to remain healthy for VM-specific error")
|
||||
}
|
||||
if len(cc.lastError) != 0 {
|
||||
t.Fatalf("expected no lastError, got %+v", cc.lastError)
|
||||
}
|
||||
|
||||
err = cc.executeWithFailover(context.Background(), func(*Client) error {
|
||||
return fmt.Errorf("authentication failed")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected auth error")
|
||||
}
|
||||
if !cc.nodeHealth["node1"] {
|
||||
t.Fatal("expected node to remain healthy for auth error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWithFailoverClearsErrorOnSuccess(t *testing.T) {
|
||||
cc := &ClusterClient{
|
||||
name: "cluster",
|
||||
endpoints: []string{"node1"},
|
||||
clients: map[string]*Client{"node1": {}},
|
||||
nodeHealth: map[string]bool{"node1": true},
|
||||
lastError: map[string]string{"node1": "stale"},
|
||||
lastHealthCheck: map[string]time.Time{"node1": time.Now()},
|
||||
rateLimitUntil: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
if err := cc.executeWithFailover(context.Background(), func(*Client) error { return nil }); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if _, ok := cc.lastError["node1"]; ok {
|
||||
t.Fatal("expected lastError to be cleared")
|
||||
}
|
||||
}
|
||||
25
pkg/reporting/reporting_test.go
Normal file
25
pkg/reporting/reporting_test.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
package reporting
|
||||
|
||||
import "testing"
|
||||
|
||||
type fakeEngine struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (f *fakeEngine) Generate(req MetricReportRequest) ([]byte, string, error) {
|
||||
f.called = true
|
||||
return []byte("ok"), "text/plain", nil
|
||||
}
|
||||
|
||||
func TestSetGetEngine(t *testing.T) {
|
||||
engine := &fakeEngine{}
|
||||
SetEngine(engine)
|
||||
if GetEngine() != engine {
|
||||
t.Fatal("expected engine to be set")
|
||||
}
|
||||
|
||||
SetEngine(nil)
|
||||
if GetEngine() != nil {
|
||||
t.Fatal("expected engine to be cleared")
|
||||
}
|
||||
}
|
||||
52
pkg/server/metrics_test.go
Normal file
52
pkg/server/metrics_test.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStartMetricsServer(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
addr := listener.Addr().String()
|
||||
listener.Close()
|
||||
|
||||
startMetricsServer(ctx, addr)
|
||||
|
||||
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
var status int
|
||||
for time.Now().Before(deadline) {
|
||||
resp, err := client.Get("http://" + addr + "/metrics")
|
||||
if err == nil {
|
||||
status = resp.StatusCode
|
||||
resp.Body.Close()
|
||||
if status == http.StatusOK {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
}
|
||||
if status != http.StatusOK {
|
||||
t.Fatalf("expected metrics endpoint to respond, got status %d", status)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, err := client.Get("http://" + addr + "/metrics"); err != nil {
|
||||
return
|
||||
}
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("expected metrics server to stop after context cancellation")
|
||||
}
|
||||
105
pkg/server/server_helpers_test.go
Normal file
105
pkg/server/server_helpers_test.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShouldAutoImport(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
t.Setenv("PULSE_DATA_DIR", dir)
|
||||
t.Setenv("PULSE_INIT_CONFIG_DATA", "")
|
||||
t.Setenv("PULSE_INIT_CONFIG_FILE", "")
|
||||
|
||||
if ShouldAutoImport() {
|
||||
t.Fatal("expected auto-import false without config")
|
||||
}
|
||||
|
||||
t.Setenv("PULSE_INIT_CONFIG_DATA", "payload")
|
||||
if !ShouldAutoImport() {
|
||||
t.Fatal("expected auto-import true with data")
|
||||
}
|
||||
|
||||
t.Setenv("PULSE_INIT_CONFIG_DATA", "")
|
||||
t.Setenv("PULSE_INIT_CONFIG_FILE", "/tmp/file")
|
||||
if !ShouldAutoImport() {
|
||||
t.Fatal("expected auto-import true with file")
|
||||
}
|
||||
|
||||
file := filepath.Join(dir, "nodes.enc")
|
||||
if err := os.WriteFile(file, []byte("x"), 0600); err != nil {
|
||||
t.Fatalf("write nodes: %v", err)
|
||||
}
|
||||
if ShouldAutoImport() {
|
||||
t.Fatal("expected auto-import false when nodes.enc exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeImportPayload(t *testing.T) {
|
||||
if _, err := NormalizeImportPayload([]byte("")); err == nil {
|
||||
t.Fatal("expected error for empty payload")
|
||||
}
|
||||
|
||||
raw := []byte("hello")
|
||||
encoded := base64.StdEncoding.EncodeToString(raw)
|
||||
out, err := NormalizeImportPayload([]byte(encoded))
|
||||
if err != nil {
|
||||
t.Fatalf("normalize error: %v", err)
|
||||
}
|
||||
if out != encoded {
|
||||
t.Fatalf("unexpected output: %s", out)
|
||||
}
|
||||
|
||||
double := base64.StdEncoding.EncodeToString([]byte(encoded))
|
||||
out, err = NormalizeImportPayload([]byte(double))
|
||||
if err != nil {
|
||||
t.Fatalf("normalize error: %v", err)
|
||||
}
|
||||
if out != encoded {
|
||||
t.Fatalf("unexpected output: %s", out)
|
||||
}
|
||||
|
||||
out, err = NormalizeImportPayload([]byte("not-base64"))
|
||||
if err != nil {
|
||||
t.Fatalf("normalize error: %v", err)
|
||||
}
|
||||
if out == "not-base64" {
|
||||
t.Fatal("expected payload to be encoded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeBase64(t *testing.T) {
|
||||
if LooksLikeBase64("") {
|
||||
t.Fatal("expected false for empty")
|
||||
}
|
||||
if !LooksLikeBase64("aGVsbG8=") {
|
||||
t.Fatal("expected true for base64")
|
||||
}
|
||||
if LooksLikeBase64("nope!!") {
|
||||
t.Fatal("expected false for invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformAutoImportErrors(t *testing.T) {
|
||||
t.Setenv("PULSE_INIT_CONFIG_DATA", "data")
|
||||
t.Setenv("PULSE_INIT_CONFIG_FILE", "")
|
||||
t.Setenv("PULSE_INIT_CONFIG_PASSPHRASE", "")
|
||||
if err := PerformAutoImport(); err == nil {
|
||||
t.Fatal("expected error without passphrase")
|
||||
}
|
||||
|
||||
t.Setenv("PULSE_INIT_CONFIG_PASSPHRASE", "pass")
|
||||
t.Setenv("PULSE_INIT_CONFIG_FILE", "/tmp/missing-file")
|
||||
t.Setenv("PULSE_INIT_CONFIG_DATA", "")
|
||||
if err := PerformAutoImport(); err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
|
||||
t.Setenv("PULSE_INIT_CONFIG_FILE", "")
|
||||
t.Setenv("PULSE_INIT_CONFIG_DATA", "")
|
||||
if err := PerformAutoImport(); err == nil {
|
||||
t.Fatal("expected error for missing data")
|
||||
}
|
||||
}
|
||||
100
pkg/server/server_test.go
Normal file
100
pkg/server/server_test.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rcourtman/pulse-go-rewrite/internal/config"
|
||||
"github.com/rcourtman/pulse-go-rewrite/pkg/metrics"
|
||||
)
|
||||
|
||||
func TestBusinessHooks(t *testing.T) {
|
||||
called := false
|
||||
hook := func(store *metrics.Store) {
|
||||
called = true
|
||||
}
|
||||
|
||||
SetBusinessHooks(BusinessHooks{
|
||||
OnMetricsStoreReady: hook,
|
||||
})
|
||||
|
||||
globalHooksMu.Lock()
|
||||
defer globalHooksMu.Unlock()
|
||||
|
||||
if globalHooks.OnMetricsStoreReady == nil {
|
||||
t.Error("expected OnMetricsStoreReady to be set")
|
||||
}
|
||||
|
||||
// Manually trigger to verify it works
|
||||
globalHooks.OnMetricsStoreReady(nil)
|
||||
if !called {
|
||||
t.Error("expected hook to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformAutoImport_Success(t *testing.T) {
|
||||
// Setup temp directory
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PULSE_DATA_DIR", tmpDir)
|
||||
|
||||
// Create a persistence instance to generate valid encrypted payload
|
||||
sourceDir := t.TempDir()
|
||||
sourcePersistence := config.NewConfigPersistence(sourceDir)
|
||||
|
||||
passphrase := "test-pass"
|
||||
encryptedData, err := sourcePersistence.ExportConfig(passphrase)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate export data: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PULSE_INIT_CONFIG_DATA", encryptedData)
|
||||
t.Setenv("PULSE_INIT_CONFIG_FILE", "")
|
||||
t.Setenv("PULSE_INIT_CONFIG_PASSPHRASE", passphrase)
|
||||
|
||||
// Run PerformAutoImport
|
||||
if err := PerformAutoImport(); err != nil {
|
||||
t.Fatalf("PerformAutoImport failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify persistence file created (nodes.enc is a good indicator)
|
||||
_, err = os.Stat(filepath.Join(tmpDir, "nodes.enc"))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
t.Error("expected nodes.enc to be created")
|
||||
} else {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Minimal test for Server startup context cancellation
|
||||
func TestServerRun_Shutdown(t *testing.T) {
|
||||
// Setup minimal environment
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PULSE_DATA_DIR", tmpDir)
|
||||
t.Setenv("PULSE_CONFIG_PATH", tmpDir)
|
||||
|
||||
// Create a dummy config.yaml
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
// Use 0 port to try to avoid conflicts, though Run() might default it.
|
||||
if err := os.WriteFile(configFile, []byte("backendHost: 127.0.0.1\nfrontendPort: 0"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Cancel immediately/shortly to trigger shutdown path
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := Run(ctx, "test-version")
|
||||
|
||||
if err != nil && err != context.Canceled {
|
||||
t.Logf("Run returned: %v", err)
|
||||
}
|
||||
}
|
||||
67
pkg/tlsutil/extra_test.go
Normal file
67
pkg/tlsutil/extra_test.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package tlsutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDialContextWithCache(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := DialContextWithCache(ctx, "tcp", listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("DialContextWithCache error: %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected server accept")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchFingerprint(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cert := server.TLS.Certificates[0]
|
||||
if len(cert.Certificate) == 0 {
|
||||
t.Fatal("expected server certificate")
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(cert.Certificate[0])
|
||||
expected := hex.EncodeToString(sum[:])
|
||||
|
||||
fingerprint, err := FetchFingerprint(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchFingerprint error: %v", err)
|
||||
}
|
||||
if fingerprint != expected {
|
||||
t.Fatalf("unexpected fingerprint: %s", fingerprint)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue