Improve recovery handling in TCP resolver

This commit is contained in:
Daniel 2020-10-16 15:00:34 +02:00
parent 5911e3b089
commit 0e0f716499

View file

@ -18,6 +18,7 @@ import (
const (
tcpWriteTimeout = 1 * time.Second
ignoreQueriesAfter = 10 * time.Minute
heartbeatTimeout = 15 * time.Second
)
// TCPResolver is a resolver using just a single tcp connection with pipelining.
@ -29,7 +30,7 @@ type TCPResolver struct {
clientStarted *abool.AtomicBool
clientHeartbeat chan struct{}
clientCancel func()
stopClient func()
connInstanceID *uint32
queries chan *dns.Msg
inFlightQueries map[uint16]*InFlightQuery
@ -75,9 +76,9 @@ func NewTCPResolver(resolver *Resolver) *TCPResolver {
},
clientStarted: abool.New(),
clientHeartbeat: make(chan struct{}),
clientCancel: func() {},
stopClient: func() {},
connInstanceID: &instanceID,
queries: make(chan *dns.Msg, 100),
queries: make(chan *dns.Msg, 1000),
inFlightQueries: make(map[uint16]*InFlightQuery),
}
}
@ -181,15 +182,15 @@ func (tr *TCPResolver) checkClientStatus() {
// Get client cancel function before waiting in order to not immediately
// cancel a new client.
tr.Lock()
cancelClient := tr.clientCancel
stopClient := tr.stopClient
tr.Unlock()
// Check if the client is alive with the heartbeat, if not shut it down.
select {
case tr.clientHeartbeat <- struct{}{}:
case <-time.After(defaultRequestTimeout):
case <-time.After(heartbeatTimeout):
log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.GetName())
cancelClient()
stopClient()
}
}
@ -214,16 +215,16 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
mgr.tr.clientStarted.Set()
// Create additional cancel function for this worker.
workerCtx, cancelWorker := context.WithCancel(workerCtx)
clientCtx, stopClient := context.WithCancel(workerCtx)
mgr.tr.Lock()
mgr.tr.clientCancel = cancelWorker
mgr.tr.stopClient = stopClient
mgr.tr.Unlock()
// connection lifecycle loop
for {
// check if we are shutting down
select {
case <-workerCtx.Done():
case <-clientCtx.Done():
return nil
default:
}
@ -234,7 +235,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
}
// wait for work before creating connection
proceed := mgr.waitForWork(workerCtx)
proceed := mgr.waitForWork(clientCtx)
if !proceed {
return nil
}
@ -250,7 +251,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
netenv.ReportSuccessfulConnection()
// handle queries
proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx)
proceed = mgr.queryHandler(clientCtx, conn, connClosing, connCtx, cancelConnCtx)
if !proceed {
return nil
}
@ -276,7 +277,7 @@ func (mgr *tcpResolverConnMgr) shutdown() {
}
}
func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) {
func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed bool) {
// wait until there is something to do
mgr.tr.Lock()
waiting := len(mgr.tr.inFlightQueries)
@ -308,7 +309,7 @@ func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed b
// wait for first query
select {
case <-workerCtx.Done():
case <-clientCtx.Done():
return false
case msg := <-mgr.tr.queries:
// re-insert query, we will handle it later
@ -362,7 +363,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() (
)
// start reader
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(clientCtx context.Context) error {
return mgr.msgReader(conn, connClosing, cancelConnCtx)
})
@ -370,7 +371,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() (
}
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
workerCtx context.Context,
clientCtx context.Context,
conn *dns.Conn,
connClosing *abool.AtomicBool,
connCtx context.Context,
@ -394,7 +395,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
case <-mgr.tr.clientHeartbeat:
// respond to alive checks
case <-workerCtx.Done():
case <-clientCtx.Done():
// module shutdown
return false