diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 09b152e3..e76e1be4 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -27,6 +27,14 @@ type Scope struct { Resolvers []*Resolver } +const ( + parameterName = "name" + parameterVerify = "verify" + parameterBlockedIf = "blockedif" + parameterSearch = "search" + parameterSearchOnly = "search-only" +) + var ( globalResolvers []*Resolver // all (global) resolvers localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges @@ -122,31 +130,47 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { return nil, true, nil // skip } + // Get parameters and check if keys exist. query := u.Query() - verifyDomain := query.Get("verify") + for key := range query { + switch key { + case parameterName, + parameterVerify, + parameterBlockedIf, + parameterSearch, + parameterSearchOnly: + // Known key, continue. + default: + // Unknown key, abort. + return nil, false, fmt.Errorf(`unknown parameter "%s"`, key) + } + } + + // Check domain verification config. + verifyDomain := query.Get(parameterVerify) if verifyDomain != "" && u.Scheme != ServerTypeDoT { return nil, false, fmt.Errorf("domain verification only supported in DOT") } - if verifyDomain == "" && u.Scheme == ServerTypeDoT { return nil, false, fmt.Errorf("DOT must have a verify query parameter set") } - blockType := query.Get("blockedif") + // Check block detection type. + blockType := query.Get(parameterBlockedIf) if blockType == "" { blockType = BlockDetectionZeroIP } - switch blockType { case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP: default: return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)") } + // Build resolver. newResolver := &Resolver{ ConfigURL: resolverURL, Info: &ResolverInfo{ - Name: query.Get("name"), + Name: query.Get(parameterName), Type: u.Scheme, Source: source, IP: ip, @@ -159,7 +183,7 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { } // Parse search domains. - searchDomains := query.Get("search") + searchDomains := query.Get(parameterSearch) if searchDomains != "" { err = configureSearchDomains(newResolver, strings.Split(searchDomains, ","), true) if err != nil { @@ -168,14 +192,13 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { } // Check if searchOnly is set and valid. - if query.Has("searchOnly") { + if query.Has(parameterSearchOnly) { newResolver.SearchOnly = true - - if query.Get("searchOnly") != "" { - return nil, false, errors.New("searchOnly may only be used as an empty parameter") + if query.Get(parameterSearchOnly) != "" { + return nil, false, fmt.Errorf("%s may only be used as an empty parameter", parameterSearchOnly) } if len(newResolver.Search) == 0 { - return nil, false, errors.New("cannot use searchOnly without search scopes") + return nil, false, fmt.Errorf("cannot use %s without search scopes", parameterSearchOnly) } } @@ -186,7 +209,7 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { func configureSearchDomains(resolver *Resolver, searches []string, hardfail bool) error { // Check all search domains. for i, value := range searches { - trimmedDomain := strings.Trim(value, ".") + trimmedDomain := strings.ToLower(strings.Trim(value, ".")) err := checkSearchScope(trimmedDomain) if err != nil { if hardfail {