Add TLS resolver connection reusing and pooling

Also, fix caching issues and add more tests
This commit is contained in:
Daniel 2020-05-15 22:43:06 +02:00
parent dd837e40e2
commit 53eb309e72
11 changed files with 510 additions and 61 deletions

View file

@ -1,6 +1,7 @@
package resolver
import (
"context"
"crypto/tls"
"net"
"sync"
@ -9,6 +10,12 @@ import (
"github.com/miekg/dns"
)
const (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 5 * time.Second
connectionEOLGracePeriod = 10 * time.Second
)
var (
localAddrFactory func(network string) net.Addr
)
@ -27,21 +34,72 @@ func getLocalAddr(network string) net.Addr {
return nil
}
type clientManager struct {
dnsClient *dns.Client
factory func() *dns.Client
type dnsClientManager struct {
lock sync.Mutex
lock sync.Mutex
refreshAfter time.Time
ttl time.Duration // force refresh of connection to reduce traceability
// set by creator
serverAddress string
ttl time.Duration // force refresh of connection to reduce traceability
factory func() *dns.Client
// internal
pool []*dnsClient
}
func newDNSClientManager(_ *Resolver) *clientManager {
return &clientManager{
ttl: 0, // new client for every request, as we need to randomize the port
type dnsClient struct {
mgr *dnsClientManager
inUse bool
useUntil time.Time
dead bool
inPool bool
poolIndex int
client *dns.Client
conn *dns.Conn
}
// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
if dc.conn == nil {
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
if err != nil {
return nil, false, err
}
return dc.conn, true, nil
}
return dc.conn, false, nil
}
func (dc *dnsClient) done() {
dc.mgr.lock.Lock()
defer dc.mgr.lock.Unlock()
dc.inUse = false
}
func (dc *dnsClient) destroy() {
dc.mgr.lock.Lock()
dc.inUse = true // block from being used
dc.dead = true // abort cleaning
if dc.inPool {
dc.inPool = false
dc.mgr.pool[dc.poolIndex] = nil
}
dc.mgr.lock.Unlock()
if dc.conn != nil {
_ = dc.conn.Close()
}
}
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: 0, // new client for every request, as we need to randomize the port
factory: func() *dns.Client {
return &dns.Client{
Timeout: 5 * time.Second,
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("udp"),
},
@ -50,25 +108,27 @@ func newDNSClientManager(_ *Resolver) *clientManager {
}
}
func newTCPClientManager(_ *Resolver) *clientManager {
return &clientManager{
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff)
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp",
Timeout: 5 * time.Second,
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
KeepAlive: 15 * time.Second,
KeepAlive: defaultClientTTL,
},
}
},
}
}
func newTLSClientManager(resolver *Resolver) *clientManager {
return &clientManager{
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff)
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp-tls",
@ -77,24 +137,90 @@ func newTLSClientManager(resolver *Resolver) *clientManager {
ServerName: resolver.VerifyDomain,
// TODO: use portbase rng
},
Timeout: 5 * time.Second,
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
KeepAlive: 15 * time.Second,
KeepAlive: defaultClientTTL,
},
}
},
}
}
func (cm *clientManager) getDNSClient() *dns.Client {
func (cm *dnsClientManager) getDNSClient() *dnsClient {
cm.lock.Lock()
defer cm.lock.Unlock()
if cm.dnsClient == nil || cm.ttl == 0 || time.Now().After(cm.refreshAfter) {
cm.dnsClient = cm.factory()
cm.refreshAfter = time.Now().Add(cm.ttl)
// return new immediately if a new client should be used for every request
if cm.ttl == 0 {
return &dnsClient{
mgr: cm,
client: cm.factory(),
}
}
return cm.dnsClient
// get first unused from pool
now := time.Now().UTC()
for _, dc := range cm.pool {
if dc != nil && !dc.inUse && now.Before(dc.useUntil) {
dc.inUse = true
return dc
}
}
// no available in pool, create new
newClient := &dnsClient{
mgr: cm,
inUse: true,
useUntil: now.Add(cm.ttl),
inPool: true,
client: cm.factory(),
}
newClient.startCleaner()
// find free spot in pool
for poolIndex, dc := range cm.pool {
if dc == nil {
cm.pool[poolIndex] = newClient
newClient.poolIndex = poolIndex
return newClient
}
}
// append to pool
cm.pool = append(cm.pool, newClient)
newClient.poolIndex = len(cm.pool) - 1
// TODO: shrink pool again?
return newClient
}
// startCleaner waits for EOL of the client and then removes it from the pool.
func (dc *dnsClient) startCleaner() {
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
select {
case <-time.After(dc.mgr.ttl + time.Second):
dc.mgr.lock.Lock()
cleanNow := dc.dead || !dc.inUse
dc.mgr.lock.Unlock()
if cleanNow {
dc.destroy()
return nil
}
case <-ctx.Done():
// give a short time before kill for graceful request completion
time.Sleep(100 * time.Millisecond)
}
// wait for grace period to end, then kill
select {
case <-time.After(connectionEOLGracePeriod):
case <-ctx.Done():
}
dc.destroy()
return nil
})
}

