mirror of
https://github.com/safing/portmaster
synced 2025-09-01 18:19:12 +00:00
Add hearbeat check to TCP resolver
This commit is contained in:
parent
bd8d047428
commit
34247b1d82
1 changed files with 79 additions and 12 deletions
|
@ -3,6 +3,8 @@ package resolver
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
@ -26,6 +28,8 @@ type TCPResolver struct {
|
||||||
dnsClient *dns.Client
|
dnsClient *dns.Client
|
||||||
|
|
||||||
clientStarted *abool.AtomicBool
|
clientStarted *abool.AtomicBool
|
||||||
|
clientHeartbeat chan struct{}
|
||||||
|
clientCancel func()
|
||||||
connInstanceID *uint32
|
connInstanceID *uint32
|
||||||
queries chan *dns.Msg
|
queries chan *dns.Msg
|
||||||
inFlightQueries map[uint16]*InFlightQuery
|
inFlightQueries map[uint16]*InFlightQuery
|
||||||
|
@ -68,10 +72,12 @@ func NewTCPResolver(resolver *Resolver) *TCPResolver {
|
||||||
Timeout: defaultConnectTimeout,
|
Timeout: defaultConnectTimeout,
|
||||||
WriteTimeout: tcpWriteTimeout,
|
WriteTimeout: tcpWriteTimeout,
|
||||||
},
|
},
|
||||||
connInstanceID: &instanceID,
|
|
||||||
queries: make(chan *dns.Msg, 100),
|
|
||||||
inFlightQueries: make(map[uint16]*InFlightQuery),
|
|
||||||
clientStarted: abool.New(),
|
clientStarted: abool.New(),
|
||||||
|
clientHeartbeat: make(chan struct{}),
|
||||||
|
clientCancel: func() {},
|
||||||
|
connInstanceID: &instanceID,
|
||||||
|
queries: make(chan *dns.Msg, 1000),
|
||||||
|
inFlightQueries: make(map[uint16]*InFlightQuery),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,6 +152,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
// submit to client
|
// submit to client
|
||||||
inFlight := tr.submitQuery(ctx, q)
|
inFlight := tr.submitQuery(ctx, q)
|
||||||
if inFlight == nil {
|
if inFlight == nil {
|
||||||
|
tr.checkClientStatus()
|
||||||
return nil, ErrTimeout
|
return nil, ErrTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,6 +160,7 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
select {
|
select {
|
||||||
case reply = <-inFlight.Response:
|
case reply = <-inFlight.Response:
|
||||||
case <-time.After(defaultRequestTimeout):
|
case <-time.After(defaultRequestTimeout):
|
||||||
|
tr.checkClientStatus()
|
||||||
return nil, ErrTimeout
|
return nil, ErrTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,6 +176,21 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
return inFlight.MakeCacheRecord(reply), nil
|
return inFlight.MakeCacheRecord(reply), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tr *TCPResolver) checkClientStatus() {
|
||||||
|
// Get client cancel function before waiting in order to not immediately
|
||||||
|
// cancel a new client.
|
||||||
|
tr.Lock()
|
||||||
|
cancelClient := tr.clientCancel
|
||||||
|
tr.Unlock()
|
||||||
|
|
||||||
|
// Check if the client is alive with the heartbeat, if not shut it down.
|
||||||
|
select {
|
||||||
|
case tr.clientHeartbeat <- struct{}{}:
|
||||||
|
case <-time.After(defaultRequestTimeout):
|
||||||
|
cancelClient()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type tcpResolverConnMgr struct {
|
type tcpResolverConnMgr struct {
|
||||||
tr *TCPResolver
|
tr *TCPResolver
|
||||||
responses chan *dns.Msg
|
responses chan *dns.Msg
|
||||||
|
@ -185,8 +208,14 @@ func (tr *TCPResolver) startClient() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
||||||
mgr.tr.clientStarted.Set()
|
|
||||||
defer mgr.shutdown()
|
defer mgr.shutdown()
|
||||||
|
mgr.tr.clientStarted.Set()
|
||||||
|
|
||||||
|
// Create additional cancel function for this worker.
|
||||||
|
workerCtx, cancelWorker := context.WithCancel(workerCtx)
|
||||||
|
mgr.tr.Lock()
|
||||||
|
mgr.tr.clientCancel = cancelWorker
|
||||||
|
mgr.tr.Unlock()
|
||||||
|
|
||||||
// connection lifecycle loop
|
// connection lifecycle loop
|
||||||
for {
|
for {
|
||||||
|
@ -314,10 +343,21 @@ func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) (
|
||||||
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.GetName(), mgr.tr.resolver.ServerAddress)
|
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.GetName(), mgr.tr.resolver.ServerAddress)
|
||||||
return nil, nil, nil, nil
|
return nil, nil, nil, nil
|
||||||
}
|
}
|
||||||
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
|
connCtx, cancelConnCtx = context.WithCancel(context.Background())
|
||||||
connClosing = abool.New()
|
connClosing = abool.New()
|
||||||
|
|
||||||
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.GetName(), conn.RemoteAddr())
|
// Get amount of in waiting queries.
|
||||||
|
mgr.tr.Lock()
|
||||||
|
waitingQueries := len(mgr.tr.inFlightQueries)
|
||||||
|
mgr.tr.Unlock()
|
||||||
|
|
||||||
|
// Log that a connection to the resolver was established.
|
||||||
|
log.Debugf(
|
||||||
|
"resolver: connected to %s (%s) with %d queries waiting",
|
||||||
|
mgr.tr.resolver.GetName(),
|
||||||
|
conn.RemoteAddr(),
|
||||||
|
waitingQueries,
|
||||||
|
)
|
||||||
|
|
||||||
// start reader
|
// start reader
|
||||||
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
|
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
|
||||||
|
@ -349,6 +389,9 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-mgr.tr.clientHeartbeat:
|
||||||
|
// respond to alive checks
|
||||||
|
|
||||||
case <-workerCtx.Done():
|
case <-workerCtx.Done():
|
||||||
// module shutdown
|
// module shutdown
|
||||||
return false
|
return false
|
||||||
|
@ -373,9 +416,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
|
||||||
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
|
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
|
||||||
err := conn.WriteMsg(msg)
|
err := conn.WriteMsg(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if connClosing.SetToIf(false, true) {
|
mgr.logConnectionError(err, conn, connClosing)
|
||||||
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -456,11 +497,37 @@ func (mgr *tcpResolverConnMgr) msgReader(
|
||||||
for {
|
for {
|
||||||
msg, err := conn.ReadMsg()
|
msg, err := conn.ReadMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if connClosing.SetToIf(false, true) {
|
mgr.logConnectionError(err, conn, connClosing)
|
||||||
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.GetName(), conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
mgr.responses <- msg
|
mgr.responses <- msg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) logConnectionError(err error, conn *dns.Conn, connClosing *abool.AtomicBool) {
|
||||||
|
// Check if we are the first to see an error.
|
||||||
|
if connClosing.SetToIf(false, true) {
|
||||||
|
// Get amount of in flight queries.
|
||||||
|
mgr.tr.Lock()
|
||||||
|
inFlightQueries := len(mgr.tr.inFlightQueries)
|
||||||
|
mgr.tr.Unlock()
|
||||||
|
|
||||||
|
// Log error.
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
log.Debugf(
|
||||||
|
"resolver: connection to %s (%s) was closed with %d in-flight queries",
|
||||||
|
mgr.tr.resolver.GetName(),
|
||||||
|
conn.RemoteAddr(),
|
||||||
|
inFlightQueries,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
log.Warningf(
|
||||||
|
"resolver: write error to %s (%s) with %d in-flight queries: %s",
|
||||||
|
mgr.tr.resolver.GetName(),
|
||||||
|
conn.RemoteAddr(),
|
||||||
|
inFlightQueries,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue