Merge pull request #57 from Snawoot/concurrent_bootstrap_dns

Concurrent bootstrap dns
This commit is contained in:
Snawoot 2024-08-01 12:52:16 +03:00 committed by GitHub
commit 0ec231b994
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 88 additions and 77 deletions

75
main.go
View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
@ -8,10 +9,12 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"os" "os"
"strings" "strings"
@ -43,6 +46,37 @@ func arg_fail(msg string) {
os.Exit(2) os.Exit(2)
} }
type CSVArg struct {
values []string
}
func (a *CSVArg) String() string {
if len(a.values) == 0 {
return ""
}
buf := new(bytes.Buffer)
wr := csv.NewWriter(buf)
wr.Write(a.values)
wr.Flush()
return strings.TrimRight(buf.String(), "\n")
}
func (a *CSVArg) Set(line string) error {
rd := csv.NewReader(strings.NewReader(line))
rd.FieldsPerRecord = -1
rd.TrimLeadingSpace = true
values, err := rd.Read()
if err == io.EOF {
a.values = nil
return nil
}
if err != nil {
return fmt.Errorf("unable to parse comma-separated argument: %w", err)
}
a.values = values
return nil
}
type CLIArgs struct { type CLIArgs struct {
country string country string
listCountries bool listCountries bool
@ -55,15 +89,28 @@ type CLIArgs struct {
apiLogin string apiLogin string
apiPassword string apiPassword string
apiAddress string apiAddress string
bootstrapDNS string bootstrapDNS *CSVArg
refresh time.Duration refresh time.Duration
refreshRetry time.Duration refreshRetry time.Duration
certChainWorkaround bool certChainWorkaround bool
caFile string caFile string
} }
func parse_args() CLIArgs { func parse_args() *CLIArgs {
var args CLIArgs args := &CLIArgs{
bootstrapDNS: &CSVArg{
values: []string{
"https://1.1.1.3/dns-query",
"https://8.8.8.8/dns-query",
"https://dns.google/dns-query",
"https://security.cloudflare-dns.com/dns-query",
"https://wikimedia-dns.org/dns-query",
"https://dns.adguard-dns.com/dns-query",
"https://dns.quad9.net/dns-query",
"https://doh.cleanbrowsing.org/doh/adult-filter/",
},
},
}
flag.StringVar(&args.country, "country", "EU", "desired proxy location") flag.StringVar(&args.country, "country", "EU", "desired proxy location")
flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit") flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit")
flag.BoolVar(&args.listProxies, "list-proxies", false, "output proxy list and exit") flag.BoolVar(&args.listProxies, "list-proxies", false, "output proxy list and exit")
@ -78,10 +125,10 @@ func parse_args() CLIArgs {
flag.StringVar(&args.apiLogin, "api-login", "se0316", "SurfEasy API login") flag.StringVar(&args.apiLogin, "api-login", "se0316", "SurfEasy API login")
flag.StringVar(&args.apiPassword, "api-password", "SILrMEPBmJuhomxWkfm3JalqHX2Eheg1YhlEZiMh8II", "SurfEasy API password") flag.StringVar(&args.apiPassword, "api-password", "SILrMEPBmJuhomxWkfm3JalqHX2Eheg1YhlEZiMh8II", "SurfEasy API password")
flag.StringVar(&args.apiAddress, "api-address", "", fmt.Sprintf("override IP address of %s", API_DOMAIN)) flag.StringVar(&args.apiAddress, "api-address", "", fmt.Sprintf("override IP address of %s", API_DOMAIN))
flag.StringVar(&args.bootstrapDNS, "bootstrap-dns", "https://1.1.1.3/dns-query", flag.Var(args.bootstrapDNS, "bootstrap-dns",
"DNS/DoH/DoT/DoQ resolver for initial discovering of SurfEasy API address. "+ "comma-separated list of DNS/DoH/DoT/DoQ resolvers for initial discovery of SurfEasy API address. "+
"See https://github.com/ameshkov/dnslookup/ for upstream DNS URL format. "+ "See https://github.com/ameshkov/dnslookup/ for upstream DNS URL format. "+
"Examples: https://1.1.1.1/dns-query, quic://dns.adguard.com") "Examples: https://1.1.1.1/dns-query,quic://dns.adguard.com")
flag.DurationVar(&args.refresh, "refresh", 4*time.Hour, "login refresh interval") flag.DurationVar(&args.refresh, "refresh", 4*time.Hour, "login refresh interval")
flag.DurationVar(&args.refreshRetry, "refresh-retry", 5*time.Second, "login refresh retry interval") flag.DurationVar(&args.refreshRetry, "refresh-retry", 5*time.Second, "login refresh retry interval")
flag.BoolVar(&args.certChainWorkaround, "certchain-workaround", true, flag.BoolVar(&args.certChainWorkaround, "certchain-workaround", true,
@ -147,26 +194,34 @@ func run() int {
} }
seclientDialer := dialer seclientDialer := dialer
if args.apiAddress != "" || args.bootstrapDNS != "" { if args.apiAddress != "" || len(args.bootstrapDNS.values) > 0 {
var apiAddress string var apiAddress string
if args.apiAddress != "" { if args.apiAddress != "" {
apiAddress = args.apiAddress apiAddress = args.apiAddress
mainLogger.Info("Using fixed API host IP address = %s", apiAddress) mainLogger.Info("Using fixed API host IP address = %s", apiAddress)
} else { } else {
resolver, err := NewResolver(args.bootstrapDNS, args.timeout) resolver, err := NewResolver(args.bootstrapDNS.values, args.timeout)
if err != nil { if err != nil {
mainLogger.Critical("Unable to instantiate DNS resolver: %v", err) mainLogger.Critical("Unable to instantiate DNS resolver: %v", err)
return 4 return 4
} }
mainLogger.Info("Discovering API IP address...") mainLogger.Info("Discovering API IP address...")
addrs := resolver.ResolveA(API_DOMAIN) addrs, err := func() ([]netip.Addr, error) {
ctx, cancel := context.WithTimeout(context.Background(), args.timeout)
defer cancel()
return resolver.LookupNetIP(ctx, "ip4", API_DOMAIN)
}()
if err != nil {
mainLogger.Critical("Unable to resolve API server address: %v", err)
return 14
}
if len(addrs) == 0 { if len(addrs) == 0 {
mainLogger.Critical("Unable to resolve %s with specified bootstrap DNS", API_DOMAIN) mainLogger.Critical("Unable to resolve %s with specified bootstrap DNS", API_DOMAIN)
return 14 return 14
} }
apiAddress = addrs[0] apiAddress = addrs[0].String()
mainLogger.Info("Discovered address of API host = %s", apiAddress) mainLogger.Info("Discovered address of API host = %s", apiAddress)
} }
seclientDialer = NewFixedDialer(apiAddress, dialer) seclientDialer = NewFixedDialer(apiAddress, dialer)

View file

@ -1,82 +1,38 @@
package main package main
import ( import (
"github.com/AdguardTeam/dnsproxy/upstream" "context"
"github.com/miekg/dns" "fmt"
"net/netip"
"time" "time"
"github.com/AdguardTeam/dnsproxy/upstream"
) )
type Resolver struct { type Resolver struct {
upstream upstream.Upstream resolvers upstream.ParallelResolver
timeout time.Duration
} }
const DOT = 0x2e func NewResolver(addresses []string, timeout time.Duration) (*Resolver, error) {
resolvers := make([]upstream.Resolver, 0, len(addresses))
func NewResolver(address string, timeout time.Duration) (*Resolver, error) { opts := &upstream.Options{
opts := &upstream.Options{Timeout: timeout} Timeout: timeout,
u, err := upstream.AddressToUpstream(address, opts) }
for _, addr := range addresses {
u, err := upstream.AddressToUpstream(addr, opts)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("unable to construct upstream resolver from string %q: %w",
addr, err)
} }
return &Resolver{upstream: u}, nil resolvers = append(resolvers, &upstream.UpstreamResolver{Upstream: u})
}
return &Resolver{
resolvers: resolvers,
timeout: timeout,
}, nil
} }
func (r *Resolver) ResolveA(domain string) []string { func (r *Resolver) LookupNetIP(ctx context.Context, network string, host string) (addrs []netip.Addr, err error) {
res := make([]string, 0) return r.resolvers.LookupNetIP(ctx, network, host)
if len(domain) == 0 {
return res
}
if domain[len(domain)-1] != DOT {
domain = domain + "."
}
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: domain, Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := r.upstream.Exchange(&req)
if err != nil {
return res
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.A); ok {
res = append(res, a.A.String())
}
}
return res
}
func (r *Resolver) ResolveAAAA(domain string) []string {
res := make([]string, 0)
if len(domain) == 0 {
return res
}
if domain[len(domain)-1] != DOT {
domain = domain + "."
}
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: domain, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET},
}
reply, err := r.upstream.Exchange(&req)
if err != nil {
return res
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.AAAA); ok {
res = append(res, a.AAAA.String())
}
}
return res
}
func (r *Resolver) Resolve(domain string) []string {
res := r.ResolveA(domain)
if len(res) == 0 {
res = r.ResolveAAAA(domain)
}
return res
} }