diff --git a/firewall/packet_handler.go b/firewall/packet_handler.go index 64a22d86..0ddf3a9a 100644 --- a/firewall/packet_handler.go +++ b/firewall/packet_handler.go @@ -23,6 +23,7 @@ import ( "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/reference" + "github.com/safing/portmaster/process" "github.com/safing/spn/access" ) @@ -140,12 +141,12 @@ func handlePacket(pkt packet.Packet) { } // fastTrackedPermit quickly permits certain network critical or internal connections. -func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bool) { +func fastTrackedPermit(conn *network.Connection, pkt packet.Packet) (verdict network.Verdict, permanent bool) { meta := pkt.Info() // Check if packed was already fast-tracked by the OS integration. if pkt.FastTrackedByIntegration() { - log.Debugf("filter: fast-tracked by OS integration: %s", pkt) + log.Tracer(pkt.Ctx()).Debugf("filter: fast-tracked by OS integration: %s", pkt) return network.VerdictAccept, true } @@ -159,7 +160,7 @@ func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bo // Eg. dig: https://gitlab.isc.org/isc-projects/bind9/-/issues/1140 if meta.SrcPort == meta.DstPort && meta.Src.Equal(meta.Dst) { - log.Debugf("filter: fast-track network self-check: %s", pkt) + log.Tracer(pkt.Ctx()).Debugf("filter: fast-track network self-check: %s", pkt) return network.VerdictAccept, true } @@ -169,7 +170,7 @@ func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bo // Load packet data. err := pkt.LoadPacketData() if err != nil { - log.Debugf("filter: failed to load ICMP packet data: %s", err) + log.Tracer(pkt.Ctx()).Debugf("filter: failed to load ICMP packet data: %s", err) return network.VerdictAccept, true } @@ -179,7 +180,7 @@ func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bo // If the packet was submitted to the listener, we must not do a // permanent accept, because then we won't see any future packets of that // connection and thus cannot continue to submit them. - log.Debugf("filter: fast-track tracing ICMP/v6: %s", pkt) + log.Tracer(pkt.Ctx()).Debugf("filter: fast-track tracing ICMP/v6: %s", pkt) return network.VerdictAccept, false } @@ -202,7 +203,7 @@ func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bo } // Permit all ICMP/v6 packets that are not echo requests or replies. - log.Debugf("filter: fast-track accepting ICMP/v6: %s", pkt) + log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting ICMP/v6: %s", pkt) return network.VerdictAccept, true case packet.UDP, packet.TCP: @@ -224,7 +225,7 @@ func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bo } // Log and permit. - log.Debugf("filter: fast-track accepting DHCP: %s", pkt) + log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting DHCP: %s", pkt) return network.VerdictAccept, true case apiPort: @@ -249,14 +250,14 @@ func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bo isMe, err := netenv.IsMyIP(meta.Src) switch { case err != nil: - log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err) + log.Tracer(pkt.Ctx()).Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err) return network.VerdictUndecided, false case !isMe: return network.VerdictUndecided, false } // Log and permit. - log.Debugf("filter: fast-track accepting api connection: %s", pkt) + log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting api connection: %s", pkt) return network.VerdictAccept, true case 53: @@ -277,15 +278,24 @@ func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bo isMe, err := netenv.IsMyIP(meta.Src) switch { case err != nil: - log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err) + log.Tracer(pkt.Ctx()).Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err) return network.VerdictUndecided, false case !isMe: return network.VerdictUndecided, false } // Log and permit. - log.Debugf("filter: fast-track accepting local dns: %s", pkt) - return network.VerdictAccept, true + log.Tracer(pkt.Ctx()).Debugf("filter: fast-track accepting local dns: %s", pkt) + + // Add to DNS request connections to attribute DNS request if outgoing. + if pkt.IsOutbound() { + // Assign PID from packet directly, as processing stops after fast-track. + conn.PID = pkt.Info().PID + network.SaveDNSRequestConnection(conn, pkt) + } + + // Accept local DNS, but only make permanent if we have the PID too. + return network.VerdictAccept, conn.PID != process.UndefinedProcessID } case compat.SystemIntegrationCheckProtocol: @@ -299,7 +309,7 @@ func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bo } func fastTrackHandler(conn *network.Connection, pkt packet.Packet) { - fastTrackedVerdict, permanent := fastTrackedPermit(pkt) + fastTrackedVerdict, permanent := fastTrackedPermit(conn, pkt) if fastTrackedVerdict != network.VerdictUndecided { // Set verdict on connection. conn.Verdict.Active = fastTrackedVerdict @@ -375,6 +385,10 @@ func filterHandler(conn *network.Connection, pkt packet.Packet) { conn.SetVerdict(network.VerdictRerouteToNameserver, "redirecting rogue dns query", "", nil) conn.Internal = true log.Tracer(pkt.Ctx()).Infof("filter: redirecting dns query %s to Portmaster", conn) + + // Add to DNS request connections to attribute DNS request. + network.SaveDNSRequestConnection(conn, pkt) + // End directly, as no other processing is necessary. conn.StopFirewallHandler() finalizeVerdict(conn) diff --git a/network/clean.go b/network/clean.go index 1a690757..d4fb50f3 100644 --- a/network/clean.go +++ b/network/clean.go @@ -118,6 +118,9 @@ func cleanConnections() (activePIDs map[int]struct{}) { conn.Unlock() } + // rerouted dns requests + cleanDNSRequestConnections() + return nil }) diff --git a/network/connection.go b/network/connection.go index fd2cc35c..b1d1958b 100644 --- a/network/connection.go +++ b/network/connection.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net" - "strings" "sync" "time" @@ -293,18 +292,22 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri } // Check if the dns request connection was reported with process info. - dnsRequestConnID := pi.CreateConnectionID() - // Cut the destination, as the dns request may have been redirected and we - // don't know the original destination. - dnsRequestConnIDPrefix, ok := strings.CutSuffix(dnsRequestConnID, "-0") - if !ok { - log.Tracer(ctx).Warningf("network: unexpected connection ID for finding dns requests connection: %s", dnsRequestConnID) - } - // Find matching dns request connection. - dnsRequestConn, ok := conns.findByPrefix(dnsRequestConnIDPrefix) - if ok && dnsRequestConn.PID != process.UndefinedProcessID { - log.Tracer(ctx).Debugf("network: found matching dns request connection %s", dnsRequestConn) + var proc *process.Process + dnsRequestConn, ok := GetDNSRequestConnection(pi) + switch { + case !ok: + // No dns request connection found. + case dnsRequestConn.PID < 0: + // Process is not identified or is special. + case dnsRequestConn.Ended > 0 && dnsRequestConn.Ended < time.Now().Unix()-3: + // Connection has already ended (too long ago). + log.Tracer(ctx).Debugf("network: found ended dns request connection %s for dns request for %s", dnsRequestConn, fqdn) + default: + log.Tracer(ctx).Debugf("network: found matching dns request connection %s", dnsRequestConn.String()) + // Inherit PID. pi.PID = dnsRequestConn.PID + // Inherit process struct itself, as the PID may already be re-used. + proc = dnsRequestConn.process } // Find process by remote IP/Port. @@ -316,7 +319,9 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri } // Get process and profile with PID. - proc, _ := process.GetProcessWithProfile(ctx, pi.PID) + if proc == nil { + proc, _ = process.GetProcessWithProfile(ctx, pi.PID) + } timestamp := time.Now().Unix() dnsConn := &Connection{ @@ -1017,6 +1022,8 @@ func (conn *Connection) SetInspectorData(newInspectorData map[uint8]interface{}) // String returns a string representation of conn. func (conn *Connection) String() string { switch { + case conn.process == nil || conn.Entity == nil: + return conn.ID case conn.Inbound: return fmt.Sprintf("%s <- %s", conn.process, conn.Entity.IP) case conn.Entity.Domain != "": diff --git a/network/connection_store.go b/network/connection_store.go index 2dec16c9..7a6a61ef 100644 --- a/network/connection_store.go +++ b/network/connection_store.go @@ -40,7 +40,7 @@ func (cs *connectionStore) get(id string) (*Connection, bool) { // findByPrefix returns the first connection where the key matches the given prefix. // If the prefix matches multiple entries, the result is not deterministic. -func (cs *connectionStore) findByPrefix(prefix string) (*Connection, bool) { +func (cs *connectionStore) findByPrefix(prefix string) (*Connection, bool) { //nolint:unused cs.rw.RLock() defer cs.rw.RUnlock() diff --git a/network/dns.go b/network/dns.go index 2a1628fc..355fb65d 100644 --- a/network/dns.go +++ b/network/dns.go @@ -12,12 +12,16 @@ import ( "github.com/safing/portbase/log" "github.com/safing/portmaster/nameserver/nsutil" + "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/process" "github.com/safing/portmaster/resolver" ) var ( - openDNSRequests = make(map[string]*Connection) // key: /fqdn + dnsRequestConnections = make(map[string]*Connection) // key: -- + dnsRequestConnectionsLock sync.RWMutex + + openDNSRequests = make(map[string]*Connection) // key: / openDNSRequestsLock sync.Mutex supportedDomainToIPRecordTypes = []uint16{ @@ -38,6 +42,82 @@ const ( openDNSRequestLimit = 3 * time.Second ) +func getDNSRequestConnectionKey(packetInfo *packet.Info) (id string, ok bool) { + // We only support protocols with ports. + if packetInfo.SrcPort == 0 { + return "", false + } + + return fmt.Sprintf("%d-%s-%d", packetInfo.Protocol, packetInfo.Src, packetInfo.SrcPort), true +} + +// SaveDNSRequestConnection saves a dns request connection for later retrieval. +func SaveDNSRequestConnection(conn *Connection, pkt packet.Packet) { + // Check connection. + if conn.PID == process.UndefinedProcessID { + log.Tracer(pkt.Ctx()).Tracef("network: not saving dns request connection because the PID is undefined") + return + } + + // Create key. + key, ok := getDNSRequestConnectionKey(pkt.Info()) + if !ok { + log.Tracer(pkt.Ctx()).Debugf("network: not saving dns request connection %s because the protocol is not supported", pkt) + return + } + + // Add or update DNS request connection. + log.Tracer(pkt.Ctx()).Tracef("network: saving %s with PID %d as dns request connection for fast DNS request attribution", pkt, conn.PID) + dnsRequestConnectionsLock.Lock() + defer dnsRequestConnectionsLock.Unlock() + dnsRequestConnections[key] = conn +} + +// GetDNSRequestConnection returns a saved dns request connection. +func GetDNSRequestConnection(packetInfo *packet.Info) (conn *Connection, ok bool) { + // Make key. + key, ok := getDNSRequestConnectionKey(packetInfo) + if !ok { + return nil, false + } + + // Get and return + dnsRequestConnectionsLock.RLock() + defer dnsRequestConnectionsLock.RUnlock() + + conn, ok = dnsRequestConnections[key] + return +} + +// deleteDNSRequestConnection removes a connection from the dns request connections. +func deleteDNSRequestConnection(packetInfo *packet.Info) { //nolint:unused,deadcode + dnsRequestConnectionsLock.Lock() + defer dnsRequestConnectionsLock.Unlock() + + key, ok := getDNSRequestConnectionKey(packetInfo) + if ok { + delete(dnsRequestConnections, key) + } +} + +// cleanDNSRequestConnections deletes old DNS request connections. +func cleanDNSRequestConnections() { + deleteOlderThan := time.Now().Unix() - 3 + + dnsRequestConnectionsLock.Lock() + defer dnsRequestConnectionsLock.Unlock() + + for key, conn := range dnsRequestConnections { + conn.Lock() + + if conn.Ended > 0 && conn.Ended < deleteOlderThan { + delete(dnsRequestConnections, key) + } + + conn.Unlock() + } +} + // IsSupportDNSRecordType returns whether the given DSN RR type is supported // by the network package, as in the requests are specially handled and can be // "merged" into the resulting connection. diff --git a/process/process.go b/process/process.go index e982c901..934c195c 100644 --- a/process/process.go +++ b/process/process.go @@ -112,6 +112,16 @@ func (p *Process) IsIdentified() bool { } } +// IsLocal returns whether the process has been identified as a local process. +func (p *Process) IsLocal() bool { + // Check if process exists. + if p == nil { + return false + } + + return p.Pid >= 0 +} + // Equal returns if the two processes are both identified and have the same PID. func (p *Process) Equal(other *Process) bool { return p.IsIdentified() && other.IsIdentified() && p.Pid == other.Pid