View file

@ -9,6 +9,8 @@ import (
"sync"
"time"
"github.com/safing/portmaster/network/netutils"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
@ -29,10 +31,11 @@ var (
questionsLock sync.Mutex
mDNSResolver = &Resolver{
Server: ServerSourceMDNS,
ServerType: ServerTypeDNS,
Source: ServerSourceMDNS,
Conn: &mDNSResolverConn{},
Server: ServerSourceMDNS,
ServerType: ServerTypeDNS,
ServerIPScope: netutils.SiteLocal,
Source: ServerSourceMDNS,
Conn: &mDNSResolverConn{},
}
)
@ -189,15 +192,21 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
// get entry from database
if saveFullRequest {
// get from database
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
// if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired:
// create new and do not append
if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() {
rrCache = &RRCache{
Domain: question.Name,
Question: dns.Type(question.Qtype),
Domain: question.Name,
Question: dns.Type(question.Qtype),
Server: mDNSResolver.Server,
ServerScope: mDNSResolver.ServerIPScope,
}
}
}
// add all entries to RRCache
for _, entry := range message.Answer {
if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) {
if saveFullRequest {
@ -289,9 +298,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
continue
}
rrCache = &RRCache{
Domain: v.Header().Name,
Question: dns.Type(v.Header().Class),
Answer: []dns.RR{v},
Domain: v.Header().Name,
Question: dns.Type(v.Header().Class),
Answer: []dns.RR{v},
Server: mDNSResolver.Server,
ServerScope: mDNSResolver.ServerIPScope,
}
rrCache.Clean(60)
err := rrCache.Save()

View file

@ -12,7 +12,7 @@ import (
var (
recordDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 2592000, // 30 days
CacheSize: 128,
CacheSize: 256,
})
)

View file

@ -0,0 +1,27 @@
package resolver
import "testing"
func TestNameRecordStorage(t *testing.T) {
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
testQuestion := "A"
testNameRecord := &NameRecord{
Domain: testDomain,
Question: testQuestion,
}
err := testNameRecord.Save()
if err != nil {
t.Fatal(err)
}
r, err := GetNameRecord(testDomain, testQuestion)
if err != nil {
t.Fatal(err)
}
if r.Domain != testDomain || r.Question != testQuestion {
t.Fatal("mismatch")
}
}

184
resolver/pooling_test.go Normal file
View file

@ -0,0 +1,184 @@
package resolver
import (
"sync"
"testing"
"github.com/miekg/dns"
)
var (
domainFeed = make(chan string)
)
func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) {
dnsClient := brc.clientManager.getDNSClient()
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get connection
conn, new, err := dnsClient.getConn()
if err != nil {
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
}
// query server
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
if err != nil {
t.Fatal(err) //nolint:staticcheck
}
if reply == nil {
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
}
t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl)
dnsClient.done()
wg.Done()
}
func TestClientPooling(t *testing.T) {
// skip if short - this test depends on the Internet and might fail randomly
if testing.Short() {
t.Skip()
}
go feedDomains()
// create separate resolver for this test
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
if err != nil {
t.Fatal(err)
}
brc := resolver.Conn.(*BasicResolverConn)
wg := &sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(10)
for i := 0; i < 10; i++ {
go testQuery(t, wg, brc, &Query{
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
if len(brc.clientManager.pool) != 10 {
t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool))
}
}
}
func feedDomains() {
for {
for _, domain := range poolingTestDomains {
domainFeed <- domain
}
}
}
// Data
var (
poolingTestDomains = []string{
"facebook.com.",
"google.com.",
"youtube.com.",
"twitter.com.",
"instagram.com.",
"linkedin.com.",
"microsoft.com.",
"apple.com.",
"wikipedia.org.",
"plus.google.com.",
"en.wikipedia.org.",
"googletagmanager.com.",
"youtu.be.",
"adobe.com.",
"vimeo.com.",
"pinterest.com.",
"itunes.apple.com.",
"play.google.com.",
"maps.google.com.",
"goo.gl.",
"wordpress.com.",
"blogspot.com.",
"bit.ly.",
"github.com.",
"player.vimeo.com.",
"amazon.com.",
"wordpress.org.",
"docs.google.com.",
"yahoo.com.",
"mozilla.org.",
"tumblr.com.",
"godaddy.com.",
"flickr.com.",
"parked-content.godaddy.com.",
"drive.google.com.",
"support.google.com.",
"apache.org.",
"gravatar.com.",
"europa.eu.",
"qq.com.",
"w3.org.",
"nytimes.com.",
"reddit.com.",
"macromedia.com.",
"get.adobe.com.",
"soundcloud.com.",
"sourceforge.net.",
"sites.google.com.",
"nih.gov.",
"amazonaws.com.",
"t.co.",
"support.microsoft.com.",
"forbes.com.",
"theguardian.com.",
"cnn.com.",
"github.io.",
"bbc.co.uk.",
"dropbox.com.",
"whatsapp.com.",
"medium.com.",
"creativecommons.org.",
"www.ncbi.nlm.nih.gov.",
"httpd.apache.org.",
"archive.org.",
"ec.europa.eu.",
"php.net.",
"apps.apple.com.",
"weebly.com.",
"support.apple.com.",
"weibo.com.",
"wixsite.com.",
"issuu.com.",
"who.int.",
"paypal.com.",
"m.facebook.com.",
"oracle.com.",
"msn.com.",
"gnu.org.",
"tinyurl.com.",
"reuters.com.",
"l.facebook.com.",
"cloudflare.com.",
"wsj.com.",
"washingtonpost.com.",
"domainmarket.com.",
"imdb.com.",
"bbc.com.",
"bing.com.",
"accounts.google.com.",
"vk.com.",
"api.whatsapp.com.",
"opera.com.",
"cdc.gov.",
"slideshare.net.",
"wpa.qq.com.",
"harvard.edu.",
"mit.edu.",
"code.google.com.",
"wikimedia.org.",
}
)

