mirror of
https://github.com/safing/portmaster
synced 2025-04-08 04:59:10 +00:00
179 lines
5.3 KiB
Rust
179 lines
5.3 KiB
Rust
use core::{fmt::Display, time::Duration};
|
|
|
|
use crate::connection::Connection;
|
|
use alloc::vec::Vec;
|
|
use hashbrown::HashMap;
|
|
use smoltcp::wire::{IpAddress, IpProtocol};
|
|
|
|
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
|
|
pub struct Key {
|
|
pub(crate) protocol: IpProtocol,
|
|
pub(crate) local_address: IpAddress,
|
|
pub(crate) local_port: u16,
|
|
pub(crate) remote_address: IpAddress,
|
|
pub(crate) remote_port: u16,
|
|
}
|
|
|
|
impl Display for Key {
|
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
|
write!(
|
|
f,
|
|
"p: {} l: {}:{} r: {}:{}",
|
|
self.protocol,
|
|
self.local_address,
|
|
self.local_port,
|
|
self.remote_address,
|
|
self.remote_port
|
|
)
|
|
}
|
|
}
|
|
|
|
impl Key {
|
|
/// Returns the protocol and port as a tuple.
|
|
pub fn small(&self) -> (IpProtocol, u16) {
|
|
(self.protocol, self.local_port)
|
|
}
|
|
|
|
/// Returns true if the local address is an IPv4 address.
|
|
pub fn is_ipv6(&self) -> bool {
|
|
match self.local_address {
|
|
IpAddress::Ipv4(_) => false,
|
|
IpAddress::Ipv6(_) => true,
|
|
}
|
|
}
|
|
|
|
/// Returns true if the local address is a loopback address.
|
|
pub fn is_loopback(&self) -> bool {
|
|
match self.local_address {
|
|
IpAddress::Ipv4(ip) => ip.is_loopback(),
|
|
IpAddress::Ipv6(ip) => ip.is_loopback(),
|
|
}
|
|
}
|
|
|
|
/// Returns a new key with the local and remote addresses and ports reversed.
|
|
#[allow(dead_code)]
|
|
pub fn reverse(&self) -> Key {
|
|
Key {
|
|
protocol: self.protocol,
|
|
local_address: self.remote_address,
|
|
local_port: self.remote_port,
|
|
remote_address: self.local_address,
|
|
remote_port: self.local_port,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ConnectionMap<T: Connection>(HashMap<(IpProtocol, u16), Vec<T>>);
|
|
|
|
impl<T: Connection + Clone> ConnectionMap<T> {
|
|
pub fn new() -> Self {
|
|
Self(HashMap::new())
|
|
}
|
|
|
|
pub fn add(&mut self, conn: T) {
|
|
let key = conn.get_key().small();
|
|
if let Some(connections) = self.0.get_mut(&key) {
|
|
connections.push(conn);
|
|
} else {
|
|
self.0.insert(key, alloc::vec![conn]);
|
|
}
|
|
}
|
|
|
|
pub fn get_mut(&mut self, key: &Key) -> Option<&mut T> {
|
|
if let Some(connections) = self.0.get_mut(&key.small()) {
|
|
for conn in connections {
|
|
if conn.remote_equals(key) {
|
|
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
|
|
return Some(conn);
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
pub fn read<C>(&self, key: &Key, read_connection: fn(&T) -> Option<C>) -> Option<C> {
|
|
if let Some(connections) = self.0.get(&key.small()) {
|
|
for conn in connections {
|
|
if conn.remote_equals(key) {
|
|
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
|
|
return read_connection(conn);
|
|
}
|
|
if conn.redirect_equals(key) {
|
|
conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
|
|
return read_connection(conn);
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
pub fn end(&mut self, key: Key) -> Option<T> {
|
|
if let Some(connections) = self.0.get_mut(&key.small()) {
|
|
for conn in connections.iter_mut() {
|
|
if conn.remote_equals(&key) {
|
|
conn.end(wdk::utils::get_system_timestamp_ms());
|
|
return Some(conn.clone());
|
|
}
|
|
}
|
|
}
|
|
return None;
|
|
}
|
|
|
|
pub fn end_all_on_port(&mut self, key: (IpProtocol, u16)) -> Option<Vec<T>> {
|
|
if let Some(connections) = self.0.get_mut(&key) {
|
|
let mut vec = Vec::with_capacity(connections.len());
|
|
for conn in connections.iter_mut() {
|
|
if !conn.has_ended() {
|
|
conn.end(wdk::utils::get_system_timestamp_ms());
|
|
vec.push(conn.clone());
|
|
}
|
|
}
|
|
return Some(vec);
|
|
}
|
|
return None;
|
|
}
|
|
|
|
pub fn clear(&mut self) {
|
|
self.0.clear();
|
|
}
|
|
|
|
pub fn clean_ended_connections(&mut self) {
|
|
let now = wdk::utils::get_system_timestamp_ms();
|
|
const TEN_MINUETS: u64 = Duration::from_secs(60 * 10).as_millis() as u64;
|
|
let before_ten_minutes = now - TEN_MINUETS;
|
|
let before_one_minute = now - Duration::from_secs(60).as_millis() as u64;
|
|
|
|
for (_, connections) in self.0.iter_mut() {
|
|
connections.retain(|c| {
|
|
if c.has_ended() && c.get_end_time() < before_one_minute {
|
|
// Ended more than 1 minute ago
|
|
return false;
|
|
}
|
|
|
|
if c.get_last_accessed_time() < before_ten_minutes {
|
|
// Last active more than 10 minutes ago
|
|
return false;
|
|
}
|
|
|
|
// Keep
|
|
return true;
|
|
});
|
|
}
|
|
self.0.retain(|_, v| !v.is_empty());
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub fn get_count(&self) -> usize {
|
|
let mut count = 0;
|
|
for conn in self.0.values() {
|
|
count += conn.len();
|
|
}
|
|
return count;
|
|
}
|
|
|
|
pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec<T>> {
|
|
self.0.iter()
|
|
}
|
|
}
|