diff --git a/network/connection.go b/network/connection.go index eae2617c..63182ca2 100644 --- a/network/connection.go +++ b/network/connection.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "strings" "sync" "time" @@ -172,6 +173,9 @@ type Connection struct { //nolint:maligned // TODO: fix alignment StopTunnel() error } + RecvBytes uint64 + SentBytes uint64 + // pkgQueue is used to serialize packet handling for a single // connection and is served by the connections packetHandler. pktQueue chan packet.Packet @@ -264,24 +268,43 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri ipVersion = packet.IPv4 } - // Get Process. - // FIXME: Find direct or redirected connection and grab the PID from there. + // Create packet info for dns request connection. + pi := &packet.Info{ + Inbound: false, // outbound as we are looking for the process of the source address + Version: ipVersion, + Protocol: packet.UDP, + Src: localIP, // source as in the process we are looking for + SrcPort: localPort, // source as in the process we are looking for + Dst: nil, // do not record direction + DstPort: 0, // do not record direction + PID: process.UndefinedProcessID, + } + + // Check if the dns request connection was reported with process info. + dnsRequestConnID := pi.CreateConnectionID() + // Cut the destination, as the dns request may have been redirected and we + // don't know the original destination. + dnsRequestConnIDPrefix, ok := strings.CutSuffix(dnsRequestConnID, "-0") + if !ok { + log.Tracer(ctx).Warningf("network: unexpected connection ID for finding dns requests connection: %s", dnsRequestConnID) + } + // Find matching dns request connection. + dnsRequestConn, ok := conns.findByPrefix(dnsRequestConnIDPrefix) + if ok && dnsRequestConn.PID != process.UndefinedProcessID { + log.Tracer(ctx).Debugf("network: found matching dns request connection %s", dnsRequestConn) + pi.PID = dnsRequestConn.PID + } // Find process by remote IP/Port. - pid, _, _ := process.GetPidOfConnection( - ctx, - &packet.Info{ - Inbound: false, // outbound as we are looking for the process of the source address - Version: ipVersion, - Protocol: packet.UDP, - Src: localIP, // source as in the process we are looking for - SrcPort: localPort, // source as in the process we are looking for - Dst: nil, // do not record direction - DstPort: 0, // do not record direction - PID: process.UndefinedProcessID, - }, - ) - proc, _ := process.GetProcessWithProfile(ctx, pid) + if pi.PID == process.UndefinedProcessID { + pi.PID, _, _ = process.GetPidOfConnection( + ctx, + pi, + ) + } + + // Get process and profile with PID. + proc, _ := process.GetProcessWithProfile(ctx, pi.PID) timestamp := time.Now().Unix() dnsConn := &Connection{ @@ -378,8 +401,7 @@ func NewIncompleteConnection(pkt packet.Packet) *Connection { // GatherConnectionInfo gathers information on the process and remote entity. func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { // Get PID if not yet available. - // FIXME: Only match for UndefinedProcessID when integrations have been updated. - if conn.PID <= 0 { + if conn.PID == process.UndefinedProcessID { // Get process by looking at the system state tables. // Apply direction as reported from the state tables. conn.PID, conn.Inbound, _ = process.GetPidOfConnection(pkt.Ctx(), pkt.Info()) @@ -390,20 +412,22 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { if conn.process == nil { // We got connection from the system. conn.process, err = process.GetProcessWithProfile(pkt.Ctx(), conn.PID) - if err != nil { + if err == nil { + // Add process/profile metadata for connection. + conn.ProcessContext = getProcessContext(pkt.Ctx(), conn.process) + conn.ProfileRevisionCounter = conn.process.Profile().RevisionCnt() + + // Inherit internal status of profile. + if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil { + conn.Internal = localProfile.Internal + } + } else { conn.process = nil - err = fmt.Errorf("failed to get process and profile of PID %d: %w", conn.PID, err) - log.Tracer(pkt.Ctx()).Debugf("network: %s", err) - return err - } - - // Add process/profile metadata for connection. - conn.ProcessContext = getProcessContext(pkt.Ctx(), conn.process) - conn.ProfileRevisionCounter = conn.process.Profile().RevisionCnt() - - // Inherit internal status of profile. - if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil { - conn.Internal = localProfile.Internal + if pkt.InfoOnly() { + log.Tracer(pkt.Ctx()).Debugf("network: failed to get process and profile of PID %d: %s", conn.PID, err) + } else { + log.Tracer(pkt.Ctx()).Warningf("network: failed to get process and profile of PID %d: %s", conn.PID, err) + } } } @@ -435,48 +459,50 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { conn.Scope = IncomingInvalid } } else { + // Outbound direct (possibly P2P) connection. + switch conn.Entity.IPScope { + case netutils.HostLocal: + conn.Scope = PeerHost + case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: + conn.Scope = PeerLAN + case netutils.Global, netutils.GlobalMulticast: + conn.Scope = PeerInternet - // check if we can find a domain for that IP - ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String()) - if err != nil { - // Try again with the global scope, in case DNS went through the system resolver. - ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String()) - } - if err == nil { - lastResolvedDomain := ipinfo.MostRecentDomain() - if lastResolvedDomain != nil { - conn.Scope = lastResolvedDomain.Domain - conn.Entity.Domain = lastResolvedDomain.Domain - conn.Entity.CNAME = lastResolvedDomain.CNAMEs - conn.DNSContext = lastResolvedDomain.DNSRequestContext - conn.Resolver = lastResolvedDomain.Resolver - removeOpenDNSRequest(conn.process.Pid, lastResolvedDomain.Domain) - } + case netutils.Undefined, netutils.Invalid: + fallthrough + default: + conn.Scope = PeerInvalid } + } + } - // check if destination IP is the captive portal's IP - portal := netenv.GetCaptivePortal() - if pkt.Info().RemoteIP().Equal(portal.IP) { - conn.Scope = portal.Domain - conn.Entity.Domain = portal.Domain + // Find domain and DNS context of entity. + if conn.Entity.Domain == "" && conn.process.Profile() != nil { + // check if we can find a domain for that IP + ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String()) + if err != nil { + // Try again with the global scope, in case DNS went through the system resolver. + ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String()) + } + if err == nil { + lastResolvedDomain := ipinfo.MostRecentDomain() + if lastResolvedDomain != nil { + conn.Scope = lastResolvedDomain.Domain + conn.Entity.Domain = lastResolvedDomain.Domain + conn.Entity.CNAME = lastResolvedDomain.CNAMEs + conn.DNSContext = lastResolvedDomain.DNSRequestContext + conn.Resolver = lastResolvedDomain.Resolver + removeOpenDNSRequest(conn.process.Pid, lastResolvedDomain.Domain) } + } + } - if conn.Scope == "" { - // outbound direct (possibly P2P) connection - switch conn.Entity.IPScope { - case netutils.HostLocal: - conn.Scope = PeerHost - case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: - conn.Scope = PeerLAN - case netutils.Global, netutils.GlobalMulticast: - conn.Scope = PeerInternet - - case netutils.Undefined, netutils.Invalid: - fallthrough - default: - conn.Scope = PeerInvalid - } - } + // Check if destination IP is the captive portal's IP. + if conn.Entity.Domain == "" { + portal := netenv.GetCaptivePortal() + if pkt.Info().RemoteIP().Equal(portal.IP) { + conn.Scope = portal.Domain + conn.Entity.Domain = portal.Domain } } @@ -838,7 +864,7 @@ func packetHandlerHandleConn(ctx context.Context, conn *Connection, pkt packet.P case conn.Verdict.Firewall != VerdictUndecided: tracer.Debugf("filter: connection %s fast-tracked", pkt) default: - tracer.Infof("filter: gathered data on connection %s", conn) + tracer.Debugf("filter: gathered data on connection %s", conn) } // Submit trace logs. tracer.Submit() diff --git a/network/connection_store.go b/network/connection_store.go index 86976579..2dec16c9 100644 --- a/network/connection_store.go +++ b/network/connection_store.go @@ -1,6 +1,7 @@ package network import ( + "strings" "sync" ) @@ -37,6 +38,21 @@ func (cs *connectionStore) get(id string) (*Connection, bool) { return conn, ok } +// findByPrefix returns the first connection where the key matches the given prefix. +// If the prefix matches multiple entries, the result is not deterministic. +func (cs *connectionStore) findByPrefix(prefix string) (*Connection, bool) { + cs.rw.RLock() + defer cs.rw.RUnlock() + + for key, conn := range cs.items { + if strings.HasPrefix(key, prefix) { + return conn, true + } + } + + return nil, false +} + func (cs *connectionStore) clone() map[string]*Connection { cs.rw.RLock() defer cs.rw.RUnlock()