mirror of
https://github.com/safing/portmaster
synced 2025-09-01 10:09:11 +00:00
Add TLS resolver connection reusing and pooling
Also, fix caching issues and add more tests
This commit is contained in:
parent
dd837e40e2
commit
53eb309e72
11 changed files with 510 additions and 61 deletions
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -9,6 +10,12 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClientTTL = 5 * time.Minute
|
||||
defaultRequestTimeout = 5 * time.Second
|
||||
connectionEOLGracePeriod = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
localAddrFactory func(network string) net.Addr
|
||||
)
|
||||
|
@ -27,21 +34,72 @@ func getLocalAddr(network string) net.Addr {
|
|||
return nil
|
||||
}
|
||||
|
||||
type clientManager struct {
|
||||
dnsClient *dns.Client
|
||||
factory func() *dns.Client
|
||||
type dnsClientManager struct {
|
||||
lock sync.Mutex
|
||||
|
||||
lock sync.Mutex
|
||||
refreshAfter time.Time
|
||||
ttl time.Duration // force refresh of connection to reduce traceability
|
||||
// set by creator
|
||||
serverAddress string
|
||||
ttl time.Duration // force refresh of connection to reduce traceability
|
||||
factory func() *dns.Client
|
||||
|
||||
// internal
|
||||
pool []*dnsClient
|
||||
}
|
||||
|
||||
func newDNSClientManager(_ *Resolver) *clientManager {
|
||||
return &clientManager{
|
||||
ttl: 0, // new client for every request, as we need to randomize the port
|
||||
type dnsClient struct {
|
||||
mgr *dnsClientManager
|
||||
|
||||
inUse bool
|
||||
useUntil time.Time
|
||||
dead bool
|
||||
inPool bool
|
||||
poolIndex int
|
||||
|
||||
client *dns.Client
|
||||
conn *dns.Conn
|
||||
}
|
||||
|
||||
// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
|
||||
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
|
||||
if dc.conn == nil {
|
||||
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return dc.conn, true, nil
|
||||
}
|
||||
return dc.conn, false, nil
|
||||
}
|
||||
|
||||
func (dc *dnsClient) done() {
|
||||
dc.mgr.lock.Lock()
|
||||
defer dc.mgr.lock.Unlock()
|
||||
|
||||
dc.inUse = false
|
||||
}
|
||||
|
||||
func (dc *dnsClient) destroy() {
|
||||
dc.mgr.lock.Lock()
|
||||
dc.inUse = true // block from being used
|
||||
dc.dead = true // abort cleaning
|
||||
if dc.inPool {
|
||||
dc.inPool = false
|
||||
dc.mgr.pool[dc.poolIndex] = nil
|
||||
}
|
||||
dc.mgr.lock.Unlock()
|
||||
|
||||
if dc.conn != nil {
|
||||
_ = dc.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
|
||||
return &dnsClientManager{
|
||||
serverAddress: resolver.ServerAddress,
|
||||
ttl: 0, // new client for every request, as we need to randomize the port
|
||||
factory: func() *dns.Client {
|
||||
return &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: defaultRequestTimeout,
|
||||
Dialer: &net.Dialer{
|
||||
LocalAddr: getLocalAddr("udp"),
|
||||
},
|
||||
|
@ -50,25 +108,27 @@ func newDNSClientManager(_ *Resolver) *clientManager {
|
|||
}
|
||||
}
|
||||
|
||||
func newTCPClientManager(_ *Resolver) *clientManager {
|
||||
return &clientManager{
|
||||
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff)
|
||||
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
|
||||
return &dnsClientManager{
|
||||
serverAddress: resolver.ServerAddress,
|
||||
ttl: defaultClientTTL,
|
||||
factory: func() *dns.Client {
|
||||
return &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: defaultRequestTimeout,
|
||||
Dialer: &net.Dialer{
|
||||
LocalAddr: getLocalAddr("tcp"),
|
||||
KeepAlive: 15 * time.Second,
|
||||
KeepAlive: defaultClientTTL,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newTLSClientManager(resolver *Resolver) *clientManager {
|
||||
return &clientManager{
|
||||
ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff)
|
||||
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
|
||||
return &dnsClientManager{
|
||||
serverAddress: resolver.ServerAddress,
|
||||
ttl: defaultClientTTL,
|
||||
factory: func() *dns.Client {
|
||||
return &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
|
@ -77,24 +137,90 @@ func newTLSClientManager(resolver *Resolver) *clientManager {
|
|||
ServerName: resolver.VerifyDomain,
|
||||
// TODO: use portbase rng
|
||||
},
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: defaultRequestTimeout,
|
||||
Dialer: &net.Dialer{
|
||||
LocalAddr: getLocalAddr("tcp"),
|
||||
KeepAlive: 15 * time.Second,
|
||||
KeepAlive: defaultClientTTL,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *clientManager) getDNSClient() *dns.Client {
|
||||
func (cm *dnsClientManager) getDNSClient() *dnsClient {
|
||||
cm.lock.Lock()
|
||||
defer cm.lock.Unlock()
|
||||
|
||||
if cm.dnsClient == nil || cm.ttl == 0 || time.Now().After(cm.refreshAfter) {
|
||||
cm.dnsClient = cm.factory()
|
||||
cm.refreshAfter = time.Now().Add(cm.ttl)
|
||||
// return new immediately if a new client should be used for every request
|
||||
if cm.ttl == 0 {
|
||||
return &dnsClient{
|
||||
mgr: cm,
|
||||
client: cm.factory(),
|
||||
}
|
||||
}
|
||||
|
||||
return cm.dnsClient
|
||||
// get first unused from pool
|
||||
now := time.Now().UTC()
|
||||
for _, dc := range cm.pool {
|
||||
if dc != nil && !dc.inUse && now.Before(dc.useUntil) {
|
||||
dc.inUse = true
|
||||
return dc
|
||||
}
|
||||
}
|
||||
|
||||
// no available in pool, create new
|
||||
newClient := &dnsClient{
|
||||
mgr: cm,
|
||||
inUse: true,
|
||||
useUntil: now.Add(cm.ttl),
|
||||
inPool: true,
|
||||
client: cm.factory(),
|
||||
}
|
||||
newClient.startCleaner()
|
||||
|
||||
// find free spot in pool
|
||||
for poolIndex, dc := range cm.pool {
|
||||
if dc == nil {
|
||||
cm.pool[poolIndex] = newClient
|
||||
newClient.poolIndex = poolIndex
|
||||
return newClient
|
||||
}
|
||||
}
|
||||
|
||||
// append to pool
|
||||
cm.pool = append(cm.pool, newClient)
|
||||
newClient.poolIndex = len(cm.pool) - 1
|
||||
// TODO: shrink pool again?
|
||||
|
||||
return newClient
|
||||
}
|
||||
|
||||
// startCleaner waits for EOL of the client and then removes it from the pool.
|
||||
func (dc *dnsClient) startCleaner() {
|
||||
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
|
||||
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
|
||||
select {
|
||||
case <-time.After(dc.mgr.ttl + time.Second):
|
||||
dc.mgr.lock.Lock()
|
||||
cleanNow := dc.dead || !dc.inUse
|
||||
dc.mgr.lock.Unlock()
|
||||
|
||||
if cleanNow {
|
||||
dc.destroy()
|
||||
return nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// give a short time before kill for graceful request completion
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// wait for grace period to end, then kill
|
||||
select {
|
||||
case <-time.After(connectionEOLGracePeriod):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
dc.destroy()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
|
@ -29,10 +31,11 @@ var (
|
|||
questionsLock sync.Mutex
|
||||
|
||||
mDNSResolver = &Resolver{
|
||||
Server: ServerSourceMDNS,
|
||||
ServerType: ServerTypeDNS,
|
||||
Source: ServerSourceMDNS,
|
||||
Conn: &mDNSResolverConn{},
|
||||
Server: ServerSourceMDNS,
|
||||
ServerType: ServerTypeDNS,
|
||||
ServerIPScope: netutils.SiteLocal,
|
||||
Source: ServerSourceMDNS,
|
||||
Conn: &mDNSResolverConn{},
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -189,15 +192,21 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
|
|||
|
||||
// get entry from database
|
||||
if saveFullRequest {
|
||||
// get from database
|
||||
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
|
||||
// if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired:
|
||||
// create new and do not append
|
||||
if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() {
|
||||
rrCache = &RRCache{
|
||||
Domain: question.Name,
|
||||
Question: dns.Type(question.Qtype),
|
||||
Domain: question.Name,
|
||||
Question: dns.Type(question.Qtype),
|
||||
Server: mDNSResolver.Server,
|
||||
ServerScope: mDNSResolver.ServerIPScope,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add all entries to RRCache
|
||||
for _, entry := range message.Answer {
|
||||
if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) {
|
||||
if saveFullRequest {
|
||||
|
@ -289,9 +298,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
|
|||
continue
|
||||
}
|
||||
rrCache = &RRCache{
|
||||
Domain: v.Header().Name,
|
||||
Question: dns.Type(v.Header().Class),
|
||||
Answer: []dns.RR{v},
|
||||
Domain: v.Header().Name,
|
||||
Question: dns.Type(v.Header().Class),
|
||||
Answer: []dns.RR{v},
|
||||
Server: mDNSResolver.Server,
|
||||
ServerScope: mDNSResolver.ServerIPScope,
|
||||
}
|
||||
rrCache.Clean(60)
|
||||
err := rrCache.Save()
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
var (
|
||||
recordDatabase = database.NewInterface(&database.Options{
|
||||
AlwaysSetRelativateExpiry: 2592000, // 30 days
|
||||
CacheSize: 128,
|
||||
CacheSize: 256,
|
||||
})
|
||||
)
|
||||
|
||||
|
|
27
resolver/namerecord_test.go
Normal file
27
resolver/namerecord_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package resolver
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNameRecordStorage(t *testing.T) {
|
||||
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
|
||||
testQuestion := "A"
|
||||
|
||||
testNameRecord := &NameRecord{
|
||||
Domain: testDomain,
|
||||
Question: testQuestion,
|
||||
}
|
||||
|
||||
err := testNameRecord.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := GetNameRecord(testDomain, testQuestion)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.Domain != testDomain || r.Question != testQuestion {
|
||||
t.Fatal("mismatch")
|
||||
}
|
||||
}
|
184
resolver/pooling_test.go
Normal file
184
resolver/pooling_test.go
Normal file
|
@ -0,0 +1,184 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var (
|
||||
domainFeed = make(chan string)
|
||||
)
|
||||
|
||||
func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) {
|
||||
dnsClient := brc.clientManager.getDNSClient()
|
||||
|
||||
// create query
|
||||
dnsQuery := new(dns.Msg)
|
||||
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
|
||||
|
||||
// get connection
|
||||
conn, new, err := dnsClient.getConn()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
|
||||
}
|
||||
|
||||
// query server
|
||||
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
|
||||
if err != nil {
|
||||
t.Fatal(err) //nolint:staticcheck
|
||||
}
|
||||
if reply == nil {
|
||||
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
|
||||
}
|
||||
|
||||
t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl)
|
||||
dnsClient.done()
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
func TestClientPooling(t *testing.T) {
|
||||
// skip if short - this test depends on the Internet and might fail randomly
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
go feedDomains()
|
||||
|
||||
// create separate resolver for this test
|
||||
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
brc := resolver.Conn.(*BasicResolverConn)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go testQuery(t, wg, brc, &Query{
|
||||
FQDN: <-domainFeed,
|
||||
QType: dns.Type(dns.TypeA),
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
if len(brc.clientManager.pool) != 10 {
|
||||
t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func feedDomains() {
|
||||
for {
|
||||
for _, domain := range poolingTestDomains {
|
||||
domainFeed <- domain
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Data
|
||||
|
||||
var (
|
||||
poolingTestDomains = []string{
|
||||
"facebook.com.",
|
||||
"google.com.",
|
||||
"youtube.com.",
|
||||
"twitter.com.",
|
||||
"instagram.com.",
|
||||
"linkedin.com.",
|
||||
"microsoft.com.",
|
||||
"apple.com.",
|
||||
"wikipedia.org.",
|
||||
"plus.google.com.",
|
||||
"en.wikipedia.org.",
|
||||
"googletagmanager.com.",
|
||||
"youtu.be.",
|
||||
"adobe.com.",
|
||||
"vimeo.com.",
|
||||
"pinterest.com.",
|
||||
"itunes.apple.com.",
|
||||
"play.google.com.",
|
||||
"maps.google.com.",
|
||||
"goo.gl.",
|
||||
"wordpress.com.",
|
||||
"blogspot.com.",
|
||||
"bit.ly.",
|
||||
"github.com.",
|
||||
"player.vimeo.com.",
|
||||
"amazon.com.",
|
||||
"wordpress.org.",
|
||||
"docs.google.com.",
|
||||
"yahoo.com.",
|
||||
"mozilla.org.",
|
||||
"tumblr.com.",
|
||||
"godaddy.com.",
|
||||
"flickr.com.",
|
||||
"parked-content.godaddy.com.",
|
||||
"drive.google.com.",
|
||||
"support.google.com.",
|
||||
"apache.org.",
|
||||
"gravatar.com.",
|
||||
"europa.eu.",
|
||||
"qq.com.",
|
||||
"w3.org.",
|
||||
"nytimes.com.",
|
||||
"reddit.com.",
|
||||
"macromedia.com.",
|
||||
"get.adobe.com.",
|
||||
"soundcloud.com.",
|
||||
"sourceforge.net.",
|
||||
"sites.google.com.",
|
||||
"nih.gov.",
|
||||
"amazonaws.com.",
|
||||
"t.co.",
|
||||
"support.microsoft.com.",
|
||||
"forbes.com.",
|
||||
"theguardian.com.",
|
||||
"cnn.com.",
|
||||
"github.io.",
|
||||
"bbc.co.uk.",
|
||||
"dropbox.com.",
|
||||
"whatsapp.com.",
|
||||
"medium.com.",
|
||||
"creativecommons.org.",
|
||||
"www.ncbi.nlm.nih.gov.",
|
||||
"httpd.apache.org.",
|
||||
"archive.org.",
|
||||
"ec.europa.eu.",
|
||||
"php.net.",
|
||||
"apps.apple.com.",
|
||||
"weebly.com.",
|
||||
"support.apple.com.",
|
||||
"weibo.com.",
|
||||
"wixsite.com.",
|
||||
"issuu.com.",
|
||||
"who.int.",
|
||||
"paypal.com.",
|
||||
"m.facebook.com.",
|
||||
"oracle.com.",
|
||||
"msn.com.",
|
||||
"gnu.org.",
|
||||
"tinyurl.com.",
|
||||
"reuters.com.",
|
||||
"l.facebook.com.",
|
||||
"cloudflare.com.",
|
||||
"wsj.com.",
|
||||
"washingtonpost.com.",
|
||||
"domainmarket.com.",
|
||||
"imdb.com.",
|
||||
"bbc.com.",
|
||||
"bing.com.",
|
||||
"accounts.google.com.",
|
||||
"vk.com.",
|
||||
"api.whatsapp.com.",
|
||||
"opera.com.",
|
||||
"cdc.gov.",
|
||||
"slideshare.net.",
|
||||
"wpa.qq.com.",
|
||||
"harvard.edu.",
|
||||
"mit.edu.",
|
||||
"code.google.com.",
|
||||
"wikimedia.org.",
|
||||
}
|
||||
)
|
|
@ -114,6 +114,7 @@ func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) {
|
|||
rrCache.MixAnswers()
|
||||
return rrCache, nil
|
||||
}
|
||||
log.Tracer(ctx).Debugf("resolver: waited for another %s%s query, but cache missed!", q.FQDN, q.QType)
|
||||
// if cache is still empty or non-compliant, go ahead and just query
|
||||
} else {
|
||||
// we are the first!
|
||||
|
@ -132,14 +133,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
|
|||
if err != nil {
|
||||
if err != database.ErrNotFound {
|
||||
log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
|
||||
log.Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// get resolver that rrCache was resolved with
|
||||
resolver := getResolverByIDWithLocking(rrCache.Server)
|
||||
resolver := getActiveResolverByIDWithLocking(rrCache.Server)
|
||||
if resolver == nil {
|
||||
log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -165,6 +166,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
|
|||
})
|
||||
}
|
||||
|
||||
log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0)))
|
||||
return rrCache
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -92,7 +93,7 @@ type BasicResolverConn struct {
|
|||
sync.Mutex // for lastFail
|
||||
|
||||
resolver *Resolver
|
||||
clientManager *clientManager
|
||||
clientManager *dnsClientManager
|
||||
lastFail time.Time
|
||||
}
|
||||
|
||||
|
@ -126,18 +127,41 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
|
|||
|
||||
// start
|
||||
var reply *dns.Msg
|
||||
var ttl time.Duration
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
var conn *dns.Conn
|
||||
var new bool
|
||||
var i int
|
||||
|
||||
// log query time
|
||||
// qStart := time.Now()
|
||||
reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress)
|
||||
// log.Tracef("resolver: query to %s took %s", resolver.Server, time.Now().Sub(qStart))
|
||||
for ; i < 5; i++ {
|
||||
|
||||
// first get connection
|
||||
dc := brc.clientManager.getDNSClient()
|
||||
conn, new, err = dc.getConn()
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err)
|
||||
// remove client from pool
|
||||
dc.destroy()
|
||||
// try again
|
||||
continue
|
||||
}
|
||||
if new {
|
||||
log.Tracer(ctx).Tracef("resolver: created new connection to %s", resolver.ServerAddress)
|
||||
} else {
|
||||
log.Tracer(ctx).Tracef("resolver: reusing connection to %s", resolver.ServerAddress)
|
||||
}
|
||||
|
||||
// query server
|
||||
reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn)
|
||||
log.Tracer(ctx).Tracef("resolver: query took %s", ttl)
|
||||
|
||||
// error handling
|
||||
if err != nil {
|
||||
log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err)
|
||||
|
||||
// remove client from pool
|
||||
dc.destroy()
|
||||
|
||||
// TODO: handle special cases
|
||||
// 1. connect: network is unreachable
|
||||
// 2. timeout
|
||||
|
@ -148,13 +172,23 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
|
|||
// temporary error
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server)
|
||||
// try again
|
||||
continue
|
||||
}
|
||||
|
||||
// permanent error
|
||||
break
|
||||
} else if reply == nil {
|
||||
// remove client from pool
|
||||
dc.destroy()
|
||||
|
||||
log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server)
|
||||
return nil, errors.New("internal error")
|
||||
}
|
||||
|
||||
// make client available again
|
||||
dc.done()
|
||||
|
||||
if resolver.IsBlockedUpstream(reply) {
|
||||
return nil, &BlockedUpstreamError{resolver.GetName()}
|
||||
}
|
||||
|
@ -166,12 +200,15 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
|
|||
if err != nil {
|
||||
return nil, err
|
||||
// TODO: mark as failed
|
||||
} else if reply == nil {
|
||||
log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), i+1)
|
||||
return nil, errors.New("internal error")
|
||||
}
|
||||
|
||||
// hint network environment at successful connection
|
||||
netenv.ReportSuccessfulConnection()
|
||||
|
||||
new := &RRCache{
|
||||
newRecord := &RRCache{
|
||||
Domain: q.FQDN,
|
||||
Question: q.QType,
|
||||
Answer: reply.Answer,
|
||||
|
@ -182,5 +219,5 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
|
|||
}
|
||||
|
||||
// TODO: check if reply.Answer is valid
|
||||
return new, nil
|
||||
return newRecord, nil
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ var (
|
|||
globalResolvers []*Resolver // all (global) resolvers
|
||||
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
|
||||
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
|
||||
allResolvers map[string]*Resolver // lookup map of all resolvers
|
||||
activeResolvers map[string]*Resolver // lookup map of all resolvers
|
||||
resolversLock sync.RWMutex
|
||||
|
||||
dupReqMap = make(map[string]*sync.WaitGroup)
|
||||
|
@ -41,11 +41,11 @@ func indexOfScope(domain string, list []*Scope) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
func getResolverByIDWithLocking(server string) *Resolver {
|
||||
resolversLock.Lock()
|
||||
defer resolversLock.Unlock()
|
||||
func getActiveResolverByIDWithLocking(server string) *Resolver {
|
||||
resolversLock.RLock()
|
||||
defer resolversLock.RUnlock()
|
||||
|
||||
resolver, ok := allResolvers[server]
|
||||
resolver, ok := activeResolvers[server]
|
||||
if ok {
|
||||
return resolver
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string {
|
|||
return address
|
||||
}
|
||||
|
||||
func clientManagerFactory(serverType string) func(*Resolver) *clientManager {
|
||||
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
|
||||
switch serverType {
|
||||
case ServerTypeDNS:
|
||||
return newDNSClientManager
|
||||
|
@ -152,8 +152,8 @@ func configureSearchDomains(resolver *Resolver, searches []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func getConfiguredResolvers() (resolvers []*Resolver) {
|
||||
for _, server := range configuredNameServers() {
|
||||
func getConfiguredResolvers(list []string) (resolvers []*Resolver) {
|
||||
for _, server := range list {
|
||||
resolver, skip, err := createResolver(server, "config")
|
||||
if err != nil {
|
||||
// TODO(ppacher): module error
|
||||
|
@ -199,19 +199,40 @@ func loadResolvers() {
|
|||
defer resolversLock.Unlock()
|
||||
|
||||
newResolvers := append(
|
||||
getConfiguredResolvers(),
|
||||
getConfiguredResolvers(configuredNameServers()),
|
||||
getSystemResolvers()...,
|
||||
)
|
||||
|
||||
// save resolvers
|
||||
globalResolvers = newResolvers
|
||||
if len(globalResolvers) == 0 {
|
||||
log.Criticalf("resolver: no (valid) dns servers found in configuration and system")
|
||||
// TODO(module error)
|
||||
if len(newResolvers) == 0 {
|
||||
msg := "no (valid) dns servers found in (user) configuration or system, falling back to defaults"
|
||||
log.Warningf("resolver: %s", msg)
|
||||
module.Warning("no-valid-user-resolvers", msg)
|
||||
|
||||
// load defaults directly, overriding config system
|
||||
newResolvers = getConfiguredResolvers(defaultNameServers)
|
||||
if len(newResolvers) == 0 {
|
||||
msg = "no (valid) dns servers found in configuration or system"
|
||||
log.Criticalf("resolver: %s", msg)
|
||||
module.Error("no-valid-default-resolvers", msg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// save resolvers
|
||||
globalResolvers = newResolvers
|
||||
|
||||
// assing resolvers to scopes
|
||||
setLocalAndScopeResolvers(globalResolvers)
|
||||
|
||||
// set active resolvers (for cache validation)
|
||||
// reset
|
||||
activeResolvers = make(map[string]*Resolver)
|
||||
// add
|
||||
for _, resolver := range newResolvers {
|
||||
activeResolvers[resolver.Server] = resolver
|
||||
}
|
||||
activeResolvers[mDNSResolver.Server] = mDNSResolver
|
||||
|
||||
// log global resolvers
|
||||
if len(globalResolvers) > 0 {
|
||||
log.Trace("resolver: loaded global resolvers:")
|
||||
|
|
|
@ -65,12 +65,12 @@ func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) (
|
|||
for _, rr := range rrCache.Answer {
|
||||
switch v := rr.(type) {
|
||||
case *dns.A:
|
||||
log.Infof("A: %s", v.A.String())
|
||||
// log.Debugf("A: %s", v.A.String())
|
||||
if ip == v.A.String() {
|
||||
return ptrName, nil
|
||||
}
|
||||
case *dns.AAAA:
|
||||
log.Infof("AAAA: %s", v.AAAA.String())
|
||||
// log.Debugf("AAAA: %s", v.AAAA.String())
|
||||
if ip == v.AAAA.String() {
|
||||
return ptrName, nil
|
||||
}
|
||||
|
|
41
resolver/rrcache_test.go
Normal file
41
resolver/rrcache_test.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com."
|
||||
testQuestion := "A"
|
||||
|
||||
testNameRecord := &NameRecord{
|
||||
Domain: testDomain,
|
||||
Question: testQuestion,
|
||||
}
|
||||
|
||||
err := testNameRecord.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rrCache, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = rrCache.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rrCache2, err := GetRRCache(testDomain, dns.Type(dns.TypeA))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if rrCache2.Domain != rrCache.Domain {
|
||||
t.Fatal("something very is wrong")
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue