Add TLS resolver connection reusing and pooling

Also, fix caching issues and add more tests
This commit is contained in:
Daniel 2020-05-15 22:43:06 +02:00
parent dd837e40e2
commit 53eb309e72
11 changed files with 510 additions and 61 deletions

View file

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"sync" "sync"
@ -9,6 +10,12 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
const (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 5 * time.Second
connectionEOLGracePeriod = 10 * time.Second
)
var ( var (
localAddrFactory func(network string) net.Addr localAddrFactory func(network string) net.Addr
) )
@ -27,21 +34,72 @@ func getLocalAddr(network string) net.Addr {
return nil return nil
} }
type clientManager struct { type dnsClientManager struct {
dnsClient *dns.Client lock sync.Mutex
factory func() *dns.Client
lock sync.Mutex // set by creator
refreshAfter time.Time serverAddress string
ttl time.Duration // force refresh of connection to reduce traceability ttl time.Duration // force refresh of connection to reduce traceability
factory func() *dns.Client
// internal
pool []*dnsClient
} }
func newDNSClientManager(_ *Resolver) *clientManager { type dnsClient struct {
return &clientManager{ mgr *dnsClientManager
ttl: 0, // new client for every request, as we need to randomize the port
inUse bool
useUntil time.Time
dead bool
inPool bool
poolIndex int
client *dns.Client
conn *dns.Conn
}
// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
if dc.conn == nil {
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
if err != nil {
return nil, false, err
}
return dc.conn, true, nil
}
return dc.conn, false, nil
}
func (dc *dnsClient) done() {
dc.mgr.lock.Lock()
defer dc.mgr.lock.Unlock()
dc.inUse = false
}
func (dc *dnsClient) destroy() {
dc.mgr.lock.Lock()
dc.inUse = true // block from being used
dc.dead = true // abort cleaning
if dc.inPool {
dc.inPool = false
dc.mgr.pool[dc.poolIndex] = nil
}
dc.mgr.lock.Unlock()
if dc.conn != nil {
_ = dc.conn.Close()
}
}
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: 0, // new client for every request, as we need to randomize the port
factory: func() *dns.Client { factory: func() *dns.Client {
return &dns.Client{ return &dns.Client{
Timeout: 5 * time.Second, Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{ Dialer: &net.Dialer{
LocalAddr: getLocalAddr("udp"), LocalAddr: getLocalAddr("udp"),
}, },
@ -50,25 +108,27 @@ func newDNSClientManager(_ *Resolver) *clientManager {
} }
} }
func newTCPClientManager(_ *Resolver) *clientManager { func newTCPClientManager(resolver *Resolver) *dnsClientManager {
return &clientManager{ return &dnsClientManager{
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) serverAddress: resolver.ServerAddress,
ttl: defaultClientTTL,
factory: func() *dns.Client { factory: func() *dns.Client {
return &dns.Client{ return &dns.Client{
Net: "tcp", Net: "tcp",
Timeout: 5 * time.Second, Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{ Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"), LocalAddr: getLocalAddr("tcp"),
KeepAlive: 15 * time.Second, KeepAlive: defaultClientTTL,
}, },
} }
}, },
} }
} }
func newTLSClientManager(resolver *Resolver) *clientManager { func newTLSClientManager(resolver *Resolver) *dnsClientManager {
return &clientManager{ return &dnsClientManager{
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) serverAddress: resolver.ServerAddress,
ttl: defaultClientTTL,
factory: func() *dns.Client { factory: func() *dns.Client {
return &dns.Client{ return &dns.Client{
Net: "tcp-tls", Net: "tcp-tls",
@ -77,24 +137,90 @@ func newTLSClientManager(resolver *Resolver) *clientManager {
ServerName: resolver.VerifyDomain, ServerName: resolver.VerifyDomain,
// TODO: use portbase rng // TODO: use portbase rng
}, },
Timeout: 5 * time.Second, Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{ Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"), LocalAddr: getLocalAddr("tcp"),
KeepAlive: 15 * time.Second, KeepAlive: defaultClientTTL,
}, },
} }
}, },
} }
} }
func (cm *clientManager) getDNSClient() *dns.Client { func (cm *dnsClientManager) getDNSClient() *dnsClient {
cm.lock.Lock() cm.lock.Lock()
defer cm.lock.Unlock() defer cm.lock.Unlock()
if cm.dnsClient == nil || cm.ttl == 0 || time.Now().After(cm.refreshAfter) { // return new immediately if a new client should be used for every request
cm.dnsClient = cm.factory() if cm.ttl == 0 {
cm.refreshAfter = time.Now().Add(cm.ttl) return &dnsClient{
mgr: cm,
client: cm.factory(),
}
} }
return cm.dnsClient // get first unused from pool
now := time.Now().UTC()
for _, dc := range cm.pool {
if dc != nil && !dc.inUse && now.Before(dc.useUntil) {
dc.inUse = true
return dc
}
}
// no available in pool, create new
newClient := &dnsClient{
mgr: cm,
inUse: true,
useUntil: now.Add(cm.ttl),
inPool: true,
client: cm.factory(),
}
newClient.startCleaner()
// find free spot in pool
for poolIndex, dc := range cm.pool {
if dc == nil {
cm.pool[poolIndex] = newClient
newClient.poolIndex = poolIndex
return newClient
}
}
// append to pool
cm.pool = append(cm.pool, newClient)
newClient.poolIndex = len(cm.pool) - 1
// TODO: shrink pool again?
return newClient
}
// startCleaner waits for EOL of the client and then removes it from the pool.
func (dc *dnsClient) startCleaner() {
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
select {
case <-time.After(dc.mgr.ttl + time.Second):
dc.mgr.lock.Lock()
cleanNow := dc.dead || !dc.inUse
dc.mgr.lock.Unlock()
if cleanNow {
dc.destroy()
return nil
}
case <-ctx.Done():
// give a short time before kill for graceful request completion
time.Sleep(100 * time.Millisecond)
}
// wait for grace period to end, then kill
select {
case <-time.After(connectionEOLGracePeriod):
case <-ctx.Done():
}
dc.destroy()
return nil
})
} }

