diff --git a/firewall/interception/nfq/nfq.go b/firewall/interception/nfq/nfq.go index f8ad2ddf..f18c5115 100644 --- a/firewall/interception/nfq/nfq.go +++ b/firewall/interception/nfq/nfq.go @@ -5,6 +5,7 @@ package nfq import ( "context" + "runtime" "sync/atomic" "time" @@ -19,25 +20,86 @@ import ( // Queue wraps a nfqueue type Queue struct { id uint16 - nf *nfqueue.Nfqueue + afFamily uint8 + nf atomic.Value packets chan pmpacket.Packet cancelSocketCallback context.CancelFunc + restart chan struct{} pendingVerdicts uint64 verdictCompleted chan struct{} } +func (q *Queue) getNfq() *nfqueue.Nfqueue { + return q.nf.Load().(*nfqueue.Nfqueue) +} + // New opens a new nfQueue. func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit afFamily := unix.AF_INET if v6 { afFamily = unix.AF_INET6 } + + ctx, cancel := context.WithCancel(context.Background()) + q := &Queue{ + id: qid, + afFamily: uint8(afFamily), + nf: atomic.Value{}, + restart: make(chan struct{}, 1), + packets: make(chan pmpacket.Packet, 1000), + cancelSocketCallback: cancel, + verdictCompleted: make(chan struct{}, 1), + } + + // Do not retry if the first one fails immediately as it + // might point to a deeper integration error that's not fixable + // with retrying ... + if err := q.open(ctx); err != nil { + return nil, err + } + + go func() { + Wait: + for { + select { + case <-ctx.Done(): + return + case <-q.restart: + runtime.Gosched() + } + + for { + err := q.open(ctx) + if err == nil { + continue Wait + } + + // Wait 100 ms and then try again ... + log.Errorf("Failed to open nfqueue: %s", err) + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + } + } + } + }() + + return q, nil +} + +// open opens a new netlink socket and and creates a new nfqueue. +// Upon success, the new nfqueue is atomically stored in Queue.nf. +// Users must use Queue.getNfq to access it. open does not care about +// any other value or queue that might be stored in Queue.nf at +// the time open is called. +func (q *Queue) open(ctx context.Context) error { cfg := &nfqueue.Config{ - NfQueue: qid, + NfQueue: q.id, MaxPacketLen: 1600, // mtu is normally around 1500, make sure to capture it. MaxQueueLen: 0xffff, - AfFamily: uint8(afFamily), + AfFamily: q.afFamily, Copymode: nfqueue.NfQnlCopyPacket, ReadTimeout: 1000 * time.Millisecond, WriteTimeout: 1000 * time.Millisecond, @@ -45,20 +107,54 @@ func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit nf, err := nfqueue.Open(cfg) if err != nil { - return nil, err + return err } - ctx, cancel := context.WithCancel(context.Background()) - q := &Queue{ - id: qid, - nf: nf, - packets: make(chan pmpacket.Packet, 1000), - cancelSocketCallback: cancel, - verdictCompleted: make(chan struct{}, 1), + if err := nf.RegisterWithErrorFunc(ctx, q.packetHandler(ctx), q.handleError); err != nil { + defer nf.Close() + return err } - fn := func(attrs nfqueue.Attribute) int { + q.nf.Store(nf) + return nil +} + +func (q *Queue) handleError(e error) int { + // embedded interface is required to work-around some + // dep-vendoring weirdness + if opError, ok := e.(interface { + Timeout() bool + Temporary() bool + }); ok { + if opError.Timeout() || opError.Temporary() { + c := atomic.LoadUint64(&q.pendingVerdicts) + if c > 0 { + log.Tracef("nfqueue: waiting for %d pending verdicts", c) + + for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here + <-q.verdictCompleted + } + } + + return 0 + } + } + log.Errorf("nfqueue: encountered error while receiving packets: %s\n", e.Error()) + + // Close the existing socket + if nf := q.getNfq(); nf != nil { + nf.Close() + } + + // Trigger a restart of the queue + q.restart <- struct{}{} + + return 1 +} + +func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int { + return func(attrs nfqueue.Attribute) int { if attrs.PacketID == nil { // we need a packet id to set a verdict, // if we don't get an ID there's hardly anything @@ -107,46 +203,16 @@ func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit return 0 // continue calling this fn } - - errorFunc := func(e error) int { - // embedded interface is required to work-around some - // dep-vendoring weirdness - if opError, ok := e.(interface { - Timeout() bool - Temporary() bool - }); ok { - if opError.Timeout() || opError.Temporary() { - c := atomic.LoadUint64(&q.pendingVerdicts) - if c > 0 { - log.Tracef("nfqueue: waiting for %d pending verdicts", c) - - for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here - <-q.verdictCompleted - } - } - - return 0 - } - } - log.Errorf("nfqueue: encountered error while receiving packets: %s\n", e.Error()) - - return 1 - } - - if err := q.nf.RegisterWithErrorFunc(ctx, fn, errorFunc); err != nil { - defer q.nf.Close() - return nil, err - } - - return q, nil } // Destroy destroys the queue. Any error encountered is logged. func (q *Queue) Destroy() { q.cancelSocketCallback() - if err := q.nf.Close(); err != nil { - log.Errorf("nfqueue: failed to close queue %d: %s", q.id, err) + if nf := q.getNfq(); nf != nil { + if err := nf.Close(); err != nil { + log.Errorf("nfqueue: failed to close queue %d: %s", q.id, err) + } } } diff --git a/firewall/interception/nfq/packet.go b/firewall/interception/nfq/packet.go index 2399d43e..872a6243 100644 --- a/firewall/interception/nfq/packet.go +++ b/firewall/interception/nfq/packet.go @@ -96,7 +96,7 @@ func (pkt *packet) setMark(mark int) error { }() for { - if err := pkt.queue.nf.SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil { + if err := pkt.queue.getNfq().SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil { // embedded interface is required to work-around some // dep-vendoring weirdness if opErr, ok := err.(interface {