remove the need for the glue library

This commit is contained in:
Vladimir 2022-10-17 23:45:49 -07:00
parent f858ef492f
commit 3b341496af
3 changed files with 96 additions and 151 deletions

View file

@ -10,16 +10,12 @@ import (
// start starts the interception. // start starts the interception.
func start(ch chan packet.Packet) error { func start(ch chan packet.Packet) error {
dllFile, err := updates.GetPlatformFile("kext/portmaster-kext.dll")
if err != nil {
return fmt.Errorf("interception: could not get kext dll: %s", err)
}
kextFile, err := updates.GetPlatformFile("kext/portmaster-kext.sys") kextFile, err := updates.GetPlatformFile("kext/portmaster-kext.sys")
if err != nil { if err != nil {
return fmt.Errorf("interception: could not get kext sys: %s", err) return fmt.Errorf("interception: could not get kext sys: %s", err)
} }
err = windowskext.Init(dllFile.Path(), kextFile.Path()) err = windowskext.Init("", kextFile.Path())
if err != nil { if err != nil {
return fmt.Errorf("interception: could not init windows kext: %s", err) return fmt.Errorf("interception: could not init windows kext: %s", err)
} }

View file

@ -6,6 +6,7 @@ package windowskext
import ( import (
"errors" "errors"
"fmt" "fmt"
"os/exec"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
@ -25,10 +26,10 @@ var (
winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA) winErrInvalidData = uintptr(windows.ERROR_INVALID_DATA)
kext *WinKext
kextLock sync.RWMutex kextLock sync.RWMutex
ready = abool.NewBool(false) ready = abool.NewBool(false)
urgentRequests *int32 urgentRequests *int32
driverPath string
kextHandle windows.Handle kextHandle windows.Handle
) )
@ -38,85 +39,9 @@ func init() {
urgentRequests = &urgentRequestsValue urgentRequests = &urgentRequestsValue
} }
// WinKext holds the DLL handle.
type WinKext struct {
sync.RWMutex
dll *windows.DLL
driverPath string
init *windows.Proc
start *windows.Proc
stop *windows.Proc
recvVerdictRequest *windows.Proc
setVerdict *windows.Proc
getPayload *windows.Proc
clearCache *windows.Proc
setHandle *windows.Proc
}
// Init initializes the DLL and the Kext (Kernel Driver). // Init initializes the DLL and the Kext (Kernel Driver).
func Init(dllPath, driverPath string) error { func Init(dllPath, path string) error {
driverPath = path
new := &WinKext{
driverPath: driverPath,
}
var err error
// load dll
new.dll, err = windows.LoadDLL(dllPath)
if err != nil {
return err
}
// load functions
new.init, err = new.dll.FindProc("PortmasterInit")
if err != nil {
return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err)
}
new.start, err = new.dll.FindProc("PortmasterStart")
if err != nil {
return fmt.Errorf("could not find proc PortmasterStart in dll: %s", err)
}
new.stop, err = new.dll.FindProc("PortmasterStop")
if err != nil {
return fmt.Errorf("could not find proc PortmasterStop in dll: %s", err)
}
new.recvVerdictRequest, err = new.dll.FindProc("PortmasterRecvVerdictRequest")
if err != nil {
return fmt.Errorf("could not find proc PortmasterRecvVerdictRequest in dll: %s", err)
}
new.setVerdict, err = new.dll.FindProc("PortmasterSetVerdict")
if err != nil {
return fmt.Errorf("could not find proc PortmasterSetVerdict in dll: %s", err)
}
new.getPayload, err = new.dll.FindProc("PortmasterGetPayload")
if err != nil {
return fmt.Errorf("could not find proc PortmasterGetPayload in dll: %s", err)
}
new.clearCache, err = new.dll.FindProc("PortmasterClearCache")
if err != nil {
// the loaded dll is an old version
log.Errorf("could not find proc PortmasterClearCache (v1.0.12+) in dll: %s", err)
}
new.setHandle, err = new.dll.FindProc("PortmasterSetDeviceHandle")
if err != nil {
log.Errorf("could not find proc PortmasterSetDeviceHandle in dll: %s", err)
}
// initialize dll/kext
rc, _, lastErr := new.init.Call()
if rc != windows.NO_ERROR {
return formatErr(lastErr, rc)
}
// set kext
kextLock.Lock()
defer kextLock.Unlock()
kext = new
return nil return nil
} }
@ -132,7 +57,7 @@ func Start() error {
return fmt.Errorf("Bad filename: %s", err) return fmt.Errorf("Bad filename: %s", err)
} }
u16DriverPath, err := syscall.UTF16FromString(kext.driverPath) u16DriverPath, err := syscall.UTF16FromString(driverPath)
if err != nil { if err != nil {
return fmt.Errorf("Bad driver path: %s", err) return fmt.Errorf("Bad driver path: %s", err)
} }
@ -157,26 +82,14 @@ func Start() error {
return fmt.Errorf("Faield to kext service: %s %q", err, filename) return fmt.Errorf("Faield to kext service: %s %q", err, filename)
} }
// rc, _, lastErr := kext.start.Call(
// uintptr(unsafe.Pointer(&charArray[0])),
// )
// if rc != windows.NO_ERROR {
// return formatErr(lastErr, rc)
// }
kext.setHandle.Call(uintptr(kextHandle))
ready.Set() ready.Set()
testRead() testRead()
return nil return nil
} }
func testRead() { func testRead() {
buf := [5]byte{1, 2, 3, 4, 5} buf := [5]byte{1, 2, 3, 4, 5}
var read uint32 = 0 _, err := deviceIoControl(IOCTL_TEST, &buf[0], uintptr(len(buf)))
err := windows.ReadFile(kextHandle, buf[:], &read, nil)
if err != nil { if err != nil {
log.Criticalf("Erro reading test data: %s", err) log.Criticalf("Erro reading test data: %s", err)
} }
@ -252,10 +165,11 @@ func Stop() error {
} }
ready.UnSet() ready.UnSet()
rc, _, lastErr := kext.stop.Call() err := windows.CloseHandle(kextHandle)
if rc != windows.NO_ERROR { if err != nil {
return formatErr(lastErr, rc) log.Errorf("kext: faield to close handle: %s", err)
} }
_, _ = exec.Command("sc", "stop", "PortmasterKext").Output()
return nil return nil
} }
@ -266,9 +180,6 @@ func RecvVerdictRequest() (*VerdictRequest, error) {
if !ready.IsSet() { if !ready.IsSet() {
return nil, ErrKextNotReady return nil, ErrKextNotReady
} }
new := &VerdictRequest{}
// wait for urgent requests to complete // wait for urgent requests to complete
for i := 1; i <= 100; i++ { for i := 1; i <= 100; i++ {
if atomic.LoadInt32(urgentRequests) <= 0 { if atomic.LoadInt32(urgentRequests) <= 0 {
@ -280,19 +191,18 @@ func RecvVerdictRequest() (*VerdictRequest, error) {
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
} }
// timestamp := time.Now() timestamp := time.Now()
rc, _, lastErr := kext.recvVerdictRequest.Call( var new VerdictRequest
uintptr(unsafe.Pointer(new)),
)
// log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp))
if rc != windows.NO_ERROR { data := (*byte)(unsafe.Pointer(&new))
if rc == winErrInvalidData { _, err := deviceIoControl(IOCTL_RECV_VERDICT_REQ, data, unsafe.Sizeof(new))
return nil, nil if err != nil {
return nil, err
} }
return nil, formatErr(lastErr, rc) log.Tracef("winkext: getting verdict request took %s", time.Now().Sub(timestamp))
}
return new, nil log.Criticalf("%v", new)
return &new, nil
} }
// SetVerdict sets the verdict for a packet and/or connection. // SetVerdict sets the verdict for a packet and/or connection.
@ -309,17 +219,18 @@ func SetVerdict(pkt *Packet, verdict network.Verdict) error {
return ErrKextNotReady return ErrKextNotReady
} }
verdictInfo := struct {
id uint32
verdict network.Verdict
}{pkt.verdictRequest.id, verdict}
atomic.AddInt32(urgentRequests, 1) atomic.AddInt32(urgentRequests, 1)
// timestamp := time.Now() _, err := deviceIoControlBufferd(IOCTL_SET_VERDICT,
rc, _, lastErr := kext.setVerdict.Call( (*byte)(unsafe.Pointer(&verdictInfo)), unsafe.Sizeof(verdictInfo), nil, 0)
uintptr(pkt.verdictRequest.id),
uintptr(verdict),
)
// log.Tracef("winkext: settings verdict for packetID %d took %s", packetID, time.Now().Sub(timestamp))
atomic.AddInt32(urgentRequests, -1) atomic.AddInt32(urgentRequests, -1)
if rc != windows.NO_ERROR { if err != nil {
log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id) log.Tracer(pkt.Ctx()).Errorf("kext: failed to set verdict %s on packet %d", verdict, pkt.verdictRequest.id)
return formatErr(lastErr, rc) return err
} }
return nil return nil
} }
@ -338,26 +249,31 @@ func GetPayload(packetID uint32, packetSize uint32) ([]byte, error) {
buf := make([]byte, packetSize) buf := make([]byte, packetSize)
payload := struct {
id uint32
length uint32
}{packetID, packetSize}
atomic.AddInt32(urgentRequests, 1) atomic.AddInt32(urgentRequests, 1)
writenSize, err := deviceIoControlBufferd(IOCTL_GET_PAYLOAD,
(*byte)(unsafe.Pointer(&payload)), unsafe.Sizeof(payload),
&buf[0], uintptr(packetSize))
// timestamp := time.Now() // timestamp := time.Now()
rc, _, lastErr := kext.getPayload.Call(
uintptr(packetID),
uintptr(unsafe.Pointer(&buf[0])),
uintptr(unsafe.Pointer(&packetSize)),
)
// log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp)) // log.Tracef("winkext: getting payload for packetID %d took %s", packetID, time.Now().Sub(timestamp))
atomic.AddInt32(urgentRequests, -1) atomic.AddInt32(urgentRequests, -1)
if rc != windows.NO_ERROR { if err != nil {
return nil, formatErr(lastErr, rc) return nil, err
} }
if packetSize == 0 { if writenSize == 0 {
return nil, errors.New("windows kext did not return any data") return nil, errors.New("windows kext did not return any data")
} }
if packetSize < uint32(len(buf)) { if writenSize < uint32(len(buf)) {
return buf[:packetSize], nil return buf[:writenSize], nil
} }
return buf, nil return buf, nil
@ -371,23 +287,6 @@ func ClearCache() error {
return ErrKextNotReady return ErrKextNotReady
} }
if kext.clearCache == nil { _, err := deviceIoControl(IOCTL_CLEAR_CACHE, nil, 0)
log.Error("kext: cannot clear cache: clearCache function missing")
}
rc, _, lastErr := kext.clearCache.Call()
if rc != windows.NO_ERROR {
return formatErr(lastErr, rc)
}
return nil
}
func formatErr(err error, rc uintptr) error {
sysErr, ok := err.(syscall.Errno)
if ok {
return fmt.Errorf("%s [LE 0x%X] [RC 0x%X]", err, uintptr(sysErr), rc)
}
return err return err
} }

