diff --git a/resolver/resolve.go b/resolver/resolve.go index 316edd27..f4b05c15 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -51,6 +51,17 @@ const ( maxTTL = 24 * 60 * 60 // 24 hours ) +var ( + dupReqMap = make(map[string]*dedupeStatus) + dupReqLock sync.Mutex +) + +type dedupeStatus struct { + completed chan struct{} + waitUntil time.Time + superseded bool +} + // BlockedUpstreamError is returned when a DNS request // has been blocked by the upstream server. type BlockedUpstreamError struct { @@ -195,7 +206,10 @@ func checkCache(ctx context.Context, q *Query) *RRCache { rrCache.requestingNew = true rrCache.Unlock() - log.Tracer(ctx).Trace("resolver: serving from cache, requesting new") + log.Tracer(ctx).Tracef( + "resolver: using expired RR from cache (since %s), refreshing async now", + time.Since(time.Unix(rrCache.TTL, 0)), + ) // resolve async module.StartWorker("resolve async", func(ctx context.Context) error { @@ -205,9 +219,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache { } return nil }) + + return rrCache } - log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0))) + log.Tracer(ctx).Tracef( + "resolver: using cached RR (expires in %s)", + time.Until(time.Unix(rrCache.TTL, 0)), + ) return rrCache } @@ -215,31 +234,46 @@ func deduplicateRequest(ctx context.Context, q *Query) (finishRequest func()) { // create identifier key dupKey := q.ID() + // restart here if waiting timed out +retry: + dupReqLock.Lock() - // get duplicate request waitgroup - wg, requestActive := dupReqMap[dupKey] + // get duplicate request waitgroup + status, requestActive := dupReqMap[dupKey] - // someone else is already on it! + // check if the request ist active if requestActive { - dupReqLock.Unlock() + // someone else is already on it! + if time.Now().Before(status.waitUntil) { + dupReqLock.Unlock() - // log that we are waiting - log.Tracer(ctx).Tracef("resolver: waiting for duplicate query for %s to complete", dupKey) - // wait - wg.Wait() - // done! - return nil + // log that we are waiting + log.Tracer(ctx).Tracef("resolver: waiting for duplicate query for %s to complete", dupKey) + // wait + select { + case <-status.completed: + // done! + return nil + case <-time.After(maxRequestTimeout): + // something went wrong with the query, retry + goto retry + } + } else { + // but that someone is taking too long + status.superseded = true + } } // we are currently the only one doing a request for this - // create new waitgroup - wg = new(sync.WaitGroup) - // add worker (us!) - wg.Add(1) + // create new status + status = &dedupeStatus{ + completed: make(chan struct{}), + waitUntil: time.Now().Add(maxRequestTimeout), + } // add to registry - dupReqMap[dupKey] = wg + dupReqMap[dupKey] = status dupReqLock.Unlock() @@ -248,9 +282,11 @@ func deduplicateRequest(ctx context.Context, q *Query) (finishRequest func()) { dupReqLock.Lock() defer dupReqLock.Unlock() // mark request as done - wg.Done() + close(status.completed) // delete from registry - delete(dupReqMap, dupKey) + if !status.superseded { + delete(dupReqMap, dupKey) + } } } diff --git a/resolver/resolver-plain.go b/resolver/resolver-plain.go index 991cf8a1..4417ebc1 100644 --- a/resolver/resolver-plain.go +++ b/resolver/resolver-plain.go @@ -14,6 +14,7 @@ var ( defaultClientTTL = 5 * time.Minute defaultRequestTimeout = 3 * time.Second // dns query defaultConnectTimeout = 5 * time.Second // tcp/tls + maxRequestTimeout = 5 * time.Second ) // PlainResolver is a resolver using plain DNS. diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index 5f4ca7a4..f2e1e5b7 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -110,7 +110,11 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery { tr.Unlock() // submit msg for writing - tr.queries <- msg + select { + case tr.queries <- msg: + case <-time.After(defaultRequestTimeout): + return nil + } return inFlight } @@ -140,8 +144,11 @@ func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) { func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { // submit to client inFlight := tr.submitQuery(ctx, q) - var reply *dns.Msg + if inFlight == nil { + return nil, ErrTimeout + } + var reply *dns.Msg select { case reply = <-inFlight.Response: case <-time.After(defaultRequestTimeout): @@ -177,26 +184,26 @@ func (tr *TCPResolver) startClient() { } func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { + mgr.tr.clientStarted.Set() + defer mgr.shutdown() + // connection lifecycle loop for { // check if we are shutting down select { case <-workerCtx.Done(): - mgr.shutdown() return nil default: } // check if we are failing if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() { - mgr.shutdown() return nil } // wait for work before creating connection proceed := mgr.waitForWork(workerCtx) if !proceed { - mgr.shutdown() return nil } @@ -213,7 +220,6 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { // handle queries proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx) if !proceed { - mgr.shutdown() return nil } } @@ -222,13 +228,15 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { func (mgr *tcpResolverConnMgr) shutdown() { // reply to all waiting queries mgr.tr.Lock() + defer mgr.tr.Unlock() + + mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds + atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter + for id, inFlight := range mgr.tr.inFlightQueries { close(inFlight.Response) delete(mgr.tr.inFlightQueries, id) } - mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds - atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter - mgr.tr.Unlock() // hint network environment at failed connection if mgr.failCnt >= FailThreshold { diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 5e4a1fee..5ce0b094 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -28,9 +28,6 @@ var ( localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope activeResolvers map[string]*Resolver // lookup map of all resolvers resolversLock sync.RWMutex - - dupReqMap = make(map[string]*sync.WaitGroup) - dupReqLock sync.Mutex ) func indexOfScope(domain string, list []*Scope) int {