[windows_kext] fix all linter error

This commit is contained in:
Vladimir Stoilov 2024-05-16 16:21:27 +03:00
parent 5610c88208
commit 1d6228ea7b
No known key found for this signature in database
GPG key ID: 2F190B67A43A81AF
18 changed files with 137 additions and 131 deletions

View file

@ -44,7 +44,7 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch
conn := packetInfo.ConnectionV4 conn := packetInfo.ConnectionV4
// New Packet // New Packet
newPacket := &Packet{ newPacket := &Packet{
verdictRequest: conn.Id, verdictRequest: conn.ID,
payload: conn.Payload, payload: conn.Payload,
verdictSet: abool.NewBool(false), verdictSet: abool.NewBool(false),
} }
@ -52,7 +52,7 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch
info.Inbound = conn.Direction > 0 info.Inbound = conn.Direction > 0
info.InTunnel = false info.InTunnel = false
info.Protocol = packet.IPProtocol(conn.Protocol) info.Protocol = packet.IPProtocol(conn.Protocol)
info.PID = int(conn.ProcessId) info.PID = int(conn.ProcessID)
info.SeenAt = time.Now() info.SeenAt = time.Now()
// Check PID // Check PID
@ -68,12 +68,12 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch
// Set IPs // Set IPs
if info.Inbound { if info.Inbound {
// Inbound // Inbound
info.Src = conn.RemoteIp[:] info.Src = conn.RemoteIP[:]
info.Dst = conn.LocalIp[:] info.Dst = conn.LocalIP[:]
} else { } else {
// Outbound // Outbound
info.Src = conn.LocalIp[:] info.Src = conn.LocalIP[:]
info.Dst = conn.RemoteIp[:] info.Dst = conn.RemoteIP[:]
} }
// Set Ports // Set Ports
@ -95,7 +95,7 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch
conn := packetInfo.ConnectionV6 conn := packetInfo.ConnectionV6
// New Packet // New Packet
newPacket := &Packet{ newPacket := &Packet{
verdictRequest: conn.Id, verdictRequest: conn.ID,
payload: conn.Payload, payload: conn.Payload,
verdictSet: abool.NewBool(false), verdictSet: abool.NewBool(false),
} }
@ -103,7 +103,7 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch
info.Inbound = conn.Direction > 0 info.Inbound = conn.Direction > 0
info.InTunnel = false info.InTunnel = false
info.Protocol = packet.IPProtocol(conn.Protocol) info.Protocol = packet.IPProtocol(conn.Protocol)
info.PID = int(conn.ProcessId) info.PID = int(conn.ProcessID)
info.SeenAt = time.Now() info.SeenAt = time.Now()
// Check PID // Check PID
@ -119,12 +119,12 @@ func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate ch
// Set IPs // Set IPs
if info.Inbound { if info.Inbound {
// Inbound // Inbound
info.Src = conn.RemoteIp[:] info.Src = conn.RemoteIP[:]
info.Dst = conn.LocalIp[:] info.Dst = conn.LocalIP[:]
} else { } else {
// Outbound // Outbound
info.Src = conn.LocalIp[:] info.Src = conn.LocalIP[:]
info.Dst = conn.RemoteIp[:] info.Dst = conn.RemoteIP[:]
} }
// Set Ports // Set Ports

View file

