diff --git a/Gopkg.lock b/Gopkg.lock index 74c0b71c..478ded07 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -211,7 +211,7 @@ [[projects]] branch = "master" - digest = "1:a931c816b1c08002eddb0ec1b920ed1987a0e873dad1f9f443e4905d70b59c66" + digest = "1:84945c0665ea5fc3ccbd067c35890a7d28e369131ac411b8a820b40115245c19" name = "golang.org/x/sys" packages = [ "cpu", @@ -219,7 +219,9 @@ "windows", "windows/registry", "windows/svc", + "windows/svc/debug", "windows/svc/eventlog", + "windows/svc/mgr", ] pruneopts = "UT" revision = "04f50cda93cbb67f2afa353c52f342100e80e625" @@ -251,7 +253,9 @@ "golang.org/x/net/ipv4", "golang.org/x/sys/windows", "golang.org/x/sys/windows/svc", + "golang.org/x/sys/windows/svc/debug", "golang.org/x/sys/windows/svc/eventlog", + "golang.org/x/sys/windows/svc/mgr", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/pmctl/install_windows.go b/pmctl/install_windows.go new file mode 100644 index 00000000..b5bca3bd --- /dev/null +++ b/pmctl/install_windows.go @@ -0,0 +1,186 @@ +package main + +// Based on the offical Go examples from +// https://github.com/golang/sys/blob/master/windows/svc/example +// by The Go Authors. +// Original LICENSE (sha256sum: 2d36597f7117c38b006835ae7f537487207d8ec407aa9d9980794b2030cbc067) can be found in vendor/pkg cache directory. + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/spf13/cobra" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/mgr" +) + +func init() { + rootCmd.AddCommand(installCmd) + installCmd.AddCommand(installService) + + rootCmd.AddCommand(uninstallCmd) + uninstallCmd.AddCommand(uninstallService) +} + +var installCmd = &cobra.Command{ + Use: "install", + Short: "Install system integrations", +} + +var uninstallCmd = &cobra.Command{ + Use: "uninstall", + Short: "Uninstall system integrations", +} + +var installService = &cobra.Command{ + Use: "core-service", + Short: "Install Portmaster Core Windows Service", + RunE: installWindowsService, +} + +var uninstallService = &cobra.Command{ + Use: "core-service", + Short: "Uninstall Portmaster Core Windows Service", + RunE: uninstallWindowsService, +} + +func getExePath() (string, error) { + // get own filepath + prog := os.Args[0] + p, err := filepath.Abs(prog) + if err != nil { + return "", err + } + // check if the path is valid + fi, err := os.Stat(p) + if err == nil { + if !fi.Mode().IsDir() { + return p, nil + } + err = fmt.Errorf("%s is directory", p) + } + // check if we have a .exe extension, add and check if not + if filepath.Ext(p) == "" { + p += ".exe" + fi, err := os.Stat(p) + if err == nil { + if !fi.Mode().IsDir() { + return p, nil + } + err = fmt.Errorf("%s is directory", p) + } + } + return "", err +} + +func getServiceExecCommand(exePath string) string { + return fmt.Sprintf( + "%s run core-service --db %s --input-signals", + windows.EscapeArg(exePath), + windows.EscapeArg(*databaseRootDir), + ) +} + +func getServiceConfig(exePath string) mgr.Config { + return mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + BinaryPathName: getServiceExecCommand(exePath), + DisplayName: "Portmaster Core", + Description: "Portmaster Application Firewall - Core Service", + } +} + +func getRecoveryActions() (recoveryActions []mgr.RecoveryAction, resetPeriod uint32) { + return []mgr.RecoveryAction{ + //mgr.RecoveryAction{ + // Type: mgr.ServiceRestart, // one of NoAction, ComputerReboot, ServiceRestart or RunCommand + // Delay: 1 * time.Minute, // the time to wait before performing the specified action + //}, + // mgr.RecoveryAction{ + // Type: mgr.ServiceRestart, // one of NoAction, ComputerReboot, ServiceRestart or RunCommand + // Delay: 1 * time.Minute, // the time to wait before performing the specified action + // }, + mgr.RecoveryAction{ + Type: mgr.ServiceRestart, // one of NoAction, ComputerReboot, ServiceRestart or RunCommand + Delay: 1 * time.Minute, // the time to wait before performing the specified action + }, + }, 86400 +} + +func installWindowsService(cmd *cobra.Command, args []string) error { + // get exe path + exePath, err := getExePath() + if err != nil { + return fmt.Errorf("failed to get exe path: %s", err) + } + + // connect to Windows service manager + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %s", err) + } + defer m.Disconnect() + + // open service + created := false + s, err := m.OpenService(serviceName) + if err != nil { + // create service + s, err = m.CreateService(serviceName, getServiceExecCommand(exePath), getServiceConfig(exePath)) + if err != nil { + return fmt.Errorf("failed to create service: %s", err) + } + defer s.Close() + created = true + } else { + // update service + s.UpdateConfig(getServiceConfig(exePath)) + if err != nil { + return fmt.Errorf("failed to update service: %s", err) + } + defer s.Close() + } + + // update recovery actions + err = s.SetRecoveryActions(getRecoveryActions()) + if err != nil { + return fmt.Errorf("failed to update recovery actions: %s", err) + } + + if created { + fmt.Printf("%s created service %s\n", logPrefix, serviceName) + } else { + fmt.Printf("%s updated service %s\n", logPrefix, serviceName) + } + + return nil +} + +func uninstallWindowsService(cmd *cobra.Command, args []string) error { + // connect to Windows service manager + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + + // open service + s, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("service %s is not installed", serviceName) + } + defer s.Close() + + // delete service + err = s.Delete() + if err != nil { + return fmt.Errorf("failed to delete service: %s", err) + } + + fmt.Printf("%s uninstalled service %s\n", logPrefix, serviceName) + return nil +} diff --git a/pmctl/run.go b/pmctl/run.go index 6b5a8844..dbafd019 100644 --- a/pmctl/run.go +++ b/pmctl/run.go @@ -8,6 +8,7 @@ import ( "path/filepath" "runtime" "strings" + "sync" "time" "github.com/safing/portbase/container" @@ -190,6 +191,15 @@ func execute(opts *Options, args []string) (cont bool, err error) { hideWindow(exc) } + // check if input signals are enabled + inputSignalsEnabled := false + for _, arg := range args { + if strings.HasSuffix(arg, "-input-signals") { + inputSignalsEnabled = true + break + } + } + // consume stdout/stderr stdout, err := exc.StdoutPipe() if err != nil { @@ -199,6 +209,13 @@ func execute(opts *Options, args []string) (cont bool, err error) { if err != nil { return true, fmt.Errorf("failed to connect stderr: %s", err) } + var stdin io.WriteCloser + if inputSignalsEnabled { + stdin, err = exc.StdinPipe() + if err != nil { + return true, fmt.Errorf("failed to connect stdin: %s", err) + } + } // start err = exc.Start() @@ -208,6 +225,8 @@ func execute(opts *Options, args []string) (cont bool, err error) { childIsRunning.Set() // start output writers + var wg sync.WaitGroup + wg.Add(2) go func() { var logFileError error if logFile == nil { @@ -218,6 +237,7 @@ func execute(opts *Options, args []string) (cont bool, err error) { if logFileError != nil { fmt.Printf("%s failed write logs: %s\n", logPrefix, logFileError) } + wg.Done() }() go func() { var errorFileError error @@ -229,10 +249,11 @@ func execute(opts *Options, args []string) (cont bool, err error) { if errorFileError != nil { fmt.Printf("%s failed write error logs: %s\n", logPrefix, errorFileError) } + wg.Done() }() // give some time to finish log file writing defer func() { - time.Sleep(100 * time.Millisecond) + wg.Wait() childIsRunning.UnSet() }() @@ -247,14 +268,25 @@ func execute(opts *Options, args []string) (cont bool, err error) { for { select { case <-shuttingDown: - err := exc.Process.Signal(os.Interrupt) + // signal process shutdown + if inputSignalsEnabled { + // for windows + _, err = stdin.Write([]byte("SIGINT\n")) + } else { + err = exc.Process.Signal(os.Interrupt) + } if err != nil { fmt.Printf("%s failed to signal %s to shutdown: %s\n", logPrefix, opts.Identifier, err) - fmt.Printf("%s forcing shutdown...\n", logPrefix) - // wait until shut down - <-finished - return false, nil + err = exc.Process.Kill() + if err != nil { + fmt.Printf("%s failed to kill %s: %s\n", logPrefix, opts.Identifier, err) + } else { + fmt.Printf("%s killed %s\n", logPrefix, opts.Identifier) + } } + // wait until shut down + <-finished + return false, nil case err := <-finished: if err != nil { exErr, ok := err.(*exec.ExitError) diff --git a/pmctl/service_default.go b/pmctl/service_default.go deleted file mode 100644 index 93b6a133..00000000 --- a/pmctl/service_default.go +++ /dev/null @@ -1,9 +0,0 @@ -// +build !windows - -package main - -import "github.com/spf13/cobra" - -func runService(cmd *cobra.Command, opts *Options) { - run(cmd, opts) -} diff --git a/pmctl/service_windows.go b/pmctl/service_windows.go index 6deaaca6..36b6c7d8 100644 --- a/pmctl/service_windows.go +++ b/pmctl/service_windows.go @@ -1,11 +1,17 @@ package main +// Based on the offical Go examples from +// https://github.com/golang/sys/blob/master/windows/svc/example +// by The Go Authors. +// Original LICENSE (sha256sum: 2d36597f7117c38b006835ae7f537487207d8ec407aa9d9980794b2030cbc067) can be found in vendor/pkg cache directory. + import ( "fmt" "time" "github.com/spf13/cobra" "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" "golang.org/x/sys/windows/svc/eventlog" ) @@ -29,6 +35,9 @@ var ( // helpers for execution runError chan error runWrapper func() error + + // eventlog + eventlogger *eventlog.Log ) func init() { @@ -50,10 +59,10 @@ func (ws *windowsService) Execute(args []string, changeRequests <-chan svc.Chang }() // poll for start completion - var started chan struct{} + started := make(chan struct{}) go func() { for { - time.Sleep(100 * time.Millisecond) + time.Sleep(10 * time.Millisecond) if childIsRunning.IsSet() { close(started) return @@ -66,39 +75,41 @@ func (ws *windowsService) Execute(args []string, changeRequests <-chan svc.Chang case err := <-runError: // TODO: log error to windows fmt.Printf("%s start error: %s", logPrefix, err) + eventlogger.Error(4, fmt.Sprintf("failed to start Portmaster Core: %s", err)) + changes <- svc.Status{State: svc.Stopped} return false, 1 case <-started: // give some more time for enabling packet interception time.Sleep(500 * time.Millisecond) changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} - fmt.Printf("%s startup complete, entered service running state", logPrefix) + fmt.Printf("%s startup complete, entered service running state\n", logPrefix) } // wait for change requests +serviceLoop: for { select { case <-shuttingDown: - // signal that we are shutting down - changes <- svc.Status{State: svc.StopPending} - // wait for program to exit - <-programEnded - return + break serviceLoop case c := <-changeRequests: switch c.Cmd { case svc.Interrogate: changes <- c.CurrentStatus case svc.Stop, svc.Shutdown: - changes <- svc.Status{State: svc.StopPending} initiateShutdown() - // wait for program to exit - <-programEnded - return + break serviceLoop default: - fmt.Printf("%s unexpected control request: #%d", logPrefix, c) + fmt.Printf("%s unexpected control request: #%d\n", logPrefix, c) } } } + + // signal that we are shutting down changes <- svc.Status{State: svc.StopPending} + // wait for program to exit + <-programEnded + // signal shutdown complete + changes <- svc.Status{State: svc.Stopped} return } @@ -108,6 +119,12 @@ func runService(cmd *cobra.Command, opts *Options) error { return run(cmd, opts) } + // check if we are running interactively + isDebug, err := svc.IsAnInteractiveSession() + if err != nil { + return fmt.Errorf("could not determine if running interactively: %s", err) + } + // open eventlog // TODO: do something useful with eventlog elog, err := eventlog.Open(serviceName) @@ -115,9 +132,17 @@ func runService(cmd *cobra.Command, opts *Options) error { return fmt.Errorf("failed to open eventlog: %s", err) } defer elog.Close() + eventlogger = elog elog.Info(1, fmt.Sprintf("starting %s service", serviceName)) - err = svc.Run(serviceName, &windowsService{}) + // select run method bas + run := svc.Run + if isDebug { + fmt.Printf("%s WARNING: running interactively, switching to debug execution (no real service).\n", logPrefix) + run = debug.Run + } + // run + err = run(serviceName, &windowsService{}) if err != nil { elog.Error(3, fmt.Sprintf("%s service failed: %v", serviceName, err)) return fmt.Errorf("failed to start service: %s", err)