mirror of
https://github.com/safing/portmaster
synced 2025-09-02 02:29:12 +00:00
Fix tests and issues
This commit is contained in:
parent
2d0d711938
commit
ddf7ba170e
10 changed files with 325 additions and 231 deletions
|
@ -29,21 +29,23 @@ func newTestInstance(testName string) (*testInstance, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigPersistence(t *testing.T) {
|
func TestMain(m *testing.M) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
instance, err := newTestInstance("test-config")
|
instance, err := newTestInstance("test-config")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create test instance: %s", err)
|
panic(fmt.Errorf("failed to create test instance: %w", err))
|
||||||
}
|
}
|
||||||
defer func() { _ = os.RemoveAll(instance.DataDir()) }()
|
defer func() { _ = os.RemoveAll(instance.DataDir()) }()
|
||||||
|
|
||||||
module, err = New(instance)
|
module, err = New(instance)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to initialize module: %s", err)
|
panic(fmt.Errorf("failed to initialize module: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = SaveConfig()
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigPersistence(t *testing.T) { //nolint:paralleltest
|
||||||
|
err := SaveConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,16 +70,16 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func scan(cmd *cobra.Command, args []string) error {
|
func scan(cmd *cobra.Command, args []string) error {
|
||||||
bundle, err := updates.GenerateIndexFromDir(scanDir, scanConfig)
|
index, err := updates.GenerateIndexFromDir(scanDir, scanConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
bundleStr, err := json.MarshalIndent(&bundle, "", " ")
|
indexJson, err := json.MarshalIndent(&index, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal index: %w", err)
|
return fmt.Errorf("marshal index: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("%s", bundleStr)
|
fmt.Printf("%s", indexJson)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,6 +130,7 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { //nolint:maintidx
|
||||||
dataDir: svcCfg.DataDir,
|
dataDir: svcCfg.DataDir,
|
||||||
}
|
}
|
||||||
instance.ctx, instance.cancelCtx = context.WithCancel(context.Background())
|
instance.ctx, instance.cancelCtx = context.WithCancel(context.Background())
|
||||||
|
instance.shutdownCtx, instance.cancelShutdownCtx = context.WithCancel(context.Background())
|
||||||
|
|
||||||
// Base modules
|
// Base modules
|
||||||
instance.base, err = base.New(instance)
|
instance.base, err = base.New(instance)
|
||||||
|
@ -651,24 +652,23 @@ func (i *Instance) shutdown(exitCode int) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cancel main context.
|
||||||
|
i.cancelCtx()
|
||||||
|
|
||||||
// Set given exit code.
|
// Set given exit code.
|
||||||
i.exitCode.Store(int32(exitCode))
|
i.exitCode.Store(int32(exitCode))
|
||||||
|
|
||||||
// Cancel contexts.
|
|
||||||
i.cancelCtx()
|
|
||||||
defer i.cancelShutdownCtx()
|
|
||||||
|
|
||||||
// Start shutdown asynchronously in a separate manager.
|
// Start shutdown asynchronously in a separate manager.
|
||||||
m := mgr.New("instance")
|
m := mgr.New("instance")
|
||||||
m.Go("shutdown", func(w *mgr.WorkerCtx) error {
|
m.Go("shutdown", func(w *mgr.WorkerCtx) error {
|
||||||
for {
|
// Stop all modules.
|
||||||
if err := i.Stop(); err != nil {
|
if err := i.Stop(); err != nil {
|
||||||
w.Error("failed to shutdown", "err", err, "retry", "1s")
|
w.Error("failed to shutdown", "err", err)
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
} else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cancel shutdown process context.
|
||||||
|
i.cancelShutdownCtx()
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ func (d *Downloader) updateIndex(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warningf("updates/%s: failed to update index from %q: %s", d.u.cfg.Name, url, err)
|
log.Warningf("updates/%s: failed to update index from %q: %s", d.u.cfg.Name, url, err)
|
||||||
err = fmt.Errorf("update index file from %q: %s", url, err)
|
err = fmt.Errorf("update index file from %q: %w", url, err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("all index URLs failed, last error: %w", err)
|
return fmt.Errorf("all index URLs failed, last error: %w", err)
|
||||||
|
@ -65,7 +65,7 @@ func (d *Downloader) updateIndex(ctx context.Context) error {
|
||||||
|
|
||||||
// Write the index into a file.
|
// Write the index into a file.
|
||||||
indexFilepath := filepath.Join(d.u.cfg.DownloadDirectory, d.u.cfg.IndexFile)
|
indexFilepath := filepath.Join(d.u.cfg.DownloadDirectory, d.u.cfg.IndexFile)
|
||||||
err = os.WriteFile(indexFilepath, []byte(indexData), defaultFileMode)
|
err = os.WriteFile(indexFilepath, indexData, defaultFileMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("write index file: %w", err)
|
return fmt.Errorf("write index file: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +111,7 @@ func (d *Downloader) gatherExistingFiles(dir string) error {
|
||||||
// Read full file.
|
// Read full file.
|
||||||
fileData, err := os.ReadFile(fullpath)
|
fileData, err := os.ReadFile(fullpath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("updates/%s: failed to read file %q while searching for existing files: %w", d.u.cfg.Name, fullpath, err)
|
log.Debugf("updates/%s: failed to read file %q while searching for existing files: %s", d.u.cfg.Name, fullpath, err)
|
||||||
return fmt.Errorf("failed to read file %s: %w", fullpath, err)
|
return fmt.Errorf("failed to read file %s: %w", fullpath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,7 +150,12 @@ artifacts:
|
||||||
if err == nil {
|
if err == nil {
|
||||||
continue artifacts
|
continue artifacts
|
||||||
}
|
}
|
||||||
log.Debugf("updates/%s: failed to copy existing file %s: %w", d.u.cfg.Name, artifact.Filename, err)
|
log.Debugf("updates/%s: failed to copy existing file %s: %s", d.u.cfg.Name, artifact.Filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the artifact has download URLs.
|
||||||
|
if len(artifact.URLs) == 0 {
|
||||||
|
return fmt.Errorf("artifact %s is missing download URLs", artifact.Filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to download the artifact from one of the URLs.
|
// Try to download the artifact from one of the URLs.
|
||||||
|
@ -163,7 +168,7 @@ artifacts:
|
||||||
// Valid artifact found!
|
// Valid artifact found!
|
||||||
break artifactURLs
|
break artifactURLs
|
||||||
}
|
}
|
||||||
err = fmt.Errorf("update index file from %q: %s", url, err)
|
err = fmt.Errorf("update index file from %q: %w", url, err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("all artifact URLs for %s failed, last error: %w", artifact.Filename, err)
|
return fmt.Errorf("all artifact URLs for %s failed, last error: %w", artifact.Filename, err)
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
semver "github.com/hashicorp/go-version"
|
semver "github.com/hashicorp/go-version"
|
||||||
|
|
||||||
"github.com/safing/jess"
|
"github.com/safing/jess"
|
||||||
"github.com/safing/jess/filesig"
|
"github.com/safing/jess/filesig"
|
||||||
)
|
)
|
||||||
|
@ -26,7 +27,7 @@ const currentPlatform = runtime.GOOS + "_" + runtime.GOARCH
|
||||||
|
|
||||||
var zeroVersion = semver.Must(semver.NewVersion("0.0.0"))
|
var zeroVersion = semver.Must(semver.NewVersion("0.0.0"))
|
||||||
|
|
||||||
// Artifacts represents a single file with metadata.
|
// Artifact represents a single file with metadata.
|
||||||
type Artifact struct {
|
type Artifact struct {
|
||||||
Filename string `json:"Filename"`
|
Filename string `json:"Filename"`
|
||||||
SHA256 string `json:"SHA256"`
|
SHA256 string `json:"SHA256"`
|
||||||
|
@ -85,7 +86,7 @@ func (a *Artifact) IsNewerThan(b *Artifact) (newer, ok bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Artifact) export(dir string, indexVersion *semver.Version) *Artifact {
|
func (a *Artifact) export(dir string, indexVersion *semver.Version) *Artifact {
|
||||||
copy := &Artifact{
|
copied := &Artifact{
|
||||||
Filename: a.Filename,
|
Filename: a.Filename,
|
||||||
SHA256: a.SHA256,
|
SHA256: a.SHA256,
|
||||||
URLs: a.URLs,
|
URLs: a.URLs,
|
||||||
|
@ -98,20 +99,20 @@ func (a *Artifact) export(dir string, indexVersion *semver.Version) *Artifact {
|
||||||
|
|
||||||
// Make sure we have a version number.
|
// Make sure we have a version number.
|
||||||
switch {
|
switch {
|
||||||
case copy.versionNum != nil:
|
case copied.versionNum != nil:
|
||||||
// Version already parsed.
|
// Version already parsed.
|
||||||
case copy.Version != "":
|
case copied.Version != "":
|
||||||
// Need to parse version.
|
// Need to parse version.
|
||||||
v, err := semver.NewVersion(copy.Version)
|
v, err := semver.NewVersion(copied.Version)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
copy.versionNum = v
|
copied.versionNum = v
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// No version defined, inherit index version.
|
// No version defined, inherit index version.
|
||||||
copy.versionNum = indexVersion
|
copied.versionNum = indexVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
return copy
|
return copied
|
||||||
}
|
}
|
||||||
|
|
||||||
// Index represents a collection of artifacts with metadata.
|
// Index represents a collection of artifacts with metadata.
|
||||||
|
@ -146,8 +147,8 @@ func ParseIndex(jsonContent []byte, trustStore jess.TrustStore) (*Index, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse json.
|
// Parse json.
|
||||||
var index Index
|
index := &Index{}
|
||||||
err := json.Unmarshal([]byte(jsonContent), &index)
|
err := json.Unmarshal(jsonContent, index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse index: %w", err)
|
return nil, fmt.Errorf("parse index: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -158,7 +159,7 @@ func ParseIndex(jsonContent []byte, trustStore jess.TrustStore) (*Index, error)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (index *Index) init() error {
|
func (index *Index) init() error {
|
||||||
|
@ -219,7 +220,7 @@ func (index *Index) ShouldUpgradeTo(newIndex *Index) error {
|
||||||
return fmt.Errorf("current index cannot do upgrades: %w", err)
|
return fmt.Errorf("current index cannot do upgrades: %w", err)
|
||||||
}
|
}
|
||||||
if err := newIndex.CanDoUpgrades(); err != nil {
|
if err := newIndex.CanDoUpgrades(); err != nil {
|
||||||
return fmt.Errorf("new index cannot do upgrade: %w")
|
return fmt.Errorf("new index cannot do upgrade: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
@ -229,13 +230,14 @@ func (index *Index) ShouldUpgradeTo(newIndex *Index) error {
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case index.Name != newIndex.Name:
|
case index.Name != newIndex.Name:
|
||||||
return errors.New("index names do not match")
|
return errors.New("new index name does not match")
|
||||||
|
|
||||||
case index.versionNum.GreaterThan(newIndex.versionNum):
|
|
||||||
return errors.New("current index has newer version")
|
|
||||||
|
|
||||||
case index.Published.After(newIndex.Published):
|
case index.Published.After(newIndex.Published):
|
||||||
return errors.New("current index was published later")
|
return errors.New("new index is older (time)")
|
||||||
|
|
||||||
|
case index.versionNum.Segments()[0] > newIndex.versionNum.Segments()[0]:
|
||||||
|
// Downgrades are allowed, if they are not breaking changes.
|
||||||
|
return errors.New("new index is a breaking change downgrade")
|
||||||
|
|
||||||
case index.Published.Equal(newIndex.Published):
|
case index.Published.Equal(newIndex.Published):
|
||||||
// "Do nothing".
|
// "Do nothing".
|
||||||
|
@ -252,7 +254,7 @@ func (index *Index) VerifyArtifacts(dir string) error {
|
||||||
for _, artifact := range index.Artifacts {
|
for _, artifact := range index.Artifacts {
|
||||||
err := checkSHA256SumFile(filepath.Join(dir, artifact.Filename), artifact.SHA256)
|
err := checkSHA256SumFile(filepath.Join(dir, artifact.Filename), artifact.SHA256)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("verify %s: %s", artifact.Filename, err)
|
return fmt.Errorf("verify %s: %w", artifact.Filename, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ package updates
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
@ -95,7 +94,7 @@ settings:
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateIndexFromDir generates a index from a given folder.
|
// GenerateIndexFromDir generates a index from a given folder.
|
||||||
func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) {
|
func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error) { //nolint:maintidx
|
||||||
artifacts := make(map[string]Artifact)
|
artifacts := make(map[string]Artifact)
|
||||||
|
|
||||||
// Initialize.
|
// Initialize.
|
||||||
|
@ -107,6 +106,13 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid index dir: %w", err)
|
return nil, fmt.Errorf("invalid index dir: %w", err)
|
||||||
}
|
}
|
||||||
|
var indexVersion *semver.Version
|
||||||
|
if cfg.Version != "" {
|
||||||
|
indexVersion, err = semver.NewVersion(cfg.Version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid index version: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = filepath.WalkDir(sourceDir, func(fullpath string, d fs.DirEntry, err error) error {
|
err = filepath.WalkDir(sourceDir, func(fullpath string, d fs.DirEntry, err error) error {
|
||||||
// Fail on access error.
|
// Fail on access error.
|
||||||
|
@ -227,9 +233,10 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error)
|
||||||
|
|
||||||
// Create base index.
|
// Create base index.
|
||||||
index := &Index{
|
index := &Index{
|
||||||
Name: cfg.Name,
|
Name: cfg.Name,
|
||||||
Version: cfg.Version,
|
Version: cfg.Version,
|
||||||
Published: time.Now(),
|
Published: time.Now(),
|
||||||
|
versionNum: indexVersion,
|
||||||
}
|
}
|
||||||
if index.Version == "" && cfg.PrimaryArtifact != "" {
|
if index.Version == "" && cfg.PrimaryArtifact != "" {
|
||||||
pv, ok := artifacts[cfg.PrimaryArtifact]
|
pv, ok := artifacts[cfg.PrimaryArtifact]
|
||||||
|
@ -286,45 +293,6 @@ func GenerateIndexFromDir(sourceDir string, cfg IndexScanConfig) (*Index, error)
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func selectLatestArtifacts(artifacts []Artifact) ([]Artifact, error) {
|
|
||||||
artifactsMap := make(map[string]Artifact)
|
|
||||||
|
|
||||||
for _, a := range artifacts {
|
|
||||||
// Make the key platform specific since there can be same filename for multiple platforms.
|
|
||||||
key := a.Filename + a.Platform
|
|
||||||
aMap, ok := artifactsMap[key]
|
|
||||||
if !ok {
|
|
||||||
artifactsMap[key] = a
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if aMap.Version == "" || a.Version == "" {
|
|
||||||
return nil, fmt.Errorf("invalid mix version and non versioned files for: %s", a.Filename)
|
|
||||||
}
|
|
||||||
|
|
||||||
mapVersion, err := semver.NewVersion(aMap.Version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid version for artifact: %s", aMap.Filename)
|
|
||||||
}
|
|
||||||
|
|
||||||
artifactVersion, err := semver.NewVersion(a.Version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid version for artifact: %s", a.Filename)
|
|
||||||
}
|
|
||||||
|
|
||||||
if mapVersion.LessThan(artifactVersion) {
|
|
||||||
artifactsMap[key] = a
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
artifactsFiltered := make([]Artifact, 0, len(artifactsMap))
|
|
||||||
for _, a := range artifactsMap {
|
|
||||||
artifactsFiltered = append(artifactsFiltered, a)
|
|
||||||
}
|
|
||||||
|
|
||||||
return artifactsFiltered, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSHA256(path string, unpackType string) (string, error) {
|
func getSHA256(path string, unpackType string) (string, error) {
|
||||||
content, err := os.ReadFile(path)
|
content, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -372,50 +340,3 @@ func getIdentifierAndVersion(versionedPath string) (identifier, version string,
|
||||||
// `dirPath + filename` is guaranteed by path.Split()
|
// `dirPath + filename` is guaranteed by path.Split()
|
||||||
return dirPath + filename, version, true
|
return dirPath + filename, version, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateMockFolder generates mock index folder for testing.
|
|
||||||
func GenerateMockFolder(dir, name, version string) error { // FIXME: move this to test?
|
|
||||||
// Make sure dir exists
|
|
||||||
_ = os.MkdirAll(dir, defaultDirMode)
|
|
||||||
|
|
||||||
// Create empty files
|
|
||||||
file, err := os.Create(filepath.Join(dir, "portmaster"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_ = file.Close()
|
|
||||||
file, err = os.Create(filepath.Join(dir, "portmaster-core"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_ = file.Close()
|
|
||||||
file, err = os.Create(filepath.Join(dir, "portmaster.zip"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_ = file.Close()
|
|
||||||
file, err = os.Create(filepath.Join(dir, "assets.zip"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_ = file.Close()
|
|
||||||
|
|
||||||
index, err := GenerateIndexFromDir(dir, IndexScanConfig{
|
|
||||||
Name: name,
|
|
||||||
Version: version,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
indexJson, err := json.MarshalIndent(index, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "failed to marshal index: %s\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.WriteFile(filepath.Join(dir, "index.json"), indexJson, defaultFileMode)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -10,11 +10,12 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/tevino/abool"
|
||||||
|
|
||||||
"github.com/safing/jess"
|
"github.com/safing/jess"
|
||||||
"github.com/safing/portmaster/base/log"
|
"github.com/safing/portmaster/base/log"
|
||||||
"github.com/safing/portmaster/base/notifications"
|
"github.com/safing/portmaster/base/notifications"
|
||||||
"github.com/safing/portmaster/service/mgr"
|
"github.com/safing/portmaster/service/mgr"
|
||||||
"github.com/tevino/abool"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -169,7 +170,7 @@ func New(instance instance, name string, cfg Config) (*Updater, error) {
|
||||||
|
|
||||||
// Fall back to scanning the directory.
|
// Fall back to scanning the directory.
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
log.Errorf("updates/%s: invalid index file, falling back to dir scan: %w", cfg.Name, err)
|
log.Errorf("updates/%s: invalid index file, falling back to dir scan: %s", cfg.Name, err)
|
||||||
}
|
}
|
||||||
index, err = GenerateIndexFromDir(cfg.Directory, IndexScanConfig{Version: "0.0.0"})
|
index, err = GenerateIndexFromDir(cfg.Directory, IndexScanConfig{Version: "0.0.0"})
|
||||||
if err == nil && index.init() == nil {
|
if err == nil && index.init() == nil {
|
||||||
|
@ -181,13 +182,12 @@ func New(instance instance, name string, cfg Config) (*Updater, error) {
|
||||||
return module, nil
|
return module, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreVersion, forceApply bool) (err error) {
|
func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreVersion, forceApply bool) (err error) { //nolint:maintidx
|
||||||
// Make sure only one update process is running.
|
// Make sure only one update process is running.
|
||||||
if !u.isUpdateRunning.SetToIf(false, true) {
|
if !u.isUpdateRunning.SetToIf(false, true) {
|
||||||
return fmt.Errorf("an updater task is already running, please try again later")
|
return fmt.Errorf("an updater task is already running, please try again later")
|
||||||
}
|
}
|
||||||
defer u.isUpdateRunning.UnSet()
|
defer u.isUpdateRunning.UnSet()
|
||||||
// FIXME: Switch to mutex?
|
|
||||||
|
|
||||||
// Create a new downloader.
|
// Create a new downloader.
|
||||||
downloader := NewDownloader(u, indexURLs)
|
downloader := NewDownloader(u, indexURLs)
|
||||||
|
@ -201,7 +201,7 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, load index from download dir.
|
// Otherwise, load index from download dir.
|
||||||
downloader.index, err = LoadIndex(filepath.Join(u.cfg.Directory, u.cfg.IndexFile), u.cfg.Verify)
|
downloader.index, err = LoadIndex(filepath.Join(u.cfg.DownloadDirectory, u.cfg.IndexFile), u.cfg.Verify)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("load previously downloaded index file: %w", err)
|
return fmt.Errorf("load previously downloaded index file: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -215,23 +215,42 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
|
||||||
u.indexLock.Unlock()
|
u.indexLock.Unlock()
|
||||||
// Check with local pointer to index.
|
// Check with local pointer to index.
|
||||||
if err := index.ShouldUpgradeTo(downloader.index); err != nil {
|
if err := index.ShouldUpgradeTo(downloader.index); err != nil {
|
||||||
log.Infof("updates/%s: no new or eligible update: %s", u.cfg.Name, err)
|
if errors.Is(err, ErrSameIndex) {
|
||||||
if u.cfg.Notify && u.instance.Notifications() != nil {
|
log.Infof("updates/%s: no new update", u.cfg.Name)
|
||||||
u.instance.Notifications().Notify(¬ifications.Notification{
|
if u.cfg.Notify && u.instance.Notifications() != nil {
|
||||||
EventID: noNewUpdateNotificationID,
|
u.instance.Notifications().Notify(¬ifications.Notification{
|
||||||
Type: notifications.Info,
|
EventID: noNewUpdateNotificationID,
|
||||||
Title: "Portmaster Is Up-To-Date",
|
Type: notifications.Info,
|
||||||
Message: "Portmaster v" + index.Version + " is the newest version.",
|
Title: "Portmaster Is Up-To-Date",
|
||||||
Expires: time.Now().Add(1 * time.Minute).Unix(),
|
Message: "Portmaster v" + index.Version + " is the newest version.",
|
||||||
AvailableActions: []*notifications.Action{
|
Expires: time.Now().Add(1 * time.Minute).Unix(),
|
||||||
{
|
AvailableActions: []*notifications.Action{
|
||||||
ID: "ack",
|
{
|
||||||
Text: "OK",
|
ID: "ack",
|
||||||
|
Text: "OK",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
})
|
||||||
})
|
}
|
||||||
|
} else {
|
||||||
|
log.Warningf("updates/%s: cannot update: %s", u.cfg.Name, err)
|
||||||
|
if u.cfg.Notify && u.instance.Notifications() != nil {
|
||||||
|
u.instance.Notifications().Notify(¬ifications.Notification{
|
||||||
|
EventID: noNewUpdateNotificationID,
|
||||||
|
Type: notifications.Info,
|
||||||
|
Title: "Portmaster Is Up-To-Date*",
|
||||||
|
Message: "While Portmaster v" + index.Version + " is the newest version, there is an internal issue with checking for updates: " + err.Error(),
|
||||||
|
Expires: time.Now().Add(1 * time.Minute).Unix(),
|
||||||
|
AvailableActions: []*notifications.Action{
|
||||||
|
{
|
||||||
|
ID: "ack",
|
||||||
|
Text: "OK",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ErrNoUpdateAvailable
|
return fmt.Errorf("%w: %w", ErrNoUpdateAvailable, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,7 +339,10 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
|
||||||
// Install is complete!
|
// Install is complete!
|
||||||
|
|
||||||
// Clean up and notify modules of changed files.
|
// Clean up and notify modules of changed files.
|
||||||
u.cleanupAfterUpgrade()
|
err = u.cleanupAfterUpgrade()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("updates/%s: failed to clean up after upgrade: %s", u.cfg.Name, err)
|
||||||
|
}
|
||||||
u.EventResourcesUpdated.Submit(struct{}{})
|
u.EventResourcesUpdated.Submit(struct{}{})
|
||||||
|
|
||||||
// If no restart is needed, we are done.
|
// If no restart is needed, we are done.
|
||||||
|
@ -363,7 +385,7 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
|
||||||
Type: notifications.ActionTypeWebhook,
|
Type: notifications.ActionTypeWebhook,
|
||||||
Payload: notifications.ActionTypeWebhookPayload{
|
Payload: notifications.ActionTypeWebhookPayload{
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
URL: "updates/apply", // FIXME
|
URL: "core/restart",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -376,15 +398,35 @@ func (u *Updater) updateAndUpgrade(w *mgr.WorkerCtx, indexURLs []string, ignoreV
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) updateCheckWorker(w *mgr.WorkerCtx) error {
|
func (u *Updater) updateCheckWorker(w *mgr.WorkerCtx) error {
|
||||||
_ = u.updateAndUpgrade(w, u.cfg.IndexURLs, false, false)
|
err := u.updateAndUpgrade(w, u.cfg.IndexURLs, false, false)
|
||||||
// FIXME: Handle errors.
|
switch {
|
||||||
return nil
|
case err == nil:
|
||||||
|
return nil // Success!
|
||||||
|
case errors.Is(err, ErrSameIndex):
|
||||||
|
return nil // Nothing to do.
|
||||||
|
case errors.Is(err, ErrNoUpdateAvailable):
|
||||||
|
return nil // Already logged.
|
||||||
|
case errors.Is(err, ErrActionRequired) && !u.cfg.Notify:
|
||||||
|
return fmt.Errorf("user action required, but notifying user is disabled: %w", err)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("udpating failed: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) upgradeWorker(w *mgr.WorkerCtx) error {
|
func (u *Updater) upgradeWorker(w *mgr.WorkerCtx) error {
|
||||||
_ = u.updateAndUpgrade(w, u.cfg.IndexURLs, false, true)
|
err := u.updateAndUpgrade(w, u.cfg.IndexURLs, false, true)
|
||||||
// FIXME: Handle errors.
|
switch {
|
||||||
return nil
|
case err == nil:
|
||||||
|
return nil // Success!
|
||||||
|
case errors.Is(err, ErrSameIndex):
|
||||||
|
return nil // Nothing to do.
|
||||||
|
case errors.Is(err, ErrNoUpdateAvailable):
|
||||||
|
return nil // Already logged.
|
||||||
|
case errors.Is(err, ErrActionRequired) && !u.cfg.Notify:
|
||||||
|
return fmt.Errorf("user action required, but notifying user is disabled: %w", err)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("udpating failed: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForceUpdate executes a forced update and upgrade directly and synchronously
|
// ForceUpdate executes a forced update and upgrade directly and synchronously
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
package updates
|
package updates
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/safing/portmaster/base/notifications"
|
"github.com/safing/portmaster/base/notifications"
|
||||||
|
"github.com/safing/portmaster/service/mgr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testInstance struct{}
|
type testInstance struct{}
|
||||||
|
@ -24,66 +27,129 @@ func (i *testInstance) Ready() bool {
|
||||||
|
|
||||||
func (i *testInstance) SetCmdLineOperation(f func() error) {}
|
func (i *testInstance) SetCmdLineOperation(f func() error) {}
|
||||||
|
|
||||||
func TestPreformUpdate(t *testing.T) {
|
func TestPerformUpdate(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// Initialize mock instance
|
// Initialize mock instance
|
||||||
stub := &testInstance{}
|
stub := &testInstance{}
|
||||||
|
|
||||||
// Make tmp dirs
|
// Make tmp dirs
|
||||||
installedDir, err := os.MkdirTemp("", "updates_current")
|
installedDir, err := os.MkdirTemp("", "updates_current_")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer func() { _ = os.RemoveAll(installedDir) }()
|
defer func() { _ = os.RemoveAll(installedDir) }()
|
||||||
updateDir, err := os.MkdirTemp("", "updates_new")
|
updateDir, err := os.MkdirTemp("", "updates_new_")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer func() { _ = os.RemoveAll(updateDir) }()
|
defer func() { _ = os.RemoveAll(updateDir) }()
|
||||||
purgeDir, err := os.MkdirTemp("", "updates_purge")
|
purgeDir, err := os.MkdirTemp("", "updates_purge_")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer func() { _ = os.RemoveAll(purgeDir) }()
|
defer func() { _ = os.RemoveAll(purgeDir) }()
|
||||||
|
|
||||||
// Generate mock files
|
// Generate mock files
|
||||||
if err := GenerateMockFolder(installedDir, "Test", "1.0.0"); err != nil {
|
now := time.Now()
|
||||||
panic(err)
|
if err := GenerateMockFolder(installedDir, "Test", "1.0.0", now); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if err := GenerateMockFolder(updateDir, "Test", "1.0.1"); err != nil {
|
if err := GenerateMockFolder(updateDir, "Test", "1.0.1", now.Add(1*time.Minute)); err != nil {
|
||||||
panic(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create updater
|
// Create updater (loads index).
|
||||||
updates, err := New(stub, "Test", Config{
|
updater, err := New(stub, "Test", Config{
|
||||||
|
Name: "Test",
|
||||||
Directory: installedDir,
|
Directory: installedDir,
|
||||||
DownloadDirectory: updateDir,
|
DownloadDirectory: updateDir,
|
||||||
PurgeDirectory: purgeDir,
|
PurgeDirectory: purgeDir,
|
||||||
IndexFile: "index.json",
|
IndexFile: "index.json",
|
||||||
AutoApply: false,
|
AutoDownload: true,
|
||||||
NeedsRestart: false,
|
AutoApply: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
t.Fatal(err)
|
||||||
}
|
|
||||||
// Read and parse the index file
|
|
||||||
if err := updates.downloader.Verify(); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to apply the updates
|
// Try to apply the updates
|
||||||
err = updates.applyUpdates(nil)
|
m := mgr.New("updates test")
|
||||||
if err != nil {
|
_ = m.Do("test update and upgrade", func(w *mgr.WorkerCtx) error {
|
||||||
panic(err)
|
if err := updater.updateAndUpgrade(w, nil, false, false); err != nil {
|
||||||
}
|
if data, err := os.ReadFile(filepath.Join(installedDir, "index.json")); err == nil {
|
||||||
|
fmt.Println(string(data))
|
||||||
|
fmt.Println(updater.index.Version)
|
||||||
|
fmt.Println(updater.index.versionNum)
|
||||||
|
}
|
||||||
|
if data, err := os.ReadFile(filepath.Join(updateDir, "index.json")); err == nil {
|
||||||
|
fmt.Println(string(data))
|
||||||
|
idx, err := ParseIndex(data, nil)
|
||||||
|
if err == nil {
|
||||||
|
fmt.Println(idx.Version)
|
||||||
|
fmt.Println(idx.versionNum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CHeck if the current version is now the new.
|
t.Fatal(err)
|
||||||
bundle, err := LoadBundle(filepath.Join(installedDir, "index.json"))
|
}
|
||||||
if err != nil {
|
return nil
|
||||||
panic(err)
|
})
|
||||||
}
|
|
||||||
|
|
||||||
if bundle.Version != "1.0.1" {
|
// Check if the current version is now the new.
|
||||||
panic(fmt.Errorf("expected version 1.0.1 found %s", bundle.Version))
|
newIndex, err := LoadIndex(filepath.Join(installedDir, "index.json"), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if newIndex.Version != "1.0.1" {
|
||||||
|
t.Fatalf("expected version 1.0.1 found %s", newIndex.Version)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateMockFolder generates mock index folder for testing.
|
||||||
|
func GenerateMockFolder(dir, name, version string, published time.Time) error {
|
||||||
|
// Make sure dir exists
|
||||||
|
_ = os.MkdirAll(dir, defaultDirMode)
|
||||||
|
|
||||||
|
// Create empty files
|
||||||
|
file, err := os.Create(filepath.Join(dir, "portmaster"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = file.Close()
|
||||||
|
file, err = os.Create(filepath.Join(dir, "portmaster-core"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = file.Close()
|
||||||
|
file, err = os.Create(filepath.Join(dir, "portmaster.zip"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = file.Close()
|
||||||
|
file, err = os.Create(filepath.Join(dir, "assets.zip"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = file.Close()
|
||||||
|
|
||||||
|
index, err := GenerateIndexFromDir(dir, IndexScanConfig{
|
||||||
|
Name: name,
|
||||||
|
Version: version,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
index.Published = published
|
||||||
|
|
||||||
|
indexJSON, err := json.MarshalIndent(index, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "failed to marshal index: %s\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(filepath.Join(dir, "index.json"), indexJSON, defaultFileMode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package updates
|
package updates
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
|
@ -31,7 +32,7 @@ func (u *Updater) upgrade(downloader *Downloader, ignoreVersion bool) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the upgrade.
|
// Execute the upgrade.
|
||||||
upgradeError := u.upgradeMoveFiles(downloader, ignoreVersion)
|
upgradeError := u.upgradeMoveFiles(downloader)
|
||||||
if upgradeError == nil {
|
if upgradeError == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -43,10 +44,10 @@ func (u *Updater) upgrade(downloader *Downloader, ignoreVersion bool) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recovery failed too.
|
// Recovery failed too.
|
||||||
return fmt.Errorf("upgrade (including recovery) failed: %s", u.cfg.Name, upgradeError)
|
return fmt.Errorf("upgrade (including recovery) failed: %w", upgradeError)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) upgradeMoveFiles(downloader *Downloader, ignoreVersion bool) error {
|
func (u *Updater) upgradeMoveFiles(downloader *Downloader) error {
|
||||||
// Important:
|
// Important:
|
||||||
// We assume that the downloader has done its job and all artifacts are verified.
|
// We assume that the downloader has done its job and all artifacts are verified.
|
||||||
// Files will just be moved here.
|
// Files will just be moved here.
|
||||||
|
@ -65,20 +66,28 @@ func (u *Updater) upgradeMoveFiles(downloader *Downloader, ignoreVersion bool) e
|
||||||
}
|
}
|
||||||
files, err := os.ReadDir(u.cfg.Directory)
|
files, err := os.ReadDir(u.cfg.Directory)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("read current directory: %w", err)
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
}
|
return fmt.Errorf("read current directory: %w", err)
|
||||||
for _, file := range files {
|
|
||||||
// Check if file is ignored.
|
|
||||||
if slices.Contains(u.cfg.Ignore, file.Name()) {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
err = os.MkdirAll(u.cfg.Directory, defaultDirMode)
|
||||||
// Otherwise, move file to purge dir.
|
|
||||||
src := filepath.Join(u.cfg.Directory, file.Name())
|
|
||||||
dst := filepath.Join(u.cfg.PurgeDirectory, file.Name())
|
|
||||||
err := u.moveFile(src, dst, "", file.Type().Perm())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to move current file %s to purge dir: %w", file.Name(), err)
|
return fmt.Errorf("create current directory: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Move files.
|
||||||
|
for _, file := range files {
|
||||||
|
// Check if file is ignored.
|
||||||
|
if slices.Contains(u.cfg.Ignore, file.Name()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, move file to purge dir.
|
||||||
|
src := filepath.Join(u.cfg.Directory, file.Name())
|
||||||
|
dst := filepath.Join(u.cfg.PurgeDirectory, file.Name())
|
||||||
|
err := u.moveFile(src, dst, "", file.Type().Perm())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to move current file %s to purge dir: %w", file.Name(), err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,7 +127,7 @@ func (u *Updater) moveFile(currentPath, newPath string, sha256sum string, fileMo
|
||||||
// Moving was successful, return.
|
// Moving was successful, return.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.Tracef("updates/%s: failed to move to %q, falling back to copy+delete: %w", u.cfg.Name, newPath, err)
|
log.Tracef("updates/%s: failed to move to %q, falling back to copy+delete: %s", u.cfg.Name, newPath, err)
|
||||||
|
|
||||||
// Copy and check the checksum while we are at it.
|
// Copy and check the checksum while we are at it.
|
||||||
err = copyAndCheckSHA256Sum(currentPath, newPath, sha256sum, fileMode)
|
err = copyAndCheckSHA256Sum(currentPath, newPath, sha256sum, fileMode)
|
||||||
|
@ -144,7 +153,7 @@ func (u *Updater) recoverFromFailedUpgrade() error {
|
||||||
err := u.moveFile(purgedFile, activeFile, "", file.Type().Perm())
|
err := u.moveFile(purgedFile, activeFile, "", file.Type().Perm())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Only warn and continue to recover as many files as possible.
|
// Only warn and continue to recover as many files as possible.
|
||||||
log.Warningf("updates/%s: failed to roll back file %s: %w", u.cfg.Name, file.Name(), err)
|
log.Warningf("updates/%s: failed to roll back file %s: %s", u.cfg.Name, file.Name(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,10 +34,17 @@ import (
|
||||||
|
|
||||||
// Instance is an instance of a Portmaster service.
|
// Instance is an instance of a Portmaster service.
|
||||||
type Instance struct {
|
type Instance struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancelCtx context.CancelFunc
|
cancelCtx context.CancelFunc
|
||||||
|
|
||||||
|
shutdownCtx context.Context
|
||||||
|
cancelShutdownCtx context.CancelFunc
|
||||||
|
|
||||||
serviceGroup *mgr.Group
|
serviceGroup *mgr.Group
|
||||||
|
|
||||||
|
binDir string
|
||||||
|
dataDir string
|
||||||
|
|
||||||
exitCode atomic.Int32
|
exitCode atomic.Int32
|
||||||
|
|
||||||
base *base.Base
|
base *base.Base
|
||||||
|
@ -67,6 +74,7 @@ type Instance struct {
|
||||||
terminal *terminal.TerminalModule
|
terminal *terminal.TerminalModule
|
||||||
|
|
||||||
CommandLineOperation func() error
|
CommandLineOperation func() error
|
||||||
|
ShouldRestart bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new Portmaster service instance.
|
// New returns a new Portmaster service instance.
|
||||||
|
@ -74,6 +82,7 @@ func New() (*Instance, error) {
|
||||||
// Create instance to pass it to modules.
|
// Create instance to pass it to modules.
|
||||||
instance := &Instance{}
|
instance := &Instance{}
|
||||||
instance.ctx, instance.cancelCtx = context.WithCancel(context.Background())
|
instance.ctx, instance.cancelCtx = context.WithCancel(context.Background())
|
||||||
|
instance.shutdownCtx, instance.cancelShutdownCtx = context.WithCancel(context.Background())
|
||||||
|
|
||||||
binaryUpdateIndex := updates.Config{
|
binaryUpdateIndex := updates.Config{
|
||||||
// FIXME: fill
|
// FIXME: fill
|
||||||
|
@ -234,6 +243,18 @@ func (i *Instance) SetSleep(enabled bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BinDir returns the directory for binaries.
|
||||||
|
// This directory may be read-only.
|
||||||
|
func (i *Instance) BinDir() string {
|
||||||
|
return i.binDir
|
||||||
|
}
|
||||||
|
|
||||||
|
// DataDir returns the directory for variable data.
|
||||||
|
// This directory is expected to be read/writeable.
|
||||||
|
func (i *Instance) DataDir() string {
|
||||||
|
return i.dataDir
|
||||||
|
}
|
||||||
|
|
||||||
// Database returns the database module.
|
// Database returns the database module.
|
||||||
func (i *Instance) Database() *dbmodule.DBModule {
|
func (i *Instance) Database() *dbmodule.DBModule {
|
||||||
return i.database
|
return i.database
|
||||||
|
@ -379,12 +400,6 @@ func (i *Instance) Ready() bool {
|
||||||
return i.serviceGroup.Ready()
|
return i.serviceGroup.Ready()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ctx returns the instance context.
|
|
||||||
// It is only canceled on shutdown.
|
|
||||||
func (i *Instance) Ctx() context.Context {
|
|
||||||
return i.ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the instance.
|
// Start starts the instance.
|
||||||
func (i *Instance) Start() error {
|
func (i *Instance) Start() error {
|
||||||
return i.serviceGroup.Start()
|
return i.serviceGroup.Start()
|
||||||
|
@ -392,7 +407,6 @@ func (i *Instance) Start() error {
|
||||||
|
|
||||||
// Stop stops the instance and cancels the instance context when done.
|
// Stop stops the instance and cancels the instance context when done.
|
||||||
func (i *Instance) Stop() error {
|
func (i *Instance) Stop() error {
|
||||||
defer i.cancelCtx()
|
|
||||||
return i.serviceGroup.Stop()
|
return i.serviceGroup.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -406,6 +420,8 @@ func (i *Instance) Restart() {
|
||||||
i.core.EventRestart.Submit(struct{}{})
|
i.core.EventRestart.Submit(struct{}{})
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Set the restart flag and shutdown.
|
||||||
|
i.ShouldRestart = true
|
||||||
i.shutdown(RestartExitCode)
|
i.shutdown(RestartExitCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -419,32 +435,63 @@ func (i *Instance) Shutdown() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) shutdown(exitCode int) {
|
func (i *Instance) shutdown(exitCode int) {
|
||||||
|
// Only shutdown once.
|
||||||
|
if i.IsShuttingDown() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel main context.
|
||||||
|
i.cancelCtx()
|
||||||
|
|
||||||
// Set given exit code.
|
// Set given exit code.
|
||||||
i.exitCode.Store(int32(exitCode))
|
i.exitCode.Store(int32(exitCode))
|
||||||
|
|
||||||
|
// Start shutdown asynchronously in a separate manager.
|
||||||
m := mgr.New("instance")
|
m := mgr.New("instance")
|
||||||
m.Go("shutdown", func(w *mgr.WorkerCtx) error {
|
m.Go("shutdown", func(w *mgr.WorkerCtx) error {
|
||||||
for {
|
// Stop all modules.
|
||||||
if err := i.Stop(); err != nil {
|
if err := i.Stop(); err != nil {
|
||||||
w.Error("failed to shutdown", "err", err, "retry", "1s")
|
w.Error("failed to shutdown", "err", err)
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
} else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cancel shutdown process context.
|
||||||
|
i.cancelShutdownCtx()
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stopping returns whether the instance is shutting down.
|
// Ctx returns the instance context.
|
||||||
func (i *Instance) Stopping() bool {
|
// It is canceled when shutdown is started.
|
||||||
return i.ctx.Err() == nil
|
func (i *Instance) Ctx() context.Context {
|
||||||
|
return i.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stopped returns a channel that is triggered when the instance has shut down.
|
// IsShuttingDown returns whether the instance is shutting down.
|
||||||
func (i *Instance) Stopped() <-chan struct{} {
|
func (i *Instance) IsShuttingDown() bool {
|
||||||
|
return i.ctx.Err() != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShuttingDown returns a channel that is triggered when the instance starts shutting down.
|
||||||
|
func (i *Instance) ShuttingDown() <-chan struct{} {
|
||||||
return i.ctx.Done()
|
return i.ctx.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ShutdownCtx returns the instance shutdown context.
|
||||||
|
// It is canceled when shutdown is complete.
|
||||||
|
func (i *Instance) ShutdownCtx() context.Context {
|
||||||
|
return i.shutdownCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsShutDown returns whether the instance has stopped.
|
||||||
|
func (i *Instance) IsShutDown() bool {
|
||||||
|
return i.shutdownCtx.Err() != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutDownComplete returns a channel that is triggered when the instance has shut down.
|
||||||
|
func (i *Instance) ShutdownComplete() <-chan struct{} {
|
||||||
|
return i.shutdownCtx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
// ExitCode returns the set exit code of the instance.
|
// ExitCode returns the set exit code of the instance.
|
||||||
func (i *Instance) ExitCode() int {
|
func (i *Instance) ExitCode() int {
|
||||||
return int(i.exitCode.Load())
|
return int(i.exitCode.Load())
|
||||||
|
|
Loading…
Add table
Reference in a new issue