Use reported PIDs for DNS requests and improve data gathering process

This commit is contained in:
Daniel 2023-07-20 13:37:01 +02:00
parent 5d7caeb4bb
commit 41c5266315
2 changed files with 111 additions and 69 deletions

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@ -172,6 +173,9 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
StopTunnel() error StopTunnel() error
} }
RecvBytes uint64
SentBytes uint64
// pkgQueue is used to serialize packet handling for a single // pkgQueue is used to serialize packet handling for a single
// connection and is served by the connections packetHandler. // connection and is served by the connections packetHandler.
pktQueue chan packet.Packet pktQueue chan packet.Packet
@ -264,24 +268,43 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
ipVersion = packet.IPv4 ipVersion = packet.IPv4
} }
// Get Process. // Create packet info for dns request connection.
// FIXME: Find direct or redirected connection and grab the PID from there. pi := &packet.Info{
Inbound: false, // outbound as we are looking for the process of the source address
Version: ipVersion,
Protocol: packet.UDP,
Src: localIP, // source as in the process we are looking for
SrcPort: localPort, // source as in the process we are looking for
Dst: nil, // do not record direction
DstPort: 0, // do not record direction
PID: process.UndefinedProcessID,
}
// Check if the dns request connection was reported with process info.
dnsRequestConnID := pi.CreateConnectionID()
// Cut the destination, as the dns request may have been redirected and we
// don't know the original destination.
dnsRequestConnIDPrefix, ok := strings.CutSuffix(dnsRequestConnID, "<nil>-0")
if !ok {
log.Tracer(ctx).Warningf("network: unexpected connection ID for finding dns requests connection: %s", dnsRequestConnID)
}
// Find matching dns request connection.
dnsRequestConn, ok := conns.findByPrefix(dnsRequestConnIDPrefix)
if ok && dnsRequestConn.PID != process.UndefinedProcessID {
log.Tracer(ctx).Debugf("network: found matching dns request connection %s", dnsRequestConn)
pi.PID = dnsRequestConn.PID
}
// Find process by remote IP/Port. // Find process by remote IP/Port.
pid, _, _ := process.GetPidOfConnection( if pi.PID == process.UndefinedProcessID {
ctx, pi.PID, _, _ = process.GetPidOfConnection(
&packet.Info{ ctx,
Inbound: false, // outbound as we are looking for the process of the source address pi,
Version: ipVersion, )
Protocol: packet.UDP, }
Src: localIP, // source as in the process we are looking for
SrcPort: localPort, // source as in the process we are looking for // Get process and profile with PID.
Dst: nil, // do not record direction proc, _ := process.GetProcessWithProfile(ctx, pi.PID)
DstPort: 0, // do not record direction
PID: process.UndefinedProcessID,
},
)
proc, _ := process.GetProcessWithProfile(ctx, pid)
timestamp := time.Now().Unix() timestamp := time.Now().Unix()
dnsConn := &Connection{ dnsConn := &Connection{
@ -378,8 +401,7 @@ func NewIncompleteConnection(pkt packet.Packet) *Connection {
// GatherConnectionInfo gathers information on the process and remote entity. // GatherConnectionInfo gathers information on the process and remote entity.
func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
// Get PID if not yet available. // Get PID if not yet available.
// FIXME: Only match for UndefinedProcessID when integrations have been updated. if conn.PID == process.UndefinedProcessID {
if conn.PID <= 0 {
// Get process by looking at the system state tables. // Get process by looking at the system state tables.
// Apply direction as reported from the state tables. // Apply direction as reported from the state tables.
conn.PID, conn.Inbound, _ = process.GetPidOfConnection(pkt.Ctx(), pkt.Info()) conn.PID, conn.Inbound, _ = process.GetPidOfConnection(pkt.Ctx(), pkt.Info())
@ -390,20 +412,22 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
if conn.process == nil { if conn.process == nil {
// We got connection from the system. // We got connection from the system.
conn.process, err = process.GetProcessWithProfile(pkt.Ctx(), conn.PID) conn.process, err = process.GetProcessWithProfile(pkt.Ctx(), conn.PID)
if err != nil { if err == nil {
// Add process/profile metadata for connection.
conn.ProcessContext = getProcessContext(pkt.Ctx(), conn.process)
conn.ProfileRevisionCounter = conn.process.Profile().RevisionCnt()
// Inherit internal status of profile.
if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil {
conn.Internal = localProfile.Internal
}
} else {
conn.process = nil conn.process = nil
err = fmt.Errorf("failed to get process and profile of PID %d: %w", conn.PID, err) if pkt.InfoOnly() {
log.Tracer(pkt.Ctx()).Debugf("network: %s", err) log.Tracer(pkt.Ctx()).Debugf("network: failed to get process and profile of PID %d: %s", conn.PID, err)
return err } else {
} log.Tracer(pkt.Ctx()).Warningf("network: failed to get process and profile of PID %d: %s", conn.PID, err)
}
// Add process/profile metadata for connection.
conn.ProcessContext = getProcessContext(pkt.Ctx(), conn.process)
conn.ProfileRevisionCounter = conn.process.Profile().RevisionCnt()
// Inherit internal status of profile.
if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil {
conn.Internal = localProfile.Internal
} }
} }
@ -435,48 +459,50 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
conn.Scope = IncomingInvalid conn.Scope = IncomingInvalid
} }
} else { } else {
// Outbound direct (possibly P2P) connection.
switch conn.Entity.IPScope {
case netutils.HostLocal:
conn.Scope = PeerHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
conn.Scope = PeerLAN
case netutils.Global, netutils.GlobalMulticast:
conn.Scope = PeerInternet
// check if we can find a domain for that IP case netutils.Undefined, netutils.Invalid:
ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String()) fallthrough
if err != nil { default:
// Try again with the global scope, in case DNS went through the system resolver. conn.Scope = PeerInvalid
ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
}
if err == nil {
lastResolvedDomain := ipinfo.MostRecentDomain()
if lastResolvedDomain != nil {
conn.Scope = lastResolvedDomain.Domain
conn.Entity.Domain = lastResolvedDomain.Domain
conn.Entity.CNAME = lastResolvedDomain.CNAMEs
conn.DNSContext = lastResolvedDomain.DNSRequestContext
conn.Resolver = lastResolvedDomain.Resolver
removeOpenDNSRequest(conn.process.Pid, lastResolvedDomain.Domain)
}
} }
}
}
// check if destination IP is the captive portal's IP // Find domain and DNS context of entity.
portal := netenv.GetCaptivePortal() if conn.Entity.Domain == "" && conn.process.Profile() != nil {
if pkt.Info().RemoteIP().Equal(portal.IP) { // check if we can find a domain for that IP
conn.Scope = portal.Domain ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String())
conn.Entity.Domain = portal.Domain if err != nil {
// Try again with the global scope, in case DNS went through the system resolver.
ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
}
if err == nil {
lastResolvedDomain := ipinfo.MostRecentDomain()
if lastResolvedDomain != nil {
conn.Scope = lastResolvedDomain.Domain
conn.Entity.Domain = lastResolvedDomain.Domain
conn.Entity.CNAME = lastResolvedDomain.CNAMEs
conn.DNSContext = lastResolvedDomain.DNSRequestContext
conn.Resolver = lastResolvedDomain.Resolver
removeOpenDNSRequest(conn.process.Pid, lastResolvedDomain.Domain)
} }
}
}
if conn.Scope == "" { // Check if destination IP is the captive portal's IP.
// outbound direct (possibly P2P) connection if conn.Entity.Domain == "" {
switch conn.Entity.IPScope { portal := netenv.GetCaptivePortal()
case netutils.HostLocal: if pkt.Info().RemoteIP().Equal(portal.IP) {
conn.Scope = PeerHost conn.Scope = portal.Domain
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: conn.Entity.Domain = portal.Domain
conn.Scope = PeerLAN
case netutils.Global, netutils.GlobalMulticast:
conn.Scope = PeerInternet
case netutils.Undefined, netutils.Invalid:
fallthrough
default:
conn.Scope = PeerInvalid
}
}
} }
} }
@ -838,7 +864,7 @@ func packetHandlerHandleConn(ctx context.Context, conn *Connection, pkt packet.P
case conn.Verdict.Firewall != VerdictUndecided: case conn.Verdict.Firewall != VerdictUndecided:
tracer.Debugf("filter: connection %s fast-tracked", pkt) tracer.Debugf("filter: connection %s fast-tracked", pkt)
default: default:
tracer.Infof("filter: gathered data on connection %s", conn) tracer.Debugf("filter: gathered data on connection %s", conn)
} }
// Submit trace logs. // Submit trace logs.
tracer.Submit() tracer.Submit()

View file

@ -1,6 +1,7 @@
package network package network
import ( import (
"strings"
"sync" "sync"
) )
@ -37,6 +38,21 @@ func (cs *connectionStore) get(id string) (*Connection, bool) {
return conn, ok return conn, ok
} }
// findByPrefix returns the first connection where the key matches the given prefix.
// If the prefix matches multiple entries, the result is not deterministic.
func (cs *connectionStore) findByPrefix(prefix string) (*Connection, bool) {
cs.rw.RLock()
defer cs.rw.RUnlock()
for key, conn := range cs.items {
if strings.HasPrefix(key, prefix) {
return conn, true
}
}
return nil, false
}
func (cs *connectionStore) clone() map[string]*Connection { func (cs *connectionStore) clone() map[string]*Connection {
cs.rw.RLock() cs.rw.RLock()
defer cs.rw.RUnlock() defer cs.rw.RUnlock()