Switch resolver pooling to use sync.Pool

This commit is contained in:
Daniel 2020-05-20 14:57:47 +02:00
parent 36c60a1e33
commit c8223f1a63
4 changed files with 41 additions and 81 deletions

View file

@ -12,8 +12,9 @@ import (
const (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 5 * time.Second
connectionEOLGracePeriod = 10 * time.Second
defaultRequestTimeout = 3 * time.Second // dns query
defaultConnectTimeout = 2 * time.Second // tcp/tls
connectionEOLGracePeriod = 7 * time.Second
)
var (
@ -43,23 +44,17 @@ type dnsClientManager struct {
factory func() *dns.Client
// internal
pool []*dnsClient
pool sync.Pool
}
type dnsClient struct {
mgr *dnsClientManager
inUse bool
useUntil time.Time
dead bool
inPool bool
poolIndex int
client *dns.Client
conn *dns.Conn
mgr *dnsClientManager
client *dns.Client
conn *dns.Conn
useUntil time.Time
}
// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
if dc.conn == nil {
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
@ -71,23 +66,11 @@ func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
return dc.conn, false, nil
}
func (dc *dnsClient) done() {
dc.mgr.lock.Lock()
defer dc.mgr.lock.Unlock()
dc.inUse = false
func (dc *dnsClient) addToPool() {
dc.mgr.pool.Put(dc)
}
func (dc *dnsClient) destroy() {
dc.mgr.lock.Lock()
dc.inUse = true // block from being used
dc.dead = true // abort cleaning
if dc.inPool {
dc.inPool = false
dc.mgr.pool[dc.poolIndex] = nil
}
dc.mgr.lock.Unlock()
if dc.conn != nil {
_ = dc.conn.Close()
}
@ -118,6 +101,7 @@ func newTCPClientManager(resolver *Resolver) *dnsClientManager {
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
},
}
@ -140,6 +124,7 @@ func newTLSClientManager(resolver *Resolver) *dnsClientManager {
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
},
}
@ -159,11 +144,18 @@ func (cm *dnsClientManager) getDNSClient() *dnsClient {
}
}
// get first unused from pool
// get cached client from pool
now := time.Now().UTC()
for _, dc := range cm.pool {
if dc != nil && !dc.inUse && now.Before(dc.useUntil) {
dc.inUse = true
poolLoop:
for {
dc, ok := cm.pool.Get().(*dnsClient)
switch {
case !ok || dc == nil: // cache empty (probably, pool may always return nil!)
break poolLoop // create new
case now.After(dc.useUntil):
continue // get next
default:
return dc
}
}
@ -171,27 +163,11 @@ func (cm *dnsClientManager) getDNSClient() *dnsClient {
// no available in pool, create new
newClient := &dnsClient{
mgr: cm,
inUse: true,
useUntil: now.Add(cm.ttl),
inPool: true,
client: cm.factory(),
useUntil: now.Add(cm.ttl),
}
newClient.startCleaner()
// find free spot in pool
for poolIndex, dc := range cm.pool {
if dc == nil {
cm.pool[poolIndex] = newClient
newClient.poolIndex = poolIndex
return newClient
}
}
// append to pool
cm.pool = append(cm.pool, newClient)
newClient.poolIndex = len(cm.pool) - 1
// TODO: shrink pool again?
return newClient
}
@ -200,26 +176,12 @@ func (dc *dnsClient) startCleaner() {
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
select {
case <-time.After(dc.mgr.ttl + time.Second):
dc.mgr.lock.Lock()
cleanNow := dc.dead || !dc.inUse
dc.mgr.lock.Unlock()
if cleanNow {
dc.destroy()
return nil
}
case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod):
// destroy
case <-ctx.Done():
// give a short time before kill for graceful request completion
time.Sleep(100 * time.Millisecond)
}
// wait for grace period to end, then kill
select {
case <-time.After(connectionEOLGracePeriod):
case <-ctx.Done():
}
dc.destroy()
return nil
})

View file

@ -2,6 +2,7 @@ package resolver
import (
"sync"
"sync/atomic"
"testing"
"github.com/miekg/dns"
@ -11,7 +12,7 @@ var (
domainFeed = make(chan string)
)
func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) {
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
dnsClient := brc.clientManager.getDNSClient()
// create query
@ -23,6 +24,9 @@ func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Quer
if err != nil {
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
}
if new {
atomic.AddUint32(newCnt, 1)
}
// query server
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
@ -33,8 +37,8 @@ func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Quer
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
}
t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl)
dnsClient.done()
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
dnsClient.addToPool()
wg.Done()
}
@ -54,17 +58,18 @@ func TestClientPooling(t *testing.T) {
brc := resolver.Conn.(*BasicResolverConn)
wg := &sync.WaitGroup{}
var newCnt uint32
for i := 0; i < 10; i++ {
wg.Add(10)
for i := 0; i < 10; i++ {
go testQuery(t, wg, brc, &Query{
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
if len(brc.clientManager.pool) != 10 {
t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool))
if newCnt > uint32(10+i) {
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
}
}
}

View file

@ -14,8 +14,6 @@ import (
)
var (
mtAsyncResolve = "async resolve"
// basic errors
// ErrNotFound is a basic error that will match all "not found" errors
@ -160,7 +158,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
log.Tracer(ctx).Trace("resolver: serving from cache, requesting new")
// resolve async
module.StartLowPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error {
module.StartWorker("resolve async", func(ctx context.Context) error {
_, _ = resolveAndCache(ctx, q)
return nil
})
@ -220,11 +218,6 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error
return nil, ErrNoCompliance
}
// prep
lastFailBoundary := time.Now().Add(
-time.Duration(nameserverRetryRate()) * time.Second,
)
// start resolving
var i int

View file

@ -215,8 +215,8 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
return nil, errors.New("internal error")
}
// make client available again
dc.done()
// make client available (again)
dc.addToPool()
if resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{resolver.GetName()}