Fix tests and issues

This commit is contained in:
Daniel 2024-11-08 14:39:58 +01:00
parent 2d0d711938
commit ddf7ba170e
10 changed files with 325 additions and 231 deletions

View file

@ -29,21 +29,23 @@ func newTestInstance(testName string) (*testInstance, error) {
}, nil
}
func TestConfigPersistence(t *testing.T) {
t.Parallel()
func TestMain(m *testing.M) {
instance, err := newTestInstance("test-config")
if err != nil {
t.Fatalf("failed to create test instance: %s", err)
panic(fmt.Errorf("failed to create test instance: %w", err))
}
defer func() { _ = os.RemoveAll(instance.DataDir()) }()
module, err = New(instance)
if err != nil {
t.Fatalf("failed to initialize module: %s", err)
panic(fmt.Errorf("failed to initialize module: %w", err))
}
err = SaveConfig()
m.Run()
}
func TestConfigPersistence(t *testing.T) { //nolint:paralleltest
err := SaveConfig()
if err != nil {
t.Fatal(err)
}

View file

@ -70,16 +70,16 @@ func init() {
}
func scan(cmd *cobra.Command, args []string) error {
bundle, err := updates.GenerateIndexFromDir(scanDir, scanConfig)
index, err := updates.GenerateIndexFromDir(scanDir, scanConfig)
if err != nil {
return err
}
bundleStr, err := json.MarshalIndent(&bundle, "", " ")
indexJson, err := json.MarshalIndent(&index, "", " ")
if err != nil {
return fmt.Errorf("marshal index: %w", err)
}
fmt.Printf("%s", bundleStr)
fmt.Printf("%s", indexJson)
return nil
}

View file

@ -130,6 +130,7 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { //nolint:maintidx
dataDir: svcCfg.DataDir,
}
instance.ctx, instance.cancelCtx = context.WithCancel(context.Background())
instance.shutdownCtx, instance.cancelShutdownCtx = context.WithCancel(context.Background())
// Base modules
instance.base, err = base.New(instance)
@ -651,24 +652,23 @@ func (i *Instance) shutdown(exitCode int) {
return
}
// Cancel main context.
i.cancelCtx()
// Set given exit code.
i.exitCode.Store(int32(exitCode))
// Cancel contexts.
i.cancelCtx()
defer i.cancelShutdownCtx()
// Start shutdown asynchronously in a separate manager.
m := mgr.New("instance")
m.Go("shutdown", func(w *mgr.WorkerCtx) error {
for {
if err := i.Stop(); err != nil {
w.Error("failed to shutdown", "err", err, "retry", "1s")
time.Sleep(1 * time.Second)
} else {
return nil
}
// Stop all modules.
if err := i.Stop(); err != nil {
w.Error("failed to shutdown", "err", err)
}
// Cancel shutdown process context.
i.cancelShutdownCtx()
return nil
})
}

View file

@ -56,7 +56,7 @@ func (d *Downloader) updateIndex(ctx context.Context) error {
}
log.Warningf("updates/%s: failed to update index from %q: %s", d.u.cfg.Name, url, err)
err = fmt.Errorf("update index file from %q: %s", url, err)
err = fmt.Errorf("update index file from %q: %w", url, err)
}
if err != nil {
return fmt.Errorf("all index URLs failed, last error: %w", err)
@ -65,7 +65,7 @@ func (d *Downloader) updateIndex(ctx context.Context) error {
// Write the index into a file.
indexFilepath := filepath.Join(d.u.cfg.DownloadDirectory, d.u.cfg.IndexFile)
err = os.WriteFile(indexFilepath, []byte(indexData), defaultFileMode)
err = os.WriteFile(indexFilepath, indexData, defaultFileMode)
if err != nil {
return fmt.Errorf("write index file: %w", err)
}
@ -111,7 +111,7 @@ func (d *Downloader) gatherExistingFiles(dir string) error {
// Read full file.
fileData, err := os.ReadFile(fullpath)
if err != nil {
log.Debugf("updates/%s: failed to read file %q while searching for existing files: %w", d.u.cfg.Name, fullpath, err)
log.Debugf("updates/%s: failed to read file %q while searching for existing files: %s", d.u.cfg.Name, fullpath, err)
return fmt.Errorf("failed to read file %s: %w", fullpath, err)
}
@ -150,7 +150,12 @@ artifacts:
if err == nil {
continue artifacts
}
log.Debugf("updates/%s: failed to copy existing file %s: %w", d.u.cfg.Name, artifact.Filename, err)
log.Debugf("updates/%s: failed to copy existing file %s: %s", d.u.cfg.Name, artifact.Filename, err)
}
// Check if the artifact has download URLs.
if len(artifact.URLs) == 0 {
return fmt.Errorf("artifact %s is missing download URLs", artifact.Filename)
}
// Try to download the artifact from one of the URLs.
@ -163,7 +168,7 @@ artifacts:
// Valid artifact found!
break artifactURLs
}
err = fmt.Errorf("update index file from %q: %s", url, err)
err = fmt.Errorf("update index file from %q: %w", url, err)
}
if err != nil {
return fmt.Errorf("all artifact URLs for %s failed, last error: %w", artifact.Filename, err)

View file

@ -15,6 +15,7 @@ import (
"time"
semver "github.com/hashicorp/go-version"
"github.com/safing/jess"
"github.com/safing/jess/filesig"
)
@ -26,7 +27,7 @@ const currentPlatform = runtime.GOOS + "_" + runtime.GOARCH
var zeroVersion = semver.Must(semver.NewVersion("0.0.0"))
// Artifacts represents a single file with metadata.
// Artifact represents a single file with metadata.
type Artifact struct {
Filename string `json:"Filename"`
SHA256 string `json:"SHA256"`
@ -85,7 +86,7 @@ func (a *Artifact) IsNewerThan(b *Artifact) (newer, ok bool) {
}
func (a *Artifact) export(dir string, indexVersion *semver.Version) *Artifact {
copy := &Artifact{
copied := &Artifact{
Filename: a.Filename,
SHA256: a.SHA256,
URLs: a.URLs,
@ -98,20 +99,20 @@ func (a *Artifact) export(dir string, indexVersion *semver.Version) *Artifact {
// Make sure we have a version number.
switch {
case copy.versionNum != nil:
case copied.versionNum != nil:
// Version already parsed.
case copy.Version != "":
case copied.Version != "":
// Need to parse version.
v, err := semver.NewVersion(copy.Version)
v, err := semver.NewVersion(copied.Version)
if err == nil {
copy.versionNum = v
copied.versionNum = v
}
default:
// No version defined, inherit index version.
copy.versionNum = indexVersion
copied.versionNum = indexVersion
}
return copy
return copied
}
// Index represents a collection of artifacts with metadata.
@ -146,8 +147,8 @@ func ParseIndex(jsonContent []byte, trustStore jess.TrustStore) (*Index, error)
}
// Parse json.
var index Index
err := json.Unmarshal([]byte(jsonContent), &index)
index := &Index{}
err := json.Unmarshal(jsonContent, index)
if err != nil {
return nil, fmt.Errorf("parse index: %w", err)
}
@ -158,7 +159,7 @@ func ParseIndex(jsonContent []byte, trustStore jess.TrustStore) (*Index, error)
return nil, err
}
return &index, nil
return index, nil
}
func (index *Index) init() error {
@ -219,7 +220,7 @@ func (index *Index) ShouldUpgradeTo(newIndex *Index) error {
return fmt.Errorf("current index cannot do upgrades: %w", err)
}
if err := newIndex.CanDoUpgrades(); err != nil {
return fmt.Errorf("new index cannot do upgrade: %w")
return fmt.Errorf("new index cannot do upgrade: %w", err)
}
switch {
@ -229,13 +230,14 @@ func (index *Index) ShouldUpgradeTo(newIndex *Index) error {
return nil
case index.Name != newIndex.Name:
return errors.New("index names do not match")
case index.versionNum.GreaterThan(newIndex.versionNum):
return errors.New("current index has newer version")
return errors.New("new index name does not match")
case index.Published.After(newIndex.Published):
return errors.New("current index was published later")
return errors.New("new index is older (time)")
case index.versionNum.Segments()[0] > newIndex.versionNum.Segments()[0]:
// Downgrades are allowed, if they are not breaking changes.
return errors.New("new index is a breaking change downgrade")
case index.Published.Equal(newIndex.Published):
// "Do nothing".
@ -252,7 +254,7 @@ func (index *Index) VerifyArtifacts(dir string) error {
for _, artifact := range index.Artifacts {
err := checkSHA256SumFile(filepath.Join(dir, artifact.Filename), artifact.SHA256)
if err != nil {
return fmt.Errorf("verify %s: %s", artifact.Filename, err)
return fmt.Errorf("verify %s: %w", artifact.Filename, err)
}
}

View file

@ -3,7 +3,6 @@ package updates
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io/fs"
@ -95,7 +94,7 @@ settings:
}
// GenerateIndexFromDir generates a index from a given folder.
func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) {
func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) { //nolint:maintidx
artifacts := make(map[string]Artifact)
// Initialize.
@ -107,6 +106,13 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error)
if err != nil {
return nil, fmt.Errorf("invalid index dir: %w", err)
}
var indexVersion *semver.Version
if cfg.Version != "" {
indexVersion, err = semver.NewVersion(cfg.Version)
if err != nil {
return nil, fmt.Errorf("invalid index version: %w", err)
}
}
err = filepath.WalkDir(sourceDir, func(fullpath string, d fs.DirEntry, err error) error {
// Fail on access error.
@ -227,9 +233,10 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error)
// Create base index.
index := &Index{
Name: cfg.Name,
Version: cfg.Version,
Published: time.Now(),
Name: cfg.Name,
Version: cfg.Version,
Published: time.Now(),
versionNum: indexVersion,
}
if index.Version == "" && cfg.PrimaryArtifact != "" {
pv, ok := artifacts[cfg.PrimaryArtifact]
@ -286,45 +293,6 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error)
return index, nil
}
func selectLatestArtifacts(artifacts []Artifact) ([]Artifact, error) {
artifactsMap := make(map[string]Artifact)
for _, a := range artifacts {
// Make the key platform specific since there can be same filename for multiple platforms.
key := a.Filename + a.Platform
aMap, ok := artifactsMap[key]
if !ok {
artifactsMap[key] = a
continue
}
if aMap.Version == "" || a.Version == "" {
return nil, fmt.Errorf("invalid mix version and non versioned files for: %s", a.Filename)
}
mapVersion, err := semver.NewVersion(aMap.Version)
if err != nil {
return nil, fmt.Errorf("invalid version for artifact: %s", aMap.Filename)
}
artifactVersion, err := semver.NewVersion(a.Version)
if err != nil {
return nil, fmt.Errorf("invalid version for artifact: %s", a.Filename)
}
if mapVersion.LessThan(artifactVersion) {
artifactsMap[key] = a
}
}
artifactsFiltered := make([]Artifact, 0, len(artifactsMap))
for _, a := range artifactsMap {
artifactsFiltered = append(artifactsFiltered, a)
}
return artifactsFiltered, nil
}
func getSHA256(path string, unpackType string) (string, error) {
content, err := os.ReadFile(path)
if err != nil {
@ -372,50 +340,3 @@ func getIdentifierAndVersion(versionedPath string) (identifier, version string,
// `dirPath + filename` is guaranteed by path.Split()
return dirPath + filename, version, true
}
// GenerateMockFolder generates mock index folder for testing.
func GenerateMockFolder(dir, name, version string) error { // FIXME: move this to test?
// Make sure dir exists
_ = os.MkdirAll(dir, defaultDirMode)
// Create empty files
file, err := os.Create(filepath.Join(dir, "portmaster"))
if err != nil {
return err
}
_ = file.Close()
file, err = os.Create(filepath.Join(dir, "portmaster-core"))
if err != nil {
return err
}
_ = file.Close()
file, err = os.Create(filepath.Join(dir, "portmaster.zip"))
if err != nil {
return err
}
_ = file.Close()
file, err = os.Create(filepath.Join(dir, "assets.zip"))
if err != nil {
return err
}
_ = file.Close()
index, err := GenerateIndexFromDir(dir, IndexScanConfig{
Name: name,
Version: version,
})
if err != nil {
return err
}
indexJson, err := json.MarshalIndent(index, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "failed to marshal index: %s\n", err)
}
err = os.WriteFile(filepath.Join(dir, "index.json"), indexJson, defaultFileMode)
if err != nil {
return err
}
return nil
}

View file

@ -10,11 +10,12 @@ import (
"sync"
"time"
"github.com/tevino/abool"
"github.com/safing/jess"
"github.com/safing/portmaster/base/log"
"github.com/safing/portmaster/base/notifications"
"github.com/safing/portmaster/service/mgr"
"github.com/tevino/abool"
)
const (
@ -169,7 +170,7 @@ func New(instance instance, name string, cfg Config) (*Updater, error) {
// Fall back to scanning the directory.
if !errors.Is(err, os.ErrNotExist) {
log.Errorf("updates/%s: invalid index file, falling back to dir scan: %w", cfg.Name, err)
log.Errorf("updates/%s: invalid index file, falling back to dir scan: %s", cfg.Name, err)
}
index, err = GenerateIndexFromDir(cfg.Directory, IndexScanConfig{Version: "0.0.0"})
if err == nil && index.init() == nil {
@ -181,13 +182,12 @@ func New(instance instance, name string, cfg Config) (*Updater, error) {
return module, nil
}
func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreVersion, forceApply bool) (err error) {
func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreVersion, forceApply bool) (err error) { //nolint:maintidx
// Make sure only one update process is running.
if !u.isUpdateRunning.SetToIf(false, true) {
return fmt.Errorf("an updater task is already running, please try again later")
}
defer u.isUpdateRunning.UnSet()
// FIXME: Switch to mutex?
// Create a new downloader.
downloader := NewDownloader(u, indexURLs)
@ -201,7 +201,7 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
}
} else {
// Otherwise, load index from download dir.
downloader.index, err = LoadIndex(filepath.Join(u.cfg.Directory, u.cfg.IndexFile), u.cfg.Verify)
downloader.index, err = LoadIndex(filepath.Join(u.cfg.DownloadDirectory, u.cfg.IndexFile), u.cfg.Verify)
if err != nil {
return fmt.Errorf("load previously downloaded index file: %w", err)
}
@ -215,23 +215,42 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
u.indexLock.Unlock()
// Check with local pointer to index.
if err := index.ShouldUpgradeTo(downloader.index); err != nil {
log.Infof("updates/%s: no new or eligible update: %s", u.cfg.Name, err)
if u.cfg.Notify && u.instance.Notifications() != nil {
u.instance.Notifications().Notify(&notifications.Notification{
EventID: noNewUpdateNotificationID,
Type: notifications.Info,
Title: "Portmaster Is Up-To-Date",
Message: "Portmaster v" + index.Version + " is the newest version.",
Expires: time.Now().Add(1 * time.Minute).Unix(),
AvailableActions: []*notifications.Action{
{
ID: "ack",
Text: "OK",
if errors.Is(err, ErrSameIndex) {
log.Infof("updates/%s: no new update", u.cfg.Name)
if u.cfg.Notify && u.instance.Notifications() != nil {
u.instance.Notifications().Notify(&notifications.Notification{
EventID: noNewUpdateNotificationID,
Type: notifications.Info,
Title: "Portmaster Is Up-To-Date",
Message: "Portmaster v" + index.Version + " is the newest version.",
Expires: time.Now().Add(1 * time.Minute).Unix(),
AvailableActions: []*notifications.Action{
{
ID: "ack",
Text: "OK",
},
},
},
})
})
}
} else {
log.Warningf("updates/%s: cannot update: %s", u.cfg.Name, err)
if u.cfg.Notify && u.instance.Notifications() != nil {
u.instance.Notifications().Notify(&notifications.Notification{
EventID: noNewUpdateNotificationID,
Type: notifications.Info,
Title: "Portmaster Is Up-To-Date*",
Message: "While Portmaster v" + index.Version + " is the newest version, there is an internal issue with checking for updates: " + err.Error(),
Expires: time.Now().Add(1 * time.Minute).Unix(),
AvailableActions: []*notifications.Action{
{
ID: "ack",
Text: "OK",
},
},
})
}
}
return ErrNoUpdateAvailable
return fmt.Errorf("%w: %w", ErrNoUpdateAvailable, err)
}
}
@ -320,7 +339,10 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
// Install is complete!
// Clean up and notify modules of changed files.
u.cleanupAfterUpgrade()
err = u.cleanupAfterUpgrade()
if err != nil {
log.Debugf("updates/%s: failed to clean up after upgrade: %s", u.cfg.Name, err)
}
u.EventResourcesUpdated.Submit(struct{}{})
// If no restart is needed, we are done.
@ -363,7 +385,7 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
Type: notifications.ActionTypeWebhook,
Payload: notifications.ActionTypeWebhookPayload{
Method: "POST",
URL: "updates/apply", // FIXME
URL: "core/restart",
},
},
)
@ -376,15 +398,35 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
}
func (u *Updater) updateCheckWorker(w *mgr.WorkerCtx) error {
_ = u.updateAndUpgrade(w, u.cfg.IndexURLs, false, false)
// FIXME: Handle errors.
return nil
err := u.updateAndUpgrade(w, u.cfg.IndexURLs, false, false)
switch {
case err == nil:
return nil // Success!
case errors.Is(err, ErrSameIndex):
return nil // Nothing to do.
case errors.Is(err, ErrNoUpdateAvailable):
return nil // Already logged.
case errors.Is(err, ErrActionRequired) && !u.cfg.Notify:
return fmt.Errorf("user action required, but notifying user is disabled: %w", err)
default:
return fmt.Errorf("udpating failed: %w", err)
}
}
func (u *Updater) upgradeWorker(w *mgr.WorkerCtx) error {
_ = u.updateAndUpgrade(w, u.cfg.IndexURLs, false, true)
// FIXME: Handle errors.
return nil
err := u.updateAndUpgrade(w, u.cfg.IndexURLs, false, true)
switch {
case err == nil:
return nil // Success!
case errors.Is(err, ErrSameIndex):
return nil // Nothing to do.
case errors.Is(err, ErrNoUpdateAvailable):
return nil // Already logged.
case errors.Is(err, ErrActionRequired) && !u.cfg.Notify:
return fmt.Errorf("user action required, but notifying user is disabled: %w", err)
default:
return fmt.Errorf("udpating failed: %w", err)
}
}
// ForceUpdate executes a forced update and upgrade directly and synchronously

View file

@ -1,12 +1,15 @@
package updates
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/safing/portmaster/base/notifications"
"github.com/safing/portmaster/service/mgr"
)
type testInstance struct{}
@ -24,66 +27,129 @@ func (i *testInstance) Ready() bool {
func (i *testInstance) SetCmdLineOperation(f func() error) {}
func TestPreformUpdate(t *testing.T) {
func TestPerformUpdate(t *testing.T) {
t.Parallel()
// Initialize mock instance
stub := &testInstance{}
// Make tmp dirs
installedDir, err := os.MkdirTemp("", "updates_current")
installedDir, err := os.MkdirTemp("", "updates_current_")
if err != nil {
panic(err)
t.Fatal(err)
}
defer func() { _ = os.RemoveAll(installedDir) }()
updateDir, err := os.MkdirTemp("", "updates_new")
updateDir, err := os.MkdirTemp("", "updates_new_")
if err != nil {
panic(err)
t.Fatal(err)
}
defer func() { _ = os.RemoveAll(updateDir) }()
purgeDir, err := os.MkdirTemp("", "updates_purge")
purgeDir, err := os.MkdirTemp("", "updates_purge_")
if err != nil {
panic(err)
t.Fatal(err)
}
defer func() { _ = os.RemoveAll(purgeDir) }()
// Generate mock files
if err := GenerateMockFolder(installedDir, "Test", "1.0.0"); err != nil {
panic(err)
now := time.Now()
if err := GenerateMockFolder(installedDir, "Test", "1.0.0", now); err != nil {
t.Fatal(err)
}
if err := GenerateMockFolder(updateDir, "Test", "1.0.1"); err != nil {
panic(err)
if err := GenerateMockFolder(updateDir, "Test", "1.0.1", now.Add(1*time.Minute)); err != nil {
t.Fatal(err)
}
// Create updater
updates, err := New(stub, "Test", Config{
// Create updater (loads index).
updater, err := New(stub, "Test", Config{
Name: "Test",
Directory: installedDir,
DownloadDirectory: updateDir,
PurgeDirectory: purgeDir,
IndexFile: "index.json",
AutoApply: false,
NeedsRestart: false,
AutoDownload: true,
AutoApply: true,
})
if err != nil {
panic(err)
}
// Read and parse the index file
if err := updates.downloader.Verify(); err != nil {
panic(err)
t.Fatal(err)
}
// Try to apply the updates
err = updates.applyUpdates(nil)
if err != nil {
panic(err)
}
m := mgr.New("updates test")
_ = m.Do("test update and upgrade", func(w *mgr.WorkerCtx) error {
if err := updater.updateAndUpgrade(w, nil, false, false); err != nil {
if data, err := os.ReadFile(filepath.Join(installedDir, "index.json")); err == nil {
fmt.Println(string(data))
fmt.Println(updater.index.Version)
fmt.Println(updater.index.versionNum)
}
if data, err := os.ReadFile(filepath.Join(updateDir, "index.json")); err == nil {
fmt.Println(string(data))
idx, err := ParseIndex(data, nil)
if err == nil {
fmt.Println(idx.Version)
fmt.Println(idx.versionNum)
}
}
// CHeck if the current version is now the new.
bundle, err := LoadBundle(filepath.Join(installedDir, "index.json"))
if err != nil {
panic(err)
}
t.Fatal(err)
}
return nil
})
if bundle.Version != "1.0.1" {
panic(fmt.Errorf("expected version 1.0.1 found %s", bundle.Version))
// Check if the current version is now the new.
newIndex, err := LoadIndex(filepath.Join(installedDir, "index.json"), nil)
if err != nil {
t.Fatal(err)
}
if newIndex.Version != "1.0.1" {
t.Fatalf("expected version 1.0.1 found %s", newIndex.Version)
}
}
// GenerateMockFolder generates mock index folder for testing.
func GenerateMockFolder(dir, name, version string, published time.Time) error {
// Make sure dir exists
_ = os.MkdirAll(dir, defaultDirMode)
// Create empty files
file, err := os.Create(filepath.Join(dir, "portmaster"))
if err != nil {
return err
}
_ = file.Close()
file, err = os.Create(filepath.Join(dir, "portmaster-core"))
if err != nil {
return err
}
_ = file.Close()
file, err = os.Create(filepath.Join(dir, "portmaster.zip"))
if err != nil {
return err
}
_ = file.Close()
file, err = os.Create(filepath.Join(dir, "assets.zip"))
if err != nil {
return err
}
_ = file.Close()
index, err := GenerateIndexFromDir(dir, IndexScanConfig{
Name: name,
Version: version,
})
if err != nil {
return err
}
index.Published = published
indexJSON, err := json.MarshalIndent(index, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "failed to marshal index: %s\n", err)
}
err = os.WriteFile(filepath.Join(dir, "index.json"), indexJSON, defaultFileMode)
if err != nil {
return err
}
return nil
}

View file

@ -1,6 +1,7 @@
package updates
import (
"errors"
"fmt"
"io/fs"
"os"
@ -31,7 +32,7 @@ func (u *Updater) upgrade(downloader *Downloader, ignoreVersion bool) error {
}
// Execute the upgrade.
upgradeError := u.upgradeMoveFiles(downloader, ignoreVersion)
upgradeError := u.upgradeMoveFiles(downloader)
if upgradeError == nil {
return nil
}
@ -43,10 +44,10 @@ func (u *Updater) upgrade(downloader *Downloader, ignoreVersion bool) error {
}
// Recovery failed too.
return fmt.Errorf("upgrade (including recovery) failed: %s", u.cfg.Name, upgradeError)
return fmt.Errorf("upgrade (including recovery) failed: %w", upgradeError)
}
func (u *Updater) upgradeMoveFiles(downloader *Downloader, ignoreVersion bool) error {
func (u *Updater) upgradeMoveFiles(downloader *Downloader) error {
// Important:
// We assume that the downloader has done its job and all artifacts are verified.
// Files will just be moved here.
@ -65,20 +66,28 @@ func (u *Updater) upgradeMoveFiles(downloader *Downloader, ignoreVersion bool) e
}
files, err := os.ReadDir(u.cfg.Directory)
if err != nil {
return fmt.Errorf("read current directory: %w", err)
}
for _, file := range files {
// Check if file is ignored.
if slices.Contains(u.cfg.Ignore, file.Name()) {
continue
if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("read current directory: %w", err)
}
// Otherwise, move file to purge dir.
src := filepath.Join(u.cfg.Directory, file.Name())
dst := filepath.Join(u.cfg.PurgeDirectory, file.Name())
err := u.moveFile(src, dst, "", file.Type().Perm())
err = os.MkdirAll(u.cfg.Directory, defaultDirMode)
if err != nil {
return fmt.Errorf("failed to move current file %s to purge dir: %w", file.Name(), err)
return fmt.Errorf("create current directory: %w", err)
}
} else {
// Move files.
for _, file := range files {
// Check if file is ignored.
if slices.Contains(u.cfg.Ignore, file.Name()) {
continue
}
// Otherwise, move file to purge dir.
src := filepath.Join(u.cfg.Directory, file.Name())
dst := filepath.Join(u.cfg.PurgeDirectory, file.Name())
err := u.moveFile(src, dst, "", file.Type().Perm())
if err != nil {
return fmt.Errorf("failed to move current file %s to purge dir: %w", file.Name(), err)
}
}
}
@ -118,7 +127,7 @@ func (u *Updater) moveFile(currentPath, newPath string, sha256sum string, fileMo
// Moving was successful, return.
return nil
}
log.Tracef("updates/%s: failed to move to %q, falling back to copy+delete: %w", u.cfg.Name, newPath, err)
log.Tracef("updates/%s: failed to move to %q, falling back to copy+delete: %s", u.cfg.Name, newPath, err)
// Copy and check the checksum while we are at it.
err = copyAndCheckSHA256Sum(currentPath, newPath, sha256sum, fileMode)
@ -144,7 +153,7 @@ func (u *Updater) recoverFromFailedUpgrade() error {
err := u.moveFile(purgedFile, activeFile, "", file.Type().Perm())
if err != nil {
// Only warn and continue to recover as many files as possible.
log.Warningf("updates/%s: failed to roll back file %s: %w", u.cfg.Name, file.Name(), err)
log.Warningf("updates/%s: failed to roll back file %s: %s", u.cfg.Name, file.Name(), err)
}
}

View file

@ -34,10 +34,17 @@ import (
// Instance is an instance of a Portmaster service.
type Instance struct {
ctx context.Context
cancelCtx context.CancelFunc
ctx context.Context
cancelCtx context.CancelFunc
shutdownCtx context.Context
cancelShutdownCtx context.CancelFunc
serviceGroup *mgr.Group
binDir string
dataDir string
exitCode atomic.Int32
base *base.Base
@ -67,6 +74,7 @@ type Instance struct {
terminal *terminal.TerminalModule
CommandLineOperation func() error
ShouldRestart bool
}
// New returns a new Portmaster service instance.
@ -74,6 +82,7 @@ func New() (*Instance, error) {
// Create instance to pass it to modules.
instance := &Instance{}
instance.ctx, instance.cancelCtx = context.WithCancel(context.Background())
instance.shutdownCtx, instance.cancelShutdownCtx = context.WithCancel(context.Background())
binaryUpdateIndex := updates.Config{
// FIXME: fill
@ -234,6 +243,18 @@ func (i *Instance) SetSleep(enabled bool) {
}
}
// BinDir returns the directory for binaries.
// This directory may be read-only.
func (i *Instance) BinDir() string {
return i.binDir
}
// DataDir returns the directory for variable data.
// This directory is expected to be read/writeable.
func (i *Instance) DataDir() string {
return i.dataDir
}
// Database returns the database module.
func (i *Instance) Database() *dbmodule.DBModule {
return i.database
@ -379,12 +400,6 @@ func (i *Instance) Ready() bool {
return i.serviceGroup.Ready()
}
// Ctx returns the instance context.
// It is only canceled on shutdown.
func (i *Instance) Ctx() context.Context {
return i.ctx
}
// Start starts the instance.
func (i *Instance) Start() error {
return i.serviceGroup.Start()
@ -392,7 +407,6 @@ func (i *Instance) Start() error {
// Stop stops the instance and cancels the instance context when done.
func (i *Instance) Stop() error {
defer i.cancelCtx()
return i.serviceGroup.Stop()
}
@ -406,6 +420,8 @@ func (i *Instance) Restart() {
i.core.EventRestart.Submit(struct{}{})
time.Sleep(10 * time.Millisecond)
// Set the restart flag and shutdown.
i.ShouldRestart = true
i.shutdown(RestartExitCode)
}
@ -419,32 +435,63 @@ func (i *Instance) Shutdown() {
}
func (i *Instance) shutdown(exitCode int) {
// Only shutdown once.
if i.IsShuttingDown() {
return
}
// Cancel main context.
i.cancelCtx()
// Set given exit code.
i.exitCode.Store(int32(exitCode))
// Start shutdown asynchronously in a separate manager.
m := mgr.New("instance")
m.Go("shutdown", func(w *mgr.WorkerCtx) error {
for {
if err := i.Stop(); err != nil {
w.Error("failed to shutdown", "err", err, "retry", "1s")
time.Sleep(1 * time.Second)
} else {
return nil
}
// Stop all modules.
if err := i.Stop(); err != nil {
w.Error("failed to shutdown", "err", err)
}
// Cancel shutdown process context.
i.cancelShutdownCtx()
return nil
})
}
// Stopping returns whether the instance is shutting down.
func (i *Instance) Stopping() bool {
return i.ctx.Err() == nil
// Ctx returns the instance context.
// It is canceled when shutdown is started.
func (i *Instance) Ctx() context.Context {
return i.ctx
}
// Stopped returns a channel that is triggered when the instance has shut down.
func (i *Instance) Stopped() <-chan struct{} {
// IsShuttingDown returns whether the instance is shutting down.
func (i *Instance) IsShuttingDown() bool {
return i.ctx.Err() != nil
}
// ShuttingDown returns a channel that is triggered when the instance starts shutting down.
func (i *Instance) ShuttingDown() <-chan struct{} {
return i.ctx.Done()
}
// ShutdownCtx returns the instance shutdown context.
// It is canceled when shutdown is complete.
func (i *Instance) ShutdownCtx() context.Context {
return i.shutdownCtx
}
// IsShutDown returns whether the instance has stopped.
func (i *Instance) IsShutDown() bool {
return i.shutdownCtx.Err() != nil
}
// ShutDownComplete returns a channel that is triggered when the instance has shut down.
func (i *Instance) ShutdownComplete() <-chan struct{} {
return i.shutdownCtx.Done()
}
// ExitCode returns the set exit code of the instance.
func (i *Instance) ExitCode() int {
return int(i.exitCode.Load())