Update linux integration code to re-create the nfqueues on failure

This commit is contained in:
Patrick Pacher 2020-12-14 09:42:04 +01:00
parent 7bd94d13d2
commit 44e1b97c30
2 changed files with 113 additions and 47 deletions

View file

@ -5,6 +5,7 @@ package nfq
import ( import (
"context" "context"
"runtime"
"sync/atomic" "sync/atomic"
"time" "time"
@ -19,25 +20,86 @@ import (
// Queue wraps a nfqueue // Queue wraps a nfqueue
type Queue struct { type Queue struct {
id uint16 id uint16
nf *nfqueue.Nfqueue afFamily uint8
nf atomic.Value
packets chan pmpacket.Packet packets chan pmpacket.Packet
cancelSocketCallback context.CancelFunc cancelSocketCallback context.CancelFunc
restart chan struct{}
pendingVerdicts uint64 pendingVerdicts uint64
verdictCompleted chan struct{} verdictCompleted chan struct{}
} }
func (q *Queue) getNfq() *nfqueue.Nfqueue {
return q.nf.Load().(*nfqueue.Nfqueue)
}
// New opens a new nfQueue. // New opens a new nfQueue.
func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
afFamily := unix.AF_INET afFamily := unix.AF_INET
if v6 { if v6 {
afFamily = unix.AF_INET6 afFamily = unix.AF_INET6
} }
ctx, cancel := context.WithCancel(context.Background())
q := &Queue{
id: qid,
afFamily: uint8(afFamily),
nf: atomic.Value{},
restart: make(chan struct{}, 1),
packets: make(chan pmpacket.Packet, 1000),
cancelSocketCallback: cancel,
verdictCompleted: make(chan struct{}, 1),
}
// Do not retry if the first one fails immediately as it
// might point to a deeper integration error that's not fixable
// with retrying ...
if err := q.open(ctx); err != nil {
return nil, err
}
go func() {
Wait:
for {
select {
case <-ctx.Done():
return
case <-q.restart:
runtime.Gosched()
}
for {
err := q.open(ctx)
if err == nil {
continue Wait
}
// Wait 100 ms and then try again ...
log.Errorf("Failed to open nfqueue: %s", err)
select {
case <-ctx.Done():
return
case <-time.After(100 * time.Millisecond):
}
}
}
}()
return q, nil
}
// open opens a new netlink socket and and creates a new nfqueue.
// Upon success, the new nfqueue is atomically stored in Queue.nf.
// Users must use Queue.getNfq to access it. open does not care about
// any other value or queue that might be stored in Queue.nf at
// the time open is called.
func (q *Queue) open(ctx context.Context) error {
cfg := &nfqueue.Config{ cfg := &nfqueue.Config{
NfQueue: qid, NfQueue: q.id,
MaxPacketLen: 1600, // mtu is normally around 1500, make sure to capture it. MaxPacketLen: 1600, // mtu is normally around 1500, make sure to capture it.
MaxQueueLen: 0xffff, MaxQueueLen: 0xffff,
AfFamily: uint8(afFamily), AfFamily: q.afFamily,
Copymode: nfqueue.NfQnlCopyPacket, Copymode: nfqueue.NfQnlCopyPacket,
ReadTimeout: 1000 * time.Millisecond, ReadTimeout: 1000 * time.Millisecond,
WriteTimeout: 1000 * time.Millisecond, WriteTimeout: 1000 * time.Millisecond,
@ -45,20 +107,54 @@ func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
nf, err := nfqueue.Open(cfg) nf, err := nfqueue.Open(cfg)
if err != nil { if err != nil {
return nil, err return err
} }
ctx, cancel := context.WithCancel(context.Background()) if err := nf.RegisterWithErrorFunc(ctx, q.packetHandler(ctx), q.handleError); err != nil {
q := &Queue{ defer nf.Close()
id: qid, return err
nf: nf,
packets: make(chan pmpacket.Packet, 1000),
cancelSocketCallback: cancel,
verdictCompleted: make(chan struct{}, 1),
} }
fn := func(attrs nfqueue.Attribute) int { q.nf.Store(nf)
return nil
}
func (q *Queue) handleError(e error) int {
// embedded interface is required to work-around some
// dep-vendoring weirdness
if opError, ok := e.(interface {
Timeout() bool
Temporary() bool
}); ok {
if opError.Timeout() || opError.Temporary() {
c := atomic.LoadUint64(&q.pendingVerdicts)
if c > 0 {
log.Tracef("nfqueue: waiting for %d pending verdicts", c)
for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here
<-q.verdictCompleted
}
}
return 0
}
}
log.Errorf("nfqueue: encountered error while receiving packets: %s\n", e.Error())
// Close the existing socket
if nf := q.getNfq(); nf != nil {
nf.Close()
}
// Trigger a restart of the queue
q.restart <- struct{}{}
return 1
}
func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int {
return func(attrs nfqueue.Attribute) int {
if attrs.PacketID == nil { if attrs.PacketID == nil {
// we need a packet id to set a verdict, // we need a packet id to set a verdict,
// if we don't get an ID there's hardly anything // if we don't get an ID there's hardly anything
@ -107,48 +203,18 @@ func New(qid uint16, v6 bool) (*Queue, error) { //nolint:gocognit
return 0 // continue calling this fn return 0 // continue calling this fn
} }
errorFunc := func(e error) int {
// embedded interface is required to work-around some
// dep-vendoring weirdness
if opError, ok := e.(interface {
Timeout() bool
Temporary() bool
}); ok {
if opError.Timeout() || opError.Temporary() {
c := atomic.LoadUint64(&q.pendingVerdicts)
if c > 0 {
log.Tracef("nfqueue: waiting for %d pending verdicts", c)
for atomic.LoadUint64(&q.pendingVerdicts) > 0 { // must NOT use c here
<-q.verdictCompleted
}
}
return 0
}
}
log.Errorf("nfqueue: encountered error while receiving packets: %s\n", e.Error())
return 1
}
if err := q.nf.RegisterWithErrorFunc(ctx, fn, errorFunc); err != nil {
defer q.nf.Close()
return nil, err
}
return q, nil
} }
// Destroy destroys the queue. Any error encountered is logged. // Destroy destroys the queue. Any error encountered is logged.
func (q *Queue) Destroy() { func (q *Queue) Destroy() {
q.cancelSocketCallback() q.cancelSocketCallback()
if err := q.nf.Close(); err != nil { if nf := q.getNfq(); nf != nil {
if err := nf.Close(); err != nil {
log.Errorf("nfqueue: failed to close queue %d: %s", q.id, err) log.Errorf("nfqueue: failed to close queue %d: %s", q.id, err)
} }
} }
}
// PacketChannel returns the packet channel. // PacketChannel returns the packet channel.
func (q *Queue) PacketChannel() <-chan pmpacket.Packet { func (q *Queue) PacketChannel() <-chan pmpacket.Packet {

View file

@ -96,7 +96,7 @@ func (pkt *packet) setMark(mark int) error {
}() }()
for { for {
if err := pkt.queue.nf.SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil { if err := pkt.queue.getNfq().SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil {
// embedded interface is required to work-around some // embedded interface is required to work-around some
// dep-vendoring weirdness // dep-vendoring weirdness
if opErr, ok := err.(interface { if opErr, ok := err.(interface {