Update kext library

This commit is contained in:
Vladimir Stoilov 2024-01-22 01:15:56 +02:00
parent 1f2f0e5213
commit e308543f4f
4 changed files with 56 additions and 264 deletions

View file

@ -41,7 +41,7 @@ func startInterception(packets chan packet.Packet) error {
// Start kext logging. The worker will periodically send request to the kext to send logs. // Start kext logging. The worker will periodically send request to the kext to send logs.
module.StartServiceWorker("kext log request worker", 0, func(ctx context.Context) error { module.StartServiceWorker("kext log request worker", 0, func(ctx context.Context) error {
timer := time.NewTimer(time.Second) timer := time.NewTicker(1 * time.Second)
for { for {
select { select {
case <-timer.C: case <-timer.C:

View file

@ -1,132 +0,0 @@
//go:build windows
// +build windows
package windowskext
// This file contains example code how to read bandwidth stats from the kext. Its not ment to be used in production.
import (
"context"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network/packet"
)
type Rxtxdata struct {
rx uint64
tx uint64
}
type Key struct {
localIP [4]uint32
remoteIP [4]uint32
localPort uint16
remotePort uint16
ipv6 bool
protocol uint8
}
var m = make(map[Key]Rxtxdata)
func BandwidthStatsWorker(ctx context.Context, collectInterval time.Duration, bandwidthUpdates chan *packet.BandwidthUpdate) error {
// Setup ticker.
ticker := time.NewTicker(collectInterval)
defer ticker.Stop()
// Collect bandwidth at every tick.
for {
select {
case <-ticker.C:
err := reportBandwidth(ctx, bandwidthUpdates)
if err != nil {
return err
}
case <-ctx.Done():
return nil
}
}
}
func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.BandwidthUpdate) error {
stats, err := GetConnectionsStats()
if err != nil {
return err
}
// Report all statistics.
for i, stat := range stats {
connID := packet.CreateConnectionID(
packet.IPProtocol(stat.protocol),
convertArrayToIP(stat.localIP, stat.ipV6 == 1), stat.localPort,
convertArrayToIP(stat.remoteIP, stat.ipV6 == 1), stat.remotePort,
false,
)
update := &packet.BandwidthUpdate{
ConnID: connID,
BytesReceived: stat.receivedBytes,
BytesSent: stat.transmittedBytes,
Method: packet.Additive,
}
select {
case bandwidthUpdates <- update:
case <-ctx.Done():
return nil
default:
log.Warningf("kext: bandwidth update queue is full, skipping rest of batch (%d entries)", len(stats)-i)
return nil
}
}
return nil
}
func StartBandwithConsoleLogger() {
go func() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for range ticker.C {
conns, err := GetConnectionsStats()
if err != nil {
continue
}
for _, conn := range conns {
if conn.receivedBytes == 0 && conn.transmittedBytes == 0 {
continue
}
key := Key{
localIP: conn.localIP,
remoteIP: conn.remoteIP,
localPort: conn.localPort,
remotePort: conn.remotePort,
ipv6: conn.ipV6 == 1,
protocol: conn.protocol,
}
// First we get a "copy" of the entry
if entry, ok := m[key]; ok {
// Then we modify the copy
entry.rx += conn.receivedBytes
entry.tx += conn.transmittedBytes
// Then we reassign map entry
m[key] = entry
} else {
m[key] = Rxtxdata{
rx: conn.receivedBytes,
tx: conn.transmittedBytes,
}
}
}
log.Debug("----------------------------------")
for key, value := range m {
log.Debugf(
"Conn: %d %s:%d %s:%d rx:%d tx:%d", key.protocol,
convertArrayToIP(key.localIP, key.ipv6), key.localPort,
convertArrayToIP(key.remoteIP, key.ipv6), key.remotePort,
value.rx, value.tx,
)
}
}
}()
}

View file

@ -5,11 +5,8 @@ package windowskext
import ( import (
"context" "context"
"encoding/binary"
"fmt" "fmt"
"net"
"time" "time"
"unsafe"
"github.com/safing/portmaster/process" "github.com/safing/portmaster/process"
@ -19,34 +16,6 @@ import (
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
) )
const (
// VerdictRequestFlagFastTrackPermitted is set on packets that have been
// already permitted by the kernel extension and the verdict request is only
// informational.
VerdictRequestFlagFastTrackPermitted = 1
// VerdictRequestFlagSocketAuth indicates that the verdict request is for a
// connection that was intercepted on an ALE layer instead of in the network
// stack itself. Thus, no packet data is available.
VerdictRequestFlagSocketAuth = 2
// VerdictRequestFlagExpectSocketAuth indicates that the next verdict
// requests is expected to be an informational socket auth request from
// the ALE layer.
VerdictRequestFlagExpectSocketAuth = 4
)
type ConnectionStat struct {
localIP [4]uint32 //Source Address, only srcIP[0] if IPv4
remoteIP [4]uint32 //Destination Address
localPort uint16 //Source Port
remotePort uint16 //Destination port
receivedBytes uint64 //Number of bytes recived on this connection
transmittedBytes uint64 //Number of bytes transsmited from this connection
ipV6 uint8 //True: IPv6, False: IPv4
protocol uint8 //Protocol (UDP, TCP, ...)
}
type VersionInfo struct { type VersionInfo struct {
major uint8 major uint8
minor uint8 minor uint8
@ -79,7 +48,7 @@ func Handler(ctx context.Context, packets chan packet.Packet) {
info.Inbound = conn.Direction > 0 info.Inbound = conn.Direction > 0
info.InTunnel = false info.InTunnel = false
info.Protocol = packet.IPProtocol(conn.Protocol) info.Protocol = packet.IPProtocol(conn.Protocol)
info.PID = int(*conn.ProcessId) info.PID = int(conn.ProcessId)
info.SeenAt = time.Now() info.SeenAt = time.Now()
// Check PID // Check PID
@ -90,21 +59,17 @@ func Handler(ctx context.Context, packets chan packet.Packet) {
} }
// Set IP version // Set IP version
if conn.IpV6 {
info.Version = packet.IPv6
} else {
info.Version = packet.IPv4 info.Version = packet.IPv4
}
// Set IPs // Set IPs
if info.Inbound { if info.Inbound {
// Inbound // Inbound
info.Src = net.IP(conn.RemoteIp) info.Src = conn.RemoteIp[:]
info.Dst = net.IP(conn.LocalIp) info.Dst = conn.LocalIp[:]
} else { } else {
// Outbound // Outbound
info.Src = net.IP(conn.LocalIp) info.Src = conn.LocalIp[:]
info.Dst = net.IP(conn.RemoteIp) info.Dst = conn.RemoteIp[:]
} }
// Set Ports // Set Ports
@ -121,61 +86,21 @@ func Handler(ctx context.Context, packets chan packet.Packet) {
packets <- new packets <- new
} }
if packetInfo.LogLines != nil { // if packetInfo.LogLines != nil {
for _, line := range *packetInfo.LogLines { // for _, line := range *packetInfo.LogLines {
switch line.Severity { // switch line.Severity {
case int(log.DebugLevel): // case int(log.DebugLevel):
log.Debugf("kext: %s", line.Line) // log.Debugf("kext: %s", line.Line)
case int(log.InfoLevel): // case int(log.InfoLevel):
log.Infof("kext: %s", line.Line) // log.Infof("kext: %s", line.Line)
case int(log.WarningLevel): // case int(log.WarningLevel):
log.Warningf("kext: %s", line.Line) // log.Warningf("kext: %s", line.Line)
case int(log.ErrorLevel): // case int(log.ErrorLevel):
log.Errorf("kext: %s", line.Line) // log.Errorf("kext: %s", line.Line)
case int(log.CriticalLevel): // case int(log.CriticalLevel):
log.Criticalf("kext: %s", line.Line) // log.Criticalf("kext: %s", line.Line)
// }
// }
// }
} }
} }
}
}
}
// convertArrayToIP converts an array of uint32 values to a net.IP address.
func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP {
if !ipv6 {
addressBuf := make([]byte, 4)
binary.BigEndian.PutUint32(addressBuf, input[0])
return net.IP(addressBuf)
}
addressBuf := make([]byte, 16)
for i := 0; i < 4; i++ {
binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i])
}
return net.IP(addressBuf)
}
func ipAddressToArray(ip net.IP, isIPv6 bool) [4]uint32 {
array := [4]uint32{0}
if isIPv6 {
for i := 0; i < 4; i++ {
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[i], 4), getUInt32Value(&ip[i]))
}
} else {
binary.BigEndian.PutUint32(asByteArrayWithLength(&array[0], 4), getUInt32Value(&ip[0]))
}
return array
}
func asByteArray[T any](obj *T) []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), unsafe.Sizeof(*obj))
}
func asByteArrayWithLength[T any](obj *T, size uint32) []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(obj)), size)
}
func getUInt32Value[T any](obj *T) uint32 {
return *(*uint32)(unsafe.Pointer(obj))
}

