mirror of
https://github.com/safing/portmaster
synced 2025-09-02 02:29:12 +00:00
Add call priorities to winkext
This commit is contained in:
parent
0af3ab2305
commit
ab81f02d94
3 changed files with 49 additions and 14 deletions
|
@ -4,9 +4,12 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/Safing/portbase/log"
|
||||||
"github.com/Safing/portmaster/network"
|
"github.com/Safing/portmaster/network"
|
||||||
"github.com/tevino/abool"
|
"github.com/tevino/abool"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
@ -16,11 +19,19 @@ import (
|
||||||
var (
|
var (
|
||||||
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
|
ErrKextNotReady = errors.New("the windows kernel extension (driver) is not ready to accept commands")
|
||||||
|
|
||||||
kext *WinKext
|
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
|
||||||
kextLock sync.RWMutex
|
|
||||||
ready = abool.NewBool(false)
|
kext *WinKext
|
||||||
|
kextLock sync.RWMutex
|
||||||
|
ready = abool.NewBool(false)
|
||||||
|
urgentRequests *int32
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var urgentRequestsValue int32
|
||||||
|
urgentRequests = &urgentRequestsValue
|
||||||
|
}
|
||||||
|
|
||||||
// WinKext holds the DLL handle.
|
// WinKext holds the DLL handle.
|
||||||
type WinKext struct {
|
type WinKext struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
@ -80,7 +91,7 @@ func Init(dllPath, driverPath string) error {
|
||||||
// initialize dll/kext
|
// initialize dll/kext
|
||||||
rc, _, lastErr := new.init.Call()
|
rc, _, lastErr := new.init.Call()
|
||||||
if rc != windows.NO_ERROR {
|
if rc != windows.NO_ERROR {
|
||||||
return formatErr(lastErr)
|
return formatErr(lastErr, rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// set kext
|
// set kext
|
||||||
|
@ -105,7 +116,7 @@ func Start() error {
|
||||||
uintptr(unsafe.Pointer(&charArray[0])),
|
uintptr(unsafe.Pointer(&charArray[0])),
|
||||||
)
|
)
|
||||||
if rc != windows.NO_ERROR {
|
if rc != windows.NO_ERROR {
|
||||||
return formatErr(lastErr)
|
return formatErr(lastErr, rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
ready.Set()
|
ready.Set()
|
||||||
|
@ -123,7 +134,7 @@ func Stop() error {
|
||||||
|
|
||||||
rc, _, lastErr := kext.stop.Call()
|
rc, _, lastErr := kext.stop.Call()
|
||||||
if rc != windows.NO_ERROR {
|
if rc != windows.NO_ERROR {
|
||||||
return formatErr(lastErr)
|
return formatErr(lastErr, rc)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -138,14 +149,28 @@ func RecvVerdictRequest() (*VerdictRequest, error) {
|
||||||
|
|
||||||
new := &VerdictRequest{}
|
new := &VerdictRequest{}
|
||||||
|
|
||||||
|
// wait for urgent requests to complete
|
||||||
|
for i := 1; i <= 100; i++ {
|
||||||
|
if atomic.LoadInt32(urgentRequests) <= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i == 100 {
|
||||||
|
log.Warningf("winkext: RecvVerdictRequest waited 100 times")
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Microsecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// timestamp := time.Now()
|
||||||
rc, _, lastErr := kext.recvVerdictRequest.Call(
|
rc, _, lastErr := kext.recvVerdictRequest.Call(
|
||||||
uintptr(unsafe.Pointer(new)),
|
uintptr(unsafe.Pointer(new)),
|
||||||
)
|
)
|
||||||
if rc != 0 {
|
// log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp))
|
||||||
if rc == 13 /* ERROR_INVALID_DATA */ {
|
|
||||||
|
if rc != windows.NO_ERROR {
|
||||||
|
if rc == winErrInvalidData {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, formatErr(lastErr)
|
return nil, formatErr(lastErr, rc)
|
||||||
}
|
}
|
||||||
return new, nil
|
return new, nil
|
||||||
}
|
}
|
||||||
|
@ -158,12 +183,16 @@ func SetVerdict(packetID uint32, verdict network.Verdict) error {
|
||||||
return ErrKextNotReady
|
return ErrKextNotReady
|
||||||
}
|
}
|
||||||
|
|
||||||
|
atomic.AddInt32(urgentRequests, 1)
|
||||||
|
// timestamp := time.Now()
|
||||||
rc, _, lastErr := kext.setVerdict.Call(
|
rc, _, lastErr := kext.setVerdict.Call(
|
||||||
uintptr(packetID),
|
uintptr(packetID),
|
||||||
uintptr(verdict),
|
uintptr(verdict),
|
||||||
)
|
)
|
||||||
|
// log.Tracef("winkext: settings verdict for packetID %d took %s", packetID, time.Now().Sub(timestamp))
|
||||||
|
atomic.AddInt32(urgentRequests, -1)
|
||||||
if rc != windows.NO_ERROR {
|
if rc != windows.NO_ERROR {
|
||||||
return formatErr(lastErr)
|
return formatErr(lastErr, rc)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -178,13 +207,18 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
|
||||||
|
|
||||||
buf := make([]byte, packetSize)
|
buf := make([]byte, packetSize)
|
||||||
|
|
||||||
|
atomic.AddInt32(urgentRequests, 1)
|
||||||
|
// timestamp := time.Now()
|
||||||
rc, _, lastErr := kext.getPayload.Call(
|
rc, _, lastErr := kext.getPayload.Call(
|
||||||
uintptr(packetID),
|
uintptr(packetID),
|
||||||
uintptr(unsafe.Pointer(&buf[0])),
|
uintptr(unsafe.Pointer(&buf[0])),
|
||||||
uintptr(unsafe.Pointer(&packetSize)),
|
uintptr(unsafe.Pointer(&packetSize)),
|
||||||
)
|
)
|
||||||
|
// log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp))
|
||||||
|
atomic.AddInt32(urgentRequests, -1)
|
||||||
|
|
||||||
if rc != windows.NO_ERROR {
|
if rc != windows.NO_ERROR {
|
||||||
return nil, formatErr(lastErr)
|
return nil, formatErr(lastErr, rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
if packetSize == 0 {
|
if packetSize == 0 {
|
||||||
|
@ -197,10 +231,10 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatErr(err error) error {
|
func formatErr(err error, rc uintptr) error {
|
||||||
sysErr, ok := err.(syscall.Errno)
|
sysErr, ok := err.(syscall.Errno)
|
||||||
if ok {
|
if ok {
|
||||||
return fmt.Errorf("%s [0x%X]", err, uintptr(sysErr))
|
return fmt.Errorf("%s [LE 0x%X] [RC 0x%X]", err, uintptr(sysErr), rc)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@ func (pkt *Packet) GetPayload() ([]byte, error) {
|
||||||
|
|
||||||
payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize)
|
payload, err := GetPayload(pkt.verdictRequest.id, pkt.verdictRequest.packetSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Tracer(pkt.Ctx()).Warningf("windowskext: failed to load payload %s", err)
|
||||||
log.Errorf("windowskext: failed to load payload %s", err)
|
log.Errorf("windowskext: failed to load payload %s", err)
|
||||||
return nil, packet.ErrFailedToLoadPayload
|
return nil, packet.ErrFailedToLoadPayload
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,7 @@ func main() {
|
||||||
// stop
|
// stop
|
||||||
err = windowskext.Stop()
|
err = windowskext.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("error stopping: %s\n", err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("shutdown complete")
|
log.Info("shutdown complete")
|
||||||
|
|
Loading…
Add table
Reference in a new issue