From 5371350b3dd93b78d8fb836467a1ea082a187183 Mon Sep 17 00:00:00 2001
From: Daniel <dhaavi@users.noreply.github.com>
Date: Mon, 13 May 2024 15:37:10 +0200
Subject: [PATCH] Add new ICMP filter handler, fix cleaning of ICMP connections

---
 service/firewall/module.go         |   2 +-
 service/firewall/packet_handler.go | 117 ++++++++++++++++++++++++-----
 service/network/clean.go           |  52 +++++++++----
 service/network/connection.go      |  35 ++++++---
 4 files changed, 165 insertions(+), 41 deletions(-)

diff --git a/service/firewall/module.go b/service/firewall/module.go
index 168ee7b8..3b80fc88 100644
--- a/service/firewall/module.go
+++ b/service/firewall/module.go
@@ -49,7 +49,7 @@ func init() {
 }
 
 func prep() error {
-	network.SetDefaultFirewallHandler(verdictHandler)
+	network.SetDefaultFirewallHandler(defaultFirewallHandler)
 
 	// Reset connections every time configuration changes
 	// this will be triggered on spn enable/disable
diff --git a/service/firewall/packet_handler.go b/service/firewall/packet_handler.go
index 766ff2b0..bfb8473c 100644
--- a/service/firewall/packet_handler.go
+++ b/service/firewall/packet_handler.go
@@ -22,7 +22,6 @@ import (
 	"github.com/safing/portmaster/service/network"
 	"github.com/safing/portmaster/service/network/netutils"
 	"github.com/safing/portmaster/service/network/packet"
-	"github.com/safing/portmaster/service/network/reference"
 	"github.com/safing/portmaster/service/process"
 	"github.com/safing/portmaster/spn/access"
 )
@@ -227,7 +226,6 @@ func fastTrackedPermit(conn *network.Connection, pkt packet.Packet) (verdict net
 		meta.Src.Equal(meta.Dst) {
 		log.Tracer(pkt.Ctx()).Debugf("filter: fast-track network self-check: %s", pkt)
 		return network.VerdictAccept, true
-
 	}
 
 	switch meta.Protocol { //nolint:exhaustive // Checking for specific values only.
@@ -374,6 +372,8 @@ func fastTrackedPermit(conn *network.Connection, pkt packet.Packet) (verdict net
 }
 
 func fastTrackHandler(conn *network.Connection, pkt packet.Packet) {
+	conn.SaveWhenFinished()
+
 	fastTrackedVerdict, permanent := fastTrackedPermit(conn, pkt)
 	if fastTrackedVerdict != network.VerdictUndecided {
 		// Set verdict on connection.
@@ -402,6 +402,8 @@ func fastTrackHandler(conn *network.Connection, pkt packet.Packet) {
 }
 
 func gatherDataHandler(conn *network.Connection, pkt packet.Packet) {
+	conn.SaveWhenFinished()
+
 	// Get process info
 	_ = conn.GatherConnectionInfo(pkt)
 	// Errors are informational and are logged to the context.
@@ -412,11 +414,20 @@ func gatherDataHandler(conn *network.Connection, pkt packet.Packet) {
 	}
 
 	// Continue to filter handler, when connection data is complete.
-	conn.UpdateFirewallHandler(filterHandler)
-	filterHandler(conn, pkt)
+	switch conn.IPProtocol { //nolint:exhaustive
+	case packet.ICMP, packet.ICMPv6:
+		conn.UpdateFirewallHandler(icmpFilterHandler)
+		icmpFilterHandler(conn, pkt)
+
+	default:
+		conn.UpdateFirewallHandler(filterHandler)
+		filterHandler(conn, pkt)
+	}
 }
 
 func filterHandler(conn *network.Connection, pkt packet.Packet) {
+	conn.SaveWhenFinished()
+
 	// Skip if data is not complete or packet is info-only.
 	if !conn.DataIsComplete() || pkt.InfoOnly() {
 		return
@@ -469,11 +480,12 @@ func filterHandler(conn *network.Connection, pkt packet.Packet) {
 	switch {
 	case conn.Inspecting:
 		log.Tracer(pkt.Ctx()).Trace("filter: start inspecting")
-		conn.SetFirewallHandler(inspectAndVerdictHandler)
+		conn.UpdateFirewallHandler(inspectAndVerdictHandler)
 		inspectAndVerdictHandler(conn, pkt)
+
 	default:
 		conn.StopFirewallHandler()
-		issueVerdict(conn, pkt, 0, true)
+		verdictHandler(conn, pkt)
 	}
 }
 
@@ -529,6 +541,18 @@ func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.
 	}
 }
 
+// defaultFirewallHandler is used when no other firewall handler is set on a connection.
+func defaultFirewallHandler(conn *network.Connection, pkt packet.Packet) {
+	switch conn.IPProtocol { //nolint:exhaustive
+	case packet.ICMP, packet.ICMPv6:
+		// Always use the ICMP handler for ICMP connections.
+		icmpFilterHandler(conn, pkt)
+
+	default:
+		verdictHandler(conn, pkt)
+	}
+}
+
 func verdictHandler(conn *network.Connection, pkt packet.Packet) {
 	// Ignore info-only packets in this handler.
 	if pkt.InfoOnly() {
@@ -556,6 +580,73 @@ func inspectAndVerdictHandler(conn *network.Connection, pkt packet.Packet) {
 	issueVerdict(conn, pkt, 0, true)
 }
 
+func icmpFilterHandler(conn *network.Connection, pkt packet.Packet) {
+	// Load packet data.
+	err := pkt.LoadPacketData()
+	if err != nil {
+		log.Tracer(pkt.Ctx()).Debugf("filter: failed to load ICMP packet data: %s", err)
+		issueVerdict(conn, pkt, network.VerdictDrop, false)
+		return
+	}
+
+	// Submit to ICMP listener.
+	submitted := netenv.SubmitPacketToICMPListener(pkt)
+	if submitted {
+		issueVerdict(conn, pkt, network.VerdictDrop, false)
+		return
+	}
+
+	// Handle echo request and replies regularly.
+	// Other ICMP packets are considered system business.
+	icmpLayers := pkt.Layers().LayerClass(layers.LayerClassIPControl)
+	switch icmpLayer := icmpLayers.(type) {
+	case *layers.ICMPv4:
+		switch icmpLayer.TypeCode.Type() {
+		case layers.ICMPv4TypeEchoRequest,
+			layers.ICMPv4TypeEchoReply:
+			// Continue
+		default:
+			issueVerdict(conn, pkt, network.VerdictAccept, false)
+			return
+		}
+
+	case *layers.ICMPv6:
+		switch icmpLayer.TypeCode.Type() {
+		case layers.ICMPv6TypeEchoRequest,
+			layers.ICMPv6TypeEchoReply:
+			// Continue
+
+		default:
+			issueVerdict(conn, pkt, network.VerdictAccept, false)
+			return
+		}
+	}
+
+	// Check if we already have a verdict.
+	switch conn.Verdict { //nolint:exhaustive
+	case network.VerdictUndecided, network.VerdictUndeterminable:
+		// Apply privacy filter and check tunneling.
+		FilterConnection(pkt.Ctx(), conn, pkt, true, false)
+
+		// Save and propagate changes.
+		conn.SaveWhenFinished()
+	}
+
+	// Outbound direction has priority.
+	if conn.Inbound && conn.Ended == 0 && pkt.IsOutbound() {
+		// Change direction from inbound to outbound on first outbound ICMP packet.
+		conn.Inbound = false
+
+		// Apply privacy filter and check tunneling.
+		FilterConnection(pkt.Ctx(), conn, pkt, true, false)
+
+		// Save and propagate changes.
+		conn.SaveWhenFinished()
+	}
+
+	issueVerdict(conn, pkt, 0, false)
+}
+
 func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.Verdict, allowPermanent bool) {
 	// Check if packed was already fast-tracked by the OS integration.
 	if pkt.FastTrackedByIntegration() {
@@ -563,17 +654,9 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V
 	}
 
 	// Enable permanent verdict.
-	if allowPermanent && !conn.VerdictPermanent {
-		switch {
-		case !permanentVerdicts():
-			// Permanent verdicts are disabled by configuration.
-		case conn.Entity != nil && reference.IsICMP(conn.Entity.Protocol):
-		case pkt != nil && reference.IsICMP(uint8(pkt.Info().Protocol)):
-			// ICMP is handled differently based on payload, so we cannot use persistent verdicts.
-		default:
-			conn.VerdictPermanent = true
-			conn.SaveWhenFinished()
-		}
+	if allowPermanent && !conn.VerdictPermanent && permanentVerdicts() {
+		conn.VerdictPermanent = true
+		conn.SaveWhenFinished()
 	}
 
 	// do not allow to circumvent decision: e.g. to ACCEPT packets from a DROP-ed connection
diff --git a/service/network/clean.go b/service/network/clean.go
index 9901b00b..c2777164 100644
--- a/service/network/clean.go
+++ b/service/network/clean.go
@@ -11,6 +11,14 @@ import (
 )
 
 const (
+	// EndConnsAfterInactiveFor defines the amount of time after not seen
+	// connections of unsupported protocols are marked as ended.
+	EndConnsAfterInactiveFor = 5 * time.Minute
+
+	// EndICMPConnsAfterInactiveFor defines the amount of time after not seen
+	// ICMP "connections" are marked as ended.
+	EndICMPConnsAfterInactiveFor = 1 * time.Minute
+
 	// DeleteConnsAfterEndedThreshold defines the amount of time after which
 	// ended connections should be removed from the internal connection state.
 	DeleteConnsAfterEndedThreshold = 10 * time.Minute
@@ -48,7 +56,9 @@ func cleanConnections() (activePIDs map[int]struct{}) {
 	_ = module.RunMicroTask("clean connections", 0, func(ctx context.Context) error {
 		now := time.Now().UTC()
 		nowUnix := now.Unix()
-		ignoreNewer := nowUnix - 1
+		ignoreNewer := nowUnix - 2
+		endNotSeenSince := now.Add(-EndConnsAfterInactiveFor).Unix()
+		endICMPNotSeenSince := now.Add(-EndICMPConnsAfterInactiveFor).Unix()
 		deleteOlderThan := now.Add(-DeleteConnsAfterEndedThreshold).Unix()
 		deleteIncompleteOlderThan := now.Add(-DeleteIncompleteConnsAfterStartedThreshold).Unix()
 
@@ -68,22 +78,37 @@ func cleanConnections() (activePIDs map[int]struct{}) {
 					// Remove connection from state.
 					conn.delete()
 				}
+
 			case conn.Ended == 0:
 				// Step 1: check if still active
-				exists := state.Exists(&packet.Info{
-					Inbound:  false, // src == local
-					Version:  conn.IPVersion,
-					Protocol: conn.IPProtocol,
-					Src:      conn.LocalIP,
-					SrcPort:  conn.LocalPort,
-					Dst:      conn.Entity.IP,
-					DstPort:  conn.Entity.Port,
-					PID:      process.UndefinedProcessID,
-					SeenAt:   time.Unix(conn.Started, 0), // State tables will be updated if older than this.
-				}, now)
+				var connActive bool
+				switch conn.IPProtocol { //nolint:exhaustive
+				case packet.TCP, packet.UDP:
+					connActive = state.Exists(&packet.Info{
+						Inbound:  false, // src == local
+						Version:  conn.IPVersion,
+						Protocol: conn.IPProtocol,
+						Src:      conn.LocalIP,
+						SrcPort:  conn.LocalPort,
+						Dst:      conn.Entity.IP,
+						DstPort:  conn.Entity.Port,
+						PID:      process.UndefinedProcessID,
+						SeenAt:   time.Unix(conn.Started, 0), // State tables will be updated if older than this.
+					}, now)
+					// Update last seen value for permanent verdict connections.
+					if connActive && conn.VerdictPermanent {
+						conn.lastSeen.Store(nowUnix)
+					}
+
+				case packet.ICMP, packet.ICMPv6:
+					connActive = conn.lastSeen.Load() > endICMPNotSeenSince
+
+				default:
+					connActive = conn.lastSeen.Load() > endNotSeenSince
+				}
 
 				// Step 2: mark as ended
-				if !exists {
+				if !connActive {
 					conn.Ended = nowUnix
 
 					// Stop the firewall handler, in case one is running.
@@ -97,6 +122,7 @@ func cleanConnections() (activePIDs map[int]struct{}) {
 				if conn.process != nil {
 					activePIDs[conn.process.Pid] = struct{}{}
 				}
+
 			case conn.Ended < deleteOlderThan:
 				// Step 3: delete
 				// DEBUG:
diff --git a/service/network/connection.go b/service/network/connection.go
index 2459ea14..457a2283 100644
--- a/service/network/connection.go
+++ b/service/network/connection.go
@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/tevino/abool"
@@ -180,6 +181,11 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
 	// BytesSent holds the observed sent bytes of the connection.
 	BytesSent uint64
 
+	// lastSeen holds the timestamp when the connection was last seen.
+	// If permanent verdicts are enabled and bandwidth reporting is not active,
+	// this value will likely not be correct.
+	lastSeen atomic.Int64
+
 	// prompt holds the active prompt for this connection, if there is one.
 	prompt *notifications.Notification
 	// promptLock locks the prompt separately from the connection.
@@ -340,6 +346,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
 		Ended:          timestamp,
 		dataComplete:   abool.NewBool(true),
 	}
+	dnsConn.lastSeen.Store(timestamp)
 
 	// Inherit internal status of profile.
 	if localProfile := proc.Profile().LocalProfile(); localProfile != nil {
@@ -383,6 +390,7 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname
 		Ended:          timestamp,
 		dataComplete:   abool.NewBool(true),
 	}
+	dnsConn.lastSeen.Store(timestamp)
 
 	// Inherit internal status of profile.
 	if localProfile := remoteHost.Profile().LocalProfile(); localProfile != nil {
@@ -418,6 +426,7 @@ func NewIncompleteConnection(pkt packet.Packet) *Connection {
 		PID:          info.PID,
 		dataComplete: abool.NewBool(false),
 	}
+	conn.lastSeen.Store(conn.Started)
 
 	// Bullshit check Started timestamp.
 	if conn.Started < tooOldTimestamp {
@@ -569,6 +578,7 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
 		conn.dataComplete.Set()
 	}
 
+	conn.SaveWhenFinished()
 	return nil
 }
 
@@ -859,6 +869,9 @@ func (conn *Connection) StopFirewallHandler() {
 
 // HandlePacket queues packet of Link for handling.
 func (conn *Connection) HandlePacket(pkt packet.Packet) {
+	// Update last seen timestamp.
+	conn.lastSeen.Store(time.Now().Unix())
+
 	conn.pktQueueLock.Lock()
 	defer conn.pktQueueLock.Unlock()
 
@@ -994,17 +1007,19 @@ func packetHandlerHandleConn(ctx context.Context, conn *Connection, pkt packet.P
 	// Record metrics.
 	packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt)
 
-	// Log result and submit trace.
-	switch {
-	case conn.DataIsComplete():
-		tracer.Infof("filter: connection %s %s: %s", conn, conn.VerdictVerb(), conn.Reason.Msg)
-	case conn.Verdict != VerdictUndecided:
-		tracer.Debugf("filter: connection %s fast-tracked", pkt)
-	default:
-		tracer.Debugf("filter: gathered data on connection %s", conn)
+	// Log result and submit trace, when there are any changes.
+	if conn.saveWhenFinished {
+		switch {
+		case conn.DataIsComplete():
+			tracer.Infof("filter: connection %s %s: %s", conn, conn.VerdictVerb(), conn.Reason.Msg)
+		case conn.Verdict != VerdictUndecided:
+			tracer.Debugf("filter: connection %s fast-tracked", pkt)
+		default:
+			tracer.Debugf("filter: gathered data on connection %s", conn)
+		}
+		// Submit trace logs.
+		tracer.Submit()
 	}
-	// Submit trace logs.
-	tracer.Submit()
 
 	// Push changes, if there are any.
 	if conn.saveWhenFinished {