From e308543f4f77dbc186237c28d80823a2ee11835a Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Mon, 22 Jan 2024 01:15:56 +0200 Subject: [PATCH] Update kext library --- firewall/interception/interception_windows.go | 2 +- .../windowskext2/bandwidth_stats.go | 132 ------------------ firewall/interception/windowskext2/handler.go | 119 +++------------- firewall/interception/windowskext2/kext.go | 67 +++++---- 4 files changed, 56 insertions(+), 264 deletions(-) delete mode 100644 firewall/interception/windowskext2/bandwidth_stats.go diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index 8deeecdb..d46cbeb8 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -41,7 +41,7 @@ func startInterception(packets chan packet.Packet) error { // Start kext logging. The worker will periodically send request to the kext to send logs. module.StartServiceWorker("kext log request worker", 0, func(ctx context.Context) error { - timer := time.NewTimer(time.Second) + timer := time.NewTicker(1 * time.Second) for { select { case <-timer.C: diff --git a/firewall/interception/windowskext2/bandwidth_stats.go b/firewall/interception/windowskext2/bandwidth_stats.go deleted file mode 100644 index 2a1bddc0..00000000 --- a/firewall/interception/windowskext2/bandwidth_stats.go +++ /dev/null @@ -1,132 +0,0 @@ -//go:build windows -// +build windows - -package windowskext - -// This file contains example code how to read bandwidth stats from the kext. Its not ment to be used in production. - -import ( - "context" - "time" - - "github.com/safing/portbase/log" - "github.com/safing/portmaster/network/packet" -) - -type Rxtxdata struct { - rx uint64 - tx uint64 -} - -type Key struct { - localIP [4]uint32 - remoteIP [4]uint32 - localPort uint16 - remotePort uint16 - ipv6 bool - protocol uint8 -} - -var m = make(map[Key]Rxtxdata) - -func BandwidthStatsWorker(ctx context.Context, collectInterval time.Duration, bandwidthUpdates chan *packet.BandwidthUpdate) error { - // Setup ticker. - ticker := time.NewTicker(collectInterval) - defer ticker.Stop() - - // Collect bandwidth at every tick. - for { - select { - case <-ticker.C: - err := reportBandwidth(ctx, bandwidthUpdates) - if err != nil { - return err - } - case <-ctx.Done(): - return nil - } - } -} - -func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.BandwidthUpdate) error { - stats, err := GetConnectionsStats() - if err != nil { - return err - } - - // Report all statistics. - for i, stat := range stats { - connID := packet.CreateConnectionID( - packet.IPProtocol(stat.protocol), - convertArrayToIP(stat.localIP, stat.ipV6 == 1), stat.localPort, - convertArrayToIP(stat.remoteIP, stat.ipV6 == 1), stat.remotePort, - false, - ) - update := &packet.BandwidthUpdate{ - ConnID: connID, - BytesReceived: stat.receivedBytes, - BytesSent: stat.transmittedBytes, - Method: packet.Additive, - } - select { - case bandwidthUpdates <- update: - case <-ctx.Done(): - return nil - default: - log.Warningf("kext: bandwidth update queue is full, skipping rest of batch (%d entries)", len(stats)-i) - return nil - } - } - - return nil -} - -func StartBandwithConsoleLogger() { - go func() { - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - for range ticker.C { - conns, err := GetConnectionsStats() - if err != nil { - continue - } - for _, conn := range conns { - if conn.receivedBytes == 0 && conn.transmittedBytes == 0 { - continue - } - key := Key{ - localIP: conn.localIP, - remoteIP: conn.remoteIP, - localPort: conn.localPort, - remotePort: conn.remotePort, - ipv6: conn.ipV6 == 1, - protocol: conn.protocol, - } - - // First we get a "copy" of the entry - if entry, ok := m[key]; ok { - // Then we modify the copy - entry.rx += conn.receivedBytes - entry.tx += conn.transmittedBytes - - // Then we reassign map entry - m[key] = entry - } else { - m[key] = Rxtxdata{ - rx: conn.receivedBytes, - tx: conn.transmittedBytes, - } - } - } - log.Debug("----------------------------------") - for key, value := range m { - log.Debugf( - "Conn: %d %s:%d %s:%d rx:%d tx:%d", key.protocol, - convertArrayToIP(key.localIP, key.ipv6), key.localPort, - convertArrayToIP(key.remoteIP, key.ipv6), key.remotePort, - value.rx, value.tx, - ) - } - } - }() -} diff --git a/firewall/interception/windowskext2/handler.go b/firewall/interception/windowskext2/handler.go index c6dff10d..ba5169fd 100644 --- a/firewall/interception/windowskext2/handler.go +++ b/firewall/interception/windowskext2/handler.go @@ -5,11 +5,8 @@ package windowskext import ( "context" - "encoding/binary" "fmt" - "net" "time" - "unsafe" "github.com/safing/portmaster/process" @@ -19,34 +16,6 @@ import ( "github.com/safing/portmaster/network/packet" ) -const ( - // VerdictRequestFlagFastTrackPermitted is set on packets that have been - // already permitted by the kernel extension and the verdict request is only - // informational. - VerdictRequestFlagFastTrackPermitted = 1 - - // VerdictRequestFlagSocketAuth indicates that the verdict request is for a - // connection that was intercepted on an ALE layer instead of in the network - // stack itself. Thus, no packet data is available. - VerdictRequestFlagSocketAuth = 2 - - // VerdictRequestFlagExpectSocketAuth indicates that the next verdict - // requests is expected to be an informational socket auth request from - // the ALE layer. - VerdictRequestFlagExpectSocketAuth = 4 -) - -type ConnectionStat struct { - localIP [4]uint32 //Source Address, only srcIP[0] if IPv4 - remoteIP [4]uint32 //Destination Address - localPort uint16 //Source Port - remotePort uint16 //Destination port - receivedBytes uint64 //Number of bytes recived on this connection - transmittedBytes uint64 //Number of bytes transsmited from this connection - ipV6 uint8 //True: IPv6, False: IPv4 - protocol uint8 //Protocol (UDP, TCP, ...) -} - type VersionInfo struct { major uint8 minor uint8 @@ -79,7 +48,7 @@ func Handler(ctx context.Context, packets chan packet.Packet) { info.Inbound = conn.Direction > 0 info.InTunnel = false info.Protocol = packet.IPProtocol(conn.Protocol) - info.PID = int(*conn.ProcessId) + info.PID = int(conn.ProcessId) info.SeenAt = time.Now() // Check PID @@ -90,21 +59,17 @@ func Handler(ctx context.Context, packets chan packet.Packet) { } // Set IP version - if conn.IpV6 { - info.Version = packet.IPv6 - } else { - info.Version = packet.IPv4 - } + info.Version = packet.IPv4 // Set IPs if info.Inbound { // Inbound - info.Src = net.IP(conn.RemoteIp) - info.Dst = net.IP(conn.LocalIp) + info.Src = conn.RemoteIp[:] + info.Dst = conn.LocalIp[:] } else { // Outbound - info.Src = net.IP(conn.LocalIp) - info.Dst = net.IP(conn.RemoteIp) + info.Src = conn.LocalIp[:] + info.Dst = conn.RemoteIp[:] } // Set Ports @@ -121,61 +86,21 @@ func Handler(ctx context.Context, packets chan packet.Packet) { packets <- new } - if packetInfo.LogLines != nil { - for _, line := range *packetInfo.LogLines { - switch line.Severity { - case int(log.DebugLevel): - log.Debugf("kext: %s", line.Line) - case int(log.InfoLevel): - log.Infof("kext: %s", line.Line) - case int(log.WarningLevel): - log.Warningf("kext: %s", line.Line) - case int(log.ErrorLevel): - log.Errorf("kext: %s", line.Line) - case int(log.CriticalLevel): - log.Criticalf("kext: %s", line.Line) - } - } - } + // if packetInfo.LogLines != nil { + // for _, line := range *packetInfo.LogLines { + // switch line.Severity { + // case int(log.DebugLevel): + // log.Debugf("kext: %s", line.Line) + // case int(log.InfoLevel): + // log.Infof("kext: %s", line.Line) + // case int(log.WarningLevel): + // log.Warningf("kext: %s", line.Line) + // case int(log.ErrorLevel): + // log.Errorf("kext: %s", line.Line) + // case int(log.CriticalLevel): + // log.Criticalf("kext: %s", line.Line) + // } + // } + // } } } - -// convertArrayToIP converts an array of uint32 values to a net.IP address. -func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP { - if !ipv6 { - addressBuf := make([]byte, 4) - binary.BigEndian.PutUint32(addressBuf, input[0]) - return net.IP(addressBuf) - } - - addressBuf := make([]byte, 16) - for i := 0; i < 4; i++ { - binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) - } - return net.IP(addressBuf) -} - -func ipAddressToArray(ip net.IP, isIPv6 bool) [4]uint32 { - array := [4]uint32{0} - if isIPv6 { - for i := 0; i < 4; i++ { - binary.BigEndian.PutUint32(asByteArrayWithLength(&array[i], 4), getUInt32Value(&ip[i])) - } - } else { - binary.BigEndian.PutUint32(asByteArrayWithLength(&array[0], 4), getUInt32Value(&ip[0])) - } - - return array -} - -func asByteArray[T any](obj *T) []byte { - return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj)) -} - -func asByteArrayWithLength[T any](obj *T, size uint32) []byte { - return unsafe.Slice((*byte)(unsafe.Pointer(obj)), size) -} - -func getUInt32Value[T any](obj *T) uint32 { - return *(*uint32)(unsafe.Pointer(obj)) -} diff --git a/firewall/interception/windowskext2/kext.go b/firewall/interception/windowskext2/kext.go index f93605b5..ca90b991 100644 --- a/firewall/interception/windowskext2/kext.go +++ b/firewall/interception/windowskext2/kext.go @@ -4,9 +4,7 @@ package windowskext import ( - "errors" "fmt" - "unsafe" "github.com/safing/portbase/log" "github.com/safing/portmaster/network" @@ -15,9 +13,6 @@ import ( // Package errors var ( - ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands") - ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension") - driverPath string service *kext_interface.KextService @@ -28,7 +23,6 @@ const ( driverName = "PortmasterKext" ) -// Init initializes the DLL and the Kext (Kernel Driver). func Init(path string) error { driverPath = path return nil @@ -63,20 +57,32 @@ func Stop() error { log.Warningf("winkext: shutdown request failed: %s", err) } // Close the interface to the driver. Driver will continue to run. - kextFile.Close() + err = kextFile.Close() + if err != nil { + log.Warningf("winkext: failed to close kext file: %s", err) + } // Stop and delete the driver. - service.Stop(true) - service.Delete() + err = service.Stop(true) + if err != nil { + log.Warningf("winkext: failed to stop kernel service: %s", err) + } + + err = service.Delete() + if err != nil { + log.Warningf("winkext: failed to delete kernel service: %s", err) + } return nil } +// Sends a shutdown request. func shutdownRequest() error { - return kext_interface.WriteCommand(kextFile, kext_interface.BuildShutdown()) + return kext_interface.WriteShutdownCommand(kextFile) } +// Send request for logs of the kext. func SendLogRequest() error { - return kext_interface.WriteCommand(kextFile, kext_interface.BuildGetLogs()) + return kext_interface.WriteGetLogsCommand(kextFile) } // RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil. @@ -87,53 +93,52 @@ func RecvVerdictRequest() (*kext_interface.Info, error) { // SetVerdict sets the verdict for a packet and/or connection. func SetVerdict(pkt *Packet, verdict network.Verdict) error { if verdict == network.VerdictRerouteToNameserver { - redirect := kext_interface.Redirect{Id: pkt.verdictRequest, RemoteAddress: []uint8{127, 0, 0, 1}, RemotePort: 53} - command := kext_interface.BuildRedirect(redirect) - kext_interface.WriteCommand(kextFile, command) + redirect := kext_interface.RedirectV4{Id: pkt.verdictRequest, RemoteAddress: [4]uint8{127, 0, 0, 1}, RemotePort: 53} + kext_interface.WriteRedirectCommand(kextFile, redirect) } else if verdict == network.VerdictRerouteToTunnel { - redirect := kext_interface.Redirect{Id: pkt.verdictRequest, RemoteAddress: []uint8{192, 168, 122, 196}, RemotePort: 717} - command := kext_interface.BuildRedirect(redirect) - kext_interface.WriteCommand(kextFile, command) + redirect := kext_interface.RedirectV4{Id: pkt.verdictRequest, RemoteAddress: [4]uint8{192, 168, 122, 196}, RemotePort: 717} + kext_interface.WriteRedirectCommand(kextFile, redirect) } else { verdict := kext_interface.Verdict{Id: pkt.verdictRequest, Verdict: uint8(verdict)} - command := kext_interface.BuildVerdict(verdict) - kext_interface.WriteCommand(kextFile, command) + kext_interface.WriteVerdictCommand(kextFile, verdict) } return nil } +// Clears the internal connection cache. func ClearCache() error { - return kext_interface.WriteCommand(kextFile, kext_interface.BuildClearCache()) + return kext_interface.WriteClearCacheCommand(kextFile) } +// Updates a specific connection verdict. func UpdateVerdict(conn *network.Connection) error { - redirectAddress := []uint8{} + redirectAddress := [4]byte{} redirectPort := 0 if conn.Verdict.Active == network.VerdictRerouteToNameserver { - redirectAddress = []uint8{127, 0, 0, 1} + redirectAddress = [4]byte{127, 0, 0, 1} redirectPort = 53 } if conn.Verdict.Active == network.VerdictRerouteToTunnel { - redirectAddress = []uint8{192, 168, 122, 196} + redirectAddress = [4]byte{192, 168, 122, 196} redirectPort = 717 } - update := kext_interface.Update{ + update := kext_interface.UpdateV4{ Protocol: conn.Entity.Protocol, - LocalAddress: conn.LocalIP, + LocalAddress: [4]byte(conn.LocalIP), LocalPort: conn.LocalPort, - RemoteAddress: conn.Entity.IP, + RemoteAddress: [4]byte(conn.Entity.IP), RemotePort: conn.Entity.Port, Verdict: uint8(conn.Verdict.Active), RedirectAddress: redirectAddress, RedirectPort: uint16(redirectPort), } - command := kext_interface.BuildUpdate(update) - kext_interface.WriteCommand(kextFile, command) + kext_interface.WriteUpdateCommand(kextFile, update) return nil } +// Returns the kext version. func GetVersion() (*VersionInfo, error) { data, err := kext_interface.ReadVersion(kextFile) if err != nil { @@ -148,9 +153,3 @@ func GetVersion() (*VersionInfo, error) { } return version, nil } - -var sizeOfConnectionStat = uint32(unsafe.Sizeof(ConnectionStat{})) - -func GetConnectionsStats() ([]ConnectionStat, error) { - return nil, nil -}