mirror of
https://github.com/safing/portmaster
synced 2025-09-01 18:19:12 +00:00
Add hearbeat check to TCP resolver
This commit is contained in:
parent
bd8d047428
commit
34247b1d82
1 changed files with 79 additions and 12 deletions
|
@ -3,6 +3,8 @@ package resolver
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -26,6 +28,8 @@ type TCPResolver struct {
|
|||
dnsClient *dns.Client
|
||||
|
||||
clientStarted *abool.AtomicBool
|
||||
clientHeartbeat chan struct{}
|
||||
clientCancel func()
|
||||
connInstanceID *uint32
|
||||
queries chan *dns.Msg
|
||||
inFlightQueries map[uint16]*InFlightQuery
|
||||
|
@ -68,10 +72,12 @@ func NewTCPResolver(resolver *Resolver) *TCPResolver {
|
|||
Timeout: defaultConnectTimeout,
|
||||
WriteTimeout: tcpWriteTimeout,
|
||||
},
|
||||
connInstanceID: &instanceID,
|
||||
queries: make(chan *dns.Msg, 100),
|
||||
inFlightQueries: make(map[uint16]*InFlightQuery),
|
||||
clientStarted: abool.New(),
|
||||
clientHeartbeat: make(chan struct{}),
|
||||
clientCancel: func() {},
|
||||
connInstanceID: &instanceID,
|
||||
queries: make(chan *dns.Msg, 1000),
|
||||
inFlightQueries: make(map[uint16]*InFlightQuery),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -146,6 +152,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
|||
// submit to client
|
||||
inFlight := tr.submitQuery(ctx, q)
|
||||
if inFlight == nil {
|
||||
tr.checkClientStatus()
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
|
@ -153,6 +160,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
|||
select {
|
||||
case reply = <-inFlight.Response:
|
||||
case <-time.After(defaultRequestTimeout):
|
||||
tr.checkClientStatus()
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
|
@ -168,6 +176,21 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
|||
return inFlight.MakeCacheRecord(reply), nil
|
||||
}
|
||||
|
||||
func (tr *TCPResolver) checkClientStatus() {
|
||||
// Get client cancel function before waiting in order to not immediately
|
||||
// cancel a new client.
|
||||
tr.Lock()
|
||||
cancelClient := tr.clientCancel
|
||||
tr.Unlock()
|
||||
|
||||
// Check if the client is alive with the heartbeat, if not shut it down.
|
||||
select {
|
||||
case tr.clientHeartbeat <- struct{}{}:
|
||||
case <-time.After(defaultRequestTimeout):
|
||||
cancelClient()
|
||||
}
|
||||
}
|
||||
|
||||
type tcpResolverConnMgr struct {
|
||||
tr *TCPResolver
|
||||
responses chan *dns.Msg
|
||||
|
@ -185,8 +208,14 @@ func (tr *TCPResolver) startClient() {
|
|||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
||||
mgr.tr.clientStarted.Set()
|
||||
defer mgr.shutdown()
|
||||
mgr.tr.clientStarted.Set()
|
||||
|
||||
// Create additional cancel function for this worker.
|
||||
workerCtx, cancelWorker := context.WithCancel(workerCtx)
|
||||
mgr.tr.Lock()
|
||||
mgr.tr.clientCancel = cancelWorker
|
||||
mgr.tr.Unlock()
|
||||
|
||||
// connection lifecycle loop
|
||||
for {
|
||||
|
@ -314,10 +343,21 @@ func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) (
|
|||
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.GetName(), mgr.tr.resolver.ServerAddress)
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
|
||||
connCtx, cancelConnCtx = context.WithCancel(context.Background())
|
||||
connClosing = abool.New()
|
||||
|
||||
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.GetName(), conn.RemoteAddr())
|
||||
// Get amount of in waiting queries.
|
||||
mgr.tr.Lock()
|
||||
waitingQueries := len(mgr.tr.inFlightQueries)
|
||||
mgr.tr.Unlock()
|
||||
|
||||
// Log that a connection to the resolver was established.
|
||||
log.Debugf(
|
||||
"resolver: connected to %s (%s) with %d queries waiting",
|
||||
mgr.tr.resolver.GetName(),
|
||||
conn.RemoteAddr(),
|
||||
waitingQueries,
|
||||
)
|
||||
|
||||
// start reader
|
||||
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
|
||||
|
@ -349,6 +389,9 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
|
|||
|
||||
for {
|
||||
select {
|
||||
case <-mgr.tr.clientHeartbeat:
|
||||
// respond to alive checks
|
||||
|
||||
case <-workerCtx.Done():
|
||||
// module shutdown
|
||||
return false
|
||||
|
@ -373,9 +416,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
|
|||
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
|
||||
err := conn.WriteMsg(msg)
|
||||
if err != nil {
|
||||
if connClosing.SetToIf(false, true) {
|
||||
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err)
|
||||
}
|
||||
mgr.logConnectionError(err, conn, connClosing)
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -456,11 +497,37 @@ func (mgr *tcpResolverConnMgr) msgReader(
|
|||
for {
|
||||
msg, err := conn.ReadMsg()
|
||||
if err != nil {
|
||||
if connClosing.SetToIf(false, true) {
|
||||
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err)
|
||||
}
|
||||
mgr.logConnectionError(err, conn, connClosing)
|
||||
return nil
|
||||
}
|
||||
mgr.responses <- msg
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) logConnectionError(err error, conn *dns.Conn, connClosing *abool.AtomicBool) {
|
||||
// Check if we are the first to see an error.
|
||||
if connClosing.SetToIf(false, true) {
|
||||
// Get amount of in flight queries.
|
||||
mgr.tr.Lock()
|
||||
inFlightQueries := len(mgr.tr.inFlightQueries)
|
||||
mgr.tr.Unlock()
|
||||
|
||||
// Log error.
|
||||
if errors.Is(err, io.EOF) {
|
||||
log.Debugf(
|
||||
"resolver: connection to %s (%s) was closed with %d in-flight queries",
|
||||
mgr.tr.resolver.GetName(),
|
||||
conn.RemoteAddr(),
|
||||
inFlightQueries,
|
||||
)
|
||||
} else {
|
||||
log.Warningf(
|
||||
"resolver: write error to %s (%s) with %d in-flight queries: %s",
|
||||
mgr.tr.resolver.GetName(),
|
||||
conn.RemoteAddr(),
|
||||
inFlightQueries,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue