From f7320d760d992cb1ebafb9c8481da5eb3f554766 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jun 2020 15:21:05 +0200 Subject: [PATCH] Add TCP/TLS pipelining dns resolver --- .gitignore | 3 + core/pmtesting/testing.go | 4 +- resolver/main_test.go | 122 +++++++ resolver/pooling_test.go | 189 ----------- resolver/resolve.go | 2 + resolver/{mdns.go => resolver-mdns.go} | 0 resolver/{clients.go => resolver-pooled.go} | 25 +- resolver/resolver-pooled_test.go | 83 +++++ resolver/resolver-tcp.go | 359 ++++++++++++++++++++ resolver/resolver-tcp_test.go | 72 ++++ resolver/resolvers.go | 21 +- resolver/{resolver-scopes.go => scopes.go} | 0 12 files changed, 672 insertions(+), 208 deletions(-) delete mode 100644 resolver/pooling_test.go rename resolver/{mdns.go => resolver-mdns.go} (100%) rename resolver/{clients.go => resolver-pooled.go} (87%) create mode 100644 resolver/resolver-pooled_test.go create mode 100644 resolver/resolver-tcp.go create mode 100644 resolver/resolver-tcp_test.go rename resolver/{resolver-scopes.go => scopes.go} (100%) diff --git a/.gitignore b/.gitignore index 0ca9ee77..3d3e9052 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,9 @@ dist # vendor dir vendor +# testing +testing + # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a diff --git a/core/pmtesting/testing.go b/core/pmtesting/testing.go index a3091586..fdfff913 100644 --- a/core/pmtesting/testing.go +++ b/core/pmtesting/testing.go @@ -54,7 +54,9 @@ func TestMain(m *testing.M, module *modules.Module) { // are shutdown. func TestMainWithHooks(m *testing.M, module *modules.Module, afterStartFn, beforeStopFn TestHookFunc) { // enable module for testing - module.Enable() + if module != nil { + module.Enable() + } // switch databases to memory only base.DefaultDatabaseStorageType = "hashmap" diff --git a/resolver/main_test.go b/resolver/main_test.go index 07537926..5d4862f8 100644 --- a/resolver/main_test.go +++ b/resolver/main_test.go @@ -6,6 +6,128 @@ import ( "github.com/safing/portmaster/core/pmtesting" ) +var ( + domainFeed = make(chan string) +) + func TestMain(m *testing.M) { pmtesting.TestMain(m, module) } + +func init() { + go feedDomains() +} + +func feedDomains() { + for { + for _, domain := range testDomains { + domainFeed <- domain + } + } +} + +// Data + +var ( + testDomains = []string{ + "facebook.com.", + "google.com.", + "youtube.com.", + "twitter.com.", + "instagram.com.", + "linkedin.com.", + "microsoft.com.", + "apple.com.", + "wikipedia.org.", + "plus.google.com.", + "en.wikipedia.org.", + "googletagmanager.com.", + "youtu.be.", + "adobe.com.", + "vimeo.com.", + "pinterest.com.", + "itunes.apple.com.", + "play.google.com.", + "maps.google.com.", + "goo.gl.", + "wordpress.com.", + "blogspot.com.", + "bit.ly.", + "github.com.", + "player.vimeo.com.", + "amazon.com.", + "wordpress.org.", + "docs.google.com.", + "yahoo.com.", + "mozilla.org.", + "tumblr.com.", + "godaddy.com.", + "flickr.com.", + "parked-content.godaddy.com.", + "drive.google.com.", + "support.google.com.", + "apache.org.", + "gravatar.com.", + "europa.eu.", + "qq.com.", + "w3.org.", + "nytimes.com.", + "reddit.com.", + "macromedia.com.", + "get.adobe.com.", + "soundcloud.com.", + "sourceforge.net.", + "sites.google.com.", + "nih.gov.", + "amazonaws.com.", + "t.co.", + "support.microsoft.com.", + "forbes.com.", + "theguardian.com.", + "cnn.com.", + "github.io.", + "bbc.co.uk.", + "dropbox.com.", + "whatsapp.com.", + "medium.com.", + "creativecommons.org.", + "www.ncbi.nlm.nih.gov.", + "httpd.apache.org.", + "archive.org.", + "ec.europa.eu.", + "php.net.", + "apps.apple.com.", + "weebly.com.", + "support.apple.com.", + "weibo.com.", + "wixsite.com.", + "issuu.com.", + "who.int.", + "paypal.com.", + "m.facebook.com.", + "oracle.com.", + "msn.com.", + "gnu.org.", + "tinyurl.com.", + "reuters.com.", + "l.facebook.com.", + "cloudflare.com.", + "wsj.com.", + "washingtonpost.com.", + "domainmarket.com.", + "imdb.com.", + "bbc.com.", + "bing.com.", + "accounts.google.com.", + "vk.com.", + "api.whatsapp.com.", + "opera.com.", + "cdc.gov.", + "slideshare.net.", + "wpa.qq.com.", + "harvard.edu.", + "mit.edu.", + "code.google.com.", + "wikimedia.org.", + } +) diff --git a/resolver/pooling_test.go b/resolver/pooling_test.go deleted file mode 100644 index 3c03c14c..00000000 --- a/resolver/pooling_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package resolver - -import ( - "sync" - "sync/atomic" - "testing" - - "github.com/miekg/dns" -) - -var ( - domainFeed = make(chan string) -) - -func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) { - dnsClient := brc.clientManager.getDNSClient() - - // create query - dnsQuery := new(dns.Msg) - dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) - - // get connection - conn, new, err := dnsClient.getConn() - if err != nil { - t.Fatalf("failed to connect: %s", err) //nolint:staticcheck - } - if new { - atomic.AddUint32(newCnt, 1) - } - - // query server - reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn) - if err != nil { - t.Fatal(err) //nolint:staticcheck - } - if reply == nil { - t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck - } - - t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl) - dnsClient.addToPool() - wg.Done() -} - -func TestClientPooling(t *testing.T) { - // skip if short - this test depends on the Internet and might fail randomly - if testing.Short() { - t.Skip() - } - - go feedDomains() - - // create separate resolver for this test - resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config") - if err != nil { - t.Fatal(err) - } - brc := resolver.Conn.(*BasicResolverConn) - - wg := &sync.WaitGroup{} - var newCnt uint32 - for i := 0; i < 10; i++ { - wg.Add(10) - for i := 0; i < 10; i++ { - go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck - FQDN: <-domainFeed, - QType: dns.Type(dns.TypeA), - }) - } - wg.Wait() - if newCnt > uint32(10+i) { - t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i) - } - } -} - -func feedDomains() { - for { - for _, domain := range poolingTestDomains { - domainFeed <- domain - } - } -} - -// Data - -var ( - poolingTestDomains = []string{ - "facebook.com.", - "google.com.", - "youtube.com.", - "twitter.com.", - "instagram.com.", - "linkedin.com.", - "microsoft.com.", - "apple.com.", - "wikipedia.org.", - "plus.google.com.", - "en.wikipedia.org.", - "googletagmanager.com.", - "youtu.be.", - "adobe.com.", - "vimeo.com.", - "pinterest.com.", - "itunes.apple.com.", - "play.google.com.", - "maps.google.com.", - "goo.gl.", - "wordpress.com.", - "blogspot.com.", - "bit.ly.", - "github.com.", - "player.vimeo.com.", - "amazon.com.", - "wordpress.org.", - "docs.google.com.", - "yahoo.com.", - "mozilla.org.", - "tumblr.com.", - "godaddy.com.", - "flickr.com.", - "parked-content.godaddy.com.", - "drive.google.com.", - "support.google.com.", - "apache.org.", - "gravatar.com.", - "europa.eu.", - "qq.com.", - "w3.org.", - "nytimes.com.", - "reddit.com.", - "macromedia.com.", - "get.adobe.com.", - "soundcloud.com.", - "sourceforge.net.", - "sites.google.com.", - "nih.gov.", - "amazonaws.com.", - "t.co.", - "support.microsoft.com.", - "forbes.com.", - "theguardian.com.", - "cnn.com.", - "github.io.", - "bbc.co.uk.", - "dropbox.com.", - "whatsapp.com.", - "medium.com.", - "creativecommons.org.", - "www.ncbi.nlm.nih.gov.", - "httpd.apache.org.", - "archive.org.", - "ec.europa.eu.", - "php.net.", - "apps.apple.com.", - "weebly.com.", - "support.apple.com.", - "weibo.com.", - "wixsite.com.", - "issuu.com.", - "who.int.", - "paypal.com.", - "m.facebook.com.", - "oracle.com.", - "msn.com.", - "gnu.org.", - "tinyurl.com.", - "reuters.com.", - "l.facebook.com.", - "cloudflare.com.", - "wsj.com.", - "washingtonpost.com.", - "domainmarket.com.", - "imdb.com.", - "bbc.com.", - "bing.com.", - "accounts.google.com.", - "vk.com.", - "api.whatsapp.com.", - "opera.com.", - "cdc.gov.", - "slideshare.net.", - "wpa.qq.com.", - "harvard.edu.", - "mit.edu.", - "code.google.com.", - "wikimedia.org.", - } -) diff --git a/resolver/resolve.go b/resolver/resolve.go index e9032f71..4dad1b3a 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -22,6 +22,8 @@ var ( ErrBlocked = errors.New("query was blocked") // ErrLocalhost is returned to *.localhost queries ErrLocalhost = errors.New("query for localhost") + // ErrTimeout is returned when a query times out + ErrTimeout = errors.New("query timed out") // detailed errors diff --git a/resolver/mdns.go b/resolver/resolver-mdns.go similarity index 100% rename from resolver/mdns.go rename to resolver/resolver-mdns.go diff --git a/resolver/clients.go b/resolver/resolver-pooled.go similarity index 87% rename from resolver/clients.go rename to resolver/resolver-pooled.go index e3456759..6a21e951 100644 --- a/resolver/clients.go +++ b/resolver/resolver-pooled.go @@ -8,12 +8,13 @@ import ( "time" "github.com/miekg/dns" + "github.com/safing/portbase/utils" ) const ( defaultClientTTL = 5 * time.Minute defaultRequestTimeout = 3 * time.Second // dns query - defaultConnectTimeout = 2 * time.Second // tcp/tls + defaultConnectTimeout = 5 * time.Second // tcp/tls connectionEOLGracePeriod = 7 * time.Second ) @@ -39,12 +40,12 @@ type dnsClientManager struct { lock sync.Mutex // set by creator - serverAddress string - ttl time.Duration // force refresh of connection to reduce traceability - factory func() *dns.Client + resolver *Resolver + ttl time.Duration // force refresh of connection to reduce traceability + factory func() *dns.Client // internal - pool sync.Pool + pool utils.StablePool } type dnsClient struct { @@ -57,7 +58,7 @@ type dnsClient struct { // getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { if dc.conn == nil { - dc.conn, err = dc.client.Dial(dc.mgr.serverAddress) + dc.conn, err = dc.client.Dial(dc.mgr.resolver.ServerAddress) if err != nil { return nil, false, err } @@ -78,8 +79,8 @@ func (dc *dnsClient) destroy() { func newDNSClientManager(resolver *Resolver) *dnsClientManager { return &dnsClientManager{ - serverAddress: resolver.ServerAddress, - ttl: 0, // new client for every request, as we need to randomize the port + resolver: resolver, + ttl: 0, // new client for every request, as we need to randomize the port factory: func() *dns.Client { return &dns.Client{ Timeout: defaultRequestTimeout, @@ -93,8 +94,8 @@ func newDNSClientManager(resolver *Resolver) *dnsClientManager { func newTCPClientManager(resolver *Resolver) *dnsClientManager { return &dnsClientManager{ - serverAddress: resolver.ServerAddress, - ttl: defaultClientTTL, + resolver: resolver, + ttl: defaultClientTTL, factory: func() *dns.Client { return &dns.Client{ Net: "tcp", @@ -111,8 +112,8 @@ func newTCPClientManager(resolver *Resolver) *dnsClientManager { func newTLSClientManager(resolver *Resolver) *dnsClientManager { return &dnsClientManager{ - serverAddress: resolver.ServerAddress, - ttl: defaultClientTTL, + resolver: resolver, + ttl: defaultClientTTL, factory: func() *dns.Client { return &dns.Client{ Net: "tcp-tls", diff --git a/resolver/resolver-pooled_test.go b/resolver/resolver-pooled_test.go new file mode 100644 index 00000000..5b460584 --- /dev/null +++ b/resolver/resolver-pooled_test.go @@ -0,0 +1,83 @@ +package resolver + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" +) + +func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) { + dnsClient := brc.clientManager.getDNSClient() + + // create query + dnsQuery := new(dns.Msg) + dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) + + // get connection + conn, new, err := dnsClient.getConn() + if err != nil { + t.Logf("failed to connect: %s", err) //nolint:staticcheck + wg.Done() + return + } + if new { + atomic.AddUint32(newCnt, 1) + } + + // query server + reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn) + if err != nil { + t.Logf("client failed: %s", err) //nolint:staticcheck + wg.Done() + return + } + if reply == nil { + t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck + } + + t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl) + dnsClient.addToPool() + wg.Done() +} + +func TestClientPooling(t *testing.T) { + // skip if short - this test depends on the Internet and might fail randomly + if testing.Short() { + t.Skip() + } + + // create separate resolver for this test + resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config") + // resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config") + if err != nil { + t.Fatal(err) + } + brc := &BasicResolverConn{ + clientManager: clientManagerFactory(resolver.ServerType)(resolver), + resolver: resolver, + } + resolver.Conn = brc + + started := time.Now() + + wg := &sync.WaitGroup{} + var newCnt uint32 + for i := 0; i < 10; i++ { + wg.Add(10) + for j := 0; j < 10; j++ { + go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck + FQDN: <-domainFeed, + QType: dns.Type(dns.TypeA), + }) + } + wg.Wait() + if newCnt > uint32(10+i) { + t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i) + } + } + + t.Logf("time taken: %s", time.Since(started)) +} diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go new file mode 100644 index 00000000..f4300f39 --- /dev/null +++ b/resolver/resolver-tcp.go @@ -0,0 +1,359 @@ +package resolver + +import ( + "context" + "crypto/tls" + "net" + "sync/atomic" + "time" + + "github.com/miekg/dns" + "github.com/safing/portbase/log" + "github.com/safing/portmaster/netenv" + "github.com/tevino/abool" +) + +const ( + tcpWriteTimeout = 1 * time.Second + ignoreQueriesAfter = 10 * time.Minute +) + +// TCPResolver is a resolver using just a single tcp connection with pipelining. +type TCPResolver struct { + BasicResolverConn + + clientTTL time.Duration + dnsClient *dns.Client + dnsConnection *dns.Conn + connInstanceID *uint32 + + queries chan *dns.Msg + inFlightQueries map[uint16]*InFlightQuery + clientStarted *abool.AtomicBool +} + +// InFlightQuery represents an in flight query of a TCPResolver. +type InFlightQuery struct { + Query *Query + Msg *dns.Msg + Response chan *dns.Msg + Resolver *Resolver + Started time.Time + ConnInstanceID uint32 +} + +// MakeCacheRecord creates an RCache record from a reply. +func (ifq *InFlightQuery) MakeCacheRecord(reply *dns.Msg) *RRCache { + return &RRCache{ + Domain: ifq.Query.FQDN, + Question: ifq.Query.QType, + Answer: reply.Answer, + Ns: reply.Ns, + Extra: reply.Extra, + Server: ifq.Resolver.Server, + ServerScope: ifq.Resolver.ServerIPScope, + } +} + +// NewTCPResolver returns a new TPCResolver. +func NewTCPResolver(resolver *Resolver) *TCPResolver { + var instanceID uint32 + return &TCPResolver{ + BasicResolverConn: BasicResolverConn{ + resolver: resolver, + }, + clientTTL: defaultClientTTL, + dnsClient: &dns.Client{ + Net: "tcp", + Timeout: defaultConnectTimeout, + WriteTimeout: tcpWriteTimeout, + }, + connInstanceID: &instanceID, + queries: make(chan *dns.Msg, 100), + inFlightQueries: make(map[uint16]*InFlightQuery), + clientStarted: abool.New(), + } +} + +// UseTLS enabled TLS for the TCPResolver. TLS settings must be correctly configured in the Resolver. +func (tr *TCPResolver) UseTLS() *TCPResolver { + tr.dnsClient.Net = "tcp-tls" + tr.dnsClient.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: tr.resolver.VerifyDomain, + // TODO: use portbase rng + } + return tr +} + +func (tr *TCPResolver) client(workerCtx context.Context) error { //nolint:gocognit,gocyclo // TODO + connTimer := time.NewTimer(tr.clientTTL) + connClosing := abool.New() + var connCtx context.Context + var cancelConnCtx func() + var recycleConn bool + var shuttingDown bool + var incoming = make(chan *dns.Msg, 100) + + // enable client restarting after crash + defer tr.clientStarted.UnSet() + +connMgmt: + for { + // cleanup old connection + if tr.dnsConnection != nil { + connClosing.Set() + _ = tr.dnsConnection.Close() + cancelConnCtx() + + tr.dnsConnection = nil + atomic.AddUint32(tr.connInstanceID, 1) + } + + // check if we are shutting down or failing + if shuttingDown || tr.IsFailing() { + // reply to all waiting queries + tr.Lock() + for id, inFlight := range tr.inFlightQueries { + close(inFlight.Response) + delete(tr.inFlightQueries, id) + } + tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds + tr.Unlock() + + // hint network environment at failed connection + netenv.ReportFailedConnection() + + cancelConnCtx() // The linter said so. Don't even... + return nil + } + + // wait until there is something to do + tr.Lock() + waiting := len(tr.inFlightQueries) + tr.Unlock() + if waiting > 0 { + // queue abandoned queries + ignoreBefore := time.Now().Add(-ignoreQueriesAfter) + currentConnInstanceID := atomic.LoadUint32(tr.connInstanceID) + tr.Lock() + for id, inFlight := range tr.inFlightQueries { + if inFlight.Started.Before(ignoreBefore) { + // remove + delete(tr.inFlightQueries, id) + } else if inFlight.ConnInstanceID != currentConnInstanceID { + inFlight.ConnInstanceID = currentConnInstanceID + // re-inject + select { + case tr.queries <- inFlight.Msg: + default: + log.Warningf("resolver: failed to re-inject abandoned query to %s", tr.resolver.Name) + } + } + } + tr.Unlock() + } else { + // wait for first query + select { + case <-workerCtx.Done(): + // abort + shuttingDown = true + continue connMgmt + case msg := <-tr.queries: + // re-insert, we will handle later + select { + case tr.queries <- msg: + default: + log.Warningf("resolver: failed to re-inject waking query to %s", tr.resolver.Name) + } + } + } + + // create connection + connCtx, cancelConnCtx = context.WithCancel(workerCtx) + // refresh dialer for authenticated local address + tr.dnsClient.Dialer = &net.Dialer{ + LocalAddr: getLocalAddr("tcp"), + Timeout: defaultConnectTimeout, + KeepAlive: defaultClientTTL, + } + // connect + c, err := tr.dnsClient.Dial(tr.resolver.ServerAddress) + if err != nil { + tr.ReportFailure() + log.Debugf("resolver: failed to connect to %s (%s)", tr.resolver.Name, tr.resolver.ServerAddress) + continue connMgmt + } + tr.dnsConnection = c + log.Debugf("resolver: connected to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr()) + + // hint network environment at successful connection + netenv.ReportSuccessfulConnection() + + // reset timer + connTimer.Stop() + select { + case <-connTimer.C: // try to empty the timer + default: + } + connTimer.Reset(tr.clientTTL) + recycleConn = false + + // start reader + module.StartWorker("dns client reader", func(ctx context.Context) error { + conn := tr.dnsConnection + for { + msg, err := conn.ReadMsg() + if err != nil { + if connClosing.SetToIf(false, true) { + cancelConnCtx() + tr.ReportFailure() + log.Warningf("resolver: read error from %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err) + } + return nil + } + incoming <- msg + } + }) + + // query management + for { + select { + case <-workerCtx.Done(): + // shutting down + shuttingDown = true + continue connMgmt + case <-connCtx.Done(): + // connection error + continue connMgmt + case <-connTimer.C: + // client TTL expired, recycle connection + recycleConn = true + // trigger check + select { + case incoming <- nil: + default: + // quere is full anyway, do nothing + } + + case msg := <-tr.queries: + // write query + _ = tr.dnsConnection.SetWriteDeadline(time.Now().Add(tr.dnsClient.WriteTimeout)) + err := tr.dnsConnection.WriteMsg(msg) + if err != nil { + if connClosing.SetToIf(false, true) { + cancelConnCtx() + tr.ReportFailure() + log.Warningf("resolver: write error to %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err) + } + continue connMgmt + } + + case msg := <-incoming: + + if msg != nil { + // handle query from resolver + tr.Lock() + inFlight, ok := tr.inFlightQueries[msg.Id] + if ok { + delete(tr.inFlightQueries, msg.Id) + } + tr.Unlock() + + if ok { + select { + case inFlight.Response <- msg: + // responded! + default: + // save to cache, if enabled + if !inFlight.Query.NoCaching { + // persist to database + rrCache := inFlight.MakeCacheRecord(msg) + rrCache.Clean(600) + err = rrCache.Save() + if err != nil { + log.Warningf( + "resolver: failed to cache RR for %s%s: %s", + inFlight.Query.FQDN, + inFlight.Query.QType.String(), + err, + ) + } + } + } + } else { + log.Debugf( + "resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v", + tr.resolver.Name, + tr.dnsConnection.RemoteAddr(), + msg.Id, + msg.Question, + ) + } + } + + // check if we have finished all queries and want to recycle conn + if recycleConn { + tr.Lock() + activeQueries := len(tr.inFlightQueries) + tr.Unlock() + if activeQueries == 0 { + log.Debugf("resolver: recycling conn to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr()) + continue connMgmt + } + } + + } + } + + } +} + +func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery { + // create msg + msg := &dns.Msg{} + msg.SetQuestion(q.FQDN, uint16(q.QType)) + + // save to waitlist + inFlight := &InFlightQuery{ + Query: q, + Msg: msg, + Response: make(chan *dns.Msg), + Resolver: tr.resolver, + Started: time.Now().UTC(), + ConnInstanceID: atomic.LoadUint32(tr.connInstanceID), + } + tr.Lock() + tr.inFlightQueries[msg.Id] = inFlight + tr.Unlock() + + // submit msg for writing + tr.queries <- msg + + // make sure client is started + if tr.clientStarted.SetToIf(false, true) { + module.StartWorker("dns client", tr.client) + } + + return inFlight +} + +// Query executes the given query against the resolver. +func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { + // submit to client + inFlight := tr.submitQuery(ctx, q) + var reply *dns.Msg + + select { + case reply = <-inFlight.Response: + case <-time.After(defaultRequestTimeout): + tr.ReportFailure() + return nil, ErrTimeout + } + + if tr.resolver.IsBlockedUpstream(reply) { + return nil, &BlockedUpstreamError{tr.resolver.GetName()} + } + + return inFlight.MakeCacheRecord(reply), nil +} diff --git a/resolver/resolver-tcp_test.go b/resolver/resolver-tcp_test.go new file mode 100644 index 00000000..082322ae --- /dev/null +++ b/resolver/resolver-tcp_test.go @@ -0,0 +1,72 @@ +package resolver + +import ( + "context" + "fmt" + "io" + "os" + "runtime/pprof" + "sync" + "testing" + "time" + + "github.com/miekg/dns" +) + +func testTCPQuery(t *testing.T, wg *sync.WaitGroup, rc ResolverConn, q *Query) { + start := time.Now() + _, err := rc.Query(context.TODO(), q) + if err != nil { + t.Logf("client failed: %s", err) //nolint:staticcheck + wg.Done() + return + } + t.Logf("resolved %s in %s", q.FQDN, time.Since(start)) + wg.Done() +} + +func TestTCPResolver(t *testing.T) { + // skip if short - this test depends on the Internet and might fail randomly + if testing.Short() { + t.Skip() + } + + go func() { + time.Sleep(15 * time.Second) + fmt.Fprintln(os.Stderr, "===== TAKING TOO LONG FOR SHUTDOWN =====") + printStackTo(os.Stderr) + os.Exit(1) + }() + + // create separate resolver for this test + resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config") + // resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config") + if err != nil { + t.Fatal(err) + } + + started := time.Now() + + wg := &sync.WaitGroup{} + wg.Add(100) + for i := 0; i < 100; i++ { + go testTCPQuery(t, wg, resolver.Conn, &Query{ //nolint:staticcheck + FQDN: <-domainFeed, + QType: dns.Type(dns.TypeA), + }) + } + wg.Wait() + + t.Logf("time taken: %s", time.Since(started)) +} + +func printStackTo(writer io.Writer) { + fmt.Fprintln(writer, "=== PRINTING TRACES ===") + fmt.Fprintln(writer, "=== GOROUTINES ===") + _ = pprof.Lookup("goroutine").WriteTo(writer, 1) + fmt.Fprintln(writer, "=== BLOCKING ===") + _ = pprof.Lookup("block").WriteTo(writer, 1) + fmt.Fprintln(writer, "=== MUTEXES ===") + _ = pprof.Lookup("mutex").WriteTo(writer, 1) + fmt.Fprintln(writer, "=== END TRACES ===") +} diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 5d65fed3..5b657378 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -62,6 +62,20 @@ func formatIPAndPort(ip net.IP, port uint16) string { return address } +func resolverConnFactory(resolver *Resolver) ResolverConn { + switch resolver.ServerType { + case ServerTypeTCP: + return NewTCPResolver(resolver) + case ServerTypeDoT: + return NewTCPResolver(resolver).UseTLS() + default: + return &BasicResolverConn{ + clientManager: clientManagerFactory(resolver.ServerType)(resolver), + resolver: resolver, + } + } +} + func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager { switch serverType { case ServerTypeDNS: @@ -129,12 +143,7 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { UpstreamBlockDetection: blockType, } - newConn := &BasicResolverConn{ - clientManager: clientManagerFactory(u.Scheme)(new), - resolver: new, - } - - new.Conn = newConn + new.Conn = resolverConnFactory(new) return new, false, nil } diff --git a/resolver/resolver-scopes.go b/resolver/scopes.go similarity index 100% rename from resolver/resolver-scopes.go rename to resolver/scopes.go