View file

@ -4,9 +4,7 @@
package windowskext package windowskext
import ( import (
"errors"
"fmt" "fmt"
"unsafe"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
@ -15,9 +13,6 @@ import (
// Package errors // Package errors
var ( var (
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
ErrNoPacketID = errors.New("the packet has no ID, possibly because it was fast-tracked by the kernel extension")
driverPath string driverPath string
service *kext_interface.KextService service *kext_interface.KextService
@ -28,7 +23,6 @@ const (
driverName = "PortmasterKext" driverName = "PortmasterKext"
) )
// Init initializes the DLL and the Kext (Kernel Driver).
func Init(path string) error { func Init(path string) error {
driverPath = path driverPath = path
return nil return nil
@ -63,20 +57,32 @@ func Stop() error {
log.Warningf("winkext: shutdown request failed: %s", err) log.Warningf("winkext: shutdown request failed: %s", err)
} }
// Close the interface to the driver. Driver will continue to run. // Close the interface to the driver. Driver will continue to run.
kextFile.Close() err = kextFile.Close()
if err != nil {
log.Warningf("winkext: failed to close kext file: %s", err)
}
// Stop and delete the driver. // Stop and delete the driver.
service.Stop(true) err = service.Stop(true)
service.Delete() if err != nil {
log.Warningf("winkext: failed to stop kernel service: %s", err)
}
err = service.Delete()
if err != nil {
log.Warningf("winkext: failed to delete kernel service: %s", err)
}
return nil return nil
} }
// Sends a shutdown request.
func shutdownRequest() error { func shutdownRequest() error {
return kext_interface.WriteCommand(kextFile, kext_interface.BuildShutdown()) return kext_interface.WriteShutdownCommand(kextFile)
} }
// Send request for logs of the kext.
func SendLogRequest() error { func SendLogRequest() error {
return kext_interface.WriteCommand(kextFile, kext_interface.BuildGetLogs()) return kext_interface.WriteGetLogsCommand(kextFile)
} }
// RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil. // RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil.
@ -87,53 +93,52 @@ func RecvVerdictRequest() (*kext_interface.Info, error) {
// SetVerdict sets the verdict for a packet and/or connection. // SetVerdict sets the verdict for a packet and/or connection.
func SetVerdict(pkt *Packet, verdict network.Verdict) error { func SetVerdict(pkt *Packet, verdict network.Verdict) error {
if verdict == network.VerdictRerouteToNameserver { if verdict == network.VerdictRerouteToNameserver {
redirect := kext_interface.Redirect{Id: pkt.verdictRequest, RemoteAddress: []uint8{127, 0, 0, 1}, RemotePort: 53} redirect := kext_interface.RedirectV4{Id: pkt.verdictRequest, RemoteAddress: [4]uint8{127, 0, 0, 1}, RemotePort: 53}
command := kext_interface.BuildRedirect(redirect) kext_interface.WriteRedirectCommand(kextFile, redirect)
kext_interface.WriteCommand(kextFile, command)
} else if verdict == network.VerdictRerouteToTunnel { } else if verdict == network.VerdictRerouteToTunnel {
redirect := kext_interface.Redirect{Id: pkt.verdictRequest, RemoteAddress: []uint8{192, 168, 122, 196}, RemotePort: 717} redirect := kext_interface.RedirectV4{Id: pkt.verdictRequest, RemoteAddress: [4]uint8{192, 168, 122, 196}, RemotePort: 717}
command := kext_interface.BuildRedirect(redirect) kext_interface.WriteRedirectCommand(kextFile, redirect)
kext_interface.WriteCommand(kextFile, command)
} else { } else {
verdict := kext_interface.Verdict{Id: pkt.verdictRequest, Verdict: uint8(verdict)} verdict := kext_interface.Verdict{Id: pkt.verdictRequest, Verdict: uint8(verdict)}
command := kext_interface.BuildVerdict(verdict) kext_interface.WriteVerdictCommand(kextFile, verdict)
kext_interface.WriteCommand(kextFile, command)
} }
return nil return nil
} }
// Clears the internal connection cache.
func ClearCache() error { func ClearCache() error {
return kext_interface.WriteCommand(kextFile, kext_interface.BuildClearCache()) return kext_interface.WriteClearCacheCommand(kextFile)
} }
// Updates a specific connection verdict.
func UpdateVerdict(conn *network.Connection) error { func UpdateVerdict(conn *network.Connection) error {
redirectAddress := []uint8{} redirectAddress := [4]byte{}
redirectPort := 0 redirectPort := 0
if conn.Verdict.Active == network.VerdictRerouteToNameserver { if conn.Verdict.Active == network.VerdictRerouteToNameserver {
redirectAddress = []uint8{127, 0, 0, 1} redirectAddress = [4]byte{127, 0, 0, 1}
redirectPort = 53 redirectPort = 53
} }
if conn.Verdict.Active == network.VerdictRerouteToTunnel { if conn.Verdict.Active == network.VerdictRerouteToTunnel {
redirectAddress = []uint8{192, 168, 122, 196} redirectAddress = [4]byte{192, 168, 122, 196}
redirectPort = 717 redirectPort = 717
} }
update := kext_interface.Update{ update := kext_interface.UpdateV4{
Protocol: conn.Entity.Protocol, Protocol: conn.Entity.Protocol,
LocalAddress: conn.LocalIP, LocalAddress: [4]byte(conn.LocalIP),
LocalPort: conn.LocalPort, LocalPort: conn.LocalPort,
RemoteAddress: conn.Entity.IP, RemoteAddress: [4]byte(conn.Entity.IP),
RemotePort: conn.Entity.Port, RemotePort: conn.Entity.Port,
Verdict: uint8(conn.Verdict.Active), Verdict: uint8(conn.Verdict.Active),
RedirectAddress: redirectAddress, RedirectAddress: redirectAddress,
RedirectPort: uint16(redirectPort), RedirectPort: uint16(redirectPort),
} }
command := kext_interface.BuildUpdate(update) kext_interface.WriteUpdateCommand(kextFile, update)
kext_interface.WriteCommand(kextFile, command)
return nil return nil
} }
// Returns the kext version.
func GetVersion() (*VersionInfo, error) { func GetVersion() (*VersionInfo, error) {
data, err := kext_interface.ReadVersion(kextFile) data, err := kext_interface.ReadVersion(kextFile)
if err != nil { if err != nil {
@ -148,9 +153,3 @@ func GetVersion() (*VersionInfo, error) {
} }
return version, nil return version, nil
} }
var sizeOfConnectionStat = uint32(unsafe.Sizeof(ConnectionStat{}))
func GetConnectionsStats() ([]ConnectionStat, error) {
return nil, nil
}