safing-portmaster/resolver/clients.go
Daniel 53eb309e72 Add TLS resolver connection reusing and pooling
Also, fix caching issues and add more tests
2020-05-15 22:43:06 +02:00

226 lines
4.9 KiB
Go

package resolver
import (
"context"
"crypto/tls"
"net"
"sync"
"time"
"github.com/miekg/dns"
)
const (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 5 * time.Second
connectionEOLGracePeriod = 10 * time.Second
)
var (
localAddrFactory func(network string) net.Addr
)
// SetLocalAddrFactory supplies the intel package with a function to get permitted local addresses for connections.
func SetLocalAddrFactory(laf func(network string) net.Addr) {
if localAddrFactory == nil {
localAddrFactory = laf
}
}
func getLocalAddr(network string) net.Addr {
if localAddrFactory != nil {
return localAddrFactory(network)
}
return nil
}
type dnsClientManager struct {
lock sync.Mutex
// set by creator
serverAddress string
ttl time.Duration // force refresh of connection to reduce traceability
factory func() *dns.Client
// internal
pool []*dnsClient
}
type dnsClient struct {
mgr *dnsClientManager
inUse bool
useUntil time.Time
dead bool
inPool bool
poolIndex int
client *dns.Client
conn *dns.Conn
}
// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
if dc.conn == nil {
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
if err != nil {
return nil, false, err
}
return dc.conn, true, nil
}
return dc.conn, false, nil
}
func (dc *dnsClient) done() {
dc.mgr.lock.Lock()
defer dc.mgr.lock.Unlock()
dc.inUse = false
}
func (dc *dnsClient) destroy() {
dc.mgr.lock.Lock()
dc.inUse = true // block from being used
dc.dead = true // abort cleaning
if dc.inPool {
dc.inPool = false
dc.mgr.pool[dc.poolIndex] = nil
}
dc.mgr.lock.Unlock()
if dc.conn != nil {
_ = dc.conn.Close()
}
}
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: 0, // new client for every request, as we need to randomize the port
factory: func() *dns.Client {
return &dns.Client{
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("udp"),
},
}
},
}
}
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp",
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
KeepAlive: defaultClientTTL,
},
}
},
}
}
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
serverAddress: resolver.ServerAddress,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp-tls",
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: resolver.VerifyDomain,
// TODO: use portbase rng
},
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
KeepAlive: defaultClientTTL,
},
}
},
}
}
func (cm *dnsClientManager) getDNSClient() *dnsClient {
cm.lock.Lock()
defer cm.lock.Unlock()
// return new immediately if a new client should be used for every request
if cm.ttl == 0 {
return &dnsClient{
mgr: cm,
client: cm.factory(),
}
}
// get first unused from pool
now := time.Now().UTC()
for _, dc := range cm.pool {
if dc != nil && !dc.inUse && now.Before(dc.useUntil) {
dc.inUse = true
return dc
}
}
// no available in pool, create new
newClient := &dnsClient{
mgr: cm,
inUse: true,
useUntil: now.Add(cm.ttl),
inPool: true,
client: cm.factory(),
}
newClient.startCleaner()
// find free spot in pool
for poolIndex, dc := range cm.pool {
if dc == nil {
cm.pool[poolIndex] = newClient
newClient.poolIndex = poolIndex
return newClient
}
}
// append to pool
cm.pool = append(cm.pool, newClient)
newClient.poolIndex = len(cm.pool) - 1
// TODO: shrink pool again?
return newClient
}
// startCleaner waits for EOL of the client and then removes it from the pool.
func (dc *dnsClient) startCleaner() {
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
select {
case <-time.After(dc.mgr.ttl + time.Second):
dc.mgr.lock.Lock()
cleanNow := dc.dead || !dc.inUse
dc.mgr.lock.Unlock()
if cleanNow {
dc.destroy()
return nil
}
case <-ctx.Done():
// give a short time before kill for graceful request completion
time.Sleep(100 * time.Millisecond)
}
// wait for grace period to end, then kill
select {
case <-time.After(connectionEOLGracePeriod):
case <-ctx.Done():
}
dc.destroy()
return nil
})
}