diff --git a/internal/config/export.go b/internal/config/export.go index 0e74b35a9..eb4cc0f4a 100644 --- a/internal/config/export.go +++ b/internal/config/export.go @@ -161,6 +161,23 @@ func (c *ConfigPersistence) ImportConfig(encryptedData string, passphrase string default: fmt.Printf("Warning: Config was exported from unsupported version %s. Proceeding with best effort.\n", exportData.Version) } + + tx, err := newImportTransaction(c.configDir) + if err != nil { + return fmt.Errorf("failed to start import transaction: %w", err) + } + defer tx.Cleanup() + + c.beginTransaction(tx) + defer c.endTransaction(tx) + + committed := false + defer func() { + if !committed { + tx.Rollback() + } + }() + // Import all configurations if err := c.SaveNodesConfig(exportData.Nodes.PVEInstances, exportData.Nodes.PBSInstances, exportData.Nodes.PMGInstances); err != nil { return fmt.Errorf("failed to import nodes config: %w", err) @@ -201,7 +218,14 @@ func (c *ConfigPersistence) ImportConfig(encryptedData string, passphrase string if err := c.SaveOIDCConfig(*exportData.OIDC); err != nil { return fmt.Errorf("failed to import oidc configuration: %w", err) } - } else { + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit import transaction: %w", err) + } + committed = true + + if exportData.OIDC == nil { // Remove existing OIDC config if backup did not include one if err := os.Remove(c.oidcFile); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove existing oidc configuration: %w", err) diff --git a/internal/config/import_transaction.go b/internal/config/import_transaction.go new file mode 100644 index 000000000..79f08dbf3 --- /dev/null +++ b/internal/config/import_transaction.go @@ -0,0 +1,191 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "time" +) + +// importTransaction coordinates staging config writes during an import so they +// can be committed atomically or rolled back on failure. +type importTransaction struct { + configDir string + stagingDir string + timestamp string + + staged map[string]string // target path -> staged temp file + backups map[string]string // target path -> backup file path + + committed bool +} + +func newImportTransaction(configDir string) (*importTransaction, error) { + stagingDir, err := os.MkdirTemp(configDir, ".import-staging-*") + if err != nil { + return nil, fmt.Errorf("create import staging dir: %w", err) + } + + tx := &importTransaction{ + configDir: configDir, + stagingDir: stagingDir, + timestamp: time.Now().UTC().Format("20060102-150405"), + staged: make(map[string]string), + backups: make(map[string]string), + } + return tx, nil +} + +// StageFile writes the provided data to a temporary file within the staging +// directory and records it for later commit. +func (tx *importTransaction) StageFile(target string, data []byte, perm os.FileMode) error { + if tx.committed { + return fmt.Errorf("transaction already committed") + } + + if err := os.MkdirAll(tx.stagingDir, 0o700); err != nil { + return fmt.Errorf("ensure staging dir: %w", err) + } + + // Remove any previously staged data for this target. + if existing, ok := tx.staged[target]; ok { + _ = os.Remove(existing) + } + + base := filepath.Base(target) + if base == "" || base == string(os.PathSeparator) { + base = "staged" + } + prefix := strings.ReplaceAll(base, string(os.PathSeparator), "_") + if prefix == "" { + prefix = "staged" + } + if !strings.Contains(prefix, "*") { + prefix = prefix + ".tmp-*" + } + + // Create the staged file. + tmpFile, err := os.CreateTemp(tx.stagingDir, prefix) + if err != nil { + return fmt.Errorf("create staged file for %s: %w", target, err) + } + defer tmpFile.Close() + + if _, err := tmpFile.Write(data); err != nil { + _ = os.Remove(tmpFile.Name()) + return fmt.Errorf("write staged file for %s: %w", target, err) + } + + if err := tmpFile.Chmod(perm); err != nil { + _ = os.Remove(tmpFile.Name()) + return fmt.Errorf("chmod staged file for %s: %w", target, err) + } + + tx.staged[target] = tmpFile.Name() + return nil +} + +// Commit atomically applies all staged files. If any step fails the transaction +// restores previous backups and returns an error. +func (tx *importTransaction) Commit() error { + if tx.committed { + return fmt.Errorf("transaction already committed") + } + tx.committed = true + + targets := make([]string, 0, len(tx.staged)) + for target := range tx.staged { + targets = append(targets, target) + } + sort.Strings(targets) + + applied := make([]string, 0, len(targets)) + + restore := func() { + for i := len(applied) - 1; i >= 0; i-- { + target := applied[i] + stagedPath := tx.staged[target] + + // Ensure staged file removed (best effort). + _ = os.Remove(stagedPath) + + // Restore backup if present. + if backup := tx.backups[target]; backup != "" { + if _, err := os.Stat(backup); err == nil { + _ = os.Remove(target) + if err := os.Rename(backup, target); err == nil { + tx.backups[target] = "" + } + } + } + } + } + + for _, target := range targets { + stagedPath := tx.staged[target] + + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { + restore() + return fmt.Errorf("ensure dir for %s: %w", target, err) + } + + // Move current file to backup (if it exists and isn't already a dir). + if info, err := os.Stat(target); err == nil { + if info.IsDir() { + restore() + return fmt.Errorf("destination %s is a directory", target) + } + backupPath := fmt.Sprintf("%s.import-backup-%s", target, tx.timestamp) + if err := os.Rename(target, backupPath); err != nil { + restore() + return fmt.Errorf("backup existing file %s: %w", target, err) + } + tx.backups[target] = backupPath + } else if !os.IsNotExist(err) { + restore() + return fmt.Errorf("stat destination %s: %w", target, err) + } + + if err := os.Rename(stagedPath, target); err != nil { + restore() + return fmt.Errorf("apply staged file to %s: %w", target, err) + } + + applied = append(applied, target) + } + + // Successful commit: remove backups (best effort). + for _, target := range applied { + if backup := tx.backups[target]; backup != "" { + _ = os.Remove(backup) + tx.backups[target] = "" + } + } + return nil +} + +// Rollback drops all staged files and restores any backups already created. +func (tx *importTransaction) Rollback() { + for target, stagedPath := range tx.staged { + _ = os.Remove(stagedPath) + + if backup := tx.backups[target]; backup != "" { + // Only attempt restore when backup still exists. + if _, err := os.Stat(backup); err != nil { + continue + } + _ = os.Remove(target) + if err := os.Rename(backup, target); err != nil { + continue + } + tx.backups[target] = "" + } + } +} + +// Cleanup removes the staging directory. +func (tx *importTransaction) Cleanup() { + _ = os.RemoveAll(tx.stagingDir) +} diff --git a/internal/config/persistence.go b/internal/config/persistence.go index a23acd711..4920cf6f6 100644 --- a/internal/config/persistence.go +++ b/internal/config/persistence.go @@ -20,6 +20,7 @@ import ( // ConfigPersistence handles saving and loading configuration type ConfigPersistence struct { mu sync.RWMutex + tx *importTransaction configDir string alertFile string emailFile string @@ -73,6 +74,36 @@ func (c *ConfigPersistence) EnsureConfigDir() error { return os.MkdirAll(c.configDir, 0700) } +func (c *ConfigPersistence) beginTransaction(tx *importTransaction) { + c.mu.Lock() + c.tx = tx + c.mu.Unlock() +} + +func (c *ConfigPersistence) endTransaction(tx *importTransaction) { + c.mu.Lock() + if c.tx == tx { + c.tx = nil + } + c.mu.Unlock() +} + +func (c *ConfigPersistence) writeConfigFileLocked(path string, data []byte, perm os.FileMode) error { + if c.tx != nil { + return c.tx.StageFile(path, data, perm) + } + + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, perm); err != nil { + return err + } + if err := os.Rename(tmp, path); err != nil { + _ = os.Remove(tmp) + return err + } + return nil +} + // LoadAPITokens loads API token metadata from disk. func (c *ConfigPersistence) LoadAPITokens() ([]APITokenRecord, error) { c.mu.RLock() @@ -119,17 +150,7 @@ func (c *ConfigPersistence) SaveAPITokens(tokens []APITokenRecord) error { return err } - tmp := c.apiTokensFile + ".tmp" - if err := os.WriteFile(tmp, data, 0600); err != nil { - return err - } - - if err := os.Rename(tmp, c.apiTokensFile); err != nil { - os.Remove(tmp) - return err - } - - return nil + return c.writeConfigFileLocked(c.apiTokensFile, data, 0600) } // SaveAlertConfig saves alert configuration to file @@ -214,7 +235,7 @@ func (c *ConfigPersistence) SaveAlertConfig(config alerts.AlertConfig) error { return err } - if err := os.WriteFile(c.alertFile, data, 0600); err != nil { + if err := c.writeConfigFileLocked(c.alertFile, data, 0600); err != nil { return err } @@ -383,7 +404,7 @@ func (c *ConfigPersistence) SaveEmailConfig(config notifications.EmailConfig) er } // Save with restricted permissions (owner read/write only) - if err := os.WriteFile(c.emailFile, data, 0600); err != nil { + if err := c.writeConfigFileLocked(c.emailFile, data, 0600); err != nil { return err } @@ -458,7 +479,7 @@ func (c *ConfigPersistence) SaveAppriseConfig(config notifications.AppriseConfig data = encrypted } - if err := os.WriteFile(c.appriseFile, data, 0600); err != nil { + if err := c.writeConfigFileLocked(c.appriseFile, data, 0600); err != nil { return err } @@ -533,7 +554,7 @@ func (c *ConfigPersistence) SaveWebhooks(webhooks []notifications.WebhookConfig) data = encrypted } - if err := os.WriteFile(c.webhookFile, data, 0600); err != nil { + if err := c.writeConfigFileLocked(c.webhookFile, data, 0600); err != nil { return err } @@ -791,15 +812,7 @@ func (c *ConfigPersistence) saveNodesConfig(pveInstances []PVEInstance, pbsInsta data = encrypted } - // Write to temporary file first, then atomically rename - tempFile := c.nodesFile + ".tmp" - if err := os.WriteFile(tempFile, data, 0600); err != nil { - return err - } - - // Atomic rename - if err := os.Rename(tempFile, c.nodesFile); err != nil { - os.Remove(tempFile) // Clean up temp file on error + if err := c.writeConfigFileLocked(c.nodesFile, data, 0600); err != nil { return err } @@ -1077,7 +1090,7 @@ func (c *ConfigPersistence) SaveSystemSettings(settings SystemSettings) error { return err } - if err := os.WriteFile(c.systemFile, data, 0600); err != nil { + if err := c.writeConfigFileLocked(c.systemFile, data, 0600); err != nil { return err } @@ -1117,7 +1130,7 @@ func (c *ConfigPersistence) SaveOIDCConfig(settings OIDCConfig) error { data = encrypted } - if err := os.WriteFile(c.oidcFile, data, 0600); err != nil { + if err := c.writeConfigFileLocked(c.oidcFile, data, 0600); err != nil { return err } diff --git a/internal/config/persistence_test.go b/internal/config/persistence_test.go index b0aca3fc4..262e680e5 100644 --- a/internal/config/persistence_test.go +++ b/internal/config/persistence_test.go @@ -1,15 +1,24 @@ package config_test import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" + "io" "os" "path/filepath" "reflect" "testing" + "time" "github.com/rcourtman/pulse-go-rewrite/internal/alerts" "github.com/rcourtman/pulse-go-rewrite/internal/config" "github.com/rcourtman/pulse-go-rewrite/internal/notifications" + "golang.org/x/crypto/pbkdf2" ) func TestSaveAlertConfig_PreservesStorageOverrideHysteresis(t *testing.T) { @@ -256,3 +265,645 @@ func TestAppriseConfigPersistence(t *testing.T) { t.Fatalf("expected disabled configuration when no targets stored") } } + +func TestExportConfigIncludesAPITokens(t *testing.T) { + t.Setenv("PULSE_DATA_DIR", t.TempDir()) + + tempDir := t.TempDir() + cp := config.NewConfigPersistence(tempDir) + if err := cp.EnsureConfigDir(); err != nil { + t.Fatalf("EnsureConfigDir: %v", err) + } + + createdAt := time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC) + tokens := []config.APITokenRecord{ + { + ID: "token-1", + Name: "automation", + Hash: "hash-1", + Prefix: "hash-1", + Suffix: "-0001", + CreatedAt: createdAt, + }, + { + ID: "token-2", + Name: "metrics", + Hash: "hash-2", + Prefix: "hash-2", + Suffix: "-0002", + CreatedAt: createdAt.Add(time.Hour), + }, + } + + if err := cp.SaveAPITokens(tokens); err != nil { + t.Fatalf("SaveAPITokens: %v", err) + } + + passphrase := "strong-passphrase" + exported, err := cp.ExportConfig(passphrase) + if err != nil { + t.Fatalf("ExportConfig: %v", err) + } + + decoded := mustDecodeExport(t, exported, passphrase) + + if decoded.Version != "4.1" { + t.Fatalf("expected export version 4.1, got %q", decoded.Version) + } + + assertJSONEqual(t, decoded.APITokens, tokens, "api tokens") +} + +func TestImportConfigTransactionalSuccess(t *testing.T) { + const passphrase = "import-success" + + sourceDataDir := t.TempDir() + t.Setenv("PULSE_DATA_DIR", sourceDataDir) + + sourceConfigDir := t.TempDir() + source := config.NewConfigPersistence(sourceConfigDir) + if err := source.EnsureConfigDir(); err != nil { + t.Fatalf("EnsureConfigDir: %v", err) + } + + newNodes := []config.PVEInstance{ + { + Name: "pve-new", + Host: "https://pve-new.example:8006", + User: "root@pam", + MonitorVMs: true, + MonitorStorage: true, + }, + } + newPBS := []config.PBSInstance{ + { + Name: "pbs-new", + Host: "https://pbs-new.example:8007", + User: "pbs@pam", + MonitorBackups: true, + }, + } + if err := source.SaveNodesConfig(newNodes, newPBS, nil); err != nil { + t.Fatalf("SaveNodesConfig: %v", err) + } + + newAlerts := alerts.AlertConfig{ + Enabled: true, + HysteresisMargin: 3.5, + StorageDefault: alerts.HysteresisThreshold{ + Trigger: 70, + Clear: 65, + }, + TimeThreshold: 10, + TimeThresholds: map[string]int{ + "guest": 10, + "node": 10, + "storage": 10, + "pbs": 10, + }, + Overrides: map[string]alerts.ThresholdConfig{ + "node/pve-new": { + CPU: &alerts.HysteresisThreshold{Trigger: 80, Clear: 72}, + }, + }, + } + if err := source.SaveAlertConfig(newAlerts); err != nil { + t.Fatalf("SaveAlertConfig: %v", err) + } + + newSystem := config.SystemSettings{ + PBSPollingInterval: 45, + PMGPollingInterval: 50, + AutoUpdateEnabled: true, + DiscoveryEnabled: false, + DiscoverySubnet: "192.168.10.0/24", + DiscoveryConfig: config.DefaultDiscoveryConfig(), + Theme: "dark", + AllowEmbedding: true, + } + if err := source.SaveSystemSettings(newSystem); err != nil { + t.Fatalf("SaveSystemSettings: %v", err) + } + + newTokens := []config.APITokenRecord{ + { + ID: "token-new-1", + Name: "automation", + Hash: "hash-new-1", + Prefix: "hashn1", + Suffix: "n1", + CreatedAt: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), + }, + } + if err := source.SaveAPITokens(newTokens); err != nil { + t.Fatalf("SaveAPITokens: %v", err) + } + + exported, err := source.ExportConfig(passphrase) + if err != nil { + t.Fatalf("ExportConfig: %v", err) + } + exportedData := mustDecodeExport(t, exported, passphrase) + + targetDataDir := t.TempDir() + t.Setenv("PULSE_DATA_DIR", targetDataDir) + + targetConfigDir := t.TempDir() + target := config.NewConfigPersistence(targetConfigDir) + if err := target.EnsureConfigDir(); err != nil { + t.Fatalf("EnsureConfigDir: %v", err) + } + + oldNodes := []config.PVEInstance{ + { + Name: "pve-old", + Host: "https://pve-old.example:8006", + User: "root@pam", + }, + } + if err := target.SaveNodesConfig(oldNodes, nil, nil); err != nil { + t.Fatalf("SaveNodesConfig baseline: %v", err) + } + + oldAlerts := alerts.AlertConfig{ + Enabled: true, + HysteresisMargin: 5, + StorageDefault: alerts.HysteresisThreshold{ + Trigger: 85, + Clear: 80, + }, + Overrides: map[string]alerts.ThresholdConfig{}, + } + if err := target.SaveAlertConfig(oldAlerts); err != nil { + t.Fatalf("SaveAlertConfig baseline: %v", err) + } + + oldSystem := config.SystemSettings{ + PBSPollingInterval: 120, + PMGPollingInterval: 120, + AutoUpdateEnabled: false, + DiscoveryEnabled: true, + DiscoverySubnet: "auto", + DiscoveryConfig: config.DefaultDiscoveryConfig(), + Theme: "light", + } + if err := target.SaveSystemSettings(oldSystem); err != nil { + t.Fatalf("SaveSystemSettings baseline: %v", err) + } + + oldTokens := []config.APITokenRecord{ + { + ID: "token-old-1", + Name: "legacy", + Hash: "hash-old-1", + Prefix: "hasho1", + Suffix: "o1", + CreatedAt: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), + }, + } + if err := target.SaveAPITokens(oldTokens); err != nil { + t.Fatalf("SaveAPITokens baseline: %v", err) + } + + if err := target.ImportConfig(exported, passphrase); err != nil { + t.Fatalf("ImportConfig: %v", err) + } + + nodesAfter, err := target.LoadNodesConfig() + if err != nil { + t.Fatalf("LoadNodesConfig: %v", err) + } + assertJSONEqual(t, nodesAfter, exportedData.Nodes, "nodes") + + alertsAfter, err := target.LoadAlertConfig() + if err != nil { + t.Fatalf("LoadAlertConfig: %v", err) + } + assertJSONEqual(t, alertsAfter, exportedData.Alerts, "alerts") + + systemAfter, err := target.LoadSystemSettings() + if err != nil { + t.Fatalf("LoadSystemSettings: %v", err) + } + if systemAfter == nil { + t.Fatal("expected system settings after import") + } + assertJSONEqual(t, systemAfter, exportedData.System, "system settings") + + tokensAfter, err := target.LoadAPITokens() + if err != nil { + t.Fatalf("LoadAPITokens: %v", err) + } + assertJSONEqual(t, tokensAfter, exportedData.APITokens, "api tokens") + + tmpFiles, err := filepath.Glob(filepath.Join(targetConfigDir, "*.tmp")) + if err != nil { + t.Fatalf("Glob tmp files: %v", err) + } + if len(tmpFiles) != 0 { + t.Fatalf("expected no tmp files after import, found %v", tmpFiles) + } +} + +func TestImportConfigRollbackOnFailure(t *testing.T) { + const passphrase = "import-rollback" + + sourceDataDir := t.TempDir() + t.Setenv("PULSE_DATA_DIR", sourceDataDir) + + sourceConfigDir := t.TempDir() + source := config.NewConfigPersistence(sourceConfigDir) + if err := source.EnsureConfigDir(); err != nil { + t.Fatalf("EnsureConfigDir: %v", err) + } + + newNodes := []config.PVEInstance{ + { + Name: "pve-new", + Host: "https://pve-new.example:8006", + User: "root@pam", + }, + } + if err := source.SaveNodesConfig(newNodes, nil, nil); err != nil { + t.Fatalf("SaveNodesConfig: %v", err) + } + + newAlerts := alerts.AlertConfig{ + Enabled: true, + HysteresisMargin: 4, + StorageDefault: alerts.HysteresisThreshold{ + Trigger: 65, + Clear: 60, + }, + Overrides: map[string]alerts.ThresholdConfig{}, + } + if err := source.SaveAlertConfig(newAlerts); err != nil { + t.Fatalf("SaveAlertConfig: %v", err) + } + + newSystem := config.SystemSettings{ + PBSPollingInterval: 30, + PMGPollingInterval: 30, + AutoUpdateEnabled: true, + DiscoveryEnabled: false, + DiscoverySubnet: "10.20.0.0/24", + DiscoveryConfig: config.DefaultDiscoveryConfig(), + } + if err := source.SaveSystemSettings(newSystem); err != nil { + t.Fatalf("SaveSystemSettings: %v", err) + } + + newTokens := []config.APITokenRecord{ + { + ID: "token-new", + Name: "new", + Hash: "hash-new", + Prefix: "hashn", + Suffix: "-n", + CreatedAt: time.Date(2024, 2, 2, 12, 0, 0, 0, time.UTC), + }, + } + if err := source.SaveAPITokens(newTokens); err != nil { + t.Fatalf("SaveAPITokens: %v", err) + } + + exported, err := source.ExportConfig(passphrase) + if err != nil { + t.Fatalf("ExportConfig: %v", err) + } + + targetDataDir := t.TempDir() + t.Setenv("PULSE_DATA_DIR", targetDataDir) + + targetConfigDir := t.TempDir() + target := config.NewConfigPersistence(targetConfigDir) + if err := target.EnsureConfigDir(); err != nil { + t.Fatalf("EnsureConfigDir: %v", err) + } + + baselineNodes := []config.PVEInstance{ + { + Name: "pve-original", + Host: "https://pve-original.example:8006", + User: "root@pam", + }, + } + if err := target.SaveNodesConfig(baselineNodes, nil, nil); err != nil { + t.Fatalf("SaveNodesConfig baseline: %v", err) + } + baselineAlerts := alerts.AlertConfig{ + Enabled: true, + HysteresisMargin: 5, + StorageDefault: alerts.HysteresisThreshold{ + Trigger: 90, + Clear: 85, + }, + Overrides: map[string]alerts.ThresholdConfig{}, + } + if err := target.SaveAlertConfig(baselineAlerts); err != nil { + t.Fatalf("SaveAlertConfig baseline: %v", err) + } + baselineTokens := []config.APITokenRecord{ + { + ID: "token-original", + Name: "original", + Hash: "hash-original", + Prefix: "hasho", + Suffix: "-o", + CreatedAt: time.Date(2023, 3, 3, 12, 0, 0, 0, time.UTC), + }, + } + if err := target.SaveAPITokens(baselineTokens); err != nil { + t.Fatalf("SaveAPITokens baseline: %v", err) + } + + originalNodes, err := target.LoadNodesConfig() + if err != nil { + t.Fatalf("LoadNodesConfig: %v", err) + } + originalNodesJSON := mustMarshalJSON(t, originalNodes) + + originalAlerts, err := target.LoadAlertConfig() + if err != nil { + t.Fatalf("LoadAlertConfig: %v", err) + } + originalAlertsJSON := mustMarshalJSON(t, originalAlerts) + + originalTokens, err := target.LoadAPITokens() + if err != nil { + t.Fatalf("LoadAPITokens: %v", err) + } + originalTokensJSON := mustMarshalJSON(t, originalTokens) + + if err := os.Mkdir(filepath.Join(targetConfigDir, "system.json"), 0o700); err != nil { + t.Fatalf("creating obstacle directory: %v", err) + } + + if err := target.ImportConfig(exported, passphrase); err == nil { + t.Fatal("expected import to fail, but it succeeded") + } + + nodesAfter, err := target.LoadNodesConfig() + if err != nil { + t.Fatalf("LoadNodesConfig after failure: %v", err) + } + if !bytes.Equal(mustMarshalJSON(t, nodesAfter), originalNodesJSON) { + t.Fatalf("nodes changed despite rollback:\noriginal: %s\ncurrent: %s", + originalNodesJSON, mustMarshalJSON(t, nodesAfter)) + } + + alertsAfter, err := target.LoadAlertConfig() + if err != nil { + t.Fatalf("LoadAlertConfig after failure: %v", err) + } + if !bytes.Equal(mustMarshalJSON(t, alertsAfter), originalAlertsJSON) { + t.Fatalf("alerts changed despite rollback:\noriginal: %s\ncurrent: %s", + originalAlertsJSON, mustMarshalJSON(t, alertsAfter)) + } + + tokensAfter, err := target.LoadAPITokens() + if err != nil { + t.Fatalf("LoadAPITokens after failure: %v", err) + } + if !bytes.Equal(mustMarshalJSON(t, tokensAfter), originalTokensJSON) { + t.Fatalf("api tokens changed despite rollback:\noriginal: %s\ncurrent: %s", + originalTokensJSON, mustMarshalJSON(t, tokensAfter)) + } + + tmpFiles, err := filepath.Glob(filepath.Join(targetConfigDir, "*.tmp")) + if err != nil { + t.Fatalf("Glob tmp files: %v", err) + } + if len(tmpFiles) != 0 { + t.Fatalf("expected tmp files cleaned up after rollback, found %v", tmpFiles) + } +} + +func TestImportAcceptsVersion40Bundle(t *testing.T) { + const passphrase = "import-legacy" + + sourceDataDir := t.TempDir() + t.Setenv("PULSE_DATA_DIR", sourceDataDir) + + sourceConfigDir := t.TempDir() + source := config.NewConfigPersistence(sourceConfigDir) + if err := source.EnsureConfigDir(); err != nil { + t.Fatalf("EnsureConfigDir: %v", err) + } + + newNodes := []config.PVEInstance{ + { + Name: "pve-legacy", + Host: "https://pve-legacy.example:8006", + User: "root@pam", + }, + } + if err := source.SaveNodesConfig(newNodes, nil, nil); err != nil { + t.Fatalf("SaveNodesConfig: %v", err) + } + + newAlerts := alerts.AlertConfig{ + Enabled: true, + HysteresisMargin: 4, + StorageDefault: alerts.HysteresisThreshold{ + Trigger: 75, + Clear: 70, + }, + Overrides: map[string]alerts.ThresholdConfig{}, + } + if err := source.SaveAlertConfig(newAlerts); err != nil { + t.Fatalf("SaveAlertConfig: %v", err) + } + + newSystem := config.SystemSettings{ + PBSPollingInterval: 80, + PMGPollingInterval: 90, + AutoUpdateEnabled: true, + DiscoveryEnabled: true, + DiscoverySubnet: "172.16.0.0/24", + DiscoveryConfig: config.DefaultDiscoveryConfig(), + } + if err := source.SaveSystemSettings(newSystem); err != nil { + t.Fatalf("SaveSystemSettings: %v", err) + } + + exported, err := source.ExportConfig(passphrase) + if err != nil { + t.Fatalf("ExportConfig: %v", err) + } + + exportData := mustDecodeExport(t, exported, passphrase) + exportData.Version = "4.0" + exportData.APITokens = nil + + legacyPayload := mustEncodeExport(t, exportData, passphrase) + + targetDataDir := t.TempDir() + t.Setenv("PULSE_DATA_DIR", targetDataDir) + + targetConfigDir := t.TempDir() + target := config.NewConfigPersistence(targetConfigDir) + if err := target.EnsureConfigDir(); err != nil { + t.Fatalf("EnsureConfigDir: %v", err) + } + + baselineTokens := []config.APITokenRecord{ + { + ID: "token-legacy", + Name: "keep-me", + Hash: "hash-keep", + Prefix: "hashk", + Suffix: "-k", + CreatedAt: time.Date(2022, 4, 4, 12, 0, 0, 0, time.UTC), + }, + } + if err := target.SaveAPITokens(baselineTokens); err != nil { + t.Fatalf("SaveAPITokens baseline: %v", err) + } + + if err := target.ImportConfig(legacyPayload, passphrase); err != nil { + t.Fatalf("ImportConfig (legacy 4.0): %v", err) + } + + nodesAfter, err := target.LoadNodesConfig() + if err != nil { + t.Fatalf("LoadNodesConfig: %v", err) + } + assertJSONEqual(t, nodesAfter, exportData.Nodes, "nodes (4.0 import)") + + alertsAfter, err := target.LoadAlertConfig() + if err != nil { + t.Fatalf("LoadAlertConfig: %v", err) + } + assertJSONEqual(t, alertsAfter, exportData.Alerts, "alerts (4.0 import)") + + systemAfter, err := target.LoadSystemSettings() + if err != nil { + t.Fatalf("LoadSystemSettings: %v", err) + } + if systemAfter == nil { + t.Fatal("expected system settings after legacy import") + } + assertJSONEqual(t, systemAfter, exportData.System, "system settings (4.0 import)") + + tokensAfter, err := target.LoadAPITokens() + if err != nil { + t.Fatalf("LoadAPITokens: %v", err) + } + assertJSONEqual(t, tokensAfter, baselineTokens, "api tokens unchanged for 4.0 import") +} + +func mustDecodeExport(t *testing.T, payload, passphrase string) config.ExportData { + t.Helper() + + raw, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + t.Fatalf("base64 decode: %v", err) + } + + plaintext, err := decryptExportPayload(raw, passphrase) + if err != nil { + t.Fatalf("decrypt export: %v", err) + } + + var data config.ExportData + if err := json.Unmarshal(plaintext, &data); err != nil { + t.Fatalf("unmarshal export data: %v", err) + } + return data +} + +func mustEncodeExport(t *testing.T, data config.ExportData, passphrase string) string { + t.Helper() + + plaintext, err := json.Marshal(data) + if err != nil { + t.Fatalf("marshal export data: %v", err) + } + + ciphertext, err := encryptExportPayload(plaintext, passphrase) + if err != nil { + t.Fatalf("encrypt export data: %v", err) + } + + return base64.StdEncoding.EncodeToString(ciphertext) +} + +func encryptExportPayload(plaintext []byte, passphrase string) ([]byte, error) { + salt := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + return nil, err + } + + key := pbkdf2.Key([]byte(passphrase), salt, 100000, 32, sha256.New) + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + result := append(salt, ciphertext...) + return result, nil +} + +func decryptExportPayload(ciphertext []byte, passphrase string) ([]byte, error) { + if len(ciphertext) < 32 { + return nil, io.ErrUnexpectedEOF + } + + salt := ciphertext[:32] + cipherbody := ciphertext[32:] + + key := pbkdf2.Key([]byte(passphrase), salt, 100000, 32, sha256.New) + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + if len(cipherbody) < gcm.NonceSize() { + return nil, io.ErrUnexpectedEOF + } + + nonce := cipherbody[:gcm.NonceSize()] + payload := cipherbody[gcm.NonceSize():] + + return gcm.Open(nil, nonce, payload, nil) +} + +func mustMarshalJSON(t *testing.T, v interface{}) []byte { + t.Helper() + + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal json: %v", err) + } + return data +} + +func assertJSONEqual(t *testing.T, got interface{}, want interface{}, context string) { + t.Helper() + + gotJSON := mustMarshalJSON(t, got) + wantJSON := mustMarshalJSON(t, want) + + if !bytes.Equal(gotJSON, wantJSON) { + t.Fatalf("%s mismatch:\n got: %s\nwant: %s", context, gotJSON, wantJSON) + } +}