From 7b40e83aecd154d5fb9ee5bffbc278833c72f00f Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 21 Jul 2020 14:58:14 +0200 Subject: [PATCH] Removed pooled server and add plain resolver --- resolver/main.go | 19 ++++ resolver/resolver-plain.go | 92 +++++++++++++++ resolver/resolver-pooled.go | 187 ------------------------------- resolver/resolver-pooled_test.go | 82 -------------- resolver/resolvers.go | 16 +-- 5 files changed, 112 insertions(+), 284 deletions(-) create mode 100644 resolver/resolver-plain.go delete mode 100644 resolver/resolver-pooled.go delete mode 100644 resolver/resolver-pooled_test.go diff --git a/resolver/main.go b/resolver/main.go index a68bd76e..c586d0df 100644 --- a/resolver/main.go +++ b/resolver/main.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "net" "strings" "time" @@ -94,3 +95,21 @@ func start() error { return nil } + +var ( + localAddrFactory func(network string) net.Addr +) + +// SetLocalAddrFactory supplies the intel package with a function to get permitted local addresses for connections. +func SetLocalAddrFactory(laf func(network string) net.Addr) { + if localAddrFactory == nil { + localAddrFactory = laf + } +} + +func getLocalAddr(network string) net.Addr { + if localAddrFactory != nil { + return localAddrFactory(network) + } + return nil +} diff --git a/resolver/resolver-plain.go b/resolver/resolver-plain.go new file mode 100644 index 00000000..991cf8a1 --- /dev/null +++ b/resolver/resolver-plain.go @@ -0,0 +1,92 @@ +package resolver + +import ( + "context" + "net" + "time" + + "github.com/miekg/dns" + "github.com/safing/portbase/log" + "github.com/safing/portmaster/netenv" +) + +var ( + defaultClientTTL = 5 * time.Minute + defaultRequestTimeout = 3 * time.Second // dns query + defaultConnectTimeout = 5 * time.Second // tcp/tls +) + +// PlainResolver is a resolver using plain DNS. +type PlainResolver struct { + BasicResolverConn +} + +// NewPlainResolver returns a new TPCResolver. +func NewPlainResolver(resolver *Resolver) *PlainResolver { + return &PlainResolver{ + BasicResolverConn: BasicResolverConn{ + resolver: resolver, + }, + } +} + +// Query executes the given query against the resolver. +func (pr *PlainResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { + // create query + dnsQuery := new(dns.Msg) + dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) + + // get timeout from context and config + var timeout time.Duration + if deadline, ok := ctx.Deadline(); !ok { + timeout = 0 + } else { + timeout = time.Until(deadline) + } + if timeout > defaultRequestTimeout { + timeout = defaultRequestTimeout + } + + // create client + dnsClient := &dns.Client{ + Timeout: timeout, + Dialer: &net.Dialer{ + Timeout: timeout, + LocalAddr: getLocalAddr("udp"), + }, + } + + // query server + reply, ttl, err := dnsClient.Exchange(dnsQuery, pr.resolver.ServerAddress) + log.Tracer(ctx).Tracef("resolver: query took %s", ttl) + // error handling + if err != nil { + // Hint network environment at failed connection if err is not a timeout. + if nErr, ok := err.(net.Error); ok && !nErr.Timeout() { + netenv.ReportFailedConnection() + } + + return nil, err + } + + // check if blocked + if pr.resolver.IsBlockedUpstream(reply) { + return nil, &BlockedUpstreamError{pr.resolver.GetName()} + } + + // hint network environment at successful connection + netenv.ReportSuccessfulConnection() + + newRecord := &RRCache{ + Domain: q.FQDN, + Question: q.QType, + Answer: reply.Answer, + Ns: reply.Ns, + Extra: reply.Extra, + Server: pr.resolver.Server, + ServerScope: pr.resolver.ServerIPScope, + } + + // TODO: check if reply.Answer is valid + return newRecord, nil +} diff --git a/resolver/resolver-pooled.go b/resolver/resolver-pooled.go deleted file mode 100644 index aed6c5a9..00000000 --- a/resolver/resolver-pooled.go +++ /dev/null @@ -1,187 +0,0 @@ -package resolver - -import ( - "context" - "crypto/tls" - "net" - "sync" - "time" - - "github.com/miekg/dns" - "github.com/safing/portbase/utils" -) - -var ( - defaultClientTTL = 5 * time.Minute - defaultRequestTimeout = 3 * time.Second // dns query - defaultConnectTimeout = 5 * time.Second // tcp/tls - connectionEOLGracePeriod = 7 * time.Second - - localAddrFactory func(network string) net.Addr -) - -// SetLocalAddrFactory supplies the intel package with a function to get permitted local addresses for connections. -func SetLocalAddrFactory(laf func(network string) net.Addr) { - if localAddrFactory == nil { - localAddrFactory = laf - } -} - -func getLocalAddr(network string) net.Addr { - if localAddrFactory != nil { - return localAddrFactory(network) - } - return nil -} - -type dnsClientManager struct { - lock sync.Mutex - - // set by creator - resolver *Resolver - ttl time.Duration // force refresh of connection to reduce traceability - factory func() *dns.Client - - // internal - pool utils.StablePool -} - -type dnsClient struct { - mgr *dnsClientManager - client *dns.Client - conn *dns.Conn - useUntil time.Time -} - -// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). -func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { - if dc.conn == nil { - dc.conn, err = dc.client.Dial(dc.mgr.resolver.ServerAddress) - if err != nil { - return nil, false, err - } - return dc.conn, true, nil - } - return dc.conn, false, nil -} - -func (dc *dnsClient) addToPool() { - dc.mgr.pool.Put(dc) -} - -func (dc *dnsClient) destroy() { - if dc.conn != nil { - _ = dc.conn.Close() - } -} - -func newDNSClientManager(resolver *Resolver) *dnsClientManager { - return &dnsClientManager{ - resolver: resolver, - ttl: 0, // new client for every request, as we need to randomize the port - factory: func() *dns.Client { - return &dns.Client{ - Timeout: defaultRequestTimeout, - Dialer: &net.Dialer{ - LocalAddr: getLocalAddr("udp"), - }, - } - }, - } -} - -func newTCPClientManager(resolver *Resolver) *dnsClientManager { - return &dnsClientManager{ - resolver: resolver, - ttl: defaultClientTTL, - factory: func() *dns.Client { - return &dns.Client{ - Net: "tcp", - Timeout: defaultRequestTimeout, - Dialer: &net.Dialer{ - LocalAddr: getLocalAddr("tcp"), - Timeout: defaultConnectTimeout, - KeepAlive: defaultClientTTL, - }, - } - }, - } -} - -func newTLSClientManager(resolver *Resolver) *dnsClientManager { - return &dnsClientManager{ - resolver: resolver, - ttl: defaultClientTTL, - factory: func() *dns.Client { - return &dns.Client{ - Net: "tcp-tls", - TLSConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - ServerName: resolver.VerifyDomain, - // TODO: use portbase rng - }, - Timeout: defaultRequestTimeout, - Dialer: &net.Dialer{ - LocalAddr: getLocalAddr("tcp"), - Timeout: defaultConnectTimeout, - KeepAlive: defaultClientTTL, - }, - } - }, - } -} - -func (cm *dnsClientManager) getDNSClient() *dnsClient { - cm.lock.Lock() - defer cm.lock.Unlock() - - // return new immediately if a new client should be used for every request - if cm.ttl == 0 { - return &dnsClient{ - mgr: cm, - client: cm.factory(), - } - } - - // get cached client from pool - now := time.Now().UTC() - -poolLoop: - for { - dc, ok := cm.pool.Get().(*dnsClient) - switch { - case !ok || dc == nil: // cache empty (probably, pool may always return nil!) - break poolLoop // create new - case now.After(dc.useUntil): - continue // get next - default: - return dc - } - } - - // no available in pool, create new - newClient := &dnsClient{ - mgr: cm, - client: cm.factory(), - useUntil: now.Add(cm.ttl), - } - newClient.startCleaner() - - return newClient -} - -// startCleaner waits for EOL of the client and then removes it from the pool. -func (dc *dnsClient) startCleaner() { - // While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone. - module.StartWorker("dns client cleanup", func(ctx context.Context) error { - select { - case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod): - // destroy - case <-ctx.Done(): - // give a short time before kill for graceful request completion - time.Sleep(100 * time.Millisecond) - } - dc.destroy() - return nil - }) -} diff --git a/resolver/resolver-pooled_test.go b/resolver/resolver-pooled_test.go deleted file mode 100644 index b7da984a..00000000 --- a/resolver/resolver-pooled_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package resolver - -import ( - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/miekg/dns" -) - -func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) { - dnsClient := brc.clientManager.getDNSClient() - - // create query - dnsQuery := new(dns.Msg) - dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) - - // get connection - conn, new, err := dnsClient.getConn() - if err != nil { - t.Logf("failed to connect: %s", err) //nolint:staticcheck - wg.Done() - return - } - if new { - atomic.AddUint32(newCnt, 1) - } - - // query server - reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn) - if err != nil { - t.Logf("client failed: %s", err) //nolint:staticcheck - wg.Done() - return - } - if reply == nil { - t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck - } - - t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl) - dnsClient.addToPool() - wg.Done() -} - -func TestClientPooling(t *testing.T) { - // skip if short - this test depends on the Internet and might fail randomly - if testing.Short() { - t.Skip() - } - - // create separate resolver for this test - resolver, _, err := createResolver(testResolver, "config") - if err != nil { - t.Fatal(err) - } - brc := &BasicResolverConn{ - clientManager: clientManagerFactory(resolver.ServerType)(resolver), - resolver: resolver, - } - resolver.Conn = brc - - started := time.Now() - - wg := &sync.WaitGroup{} - var newCnt uint32 - for i := 0; i < 10; i++ { - wg.Add(10) - for j := 0; j < 10; j++ { - go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck - FQDN: <-domainFeed, - QType: dns.Type(dns.TypeA), - }) - } - wg.Wait() - if newCnt > uint32(10+i) { - t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i) - } - } - - t.Logf("time taken: %s", time.Since(started)) -} diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 0dfa3881..d60caf00 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -69,22 +69,8 @@ func resolverConnFactory(resolver *Resolver) ResolverConn { return NewTCPResolver(resolver) case ServerTypeDoT: return NewTCPResolver(resolver).UseTLS() - default: - return &BasicResolverConn{ - clientManager: clientManagerFactory(resolver.ServerType)(resolver), - resolver: resolver, - } - } -} - -func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager { - switch serverType { case ServerTypeDNS: - return newDNSClientManager - case ServerTypeDoT: - return newTLSClientManager - case ServerTypeTCP: - return newTCPClientManager + return NewPlainResolver(resolver) default: return nil }