Removed pooled server and add plain resolver

This commit is contained in:
Daniel 2020-07-21 14:58:14 +02:00
parent b87ba37d4c
commit 7b40e83aec
5 changed files with 112 additions and 284 deletions

View file

@ -2,6 +2,7 @@ package resolver
import (
"context"
"net"
"strings"
"time"
@ -94,3 +95,21 @@ func start() error {
return nil
}
var (
localAddrFactory func(network string) net.Addr
)
// SetLocalAddrFactory supplies the intel package with a function to get permitted local addresses for connections.
func SetLocalAddrFactory(laf func(network string) net.Addr) {
if localAddrFactory == nil {
localAddrFactory = laf
}
}
func getLocalAddr(network string) net.Addr {
if localAddrFactory != nil {
return localAddrFactory(network)
}
return nil
}

View file

@ -0,0 +1,92 @@
package resolver
import (
"context"
"net"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netenv"
)
var (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 3 * time.Second // dns query
defaultConnectTimeout = 5 * time.Second // tcp/tls
)
// PlainResolver is a resolver using plain DNS.
type PlainResolver struct {
BasicResolverConn
}
// NewPlainResolver returns a new TPCResolver.
func NewPlainResolver(resolver *Resolver) *PlainResolver {
return &PlainResolver{
BasicResolverConn: BasicResolverConn{
resolver: resolver,
},
}
}
// Query executes the given query against the resolver.
func (pr *PlainResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get timeout from context and config
var timeout time.Duration
if deadline, ok := ctx.Deadline(); !ok {
timeout = 0
} else {
timeout = time.Until(deadline)
}
if timeout > defaultRequestTimeout {
timeout = defaultRequestTimeout
}
// create client
dnsClient := &dns.Client{
Timeout: timeout,
Dialer: &net.Dialer{
Timeout: timeout,
LocalAddr: getLocalAddr("udp"),
},
}
// query server
reply, ttl, err := dnsClient.Exchange(dnsQuery, pr.resolver.ServerAddress)
log.Tracer(ctx).Tracef("resolver: query took %s", ttl)
// error handling
if err != nil {
// Hint network environment at failed connection if err is not a timeout.
if nErr, ok := err.(net.Error); ok && !nErr.Timeout() {
netenv.ReportFailedConnection()
}
return nil, err
}
// check if blocked
if pr.resolver.IsBlockedUpstream(reply) {
return nil, &BlockedUpstreamError{pr.resolver.GetName()}
}
// hint network environment at successful connection
netenv.ReportSuccessfulConnection()
newRecord := &RRCache{
Domain: q.FQDN,
Question: q.QType,
Answer: reply.Answer,
Ns: reply.Ns,
Extra: reply.Extra,
Server: pr.resolver.Server,
ServerScope: pr.resolver.ServerIPScope,
}
// TODO: check if reply.Answer is valid
return newRecord, nil
}

View file

@ -1,187 +0,0 @@
package resolver
import (
"context"
"crypto/tls"
"net"
"sync"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/utils"
)
var (
defaultClientTTL = 5 * time.Minute
defaultRequestTimeout = 3 * time.Second // dns query
defaultConnectTimeout = 5 * time.Second // tcp/tls
connectionEOLGracePeriod = 7 * time.Second
localAddrFactory func(network string) net.Addr
)
// SetLocalAddrFactory supplies the intel package with a function to get permitted local addresses for connections.
func SetLocalAddrFactory(laf func(network string) net.Addr) {
if localAddrFactory == nil {
localAddrFactory = laf
}
}
func getLocalAddr(network string) net.Addr {
if localAddrFactory != nil {
return localAddrFactory(network)
}
return nil
}
type dnsClientManager struct {
lock sync.Mutex
// set by creator
resolver *Resolver
ttl time.Duration // force refresh of connection to reduce traceability
factory func() *dns.Client
// internal
pool utils.StablePool
}
type dnsClient struct {
mgr *dnsClientManager
client *dns.Client
conn *dns.Conn
useUntil time.Time
}
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
if dc.conn == nil {
dc.conn, err = dc.client.Dial(dc.mgr.resolver.ServerAddress)
if err != nil {
return nil, false, err
}
return dc.conn, true, nil
}
return dc.conn, false, nil
}
func (dc *dnsClient) addToPool() {
dc.mgr.pool.Put(dc)
}
func (dc *dnsClient) destroy() {
if dc.conn != nil {
_ = dc.conn.Close()
}
}
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
resolver: resolver,
ttl: 0, // new client for every request, as we need to randomize the port
factory: func() *dns.Client {
return &dns.Client{
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("udp"),
},
}
},
}
}
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
resolver: resolver,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp",
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
},
}
},
}
}
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
return &dnsClientManager{
resolver: resolver,
ttl: defaultClientTTL,
factory: func() *dns.Client {
return &dns.Client{
Net: "tcp-tls",
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: resolver.VerifyDomain,
// TODO: use portbase rng
},
Timeout: defaultRequestTimeout,
Dialer: &net.Dialer{
LocalAddr: getLocalAddr("tcp"),
Timeout: defaultConnectTimeout,
KeepAlive: defaultClientTTL,
},
}
},
}
}
func (cm *dnsClientManager) getDNSClient() *dnsClient {
cm.lock.Lock()
defer cm.lock.Unlock()
// return new immediately if a new client should be used for every request
if cm.ttl == 0 {
return &dnsClient{
mgr: cm,
client: cm.factory(),
}
}
// get cached client from pool
now := time.Now().UTC()
poolLoop:
for {
dc, ok := cm.pool.Get().(*dnsClient)
switch {
case !ok || dc == nil: // cache empty (probably, pool may always return nil!)
break poolLoop // create new
case now.After(dc.useUntil):
continue // get next
default:
return dc
}
}
// no available in pool, create new
newClient := &dnsClient{
mgr: cm,
client: cm.factory(),
useUntil: now.Add(cm.ttl),
}
newClient.startCleaner()
return newClient
}
// startCleaner waits for EOL of the client and then removes it from the pool.
func (dc *dnsClient) startCleaner() {
// While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone.
module.StartWorker("dns client cleanup", func(ctx context.Context) error {
select {
case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod):
// destroy
case <-ctx.Done():
// give a short time before kill for graceful request completion
time.Sleep(100 * time.Millisecond)
}
dc.destroy()
return nil
})
}

