Listen for context end better in resolver

This commit is contained in:
Daniel 2021-09-17 22:04:41 +02:00
parent 8b489f4c26
commit a6ce021dbd
3 changed files with 27 additions and 4 deletions

View file

@ -32,6 +32,10 @@ var (
ErrFailure = errors.New("query failed") ErrFailure = errors.New("query failed")
// ErrContinue is returned when the resolver has no answer, and the next resolver should be asked // ErrContinue is returned when the resolver has no answer, and the next resolver should be asked
ErrContinue = errors.New("resolver has no answer") ErrContinue = errors.New("resolver has no answer")
// ErrCancelled is returned when the request was cancelled.
ErrCancelled = errors.New("request cancelled")
// ErrShuttingDown is returned when the resolver is shutting down.
ErrShuttingDown = errors.New("resolver is shutting down")
// detailed errors // detailed errors
@ -275,6 +279,8 @@ retry:
case <-time.After(maxRequestTimeout): case <-time.After(maxRequestTimeout):
// something went wrong with the query, retry // something went wrong with the query, retry
goto retry goto retry
case <-ctx.Done():
return nil
} }
} else { } else {
// but that someone is taking too long // but that someone is taking too long
@ -331,7 +337,7 @@ resolveLoop:
for i = 0; i < 2; i++ { for i = 0; i < 2; i++ {
for _, resolver := range resolvers { for _, resolver := range resolvers {
if module.IsStopping() { if module.IsStopping() {
return nil, errors.New("shutting down") return nil, ErrShuttingDown
} }
// check if resolver failed recently (on first run) // check if resolver failed recently (on first run)
@ -364,6 +370,10 @@ resolveLoop:
resolver.Conn.ReportFailure() resolver.Conn.ReportFailure()
log.Tracer(ctx).Debugf("resolver: query to %s timed out", resolver.Info.ID()) log.Tracer(ctx).Debugf("resolver: query to %s timed out", resolver.Info.ID())
continue continue
case errors.Is(err, ErrCancelled):
return nil, err
case errors.Is(err, ErrShuttingDown):
return nil, err
default: default:
resolver.Conn.ReportFailure() resolver.Conn.ReportFailure()
log.Tracer(ctx).Debugf("resolver: query to %s failed: %s", resolver.Info.ID(), err) log.Tracer(ctx).Debugf("resolver: query to %s failed: %s", resolver.Info.ID(), err)

View file

@ -418,6 +418,8 @@ func queryMulticastDNS(ctx context.Context, q *Query) (*RRCache, error) {
if err != nil { if err != nil {
return rrCache, nil return rrCache, nil
} }
case <-ctx.Done():
return nil, ErrCancelled
} }
// Respond with NXDomain. // Respond with NXDomain.

View file

@ -105,7 +105,7 @@ func (tr *TCPResolver) UseTLS() *TCPResolver {
return tr return tr
} }
func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) { func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolverConn, error) {
tr.Lock() tr.Lock()
defer tr.Unlock() defer tr.Unlock()
@ -117,6 +117,10 @@ func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
return tr.resolverConn, nil return tr.resolverConn, nil
case <-time.After(heartbeatTimeout): case <-time.After(heartbeatTimeout):
log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName()) log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName())
case <-ctx.Done():
return nil, ErrCancelled
case <-module.Stopping():
return nil, ErrShuttingDown
} }
} }
@ -130,7 +134,6 @@ func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
} }
// Connect to server. // Connect to server.
var err error
conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress) conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
if err != nil { if err != nil {
log.Debugf("resolver: failed to connect to %s", tr.resolver.Info.DescriptiveName()) log.Debugf("resolver: failed to connect to %s", tr.resolver.Info.DescriptiveName())
@ -171,7 +174,7 @@ func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
// Query executes the given query against the resolver. // Query executes the given query against the resolver.
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// Get resolver connection. // Get resolver connection.
resolverConn, err := tr.getOrCreateResolverConn() resolverConn, err := tr.getOrCreateResolverConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -185,6 +188,10 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// Submit query request to live connection. // Submit query request to live connection.
select { select {
case resolverConn.queries <- tq: case resolverConn.queries <- tq:
case <-ctx.Done():
return nil, ErrCancelled
case <-module.Stopping():
return nil, ErrShuttingDown
case <-time.After(defaultRequestTimeout): case <-time.After(defaultRequestTimeout):
return nil, ErrTimeout return nil, ErrTimeout
} }
@ -193,6 +200,10 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
var reply *dns.Msg var reply *dns.Msg
select { select {
case reply = <-tq.Response: case reply = <-tq.Response:
case <-ctx.Done():
return nil, ErrCancelled
case <-module.Stopping():
return nil, ErrShuttingDown
case <-time.After(defaultRequestTimeout): case <-time.After(defaultRequestTimeout):
return nil, ErrTimeout return nil, ErrTimeout
} }