Move some Resolver information to ResolverInfo and propagate it

This commit is contained in:
Daniel 2021-03-20 22:19:27 +01:00
parent 43cfba8445
commit 20383226f8
13 changed files with 275 additions and 180 deletions

View file

@ -5,6 +5,7 @@ import (
"net"
"net/url"
"sort"
"strconv"
"strings"
"sync"
@ -61,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string {
}
func resolverConnFactory(resolver *Resolver) ResolverConn {
switch resolver.ServerType {
switch resolver.Info.Type {
case ServerTypeTCP:
return NewTCPResolver(resolver)
case ServerTypeDoT:
@ -82,26 +83,36 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
switch u.Scheme {
case ServerTypeDNS, ServerTypeDoT, ServerTypeTCP:
default:
return nil, false, fmt.Errorf("invalid DNS resolver scheme %q", u.Scheme)
return nil, false, fmt.Errorf("DNS resolver scheme %q invalid", u.Scheme)
}
ip := net.ParseIP(u.Hostname())
if ip == nil {
return nil, false, fmt.Errorf("invalid resolver IP")
return nil, false, fmt.Errorf("resolver IP %q invalid", u.Hostname())
}
// Add default port for scheme if it is missing.
if u.Port() == "" {
switch u.Scheme {
case ServerTypeDNS, ServerTypeTCP:
u.Host += ":53"
case ServerTypeDoT:
u.Host += ":853"
var port uint16
hostPort := u.Port()
switch {
case hostPort != "":
parsedPort, err := strconv.ParseUint(hostPort, 10, 16)
if err != nil {
return nil, false, fmt.Errorf("resolver port %q invalid", u.Port())
}
port = uint16(parsedPort)
case u.Scheme == ServerTypeDNS, u.Scheme == ServerTypeTCP:
port = 53
case u.Scheme == ServerTypeDoH:
port = 443
case u.Scheme == ServerTypeDoT:
port = 853
default:
return nil, false, fmt.Errorf("missing port in %q", u.Host)
}
scope := netutils.ClassifyIP(ip)
if scope == netutils.HostLocal {
scope := netutils.GetIPScope(ip)
if scope.IsLocalhost() {
return nil, true, nil // skip
}
@ -127,24 +138,20 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
}
new := &Resolver{
Server: resolverURL,
ServerType: u.Scheme,
ServerAddress: u.Host,
ServerIP: ip,
ServerIPScope: scope,
Source: source,
ConfigURL: resolverURL,
Info: &ResolverInfo{
Name: query.Get("name"),
Type: u.Scheme,
Source: source,
IP: ip,
IPScope: scope,
Port: port,
},
ServerAddress: net.JoinHostPort(ip.String(), strconv.Itoa(int(port))),
VerifyDomain: verifyDomain,
Name: query.Get("name"),
UpstreamBlockDetection: blockType,
}
u.RawQuery = "" // Remove options from parsed URL
if new.Name != "" {
new.ServerInfo = fmt.Sprintf("%s (%s, from %s)", new.Name, u, source)
} else {
new.ServerInfo = fmt.Sprintf("%s (from %s)", u, source)
}
new.Conn = resolverConnFactory(new)
return new, false, nil
}
@ -195,7 +202,7 @@ func getSystemResolvers() (resolvers []*Resolver) {
continue
}
if netutils.IPIsLAN(nameserver.IP) {
if resolver.Info.IPScope.IsLAN() {
configureSearchDomains(resolver, nameserver.Search)
}
@ -244,16 +251,16 @@ func loadResolvers() {
activeResolvers = make(map[string]*Resolver)
// add
for _, resolver := range newResolvers {
activeResolvers[resolver.Server] = resolver
activeResolvers[resolver.Info.ID()] = resolver
}
activeResolvers[mDNSResolver.Server] = mDNSResolver
activeResolvers[envResolver.Server] = envResolver
activeResolvers[mDNSResolver.Info.ID()] = mDNSResolver
activeResolvers[envResolver.Info.ID()] = envResolver
// log global resolvers
if len(globalResolvers) > 0 {
log.Trace("resolver: loaded global resolvers:")
for _, resolver := range globalResolvers {
log.Tracef("resolver: %s", resolver.Server)
log.Tracef("resolver: %s", resolver.ConfigURL)
}
} else {
log.Warning("resolver: no global resolvers loaded")
@ -263,7 +270,7 @@ func loadResolvers() {
if len(localResolvers) > 0 {
log.Trace("resolver: loaded local resolvers:")
for _, resolver := range localResolvers {
log.Tracef("resolver: %s", resolver.Server)
log.Tracef("resolver: %s", resolver.ConfigURL)
}
} else {
log.Info("resolver: no local resolvers loaded")
@ -273,7 +280,7 @@ func loadResolvers() {
if len(systemResolvers) > 0 {
log.Trace("resolver: loaded system/network-assigned resolvers:")
for _, resolver := range systemResolvers {
log.Tracef("resolver: %s", resolver.Server)
log.Tracef("resolver: %s", resolver.ConfigURL)
}
} else {
log.Info("resolver: no system/network-assigned resolvers loaded")
@ -285,7 +292,7 @@ func loadResolvers() {
for _, scope := range localScopes {
var scopeServers []string
for _, resolver := range scope.Resolvers {
scopeServers = append(scopeServers, resolver.Server)
scopeServers = append(scopeServers, resolver.ConfigURL)
}
log.Tracef("resolver: %s: %s", scope.Domain, strings.Join(scopeServers, ", "))
}
@ -306,11 +313,11 @@ func setScopedResolvers(resolvers []*Resolver) {
localScopes = make([]*Scope, 0)
for _, resolver := range resolvers {
if resolver.ServerIP != nil && netutils.IPIsLAN(resolver.ServerIP) {
if resolver.Info.IPScope.IsLAN() {
localResolvers = append(localResolvers, resolver)
}
if resolver.Source == ServerSourceOperatingSystem {
if resolver.Info.Source == ServerSourceOperatingSystem {
systemResolvers = append(systemResolvers, resolver)
}