@ -8,7 +8,7 @@ import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/windows_kext/kext_interface" "github.com/safing/portmaster/windows_kext/kextinterface"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
@ -16,8 +16,8 @@ import (
var ( var (
driverPath string driverPath string
service *kext_interface.KextService service *kextinterface.KextService
kextFile *kext_interface.KextFile kextFile *kextinterface.KextFile
) )
const ( const (
@ -31,10 +31,9 @@ func Init(path string) error {
// Start intercepting. // Start intercepting.
func Start() error { func Start() error {
// initialize and start driver service // initialize and start driver service
var err error var err error
service, err = kext_interface.CreateKextService(driverName, driverPath) service, err = kextinterface.CreateKextService(driverName, driverPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to create service: %w", err) return fmt.Errorf("failed to create service: %w", err)
} }
@ -86,46 +85,46 @@ func Stop() error {
// Sends a shutdown request. // Sends a shutdown request.
func shutdownRequest() error { func shutdownRequest() error {
return kext_interface.SendShutdownCommand(kextFile) return kextinterface.SendShutdownCommand(kextFile)
} }
// Send request for logs of the kext. // Send request for logs of the kext.
func SendLogRequest() error { func SendLogRequest() error {
return kext_interface.SendGetLogsCommand(kextFile) return kextinterface.SendGetLogsCommand(kextFile)
} }
func SendBandwidthStatsRequest() error { func SendBandwidthStatsRequest() error {
return kext_interface.SendGetBandwidthStatsCommand(kextFile) return kextinterface.SendGetBandwidthStatsCommand(kextFile)
} }
func SendPrintMemoryStatsCommand() error { func SendPrintMemoryStatsCommand() error {
return kext_interface.SendPrintMemoryStatsCommand(kextFile) return kextinterface.SendPrintMemoryStatsCommand(kextFile)
} }
func SendCleanEndedConnection() error { func SendCleanEndedConnection() error {
return kext_interface.SendCleanEndedConnectionsCommand(kextFile) return kextinterface.SendCleanEndedConnectionsCommand(kextFile)
} }
// RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil. // RecvVerdictRequest waits for the next verdict request from the kext. If a timeout is reached, both *VerdictRequest and error will be nil.
func RecvVerdictRequest() (*kext_interface.Info, error) { func RecvVerdictRequest() (*kextinterface.Info, error) {
return kext_interface.RecvInfo(kextFile) return kextinterface.RecvInfo(kextFile)
} }
// SetVerdict sets the verdict for a packet and/or connection. // SetVerdict sets the verdict for a packet and/or connection.
func SetVerdict(pkt *Packet, verdict kext_interface.KextVerdict) error { func SetVerdict(pkt *Packet, verdict kextinterface.KextVerdict) error {
verdictCommand := kext_interface.Verdict{Id: pkt.verdictRequest, Verdict: uint8(verdict)} verdictCommand := kextinterface.Verdict{ID: pkt.verdictRequest, Verdict: uint8(verdict)}
return kext_interface.SendVerdictCommand(kextFile, verdictCommand) return kextinterface.SendVerdictCommand(kextFile, verdictCommand)
} }
// Clears the internal connection cache. // Clears the internal connection cache.
func ClearCache() error { func ClearCache() error {
return kext_interface.SendClearCacheCommand(kextFile) return kextinterface.SendClearCacheCommand(kextFile)
} }
// Updates a specific connection verdict. // Updates a specific connection verdict.
func UpdateVerdict(conn *network.Connection) error { func UpdateVerdict(conn *network.Connection) error {
if conn.IPVersion == 4 { if conn.IPVersion == 4 {
update := kext_interface.UpdateV4{ update := kextinterface.UpdateV4{
Protocol: conn.Entity.Protocol, Protocol: conn.Entity.Protocol,
LocalAddress: [4]byte(conn.LocalIP), LocalAddress: [4]byte(conn.LocalIP),
LocalPort: conn.LocalPort, LocalPort: conn.LocalPort,
@ -134,9 +133,9 @@ func UpdateVerdict(conn *network.Connection) error {
Verdict: uint8(conn.Verdict), Verdict: uint8(conn.Verdict),
} }
return kext_interface.SendUpdateV4Command(kextFile, update) return kextinterface.SendUpdateV4Command(kextFile, update)
} else if conn.IPVersion == 6 { } else if conn.IPVersion == 6 {
update := kext_interface.UpdateV6{ update := kextinterface.UpdateV6{
Protocol: conn.Entity.Protocol, Protocol: conn.Entity.Protocol,
LocalAddress: [16]byte(conn.LocalIP), LocalAddress: [16]byte(conn.LocalIP),
LocalPort: conn.LocalPort, LocalPort: conn.LocalPort,
@ -145,14 +144,14 @@ func UpdateVerdict(conn *network.Connection) error {
Verdict: uint8(conn.Verdict), Verdict: uint8(conn.Verdict),
} }
return kext_interface.SendUpdateV6Command(kextFile, update) return kextinterface.SendUpdateV6Command(kextFile, update)
} }
return nil return nil
} }
// Returns the kext version. // Returns the kext version.
func GetVersion() (*VersionInfo, error) { func GetVersion() (*VersionInfo, error) {
data, err := kext_interface.ReadVersion(kextFile) data, err := kextinterface.ReadVersion(kextFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -10,7 +10,7 @@ import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/windows_kext/kext_interface" "github.com/safing/portmaster/windows_kext/kextinterface"
) )
// Packet represents an IP packet. // Packet represents an IP packet.
@ -70,7 +70,7 @@ func (pkt *Packet) LoadPacketData() error {
// Accept accepts the packet. // Accept accepts the packet.
func (pkt *Packet) Accept() error { func (pkt *Packet) Accept() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kext_interface.VerdictAccept) return SetVerdict(pkt, kextinterface.VerdictAccept)
} }
return nil return nil
} }
@ -78,7 +78,7 @@ func (pkt *Packet) Accept() error {
// Block blocks the packet. // Block blocks the packet.
func (pkt *Packet) Block() error { func (pkt *Packet) Block() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kext_interface.VerdictBlock) return SetVerdict(pkt, kextinterface.VerdictBlock)
} }
return nil return nil
} }
@ -86,7 +86,7 @@ func (pkt *Packet) Block() error {
// Drop drops the packet. // Drop drops the packet.
func (pkt *Packet) Drop() error { func (pkt *Packet) Drop() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kext_interface.VerdictDrop) return SetVerdict(pkt, kextinterface.VerdictDrop)
} }
return nil return nil
} }
@ -94,7 +94,7 @@ func (pkt *Packet) Drop() error {
// PermanentAccept permanently accepts connection (and the current packet). // PermanentAccept permanently accepts connection (and the current packet).
func (pkt *Packet) PermanentAccept() error { func (pkt *Packet) PermanentAccept() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kext_interface.VerdictPermanentAccept) return SetVerdict(pkt, kextinterface.VerdictPermanentAccept)
} }
return nil return nil
} }
@ -102,7 +102,7 @@ func (pkt *Packet) PermanentAccept() error {
// PermanentBlock permanently blocks connection (and the current packet). // PermanentBlock permanently blocks connection (and the current packet).
func (pkt *Packet) PermanentBlock() error { func (pkt *Packet) PermanentBlock() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kext_interface.VerdictPermanentBlock) return SetVerdict(pkt, kextinterface.VerdictPermanentBlock)
} }
return nil return nil
} }
@ -110,7 +110,7 @@ func (pkt *Packet) PermanentBlock() error {
// PermanentDrop permanently drops connection (and the current packet). // PermanentDrop permanently drops connection (and the current packet).
func (pkt *Packet) PermanentDrop() error { func (pkt *Packet) PermanentDrop() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kext_interface.VerdictPermanentDrop) return SetVerdict(pkt, kextinterface.VerdictPermanentDrop)
} }
return nil return nil
} }
@ -118,7 +118,7 @@ func (pkt *Packet) PermanentDrop() error {
// RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet). // RerouteToNameserver permanently reroutes the connection to the local nameserver (and the current packet).
func (pkt *Packet) RerouteToNameserver() error { func (pkt *Packet) RerouteToNameserver() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kext_interface.VerdictRerouteToNameserver) return SetVerdict(pkt, kextinterface.VerdictRerouteToNameserver)
} }
return nil return nil
} }
@ -126,7 +126,7 @@ func (pkt *Packet) RerouteToNameserver() error {
// RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet). // RerouteToTunnel permanently reroutes the connection to the local tunnel entrypoint (and the current packet).
func (pkt *Packet) RerouteToTunnel() error { func (pkt *Packet) RerouteToTunnel() error {
if pkt.verdictSet.SetToIf(false, true) { if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kext_interface.VerdictRerouteToTunnel) return SetVerdict(pkt, kextinterface.VerdictRerouteToTunnel)
} }
return nil return nil
} }