View file

@ -9,6 +9,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/safing/portmaster/network/netutils"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -29,10 +31,11 @@ var (
questionsLock sync.Mutex questionsLock sync.Mutex
mDNSResolver = &Resolver{ mDNSResolver = &Resolver{
Server: ServerSourceMDNS, Server: ServerSourceMDNS,
ServerType: ServerTypeDNS, ServerType: ServerTypeDNS,
Source: ServerSourceMDNS, ServerIPScope: netutils.SiteLocal,
Conn: &mDNSResolverConn{}, Source: ServerSourceMDNS,
Conn: &mDNSResolverConn{},
} }
) )
@ -189,15 +192,21 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
// get entry from database // get entry from database
if saveFullRequest { if saveFullRequest {
// get from database
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype)) rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
// if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired:
// create new and do not append
if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() { if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() {
rrCache = &RRCache{ rrCache = &RRCache{
Domain: question.Name, Domain: question.Name,
Question: dns.Type(question.Qtype), Question: dns.Type(question.Qtype),
Server: mDNSResolver.Server,
ServerScope: mDNSResolver.ServerIPScope,
} }
} }
} }
// add all entries to RRCache
for _, entry := range message.Answer { for _, entry := range message.Answer {
if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) { if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) {
if saveFullRequest { if saveFullRequest {
@ -289,9 +298,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
continue continue
} }
rrCache = &RRCache{ rrCache = &RRCache{
Domain: v.Header().Name, Domain: v.Header().Name,
Question: dns.Type(v.Header().Class), Question: dns.Type(v.Header().Class),
Answer: []dns.RR{v}, Answer: []dns.RR{v},
Server: mDNSResolver.Server,
ServerScope: mDNSResolver.ServerIPScope,
} }
rrCache.Clean(60) rrCache.Clean(60)
err := rrCache.Save() err := rrCache.Save()

View file

@ -12,7 +12,7 @@ import (
var ( var (
recordDatabase = database.NewInterface(&database.Options{ recordDatabase = database.NewInterface(&database.Options{
AlwaysSetRelativateExpiry: 2592000, // 30 days AlwaysSetRelativateExpiry: 2592000, // 30 days
CacheSize: 128, CacheSize: 256,
}) })
) )

View file

@ -0,0 +1,27 @@
package resolver
import "testing"
func TestNameRecordStorage(t *testing.T) {
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
testQuestion := "A"
testNameRecord := &NameRecord{
Domain: testDomain,
Question: testQuestion,
}
err := testNameRecord.Save()
if err != nil {
t.Fatal(err)
}
r, err := GetNameRecord(testDomain, testQuestion)
if err != nil {
t.Fatal(err)
}
if r.Domain != testDomain || r.Question != testQuestion {
t.Fatal("mismatch")
}
}

184
resolver/pooling_test.go Normal file
View file

@ -0,0 +1,184 @@
package resolver
import (
"sync"
"testing"
"github.com/miekg/dns"
)
var (
domainFeed = make(chan string)
)
func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) {
dnsClient := brc.clientManager.getDNSClient()
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get connection
conn, new, err := dnsClient.getConn()
if err != nil {
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
}
// query server
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
if err != nil {
t.Fatal(err) //nolint:staticcheck
}
if reply == nil {
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
}
t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl)
dnsClient.done()
wg.Done()
}
func TestClientPooling(t *testing.T) {
// skip if short - this test depends on the Internet and might fail randomly
if testing.Short() {
t.Skip()
}
go feedDomains()
// create separate resolver for this test
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
if err != nil {
t.Fatal(err)
}
brc := resolver.Conn.(*BasicResolverConn)
wg := &sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(10)
for i := 0; i < 10; i++ {
go testQuery(t, wg, brc, &Query{
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
if len(brc.clientManager.pool) != 10 {
t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool))
}
}
}
func feedDomains() {
for {
for _, domain := range poolingTestDomains {
domainFeed <- domain
}
}
}
// Data
var (
poolingTestDomains = []string{
"facebook.com.",
"google.com.",
"youtube.com.",
"twitter.com.",
"instagram.com.",
"linkedin.com.",
"microsoft.com.",
"apple.com.",
"wikipedia.org.",
"plus.google.com.",
"en.wikipedia.org.",
"googletagmanager.com.",
"youtu.be.",
"adobe.com.",
"vimeo.com.",
"pinterest.com.",
"itunes.apple.com.",
"play.google.com.",
"maps.google.com.",
"goo.gl.",
"wordpress.com.",
"blogspot.com.",
"bit.ly.",
"github.com.",
"player.vimeo.com.",
"amazon.com.",
"wordpress.org.",
"docs.google.com.",
"yahoo.com.",
"mozilla.org.",
"tumblr.com.",
"godaddy.com.",
"flickr.com.",
"parked-content.godaddy.com.",
"drive.google.com.",
"support.google.com.",
"apache.org.",
"gravatar.com.",
"europa.eu.",
"qq.com.",
"w3.org.",
"nytimes.com.",
"reddit.com.",
"macromedia.com.",
"get.adobe.com.",
"soundcloud.com.",
"sourceforge.net.",
"sites.google.com.",
"nih.gov.",
"amazonaws.com.",
"t.co.",
"support.microsoft.com.",
"forbes.com.",
"theguardian.com.",
"cnn.com.",
"github.io.",
"bbc.co.uk.",
"dropbox.com.",
"whatsapp.com.",
"medium.com.",
"creativecommons.org.",
"www.ncbi.nlm.nih.gov.",
"httpd.apache.org.",
"archive.org.",
"ec.europa.eu.",
"php.net.",
"apps.apple.com.",
"weebly.com.",
"support.apple.com.",
"weibo.com.",
"wixsite.com.",
"issuu.com.",
"who.int.",
"paypal.com.",
"m.facebook.com.",
"oracle.com.",
"msn.com.",
"gnu.org.",
"tinyurl.com.",
"reuters.com.",
"l.facebook.com.",
"cloudflare.com.",
"wsj.com.",
"washingtonpost.com.",
"domainmarket.com.",
"imdb.com.",
"bbc.com.",
"bing.com.",
"accounts.google.com.",
"vk.com.",
"api.whatsapp.com.",
"opera.com.",
"cdc.gov.",
"slideshare.net.",
"wpa.qq.com.",
"harvard.edu.",
"mit.edu.",
"code.google.com.",
"wikimedia.org.",
}
)

View file

@ -114,6 +114,7 @@ func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) {
rrCache.MixAnswers() rrCache.MixAnswers()
return rrCache, nil return rrCache, nil
} }
log.Tracer(ctx).Debugf("resolver: waited for another %s%s query, but cache missed!", q.FQDN, q.QType)
// if cache is still empty or non-compliant, go ahead and just query // if cache is still empty or non-compliant, go ahead and just query
} else { } else {
// we are the first! // we are the first!
@ -132,14 +133,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
if err != nil { if err != nil {
if err != database.ErrNotFound { if err != database.ErrNotFound {
log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
log.Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
} }
return nil return nil
} }
// get resolver that rrCache was resolved with // get resolver that rrCache was resolved with
resolver := getResolverByIDWithLocking(rrCache.Server) resolver := getActiveResolverByIDWithLocking(rrCache.Server)
if resolver == nil { if resolver == nil {
log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server)
return nil return nil
} }
@ -165,6 +166,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
}) })
} }
log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0)))
return rrCache return rrCache
} }

