diff --git a/main.go b/main.go index 14a99b6..b3930e1 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "log" "net" "net/http" + "net/netip" "net/url" "os" "strings" @@ -197,20 +198,28 @@ func run() int { apiAddress = args.apiAddress mainLogger.Info("Using fixed API host IP address = %s", apiAddress) } else { - resolver, err := NewResolver(args.bootstrapDNS.values[0], args.timeout) + resolver, err := NewResolver(args.bootstrapDNS.values, args.timeout) if err != nil { mainLogger.Critical("Unable to instantiate DNS resolver: %v", err) return 4 } mainLogger.Info("Discovering API IP address...") - addrs := resolver.ResolveA(API_DOMAIN) + addrs, err := func() ([]netip.Addr, error) { + ctx, cancel := context.WithTimeout(context.Background(), args.timeout) + defer cancel() + return resolver.LookupNetIP(ctx, "ip4", API_DOMAIN) + }() + if err != nil { + mainLogger.Critical("Unable to resolve API server address: %v", err) + return 14 + } if len(addrs) == 0 { mainLogger.Critical("Unable to resolve %s with specified bootstrap DNS", API_DOMAIN) return 14 } - apiAddress = addrs[0] + apiAddress = addrs[0].String() mainLogger.Info("Discovered address of API host = %s", apiAddress) } seclientDialer = NewFixedDialer(apiAddress, dialer) diff --git a/resolver.go b/resolver.go index 9f2318e..f460e4e 100644 --- a/resolver.go +++ b/resolver.go @@ -1,83 +1,38 @@ package main import ( + "context" + "fmt" + "net/netip" "time" "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/miekg/dns" ) type Resolver struct { - upstream upstream.Upstream + resolvers upstream.ParallelResolver + timeout time.Duration } -const DOT = 0x2e - -func NewResolver(address string, timeout time.Duration) (*Resolver, error) { - opts := &upstream.Options{Timeout: timeout} - u, err := upstream.AddressToUpstream(address, opts) - if err != nil { - return nil, err +func NewResolver(addresses []string, timeout time.Duration) (*Resolver, error) { + resolvers := make([]upstream.Resolver, 0, len(addresses)) + opts := &upstream.Options{ + Timeout: timeout, } - return &Resolver{upstream: u}, nil -} - -func (r *Resolver) ResolveA(domain string) []string { - res := make([]string, 0) - if len(domain) == 0 { - return res - } - if domain[len(domain)-1] != DOT { - domain = domain + "." - } - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: domain, Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - reply, err := r.upstream.Exchange(&req) - if err != nil { - return res - } - for _, rr := range reply.Answer { - if a, ok := rr.(*dns.A); ok { - res = append(res, a.A.String()) + for _, addr := range addresses { + u, err := upstream.AddressToUpstream(addr, opts) + if err != nil { + return nil, fmt.Errorf("unable to construct upstream resolver from string %q: %w", + addr, err) } + resolvers = append(resolvers, &upstream.UpstreamResolver{Upstream: u}) } - return res + return &Resolver{ + resolvers: resolvers, + timeout: timeout, + }, nil } -func (r *Resolver) ResolveAAAA(domain string) []string { - res := make([]string, 0) - if len(domain) == 0 { - return res - } - if domain[len(domain)-1] != DOT { - domain = domain + "." - } - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: domain, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}, - } - reply, err := r.upstream.Exchange(&req) - if err != nil { - return res - } - for _, rr := range reply.Answer { - if a, ok := rr.(*dns.AAAA); ok { - res = append(res, a.AAAA.String()) - } - } - return res -} - -func (r *Resolver) Resolve(domain string) []string { - res := r.ResolveA(domain) - if len(res) == 0 { - res = r.ResolveAAAA(domain) - } - return res +func (r *Resolver) LookupNetIP(ctx context.Context, network string, host string) (addrs []netip.Addr, err error) { + return r.resolvers.LookupNetIP(ctx, network, host) }