diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index 63a0c6ef..057fd034 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -4,9 +4,9 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "net" - "sync/atomic" "time" "github.com/miekg/dns" @@ -16,68 +16,79 @@ import ( ) const ( - tcpWriteTimeout = 2 * time.Second - ignoreQueriesAfter = 10 * time.Minute - heartbeatTimeout = 15 * time.Second + tcpConnectionEstablishmentTimeout = 3 * time.Second + tcpWriteTimeout = 2 * time.Second + heartbeatTimeout = 5 * time.Second + ignoreQueriesAfter = 5 * time.Minute ) // TCPResolver is a resolver using just a single tcp connection with pipelining. type TCPResolver struct { BasicResolverConn - clientTTL time.Duration + // dnsClient holds the connection configuration of the resolver. dnsClient *dns.Client - - clientStarted *abool.AtomicBool - clientHeartbeat chan struct{} - stopClient func() - connInstanceID *uint32 - queries chan *dns.Msg - inFlightQueries map[uint16]*InFlightQuery + // resolverConn holds a connection to the DNS server, including query management. + resolverConn *tcpResolverConn + // resolverConnInstanceID holds the current ID of the resolverConn. + resolverConnInstanceID int } -// InFlightQuery represents an in flight query of a TCPResolver. -type InFlightQuery struct { - Query *Query - Msg *dns.Msg - Response chan *dns.Msg - Resolver *Resolver - Started time.Time - ConnInstanceID uint32 +// tcpResolverConn represents a single connection to an upstream DNS server. +type tcpResolverConn struct { + // ctx is the context of the tcpResolverConn. + ctx context.Context + // cancelCtx cancels cancelCtx + cancelCtx func() + // id is the ID assigned to the resolver conn. + id int + // conn is the connection to the DNS server. + conn *dns.Conn + // resolverInfo holds information about the resolver to enhance error messages. + resolverInfo *ResolverInfo + // queries is used to submit queries to be sent to the connected DNS server. + queries chan *tcpQuery + // responses is used to hand the responses from the reader to the handler. + responses chan *dns.Msg + // inFlightQueries holds all in-flight queries of this connection. + inFlightQueries map[uint16]*tcpQuery + // heartbeat is a alive-checking channel from which the resolver conn must + // always read asap. + heartbeat chan struct{} + // abandoned signifies if the resolver conn has been abandoned. + abandoned *abool.AtomicBool } -// MakeCacheRecord creates an RCache record from a reply. -func (ifq *InFlightQuery) MakeCacheRecord(reply *dns.Msg) *RRCache { +// tcpQuery holds the query information for a tcpResolverConn. +type tcpQuery struct { + Query *Query + Response chan *dns.Msg +} + +// MakeCacheRecord creates an RRCache record from a reply. +func (tq *tcpQuery) MakeCacheRecord(reply *dns.Msg, resolverInfo *ResolverInfo) *RRCache { return &RRCache{ - Domain: ifq.Query.FQDN, - Question: ifq.Query.QType, + Domain: tq.Query.FQDN, + Question: tq.Query.QType, RCode: reply.Rcode, Answer: reply.Answer, Ns: reply.Ns, Extra: reply.Extra, - Resolver: ifq.Resolver.Info.Copy(), + Resolver: resolverInfo.Copy(), } } // NewTCPResolver returns a new TPCResolver. func NewTCPResolver(resolver *Resolver) *TCPResolver { - var instanceID uint32 newResolver := &TCPResolver{ BasicResolverConn: BasicResolverConn{ resolver: resolver, }, - clientTTL: defaultClientTTL, dnsClient: &dns.Client{ Net: "tcp", Timeout: defaultConnectTimeout, WriteTimeout: tcpWriteTimeout, }, - clientStarted: abool.New(), - clientHeartbeat: make(chan struct{}), - stopClient: func() {}, - connInstanceID: &instanceID, - queries: make(chan *dns.Msg, 1000), - inFlightQueries: make(map[uint16]*InFlightQuery), } newResolver.BasicResolverConn.init() return newResolver @@ -94,45 +105,214 @@ func (tr *TCPResolver) UseTLS() *TCPResolver { return tr } -func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery { - // make sure client is started - tr.startClient() - - // create msg - msg := &dns.Msg{} - msg.SetQuestion(q.FQDN, uint16(q.QType)) - - // save to waitlist - inFlight := &InFlightQuery{ - Query: q, - Msg: msg, - Response: make(chan *dns.Msg), - Resolver: tr.resolver, - Started: time.Now().UTC(), - ConnInstanceID: atomic.LoadUint32(tr.connInstanceID), - } +func (tr *TCPResolver) getOrCreateResolverConn() (*tcpResolverConn, error) { tr.Lock() - // check for existing query - tr.ensureUniqueID(msg) - // add query to in flight registry - tr.inFlightQueries[msg.Id] = inFlight - tr.Unlock() + defer tr.Unlock() - // submit msg for writing - select { - case tr.queries <- msg: - case <-time.After(defaultRequestTimeout): - return nil + // Check if we have a resolver. + if tr.resolverConn != nil && tr.resolverConn.abandoned.IsNotSet() { + // If there is one, check if it's alive! + select { + case tr.resolverConn.heartbeat <- struct{}{}: + return tr.resolverConn, nil + case <-time.After(heartbeatTimeout): + log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName()) + } } - return inFlight + // Create a new if no active one is available. + + // Refresh the dialer in order to set an authenticated local address. + tr.dnsClient.Dialer = &net.Dialer{ + LocalAddr: getLocalAddr("tcp"), + Timeout: tcpConnectionEstablishmentTimeout, + KeepAlive: defaultClientTTL, + } + + // Connect to server. + var err error + conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress) + if err != nil { + log.Debugf("resolver: failed to connect to %s", tr.resolver.Info.DescriptiveName()) + return nil, fmt.Errorf("%w: failed to connect to %s: %s", ErrFailure, tr.resolver.Info.DescriptiveName(), err) + } + + // Log that a connection to the resolver was established. + log.Debugf( + "resolver: connected to %s", + tr.resolver.Info.DescriptiveName(), + ) + + // Create resolver connection. + tr.resolverConnInstanceID++ + resolverConn := &tcpResolverConn{ + id: tr.resolverConnInstanceID, + conn: conn, + resolverInfo: tr.resolver.Info, + queries: make(chan *tcpQuery, 10), + responses: make(chan *dns.Msg, 10), + inFlightQueries: make(map[uint16]*tcpQuery, 10), + heartbeat: make(chan struct{}), + abandoned: abool.New(), + } + + // Start worker. + module.StartWorker("dns client", resolverConn.handler) + + // Set resolver conn for reuse. + tr.resolverConn = resolverConn + + // Hint network environment at successful connection. + netenv.ReportSuccessfulConnection() + + return resolverConn, nil } -// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked. -func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) { +// Query executes the given query against the resolver. +func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { + // Get resolver connection. + resolverConn, err := tr.getOrCreateResolverConn() + if err != nil { + return nil, err + } + + // Create query request. + tq := &tcpQuery{ + Query: q, + Response: make(chan *dns.Msg), + } + + // Submit query request to live connection. + select { + case resolverConn.queries <- tq: + case <-time.After(defaultRequestTimeout): + return nil, ErrTimeout + } + + // Wait for reply. + var reply *dns.Msg + select { + case reply = <-tq.Response: + case <-time.After(defaultRequestTimeout): + return nil, ErrTimeout + } + + // Check if we have a reply. + if reply == nil { + // Resolver is shutting down. The Portmaster may be shutting down, or + // there is a connection error. + return nil, ErrFailure + } + + // Check if the reply was blocked upstream. + if tr.resolver.IsBlockedUpstream(reply) { + return nil, &BlockedUpstreamError{tr.resolver.Info.DescriptiveName()} + } + + // Create RRCache from reply and return it. + return tq.MakeCacheRecord(reply, tr.resolver.Info), nil +} + +func (trc *tcpResolverConn) shutdown() { + // Set abandoned status and close connection to the DNS server. + if trc.abandoned.SetToIf(false, true) { + _ = trc.conn.Close() + } + + // Close all response channels for in-flight queries. + for _, tq := range trc.inFlightQueries { + close(tq.Response) + } + + // Respond to any incoming queries for some time in order to not leave them + // hanging longer than necessary. + for { + select { + case tq := <-trc.queries: + close(tq.Response) + case <-time.After(100 * time.Millisecond): + return + } + } +} + +func (trc *tcpResolverConn) handler(workerCtx context.Context) error { + // Set up context and cleanup. + trc.ctx, trc.cancelCtx = context.WithCancel(workerCtx) + defer trc.shutdown() + + // Set up variables. + var readyToRecycle bool + ttlTimer := time.After(defaultClientTTL) + + // Start connection reader. + module.StartWorker("dns client reader", trc.reader) + + // Handle requests. + for { + select { + case <-trc.heartbeat: + // Respond to alive checks. + + case <-trc.ctx.Done(): + // Respond to module shutdown or conn error. + return nil + + case <-ttlTimer: + // Recycle the connection after the TTL is reached. + readyToRecycle = true + // Send dummy response to trigger the check. + select { + case trc.responses <- nil: + default: + // The response queue is full. + // The check will be triggered by another response. + } + + case tq := <-trc.queries: + // Handle DNS query request. + + // Create dns request message. + msg := &dns.Msg{} + msg.SetQuestion(tq.Query.FQDN, uint16(tq.Query.QType)) + + // Assign a unique message ID. + trc.assignUniqueID(msg) + + // Add query to in flight registry. + trc.inFlightQueries[msg.Id] = tq + + // Write query to connected DNS server. + _ = trc.conn.SetWriteDeadline(time.Now().Add(tcpWriteTimeout)) + err := trc.conn.WriteMsg(msg) + if err != nil { + trc.logConnectionError(err, false) + return nil + } + + case msg := <-trc.responses: + if msg != nil { + trc.handleQueryResponse(msg) + } + + // If we are ready to recycle and we have no in-flight queries, we can + // shutdown the connection and create a new one for the next query. + if readyToRecycle { + if len(trc.inFlightQueries) == 0 { + log.Debugf("resolver: recycling connection to %s", trc.resolverInfo.DescriptiveName()) + return nil + } + } + + } + } +} + +// assignUniqueID makes sure that ID assigned to msg is unique. +func (trc *tcpResolverConn) assignUniqueID(msg *dns.Msg) { // try a random ID 10000 times for i := 0; i < 10000; i++ { // don't try forever - _, exists := tr.inFlightQueries[msg.Id] + _, exists := trc.inFlightQueries[msg.Id] if !exists { return // we are unique, yay! } @@ -141,7 +321,7 @@ func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) { // go through the complete space var id uint16 for ; id <= (1<<16)-1; id++ { // don't try forever - _, exists := tr.inFlightQueries[id] + _, exists := trc.inFlightQueries[id] if !exists { msg.Id = id return // we are unique, yay! @@ -149,390 +329,80 @@ func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) { } } -// Query executes the given query against the resolver. -func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { - // submit to client - inFlight := tr.submitQuery(ctx, q) - if inFlight == nil { - tr.checkClientStatus() - return nil, ErrTimeout - } - - var reply *dns.Msg - select { - case reply = <-inFlight.Response: - case <-time.After(defaultRequestTimeout): - tr.checkClientStatus() - return nil, ErrTimeout - } - - if reply == nil { - // Resolver is shutting down, could be server failure or we are offline - return nil, ErrFailure - } - - if tr.resolver.IsBlockedUpstream(reply) { - return nil, &BlockedUpstreamError{tr.resolver.Info.DescriptiveName()} - } - - return inFlight.MakeCacheRecord(reply), nil -} - -func (tr *TCPResolver) checkClientStatus() { - // Get client cancel function before waiting in order to not immediately - // cancel a new client. - tr.Lock() - stopClient := tr.stopClient - tr.Unlock() - - // Check if the client is alive with the heartbeat, if not shut it down. - select { - case tr.clientHeartbeat <- struct{}{}: - case <-time.After(heartbeatTimeout): - log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.Info.DescriptiveName()) - stopClient() - } -} - -type tcpResolverConnMgr struct { - tr *TCPResolver - responses chan *dns.Msg - failCnt int -} - -func (tr *TCPResolver) startClient() { - if tr.clientStarted.SetToIf(false, true) { - mgr := &tcpResolverConnMgr{ - tr: tr, - responses: make(chan *dns.Msg, 100), - } - module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run) - } -} - -func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { - defer mgr.shutdown() - mgr.tr.clientStarted.Set() - - // Create additional cancel function for this worker. - clientCtx, stopClient := context.WithCancel(workerCtx) - mgr.tr.Lock() - mgr.tr.stopClient = stopClient - mgr.tr.Unlock() - - // connection lifecycle loop - for { - // check if we are shutting down - select { - case <-clientCtx.Done(): - return nil - default: - } - - // check if we are failing - if mgr.failCnt >= FailThreshold || mgr.tr.IsFailing() { - return nil - } - - // wait for work before creating connection - proceed := mgr.waitForWork(clientCtx) - if !proceed { - return nil - } - - // create connection - conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection() - if conn == nil { - mgr.failCnt++ - continue - } - - // hint network environment at successful connection - netenv.ReportSuccessfulConnection() - - // handle queries - proceed = mgr.queryHandler(clientCtx, conn, connClosing, connCtx, cancelConnCtx) - if !proceed { - return nil - } - } -} - -func (mgr *tcpResolverConnMgr) shutdown() { - // reply to all waiting queries - mgr.tr.Lock() - defer mgr.tr.Unlock() - - mgr.tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds - atomic.AddUint32(mgr.tr.connInstanceID, 1) // increase instance counter - - for id, inFlight := range mgr.tr.inFlightQueries { - close(inFlight.Response) - delete(mgr.tr.inFlightQueries, id) - } - - // hint network environment at failed connection - if mgr.failCnt >= FailThreshold { - netenv.ReportFailedConnection() - } -} - -func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed bool) { - // wait until there is something to do - mgr.tr.Lock() - waiting := len(mgr.tr.inFlightQueries) - mgr.tr.Unlock() - if waiting > 0 { - // queue abandoned queries - ignoreBefore := time.Now().Add(-ignoreQueriesAfter) - currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID) - mgr.tr.Lock() - defer mgr.tr.Unlock() - for id, inFlight := range mgr.tr.inFlightQueries { - if inFlight.Started.Before(ignoreBefore) { - // remove old queries - close(inFlight.Response) - delete(mgr.tr.inFlightQueries, id) - } else if inFlight.ConnInstanceID != currentConnInstanceID { - inFlight.ConnInstanceID = currentConnInstanceID - // re-inject queries that died with a previously failed connection - select { - case mgr.tr.queries <- inFlight.Msg: - default: - log.Warningf("resolver: failed to re-inject abandoned query to %s", mgr.tr.resolver.Info.DescriptiveName()) - } - } - // in-flight queries that match the connection instance ID are not changed. They are already in the queue. - } - return true - } - - // wait for first query - select { - case <-clientCtx.Done(): - return false - case msg := <-mgr.tr.queries: - // re-insert query, we will handle it later - module.StartWorker("reinject triggering dns query", func(ctx context.Context) error { - select { - case mgr.tr.queries <- msg: - case <-time.After(2 * time.Second): - log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Info.DescriptiveName()) - } - return nil - }) - } - - return true -} - -func (mgr *tcpResolverConnMgr) establishConnection() ( - conn *dns.Conn, - connClosing *abool.AtomicBool, - connCtx context.Context, - cancelConnCtx context.CancelFunc, -) { - // refresh dialer to set an authenticated local address - // TODO: lock dnsClient (only manager should run at any time, so this should not be an issue) - mgr.tr.dnsClient.Dialer = &net.Dialer{ - LocalAddr: getLocalAddr("tcp"), - Timeout: defaultConnectTimeout, - KeepAlive: defaultClientTTL, - } - // connect - var err error - conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress) - if err != nil { - log.Debugf("resolver: failed to connect to %s", mgr.tr.resolver.Info.DescriptiveName()) - return nil, nil, nil, nil - } - connCtx, cancelConnCtx = context.WithCancel(context.Background()) - connClosing = abool.New() - - // Get amount of in waiting queries. - mgr.tr.Lock() - waitingQueries := len(mgr.tr.inFlightQueries) - mgr.tr.Unlock() - - // Log that a connection to the resolver was established. - log.Debugf( - "resolver: connected to %s with %d queries waiting", - mgr.tr.resolver.Info.DescriptiveName(), - waitingQueries, - ) - - // start reader - module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(clientCtx context.Context) error { - return mgr.msgReader(conn, connClosing, cancelConnCtx) - }) - - return conn, connClosing, connCtx, cancelConnCtx -} - -func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter. - clientCtx context.Context, - conn *dns.Conn, - connClosing *abool.AtomicBool, - connCtx context.Context, - cancelConnCtx context.CancelFunc, -) (proceed bool) { - var readyToRecycle bool - ttlTimer := time.After(mgr.tr.clientTTL) - - // clean up connection - defer func() { - connClosing.Set() // silence connection errors - cancelConnCtx() - _ = conn.Close() - - // increase instance counter - atomic.AddUint32(mgr.tr.connInstanceID, 1) - }() - - for { - select { - case <-mgr.tr.clientHeartbeat: - // respond to alive checks - - case <-clientCtx.Done(): - // module shutdown - return false - - case <-connCtx.Done(): - // connection error - return true - - case <-ttlTimer: - // connection TTL reached, rebuild connection - // but handle all in flight queries first - readyToRecycle = true - // trigger check - select { - case mgr.responses <- nil: - default: - // queue is full, check will be triggered anyway - } - - case msg := <-mgr.tr.queries: - // write query - _ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout)) - err := conn.WriteMsg(msg) - if err != nil { - mgr.logConnectionError(err, conn, connClosing, false) - return true - } - - case msg := <-mgr.responses: - if msg != nil { - mgr.handleQueryResponse(conn, msg) - } - - if readyToRecycle { - // check to see if we can recycle the connection - mgr.tr.Lock() - activeQueries := len(mgr.tr.inFlightQueries) - mgr.tr.Unlock() - if activeQueries == 0 { - log.Debugf("resolver: recycling conn to %s", mgr.tr.resolver.Info.DescriptiveName()) - return true - } - } - - } - } -} - -func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) { - // handle query from resolver - mgr.tr.Lock() - inFlight, ok := mgr.tr.inFlightQueries[msg.Id] +func (trc *tcpResolverConn) handleQueryResponse(msg *dns.Msg) { + // Get in flight from registry. + tq, ok := trc.inFlightQueries[msg.Id] if ok { - delete(mgr.tr.inFlightQueries, msg.Id) - } - mgr.tr.Unlock() - - if !ok { + delete(trc.inFlightQueries, msg.Id) + } else { log.Debugf( "resolver: received possibly unsolicited reply from %s: txid=%d q=%+v", - mgr.tr.resolver.Info.DescriptiveName(), + trc.resolverInfo.DescriptiveName(), msg.Id, msg.Question, ) return } + // Send response to waiting query handler. select { - case inFlight.Response <- msg: - mgr.failCnt = 0 // reset fail counter - // responded! + case tq.Response <- msg: return default: - // no one is listening for that response. + // No one is listening for that response. } - // if caching is disabled we're done - if inFlight.Query.NoCaching { + // If caching is disabled for this query, we are done. + if tq.Query.NoCaching { return } - // persist to database - rrCache := inFlight.MakeCacheRecord(msg) + // Otherwise, we can persist the answer in case the request is repeated. + rrCache := tq.MakeCacheRecord(msg, trc.resolverInfo) rrCache.Clean(minTTL) err := rrCache.Save() if err != nil { log.Warningf( - "resolver: failed to cache RR for %s%s: %s", - inFlight.Query.FQDN, - inFlight.Query.QType.String(), + "resolver: failed to cache RR for %s: %s", + tq.Query.ID(), err, ) } } -func (mgr *tcpResolverConnMgr) msgReader( - conn *dns.Conn, - connClosing *abool.AtomicBool, - cancelConnCtx context.CancelFunc, -) error { - defer cancelConnCtx() +func (trc *tcpResolverConn) reader(workerCtx context.Context) error { + defer trc.cancelCtx() + for { - msg, err := conn.ReadMsg() + msg, err := trc.conn.ReadMsg() if err != nil { - mgr.logConnectionError(err, conn, connClosing, true) + trc.logConnectionError(err, true) return nil } - mgr.responses <- msg + trc.responses <- msg } } -func (mgr *tcpResolverConnMgr) logConnectionError(err error, conn *dns.Conn, connClosing *abool.AtomicBool, reading bool) { +func (trc *tcpResolverConn) logConnectionError(err error, reading bool) { // Check if we are the first to see an error. - if connClosing.SetToIf(false, true) { - // Get amount of in flight queries. - mgr.tr.Lock() - inFlightQueries := len(mgr.tr.inFlightQueries) - mgr.tr.Unlock() - + if trc.abandoned.SetToIf(false, true) { // Log error. switch { case errors.Is(err, io.EOF): log.Debugf( - "resolver: connection to %s was closed with %d in-flight queries", - mgr.tr.resolver.Info.DescriptiveName(), - inFlightQueries, + "resolver: connection to %s was closed", + trc.resolverInfo.DescriptiveName(), ) case reading: log.Warningf( - "resolver: read error from %s with %d in-flight queries: %s", - mgr.tr.resolver.Info.DescriptiveName(), - inFlightQueries, + "resolver: read error from %s: %s", + trc.resolverInfo.DescriptiveName(), err, ) default: log.Warningf( - "resolver: write error to %s with %d in-flight queries: %s", - mgr.tr.resolver.Info.DescriptiveName(), - inFlightQueries, + "resolver: write error to %s: %s", + trc.resolverInfo.DescriptiveName(), err, ) } diff --git a/resolver/resolvers.go b/resolver/resolvers.go index ca556d60..c235b85f 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -128,7 +128,7 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { blockType := query.Get("blockedif") if blockType == "" { - blockType = BlockDetectionRefused + blockType = BlockDetectionZeroIP } switch blockType {