Merge pull request #1599 from safing/fix/kext-bug

[service] Fix kext verdict of update command
This commit is contained in:
Daniel Hååvi 2024-07-02 11:29:33 +02:00 committed by GitHub
commit b7c32ec6de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 370 additions and 199 deletions

View file

@ -5,11 +5,13 @@ package windowskext
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
"github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/windows_kext/kextinterface"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -32,8 +34,15 @@ func (v *VersionInfo) String() string {
func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate chan *packet.BandwidthUpdate) { func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate chan *packet.BandwidthUpdate) {
for { for {
packetInfo, err := RecvVerdictRequest() packetInfo, err := RecvVerdictRequest()
if errors.Is(err, kextinterface.ErrUnexpectedInfoSize) || errors.Is(err, kextinterface.ErrUnexpectedReadError) {
log.Criticalf("unexpected kext info data: %s", err)
continue // Depending on the info type this may not affect the functionality. Try to continue reading the next commands.
}
if err != nil { if err != nil {
log.Warningf("failed to get packet from windows kext: %s", err) log.Warningf("failed to get packet from windows kext: %s", err)
// Probably IO error, nothing else we can do.
return return
} }

View file

@ -39,9 +39,12 @@ func Start() error {
} }
// Start service and open file // Start service and open file
service.Start(true) err = service.Start(true)
kextFile, err = service.OpenFile(1024) if err != nil {
log.Errorf("failed to start service: %s", err)
}
kextFile, err = service.OpenFile(1024)
if err != nil { if err != nil {
return fmt.Errorf("failed to open driver: %w", err) return fmt.Errorf("failed to open driver: %w", err)
} }
@ -130,7 +133,7 @@ func UpdateVerdict(conn *network.Connection) error {
LocalPort: conn.LocalPort, LocalPort: conn.LocalPort,
RemoteAddress: [4]byte(conn.Entity.IP), RemoteAddress: [4]byte(conn.Entity.IP),
RemotePort: conn.Entity.Port, RemotePort: conn.Entity.Port,
Verdict: uint8(conn.Verdict), Verdict: uint8(getKextVerdictFromConnection(conn)),
} }
return kextinterface.SendUpdateV4Command(kextFile, update) return kextinterface.SendUpdateV4Command(kextFile, update)
@ -141,7 +144,7 @@ func UpdateVerdict(conn *network.Connection) error {
LocalPort: conn.LocalPort, LocalPort: conn.LocalPort,
RemoteAddress: [16]byte(conn.Entity.IP), RemoteAddress: [16]byte(conn.Entity.IP),
RemotePort: conn.Entity.Port, RemotePort: conn.Entity.Port,
Verdict: uint8(conn.Verdict), Verdict: uint8(getKextVerdictFromConnection(conn)),
} }
return kextinterface.SendUpdateV6Command(kextFile, update) return kextinterface.SendUpdateV6Command(kextFile, update)
@ -149,6 +152,40 @@ func UpdateVerdict(conn *network.Connection) error {
return nil return nil
} }
func getKextVerdictFromConnection(conn *network.Connection) kextinterface.KextVerdict {
switch conn.Verdict {
case network.VerdictUndecided:
return kextinterface.VerdictUndecided
case network.VerdictUndeterminable:
return kextinterface.VerdictUndeterminable
case network.VerdictAccept:
if conn.VerdictPermanent {
return kextinterface.VerdictPermanentAccept
} else {
return kextinterface.VerdictAccept
}
case network.VerdictBlock:
if conn.VerdictPermanent {
return kextinterface.VerdictPermanentBlock
} else {
return kextinterface.VerdictBlock
}
case network.VerdictDrop:
if conn.VerdictPermanent {
return kextinterface.VerdictPermanentDrop
} else {
return kextinterface.VerdictDrop
}
case network.VerdictRerouteToNameserver:
return kextinterface.VerdictRerouteToNameserver
case network.VerdictRerouteToTunnel:
return kextinterface.VerdictRerouteToTunnel
case network.VerdictFailed:
return kextinterface.VerdictFailed
}
return kextinterface.VerdictUndeterminable
}
// Returns the kext version. // Returns the kext version.
func GetVersion() (*VersionInfo, error) { func GetVersion() (*VersionInfo, error) {
data, err := kextinterface.ReadVersion(kextFile) data, err := kextinterface.ReadVersion(kextFile)

View file

@ -340,19 +340,20 @@ checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8"
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.52.0" version = "0.52.0"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
dependencies = [ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]] [[package]]
name = "windows-targets" name = "windows-targets"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
dependencies = [ dependencies = [
"windows_aarch64_gnullvm", "windows_aarch64_gnullvm",
"windows_aarch64_msvc", "windows_aarch64_msvc",
"windows_i686_gnu", "windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc", "windows_i686_msvc",
"windows_x86_64_gnu", "windows_x86_64_gnu",
"windows_x86_64_gnullvm", "windows_x86_64_gnullvm",
@ -361,38 +362,43 @@ dependencies = [
[[package]] [[package]]
name = "windows_aarch64_gnullvm" name = "windows_aarch64_gnullvm"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_aarch64_msvc" name = "windows_aarch64_msvc"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_i686_gnu" name = "windows_i686_gnu"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_i686_msvc" name = "windows_i686_msvc"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_x86_64_gnu" name = "windows_x86_64_gnu"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_x86_64_gnullvm" name = "windows_x86_64_gnullvm"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_x86_64_msvc" name = "windows_x86_64_msvc"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"

View file

@ -22,5 +22,5 @@ hashbrown = { version = "0.14.3", default-features = false, features = ["ahash"]
# WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels. # WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels.
[dependencies.windows-sys] [dependencies.windows-sys]
git = "https://github.com/microsoft/windows-rs" git = "https://github.com/microsoft/windows-rs"
rev = "41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" rev = "dffa8b03dc4987c278d82e88015ffe96aa8ac317"
features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_SystemServices", "Win32_Foundation", "Win32_Security", "Win32_System_IO", "Win32_System_Kernel", "Win32_System_Power", "Win32_System_WindowsProgramming", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_WindowsFilteringPlatform"] features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_SystemServices", "Win32_Foundation", "Win32_Security", "Win32_System_IO", "Win32_System_Kernel", "Win32_System_Power", "Win32_System_WindowsProgramming", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_WindowsFilteringPlatform"]

View file

@ -12,8 +12,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
// ----------------------------------------- // -----------------------------------------
// ALE Auth layers // ALE Auth layers
Callout::new( Callout::new(
"AleLayerOutboundV4", "Portmaster ALE Outbound IPv4",
"ALE layer for outbound connection for ipv4", "Portmaster uses this layer to block/permit outgoing ipv4 connections",
0x58545073_f893_454c_bbea_a57bc964f46d, 0x58545073_f893_454c_bbea_a57bc964f46d,
Layer::AleAuthConnectV4, Layer::AleAuthConnectV4,
consts::FWP_ACTION_CALLOUT_TERMINATING, consts::FWP_ACTION_CALLOUT_TERMINATING,
@ -21,8 +21,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
ale_callouts::ale_layer_connect_v4, ale_callouts::ale_layer_connect_v4,
), ),
Callout::new( Callout::new(
"AleLayerOutboundV6", "Portmaster ALE Outbound IPv6",
"ALE layer for outbound connections for ipv6", "Portmaster uses this layer to block/permit outgoing ipv6 connections",
0x4bd2a080_2585_478d_977c_7f340c6bc3d4, 0x4bd2a080_2585_478d_977c_7f340c6bc3d4,
Layer::AleAuthConnectV6, Layer::AleAuthConnectV6,
consts::FWP_ACTION_CALLOUT_TERMINATING, consts::FWP_ACTION_CALLOUT_TERMINATING,
@ -32,8 +32,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
// ----------------------------------------- // -----------------------------------------
// ALE connection end layers // ALE connection end layers
Callout::new( Callout::new(
"AleEndpointClosureV4", "Portmaster Endpoint Closure IPv4",
"ALE layer for indicating closing of connection for ipv4", "Portmaster uses this layer to detect when a IPv4 connection has ended",
0x58f02845_ace9_4455_ac80_8a84b86fe566, 0x58f02845_ace9_4455_ac80_8a84b86fe566,
Layer::AleEndpointClosureV4, Layer::AleEndpointClosureV4,
consts::FWP_ACTION_CALLOUT_INSPECTION, consts::FWP_ACTION_CALLOUT_INSPECTION,
@ -41,8 +41,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
ale_callouts::endpoint_closure_v4, ale_callouts::endpoint_closure_v4,
), ),
Callout::new( Callout::new(
"AleEndpointClosureV6", "Portmaster Endpoint Closure IPv6",
"ALE layer for indicating closing of connection for ipv6", "Portmaster uses this layer to detect when a IPv6 connection has ended",
0x2bc82359_9dc5_4315_9c93_c89467e283ce, 0x2bc82359_9dc5_4315_9c93_c89467e283ce,
Layer::AleEndpointClosureV6, Layer::AleEndpointClosureV6,
consts::FWP_ACTION_CALLOUT_INSPECTION, consts::FWP_ACTION_CALLOUT_INSPECTION,
@ -61,8 +61,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
// ale_callouts::ale_resource_monitor, // ale_callouts::ale_resource_monitor,
// ), // ),
Callout::new( Callout::new(
"AleResourceReleaseV4", "Portmaster resource release IPv4",
"Ipv4 Port release monitor", "Portmaster uses this layer to detect when a IPv4 port has been released",
0x7b513bb3_a0be_4f77_a4bc_03c052abe8d7, 0x7b513bb3_a0be_4f77_a4bc_03c052abe8d7,
Layer::AleResourceReleaseV4, Layer::AleResourceReleaseV4,
consts::FWP_ACTION_CALLOUT_INSPECTION, consts::FWP_ACTION_CALLOUT_INSPECTION,
@ -79,8 +79,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
// ale_callouts::ale_resource_monitor, // ale_callouts::ale_resource_monitor,
// ), // ),
Callout::new( Callout::new(
"AleResourceReleaseV6", "Portmaster resource release IPv6",
"Ipv6 Port release monitor", "Portmaster uses this layer to detect when a IPv6 port has been released",
0x6cf36e04_e656_42c3_8cac_a1ce05328bd1, 0x6cf36e04_e656_42c3_8cac_a1ce05328bd1,
Layer::AleResourceReleaseV6, Layer::AleResourceReleaseV6,
consts::FWP_ACTION_CALLOUT_INSPECTION, consts::FWP_ACTION_CALLOUT_INSPECTION,
@ -90,8 +90,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
// ----------------------------------------- // -----------------------------------------
// Stream layer // Stream layer
Callout::new( Callout::new(
"StreamLayerV4", "Portmaster Stream IPv4",
"Stream layer for ipv4", "Portmaster uses this layer for bandwidth statistics of IPv4 TCP connections",
0xe2ca13bf_9710_4caa_a45c_e8c78b5ac780, 0xe2ca13bf_9710_4caa_a45c_e8c78b5ac780,
Layer::StreamV4, Layer::StreamV4,
consts::FWP_ACTION_CALLOUT_INSPECTION, consts::FWP_ACTION_CALLOUT_INSPECTION,
@ -99,8 +99,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
stream_callouts::stream_layer_tcp_v4, stream_callouts::stream_layer_tcp_v4,
), ),
Callout::new( Callout::new(
"StreamLayerV6", "Portmaster Stream IPv6",
"Stream layer for ipv6", "Portmaster uses this layer for bandwidth statistics of IPv6 TCP connections",
0x66c549b3_11e2_4b27_8f73_856e6fd82baa, 0x66c549b3_11e2_4b27_8f73_856e6fd82baa,
Layer::StreamV6, Layer::StreamV6,
consts::FWP_ACTION_CALLOUT_INSPECTION, consts::FWP_ACTION_CALLOUT_INSPECTION,
@ -108,8 +108,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
stream_callouts::stream_layer_tcp_v6, stream_callouts::stream_layer_tcp_v6,
), ),
Callout::new( Callout::new(
"DatagramDataLayerV4", "Portmaster Datagram IPv4",
"DatagramData layer for ipv4", "Portmaster uses this layer for bandwidth statistics of IPv4 UDP connections",
0xe7eeeaba_168a_45bb_8747_e1a702feb2c5, 0xe7eeeaba_168a_45bb_8747_e1a702feb2c5,
Layer::DatagramDataV4, Layer::DatagramDataV4,
consts::FWP_ACTION_CALLOUT_INSPECTION, consts::FWP_ACTION_CALLOUT_INSPECTION,
@ -117,8 +117,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
stream_callouts::stream_layer_udp_v4, stream_callouts::stream_layer_udp_v4,
), ),
Callout::new( Callout::new(
"DatagramDataLayerV6", "Portmaster Datagram IPv6",
"DatagramData layer for ipv4", "Portmaster uses this layer for bandwidth statistics of IPv6 UDP connections",
0xb25862cd_f744_4452_b14a_d0c1e5a25b30, 0xb25862cd_f744_4452_b14a_d0c1e5a25b30,
Layer::DatagramDataV6, Layer::DatagramDataV6,
consts::FWP_ACTION_CALLOUT_INSPECTION, consts::FWP_ACTION_CALLOUT_INSPECTION,
@ -128,8 +128,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
// ----------------------------------------- // -----------------------------------------
// Packet layers // Packet layers
Callout::new( Callout::new(
"IPPacketOutboundV4", "Portmaster Packet Outbound IPv4",
"IP packet outbound network layer callout for Ipv4", "Portmaster uses this layer to redirect/block/permit outgoing ipv4 packets",
0xf3183afe_dc35_49f1_8ea2_b16b5666dd36, 0xf3183afe_dc35_49f1_8ea2_b16b5666dd36,
Layer::OutboundIppacketV4, Layer::OutboundIppacketV4,
consts::FWP_ACTION_CALLOUT_TERMINATING, consts::FWP_ACTION_CALLOUT_TERMINATING,
@ -137,8 +137,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
packet_callouts::ip_packet_layer_outbound_v4, packet_callouts::ip_packet_layer_outbound_v4,
), ),
Callout::new( Callout::new(
"IPPacketInboundV4", "Portmaster Packet Inbound IPv4",
"IP packet inbound network layer callout for Ipv4", "Portmaster uses this layer to redirect/block/permit inbound ipv4 packets",
0xf0369374_203d_4bf0_83d2_b2ad3cc17a50, 0xf0369374_203d_4bf0_83d2_b2ad3cc17a50,
Layer::InboundIppacketV4, Layer::InboundIppacketV4,
consts::FWP_ACTION_CALLOUT_TERMINATING, consts::FWP_ACTION_CALLOUT_TERMINATING,
@ -146,8 +146,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
packet_callouts::ip_packet_layer_inbound_v4, packet_callouts::ip_packet_layer_inbound_v4,
), ),
Callout::new( Callout::new(
"IPPacketOutboundV6", "Portmaster Packet Outbound IPv6",
"IP packet outbound network layer callout for Ipv6", "Portmaster uses this layer to redirect/block/permit outgoing ipv6 packets",
0x91daf8bc_0908_4bf8_9f81_2c538ab8f25a, 0x91daf8bc_0908_4bf8_9f81_2c538ab8f25a,
Layer::OutboundIppacketV6, Layer::OutboundIppacketV6,
consts::FWP_ACTION_CALLOUT_TERMINATING, consts::FWP_ACTION_CALLOUT_TERMINATING,
@ -155,8 +155,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
packet_callouts::ip_packet_layer_outbound_v6, packet_callouts::ip_packet_layer_outbound_v6,
), ),
Callout::new( Callout::new(
"IPPacketInboundV6", "Portmaster Packet Inbound IPv6",
"IP packet inbound network layer callout for Ipv6", "Portmaster uses this layer to redirect/block/permit inbound ipv6 packets",
0xfe9faf5f_ceb2_4cd9_9995_f2f2b4f5fcc0, 0xfe9faf5f_ceb2_4cd9_9995_f2f2b4f5fcc0,
Layer::InboundIppacketV6, Layer::InboundIppacketV6,
consts::FWP_ACTION_CALLOUT_TERMINATING, consts::FWP_ACTION_CALLOUT_TERMINATING,

View file

@ -38,10 +38,10 @@ pub extern "system" fn driver_entry(
}; };
// Set driver functions. // Set driver functions.
driver.set_driver_unload(driver_unload); driver.set_driver_unload(Some(driver_unload));
driver.set_read_fn(driver_read); driver.set_read_fn(Some(driver_read));
driver.set_write_fn(driver_write); driver.set_write_fn(Some(driver_write));
driver.set_device_control_fn(device_control); driver.set_device_control_fn(Some(device_control));
// Initialize device. // Initialize device.
unsafe { unsafe {
@ -70,10 +70,10 @@ unsafe extern "system" fn driver_unload(_object: *const DRIVER_OBJECT) {
// driver_read event triggered from user-space on file.Read. // driver_read event triggered from user-space on file.Read.
unsafe extern "system" fn driver_read( unsafe extern "system" fn driver_read(
_device_object: &mut DEVICE_OBJECT, _device_object: *const DEVICE_OBJECT,
irp: &mut IRP, irp: *mut IRP,
) -> NTSTATUS { ) -> NTSTATUS {
let mut read_request = ReadRequest::new(irp); let mut read_request = ReadRequest::new(irp.as_mut().unwrap());
let Some(device) = get_device() else { let Some(device) = get_device() else {
read_request.complete(); read_request.complete();
@ -86,10 +86,10 @@ unsafe extern "system" fn driver_read(
/// driver_write event triggered from user-space on file.Write. /// driver_write event triggered from user-space on file.Write.
unsafe extern "system" fn driver_write( unsafe extern "system" fn driver_write(
_device_object: &mut DEVICE_OBJECT, _device_object: *const DEVICE_OBJECT,
irp: &mut IRP, irp: *mut IRP,
) -> NTSTATUS { ) -> NTSTATUS {
let mut write_request = WriteRequest::new(irp); let mut write_request = WriteRequest::new(irp.as_mut().unwrap());
let Some(device) = get_device() else { let Some(device) = get_device() else {
write_request.complete(); write_request.complete();
return write_request.get_status(); return write_request.get_status();
@ -104,10 +104,10 @@ unsafe extern "system" fn driver_write(
/// device_control event triggered from user-space on file.deviceIOControl. /// device_control event triggered from user-space on file.deviceIOControl.
unsafe extern "system" fn device_control( unsafe extern "system" fn device_control(
_device_object: &mut DEVICE_OBJECT, _device_object: *const DEVICE_OBJECT,
irp: &mut IRP, irp: *mut IRP,
) -> NTSTATUS { ) -> NTSTATUS {
let mut control_request = DeviceControlRequest::new(irp); let mut control_request = DeviceControlRequest::new(irp.as_mut().unwrap());
let Some(device) = get_device() else { let Some(device) = get_device() else {
control_request.complete(); control_request.complete();
return control_request.get_status(); return control_request.get_status();

View file

@ -140,7 +140,7 @@ fn ip_packet_layer(
} { } {
Ok(key) => key, Ok(key) => key,
Err(err) => { Err(err) => {
crate::warn!("failed to get key from nbl: {}", err); crate::dbg!("failed to get key from nbl: {}", err);
return; return;
} }
}; };

View file

@ -3,6 +3,7 @@ package kextinterface
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
) )
@ -18,6 +19,7 @@ const (
var ( var (
ErrUnknownInfoType = errors.New("unknown info type") ErrUnknownInfoType = errors.New("unknown info type")
ErrUnexpectedInfoSize = errors.New("unexpected info size")
ErrUnexpectedReadError = errors.New("unexpected read error") ErrUnexpectedReadError = errors.New("unexpected read error")
) )
@ -135,117 +137,215 @@ type Info struct {
BandwidthStats *BandwidthStatsArray BandwidthStats *BandwidthStatsArray
} }
func RecvInfo(reader io.Reader) (*Info, error) { type readHelper struct {
var infoType byte infoType byte
err := binary.Read(reader, binary.LittleEndian, &infoType) commandSize uint32
readSize int
reader io.Reader
}
func newReadHelper(reader io.Reader) (*readHelper, error) {
helper := &readHelper{reader: reader}
err := binary.Read(reader, binary.LittleEndian, &helper.infoType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Read size of data err = binary.Read(reader, binary.LittleEndian, &helper.commandSize)
var size uint32
err = binary.Read(reader, binary.LittleEndian, &size)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return helper, nil
}
func (r *readHelper) ReadData(data any) error {
err := binary.Read(r, binary.LittleEndian, data)
if err != nil {
return errors.Join(ErrUnexpectedReadError, err)
}
if err := r.checkOverRead(); err != nil {
return err
}
return nil
}
// Passing size = 0 will read the rest of the command.
func (r *readHelper) ReadBytes(size uint32) ([]byte, error) {
if uint32(r.readSize) >= r.commandSize {
return nil, errors.Join(fmt.Errorf("cannot read more bytes than the command size: %d >= %d", r.readSize, r.commandSize), ErrUnexpectedReadError)
}
if size == 0 {
size = r.commandSize - uint32(r.readSize)
}
if r.commandSize < uint32(r.readSize)+size {
return nil, ErrUnexpectedInfoSize
}
bytes := make([]byte, size)
err := binary.Read(r, binary.LittleEndian, bytes)
if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err)
}
return bytes, nil
}
func (r *readHelper) ReadUntilTheEnd() {
_, _ = r.ReadBytes(0)
}
func (r *readHelper) checkOverRead() error {
if uint32(r.readSize) > r.commandSize {
return ErrUnexpectedInfoSize
}
return nil
}
func (r *readHelper) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.readSize += n
return
}
func RecvInfo(reader io.Reader) (*Info, error) {
helper, err := newReadHelper(reader)
if err != nil {
return nil, err
}
// Make sure the whole command is read before return.
defer helper.ReadUntilTheEnd()
// Read data // Read data
switch infoType { switch helper.infoType {
case InfoConnectionIpv4: case InfoConnectionIpv4:
{ {
parseError := fmt.Errorf("failed to parse InfoConnectionIpv4")
newInfo := ConnectionV4{}
var fixedSizeValues connectionV4Internal var fixedSizeValues connectionV4Internal
err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) // Read fixed size values.
err = helper.ReadData(&fixedSizeValues)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err, fmt.Errorf("fixed"))
} }
// Read size of payload newInfo.connectionV4Internal = fixedSizeValues
var size uint32 // Read size of payload.
err = binary.Read(reader, binary.LittleEndian, &size) var payloadSize uint32
err = helper.ReadData(&payloadSize)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err, fmt.Errorf("payloadsize"))
} }
newInfo := ConnectionV4{connectionV4Internal: fixedSizeValues, Payload: make([]byte, size)}
err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) // Check if there is payload.
if err != nil { if payloadSize > 0 {
return nil, errors.Join(ErrUnexpectedReadError, err) // Read payload.
newInfo.Payload, err = helper.ReadBytes(payloadSize)
if err != nil {
return nil, errors.Join(parseError, err, fmt.Errorf("payload"))
}
} }
return &Info{ConnectionV4: &newInfo}, nil return &Info{ConnectionV4: &newInfo}, nil
} }
case InfoConnectionIpv6: case InfoConnectionIpv6:
{ {
var fixedSizeValues connectionV6Internal parseError := fmt.Errorf("failed to parse InfoConnectionIpv6")
err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues) newInfo := ConnectionV6{}
// Read fixed size values.
err = helper.ReadData(&newInfo.connectionV6Internal)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read size of payload
var size uint32 // Read size of payload.
err = binary.Read(reader, binary.LittleEndian, &size) var payloadSize uint32
err = helper.ReadData(&payloadSize)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
newInfo := ConnectionV6{connectionV6Internal: fixedSizeValues, Payload: make([]byte, size)}
err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload) // Check if there is payload.
if err != nil { if payloadSize > 0 {
return nil, errors.Join(ErrUnexpectedReadError, err) // Read payload.
newInfo.Payload, err = helper.ReadBytes(payloadSize)
if err != nil {
return nil, errors.Join(parseError, err)
}
} }
return &Info{ConnectionV6: &newInfo}, nil return &Info{ConnectionV6: &newInfo}, nil
} }
case InfoConnectionEndEventV4: case InfoConnectionEndEventV4:
{ {
parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV4")
var connectionEnd ConnectionEndV4 var connectionEnd ConnectionEndV4
err = binary.Read(reader, binary.LittleEndian, &connectionEnd)
// Read fixed size values.
err = helper.ReadData(&connectionEnd)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
return &Info{ConnectionEndV4: &connectionEnd}, nil return &Info{ConnectionEndV4: &connectionEnd}, nil
} }
case InfoConnectionEndEventV6: case InfoConnectionEndEventV6:
{ {
parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV6")
var connectionEnd ConnectionEndV6 var connectionEnd ConnectionEndV6
err = binary.Read(reader, binary.LittleEndian, &connectionEnd)
// Read fixed size values.
err = helper.ReadData(&connectionEnd)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
return &Info{ConnectionEndV6: &connectionEnd}, nil return &Info{ConnectionEndV6: &connectionEnd}, nil
} }
case InfoLogLine: case InfoLogLine:
{ {
parseError := fmt.Errorf("failed to parse InfoLogLine")
logLine := LogLine{} logLine := LogLine{}
// Read severity // Read severity
err = binary.Read(reader, binary.LittleEndian, &logLine.Severity) err = helper.ReadData(&logLine.Severity)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read string // Read string
line := make([]byte, size-1) // -1 for the severity enum. bytes, err := helper.ReadBytes(0)
err = binary.Read(reader, binary.LittleEndian, &line)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
logLine.Line = string(line) logLine.Line = string(bytes)
return &Info{LogLine: &logLine}, nil return &Info{LogLine: &logLine}, nil
} }
case InfoBandwidthStatsV4: case InfoBandwidthStatsV4:
{ {
parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV4")
// Read Protocol // Read Protocol
var protocol uint8 var protocol uint8
err = binary.Read(reader, binary.LittleEndian, &protocol) err = helper.ReadData(&protocol)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read size of array // Read size of array
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = helper.ReadData(&size)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read array // Read array
statsArray := make([]BandwidthValueV4, size) statsArray := make([]BandwidthValueV4, size)
for i := range int(size) { for i := range int(size) {
err = binary.Read(reader, binary.LittleEndian, &statsArray[i]) err = helper.ReadData(&statsArray[i])
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
} }
@ -253,24 +353,25 @@ func RecvInfo(reader io.Reader) (*Info, error) {
} }
case InfoBandwidthStatsV6: case InfoBandwidthStatsV6:
{ {
parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV6")
// Read Protocol // Read Protocol
var protocol uint8 var protocol uint8
err = binary.Read(reader, binary.LittleEndian, &protocol) err = helper.ReadData(&protocol)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read size of array // Read size of array
var size uint32 var size uint32
err = binary.Read(reader, binary.LittleEndian, &size) err = helper.ReadData(&size)
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
// Read array // Read array
statsArray := make([]BandwidthValueV6, size) statsArray := make([]BandwidthValueV6, size)
for i := range int(size) { for i := range int(size) {
err = binary.Read(reader, binary.LittleEndian, &statsArray[i]) err = helper.ReadData(&statsArray[i])
if err != nil { if err != nil {
return nil, errors.Join(ErrUnexpectedReadError, err) return nil, errors.Join(parseError, err)
} }
} }
@ -278,10 +379,5 @@ func RecvInfo(reader io.Reader) (*Info, error) {
} }
} }
// Command not recognized, read until the end of command and return.
// During normal operation this should not happen.
unknownData := make([]byte, size)
_, _ = reader.Read(unknownData)
return nil, ErrUnknownInfoType return nil, ErrUnknownInfoType
} }

View file

@ -11,6 +11,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/safing/portbase/log"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
@ -221,7 +222,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro
// Check if there is an old service. // Check if there is an old service.
service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS) service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS)
if err == nil { if err == nil {
fmt.Println("kext: old driver service was found") log.Warning("kext: old driver service was found")
oldService := &KextService{handle: service, driverName: driverName} oldService := &KextService{handle: service, driverName: driverName}
oldService.Stop(true) oldService.Stop(true)
err = oldService.Delete() err = oldService.Delete()
@ -234,7 +235,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro
} }
service = winInvalidHandleValue service = winInvalidHandleValue
fmt.Println("kext: old driver service was deleted successfully") log.Warning("kext: old driver service was deleted successfully")
} }
driverPathU16, err := syscall.UTF16FromString(driverPath) driverPathU16, err := syscall.UTF16FromString(driverPath)

View file

@ -18,8 +18,18 @@ func TestRustInfoFile(t *testing.T) {
defer func() { defer func() {
_ = file.Close() _ = file.Close()
}() }()
first := true
for { for {
info, err := RecvInfo(file) info, err := RecvInfo(file)
// First info should be with invalid size.
// This tests if invalid info data is handled properly.
if first {
if !errors.Is(err, ErrUnexpectedInfoSize) {
t.Errorf("unexpected error: %s\n", err)
}
first = false
continue
}
if err != nil { if err != nil {
if errors.Is(err, ErrUnexpectedReadError) { if errors.Is(err, ErrUnexpectedReadError) {
t.Errorf("unexpected error: %s\n", err) t.Errorf("unexpected error: %s\n", err)

View file

@ -441,6 +441,22 @@ fn generate_test_info_file() -> Result<(), std::io::Error> {
for _ in 0..selected.capacity() { for _ in 0..selected.capacity() {
selected.push(enums.choose(&mut rng).unwrap().clone()); selected.push(enums.choose(&mut rng).unwrap().clone());
} }
// Write wrong size data. To make sure that mismatches between kext and portmaster are handled properly.
let mut info = connection_info_v6(
1,
2,
3,
4,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
5,
6,
7,
&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
);
info.assert_size();
info.0[0] = InfoType::ConnectionIpv4 as u8;
file.write_all(&info.0)?;
for value in selected { for value in selected {
file.write_all(&match value { file.write_all(&match value {
@ -548,5 +564,6 @@ fn generate_test_info_file() -> Result<(), std::io::Error> {
} }
})?; })?;
} }
return Ok(()); return Ok(());
} }

View file

@ -84,19 +84,20 @@ checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8"
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.52.0" version = "0.52.0"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
dependencies = [ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]] [[package]]
name = "windows-targets" name = "windows-targets"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
dependencies = [ dependencies = [
"windows_aarch64_gnullvm", "windows_aarch64_gnullvm",
"windows_aarch64_msvc", "windows_aarch64_msvc",
"windows_i686_gnu", "windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc", "windows_i686_msvc",
"windows_x86_64_gnu", "windows_x86_64_gnu",
"windows_x86_64_gnullvm", "windows_x86_64_gnullvm",
@ -105,35 +106,40 @@ dependencies = [
[[package]] [[package]]
name = "windows_aarch64_gnullvm" name = "windows_aarch64_gnullvm"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_aarch64_msvc" name = "windows_aarch64_msvc"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_i686_gnu" name = "windows_i686_gnu"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_i686_msvc" name = "windows_i686_msvc"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_x86_64_gnu" name = "windows_x86_64_gnu"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_x86_64_gnullvm" name = "windows_x86_64_gnullvm"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]] [[package]]
name = "windows_x86_64_msvc" name = "windows_x86_64_msvc"
version = "0.52.0" version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"

View file

@ -16,5 +16,5 @@ features = ["alloc"]
# WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels. # WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels.
[dependencies.windows-sys] [dependencies.windows-sys]
git = "https://github.com/microsoft/windows-rs" git = "https://github.com/microsoft/windows-rs"
rev = "41ad38d8c42c92fd23fe25ba4dca76c2d861ca06" rev = "dffa8b03dc4987c278d82e88015ffe96aa8ac317"
features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_SystemServices", "Win32_Foundation", "Win32_Security", "Win32_System_IO", "Win32_System_Kernel", "Win32_System_Power", "Win32_System_WindowsProgramming", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_WindowsFilteringPlatform", "Win32_System_Rpc"] features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_SystemServices", "Win32_Foundation", "Win32_Security", "Win32_System_IO", "Win32_System_Kernel", "Win32_System_Power", "Win32_System_WindowsProgramming", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_WindowsFilteringPlatform", "Win32_System_Rpc"]

View file

@ -10,5 +10,7 @@ see: `wdk/src/driver.rs`
see: `wdk/src/irp_helper.rs` see: `wdk/src/irp_helper.rs`
Open issues need to be resolved: Open issues need to be resolved:
https://github.com/microsoft/wdkmetadata/issues/59
https://github.com/microsoft/windows-rs/issues/2805 https://github.com/microsoft/windows-rs/issues/2805
Resolved:
https://github.com/microsoft/wdkmetadata/issues/59

View file

@ -43,8 +43,8 @@ unsafe impl GlobalAlloc for WindowsAllocator {
} }
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
let pool = self.alloc(layout);
pool self.alloc(layout)
} }
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {

View file

@ -1,6 +1,6 @@
use windows_sys::{ use windows_sys::{
Wdk::Foundation::{DEVICE_OBJECT, DRIVER_OBJECT, IRP}, Wdk::Foundation::{DEVICE_OBJECT, DRIVER_DISPATCH, DRIVER_OBJECT, DRIVER_UNLOAD},
Win32::Foundation::{HANDLE, NTSTATUS}, Win32::Foundation::HANDLE,
}; };
use crate::{ use crate::{
@ -23,11 +23,6 @@ pub struct Driver {
} }
unsafe impl Sync for Driver {} unsafe impl Sync for Driver {}
// This is a workaround for current state of wdk bindings.
// TODO: replace with official version when they are correct: https://github.com/microsoft/wdkmetadata/issues/59
pub type UnloadFnType = unsafe extern "system" fn(driver_object: *const DRIVER_OBJECT);
pub type MjFnType = unsafe extern "system" fn(&mut DEVICE_OBJECT, &mut IRP) -> NTSTATUS;
impl Driver { impl Driver {
pub(crate) fn new( pub(crate) fn new(
driver_object: *mut DRIVER_OBJECT, driver_object: *mut DRIVER_OBJECT,
@ -50,54 +45,54 @@ impl Driver {
return unsafe { self.device_object.as_mut() }; return unsafe { self.device_object.as_mut() };
} }
pub fn set_driver_unload(&mut self, driver_unload: UnloadFnType) { pub fn set_driver_unload(&mut self, driver_unload: DRIVER_UNLOAD) {
if let Some(driver) = unsafe { self.driver_object.as_mut() } { if let Some(driver) = unsafe { self.driver_object.as_mut() } {
driver.DriverUnload = Some(unsafe { core::mem::transmute(driver_unload) }) driver.DriverUnload = driver_unload
} }
} }
pub fn set_read_fn(&mut self, mj_fn: MjFnType) { pub fn set_read_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
self.set_major_fn(windows_sys::Wdk::System::SystemServices::IRP_MJ_READ, mj_fn); self.set_major_fn(windows_sys::Wdk::System::SystemServices::IRP_MJ_READ, mj_fn);
} }
pub fn set_write_fn(&mut self, mj_fn: MjFnType) { pub fn set_write_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
self.set_major_fn( self.set_major_fn(
windows_sys::Wdk::System::SystemServices::IRP_MJ_WRITE, windows_sys::Wdk::System::SystemServices::IRP_MJ_WRITE,
mj_fn, mj_fn,
); );
} }
pub fn set_create_fn(&mut self, mj_fn: MjFnType) { pub fn set_create_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
self.set_major_fn( self.set_major_fn(
windows_sys::Wdk::System::SystemServices::IRP_MJ_CREATE, windows_sys::Wdk::System::SystemServices::IRP_MJ_CREATE,
mj_fn, mj_fn,
); );
} }
pub fn set_device_control_fn(&mut self, mj_fn: MjFnType) { pub fn set_device_control_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
self.set_major_fn( self.set_major_fn(
windows_sys::Wdk::System::SystemServices::IRP_MJ_DEVICE_CONTROL, windows_sys::Wdk::System::SystemServices::IRP_MJ_DEVICE_CONTROL,
mj_fn, mj_fn,
); );
} }
pub fn set_close_fn(&mut self, mj_fn: MjFnType) { pub fn set_close_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
self.set_major_fn( self.set_major_fn(
windows_sys::Wdk::System::SystemServices::IRP_MJ_CLOSE, windows_sys::Wdk::System::SystemServices::IRP_MJ_CLOSE,
mj_fn, mj_fn,
); );
} }
pub fn set_cleanup_fn(&mut self, mj_fn: MjFnType) { pub fn set_cleanup_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
self.set_major_fn( self.set_major_fn(
windows_sys::Wdk::System::SystemServices::IRP_MJ_CLEANUP, windows_sys::Wdk::System::SystemServices::IRP_MJ_CLEANUP,
mj_fn, mj_fn,
); );
} }
fn set_major_fn(&mut self, fn_index: u32, mj_fn: MjFnType) { fn set_major_fn(&mut self, fn_index: u32, mj_fn: DRIVER_DISPATCH) {
if let Some(driver) = unsafe { self.driver_object.as_mut() } { if let Some(driver) = unsafe { self.driver_object.as_mut() } {
driver.MajorFunction[fn_index as usize] = Some(unsafe { core::mem::transmute(mj_fn) }) driver.MajorFunction[fn_index as usize] = mj_fn
} }
} }
} }

View file

@ -270,7 +270,7 @@ impl WdfObjectAttributes {
evt_destroy_callback: None, evt_destroy_callback: None,
execution_level: WdfExecutionLevel::InheritFromParent, execution_level: WdfExecutionLevel::InheritFromParent,
synchronization_scope: WdfSynchronizationScope::InheritFromParent, synchronization_scope: WdfSynchronizationScope::InheritFromParent,
parent_object: 0, parent_object: core::ptr::null_mut(),
context_size_override: 0, context_size_override: 0,
context_type_info: core::ptr::null(), context_type_info: core::ptr::null(),
} }

View file

@ -1,7 +1,7 @@
use super::{callout_data::CalloutData, ffi, layer::Layer}; use super::{callout_data::CalloutData, ffi, layer::Layer};
use crate::ffi::FwpsCalloutClassifyFn; use crate::ffi::FwpsCalloutClassifyFn;
use alloc::{borrow::ToOwned, format, string::String}; use alloc::{borrow::ToOwned, format, string::String};
use windows_sys::Wdk::Foundation::DEVICE_OBJECT; use windows_sys::{Wdk::Foundation::DEVICE_OBJECT, Win32::Foundation::HANDLE};
pub enum FilterType { pub enum FilterType {
Resettable, Resettable,
@ -49,13 +49,13 @@ impl Callout {
pub fn register_filter( pub fn register_filter(
&mut self, &mut self,
filter_engine_handle: isize, filter_engine_handle: HANDLE,
sublayer_guid: u128, sublayer_guid: u128,
) -> Result<(), String> { ) -> Result<(), String> {
match ffi::register_filter( match ffi::register_filter(
filter_engine_handle, filter_engine_handle,
sublayer_guid, sublayer_guid,
&format!("{}-filter", self.name), &self.name,
&self.description, &self.description,
self.guid, self.guid,
self.layer, self.layer,
@ -75,14 +75,14 @@ impl Callout {
pub(crate) fn register_callout( pub(crate) fn register_callout(
&mut self, &mut self,
filter_engine_handle: isize, filter_engine_handle: HANDLE,
device_object: *mut DEVICE_OBJECT, device_object: *mut DEVICE_OBJECT,
callout_fn: FwpsCalloutClassifyFn, callout_fn: FwpsCalloutClassifyFn,
) -> Result<(), String> { ) -> Result<(), String> {
match ffi::register_callout( match ffi::register_callout(
device_object, device_object,
filter_engine_handle, filter_engine_handle,
&format!("{}-callout", self.name), &self.name,
&self.description, &self.description,
self.guid, self.guid,
self.layer, self.layer,

View file

@ -37,9 +37,7 @@ impl ClassifyDefer {
} }
ClassifyDefer::Reauthorization(_callout_id, packet_list) => { ClassifyDefer::Reauthorization(_callout_id, packet_list) => {
// There is no way to reset single filter. If another request for filter reset is trigger at the same time it will fail. // There is no way to reset single filter. If another request for filter reset is trigger at the same time it will fail.
if let Err(err) = filter_engine.reset_all_filters() { filter_engine.reset_all_filters()?;
return Err(err);
}
return Ok(packet_list); return Ok(packet_list);
} }
} }
@ -140,7 +138,7 @@ impl<'a> CalloutData<'a> {
packet_list: Option<TransportPacketList>, packet_list: Option<TransportPacketList>,
) -> Result<ClassifyDefer, String> { ) -> Result<ClassifyDefer, String> {
unsafe { unsafe {
let mut completion_context = 0; let mut completion_context: HANDLE = core::ptr::null_mut();
if let Some(completion_handle) = (*self.metadata).get_completion_handle() { if let Some(completion_handle) = (*self.metadata).get_completion_handle() {
let status = FwpsPendOperation0(completion_handle, &mut completion_context); let status = FwpsPendOperation0(completion_handle, &mut completion_context);
check_ntstatus(status)?; check_ntstatus(status)?;

View file

@ -113,9 +113,7 @@ pub(crate) fn register_callout(
check_ntstatus(status)?; check_ntstatus(status)?;
if let Err(err) = callout_add(filter_engine_handle, guid, layer, name, description) { callout_add(filter_engine_handle, guid, layer, name, description)?;
return Err(err);
}
return Ok(callout_id); return Ok(callout_id);
} }

View file

@ -154,10 +154,10 @@ impl FwpsIncomingMetadataValues {
#[allow(dead_code)] #[allow(dead_code)]
#[repr(C)] #[repr(C)]
enum FwpsDiscardModule0 { enum FwpsDiscardModule0 {
FwpsDiscardModuleNetwork = 0, Network = 0,
FwpsDiscardModuleTransport = 1, Transport = 1,
FwpsDiscardModuleGeneral = 2, General = 2,
FwpsDiscardModuleMax = 3, Max = 3,
} }
#[repr(C)] #[repr(C)]

View file

@ -107,9 +107,7 @@ impl FilterEngine {
filter_engine.callouts = Some(boxed_callouts); filter_engine.callouts = Some(boxed_callouts);
} }
if let Err(err) = filter_engine.commit() { filter_engine.commit()?
return Err(err);
}
} }
self.committed = true; self.committed = true;
info!("transaction committed"); info!("transaction committed");
@ -147,9 +145,7 @@ impl FilterEngine {
} }
} }
// Commit transaction. // Commit transaction.
if let Err(err) = filter_engine.commit() { filter_engine.commit()?;
return Err(err);
}
return Ok(()); return Ok(());
} }
@ -192,7 +188,7 @@ impl Drop for FilterEngine {
} }
} }
if self.handle != 0 && self.handle != INVALID_HANDLE_VALUE { if !self.handle.is_null() && self.handle != INVALID_HANDLE_VALUE {
_ = ffi::filter_engine_close(self.handle); _ = ffi::filter_engine_close(self.handle);
} }
} }

View file

@ -85,7 +85,7 @@ impl NetBufferList {
} }
// Allocate space in buffer, if buffer is too small. // Allocate space in buffer, if buffer is too small.
let mut buffer = alloc::vec![0 as u8; data_length as usize]; let mut buffer = alloc::vec![0_u8; data_length as usize];
let ptr = NdisGetDataBuffer(nb, data_length, buffer.as_mut_ptr(), 1, 0); let ptr = NdisGetDataBuffer(nb, data_length, buffer.as_mut_ptr(), 1, 0);
@ -209,7 +209,7 @@ impl Iterator for NetBufferListIter {
} }
} }
pub fn read_packet_partial<'a>(nbl: *mut NET_BUFFER_LIST, buffer: &'a mut [u8]) -> Result<(), ()> { pub fn read_packet_partial(nbl: *mut NET_BUFFER_LIST, buffer: &mut [u8]) -> Result<(), ()> {
unsafe { unsafe {
let Some(nbl) = nbl.as_ref() else { let Some(nbl) = nbl.as_ref() else {
return Err(()); return Err(());

View file

@ -105,9 +105,9 @@ impl Injector {
} }
let mut remote_ip: [u8; 16] = [0; 16]; let mut remote_ip: [u8; 16] = [0; 16];
if ipv6 { if ipv6 {
remote_ip[0..16].copy_from_slice(&remote_ip_slice); remote_ip[0..16].copy_from_slice(remote_ip_slice);
} else { } else {
remote_ip[0..4].copy_from_slice(&remote_ip_slice); remote_ip[0..4].copy_from_slice(remote_ip_slice);
} }
TransportPacketList { TransportPacketList {
@ -163,7 +163,7 @@ impl Injector {
let status = if packet_list.inbound { let status = if packet_list.inbound {
FwpsInjectTransportReceiveAsync0( FwpsInjectTransportReceiveAsync0(
self.transport_inject_handle, self.transport_inject_handle,
0, core::ptr::null_mut(),
core::ptr::null_mut(), core::ptr::null_mut(),
0, 0,
address_family, address_family,
@ -177,7 +177,7 @@ impl Injector {
} else { } else {
FwpsInjectTransportSendAsync1( FwpsInjectTransportSendAsync1(
self.transport_inject_handle, self.transport_inject_handle,
0, core::ptr::null_mut(),
packet_list.endpoint_handle, packet_list.endpoint_handle,
0, 0,
&mut send_params, &mut send_params,
@ -222,7 +222,7 @@ impl Injector {
unsafe { unsafe {
FwpsInjectNetworkReceiveAsync0( FwpsInjectNetworkReceiveAsync0(
inject_handle, inject_handle,
0, core::ptr::null_mut(),
0, 0,
UNSPECIFIED_COMPARTMENT_ID, UNSPECIFIED_COMPARTMENT_ID,
inject_info.interface_index, inject_info.interface_index,
@ -237,7 +237,7 @@ impl Injector {
unsafe { unsafe {
FwpsInjectNetworkSendAsync0( FwpsInjectNetworkSendAsync0(
inject_handle, inject_handle,
0, core::ptr::null_mut(),
0, 0,
UNSPECIFIED_COMPARTMENT_ID, UNSPECIFIED_COMPARTMENT_ID,
nbl, nbl,
@ -269,7 +269,7 @@ impl Injector {
} else { } else {
self.packet_inject_handle_v4 self.packet_inject_handle_v4
}; };
if inject_handle == INVALID_HANDLE_VALUE || inject_handle == 0 { if inject_handle == INVALID_HANDLE_VALUE || inject_handle.is_null() {
return false; return false;
} }
@ -309,19 +309,19 @@ impl Drop for Injector {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
if self.transport_inject_handle != INVALID_HANDLE_VALUE if self.transport_inject_handle != INVALID_HANDLE_VALUE
&& self.transport_inject_handle != 0 && !self.transport_inject_handle.is_null()
{ {
FwpsInjectionHandleDestroy0(self.transport_inject_handle); FwpsInjectionHandleDestroy0(self.transport_inject_handle);
self.transport_inject_handle = INVALID_HANDLE_VALUE; self.transport_inject_handle = INVALID_HANDLE_VALUE;
} }
if self.packet_inject_handle_v4 != INVALID_HANDLE_VALUE if self.packet_inject_handle_v4 != INVALID_HANDLE_VALUE
&& self.packet_inject_handle_v4 != 0 && !self.packet_inject_handle_v4.is_null()
{ {
FwpsInjectionHandleDestroy0(self.packet_inject_handle_v4); FwpsInjectionHandleDestroy0(self.packet_inject_handle_v4);
self.packet_inject_handle_v4 = INVALID_HANDLE_VALUE; self.packet_inject_handle_v4 = INVALID_HANDLE_VALUE;
} }
if self.packet_inject_handle_v6 != INVALID_HANDLE_VALUE if self.packet_inject_handle_v6 != INVALID_HANDLE_VALUE
&& self.packet_inject_handle_v6 != 0 && !self.packet_inject_handle_v6.is_null()
{ {
FwpsInjectionHandleDestroy0(self.packet_inject_handle_v6); FwpsInjectionHandleDestroy0(self.packet_inject_handle_v6);
self.packet_inject_handle_v6 = INVALID_HANDLE_VALUE; self.packet_inject_handle_v6 = INVALID_HANDLE_VALUE;

View file

@ -67,7 +67,7 @@ impl ReadRequest<'_> {
for i in 0..bytes_to_write { for i in 0..bytes_to_write {
self.buffer[self.fill_index + i] = bytes[i]; self.buffer[self.fill_index + i] = bytes[i];
} }
self.fill_index = self.fill_index + bytes_to_write; self.fill_index += bytes_to_write;
bytes_to_write bytes_to_write
} }
@ -94,7 +94,7 @@ impl WriteRequest<'_> {
} }
pub fn get_buffer(&self) -> &[u8] { pub fn get_buffer(&self) -> &[u8] {
&self.buffer self.buffer
} }
pub fn mark_all_as_read(&mut self) { pub fn mark_all_as_read(&mut self) {
@ -155,7 +155,7 @@ impl DeviceControlRequest<'_> {
} }
pub fn get_buffer(&self) -> &[u8] { pub fn get_buffer(&self) -> &[u8] {
&self.buffer self.buffer
} }
pub fn write(&mut self, bytes: &[u8]) -> usize { pub fn write(&mut self, bytes: &[u8]) -> usize {
let mut bytes_to_write: usize = bytes.len(); let mut bytes_to_write: usize = bytes.len();
@ -168,7 +168,7 @@ impl DeviceControlRequest<'_> {
for i in 0..bytes_to_write { for i in 0..bytes_to_write {
self.buffer[self.fill_index + i] = bytes[i]; self.buffer[self.fill_index + i] = bytes[i];
} }
self.fill_index = self.fill_index + bytes_to_write; self.fill_index += bytes_to_write;
bytes_to_write bytes_to_write
} }