Merge pull request #9 from safing/develop

Fix clean shutdown with portmaster-control
This commit is contained in:
Daniel 2019-07-05 09:26:58 +02:00 committed by GitHub
commit 9bdebf207b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 173 additions and 105 deletions

View file

@ -31,7 +31,7 @@ func init() {
func main() { func main() {
// Set Info // Set Info
info.Set("Portmaster", "0.3.0", "AGPLv3", true) info.Set("Portmaster", "0.3.1", "AGPLv3", true)
// Start // Start
err := modules.Start() err := modules.Start()

View file

@ -1,17 +1,18 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"os" "time"
"github.com/safing/portbase/utils"
"github.com/safing/portmaster/updates" "github.com/safing/portmaster/updates"
) )
func getFile(identifier string) (*updates.File, error) { func getFile(opts *Options) (*updates.File, error) {
// get newest local file // get newest local file
updates.LoadLatest() updates.LoadLatest()
file, err := updates.GetPlatformFile(identifier)
file, err := updates.GetLocalPlatformFile(opts.Identifier)
if err == nil { if err == nil {
return file, nil return file, nil
} }
@ -19,28 +20,42 @@ func getFile(identifier string) (*updates.File, error) {
return nil, err return nil, err
} }
fmt.Printf("%s downloading %s...\n", logPrefix, identifier) // download
if opts.AllowDownload {
fmt.Printf("%s downloading %s...\n", logPrefix, opts.Identifier)
// if no matching file exists, load index // download indexes
err = updates.LoadIndexes() err = updates.UpdateIndexes()
if err != nil { if err != nil {
if os.IsNotExist(err) { return nil, err
// create dirs }
err = utils.EnsureDirectory(updateStoragePath, 0755)
if err != nil {
return nil, err
}
// download indexes // download file
err = updates.CheckForUpdates() file, err := updates.GetPlatformFile(opts.Identifier)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { return file, nil
}
// wait for 30 seconds
fmt.Printf("%s waiting for download of %s (by Portmaster Core) to complete...\n", logPrefix, opts.Identifier)
// try every 0.5 secs
for tries := 0; tries < 60; tries++ {
time.Sleep(500 * time.Millisecond)
// reload local files
updates.LoadLatest()
// get file
file, err := updates.GetLocalPlatformFile(opts.Identifier)
if err == nil {
return file, nil
}
if err != updates.ErrNotFound {
return nil, err return nil, err
} }
} }
return nil, errors.New("please try again later or check the Portmaster logs")
// get file
return updates.GetPlatformFile(identifier)
} }

View file

@ -5,7 +5,10 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"os/user"
"path/filepath" "path/filepath"
"runtime"
"strings"
"github.com/safing/portbase/info" "github.com/safing/portbase/info"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -55,7 +58,7 @@ func main() {
// }() // }()
// set meta info // set meta info
info.Set("Portmaster Control", "0.2.0", "AGPLv3", true) info.Set("Portmaster Control", "0.2.1", "AGPLv3", true)
// check if meta info is ok // check if meta info is ok
err := info.CheckVersion() err := info.CheckVersion()
@ -86,7 +89,23 @@ func initPmCtl(cmd *cobra.Command, args []string) error {
return errors.New("please supply the database directory using the --db flag") return errors.New("please supply the database directory using the --db flag")
} }
err := removeOldBin() // check if we are root/admin for self upgrade
userInfo, err := user.Current()
if err != nil {
return nil
}
switch runtime.GOOS {
case "linux":
if userInfo.Username != "root" {
return nil
}
case "windows":
if !strings.HasSuffix(userInfo.Username, "SYSTEM") { // is this correct?
return nil
}
}
err = removeOldBin()
if err != nil { if err != nil {
fmt.Printf("%s warning: failed to remove old upgrade: %s\n", logPrefix, err) fmt.Printf("%s warning: failed to remove old upgrade: %s\n", logPrefix, err)
} }

View file

@ -5,12 +5,20 @@ import (
"io" "io"
"os" "os"
"os/exec" "os/exec"
"os/signal"
"runtime" "runtime"
"strings" "strings"
"syscall"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// Options for starting component
type Options struct {
Identifier string
AllowDownload bool
}
func init() { func init() {
rootCmd.AddCommand(runCmd) rootCmd.AddCommand(runCmd)
runCmd.AddCommand(runCore) runCmd.AddCommand(runCore)
@ -27,7 +35,10 @@ var runCore = &cobra.Command{
Use: "core", Use: "core",
Short: "Run the Portmaster Core", Short: "Run the Portmaster Core",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return run("core/portmaster-core", cmd, false) return run(cmd, &Options{
Identifier: "core/portmaster-core",
AllowDownload: true,
})
}, },
FParseErrWhitelist: cobra.FParseErrWhitelist{ FParseErrWhitelist: cobra.FParseErrWhitelist{
// UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags
@ -39,7 +50,10 @@ var runApp = &cobra.Command{
Use: "app", Use: "app",
Short: "Run the Portmaster App", Short: "Run the Portmaster App",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return run("app/portmaster-app", cmd, true) return run(cmd, &Options{
Identifier: "app/portmaster-app",
AllowDownload: false,
})
}, },
FParseErrWhitelist: cobra.FParseErrWhitelist{ FParseErrWhitelist: cobra.FParseErrWhitelist{
// UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags
@ -51,7 +65,10 @@ var runNotifier = &cobra.Command{
Use: "notifier", Use: "notifier",
Short: "Run the Portmaster Notifier", Short: "Run the Portmaster Notifier",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return run("notifier/portmaster-notifier", cmd, true) return run(cmd, &Options{
Identifier: "notifier/portmaster-notifier",
AllowDownload: false,
})
}, },
FParseErrWhitelist: cobra.FParseErrWhitelist{ FParseErrWhitelist: cobra.FParseErrWhitelist{
// UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags
@ -59,68 +76,42 @@ var runNotifier = &cobra.Command{
}, },
} }
func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error { func run(cmd *cobra.Command, opts *Options) error {
// get original arguments // get original arguments
if len(os.Args) <= 3 { var args []string
if len(os.Args) < 4 {
return cmd.Help() return cmd.Help()
} }
var args []string args = os.Args[3:]
// filter out database flag
if filterDatabaseFlag {
skip := false
for _, arg := range os.Args[3:] {
if skip {
skip = false
continue
}
if arg == "--db" {
// flag is seperated, skip two arguments
skip = true
continue
}
if strings.HasPrefix(arg, "--db=") {
// flag is one string, skip one argument
continue
}
args = append(args, arg)
}
} else {
args = os.Args[3:]
}
// adapt identifier // adapt identifier
if windows() { if windows() {
identifier += ".exe" opts.Identifier += ".exe"
} }
// run // run
for { for {
file, err := getFile(identifier) file, err := getFile(opts)
if err != nil { if err != nil {
return fmt.Errorf("%s could not get component: %s", logPrefix, err) return fmt.Errorf("could not get component: %s", err)
} }
// check permission // check permission
if !windows() { if !windows() {
info, err := os.Stat(file.Path()) info, err := os.Stat(file.Path())
if err != nil { if err != nil {
return fmt.Errorf("%s failed to get file info on %s: %s", logPrefix, file.Path(), err) return fmt.Errorf("failed to get file info on %s: %s", file.Path(), err)
} }
if info.Mode() != 0755 { if info.Mode() != 0755 {
err := os.Chmod(file.Path(), 0755) err := os.Chmod(file.Path(), 0755)
if err != nil { if err != nil {
return fmt.Errorf("%s failed to set exec permissions on %s: %s", logPrefix, file.Path(), err) return fmt.Errorf("failed to set exec permissions on %s: %s", file.Path(), err)
} }
} }
} }
fmt.Printf("%s starting %s %s\n", logPrefix, file.Path(), strings.Join(args, " ")) fmt.Printf("%s starting %s %s\n", logPrefix, file.Path(), strings.Join(args, " "))
// os.Exit(0)
// create command // create command
exc := exec.Command(file.Path(), args...) exc := exec.Command(file.Path(), args...)
@ -128,17 +119,17 @@ func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error {
// consume stdout/stderr // consume stdout/stderr
stdout, err := exc.StdoutPipe() stdout, err := exc.StdoutPipe()
if err != nil { if err != nil {
return fmt.Errorf("%s failed to connect stdout: %s", logPrefix, err) return fmt.Errorf("failed to connect stdout: %s", err)
} }
stderr, err := exc.StderrPipe() stderr, err := exc.StderrPipe()
if err != nil { if err != nil {
return fmt.Errorf("%s failed to connect stderr: %s", logPrefix, err) return fmt.Errorf("failed to connect stderr: %s", err)
} }
// start // start
err = exc.Start() err = exc.Start()
if err != nil { if err != nil {
return fmt.Errorf("%s failed to start %s: %s", logPrefix, identifier, err) return fmt.Errorf("failed to start %s: %s", opts.Identifier, err)
} }
// start output writers // start output writers
@ -149,6 +140,24 @@ func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error {
io.Copy(os.Stderr, stderr) io.Copy(os.Stderr, stderr)
}() }()
// catch interrupt for clean shutdown
signalCh := make(chan os.Signal)
signal.Notify(
signalCh,
os.Interrupt,
os.Kill,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
)
go func() {
for {
sig := <-signalCh
fmt.Printf("%s got %s signal (ignoring), waiting for %s to exit...\n", logPrefix, sig, opts.Identifier)
}
}()
// wait for completion // wait for completion
err = exc.Wait() err = exc.Wait()
if err != nil { if err != nil {
@ -157,31 +166,30 @@ func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error {
switch exErr.ProcessState.ExitCode() { switch exErr.ProcessState.ExitCode() {
case 0: case 0:
// clean exit // clean exit
fmt.Printf("%s clean exit of %s, but with error: %s\n", logPrefix, identifier, err) fmt.Printf("%s clean exit of %s, but with error: %s\n", logPrefix, opts.Identifier, err)
os.Exit(1) os.Exit(1)
case 1: case 1:
// error exit // error exit
fmt.Printf("%s error during execution of %s: %s\n", logPrefix, identifier, err) fmt.Printf("%s error during execution of %s: %s\n", logPrefix, opts.Identifier, err)
os.Exit(1) os.Exit(1)
case 2357427: // Leet Speak for "restart" case 2357427: // Leet Speak for "restart"
// restart request // restart request
fmt.Printf("%s restarting %s\n", logPrefix, identifier) fmt.Printf("%s restarting %s\n", logPrefix, opts.Identifier)
continue continue
default: default:
fmt.Printf("%s unexpected error during execution of %s: %s\n", logPrefix, identifier, err) fmt.Printf("%s unexpected error during execution of %s: %s\n", logPrefix, opts.Identifier, err)
os.Exit(exErr.ProcessState.ExitCode()) os.Exit(exErr.ProcessState.ExitCode())
} }
} else { } else {
fmt.Printf("%s unexpected error type during execution of %s: %s\n", logPrefix, identifier, err) fmt.Printf("%s unexpected error type during execution of %s: %s\n", logPrefix, opts.Identifier, err)
os.Exit(1) os.Exit(1)
} }
} }
// clean exit // clean exit
break break
} }
fmt.Printf("%s %s completed successfully\n", logPrefix, identifier) fmt.Printf("%s %s completed successfully\n", logPrefix, opts.Identifier)
return nil return nil
} }

View file

@ -11,6 +11,10 @@ import (
"github.com/safing/portmaster/updates" "github.com/safing/portmaster/updates"
) )
var (
oldBinSuffix = "-old"
)
func checkForUpgrade() (update *updates.File) { func checkForUpgrade() (update *updates.File) {
info := info.GetInfo() info := info.GetInfo()
file, err := updates.GetLocalPlatformFile("control/portmaster-control") file, err := updates.GetLocalPlatformFile("control/portmaster-control")
@ -25,6 +29,8 @@ func checkForUpgrade() (update *updates.File) {
func doSelfUpgrade(file *updates.File) error { func doSelfUpgrade(file *updates.File) error {
// FIXME: fix permissions if needed
// get destination // get destination
dst, err := os.Executable() dst, err := os.Executable()
if err != nil { if err != nil {
@ -36,7 +42,7 @@ func doSelfUpgrade(file *updates.File) error {
} }
// mv destination // mv destination
err = os.Rename(dst, dst+"_old") err = os.Rename(dst, dst+oldBinSuffix)
if err != nil { if err != nil {
return err return err
} }
@ -105,7 +111,7 @@ func removeOldBin() error {
} }
// delete old // delete old
err = os.Remove(dst + "_old") err = os.Remove(dst + oldBinSuffix)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return err return err

View file

@ -11,14 +11,19 @@ import (
"time" "time"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
) )
func updater() { func updater() {
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
for { for {
err := CheckForUpdates() err := UpdateIndexes()
if err != nil { if err != nil {
log.Warningf("updates: failed to check for updates: %s", err) log.Warningf("updates: updating index failed: %s", err)
}
err = DownloadUpdates()
if err != nil {
log.Warningf("updates: downloading updates failed: %s", err)
} }
time.Sleep(1 * time.Hour) time.Sleep(1 * time.Hour)
} }
@ -40,10 +45,9 @@ func markPlatformFileForDownload(identifier string) {
markFileForDownload(identifier) markFileForDownload(identifier)
} }
// CheckForUpdates checks if updates are available and downloads updates of used components. // UpdateIndexes downloads the current update indexes.
func CheckForUpdates() (err error) { func UpdateIndexes() (err error) {
// download new indexes
// download new index
var data []byte var data []byte
for tries := 0; tries < 3; tries++ { for tries := 0; tries < 3; tries++ {
data, err = fetchData("stable.json", tries) data, err = fetchData("stable.json", tries)
@ -52,39 +56,72 @@ func CheckForUpdates() (err error) {
} }
} }
if err != nil { if err != nil {
return err return fmt.Errorf("failed to download: %s", err)
} }
newStableUpdates := make(map[string]string) newStableUpdates := make(map[string]string)
err = json.Unmarshal(data, &newStableUpdates) err = json.Unmarshal(data, &newStableUpdates)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to parse: %s", err)
} }
if len(newStableUpdates) == 0 { if len(newStableUpdates) == 0 {
return errors.New("stable.json is empty") return errors.New("index is empty")
} }
// update stable index
updatesLock.Lock()
stableUpdates = newStableUpdates
updatesLock.Unlock()
// check dir
err = utils.EnsureDirectory(updateStoragePath, 0755)
if err != nil {
return err
}
// save stable index
err = ioutil.WriteFile(filepath.Join(updateStoragePath, "stable.json"), data, 0644)
if err != nil {
log.Warningf("updates: failed to save new version of stable.json: %s", err)
}
// update version status
updatesLock.RLock()
updateStatus(versionClassStable, stableUpdates)
updatesLock.RUnlock()
// FIXME IN STABLE: correct log line // FIXME IN STABLE: correct log line
log.Infof("updates: downloaded new update index: stable.json (alpha until we actually reach stable)") log.Infof("updates: updated index stable.json (alpha/beta until we actually reach stable)")
return nil
}
// DownloadUpdates checks if updates are available and downloads updates of used components.
func DownloadUpdates() (err error) {
// ensure important components are always updated // ensure important components are always updated
updatesLock.Lock() updatesLock.Lock()
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
markPlatformFileForDownload("core/portmaster-core.exe")
markPlatformFileForDownload("control/portmaster-control.exe") markPlatformFileForDownload("control/portmaster-control.exe")
markPlatformFileForDownload("app/portmaster-app.exe") markPlatformFileForDownload("app/portmaster-app.exe")
markPlatformFileForDownload("notifier/portmaster-notifier.exe") markPlatformFileForDownload("notifier/portmaster-notifier.exe")
} else { } else {
markPlatformFileForDownload("core/portmaster-core")
markPlatformFileForDownload("control/portmaster-control") markPlatformFileForDownload("control/portmaster-control")
markPlatformFileForDownload("app/portmaster-app") markPlatformFileForDownload("app/portmaster-app")
markPlatformFileForDownload("notifier/portmaster-notifier") markPlatformFileForDownload("notifier/portmaster-notifier")
} }
updatesLock.Unlock() updatesLock.Unlock()
// RLock for the remaining function
updatesLock.RLock()
defer updatesLock.RUnlock()
// update existing files // update existing files
log.Tracef("updates: updating existing files") log.Tracef("updates: updating existing files")
updatesLock.RLock() for identifier, newVersion := range stableUpdates {
for identifier, newVersion := range newStableUpdates {
oldVersion, ok := localUpdates[identifier] oldVersion, ok := localUpdates[identifier]
if ok && newVersion != oldVersion { if ok && newVersion != oldVersion {
@ -103,24 +140,7 @@ func CheckForUpdates() (err error) {
} }
} }
updatesLock.RUnlock()
log.Tracef("updates: finished updating existing files") log.Tracef("updates: finished updating existing files")
// update stable index
updatesLock.Lock()
stableUpdates = newStableUpdates
updatesLock.Unlock()
// save stable index
err = ioutil.WriteFile(filepath.Join(updateStoragePath, "stable.json"), data, 0644)
if err != nil {
log.Warningf("updates: failed to save new version of stable.json: %s", err)
}
// update version status
updatesLock.RLock()
defer updatesLock.RUnlock()
updateStatus(versionClassStable, stableUpdates)
return nil return nil
} }