Add TCP/TLS pipelining dns resolver

This commit is contained in:
Daniel 2020-06-16 15:21:05 +02:00
parent fe3b61f1a3
commit f7320d760d
12 changed files with 672 additions and 208 deletions

3
.gitignore vendored
View file

@ -12,6 +12,9 @@ dist
# vendor dir
vendor
# testing
testing
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a

View file

@ -54,7 +54,9 @@ func TestMain(m *testing.M, module *modules.Module) {
// are shutdown.
func TestMainWithHooks(m *testing.M, module *modules.Module, afterStartFn, beforeStopFn TestHookFunc) {
// enable module for testing
if module != nil {
module.Enable()
}
// switch databases to memory only
base.DefaultDatabaseStorageType = "hashmap"

View file

@ -6,6 +6,128 @@ import (
"github.com/safing/portmaster/core/pmtesting"
)
var (
domainFeed = make(chan string)
)
func TestMain(m *testing.M) {
pmtesting.TestMain(m, module)
}
func init() {
go feedDomains()
}
func feedDomains() {
for {
for _, domain := range testDomains {
domainFeed <- domain
}
}
}
// Data
var (
testDomains = []string{
"facebook.com.",
"google.com.",
"youtube.com.",
"twitter.com.",
"instagram.com.",
"linkedin.com.",
"microsoft.com.",
"apple.com.",
"wikipedia.org.",
"plus.google.com.",
"en.wikipedia.org.",
"googletagmanager.com.",
"youtu.be.",
"adobe.com.",
"vimeo.com.",
"pinterest.com.",
"itunes.apple.com.",
"play.google.com.",
"maps.google.com.",
"goo.gl.",
"wordpress.com.",
"blogspot.com.",
"bit.ly.",
"github.com.",
"player.vimeo.com.",
"amazon.com.",
"wordpress.org.",
"docs.google.com.",
"yahoo.com.",
"mozilla.org.",
"tumblr.com.",
"godaddy.com.",
"flickr.com.",
"parked-content.godaddy.com.",
"drive.google.com.",
"support.google.com.",
"apache.org.",
"gravatar.com.",
"europa.eu.",
"qq.com.",
"w3.org.",
"nytimes.com.",
"reddit.com.",
"macromedia.com.",
"get.adobe.com.",
"soundcloud.com.",
"sourceforge.net.",
"sites.google.com.",
"nih.gov.",
"amazonaws.com.",
"t.co.",
"support.microsoft.com.",
"forbes.com.",
"theguardian.com.",
"cnn.com.",
"github.io.",
"bbc.co.uk.",
"dropbox.com.",
"whatsapp.com.",
"medium.com.",
"creativecommons.org.",
"www.ncbi.nlm.nih.gov.",
"httpd.apache.org.",
"archive.org.",
"ec.europa.eu.",
"php.net.",
"apps.apple.com.",
"weebly.com.",
"support.apple.com.",
"weibo.com.",
"wixsite.com.",
"issuu.com.",
"who.int.",
"paypal.com.",
"m.facebook.com.",
"oracle.com.",
"msn.com.",
"gnu.org.",
"tinyurl.com.",
"reuters.com.",
"l.facebook.com.",
"cloudflare.com.",
"wsj.com.",
"washingtonpost.com.",
"domainmarket.com.",
"imdb.com.",
"bbc.com.",
"bing.com.",
"accounts.google.com.",
"vk.com.",
"api.whatsapp.com.",
"opera.com.",
"cdc.gov.",
"slideshare.net.",
"wpa.qq.com.",
"harvard.edu.",
"mit.edu.",
"code.google.com.",
"wikimedia.org.",
}
)

View file

@ -1,189 +0,0 @@
package resolver
import (
"sync"
"sync/atomic"
"testing"
"github.com/miekg/dns"
)
var (
domainFeed = make(chan string)
)
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
dnsClient := brc.clientManager.getDNSClient()
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get connection
conn, new, err := dnsClient.getConn()
if err != nil {
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
}
if new {
atomic.AddUint32(newCnt, 1)
}
// query server
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
if err != nil {
t.Fatal(err) //nolint:staticcheck
}
if reply == nil {
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
}
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
dnsClient.addToPool()
wg.Done()
}
func TestClientPooling(t *testing.T) {
// skip if short - this test depends on the Internet and might fail randomly
if testing.Short() {
t.Skip()
}
go feedDomains()
// create separate resolver for this test
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
if err != nil {
t.Fatal(err)
}
brc := resolver.Conn.(*BasicResolverConn)
wg := &sync.WaitGroup{}
var newCnt uint32
for i := 0; i < 10; i++ {
wg.Add(10)
for i := 0; i < 10; i++ {
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
if newCnt > uint32(10+i) {
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
}
}
}
func feedDomains() {
for {
for _, domain := range poolingTestDomains {
domainFeed <- domain
}
}
}
// Data
var (
poolingTestDomains = []string{
"facebook.com.",
"google.com.",
"youtube.com.",
"twitter.com.",
"instagram.com.",
"linkedin.com.",
"microsoft.com.",
"apple.com.",
"wikipedia.org.",
"plus.google.com.",
"en.wikipedia.org.",
"googletagmanager.com.",
"youtu.be.",
"adobe.com.",
"vimeo.com.",
"pinterest.com.",
"itunes.apple.com.",
"play.google.com.",
"maps.google.com.",
"goo.gl.",
"wordpress.com.",
"blogspot.com.",
"bit.ly.",
"github.com.",
"player.vimeo.com.",
"amazon.com.",
"wordpress.org.",
"docs.google.com.",
"yahoo.com.",
"mozilla.org.",
"tumblr.com.",
"godaddy.com.",
"flickr.com.",
"parked-content.godaddy.com.",
"drive.google.com.",
"support.google.com.",
"apache.org.",
"gravatar.com.",
"europa.eu.",
"qq.com.",
"w3.org.",
"nytimes.com.",
"reddit.com.",
"macromedia.com.",
"get.adobe.com.",
"soundcloud.com.",
"sourceforge.net.",
"sites.google.com.",
"nih.gov.",
"amazonaws.com.",
"t.co.",
"support.microsoft.com.",
"forbes.com.",
"theguardian.com.",
"cnn.com.",
"github.io.",
"bbc.co.uk.",
"dropbox.com.",
"whatsapp.com.",
"medium.com.",
"creativecommons.org.",
"www.ncbi.nlm.nih.gov.",
"httpd.apache.org.",
"archive.org.",
"ec.europa.eu.",
"php.net.",
"apps.apple.com.",
"weebly.com.",
"support.apple.com.",
"weibo.com.",
"wixsite.com.",
"issuu.com.",
"who.int.",
"paypal.com.",
"m.facebook.com.",
"oracle.com.",
"msn.com.",
"gnu.org.",
"tinyurl.com.",
"reuters.com.",
"l.facebook.com.",
"cloudflare.com.",
"wsj.com.",
"washingtonpost.com.",
"domainmarket.com.",
"imdb.com.",
"bbc.com.",
"bing.com.",
"accounts.google.com.",
"vk.com.",
"api.whatsapp.com.",
"opera.com.",
"cdc.gov.",
"slideshare.net.",
"wpa.qq.com.",
"harvard.edu.",
"mit.edu.",
"code.google.com.",
"wikimedia.org.",
}
)

View file

@ -22,6 +22,8 @@ var (
ErrBlocked = errors.New("query was blocked")
// ErrLocalhost is returned to *.localhost queries
ErrLocalhost = errors.New("query for localhost")
// ErrTimeout is returned when a query times out
ErrTimeout = errors.New("query timed out")
// detailed errors

View file

@ -8,12 +8,13 @@ import (
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/utils"
)
const (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 3 * time.Second // dns query
defaultConnectTimeout = 2 * time.Second // tcp/tls
defaultConnectTimeout = 5 * time.Second // tcp/tls
connectionEOLGracePeriod = 7 * time.Second
)
@ -39,12 +40,12 @@ type dnsClientManager struct {
lock sync.Mutex
// set by creator
serverAddress string
resolver *Resolver
ttl time.Duration // force refresh of connection to reduce traceability
factory func() *dns.Client
// internal
pool sync.Pool
pool utils.StablePool
}
type dnsClient struct {
@ -57,7 +58,7 @@ type dnsClient struct {
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
if dc.conn == nil {
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
dc.conn, err = dc.client.Dial(dc.mgr.resolver.ServerAddress)
if err != nil {
return nil, false, err
}
@ -78,7 +79,7 @@ func (dc *dnsClient) destroy() {
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
resolver: resolver,
ttl: 0, // new client for every request, as we need to randomize the port
factory: func() *dns.Client {
return &dns.Client{
@ -93,7 +94,7 @@ func newDNSClientManager(resolver *Resolver) *dnsClientManager {
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
resolver: resolver,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
@ -111,7 +112,7 @@ func newTCPClientManager(resolver *Resolver) *dnsClientManager {
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
resolver: resolver,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{

View file

@ -0,0 +1,83 @@
package resolver
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
)
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
dnsClient := brc.clientManager.getDNSClient()
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get connection
conn, new, err := dnsClient.getConn()
if err != nil {
t.Logf("failed to connect: %s", err) //nolint:staticcheck
wg.Done()
return
}
if new {
atomic.AddUint32(newCnt, 1)
}
// query server
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
if err != nil {
t.Logf("client failed: %s", err) //nolint:staticcheck
wg.Done()
return
}
if reply == nil {
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
}
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
dnsClient.addToPool()
wg.Done()
}
func TestClientPooling(t *testing.T) {
// skip if short - this test depends on the Internet and might fail randomly
if testing.Short() {
t.Skip()
}
// create separate resolver for this test
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
// resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config")
if err != nil {
t.Fatal(err)
}
brc := &BasicResolverConn{
clientManager: clientManagerFactory(resolver.ServerType)(resolver),
resolver: resolver,
}
resolver.Conn = brc
started := time.Now()
wg := &sync.WaitGroup{}
var newCnt uint32
for i := 0; i < 10; i++ {
wg.Add(10)
for j := 0; j < 10; j++ {
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
if newCnt > uint32(10+i) {
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
}
}
t.Logf("time taken: %s", time.Since(started))
}

359
resolver/resolver-tcp.go Normal file
View file

@ -0,0 +1,359 @@
package resolver
import (
"context"
"crypto/tls"
"net"
"sync/atomic"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netenv"
"github.com/tevino/abool"
)
const (
tcpWriteTimeout = 1 * time.Second
ignoreQueriesAfter = 10 * time.Minute
)
// TCPResolver is a resolver using just a single tcp connection with pipelining.
type TCPResolver struct {
BasicResolverConn
clientTTL time.Duration
dnsClient *dns.Client
dnsConnection *dns.Conn
connInstanceID *uint32
queries chan *dns.Msg
inFlightQueries map[uint16]*InFlightQuery
clientStarted *abool.AtomicBool
}
// InFlightQuery represents an in flight query of a TCPResolver.
type InFlightQuery struct {
Query *Query
Msg *dns.Msg
Response chan *dns.Msg
Resolver *Resolver
Started time.Time
ConnInstanceID uint32
}
// MakeCacheRecord creates an RCache record from a reply.
func (ifq *InFlightQuery) MakeCacheRecord(reply *dns.Msg) *RRCache {
return &RRCache{
Domain: ifq.Query.FQDN,
Question: ifq.Query.QType,
Answer: reply.Answer,
Ns: reply.Ns,
Extra: reply.Extra,
Server: ifq.Resolver.Server,
ServerScope: ifq.Resolver.ServerIPScope,
}
}
// NewTCPResolver returns a new TPCResolver.
func NewTCPResolver(resolver *Resolver) *TCPResolver {
var instanceID uint32
return &TCPResolver{
BasicResolverConn: BasicResolverConn{
resolver: resolver,
},
clientTTL: defaultClientTTL,
dnsClient: &dns.Client{
Net: "tcp",
Timeout: defaultConnectTimeout,
WriteTimeout: tcpWriteTimeout,
},
connInstanceID: &instanceID,
queries: make(chan *dns.Msg, 100),
inFlightQueries: make(map[uint16]*InFlightQuery),
clientStarted: abool.New(),
}
}
// UseTLS enabled TLS for the TCPResolver. TLS settings must be correctly configured in the Resolver.
func (tr *TCPResolver) UseTLS() *TCPResolver {
tr.dnsClient.Net = "tcp-tls"
tr.dnsClient.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: tr.resolver.VerifyDomain,
// TODO: use portbase rng
}
return tr
}
func (tr *TCPResolver) client(workerCtx context.Context) error { //nolint:gocognit,gocyclo // TODO
connTimer := time.NewTimer(tr.clientTTL)
connClosing := abool.New()
var connCtx context.Context
var cancelConnCtx func()
var recycleConn bool
var shuttingDown bool
var incoming = make(chan *dns.Msg, 100)
// enable client restarting after crash
defer tr.clientStarted.UnSet()
connMgmt:
for {
// cleanup old connection
if tr.dnsConnection != nil {
connClosing.Set()
_ = tr.dnsConnection.Close()
cancelConnCtx()
tr.dnsConnection = nil
atomic.AddUint32(tr.connInstanceID, 1)
}
// check if we are shutting down or failing
if shuttingDown || tr.IsFailing() {
// reply to all waiting queries
tr.Lock()
for id, inFlight := range tr.inFlightQueries {
close(inFlight.Response)
delete(tr.inFlightQueries, id)
}
tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
tr.Unlock()
// hint network environment at failed connection
netenv.ReportFailedConnection()
cancelConnCtx() // The linter said so. Don't even...
return nil
}
// wait until there is something to do
tr.Lock()
waiting := len(tr.inFlightQueries)
tr.Unlock()
if waiting > 0 {
// queue abandoned queries
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
currentConnInstanceID := atomic.LoadUint32(tr.connInstanceID)
tr.Lock()
for id, inFlight := range tr.inFlightQueries {
if inFlight.Started.Before(ignoreBefore) {
// remove
delete(tr.inFlightQueries, id)
} else if inFlight.ConnInstanceID != currentConnInstanceID {
inFlight.ConnInstanceID = currentConnInstanceID
// re-inject
select {
case tr.queries <- inFlight.Msg:
default:
log.Warningf("resolver: failed to re-inject abandoned query to %s", tr.resolver.Name)
}
}
}
tr.Unlock()
} else {
// wait for first query
select {
case <-workerCtx.Done():
// abort
shuttingDown = true
continue connMgmt
case msg := <-tr.queries:
// re-insert, we will handle later
select {
case tr.queries <- msg:
default:
log.Warningf("resolver: failed to re-inject waking query to %s", tr.resolver.Name)
}
}
}
// create connection
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
// refresh dialer for authenticated local address
tr.dnsClient.Dialer = &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
}
// connect
c, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
if err != nil {
tr.ReportFailure()
log.Debugf("resolver: failed to connect to %s (%s)", tr.resolver.Name, tr.resolver.ServerAddress)
continue connMgmt
}
tr.dnsConnection = c
log.Debugf("resolver: connected to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr())
// hint network environment at successful connection
netenv.ReportSuccessfulConnection()
// reset timer
connTimer.Stop()
select {
case <-connTimer.C: // try to empty the timer
default:
}
connTimer.Reset(tr.clientTTL)
recycleConn = false
// start reader
module.StartWorker("dns client reader", func(ctx context.Context) error {
conn := tr.dnsConnection
for {
msg, err := conn.ReadMsg()
if err != nil {
if connClosing.SetToIf(false, true) {
cancelConnCtx()
tr.ReportFailure()
log.Warningf("resolver: read error from %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err)
}
return nil
}
incoming <- msg
}
})
// query management
for {
select {
case <-workerCtx.Done():
// shutting down
shuttingDown = true
continue connMgmt
case <-connCtx.Done():
// connection error
continue connMgmt
case <-connTimer.C:
// client TTL expired, recycle connection
recycleConn = true
// trigger check
select {
case incoming <- nil:
default:
// quere is full anyway, do nothing
}
case msg := <-tr.queries:
// write query
_ = tr.dnsConnection.SetWriteDeadline(time.Now().Add(tr.dnsClient.WriteTimeout))
err := tr.dnsConnection.WriteMsg(msg)
if err != nil {
if connClosing.SetToIf(false, true) {
cancelConnCtx()
tr.ReportFailure()
log.Warningf("resolver: write error to %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err)
}
continue connMgmt
}
case msg := <-incoming:
if msg != nil {
// handle query from resolver
tr.Lock()
inFlight, ok := tr.inFlightQueries[msg.Id]
if ok {
delete(tr.inFlightQueries, msg.Id)
}
tr.Unlock()
if ok {
select {
case inFlight.Response <- msg:
// responded!
default:
// save to cache, if enabled
if !inFlight.Query.NoCaching {
// persist to database
rrCache := inFlight.MakeCacheRecord(msg)
rrCache.Clean(600)
err = rrCache.Save()
if err != nil {
log.Warningf(
"resolver: failed to cache RR for %s%s: %s",
inFlight.Query.FQDN,
inFlight.Query.QType.String(),
err,
)
}
}
}
} else {
log.Debugf(
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
tr.resolver.Name,
tr.dnsConnection.RemoteAddr(),
msg.Id,
msg.Question,
)
}
}
// check if we have finished all queries and want to recycle conn
if recycleConn {
tr.Lock()
activeQueries := len(tr.inFlightQueries)
tr.Unlock()
if activeQueries == 0 {
log.Debugf("resolver: recycling conn to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr())
continue connMgmt
}
}
}
}
}
}
func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
// create msg
msg := &dns.Msg{}
msg.SetQuestion(q.FQDN, uint16(q.QType))
// save to waitlist
inFlight := &InFlightQuery{
Query: q,
Msg: msg,
Response: make(chan *dns.Msg),
Resolver: tr.resolver,
Started: time.Now().UTC(),
ConnInstanceID: atomic.LoadUint32(tr.connInstanceID),
}
tr.Lock()
tr.inFlightQueries[msg.Id] = inFlight
tr.Unlock()
// submit msg for writing
tr.queries <- msg
// make sure client is started
if tr.clientStarted.SetToIf(false, true) {
module.StartWorker("dns client", tr.client)
}
return inFlight
}
// Query executes the given query against the resolver.
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// submit to client
inFlight := tr.submitQuery(ctx, q)
var reply *dns.Msg
select {
case reply = <-inFlight.Response:
case <-time.After(defaultRequestTimeout):
tr.ReportFailure()
return nil, ErrTimeout
}
if tr.resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{tr.resolver.GetName()}
}
return inFlight.MakeCacheRecord(reply), nil
}

View file

@ -0,0 +1,72 @@
package resolver
import (
"context"
"fmt"
"io"
"os"
"runtime/pprof"
"sync"
"testing"
"time"
"github.com/miekg/dns"
)
func testTCPQuery(t *testing.T, wg *sync.WaitGroup, rc ResolverConn, q *Query) {
start := time.Now()
_, err := rc.Query(context.TODO(), q)
if err != nil {
t.Logf("client failed: %s", err) //nolint:staticcheck
wg.Done()
return
}
t.Logf("resolved %s in %s", q.FQDN, time.Since(start))
wg.Done()
}
func TestTCPResolver(t *testing.T) {
// skip if short - this test depends on the Internet and might fail randomly
if testing.Short() {
t.Skip()
}
go func() {
time.Sleep(15 * time.Second)
fmt.Fprintln(os.Stderr, "===== TAKING TOO LONG FOR SHUTDOWN =====")
printStackTo(os.Stderr)
os.Exit(1)
}()
// create separate resolver for this test
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
// resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config")
if err != nil {
t.Fatal(err)
}
started := time.Now()
wg := &sync.WaitGroup{}
wg.Add(100)
for i := 0; i < 100; i++ {
go testTCPQuery(t, wg, resolver.Conn, &Query{ //nolint:staticcheck
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
t.Logf("time taken: %s", time.Since(started))
}
func printStackTo(writer io.Writer) {
fmt.Fprintln(writer, "=== PRINTING TRACES ===")
fmt.Fprintln(writer, "=== GOROUTINES ===")
_ = pprof.Lookup("goroutine").WriteTo(writer, 1)
fmt.Fprintln(writer, "=== BLOCKING ===")
_ = pprof.Lookup("block").WriteTo(writer, 1)
fmt.Fprintln(writer, "=== MUTEXES ===")
_ = pprof.Lookup("mutex").WriteTo(writer, 1)
fmt.Fprintln(writer, "=== END TRACES ===")
}

View file

@ -62,6 +62,20 @@ func formatIPAndPort(ip net.IP, port uint16) string {
return address
}
func resolverConnFactory(resolver *Resolver) ResolverConn {
switch resolver.ServerType {
case ServerTypeTCP:
return NewTCPResolver(resolver)
case ServerTypeDoT:
return NewTCPResolver(resolver).UseTLS()
default:
return &BasicResolverConn{
clientManager: clientManagerFactory(resolver.ServerType)(resolver),
resolver: resolver,
}
}
}
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
switch serverType {
case ServerTypeDNS:
@ -129,12 +143,7 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
UpstreamBlockDetection: blockType,
}
newConn := &BasicResolverConn{
clientManager: clientManagerFactory(u.Scheme)(new),
resolver: new,
}
new.Conn = newConn
new.Conn = resolverConnFactory(new)
return new, false, nil
}