Get/Create connections in single-inflight lock

This commit is contained in:
Daniel 2021-04-07 16:43:13 +02:00
parent c8bb071e29
commit 9ff824967e

View file

@ -3,12 +3,14 @@ package firewall
import (
"context"
"errors"
"fmt"
"net"
"os"
"sync/atomic"
"time"
"github.com/safing/portmaster/netenv"
"golang.org/x/sync/singleflight"
"github.com/tevino/abool"
@ -102,20 +104,60 @@ func handlePacket(ctx context.Context, pkt packet.Packet) {
}
pkt.SetCtx(traceCtx)
// associate packet to link and handle
conn, ok := network.GetConnection(pkt.GetConnectionID())
if ok {
tracer.Tracef("filter: assigned to connection %s", conn.ID)
} else {
conn = network.NewConnectionFromFirstPacket(pkt)
tracer.Tracef("filter: created new connection %s", conn.ID)
conn.SetFirewallHandler(initialHandler)
// Get connection of packet.
conn, err := getConnection(pkt)
if err != nil {
tracer.Errorf("filter: packet %s dropped: %s", pkt, err)
_ = pkt.Drop()
return
}
// handle packet
conn.HandlePacket(pkt)
}
var getConnectionSingleInflight singleflight.Group
func getConnection(pkt packet.Packet) (*network.Connection, error) {
created := false
// Create or get connection in single inflight lock in order to prevent duplicates.
newConn, err, shared := getConnectionSingleInflight.Do(pkt.GetConnectionID(), func() (interface{}, error) {
// First, check for an existing connection.
conn, ok := network.GetConnection(pkt.GetConnectionID())
if ok {
return conn, nil
}
// Else create new one from the packet.
conn = network.NewConnectionFromFirstPacket(pkt)
conn.SetFirewallHandler(initialHandler)
created = true
return conn, nil
})
if err != nil {
return nil, fmt.Errorf("failed to get connection: %s", err)
}
if newConn == nil {
return nil, errors.New("connection getter returned nil")
}
// Transform and log result.
conn := newConn.(*network.Connection)
switch {
case created && shared:
log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s (shared)", conn.ID)
case created:
log.Tracer(pkt.Ctx()).Tracef("filter: created new connection %s", conn.ID)
case shared:
log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s (shared)", conn.ID)
default:
log.Tracer(pkt.Ctx()).Tracef("filter: assigned connection %s", conn.ID)
}
return conn, nil
}
// fastTrackedPermit quickly permits certain network criticial or internal connections.
func fastTrackedPermit(pkt packet.Packet) (handled bool) {
meta := pkt.Info()