Add packet payload for kext2

This commit is contained in:
Vladimir Stoilov 2024-04-17 11:15:29 +03:00
parent ead271f51c
commit c425007be1
No known key found for this signature in database
GPG key ID: 2F190B67A43A81AF
5 changed files with 157 additions and 129 deletions

View file

@ -105,26 +105,25 @@ func startInterception(packets chan packet.Packet) error {
} }
}) })
// Start kext logging. The worker will periodically send request to the kext to print memory stats. module.StartServiceWorker("kext clean ended connection worker", 0, func(ctx context.Context) error {
// module.StartServiceWorker("kext memory stats request worker", 0, func(ctx context.Context) error { timer := time.NewTicker(30 * time.Second)
// timer := time.NewTicker(20 * time.Second) for {
// for { select {
// select { case <-timer.C:
// case <-timer.C: {
// { err := kext2.SendCleanEndedConnection()
// err := kext2.SendPrintMemoryStatsCommand() if err != nil {
// if err != nil { return err
// return err }
// } }
// } case <-ctx.Done():
// case <-ctx.Done(): {
// { return nil
// return nil }
// } }
// }
// } }
// }) })
} }
return nil return nil

View file

@ -24,6 +24,7 @@ func createKextService(driverName string, driverPath string) (*KextService, erro
} }
defer windows.CloseServiceHandle(manager) defer windows.CloseServiceHandle(manager)
// Convert the driver name to a UTF16 string
driverNameU16, err := syscall.UTF16FromString(driverName) driverNameU16, err := syscall.UTF16FromString(driverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err) return nil, fmt.Errorf("failed to convert driver name to UTF16 string: %w", err)
@ -47,103 +48,103 @@ func createKextService(driverName string, driverPath string) (*KextService, erro
// Create the service // Create the service
service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil) service, err = windows.CreateService(manager, &driverNameU16[0], &driverNameU16[0], windows.SERVICE_ALL_ACCESS, windows.SERVICE_KERNEL_DRIVER, windows.SERVICE_DEMAND_START, windows.SERVICE_ERROR_NORMAL, &driverPathU16[0], nil, nil, nil, nil, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &KextService{handle: service}, nil return &KextService{handle: service}, nil
} }
func deleteService(manager windows.Handle, service *KextService, driverName []uint16) error { func deleteService(manager windows.Handle, service *KextService, driverName []uint16) error {
// Stop and wait before deleting // Stop and wait before deleting
_ = service.stop(true) _ = service.stop(true)
// Try to delete even if stop failed // Try to delete even if stop failed
err := service.delete() err := service.delete()
if err != nil { if err != nil {
return fmt.Errorf("failed to delete old service: %s", err) return fmt.Errorf("failed to delete old service: %s", err)
} }
// Wait until we can no longer open the old service. // Wait until we can no longer open the old service.
// Not very efficient but NotifyServiceStatusChange cannot be used with driver service. // Not very efficient but NotifyServiceStatusChange cannot be used with driver service.
start := time.Now() start := time.Now()
timeLimit := time.Duration(30 * time.Second) timeLimit := time.Duration(30 * time.Second)
for { for {
handle, err := windows.OpenService(manager, &driverName[0], windows.SERVICE_ALL_ACCESS) handle, err := windows.OpenService(manager, &driverName[0], windows.SERVICE_ALL_ACCESS)
if err != nil { if err != nil {
break break
} }
_ = windows.CloseServiceHandle(handle) _ = windows.CloseServiceHandle(handle)
if time.Since(start) > timeLimit { if time.Since(start) > timeLimit {
return fmt.Errorf("time limit reached") return fmt.Errorf("time limit reached")
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
return nil return nil
} }
func (s *KextService) isValid() bool { func (s *KextService) isValid() bool {
return s != nil && s.handle != winInvalidHandleValue && s.handle != 0 return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
} }
func (s *KextService) isRunning() (bool, error) { func (s *KextService) isRunning() (bool, error) {
if !s.isValid() { if !s.isValid() {
return false, fmt.Errorf("kext service not initialized") return false, fmt.Errorf("kext service not initialized")
} }
var status windows.SERVICE_STATUS var status windows.SERVICE_STATUS
err := windows.QueryServiceStatus(s.handle, &status) err := windows.QueryServiceStatus(s.handle, &status)
if err != nil { if err != nil {
return false, err return false, err
} }
return status.CurrentState == windows.SERVICE_RUNNING, nil return status.CurrentState == windows.SERVICE_RUNNING, nil
} }
func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) { func waitForServiceStatus(handle windows.Handle, neededStatus uint32, timeLimit time.Duration) (bool, error) {
var status windows.SERVICE_STATUS var status windows.SERVICE_STATUS
status.CurrentState = windows.SERVICE_NO_CHANGE status.CurrentState = windows.SERVICE_NO_CHANGE
start := time.Now() start := time.Now()
for status.CurrentState == neededStatus { for status.CurrentState == neededStatus {
err := windows.QueryServiceStatus(handle, &status) err := windows.QueryServiceStatus(handle, &status)
if err != nil { if err != nil {
return false, fmt.Errorf("failed while waiting for service to start: %w", err) return false, fmt.Errorf("failed while waiting for service to start: %w", err)
} }
if time.Since(start) > timeLimit { if time.Since(start) > timeLimit {
return false, fmt.Errorf("time limit reached") return false, fmt.Errorf("time limit reached")
} }
// Sleep for 1/10 of the wait hint, recommended time from microsoft // Sleep for 1/10 of the wait hint, recommended time from microsoft
time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond) time.Sleep(time.Duration((status.WaitHint / 10)) * time.Millisecond)
} }
return true, nil return true, nil
} }
func (s *KextService) start(wait bool) error { func (s *KextService) start(wait bool) error {
if !s.isValid() { if !s.isValid() {
return fmt.Errorf("kext service not initialized") return fmt.Errorf("kext service not initialized")
} }
// Start the service: // Start the service:
err := windows.StartService(s.handle, 0, nil) err := windows.StartService(s.handle, 0, nil)
if err != nil { if err != nil {
err = windows.GetLastError() err = windows.GetLastError()
if err != windows.ERROR_SERVICE_ALREADY_RUNNING { if err != windows.ERROR_SERVICE_ALREADY_RUNNING {
// Failed to start service; clean-up: // Failed to start service; clean-up:
var status windows.SERVICE_STATUS var status windows.SERVICE_STATUS
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status) _ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
_ = windows.DeleteService(s.handle) _ = windows.DeleteService(s.handle)
_ = windows.CloseServiceHandle(s.handle) _ = windows.CloseServiceHandle(s.handle)
s.handle = winInvalidHandleValue s.handle = winInvalidHandleValue
return err return err
} }
} }
// Wait for service to start // Wait for service to start
if wait { if wait {
success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second)) success, err := waitForServiceStatus(s.handle, windows.SERVICE_RUNNING, time.Duration(10*time.Second))
if err != nil || !success { if err != nil || !success {
return fmt.Errorf("service did not start: %w", err) return fmt.Errorf("service did not start: %w", err)

View file

@ -45,8 +45,10 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch
// New Packet // New Packet
new := &Packet{ new := &Packet{
verdictRequest: conn.Id, verdictRequest: conn.Id,
payload: conn.Payload,
verdictSet: abool.NewBool(false), verdictSet: abool.NewBool(false),
} }
new.Base.Payload()
info := new.Info() info := new.Info()
info.Inbound = conn.Direction > 0 info.Inbound = conn.Direction > 0
info.InTunnel = false info.InTunnel = false
@ -95,6 +97,7 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch
// New Packet // New Packet
new := &Packet{ new := &Packet{
verdictRequest: conn.Id, verdictRequest: conn.Id,
payload: conn.Payload,
verdictSet: abool.NewBool(false), verdictSet: abool.NewBool(false),
} }
info := new.Info() info := new.Info()

View file

@ -102,13 +102,17 @@ func SendPrintMemoryStatsCommand() error {
return kext_interface.SendPrintMemoryStatsCommand(kextFile) return kext_interface.SendPrintMemoryStatsCommand(kextFile)
} }
func SendCleanEndedConnection() error {
return kext_interface.SendCleanEndedConnectionsCommand(kextFile)
}
// RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil. // RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil.
func RecvVerdictRequest() (*kext_interface.Info, error) { func RecvVerdictRequest() (*kext_interface.Info, error) {
return kext_interface.RecvInfo(kextFile) return kext_interface.RecvInfo(kextFile)
} }
// SetVerdict sets the verdict for a packet and/or connection. // SetVerdict sets the verdict for a packet and/or connection.
func SetVerdict(pkt *Packet, verdict network.Verdict) error { func SetVerdict(pkt *Packet, verdict kext_interface.KextVerdict) error {
verdictCommand := kext_interface.Verdict{Id: pkt.verdictRequest, Verdict: uint8(verdict)} verdictCommand := kext_interface.Verdict{Id: pkt.verdictRequest, Verdict: uint8(verdict)}
return kext_interface.SendVerdictCommand(kextFile, verdictCommand) return kext_interface.SendVerdictCommand(kextFile, verdictCommand)
} }

View file

@ -4,12 +4,12 @@
package windowskext package windowskext
import ( import (
"fmt"
"sync" "sync"
"github.com/tevino/abool" "github.com/tevino/abool"
"github.com/vlabo/portmaster_windows_rust_kext/kext_interface"
"github.com/safing/portmaster/network" "github.com/safing/portbase/log"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
) )
@ -18,6 +18,7 @@ type Packet struct {
packet.Base packet.Base
verdictRequest uint64 verdictRequest uint64
payload []byte
verdictSet *abool.AtomicBool verdictSet *abool.AtomicBool
payloadLoaded bool payloadLoaded bool
@ -33,7 +34,7 @@ func (pkt *Packet) FastTrackedByIntegration() bool {
// InfoOnly returns whether the packet is informational only and does not // InfoOnly returns whether the packet is informational only and does not
// represent an actual packet. // represent an actual packet.
func (pkt *Packet) InfoOnly() bool { func (pkt *Packet) InfoOnly() bool {
return pkt.verdictRequest == 0 return false
} }
// ExpectInfo returns whether the next packet is expected to be informational only. // ExpectInfo returns whether the next packet is expected to be informational only.
@ -43,13 +44,33 @@ func (pkt *Packet) ExpectInfo() bool {
// GetPayload returns the full raw packet. // GetPayload returns the full raw packet.
func (pkt *Packet) LoadPacketData() error { func (pkt *Packet) LoadPacketData() error {
return fmt.Errorf("Not implemented") pkt.lock.Lock()
defer pkt.lock.Unlock()
if !pkt.payloadLoaded {
pkt.payloadLoaded = true
if len(pkt.payload) > 0 {
err := packet.Parse(pkt.payload, &pkt.Base)
if err != nil {
log.Tracef("payload: %#v", pkt.payload)
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to parse payload: %s", err)
return packet.ErrFailedToLoadPayload
}
}
}
if len(pkt.Raw()) == 0 {
return packet.ErrFailedToLoadPayload
}
return nil
} }
// Accept accepts the packet. // Accept accepts the packet.
func (pkt *Packet) Accept() error { func (pkt *Packet) Accept() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, -network.VerdictAccept) return SetVerdict(pkt, kext_interface.VerdictAccept)
} }
return nil return nil
} }
@ -57,7 +78,7 @@ func (pkt *Packet) Accept() error {
// Block blocks the packet. // Block blocks the packet.
func (pkt *Packet) Block() error { func (pkt *Packet) Block() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, -network.VerdictBlock) return SetVerdict(pkt, kext_interface.VerdictBlock)
} }
return nil return nil
} }
@ -65,7 +86,7 @@ func (pkt *Packet) Block() error {
// Drop drops the packet. // Drop drops the packet.
func (pkt *Packet) Drop() error { func (pkt *Packet) Drop() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, -network.VerdictDrop) return SetVerdict(pkt, kext_interface.VerdictDrop)
} }
return nil return nil
} }
@ -73,7 +94,7 @@ func (pkt *Packet) Drop() error {
// PermanentAccept permanently accepts connection (and the current packet). // PermanentAccept permanently accepts connection (and the current packet).
func (pkt *Packet) PermanentAccept() error { func (pkt *Packet) PermanentAccept() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictAccept) return SetVerdict(pkt, kext_interface.VerdictAccept)
} }
return nil return nil
} }
@ -81,7 +102,7 @@ func (pkt *Packet) PermanentAccept() error {
// PermanentBlock permanently blocks connection (and the current packet). // PermanentBlock permanently blocks connection (and the current packet).
func (pkt *Packet) PermanentBlock() error { func (pkt *Packet) PermanentBlock() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictBlock) return SetVerdict(pkt, kext_interface.VerdictBlock)
} }
return nil return nil
} }
@ -89,7 +110,7 @@ func (pkt *Packet) PermanentBlock() error {
// PermanentDrop permanently drops connection (and the current packet). // PermanentDrop permanently drops connection (and the current packet).
func (pkt *Packet) PermanentDrop() error { func (pkt *Packet) PermanentDrop() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictDrop) return SetVerdict(pkt, kext_interface.VerdictDrop)
} }
return nil return nil
} }
@ -97,7 +118,7 @@ func (pkt *Packet) PermanentDrop() error {
// RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet). // RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet).
func (pkt *Packet) RerouteToNameserver() error { func (pkt *Packet) RerouteToNameserver() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictRerouteToNameserver) return SetVerdict(pkt, kext_interface.VerdictRerouteToNameserver)
} }
return nil return nil
} }
@ -105,7 +126,7 @@ func (pkt *Packet) RerouteToNameserver() error {
// RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet). // RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet).
func (pkt *Packet) RerouteToTunnel() error { func (pkt *Packet) RerouteToTunnel() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictRerouteToTunnel) return SetVerdict(pkt, kext_interface.VerdictRerouteToTunnel)
} }
return nil return nil
} }