View file

@ -114,6 +114,7 @@ func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) {
rrCache.MixAnswers()
return rrCache, nil
}
log.Tracer(ctx).Debugf("resolver: waited for another %s%s query, but cache missed!", q.FQDN, q.QType)
// if cache is still empty or non-compliant, go ahead and just query
} else {
// we are the first!
@ -132,14 +133,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
if err != nil {
if err != database.ErrNotFound {
log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
log.Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
}
return nil
}
// get resolver that rrCache was resolved with
resolver := getResolverByIDWithLocking(rrCache.Server)
resolver := getActiveResolverByIDWithLocking(rrCache.Server)
if resolver == nil {
log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server)
return nil
}
@ -165,6 +166,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
})
}
log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0)))
return rrCache
}

View file

@ -2,6 +2,7 @@ package resolver
import (
"context"
"errors"
"net"
"sync"
"time"
@ -92,7 +93,7 @@ type BasicResolverConn struct {
sync.Mutex // for lastFail
resolver *Resolver
clientManager *clientManager
clientManager *dnsClientManager
lastFail time.Time
}
@ -126,18 +127,41 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
// start
var reply *dns.Msg
var ttl time.Duration
var err error
for i := 0; i < 3; i++ {
var conn *dns.Conn
var new bool
var i int
// log query time
// qStart := time.Now()
reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress)
// log.Tracef("resolver: query to %s took %s", resolver.Server, time.Now().Sub(qStart))
for ; i < 5; i++ {
// first get connection
dc := brc.clientManager.getDNSClient()
conn, new, err = dc.getConn()
if err != nil {
log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err)
// remove client from pool
dc.destroy()
// try again
continue
}
if new {
log.Tracer(ctx).Tracef("resolver: created new connection to %s", resolver.ServerAddress)
} else {
log.Tracer(ctx).Tracef("resolver: reusing connection to %s", resolver.ServerAddress)
}
// query server
reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn)
log.Tracer(ctx).Tracef("resolver: query took %s", ttl)
// error handling
if err != nil {
log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err)
// remove client from pool
dc.destroy()
// TODO: handle special cases
// 1. connect: network is unreachable
// 2. timeout
@ -148,13 +172,23 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
// temporary error
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server)
// try again
continue
}
// permanent error
break
} else if reply == nil {
// remove client from pool
dc.destroy()
log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server)
return nil, errors.New("internal error")
}
// make client available again
dc.done()
if resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{resolver.GetName()}
}
@ -166,12 +200,15 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
if err != nil {
return nil, err
// TODO: mark as failed
} else if reply == nil {
log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), i+1)
return nil, errors.New("internal error")
}
// hint network environment at successful connection
netenv.ReportSuccessfulConnection()
new := &RRCache{
newRecord := &RRCache{
Domain: q.FQDN,
Question: q.QType,
Answer: reply.Answer,
@ -182,5 +219,5 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
}
// TODO: check if reply.Answer is valid
return new, nil
return newRecord, nil
}