View file

@ -3,8 +3,8 @@
package windowskext package windowskext
import "github.com/safing/portmaster/windows_kext/kext_interface" import "github.com/safing/portmaster/windows_kext/kextinterface"
func createKextService(driverName string, driverPath string) (*kext_interface.KextService, error) { func createKextService(driverName string, driverPath string) (*kextinterface.KextService, error) {
return kext_interface.CreateKextService(driverName, driverPath) return kextinterface.CreateKextService(driverName, driverPath)
} }

View file

@ -7,7 +7,7 @@ use wdk::{err, info, interface};
use windows_sys::Wdk::Foundation::{DEVICE_OBJECT, DRIVER_OBJECT, IRP}; use windows_sys::Wdk::Foundation::{DEVICE_OBJECT, DRIVER_OBJECT, IRP};
use windows_sys::Win32::Foundation::{NTSTATUS, STATUS_SUCCESS}; use windows_sys::Win32::Foundation::{NTSTATUS, STATUS_SUCCESS};
static VERSION: [u8; 4] = include!("../../kext_interface/version.txt"); static VERSION: [u8; 4] = include!("../../kextinterface/version.txt");
static mut DEVICE: *mut device::Device = core::ptr::null_mut(); static mut DEVICE: *mut device::Device = core::ptr::null_mut();
pub fn get_device() -> Option<&'static mut device::Device> { pub fn get_device() -> Option<&'static mut device::Device> {

View file

@ -1,4 +1,4 @@
package kext_interface package kextinterface
import ( import (
"encoding/binary" "encoding/binary"
@ -44,20 +44,6 @@ type Verdict struct {
Verdict uint8 Verdict uint8
} }
type RedirectV4 struct {
command uint8
ID uint64
RemoteAddress [4]byte
RemotePort uint16
}
type RedirectV6 struct {
command uint8
ID uint64
RemoteAddress [16]byte
RemotePort uint16
}
type UpdateV4 struct { type UpdateV4 struct {
command uint8 command uint8
Protocol uint8 Protocol uint8

View file

@ -1,4 +1,4 @@
package kext_interface package kextinterface
import ( import (
"encoding/binary" "encoding/binary"
@ -16,11 +16,14 @@ const (
InfoBandwidthStatsV6 = 6 InfoBandwidthStatsV6 = 6
) )
var ErrorUnknownInfoType = errors.New("unknown info type") var (
ErrUnknownInfoType = errors.New("unknown info type")
ErrUnexpectedReadError = errors.New("unexpected read error")
)
type connectionV4Internal struct { type connectionV4Internal struct {
Id uint64 ID uint64
ProcessId uint64 ProcessID uint64
Direction byte Direction byte
Protocol byte Protocol byte
LocalIP [4]byte LocalIP [4]byte
@ -36,8 +39,8 @@ type ConnectionV4 struct {
} }
func (c *ConnectionV4) Compare(other *ConnectionV4) bool { func (c *ConnectionV4) Compare(other *ConnectionV4) bool {
return c.Id == other.Id && return c.ID == other.ID &&
c.ProcessId == other.ProcessId && c.ProcessID == other.ProcessID &&
c.Direction == other.Direction && c.Direction == other.Direction &&
c.Protocol == other.Protocol && c.Protocol == other.Protocol &&
c.LocalIP == other.LocalIP && c.LocalIP == other.LocalIP &&
@ -47,7 +50,7 @@ func (c *ConnectionV4) Compare(other *ConnectionV4) bool {
} }
type connectionV6Internal struct { type connectionV6Internal struct {
Id uint64 ID uint64
ProcessID uint64 ProcessID uint64
Direction byte Direction byte
Protocol byte Protocol byte
@ -64,7 +67,7 @@ type ConnectionV6 struct {
} }
func (c ConnectionV6) Compare(other *ConnectionV6) bool { func (c ConnectionV6) Compare(other *ConnectionV6) bool {
return c.Id == other.Id && return c.ID == other.ID &&
c.ProcessID == other.ProcessID && c.ProcessID == other.ProcessID &&
c.Direction == other.Direction && c.Direction == other.Direction &&
c.Protocol == other.Protocol && c.Protocol == other.Protocol &&
@ -75,21 +78,21 @@ func (c ConnectionV6) Compare(other *ConnectionV6) bool {
} }
type ConnectionEndV4 struct { type ConnectionEndV4 struct {
ProcessId uint64 ProcessID uint64
Direction byte Direction byte
Protocol byte Protocol byte
LocalIp [4]byte LocalIP [4]byte
RemoteIp [4]byte RemoteIP [4]byte
LocalPort uint16 LocalPort uint16
RemotePort uint16 RemotePort uint16
} }
type ConnectionEndV6 struct { type ConnectionEndV6 struct {
ProcessId uint64 ProcessID uint64
Direction byte Direction byte
Protocol byte Protocol byte
LocalIp [16]byte LocalIP [16]byte
RemoteIp [16]byte RemoteIP [16]byte
LocalPort uint16 LocalPort uint16
RemotePort uint16 RemotePort uint16
} }
@ -142,6 +145,9 @@ func RecvInfo(reader io.Reader) (*Info, error) {
// Read size of data // Read size of data
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = binary.Read(reader, binary.LittleEndian, &size)
if err != nil {
return nil, err
}
// Read data // Read data
switch infoType { switch infoType {
@ -150,16 +156,19 @@ func RecvInfo(reader io.Reader) (*Info, error) {
var fixedSizeValues connectionV4Internal var fixedSizeValues connectionV4Internal
err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
// Read size of payload // Read size of payload
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = binary.Read(reader, binary.LittleEndian, &size)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
newInfo := ConnectionV4{connectionV4Internal: fixedSizeValues, Payload: make([]byte, size)} newInfo := ConnectionV4{connectionV4Internal: fixedSizeValues, Payload: make([]byte, size)}
err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload)
if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err)
}
return &Info{ConnectionV4: &newInfo}, nil return &Info{ConnectionV4: &newInfo}, nil
} }
case InfoConnectionIpv6: case InfoConnectionIpv6:
@ -167,47 +176,53 @@ func RecvInfo(reader io.Reader) (*Info, error) {
var fixedSizeValues connectionV6Internal var fixedSizeValues connectionV6Internal
err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
// Read size of payload // Read size of payload
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = binary.Read(reader, binary.LittleEndian, &size)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
newInfo := ConnectionV6{connectionV6Internal: fixedSizeValues, Payload: make([]byte, size)} newInfo := ConnectionV6{connectionV6Internal: fixedSizeValues, Payload: make([]byte, size)}
err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload)
if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err)
}
return &Info{ConnectionV6: &newInfo}, nil return &Info{ConnectionV6: &newInfo}, nil
} }
case InfoConnectionEndEventV4: case InfoConnectionEndEventV4:
{ {
var new ConnectionEndV4 var connectionEnd ConnectionEndV4
err = binary.Read(reader, binary.LittleEndian, &new) err = binary.Read(reader, binary.LittleEndian, &connectionEnd)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
return &Info{ConnectionEndV4: &new}, nil return &Info{ConnectionEndV4: &connectionEnd}, nil
} }
case InfoConnectionEndEventV6: case InfoConnectionEndEventV6:
{ {
var new ConnectionEndV6 var connectionEnd ConnectionEndV6
err = binary.Read(reader, binary.LittleEndian, &new) err = binary.Read(reader, binary.LittleEndian, &connectionEnd)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
return &Info{ConnectionEndV6: &new}, nil return &Info{ConnectionEndV6: &connectionEnd}, nil
} }
case InfoLogLine: case InfoLogLine:
{ {
var logLine = LogLine{} logLine := LogLine{}
// Read severity // Read severity
err = binary.Read(reader, binary.LittleEndian, &logLine.Severity) err = binary.Read(reader, binary.LittleEndian, &logLine.Severity)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
// Read string // Read string
var line = make([]byte, size-1) // -1 for the severity enum. line := make([]byte, size-1) // -1 for the severity enum.
err = binary.Read(reader, binary.LittleEndian, &line) err = binary.Read(reader, binary.LittleEndian, &line)
if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err)
}
logLine.Line = string(line) logLine.Line = string(line)
return &Info{LogLine: &logLine}, nil return &Info{LogLine: &logLine}, nil
} }
@ -217,21 +232,24 @@ func RecvInfo(reader io.Reader) (*Info, error) {
var protocol uint8 var protocol uint8
err = binary.Read(reader, binary.LittleEndian, &protocol) err = binary.Read(reader, binary.LittleEndian, &protocol)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
// Read size of array // Read size of array
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = binary.Read(reader, binary.LittleEndian, &size)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
// Read array // Read array
var stats_array = make([]BandwidthValueV4, size) statsArray := make([]BandwidthValueV4, size)
for i := 0; i < int(size); i++ { for i := 0; i < int(size); i++ {
binary.Read(reader, binary.LittleEndian, &stats_array[i]) err = binary.Read(reader, binary.LittleEndian, &statsArray[i])
if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err)
}
} }
return &Info{BandwidthStats: &BandwidthStatsArray{Protocol: protocol, ValuesV4: stats_array}}, nil return &Info{BandwidthStats: &BandwidthStatsArray{Protocol: protocol, ValuesV4: statsArray}}, nil
} }
case InfoBandwidthStatsV6: case InfoBandwidthStatsV6:
{ {
@ -239,25 +257,31 @@ func RecvInfo(reader io.Reader) (*Info, error) {
var protocol uint8 var protocol uint8
err = binary.Read(reader, binary.LittleEndian, &protocol) err = binary.Read(reader, binary.LittleEndian, &protocol)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
// Read size of array // Read size of array
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = binary.Read(reader, binary.LittleEndian, &size)
if err != nil { if err != nil {
return nil, err return nil, errors.Join(ErrUnexpectedReadError, err)
} }
// Read array // Read array
var stats_array = make([]BandwidthValueV6, size) statsArray := make([]BandwidthValueV6, size)
for i := 0; i < int(size); i++ { for i := 0; i < int(size); i++ {
binary.Read(reader, binary.LittleEndian, &stats_array[i]) err = binary.Read(reader, binary.LittleEndian, &statsArray[i])
if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err)
}
} }
return &Info{BandwidthStats: &BandwidthStatsArray{Protocol: protocol, ValuesV6: stats_array}}, nil return &Info{BandwidthStats: &BandwidthStatsArray{Protocol: protocol, ValuesV6: statsArray}}, nil
} }
} }
// Command not recognized, read until the end of command and return.
// During normal operation this should not happen.
unknownData := make([]byte, size) unknownData := make([]byte, size)
reader.Read(unknownData) _, _ = reader.Read(unknownData)
return nil, ErrorUnknownInfoType
return nil, ErrUnknownInfoType
} }

