mirror of
https://github.com/safing/portmaster
synced 2025-09-01 10:09:11 +00:00
Switch resolver pooling to use sync.Pool
This commit is contained in:
parent
36c60a1e33
commit
c8223f1a63
4 changed files with 41 additions and 81 deletions
|
@ -12,8 +12,9 @@ import (
|
|||
|
||||
const (
|
||||
defaultClientTTL = 5 * time.Minute
|
||||
defaultRequestTimeout = 5 * time.Second
|
||||
connectionEOLGracePeriod = 10 * time.Second
|
||||
defaultRequestTimeout = 3 * time.Second // dns query
|
||||
defaultConnectTimeout = 2 * time.Second // tcp/tls
|
||||
connectionEOLGracePeriod = 7 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -43,23 +44,17 @@ type dnsClientManager struct {
|
|||
factory func() *dns.Client
|
||||
|
||||
// internal
|
||||
pool []*dnsClient
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
type dnsClient struct {
|
||||
mgr *dnsClientManager
|
||||
|
||||
inUse bool
|
||||
useUntil time.Time
|
||||
dead bool
|
||||
inPool bool
|
||||
poolIndex int
|
||||
|
||||
client *dns.Client
|
||||
conn *dns.Conn
|
||||
mgr *dnsClientManager
|
||||
client *dns.Client
|
||||
conn *dns.Conn
|
||||
useUntil time.Time
|
||||
}
|
||||
|
||||
// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
|
||||
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
|
||||
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
|
||||
if dc.conn == nil {
|
||||
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
|
||||
|
@ -71,23 +66,11 @@ func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
|
|||
return dc.conn, false, nil
|
||||
}
|
||||
|
||||
func (dc *dnsClient) done() {
|
||||
dc.mgr.lock.Lock()
|
||||
defer dc.mgr.lock.Unlock()
|
||||
|
||||
dc.inUse = false
|
||||
func (dc *dnsClient) addToPool() {
|
||||
dc.mgr.pool.Put(dc)
|
||||
}
|
||||
|
||||
func (dc *dnsClient) destroy() {
|
||||
dc.mgr.lock.Lock()
|
||||
dc.inUse = true // block from being used
|
||||
dc.dead = true // abort cleaning
|
||||
if dc.inPool {
|
||||
dc.inPool = false
|
||||
dc.mgr.pool[dc.poolIndex] = nil
|
||||
}
|
||||
dc.mgr.lock.Unlock()
|
||||
|
||||
if dc.conn != nil {
|
||||
_ = dc.conn.Close()
|
||||
}
|
||||
|
@ -118,6 +101,7 @@ func newTCPClientManager(resolver *Resolver) *dnsClientManager {
|
|||
Timeout: defaultRequestTimeout,
|
||||
Dialer: &net.Dialer{
|
||||
LocalAddr: getLocalAddr("tcp"),
|
||||
Timeout: defaultConnectTimeout,
|
||||
KeepAlive: defaultClientTTL,
|
||||
},
|
||||
}
|
||||
|
@ -140,6 +124,7 @@ func newTLSClientManager(resolver *Resolver) *dnsClientManager {
|
|||
Timeout: defaultRequestTimeout,
|
||||
Dialer: &net.Dialer{
|
||||
LocalAddr: getLocalAddr("tcp"),
|
||||
Timeout: defaultConnectTimeout,
|
||||
KeepAlive: defaultClientTTL,
|
||||
},
|
||||
}
|
||||
|
@ -159,11 +144,18 @@ func (cm *dnsClientManager) getDNSClient() *dnsClient {
|
|||
}
|
||||
}
|
||||
|
||||
// get first unused from pool
|
||||
// get cached client from pool
|
||||
now := time.Now().UTC()
|
||||
for _, dc := range cm.pool {
|
||||
if dc != nil && !dc.inUse && now.Before(dc.useUntil) {
|
||||
dc.inUse = true
|
||||
|
||||
poolLoop:
|
||||
for {
|
||||
dc, ok := cm.pool.Get().(*dnsClient)
|
||||
switch {
|
||||
case !ok || dc == nil: // cache empty (probably, pool may always return nil!)
|
||||
break poolLoop // create new
|
||||
case now.After(dc.useUntil):
|
||||
continue // get next
|
||||
default:
|
||||
return dc
|
||||
}
|
||||
}
|
||||
|
@ -171,27 +163,11 @@ func (cm *dnsClientManager) getDNSClient() *dnsClient {
|
|||
// no available in pool, create new
|
||||
newClient := &dnsClient{
|
||||
mgr: cm,
|
||||
inUse: true,
|
||||
useUntil: now.Add(cm.ttl),
|
||||
inPool: true,
|
||||
client: cm.factory(),
|
||||
useUntil: now.Add(cm.ttl),
|
||||
}
|
||||
newClient.startCleaner()
|
||||
|
||||
// find free spot in pool
|
||||
for poolIndex, dc := range cm.pool {
|
||||
if dc == nil {
|
||||
cm.pool[poolIndex] = newClient
|
||||
newClient.poolIndex = poolIndex
|
||||
return newClient
|
||||
}
|
||||
}
|
||||
|
||||
// append to pool
|
||||
cm.pool = append(cm.pool, newClient)
|
||||
newClient.poolIndex = len(cm.pool) - 1
|
||||
// TODO: shrink pool again?
|
||||
|
||||
return newClient
|
||||
}
|
||||
|
||||
|
@ -200,26 +176,12 @@ func (dc *dnsClient) startCleaner() {
|
|||
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
|
||||
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
|
||||
select {
|
||||
case <-time.After(dc.mgr.ttl + time.Second):
|
||||
dc.mgr.lock.Lock()
|
||||
cleanNow := dc.dead || !dc.inUse
|
||||
dc.mgr.lock.Unlock()
|
||||
|
||||
if cleanNow {
|
||||
dc.destroy()
|
||||
return nil
|
||||
}
|
||||
case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod):
|
||||
// destroy
|
||||
case <-ctx.Done():
|
||||
// give a short time before kill for graceful request completion
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// wait for grace period to end, then kill
|
||||
select {
|
||||
case <-time.After(connectionEOLGracePeriod):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
dc.destroy()
|
||||
return nil
|
||||
})
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -11,7 +12,7 @@ var (
|
|||
domainFeed = make(chan string)
|
||||
)
|
||||
|
||||
func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) {
|
||||
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
|
||||
dnsClient := brc.clientManager.getDNSClient()
|
||||
|
||||
// create query
|
||||
|
@ -23,6 +24,9 @@ func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Quer
|
|||
if err != nil {
|
||||
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
|
||||
}
|
||||
if new {
|
||||
atomic.AddUint32(newCnt, 1)
|
||||
}
|
||||
|
||||
// query server
|
||||
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
|
||||
|
@ -33,8 +37,8 @@ func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Quer
|
|||
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
|
||||
}
|
||||
|
||||
t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl)
|
||||
dnsClient.done()
|
||||
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
|
||||
dnsClient.addToPool()
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
|
@ -54,17 +58,18 @@ func TestClientPooling(t *testing.T) {
|
|||
brc := resolver.Conn.(*BasicResolverConn)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
var newCnt uint32
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go testQuery(t, wg, brc, &Query{
|
||||
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
|
||||
FQDN: <-domainFeed,
|
||||
QType: dns.Type(dns.TypeA),
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
if len(brc.clientManager.pool) != 10 {
|
||||
t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool))
|
||||
if newCnt > uint32(10+i) {
|
||||
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,8 +14,6 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
mtAsyncResolve = "async resolve"
|
||||
|
||||
// basic errors
|
||||
|
||||
// ErrNotFound is a basic error that will match all "not found" errors
|
||||
|
@ -160,7 +158,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache {
|
|||
log.Tracer(ctx).Trace("resolver: serving from cache, requesting new")
|
||||
|
||||
// resolve async
|
||||
module.StartLowPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error {
|
||||
module.StartWorker("resolve async", func(ctx context.Context) error {
|
||||
_, _ = resolveAndCache(ctx, q)
|
||||
return nil
|
||||
})
|
||||
|
@ -220,11 +218,6 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error
|
|||
return nil, ErrNoCompliance
|
||||
}
|
||||
|
||||
// prep
|
||||
lastFailBoundary := time.Now().Add(
|
||||
-time.Duration(nameserverRetryRate()) * time.Second,
|
||||
)
|
||||
|
||||
// start resolving
|
||||
|
||||
var i int
|
||||
|
|
|
@ -215,8 +215,8 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er
|
|||
return nil, errors.New("internal error")
|
||||
}
|
||||
|
||||
// make client available again
|
||||
dc.done()
|
||||
// make client available (again)
|
||||
dc.addToPool()
|
||||
|
||||
if resolver.IsBlockedUpstream(reply) {
|
||||
return nil, &BlockedUpstreamError{resolver.GetName()}
|
||||
|
|
Loading…
Add table
Reference in a new issue