View file

@ -1,82 +0,0 @@
package resolver
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
)
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
dnsClient := brc.clientManager.getDNSClient()
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// get connection
conn, new, err := dnsClient.getConn()
if err != nil {
t.Logf("failed to connect: %s", err) //nolint:staticcheck
wg.Done()
return
}
if new {
atomic.AddUint32(newCnt, 1)
}
// query server
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
if err != nil {
t.Logf("client failed: %s", err) //nolint:staticcheck
wg.Done()
return
}
if reply == nil {
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
}
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
dnsClient.addToPool()
wg.Done()
}
func TestClientPooling(t *testing.T) {
// skip if short - this test depends on the Internet and might fail randomly
if testing.Short() {
t.Skip()
}
// create separate resolver for this test
resolver, _, err := createResolver(testResolver, "config")
if err != nil {
t.Fatal(err)
}
brc := &BasicResolverConn{
clientManager: clientManagerFactory(resolver.ServerType)(resolver),
resolver: resolver,
}
resolver.Conn = brc
started := time.Now()
wg := &sync.WaitGroup{}
var newCnt uint32
for i := 0; i < 10; i++ {
wg.Add(10)
for j := 0; j < 10; j++ {
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
FQDN: <-domainFeed,
QType: dns.Type(dns.TypeA),
})
}
wg.Wait()
if newCnt > uint32(10+i) {
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
}
}
t.Logf("time taken: %s", time.Since(started))
}

View file

@ -69,22 +69,8 @@ func resolverConnFactory(resolver *Resolver) ResolverConn {
return NewTCPResolver(resolver)
case ServerTypeDoT:
return NewTCPResolver(resolver).UseTLS()
default:
return &BasicResolverConn{
clientManager: clientManagerFactory(resolver.ServerType)(resolver),
resolver: resolver,
}
}
}
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
switch serverType {
case ServerTypeDNS:
return newDNSClientManager
case ServerTypeDoT:
return newTLSClientManager
case ServerTypeTCP:
return newTCPClientManager
return NewPlainResolver(resolver)
default:
return nil
}