View file

@ -1,7 +1,7 @@
//go:build windows //go:build windows
// +build windows // +build windows
package kext_interface package kextinterface
import ( import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@ -28,7 +28,6 @@ var (
func ReadVersion(file *KextFile) ([]uint8, error) { func ReadVersion(file *KextFile) ([]uint8, error) {
data := make([]uint8, 4) data := make([]uint8, 4)
_, err := file.deviceIOControl(IOCTL_VERSION, nil, data) _, err := file.deviceIOControl(IOCTL_VERSION, nil, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,7 +1,7 @@
//go:build windows //go:build windows
// +build windows // +build windows
package kext_interface package kextinterface
import ( import (
_ "embed" _ "embed"
@ -36,8 +36,10 @@ var (
}() }()
) )
const winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value const (
const stopServiceTimeoutDuration = time.Duration(30 * time.Second) winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
stopServiceTimeoutDuration = time.Duration(30 * time.Second)
)
type KextService struct { type KextService struct {
handle windows.Handle handle windows.Handle
@ -88,7 +90,6 @@ func (s *KextService) Start(wait bool) error {
// Start the service: // Start the service:
err := windows.StartService(s.handle, 0, nil) err := windows.StartService(s.handle, 0, nil)
if err != nil { if err != nil {
err = windows.GetLastError() err = windows.GetLastError()
if err != windows.ERROR_SERVICE_ALREADY_RUNNING { if err != windows.ERROR_SERVICE_ALREADY_RUNNING {

View file

@ -1,7 +1,7 @@
//go:build windows //go:build windows
// +build windows // +build windows
package kext_interface package kextinterface
import ( import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@ -85,7 +85,6 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (
inDataPtr, inDataSize, inDataPtr, inDataSize,
outDataPtr, outDataSize, outDataPtr, outDataSize,
nil, overlapped) nil, overlapped)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,7 +1,7 @@
//go:build linux //go:build linux
// +build linux // +build linux
package kext_interface package kextinterface
type KextFile struct{} type KextFile struct{}
@ -9,4 +9,4 @@ func (f *KextFile) Read(buffer []byte) (int, error) {
return 0, nil return 0, nil
} }
func (f *KextFile) flushBuffer() {} // func (f *KextFile) flushBuffer() {}

View file

@ -1,9 +1,8 @@
package kext_interface package kextinterface
import ( import (
"bytes" "bytes"
"errors" "errors"
"io"
"math/rand" "math/rand"
"os" "os"
"testing" "testing"
@ -22,7 +21,7 @@ func TestRustInfoFile(t *testing.T) {
for { for {
info, err := RecvInfo(file) info, err := RecvInfo(file)
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, ErrUnexpectedReadError) {
t.Errorf("unexpected error: %s\n", err) t.Errorf("unexpected error: %s\n", err)
} }
return return
@ -40,8 +39,8 @@ func TestRustInfoFile(t *testing.T) {
case info.ConnectionV4 != nil: case info.ConnectionV4 != nil:
conn := info.ConnectionV4 conn := info.ConnectionV4
expected := connectionV4Internal{ expected := connectionV4Internal{
Id: 1, ID: 1,
ProcessId: 2, ProcessID: 2,
Direction: 3, Direction: 3,
Protocol: 4, Protocol: 4,
LocalIP: [4]byte{1, 2, 3, 4}, LocalIP: [4]byte{1, 2, 3, 4},
@ -60,7 +59,7 @@ func TestRustInfoFile(t *testing.T) {
case info.ConnectionV6 != nil: case info.ConnectionV6 != nil:
conn := info.ConnectionV6 conn := info.ConnectionV6
expected := connectionV6Internal{ expected := connectionV6Internal{
Id: 1, ID: 1,
ProcessID: 2, ProcessID: 2,
Direction: 3, Direction: 3,
Protocol: 4, Protocol: 4,
@ -80,11 +79,11 @@ func TestRustInfoFile(t *testing.T) {
case info.ConnectionEndV4 != nil: case info.ConnectionEndV4 != nil:
endEvent := info.ConnectionEndV4 endEvent := info.ConnectionEndV4
expected := ConnectionEndV4{ expected := ConnectionEndV4{
ProcessId: 1, ProcessID: 1,
Direction: 2, Direction: 2,
Protocol: 3, Protocol: 3,
LocalIp: [4]byte{1, 2, 3, 4}, LocalIP: [4]byte{1, 2, 3, 4},
RemoteIp: [4]byte{2, 3, 4, 5}, RemoteIP: [4]byte{2, 3, 4, 5},
LocalPort: 4, LocalPort: 4,
RemotePort: 5, RemotePort: 5,
} }
@ -95,11 +94,11 @@ func TestRustInfoFile(t *testing.T) {
case info.ConnectionEndV6 != nil: case info.ConnectionEndV6 != nil:
endEvent := info.ConnectionEndV6 endEvent := info.ConnectionEndV6
expected := ConnectionEndV6{ expected := ConnectionEndV6{
ProcessId: 1, ProcessID: 1,
Direction: 2, Direction: 2,
Protocol: 3, Protocol: 3,
LocalIp: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, LocalIP: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
RemoteIp: [16]byte{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, RemoteIP: [16]byte{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17},
LocalPort: 4, LocalPort: 4,
RemotePort: 5, RemotePort: 5,
} }
@ -273,5 +272,4 @@ func TestGenerateCommandFile(t *testing.T) {
} }
} }
} }
} }

View file

@ -1,4 +1,4 @@
# Protocol # Protocol
Defines protocol that communicates with `kext_interface` / Portmaster. Defines protocol that communicates with `kextinterface` / Portmaster.

View file

@ -86,7 +86,7 @@ use std::panic;
#[test] #[test]
fn test_go_command_file() { fn test_go_command_file() {
let mut file = File::open("../kext_interface/go_command_test.bin").unwrap(); let mut file = File::open("../kextinterface/go_command_test.bin").unwrap();
loop { loop {
let mut command: [u8; 1] = [0]; let mut command: [u8; 1] = [0];
let bytes_count = file.read(&mut command).unwrap(); let bytes_count = file.read(&mut command).unwrap();

View file

@ -1,7 +1,7 @@
# Kext release tool # Kext release tool
### Generate the zip file ### Generate the zip file
- Make sure `kext_interface/version.txt` is up to date - Make sure `kextinterface/version.txt` is up to date
- Execute: `cargo run` - Execute: `cargo run`
* This will generate release `kext_release_vX-X-X.zip` file. Which contains all the necessary files to make the release. * This will generate release `kext_release_vX-X-X.zip` file. Which contains all the necessary files to make the release.

View file

@ -5,7 +5,7 @@ use handlebars::Handlebars;
use serde_json::json; use serde_json::json;
use zip::{write::FileOptions, ZipWriter}; use zip::{write::FileOptions, ZipWriter};
static VERSION: [u8; 4] = include!("../../kext_interface/version.txt"); static VERSION: [u8; 4] = include!("../../kextinterface/version.txt");
static LIB_PATH: &'static str = "./build/x86_64-pc-windows-msvc/release/driver.lib"; static LIB_PATH: &'static str = "./build/x86_64-pc-windows-msvc/release/driver.lib";
fn main() { fn main() {

View file

@ -5,7 +5,7 @@ echo ========================
cd protocol cd protocol
cargo test info::generate_test_info_file cargo test info::generate_test_info_file
cd ../kext_interface cd ../kextinterface
go test -v -run TestGenerateCommandFile go test -v -run TestGenerateCommandFile
cd .. cd ..
@ -15,7 +15,7 @@ echo ========================
cd protocol cd protocol
cargo test command::test_go_command_file cargo test command::test_go_command_file
cd ../kext_interface cd ../kextinterface
go test -v -run TestRustInfoFile go test -v -run TestRustInfoFile
echo ======================== echo ========================