View file

@ -25,7 +25,7 @@ var (
globalResolvers []*Resolver // all (global) resolvers
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
allResolvers map[string]*Resolver // lookup map of all resolvers
activeResolvers map[string]*Resolver // lookup map of all resolvers
resolversLock sync.RWMutex
dupReqMap = make(map[string]*sync.WaitGroup)
@ -41,11 +41,11 @@ func indexOfScope(domain string, list []*Scope) int {
return -1
}
func getResolverByIDWithLocking(server string) *Resolver {
resolversLock.Lock()
defer resolversLock.Unlock()
func getActiveResolverByIDWithLocking(server string) *Resolver {
resolversLock.RLock()
defer resolversLock.RUnlock()
resolver, ok := allResolvers[server]
resolver, ok := activeResolvers[server]
if ok {
return resolver
}
@ -62,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string {
return address
}
func clientManagerFactory(serverType string) func(*Resolver) *clientManager {
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
switch serverType {
case ServerTypeDNS:
return newDNSClientManager
@ -152,8 +152,8 @@ func configureSearchDomains(resolver *Resolver, searches []string) {
}
}
func getConfiguredResolvers() (resolvers []*Resolver) {
for _, server := range configuredNameServers() {
func getConfiguredResolvers(list []string) (resolvers []*Resolver) {
for _, server := range list {
resolver, skip, err := createResolver(server, "config")
if err != nil {
// TODO(ppacher): module error
@ -199,19 +199,40 @@ func loadResolvers() {
defer resolversLock.Unlock()
newResolvers := append(
getConfiguredResolvers(),
getConfiguredResolvers(configuredNameServers()),
getSystemResolvers()...,
)
// save resolvers
globalResolvers = newResolvers
if len(globalResolvers) == 0 {
log.Criticalf("resolver: no (valid) dns servers found in configuration and system")
// TODO(module error)
if len(newResolvers) == 0 {
msg := "no (valid) dns servers found in (user) configuration or system, falling back to defaults"
log.Warningf("resolver: %s", msg)
module.Warning("no-valid-user-resolvers", msg)
// load defaults directly, overriding config system
newResolvers = getConfiguredResolvers(defaultNameServers)
if len(newResolvers) == 0 {
msg = "no (valid) dns servers found in configuration or system"
log.Criticalf("resolver: %s", msg)
module.Error("no-valid-default-resolvers", msg)
return
}
}
// save resolvers
globalResolvers = newResolvers
// assing resolvers to scopes
setLocalAndScopeResolvers(globalResolvers)
// set active resolvers (for cache validation)
// reset
activeResolvers = make(map[string]*Resolver)
// add
for _, resolver := range newResolvers {
activeResolvers[resolver.Server] = resolver
}
activeResolvers[mDNSResolver.Server] = mDNSResolver
// log global resolvers
if len(globalResolvers) > 0 {
log.Trace("resolver: loaded global resolvers:")

View file

@ -65,12 +65,12 @@ func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) (
for _, rr := range rrCache.Answer {
switch v := rr.(type) {
case *dns.A:
log.Infof("A: %s", v.A.String())
// log.Debugf("A: %s", v.A.String())
if ip == v.A.String() {
return ptrName, nil
}
case *dns.AAAA:
log.Infof("AAAA: %s", v.AAAA.String())
// log.Debugf("AAAA: %s", v.AAAA.String())
if ip == v.AAAA.String() {
return ptrName, nil
}

41
resolver/rrcache_test.go Normal file
View file

@ -0,0 +1,41 @@
package resolver
import (
"testing"
"github.com/miekg/dns"
)
func TestCaching(t *testing.T) {
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
testQuestion := "A"
testNameRecord := &NameRecord{
Domain: testDomain,
Question: testQuestion,
}
err := testNameRecord.Save()
if err != nil {
t.Fatal(err)
}
rrCache, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
if err != nil {
t.Fatal(err)
}
err = rrCache.Save()
if err != nil {
t.Fatal(err)
}
rrCache2, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
if err != nil {
t.Fatal(err)
}
if rrCache2.Domain != rrCache.Domain {
t.Fatal("something very is wrong")
}
}