Add hearbeat check to TCP resolver

This commit is contained in:
Daniel 2020-09-22 16:08:17 +02:00
parent bd8d047428
commit 34247b1d82

View file

@ -3,6 +3,8 @@ package resolver
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"sync/atomic"
"time"
@ -26,6 +28,8 @@ type TCPResolver struct {
dnsClient *dns.Client
clientStarted *abool.AtomicBool
clientHeartbeat chan struct{}
clientCancel func()
connInstanceID *uint32
queries chan *dns.Msg
inFlightQueries map[uint16]*InFlightQuery
@ -68,10 +72,12 @@ func NewTCPResolver(resolver *Resolver) *TCPResolver {
Timeout: defaultConnectTimeout,
WriteTimeout: tcpWriteTimeout,
},
connInstanceID: &instanceID,
queries: make(chan *dns.Msg, 100),
inFlightQueries: make(map[uint16]*InFlightQuery),
clientStarted: abool.New(),
clientHeartbeat: make(chan struct{}),
clientCancel: func() {},
connInstanceID: &instanceID,
queries: make(chan *dns.Msg, 1000),
inFlightQueries: make(map[uint16]*InFlightQuery),
}
}
@ -146,6 +152,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// submit to client
inFlight := tr.submitQuery(ctx, q)
if inFlight == nil {
tr.checkClientStatus()
return nil, ErrTimeout
}
@ -153,6 +160,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
select {
case reply = <-inFlight.Response:
case <-time.After(defaultRequestTimeout):
tr.checkClientStatus()
return nil, ErrTimeout
}
@ -168,6 +176,21 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
return inFlight.MakeCacheRecord(reply), nil
}
func (tr *TCPResolver) checkClientStatus() {
// Get client cancel function before waiting in order to not immediately
// cancel a new client.
tr.Lock()
cancelClient := tr.clientCancel
tr.Unlock()
// Check if the client is alive with the heartbeat, if not shut it down.
select {
case tr.clientHeartbeat <- struct{}{}:
case <-time.After(defaultRequestTimeout):
cancelClient()
}
}
type tcpResolverConnMgr struct {
tr *TCPResolver
responses chan *dns.Msg
@ -185,8 +208,14 @@ func (tr *TCPResolver) startClient() {
}
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
mgr.tr.clientStarted.Set()
defer mgr.shutdown()
mgr.tr.clientStarted.Set()
// Create additional cancel function for this worker.
workerCtx, cancelWorker := context.WithCancel(workerCtx)
mgr.tr.Lock()
mgr.tr.clientCancel = cancelWorker
mgr.tr.Unlock()
// connection lifecycle loop
for {
@ -314,10 +343,21 @@ func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) (
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.GetName(), mgr.tr.resolver.ServerAddress)
return nil, nil, nil, nil
}
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
connCtx, cancelConnCtx = context.WithCancel(context.Background())
connClosing = abool.New()
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.GetName(), conn.RemoteAddr())
// Get amount of in waiting queries.
mgr.tr.Lock()
waitingQueries := len(mgr.tr.inFlightQueries)
mgr.tr.Unlock()
// Log that a connection to the resolver was established.
log.Debugf(
"resolver: connected to %s (%s) with %d queries waiting",
mgr.tr.resolver.GetName(),
conn.RemoteAddr(),
waitingQueries,
)
// start reader
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
@ -349,6 +389,9 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
for {
select {
case <-mgr.tr.clientHeartbeat:
// respond to alive checks
case <-workerCtx.Done():
// module shutdown
return false
@ -373,9 +416,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
err := conn.WriteMsg(msg)
if err != nil {
if connClosing.SetToIf(false, true) {
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err)
}
mgr.logConnectionError(err, conn, connClosing)
return true
}
@ -456,11 +497,37 @@ func (mgr *tcpResolverConnMgr) msgReader(
for {
msg, err := conn.ReadMsg()
if err != nil {
if connClosing.SetToIf(false, true) {
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err)
}
mgr.logConnectionError(err, conn, connClosing)
return nil
}
mgr.responses <- msg
}
}
func (mgr *tcpResolverConnMgr) logConnectionError(err error, conn *dns.Conn, connClosing *abool.AtomicBool) {
// Check if we are the first to see an error.
if connClosing.SetToIf(false, true) {
// Get amount of in flight queries.
mgr.tr.Lock()
inFlightQueries := len(mgr.tr.inFlightQueries)
mgr.tr.Unlock()
// Log error.
if errors.Is(err, io.EOF) {
log.Debugf(
"resolver: connection to %s (%s) was closed with %d in-flight queries",
mgr.tr.resolver.GetName(),
conn.RemoteAddr(),
inFlightQueries,
)
} else {
log.Warningf(
"resolver: write error to %s (%s) with %d in-flight queries: %s",
mgr.tr.resolver.GetName(),
conn.RemoteAddr(),
inFlightQueries,
err,
)
}
}
}