Improve recovery handling in TCP resolver

This commit is contained in:
Daniel 2020-10-16 15:00:34 +02:00
parent 5911e3b089
commit 0e0f716499

View file

@ -18,6 +18,7 @@ import (
const ( const (
tcpWriteTimeout = 1 * time.Second tcpWriteTimeout = 1 * time.Second
ignoreQueriesAfter = 10 * time.Minute ignoreQueriesAfter = 10 * time.Minute
heartbeatTimeout = 15 * time.Second
) )
// TCPResolver is a resolver using just a single tcp connection with pipelining. // TCPResolver is a resolver using just a single tcp connection with pipelining.
@ -29,7 +30,7 @@ type TCPResolver struct {
clientStarted *abool.AtomicBool clientStarted *abool.AtomicBool
clientHeartbeat chan struct{} clientHeartbeat chan struct{}
clientCancel func() stopClient func()
connInstanceID *uint32 connInstanceID *uint32
queries chan *dns.Msg queries chan *dns.Msg
inFlightQueries map[uint16]*InFlightQuery inFlightQueries map[uint16]*InFlightQuery
@ -75,9 +76,9 @@ func NewTCPResolver(resolver *Resolver) *TCPResolver {
}, },
clientStarted: abool.New(), clientStarted: abool.New(),
clientHeartbeat: make(chan struct{}), clientHeartbeat: make(chan struct{}),
clientCancel: func() {}, stopClient: func() {},
connInstanceID: &instanceID, connInstanceID: &instanceID,
queries: make(chan *dns.Msg, 100), queries: make(chan *dns.Msg, 1000),
inFlightQueries: make(map[uint16]*InFlightQuery), inFlightQueries: make(map[uint16]*InFlightQuery),
} }
} }
@ -181,15 +182,15 @@ func (tr *TCPResolver) checkClientStatus() {
// Get client cancel function before waiting in order to not immediately // Get client cancel function before waiting in order to not immediately
// cancel a new client. // cancel a new client.
tr.Lock() tr.Lock()
cancelClient := tr.clientCancel stopClient := tr.stopClient
tr.Unlock() tr.Unlock()
// Check if the client is alive with the heartbeat, if not shut it down. // Check if the client is alive with the heartbeat, if not shut it down.
select { select {
case tr.clientHeartbeat <- struct{}{}: case tr.clientHeartbeat <- struct{}{}:
case <-time.After(defaultRequestTimeout): case <-time.After(heartbeatTimeout):
log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.GetName()) log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.GetName())
cancelClient() stopClient()
} }
} }
@ -214,16 +215,16 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
mgr.tr.clientStarted.Set() mgr.tr.clientStarted.Set()
// Create additional cancel function for this worker. // Create additional cancel function for this worker.
workerCtx, cancelWorker := context.WithCancel(workerCtx) clientCtx, stopClient := context.WithCancel(workerCtx)
mgr.tr.Lock() mgr.tr.Lock()
mgr.tr.clientCancel = cancelWorker mgr.tr.stopClient = stopClient
mgr.tr.Unlock() mgr.tr.Unlock()
// connection lifecycle loop // connection lifecycle loop
for { for {
// check if we are shutting down // check if we are shutting down
select { select {
case <-workerCtx.Done(): case <-clientCtx.Done():
return nil return nil
default: default:
} }
@ -234,7 +235,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
} }
// wait for work before creating connection // wait for work before creating connection
proceed := mgr.waitForWork(workerCtx) proceed := mgr.waitForWork(clientCtx)
if !proceed { if !proceed {
return nil return nil
} }
@ -250,7 +251,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
netenv.ReportSuccessfulConnection() netenv.ReportSuccessfulConnection()
// handle queries // handle queries
proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx) proceed = mgr.queryHandler(clientCtx, conn, connClosing, connCtx, cancelConnCtx)
if !proceed { if !proceed {
return nil return nil
} }
@ -276,7 +277,7 @@ func (mgr *tcpResolverConnMgr) shutdown() {
} }
} }
func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) { func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed bool) {
// wait until there is something to do // wait until there is something to do
mgr.tr.Lock() mgr.tr.Lock()
waiting := len(mgr.tr.inFlightQueries) waiting := len(mgr.tr.inFlightQueries)
@ -308,7 +309,7 @@ func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed b
// wait for first query // wait for first query
select { select {
case <-workerCtx.Done(): case <-clientCtx.Done():
return false return false
case msg := <-mgr.tr.queries: case msg := <-mgr.tr.queries:
// re-insert query, we will handle it later // re-insert query, we will handle it later
@ -362,7 +363,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() (
) )
// start reader // start reader
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error { module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(clientCtx context.Context) error {
return mgr.msgReader(conn, connClosing, cancelConnCtx) return mgr.msgReader(conn, connClosing, cancelConnCtx)
}) })
@ -370,7 +371,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() (
} }
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter. func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
workerCtx context.Context, clientCtx context.Context,
conn *dns.Conn, conn *dns.Conn,
connClosing *abool.AtomicBool, connClosing *abool.AtomicBool,
connCtx context.Context, connCtx context.Context,
@ -394,7 +395,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
case <-mgr.tr.clientHeartbeat: case <-mgr.tr.clientHeartbeat:
// respond to alive checks // respond to alive checks
case <-workerCtx.Done(): case <-clientCtx.Done():
// module shutdown // module shutdown
return false return false