mirror of
https://github.com/safing/portmaster
synced 2025-09-01 18:19:12 +00:00
Simplify TCP resolver
This commit is contained in:
parent
27d41d51d6
commit
9624995c6e
2 changed files with 275 additions and 405 deletions
|
@ -4,9 +4,9 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -16,68 +16,79 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
tcpWriteTimeout = 2 * time.Second
|
||||
ignoreQueriesAfter = 10 * time.Minute
|
||||
heartbeatTimeout = 15 * time.Second
|
||||
tcpConnectionEstablishmentTimeout = 3 * time.Second
|
||||
tcpWriteTimeout = 2 * time.Second
|
||||
heartbeatTimeout = 5 * time.Second
|
||||
ignoreQueriesAfter = 5 * time.Minute
|
||||
)
|
||||
|
||||
// TCPResolver is a resolver using just a single tcp connection with pipelining.
|
||||
type TCPResolver struct {
|
||||
BasicResolverConn
|
||||
|
||||
clientTTL time.Duration
|
||||
// dnsClient holds the connection configuration of the resolver.
|
||||
dnsClient *dns.Client
|
||||
|
||||
clientStarted *abool.AtomicBool
|
||||
clientHeartbeat chan struct{}
|
||||
stopClient func()
|
||||
connInstanceID *uint32
|
||||
queries chan *dns.Msg
|
||||
inFlightQueries map[uint16]*InFlightQuery
|
||||
// resolverConn holds a connection to the DNS server, including query management.
|
||||
resolverConn *tcpResolverConn
|
||||
// resolverConnInstanceID holds the current ID of the resolverConn.
|
||||
resolverConnInstanceID int
|
||||
}
|
||||
|
||||
// InFlightQuery represents an in flight query of a TCPResolver.
|
||||
type InFlightQuery struct {
|
||||
Query *Query
|
||||
Msg *dns.Msg
|
||||
Response chan *dns.Msg
|
||||
Resolver *Resolver
|
||||
Started time.Time
|
||||
ConnInstanceID uint32
|
||||
// tcpResolverConn represents a single connection to an upstream DNS server.
|
||||
type tcpResolverConn struct {
|
||||
// ctx is the context of the tcpResolverConn.
|
||||
ctx context.Context
|
||||
// cancelCtx cancels cancelCtx
|
||||
cancelCtx func()
|
||||
// id is the ID assigned to the resolver conn.
|
||||
id int
|
||||
// conn is the connection to the DNS server.
|
||||
conn *dns.Conn
|
||||
// resolverInfo holds information about the resolver to enhance error messages.
|
||||
resolverInfo *ResolverInfo
|
||||
// queries is used to submit queries to be sent to the connected DNS server.
|
||||
queries chan *tcpQuery
|
||||
// responses is used to hand the responses from the reader to the handler.
|
||||
responses chan *dns.Msg
|
||||
// inFlightQueries holds all in-flight queries of this connection.
|
||||
inFlightQueries map[uint16]*tcpQuery
|
||||
// heartbeat is a alive-checking channel from which the resolver conn must
|
||||
// always read asap.
|
||||
heartbeat chan struct{}
|
||||
// abandoned signifies if the resolver conn has been abandoned.
|
||||
abandoned *abool.AtomicBool
|
||||
}
|
||||
|
||||
// MakeCacheRecord creates an RCache record from a reply.
|
||||
func (ifq *InFlightQuery) MakeCacheRecord(reply *dns.Msg) *RRCache {
|
||||
// tcpQuery holds the query information for a tcpResolverConn.
|
||||
type tcpQuery struct {
|
||||
Query *Query
|
||||
Response chan *dns.Msg
|
||||
}
|
||||
|
||||
// MakeCacheRecord creates an RRCache record from a reply.
|
||||
func (tq *tcpQuery) MakeCacheRecord(reply *dns.Msg, resolverInfo *ResolverInfo) *RRCache {
|
||||
return &RRCache{
|
||||
Domain: ifq.Query.FQDN,
|
||||
Question: ifq.Query.QType,
|
||||
Domain: tq.Query.FQDN,
|
||||
Question: tq.Query.QType,
|
||||
RCode: reply.Rcode,
|
||||
Answer: reply.Answer,
|
||||
Ns: reply.Ns,
|
||||
Extra: reply.Extra,
|
||||
Resolver: ifq.Resolver.Info.Copy(),
|
||||
Resolver: resolverInfo.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTCPResolver returns a new TPCResolver.
|
||||
func NewTCPResolver(resolver *Resolver) *TCPResolver {
|
||||
var instanceID uint32
|
||||
newResolver := &TCPResolver{
|
||||
BasicResolverConn: BasicResolverConn{
|
||||
resolver: resolver,
|
||||
},
|
||||
clientTTL: defaultClientTTL,
|
||||
dnsClient: &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: defaultConnectTimeout,
|
||||
WriteTimeout: tcpWriteTimeout,
|
||||
},
|
||||
clientStarted: abool.New(),
|
||||
clientHeartbeat: make(chan struct{}),
|
||||
stopClient: func() {},
|
||||
connInstanceID: &instanceID,
|
||||
queries: make(chan *dns.Msg, 1000),
|
||||
inFlightQueries: make(map[uint16]*InFlightQuery),
|
||||
}
|
||||
newResolver.BasicResolverConn.init()
|
||||
return newResolver
|
||||
|
@ -94,45 +105,214 @@ func (tr *TCPResolver) UseTLS() *TCPResolver {
|
|||
return tr
|
||||
}
|
||||
|
||||
func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
|
||||
// make sure client is started
|
||||
tr.startClient()
|
||||
|
||||
// create msg
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(q.FQDN, uint16(q.QType))
|
||||
|
||||
// save to waitlist
|
||||
inFlight := &InFlightQuery{
|
||||
Query: q,
|
||||
Msg: msg,
|
||||
Response: make(chan *dns.Msg),
|
||||
Resolver: tr.resolver,
|
||||
Started: time.Now().UTC(),
|
||||
ConnInstanceID: atomic.LoadUint32(tr.connInstanceID),
|
||||
}
|
||||
func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
|
||||
tr.Lock()
|
||||
// check for existing query
|
||||
tr.ensureUniqueID(msg)
|
||||
// add query to in flight registry
|
||||
tr.inFlightQueries[msg.Id] = inFlight
|
||||
tr.Unlock()
|
||||
defer tr.Unlock()
|
||||
|
||||
// submit msg for writing
|
||||
select {
|
||||
case tr.queries <- msg:
|
||||
case <-time.After(defaultRequestTimeout):
|
||||
return nil
|
||||
// Check if we have a resolver.
|
||||
if tr.resolverConn != nil && tr.resolverConn.abandoned.IsNotSet() {
|
||||
// If there is one, check if it's alive!
|
||||
select {
|
||||
case tr.resolverConn.heartbeat <- struct{}{}:
|
||||
return tr.resolverConn, nil
|
||||
case <-time.After(heartbeatTimeout):
|
||||
log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName())
|
||||
}
|
||||
}
|
||||
|
||||
return inFlight
|
||||
// Create a new if no active one is available.
|
||||
|
||||
// Refresh the dialer in order to set an authenticated local address.
|
||||
tr.dnsClient.Dialer = &net.Dialer{
|
||||
LocalAddr: getLocalAddr("tcp"),
|
||||
Timeout: tcpConnectionEstablishmentTimeout,
|
||||
KeepAlive: defaultClientTTL,
|
||||
}
|
||||
|
||||
// Connect to server.
|
||||
var err error
|
||||
conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
|
||||
if err != nil {
|
||||
log.Debugf("resolver: failed to connect to %s", tr.resolver.Info.DescriptiveName())
|
||||
return nil, fmt.Errorf("%w: failed to connect to %s: %s", ErrFailure, tr.resolver.Info.DescriptiveName(), err)
|
||||
}
|
||||
|
||||
// Log that a connection to the resolver was established.
|
||||
log.Debugf(
|
||||
"resolver: connected to %s",
|
||||
tr.resolver.Info.DescriptiveName(),
|
||||
)
|
||||
|
||||
// Create resolver connection.
|
||||
tr.resolverConnInstanceID++
|
||||
resolverConn := &tcpResolverConn{
|
||||
id: tr.resolverConnInstanceID,
|
||||
conn: conn,
|
||||
resolverInfo: tr.resolver.Info,
|
||||
queries: make(chan *tcpQuery, 10),
|
||||
responses: make(chan *dns.Msg, 10),
|
||||
inFlightQueries: make(map[uint16]*tcpQuery, 10),
|
||||
heartbeat: make(chan struct{}),
|
||||
abandoned: abool.New(),
|
||||
}
|
||||
|
||||
// Start worker.
|
||||
module.StartWorker("dns client", resolverConn.handler)
|
||||
|
||||
// Set resolver conn for reuse.
|
||||
tr.resolverConn = resolverConn
|
||||
|
||||
// Hint network environment at successful connection.
|
||||
netenv.ReportSuccessfulConnection()
|
||||
|
||||
return resolverConn, nil
|
||||
}
|
||||
|
||||
// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked.
|
||||
func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
|
||||
// Query executes the given query against the resolver.
|
||||
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||
// Get resolver connection.
|
||||
resolverConn, err := tr.getOrCreateResolverConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create query request.
|
||||
tq := &tcpQuery{
|
||||
Query: q,
|
||||
Response: make(chan *dns.Msg),
|
||||
}
|
||||
|
||||
// Submit query request to live connection.
|
||||
select {
|
||||
case resolverConn.queries <- tq:
|
||||
case <-time.After(defaultRequestTimeout):
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
// Wait for reply.
|
||||
var reply *dns.Msg
|
||||
select {
|
||||
case reply = <-tq.Response:
|
||||
case <-time.After(defaultRequestTimeout):
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
// Check if we have a reply.
|
||||
if reply == nil {
|
||||
// Resolver is shutting down. The Portmaster may be shutting down, or
|
||||
// there is a connection error.
|
||||
return nil, ErrFailure
|
||||
}
|
||||
|
||||
// Check if the reply was blocked upstream.
|
||||
if tr.resolver.IsBlockedUpstream(reply) {
|
||||
return nil, &BlockedUpstreamError{tr.resolver.Info.DescriptiveName()}
|
||||
}
|
||||
|
||||
// Create RRCache from reply and return it.
|
||||
return tq.MakeCacheRecord(reply, tr.resolver.Info), nil
|
||||
}
|
||||
|
||||
func (trc *tcpResolverConn) shutdown() {
|
||||
// Set abandoned status and close connection to the DNS server.
|
||||
if trc.abandoned.SetToIf(false, true) {
|
||||
_ = trc.conn.Close()
|
||||
}
|
||||
|
||||
// Close all response channels for in-flight queries.
|
||||
for _, tq := range trc.inFlightQueries {
|
||||
close(tq.Response)
|
||||
}
|
||||
|
||||
// Respond to any incoming queries for some time in order to not leave them
|
||||
// hanging longer than necessary.
|
||||
for {
|
||||
select {
|
||||
case tq := <-trc.queries:
|
||||
close(tq.Response)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (trc *tcpResolverConn) handler(workerCtx context.Context) error {
|
||||
// Set up context and cleanup.
|
||||
trc.ctx, trc.cancelCtx = context.WithCancel(workerCtx)
|
||||
defer trc.shutdown()
|
||||
|
||||
// Set up variables.
|
||||
var readyToRecycle bool
|
||||
ttlTimer := time.After(defaultClientTTL)
|
||||
|
||||
// Start connection reader.
|
||||
module.StartWorker("dns client reader", trc.reader)
|
||||
|
||||
// Handle requests.
|
||||
for {
|
||||
select {
|
||||
case <-trc.heartbeat:
|
||||
// Respond to alive checks.
|
||||
|
||||
case <-trc.ctx.Done():
|
||||
// Respond to module shutdown or conn error.
|
||||
return nil
|
||||
|
||||
case <-ttlTimer:
|
||||
// Recycle the connection after the TTL is reached.
|
||||
readyToRecycle = true
|
||||
// Send dummy response to trigger the check.
|
||||
select {
|
||||
case trc.responses <- nil:
|
||||
default:
|
||||
// The response queue is full.
|
||||
// The check will be triggered by another response.
|
||||
}
|
||||
|
||||
case tq := <-trc.queries:
|
||||
// Handle DNS query request.
|
||||
|
||||
// Create dns request message.
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(tq.Query.FQDN, uint16(tq.Query.QType))
|
||||
|
||||
// Assign a unique message ID.
|
||||
trc.assignUniqueID(msg)
|
||||
|
||||
// Add query to in flight registry.
|
||||
trc.inFlightQueries[msg.Id] = tq
|
||||
|
||||
// Write query to connected DNS server.
|
||||
_ = trc.conn.SetWriteDeadline(time.Now().Add(tcpWriteTimeout))
|
||||
err := trc.conn.WriteMsg(msg)
|
||||
if err != nil {
|
||||
trc.logConnectionError(err, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
case msg := <-trc.responses:
|
||||
if msg != nil {
|
||||
trc.handleQueryResponse(msg)
|
||||
}
|
||||
|
||||
// If we are ready to recycle and we have no in-flight queries, we can
|
||||
// shutdown the connection and create a new one for the next query.
|
||||
if readyToRecycle {
|
||||
if len(trc.inFlightQueries) == 0 {
|
||||
log.Debugf("resolver: recycling connection to %s", trc.resolverInfo.DescriptiveName())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assignUniqueID makes sure that ID assigned to msg is unique.
|
||||
func (trc *tcpResolverConn) assignUniqueID(msg *dns.Msg) {
|
||||
// try a random ID 10000 times
|
||||
for i := 0; i < 10000; i++ { // don't try forever
|
||||
_, exists := tr.inFlightQueries[msg.Id]
|
||||
_, exists := trc.inFlightQueries[msg.Id]
|
||||
if !exists {
|
||||
return // we are unique, yay!
|
||||
}
|
||||
|
@ -141,7 +321,7 @@ func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
|
|||
// go through the complete space
|
||||
var id uint16
|
||||
for ; id <= (1<<16)-1; id++ { // don't try forever
|
||||
_, exists := tr.inFlightQueries[id]
|
||||
_, exists := trc.inFlightQueries[id]
|
||||
if !exists {
|
||||
msg.Id = id
|
||||
return // we are unique, yay!
|
||||
|
@ -149,390 +329,80 @@ func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
|
|||
}
|
||||
}
|
||||
|
||||
// Query executes the given query against the resolver.
|
||||
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||
// submit to client
|
||||
inFlight := tr.submitQuery(ctx, q)
|
||||
if inFlight == nil {
|
||||
tr.checkClientStatus()
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
var reply *dns.Msg
|
||||
select {
|
||||
case reply = <-inFlight.Response:
|
||||
case <-time.After(defaultRequestTimeout):
|
||||
tr.checkClientStatus()
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
if reply == nil {
|
||||
// Resolver is shutting down, could be server failure or we are offline
|
||||
return nil, ErrFailure
|
||||
}
|
||||
|
||||
if tr.resolver.IsBlockedUpstream(reply) {
|
||||
return nil, &BlockedUpstreamError{tr.resolver.Info.DescriptiveName()}
|
||||
}
|
||||
|
||||
return inFlight.MakeCacheRecord(reply), nil
|
||||
}
|
||||
|
||||
func (tr *TCPResolver) checkClientStatus() {
|
||||
// Get client cancel function before waiting in order to not immediately
|
||||
// cancel a new client.
|
||||
tr.Lock()
|
||||
stopClient := tr.stopClient
|
||||
tr.Unlock()
|
||||
|
||||
// Check if the client is alive with the heartbeat, if not shut it down.
|
||||
select {
|
||||
case tr.clientHeartbeat <- struct{}{}:
|
||||
case <-time.After(heartbeatTimeout):
|
||||
log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.Info.DescriptiveName())
|
||||
stopClient()
|
||||
}
|
||||
}
|
||||
|
||||
type tcpResolverConnMgr struct {
|
||||
tr *TCPResolver
|
||||
responses chan *dns.Msg
|
||||
failCnt int
|
||||
}
|
||||
|
||||
func (tr *TCPResolver) startClient() {
|
||||
if tr.clientStarted.SetToIf(false, true) {
|
||||
mgr := &tcpResolverConnMgr{
|
||||
tr: tr,
|
||||
responses: make(chan *dns.Msg, 100),
|
||||
}
|
||||
module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
||||
defer mgr.shutdown()
|
||||
mgr.tr.clientStarted.Set()
|
||||
|
||||
// Create additional cancel function for this worker.
|
||||
clientCtx, stopClient := context.WithCancel(workerCtx)
|
||||
mgr.tr.Lock()
|
||||
mgr.tr.stopClient = stopClient
|
||||
mgr.tr.Unlock()
|
||||
|
||||
// connection lifecycle loop
|
||||
for {
|
||||
// check if we are shutting down
|
||||
select {
|
||||
case <-clientCtx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// check if we are failing
|
||||
if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// wait for work before creating connection
|
||||
proceed := mgr.waitForWork(clientCtx)
|
||||
if !proceed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// create connection
|
||||
conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection()
|
||||
if conn == nil {
|
||||
mgr.failCnt++
|
||||
continue
|
||||
}
|
||||
|
||||
// hint network environment at successful connection
|
||||
netenv.ReportSuccessfulConnection()
|
||||
|
||||
// handle queries
|
||||
proceed = mgr.queryHandler(clientCtx, conn, connClosing, connCtx, cancelConnCtx)
|
||||
if !proceed {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) shutdown() {
|
||||
// reply to all waiting queries
|
||||
mgr.tr.Lock()
|
||||
defer mgr.tr.Unlock()
|
||||
|
||||
mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
|
||||
atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter
|
||||
|
||||
for id, inFlight := range mgr.tr.inFlightQueries {
|
||||
close(inFlight.Response)
|
||||
delete(mgr.tr.inFlightQueries, id)
|
||||
}
|
||||
|
||||
// hint network environment at failed connection
|
||||
if mgr.failCnt >= FailThreshold {
|
||||
netenv.ReportFailedConnection()
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed bool) {
|
||||
// wait until there is something to do
|
||||
mgr.tr.Lock()
|
||||
waiting := len(mgr.tr.inFlightQueries)
|
||||
mgr.tr.Unlock()
|
||||
if waiting > 0 {
|
||||
// queue abandoned queries
|
||||
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
|
||||
currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID)
|
||||
mgr.tr.Lock()
|
||||
defer mgr.tr.Unlock()
|
||||
for id, inFlight := range mgr.tr.inFlightQueries {
|
||||
if inFlight.Started.Before(ignoreBefore) {
|
||||
// remove old queries
|
||||
close(inFlight.Response)
|
||||
delete(mgr.tr.inFlightQueries, id)
|
||||
} else if inFlight.ConnInstanceID != currentConnInstanceID {
|
||||
inFlight.ConnInstanceID = currentConnInstanceID
|
||||
// re-inject queries that died with a previously failed connection
|
||||
select {
|
||||
case mgr.tr.queries <- inFlight.Msg:
|
||||
default:
|
||||
log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.Info.DescriptiveName())
|
||||
}
|
||||
}
|
||||
// in-flight queries that match the connection instance ID are not changed. They are already in the queue.
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// wait for first query
|
||||
select {
|
||||
case <-clientCtx.Done():
|
||||
return false
|
||||
case msg := <-mgr.tr.queries:
|
||||
// re-insert query, we will handle it later
|
||||
module.StartWorker("reinject triggering dns query", func(ctx context.Context) error {
|
||||
select {
|
||||
case mgr.tr.queries <- msg:
|
||||
case <-time.After(2 * time.Second):
|
||||
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Info.DescriptiveName())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) establishConnection() (
|
||||
conn *dns.Conn,
|
||||
connClosing *abool.AtomicBool,
|
||||
connCtx context.Context,
|
||||
cancelConnCtx context.CancelFunc,
|
||||
) {
|
||||
// refresh dialer to set an authenticated local address
|
||||
// TODO: lock dnsClient (only manager should run at any time, so this should not be an issue)
|
||||
mgr.tr.dnsClient.Dialer = &net.Dialer{
|
||||
LocalAddr: getLocalAddr("tcp"),
|
||||
Timeout: defaultConnectTimeout,
|
||||
KeepAlive: defaultClientTTL,
|
||||
}
|
||||
// connect
|
||||
var err error
|
||||
conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
|
||||
if err != nil {
|
||||
log.Debugf("resolver: failed to connect to %s", mgr.tr.resolver.Info.DescriptiveName())
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
connCtx, cancelConnCtx = context.WithCancel(context.Background())
|
||||
connClosing = abool.New()
|
||||
|
||||
// Get amount of in waiting queries.
|
||||
mgr.tr.Lock()
|
||||
waitingQueries := len(mgr.tr.inFlightQueries)
|
||||
mgr.tr.Unlock()
|
||||
|
||||
// Log that a connection to the resolver was established.
|
||||
log.Debugf(
|
||||
"resolver: connected to %s with %d queries waiting",
|
||||
mgr.tr.resolver.Info.DescriptiveName(),
|
||||
waitingQueries,
|
||||
)
|
||||
|
||||
// start reader
|
||||
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(clientCtx context.Context) error {
|
||||
return mgr.msgReader(conn, connClosing, cancelConnCtx)
|
||||
})
|
||||
|
||||
return conn, connClosing, connCtx, cancelConnCtx
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
|
||||
clientCtx context.Context,
|
||||
conn *dns.Conn,
|
||||
connClosing *abool.AtomicBool,
|
||||
connCtx context.Context,
|
||||
cancelConnCtx context.CancelFunc,
|
||||
) (proceed bool) {
|
||||
var readyToRecycle bool
|
||||
ttlTimer := time.After(mgr.tr.clientTTL)
|
||||
|
||||
// clean up connection
|
||||
defer func() {
|
||||
connClosing.Set() // silence connection errors
|
||||
cancelConnCtx()
|
||||
_ = conn.Close()
|
||||
|
||||
// increase instance counter
|
||||
atomic.AddUint32(mgr.tr.connInstanceID, 1)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-mgr.tr.clientHeartbeat:
|
||||
// respond to alive checks
|
||||
|
||||
case <-clientCtx.Done():
|
||||
// module shutdown
|
||||
return false
|
||||
|
||||
case <-connCtx.Done():
|
||||
// connection error
|
||||
return true
|
||||
|
||||
case <-ttlTimer:
|
||||
// connection TTL reached, rebuild connection
|
||||
// but handle all in flight queries first
|
||||
readyToRecycle = true
|
||||
// trigger check
|
||||
select {
|
||||
case mgr.responses <- nil:
|
||||
default:
|
||||
// queue is full, check will be triggered anyway
|
||||
}
|
||||
|
||||
case msg := <-mgr.tr.queries:
|
||||
// write query
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
|
||||
err := conn.WriteMsg(msg)
|
||||
if err != nil {
|
||||
mgr.logConnectionError(err, conn, connClosing, false)
|
||||
return true
|
||||
}
|
||||
|
||||
case msg := <-mgr.responses:
|
||||
if msg != nil {
|
||||
mgr.handleQueryResponse(conn, msg)
|
||||
}
|
||||
|
||||
if readyToRecycle {
|
||||
// check to see if we can recycle the connection
|
||||
mgr.tr.Lock()
|
||||
activeQueries := len(mgr.tr.inFlightQueries)
|
||||
mgr.tr.Unlock()
|
||||
if activeQueries == 0 {
|
||||
log.Debugf("resolver: recycling conn to %s", mgr.tr.resolver.Info.DescriptiveName())
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) {
|
||||
// handle query from resolver
|
||||
mgr.tr.Lock()
|
||||
inFlight, ok := mgr.tr.inFlightQueries[msg.Id]
|
||||
func (trc *tcpResolverConn) handleQueryResponse(msg *dns.Msg) {
|
||||
// Get in flight from registry.
|
||||
tq, ok := trc.inFlightQueries[msg.Id]
|
||||
if ok {
|
||||
delete(mgr.tr.inFlightQueries, msg.Id)
|
||||
}
|
||||
mgr.tr.Unlock()
|
||||
|
||||
if !ok {
|
||||
delete(trc.inFlightQueries, msg.Id)
|
||||
} else {
|
||||
log.Debugf(
|
||||
"resolver: received possibly unsolicited reply from %s: txid=%d q=%+v",
|
||||
mgr.tr.resolver.Info.DescriptiveName(),
|
||||
trc.resolverInfo.DescriptiveName(),
|
||||
msg.Id,
|
||||
msg.Question,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Send response to waiting query handler.
|
||||
select {
|
||||
case inFlight.Response <- msg:
|
||||
mgr.failCnt = 0 // reset fail counter
|
||||
// responded!
|
||||
case tq.Response <- msg:
|
||||
return
|
||||
default:
|
||||
// no one is listening for that response.
|
||||
// No one is listening for that response.
|
||||
}
|
||||
|
||||
// if caching is disabled we're done
|
||||
if inFlight.Query.NoCaching {
|
||||
// If caching is disabled for this query, we are done.
|
||||
if tq.Query.NoCaching {
|
||||
return
|
||||
}
|
||||
|
||||
// persist to database
|
||||
rrCache := inFlight.MakeCacheRecord(msg)
|
||||
// Otherwise, we can persist the answer in case the request is repeated.
|
||||
rrCache := tq.MakeCacheRecord(msg, trc.resolverInfo)
|
||||
rrCache.Clean(minTTL)
|
||||
err := rrCache.Save()
|
||||
if err != nil {
|
||||
log.Warningf(
|
||||
"resolver: failed to cache RR for %s%s: %s",
|
||||
inFlight.Query.FQDN,
|
||||
inFlight.Query.QType.String(),
|
||||
"resolver: failed to cache RR for %s: %s",
|
||||
tq.Query.ID(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) msgReader(
|
||||
conn *dns.Conn,
|
||||
connClosing *abool.AtomicBool,
|
||||
cancelConnCtx context.CancelFunc,
|
||||
) error {
|
||||
defer cancelConnCtx()
|
||||
func (trc *tcpResolverConn) reader(workerCtx context.Context) error {
|
||||
defer trc.cancelCtx()
|
||||
|
||||
for {
|
||||
msg, err := conn.ReadMsg()
|
||||
msg, err := trc.conn.ReadMsg()
|
||||
if err != nil {
|
||||
mgr.logConnectionError(err, conn, connClosing, true)
|
||||
trc.logConnectionError(err, true)
|
||||
return nil
|
||||
}
|
||||
mgr.responses <- msg
|
||||
trc.responses <- msg
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *tcpResolverConnMgr) logConnectionError(err error, conn *dns.Conn, connClosing *abool.AtomicBool, reading bool) {
|
||||
func (trc *tcpResolverConn) logConnectionError(err error, reading bool) {
|
||||
// Check if we are the first to see an error.
|
||||
if connClosing.SetToIf(false, true) {
|
||||
// Get amount of in flight queries.
|
||||
mgr.tr.Lock()
|
||||
inFlightQueries := len(mgr.tr.inFlightQueries)
|
||||
mgr.tr.Unlock()
|
||||
|
||||
if trc.abandoned.SetToIf(false, true) {
|
||||
// Log error.
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
log.Debugf(
|
||||
"resolver: connection to %s was closed with %d in-flight queries",
|
||||
mgr.tr.resolver.Info.DescriptiveName(),
|
||||
inFlightQueries,
|
||||
"resolver: connection to %s was closed",
|
||||
trc.resolverInfo.DescriptiveName(),
|
||||
)
|
||||
case reading:
|
||||
log.Warningf(
|
||||
"resolver: read error from %s with %d in-flight queries: %s",
|
||||
mgr.tr.resolver.Info.DescriptiveName(),
|
||||
inFlightQueries,
|
||||
"resolver: read error from %s: %s",
|
||||
trc.resolverInfo.DescriptiveName(),
|
||||
err,
|
||||
)
|
||||
default:
|
||||
log.Warningf(
|
||||
"resolver: write error to %s with %d in-flight queries: %s",
|
||||
mgr.tr.resolver.Info.DescriptiveName(),
|
||||
inFlightQueries,
|
||||
"resolver: write error to %s: %s",
|
||||
trc.resolverInfo.DescriptiveName(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -128,7 +128,7 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
|
|||
|
||||
blockType := query.Get("blockedif")
|
||||
if blockType == "" {
|
||||
blockType = BlockDetectionRefused
|
||||
blockType = BlockDetectionZeroIP
|
||||
}
|
||||
|
||||
switch blockType {
|
||||
|
|
Loading…
Add table
Reference in a new issue