diff --git a/network/clean.go b/network/clean.go index efeadce9..3d3a8c23 100644 --- a/network/clean.go +++ b/network/clean.go @@ -4,6 +4,8 @@ import ( "context" "time" + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/state" "github.com/safing/portbase/log" @@ -42,7 +44,7 @@ func cleanConnections() (activePIDs map[int]struct{}) { now := time.Now().UTC() nowUnix := now.Unix() - deleteOlderThan := time.Now().Add(-deleteConnsAfterEndedThreshold).Unix() + deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix() // lock both together because we cannot fully guarantee in which map a connection lands // of course every connection should land in the correct map, but this increases resilience @@ -59,12 +61,19 @@ func cleanConnections() (activePIDs map[int]struct{}) { switch { case conn.Ended == 0: // Step 1: check if still active - exists := state.Exists(conn.IPVersion, conn.IPProtocol, conn.LocalIP, conn.LocalPort, conn.Entity.IP, conn.Entity.Port, now) - if exists { - activePIDs[conn.process.Pid] = struct{}{} - } else { + exists := state.Exists(&packet.Info{ + Inbound: false, // src == local + Version: conn.IPVersion, + Protocol: conn.IPProtocol, + Src: conn.LocalIP, + SrcPort: conn.LocalPort, + Dst: conn.Entity.IP, + DstPort: conn.Entity.Port, + }, now) + activePIDs[conn.process.Pid] = struct{}{} + + if !exists { // Step 2: mark end - activePIDs[conn.process.Pid] = struct{}{} conn.Ended = nowUnix conn.Save() } diff --git a/network/state/exists.go b/network/state/exists.go index 9af6979f..68dd8288 100644 --- a/network/state/exists.go +++ b/network/state/exists.go @@ -1,7 +1,6 @@ package state import ( - "net" "time" "github.com/safing/portmaster/network/packet" @@ -12,49 +11,38 @@ const ( UDPConnectionTTL = 10 * time.Minute ) -func Exists( - ipVersion packet.IPVersion, - protocol packet.IPProtocol, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, - now time.Time, -) (exists bool) { - +func Exists(pktInfo *packet.Info, now time.Time) (exists bool) { switch { - case ipVersion == packet.IPv4 && protocol == packet.TCP: + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: tcp4Lock.Lock() defer tcp4Lock.Unlock() - return existsTCP(tcp4Connections, localIP, localPort, remoteIP, remotePort) + return existsTCP(tcp4Connections, pktInfo) - case ipVersion == packet.IPv6 && protocol == packet.TCP: + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP: tcp6Lock.Lock() defer tcp6Lock.Unlock() - return existsTCP(tcp6Connections, localIP, localPort, remoteIP, remotePort) + return existsTCP(tcp6Connections, pktInfo) - case ipVersion == packet.IPv4 && protocol == packet.UDP: + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP: udp4Lock.Lock() defer udp4Lock.Unlock() - return existsUDP(udp4Binds, udp4States, localIP, localPort, remoteIP, remotePort, now) + return existsUDP(udp4Binds, udp4States, pktInfo, now) - case ipVersion == packet.IPv6 && protocol == packet.UDP: + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP: udp6Lock.Lock() defer udp6Lock.Unlock() - return existsUDP(udp6Binds, udp6States, localIP, localPort, remoteIP, remotePort, now) + return existsUDP(udp6Binds, udp6States, pktInfo, now) default: return false } } -func existsTCP( - connections []*socket.ConnectionInfo, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, -) (exists bool) { +func existsTCP(connections []*socket.ConnectionInfo, pktInfo *packet.Info) (exists bool) { + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + remoteIP := pktInfo.RemoteIP() + remotePort := pktInfo.RemotePort() // search connections for _, socketInfo := range connections { @@ -72,13 +60,15 @@ func existsTCP( func existsUDP( binds []*socket.BindInfo, udpStates map[string]map[string]*udpState, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, + pktInfo *packet.Info, now time.Time, ) (exists bool) { + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + remoteIP := pktInfo.RemoteIP() + remotePort := pktInfo.RemotePort() + connThreshhold := now.Add(-UDPConnectionTTL) // search binds