From 214de8f41243bdf1d8bb649b5e2476ed2ac3ab15 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Mon, 4 Nov 2024 13:26:09 +0200 Subject: [PATCH] resolving dialer --- dialer/resolver.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++ main.go | 45 +++++++++----------------------------- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/dialer/resolver.go b/dialer/resolver.go index a4c4f82..84e2c0a 100644 --- a/dialer/resolver.go +++ b/dialer/resolver.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "net/netip" "time" @@ -50,3 +51,56 @@ func (r *Resolver) Close() error { } return res } + +type LookupNetIPer interface { + LookupNetIP(context.Context, string, string) ([]netip.Addr, error) +} + +type ResolvingDialer struct { + lookup LookupNetIPer + next ContextDialer +} + +func NewResolvingDialer(lookup LookupNetIPer, next ContextDialer) *ResolvingDialer { + return &ResolvingDialer{ + lookup: lookup, + next: next, + } +} + +func (d *ResolvingDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *ResolvingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("failed to extract host and port from %s: %w", address, err) + } + + var resolveNetwork string + switch network { + case "udp4", "tcp4", "ip4": + resolveNetwork = "ip4" + case "udp6", "tcp6", "ip6": + resolveNetwork = "ip6" + case "udp", "tcp", "ip": + resolveNetwork = "ip" + default: + return nil, fmt.Errorf("resolving dial %q: unsupported network %q", address, network) + } + resolved, err := d.lookup.LookupNetIP(ctx, resolveNetwork, host) + if err != nil { + return nil, fmt.Errorf("dial failed on address lookup: %w", err) + } + + var conn net.Conn + for _, ip := range resolved { + conn, err = d.next.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + if err == nil { + return conn, nil + } + } + + return nil, fmt.Errorf("failed to dial %s: %w", address, err) +} diff --git a/main.go b/main.go index 9f774e8..9e319c9 100644 --- a/main.go +++ b/main.go @@ -14,7 +14,6 @@ import ( "log" "net" "net/http" - "net/netip" "net/url" "os" "strings" @@ -204,41 +203,17 @@ func run() int { } seclientDialer := d - if args.apiAddress != "" || len(args.bootstrapDNS.values) > 0 { - var apiAddress string - if args.apiAddress != "" { - apiAddress = args.apiAddress - mainLogger.Info("Using fixed API host IP address = %s", apiAddress) - } else { - resolver, err := dialer.NewResolver(args.bootstrapDNS.values, args.timeout) - if err != nil { - mainLogger.Critical("Unable to instantiate DNS resolver: %v", err) - return 4 - } - - mainLogger.Info("Discovering API IP address...") - addrs, err := func() ([]netip.Addr, error) { - ctx, cancel := context.WithTimeout(context.Background(), args.timeout) - defer cancel() - defer func() { - resolver = nil - }() - defer resolver.Close() - return resolver.LookupNetIP(ctx, "ip4", API_DOMAIN) - }() - if err != nil { - mainLogger.Critical("Unable to resolve API server address: %v", err) - return 14 - } - if len(addrs) == 0 { - mainLogger.Critical("Unable to resolve %s with specified bootstrap DNS", API_DOMAIN) - return 14 - } - - apiAddress = addrs[0].String() - mainLogger.Info("Discovered address of API host = %s", apiAddress) + if args.apiAddress != "" { + mainLogger.Info("Using fixed API host IP address = %s", args.apiAddress) + seclientDialer = dialer.NewFixedDialer(args.apiAddress, d) + } else if len(args.bootstrapDNS.values) > 0 { + resolver, err := dialer.NewResolver(args.bootstrapDNS.values, args.timeout) + if err != nil { + mainLogger.Critical("Unable to instantiate DNS resolver: %v", err) + return 4 } - seclientDialer = dialer.NewFixedDialer(apiAddress, d) + defer resolver.Close() + seclientDialer = dialer.NewResolvingDialer(resolver, d) } // Dialing w/o SNI, receiving self-signed certificate, so skip verification.