Revamp connection handling flow to fix race condition and support info-only packets

This commit is contained in:
Daniel 2023-06-21 15:31:45 +02:00
parent 83b084959e
commit 8a09ba6045
22 changed files with 527 additions and 349 deletions

View file

@ -141,7 +141,11 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo
authenticatedPath += string(filepath.Separator)
// Get process of request.
proc, _, err := process.GetProcessByConnection(ctx, pktInfo)
pid, _, _ := process.GetPidOfConnection(ctx, pktInfo)
if pid < 0 {
return false, fmt.Errorf(deniedMsgUnidentified, api.ErrAPIAccessDeniedMessage) //nolint:stylecheck // message for user
}
proc, err := process.GetOrFindProcess(ctx, pid)
if err != nil {
log.Tracer(ctx).Debugf("filter: failed to get process of api request: %s", err)
originalPid = process.UnidentifiedProcessID

View file

@ -11,7 +11,6 @@ import (
"github.com/google/gopacket/layers"
"github.com/tevino/abool"
"golang.org/x/sync/singleflight"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
@ -54,7 +53,7 @@ func init() {
// TODO: Move interception module to own package (dir).
interceptionModule = modules.Register("interception", interceptionPrep, interceptionStart, interceptionStop, "base", "updates", "network", "notifications", "profiles")
network.SetDefaultFirewallHandler(defaultHandler)
network.SetDefaultFirewallHandler(verdictHandler)
}
func interceptionPrep() error {
@ -120,6 +119,11 @@ func resetAllConnectionVerdicts() {
// Re-evaluate all connections.
var changedVerdicts int
for _, conn := range network.GetAllConnections() {
// Skip incomplete connections.
if !conn.DataIsComplete() {
continue
}
func() {
conn.Lock()
defer conn.Unlock()
@ -167,15 +171,10 @@ func resetAllConnectionVerdicts() {
func interceptionStart() error {
getConfig()
if err := registerMetrics(); err != nil {
return err
}
startAPIAuth()
interceptionModule.StartWorker("stat logger", statLogger)
interceptionModule.StartWorker("packet handler", packetHandler)
interceptionModule.StartServiceWorker("stat logger", 0, statLogger)
interceptionModule.StartServiceWorker("packet handler", 0, packetHandler)
return interception.Start()
}
@ -196,92 +195,38 @@ func SetNameserverIPMatcher(fn func(ip net.IP) bool) error {
return nil
}
func handlePacket(ctx context.Context, pkt packet.Packet) {
// Record metrics.
startTime := time.Now()
defer packetHandlingHistogram.UpdateDuration(startTime)
if fastTrackedPermit(pkt) {
func handlePacket(pkt packet.Packet) {
// First, check for an existing connection.
conn, ok := network.GetConnection(pkt.GetConnectionID())
if ok {
// Add packet to connection handler queue or apply verdict directly.
conn.HandlePacket(pkt)
return
}
// Add context tracer and set context on packet.
traceCtx, tracer := log.AddTracer(ctx)
if tracer != nil {
// The trace is submitted in `network.Connection.packetHandler()`.
tracer.Tracef("filter: handling packet: %s", pkt)
}
pkt.SetCtx(traceCtx)
// Else create new incomplete connection from the packet and start the new handler.
conn = network.NewIncompleteConnection(pkt)
conn.Lock()
defer conn.Unlock()
conn.SetFirewallHandler(fastTrackHandler)
// Get connection of packet.
conn, err := getConnection(pkt)
if err != nil {
tracer.Errorf("filter: packet %s dropped: %s", pkt, err)
_ = pkt.Drop()
return
}
// handle packet
// Let the new connection handler worker handle the packet.
conn.HandlePacket(pkt)
}
var getConnectionSingleInflight singleflight.Group
func getConnection(pkt packet.Packet) (*network.Connection, error) {
created := false
// Create or get connection in single inflight lock in order to prevent duplicates.
newConn, err, shared := getConnectionSingleInflight.Do(pkt.GetConnectionID(), func() (interface{}, error) {
// First, check for an existing connection.
conn, ok := network.GetConnection(pkt.GetConnectionID())
if ok {
return conn, nil
}
// Else create new one from the packet.
conn = network.NewConnectionFromFirstPacket(pkt)
conn.Lock()
defer conn.Unlock()
conn.SetFirewallHandler(initialHandler)
created = true
return conn, nil
})
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
if newConn == nil {
return nil, errors.New("connection getter returned nil")
}
// Transform and log result.
conn := newConn.(*network.Connection) //nolint:forcetypeassert // Can only be a *network.Connection.
sharedIndicator := ""
if shared {
sharedIndicator = " (shared)"
}
if created {
log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s%s", conn.ID, sharedIndicator)
} else {
log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s%s", conn.ID, sharedIndicator)
}
return conn, nil
}
// fastTrackedPermit quickly permits certain network critical or internal connections.
func fastTrackedPermit(pkt packet.Packet) (handled bool) {
func fastTrackedPermit(pkt packet.Packet) (verdict network.Verdict, permanent bool) {
meta := pkt.Info()
// Check if packed was already fast-tracked by the OS integration.
if pkt.FastTrackedByIntegration() {
log.Debugf("filter: fast-tracked by OS integration: %s", pkt)
return true
return network.VerdictAccept, true
}
// Check if connection was already blocked.
if meta.Dst.Equal(blockedIPv4) || meta.Dst.Equal(blockedIPv6) {
_ = pkt.PermanentBlock()
return true
return network.VerdictBlock, true
}
// Some programs do a network self-check where they connects to the same
@ -290,8 +235,8 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
if meta.SrcPort == meta.DstPort &&
meta.Src.Equal(meta.Dst) {
log.Debugf("filter: fast-track network self-check: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
}
switch meta.Protocol { //nolint:exhaustive // Checking for specific values only.
@ -300,8 +245,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
err := pkt.LoadPacketData()
if err != nil {
log.Debugf("filter: failed to load ICMP packet data: %s", err)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
}
// Submit to ICMP listener.
@ -311,8 +255,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
// permanent accept, because then we won't see any future packets of that
// connection and thus cannot continue to submit them.
log.Debugf("filter: fast-track tracing ICMP/v6: %s", pkt)
_ = pkt.Accept()
return true
return network.VerdictAccept, false
}
// Handle echo request and replies regularly.
@ -323,20 +266,19 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
switch icmpLayer.TypeCode.Type() {
case layers.ICMPv4TypeEchoRequest,
layers.ICMPv4TypeEchoReply:
return false
return network.VerdictUndecided, false
}
case *layers.ICMPv6:
switch icmpLayer.TypeCode.Type() {
case layers.ICMPv6TypeEchoRequest,
layers.ICMPv6TypeEchoReply:
return false
return network.VerdictUndecided, false
}
}
// Permit all ICMP/v6 packets that are not echo requests or replies.
log.Debugf("filter: fast-track accepting ICMP/v6: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
case packet.UDP, packet.TCP:
switch meta.DstPort {
@ -346,37 +288,36 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
// DHCP and DHCPv6 must be UDP.
if meta.Protocol != packet.UDP {
return false
return network.VerdictUndecided, false
}
// DHCP is only valid in local network scopes.
switch netutils.ClassifyIP(meta.Dst) { //nolint:exhaustive // Checking for specific values only.
case netutils.HostLocal, netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
default:
return false
return network.VerdictUndecided, false
}
// Log and permit.
log.Debugf("filter: fast-track accepting DHCP: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
case apiPort:
// Always allow direct access to the Portmaster API.
// Portmaster API is TCP only.
if meta.Protocol != packet.TCP {
return false
return network.VerdictUndecided, false
}
// Check if the api port is even set.
if !apiPortSet {
return false
return network.VerdictUndecided, false
}
// Must be destined for the API IP.
if !meta.Dst.Equal(apiIP) {
return false
return network.VerdictUndecided, false
}
// Only fast-track local requests.
@ -384,15 +325,14 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
switch {
case err != nil:
log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
return false
return network.VerdictUndecided, false
case !isMe:
return false
return network.VerdictUndecided, false
}
// Log and permit.
log.Debugf("filter: fast-track accepting api connection: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
case 53:
// Always allow direct access to the Portmaster Nameserver.
@ -400,12 +340,12 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
// Check if a nameserver IP matcher is set.
if !nameserverIPMatcherReady.IsSet() {
return false
return network.VerdictUndecided, false
}
// Check if packet is destined for a nameserver IP.
if !nameserverIPMatcher(meta.Dst) {
return false
return network.VerdictUndecided, false
}
// Only fast-track local requests.
@ -413,32 +353,76 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
switch {
case err != nil:
log.Debugf("filter: failed to check if %s is own IP for fast-track: %s", meta.Src, err)
return false
return network.VerdictUndecided, false
case !isMe:
return false
return network.VerdictUndecided, false
}
// Log and permit.
log.Debugf("filter: fast-track accepting local dns: %s", pkt)
_ = pkt.PermanentAccept()
return true
return network.VerdictAccept, true
}
case compat.SystemIntegrationCheckProtocol:
if pkt.Info().Dst.Equal(compat.SystemIntegrationCheckDstIP) {
compat.SubmitSystemIntegrationCheckPacket(pkt)
_ = pkt.Drop()
return true
return network.VerdictDrop, false
}
}
return false
return network.VerdictUndecided, false
}
func initialHandler(conn *network.Connection, pkt packet.Packet) {
func fastTrackHandler(conn *network.Connection, pkt packet.Packet) {
fastTrackedVerdict, permanent := fastTrackedPermit(pkt)
if fastTrackedVerdict != network.VerdictUndecided {
// Set verdict on connection.
conn.Verdict.Active = fastTrackedVerdict
conn.Verdict.Firewall = fastTrackedVerdict
// Apply verdict to (real) packet.
if !pkt.InfoOnly() {
issueVerdict(conn, pkt, fastTrackedVerdict, permanent)
}
// Stop handler if permanent.
if permanent {
conn.SetVerdict(fastTrackedVerdict, "fast-tracked", "", nil)
conn.Verdict.Worst = fastTrackedVerdict
// Do not finalize verdict, as we are missing necessary data.
conn.StopFirewallHandler()
}
// Do not continue to next handler.
return
}
// If packet is not fast-tracked, continue with gathering more information.
conn.UpdateFirewallHandler(gatherDataHandler)
gatherDataHandler(conn, pkt)
}
func gatherDataHandler(conn *network.Connection, pkt packet.Packet) {
// Get process info
_ = conn.GatherConnectionInfo(pkt)
// Errors are informational and are logged to the context.
// Run this handler again if data is not yet complete.
if !conn.DataIsComplete() {
return
}
// Continue to filter handler, when connection data is complete.
conn.UpdateFirewallHandler(filterHandler)
filterHandler(conn, pkt)
}
func filterHandler(conn *network.Connection, pkt packet.Packet) {
// Skip if data is not complete.
if !conn.DataIsComplete() {
return
}
filterConnection := true
log.Tracer(pkt.Ctx()).Trace("filter: handing over to connection-based handler")
// Check for special (internal) connection cases.
switch {
case !conn.Inbound && localPortIsPreAuthenticated(conn.Entity.Protocol, conn.LocalPort):
@ -480,8 +464,8 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
switch {
case conn.Inspecting:
log.Tracer(pkt.Ctx()).Trace("filter: start inspecting")
conn.SetFirewallHandler(inspectThenVerdict)
inspectThenVerdict(conn, pkt)
conn.SetFirewallHandler(inspectAndVerdictHandler)
inspectAndVerdictHandler(conn, pkt)
default:
conn.StopFirewallHandler()
issueVerdict(conn, pkt, 0, true)
@ -490,6 +474,11 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
// FilterConnection runs all the filtering (and tunneling) procedures.
func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet, checkFilter, checkTunnel bool) {
// Skip if data is not complete.
if !conn.DataIsComplete() {
return
}
if checkFilter {
if filterEnabled() {
log.Tracer(ctx).Trace("filter: starting decision process")
@ -537,12 +526,11 @@ func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.
}
}
func defaultHandler(conn *network.Connection, pkt packet.Packet) {
// TODO: `pkt` has an active trace log, which we currently don't submit.
func verdictHandler(conn *network.Connection, pkt packet.Packet) {
issueVerdict(conn, pkt, 0, true)
}
func inspectThenVerdict(conn *network.Connection, pkt packet.Packet) {
func inspectAndVerdictHandler(conn *network.Connection, pkt packet.Packet) {
pktVerdict, continueInspection := inspection.RunInspectors(conn, pkt)
if continueInspection {
issueVerdict(conn, pkt, pktVerdict, false)
@ -689,10 +677,11 @@ func packetHandler(ctx context.Context) error {
case <-ctx.Done():
return nil
case pkt := <-interception.Packets:
interceptionModule.StartWorker("initial packet handler", func(workerCtx context.Context) error {
handlePacket(workerCtx, pkt)
return nil
})
if pkt != nil {
handlePacket(pkt)
} else {
return errors.New("received nil packet from interception")
}
}
}
}

View file

@ -13,6 +13,12 @@ type infoPacket struct {
pmpacket.Base
}
// InfoOnly returns whether the packet is informational only and does not
// represent an actual packet.
func (pkt *infoPacket) InfoOnly() bool {
return true
}
// LoadPacketData does nothing on Linux, as data is always fully parsed.
func (pkt *infoPacket) LoadPacketData() error {
return fmt.Errorf("can't load data in info only packet")

View file

@ -5,11 +5,13 @@ import (
"encoding/binary"
"errors"
"net"
"time"
"unsafe"
"github.com/cilium/ebpf/link"
"github.com/cilium/ebpf/ringbuf"
"github.com/cilium/ebpf/rlimit"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network/packet"
)
@ -17,6 +19,7 @@ import (
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" -type Event bpf program/monitor.c
var stopper chan struct{}
// StartEBPFWorker starts the ebpf worker.
func StartEBPFWorker(ch chan packet.Packet) {
stopper = make(chan struct{})
go func() {
@ -32,7 +35,7 @@ func StartEBPFWorker(ch chan packet.Packet) {
log.Errorf("ebpf: failed to load ebpf object: %s", err)
return
}
defer objs.Close()
defer objs.Close() //nolint:errcheck
// Create a link to the tcp_connect program.
linkTCPConnect, err := link.AttachTracing(link.TracingOptions{
@ -42,7 +45,7 @@ func StartEBPFWorker(ch chan packet.Packet) {
log.Errorf("ebpf: failed to attach to tcp_v4_connect: %s ", err)
return
}
defer linkTCPConnect.Close()
defer linkTCPConnect.Close() //nolint:errcheck
// Create a link to the udp_v4_connect program.
linkUDPV4, err := link.AttachTracing(link.TracingOptions{
@ -52,7 +55,7 @@ func StartEBPFWorker(ch chan packet.Packet) {
log.Errorf("ebpf: failed to attach to udp_v4_connect: %s ", err)
return
}
defer linkUDPV4.Close()
defer linkUDPV4.Close() //nolint:errcheck
// Create a link to the udp_v6_connect program.
linkUDPV6, err := link.AttachTracing(link.TracingOptions{
@ -62,14 +65,14 @@ func StartEBPFWorker(ch chan packet.Packet) {
log.Errorf("ebpf: failed to attach to udp_v6_connect: %s ", err)
return
}
defer linkUDPV6.Close()
defer linkUDPV6.Close() //nolint:errcheck
rd, err := ringbuf.NewReader(objs.bpfMaps.Events)
if err != nil {
log.Errorf("ebpf: failed to open ring buffer: %s", err)
return
}
defer rd.Close()
defer rd.Close() //nolint:errcheck
go func() {
<-stopper
@ -107,7 +110,8 @@ func StartEBPFWorker(ch chan packet.Packet) {
DstPort: event.Dport,
Src: arrayToIP(event.Saddr, packet.IPVersion(event.IpVersion)),
Dst: arrayToIP(event.Daddr, packet.IPVersion(event.IpVersion)),
PID: event.Pid,
PID: int(event.Pid),
SeenAt: time.Now(),
}
if isEventValid(event) {
log.Debugf("ebpf: PID: %d conn: %s:%d -> %s:%d %s %s", info.PID, info.LocalIP(), info.LocalPort(), info.RemoteIP(), info.RemotePort(), info.Version.String(), info.Protocol.String())
@ -123,6 +127,7 @@ func StartEBPFWorker(ch chan packet.Packet) {
}()
}
// StopEBPFWorker stops the ebpf worker.
func StopEBPFWorker() {
close(stopper)
}
@ -148,11 +153,12 @@ func isEventValid(event bpfEvent) bool {
return true
}
// arrayToIP converts IP number array to net.IP
// arrayToIP converts IP number array to net.IP.
func arrayToIP(ipNum [4]uint32, ipVersion packet.IPVersion) net.IP {
if ipVersion == packet.IPv4 {
// FIXME: maybe convertIPv4 from windowskext package
return unsafe.Slice((*byte)(unsafe.Pointer(&ipNum)), 4)
} else {
return unsafe.Slice((*byte)(unsafe.Pointer(&ipNum)), 16)
}
// FIXME: maybe use convertIPv6 from windowskext package
return unsafe.Slice((*byte)(unsafe.Pointer(&ipNum)), 16)
}

View file

@ -16,6 +16,7 @@ import (
"github.com/safing/portbase/log"
pmpacket "github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/process"
)
// Queue wraps a nfqueue.
@ -175,10 +176,11 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int {
pkt := &packet{
pktID: *attrs.PacketID,
queue: q,
received: time.Now(),
verdictSet: make(chan struct{}),
verdictPending: abool.New(),
}
pkt.Info().PID = process.UndefinedProcessID
pkt.Info().SeenAt = time.Now()
if attrs.Payload == nil {
// There is not payload.
@ -194,11 +196,11 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int {
select {
case q.packets <- pkt:
log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received))
log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt))
case <-ctx.Done():
return 0
case <-time.After(time.Second):
log.Warningf("nfqueue: failed to queue packet (%s since it was handed over by the kernel)", time.Since(pkt.received))
log.Warningf("nfqueue: failed to queue packet (%s since it was handed over by the kernel)", time.Since(pkt.Info().SeenAt))
}
go func() {
@ -206,7 +208,7 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int {
case <-pkt.verdictSet:
case <-time.After(20 * time.Second):
log.Warningf("nfqueue: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.received))
log.Warningf("nfqueue: no verdict set for packet %s (%s -> %s) after %s, dropping", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt))
if err := pkt.Drop(); err != nil {
log.Warningf("nfqueue: failed to apply default-drop to unveridcted packet %s (%s -> %s)", pkt.ID(), pkt.Info().Src, pkt.Info().Dst)
}

View file

@ -55,7 +55,6 @@ func markToString(mark int) string {
type packet struct {
pmpacket.Base
pktID uint32
received time.Time
queue *Queue
verdictSet chan struct{}
verdictPending *abool.AtomicBool
@ -118,7 +117,7 @@ func (pkt *packet) setMark(mark int) error {
}
break
}
log.Tracer(pkt.Ctx()).Tracef("nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.received))
log.Tracer(pkt.Ctx()).Tracef("nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.Info().SeenAt))
return nil
}

View file

@ -8,8 +8,11 @@ import (
"errors"
"fmt"
"net"
"time"
"unsafe"
"github.com/safing/portmaster/process"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
@ -103,21 +106,28 @@ func Handler(packets chan packet.Packet) {
verdictRequest: packetInfo,
verdictSet: abool.NewBool(false),
}
info := new.Info()
info.Inbound = packetInfo.direction > 0
info.InTunnel = false
info.Protocol = packet.IPProtocol(packetInfo.protocol)
info.PID = packetInfo.pid
info.PID = int(packetInfo.pid)
info.SeenAt = time.Now()
// IP version
// Check PID
if info.PID == 0 {
// Windows does not have zero PIDs.
// Set to UndefinedProcessID.
info.PID = process.UndefinedProcessID
}
// Set IP version
if packetInfo.ipV6 == 1 {
info.Version = packet.IPv6
} else {
info.Version = packet.IPv4
}
// IPs
// Set IPs
if info.Version == packet.IPv4 {
// IPv4
if info.Inbound {
@ -142,7 +152,7 @@ func Handler(packets chan packet.Packet) {
}
}
// Ports
// Set Ports
if info.Inbound {
// Inbound
info.SrcPort = packetInfo.remotePort

View file

@ -1,3 +1,4 @@
//go:build windows
// +build windows
package windowskext
@ -23,6 +24,12 @@ type Packet struct {
lock sync.Mutex
}
// InfoOnly returns whether the packet is informational only and does not
// represent an actual packet.
func (pkt *Packet) InfoOnly() bool {
return pkt.verdictRequest.flags&VerdictRequestFlagSocketAuth > 0
}
// FastTrackedByIntegration returns whether the packet has been fast-track
// accepted by the OS integration.
func (pkt *Packet) FastTrackedByIntegration() bool {

View file

@ -25,19 +25,6 @@ import (
"github.com/safing/portmaster/profile/endpoints"
)
// Call order:
//
// DNS Query:
// 1. DecideOnConnection
// is called when a DNS query is made, may set verdict to Undeterminable to permit a DNS reply.
// is called with a nil packet.
// 2. DecideOnResolvedDNS
// is called to (possibly) filter out A/AAAA records that the filter would deny later.
//
// Network Connection:
// 3. DecideOnConnection
// is called with the first packet of a network connection.
const noReasonOptionKey = ""
type deciderFn func(context.Context, *network.Connection, *profile.LayeredProfile, packet.Packet) bool

View file

@ -1,21 +0,0 @@
package firewall
import (
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/metrics"
)
var packetHandlingHistogram *metrics.Histogram
func registerMetrics() (err error) {
packetHandlingHistogram, err = metrics.NewHistogram(
"firewall/handling/duration/seconds",
nil,
&metrics.Options{
Permission: api.PermitUser,
ExpertiseLevel: config.ExpertiseLevelExpert,
})
return err
}

View file

@ -87,6 +87,9 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
if !ok {
return
}
if !conn.DataIsComplete() {
continue
}
model, err := convertConnection(conn)
if err != nil {

View file

@ -51,6 +51,11 @@ func cleanConnections() (activePIDs map[int]struct{}) {
// delete inactive connections
switch {
case !conn.DataIsComplete():
// Step 0: delete old incomplete connections
if conn.Started < deleteOlderThan {
conn.delete()
}
case conn.Ended == 0:
// Step 1: check if still active
exists := state.Exists(&packet.Info{

View file

@ -7,6 +7,8 @@ import (
"sync"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/intel"
@ -102,6 +104,8 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
// set for connections created from DNS requests. LocalPort is
// considered immutable once a connection object has been created.
LocalPort uint16
// PID holds the PID of the owning process.
PID int
// Entity describes the remote entity that the connection has been
// established to. The entity might be changed or information might
// be added to it during the livetime of a connection. Access to
@ -168,6 +172,19 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
StopTunnel() error
}
// pkgQueue is used to serialize packet handling for a single
// connection and is served by the connections packetHandler.
pktQueue chan packet.Packet
// pktQueueActive signifies whether the packet queue is active and may be written to.
pktQueueActive bool
// pktQueueLock locks access to pktQueueActive and writing to pktQueue.
pktQueueLock sync.Mutex
// dataComplete signifies that all information about the connection is
// available and an actual packet has been seen.
// As long as this flag is not set, the connection may not be evaluated for
// a verdict and may not be sent to the UI.
dataComplete *abool.AtomicBool
// Internal is set to true if the connection is attributed as an
// Portmaster internal connection. Internal may be set at different
// points and access to it must be guarded by the connection lock.
@ -175,9 +192,6 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
// process holds a reference to the actor process. That is, the
// process instance that initiated the connection.
process *process.Process
// pkgQueue is used to serialize packet handling for a single
// connection and is served by the connections packetHandler.
pktQueue chan packet.Packet
// firewallHandler is the firewall handler that is called for
// each packet sent to pktQueue.
firewallHandler FirewallHandler
@ -250,8 +264,11 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
ipVersion = packet.IPv4
}
// get Process
proc, _, err := process.GetProcessByConnection(
// Get Process.
// FIXME: Find direct or redirected connection and grab the PID from there.
// Find process by remote IP/Port.
pid, _, _ := process.GetPidOfConnection(
ctx,
&packet.Info{
Inbound: false, // outbound as we are looking for the process of the source address
@ -261,18 +278,17 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
SrcPort: localPort, // source as in the process we are looking for
Dst: nil, // do not record direction
DstPort: 0, // do not record direction
PID: process.UndefinedProcessID,
},
)
if err != nil {
log.Tracer(ctx).Debugf("network: failed to find process of dns request for %s: %s", fqdn, err)
proc = process.GetUnidentifiedProcess(ctx)
}
proc, _ := process.GetProcessWithProfile(ctx, pid)
timestamp := time.Now().Unix()
dnsConn := &Connection{
ID: connID,
Type: DNSRequest,
Scope: fqdn,
PID: proc.Pid,
Entity: &intel.Entity{
Domain: fqdn,
CNAME: cnames,
@ -281,6 +297,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
ProcessContext: getProcessContext(ctx, proc),
Started: timestamp,
Ended: timestamp,
dataComplete: abool.NewBool(true),
}
// Inherit internal status of profile.
@ -292,6 +309,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
// query. Blocked requests are saved immediately, accepted ones are only
// saved if they are not "used" by a connection.
dnsConn.UpdateMeta()
return dnsConn
}
@ -308,6 +326,7 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname
Type: DNSRequest,
External: true,
Scope: fqdn,
PID: process.NetworkHostProcessID,
Entity: &intel.Entity{
Domain: fqdn,
CNAME: cnames,
@ -316,6 +335,7 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname
ProcessContext: getProcessContext(ctx, remoteHost),
Started: timestamp,
Ended: timestamp,
dataComplete: abool.NewBool(true),
}
// Inherit internal status of profile.
@ -327,131 +347,152 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname
// query. Blocked requests are saved immediately, accepted ones are only
// saved if they are not "used" by a connection.
dnsConn.UpdateMeta()
return dnsConn, nil
}
// NewConnectionFromFirstPacket returns a new connection based on the given packet.
func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
// get Process
proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info())
if err != nil {
log.Tracer(pkt.Ctx()).Debugf("network: failed to find process of packet %s: %s", pkt, err)
if inbound && !netutils.ClassifyIP(pkt.Info().Dst).IsLocalhost() {
proc = process.GetUnsolicitedProcess(pkt.Ctx())
} else {
proc = process.GetUnidentifiedProcess(pkt.Ctx())
}
// NewIncompleteConnection creates a new incomplete connection with only minimal information.
func NewIncompleteConnection(pkt packet.Packet) *Connection {
info := pkt.Info()
// Create new connection object.
// We do not yet know the direction of the connection for sure, so we can only set minimal information.
conn := &Connection{
ID: pkt.GetConnectionID(),
Type: IPConnection,
IPVersion: info.Version,
IPProtocol: info.Protocol,
Started: info.SeenAt.Unix(),
PID: info.PID,
dataComplete: abool.NewBool(false),
}
// Create the (remote) entity.
entity := &intel.Entity{
Protocol: uint8(pkt.Info().Protocol),
Port: pkt.Info().RemotePort(),
// Save connection to internal state in order to mitigate creation of
// duplicates. Do not propagate yet, as data is not yet complete.
conn.UpdateMeta()
conns.add(conn)
return conn
}
// GatherConnectionInfo gathers information on the process and remote entity.
func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
// Get PID if not yet available.
// FIXME: Only match for UndefinedProcessID when integrations have been updated.
if conn.PID <= 0 {
// Get process by looking at the system state tables.
// Apply direction as reported from the state tables.
conn.PID, conn.Inbound, _ = process.GetPidOfConnection(pkt.Ctx(), pkt.Info())
// Errors are informational and are logged to the context.
}
entity.SetIP(pkt.Info().RemoteIP())
entity.SetDstPort(pkt.Info().DstPort)
var scope string
var resolverInfo *resolver.ResolverInfo
var dnsContext *resolver.DNSRequestContext
if inbound {
switch entity.IPScope {
case netutils.HostLocal:
scope = IncomingHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
scope = IncomingLAN
case netutils.Global, netutils.GlobalMulticast:
scope = IncomingInternet
case netutils.Undefined, netutils.Invalid:
fallthrough
default:
scope = IncomingInvalid
}
} else {
// check if we can find a domain for that IP
ipinfo, err := resolver.GetIPInfo(proc.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String())
// Get Process and Profile.
if conn.process == nil {
// We got connection from the system.
conn.process, err = process.GetProcessWithProfile(pkt.Ctx(), conn.PID)
if err != nil {
// Try again with the global scope, in case DNS went through the system resolver.
ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
}
if err == nil {
lastResolvedDomain := ipinfo.MostRecentDomain()
if lastResolvedDomain != nil {
scope = lastResolvedDomain.Domain
entity.Domain = lastResolvedDomain.Domain
entity.CNAME = lastResolvedDomain.CNAMEs
dnsContext = lastResolvedDomain.DNSRequestContext
resolverInfo = lastResolvedDomain.Resolver
removeOpenDNSRequest(proc.Pid, lastResolvedDomain.Domain)
}
conn.process = nil
err = fmt.Errorf("failed to get process and profile of PID %d: %w", conn.PID, err)
log.Tracer(pkt.Ctx()).Debugf("network: %s", err)
return err
}
// check if destination IP is the captive portal's IP
portal := netenv.GetCaptivePortal()
if pkt.Info().RemoteIP().Equal(portal.IP) {
scope = portal.Domain
entity.Domain = portal.Domain
}
// Add process/profile metadata for connection.
conn.ProcessContext = getProcessContext(pkt.Ctx(), conn.process)
conn.ProfileRevisionCounter = conn.process.Profile().RevisionCnt()
if scope == "" {
// outbound direct (possibly P2P) connection
switch entity.IPScope {
// Inherit internal status of profile.
if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil {
conn.Internal = localProfile.Internal
}
}
// Create remote entity.
if conn.Entity == nil {
// Remote
conn.Entity = &intel.Entity{
Protocol: uint8(pkt.Info().Protocol),
Port: pkt.Info().RemotePort(),
}
conn.Entity.SetIP(pkt.Info().RemoteIP())
conn.Entity.SetDstPort(pkt.Info().DstPort)
// Local
conn.SetLocalIP(pkt.Info().LocalIP())
conn.LocalPort = pkt.Info().LocalPort()
if conn.Inbound {
switch conn.Entity.IPScope {
case netutils.HostLocal:
scope = PeerHost
conn.Scope = IncomingHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
scope = PeerLAN
conn.Scope = IncomingLAN
case netutils.Global, netutils.GlobalMulticast:
scope = PeerInternet
conn.Scope = IncomingInternet
case netutils.Undefined, netutils.Invalid:
fallthrough
default:
scope = PeerInvalid
conn.Scope = IncomingInvalid
}
} else {
// check if we can find a domain for that IP
ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String())
if err != nil {
// Try again with the global scope, in case DNS went through the system resolver.
ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
}
if err == nil {
lastResolvedDomain := ipinfo.MostRecentDomain()
if lastResolvedDomain != nil {
conn.Scope = lastResolvedDomain.Domain
conn.Entity.Domain = lastResolvedDomain.Domain
conn.Entity.CNAME = lastResolvedDomain.CNAMEs
conn.DNSContext = lastResolvedDomain.DNSRequestContext
conn.Resolver = lastResolvedDomain.Resolver
removeOpenDNSRequest(conn.process.Pid, lastResolvedDomain.Domain)
}
}
// check if destination IP is the captive portal's IP
portal := netenv.GetCaptivePortal()
if pkt.Info().RemoteIP().Equal(portal.IP) {
conn.Scope = portal.Domain
conn.Entity.Domain = portal.Domain
}
if conn.Scope == "" {
// outbound direct (possibly P2P) connection
switch conn.Entity.IPScope {
case netutils.HostLocal:
conn.Scope = PeerHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
conn.Scope = PeerLAN
case netutils.Global, netutils.GlobalMulticast:
conn.Scope = PeerInternet
case netutils.Undefined, netutils.Invalid:
fallthrough
default:
conn.Scope = PeerInvalid
}
}
}
}
// Create new connection object.
newConn := &Connection{
ID: pkt.GetConnectionID(),
Type: IPConnection,
Scope: scope,
IPVersion: pkt.Info().Version,
Inbound: inbound,
// local endpoint
IPProtocol: pkt.Info().Protocol,
LocalPort: pkt.Info().LocalPort(),
ProcessContext: getProcessContext(pkt.Ctx(), proc),
DNSContext: dnsContext,
process: proc,
// remote endpoint
Entity: entity,
// resolver used to resolve dns request
Resolver: resolverInfo,
// meta
Started: time.Now().Unix(),
ProfileRevisionCounter: proc.Profile().RevisionCnt(),
}
newConn.SetLocalIP(pkt.Info().LocalIP())
// Inherit internal status of profile.
if localProfile := proc.Profile().LocalProfile(); localProfile != nil {
newConn.Internal = localProfile.Internal
// Data collection is only complete with a packet.
if pkt.InfoOnly() {
return nil
}
// Save connection to internal state in order to mitigate creation of
// duplicates. Do not propagate yet, as there is no verdict yet.
conns.add(newConn)
return newConn
// If we have all data and have seen an actual packet, the connection data is complete.
conn.dataComplete.Set()
return nil
}
// GetConnection fetches a Connection from the database.
func GetConnection(id string) (*Connection, bool) {
return conns.get(id)
func GetConnection(connID string) (*Connection, bool) {
return conns.get(connID)
}
// GetAllConnections Gets all connection.
@ -563,6 +604,14 @@ func (conn *Connection) VerdictVerb() string {
)
}
// DataIsComplete returns whether all information about the connection is
// available and an actual packet has been seen.
// As long as this flag is not set, the connection may not be evaluated for
// a verdict and may not be sent to the UI.
func (conn *Connection) DataIsComplete() bool {
return conn.dataComplete.IsSet()
}
// Process returns the connection's process.
func (conn *Connection) Process() *process.Process {
return conn.process
@ -579,9 +628,13 @@ func (conn *Connection) SaveWhenFinished() {
// Callers must make sure to lock the connection itself before calling
// Save().
func (conn *Connection) Save() {
conn.addToMetrics()
conn.UpdateMeta()
// Do not save/update until data is complete.
if !conn.DataIsComplete() {
return
}
if !conn.KeyIsSet() {
if conn.Type == DNSRequest {
conn.SetKey(makeKey(conn.process.Pid, dbScopeDNS, conn.ID))
@ -592,6 +645,8 @@ func (conn *Connection) Save() {
}
}
conn.addToMetrics()
// notify database controller
dbController.PushUpdate(conn)
}
@ -610,29 +665,61 @@ func (conn *Connection) delete() {
}
conn.Meta().Delete()
dbController.PushUpdate(conn)
// Notify database controller if data is complete and thus connection was previously exposed.
if conn.DataIsComplete() {
dbController.PushUpdate(conn)
}
}
// SetFirewallHandler sets the firewall handler for this link, and starts a
// worker to handle the packets.
// The caller needs to hold a lock on the connection.
// Cannot be called with "nil" handler. Call StopFirewallHandler() instead.
func (conn *Connection) SetFirewallHandler(handler FirewallHandler) {
if conn.firewallHandler == nil {
conn.pktQueue = make(chan packet.Packet, 100)
if handler == nil {
return
}
// Start packet handler worker when first handler is set.
if conn.firewallHandler == nil {
// start handling
module.StartWorker("packet handler", conn.packetHandlerWorker)
}
// Set new handler.
conn.firewallHandler = handler
// Initialize packet queue, if needed.
conn.pktQueueLock.Lock()
defer conn.pktQueueLock.Unlock()
if !conn.pktQueueActive {
conn.pktQueue = make(chan packet.Packet, 100)
conn.pktQueueActive = true
}
}
// UpdateFirewallHandler sets the firewall handler if it already set and the
// given handler is not nil.
// The caller needs to hold a lock on the connection.
func (conn *Connection) UpdateFirewallHandler(handler FirewallHandler) {
if handler != nil && conn.firewallHandler != nil {
conn.firewallHandler = handler
}
}
// StopFirewallHandler unsets the firewall handler and stops the handler worker.
// The caller needs to hold a lock on the connection.
func (conn *Connection) StopFirewallHandler() {
conn.pktQueueLock.Lock()
defer conn.pktQueueLock.Unlock()
// Unset the firewall handler to revert to the default handler.
conn.firewallHandler = nil
// Signal the packet handler worker that it can stop.
close(conn.pktQueue)
conn.pktQueueActive = false
// Unset the packet queue so that it can be freed.
conn.pktQueue = nil
@ -640,15 +727,25 @@ func (conn *Connection) StopFirewallHandler() {
// HandlePacket queues packet of Link for handling.
func (conn *Connection) HandlePacket(pkt packet.Packet) {
conn.Lock()
defer conn.Unlock()
conn.pktQueueLock.Lock()
defer conn.pktQueueLock.Unlock()
// execute handler or verdict
if conn.firewallHandler != nil {
conn.pktQueue <- pkt
// TODO: drop if overflowing?
if conn.pktQueueActive {
select {
case conn.pktQueue <- pkt:
default:
log.Debugf(
"filter: dropping packet %s, as there is no space in the connection's handling queue",
pkt,
)
_ = pkt.Drop()
}
} else {
defaultFirewallHandler(conn, pkt)
// Record metrics.
packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt)
}
}
@ -656,7 +753,12 @@ func (conn *Connection) HandlePacket(pkt packet.Packet) {
func (conn *Connection) packetHandlerWorker(ctx context.Context) error {
// Copy packet queue, so we can remove the reference from the connection
// when we stop the firewall handler.
pktQueue := conn.pktQueue
var pktQueue chan packet.Packet
func() {
conn.pktQueueLock.Lock()
defer conn.pktQueueLock.Unlock()
pktQueue = conn.pktQueue
}()
for {
select {
@ -664,21 +766,27 @@ func (conn *Connection) packetHandlerWorker(ctx context.Context) error {
if pkt == nil {
return nil
}
packetHandlerHandleConn(conn, pkt)
packetHandlerHandleConn(ctx, conn, pkt)
case <-ctx.Done():
conn.Lock()
defer conn.Unlock()
conn.firewallHandler = nil
return nil
}
}
}
func packetHandlerHandleConn(conn *Connection, pkt packet.Packet) {
func packetHandlerHandleConn(ctx context.Context, conn *Connection, pkt packet.Packet) {
conn.Lock()
defer conn.Unlock()
// Create tracing context.
// Add context tracer and set context on packet.
traceCtx, tracer := log.AddTracer(ctx)
if tracer != nil {
// The trace is submitted in `network.Connection.packetHandler()`.
tracer.Tracef("filter: handling packet: %s", pkt)
}
pkt.SetCtx(traceCtx)
// Handle packet with appropriate handler.
if conn.firewallHandler != nil {
conn.firewallHandler(conn, pkt)
@ -686,13 +794,22 @@ func packetHandlerHandleConn(conn *Connection, pkt packet.Packet) {
defaultFirewallHandler(conn, pkt)
}
// Log verdict.
log.Tracer(pkt.Ctx()).Infof("filter: connection %s %s: %s", conn, conn.VerdictVerb(), conn.Reason.Msg)
// Submit trace logs.
log.Tracer(pkt.Ctx()).Submit()
// Record metrics.
packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt)
// Save() itself does not touch any changing data.
// Must not be locked - would deadlock with cleaner functions.
// Log result and submit trace.
switch {
case conn.DataIsComplete():
tracer.Infof("filter: connection %s %s: %s", conn, conn.VerdictVerb(), conn.Reason.Msg)
case conn.Verdict.Firewall != VerdictUndecided:
tracer.Debugf("filter: connection %s fast-tracked", conn)
default:
tracer.Infof("filter: gathered data on connection %s", conn)
}
// Submit trace logs.
tracer.Submit()
// Push changes, if there are any.
if conn.saveWhenFinished {
conn.saveWhenFinished = false
conn.Save()

View file

@ -11,6 +11,7 @@ import (
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/process"
"github.com/safing/spn/navigator"
"github.com/tevino/abool"
)
// NewDefaultConnection creates a new connection with default values except local and remote IPs and protocols.
@ -25,6 +26,7 @@ func NewDefaultConnection(localIP net.IP, localPort uint16, remoteIP net.IP, rem
LocalIP: localIP,
LocalIPScope: netutils.Global,
LocalPort: localPort,
PID: process.UnidentifiedProcessID,
Entity: &intel.Entity{
Protocol: uint8(protocol),
IP: remoteIP,
@ -35,6 +37,7 @@ func NewDefaultConnection(localIP net.IP, localPort uint16, remoteIP net.IP, rem
VerdictPermanent: false,
Tunneled: true,
Encrypted: false,
DataComplete: abool.NewBool(true),
Internal: false,
addedToMetrics: true, // Metrics are not needed for now. This will mark the Connection to be ignored.
process: process.GetUnidentifiedProcess(context.Background()),

View file

@ -90,12 +90,12 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
switch scope {
case dbScopeDNS:
if r, ok := dnsConns.get(id); ok {
return r, nil
if c, ok := dnsConns.get(id); ok && c.DataIsComplete() {
return c, nil
}
case dbScopeIP:
if r, ok := conns.get(id); ok {
return r, nil
if c, ok := conns.get(id); ok && c.DataIsComplete() {
return c, nil
}
case dbScopeNone:
if proc, ok := process.GetProcessFromStorage(pid); ok {
@ -143,11 +143,16 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
if scope == dbScopeNone || scope == dbScopeDNS {
// dns scopes only
for _, dnsConn := range dnsConns.clone() {
if !dnsConn.DataIsComplete() {
continue
}
func() {
dnsConn.Lock()
defer dnsConn.Unlock()
matches = q.Matches(dnsConn)
}()
if matches {
it.Next <- dnsConn
}
@ -157,11 +162,16 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
if scope == dbScopeNone || scope == dbScopeIP {
// connections
for _, conn := range conns.clone() {
if !conn.DataIsComplete() {
continue
}
func() {
conn.Lock()
defer conn.Unlock()
matches = q.Matches(conn)
}()
if matches {
it.Next <- conn
}

View file

@ -8,6 +8,7 @@ import (
)
var (
packetHandlingHistogram *metrics.Histogram
blockedOutConnCounter *metrics.Counter
encryptedAndTunneledOutConnCounter *metrics.Counter
encryptedOutConnCounter *metrics.Counter
@ -15,8 +16,21 @@ var (
outConnCounter *metrics.Counter
)
func registerMetrics() error {
_, err := metrics.NewGauge(
func registerMetrics() (err error) {
// This needed to be moved here, because every packet is now handled by the
// connection handler worker.
packetHandlingHistogram, err = metrics.NewHistogram(
"firewall/handling/duration/seconds",
nil,
&metrics.Options{
Permission: api.PermitUser,
ExpertiseLevel: config.ExpertiseLevelExpert,
})
if err != nil {
return err
}
_, err = metrics.NewGauge(
"network/connections/active/total",
nil,
func() float64 {

View file

@ -15,6 +15,8 @@ func GetMulticastRequestConn(responseConn *Connection, responseFromNet *net.IPNe
// Find requesting multicast/broadcast connection.
for _, conn := range conns.clone() {
switch {
case !conn.DataIsComplete():
// Ignore connection with incomplete data.
case conn.Inbound:
// Ignore incoming connections.
case conn.Ended != 0:

View file

@ -24,6 +24,12 @@ func (pkt *Base) FastTrackedByIntegration() bool {
return false
}
// InfoOnly returns whether the packet is informational only and does not
// represent an actual packet.
func (pkt *Base) InfoOnly() bool {
return false
}
// SetCtx sets the packet context.
func (pkt *Base) SetCtx(ctx context.Context) {
pkt.ctx = ctx
@ -107,6 +113,7 @@ func (pkt *Base) GetConnectionID() string {
}
func (pkt *Base) createConnectionID() {
// TODO: make this ID not depend on the packet direction for better support for forwarded packets.
if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP {
if pkt.info.Inbound {
pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort)
@ -236,6 +243,7 @@ type Packet interface {
RerouteToNameserver() error
RerouteToTunnel() error
FastTrackedByIntegration() bool
InfoOnly() bool
// Info.
SetCtx(context.Context)

View file

@ -2,6 +2,7 @@ package packet
import (
"net"
"time"
)
// Info holds IP and TCP/UDP header information.
@ -14,7 +15,8 @@ type Info struct {
SrcPort, DstPort uint16
Src, Dst net.IP
PID uint32
PID int
SeenAt time.Time
}
// LocalIP returns the local IP of the packet.

View file

@ -28,13 +28,18 @@ nextPort:
// Check if the generated port is unused.
nextConnection:
for _, conn := range allConns {
// Skip connection if the protocol does not match the protocol of interest.
if conn.Entity.Protocol != protocol {
switch {
case !conn.DataIsComplete():
// Skip connection if the data is not complete.
continue nextConnection
}
// Skip port if the local port is in dangerous proximity.
// Consecutive port numbers are very common.
if conn.LocalPort <= port && conn.LocalPort >= portRangeStart {
case conn.Entity.Protocol != protocol:
// Skip connection if the protocol does not match the protocol of interest.
continue nextConnection
case conn.LocalPort <= port && conn.LocalPort >= portRangeStart:
// Skip port if the local port is in dangerous proximity.
// Consecutive port numbers are very common.
continue nextPort
}
}

View file

@ -14,43 +14,19 @@ import (
"github.com/safing/portmaster/profile"
)
// GetProcessByConnection returns the process that owns the described connection.
func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process *Process, connInbound bool, err error) {
// GetProcessWithProfile returns the process, including the profile.
// Always returns valid data.
// Errors are logged and returned for information or special handling purposes.
func GetProcessWithProfile(ctx context.Context, pid int) (process *Process, err error) {
if !enableProcessDetection() {
log.Tracer(ctx).Tracef("process: process detection disabled")
return GetUnidentifiedProcess(ctx), pktInfo.Inbound, nil
}
// Use fast search for inbound packets, as the listening socket should
// already be there for a while now.
fastSearch := pktInfo.Inbound
var pid int
if pktInfo.PID == 0 {
log.Tracer(ctx).Tracef("process: getting pid from system network state")
pid, connInbound, err = state.Lookup(pktInfo, fastSearch)
if err != nil {
log.Tracer(ctx).Tracef("process: failed to find PID of connection: %s", err)
return nil, pktInfo.Inbound, err
}
} else {
log.Tracer(ctx).Tracef("process: pid already set in packet (by ebpf or kext)")
pid = int(pktInfo.PID)
}
// Fallback to special profiles if PID could not be found.
if pid == UndefinedProcessID {
if connInbound {
pid = UnsolicitedProcessID
} else {
pid = UnidentifiedProcessID
}
return GetUnidentifiedProcess(ctx), nil
}
process, err = GetOrFindProcess(ctx, pid)
if err != nil {
log.Tracer(ctx).Debugf("process: failed to find (primary) process with PID: %s", err)
return nil, connInbound, err
log.Tracer(ctx).Warningf("process: failed to find process with PID: %s", err)
return GetUnidentifiedProcess(ctx), err
}
changed, err := process.GetProfile(ctx)
@ -62,7 +38,46 @@ func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process
process.Save()
}
return process, connInbound, nil
return process, nil
}
// GetPidOfConnection returns the PID of the process that owns the described connection.
// Always returns valid data.
// Errors are logged and returned for information or special handling purposes.
func GetPidOfConnection(ctx context.Context, pktInfo *packet.Info) (pid int, connInbound bool, err error) {
if !enableProcessDetection() {
return UnidentifiedProcessID, pktInfo.Inbound, nil
}
// Use fast search for inbound packets, as the listening socket should
// already be there for a while now.
fastSearch := pktInfo.Inbound
connInbound = pktInfo.Inbound
// FIXME: Only match for UndefinedProcessID when integrations have been updated.
if pktInfo.PID <= 0 {
log.Tracer(ctx).Tracef("process: getting pid from system network state")
pid, connInbound, err = state.Lookup(pktInfo, fastSearch)
if err != nil {
err = fmt.Errorf("failed to find PID of connection: %w", err)
log.Tracer(ctx).Tracef("process: %s", err)
pid = UndefinedProcessID
}
} else {
log.Tracer(ctx).Tracef("process: pid already set in packet (by ebpf or kext)")
pid = pktInfo.PID
}
// Fallback to special profiles if PID could not be found.
if pid == UndefinedProcessID {
if connInbound && !netutils.ClassifyIP(pktInfo.Dst).IsLocalhost() {
pid = UnsolicitedProcessID
} else {
pid = UnidentifiedProcessID
}
}
return pid, connInbound, err
}
// GetNetworkHost returns a *Process that represents a host on the network.
@ -113,7 +128,12 @@ func GetProcessByRequestOrigin(ar *api.Request) (*Process, error) {
SrcPort: remotePort, // source as in the process we are looking for
}
proc, _, err := GetProcessByConnection(ar.Context(), pkt)
pid, _, err := GetPidOfConnection(ar.Context(), pkt)
if err != nil {
return nil, err
}
proc, err := GetProcessWithProfile(ar.Context(), pid)
if err != nil {
return nil, err
}

View file

@ -217,7 +217,7 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (*
}
// UID
// net yet implemented for windows
// TODO: implemented for windows
if onLinux {
var uids []int32
uids, err = pInfo.UidsWithContext(ctx)