mirror of
https://github.com/safing/portmaster
synced 2025-09-02 02:29:12 +00:00
Merge pull request #1599 from safing/fix/kext-bug
[service] Fix kext verdict of update command
This commit is contained in:
commit
b7c32ec6de
27 changed files with 370 additions and 199 deletions
|
@ -5,11 +5,13 @@ package windowskext
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/service/process"
|
||||
"github.com/safing/portmaster/windows_kext/kextinterface"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
|
@ -32,8 +34,15 @@ func (v *VersionInfo) String() string {
|
|||
func Handler(ctx context.Context, packets chan packet.Packet, bandwidthUpdate chan *packet.BandwidthUpdate) {
|
||||
for {
|
||||
packetInfo, err := RecvVerdictRequest()
|
||||
|
||||
if errors.Is(err, kextinterface.ErrUnexpectedInfoSize) || errors.Is(err, kextinterface.ErrUnexpectedReadError) {
|
||||
log.Criticalf("unexpected kext info data: %s", err)
|
||||
continue // Depending on the info type this may not affect the functionality. Try to continue reading the next commands.
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warningf("failed to get packet from windows kext: %s", err)
|
||||
// Probably IO error, nothing else we can do.
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -39,9 +39,12 @@ func Start() error {
|
|||
}
|
||||
|
||||
// Start service and open file
|
||||
service.Start(true)
|
||||
kextFile, err = service.OpenFile(1024)
|
||||
err = service.Start(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to start service: %s", err)
|
||||
}
|
||||
|
||||
kextFile, err = service.OpenFile(1024)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open driver: %w", err)
|
||||
}
|
||||
|
@ -130,7 +133,7 @@ func UpdateVerdict(conn *network.Connection) error {
|
|||
LocalPort: conn.LocalPort,
|
||||
RemoteAddress: [4]byte(conn.Entity.IP),
|
||||
RemotePort: conn.Entity.Port,
|
||||
Verdict: uint8(conn.Verdict),
|
||||
Verdict: uint8(getKextVerdictFromConnection(conn)),
|
||||
}
|
||||
|
||||
return kextinterface.SendUpdateV4Command(kextFile, update)
|
||||
|
@ -141,7 +144,7 @@ func UpdateVerdict(conn *network.Connection) error {
|
|||
LocalPort: conn.LocalPort,
|
||||
RemoteAddress: [16]byte(conn.Entity.IP),
|
||||
RemotePort: conn.Entity.Port,
|
||||
Verdict: uint8(conn.Verdict),
|
||||
Verdict: uint8(getKextVerdictFromConnection(conn)),
|
||||
}
|
||||
|
||||
return kextinterface.SendUpdateV6Command(kextFile, update)
|
||||
|
@ -149,6 +152,40 @@ func UpdateVerdict(conn *network.Connection) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func getKextVerdictFromConnection(conn *network.Connection) kextinterface.KextVerdict {
|
||||
switch conn.Verdict {
|
||||
case network.VerdictUndecided:
|
||||
return kextinterface.VerdictUndecided
|
||||
case network.VerdictUndeterminable:
|
||||
return kextinterface.VerdictUndeterminable
|
||||
case network.VerdictAccept:
|
||||
if conn.VerdictPermanent {
|
||||
return kextinterface.VerdictPermanentAccept
|
||||
} else {
|
||||
return kextinterface.VerdictAccept
|
||||
}
|
||||
case network.VerdictBlock:
|
||||
if conn.VerdictPermanent {
|
||||
return kextinterface.VerdictPermanentBlock
|
||||
} else {
|
||||
return kextinterface.VerdictBlock
|
||||
}
|
||||
case network.VerdictDrop:
|
||||
if conn.VerdictPermanent {
|
||||
return kextinterface.VerdictPermanentDrop
|
||||
} else {
|
||||
return kextinterface.VerdictDrop
|
||||
}
|
||||
case network.VerdictRerouteToNameserver:
|
||||
return kextinterface.VerdictRerouteToNameserver
|
||||
case network.VerdictRerouteToTunnel:
|
||||
return kextinterface.VerdictRerouteToTunnel
|
||||
case network.VerdictFailed:
|
||||
return kextinterface.VerdictFailed
|
||||
}
|
||||
return kextinterface.VerdictUndeterminable
|
||||
}
|
||||
|
||||
// Returns the kext version.
|
||||
func GetVersion() (*VersionInfo, error) {
|
||||
data, err := kextinterface.ReadVersion(kextFile)
|
||||
|
|
40
windows_kext/driver/Cargo.lock
generated
40
windows_kext/driver/Cargo.lock
generated
|
@ -340,19 +340,20 @@ checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8"
|
|||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
|
@ -361,38 +362,43 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
|
|
|
@ -22,5 +22,5 @@ hashbrown = { version = "0.14.3", default-features = false, features = ["ahash"]
|
|||
# WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels.
|
||||
[dependencies.windows-sys]
|
||||
git = "https://github.com/microsoft/windows-rs"
|
||||
rev = "41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
rev = "dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_SystemServices", "Win32_Foundation", "Win32_Security", "Win32_System_IO", "Win32_System_Kernel", "Win32_System_Power", "Win32_System_WindowsProgramming", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_WindowsFilteringPlatform"]
|
||||
|
|
|
@ -12,8 +12,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
// -----------------------------------------
|
||||
// ALE Auth layers
|
||||
Callout::new(
|
||||
"AleLayerOutboundV4",
|
||||
"ALE layer for outbound connection for ipv4",
|
||||
"Portmaster ALE Outbound IPv4",
|
||||
"Portmaster uses this layer to block/permit outgoing ipv4 connections",
|
||||
0x58545073_f893_454c_bbea_a57bc964f46d,
|
||||
Layer::AleAuthConnectV4,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
|
@ -21,8 +21,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
ale_callouts::ale_layer_connect_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"AleLayerOutboundV6",
|
||||
"ALE layer for outbound connections for ipv6",
|
||||
"Portmaster ALE Outbound IPv6",
|
||||
"Portmaster uses this layer to block/permit outgoing ipv6 connections",
|
||||
0x4bd2a080_2585_478d_977c_7f340c6bc3d4,
|
||||
Layer::AleAuthConnectV6,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
|
@ -32,8 +32,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
// -----------------------------------------
|
||||
// ALE connection end layers
|
||||
Callout::new(
|
||||
"AleEndpointClosureV4",
|
||||
"ALE layer for indicating closing of connection for ipv4",
|
||||
"Portmaster Endpoint Closure IPv4",
|
||||
"Portmaster uses this layer to detect when a IPv4 connection has ended",
|
||||
0x58f02845_ace9_4455_ac80_8a84b86fe566,
|
||||
Layer::AleEndpointClosureV4,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
|
@ -41,8 +41,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
ale_callouts::endpoint_closure_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"AleEndpointClosureV6",
|
||||
"ALE layer for indicating closing of connection for ipv6",
|
||||
"Portmaster Endpoint Closure IPv6",
|
||||
"Portmaster uses this layer to detect when a IPv6 connection has ended",
|
||||
0x2bc82359_9dc5_4315_9c93_c89467e283ce,
|
||||
Layer::AleEndpointClosureV6,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
|
@ -61,8 +61,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
// ale_callouts::ale_resource_monitor,
|
||||
// ),
|
||||
Callout::new(
|
||||
"AleResourceReleaseV4",
|
||||
"Ipv4 Port release monitor",
|
||||
"Portmaster resource release IPv4",
|
||||
"Portmaster uses this layer to detect when a IPv4 port has been released",
|
||||
0x7b513bb3_a0be_4f77_a4bc_03c052abe8d7,
|
||||
Layer::AleResourceReleaseV4,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
|
@ -79,8 +79,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
// ale_callouts::ale_resource_monitor,
|
||||
// ),
|
||||
Callout::new(
|
||||
"AleResourceReleaseV6",
|
||||
"Ipv6 Port release monitor",
|
||||
"Portmaster resource release IPv6",
|
||||
"Portmaster uses this layer to detect when a IPv6 port has been released",
|
||||
0x6cf36e04_e656_42c3_8cac_a1ce05328bd1,
|
||||
Layer::AleResourceReleaseV6,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
|
@ -90,8 +90,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
// -----------------------------------------
|
||||
// Stream layer
|
||||
Callout::new(
|
||||
"StreamLayerV4",
|
||||
"Stream layer for ipv4",
|
||||
"Portmaster Stream IPv4",
|
||||
"Portmaster uses this layer for bandwidth statistics of IPv4 TCP connections",
|
||||
0xe2ca13bf_9710_4caa_a45c_e8c78b5ac780,
|
||||
Layer::StreamV4,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
|
@ -99,8 +99,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
stream_callouts::stream_layer_tcp_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"StreamLayerV6",
|
||||
"Stream layer for ipv6",
|
||||
"Portmaster Stream IPv6",
|
||||
"Portmaster uses this layer for bandwidth statistics of IPv6 TCP connections",
|
||||
0x66c549b3_11e2_4b27_8f73_856e6fd82baa,
|
||||
Layer::StreamV6,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
|
@ -108,8 +108,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
stream_callouts::stream_layer_tcp_v6,
|
||||
),
|
||||
Callout::new(
|
||||
"DatagramDataLayerV4",
|
||||
"DatagramData layer for ipv4",
|
||||
"Portmaster Datagram IPv4",
|
||||
"Portmaster uses this layer for bandwidth statistics of IPv4 UDP connections",
|
||||
0xe7eeeaba_168a_45bb_8747_e1a702feb2c5,
|
||||
Layer::DatagramDataV4,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
|
@ -117,8 +117,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
stream_callouts::stream_layer_udp_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"DatagramDataLayerV6",
|
||||
"DatagramData layer for ipv4",
|
||||
"Portmaster Datagram IPv6",
|
||||
"Portmaster uses this layer for bandwidth statistics of IPv6 UDP connections",
|
||||
0xb25862cd_f744_4452_b14a_d0c1e5a25b30,
|
||||
Layer::DatagramDataV6,
|
||||
consts::FWP_ACTION_CALLOUT_INSPECTION,
|
||||
|
@ -128,8 +128,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
// -----------------------------------------
|
||||
// Packet layers
|
||||
Callout::new(
|
||||
"IPPacketOutboundV4",
|
||||
"IP packet outbound network layer callout for Ipv4",
|
||||
"Portmaster Packet Outbound IPv4",
|
||||
"Portmaster uses this layer to redirect/block/permit outgoing ipv4 packets",
|
||||
0xf3183afe_dc35_49f1_8ea2_b16b5666dd36,
|
||||
Layer::OutboundIppacketV4,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
|
@ -137,8 +137,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
packet_callouts::ip_packet_layer_outbound_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"IPPacketInboundV4",
|
||||
"IP packet inbound network layer callout for Ipv4",
|
||||
"Portmaster Packet Inbound IPv4",
|
||||
"Portmaster uses this layer to redirect/block/permit inbound ipv4 packets",
|
||||
0xf0369374_203d_4bf0_83d2_b2ad3cc17a50,
|
||||
Layer::InboundIppacketV4,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
|
@ -146,8 +146,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
packet_callouts::ip_packet_layer_inbound_v4,
|
||||
),
|
||||
Callout::new(
|
||||
"IPPacketOutboundV6",
|
||||
"IP packet outbound network layer callout for Ipv6",
|
||||
"Portmaster Packet Outbound IPv6",
|
||||
"Portmaster uses this layer to redirect/block/permit outgoing ipv6 packets",
|
||||
0x91daf8bc_0908_4bf8_9f81_2c538ab8f25a,
|
||||
Layer::OutboundIppacketV6,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
|
@ -155,8 +155,8 @@ pub fn get_callout_vec() -> Vec<Callout> {
|
|||
packet_callouts::ip_packet_layer_outbound_v6,
|
||||
),
|
||||
Callout::new(
|
||||
"IPPacketInboundV6",
|
||||
"IP packet inbound network layer callout for Ipv6",
|
||||
"Portmaster Packet Inbound IPv6",
|
||||
"Portmaster uses this layer to redirect/block/permit inbound ipv6 packets",
|
||||
0xfe9faf5f_ceb2_4cd9_9995_f2f2b4f5fcc0,
|
||||
Layer::InboundIppacketV6,
|
||||
consts::FWP_ACTION_CALLOUT_TERMINATING,
|
||||
|
|
|
@ -38,10 +38,10 @@ pub extern "system" fn driver_entry(
|
|||
};
|
||||
|
||||
// Set driver functions.
|
||||
driver.set_driver_unload(driver_unload);
|
||||
driver.set_read_fn(driver_read);
|
||||
driver.set_write_fn(driver_write);
|
||||
driver.set_device_control_fn(device_control);
|
||||
driver.set_driver_unload(Some(driver_unload));
|
||||
driver.set_read_fn(Some(driver_read));
|
||||
driver.set_write_fn(Some(driver_write));
|
||||
driver.set_device_control_fn(Some(device_control));
|
||||
|
||||
// Initialize device.
|
||||
unsafe {
|
||||
|
@ -70,10 +70,10 @@ unsafe extern "system" fn driver_unload(_object: *const DRIVER_OBJECT) {
|
|||
|
||||
// driver_read event triggered from user-space on file.Read.
|
||||
unsafe extern "system" fn driver_read(
|
||||
_device_object: &mut DEVICE_OBJECT,
|
||||
irp: &mut IRP,
|
||||
_device_object: *const DEVICE_OBJECT,
|
||||
irp: *mut IRP,
|
||||
) -> NTSTATUS {
|
||||
let mut read_request = ReadRequest::new(irp);
|
||||
let mut read_request = ReadRequest::new(irp.as_mut().unwrap());
|
||||
let Some(device) = get_device() else {
|
||||
read_request.complete();
|
||||
|
||||
|
@ -86,10 +86,10 @@ unsafe extern "system" fn driver_read(
|
|||
|
||||
/// driver_write event triggered from user-space on file.Write.
|
||||
unsafe extern "system" fn driver_write(
|
||||
_device_object: &mut DEVICE_OBJECT,
|
||||
irp: &mut IRP,
|
||||
_device_object: *const DEVICE_OBJECT,
|
||||
irp: *mut IRP,
|
||||
) -> NTSTATUS {
|
||||
let mut write_request = WriteRequest::new(irp);
|
||||
let mut write_request = WriteRequest::new(irp.as_mut().unwrap());
|
||||
let Some(device) = get_device() else {
|
||||
write_request.complete();
|
||||
return write_request.get_status();
|
||||
|
@ -104,10 +104,10 @@ unsafe extern "system" fn driver_write(
|
|||
|
||||
/// device_control event triggered from user-space on file.deviceIOControl.
|
||||
unsafe extern "system" fn device_control(
|
||||
_device_object: &mut DEVICE_OBJECT,
|
||||
irp: &mut IRP,
|
||||
_device_object: *const DEVICE_OBJECT,
|
||||
irp: *mut IRP,
|
||||
) -> NTSTATUS {
|
||||
let mut control_request = DeviceControlRequest::new(irp);
|
||||
let mut control_request = DeviceControlRequest::new(irp.as_mut().unwrap());
|
||||
let Some(device) = get_device() else {
|
||||
control_request.complete();
|
||||
return control_request.get_status();
|
||||
|
|
|
@ -140,7 +140,7 @@ fn ip_packet_layer(
|
|||
} {
|
||||
Ok(key) => key,
|
||||
Err(err) => {
|
||||
crate::warn!("failed to get key from nbl: {}", err);
|
||||
crate::dbg!("failed to get key from nbl: {}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -3,6 +3,7 @@ package kextinterface
|
|||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
|
@ -18,6 +19,7 @@ const (
|
|||
|
||||
var (
|
||||
ErrUnknownInfoType = errors.New("unknown info type")
|
||||
ErrUnexpectedInfoSize = errors.New("unexpected info size")
|
||||
ErrUnexpectedReadError = errors.New("unexpected read error")
|
||||
)
|
||||
|
||||
|
@ -135,117 +137,215 @@ type Info struct {
|
|||
BandwidthStats *BandwidthStatsArray
|
||||
}
|
||||
|
||||
func RecvInfo(reader io.Reader) (*Info, error) {
|
||||
var infoType byte
|
||||
err := binary.Read(reader, binary.LittleEndian, &infoType)
|
||||
type readHelper struct {
|
||||
infoType byte
|
||||
commandSize uint32
|
||||
|
||||
readSize int
|
||||
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newReadHelper(reader io.Reader) (*readHelper, error) {
|
||||
helper := &readHelper{reader: reader}
|
||||
|
||||
err := binary.Read(reader, binary.LittleEndian, &helper.infoType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read size of data
|
||||
var size uint32
|
||||
err = binary.Read(reader, binary.LittleEndian, &size)
|
||||
err = binary.Read(reader, binary.LittleEndian, &helper.commandSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return helper, nil
|
||||
}
|
||||
|
||||
func (r *readHelper) ReadData(data any) error {
|
||||
err := binary.Read(r, binary.LittleEndian, data)
|
||||
if err != nil {
|
||||
return errors.Join(ErrUnexpectedReadError, err)
|
||||
}
|
||||
|
||||
if err := r.checkOverRead(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Passing size = 0 will read the rest of the command.
|
||||
func (r *readHelper) ReadBytes(size uint32) ([]byte, error) {
|
||||
if uint32(r.readSize) >= r.commandSize {
|
||||
return nil, errors.Join(fmt.Errorf("cannot read more bytes than the command size: %d >= %d", r.readSize, r.commandSize), ErrUnexpectedReadError)
|
||||
}
|
||||
|
||||
if size == 0 {
|
||||
size = r.commandSize - uint32(r.readSize)
|
||||
}
|
||||
|
||||
if r.commandSize < uint32(r.readSize)+size {
|
||||
return nil, ErrUnexpectedInfoSize
|
||||
}
|
||||
|
||||
bytes := make([]byte, size)
|
||||
err := binary.Read(r, binary.LittleEndian, bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
}
|
||||
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
func (r *readHelper) ReadUntilTheEnd() {
|
||||
_, _ = r.ReadBytes(0)
|
||||
}
|
||||
|
||||
func (r *readHelper) checkOverRead() error {
|
||||
if uint32(r.readSize) > r.commandSize {
|
||||
return ErrUnexpectedInfoSize
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *readHelper) Read(p []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(p)
|
||||
r.readSize += n
|
||||
return
|
||||
}
|
||||
|
||||
func RecvInfo(reader io.Reader) (*Info, error) {
|
||||
helper, err := newReadHelper(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make sure the whole command is read before return.
|
||||
defer helper.ReadUntilTheEnd()
|
||||
|
||||
// Read data
|
||||
switch infoType {
|
||||
switch helper.infoType {
|
||||
case InfoConnectionIpv4:
|
||||
{
|
||||
parseError := fmt.Errorf("failed to parse InfoConnectionIpv4")
|
||||
newInfo := ConnectionV4{}
|
||||
var fixedSizeValues connectionV4Internal
|
||||
err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues)
|
||||
// Read fixed size values.
|
||||
err = helper.ReadData(&fixedSizeValues)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err, fmt.Errorf("fixed"))
|
||||
}
|
||||
// Read size of payload
|
||||
var size uint32
|
||||
err = binary.Read(reader, binary.LittleEndian, &size)
|
||||
newInfo.connectionV4Internal = fixedSizeValues
|
||||
// Read size of payload.
|
||||
var payloadSize uint32
|
||||
err = helper.ReadData(&payloadSize)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err, fmt.Errorf("payloadsize"))
|
||||
}
|
||||
newInfo := ConnectionV4{connectionV4Internal: fixedSizeValues, Payload: make([]byte, size)}
|
||||
err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload)
|
||||
|
||||
// Check if there is payload.
|
||||
if payloadSize > 0 {
|
||||
// Read payload.
|
||||
newInfo.Payload, err = helper.ReadBytes(payloadSize)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err, fmt.Errorf("payload"))
|
||||
}
|
||||
}
|
||||
return &Info{ConnectionV4: &newInfo}, nil
|
||||
}
|
||||
case InfoConnectionIpv6:
|
||||
{
|
||||
var fixedSizeValues connectionV6Internal
|
||||
err = binary.Read(reader, binary.LittleEndian, &fixedSizeValues)
|
||||
parseError := fmt.Errorf("failed to parse InfoConnectionIpv6")
|
||||
newInfo := ConnectionV6{}
|
||||
|
||||
// Read fixed size values.
|
||||
err = helper.ReadData(&newInfo.connectionV6Internal)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
// Read size of payload
|
||||
var size uint32
|
||||
err = binary.Read(reader, binary.LittleEndian, &size)
|
||||
|
||||
// Read size of payload.
|
||||
var payloadSize uint32
|
||||
err = helper.ReadData(&payloadSize)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
newInfo := ConnectionV6{connectionV6Internal: fixedSizeValues, Payload: make([]byte, size)}
|
||||
err = binary.Read(reader, binary.LittleEndian, &newInfo.Payload)
|
||||
|
||||
// Check if there is payload.
|
||||
if payloadSize > 0 {
|
||||
// Read payload.
|
||||
newInfo.Payload, err = helper.ReadBytes(payloadSize)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &Info{ConnectionV6: &newInfo}, nil
|
||||
}
|
||||
case InfoConnectionEndEventV4:
|
||||
{
|
||||
parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV4")
|
||||
var connectionEnd ConnectionEndV4
|
||||
err = binary.Read(reader, binary.LittleEndian, &connectionEnd)
|
||||
|
||||
// Read fixed size values.
|
||||
err = helper.ReadData(&connectionEnd)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
return &Info{ConnectionEndV4: &connectionEnd}, nil
|
||||
}
|
||||
case InfoConnectionEndEventV6:
|
||||
{
|
||||
parseError := fmt.Errorf("failed to parse InfoConnectionEndEventV6")
|
||||
var connectionEnd ConnectionEndV6
|
||||
err = binary.Read(reader, binary.LittleEndian, &connectionEnd)
|
||||
|
||||
// Read fixed size values.
|
||||
err = helper.ReadData(&connectionEnd)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
return &Info{ConnectionEndV6: &connectionEnd}, nil
|
||||
}
|
||||
case InfoLogLine:
|
||||
{
|
||||
parseError := fmt.Errorf("failed to parse InfoLogLine")
|
||||
logLine := LogLine{}
|
||||
// Read severity
|
||||
err = binary.Read(reader, binary.LittleEndian, &logLine.Severity)
|
||||
err = helper.ReadData(&logLine.Severity)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
// Read string
|
||||
line := make([]byte, size-1) // -1 for the severity enum.
|
||||
err = binary.Read(reader, binary.LittleEndian, &line)
|
||||
bytes, err := helper.ReadBytes(0)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
logLine.Line = string(line)
|
||||
logLine.Line = string(bytes)
|
||||
return &Info{LogLine: &logLine}, nil
|
||||
}
|
||||
case InfoBandwidthStatsV4:
|
||||
{
|
||||
parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV4")
|
||||
// Read Protocol
|
||||
var protocol uint8
|
||||
err = binary.Read(reader, binary.LittleEndian, &protocol)
|
||||
err = helper.ReadData(&protocol)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
// Read size of array
|
||||
var size uint32
|
||||
err = binary.Read(reader, binary.LittleEndian, &size)
|
||||
err = helper.ReadData(&size)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
// Read array
|
||||
statsArray := make([]BandwidthValueV4, size)
|
||||
for i := range int(size) {
|
||||
err = binary.Read(reader, binary.LittleEndian, &statsArray[i])
|
||||
err = helper.ReadData(&statsArray[i])
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,24 +353,25 @@ func RecvInfo(reader io.Reader) (*Info, error) {
|
|||
}
|
||||
case InfoBandwidthStatsV6:
|
||||
{
|
||||
parseError := fmt.Errorf("failed to parse InfoBandwidthStatsV6")
|
||||
// Read Protocol
|
||||
var protocol uint8
|
||||
err = binary.Read(reader, binary.LittleEndian, &protocol)
|
||||
err = helper.ReadData(&protocol)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
// Read size of array
|
||||
var size uint32
|
||||
err = binary.Read(reader, binary.LittleEndian, &size)
|
||||
err = helper.ReadData(&size)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
// Read array
|
||||
statsArray := make([]BandwidthValueV6, size)
|
||||
for i := range int(size) {
|
||||
err = binary.Read(reader, binary.LittleEndian, &statsArray[i])
|
||||
err = helper.ReadData(&statsArray[i])
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrUnexpectedReadError, err)
|
||||
return nil, errors.Join(parseError, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -278,10 +379,5 @@ func RecvInfo(reader io.Reader) (*Info, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Command not recognized, read until the end of command and return.
|
||||
// During normal operation this should not happen.
|
||||
unknownData := make([]byte, size)
|
||||
_, _ = reader.Read(unknownData)
|
||||
|
||||
return nil, ErrUnknownInfoType
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
|
@ -221,7 +222,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro
|
|||
// Check if there is an old service.
|
||||
service, err := windows.OpenService(manager, &driverNameU16[0], windows.SERVICE_ALL_ACCESS)
|
||||
if err == nil {
|
||||
fmt.Println("kext: old driver service was found")
|
||||
log.Warning("kext: old driver service was found")
|
||||
oldService := &KextService{handle: service, driverName: driverName}
|
||||
oldService.Stop(true)
|
||||
err = oldService.Delete()
|
||||
|
@ -234,7 +235,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro
|
|||
}
|
||||
|
||||
service = winInvalidHandleValue
|
||||
fmt.Println("kext: old driver service was deleted successfully")
|
||||
log.Warning("kext: old driver service was deleted successfully")
|
||||
}
|
||||
|
||||
driverPathU16, err := syscall.UTF16FromString(driverPath)
|
||||
|
|
|
@ -18,8 +18,18 @@ func TestRustInfoFile(t *testing.T) {
|
|||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
first := true
|
||||
for {
|
||||
info, err := RecvInfo(file)
|
||||
// First info should be with invalid size.
|
||||
// This tests if invalid info data is handled properly.
|
||||
if first {
|
||||
if !errors.Is(err, ErrUnexpectedInfoSize) {
|
||||
t.Errorf("unexpected error: %s\n", err)
|
||||
}
|
||||
first = false
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUnexpectedReadError) {
|
||||
t.Errorf("unexpected error: %s\n", err)
|
||||
|
|
Binary file not shown.
|
@ -441,6 +441,22 @@ fn generate_test_info_file() -> Result<(), std::io::Error> {
|
|||
for _ in 0..selected.capacity() {
|
||||
selected.push(enums.choose(&mut rng).unwrap().clone());
|
||||
}
|
||||
// Write wrong size data. To make sure that mismatches between kext and portmaster are handled properly.
|
||||
let mut info = connection_info_v6(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||
);
|
||||
info.assert_size();
|
||||
info.0[0] = InfoType::ConnectionIpv4 as u8;
|
||||
file.write_all(&info.0)?;
|
||||
|
||||
for value in selected {
|
||||
file.write_all(&match value {
|
||||
|
@ -548,5 +564,6 @@ fn generate_test_info_file() -> Result<(), std::io::Error> {
|
|||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
|
BIN
windows_kext/protocol/testdata/go_command_test.bin
vendored
BIN
windows_kext/protocol/testdata/go_command_test.bin
vendored
Binary file not shown.
40
windows_kext/wdk/Cargo.lock
generated
40
windows_kext/wdk/Cargo.lock
generated
|
@ -84,19 +84,20 @@ checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8"
|
|||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
|
@ -105,35 +106,40 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.0"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=41ad38d8c42c92fd23fe25ba4dca76c2d861ca06#41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
version = "0.52.5"
|
||||
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
|
|
|
@ -16,5 +16,5 @@ features = ["alloc"]
|
|||
# WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels.
|
||||
[dependencies.windows-sys]
|
||||
git = "https://github.com/microsoft/windows-rs"
|
||||
rev = "41ad38d8c42c92fd23fe25ba4dca76c2d861ca06"
|
||||
rev = "dffa8b03dc4987c278d82e88015ffe96aa8ac317"
|
||||
features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_SystemServices", "Win32_Foundation", "Win32_Security", "Win32_System_IO", "Win32_System_Kernel", "Win32_System_Power", "Win32_System_WindowsProgramming", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_NetworkManagement_WindowsFilteringPlatform", "Win32_System_Rpc"]
|
||||
|
|
|
@ -10,5 +10,7 @@ see: `wdk/src/driver.rs`
|
|||
see: `wdk/src/irp_helper.rs`
|
||||
|
||||
Open issues need to be resolved:
|
||||
https://github.com/microsoft/wdkmetadata/issues/59
|
||||
https://github.com/microsoft/windows-rs/issues/2805
|
||||
|
||||
Resolved:
|
||||
https://github.com/microsoft/wdkmetadata/issues/59
|
||||
|
|
|
@ -43,8 +43,8 @@ unsafe impl GlobalAlloc for WindowsAllocator {
|
|||
}
|
||||
|
||||
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
|
||||
let pool = self.alloc(layout);
|
||||
pool
|
||||
|
||||
self.alloc(layout)
|
||||
}
|
||||
|
||||
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use windows_sys::{
|
||||
Wdk::Foundation::{DEVICE_OBJECT, DRIVER_OBJECT, IRP},
|
||||
Win32::Foundation::{HANDLE, NTSTATUS},
|
||||
Wdk::Foundation::{DEVICE_OBJECT, DRIVER_DISPATCH, DRIVER_OBJECT, DRIVER_UNLOAD},
|
||||
Win32::Foundation::HANDLE,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
@ -23,11 +23,6 @@ pub struct Driver {
|
|||
}
|
||||
unsafe impl Sync for Driver {}
|
||||
|
||||
// This is a workaround for current state of wdk bindings.
|
||||
// TODO: replace with official version when they are correct: https://github.com/microsoft/wdkmetadata/issues/59
|
||||
pub type UnloadFnType = unsafe extern "system" fn(driver_object: *const DRIVER_OBJECT);
|
||||
pub type MjFnType = unsafe extern "system" fn(&mut DEVICE_OBJECT, &mut IRP) -> NTSTATUS;
|
||||
|
||||
impl Driver {
|
||||
pub(crate) fn new(
|
||||
driver_object: *mut DRIVER_OBJECT,
|
||||
|
@ -50,54 +45,54 @@ impl Driver {
|
|||
return unsafe { self.device_object.as_mut() };
|
||||
}
|
||||
|
||||
pub fn set_driver_unload(&mut self, driver_unload: UnloadFnType) {
|
||||
pub fn set_driver_unload(&mut self, driver_unload: DRIVER_UNLOAD) {
|
||||
if let Some(driver) = unsafe { self.driver_object.as_mut() } {
|
||||
driver.DriverUnload = Some(unsafe { core::mem::transmute(driver_unload) })
|
||||
driver.DriverUnload = driver_unload
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_read_fn(&mut self, mj_fn: MjFnType) {
|
||||
pub fn set_read_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
|
||||
self.set_major_fn(windows_sys::Wdk::System::SystemServices::IRP_MJ_READ, mj_fn);
|
||||
}
|
||||
|
||||
pub fn set_write_fn(&mut self, mj_fn: MjFnType) {
|
||||
pub fn set_write_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
|
||||
self.set_major_fn(
|
||||
windows_sys::Wdk::System::SystemServices::IRP_MJ_WRITE,
|
||||
mj_fn,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn set_create_fn(&mut self, mj_fn: MjFnType) {
|
||||
pub fn set_create_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
|
||||
self.set_major_fn(
|
||||
windows_sys::Wdk::System::SystemServices::IRP_MJ_CREATE,
|
||||
mj_fn,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn set_device_control_fn(&mut self, mj_fn: MjFnType) {
|
||||
pub fn set_device_control_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
|
||||
self.set_major_fn(
|
||||
windows_sys::Wdk::System::SystemServices::IRP_MJ_DEVICE_CONTROL,
|
||||
mj_fn,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn set_close_fn(&mut self, mj_fn: MjFnType) {
|
||||
pub fn set_close_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
|
||||
self.set_major_fn(
|
||||
windows_sys::Wdk::System::SystemServices::IRP_MJ_CLOSE,
|
||||
mj_fn,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn set_cleanup_fn(&mut self, mj_fn: MjFnType) {
|
||||
pub fn set_cleanup_fn(&mut self, mj_fn: DRIVER_DISPATCH) {
|
||||
self.set_major_fn(
|
||||
windows_sys::Wdk::System::SystemServices::IRP_MJ_CLEANUP,
|
||||
mj_fn,
|
||||
);
|
||||
}
|
||||
|
||||
fn set_major_fn(&mut self, fn_index: u32, mj_fn: MjFnType) {
|
||||
fn set_major_fn(&mut self, fn_index: u32, mj_fn: DRIVER_DISPATCH) {
|
||||
if let Some(driver) = unsafe { self.driver_object.as_mut() } {
|
||||
driver.MajorFunction[fn_index as usize] = Some(unsafe { core::mem::transmute(mj_fn) })
|
||||
driver.MajorFunction[fn_index as usize] = mj_fn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -270,7 +270,7 @@ impl WdfObjectAttributes {
|
|||
evt_destroy_callback: None,
|
||||
execution_level: WdfExecutionLevel::InheritFromParent,
|
||||
synchronization_scope: WdfSynchronizationScope::InheritFromParent,
|
||||
parent_object: 0,
|
||||
parent_object: core::ptr::null_mut(),
|
||||
context_size_override: 0,
|
||||
context_type_info: core::ptr::null(),
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::{callout_data::CalloutData, ffi, layer::Layer};
|
||||
use crate::ffi::FwpsCalloutClassifyFn;
|
||||
use alloc::{borrow::ToOwned, format, string::String};
|
||||
use windows_sys::Wdk::Foundation::DEVICE_OBJECT;
|
||||
use windows_sys::{Wdk::Foundation::DEVICE_OBJECT, Win32::Foundation::HANDLE};
|
||||
|
||||
pub enum FilterType {
|
||||
Resettable,
|
||||
|
@ -49,13 +49,13 @@ impl Callout {
|
|||
|
||||
pub fn register_filter(
|
||||
&mut self,
|
||||
filter_engine_handle: isize,
|
||||
filter_engine_handle: HANDLE,
|
||||
sublayer_guid: u128,
|
||||
) -> Result<(), String> {
|
||||
match ffi::register_filter(
|
||||
filter_engine_handle,
|
||||
sublayer_guid,
|
||||
&format!("{}-filter", self.name),
|
||||
&self.name,
|
||||
&self.description,
|
||||
self.guid,
|
||||
self.layer,
|
||||
|
@ -75,14 +75,14 @@ impl Callout {
|
|||
|
||||
pub(crate) fn register_callout(
|
||||
&mut self,
|
||||
filter_engine_handle: isize,
|
||||
filter_engine_handle: HANDLE,
|
||||
device_object: *mut DEVICE_OBJECT,
|
||||
callout_fn: FwpsCalloutClassifyFn,
|
||||
) -> Result<(), String> {
|
||||
match ffi::register_callout(
|
||||
device_object,
|
||||
filter_engine_handle,
|
||||
&format!("{}-callout", self.name),
|
||||
&self.name,
|
||||
&self.description,
|
||||
self.guid,
|
||||
self.layer,
|
||||
|
|
|
@ -37,9 +37,7 @@ impl ClassifyDefer {
|
|||
}
|
||||
ClassifyDefer::Reauthorization(_callout_id, packet_list) => {
|
||||
// There is no way to reset single filter. If another request for filter reset is trigger at the same time it will fail.
|
||||
if let Err(err) = filter_engine.reset_all_filters() {
|
||||
return Err(err);
|
||||
}
|
||||
filter_engine.reset_all_filters()?;
|
||||
return Ok(packet_list);
|
||||
}
|
||||
}
|
||||
|
@ -140,7 +138,7 @@ impl<'a> CalloutData<'a> {
|
|||
packet_list: Option<TransportPacketList>,
|
||||
) -> Result<ClassifyDefer, String> {
|
||||
unsafe {
|
||||
let mut completion_context = 0;
|
||||
let mut completion_context: HANDLE = core::ptr::null_mut();
|
||||
if let Some(completion_handle) = (*self.metadata).get_completion_handle() {
|
||||
let status = FwpsPendOperation0(completion_handle, &mut completion_context);
|
||||
check_ntstatus(status)?;
|
||||
|
|
|
@ -113,9 +113,7 @@ pub(crate) fn register_callout(
|
|||
|
||||
check_ntstatus(status)?;
|
||||
|
||||
if let Err(err) = callout_add(filter_engine_handle, guid, layer, name, description) {
|
||||
return Err(err);
|
||||
}
|
||||
callout_add(filter_engine_handle, guid, layer, name, description)?;
|
||||
|
||||
return Ok(callout_id);
|
||||
}
|
||||
|
|
|
@ -154,10 +154,10 @@ impl FwpsIncomingMetadataValues {
|
|||
#[allow(dead_code)]
|
||||
#[repr(C)]
|
||||
enum FwpsDiscardModule0 {
|
||||
FwpsDiscardModuleNetwork = 0,
|
||||
FwpsDiscardModuleTransport = 1,
|
||||
FwpsDiscardModuleGeneral = 2,
|
||||
FwpsDiscardModuleMax = 3,
|
||||
Network = 0,
|
||||
Transport = 1,
|
||||
General = 2,
|
||||
Max = 3,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
|
|
|
@ -107,9 +107,7 @@ impl FilterEngine {
|
|||
filter_engine.callouts = Some(boxed_callouts);
|
||||
}
|
||||
|
||||
if let Err(err) = filter_engine.commit() {
|
||||
return Err(err);
|
||||
}
|
||||
filter_engine.commit()?
|
||||
}
|
||||
self.committed = true;
|
||||
info!("transaction committed");
|
||||
|
@ -147,9 +145,7 @@ impl FilterEngine {
|
|||
}
|
||||
}
|
||||
// Commit transaction.
|
||||
if let Err(err) = filter_engine.commit() {
|
||||
return Err(err);
|
||||
}
|
||||
filter_engine.commit()?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
@ -192,7 +188,7 @@ impl Drop for FilterEngine {
|
|||
}
|
||||
}
|
||||
|
||||
if self.handle != 0 && self.handle != INVALID_HANDLE_VALUE {
|
||||
if !self.handle.is_null() && self.handle != INVALID_HANDLE_VALUE {
|
||||
_ = ffi::filter_engine_close(self.handle);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ impl NetBufferList {
|
|||
}
|
||||
|
||||
// Allocate space in buffer, if buffer is too small.
|
||||
let mut buffer = alloc::vec![0 as u8; data_length as usize];
|
||||
let mut buffer = alloc::vec![0_u8; data_length as usize];
|
||||
|
||||
let ptr = NdisGetDataBuffer(nb, data_length, buffer.as_mut_ptr(), 1, 0);
|
||||
|
||||
|
@ -209,7 +209,7 @@ impl Iterator for NetBufferListIter {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn read_packet_partial<'a>(nbl: *mut NET_BUFFER_LIST, buffer: &'a mut [u8]) -> Result<(), ()> {
|
||||
pub fn read_packet_partial(nbl: *mut NET_BUFFER_LIST, buffer: &mut [u8]) -> Result<(), ()> {
|
||||
unsafe {
|
||||
let Some(nbl) = nbl.as_ref() else {
|
||||
return Err(());
|
||||
|
|
|
@ -105,9 +105,9 @@ impl Injector {
|
|||
}
|
||||
let mut remote_ip: [u8; 16] = [0; 16];
|
||||
if ipv6 {
|
||||
remote_ip[0..16].copy_from_slice(&remote_ip_slice);
|
||||
remote_ip[0..16].copy_from_slice(remote_ip_slice);
|
||||
} else {
|
||||
remote_ip[0..4].copy_from_slice(&remote_ip_slice);
|
||||
remote_ip[0..4].copy_from_slice(remote_ip_slice);
|
||||
}
|
||||
|
||||
TransportPacketList {
|
||||
|
@ -163,7 +163,7 @@ impl Injector {
|
|||
let status = if packet_list.inbound {
|
||||
FwpsInjectTransportReceiveAsync0(
|
||||
self.transport_inject_handle,
|
||||
0,
|
||||
core::ptr::null_mut(),
|
||||
core::ptr::null_mut(),
|
||||
0,
|
||||
address_family,
|
||||
|
@ -177,7 +177,7 @@ impl Injector {
|
|||
} else {
|
||||
FwpsInjectTransportSendAsync1(
|
||||
self.transport_inject_handle,
|
||||
0,
|
||||
core::ptr::null_mut(),
|
||||
packet_list.endpoint_handle,
|
||||
0,
|
||||
&mut send_params,
|
||||
|
@ -222,7 +222,7 @@ impl Injector {
|
|||
unsafe {
|
||||
FwpsInjectNetworkReceiveAsync0(
|
||||
inject_handle,
|
||||
0,
|
||||
core::ptr::null_mut(),
|
||||
0,
|
||||
UNSPECIFIED_COMPARTMENT_ID,
|
||||
inject_info.interface_index,
|
||||
|
@ -237,7 +237,7 @@ impl Injector {
|
|||
unsafe {
|
||||
FwpsInjectNetworkSendAsync0(
|
||||
inject_handle,
|
||||
0,
|
||||
core::ptr::null_mut(),
|
||||
0,
|
||||
UNSPECIFIED_COMPARTMENT_ID,
|
||||
nbl,
|
||||
|
@ -269,7 +269,7 @@ impl Injector {
|
|||
} else {
|
||||
self.packet_inject_handle_v4
|
||||
};
|
||||
if inject_handle == INVALID_HANDLE_VALUE || inject_handle == 0 {
|
||||
if inject_handle == INVALID_HANDLE_VALUE || inject_handle.is_null() {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -309,19 +309,19 @@ impl Drop for Injector {
|
|||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
if self.transport_inject_handle != INVALID_HANDLE_VALUE
|
||||
&& self.transport_inject_handle != 0
|
||||
&& !self.transport_inject_handle.is_null()
|
||||
{
|
||||
FwpsInjectionHandleDestroy0(self.transport_inject_handle);
|
||||
self.transport_inject_handle = INVALID_HANDLE_VALUE;
|
||||
}
|
||||
if self.packet_inject_handle_v4 != INVALID_HANDLE_VALUE
|
||||
&& self.packet_inject_handle_v4 != 0
|
||||
&& !self.packet_inject_handle_v4.is_null()
|
||||
{
|
||||
FwpsInjectionHandleDestroy0(self.packet_inject_handle_v4);
|
||||
self.packet_inject_handle_v4 = INVALID_HANDLE_VALUE;
|
||||
}
|
||||
if self.packet_inject_handle_v6 != INVALID_HANDLE_VALUE
|
||||
&& self.packet_inject_handle_v6 != 0
|
||||
&& !self.packet_inject_handle_v6.is_null()
|
||||
{
|
||||
FwpsInjectionHandleDestroy0(self.packet_inject_handle_v6);
|
||||
self.packet_inject_handle_v6 = INVALID_HANDLE_VALUE;
|
||||
|
|
|
@ -67,7 +67,7 @@ impl ReadRequest<'_> {
|
|||
for i in 0..bytes_to_write {
|
||||
self.buffer[self.fill_index + i] = bytes[i];
|
||||
}
|
||||
self.fill_index = self.fill_index + bytes_to_write;
|
||||
self.fill_index += bytes_to_write;
|
||||
|
||||
bytes_to_write
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ impl WriteRequest<'_> {
|
|||
}
|
||||
|
||||
pub fn get_buffer(&self) -> &[u8] {
|
||||
&self.buffer
|
||||
self.buffer
|
||||
}
|
||||
|
||||
pub fn mark_all_as_read(&mut self) {
|
||||
|
@ -155,7 +155,7 @@ impl DeviceControlRequest<'_> {
|
|||
}
|
||||
|
||||
pub fn get_buffer(&self) -> &[u8] {
|
||||
&self.buffer
|
||||
self.buffer
|
||||
}
|
||||
pub fn write(&mut self, bytes: &[u8]) -> usize {
|
||||
let mut bytes_to_write: usize = bytes.len();
|
||||
|
@ -168,7 +168,7 @@ impl DeviceControlRequest<'_> {
|
|||
for i in 0..bytes_to_write {
|
||||
self.buffer[self.fill_index + i] = bytes[i];
|
||||
}
|
||||
self.fill_index = self.fill_index + bytes_to_write;
|
||||
self.fill_index += bytes_to_write;
|
||||
|
||||
bytes_to_write
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue