diff --git a/resolver/resolver-https.go b/resolver/resolver-https.go index 7fd6c14b..4bc23478 100644 --- a/resolver/resolver-https.go +++ b/resolver/resolver-https.go @@ -6,8 +6,10 @@ import ( "encoding/base64" "fmt" "io/ioutil" + "net" "net/http" "net/url" + "strconv" "github.com/miekg/dns" ) @@ -65,6 +67,14 @@ func NewHTTPSResolver(resolver *Resolver) *HttpsResolver { // Query executes the given query against the resolver. func (hr *HttpsResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { + + // Do not resolve domain names that are needed to initialize a resolver + if hr.resolver.Info.IP == nil { + if _, ok := resolverInitDomains[q.FQDN[:len(q.FQDN)-1]]; ok { + return nil, ErrContinue + } + } + dnsQuery := new(dns.Msg) dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) @@ -75,10 +85,10 @@ func (hr *HttpsResolver) Query(ctx context.Context, q *Query) (*RRCache, error) } b64dns := base64.RawStdEncoding.EncodeToString(buf) - host := hr.resolver.VerifyDomain - - if hr.resolver.ServerAddress != "" { - host = hr.resolver.ServerAddress + // Set the host, if we dont have IP address just use the domain + host := hr.resolver.ServerAddress + if host == "" { + host = net.JoinHostPort(hr.resolver.VerifyDomain, strconv.Itoa(int(hr.resolver.Info.Port))) } // Build and execute http reuqest diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index cb37ec94..47d53e6a 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "strconv" "time" "github.com/miekg/dns" @@ -142,8 +143,14 @@ func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolve KeepAlive: defaultClientTTL, } + // Set the host, if we dont have IP address just use the domain + host := tr.resolver.ServerAddress + if host == "" { + host = net.JoinHostPort(tr.resolver.VerifyDomain, strconv.Itoa(int(tr.resolver.Info.Port))) + } + // Connect to server. - conn, err := tr.dnsClient.Dial(tr.resolver.ServerAddress) + conn, err := tr.dnsClient.Dial(host) if err != nil { // Hint network environment at failed connection. netenv.ReportFailedConnection() @@ -185,6 +192,13 @@ func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolve // Query executes the given query against the resolver. func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { + // Do not resolve domain names that are needed to initialize a resolver + if tr.resolver.Info.IP == nil && tr.dnsClient.TLSConfig != nil { + if _, ok := resolverInitDomains[q.FQDN[:len(q.FQDN)-1]]; ok { + return nil, ErrContinue + } + } + // Get resolver connection. resolverConn, err := tr.getOrCreateResolverConn(ctx) if err != nil { diff --git a/resolver/resolver.go b/resolver/resolver.go index ddc58373..cd016c9b 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -92,6 +92,9 @@ type ResolverInfo struct { //nolint:golint,maligned // TODO // IP is the IP address of the resolver IP net.IP + // Domain of the dns server if it has one + Domain string + // IPScope is the network scope of the IP address. IPScope netutils.IPScope @@ -112,6 +115,20 @@ func (info *ResolverInfo) ID() string { info.id = ServerTypeMDNS case ServerTypeEnv: info.id = ServerTypeEnv + case ServerTypeDoH: + info.id = fmt.Sprintf( + "https://%s:%d#%s", + info.Domain, + info.Port, + info.Source, + ) + case ServerTypeDoT: + info.id = fmt.Sprintf( + "dot://%s:%d#%s", + info.Domain, + info.Port, + info.Source, + ) default: info.id = fmt.Sprintf( "%s://%s:%d#%s", diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 35101586..5c7167ba 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -1,7 +1,6 @@ package resolver import ( - "context" "fmt" "net" "net/url" @@ -12,7 +11,6 @@ import ( "golang.org/x/net/publicsuffix" - "github.com/miekg/dns" "github.com/safing/portbase/log" "github.com/safing/portbase/utils" "github.com/safing/portmaster/netenv" @@ -38,12 +36,13 @@ const ( ) var ( - globalResolvers []*Resolver // all (global) resolvers - localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges - systemResolvers []*Resolver // all resolvers that were assigned by the system - localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope - activeResolvers map[string]*Resolver // lookup map of all resolvers - resolversLock sync.RWMutex + globalResolvers []*Resolver // all (global) resolvers + localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges + systemResolvers []*Resolver // all resolvers that were assigned by the system + localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope + activeResolvers map[string]*Resolver // lookup map of all resolvers + resolverInitDomains map[string]bool // a set with all domains of the dns resolvers + resolversLock sync.RWMutex ) func indexOfScope(domain string, list []*Scope) int { @@ -97,6 +96,10 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { return nil, false, err } + if resolverInitDomains == nil { + resolverInitDomains = make(map[string]bool) + } + switch u.Scheme { case ServerTypeDNS, ServerTypeDoT, ServerTypeDoH, ServerTypeTCP: case HttpsProtocol: @@ -105,84 +108,51 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { return nil, false, fmt.Errorf("DNS resolver scheme %q invalid", u.Scheme) } - // Check if we are using domain name and if it's in a valid scheme - ip := net.ParseIP(u.Hostname()) - hostnameIsDomaion := (ip == nil) - if ip == nil && u.Scheme != ServerTypeDoH && u.Scheme != ServerTypeDoT { - return nil, false, fmt.Errorf("resolver IP %q invalid", u.Hostname()) - } - - path := u.Path // Used for DoH - - // Add default port for scheme if it is missing. - port, err := parsePortFromURL(u) - if err != nil { - return nil, false, err - } - - // Get parameters and check if keys exist. query := u.Query() - err = checkURLParameterValidity(u.Scheme, hostnameIsDomaion, query) - if err != nil { - return nil, false, err - } - // Get IP address and domain name from paramters. - serverAddress := "" - serverIPParamter := query.Get(parameterIP) - verifyDomain := query.Get(parameterVerify) - - if u.Scheme == ServerTypeDoT || u.Scheme == ServerTypeDoH { - switch { - case hostnameIsDomaion && serverIPParamter != "": // domain and ip as parameter - ip = net.ParseIP(serverIPParamter) - serverAddress = net.JoinHostPort(serverIPParamter, strconv.Itoa(int(port))) - verifyDomain = u.Hostname() - case !hostnameIsDomaion && verifyDomain != "": // ip and domain as parameter - serverAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) - case hostnameIsDomaion && verifyDomain == "" && serverIPParamter == "": // only domain - verifyDomain = u.Hostname() - } - } else { - serverAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) - } - - // Check block detection type. - blockType := query.Get(parameterBlockedIf) - if blockType == "" { - blockType = BlockDetectionZeroIP - } - switch blockType { - case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP: - default: - return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)") - } - - // Get ip scope if we have ip - scope := netutils.Global - if ip != nil { - scope = netutils.GetIPScope(ip) - // Skip localhost resolvers from the OS, but not if configured. - if scope.IsLocalhost() && source == ServerSourceOperatingSystem { - return nil, true, nil // skip - } - } - - // Build resolver. + // Create Resolver object newResolver := &Resolver{ ConfigURL: resolverURL, Info: &ResolverInfo{ Name: query.Get(parameterName), Type: u.Scheme, Source: source, - IP: ip, - IPScope: scope, - Port: port, + IP: nil, + Domain: "", + IPScope: netutils.Global, + Port: 0, }, - ServerAddress: serverAddress, - VerifyDomain: verifyDomain, - Path: path, - UpstreamBlockDetection: blockType, + ServerAddress: "", + VerifyDomain: "", + Path: u.Path, // Used for DoH + UpstreamBlockDetection: "", + } + + // Get parameters and check if keys exist. + err = checkAndSetResolverParamters(u, newResolver) + if err != nil { + return nil, false, err + } + + // Check block detection type. + newResolver.UpstreamBlockDetection = query.Get(parameterBlockedIf) + if newResolver.UpstreamBlockDetection == "" { + newResolver.UpstreamBlockDetection = BlockDetectionZeroIP + } + + switch newResolver.UpstreamBlockDetection { + case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP: + default: + return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)") + } + + // Get ip scope if we have ip + if newResolver.Info.IP != nil { + newResolver.Info.IPScope = netutils.GetIPScope(newResolver.Info.IP) + // Skip localhost resolvers from the OS, but not if configured. + if newResolver.Info.IPScope.IsLocalhost() && source == ServerSourceOperatingSystem { + return nil, true, nil // skip + } } // Parse search domains. @@ -209,7 +179,24 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { return newResolver, false, nil } -func checkURLParameterValidity(scheme string, hostnameIsDomaion bool, query url.Values) error { +func checkAndSetResolverParamters(u *url.URL, resolver *Resolver) error { + + // Check if we are using domain name and if it's in a valid scheme + ip := net.ParseIP(u.Hostname()) + hostnameIsDomaion := (ip == nil) + if ip == nil && u.Scheme != ServerTypeDoH && u.Scheme != ServerTypeDoT { + return fmt.Errorf("resolver IP %q invalid", u.Hostname()) + } + + // Add default port for scheme if it is missing. + port, err := parsePortFromURL(u) + if err != nil { + return err + } + resolver.Info.Port = port + + query := u.Query() + for key := range query { switch key { case parameterName, @@ -226,78 +213,45 @@ func checkURLParameterValidity(scheme string, hostnameIsDomaion bool, query url. } } - verifyDomain := query.Get(parameterVerify) + resolver.VerifyDomain = query.Get(parameterVerify) paramterServerIP := query.Get(parameterIP) - if scheme == ServerTypeDoT || scheme == ServerTypeDoH { + if u.Scheme == ServerTypeDoT || u.Scheme == ServerTypeDoH { + // Check if IP and Domain are set correctly switch { - case hostnameIsDomaion && verifyDomain != "": + case hostnameIsDomaion && resolver.VerifyDomain != "": return fmt.Errorf("cannot set the domain name via both the hostname in the URL and the verify parameter") - case !hostnameIsDomaion && verifyDomain == "": + case !hostnameIsDomaion && resolver.VerifyDomain == "": return fmt.Errorf("verify parameter must be set when using ip as domain") case !hostnameIsDomaion && paramterServerIP != "": return fmt.Errorf("cannot set the IP address via both the hostname in the URL and the ip parameter") } + + // Parse and set IP and Domain to the resolver + switch { + case hostnameIsDomaion && paramterServerIP != "": // domain and ip as parameter + resolver.Info.IP = net.ParseIP(paramterServerIP) + resolver.ServerAddress = net.JoinHostPort(paramterServerIP, strconv.Itoa(int(resolver.Info.Port))) + resolver.VerifyDomain = u.Hostname() + case !hostnameIsDomaion && resolver.VerifyDomain != "": // ip and domain as parameter + resolver.ServerAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(resolver.Info.Port))) + case hostnameIsDomaion && resolver.VerifyDomain == "" && paramterServerIP == "": // only domain + resolver.VerifyDomain = u.Hostname() + } + + resolver.Info.Domain = resolver.VerifyDomain + resolverInitDomains[resolver.Info.Domain] = true } else { - if verifyDomain != "" { + if resolver.VerifyDomain != "" { return fmt.Errorf("domain verification is only supported by DoT and DoH servers") } + resolver.ServerAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(resolver.Info.Port))) } return nil } -func resolveDomainIP(ctx context.Context, domain string) ([]net.IP, error) { - fqdn := domain - if !strings.HasSuffix(fqdn, ".") { - fqdn += "." - } - query := &Query{ - FQDN: fqdn, - QType: dns.Type(dns.TypeA), - } - - for _, resolver := range activeResolvers { - rr, err := resolver.Conn.Query(ctx, query) - if err != nil { - log.Error(err.Error()) - continue - } - - return rr.ExportAllARecords(), nil - } - - nameserves := netenv.Nameservers() - if len(nameserves) == 0 { - return nil, fmt.Errorf("unable to resolve domain %s", domain) - } - - client := new(dns.Client) - - message := new(dns.Msg) - message.SetQuestion(fqdn, dns.TypeA) - message.RecursionDesired = true - ip := net.JoinHostPort(nameserves[0].IP.String(), "53") - - reply, _, err := client.Exchange(message, ip) - - if err != nil { - return nil, err - } - - newRecord := &RRCache{ - Domain: query.FQDN, - Question: query.QType, - RCode: reply.Rcode, - Answer: reply.Answer, - Ns: reply.Ns, - Extra: reply.Extra, - } - - return newRecord.ExportAllARecords(), nil -} - func parsePortFromURL(url *url.URL) (uint16, error) { var port uint16 hostPort := url.Port()