Add call priorities to winkext

This commit is contained in:
Daniel 2019-05-10 11:55:42 +02:00
parent 0af3ab2305
commit ab81f02d94
3 changed files with 49 additions and 14 deletions

View file

@ -4,9 +4,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/Safing/portbase/log"
"github.com/Safing/portmaster/network" "github.com/Safing/portmaster/network"
"github.com/tevino/abool" "github.com/tevino/abool"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@ -16,11 +19,19 @@ import (
var ( var (
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands") ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
kext *WinKext winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
kextLock sync.RWMutex
ready = abool.NewBool(false) kext *WinKext
kextLock sync.RWMutex
ready = abool.NewBool(false)
urgentRequests *int32
) )
func init() {
var urgentRequestsValue int32
urgentRequests = &urgentRequestsValue
}
// WinKext holds the DLL handle. // WinKext holds the DLL handle.
type WinKext struct { type WinKext struct {
sync.RWMutex sync.RWMutex
@ -80,7 +91,7 @@ func Init(dllPath, driverPath string) error {
// initialize dll/kext // initialize dll/kext
rc, _, lastErr := new.init.Call() rc, _, lastErr := new.init.Call()
if rc != windows.NO_ERROR { if rc != windows.NO_ERROR {
return formatErr(lastErr) return formatErr(lastErr, rc)
} }
// set kext // set kext
@ -105,7 +116,7 @@ func Start() error {
uintptr(unsafe.Pointer(&charArray[0])), uintptr(unsafe.Pointer(&charArray[0])),
) )
if rc != windows.NO_ERROR { if rc != windows.NO_ERROR {
return formatErr(lastErr) return formatErr(lastErr, rc)
} }
ready.Set() ready.Set()
@ -123,7 +134,7 @@ func Stop() error {
rc, _, lastErr := kext.stop.Call() rc, _, lastErr := kext.stop.Call()
if rc != windows.NO_ERROR { if rc != windows.NO_ERROR {
return formatErr(lastErr) return formatErr(lastErr, rc)
} }
return nil return nil
} }
@ -138,14 +149,28 @@ func RecvVerdictRequest() (*VerdictRequest, error) {
new := &VerdictRequest{} new := &VerdictRequest{}
// wait for urgent requests to complete
for i := 1; i <= 100; i++ {
if atomic.LoadInt32(urgentRequests) <= 0 {
break
}
if i == 100 {
log.Warningf("winkext: RecvVerdictRequest waited 100 times")
}
time.Sleep(100 * time.Microsecond)
}
// timestamp := time.Now()
rc, _, lastErr := kext.recvVerdictRequest.Call( rc, _, lastErr := kext.recvVerdictRequest.Call(
uintptr(unsafe.Pointer(new)), uintptr(unsafe.Pointer(new)),
) )
if rc != 0 { // log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp))
if rc == 13 /* ERROR_INVALID_DATA */ {
if rc != windows.NO_ERROR {
if rc == winErrInvalidData {
return nil, nil return nil, nil
} }
return nil, formatErr(lastErr) return nil, formatErr(lastErr, rc)
} }
return new, nil return new, nil
} }
@ -158,12 +183,16 @@ func SetVerdict(packetID uint32, verdict network.Verdict) error {
return ErrKextNotReady return ErrKextNotReady
} }
atomic.AddInt32(urgentRequests, 1)
// timestamp := time.Now()
rc, _, lastErr := kext.setVerdict.Call( rc, _, lastErr := kext.setVerdict.Call(
uintptr(packetID), uintptr(packetID),
uintptr(verdict), uintptr(verdict),
) )
// log.Tracef("winkext: settings verdict for packetID %d took %s", packetID, time.Now().Sub(timestamp))
atomic.AddInt32(urgentRequests, -1)
if rc != windows.NO_ERROR { if rc != windows.NO_ERROR {
return formatErr(lastErr) return formatErr(lastErr, rc)
} }
return nil return nil
} }
@ -178,13 +207,18 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
buf := make([]byte, packetSize) buf := make([]byte, packetSize)
atomic.AddInt32(urgentRequests, 1)
// timestamp := time.Now()
rc, _, lastErr := kext.getPayload.Call( rc, _, lastErr := kext.getPayload.Call(
uintptr(packetID), uintptr(packetID),
uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&buf[0])),
uintptr(unsafe.Pointer(&packetSize)), uintptr(unsafe.Pointer(&packetSize)),
) )
// log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp))
atomic.AddInt32(urgentRequests, -1)
if rc != windows.NO_ERROR { if rc != windows.NO_ERROR {
return nil, formatErr(lastErr) return nil, formatErr(lastErr, rc)
} }
if packetSize == 0 { if packetSize == 0 {
@ -197,10 +231,10 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
return buf, nil return buf, nil
} }
func formatErr(err error) error { func formatErr(err error, rc uintptr) error {
sysErr, ok := err.(syscall.Errno) sysErr, ok := err.(syscall.Errno)
if ok { if ok {
return fmt.Errorf("%s [0x%X]", err, uintptr(sysErr)) return fmt.Errorf("%s [LE 0x%X] [RC 0x%X]", err, uintptr(sysErr), rc)
} }
return err return err
} }

View file

@ -31,6 +31,7 @@ func (pkt *Packet) GetPayload() ([]byte, error) {
payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize) payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize)
if err != nil { if err != nil {
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to load payload %s", err)
log.Errorf("windowskext: failed to load payload %s", err) log.Errorf("windowskext: failed to load payload %s", err)
return nil, packet.ErrFailedToLoadPayload return nil, packet.ErrFailedToLoadPayload
} }

View file

@ -68,7 +68,7 @@ func main() {
// stop // stop
err = windowskext.Stop() err = windowskext.Stop()
if err != nil { if err != nil {
fmt.Printf("error stopping: %s\n", err) panic(err)
} }
log.Info("shutdown complete") log.Info("shutdown complete")