View file

@ -2,6 +2,7 @@ package resolver
import ( import (
"context" "context"
"errors"
"net" "net"
"sync" "sync"
"time" "time"
@ -92,7 +93,7 @@ type BasicResolverConn struct {
sync.Mutex // for lastFail sync.Mutex // for lastFail
resolver *Resolver resolver *Resolver
clientManager *clientManager clientManager *dnsClientManager
lastFail time.Time lastFail time.Time
} }
@ -126,18 +127,41 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
// start // start
var reply *dns.Msg var reply *dns.Msg
var ttl time.Duration
var err error var err error
for i := 0; i < 3; i++ { var conn *dns.Conn
var new bool
var i int
// log query time for ; i < 5; i++ {
// qStart := time.Now()
reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress) // first get connection
// log.Tracef("resolver: query to %s took %s", resolver.Server, time.Now().Sub(qStart)) dc := brc.clientManager.getDNSClient()
conn, new, err = dc.getConn()
if err != nil {
log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err)
// remove client from pool
dc.destroy()
// try again
continue
}
if new {
log.Tracer(ctx).Tracef("resolver: created new connection to %s", resolver.ServerAddress)
} else {
log.Tracer(ctx).Tracef("resolver: reusing connection to %s", resolver.ServerAddress)
}
// query server
reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn)
log.Tracer(ctx).Tracef("resolver: query took %s", ttl)
// error handling // error handling
if err != nil { if err != nil {
log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err) log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err)
// remove client from pool
dc.destroy()
// TODO: handle special cases // TODO: handle special cases
// 1. connect: network is unreachable // 1. connect: network is unreachable
// 2. timeout // 2. timeout
@ -148,13 +172,23 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
// temporary error // temporary error
if nerr, ok := err.(net.Error); ok && nerr.Timeout() { if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server) log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server)
// try again
continue continue
} }
// permanent error // permanent error
break break
} else if reply == nil {
// remove client from pool
dc.destroy()
log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server)
return nil, errors.New("internal error")
} }
// make client available again
dc.done()
if resolver.IsBlockedUpstream(reply) { if resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{resolver.GetName()} return nil, &BlockedUpstreamError{resolver.GetName()}
} }
@ -166,12 +200,15 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
if err != nil { if err != nil {
return nil, err return nil, err
// TODO: mark as failed // TODO: mark as failed
} else if reply == nil {
log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), i+1)
return nil, errors.New("internal error")
} }
// hint network environment at successful connection // hint network environment at successful connection
netenv.ReportSuccessfulConnection() netenv.ReportSuccessfulConnection()
new := &RRCache{ newRecord := &RRCache{
Domain: q.FQDN, Domain: q.FQDN,
Question: q.QType, Question: q.QType,
Answer: reply.Answer, Answer: reply.Answer,
@ -182,5 +219,5 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
} }
// TODO: check if reply.Answer is valid // TODO: check if reply.Answer is valid
return new, nil return newRecord, nil
} }

