From 7babfb13abf78e89699ff87f9c88ca801d6ac24e Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Mon, 29 Jan 2024 22:23:54 +0200 Subject: [PATCH] Add bandwidth stats support --- firewall/interception/interception_windows.go | 24 +- .../windowskext/bandwidth_stats.go | 2 +- firewall/interception/windowskext2/handler.go | 209 +++++++++++++----- firewall/interception/windowskext2/kext.go | 20 +- firewall/interception/windowskext2/packet.go | 3 +- go.mod | 2 + 6 files changed, 185 insertions(+), 75 deletions(-) diff --git a/firewall/interception/interception_windows.go b/firewall/interception/interception_windows.go index d46cbeb8..54954316 100644 --- a/firewall/interception/interception_windows.go +++ b/firewall/interception/interception_windows.go @@ -30,14 +30,30 @@ func startInterception(packets chan packet.Packet) error { // Start packet handler. module.StartServiceWorker("kext packet handler", 0, func(ctx context.Context) error { - windowskext.Handler(ctx, packets) + windowskext.Handler(ctx, packets, BandwidthUpdates) return nil }) // Start bandwidth stats monitor. - // module.StartServiceWorker("kext bandwidth stats monitor", 0, func(ctx context.Context) error { - // return windowskext.BandwidthStatsWorker(ctx, 1*time.Second, BandwidthUpdates) - // }) + module.StartServiceWorker("kext bandwidth request worker", 0, func(ctx context.Context) error { + timer := time.NewTicker(1 * time.Second) + for { + select { + case <-timer.C: + { + err := windowskext.SendBandwidthStatsRequest() + if err != nil { + return err + } + } + case <-ctx.Done(): + { + return nil + } + } + + } + }) // Start kext logging. The worker will periodically send request to the kext to send logs. module.StartServiceWorker("kext log request worker", 0, func(ctx context.Context) error { diff --git a/firewall/interception/windowskext/bandwidth_stats.go b/firewall/interception/windowskext/bandwidth_stats.go index 2a1bddc0..482d81b6 100644 --- a/firewall/interception/windowskext/bandwidth_stats.go +++ b/firewall/interception/windowskext/bandwidth_stats.go @@ -81,7 +81,7 @@ func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.Bandwidt return nil } -func StartBandwithConsoleLogger() { +func StartBandwidthConsoleLogger() { go func() { ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() diff --git a/firewall/interception/windowskext2/handler.go b/firewall/interception/windowskext2/handler.go index ba5169fd..1b71d229 100644 --- a/firewall/interception/windowskext2/handler.go +++ b/firewall/interception/windowskext2/handler.go @@ -6,6 +6,7 @@ package windowskext import ( "context" "fmt" + "net" "time" "github.com/safing/portmaster/process" @@ -28,7 +29,7 @@ func (v *VersionInfo) String() string { } // Handler transforms received packets to the Packet interface. -func Handler(ctx context.Context, packets chan packet.Packet) { +func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate chan *packet.BandwidthUpdate) { for { packetInfo, err := RecvVerdictRequest() if err != nil { @@ -36,71 +37,157 @@ func Handler(ctx context.Context, packets chan packet.Packet) { return } - if packetInfo.Connection != nil { - log.Tracef("packet: %+v", packetInfo.Connection) - conn := packetInfo.Connection - // New Packet - new := &Packet{ - verdictRequest: conn.Id, - verdictSet: abool.NewBool(false), + switch { + case packetInfo.ConnectionV4 != nil: + { + log.Tracef("packet: %+v", packetInfo.ConnectionV4) + conn := packetInfo.ConnectionV4 + // New Packet + new := &Packet{ + verdictRequest: conn.Id, + verdictSet: abool.NewBool(false), + } + info := new.Info() + info.Inbound = conn.Direction > 0 + info.InTunnel = false + info.Protocol = packet.IPProtocol(conn.Protocol) + info.PID = int(conn.ProcessId) + info.SeenAt = time.Now() + + // Check PID + if info.PID == 0 { + // Windows does not have zero PIDs. + // Set to UndefinedProcessID. + info.PID = process.UndefinedProcessID + } + + // Set IP version + info.Version = packet.IPv4 + + // Set IPs + if info.Inbound { + // Inbound + info.Src = conn.RemoteIp[:] + info.Dst = conn.LocalIp[:] + } else { + // Outbound + info.Src = conn.LocalIp[:] + info.Dst = conn.RemoteIp[:] + } + + // Set Ports + if info.Inbound { + // Inbound + info.SrcPort = conn.RemotePort + info.DstPort = conn.LocalPort + } else { + // Outbound + info.SrcPort = conn.LocalPort + info.DstPort = conn.RemotePort + } + + packets <- new } - info := new.Info() - info.Inbound = conn.Direction > 0 - info.InTunnel = false - info.Protocol = packet.IPProtocol(conn.Protocol) - info.PID = int(conn.ProcessId) - info.SeenAt = time.Now() + case packetInfo.ConnectionV6 != nil: + { + log.Tracef("packet: %+v", packetInfo.ConnectionV6) + conn := packetInfo.ConnectionV6 + // New Packet + new := &Packet{ + verdictRequest: conn.Id, + verdictSet: abool.NewBool(false), + } + info := new.Info() + info.Inbound = conn.Direction > 0 + info.InTunnel = false + info.Protocol = packet.IPProtocol(conn.Protocol) + info.PID = int(conn.ProcessId) + info.SeenAt = time.Now() - // Check PID - if info.PID == 0 { - // Windows does not have zero PIDs. - // Set to UndefinedProcessID. - info.PID = process.UndefinedProcessID + // Check PID + if info.PID == 0 { + // Windows does not have zero PIDs. + // Set to UndefinedProcessID. + info.PID = process.UndefinedProcessID + } + + // Set IP version + info.Version = packet.IPv6 + + // Set IPs + if info.Inbound { + // Inbound + info.Src = conn.RemoteIp[:] + info.Dst = conn.LocalIp[:] + } else { + // Outbound + info.Src = conn.LocalIp[:] + info.Dst = conn.RemoteIp[:] + } + + // Set Ports + if info.Inbound { + // Inbound + info.SrcPort = conn.RemotePort + info.DstPort = conn.LocalPort + } else { + // Outbound + info.SrcPort = conn.LocalPort + info.DstPort = conn.RemotePort + } + + packets <- new } - - // Set IP version - info.Version = packet.IPv4 - - // Set IPs - if info.Inbound { - // Inbound - info.Src = conn.RemoteIp[:] - info.Dst = conn.LocalIp[:] - } else { - // Outbound - info.Src = conn.LocalIp[:] - info.Dst = conn.RemoteIp[:] + case packetInfo.LogLine != nil: + { + line := packetInfo.LogLine + switch line.Severity { + case byte(log.DebugLevel): + log.Debugf("kext: %s", line.Line) + case byte(log.InfoLevel): + log.Infof("kext: %s", line.Line) + case byte(log.WarningLevel): + log.Warningf("kext: %s", line.Line) + case byte(log.ErrorLevel): + log.Errorf("kext: %s", line.Line) + case byte(log.CriticalLevel): + log.Criticalf("kext: %s", line.Line) + } } - - // Set Ports - if info.Inbound { - // Inbound - info.SrcPort = conn.RemotePort - info.DstPort = conn.LocalPort - } else { - // Outbound - info.SrcPort = conn.LocalPort - info.DstPort = conn.RemotePort + case packetInfo.BandwidthStats != nil: + { + bandwidthStats := packetInfo.BandwidthStats + for _, stat := range bandwidthStats.ValuesV4 { + connID := packet.CreateConnectionID( + packet.IPProtocol(bandwidthStats.Protocol), + net.IP(stat.LocalIP[:]), stat.LocalPort, + net.IP(stat.RemoteIP[:]), stat.RemotePort, + false, + ) + update := &packet.BandwidthUpdate{ + ConnID: connID, + BytesReceived: stat.ReceivedBytes, + BytesSent: stat.TransmittedBytes, + Method: packet.Additive, + } + bandwidthUpdate <- update + } + for _, stat := range bandwidthStats.ValuesV6 { + connID := packet.CreateConnectionID( + packet.IPProtocol(bandwidthStats.Protocol), + net.IP(stat.LocalIP[:]), stat.LocalPort, + net.IP(stat.RemoteIP[:]), stat.RemotePort, + false, + ) + update := &packet.BandwidthUpdate{ + ConnID: connID, + BytesReceived: stat.ReceivedBytes, + BytesSent: stat.TransmittedBytes, + Method: packet.Additive, + } + bandwidthUpdate <- update + } } - - packets <- new } - - // if packetInfo.LogLines != nil { - // for _, line := range *packetInfo.LogLines { - // switch line.Severity { - // case int(log.DebugLevel): - // log.Debugf("kext: %s", line.Line) - // case int(log.InfoLevel): - // log.Infof("kext: %s", line.Line) - // case int(log.WarningLevel): - // log.Warningf("kext: %s", line.Line) - // case int(log.ErrorLevel): - // log.Errorf("kext: %s", line.Line) - // case int(log.CriticalLevel): - // log.Criticalf("kext: %s", line.Line) - // } - // } - // } } } diff --git a/firewall/interception/windowskext2/kext.go b/firewall/interception/windowskext2/kext.go index ca90b991..cae347c4 100644 --- a/firewall/interception/windowskext2/kext.go +++ b/firewall/interception/windowskext2/kext.go @@ -77,37 +77,41 @@ func Stop() error { // Sends a shutdown request. func shutdownRequest() error { - return kext_interface.WriteShutdownCommand(kextFile) + return kext_interface.SendShutdownCommand(kextFile) } // Send request for logs of the kext. func SendLogRequest() error { - return kext_interface.WriteGetLogsCommand(kextFile) + return kext_interface.SendGetLogsCommand(kextFile) +} + +func SendBandwidthStatsRequest() error { + return kext_interface.SendGetBandwidthStatsCommand(kextFile) } // RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil. func RecvVerdictRequest() (*kext_interface.Info, error) { - return kext_interface.ReadInfo(kextFile) + return kext_interface.RecvInfo(kextFile) } // SetVerdict sets the verdict for a packet and/or connection. func SetVerdict(pkt *Packet, verdict network.Verdict) error { if verdict == network.VerdictRerouteToNameserver { redirect := kext_interface.RedirectV4{Id: pkt.verdictRequest, RemoteAddress: [4]uint8{127, 0, 0, 1}, RemotePort: 53} - kext_interface.WriteRedirectCommand(kextFile, redirect) + kext_interface.SendRedirectV4Command(kextFile, redirect) } else if verdict == network.VerdictRerouteToTunnel { redirect := kext_interface.RedirectV4{Id: pkt.verdictRequest, RemoteAddress: [4]uint8{192, 168, 122, 196}, RemotePort: 717} - kext_interface.WriteRedirectCommand(kextFile, redirect) + kext_interface.SendRedirectV4Command(kextFile, redirect) } else { verdict := kext_interface.Verdict{Id: pkt.verdictRequest, Verdict: uint8(verdict)} - kext_interface.WriteVerdictCommand(kextFile, verdict) + kext_interface.SendVerdictCommand(kextFile, verdict) } return nil } // Clears the internal connection cache. func ClearCache() error { - return kext_interface.WriteClearCacheCommand(kextFile) + return kext_interface.SendClearCacheCommand(kextFile) } // Updates a specific connection verdict. @@ -134,7 +138,7 @@ func UpdateVerdict(conn *network.Connection) error { RedirectPort: uint16(redirectPort), } - kext_interface.WriteUpdateCommand(kextFile, update) + kext_interface.SendUpdateV4Command(kextFile, update) return nil } diff --git a/firewall/interception/windowskext2/packet.go b/firewall/interception/windowskext2/packet.go index 318d1de6..2ec37aaa 100644 --- a/firewall/interception/windowskext2/packet.go +++ b/firewall/interception/windowskext2/packet.go @@ -4,6 +4,7 @@ package windowskext import ( + "fmt" "sync" "github.com/tevino/abool" @@ -42,7 +43,7 @@ func (pkt *Packet) ExpectInfo() bool { // GetPayload returns the full raw packet. func (pkt *Packet) LoadPacketData() error { - return nil + return fmt.Errorf("Not implemented") } // Accept accepts the packet. diff --git a/go.mod b/go.mod index 78158f63..88c2138b 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,8 @@ toolchain go1.21.2 // TODO: Remove when https://github.com/tc-hib/winres/pull/4 is merged or changes are otherwise integrated. replace github.com/tc-hib/winres => github.com/dhaavi/winres v0.2.2 +replace github.com/vlabo/portmaster_windows_rust_kext/kext_interface => /home/vladimir/Dev/Safing/portmaster_windows_rust_kext/kext_interface + require ( github.com/Xuanwo/go-locale v1.1.0 github.com/agext/levenshtein v1.2.3