mirror of
https://github.com/safing/portmaster
synced 2025-09-01 10:09:11 +00:00
Add TCP/TLS pipelining dns resolver
This commit is contained in:
parent
fe3b61f1a3
commit
f7320d760d
12 changed files with 672 additions and 208 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -12,6 +12,9 @@ dist
|
||||||
# vendor dir
|
# vendor dir
|
||||||
vendor
|
vendor
|
||||||
|
|
||||||
|
# testing
|
||||||
|
testing
|
||||||
|
|
||||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||||
*.o
|
*.o
|
||||||
*.a
|
*.a
|
||||||
|
|
|
@ -54,7 +54,9 @@ func TestMain(m *testing.M, module *modules.Module) {
|
||||||
// are shutdown.
|
// are shutdown.
|
||||||
func TestMainWithHooks(m *testing.M, module *modules.Module, afterStartFn, beforeStopFn TestHookFunc) {
|
func TestMainWithHooks(m *testing.M, module *modules.Module, afterStartFn, beforeStopFn TestHookFunc) {
|
||||||
// enable module for testing
|
// enable module for testing
|
||||||
module.Enable()
|
if module != nil {
|
||||||
|
module.Enable()
|
||||||
|
}
|
||||||
|
|
||||||
// switch databases to memory only
|
// switch databases to memory only
|
||||||
base.DefaultDatabaseStorageType = "hashmap"
|
base.DefaultDatabaseStorageType = "hashmap"
|
||||||
|
|
|
@ -6,6 +6,128 @@ import (
|
||||||
"github.com/safing/portmaster/core/pmtesting"
|
"github.com/safing/portmaster/core/pmtesting"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
domainFeed = make(chan string)
|
||||||
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
pmtesting.TestMain(m, module)
|
pmtesting.TestMain(m, module)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
go feedDomains()
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedDomains() {
|
||||||
|
for {
|
||||||
|
for _, domain := range testDomains {
|
||||||
|
domainFeed <- domain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data
|
||||||
|
|
||||||
|
var (
|
||||||
|
testDomains = []string{
|
||||||
|
"facebook.com.",
|
||||||
|
"google.com.",
|
||||||
|
"youtube.com.",
|
||||||
|
"twitter.com.",
|
||||||
|
"instagram.com.",
|
||||||
|
"linkedin.com.",
|
||||||
|
"microsoft.com.",
|
||||||
|
"apple.com.",
|
||||||
|
"wikipedia.org.",
|
||||||
|
"plus.google.com.",
|
||||||
|
"en.wikipedia.org.",
|
||||||
|
"googletagmanager.com.",
|
||||||
|
"youtu.be.",
|
||||||
|
"adobe.com.",
|
||||||
|
"vimeo.com.",
|
||||||
|
"pinterest.com.",
|
||||||
|
"itunes.apple.com.",
|
||||||
|
"play.google.com.",
|
||||||
|
"maps.google.com.",
|
||||||
|
"goo.gl.",
|
||||||
|
"wordpress.com.",
|
||||||
|
"blogspot.com.",
|
||||||
|
"bit.ly.",
|
||||||
|
"github.com.",
|
||||||
|
"player.vimeo.com.",
|
||||||
|
"amazon.com.",
|
||||||
|
"wordpress.org.",
|
||||||
|
"docs.google.com.",
|
||||||
|
"yahoo.com.",
|
||||||
|
"mozilla.org.",
|
||||||
|
"tumblr.com.",
|
||||||
|
"godaddy.com.",
|
||||||
|
"flickr.com.",
|
||||||
|
"parked-content.godaddy.com.",
|
||||||
|
"drive.google.com.",
|
||||||
|
"support.google.com.",
|
||||||
|
"apache.org.",
|
||||||
|
"gravatar.com.",
|
||||||
|
"europa.eu.",
|
||||||
|
"qq.com.",
|
||||||
|
"w3.org.",
|
||||||
|
"nytimes.com.",
|
||||||
|
"reddit.com.",
|
||||||
|
"macromedia.com.",
|
||||||
|
"get.adobe.com.",
|
||||||
|
"soundcloud.com.",
|
||||||
|
"sourceforge.net.",
|
||||||
|
"sites.google.com.",
|
||||||
|
"nih.gov.",
|
||||||
|
"amazonaws.com.",
|
||||||
|
"t.co.",
|
||||||
|
"support.microsoft.com.",
|
||||||
|
"forbes.com.",
|
||||||
|
"theguardian.com.",
|
||||||
|
"cnn.com.",
|
||||||
|
"github.io.",
|
||||||
|
"bbc.co.uk.",
|
||||||
|
"dropbox.com.",
|
||||||
|
"whatsapp.com.",
|
||||||
|
"medium.com.",
|
||||||
|
"creativecommons.org.",
|
||||||
|
"www.ncbi.nlm.nih.gov.",
|
||||||
|
"httpd.apache.org.",
|
||||||
|
"archive.org.",
|
||||||
|
"ec.europa.eu.",
|
||||||
|
"php.net.",
|
||||||
|
"apps.apple.com.",
|
||||||
|
"weebly.com.",
|
||||||
|
"support.apple.com.",
|
||||||
|
"weibo.com.",
|
||||||
|
"wixsite.com.",
|
||||||
|
"issuu.com.",
|
||||||
|
"who.int.",
|
||||||
|
"paypal.com.",
|
||||||
|
"m.facebook.com.",
|
||||||
|
"oracle.com.",
|
||||||
|
"msn.com.",
|
||||||
|
"gnu.org.",
|
||||||
|
"tinyurl.com.",
|
||||||
|
"reuters.com.",
|
||||||
|
"l.facebook.com.",
|
||||||
|
"cloudflare.com.",
|
||||||
|
"wsj.com.",
|
||||||
|
"washingtonpost.com.",
|
||||||
|
"domainmarket.com.",
|
||||||
|
"imdb.com.",
|
||||||
|
"bbc.com.",
|
||||||
|
"bing.com.",
|
||||||
|
"accounts.google.com.",
|
||||||
|
"vk.com.",
|
||||||
|
"api.whatsapp.com.",
|
||||||
|
"opera.com.",
|
||||||
|
"cdc.gov.",
|
||||||
|
"slideshare.net.",
|
||||||
|
"wpa.qq.com.",
|
||||||
|
"harvard.edu.",
|
||||||
|
"mit.edu.",
|
||||||
|
"code.google.com.",
|
||||||
|
"wikimedia.org.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
|
@ -1,189 +0,0 @@
|
||||||
package resolver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
domainFeed = make(chan string)
|
|
||||||
)
|
|
||||||
|
|
||||||
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
|
|
||||||
dnsClient := brc.clientManager.getDNSClient()
|
|
||||||
|
|
||||||
// create query
|
|
||||||
dnsQuery := new(dns.Msg)
|
|
||||||
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
|
|
||||||
|
|
||||||
// get connection
|
|
||||||
conn, new, err := dnsClient.getConn()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to connect: %s", err) //nolint:staticcheck
|
|
||||||
}
|
|
||||||
if new {
|
|
||||||
atomic.AddUint32(newCnt, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// query server
|
|
||||||
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err) //nolint:staticcheck
|
|
||||||
}
|
|
||||||
if reply == nil {
|
|
||||||
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
|
|
||||||
dnsClient.addToPool()
|
|
||||||
wg.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientPooling(t *testing.T) {
|
|
||||||
// skip if short - this test depends on the Internet and might fail randomly
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip()
|
|
||||||
}
|
|
||||||
|
|
||||||
go feedDomains()
|
|
||||||
|
|
||||||
// create separate resolver for this test
|
|
||||||
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
brc := resolver.Conn.(*BasicResolverConn)
|
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
var newCnt uint32
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
wg.Add(10)
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
|
|
||||||
FQDN: <-domainFeed,
|
|
||||||
QType: dns.Type(dns.TypeA),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if newCnt > uint32(10+i) {
|
|
||||||
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func feedDomains() {
|
|
||||||
for {
|
|
||||||
for _, domain := range poolingTestDomains {
|
|
||||||
domainFeed <- domain
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Data
|
|
||||||
|
|
||||||
var (
|
|
||||||
poolingTestDomains = []string{
|
|
||||||
"facebook.com.",
|
|
||||||
"google.com.",
|
|
||||||
"youtube.com.",
|
|
||||||
"twitter.com.",
|
|
||||||
"instagram.com.",
|
|
||||||
"linkedin.com.",
|
|
||||||
"microsoft.com.",
|
|
||||||
"apple.com.",
|
|
||||||
"wikipedia.org.",
|
|
||||||
"plus.google.com.",
|
|
||||||
"en.wikipedia.org.",
|
|
||||||
"googletagmanager.com.",
|
|
||||||
"youtu.be.",
|
|
||||||
"adobe.com.",
|
|
||||||
"vimeo.com.",
|
|
||||||
"pinterest.com.",
|
|
||||||
"itunes.apple.com.",
|
|
||||||
"play.google.com.",
|
|
||||||
"maps.google.com.",
|
|
||||||
"goo.gl.",
|
|
||||||
"wordpress.com.",
|
|
||||||
"blogspot.com.",
|
|
||||||
"bit.ly.",
|
|
||||||
"github.com.",
|
|
||||||
"player.vimeo.com.",
|
|
||||||
"amazon.com.",
|
|
||||||
"wordpress.org.",
|
|
||||||
"docs.google.com.",
|
|
||||||
"yahoo.com.",
|
|
||||||
"mozilla.org.",
|
|
||||||
"tumblr.com.",
|
|
||||||
"godaddy.com.",
|
|
||||||
"flickr.com.",
|
|
||||||
"parked-content.godaddy.com.",
|
|
||||||
"drive.google.com.",
|
|
||||||
"support.google.com.",
|
|
||||||
"apache.org.",
|
|
||||||
"gravatar.com.",
|
|
||||||
"europa.eu.",
|
|
||||||
"qq.com.",
|
|
||||||
"w3.org.",
|
|
||||||
"nytimes.com.",
|
|
||||||
"reddit.com.",
|
|
||||||
"macromedia.com.",
|
|
||||||
"get.adobe.com.",
|
|
||||||
"soundcloud.com.",
|
|
||||||
"sourceforge.net.",
|
|
||||||
"sites.google.com.",
|
|
||||||
"nih.gov.",
|
|
||||||
"amazonaws.com.",
|
|
||||||
"t.co.",
|
|
||||||
"support.microsoft.com.",
|
|
||||||
"forbes.com.",
|
|
||||||
"theguardian.com.",
|
|
||||||
"cnn.com.",
|
|
||||||
"github.io.",
|
|
||||||
"bbc.co.uk.",
|
|
||||||
"dropbox.com.",
|
|
||||||
"whatsapp.com.",
|
|
||||||
"medium.com.",
|
|
||||||
"creativecommons.org.",
|
|
||||||
"www.ncbi.nlm.nih.gov.",
|
|
||||||
"httpd.apache.org.",
|
|
||||||
"archive.org.",
|
|
||||||
"ec.europa.eu.",
|
|
||||||
"php.net.",
|
|
||||||
"apps.apple.com.",
|
|
||||||
"weebly.com.",
|
|
||||||
"support.apple.com.",
|
|
||||||
"weibo.com.",
|
|
||||||
"wixsite.com.",
|
|
||||||
"issuu.com.",
|
|
||||||
"who.int.",
|
|
||||||
"paypal.com.",
|
|
||||||
"m.facebook.com.",
|
|
||||||
"oracle.com.",
|
|
||||||
"msn.com.",
|
|
||||||
"gnu.org.",
|
|
||||||
"tinyurl.com.",
|
|
||||||
"reuters.com.",
|
|
||||||
"l.facebook.com.",
|
|
||||||
"cloudflare.com.",
|
|
||||||
"wsj.com.",
|
|
||||||
"washingtonpost.com.",
|
|
||||||
"domainmarket.com.",
|
|
||||||
"imdb.com.",
|
|
||||||
"bbc.com.",
|
|
||||||
"bing.com.",
|
|
||||||
"accounts.google.com.",
|
|
||||||
"vk.com.",
|
|
||||||
"api.whatsapp.com.",
|
|
||||||
"opera.com.",
|
|
||||||
"cdc.gov.",
|
|
||||||
"slideshare.net.",
|
|
||||||
"wpa.qq.com.",
|
|
||||||
"harvard.edu.",
|
|
||||||
"mit.edu.",
|
|
||||||
"code.google.com.",
|
|
||||||
"wikimedia.org.",
|
|
||||||
}
|
|
||||||
)
|
|
|
@ -22,6 +22,8 @@ var (
|
||||||
ErrBlocked = errors.New("query was blocked")
|
ErrBlocked = errors.New("query was blocked")
|
||||||
// ErrLocalhost is returned to *.localhost queries
|
// ErrLocalhost is returned to *.localhost queries
|
||||||
ErrLocalhost = errors.New("query for localhost")
|
ErrLocalhost = errors.New("query for localhost")
|
||||||
|
// ErrTimeout is returned when a query times out
|
||||||
|
ErrTimeout = errors.New("query timed out")
|
||||||
|
|
||||||
// detailed errors
|
// detailed errors
|
||||||
|
|
||||||
|
|
|
@ -8,12 +8,13 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/safing/portbase/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultClientTTL = 5 * time.Minute
|
defaultClientTTL = 5 * time.Minute
|
||||||
defaultRequestTimeout = 3 * time.Second // dns query
|
defaultRequestTimeout = 3 * time.Second // dns query
|
||||||
defaultConnectTimeout = 2 * time.Second // tcp/tls
|
defaultConnectTimeout = 5 * time.Second // tcp/tls
|
||||||
connectionEOLGracePeriod = 7 * time.Second
|
connectionEOLGracePeriod = 7 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,12 +40,12 @@ type dnsClientManager struct {
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
|
|
||||||
// set by creator
|
// set by creator
|
||||||
serverAddress string
|
resolver *Resolver
|
||||||
ttl time.Duration // force refresh of connection to reduce traceability
|
ttl time.Duration // force refresh of connection to reduce traceability
|
||||||
factory func() *dns.Client
|
factory func() *dns.Client
|
||||||
|
|
||||||
// internal
|
// internal
|
||||||
pool sync.Pool
|
pool utils.StablePool
|
||||||
}
|
}
|
||||||
|
|
||||||
type dnsClient struct {
|
type dnsClient struct {
|
||||||
|
@ -57,7 +58,7 @@ type dnsClient struct {
|
||||||
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
|
// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done().
|
||||||
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
|
func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) {
|
||||||
if dc.conn == nil {
|
if dc.conn == nil {
|
||||||
dc.conn, err = dc.client.Dial(dc.mgr.serverAddress)
|
dc.conn, err = dc.client.Dial(dc.mgr.resolver.ServerAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
@ -78,8 +79,8 @@ func (dc *dnsClient) destroy() {
|
||||||
|
|
||||||
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
|
func newDNSClientManager(resolver *Resolver) *dnsClientManager {
|
||||||
return &dnsClientManager{
|
return &dnsClientManager{
|
||||||
serverAddress: resolver.ServerAddress,
|
resolver: resolver,
|
||||||
ttl: 0, // new client for every request, as we need to randomize the port
|
ttl: 0, // new client for every request, as we need to randomize the port
|
||||||
factory: func() *dns.Client {
|
factory: func() *dns.Client {
|
||||||
return &dns.Client{
|
return &dns.Client{
|
||||||
Timeout: defaultRequestTimeout,
|
Timeout: defaultRequestTimeout,
|
||||||
|
@ -93,8 +94,8 @@ func newDNSClientManager(resolver *Resolver) *dnsClientManager {
|
||||||
|
|
||||||
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
|
func newTCPClientManager(resolver *Resolver) *dnsClientManager {
|
||||||
return &dnsClientManager{
|
return &dnsClientManager{
|
||||||
serverAddress: resolver.ServerAddress,
|
resolver: resolver,
|
||||||
ttl: defaultClientTTL,
|
ttl: defaultClientTTL,
|
||||||
factory: func() *dns.Client {
|
factory: func() *dns.Client {
|
||||||
return &dns.Client{
|
return &dns.Client{
|
||||||
Net: "tcp",
|
Net: "tcp",
|
||||||
|
@ -111,8 +112,8 @@ func newTCPClientManager(resolver *Resolver) *dnsClientManager {
|
||||||
|
|
||||||
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
|
func newTLSClientManager(resolver *Resolver) *dnsClientManager {
|
||||||
return &dnsClientManager{
|
return &dnsClientManager{
|
||||||
serverAddress: resolver.ServerAddress,
|
resolver: resolver,
|
||||||
ttl: defaultClientTTL,
|
ttl: defaultClientTTL,
|
||||||
factory: func() *dns.Client {
|
factory: func() *dns.Client {
|
||||||
return &dns.Client{
|
return &dns.Client{
|
||||||
Net: "tcp-tls",
|
Net: "tcp-tls",
|
83
resolver/resolver-pooled_test.go
Normal file
83
resolver/resolver-pooled_test.go
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) {
|
||||||
|
dnsClient := brc.clientManager.getDNSClient()
|
||||||
|
|
||||||
|
// create query
|
||||||
|
dnsQuery := new(dns.Msg)
|
||||||
|
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
|
||||||
|
|
||||||
|
// get connection
|
||||||
|
conn, new, err := dnsClient.getConn()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("failed to connect: %s", err) //nolint:staticcheck
|
||||||
|
wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if new {
|
||||||
|
atomic.AddUint32(newCnt, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// query server
|
||||||
|
reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("client failed: %s", err) //nolint:staticcheck
|
||||||
|
wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reply == nil {
|
||||||
|
t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl)
|
||||||
|
dnsClient.addToPool()
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientPooling(t *testing.T) {
|
||||||
|
// skip if short - this test depends on the Internet and might fail randomly
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
// create separate resolver for this test
|
||||||
|
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
|
||||||
|
// resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
brc := &BasicResolverConn{
|
||||||
|
clientManager: clientManagerFactory(resolver.ServerType)(resolver),
|
||||||
|
resolver: resolver,
|
||||||
|
}
|
||||||
|
resolver.Conn = brc
|
||||||
|
|
||||||
|
started := time.Now()
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
var newCnt uint32
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(10)
|
||||||
|
for j := 0; j < 10; j++ {
|
||||||
|
go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck
|
||||||
|
FQDN: <-domainFeed,
|
||||||
|
QType: dns.Type(dns.TypeA),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
if newCnt > uint32(10+i) {
|
||||||
|
t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("time taken: %s", time.Since(started))
|
||||||
|
}
|
359
resolver/resolver-tcp.go
Normal file
359
resolver/resolver-tcp.go
Normal file
|
@ -0,0 +1,359 @@
|
||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/safing/portbase/log"
|
||||||
|
"github.com/safing/portmaster/netenv"
|
||||||
|
"github.com/tevino/abool"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpWriteTimeout = 1 * time.Second
|
||||||
|
ignoreQueriesAfter = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPResolver is a resolver using just a single tcp connection with pipelining.
|
||||||
|
type TCPResolver struct {
|
||||||
|
BasicResolverConn
|
||||||
|
|
||||||
|
clientTTL time.Duration
|
||||||
|
dnsClient *dns.Client
|
||||||
|
dnsConnection *dns.Conn
|
||||||
|
connInstanceID *uint32
|
||||||
|
|
||||||
|
queries chan *dns.Msg
|
||||||
|
inFlightQueries map[uint16]*InFlightQuery
|
||||||
|
clientStarted *abool.AtomicBool
|
||||||
|
}
|
||||||
|
|
||||||
|
// InFlightQuery represents an in flight query of a TCPResolver.
|
||||||
|
type InFlightQuery struct {
|
||||||
|
Query *Query
|
||||||
|
Msg *dns.Msg
|
||||||
|
Response chan *dns.Msg
|
||||||
|
Resolver *Resolver
|
||||||
|
Started time.Time
|
||||||
|
ConnInstanceID uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeCacheRecord creates an RCache record from a reply.
|
||||||
|
func (ifq *InFlightQuery) MakeCacheRecord(reply *dns.Msg) *RRCache {
|
||||||
|
return &RRCache{
|
||||||
|
Domain: ifq.Query.FQDN,
|
||||||
|
Question: ifq.Query.QType,
|
||||||
|
Answer: reply.Answer,
|
||||||
|
Ns: reply.Ns,
|
||||||
|
Extra: reply.Extra,
|
||||||
|
Server: ifq.Resolver.Server,
|
||||||
|
ServerScope: ifq.Resolver.ServerIPScope,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTCPResolver returns a new TPCResolver.
|
||||||
|
func NewTCPResolver(resolver *Resolver) *TCPResolver {
|
||||||
|
var instanceID uint32
|
||||||
|
return &TCPResolver{
|
||||||
|
BasicResolverConn: BasicResolverConn{
|
||||||
|
resolver: resolver,
|
||||||
|
},
|
||||||
|
clientTTL: defaultClientTTL,
|
||||||
|
dnsClient: &dns.Client{
|
||||||
|
Net: "tcp",
|
||||||
|
Timeout: defaultConnectTimeout,
|
||||||
|
WriteTimeout: tcpWriteTimeout,
|
||||||
|
},
|
||||||
|
connInstanceID: &instanceID,
|
||||||
|
queries: make(chan *dns.Msg, 100),
|
||||||
|
inFlightQueries: make(map[uint16]*InFlightQuery),
|
||||||
|
clientStarted: abool.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseTLS enabled TLS for the TCPResolver. TLS settings must be correctly configured in the Resolver.
|
||||||
|
func (tr *TCPResolver) UseTLS() *TCPResolver {
|
||||||
|
tr.dnsClient.Net = "tcp-tls"
|
||||||
|
tr.dnsClient.TLSConfig = &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: tr.resolver.VerifyDomain,
|
||||||
|
// TODO: use portbase rng
|
||||||
|
}
|
||||||
|
return tr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *TCPResolver) client(workerCtx context.Context) error { //nolint:gocognit,gocyclo // TODO
|
||||||
|
connTimer := time.NewTimer(tr.clientTTL)
|
||||||
|
connClosing := abool.New()
|
||||||
|
var connCtx context.Context
|
||||||
|
var cancelConnCtx func()
|
||||||
|
var recycleConn bool
|
||||||
|
var shuttingDown bool
|
||||||
|
var incoming = make(chan *dns.Msg, 100)
|
||||||
|
|
||||||
|
// enable client restarting after crash
|
||||||
|
defer tr.clientStarted.UnSet()
|
||||||
|
|
||||||
|
connMgmt:
|
||||||
|
for {
|
||||||
|
// cleanup old connection
|
||||||
|
if tr.dnsConnection != nil {
|
||||||
|
connClosing.Set()
|
||||||
|
_ = tr.dnsConnection.Close()
|
||||||
|
cancelConnCtx()
|
||||||
|
|
||||||
|
tr.dnsConnection = nil
|
||||||
|
atomic.AddUint32(tr.connInstanceID, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if we are shutting down or failing
|
||||||
|
if shuttingDown || tr.IsFailing() {
|
||||||
|
// reply to all waiting queries
|
||||||
|
tr.Lock()
|
||||||
|
for id, inFlight := range tr.inFlightQueries {
|
||||||
|
close(inFlight.Response)
|
||||||
|
delete(tr.inFlightQueries, id)
|
||||||
|
}
|
||||||
|
tr.clientStarted.UnSet() // in lock to guarantee to set before submitQuery proceeds
|
||||||
|
tr.Unlock()
|
||||||
|
|
||||||
|
// hint network environment at failed connection
|
||||||
|
netenv.ReportFailedConnection()
|
||||||
|
|
||||||
|
cancelConnCtx() // The linter said so. Don't even...
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait until there is something to do
|
||||||
|
tr.Lock()
|
||||||
|
waiting := len(tr.inFlightQueries)
|
||||||
|
tr.Unlock()
|
||||||
|
if waiting > 0 {
|
||||||
|
// queue abandoned queries
|
||||||
|
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
|
||||||
|
currentConnInstanceID := atomic.LoadUint32(tr.connInstanceID)
|
||||||
|
tr.Lock()
|
||||||
|
for id, inFlight := range tr.inFlightQueries {
|
||||||
|
if inFlight.Started.Before(ignoreBefore) {
|
||||||
|
// remove
|
||||||
|
delete(tr.inFlightQueries, id)
|
||||||
|
} else if inFlight.ConnInstanceID != currentConnInstanceID {
|
||||||
|
inFlight.ConnInstanceID = currentConnInstanceID
|
||||||
|
// re-inject
|
||||||
|
select {
|
||||||
|
case tr.queries <- inFlight.Msg:
|
||||||
|
default:
|
||||||
|
log.Warningf("resolver: failed to re-inject abandoned query to %s", tr.resolver.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tr.Unlock()
|
||||||
|
} else {
|
||||||
|
// wait for first query
|
||||||
|
select {
|
||||||
|
case <-workerCtx.Done():
|
||||||
|
// abort
|
||||||
|
shuttingDown = true
|
||||||
|
continue connMgmt
|
||||||
|
case msg := <-tr.queries:
|
||||||
|
// re-insert, we will handle later
|
||||||
|
select {
|
||||||
|
case tr.queries <- msg:
|
||||||
|
default:
|
||||||
|
log.Warningf("resolver: failed to re-inject waking query to %s", tr.resolver.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create connection
|
||||||
|
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
|
||||||
|
// refresh dialer for authenticated local address
|
||||||
|
tr.dnsClient.Dialer = &net.Dialer{
|
||||||
|
LocalAddr: getLocalAddr("tcp"),
|
||||||
|
Timeout: defaultConnectTimeout,
|
||||||
|
KeepAlive: defaultClientTTL,
|
||||||
|
}
|
||||||
|
// connect
|
||||||
|
c, err := tr.dnsClient.Dial(tr.resolver.ServerAddress)
|
||||||
|
if err != nil {
|
||||||
|
tr.ReportFailure()
|
||||||
|
log.Debugf("resolver: failed to connect to %s (%s)", tr.resolver.Name, tr.resolver.ServerAddress)
|
||||||
|
continue connMgmt
|
||||||
|
}
|
||||||
|
tr.dnsConnection = c
|
||||||
|
log.Debugf("resolver: connected to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr())
|
||||||
|
|
||||||
|
// hint network environment at successful connection
|
||||||
|
netenv.ReportSuccessfulConnection()
|
||||||
|
|
||||||
|
// reset timer
|
||||||
|
connTimer.Stop()
|
||||||
|
select {
|
||||||
|
case <-connTimer.C: // try to empty the timer
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
connTimer.Reset(tr.clientTTL)
|
||||||
|
recycleConn = false
|
||||||
|
|
||||||
|
// start reader
|
||||||
|
module.StartWorker("dns client reader", func(ctx context.Context) error {
|
||||||
|
conn := tr.dnsConnection
|
||||||
|
for {
|
||||||
|
msg, err := conn.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
if connClosing.SetToIf(false, true) {
|
||||||
|
cancelConnCtx()
|
||||||
|
tr.ReportFailure()
|
||||||
|
log.Warningf("resolver: read error from %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
incoming <- msg
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// query management
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-workerCtx.Done():
|
||||||
|
// shutting down
|
||||||
|
shuttingDown = true
|
||||||
|
continue connMgmt
|
||||||
|
case <-connCtx.Done():
|
||||||
|
// connection error
|
||||||
|
continue connMgmt
|
||||||
|
case <-connTimer.C:
|
||||||
|
// client TTL expired, recycle connection
|
||||||
|
recycleConn = true
|
||||||
|
// trigger check
|
||||||
|
select {
|
||||||
|
case incoming <- nil:
|
||||||
|
default:
|
||||||
|
// quere is full anyway, do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
case msg := <-tr.queries:
|
||||||
|
// write query
|
||||||
|
_ = tr.dnsConnection.SetWriteDeadline(time.Now().Add(tr.dnsClient.WriteTimeout))
|
||||||
|
err := tr.dnsConnection.WriteMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
if connClosing.SetToIf(false, true) {
|
||||||
|
cancelConnCtx()
|
||||||
|
tr.ReportFailure()
|
||||||
|
log.Warningf("resolver: write error to %s (%s): %s", tr.resolver.Name, tr.dnsConnection.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
continue connMgmt
|
||||||
|
}
|
||||||
|
|
||||||
|
case msg := <-incoming:
|
||||||
|
|
||||||
|
if msg != nil {
|
||||||
|
// handle query from resolver
|
||||||
|
tr.Lock()
|
||||||
|
inFlight, ok := tr.inFlightQueries[msg.Id]
|
||||||
|
if ok {
|
||||||
|
delete(tr.inFlightQueries, msg.Id)
|
||||||
|
}
|
||||||
|
tr.Unlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
select {
|
||||||
|
case inFlight.Response <- msg:
|
||||||
|
// responded!
|
||||||
|
default:
|
||||||
|
// save to cache, if enabled
|
||||||
|
if !inFlight.Query.NoCaching {
|
||||||
|
// persist to database
|
||||||
|
rrCache := inFlight.MakeCacheRecord(msg)
|
||||||
|
rrCache.Clean(600)
|
||||||
|
err = rrCache.Save()
|
||||||
|
if err != nil {
|
||||||
|
log.Warningf(
|
||||||
|
"resolver: failed to cache RR for %s%s: %s",
|
||||||
|
inFlight.Query.FQDN,
|
||||||
|
inFlight.Query.QType.String(),
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf(
|
||||||
|
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
|
||||||
|
tr.resolver.Name,
|
||||||
|
tr.dnsConnection.RemoteAddr(),
|
||||||
|
msg.Id,
|
||||||
|
msg.Question,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if we have finished all queries and want to recycle conn
|
||||||
|
if recycleConn {
|
||||||
|
tr.Lock()
|
||||||
|
activeQueries := len(tr.inFlightQueries)
|
||||||
|
tr.Unlock()
|
||||||
|
if activeQueries == 0 {
|
||||||
|
log.Debugf("resolver: recycling conn to %s (%s)", tr.resolver.Name, tr.dnsConnection.RemoteAddr())
|
||||||
|
continue connMgmt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
|
||||||
|
// create msg
|
||||||
|
msg := &dns.Msg{}
|
||||||
|
msg.SetQuestion(q.FQDN, uint16(q.QType))
|
||||||
|
|
||||||
|
// save to waitlist
|
||||||
|
inFlight := &InFlightQuery{
|
||||||
|
Query: q,
|
||||||
|
Msg: msg,
|
||||||
|
Response: make(chan *dns.Msg),
|
||||||
|
Resolver: tr.resolver,
|
||||||
|
Started: time.Now().UTC(),
|
||||||
|
ConnInstanceID: atomic.LoadUint32(tr.connInstanceID),
|
||||||
|
}
|
||||||
|
tr.Lock()
|
||||||
|
tr.inFlightQueries[msg.Id] = inFlight
|
||||||
|
tr.Unlock()
|
||||||
|
|
||||||
|
// submit msg for writing
|
||||||
|
tr.queries <- msg
|
||||||
|
|
||||||
|
// make sure client is started
|
||||||
|
if tr.clientStarted.SetToIf(false, true) {
|
||||||
|
module.StartWorker("dns client", tr.client)
|
||||||
|
}
|
||||||
|
|
||||||
|
return inFlight
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query executes the given query against the resolver.
|
||||||
|
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
|
// submit to client
|
||||||
|
inFlight := tr.submitQuery(ctx, q)
|
||||||
|
var reply *dns.Msg
|
||||||
|
|
||||||
|
select {
|
||||||
|
case reply = <-inFlight.Response:
|
||||||
|
case <-time.After(defaultRequestTimeout):
|
||||||
|
tr.ReportFailure()
|
||||||
|
return nil, ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if tr.resolver.IsBlockedUpstream(reply) {
|
||||||
|
return nil, &BlockedUpstreamError{tr.resolver.GetName()}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inFlight.MakeCacheRecord(reply), nil
|
||||||
|
}
|
72
resolver/resolver-tcp_test.go
Normal file
72
resolver/resolver-tcp_test.go
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"runtime/pprof"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testTCPQuery(t *testing.T, wg *sync.WaitGroup, rc ResolverConn, q *Query) {
|
||||||
|
start := time.Now()
|
||||||
|
_, err := rc.Query(context.TODO(), q)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("client failed: %s", err) //nolint:staticcheck
|
||||||
|
wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Logf("resolved %s in %s", q.FQDN, time.Since(start))
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPResolver(t *testing.T) {
|
||||||
|
// skip if short - this test depends on the Internet and might fail randomly
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(15 * time.Second)
|
||||||
|
fmt.Fprintln(os.Stderr, "===== TAKING TOO LONG FOR SHUTDOWN =====")
|
||||||
|
printStackTo(os.Stderr)
|
||||||
|
os.Exit(1)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// create separate resolver for this test
|
||||||
|
resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config")
|
||||||
|
// resolver, _, err := createResolver("dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", "config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
started := time.Now()
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
go testTCPQuery(t, wg, resolver.Conn, &Query{ //nolint:staticcheck
|
||||||
|
FQDN: <-domainFeed,
|
||||||
|
QType: dns.Type(dns.TypeA),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
t.Logf("time taken: %s", time.Since(started))
|
||||||
|
}
|
||||||
|
|
||||||
|
func printStackTo(writer io.Writer) {
|
||||||
|
fmt.Fprintln(writer, "=== PRINTING TRACES ===")
|
||||||
|
fmt.Fprintln(writer, "=== GOROUTINES ===")
|
||||||
|
_ = pprof.Lookup("goroutine").WriteTo(writer, 1)
|
||||||
|
fmt.Fprintln(writer, "=== BLOCKING ===")
|
||||||
|
_ = pprof.Lookup("block").WriteTo(writer, 1)
|
||||||
|
fmt.Fprintln(writer, "=== MUTEXES ===")
|
||||||
|
_ = pprof.Lookup("mutex").WriteTo(writer, 1)
|
||||||
|
fmt.Fprintln(writer, "=== END TRACES ===")
|
||||||
|
}
|
|
@ -62,6 +62,20 @@ func formatIPAndPort(ip net.IP, port uint16) string {
|
||||||
return address
|
return address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolverConnFactory(resolver *Resolver) ResolverConn {
|
||||||
|
switch resolver.ServerType {
|
||||||
|
case ServerTypeTCP:
|
||||||
|
return NewTCPResolver(resolver)
|
||||||
|
case ServerTypeDoT:
|
||||||
|
return NewTCPResolver(resolver).UseTLS()
|
||||||
|
default:
|
||||||
|
return &BasicResolverConn{
|
||||||
|
clientManager: clientManagerFactory(resolver.ServerType)(resolver),
|
||||||
|
resolver: resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
|
func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager {
|
||||||
switch serverType {
|
switch serverType {
|
||||||
case ServerTypeDNS:
|
case ServerTypeDNS:
|
||||||
|
@ -129,12 +143,7 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
|
||||||
UpstreamBlockDetection: blockType,
|
UpstreamBlockDetection: blockType,
|
||||||
}
|
}
|
||||||
|
|
||||||
newConn := &BasicResolverConn{
|
new.Conn = resolverConnFactory(new)
|
||||||
clientManager: clientManagerFactory(u.Scheme)(new),
|
|
||||||
resolver: new,
|
|
||||||
}
|
|
||||||
|
|
||||||
new.Conn = newConn
|
|
||||||
return new, false, nil
|
return new, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue