DoT support for domain name only

Configed resolvers skip ther own domains
This commit is contained in:
Vladimir Stoilov 2022-07-20 16:06:24 +02:00
parent 35b4ee2a29
commit bdc3792d21
4 changed files with 133 additions and 138 deletions

View file

@ -6,8 +6,10 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -65,6 +67,14 @@ func NewHTTPSResolver(resolver *Resolver) *HttpsResolver {
// Query executes the given query against the resolver. // Query executes the given query against the resolver.
func (hr *HttpsResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { func (hr *HttpsResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// Do not resolve domain names that are needed to initialize a resolver
if hr.resolver.Info.IP == nil {
if _, ok := resolverInitDomains[q.FQDN[:len(q.FQDN)-1]]; ok {
return nil, ErrContinue
}
}
dnsQuery := new(dns.Msg) dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
@ -75,10 +85,10 @@ func (hr *HttpsResolver) Query(ctx context.Context, q *Query) (*RRCache, error)
} }
b64dns := base64.RawStdEncoding.EncodeToString(buf) b64dns := base64.RawStdEncoding.EncodeToString(buf)
host := hr.resolver.VerifyDomain // Set the host, if we dont have IP address just use the domain
host := hr.resolver.ServerAddress
if hr.resolver.ServerAddress != "" { if host == "" {
host = hr.resolver.ServerAddress host = net.JoinHostPort(hr.resolver.VerifyDomain, strconv.Itoa(int(hr.resolver.Info.Port)))
} }
// Build and execute http reuqest // Build and execute http reuqest

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strconv"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -142,8 +143,14 @@ func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolve
KeepAlive: defaultClientTTL, KeepAlive: defaultClientTTL,
} }
// Set the host, if we dont have IP address just use the domain
host := tr.resolver.ServerAddress
if host == "" {
host = net.JoinHostPort(tr.resolver.VerifyDomain, strconv.Itoa(int(tr.resolver.Info.Port)))
}
// Connect to server. // Connect to server.
conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress) conn, err := tr.dnsClient.Dial(host)
if err != nil { if err != nil {
// Hint network environment at failed connection. // Hint network environment at failed connection.
netenv.ReportFailedConnection() netenv.ReportFailedConnection()
@ -185,6 +192,13 @@ func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolve
// Query executes the given query against the resolver. // Query executes the given query against the resolver.
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// Do not resolve domain names that are needed to initialize a resolver
if tr.resolver.Info.IP == nil && tr.dnsClient.TLSConfig != nil {
if _, ok := resolverInitDomains[q.FQDN[:len(q.FQDN)-1]]; ok {
return nil, ErrContinue
}
}
// Get resolver connection. // Get resolver connection.
resolverConn, err := tr.getOrCreateResolverConn(ctx) resolverConn, err := tr.getOrCreateResolverConn(ctx)
if err != nil { if err != nil {

View file

@ -92,6 +92,9 @@ type ResolverInfo struct { //nolint:golint,maligned // TODO
// IP is the IP address of the resolver // IP is the IP address of the resolver
IP net.IP IP net.IP
// Domain of the dns server if it has one
Domain string
// IPScope is the network scope of the IP address. // IPScope is the network scope of the IP address.
IPScope netutils.IPScope IPScope netutils.IPScope
@ -112,6 +115,20 @@ func (info *ResolverInfo) ID() string {
info.id = ServerTypeMDNS info.id = ServerTypeMDNS
case ServerTypeEnv: case ServerTypeEnv:
info.id = ServerTypeEnv info.id = ServerTypeEnv
case ServerTypeDoH:
info.id = fmt.Sprintf(
"https://%s:%d#%s",
info.Domain,
info.Port,
info.Source,
)
case ServerTypeDoT:
info.id = fmt.Sprintf(
"dot://%s:%d#%s",
info.Domain,
info.Port,
info.Source,
)
default: default:
info.id = fmt.Sprintf( info.id = fmt.Sprintf(
"%s://%s:%d#%s", "%s://%s:%d#%s",

View file

@ -1,7 +1,6 @@
package resolver package resolver
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@ -12,7 +11,6 @@ import (
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
"github.com/miekg/dns"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/utils" "github.com/safing/portbase/utils"
"github.com/safing/portmaster/netenv" "github.com/safing/portmaster/netenv"
@ -38,12 +36,13 @@ const (
) )
var ( var (
globalResolvers []*Resolver // all (global) resolvers globalResolvers []*Resolver // all (global) resolvers
localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges
systemResolvers []*Resolver // all resolvers that were assigned by the system systemResolvers []*Resolver // all resolvers that were assigned by the system
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
activeResolvers map[string]*Resolver // lookup map of all resolvers activeResolvers map[string]*Resolver // lookup map of all resolvers
resolversLock sync.RWMutex resolverInitDomains map[string]bool // a set with all domains of the dns resolvers
resolversLock sync.RWMutex
) )
func indexOfScope(domain string, list []*Scope) int { func indexOfScope(domain string, list []*Scope) int {
@ -97,6 +96,10 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
return nil, false, err return nil, false, err
} }
if resolverInitDomains == nil {
resolverInitDomains = make(map[string]bool)
}
switch u.Scheme { switch u.Scheme {
case ServerTypeDNS, ServerTypeDoT, ServerTypeDoH, ServerTypeTCP: case ServerTypeDNS, ServerTypeDoT, ServerTypeDoH, ServerTypeTCP:
case HttpsProtocol: case HttpsProtocol:
@ -105,84 +108,51 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
return nil, false, fmt.Errorf("DNS resolver scheme %q invalid", u.Scheme) return nil, false, fmt.Errorf("DNS resolver scheme %q invalid", u.Scheme)
} }
// Check if we are using domain name and if it's in a valid scheme
ip := net.ParseIP(u.Hostname())
hostnameIsDomaion := (ip == nil)
if ip == nil && u.Scheme != ServerTypeDoH && u.Scheme != ServerTypeDoT {
return nil, false, fmt.Errorf("resolver IP %q invalid", u.Hostname())
}
path := u.Path // Used for DoH
// Add default port for scheme if it is missing.
port, err := parsePortFromURL(u)
if err != nil {
return nil, false, err
}
// Get parameters and check if keys exist.
query := u.Query() query := u.Query()
err = checkURLParameterValidity(u.Scheme, hostnameIsDomaion, query)
if err != nil {
return nil, false, err
}
// Get IP address and domain name from paramters. // Create Resolver object
serverAddress := ""
serverIPParamter := query.Get(parameterIP)
verifyDomain := query.Get(parameterVerify)
if u.Scheme == ServerTypeDoT || u.Scheme == ServerTypeDoH {
switch {
case hostnameIsDomaion && serverIPParamter != "": // domain and ip as parameter
ip = net.ParseIP(serverIPParamter)
serverAddress = net.JoinHostPort(serverIPParamter, strconv.Itoa(int(port)))
verifyDomain = u.Hostname()
case !hostnameIsDomaion && verifyDomain != "": // ip and domain as parameter
serverAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
case hostnameIsDomaion && verifyDomain == "" && serverIPParamter == "": // only domain
verifyDomain = u.Hostname()
}
} else {
serverAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
}
// Check block detection type.
blockType := query.Get(parameterBlockedIf)
if blockType == "" {
blockType = BlockDetectionZeroIP
}
switch blockType {
case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP:
default:
return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)")
}
// Get ip scope if we have ip
scope := netutils.Global
if ip != nil {
scope = netutils.GetIPScope(ip)
// Skip localhost resolvers from the OS, but not if configured.
if scope.IsLocalhost() && source == ServerSourceOperatingSystem {
return nil, true, nil // skip
}
}
// Build resolver.
newResolver := &Resolver{ newResolver := &Resolver{
ConfigURL: resolverURL, ConfigURL: resolverURL,
Info: &ResolverInfo{ Info: &ResolverInfo{
Name: query.Get(parameterName), Name: query.Get(parameterName),
Type: u.Scheme, Type: u.Scheme,
Source: source, Source: source,
IP: ip, IP: nil,
IPScope: scope, Domain: "",
Port: port, IPScope: netutils.Global,
Port: 0,
}, },
ServerAddress: serverAddress, ServerAddress: "",
VerifyDomain: verifyDomain, VerifyDomain: "",
Path: path, Path: u.Path, // Used for DoH
UpstreamBlockDetection: blockType, UpstreamBlockDetection: "",
}
// Get parameters and check if keys exist.
err = checkAndSetResolverParamters(u, newResolver)
if err != nil {
return nil, false, err
}
// Check block detection type.
newResolver.UpstreamBlockDetection = query.Get(parameterBlockedIf)
if newResolver.UpstreamBlockDetection == "" {
newResolver.UpstreamBlockDetection = BlockDetectionZeroIP
}
switch newResolver.UpstreamBlockDetection {
case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP:
default:
return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)")
}
// Get ip scope if we have ip
if newResolver.Info.IP != nil {
newResolver.Info.IPScope = netutils.GetIPScope(newResolver.Info.IP)
// Skip localhost resolvers from the OS, but not if configured.
if newResolver.Info.IPScope.IsLocalhost() && source == ServerSourceOperatingSystem {
return nil, true, nil // skip
}
} }
// Parse search domains. // Parse search domains.
@ -209,7 +179,24 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
return newResolver, false, nil return newResolver, false, nil
} }
func checkURLParameterValidity(scheme string, hostnameIsDomaion bool, query url.Values) error { func checkAndSetResolverParamters(u *url.URL, resolver *Resolver) error {
// Check if we are using domain name and if it's in a valid scheme
ip := net.ParseIP(u.Hostname())
hostnameIsDomaion := (ip == nil)
if ip == nil && u.Scheme != ServerTypeDoH && u.Scheme != ServerTypeDoT {
return fmt.Errorf("resolver IP %q invalid", u.Hostname())
}
// Add default port for scheme if it is missing.
port, err := parsePortFromURL(u)
if err != nil {
return err
}
resolver.Info.Port = port
query := u.Query()
for key := range query { for key := range query {
switch key { switch key {
case parameterName, case parameterName,
@ -226,78 +213,45 @@ func checkURLParameterValidity(scheme string, hostnameIsDomaion bool, query url.
} }
} }
verifyDomain := query.Get(parameterVerify) resolver.VerifyDomain = query.Get(parameterVerify)
paramterServerIP := query.Get(parameterIP) paramterServerIP := query.Get(parameterIP)
if scheme == ServerTypeDoT || scheme == ServerTypeDoH { if u.Scheme == ServerTypeDoT || u.Scheme == ServerTypeDoH {
// Check if IP and Domain are set correctly
switch { switch {
case hostnameIsDomaion && verifyDomain != "": case hostnameIsDomaion && resolver.VerifyDomain != "":
return fmt.Errorf("cannot set the domain name via both the hostname in the URL and the verify parameter") return fmt.Errorf("cannot set the domain name via both the hostname in the URL and the verify parameter")
case !hostnameIsDomaion && verifyDomain == "": case !hostnameIsDomaion && resolver.VerifyDomain == "":
return fmt.Errorf("verify parameter must be set when using ip as domain") return fmt.Errorf("verify parameter must be set when using ip as domain")
case !hostnameIsDomaion && paramterServerIP != "": case !hostnameIsDomaion && paramterServerIP != "":
return fmt.Errorf("cannot set the IP address via both the hostname in the URL and the ip parameter") return fmt.Errorf("cannot set the IP address via both the hostname in the URL and the ip parameter")
} }
// Parse and set IP and Domain to the resolver
switch {
case hostnameIsDomaion && paramterServerIP != "": // domain and ip as parameter
resolver.Info.IP = net.ParseIP(paramterServerIP)
resolver.ServerAddress = net.JoinHostPort(paramterServerIP, strconv.Itoa(int(resolver.Info.Port)))
resolver.VerifyDomain = u.Hostname()
case !hostnameIsDomaion && resolver.VerifyDomain != "": // ip and domain as parameter
resolver.ServerAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(resolver.Info.Port)))
case hostnameIsDomaion && resolver.VerifyDomain == "" && paramterServerIP == "": // only domain
resolver.VerifyDomain = u.Hostname()
}
resolver.Info.Domain = resolver.VerifyDomain
resolverInitDomains[resolver.Info.Domain] = true
} else { } else {
if verifyDomain != "" { if resolver.VerifyDomain != "" {
return fmt.Errorf("domain verification is only supported by DoT and DoH servers") return fmt.Errorf("domain verification is only supported by DoT and DoH servers")
} }
resolver.ServerAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(resolver.Info.Port)))
} }
return nil return nil
} }
func resolveDomainIP(ctx context.Context, domain string) ([]net.IP, error) {
fqdn := domain
if !strings.HasSuffix(fqdn, ".") {
fqdn += "."
}
query := &Query{
FQDN: fqdn,
QType: dns.Type(dns.TypeA),
}
for _, resolver := range activeResolvers {
rr, err := resolver.Conn.Query(ctx, query)
if err != nil {
log.Error(err.Error())
continue
}
return rr.ExportAllARecords(), nil
}
nameserves := netenv.Nameservers()
if len(nameserves) == 0 {
return nil, fmt.Errorf("unable to resolve domain %s", domain)
}
client := new(dns.Client)
message := new(dns.Msg)
message.SetQuestion(fqdn, dns.TypeA)
message.RecursionDesired = true
ip := net.JoinHostPort(nameserves[0].IP.String(), "53")
reply, _, err := client.Exchange(message, ip)
if err != nil {
return nil, err
}
newRecord := &RRCache{
Domain: query.FQDN,
Question: query.QType,
RCode: reply.Rcode,
Answer: reply.Answer,
Ns: reply.Ns,
Extra: reply.Extra,
}
return newRecord.ExportAllARecords(), nil
}
func parsePortFromURL(url *url.URL) (uint16, error) { func parsePortFromURL(url *url.URL) (uint16, error) {
var port uint16 var port uint16
hostPort := url.Port() hostPort := url.Port()