safing-portmaster/windows_kext/driver/src/connection_map.rs
2024-06-04 14:55:29 +03:00

179 lines
5.3 KiB
Rust

use core::{fmt::Display, time::Duration};
use crate::connection::Connection;
use alloc::vec::Vec;
use hashbrown::HashMap;
use smoltcp::wire::{IpAddress, IpProtocol};
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
pub struct Key {
pub(crate) protocol: IpProtocol,
pub(crate) local_address: IpAddress,
pub(crate) local_port: u16,
pub(crate) remote_address: IpAddress,
pub(crate) remote_port: u16,
}
impl Display for Key {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"p: {} l: {}:{} r: {}:{}",
self.protocol,
self.local_address,
self.local_port,
self.remote_address,
self.remote_port
)
}
}
impl Key {
/// Returns the protocol and port as a tuple.
pub fn small(&self) -> (IpProtocol, u16) {
(self.protocol, self.local_port)
}
/// Returns true if the local address is an IPv4 address.
pub fn is_ipv6(&self) -> bool {
match self.local_address {
IpAddress::Ipv4(_) => false,
IpAddress::Ipv6(_) => true,
}
}
/// Returns true if the local address is a loopback address.
pub fn is_loopback(&self) -> bool {
match self.local_address {
IpAddress::Ipv4(ip) => ip.is_loopback(),
IpAddress::Ipv6(ip) => ip.is_loopback(),
}
}
/// Returns a new key with the local and remote addresses and ports reversed.
#[allow(dead_code)]
pub fn reverse(&self) -> Key {
Key {
protocol: self.protocol,
local_address: self.remote_address,
local_port: self.remote_port,
remote_address: self.local_address,
remote_port: self.local_port,
}
}
}
pub struct ConnectionMap<T: Connection>(HashMap<(IpProtocol, u16), Vec<T>>);
impl<T: Connection + Clone> ConnectionMap<T> {
pub fn new() -> Self {
Self(HashMap::new())
}
pub fn add(&mut self, conn: T) {
let key = conn.get_key().small();
if let Some(connections) = self.0.get_mut(&key) {
connections.push(conn);
} else {
self.0.insert(key, alloc::vec![conn]);
}
}
pub fn get_mut(&mut self, key: &Key) -> Option<&mut T> {
if let Some(connections) = self.0.get_mut(&key.small()) {
for conn in connections {
if conn.remote_equals(key) {
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
return Some(conn);
}
}
}
None
}
pub fn read<C>(&self, key: &Key, read_connection: fn(&T) -> Option<C>) -> Option<C> {
if let Some(connections) = self.0.get(&key.small()) {
for conn in connections {
if conn.remote_equals(key) {
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
return read_connection(conn);
}
if conn.redirect_equals(key) {
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
return read_connection(conn);
}
}
}
None
}
pub fn end(&mut self, key: Key) -> Option<T> {
if let Some(connections) = self.0.get_mut(&key.small()) {
for conn in connections.iter_mut() {
if conn.remote_equals(&key) {
conn.end(wdk::utils::get_system_timestamp_ms());
return Some(conn.clone());
}
}
}
return None;
}
pub fn end_all_on_port(&mut self, key: (IpProtocol, u16)) -> Option<Vec<T>> {
if let Some(connections) = self.0.get_mut(&key) {
let mut vec = Vec::with_capacity(connections.len());
for conn in connections.iter_mut() {
if !conn.has_ended() {
conn.end(wdk::utils::get_system_timestamp_ms());
vec.push(conn.clone());
}
}
return Some(vec);
}
return None;
}
pub fn clear(&mut self) {
self.0.clear();
}
pub fn clean_ended_connections(&mut self) {
let now = wdk::utils::get_system_timestamp_ms();
const TEN_MINUETS: u64 = Duration::from_secs(60 * 10).as_millis() as u64;
let before_ten_minutes = now - TEN_MINUETS;
let before_one_minute = now - Duration::from_secs(60).as_millis() as u64;
for (_, connections) in self.0.iter_mut() {
connections.retain(|c| {
if c.has_ended() && c.get_end_time() < before_one_minute {
// Ended more than 1 minute ago
return false;
}
if c.get_last_accessed_time() < before_ten_minutes {
// Last active more than 10 minutes ago
return false;
}
// Keep
return true;
});
}
self.0.retain(|_, v| !v.is_empty());
}
#[allow(dead_code)]
pub fn get_count(&self) -> usize {
let mut count = 0;
for conn in self.0.values() {
count += conn.len();
}
return count;
}
pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec<T>> {
self.0.iter()
}
}