[service] Fix check for invalid kext handle ()

* [service] Fix check for invalid kext handle

* [windows_kext] Use BTreeMap as cache structure

* [windows_kext] Fix synchronization bug

* Update windows_kext/kextinterface/kext_file.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update windows_kext/kextinterface/kext_file.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update windows_kext/kextinterface/kext_file.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Vladimir Stoilov 2024-10-16 12:19:08 +03:00 committed by GitHub
parent cfd877757d
commit 355f74318d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 86 additions and 196 deletions

View file

@ -2,18 +2,6 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "ahash"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
"zerocopy",
]
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
@ -57,7 +45,6 @@ checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216"
name = "driver"
version = "0.0.0"
dependencies = [
"hashbrown",
"num",
"num-derive",
"num-traits",
@ -76,15 +63,6 @@ dependencies = [
"byteorder",
]
[[package]]
name = "hashbrown"
version = "0.14.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604"
dependencies = [
"ahash",
]
[[package]]
name = "heapless"
version = "0.7.17"
@ -217,12 +195,6 @@ dependencies = [
"syn",
]
[[package]]
name = "once_cell"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "proc-macro2"
version = "1.0.78"
@ -316,12 +288,6 @@ version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "wdk"
version = "0.0.0"
@ -399,23 +365,3 @@ source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e
name = "windows_x86_64_msvc"
version = "0.52.5"
source = "git+https://github.com/microsoft/windows-rs?rev=dffa8b03dc4987c278d82e88015ffe96aa8ac317#dffa8b03dc4987c278d82e88015ffe96aa8ac317"
[[package]]
name = "zerocopy"
version = "0.7.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d6f15f7ade05d2a4935e34a457b936c23dc70a05cc1d97133dc99e7a3fe0f0e"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.7.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbbad221e3f78500350ecbd7dfa4e63ef945c05f4c61cb7f4d3f84cd0bba649b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View file

@ -17,7 +17,6 @@ num = { version = "0.4", default-features = false }
num-derive = { version = "0.4", default-features = false }
num-traits = { version = "0.2", default-features = false }
smoltcp = { version = "0.10", default-features = false, features = ["proto-ipv4", "proto-ipv6"] }
hashbrown = { version = "0.14.3", default-features = false, features = ["ahash"]}
# WARNING: Do not update. The version was choosen for a reason. See wdk/README.md for more detiels.
[dependencies.windows-sys]

View file

@ -1,14 +1,10 @@
use alloc::collections::BTreeMap;
use protocol::info::{BandwidthValueV4, BandwidthValueV6, Info};
use smoltcp::wire::{IpProtocol, Ipv4Address, Ipv6Address};
use wdk::rw_spin_lock::RwSpinLock;
use crate::driver_hashmap::DeviceHashMap;
#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
pub struct Key<Address>
where
Address: Eq + PartialEq,
{
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
pub struct Key<Address: Ord> {
pub local_ip: Address,
pub local_port: u16,
pub remote_ip: Address,
@ -25,32 +21,32 @@ enum Direction {
Rx(usize),
}
pub struct Bandwidth {
stats_tcp_v4: DeviceHashMap<Key<Ipv4Address>, Value>,
stats_tcp_v4: BTreeMap<Key<Ipv4Address>, Value>,
stats_tcp_v4_lock: RwSpinLock,
stats_tcp_v6: DeviceHashMap<Key<Ipv6Address>, Value>,
stats_tcp_v6: BTreeMap<Key<Ipv6Address>, Value>,
stats_tcp_v6_lock: RwSpinLock,
stats_udp_v4: DeviceHashMap<Key<Ipv4Address>, Value>,
stats_udp_v4: BTreeMap<Key<Ipv4Address>, Value>,
stats_udp_v4_lock: RwSpinLock,
stats_udp_v6: DeviceHashMap<Key<Ipv6Address>, Value>,
stats_udp_v6: BTreeMap<Key<Ipv6Address>, Value>,
stats_udp_v6_lock: RwSpinLock,
}
impl Bandwidth {
pub fn new() -> Self {
Self {
stats_tcp_v4: DeviceHashMap::new(),
stats_tcp_v4: BTreeMap::new(),
stats_tcp_v4_lock: RwSpinLock::default(),
stats_tcp_v6: DeviceHashMap::new(),
stats_tcp_v6: BTreeMap::new(),
stats_tcp_v6_lock: RwSpinLock::default(),
stats_udp_v4: DeviceHashMap::new(),
stats_udp_v4: BTreeMap::new(),
stats_udp_v4_lock: RwSpinLock::default(),
stats_udp_v6: DeviceHashMap::new(),
stats_udp_v6: BTreeMap::new(),
stats_udp_v6_lock: RwSpinLock::default(),
}
}
@ -62,7 +58,7 @@ impl Bandwidth {
if self.stats_tcp_v4.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v4, DeviceHashMap::new());
stats_map = core::mem::replace(&mut self.stats_tcp_v4, BTreeMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
@ -89,7 +85,7 @@ impl Bandwidth {
if self.stats_tcp_v6.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
stats_map = core::mem::replace(&mut self.stats_tcp_v6, BTreeMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
@ -116,7 +112,7 @@ impl Bandwidth {
if self.stats_udp_v4.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_udp_v4, DeviceHashMap::new());
stats_map = core::mem::replace(&mut self.stats_udp_v4, BTreeMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
@ -140,10 +136,10 @@ impl Bandwidth {
let stats_map;
{
let _guard = self.stats_udp_v6_lock.write_lock();
if self.stats_tcp_v6.is_empty() {
if self.stats_udp_v6.is_empty() {
return None;
}
stats_map = core::mem::replace(&mut self.stats_tcp_v6, DeviceHashMap::new());
stats_map = core::mem::replace(&mut self.stats_udp_v6, BTreeMap::new());
}
let mut values = alloc::vec::Vec::with_capacity(stats_map.len());
@ -235,8 +231,8 @@ impl Bandwidth {
);
}
fn update<Address: Eq + PartialEq + core::hash::Hash>(
map: &mut DeviceHashMap<Key<Address>, Value>,
fn update<Address: Ord>(
map: &mut BTreeMap<Key<Address>, Value>,
lock: &mut RwSpinLock,
key: Key<Address>,
bytes: Direction,

View file

@ -1,10 +1,8 @@
use core::time::Duration;
use crate::{
connection::{Connection, ConnectionV4, ConnectionV6, RedirectInfo, Verdict},
connection_map::{ConnectionMap, Key},
};
use alloc::{format, string::String, vec::Vec};
use alloc::vec::Vec;
use smoltcp::wire::IpProtocol;
use wdk::rw_spin_lock::RwSpinLock;
@ -128,73 +126,4 @@ impl ConnectionCache {
return size;
}
#[allow(dead_code)]
pub fn get_full_cache_info(&self) -> String {
let mut info = String::new();
let now = wdk::utils::get_system_timestamp_ms();
{
let _guard = self.lock_v4.read_lock();
for ((protocol, port), connections) in self.connections_v4.iter() {
info.push_str(&format!("{} -> {}\n", protocol, port,));
for conn in connections {
let active_time_seconds =
Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
info.push_str(&format!(
"\t{}:{} -> {}:{} {} last active {}m {}s ago",
conn.local_address,
conn.local_port,
conn.remote_address,
conn.remote_port,
conn.verdict,
active_time_seconds / 60,
active_time_seconds % 60
));
if conn.has_ended() {
let end_time_seconds =
Duration::from_millis(now - conn.get_end_time()).as_secs();
info.push_str(&format!(
"\t ended {}m {}s ago",
end_time_seconds / 60,
end_time_seconds % 60
));
}
info.push('\n');
}
}
}
{
let _guard = self.lock_v6.read_lock();
for ((protocol, port), connections) in self.connections_v6.iter() {
info.push_str(&format!("{} -> {} \n", protocol, port));
for conn in connections {
let active_time_seconds =
Duration::from_millis(now - conn.get_last_accessed_time()).as_secs();
info.push_str(&format!(
"\t{}:{} -> {}:{} {} last active {}m {}s ago",
conn.local_address,
conn.local_port,
conn.remote_address,
conn.remote_port,
conn.verdict,
active_time_seconds / 60,
active_time_seconds % 60
));
if conn.has_ended() {
let end_time_seconds =
Duration::from_millis(now - conn.get_end_time()).as_secs();
info.push_str(&format!(
"\t ended {}m {}s ago",
end_time_seconds / 60,
end_time_seconds % 60
));
}
info.push('\n');
}
}
}
return info;
}
}

View file

@ -1,8 +1,7 @@
use core::{fmt::Display, time::Duration};
use crate::connection::Connection;
use alloc::vec::Vec;
use hashbrown::HashMap;
use alloc::{collections::BTreeMap, vec::Vec};
use smoltcp::wire::{IpAddress, IpProtocol};
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
@ -63,11 +62,11 @@ impl Key {
}
}
pub struct ConnectionMap<T: Connection>(HashMap<(IpProtocol, u16), Vec<T>>);
pub struct ConnectionMap<T: Connection>(BTreeMap<(IpProtocol, u16), Vec<T>>);
impl<T: Connection + Clone> ConnectionMap<T> {
pub fn new() -> Self {
Self(HashMap::new())
Self(BTreeMap::new())
}
pub fn add(&mut self, conn: T) {
@ -164,7 +163,6 @@ impl<T: Connection + Clone> ConnectionMap<T> {
self.0.retain(|_, v| !v.is_empty());
}
#[allow(dead_code)]
pub fn get_count(&self) -> usize {
let mut count = 0;
for conn in self.0.values() {
@ -172,8 +170,4 @@ impl<T: Connection + Clone> ConnectionMap<T> {
}
return count;
}
pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec<T>> {
self.0.iter()
}
}

View file

@ -1,25 +0,0 @@
use core::ops::{Deref, DerefMut};
use hashbrown::HashMap;
pub struct DeviceHashMap<Key, Value>(Option<HashMap<Key, Value>>);
impl<Key, Value> DeviceHashMap<Key, Value> {
pub fn new() -> Self {
Self(Some(HashMap::new()))
}
}
impl<Key, Value> Deref for DeviceHashMap<Key, Value> {
type Target = HashMap<Key, Value>;
fn deref(&self) -> &Self::Target {
self.0.as_ref().unwrap()
}
}
impl<Key, Value> DerefMut for DeviceHashMap<Key, Value> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut().unwrap()
}
}

View file

@ -13,7 +13,6 @@ mod connection;
mod connection_cache;
mod connection_map;
mod device;
mod driver_hashmap;
mod entry;
mod id_cache;
pub mod logger;

View file

@ -4,6 +4,8 @@ use wdk::filter_engine::{callout_data::CalloutData, layer, net_buffer::NetBuffer
use crate::{bandwidth, connection::Direction};
pub fn stream_layer_tcp_v4(data: CalloutData) {
type Fields = layer::FieldsStreamV4;
let Some(device) = crate::entry::get_device() else {
return;
};
@ -16,7 +18,6 @@ pub fn stream_layer_tcp_v4(data: CalloutData) {
} else {
return;
};
type Fields = layer::FieldsStreamV4;
let local_ip = Ipv4Address::from_bytes(
&data
.get_value_u32(Fields::IpLocalAddress as usize)
@ -56,6 +57,8 @@ pub fn stream_layer_tcp_v4(data: CalloutData) {
}
pub fn stream_layer_tcp_v6(data: CalloutData) {
type Fields = layer::FieldsStreamV6;
let Some(device) = crate::entry::get_device() else {
return;
};
@ -68,16 +71,18 @@ pub fn stream_layer_tcp_v6(data: CalloutData) {
} else {
return;
};
type Fields = layer::FieldsStreamV6;
if data_length == 0 {
return;
}
let local_ip =
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpLocalAddress as usize));
let local_port = data.get_value_u16(Fields::IpLocalPort as usize);
let remote_ip =
Ipv6Address::from_bytes(data.get_value_byte_array16(Fields::IpRemoteAddress as usize));
let remote_port = data.get_value_u16(Fields::IpRemotePort as usize);
match direction {
Direction::Outbound => {
device.bandwidth_stats.update_tcp_v6_tx(
@ -105,6 +110,8 @@ pub fn stream_layer_tcp_v6(data: CalloutData) {
}
pub fn stream_layer_udp_v4(data: CalloutData) {
type Fields = layer::FieldsDatagramDataV4;
let Some(device) = crate::entry::get_device() else {
return;
};
@ -112,7 +119,6 @@ pub fn stream_layer_udp_v4(data: CalloutData) {
for nbl in NetBufferListIter::new(data.get_layer_data() as _) {
data_length += nbl.get_data_length() as usize;
}
type Fields = layer::FieldsDatagramDataV4;
let mut direction = Direction::Inbound;
if data.get_value_u8(Fields::Direction as usize) == 0 {
direction = Direction::Outbound;
@ -157,6 +163,8 @@ pub fn stream_layer_udp_v4(data: CalloutData) {
}
pub fn stream_layer_udp_v6(data: CalloutData) {
type Fields = layer::FieldsDatagramDataV6;
let Some(device) = crate::entry::get_device() else {
return;
};
@ -164,7 +172,6 @@ pub fn stream_layer_udp_v6(data: CalloutData) {
for nbl in NetBufferListIter::new(data.get_layer_data() as _) {
data_length += nbl.get_data_length() as usize;
}
type Fields = layer::FieldsDatagramDataV6;
let mut direction = Direction::Inbound;
if data.get_value_u8(Fields::Direction as usize) == 0 {
direction = Direction::Outbound;

View file

@ -38,7 +38,7 @@ var (
)
const (
winInvalidHandleValue = windows.Handle(^uintptr(0)) // Max value
winInvalidHandleValue = windows.InvalidHandle
stopServiceTimeoutDuration = time.Duration(30 * time.Second)
)
@ -48,7 +48,7 @@ type KextService struct {
}
func (s *KextService) isValid() bool {
return s != nil && s.handle != winInvalidHandleValue && s.handle != 0
return s != nil && s.handle != windows.InvalidHandle && s.handle != 0
}
func (s *KextService) isRunning() (bool, error) {
@ -99,7 +99,7 @@ func (s *KextService) Start(wait bool) error {
_ = windows.ControlService(s.handle, windows.SERVICE_CONTROL_STOP, &status)
_ = windows.DeleteService(s.handle)
_ = windows.CloseServiceHandle(s.handle)
s.handle = winInvalidHandleValue
s.handle = windows.InvalidHandle
return err
}
}
@ -158,7 +158,7 @@ func (s *KextService) Delete() error {
return fmt.Errorf("failed to close service handle: %s", err)
}
s.handle = winInvalidHandleValue
s.handle = windows.InvalidHandle
return nil
}
@ -234,7 +234,7 @@ func CreateKextService(driverName string, driverPath string) (*KextService, erro
return nil, err
}
service = winInvalidHandleValue
service = windows.InvalidHandle
log.Warning("kext: old driver service was deleted successfully")
}

View file

@ -4,6 +4,8 @@
package kextinterface
import (
"fmt"
"golang.org/x/sys/windows"
)
@ -13,7 +15,16 @@ type KextFile struct {
read_slice []byte
}
// Read tries to read the supplied buffer length from the driver.
// The data from the driver is read in chunks `len(f.buffer)` and the extra data is cached for the next call.
// The performance penalty of calling the function with small buffers is very small.
// The function will block until the next info packet is received from the kext.
func (f *KextFile) Read(buffer []byte) (int, error) {
if err := f.IsValid(); err != nil {
return 0, fmt.Errorf("failed to read: %w", err)
}
// If no data is available from previous calls, read from kext.
if f.read_slice == nil || len(f.read_slice) == 0 {
err := f.refill_read_buffer()
if err != nil {
@ -22,14 +33,19 @@ func (f *KextFile) Read(buffer []byte) (int, error) {
}
if len(f.read_slice) >= len(buffer) {
// Write all requested bytes.
// There is enough data to fill the requested buffer.
copy(buffer, f.read_slice[0:len(buffer)])
// Move the slice to contain the remaining data.
f.read_slice = f.read_slice[len(buffer):]
} else {
// Write all available bytes and read again.
// There is not enough data to fill the requested buffer.
// Write everything available.
copy(buffer[0:len(f.read_slice)], f.read_slice)
copiedBytes := len(f.read_slice)
f.read_slice = nil
// Read again.
_, err := f.Read(buffer[copiedBytes:])
if err != nil {
return 0, err
@ -51,20 +67,33 @@ func (f *KextFile) refill_read_buffer() error {
return nil
}
// Write sends the buffer bytes to the kext. The function will block until the whole buffer is written to the kext.
func (f *KextFile) Write(buffer []byte) (int, error) {
if err := f.IsValid(); err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
var count uint32 = 0
overlapped := &windows.Overlapped{}
err := windows.WriteFile(f.handle, buffer, &count, overlapped)
return int(count), err
}
// Close closes the handle to the kext. This will cancel all active Reads and Writes.
func (f *KextFile) Close() error {
if err := f.IsValid(); err != nil {
return fmt.Errorf("failed to close: %w", err)
}
err := windows.CloseHandle(f.handle)
f.handle = winInvalidHandleValue
f.handle = windows.InvalidHandle
return err
}
// deviceIOControl exists for compatibility with the old kext.
func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (*windows.Overlapped, error) {
if err := f.IsValid(); err != nil {
return nil, fmt.Errorf("failed to send io control: %w", err)
}
// Prepare the input data
var inDataPtr *byte = nil
var inDataSize uint32 = 0
if inData != nil {
@ -72,6 +101,7 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (
inDataSize = uint32(len(inData))
}
// Prepare the output data
var outDataPtr *byte = nil
var outDataSize uint32 = 0
if outData != nil {
@ -79,6 +109,7 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (
outDataSize = uint32(len(outData))
}
// Make the request to the kext.
overlapped := &windows.Overlapped{}
err := windows.DeviceIoControl(f.handle,
code,
@ -92,6 +123,20 @@ func (f *KextFile) deviceIOControl(code uint32, inData []byte, outData []byte) (
return overlapped, nil
}
// GetHandle returns the handle of the kext.
func (f *KextFile) GetHandle() windows.Handle {
return f.handle
}
// IsValid checks if kext file holds a valid handle to the kext driver.
func (f *KextFile) IsValid() error {
if f == nil {
return fmt.Errorf("nil kext file")
}
if f.handle == windows.Handle(0) || f.handle == windows.InvalidHandle {
return fmt.Errorf("invalid handle")
}
return nil
}