View file

@ -25,7 +25,7 @@ var (
globalResolvers []*Resolver // all (global) resolvers globalResolvers []*Resolver // all (global) resolvers
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
allResolvers map[string]*Resolver // lookup map of all resolvers activeResolvers map[string]*Resolver // lookup map of all resolvers
resolversLock sync.RWMutex resolversLock sync.RWMutex
dupReqMap = make(map[string]*sync.WaitGroup) dupReqMap = make(map[string]*sync.WaitGroup)
@ -41,11 +41,11 @@ func indexOfScope(domain string, list []*Scope) int {
return -1 return -1
} }
func getResolverByIDWithLocking(server string) *Resolver { func getActiveResolverByIDWithLocking(server string) *Resolver {
resolversLock.Lock() resolversLock.RLock()
defer resolversLock.Unlock() defer resolversLock.RUnlock()
resolver, ok := allResolvers[server] resolver, ok := activeResolvers[server]
if ok { if ok {
return resolver return resolver
} }
@ -62,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string {
return address return address
} }
func clientManagerFactory(serverType string) func(*Resolver) *clientManager { func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
switch serverType { switch serverType {
case ServerTypeDNS: case ServerTypeDNS:
return newDNSClientManager return newDNSClientManager
@ -152,8 +152,8 @@ func configureSearchDomains(resolver *Resolver, searches []string) {
} }
} }
func getConfiguredResolvers() (resolvers []*Resolver) { func getConfiguredResolvers(list []string) (resolvers []*Resolver) {
for _, server := range configuredNameServers() { for _, server := range list {
resolver, skip, err := createResolver(server, "config") resolver, skip, err := createResolver(server, "config")
if err != nil { if err != nil {
// TODO(ppacher): module error // TODO(ppacher): module error
@ -199,19 +199,40 @@ func loadResolvers() {
defer resolversLock.Unlock() defer resolversLock.Unlock()
newResolvers := append( newResolvers := append(
getConfiguredResolvers(), getConfiguredResolvers(configuredNameServers()),
getSystemResolvers()..., getSystemResolvers()...,
) )
// save resolvers if len(newResolvers) == 0 {
globalResolvers = newResolvers msg := "no (valid) dns servers found in (user) configuration or system, falling back to defaults"
if len(globalResolvers) == 0 { log.Warningf("resolver: %s", msg)
log.Criticalf("resolver: no (valid) dns servers found in configuration and system") module.Warning("no-valid-user-resolvers", msg)
// TODO(module error)
// load defaults directly, overriding config system
newResolvers = getConfiguredResolvers(defaultNameServers)
if len(newResolvers) == 0 {
msg = "no (valid) dns servers found in configuration or system"
log.Criticalf("resolver: %s", msg)
module.Error("no-valid-default-resolvers", msg)
return
}
} }
// save resolvers
globalResolvers = newResolvers
// assing resolvers to scopes
setLocalAndScopeResolvers(globalResolvers) setLocalAndScopeResolvers(globalResolvers)
// set active resolvers (for cache validation)
// reset
activeResolvers = make(map[string]*Resolver)
// add
for _, resolver := range newResolvers {
activeResolvers[resolver.Server] = resolver
}
activeResolvers[mDNSResolver.Server] = mDNSResolver
// log global resolvers // log global resolvers
if len(globalResolvers) > 0 { if len(globalResolvers) > 0 {
log.Trace("resolver: loaded global resolvers:") log.Trace("resolver: loaded global resolvers:")

View file

@ -65,12 +65,12 @@ func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) (
for _, rr := range rrCache.Answer { for _, rr := range rrCache.Answer {
switch v := rr.(type) { switch v := rr.(type) {
case *dns.A: case *dns.A:
log.Infof("A: %s", v.A.String()) // log.Debugf("A: %s", v.A.String())
if ip == v.A.String() { if ip == v.A.String() {
return ptrName, nil return ptrName, nil
} }
case *dns.AAAA: case *dns.AAAA:
log.Infof("AAAA: %s", v.AAAA.String()) // log.Debugf("AAAA: %s", v.AAAA.String())
if ip == v.AAAA.String() { if ip == v.AAAA.String() {
return ptrName, nil return ptrName, nil
} }

41
resolver/rrcache_test.go Normal file
View file

@ -0,0 +1,41 @@
package resolver
import (
"testing"
"github.com/miekg/dns"
)
func TestCaching(t *testing.T) {
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
testQuestion := "A"
testNameRecord := &NameRecord{
Domain: testDomain,
Question: testQuestion,
}
err := testNameRecord.Save()
if err != nil {
t.Fatal(err)
}
rrCache, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
if err != nil {
t.Fatal(err)
}
err = rrCache.Save()
if err != nil {
t.Fatal(err)
}
rrCache2, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
if err != nil {
t.Fatal(err)
}
if rrCache2.Domain != rrCache.Domain {
t.Fatal("something very is wrong")
}
}