mirror of
https://github.com/safing/portmaster
synced 2025-09-02 10:39:22 +00:00
Simplify TCP resolver
This commit is contained in:
parent
27d41d51d6
commit
9624995c6e
2 changed files with 275 additions and 405 deletions
|
@ -4,9 +4,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
@ -16,68 +16,79 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
tcpConnectionEstablishmentTimeout = 3 * time.Second
|
||||||
tcpWriteTimeout = 2 * time.Second
|
tcpWriteTimeout = 2 * time.Second
|
||||||
ignoreQueriesAfter = 10 * time.Minute
|
heartbeatTimeout = 5 * time.Second
|
||||||
heartbeatTimeout = 15 * time.Second
|
ignoreQueriesAfter = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// TCPResolver is a resolver using just a single tcp connection with pipelining.
|
// TCPResolver is a resolver using just a single tcp connection with pipelining.
|
||||||
type TCPResolver struct {
|
type TCPResolver struct {
|
||||||
BasicResolverConn
|
BasicResolverConn
|
||||||
|
|
||||||
clientTTL time.Duration
|
// dnsClient holds the connection configuration of the resolver.
|
||||||
dnsClient *dns.Client
|
dnsClient *dns.Client
|
||||||
|
// resolverConn holds a connection to the DNS server, including query management.
|
||||||
clientStarted *abool.AtomicBool
|
resolverConn *tcpResolverConn
|
||||||
clientHeartbeat chan struct{}
|
// resolverConnInstanceID holds the current ID of the resolverConn.
|
||||||
stopClient func()
|
resolverConnInstanceID int
|
||||||
connInstanceID *uint32
|
|
||||||
queries chan *dns.Msg
|
|
||||||
inFlightQueries map[uint16]*InFlightQuery
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// InFlightQuery represents an in flight query of a TCPResolver.
|
// tcpResolverConn represents a single connection to an upstream DNS server.
|
||||||
type InFlightQuery struct {
|
type tcpResolverConn struct {
|
||||||
|
// ctx is the context of the tcpResolverConn.
|
||||||
|
ctx context.Context
|
||||||
|
// cancelCtx cancels cancelCtx
|
||||||
|
cancelCtx func()
|
||||||
|
// id is the ID assigned to the resolver conn.
|
||||||
|
id int
|
||||||
|
// conn is the connection to the DNS server.
|
||||||
|
conn *dns.Conn
|
||||||
|
// resolverInfo holds information about the resolver to enhance error messages.
|
||||||
|
resolverInfo *ResolverInfo
|
||||||
|
// queries is used to submit queries to be sent to the connected DNS server.
|
||||||
|
queries chan *tcpQuery
|
||||||
|
// responses is used to hand the responses from the reader to the handler.
|
||||||
|
responses chan *dns.Msg
|
||||||
|
// inFlightQueries holds all in-flight queries of this connection.
|
||||||
|
inFlightQueries map[uint16]*tcpQuery
|
||||||
|
// heartbeat is a alive-checking channel from which the resolver conn must
|
||||||
|
// always read asap.
|
||||||
|
heartbeat chan struct{}
|
||||||
|
// abandoned signifies if the resolver conn has been abandoned.
|
||||||
|
abandoned *abool.AtomicBool
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpQuery holds the query information for a tcpResolverConn.
|
||||||
|
type tcpQuery struct {
|
||||||
Query *Query
|
Query *Query
|
||||||
Msg *dns.Msg
|
|
||||||
Response chan *dns.Msg
|
Response chan *dns.Msg
|
||||||
Resolver *Resolver
|
|
||||||
Started time.Time
|
|
||||||
ConnInstanceID uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MakeCacheRecord creates an RCache record from a reply.
|
// MakeCacheRecord creates an RRCache record from a reply.
|
||||||
func (ifq *InFlightQuery) MakeCacheRecord(reply *dns.Msg) *RRCache {
|
func (tq *tcpQuery) MakeCacheRecord(reply *dns.Msg, resolverInfo *ResolverInfo) *RRCache {
|
||||||
return &RRCache{
|
return &RRCache{
|
||||||
Domain: ifq.Query.FQDN,
|
Domain: tq.Query.FQDN,
|
||||||
Question: ifq.Query.QType,
|
Question: tq.Query.QType,
|
||||||
RCode: reply.Rcode,
|
RCode: reply.Rcode,
|
||||||
Answer: reply.Answer,
|
Answer: reply.Answer,
|
||||||
Ns: reply.Ns,
|
Ns: reply.Ns,
|
||||||
Extra: reply.Extra,
|
Extra: reply.Extra,
|
||||||
Resolver: ifq.Resolver.Info.Copy(),
|
Resolver: resolverInfo.Copy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPResolver returns a new TPCResolver.
|
// NewTCPResolver returns a new TPCResolver.
|
||||||
func NewTCPResolver(resolver *Resolver) *TCPResolver {
|
func NewTCPResolver(resolver *Resolver) *TCPResolver {
|
||||||
var instanceID uint32
|
|
||||||
newResolver := &TCPResolver{
|
newResolver := &TCPResolver{
|
||||||
BasicResolverConn: BasicResolverConn{
|
BasicResolverConn: BasicResolverConn{
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
},
|
},
|
||||||
clientTTL: defaultClientTTL,
|
|
||||||
dnsClient: &dns.Client{
|
dnsClient: &dns.Client{
|
||||||
Net: "tcp",
|
Net: "tcp",
|
||||||
Timeout: defaultConnectTimeout,
|
Timeout: defaultConnectTimeout,
|
||||||
WriteTimeout: tcpWriteTimeout,
|
WriteTimeout: tcpWriteTimeout,
|
||||||
},
|
},
|
||||||
clientStarted: abool.New(),
|
|
||||||
clientHeartbeat: make(chan struct{}),
|
|
||||||
stopClient: func() {},
|
|
||||||
connInstanceID: &instanceID,
|
|
||||||
queries: make(chan *dns.Msg, 1000),
|
|
||||||
inFlightQueries: make(map[uint16]*InFlightQuery),
|
|
||||||
}
|
}
|
||||||
newResolver.BasicResolverConn.init()
|
newResolver.BasicResolverConn.init()
|
||||||
return newResolver
|
return newResolver
|
||||||
|
@ -94,45 +105,214 @@ func (tr *TCPResolver) UseTLS() *TCPResolver {
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
|
func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) {
|
||||||
// make sure client is started
|
|
||||||
tr.startClient()
|
|
||||||
|
|
||||||
// create msg
|
|
||||||
msg := &dns.Msg{}
|
|
||||||
msg.SetQuestion(q.FQDN, uint16(q.QType))
|
|
||||||
|
|
||||||
// save to waitlist
|
|
||||||
inFlight := &InFlightQuery{
|
|
||||||
Query: q,
|
|
||||||
Msg: msg,
|
|
||||||
Response: make(chan *dns.Msg),
|
|
||||||
Resolver: tr.resolver,
|
|
||||||
Started: time.Now().UTC(),
|
|
||||||
ConnInstanceID: atomic.LoadUint32(tr.connInstanceID),
|
|
||||||
}
|
|
||||||
tr.Lock()
|
tr.Lock()
|
||||||
// check for existing query
|
defer tr.Unlock()
|
||||||
tr.ensureUniqueID(msg)
|
|
||||||
// add query to in flight registry
|
|
||||||
tr.inFlightQueries[msg.Id] = inFlight
|
|
||||||
tr.Unlock()
|
|
||||||
|
|
||||||
// submit msg for writing
|
// Check if we have a resolver.
|
||||||
|
if tr.resolverConn != nil && tr.resolverConn.abandoned.IsNotSet() {
|
||||||
|
// If there is one, check if it's alive!
|
||||||
select {
|
select {
|
||||||
case tr.queries <- msg:
|
case tr.resolverConn.heartbeat <- struct{}{}:
|
||||||
|
return tr.resolverConn, nil
|
||||||
|
case <-time.After(heartbeatTimeout):
|
||||||
|
log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new if no active one is available.
|
||||||
|
|
||||||
|
// Refresh the dialer in order to set an authenticated local address.
|
||||||
|
tr.dnsClient.Dialer = &net.Dialer{
|
||||||
|
LocalAddr: getLocalAddr("tcp"),
|
||||||
|
Timeout: tcpConnectionEstablishmentTimeout,
|
||||||
|
KeepAlive: defaultClientTTL,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to server.
|
||||||
|
var err error
|
||||||
|
conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("resolver: failed to connect to %s", tr.resolver.Info.DescriptiveName())
|
||||||
|
return nil, fmt.Errorf("%w: failed to connect to %s: %s", ErrFailure, tr.resolver.Info.DescriptiveName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log that a connection to the resolver was established.
|
||||||
|
log.Debugf(
|
||||||
|
"resolver: connected to %s",
|
||||||
|
tr.resolver.Info.DescriptiveName(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create resolver connection.
|
||||||
|
tr.resolverConnInstanceID++
|
||||||
|
resolverConn := &tcpResolverConn{
|
||||||
|
id: tr.resolverConnInstanceID,
|
||||||
|
conn: conn,
|
||||||
|
resolverInfo: tr.resolver.Info,
|
||||||
|
queries: make(chan *tcpQuery, 10),
|
||||||
|
responses: make(chan *dns.Msg, 10),
|
||||||
|
inFlightQueries: make(map[uint16]*tcpQuery, 10),
|
||||||
|
heartbeat: make(chan struct{}),
|
||||||
|
abandoned: abool.New(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start worker.
|
||||||
|
module.StartWorker("dns client", resolverConn.handler)
|
||||||
|
|
||||||
|
// Set resolver conn for reuse.
|
||||||
|
tr.resolverConn = resolverConn
|
||||||
|
|
||||||
|
// Hint network environment at successful connection.
|
||||||
|
netenv.ReportSuccessfulConnection()
|
||||||
|
|
||||||
|
return resolverConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query executes the given query against the resolver.
|
||||||
|
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
|
// Get resolver connection.
|
||||||
|
resolverConn, err := tr.getOrCreateResolverConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create query request.
|
||||||
|
tq := &tcpQuery{
|
||||||
|
Query: q,
|
||||||
|
Response: make(chan *dns.Msg),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit query request to live connection.
|
||||||
|
select {
|
||||||
|
case resolverConn.queries <- tq:
|
||||||
case <-time.After(defaultRequestTimeout):
|
case <-time.After(defaultRequestTimeout):
|
||||||
|
return nil, ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for reply.
|
||||||
|
var reply *dns.Msg
|
||||||
|
select {
|
||||||
|
case reply = <-tq.Response:
|
||||||
|
case <-time.After(defaultRequestTimeout):
|
||||||
|
return nil, ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have a reply.
|
||||||
|
if reply == nil {
|
||||||
|
// Resolver is shutting down. The Portmaster may be shutting down, or
|
||||||
|
// there is a connection error.
|
||||||
|
return nil, ErrFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the reply was blocked upstream.
|
||||||
|
if tr.resolver.IsBlockedUpstream(reply) {
|
||||||
|
return nil, &BlockedUpstreamError{tr.resolver.Info.DescriptiveName()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create RRCache from reply and return it.
|
||||||
|
return tq.MakeCacheRecord(reply, tr.resolver.Info), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (trc *tcpResolverConn) shutdown() {
|
||||||
|
// Set abandoned status and close connection to the DNS server.
|
||||||
|
if trc.abandoned.SetToIf(false, true) {
|
||||||
|
_ = trc.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close all response channels for in-flight queries.
|
||||||
|
for _, tq := range trc.inFlightQueries {
|
||||||
|
close(tq.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Respond to any incoming queries for some time in order to not leave them
|
||||||
|
// hanging longer than necessary.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case tq := <-trc.queries:
|
||||||
|
close(tq.Response)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (trc *tcpResolverConn) handler(workerCtx context.Context) error {
|
||||||
|
// Set up context and cleanup.
|
||||||
|
trc.ctx, trc.cancelCtx = context.WithCancel(workerCtx)
|
||||||
|
defer trc.shutdown()
|
||||||
|
|
||||||
|
// Set up variables.
|
||||||
|
var readyToRecycle bool
|
||||||
|
ttlTimer := time.After(defaultClientTTL)
|
||||||
|
|
||||||
|
// Start connection reader.
|
||||||
|
module.StartWorker("dns client reader", trc.reader)
|
||||||
|
|
||||||
|
// Handle requests.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-trc.heartbeat:
|
||||||
|
// Respond to alive checks.
|
||||||
|
|
||||||
|
case <-trc.ctx.Done():
|
||||||
|
// Respond to module shutdown or conn error.
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case <-ttlTimer:
|
||||||
|
// Recycle the connection after the TTL is reached.
|
||||||
|
readyToRecycle = true
|
||||||
|
// Send dummy response to trigger the check.
|
||||||
|
select {
|
||||||
|
case trc.responses <- nil:
|
||||||
|
default:
|
||||||
|
// The response queue is full.
|
||||||
|
// The check will be triggered by another response.
|
||||||
|
}
|
||||||
|
|
||||||
|
case tq := <-trc.queries:
|
||||||
|
// Handle DNS query request.
|
||||||
|
|
||||||
|
// Create dns request message.
|
||||||
|
msg := &dns.Msg{}
|
||||||
|
msg.SetQuestion(tq.Query.FQDN, uint16(tq.Query.QType))
|
||||||
|
|
||||||
|
// Assign a unique message ID.
|
||||||
|
trc.assignUniqueID(msg)
|
||||||
|
|
||||||
|
// Add query to in flight registry.
|
||||||
|
trc.inFlightQueries[msg.Id] = tq
|
||||||
|
|
||||||
|
// Write query to connected DNS server.
|
||||||
|
_ = trc.conn.SetWriteDeadline(time.Now().Add(tcpWriteTimeout))
|
||||||
|
err := trc.conn.WriteMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
trc.logConnectionError(err, false)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return inFlight
|
case msg := <-trc.responses:
|
||||||
|
if msg != nil {
|
||||||
|
trc.handleQueryResponse(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked.
|
// If we are ready to recycle and we have no in-flight queries, we can
|
||||||
func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
|
// shutdown the connection and create a new one for the next query.
|
||||||
|
if readyToRecycle {
|
||||||
|
if len(trc.inFlightQueries) == 0 {
|
||||||
|
log.Debugf("resolver: recycling connection to %s", trc.resolverInfo.DescriptiveName())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assignUniqueID makes sure that ID assigned to msg is unique.
|
||||||
|
func (trc *tcpResolverConn) assignUniqueID(msg *dns.Msg) {
|
||||||
// try a random ID 10000 times
|
// try a random ID 10000 times
|
||||||
for i := 0; i < 10000; i++ { // don't try forever
|
for i := 0; i < 10000; i++ { // don't try forever
|
||||||
_, exists := tr.inFlightQueries[msg.Id]
|
_, exists := trc.inFlightQueries[msg.Id]
|
||||||
if !exists {
|
if !exists {
|
||||||
return // we are unique, yay!
|
return // we are unique, yay!
|
||||||
}
|
}
|
||||||
|
@ -141,7 +321,7 @@ func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
|
||||||
// go through the complete space
|
// go through the complete space
|
||||||
var id uint16
|
var id uint16
|
||||||
for ; id <= (1<<16)-1; id++ { // don't try forever
|
for ; id <= (1<<16)-1; id++ { // don't try forever
|
||||||
_, exists := tr.inFlightQueries[id]
|
_, exists := trc.inFlightQueries[id]
|
||||||
if !exists {
|
if !exists {
|
||||||
msg.Id = id
|
msg.Id = id
|
||||||
return // we are unique, yay!
|
return // we are unique, yay!
|
||||||
|
@ -149,390 +329,80 @@ func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query executes the given query against the resolver.
|
func (trc *tcpResolverConn) handleQueryResponse(msg *dns.Msg) {
|
||||||
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
// Get in flight from registry.
|
||||||
// submit to client
|
tq, ok := trc.inFlightQueries[msg.Id]
|
||||||
inFlight := tr.submitQuery(ctx, q)
|
|
||||||
if inFlight == nil {
|
|
||||||
tr.checkClientStatus()
|
|
||||||
return nil, ErrTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
var reply *dns.Msg
|
|
||||||
select {
|
|
||||||
case reply = <-inFlight.Response:
|
|
||||||
case <-time.After(defaultRequestTimeout):
|
|
||||||
tr.checkClientStatus()
|
|
||||||
return nil, ErrTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
if reply == nil {
|
|
||||||
// Resolver is shutting down, could be server failure or we are offline
|
|
||||||
return nil, ErrFailure
|
|
||||||
}
|
|
||||||
|
|
||||||
if tr.resolver.IsBlockedUpstream(reply) {
|
|
||||||
return nil, &BlockedUpstreamError{tr.resolver.Info.DescriptiveName()}
|
|
||||||
}
|
|
||||||
|
|
||||||
return inFlight.MakeCacheRecord(reply), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tr *TCPResolver) checkClientStatus() {
|
|
||||||
// Get client cancel function before waiting in order to not immediately
|
|
||||||
// cancel a new client.
|
|
||||||
tr.Lock()
|
|
||||||
stopClient := tr.stopClient
|
|
||||||
tr.Unlock()
|
|
||||||
|
|
||||||
// Check if the client is alive with the heartbeat, if not shut it down.
|
|
||||||
select {
|
|
||||||
case tr.clientHeartbeat <- struct{}{}:
|
|
||||||
case <-time.After(heartbeatTimeout):
|
|
||||||
log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.Info.DescriptiveName())
|
|
||||||
stopClient()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type tcpResolverConnMgr struct {
|
|
||||||
tr *TCPResolver
|
|
||||||
responses chan *dns.Msg
|
|
||||||
failCnt int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tr *TCPResolver) startClient() {
|
|
||||||
if tr.clientStarted.SetToIf(false, true) {
|
|
||||||
mgr := &tcpResolverConnMgr{
|
|
||||||
tr: tr,
|
|
||||||
responses: make(chan *dns.Msg, 100),
|
|
||||||
}
|
|
||||||
module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
|
||||||
defer mgr.shutdown()
|
|
||||||
mgr.tr.clientStarted.Set()
|
|
||||||
|
|
||||||
// Create additional cancel function for this worker.
|
|
||||||
clientCtx, stopClient := context.WithCancel(workerCtx)
|
|
||||||
mgr.tr.Lock()
|
|
||||||
mgr.tr.stopClient = stopClient
|
|
||||||
mgr.tr.Unlock()
|
|
||||||
|
|
||||||
// connection lifecycle loop
|
|
||||||
for {
|
|
||||||
// check if we are shutting down
|
|
||||||
select {
|
|
||||||
case <-clientCtx.Done():
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if we are failing
|
|
||||||
if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for work before creating connection
|
|
||||||
proceed := mgr.waitForWork(clientCtx)
|
|
||||||
if !proceed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// create connection
|
|
||||||
conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection()
|
|
||||||
if conn == nil {
|
|
||||||
mgr.failCnt++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// hint network environment at successful connection
|
|
||||||
netenv.ReportSuccessfulConnection()
|
|
||||||
|
|
||||||
// handle queries
|
|
||||||
proceed = mgr.queryHandler(clientCtx, conn, connClosing, connCtx, cancelConnCtx)
|
|
||||||
if !proceed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) shutdown() {
|
|
||||||
// reply to all waiting queries
|
|
||||||
mgr.tr.Lock()
|
|
||||||
defer mgr.tr.Unlock()
|
|
||||||
|
|
||||||
mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
|
|
||||||
atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter
|
|
||||||
|
|
||||||
for id, inFlight := range mgr.tr.inFlightQueries {
|
|
||||||
close(inFlight.Response)
|
|
||||||
delete(mgr.tr.inFlightQueries, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// hint network environment at failed connection
|
|
||||||
if mgr.failCnt >= FailThreshold {
|
|
||||||
netenv.ReportFailedConnection()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed bool) {
|
|
||||||
// wait until there is something to do
|
|
||||||
mgr.tr.Lock()
|
|
||||||
waiting := len(mgr.tr.inFlightQueries)
|
|
||||||
mgr.tr.Unlock()
|
|
||||||
if waiting > 0 {
|
|
||||||
// queue abandoned queries
|
|
||||||
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
|
|
||||||
currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID)
|
|
||||||
mgr.tr.Lock()
|
|
||||||
defer mgr.tr.Unlock()
|
|
||||||
for id, inFlight := range mgr.tr.inFlightQueries {
|
|
||||||
if inFlight.Started.Before(ignoreBefore) {
|
|
||||||
// remove old queries
|
|
||||||
close(inFlight.Response)
|
|
||||||
delete(mgr.tr.inFlightQueries, id)
|
|
||||||
} else if inFlight.ConnInstanceID != currentConnInstanceID {
|
|
||||||
inFlight.ConnInstanceID = currentConnInstanceID
|
|
||||||
// re-inject queries that died with a previously failed connection
|
|
||||||
select {
|
|
||||||
case mgr.tr.queries <- inFlight.Msg:
|
|
||||||
default:
|
|
||||||
log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.Info.DescriptiveName())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// in-flight queries that match the connection instance ID are not changed. They are already in the queue.
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for first query
|
|
||||||
select {
|
|
||||||
case <-clientCtx.Done():
|
|
||||||
return false
|
|
||||||
case msg := <-mgr.tr.queries:
|
|
||||||
// re-insert query, we will handle it later
|
|
||||||
module.StartWorker("reinject triggering dns query", func(ctx context.Context) error {
|
|
||||||
select {
|
|
||||||
case mgr.tr.queries <- msg:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Info.DescriptiveName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) establishConnection() (
|
|
||||||
conn *dns.Conn,
|
|
||||||
connClosing *abool.AtomicBool,
|
|
||||||
connCtx context.Context,
|
|
||||||
cancelConnCtx context.CancelFunc,
|
|
||||||
) {
|
|
||||||
// refresh dialer to set an authenticated local address
|
|
||||||
// TODO: lock dnsClient (only manager should run at any time, so this should not be an issue)
|
|
||||||
mgr.tr.dnsClient.Dialer = &net.Dialer{
|
|
||||||
LocalAddr: getLocalAddr("tcp"),
|
|
||||||
Timeout: defaultConnectTimeout,
|
|
||||||
KeepAlive: defaultClientTTL,
|
|
||||||
}
|
|
||||||
// connect
|
|
||||||
var err error
|
|
||||||
conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("resolver: failed to connect to %s", mgr.tr.resolver.Info.DescriptiveName())
|
|
||||||
return nil, nil, nil, nil
|
|
||||||
}
|
|
||||||
connCtx, cancelConnCtx = context.WithCancel(context.Background())
|
|
||||||
connClosing = abool.New()
|
|
||||||
|
|
||||||
// Get amount of in waiting queries.
|
|
||||||
mgr.tr.Lock()
|
|
||||||
waitingQueries := len(mgr.tr.inFlightQueries)
|
|
||||||
mgr.tr.Unlock()
|
|
||||||
|
|
||||||
// Log that a connection to the resolver was established.
|
|
||||||
log.Debugf(
|
|
||||||
"resolver: connected to %s with %d queries waiting",
|
|
||||||
mgr.tr.resolver.Info.DescriptiveName(),
|
|
||||||
waitingQueries,
|
|
||||||
)
|
|
||||||
|
|
||||||
// start reader
|
|
||||||
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(clientCtx context.Context) error {
|
|
||||||
return mgr.msgReader(conn, connClosing, cancelConnCtx)
|
|
||||||
})
|
|
||||||
|
|
||||||
return conn, connClosing, connCtx, cancelConnCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
|
|
||||||
clientCtx context.Context,
|
|
||||||
conn *dns.Conn,
|
|
||||||
connClosing *abool.AtomicBool,
|
|
||||||
connCtx context.Context,
|
|
||||||
cancelConnCtx context.CancelFunc,
|
|
||||||
) (proceed bool) {
|
|
||||||
var readyToRecycle bool
|
|
||||||
ttlTimer := time.After(mgr.tr.clientTTL)
|
|
||||||
|
|
||||||
// clean up connection
|
|
||||||
defer func() {
|
|
||||||
connClosing.Set() // silence connection errors
|
|
||||||
cancelConnCtx()
|
|
||||||
_ = conn.Close()
|
|
||||||
|
|
||||||
// increase instance counter
|
|
||||||
atomic.AddUint32(mgr.tr.connInstanceID, 1)
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-mgr.tr.clientHeartbeat:
|
|
||||||
// respond to alive checks
|
|
||||||
|
|
||||||
case <-clientCtx.Done():
|
|
||||||
// module shutdown
|
|
||||||
return false
|
|
||||||
|
|
||||||
case <-connCtx.Done():
|
|
||||||
// connection error
|
|
||||||
return true
|
|
||||||
|
|
||||||
case <-ttlTimer:
|
|
||||||
// connection TTL reached, rebuild connection
|
|
||||||
// but handle all in flight queries first
|
|
||||||
readyToRecycle = true
|
|
||||||
// trigger check
|
|
||||||
select {
|
|
||||||
case mgr.responses <- nil:
|
|
||||||
default:
|
|
||||||
// queue is full, check will be triggered anyway
|
|
||||||
}
|
|
||||||
|
|
||||||
case msg := <-mgr.tr.queries:
|
|
||||||
// write query
|
|
||||||
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
|
|
||||||
err := conn.WriteMsg(msg)
|
|
||||||
if err != nil {
|
|
||||||
mgr.logConnectionError(err, conn, connClosing, false)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
case msg := <-mgr.responses:
|
|
||||||
if msg != nil {
|
|
||||||
mgr.handleQueryResponse(conn, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if readyToRecycle {
|
|
||||||
// check to see if we can recycle the connection
|
|
||||||
mgr.tr.Lock()
|
|
||||||
activeQueries := len(mgr.tr.inFlightQueries)
|
|
||||||
mgr.tr.Unlock()
|
|
||||||
if activeQueries == 0 {
|
|
||||||
log.Debugf("resolver: recycling conn to %s", mgr.tr.resolver.Info.DescriptiveName())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) {
|
|
||||||
// handle query from resolver
|
|
||||||
mgr.tr.Lock()
|
|
||||||
inFlight, ok := mgr.tr.inFlightQueries[msg.Id]
|
|
||||||
if ok {
|
if ok {
|
||||||
delete(mgr.tr.inFlightQueries, msg.Id)
|
delete(trc.inFlightQueries, msg.Id)
|
||||||
}
|
} else {
|
||||||
mgr.tr.Unlock()
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
log.Debugf(
|
log.Debugf(
|
||||||
"resolver: received possibly unsolicited reply from %s: txid=%d q=%+v",
|
"resolver: received possibly unsolicited reply from %s: txid=%d q=%+v",
|
||||||
mgr.tr.resolver.Info.DescriptiveName(),
|
trc.resolverInfo.DescriptiveName(),
|
||||||
msg.Id,
|
msg.Id,
|
||||||
msg.Question,
|
msg.Question,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send response to waiting query handler.
|
||||||
select {
|
select {
|
||||||
case inFlight.Response <- msg:
|
case tq.Response <- msg:
|
||||||
mgr.failCnt = 0 // reset fail counter
|
|
||||||
// responded!
|
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
// no one is listening for that response.
|
// No one is listening for that response.
|
||||||
}
|
}
|
||||||
|
|
||||||
// if caching is disabled we're done
|
// If caching is disabled for this query, we are done.
|
||||||
if inFlight.Query.NoCaching {
|
if tq.Query.NoCaching {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist to database
|
// Otherwise, we can persist the answer in case the request is repeated.
|
||||||
rrCache := inFlight.MakeCacheRecord(msg)
|
rrCache := tq.MakeCacheRecord(msg, trc.resolverInfo)
|
||||||
rrCache.Clean(minTTL)
|
rrCache.Clean(minTTL)
|
||||||
err := rrCache.Save()
|
err := rrCache.Save()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warningf(
|
log.Warningf(
|
||||||
"resolver: failed to cache RR for %s%s: %s",
|
"resolver: failed to cache RR for %s: %s",
|
||||||
inFlight.Query.FQDN,
|
tq.Query.ID(),
|
||||||
inFlight.Query.QType.String(),
|
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) msgReader(
|
func (trc *tcpResolverConn) reader(workerCtx context.Context) error {
|
||||||
conn *dns.Conn,
|
defer trc.cancelCtx()
|
||||||
connClosing *abool.AtomicBool,
|
|
||||||
cancelConnCtx context.CancelFunc,
|
|
||||||
) error {
|
|
||||||
defer cancelConnCtx()
|
|
||||||
for {
|
for {
|
||||||
msg, err := conn.ReadMsg()
|
msg, err := trc.conn.ReadMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mgr.logConnectionError(err, conn, connClosing, true)
|
trc.logConnectionError(err, true)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
mgr.responses <- msg
|
trc.responses <- msg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) logConnectionError(err error, conn *dns.Conn, connClosing *abool.AtomicBool, reading bool) {
|
func (trc *tcpResolverConn) logConnectionError(err error, reading bool) {
|
||||||
// Check if we are the first to see an error.
|
// Check if we are the first to see an error.
|
||||||
if connClosing.SetToIf(false, true) {
|
if trc.abandoned.SetToIf(false, true) {
|
||||||
// Get amount of in flight queries.
|
|
||||||
mgr.tr.Lock()
|
|
||||||
inFlightQueries := len(mgr.tr.inFlightQueries)
|
|
||||||
mgr.tr.Unlock()
|
|
||||||
|
|
||||||
// Log error.
|
// Log error.
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, io.EOF):
|
case errors.Is(err, io.EOF):
|
||||||
log.Debugf(
|
log.Debugf(
|
||||||
"resolver: connection to %s was closed with %d in-flight queries",
|
"resolver: connection to %s was closed",
|
||||||
mgr.tr.resolver.Info.DescriptiveName(),
|
trc.resolverInfo.DescriptiveName(),
|
||||||
inFlightQueries,
|
|
||||||
)
|
)
|
||||||
case reading:
|
case reading:
|
||||||
log.Warningf(
|
log.Warningf(
|
||||||
"resolver: read error from %s with %d in-flight queries: %s",
|
"resolver: read error from %s: %s",
|
||||||
mgr.tr.resolver.Info.DescriptiveName(),
|
trc.resolverInfo.DescriptiveName(),
|
||||||
inFlightQueries,
|
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
default:
|
default:
|
||||||
log.Warningf(
|
log.Warningf(
|
||||||
"resolver: write error to %s with %d in-flight queries: %s",
|
"resolver: write error to %s: %s",
|
||||||
mgr.tr.resolver.Info.DescriptiveName(),
|
trc.resolverInfo.DescriptiveName(),
|
||||||
inFlightQueries,
|
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,7 +128,7 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
|
||||||
|
|
||||||
blockType := query.Get("blockedif")
|
blockType := query.Get("blockedif")
|
||||||
if blockType == "" {
|
if blockType == "" {
|
||||||
blockType = BlockDetectionRefused
|
blockType = BlockDetectionZeroIP
|
||||||
}
|
}
|
||||||
|
|
||||||
switch blockType {
|
switch blockType {
|
||||||
|
|
Loading…
Add table
Reference in a new issue