diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index d0da7baf..6a9045fa 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -102,10 +102,6 @@ func Stop() error { if err != nil { log.Warningf("winkext: failed to delete service: %s", err) } - err = service.closeHandle() - if err != nil { - log.Warningf("winkext: failed to close the handle: %s", err) - } kextHandle = winInvalidHandleValue return nil @@ -274,7 +270,7 @@ func GetVersion() (*VersionInfo, error) { } data := make([]uint8, 4) - _, err := deviceIOControl(kextHandle, IOCTL_VERSION, data, nil) + _, err := deviceIOControl(kextHandle, IOCTL_VERSION, nil, data) if err != nil { return nil, err diff --git a/firewall/interception/windowskext/service.go b/firewall/interception/windowskext/service.go index ed4429e8..44b49ba1 100644 --- a/firewall/interception/windowskext/service.go +++ b/firewall/interception/windowskext/service.go @@ -8,6 +8,7 @@ import ( "syscall" "time" + "github.com/safing/portbase/log" "golang.org/x/sys/windows" ) @@ -27,10 +28,19 @@ func createKextService(driverName string, driverPath string) (*KextService, erro if err != nil { return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err) } - // Check if it's already created + + // Check if there is an old service. service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS) if err == nil { - return &KextService{handle: service}, nil // service was already created + log.Warning("kext: old driver service was found") + oldService := &KextService{handle: service} + err := deleteService(manager, oldService, driverNameU16) + if err != nil { + return nil, fmt.Errorf("failed to delete old service: %s", err) + } + + service = winInvalidHandleValue + log.Info("kext: old driver service was deleted successful") } driverPathU16, err := syscall.UTF16FromString(driverPath) @@ -44,6 +54,36 @@ func createKextService(driverName string, driverPath string) (*KextService, erro return &KextService{handle: service}, nil } +func deleteService(manager windows.Handle, service *KextService, driverName []uint16) error { + // Stop and wait before deleting + _ = service.stop(true) + + // Try to delete even if stop failed + err := service.delete() + if err != nil { + return fmt.Errorf("failed to delete old service: %s", err) + } + + // Wait until we can no longer open the old service. + // Not very efficient but NotifyServiceStatusChange cannot be used with driver service. + start := time.Now() + timeLimit := time.Duration(30 * time.Second) + for true { + handle, err := windows.OpenService(manager, &driverName[0], windows.SERVICE_ALL_ACCESS) + if err != nil { + break + } + _ = windows.CloseServiceHandle(handle) + + if time.Since(start) > timeLimit { + return fmt.Errorf("time limit reached") + } + + time.Sleep(100 * time.Millisecond) + } + return nil +} + func (s *KextService) isValid() bool { return s != nil && s.handle != winInvalidHandleValue && s.handle != 0 } @@ -145,17 +185,13 @@ func (s *KextService) delete() error { if err != nil { return fmt.Errorf("failed to delete service: %s", err) } - return nil -} -func (s *KextService) closeHandle() error { - if !s.isValid() { - return fmt.Errorf("kext service not initialized") - } - - err := windows.CloseServiceHandle(s.handle) + // Service wont be deleted until all handles are closed. + err = windows.CloseServiceHandle(s.handle) if err != nil { return fmt.Errorf("failed to close service handle: %s", err) } + + s.handle = winInvalidHandleValue return nil }