diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index b3eb35e9..c777b71c 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -3,6 +3,8 @@ package resolver import ( "context" "crypto/tls" + "errors" + "io" "net" "sync/atomic" "time" @@ -26,6 +28,8 @@ type TCPResolver struct { dnsClient *dns.Client clientStarted *abool.AtomicBool + clientHeartbeat chan struct{} + clientCancel func() connInstanceID *uint32 queries chan *dns.Msg inFlightQueries map[uint16]*InFlightQuery @@ -68,10 +72,12 @@ func NewTCPResolver(resolver *Resolver) *TCPResolver { Timeout: defaultConnectTimeout, WriteTimeout: tcpWriteTimeout, }, - connInstanceID: &instanceID, - queries: make(chan *dns.Msg, 100), - inFlightQueries: make(map[uint16]*InFlightQuery), clientStarted: abool.New(), + clientHeartbeat: make(chan struct{}), + clientCancel: func() {}, + connInstanceID: &instanceID, + queries: make(chan *dns.Msg, 1000), + inFlightQueries: make(map[uint16]*InFlightQuery), } } @@ -146,6 +152,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { // submit to client inFlight := tr.submitQuery(ctx, q) if inFlight == nil { + tr.checkClientStatus() return nil, ErrTimeout } @@ -153,6 +160,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { select { case reply = <-inFlight.Response: case <-time.After(defaultRequestTimeout): + tr.checkClientStatus() return nil, ErrTimeout } @@ -168,6 +176,21 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { return inFlight.MakeCacheRecord(reply), nil } +func (tr *TCPResolver) checkClientStatus() { + // Get client cancel function before waiting in order to not immediately + // cancel a new client. + tr.Lock() + cancelClient := tr.clientCancel + tr.Unlock() + + // Check if the client is alive with the heartbeat, if not shut it down. + select { + case tr.clientHeartbeat <- struct{}{}: + case <-time.After(defaultRequestTimeout): + cancelClient() + } +} + type tcpResolverConnMgr struct { tr *TCPResolver responses chan *dns.Msg @@ -185,8 +208,14 @@ func (tr *TCPResolver) startClient() { } func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { - mgr.tr.clientStarted.Set() defer mgr.shutdown() + mgr.tr.clientStarted.Set() + + // Create additional cancel function for this worker. + workerCtx, cancelWorker := context.WithCancel(workerCtx) + mgr.tr.Lock() + mgr.tr.clientCancel = cancelWorker + mgr.tr.Unlock() // connection lifecycle loop for { @@ -314,10 +343,21 @@ func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) ( log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.GetName(), mgr.tr.resolver.ServerAddress) return nil, nil, nil, nil } - connCtx, cancelConnCtx = context.WithCancel(workerCtx) + connCtx, cancelConnCtx = context.WithCancel(context.Background()) connClosing = abool.New() - log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.GetName(), conn.RemoteAddr()) + // Get amount of in waiting queries. + mgr.tr.Lock() + waitingQueries := len(mgr.tr.inFlightQueries) + mgr.tr.Unlock() + + // Log that a connection to the resolver was established. + log.Debugf( + "resolver: connected to %s (%s) with %d queries waiting", + mgr.tr.resolver.GetName(), + conn.RemoteAddr(), + waitingQueries, + ) // start reader module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error { @@ -349,6 +389,9 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context for { select { + case <-mgr.tr.clientHeartbeat: + // respond to alive checks + case <-workerCtx.Done(): // module shutdown return false @@ -373,9 +416,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout)) err := conn.WriteMsg(msg) if err != nil { - if connClosing.SetToIf(false, true) { - log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err) - } + mgr.logConnectionError(err, conn, connClosing) return true } @@ -456,11 +497,37 @@ func (mgr *tcpResolverConnMgr) msgReader( for { msg, err := conn.ReadMsg() if err != nil { - if connClosing.SetToIf(false, true) { - log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err) - } + mgr.logConnectionError(err, conn, connClosing) return nil } mgr.responses <- msg } } + +func (mgr *tcpResolverConnMgr) logConnectionError(err error, conn *dns.Conn, connClosing *abool.AtomicBool) { + // Check if we are the first to see an error. + if connClosing.SetToIf(false, true) { + // Get amount of in flight queries. + mgr.tr.Lock() + inFlightQueries := len(mgr.tr.inFlightQueries) + mgr.tr.Unlock() + + // Log error. + if errors.Is(err, io.EOF) { + log.Debugf( + "resolver: connection to %s (%s) was closed with %d in-flight queries", + mgr.tr.resolver.GetName(), + conn.RemoteAddr(), + inFlightQueries, + ) + } else { + log.Warningf( + "resolver: write error to %s (%s) with %d in-flight queries: %s", + mgr.tr.resolver.GetName(), + conn.RemoteAddr(), + inFlightQueries, + err, + ) + } + } +}