diff --git a/service/core/api.go b/service/core/api.go index c633956e..aa1305e0 100644 --- a/service/core/api.go +++ b/service/core/api.go @@ -106,6 +106,32 @@ func registerAPIEndpoints() error { return err } + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "updates/check", + Read: api.PermitUser, + ActionFunc: func(ar *api.Request) (string, error) { + module.instance.BinaryUpdates().TriggerUpdateCheck() + module.instance.IntelUpdates().TriggerUpdateCheck() + return "update check triggered", nil + }, + Name: "Get the ID of the calling profile", + }); err != nil { + return err + } + + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "updates/apply", + Read: api.PermitUser, + ActionFunc: func(ar *api.Request) (string, error) { + module.instance.BinaryUpdates().TriggerApplyUpdates() + module.instance.IntelUpdates().TriggerApplyUpdates() + return "upgrade triggered", nil + }, + Name: "Get the ID of the calling profile", + }); err != nil { + return err + } + return nil } diff --git a/service/core/core.go b/service/core/core.go index ecbcf948..60ad5857 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -16,6 +16,7 @@ import ( _ "github.com/safing/portmaster/service/status" _ "github.com/safing/portmaster/service/sync" _ "github.com/safing/portmaster/service/ui" + "github.com/safing/portmaster/service/updates" ) // Core is the core service module. @@ -114,4 +115,6 @@ func New(instance instance) (*Core, error) { type instance interface { Shutdown() AddWorkerInfoToDebugInfo(di *debug.Info) + BinaryUpdates() *updates.Updates + IntelUpdates() *updates.Updates } diff --git a/service/updates/bundle.go b/service/updates/bundle.go index af0c1021..c2de5a92 100644 --- a/service/updates/bundle.go +++ b/service/updates/bundle.go @@ -58,44 +58,54 @@ func ParseBundle(dir string, indexFile string) (*Bundle, error) { var bundle Bundle err = json.Unmarshal(content, &bundle) if err != nil { - return nil, err + return nil, fmt.Errorf("%s %w", filepath, err) } + + // Filter artifacts + filtered := make([]Artifact, 0) + for _, a := range bundle.Artifacts { + if a.Platform == "" || a.Platform == currentPlatform { + filtered = append(filtered, a) + } + } + bundle.Artifacts = filtered + return &bundle, nil } // CopyMatchingFilesFromCurrent check if there the current bundle files has matching files with the new bundle and copies them if they match. func (bundle Bundle) CopyMatchingFilesFromCurrent(current Bundle, currentDir, newDir string) error { + // Make sure new dir exists + _ = os.MkdirAll(newDir, defaultDirMode) + for _, currentArtifact := range current.Artifacts { new: for _, newArtifact := range bundle.Artifacts { if currentArtifact.Filename == newArtifact.Filename { if currentArtifact.SHA256 == newArtifact.SHA256 { - // Files match, make sure new dir exists - _ = os.MkdirAll(newDir, defaultDirMode) - - // Open the current file - sourceFilePath := fmt.Sprintf("%s/%s", currentDir, newArtifact.Filename) - sourceFile, err := os.Open(sourceFilePath) + // Read the content of the current file. + sourceFilePath := filepath.Join(currentDir, newArtifact.Filename) + content, err := os.ReadFile(sourceFilePath) if err != nil { - return fmt.Errorf("failed to open %s file: %w", sourceFilePath, err) + return fmt.Errorf("failed to read file %s: %w", sourceFilePath, err) + } + + // Check if the content matches the artifact hash + expectedHash, err := hex.DecodeString(newArtifact.SHA256) + if err != nil || len(expectedHash) != sha256.Size { + return fmt.Errorf("invalid artifact hash %s: %w", newArtifact.SHA256, err) + } + hash := sha256.Sum256(content) + if !bytes.Equal(expectedHash, hash[:]) { + return fmt.Errorf("expected and file hash mismatch: %s", sourceFilePath) } - defer sourceFile.Close() // Create new file - destFilePath := fmt.Sprintf("%s/%s", newDir, newArtifact.Filename) - destFile, err := os.Create(destFilePath) + destFilePath := filepath.Join(newDir, newArtifact.Filename) + err = os.WriteFile(destFilePath, content, defaultFileMode) if err != nil { - return fmt.Errorf("failed to create %s file: %w", destFilePath, err) + return fmt.Errorf("failed to write to file %s: %w", destFilePath, err) } - defer destFile.Close() - - // Copy - _, err = io.Copy(destFile, sourceFile) - if err != nil { - return fmt.Errorf("failed to copy contents: %w", err) - } - // Flush - _ = destFile.Sync() } break new @@ -108,8 +118,7 @@ func (bundle Bundle) CopyMatchingFilesFromCurrent(current Bundle, currentDir, ne func (bundle Bundle) DownloadAndVerify(dir string) { client := http.Client{} for _, artifact := range bundle.Artifacts { - - filePath := fmt.Sprintf("%s/%s", dir, artifact.Filename) + filePath := filepath.Join(dir, artifact.Filename) // TODO(vladimir): is this needed? _ = os.MkdirAll(filepath.Dir(filePath), defaultDirMode) @@ -131,13 +140,7 @@ func (bundle Bundle) DownloadAndVerify(dir string) { // Verify checks if the files are present int the dataDir and have the correct hash. func (bundle Bundle) Verify(dir string) error { for _, artifact := range bundle.Artifacts { - artifactPath := fmt.Sprintf("%s/%s", dir, artifact.Filename) - file, err := os.Open(artifactPath) - if err != nil { - return fmt.Errorf("failed to open file %s: %w", artifactPath, err) - } - defer func() { _ = file.Close() }() - + artifactPath := filepath.Join(dir, artifact.Filename) isValid, err := checkIfFileIsValid(artifactPath, artifact) if err != nil { return err @@ -177,11 +180,6 @@ func checkIfFileIsValid(filename string, artifact Artifact) (bool, error) { } func processArtifact(client *http.Client, artifact Artifact, filePath string) error { - // Skip artifacts not meant for this machine. - if artifact.Platform != "" && artifact.Platform != currentPlatform { - return nil - } - providedHash, err := hex.DecodeString(artifact.SHA256) if err != nil || len(providedHash) != sha256.Size { return fmt.Errorf("invalid provided hash %s: %w", artifact.SHA256, err) @@ -211,20 +209,14 @@ func processArtifact(client *http.Client, artifact Artifact, filePath string) er // Save tmpFilename := fmt.Sprintf("%s.download", filePath) - file, err := os.Create(tmpFilename) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) + fileMode := defaultFileMode + if artifact.Platform == currentPlatform { + fileMode = executableFileMode } - if artifact.Platform == "" { - _ = file.Chmod(defaultFileMode) - } else { - _ = file.Chmod(executableFileMode) - } - _, err = file.Write(content) + err = os.WriteFile(tmpFilename, content, fileMode) if err != nil { return fmt.Errorf("failed to write to file: %w", err) } - file.Close() // Rename err = os.Rename(tmpFilename, filePath) @@ -265,17 +257,11 @@ func downloadFile(client *http.Client, urls []string) ([]byte, error) { func unpack(cType string, fileBytes []byte) ([]byte, error) { switch cType { case "zip": - { - return decompressZip(fileBytes) - } + return decompressZip(fileBytes) case "gz": - { - return decompressGzip(fileBytes) - } + return decompressGzip(fileBytes) default: - { - return nil, fmt.Errorf("unsupported compression type") - } + return nil, fmt.Errorf("unsupported compression type") } } diff --git a/service/updates/index.go b/service/updates/index.go index 76b34000..19713064 100644 --- a/service/updates/index.go +++ b/service/updates/index.go @@ -42,7 +42,7 @@ func (ui *UpdateIndex) downloadIndexFileFromURL(url string) error { } defer func() { _ = resp.Body.Close() }() filePath := fmt.Sprintf("%s/%s", ui.DownloadDirectory, ui.IndexFile) - file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, defaultFileMode) + file, err := os.Create(filePath) if err != nil { return err } diff --git a/service/updates/module.go b/service/updates/module.go index e3976c9c..ac15d1f8 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -14,6 +14,8 @@ import ( "github.com/safing/portmaster/service/mgr" ) +const updateAvailableNotificationID = "updates:update-available" + type File struct { id string path string @@ -65,24 +67,15 @@ func New(instance instance, name string, index UpdateIndex) (*Updates, error) { EventVersionsUpdated: mgr.NewEventMgr[struct{}](VersionUpdateEvent, m), updateIndex: index, + files: make(map[string]File), instance: instance, } // Events module.updateCheckWorkerMgr = m.NewWorkerMgr("update checker", module.checkForUpdates, nil) - module.updateCheckWorkerMgr.Repeat(30 * time.Second) - module.upgraderWorkerMgr = m.NewWorkerMgr("upgrader", func(w *mgr.WorkerCtx) error { - err := applyUpdates(module.updateIndex, *module.updateBundle) - if err != nil { - // TODO(vladimir): Send notification to UI - log.Errorf("updates: failed to apply updates: %s", err) - } else { - // TODO(vladimir): Prompt user to restart? - module.instance.Restart() - } - return nil - }, nil) + module.updateCheckWorkerMgr.Repeat(1 * time.Hour) + module.upgraderWorkerMgr = m.NewWorkerMgr("upgrader", module.applyUpdates, nil) var err error module.bundle, err = ParseBundle(module.updateIndex.Directory, module.updateIndex.IndexFile) @@ -92,18 +85,44 @@ func New(instance instance, name string, index UpdateIndex) (*Updates, error) { // Add bundle artifacts to registry. module.processBundle(module.bundle) + err = module.registerEndpoints() + if err != nil { + log.Errorf("failed to register endpoints: %s", err) + } - // Remove old files - m.Go("old files cleaner", func(ctx *mgr.WorkerCtx) error { - err := os.RemoveAll(module.updateIndex.PurgeDirectory) - if err != nil { - return fmt.Errorf("failed to delete folder: %w", err) - } - return nil - }) return module, nil } +func (u *Updates) registerEndpoints() error { + if err := api.RegisterEndpoint(api.Endpoint{ + Name: "Check for update", + Description: "Trigger update check", + Path: "updates/check", + Read: api.PermitAnyone, + ActionFunc: func(ar *api.Request) (msg string, err error) { + u.updateCheckWorkerMgr.Go() + return "Check for updates triggered", nil + }, + }); err != nil { + return err + } + + if err := api.RegisterEndpoint(api.Endpoint{ + Name: "Apply update", + Description: "Triggers update", + Path: "updates/apply", + Read: api.PermitAnyone, + ActionFunc: func(ar *api.Request) (msg string, err error) { + u.upgraderWorkerMgr.Go() + return "Apply updates triggered", nil + }, + }); err != nil { + return err + } + + return nil +} + func (reg *Updates) processBundle(bundle *Bundle) { for _, artifact := range bundle.Artifacts { artifactPath := fmt.Sprintf("%s/%s", reg.updateIndex.Directory, artifact.Filename) @@ -119,33 +138,52 @@ func (u *Updates) checkForUpdates(_ *mgr.WorkerCtx) error { u.updateBundle, err = ParseBundle(u.updateIndex.DownloadDirectory, u.updateIndex.IndexFile) if err != nil { - return fmt.Errorf("failed parse bundle: %s", err) + return fmt.Errorf("failed parsing bundle: %s", err) } defer u.EventResourcesUpdated.Submit(struct{}{}) - // Compare current and downloaded index version. - currentVersion, err := semver.NewVersion(u.bundle.Version) - downloadVersion, err := semver.NewVersion(u.updateBundle.Version) - if currentVersion.Compare(downloadVersion) <= 0 { - // no updates - log.Info("updates: check complete: no new updates") + hasUpdate, err := u.checkVersionIncrement() + if err != nil { + return fmt.Errorf("failed to compare versions: %s", err) + } + + if !hasUpdate { + log.Infof("updates: check compete: no new updates") return nil } log.Infof("updates: check complete: downloading new version: %s %s", u.updateBundle.Name, u.updateBundle.Version) - err = u.DownloadUpdates() + err = u.downloadUpdates() if err != nil { log.Errorf("updates: failed to download bundle: %s", err) - } else if u.updateIndex.AutoApply { - u.ApplyUpdates() + } else { + notifications.NotifyPrompt(updateAvailableNotificationID, "Update available", "Apply update and restart?", notifications.Action{ + ID: "apply", + Text: "Apply", + Type: notifications.ActionTypeInjectEvent, + Payload: "apply-updates", + }) } return nil } -// DownloadUpdates downloads available binary updates. -func (u *Updates) DownloadUpdates() error { +func (u *Updates) checkVersionIncrement() (bool, error) { + // Compare current and downloaded index version. + currentVersion, err := semver.NewVersion(u.bundle.Version) + if err != nil { + return false, err + } + downloadVersion, err := semver.NewVersion(u.updateBundle.Version) + if err != nil { + return false, err + } + log.Debugf("updates: checking version: curr: %s new: %s", currentVersion.String(), downloadVersion.String()) + return downloadVersion.GreaterThan(currentVersion), nil +} + +func (u *Updates) downloadUpdates() error { if u.updateBundle == nil { - // CheckForBinaryUpdates needs to be called before this. + // checkForUpdates needs to be called before this. return fmt.Errorf("no valid update bundle found") } _ = deleteUnfinishedDownloads(u.updateIndex.DownloadDirectory) @@ -157,7 +195,40 @@ func (u *Updates) DownloadUpdates() error { return nil } -func (u *Updates) ApplyUpdates() { +func (u *Updates) applyUpdates(_ *mgr.WorkerCtx) error { + // Check if we have new version + hasNewVersion, err := u.checkVersionIncrement() + if err != nil { + return fmt.Errorf("error while reading bundle version: %w", err) + } + + if !hasNewVersion { + return fmt.Errorf("there is no new version to apply") + } + + err = u.updateBundle.Verify(u.updateIndex.DownloadDirectory) + if err != nil { + return fmt.Errorf("failed to apply update: %s", err) + } + + err = switchFolders(u.updateIndex, *u.updateBundle) + if err != nil { + // TODO(vladimir): Send notification to UI + log.Errorf("updates: failed to apply updates: %s", err) + } else { + // TODO(vladimir): Prompt user to restart? + u.instance.Restart() + } + return nil +} + +// TriggerUpdateCheck triggers an update check +func (u *Updates) TriggerUpdateCheck() { + u.updateCheckWorkerMgr.Go() +} + +// TriggerApplyUpdates triggers upgrade +func (u *Updates) TriggerApplyUpdates() { u.upgraderWorkerMgr.Go() } @@ -173,7 +244,16 @@ func (u *Updates) Manager() *mgr.Manager { // Start starts the module. func (u *Updates) Start() error { + // Remove old files + u.m.Go("old files cleaner", func(ctx *mgr.WorkerCtx) error { + err := os.RemoveAll(u.updateIndex.PurgeDirectory) + if err != nil { + return fmt.Errorf("failed to delete folder: %w", err) + } + return nil + }) u.updateCheckWorkerMgr.Go() + return nil } diff --git a/service/updates/updater.go b/service/updates/updater.go index 736290e6..3c6b2892 100644 --- a/service/updates/updater.go +++ b/service/updates/updater.go @@ -15,7 +15,7 @@ const ( defaultDirMode = os.FileMode(0o0755) ) -func applyUpdates(updateIndex UpdateIndex, newBundle Bundle) error { +func switchFolders(updateIndex UpdateIndex, newBundle Bundle) error { // Create purge dir. err := os.MkdirAll(updateIndex.PurgeDirectory, defaultDirMode) if err != nil { @@ -30,17 +30,17 @@ func applyUpdates(updateIndex UpdateIndex, newBundle Bundle) error { // Move current version files into purge folder. for _, file := range files { - filepath := fmt.Sprintf("%s/%s", updateIndex.Directory, file.Name()) - purgePath := fmt.Sprintf("%s/%s", updateIndex.PurgeDirectory, file.Name()) - err := os.Rename(filepath, purgePath) + currentFilepath := filepath.Join(updateIndex.Directory, file.Name()) + purgePath := filepath.Join(updateIndex.PurgeDirectory, file.Name()) + err := os.Rename(currentFilepath, purgePath) if err != nil { - return fmt.Errorf("failed to move file %s: %w", filepath, err) + return fmt.Errorf("failed to move file %s: %w", currentFilepath, err) } } // Move the new index file - indexFile := fmt.Sprintf("%s/%s", updateIndex.DownloadDirectory, updateIndex.IndexFile) - newIndexFile := fmt.Sprintf("%s/%s", updateIndex.Directory, updateIndex.IndexFile) + indexFile := filepath.Join(updateIndex.DownloadDirectory, updateIndex.IndexFile) + newIndexFile := filepath.Join(updateIndex.Directory, updateIndex.IndexFile) err = os.Rename(indexFile, newIndexFile) if err != nil { return fmt.Errorf("failed to move index file %s: %w", indexFile, err) @@ -48,8 +48,8 @@ func applyUpdates(updateIndex UpdateIndex, newBundle Bundle) error { // Move downloaded files to the current version folder. for _, artifact := range newBundle.Artifacts { - fromFilepath := fmt.Sprintf("%s/%s", updateIndex.DownloadDirectory, artifact.Filename) - toFilepath := fmt.Sprintf("%s/%s", updateIndex.Directory, artifact.Filename) + fromFilepath := filepath.Join(updateIndex.DownloadDirectory, artifact.Filename) + toFilepath := filepath.Join(updateIndex.Directory, artifact.Filename) err = os.Rename(fromFilepath, toFilepath) if err != nil { return fmt.Errorf("failed to move file %s: %w", fromFilepath, err) @@ -64,12 +64,12 @@ func deleteUnfinishedDownloads(rootDir string) error { return err } - // Check if the current file has the specified extension + // Check if the current file has the download extension if !info.IsDir() && strings.HasSuffix(info.Name(), ".download") { - log.Warningf("updates: deleting unfinished: %s\n", path) + log.Warningf("updates: deleting unfinished download file: %s\n", path) err := os.Remove(path) if err != nil { - return fmt.Errorf("failed to delete file %s: %w", path, err) + log.Errorf("updates: failed to delete unfinished download file %s: %w", path, err) } }