Add more timeouts to blocking calls in resolver

This commit is contained in:
Daniel 2020-08-14 11:28:55 +02:00
parent 05f57262e9
commit 4d667afd1d
4 changed files with 73 additions and 31 deletions

View file

@ -51,6 +51,17 @@ const (
maxTTL = 24 * 60 * 60 // 24 hours maxTTL = 24 * 60 * 60 // 24 hours
) )
var (
dupReqMap = make(map[string]*dedupeStatus)
dupReqLock sync.Mutex
)
type dedupeStatus struct {
completed chan struct{}
waitUntil time.Time
superseded bool
}
// BlockedUpstreamError is returned when a DNS request // BlockedUpstreamError is returned when a DNS request
// has been blocked by the upstream server. // has been blocked by the upstream server.
type BlockedUpstreamError struct { type BlockedUpstreamError struct {
@ -195,7 +206,10 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
rrCache.requestingNew = true rrCache.requestingNew = true
rrCache.Unlock() rrCache.Unlock()
log.Tracer(ctx).Trace("resolver: serving from cache, requesting new") log.Tracer(ctx).Tracef(
"resolver: using expired RR from cache (since %s), refreshing async now",
time.Since(time.Unix(rrCache.TTL, 0)),
)
// resolve async // resolve async
module.StartWorker("resolve async", func(ctx context.Context) error { module.StartWorker("resolve async", func(ctx context.Context) error {
@ -205,9 +219,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
} }
return nil return nil
}) })
return rrCache
} }
log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0))) log.Tracer(ctx).Tracef(
"resolver: using cached RR (expires in %s)",
time.Until(time.Unix(rrCache.TTL, 0)),
)
return rrCache return rrCache
} }
@ -215,31 +234,46 @@ func deduplicateRequest(ctx context.Context, q *Query) (finishRequest func()) {
// create identifier key // create identifier key
dupKey := q.ID() dupKey := q.ID()
// restart here if waiting timed out
retry:
dupReqLock.Lock() dupReqLock.Lock()
// get duplicate request waitgroup // get duplicate request waitgroup
wg, requestActive := dupReqMap[dupKey] status, requestActive := dupReqMap[dupKey]
// someone else is already on it! // check if the request ist active
if requestActive { if requestActive {
// someone else is already on it!
if time.Now().Before(status.waitUntil) {
dupReqLock.Unlock() dupReqLock.Unlock()
// log that we are waiting // log that we are waiting
log.Tracer(ctx).Tracef("resolver: waiting for duplicate query for %s to complete", dupKey) log.Tracer(ctx).Tracef("resolver: waiting for duplicate query for %s to complete", dupKey)
// wait // wait
wg.Wait() select {
case <-status.completed:
// done! // done!
return nil return nil
case <-time.After(maxRequestTimeout):
// something went wrong with the query, retry
goto retry
}
} else {
// but that someone is taking too long
status.superseded = true
}
} }
// we are currently the only one doing a request for this // we are currently the only one doing a request for this
// create new waitgroup // create new status
wg = new(sync.WaitGroup) status = &dedupeStatus{
// add worker (us!) completed: make(chan struct{}),
wg.Add(1) waitUntil: time.Now().Add(maxRequestTimeout),
}
// add to registry // add to registry
dupReqMap[dupKey] = wg dupReqMap[dupKey] = status
dupReqLock.Unlock() dupReqLock.Unlock()
@ -248,11 +282,13 @@ func deduplicateRequest(ctx context.Context, q *Query) (finishRequest func()) {
dupReqLock.Lock() dupReqLock.Lock()
defer dupReqLock.Unlock() defer dupReqLock.Unlock()
// mark request as done // mark request as done
wg.Done() close(status.completed)
// delete from registry // delete from registry
if !status.superseded {
delete(dupReqMap, dupKey) delete(dupReqMap, dupKey)
} }
} }
}
func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error) { //nolint:gocognit func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error) { //nolint:gocognit
// get resolvers // get resolvers

View file

@ -14,6 +14,7 @@ var (
defaultClientTTL = 5 * time.Minute defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 3 * time.Second // dns query defaultRequestTimeout = 3 * time.Second // dns query
defaultConnectTimeout = 5 * time.Second // tcp/tls defaultConnectTimeout = 5 * time.Second // tcp/tls
maxRequestTimeout = 5 * time.Second
) )
// PlainResolver is a resolver using plain DNS. // PlainResolver is a resolver using plain DNS.

View file

@ -110,7 +110,11 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
tr.Unlock() tr.Unlock()
// submit msg for writing // submit msg for writing
tr.queries <- msg select {
case tr.queries <- msg:
case <-time.After(defaultRequestTimeout):
return nil
}
return inFlight return inFlight
} }
@ -140,8 +144,11 @@ func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// submit to client // submit to client
inFlight := tr.submitQuery(ctx, q) inFlight := tr.submitQuery(ctx, q)
var reply *dns.Msg if inFlight == nil {
return nil, ErrTimeout
}
var reply *dns.Msg
select { select {
case reply = <-inFlight.Response: case reply = <-inFlight.Response:
case <-time.After(defaultRequestTimeout): case <-time.After(defaultRequestTimeout):
@ -177,26 +184,26 @@ func (tr *TCPResolver) startClient() {
} }
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
mgr.tr.clientStarted.Set()
defer mgr.shutdown()
// connection lifecycle loop // connection lifecycle loop
for { for {
// check if we are shutting down // check if we are shutting down
select { select {
case <-workerCtx.Done(): case <-workerCtx.Done():
mgr.shutdown()
return nil return nil
default: default:
} }
// check if we are failing // check if we are failing
if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() { if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() {
mgr.shutdown()
return nil return nil
} }
// wait for work before creating connection // wait for work before creating connection
proceed := mgr.waitForWork(workerCtx) proceed := mgr.waitForWork(workerCtx)
if !proceed { if !proceed {
mgr.shutdown()
return nil return nil
} }
@ -213,7 +220,6 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
// handle queries // handle queries
proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx) proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx)
if !proceed { if !proceed {
mgr.shutdown()
return nil return nil
} }
} }
@ -222,13 +228,15 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
func (mgr *tcpResolverConnMgr) shutdown() { func (mgr *tcpResolverConnMgr) shutdown() {
// reply to all waiting queries // reply to all waiting queries
mgr.tr.Lock() mgr.tr.Lock()
defer mgr.tr.Unlock()
mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter
for id, inFlight := range mgr.tr.inFlightQueries { for id, inFlight := range mgr.tr.inFlightQueries {
close(inFlight.Response) close(inFlight.Response)
delete(mgr.tr.inFlightQueries, id) delete(mgr.tr.inFlightQueries, id)
} }
mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter
mgr.tr.Unlock()
// hint network environment at failed connection // hint network environment at failed connection
if mgr.failCnt >= FailThreshold { if mgr.failCnt >= FailThreshold {

View file

@ -28,9 +28,6 @@ var (
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
activeResolvers map[string]*Resolver // lookup map of all resolvers activeResolvers map[string]*Resolver // lookup map of all resolvers
resolversLock sync.RWMutex resolversLock sync.RWMutex
dupReqMap = make(map[string]*sync.WaitGroup)
dupReqLock sync.Mutex
) )
func indexOfScope(domain string, list []*Scope) int { func indexOfScope(domain string, list []*Scope) int {