diff --git a/service/firewall/interception/windowskext2/handler.go b/service/firewall/interception/windowskext2/handler.go index bb6348dd..166ebbc2 100644 --- a/service/firewall/interception/windowskext2/handler.go +++ b/service/firewall/interception/windowskext2/handler.go @@ -5,11 +5,13 @@ package windowskext import ( "context" + "errors" "fmt" "net" "time" "github.com/safing/portmaster/service/process" + "github.com/safing/portmaster/windows_kext/kextinterface" "github.com/tevino/abool" @@ -32,8 +34,15 @@ func (v *VersionInfo) String() string { func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate chan *packet.BandwidthUpdate) { for { packetInfo, err := RecvVerdictRequest() + + if errors.Is(err, kextinterface.ErrUnexpectedInfoSize) || errors.Is(err, kextinterface.ErrUnexpectedReadError) { + log.Criticalf("unexpected kext info data: %s", err) + continue // Depending on the info type this may not affect the functionality. Try to continue reading the next commands. + } + if err != nil { log.Warningf("failed to get packet from windows kext: %s", err) + // Probably IO error, nothing else we can do. return } diff --git a/service/firewall/interception/windowskext2/kext.go b/service/firewall/interception/windowskext2/kext.go index fd6adb72..fa5a8f0d 100644 --- a/service/firewall/interception/windowskext2/kext.go +++ b/service/firewall/interception/windowskext2/kext.go @@ -39,9 +39,12 @@ func Start() error { } // Start service and open file - service.Start(true) - kextFile, err = service.OpenFile(1024) + err = service.Start(true) + if err != nil { + log.Errorf("failed to start service: %s", err) + } + kextFile, err = service.OpenFile(1024) if err != nil { return fmt.Errorf("failed to open driver: %w", err) } @@ -130,7 +133,7 @@ func UpdateVerdict(conn *network.Connection) error { LocalPort: conn.LocalPort, RemoteAddress: [4]byte(conn.Entity.IP), RemotePort: conn.Entity.Port, - Verdict: uint8(conn.Verdict), + Verdict: uint8(getKextVerdictFromConnection(conn)), } return kextinterface.SendUpdateV4Command(kextFile, update) @@ -141,7 +144,7 @@ func UpdateVerdict(conn *network.Connection) error { LocalPort: conn.LocalPort, RemoteAddress: [16]byte(conn.Entity.IP), RemotePort: conn.Entity.Port, - Verdict: uint8(conn.Verdict), + Verdict: uint8(getKextVerdictFromConnection(conn)), } return kextinterface.SendUpdateV6Command(kextFile, update) @@ -149,6 +152,40 @@ func UpdateVerdict(conn *network.Connection) error { return nil } +func getKextVerdictFromConnection(conn *network.Connection) kextinterface.KextVerdict { + switch conn.Verdict { + case network.VerdictUndecided: + return kextinterface.VerdictUndecided + case network.VerdictUndeterminable: + return kextinterface.VerdictUndeterminable + case network.VerdictAccept: + if conn.VerdictPermanent { + return kextinterface.VerdictPermanentAccept + } else { + return kextinterface.VerdictAccept + } + case network.VerdictBlock: + if conn.VerdictPermanent { + return kextinterface.VerdictPermanentBlock + } else { + return kextinterface.VerdictBlock + } + case network.VerdictDrop: + if conn.VerdictPermanent { + return kextinterface.VerdictPermanentDrop + } else { + return kextinterface.VerdictDrop + } + case network.VerdictRerouteToNameserver: + return kextinterface.VerdictRerouteToNameserver + case network.VerdictRerouteToTunnel: + return kextinterface.VerdictRerouteToTunnel + case network.VerdictFailed: + return kextinterface.VerdictFailed + } + return kextinterface.VerdictUndeterminable +} + // Returns the kext version. func GetVersion() (*VersionInfo, error) { data, err := kextinterface.ReadVersion(kextFile) diff --git a/windows_kext/driver/Cargo.lock b/windows_kext/driver/Cargo.lock index 0374a4c7..b8746745 100644 --- a/windows_kext/driver/Cargo.lock +++ b/windows_kext/driver/Cargo.lock @@ -340,19 +340,20 @@ checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" [[package]] name = "windows-sys" version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" dependencies = [ "windows-targets", ] [[package]] name = "windows-targets" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", "windows_i686_gnu", + "windows_i686_gnullvm", "windows_i686_msvc", "windows_x86_64_gnu", "windows_x86_64_gnullvm", @@ -361,38 +362,43 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_i686_gnu" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_i686_msvc" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "zerocopy" diff --git a/windows_kext/driver/Cargo.toml b/windows_kext/driver/Cargo.toml index 09ca639d..66dffaca 100644 --- a/windows_kext/driver/Cargo.toml +++ b/windows_kext/driver/Cargo.toml @@ -22,5 +22,5 @@ hashbrown = { version = "0.14.3", default-features = false, features = ["ahash"] # WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels. [dependencies.windows-sys] git = "https://github.com/microsoft/windows-rs" -rev = "41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +rev = "dffa8b03dc4987c278d82e88015ffe96aa8ac317" features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_SystemServices", "Win32_Foundation", "Win32_Security", "Win32_System_IO", "Win32_System_Kernel", "Win32_System_Power", "Win32_System_WindowsProgramming", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_WindowsFilteringPlatform"] diff --git a/windows_kext/driver/src/callouts.rs b/windows_kext/driver/src/callouts.rs index 0999a707..d49c7f07 100644 --- a/windows_kext/driver/src/callouts.rs +++ b/windows_kext/driver/src/callouts.rs @@ -12,8 +12,8 @@ pub fn get_callout_vec() -> Vec { // ----------------------------------------- // ALE Auth layers Callout::new( - "AleLayerOutboundV4", - "ALE layer for outbound connection for ipv4", + "Portmaster ALE Outbound IPv4", + "Portmaster uses this layer to block/permit outgoing ipv4 connections", 0x58545073_f893_454c_bbea_a57bc964f46d, Layer::AleAuthConnectV4, consts::FWP_ACTION_CALLOUT_TERMINATING, @@ -21,8 +21,8 @@ pub fn get_callout_vec() -> Vec { ale_callouts::ale_layer_connect_v4, ), Callout::new( - "AleLayerOutboundV6", - "ALE layer for outbound connections for ipv6", + "Portmaster ALE Outbound IPv6", + "Portmaster uses this layer to block/permit outgoing ipv6 connections", 0x4bd2a080_2585_478d_977c_7f340c6bc3d4, Layer::AleAuthConnectV6, consts::FWP_ACTION_CALLOUT_TERMINATING, @@ -32,8 +32,8 @@ pub fn get_callout_vec() -> Vec { // ----------------------------------------- // ALE connection end layers Callout::new( - "AleEndpointClosureV4", - "ALE layer for indicating closing of connection for ipv4", + "Portmaster Endpoint Closure IPv4", + "Portmaster uses this layer to detect when a IPv4 connection has ended", 0x58f02845_ace9_4455_ac80_8a84b86fe566, Layer::AleEndpointClosureV4, consts::FWP_ACTION_CALLOUT_INSPECTION, @@ -41,8 +41,8 @@ pub fn get_callout_vec() -> Vec { ale_callouts::endpoint_closure_v4, ), Callout::new( - "AleEndpointClosureV6", - "ALE layer for indicating closing of connection for ipv6", + "Portmaster Endpoint Closure IPv6", + "Portmaster uses this layer to detect when a IPv6 connection has ended", 0x2bc82359_9dc5_4315_9c93_c89467e283ce, Layer::AleEndpointClosureV6, consts::FWP_ACTION_CALLOUT_INSPECTION, @@ -61,8 +61,8 @@ pub fn get_callout_vec() -> Vec { // ale_callouts::ale_resource_monitor, // ), Callout::new( - "AleResourceReleaseV4", - "Ipv4 Port release monitor", + "Portmaster resource release IPv4", + "Portmaster uses this layer to detect when a IPv4 port has been released", 0x7b513bb3_a0be_4f77_a4bc_03c052abe8d7, Layer::AleResourceReleaseV4, consts::FWP_ACTION_CALLOUT_INSPECTION, @@ -79,8 +79,8 @@ pub fn get_callout_vec() -> Vec { // ale_callouts::ale_resource_monitor, // ), Callout::new( - "AleResourceReleaseV6", - "Ipv6 Port release monitor", + "Portmaster resource release IPv6", + "Portmaster uses this layer to detect when a IPv6 port has been released", 0x6cf36e04_e656_42c3_8cac_a1ce05328bd1, Layer::AleResourceReleaseV6, consts::FWP_ACTION_CALLOUT_INSPECTION, @@ -90,8 +90,8 @@ pub fn get_callout_vec() -> Vec { // ----------------------------------------- // Stream layer Callout::new( - "StreamLayerV4", - "Stream layer for ipv4", + "Portmaster Stream IPv4", + "Portmaster uses this layer for bandwidth statistics of IPv4 TCP connections", 0xe2ca13bf_9710_4caa_a45c_e8c78b5ac780, Layer::StreamV4, consts::FWP_ACTION_CALLOUT_INSPECTION, @@ -99,8 +99,8 @@ pub fn get_callout_vec() -> Vec { stream_callouts::stream_layer_tcp_v4, ), Callout::new( - "StreamLayerV6", - "Stream layer for ipv6", + "Portmaster Stream IPv6", + "Portmaster uses this layer for bandwidth statistics of IPv6 TCP connections", 0x66c549b3_11e2_4b27_8f73_856e6fd82baa, Layer::StreamV6, consts::FWP_ACTION_CALLOUT_INSPECTION, @@ -108,8 +108,8 @@ pub fn get_callout_vec() -> Vec { stream_callouts::stream_layer_tcp_v6, ), Callout::new( - "DatagramDataLayerV4", - "DatagramData layer for ipv4", + "Portmaster Datagram IPv4", + "Portmaster uses this layer for bandwidth statistics of IPv4 UDP connections", 0xe7eeeaba_168a_45bb_8747_e1a702feb2c5, Layer::DatagramDataV4, consts::FWP_ACTION_CALLOUT_INSPECTION, @@ -117,8 +117,8 @@ pub fn get_callout_vec() -> Vec { stream_callouts::stream_layer_udp_v4, ), Callout::new( - "DatagramDataLayerV6", - "DatagramData layer for ipv4", + "Portmaster Datagram IPv6", + "Portmaster uses this layer for bandwidth statistics of IPv6 UDP connections", 0xb25862cd_f744_4452_b14a_d0c1e5a25b30, Layer::DatagramDataV6, consts::FWP_ACTION_CALLOUT_INSPECTION, @@ -128,8 +128,8 @@ pub fn get_callout_vec() -> Vec { // ----------------------------------------- // Packet layers Callout::new( - "IPPacketOutboundV4", - "IP packet outbound network layer callout for Ipv4", + "Portmaster Packet Outbound IPv4", + "Portmaster uses this layer to redirect/block/permit outgoing ipv4 packets", 0xf3183afe_dc35_49f1_8ea2_b16b5666dd36, Layer::OutboundIppacketV4, consts::FWP_ACTION_CALLOUT_TERMINATING, @@ -137,8 +137,8 @@ pub fn get_callout_vec() -> Vec { packet_callouts::ip_packet_layer_outbound_v4, ), Callout::new( - "IPPacketInboundV4", - "IP packet inbound network layer callout for Ipv4", + "Portmaster Packet Inbound IPv4", + "Portmaster uses this layer to redirect/block/permit inbound ipv4 packets", 0xf0369374_203d_4bf0_83d2_b2ad3cc17a50, Layer::InboundIppacketV4, consts::FWP_ACTION_CALLOUT_TERMINATING, @@ -146,8 +146,8 @@ pub fn get_callout_vec() -> Vec { packet_callouts::ip_packet_layer_inbound_v4, ), Callout::new( - "IPPacketOutboundV6", - "IP packet outbound network layer callout for Ipv6", + "Portmaster Packet Outbound IPv6", + "Portmaster uses this layer to redirect/block/permit outgoing ipv6 packets", 0x91daf8bc_0908_4bf8_9f81_2c538ab8f25a, Layer::OutboundIppacketV6, consts::FWP_ACTION_CALLOUT_TERMINATING, @@ -155,8 +155,8 @@ pub fn get_callout_vec() -> Vec { packet_callouts::ip_packet_layer_outbound_v6, ), Callout::new( - "IPPacketInboundV6", - "IP packet inbound network layer callout for Ipv6", + "Portmaster Packet Inbound IPv6", + "Portmaster uses this layer to redirect/block/permit inbound ipv6 packets", 0xfe9faf5f_ceb2_4cd9_9995_f2f2b4f5fcc0, Layer::InboundIppacketV6, consts::FWP_ACTION_CALLOUT_TERMINATING, diff --git a/windows_kext/driver/src/entry.rs b/windows_kext/driver/src/entry.rs index 479fe42a..513c004b 100644 --- a/windows_kext/driver/src/entry.rs +++ b/windows_kext/driver/src/entry.rs @@ -38,10 +38,10 @@ pub extern "system" fn driver_entry( }; // Set driver functions. - driver.set_driver_unload(driver_unload); - driver.set_read_fn(driver_read); - driver.set_write_fn(driver_write); - driver.set_device_control_fn(device_control); + driver.set_driver_unload(Some(driver_unload)); + driver.set_read_fn(Some(driver_read)); + driver.set_write_fn(Some(driver_write)); + driver.set_device_control_fn(Some(device_control)); // Initialize device. unsafe { @@ -70,10 +70,10 @@ unsafe extern "system" fn driver_unload(_object: *const DRIVER_OBJECT) { // driver_read event triggered from user-space on file.Read. unsafe extern "system" fn driver_read( - _device_object: &mut DEVICE_OBJECT, - irp: &mut IRP, + _device_object: *const DEVICE_OBJECT, + irp: *mut IRP, ) -> NTSTATUS { - let mut read_request = ReadRequest::new(irp); + let mut read_request = ReadRequest::new(irp.as_mut().unwrap()); let Some(device) = get_device() else { read_request.complete(); @@ -86,10 +86,10 @@ unsafe extern "system" fn driver_read( /// driver_write event triggered from user-space on file.Write. unsafe extern "system" fn driver_write( - _device_object: &mut DEVICE_OBJECT, - irp: &mut IRP, + _device_object: *const DEVICE_OBJECT, + irp: *mut IRP, ) -> NTSTATUS { - let mut write_request = WriteRequest::new(irp); + let mut write_request = WriteRequest::new(irp.as_mut().unwrap()); let Some(device) = get_device() else { write_request.complete(); return write_request.get_status(); @@ -104,10 +104,10 @@ unsafe extern "system" fn driver_write( /// device_control event triggered from user-space on file.deviceIOControl. unsafe extern "system" fn device_control( - _device_object: &mut DEVICE_OBJECT, - irp: &mut IRP, + _device_object: *const DEVICE_OBJECT, + irp: *mut IRP, ) -> NTSTATUS { - let mut control_request = DeviceControlRequest::new(irp); + let mut control_request = DeviceControlRequest::new(irp.as_mut().unwrap()); let Some(device) = get_device() else { control_request.complete(); return control_request.get_status(); diff --git a/windows_kext/driver/src/packet_callouts.rs b/windows_kext/driver/src/packet_callouts.rs index a1b5733a..fb3ee90b 100644 --- a/windows_kext/driver/src/packet_callouts.rs +++ b/windows_kext/driver/src/packet_callouts.rs @@ -140,7 +140,7 @@ fn ip_packet_layer( } { Ok(key) => key, Err(err) => { - crate::warn!("failed to get key from nbl: {}", err); + crate::dbg!("failed to get key from nbl: {}", err); return; } }; diff --git a/windows_kext/kextinterface/info.go b/windows_kext/kextinterface/info.go index 763c3e8e..a2f5cd91 100644 --- a/windows_kext/kextinterface/info.go +++ b/windows_kext/kextinterface/info.go @@ -3,6 +3,7 @@ package kextinterface import ( "encoding/binary" "errors" + "fmt" "io" ) @@ -18,6 +19,7 @@ const ( var ( ErrUnknownInfoType = errors.New("unknown info type") + ErrUnexpectedInfoSize = errors.New("unexpected info size") ErrUnexpectedReadError = errors.New("unexpected read error") ) @@ -135,117 +137,215 @@ type Info struct { BandwidthStats *BandwidthStatsArray } -func RecvInfo(reader io.Reader) (*Info, error) { - var infoType byte - err := binary.Read(reader, binary.LittleEndian, &infoType) +type readHelper struct { + infoType byte + commandSize uint32 + + readSize int + + reader io.Reader +} + +func newReadHelper(reader io.Reader) (*readHelper, error) { + helper := &readHelper{reader: reader} + + err := binary.Read(reader, binary.LittleEndian, &helper.infoType) if err != nil { return nil, err } - // Read size of data - var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + err = binary.Read(reader, binary.LittleEndian, &helper.commandSize) if err != nil { return nil, err } + return helper, nil +} + +func (r *readHelper) ReadData(data any) error { + err := binary.Read(r, binary.LittleEndian, data) + if err != nil { + return errors.Join(ErrUnexpectedReadError, err) + } + + if err := r.checkOverRead(); err != nil { + return err + } + + return nil +} + +// Passing size = 0 will read the rest of the command. +func (r *readHelper) ReadBytes(size uint32) ([]byte, error) { + if uint32(r.readSize) >= r.commandSize { + return nil, errors.Join(fmt.Errorf("cannot read more bytes than the command size: %d >= %d", r.readSize, r.commandSize), ErrUnexpectedReadError) + } + + if size == 0 { + size = r.commandSize - uint32(r.readSize) + } + + if r.commandSize < uint32(r.readSize)+size { + return nil, ErrUnexpectedInfoSize + } + + bytes := make([]byte, size) + err := binary.Read(r, binary.LittleEndian, bytes) + if err != nil { + return nil, errors.Join(ErrUnexpectedReadError, err) + } + + return bytes, nil +} + +func (r *readHelper) ReadUntilTheEnd() { + _, _ = r.ReadBytes(0) +} + +func (r *readHelper) checkOverRead() error { + if uint32(r.readSize) > r.commandSize { + return ErrUnexpectedInfoSize + } + + return nil +} + +func (r *readHelper) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.readSize += n + return +} + +func RecvInfo(reader io.Reader) (*Info, error) { + helper, err := newReadHelper(reader) + if err != nil { + return nil, err + } + + // Make sure the whole command is read before return. + defer helper.ReadUntilTheEnd() + // Read data - switch infoType { + switch helper.infoType { case InfoConnectionIpv4: { + parseError := fmt.Errorf("failed to parse InfoConnectionIpv4") + newInfo := ConnectionV4{} var fixedSizeValues connectionV4Internal - err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) + // Read fixed size values. + err = helper.ReadData(&fixedSizeValues) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err, fmt.Errorf("fixed")) } - // Read size of payload - var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + newInfo.connectionV4Internal = fixedSizeValues + // Read size of payload. + var payloadSize uint32 + err = helper.ReadData(&payloadSize) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err, fmt.Errorf("payloadsize")) } - newInfo := ConnectionV4{connectionV4Internal: fixedSizeValues, Payload: make([]byte, size)} - err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) - if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + + // Check if there is payload. + if payloadSize > 0 { + // Read payload. + newInfo.Payload, err = helper.ReadBytes(payloadSize) + if err != nil { + return nil, errors.Join(parseError, err, fmt.Errorf("payload")) + } } return &Info{ConnectionV4: &newInfo}, nil } case InfoConnectionIpv6: { - var fixedSizeValues connectionV6Internal - err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) + parseError := fmt.Errorf("failed to parse InfoConnectionIpv6") + newInfo := ConnectionV6{} + + // Read fixed size values. + err = helper.ReadData(&newInfo.connectionV6Internal) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } - // Read size of payload - var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + + // Read size of payload. + var payloadSize uint32 + err = helper.ReadData(&payloadSize) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } - newInfo := ConnectionV6{connectionV6Internal: fixedSizeValues, Payload: make([]byte, size)} - err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) - if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + + // Check if there is payload. + if payloadSize > 0 { + // Read payload. + newInfo.Payload, err = helper.ReadBytes(payloadSize) + if err != nil { + return nil, errors.Join(parseError, err) + } } + return &Info{ConnectionV6: &newInfo}, nil } case InfoConnectionEndEventV4: { + parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV4") var connectionEnd ConnectionEndV4 - err = binary.Read(reader, binary.LittleEndian, &connectionEnd) + + // Read fixed size values. + err = helper.ReadData(&connectionEnd) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } return &Info{ConnectionEndV4: &connectionEnd}, nil } case InfoConnectionEndEventV6: { + parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV6") var connectionEnd ConnectionEndV6 - err = binary.Read(reader, binary.LittleEndian, &connectionEnd) + + // Read fixed size values. + err = helper.ReadData(&connectionEnd) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } return &Info{ConnectionEndV6: &connectionEnd}, nil } case InfoLogLine: { + parseError := fmt.Errorf("failed to parse InfoLogLine") logLine := LogLine{} // Read severity - err = binary.Read(reader, binary.LittleEndian, &logLine.Severity) + err = helper.ReadData(&logLine.Severity) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read string - line := make([]byte, size-1) // -1 for the severity enum. - err = binary.Read(reader, binary.LittleEndian, &line) + bytes, err := helper.ReadBytes(0) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } - logLine.Line = string(line) + logLine.Line = string(bytes) return &Info{LogLine: &logLine}, nil } case InfoBandwidthStatsV4: { + parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV4") // Read Protocol var protocol uint8 - err = binary.Read(reader, binary.LittleEndian, &protocol) + err = helper.ReadData(&protocol) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read size of array var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + err = helper.ReadData(&size) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read array statsArray := make([]BandwidthValueV4, size) for i := range int(size) { - err = binary.Read(reader, binary.LittleEndian, &statsArray[i]) + err = helper.ReadData(&statsArray[i]) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } } @@ -253,24 +353,25 @@ func RecvInfo(reader io.Reader) (*Info, error) { } case InfoBandwidthStatsV6: { + parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV6") // Read Protocol var protocol uint8 - err = binary.Read(reader, binary.LittleEndian, &protocol) + err = helper.ReadData(&protocol) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read size of array var size uint32 - err = binary.Read(reader, binary.LittleEndian, &size) + err = helper.ReadData(&size) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } // Read array statsArray := make([]BandwidthValueV6, size) for i := range int(size) { - err = binary.Read(reader, binary.LittleEndian, &statsArray[i]) + err = helper.ReadData(&statsArray[i]) if err != nil { - return nil, errors.Join(ErrUnexpectedReadError, err) + return nil, errors.Join(parseError, err) } } @@ -278,10 +379,5 @@ func RecvInfo(reader io.Reader) (*Info, error) { } } - // Command not recognized, read until the end of command and return. - // During normal operation this should not happen. - unknownData := make([]byte, size) - _, _ = reader.Read(unknownData) - return nil, ErrUnknownInfoType } diff --git a/windows_kext/kextinterface/kext.go b/windows_kext/kextinterface/kext.go index 2707a791..3b0956cc 100644 --- a/windows_kext/kextinterface/kext.go +++ b/windows_kext/kextinterface/kext.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "github.com/safing/portbase/log" "golang.org/x/sys/windows" ) @@ -221,7 +222,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro // Check if there is an old service. service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS) if err == nil { - fmt.Println("kext: old driver service was found") + log.Warning("kext: old driver service was found") oldService := &KextService{handle: service, driverName: driverName} oldService.Stop(true) err = oldService.Delete() @@ -234,7 +235,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro } service = winInvalidHandleValue - fmt.Println("kext: old driver service was deleted successfully") + log.Warning("kext: old driver service was deleted successfully") } driverPathU16, err := syscall.UTF16FromString(driverPath) diff --git a/windows_kext/kextinterface/protocol_test.go b/windows_kext/kextinterface/protocol_test.go index cf047442..35a5264d 100644 --- a/windows_kext/kextinterface/protocol_test.go +++ b/windows_kext/kextinterface/protocol_test.go @@ -18,8 +18,18 @@ func TestRustInfoFile(t *testing.T) { defer func() { _ = file.Close() }() + first := true for { info, err := RecvInfo(file) + // First info should be with invalid size. + // This tests if invalid info data is handled properly. + if first { + if !errors.Is(err, ErrUnexpectedInfoSize) { + t.Errorf("unexpected error: %s\n", err) + } + first = false + continue + } if err != nil { if errors.Is(err, ErrUnexpectedReadError) { t.Errorf("unexpected error: %s\n", err) diff --git a/windows_kext/kextinterface/testdata/rust_info_test.bin b/windows_kext/kextinterface/testdata/rust_info_test.bin index 3f9049a9..3b8588c7 100644 Binary files a/windows_kext/kextinterface/testdata/rust_info_test.bin and b/windows_kext/kextinterface/testdata/rust_info_test.bin differ diff --git a/windows_kext/protocol/src/info.rs b/windows_kext/protocol/src/info.rs index b8eb0c79..cb0e7664 100644 --- a/windows_kext/protocol/src/info.rs +++ b/windows_kext/protocol/src/info.rs @@ -441,6 +441,22 @@ fn generate_test_info_file() -> Result<(), std::io::Error> { for _ in 0..selected.capacity() { selected.push(enums.choose(&mut rng).unwrap().clone()); } + // Write wrong size data. To make sure that mismatches between kext and portmaster are handled properly. + let mut info = connection_info_v6( + 1, + 2, + 3, + 4, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + 5, + 6, + 7, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ); + info.assert_size(); + info.0[0] = InfoType::ConnectionIpv4 as u8; + file.write_all(&info.0)?; for value in selected { file.write_all(&match value { @@ -548,5 +564,6 @@ fn generate_test_info_file() -> Result<(), std::io::Error> { } })?; } + return Ok(()); } diff --git a/windows_kext/protocol/testdata/go_command_test.bin b/windows_kext/protocol/testdata/go_command_test.bin index 586c70ad..d518dbf1 100644 Binary files a/windows_kext/protocol/testdata/go_command_test.bin and b/windows_kext/protocol/testdata/go_command_test.bin differ diff --git a/windows_kext/wdk/Cargo.lock b/windows_kext/wdk/Cargo.lock index a6ffcec3..78822535 100644 --- a/windows_kext/wdk/Cargo.lock +++ b/windows_kext/wdk/Cargo.lock @@ -84,19 +84,20 @@ checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" [[package]] name = "windows-sys" version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" dependencies = [ "windows-targets", ] [[package]] name = "windows-targets" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", "windows_i686_gnu", + "windows_i686_gnullvm", "windows_i686_msvc", "windows_x86_64_gnu", "windows_x86_64_gnullvm", @@ -105,35 +106,40 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_i686_gnu" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_i686_msvc" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" -source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +version = "0.52.5" +source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317" diff --git a/windows_kext/wdk/Cargo.toml b/windows_kext/wdk/Cargo.toml index a3edbf08..0b85e05a 100644 --- a/windows_kext/wdk/Cargo.toml +++ b/windows_kext/wdk/Cargo.toml @@ -16,5 +16,5 @@ features = ["alloc"] # WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels. [dependencies.windows-sys] git = "https://github.com/microsoft/windows-rs" -rev = "41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" +rev = "dffa8b03dc4987c278d82e88015ffe96aa8ac317" features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_SystemServices", "Win32_Foundation", "Win32_Security", "Win32_System_IO", "Win32_System_Kernel", "Win32_System_Power", "Win32_System_WindowsProgramming", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_WindowsFilteringPlatform", "Win32_System_Rpc"] diff --git a/windows_kext/wdk/README.md b/windows_kext/wdk/README.md index 36107c4b..4712225d 100644 --- a/windows_kext/wdk/README.md +++ b/windows_kext/wdk/README.md @@ -10,5 +10,7 @@ see: `wdk/src/driver.rs` see: `wdk/src/irp_helper.rs` Open issues need to be resolved: -https://github.com/microsoft/wdkmetadata/issues/59 https://github.com/microsoft/windows-rs/issues/2805 + +Resolved: +https://github.com/microsoft/wdkmetadata/issues/59 diff --git a/windows_kext/wdk/src/allocator.rs b/windows_kext/wdk/src/allocator.rs index e3f65fa8..f8767b8a 100644 --- a/windows_kext/wdk/src/allocator.rs +++ b/windows_kext/wdk/src/allocator.rs @@ -43,8 +43,8 @@ unsafe impl GlobalAlloc for WindowsAllocator { } unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { - let pool = self.alloc(layout); - pool + + self.alloc(layout) } unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { diff --git a/windows_kext/wdk/src/driver.rs b/windows_kext/wdk/src/driver.rs index a8b7440d..97f08d6d 100644 --- a/windows_kext/wdk/src/driver.rs +++ b/windows_kext/wdk/src/driver.rs @@ -1,6 +1,6 @@ use windows_sys::{ - Wdk::Foundation::{DEVICE_OBJECT, DRIVER_OBJECT, IRP}, - Win32::Foundation::{HANDLE, NTSTATUS}, + Wdk::Foundation::{DEVICE_OBJECT, DRIVER_DISPATCH, DRIVER_OBJECT, DRIVER_UNLOAD}, + Win32::Foundation::HANDLE, }; use crate::{ @@ -23,11 +23,6 @@ pub struct Driver { } unsafe impl Sync for Driver {} -// This is a workaround for current state of wdk bindings. -// TODO: replace with official version when they are correct: https://github.com/microsoft/wdkmetadata/issues/59 -pub type UnloadFnType = unsafe extern "system" fn(driver_object: *const DRIVER_OBJECT); -pub type MjFnType = unsafe extern "system" fn(&mut DEVICE_OBJECT, &mut IRP) -> NTSTATUS; - impl Driver { pub(crate) fn new( driver_object: *mut DRIVER_OBJECT, @@ -50,54 +45,54 @@ impl Driver { return unsafe { self.device_object.as_mut() }; } - pub fn set_driver_unload(&mut self, driver_unload: UnloadFnType) { + pub fn set_driver_unload(&mut self, driver_unload: DRIVER_UNLOAD) { if let Some(driver) = unsafe { self.driver_object.as_mut() } { - driver.DriverUnload = Some(unsafe { core::mem::transmute(driver_unload) }) + driver.DriverUnload = driver_unload } } - pub fn set_read_fn(&mut self, mj_fn: MjFnType) { + pub fn set_read_fn(&mut self, mj_fn: DRIVER_DISPATCH) { self.set_major_fn(windows_sys::Wdk::System::SystemServices::IRP_MJ_READ, mj_fn); } - pub fn set_write_fn(&mut self, mj_fn: MjFnType) { + pub fn set_write_fn(&mut self, mj_fn: DRIVER_DISPATCH) { self.set_major_fn( windows_sys::Wdk::System::SystemServices::IRP_MJ_WRITE, mj_fn, ); } - pub fn set_create_fn(&mut self, mj_fn: MjFnType) { + pub fn set_create_fn(&mut self, mj_fn: DRIVER_DISPATCH) { self.set_major_fn( windows_sys::Wdk::System::SystemServices::IRP_MJ_CREATE, mj_fn, ); } - pub fn set_device_control_fn(&mut self, mj_fn: MjFnType) { + pub fn set_device_control_fn(&mut self, mj_fn: DRIVER_DISPATCH) { self.set_major_fn( windows_sys::Wdk::System::SystemServices::IRP_MJ_DEVICE_CONTROL, mj_fn, ); } - pub fn set_close_fn(&mut self, mj_fn: MjFnType) { + pub fn set_close_fn(&mut self, mj_fn: DRIVER_DISPATCH) { self.set_major_fn( windows_sys::Wdk::System::SystemServices::IRP_MJ_CLOSE, mj_fn, ); } - pub fn set_cleanup_fn(&mut self, mj_fn: MjFnType) { + pub fn set_cleanup_fn(&mut self, mj_fn: DRIVER_DISPATCH) { self.set_major_fn( windows_sys::Wdk::System::SystemServices::IRP_MJ_CLEANUP, mj_fn, ); } - fn set_major_fn(&mut self, fn_index: u32, mj_fn: MjFnType) { + fn set_major_fn(&mut self, fn_index: u32, mj_fn: DRIVER_DISPATCH) { if let Some(driver) = unsafe { self.driver_object.as_mut() } { - driver.MajorFunction[fn_index as usize] = Some(unsafe { core::mem::transmute(mj_fn) }) + driver.MajorFunction[fn_index as usize] = mj_fn } } } diff --git a/windows_kext/wdk/src/ffi.rs b/windows_kext/wdk/src/ffi.rs index b7fea16c..c250499e 100644 --- a/windows_kext/wdk/src/ffi.rs +++ b/windows_kext/wdk/src/ffi.rs @@ -270,7 +270,7 @@ impl WdfObjectAttributes { evt_destroy_callback: None, execution_level: WdfExecutionLevel::InheritFromParent, synchronization_scope: WdfSynchronizationScope::InheritFromParent, - parent_object: 0, + parent_object: core::ptr::null_mut(), context_size_override: 0, context_type_info: core::ptr::null(), } diff --git a/windows_kext/wdk/src/filter_engine/callout.rs b/windows_kext/wdk/src/filter_engine/callout.rs index ad4fcf4e..5651de1d 100644 --- a/windows_kext/wdk/src/filter_engine/callout.rs +++ b/windows_kext/wdk/src/filter_engine/callout.rs @@ -1,7 +1,7 @@ use super::{callout_data::CalloutData, ffi, layer::Layer}; use crate::ffi::FwpsCalloutClassifyFn; use alloc::{borrow::ToOwned, format, string::String}; -use windows_sys::Wdk::Foundation::DEVICE_OBJECT; +use windows_sys::{Wdk::Foundation::DEVICE_OBJECT, Win32::Foundation::HANDLE}; pub enum FilterType { Resettable, @@ -49,13 +49,13 @@ impl Callout { pub fn register_filter( &mut self, - filter_engine_handle: isize, + filter_engine_handle: HANDLE, sublayer_guid: u128, ) -> Result<(), String> { match ffi::register_filter( filter_engine_handle, sublayer_guid, - &format!("{}-filter", self.name), + &self.name, &self.description, self.guid, self.layer, @@ -75,14 +75,14 @@ impl Callout { pub(crate) fn register_callout( &mut self, - filter_engine_handle: isize, + filter_engine_handle: HANDLE, device_object: *mut DEVICE_OBJECT, callout_fn: FwpsCalloutClassifyFn, ) -> Result<(), String> { match ffi::register_callout( device_object, filter_engine_handle, - &format!("{}-callout", self.name), + &self.name, &self.description, self.guid, self.layer, diff --git a/windows_kext/wdk/src/filter_engine/callout_data.rs b/windows_kext/wdk/src/filter_engine/callout_data.rs index c09be368..bb861f84 100644 --- a/windows_kext/wdk/src/filter_engine/callout_data.rs +++ b/windows_kext/wdk/src/filter_engine/callout_data.rs @@ -37,9 +37,7 @@ impl ClassifyDefer { } ClassifyDefer::Reauthorization(_callout_id, packet_list) => { // There is no way to reset single filter. If another request for filter reset is trigger at the same time it will fail. - if let Err(err) = filter_engine.reset_all_filters() { - return Err(err); - } + filter_engine.reset_all_filters()?; return Ok(packet_list); } } @@ -140,7 +138,7 @@ impl<'a> CalloutData<'a> { packet_list: Option, ) -> Result { unsafe { - let mut completion_context = 0; + let mut completion_context: HANDLE = core::ptr::null_mut(); if let Some(completion_handle) = (*self.metadata).get_completion_handle() { let status = FwpsPendOperation0(completion_handle, &mut completion_context); check_ntstatus(status)?; diff --git a/windows_kext/wdk/src/filter_engine/ffi.rs b/windows_kext/wdk/src/filter_engine/ffi.rs index 766c1ef1..45103272 100644 --- a/windows_kext/wdk/src/filter_engine/ffi.rs +++ b/windows_kext/wdk/src/filter_engine/ffi.rs @@ -113,9 +113,7 @@ pub(crate) fn register_callout( check_ntstatus(status)?; - if let Err(err) = callout_add(filter_engine_handle, guid, layer, name, description) { - return Err(err); - } + callout_add(filter_engine_handle, guid, layer, name, description)?; return Ok(callout_id); } diff --git a/windows_kext/wdk/src/filter_engine/metadata.rs b/windows_kext/wdk/src/filter_engine/metadata.rs index 29419786..632830fa 100644 --- a/windows_kext/wdk/src/filter_engine/metadata.rs +++ b/windows_kext/wdk/src/filter_engine/metadata.rs @@ -154,10 +154,10 @@ impl FwpsIncomingMetadataValues { #[allow(dead_code)] #[repr(C)] enum FwpsDiscardModule0 { - FwpsDiscardModuleNetwork = 0, - FwpsDiscardModuleTransport = 1, - FwpsDiscardModuleGeneral = 2, - FwpsDiscardModuleMax = 3, + Network = 0, + Transport = 1, + General = 2, + Max = 3, } #[repr(C)] diff --git a/windows_kext/wdk/src/filter_engine/mod.rs b/windows_kext/wdk/src/filter_engine/mod.rs index 7e6cc20f..0405afba 100644 --- a/windows_kext/wdk/src/filter_engine/mod.rs +++ b/windows_kext/wdk/src/filter_engine/mod.rs @@ -107,9 +107,7 @@ impl FilterEngine { filter_engine.callouts = Some(boxed_callouts); } - if let Err(err) = filter_engine.commit() { - return Err(err); - } + filter_engine.commit()? } self.committed = true; info!("transaction committed"); @@ -147,9 +145,7 @@ impl FilterEngine { } } // Commit transaction. - if let Err(err) = filter_engine.commit() { - return Err(err); - } + filter_engine.commit()?; return Ok(()); } @@ -192,7 +188,7 @@ impl Drop for FilterEngine { } } - if self.handle != 0 && self.handle != INVALID_HANDLE_VALUE { + if !self.handle.is_null() && self.handle != INVALID_HANDLE_VALUE { _ = ffi::filter_engine_close(self.handle); } } diff --git a/windows_kext/wdk/src/filter_engine/net_buffer.rs b/windows_kext/wdk/src/filter_engine/net_buffer.rs index f5274547..ff94ca80 100644 --- a/windows_kext/wdk/src/filter_engine/net_buffer.rs +++ b/windows_kext/wdk/src/filter_engine/net_buffer.rs @@ -85,7 +85,7 @@ impl NetBufferList { } // Allocate space in buffer, if buffer is too small. - let mut buffer = alloc::vec![0 as u8; data_length as usize]; + let mut buffer = alloc::vec![0_u8; data_length as usize]; let ptr = NdisGetDataBuffer(nb, data_length, buffer.as_mut_ptr(), 1, 0); @@ -209,7 +209,7 @@ impl Iterator for NetBufferListIter { } } -pub fn read_packet_partial<'a>(nbl: *mut NET_BUFFER_LIST, buffer: &'a mut [u8]) -> Result<(), ()> { +pub fn read_packet_partial(nbl: *mut NET_BUFFER_LIST, buffer: &mut [u8]) -> Result<(), ()> { unsafe { let Some(nbl) = nbl.as_ref() else { return Err(()); diff --git a/windows_kext/wdk/src/filter_engine/packet.rs b/windows_kext/wdk/src/filter_engine/packet.rs index 85e26006..afdcb021 100644 --- a/windows_kext/wdk/src/filter_engine/packet.rs +++ b/windows_kext/wdk/src/filter_engine/packet.rs @@ -105,9 +105,9 @@ impl Injector { } let mut remote_ip: [u8; 16] = [0; 16]; if ipv6 { - remote_ip[0..16].copy_from_slice(&remote_ip_slice); + remote_ip[0..16].copy_from_slice(remote_ip_slice); } else { - remote_ip[0..4].copy_from_slice(&remote_ip_slice); + remote_ip[0..4].copy_from_slice(remote_ip_slice); } TransportPacketList { @@ -163,7 +163,7 @@ impl Injector { let status = if packet_list.inbound { FwpsInjectTransportReceiveAsync0( self.transport_inject_handle, - 0, + core::ptr::null_mut(), core::ptr::null_mut(), 0, address_family, @@ -177,7 +177,7 @@ impl Injector { } else { FwpsInjectTransportSendAsync1( self.transport_inject_handle, - 0, + core::ptr::null_mut(), packet_list.endpoint_handle, 0, &mut send_params, @@ -222,7 +222,7 @@ impl Injector { unsafe { FwpsInjectNetworkReceiveAsync0( inject_handle, - 0, + core::ptr::null_mut(), 0, UNSPECIFIED_COMPARTMENT_ID, inject_info.interface_index, @@ -237,7 +237,7 @@ impl Injector { unsafe { FwpsInjectNetworkSendAsync0( inject_handle, - 0, + core::ptr::null_mut(), 0, UNSPECIFIED_COMPARTMENT_ID, nbl, @@ -269,7 +269,7 @@ impl Injector { } else { self.packet_inject_handle_v4 }; - if inject_handle == INVALID_HANDLE_VALUE || inject_handle == 0 { + if inject_handle == INVALID_HANDLE_VALUE || inject_handle.is_null() { return false; } @@ -309,19 +309,19 @@ impl Drop for Injector { fn drop(&mut self) { unsafe { if self.transport_inject_handle != INVALID_HANDLE_VALUE - && self.transport_inject_handle != 0 + && !self.transport_inject_handle.is_null() { FwpsInjectionHandleDestroy0(self.transport_inject_handle); self.transport_inject_handle = INVALID_HANDLE_VALUE; } if self.packet_inject_handle_v4 != INVALID_HANDLE_VALUE - && self.packet_inject_handle_v4 != 0 + && !self.packet_inject_handle_v4.is_null() { FwpsInjectionHandleDestroy0(self.packet_inject_handle_v4); self.packet_inject_handle_v4 = INVALID_HANDLE_VALUE; } if self.packet_inject_handle_v6 != INVALID_HANDLE_VALUE - && self.packet_inject_handle_v6 != 0 + && !self.packet_inject_handle_v6.is_null() { FwpsInjectionHandleDestroy0(self.packet_inject_handle_v6); self.packet_inject_handle_v6 = INVALID_HANDLE_VALUE; diff --git a/windows_kext/wdk/src/irp_helpers.rs b/windows_kext/wdk/src/irp_helpers.rs index 52960d5e..821c3b13 100644 --- a/windows_kext/wdk/src/irp_helpers.rs +++ b/windows_kext/wdk/src/irp_helpers.rs @@ -67,7 +67,7 @@ impl ReadRequest<'_> { for i in 0..bytes_to_write { self.buffer[self.fill_index + i] = bytes[i]; } - self.fill_index = self.fill_index + bytes_to_write; + self.fill_index += bytes_to_write; bytes_to_write } @@ -94,7 +94,7 @@ impl WriteRequest<'_> { } pub fn get_buffer(&self) -> &[u8] { - &self.buffer + self.buffer } pub fn mark_all_as_read(&mut self) { @@ -155,7 +155,7 @@ impl DeviceControlRequest<'_> { } pub fn get_buffer(&self) -> &[u8] { - &self.buffer + self.buffer } pub fn write(&mut self, bytes: &[u8]) -> usize { let mut bytes_to_write: usize = bytes.len(); @@ -168,7 +168,7 @@ impl DeviceControlRequest<'_> { for i in 0..bytes_to_write { self.buffer[self.fill_index + i] = bytes[i]; } - self.fill_index = self.fill_index + bytes_to_write; + self.fill_index += bytes_to_write; bytes_to_write }