use core::{fmt::Display, time::Duration};

use crate::connection::Connection;
use alloc::vec::Vec;
use hashbrown::HashMap;
use smoltcp::wire::{IpAddress, IpProtocol};

#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
pub struct Key {
    pub(crate) protocol: IpProtocol,
    pub(crate) local_address: IpAddress,
    pub(crate) local_port: u16,
    pub(crate) remote_address: IpAddress,
    pub(crate) remote_port: u16,
}

impl Display for Key {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(
            f,
            "p: {} l: {}:{} r: {}:{}",
            self.protocol,
            self.local_address,
            self.local_port,
            self.remote_address,
            self.remote_port
        )
    }
}

impl Key {
    /// Returns the protocol and port as a tuple.
    pub fn small(&self) -> (IpProtocol, u16) {
        (self.protocol, self.local_port)
    }

    /// Returns true if the local address is an IPv4 address.
    pub fn is_ipv6(&self) -> bool {
        match self.local_address {
            IpAddress::Ipv4(_) => false,
            IpAddress::Ipv6(_) => true,
        }
    }

    /// Returns true if the local address is a loopback address.
    pub fn is_loopback(&self) -> bool {
        match self.local_address {
            IpAddress::Ipv4(ip) => ip.is_loopback(),
            IpAddress::Ipv6(ip) => ip.is_loopback(),
        }
    }

    /// Returns a new key with the local and remote addresses and ports reversed.
    #[allow(dead_code)]
    pub fn reverse(&self) -> Key {
        Key {
            protocol: self.protocol,
            local_address: self.remote_address,
            local_port: self.remote_port,
            remote_address: self.local_address,
            remote_port: self.local_port,
        }
    }
}

pub struct ConnectionMap<T: Connection>(HashMap<(IpProtocol, u16), Vec<T>>);

impl<T: Connection + Clone> ConnectionMap<T> {
    pub fn new() -> Self {
        Self(HashMap::new())
    }

    pub fn add(&mut self, conn: T) {
        let key = conn.get_key().small();
        if let Some(connections) = self.0.get_mut(&key) {
            connections.push(conn);
        } else {
            self.0.insert(key, alloc::vec![conn]);
        }
    }

    pub fn get_mut(&mut self, key: &Key) -> Option<&mut T> {
        if let Some(connections) = self.0.get_mut(&key.small()) {
            for conn in connections {
                if conn.remote_equals(key) {
                    conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
                    return Some(conn);
                }
            }
        }

        None
    }

    pub fn read<C>(&self, key: &Key, read_connection: fn(&T) -> Option<C>) -> Option<C> {
        if let Some(connections) = self.0.get(&key.small()) {
            for conn in connections {
                if conn.remote_equals(key) {
                    conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
                    return read_connection(conn);
                }
                if conn.redirect_equals(key) {
                    conn.set_last_accessed_time(wdk::utils::get_system_timestamp_ms());
                    return read_connection(conn);
                }
            }
        }

        None
    }

    pub fn end(&mut self, key: Key) -> Option<T> {
        if let Some(connections) = self.0.get_mut(&key.small()) {
            for conn in connections.iter_mut() {
                if conn.remote_equals(&key) {
                    conn.end(wdk::utils::get_system_timestamp_ms());
                    return Some(conn.clone());
                }
            }
        }
        return None;
    }

    pub fn end_all_on_port(&mut self, key: (IpProtocol, u16)) -> Option<Vec<T>> {
        if let Some(connections) = self.0.get_mut(&key) {
            let mut vec = Vec::with_capacity(connections.len());
            for conn in connections.iter_mut() {
                if !conn.has_ended() {
                    conn.end(wdk::utils::get_system_timestamp_ms());
                    vec.push(conn.clone());
                }
            }
            return Some(vec);
        }
        return None;
    }

    pub fn clear(&mut self) {
        self.0.clear();
    }

    pub fn clean_ended_connections(&mut self) {
        let now = wdk::utils::get_system_timestamp_ms();
        const TEN_MINUETS: u64 = Duration::from_secs(60 * 10).as_millis() as u64;
        let before_ten_minutes = now - TEN_MINUETS;
        let before_one_minute = now - Duration::from_secs(60).as_millis() as u64;

        for (_, connections) in self.0.iter_mut() {
            connections.retain(|c| {
                if c.has_ended() && c.get_end_time() < before_one_minute {
                    // Ended more than 1 minute ago
                    return false;
                }

                if c.get_last_accessed_time() < before_ten_minutes {
                    // Last active more than 10 minutes ago
                    return false;
                }

                // Keep
                return true;
            });
        }
        self.0.retain(|_, v| !v.is_empty());
    }

    #[allow(dead_code)]
    pub fn get_count(&self) -> usize {
        let mut count = 0;
        for conn in self.0.values() {
            count += conn.len();
        }
        return count;
    }

    pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, (IpProtocol, u16), Vec<T>> {
        self.0.iter()
    }
}