mirror of
https://github.com/safing/portmaster
synced 2025-09-14 16:59:40 +00:00
Listen for context end better in resolver
This commit is contained in:
parent
8b489f4c26
commit
a6ce021dbd
3 changed files with 27 additions and 4 deletions
|
@ -32,6 +32,10 @@ var (
|
||||||
ErrFailure = errors.New("query failed")
|
ErrFailure = errors.New("query failed")
|
||||||
// ErrContinue is returned when the resolver has no answer, and the next resolver should be asked
|
// ErrContinue is returned when the resolver has no answer, and the next resolver should be asked
|
||||||
ErrContinue = errors.New("resolver has no answer")
|
ErrContinue = errors.New("resolver has no answer")
|
||||||
|
// ErrCancelled is returned when the request was cancelled.
|
||||||
|
ErrCancelled = errors.New("request cancelled")
|
||||||
|
// ErrShuttingDown is returned when the resolver is shutting down.
|
||||||
|
ErrShuttingDown = errors.New("resolver is shutting down")
|
||||||
|
|
||||||
// detailed errors
|
// detailed errors
|
||||||
|
|
||||||
|
@ -275,6 +279,8 @@ retry:
|
||||||
case <-time.After(maxRequestTimeout):
|
case <-time.After(maxRequestTimeout):
|
||||||
// something went wrong with the query, retry
|
// something went wrong with the query, retry
|
||||||
goto retry
|
goto retry
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// but that someone is taking too long
|
// but that someone is taking too long
|
||||||
|
@ -331,7 +337,7 @@ resolveLoop:
|
||||||
for i = 0; i < 2; i++ {
|
for i = 0; i < 2; i++ {
|
||||||
for _, resolver := range resolvers {
|
for _, resolver := range resolvers {
|
||||||
if module.IsStopping() {
|
if module.IsStopping() {
|
||||||
return nil, errors.New("shutting down")
|
return nil, ErrShuttingDown
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if resolver failed recently (on first run)
|
// check if resolver failed recently (on first run)
|
||||||
|
@ -364,6 +370,10 @@ resolveLoop:
|
||||||
resolver.Conn.ReportFailure()
|
resolver.Conn.ReportFailure()
|
||||||
log.Tracer(ctx).Debugf("resolver: query to %s timed out", resolver.Info.ID())
|
log.Tracer(ctx).Debugf("resolver: query to %s timed out", resolver.Info.ID())
|
||||||
continue
|
continue
|
||||||
|
case errors.Is(err, ErrCancelled):
|
||||||
|
return nil, err
|
||||||
|
case errors.Is(err, ErrShuttingDown):
|
||||||
|
return nil, err
|
||||||
default:
|
default:
|
||||||
resolver.Conn.ReportFailure()
|
resolver.Conn.ReportFailure()
|
||||||
log.Tracer(ctx).Debugf("resolver: query to %s failed: %s", resolver.Info.ID(), err)
|
log.Tracer(ctx).Debugf("resolver: query to %s failed: %s", resolver.Info.ID(), err)
|
||||||
|
|
|
@ -418,6 +418,8 @@ func queryMulticastDNS(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return rrCache, nil
|
return rrCache, nil
|
||||||
}
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ErrCancelled
|
||||||
}
|
}
|
||||||
|
|
||||||
// Respond with NXDomain.
|
// Respond with NXDomain.
|
||||||
|
|
|
@ -105,7 +105,7 @@ func (tr *TCPResolver) UseTLS() *TCPResolver {
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
|
func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolverConn, error) {
|
||||||
tr.Lock()
|
tr.Lock()
|
||||||
defer tr.Unlock()
|
defer tr.Unlock()
|
||||||
|
|
||||||
|
@ -117,6 +117,10 @@ func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
|
||||||
return tr.resolverConn, nil
|
return tr.resolverConn, nil
|
||||||
case <-time.After(heartbeatTimeout):
|
case <-time.After(heartbeatTimeout):
|
||||||
log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName())
|
log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName())
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ErrCancelled
|
||||||
|
case <-module.Stopping():
|
||||||
|
return nil, ErrShuttingDown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +134,6 @@ func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to server.
|
// Connect to server.
|
||||||
var err error
|
|
||||||
conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
|
conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("resolver: failed to connect to %s", tr.resolver.Info.DescriptiveName())
|
log.Debugf("resolver: failed to connect to %s", tr.resolver.Info.DescriptiveName())
|
||||||
|
@ -171,7 +174,7 @@ func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
|
||||||
// Query executes the given query against the resolver.
|
// Query executes the given query against the resolver.
|
||||||
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
// Get resolver connection.
|
// Get resolver connection.
|
||||||
resolverConn, err := tr.getOrCreateResolverConn()
|
resolverConn, err := tr.getOrCreateResolverConn(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -185,6 +188,10 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
// Submit query request to live connection.
|
// Submit query request to live connection.
|
||||||
select {
|
select {
|
||||||
case resolverConn.queries <- tq:
|
case resolverConn.queries <- tq:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ErrCancelled
|
||||||
|
case <-module.Stopping():
|
||||||
|
return nil, ErrShuttingDown
|
||||||
case <-time.After(defaultRequestTimeout):
|
case <-time.After(defaultRequestTimeout):
|
||||||
return nil, ErrTimeout
|
return nil, ErrTimeout
|
||||||
}
|
}
|
||||||
|
@ -193,6 +200,10 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
var reply *dns.Msg
|
var reply *dns.Msg
|
||||||
select {
|
select {
|
||||||
case reply = <-tq.Response:
|
case reply = <-tq.Response:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ErrCancelled
|
||||||
|
case <-module.Stopping():
|
||||||
|
return nil, ErrShuttingDown
|
||||||
case <-time.After(defaultRequestTimeout):
|
case <-time.After(defaultRequestTimeout):
|
||||||
return nil, ErrTimeout
|
return nil, ErrTimeout
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue