diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index 382869be..0441d3a6 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -10,16 +10,12 @@ import ( // start starts the interception. func start(ch chan packet.Packet) error { - dllFile, err := updates.GetPlatformFile("kext/portmaster-kext.dll") - if err != nil { - return fmt.Errorf("interception: could not get kext dll: %s", err) - } kextFile, err := updates.GetPlatformFile("kext/portmaster-kext.sys") if err != nil { return fmt.Errorf("interception: could not get kext sys: %s", err) } - err = windowskext.Init(dllFile.Path(), kextFile.Path()) + err = windowskext.Init(kextFile.Path()) if err != nil { return fmt.Errorf("interception: could not init windows kext: %s", err) } diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 623c49a3..a19cbfdc 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package windowskext @@ -10,6 +11,7 @@ import ( "github.com/tevino/abool" "github.com/safing/portbase/log" + "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/packet" ) @@ -43,6 +45,11 @@ type VerdictRequest struct { packetSize uint32 } +type VerdictInfo struct { + id uint32 // ID from RegisterPacket + verdict network.Verdict // verdict for the connection +} + // Handler transforms received packets to the Packet interface. func Handler(packets chan packet.Packet) { if !ready.IsSet() { diff --git a/firewall/interception/windowskext/kext.go b/firewall/interception/windowskext/kext.go index d438538e..ca8315e6 100644 --- a/firewall/interception/windowskext/kext.go +++ b/firewall/interception/windowskext/kext.go @@ -6,9 +6,9 @@ package windowskext import ( "errors" "fmt" + "os/exec" "sync" "sync/atomic" - "syscall" "time" "unsafe" @@ -25,90 +25,24 @@ var ( winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA) - kext *WinKext kextLock sync.RWMutex ready = abool.NewBool(false) urgentRequests *int32 + driverPath string + + kextHandle windows.Handle ) +const driverName = "PortmasterKext" + func init() { var urgentRequestsValue int32 urgentRequests = &urgentRequestsValue } -// WinKext holds the DLL handle. -type WinKext struct { - sync.RWMutex - - dll *windows.DLL - driverPath string - - init *windows.Proc - start *windows.Proc - stop *windows.Proc - recvVerdictRequest *windows.Proc - setVerdict *windows.Proc - getPayload *windows.Proc - clearCache *windows.Proc -} - // Init initializes the DLL and the Kext (Kernel Driver). -func Init(dllPath, driverPath string) error { - - new := &WinKext{ - driverPath: driverPath, - } - - var err error - - // load dll - new.dll, err = windows.LoadDLL(dllPath) - if err != nil { - return err - } - - // load functions - new.init, err = new.dll.FindProc("PortmasterInit") - if err != nil { - return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err) - } - new.start, err = new.dll.FindProc("PortmasterStart") - if err != nil { - return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err) - } - new.stop, err = new.dll.FindProc("PortmasterStop") - if err != nil { - return fmt.Errorf("could not find proc PortmasterStop in dll: %s", err) - } - new.recvVerdictRequest, err = new.dll.FindProc("PortmasterRecvVerdictRequest") - if err != nil { - return fmt.Errorf("could not find proc PortmasterRecvVerdictRequest in dll: %s", err) - } - new.setVerdict, err = new.dll.FindProc("PortmasterSetVerdict") - if err != nil { - return fmt.Errorf("could not find proc PortmasterSetVerdict in dll: %s", err) - } - new.getPayload, err = new.dll.FindProc("PortmasterGetPayload") - if err != nil { - return fmt.Errorf("could not find proc PortmasterGetPayload in dll: %s", err) - } - new.clearCache, err = new.dll.FindProc("PortmasterClearCache") - if err != nil { - // the loaded dll is an old version - log.Errorf("could not find proc PortmasterClearCache (v1.0.12+) in dll: %s", err) - } - - // initialize dll/kext - rc, _, lastErr := new.init.Call() - if rc != windows.NO_ERROR { - return formatErr(lastErr, rc) - } - - // set kext - kextLock.Lock() - defer kextLock.Unlock() - kext = new - +func Init(path string) error { + driverPath = path return nil } @@ -117,16 +51,31 @@ func Start() error { kextLock.Lock() defer kextLock.Unlock() - // convert to C string - charArray := make([]byte, len(kext.driverPath)+1) - copy(charArray, []byte(kext.driverPath)) - charArray[len(charArray)-1] = 0 // force NULL byte at the end + filename := `\\.\` + driverName - rc, _, lastErr := kext.start.Call( - uintptr(unsafe.Pointer(&charArray[0])), - ) - if rc != windows.NO_ERROR { - return formatErr(lastErr, rc) + // check if driver is already installed + var err error + kextHandle, err = openDriver(filename) + if err == nil { + return nil // device was already initialized + } + + // initialize and start driver service + service, err := driverInstall(driverPath) + if err != nil { + return fmt.Errorf("Failed to start service: %s", err) + } + + // open the driver + kextHandle, err = openDriver(filename) + + // close the service handles + _ = windows.DeleteService(service) + _ = windows.CloseServiceHandle(service) + + // driver was not installed + if err != nil { + return fmt.Errorf("Failed to start the kext service: %s %q", err, filename) } ready.Set() @@ -142,9 +91,14 @@ func Stop() error { } ready.UnSet() - rc, _, lastErr := kext.stop.Call() - if rc != windows.NO_ERROR { - return formatErr(lastErr, rc) + err := closeDriver(kextHandle) + if err != nil { + log.Errorf("winkext: failed to close the handle: %s", err) + } + + _, err = exec.Command("sc", "stop", driverName).Output() // This is a question of taste, but it is a robust and solid solution + if err != nil { + log.Errorf("winkext: failed to stop the service: %q", err) } return nil } @@ -156,9 +110,6 @@ func RecvVerdictRequest() (*VerdictRequest, error) { if !ready.IsSet() { return nil, ErrKextNotReady } - - new := &VerdictRequest{} - // wait for urgent requests to complete for i := 1; i <= 100; i++ { if atomic.LoadInt32(urgentRequests) <= 0 { @@ -170,19 +121,22 @@ func RecvVerdictRequest() (*VerdictRequest, error) { time.Sleep(100 * time.Microsecond) } - // timestamp := time.Now() - rc, _, lastErr := kext.recvVerdictRequest.Call( - uintptr(unsafe.Pointer(new)), - ) - // log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp)) + timestamp := time.Now() + // Initialize struct for the output data + var new VerdictRequest - if rc != windows.NO_ERROR { - if rc == winErrInvalidData { - return nil, nil - } - return nil, formatErr(lastErr, rc) + // Make driver request + data := asByteArray(&new) + bytesRead, err := deviceIoControlRead(kextHandle, IOCTL_RECV_VERDICT_REQ, data) + if err != nil { + return nil, err } - return new, nil + if bytesRead == 0 { + return nil, nil // no error, no new verdict request + } + + log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp)) + return &new, nil } // SetVerdict sets the verdict for a packet and/or connection. @@ -199,17 +153,16 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error { return ErrKextNotReady } + verdictInfo := VerdictInfo{pkt.verdictRequest.id, verdict} + + // Make driver request atomic.AddInt32(urgentRequests, 1) - // timestamp := time.Now() - rc, _, lastErr := kext.setVerdict.Call( - uintptr(pkt.verdictRequest.id), - uintptr(verdict), - ) - // log.Tracef("winkext: settings verdict for packetID %d took %s", packetID, time.Now().Sub(timestamp)) + data := asByteArray(&verdictInfo) + _, err := deviceIoControlWrite(kextHandle, IOCTL_SET_VERDICT, data) atomic.AddInt32(urgentRequests, -1) - if rc != windows.NO_ERROR { + if err != nil { log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id) - return formatErr(lastErr, rc) + return err } return nil } @@ -220,6 +173,7 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) { return nil, ErrNoPacketID } + // Check if driver is initialized kextLock.RLock() defer kextLock.RUnlock() if !ready.IsSet() { @@ -228,26 +182,30 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) { buf := make([]byte, packetSize) + // Combine id and length + payload := struct { + id uint32 + length uint32 + }{packetID, packetSize} + + // Make driver request atomic.AddInt32(urgentRequests, 1) - // timestamp := time.Now() - rc, _, lastErr := kext.getPayload.Call( - uintptr(packetID), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&packetSize)), - ) - // log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp)) + data := asByteArray(&payload) + bytesRead, err := deviceIoControlReadWrite(kextHandle, IOCTL_GET_PAYLOAD, data, unsafe.Slice(&buf[0], packetSize)) + atomic.AddInt32(urgentRequests, -1) - if rc != windows.NO_ERROR { - return nil, formatErr(lastErr, rc) + if err != nil { + return nil, err } - if packetSize == 0 { + // check the result and return + if bytesRead == 0 { return nil, errors.New("windows kext did not return any data") } - if packetSize < uint32(len(buf)) { - return buf[:packetSize], nil + if bytesRead < uint32(len(buf)) { + return buf[:bytesRead], nil } return buf, nil @@ -256,28 +214,18 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) { func ClearCache() error { kextLock.RLock() defer kextLock.RUnlock() + + // Check if driver is initialized if !ready.IsSet() { log.Error("kext: failed to clear the cache: kext not ready") return ErrKextNotReady } - if kext.clearCache == nil { - log.Error("kext: cannot clear cache: clearCache function missing") - } - - rc, _, lastErr := kext.clearCache.Call() - - if rc != windows.NO_ERROR { - return formatErr(lastErr, rc) - } - - return nil -} - -func formatErr(err error, rc uintptr) error { - sysErr, ok := err.(syscall.Errno) - if ok { - return fmt.Errorf("%s [LE 0x%X] [RC 0x%X]", err, uintptr(sysErr), rc) - } + // Make driver request + _, err := deviceIoControlRead(kextHandle, IOCTL_CLEAR_CACHE, nil) return err } + +func asByteArray[T any](obj *T) []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj)) +} diff --git a/firewall/interception/windowskext/service.go b/firewall/interception/windowskext/service.go new file mode 100644 index 00000000..a2807a59 --- /dev/null +++ b/firewall/interception/windowskext/service.go @@ -0,0 +1,87 @@ +//go:build windows +// +build windows + +package windowskext + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/windows" +) + +func createService(manager windows.Handle, portmasterKextPath *uint16) (windows.Handle, error) { + u16filename, err := syscall.UTF16FromString(driverName) + if err != nil { + return 0, fmt.Errorf("Bad service: %s", err) + } + // Check if it's already created + service, err := windows.OpenService(manager, &u16filename[0], windows.SERVICE_ALL_ACCESS) + if err == nil { + return service, nil + } + + // Create the service + service, err = windows.CreateService(manager, &u16filename[0], &u16filename[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, portmasterKextPath, nil, nil, nil, nil, nil) + if err != nil { + return 0, err + } + + return service, nil +} + +func driverInstall(portmasterKextPath string) (windows.Handle, error) { + u16kextPath, _ := syscall.UTF16FromString(portmasterKextPath) + // Open the service manager: + manager, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_ALL_ACCESS) + if err != nil { + return 0, fmt.Errorf("Failed to open service manager: %d", err) + } + defer windows.CloseServiceHandle(manager) + + // Try to create the service. Retry if it fails. + var service windows.Handle +retryLoop: + for i := 0; i < 3; i++ { + service, err = createService(manager, &u16kextPath[0]) + if err == nil { + break retryLoop + } + } + + if err != nil { + return 0, fmt.Errorf("Failed to create service: %s", err) + } + + // Start the service: + err = windows.StartService(service, 0, nil) + + if err != nil { + err = windows.GetLastError() + if err != windows.ERROR_SERVICE_ALREADY_RUNNING { + // Failed to start service; clean-up: + var status windows.SERVICE_STATUS + _ = windows.ControlService(service, windows.SERVICE_CONTROL_STOP, &status) + _ = windows.DeleteService(service) + _ = windows.CloseServiceHandle(service) + service = 0 + } + } + + return service, nil +} + +func openDriver(filename string) (windows.Handle, error) { + u16filename, _ := syscall.UTF16FromString(filename) + + handle, err := windows.CreateFile(&u16filename[0], windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, 0, 0) + if err != nil { + return 0, err + } + + return handle, nil +} + +func closeDriver(handle windows.Handle) error { + return windows.CloseHandle(handle) +} diff --git a/firewall/interception/windowskext/syscall.go b/firewall/interception/windowskext/syscall.go new file mode 100644 index 00000000..6ec11348 --- /dev/null +++ b/firewall/interception/windowskext/syscall.go @@ -0,0 +1,91 @@ +//go:build windows +// +build windows + +package windowskext + +import "golang.org/x/sys/windows" + +const ( + METHOD_BUFFERED = 0 + METHOD_IN_DIRECT = 1 + METHOD_OUT_DIRECT = 2 + METHOD_NEITHER = 3 + + SIOCTL_TYPE = 40000 +) + +var ( + IOCTL_HELLO = ctlCode(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_RECV_VERDICT_REQ_POLL = ctlCode(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_RECV_VERDICT_REQ = ctlCode(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_SET_VERDICT = ctlCode(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_GET_PAYLOAD = ctlCode(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) + IOCTL_CLEAR_CACHE = ctlCode(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA) +) + +func ctlCode(device_type, function, method, access uint32) uint32 { + return (device_type << 16) | (access << 14) | (function << 2) | method +} + +func deviceIoControlRead(handle windows.Handle, code uint32, data []byte) (uint32, error) { + var bytesReturned uint32 + + var dataPtr *byte = nil + var dataSize uint32 = 0 + if data != nil { + dataPtr = &data[0] + dataSize = uint32(len(data)) + } + + err := windows.DeviceIoControl(handle, + code, + nil, 0, + dataPtr, dataSize, + &bytesReturned, nil) + + return bytesReturned, err +} + +func deviceIoControlWrite(handle windows.Handle, code uint32, data []byte) (uint32, error) { + var bytesReturned uint32 + + var dataPtr *byte = nil + var dataSize uint32 = 0 + if data != nil { + dataPtr = &data[0] + dataSize = uint32(len(data)) + } + + err := windows.DeviceIoControl(handle, + code, + dataPtr, dataSize, + nil, 0, + &bytesReturned, nil) + + return bytesReturned, err +} + +func deviceIoControlReadWrite(handle windows.Handle, code uint32, inData []byte, outData []byte) (uint32, error) { + var bytesReturned uint32 + + var inDataPtr *byte = nil + var inDataSize uint32 = 0 + if inData != nil { + inDataPtr = &inData[0] + inDataSize = uint32(len(inData)) + } + + var outDataPtr *byte = nil + var outDataSize uint32 = 0 + if outData != nil { + outDataPtr = &outData[0] + outDataSize = uint32(len(outData)) + } + err := windows.DeviceIoControl(handle, + code, + inDataPtr, inDataSize, + outDataPtr, outDataSize, + &bytesReturned, nil) + + return bytesReturned, err +} diff --git a/updates/helper/updates.go b/updates/helper/updates.go index 17fe6116..80193ee8 100644 --- a/updates/helper/updates.go +++ b/updates/helper/updates.go @@ -52,7 +52,6 @@ func MandatoryUpdates() (identifiers []string) { identifiers = append( identifiers, PlatformIdentifier("core/portmaster-core.exe"), - PlatformIdentifier("kext/portmaster-kext.dll"), PlatformIdentifier("kext/portmaster-kext.sys"), PlatformIdentifier("kext/portmaster-kext.pdb"), PlatformIdentifier("start/portmaster-start.exe"),