View file

@ -0,0 +1,50 @@
//go:build windows
// +build windows
package windowskext
import "golang.org/x/sys/windows"
const (
METHOD_BUFFERED = 0
METHOD_IN_DIRECT = 1
METHOD_OUT_DIRECT = 2
METHOD_NEITHER = 3
SIOCTL_TYPE = 40000
)
var (
IOCTL_HELLO = ctl_code(SIOCTL_TYPE, 0x800, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_RECV_VERDICT_REQ_POLL = ctl_code(SIOCTL_TYPE, 0x801, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_RECV_VERDICT_REQ = ctl_code(SIOCTL_TYPE, 0x802, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_SET_VERDICT = ctl_code(SIOCTL_TYPE, 0x803, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_GET_PAYLOAD = ctl_code(SIOCTL_TYPE, 0x804, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_CLEAR_CACHE = ctl_code(SIOCTL_TYPE, 0x805, METHOD_BUFFERED, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
IOCTL_TEST = ctl_code(SIOCTL_TYPE, 0x806, METHOD_NEITHER, windows.FILE_READ_DATA|windows.FILE_WRITE_DATA)
)
func ctl_code(device_type, function, method, access uint32) uint32 {
return (device_type << 16) | (access << 14) | (function << 2) | method
}
func deviceIoControl(code uint32, data *byte, size uintptr) (uint32, error) {
var bytesReturned uint32
err := windows.DeviceIoControl(kextHandle,
code,
nil, 0,
data, uint32(size),
&bytesReturned, nil)
return bytesReturned, err
}
func deviceIoControlBufferd(code uint32, inData *byte, inSize uintptr, outData *byte, outSize uintptr) (uint32, error) {
var bytesReturned uint32
err := windows.DeviceIoControl(kextHandle,
code,
inData, uint32(inSize),
outData, uint32(outSize),
&bytesReturned, nil)
return bytesReturned, err
}