From 3d69216c27ca82e5b826b4d5c0a795f20d842b7c Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 30 Jan 2023 11:33:03 +0100 Subject: [PATCH] Force resolvers to reconnect after connecting to SPN --- resolver/main.go | 13 +++++++++ resolver/resolver-env.go | 2 ++ resolver/resolver-https.go | 57 +++++++++++++++++++++++++++++--------- resolver/resolver-mdns.go | 2 ++ resolver/resolver-plain.go | 4 +++ resolver/resolver-tcp.go | 26 ++++++++++++++--- resolver/resolver.go | 1 + resolver/resolvers.go | 15 ++++++++++ 8 files changed, 103 insertions(+), 17 deletions(-) diff --git a/resolver/main.go b/resolver/main.go index 01d5d623..e9efde17 100644 --- a/resolver/main.go +++ b/resolver/main.go @@ -57,6 +57,19 @@ func start() error { return err } + // Force resolvers to reconnect when SPN has connected. + if err := module.RegisterEventHook( + "captain", + "spn connect", // Defined by captain.SPNConnectedEvent + "force resolver reconnect", + func(ctx context.Context, _ any) error { + ForceResolverReconnect(ctx) + return nil + }, + ); err != nil { + return err + } + // reload after config change prevNameservers := strings.Join(configuredNameServers(), " ") err = module.RegisterEventHook( diff --git a/resolver/resolver-env.go b/resolver/resolver-env.go index b8710ff3..d976d311 100644 --- a/resolver/resolver-env.go +++ b/resolver/resolver-env.go @@ -150,6 +150,8 @@ func (er *envResolverConn) IsFailing() bool { func (er *envResolverConn) ResetFailure() {} +func (er *envResolverConn) ForceReconnect(_ context.Context) {} + // QueryPortmasterEnv queries the environment resolver directly. func QueryPortmasterEnv(ctx context.Context, q *Query) (*RRCache, error) { return envResolver.Conn.Query(ctx, q) diff --git a/resolver/resolver-https.go b/resolver/resolver-https.go index 36e30f15..a99fa794 100644 --- a/resolver/resolver-https.go +++ b/resolver/resolver-https.go @@ -8,15 +8,18 @@ import ( "io" "net/http" "net/url" + "sync" "time" "github.com/miekg/dns" + "github.com/safing/portbase/log" ) // HTTPSResolver is a resolver using just a single tcp connection with pipelining. type HTTPSResolver struct { BasicResolverConn - Client *http.Client + client *http.Client + clientLock sync.RWMutex } // HTTPSQuery holds the query information for a hTTPSResolverConn. @@ -40,23 +43,13 @@ func (tq *HTTPSQuery) MakeCacheRecord(reply *dns.Msg, resolverInfo *ResolverInfo // NewHTTPSResolver returns a new HTTPSResolver. func NewHTTPSResolver(resolver *Resolver) *HTTPSResolver { - tr := &http.Transport{ - TLSClientConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - ServerName: resolver.Info.Domain, - // TODO: use portbase rng - }, - IdleConnTimeout: 3 * time.Minute, - } - - client := &http.Client{Transport: tr} newResolver := &HTTPSResolver{ BasicResolverConn: BasicResolverConn{ resolver: resolver, }, - Client: client, } newResolver.BasicResolverConn.init() + newResolver.refreshClient() return newResolver } @@ -86,7 +79,13 @@ func (hr *HTTPSResolver) Query(ctx context.Context, q *Query) (*RRCache, error) return nil, err } - resp, err := hr.Client.Do(request) + // Lock client for usage. + hr.clientLock.RLock() + defer hr.clientLock.RUnlock() + + // TODO: Check age of client and force a refresh similar to the TCP resolver. + + resp, err := hr.client.Do(request) if err != nil { return nil, err } @@ -124,3 +123,35 @@ func (hr *HTTPSResolver) Query(ctx context.Context, q *Query) (*RRCache, error) // TODO: check if reply.Answer is valid return newRecord, nil } + +// ForceReconnect forces the resolver to re-establish the connection to the server. +func (hr *HTTPSResolver) ForceReconnect(ctx context.Context) { + hr.refreshClient() + log.Tracer(ctx).Tracef("resolver: created new HTTP client for %s", hr.resolver) +} + +func (hr *HTTPSResolver) refreshClient() { + // Lock client for changing. + hr.clientLock.Lock() + defer hr.clientLock.Unlock() + + // Attempt to close connection of previous client. + if hr.client != nil { + hr.client.CloseIdleConnections() + } + + // Create new client. + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: hr.resolver.Info.Domain, + // TODO: use portbase rng + }, + IdleConnTimeout: 1 * time.Minute, + TLSHandshakeTimeout: defaultConnectTimeout, + } + hr.client = &http.Client{ + Transport: tr, + Timeout: maxRequestTimeout, + } +} diff --git a/resolver/resolver-mdns.go b/resolver/resolver-mdns.go index 29677350..2e01122a 100644 --- a/resolver/resolver-mdns.go +++ b/resolver/resolver-mdns.go @@ -56,6 +56,8 @@ func (mrc *mDNSResolverConn) IsFailing() bool { func (mrc *mDNSResolverConn) ResetFailure() {} +func (mrc *mDNSResolverConn) ForceReconnect(_ context.Context) {} + type savedQuestion struct { question dns.Question expires time.Time diff --git a/resolver/resolver-plain.go b/resolver/resolver-plain.go index 2ddcff90..992366b4 100644 --- a/resolver/resolver-plain.go +++ b/resolver/resolver-plain.go @@ -96,3 +96,7 @@ func (pr *PlainResolver) Query(ctx context.Context, q *Query) (*RRCache, error) // TODO: check if reply.Answer is valid return newRecord, nil } + +// ForceReconnect forces the resolver to re-establish the connection to the server. +// Does nothing for PlainResolver, as every request uses its own connection. +func (pr *PlainResolver) ForceReconnect(_ context.Context) {} diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index 746d6c01..f61779b1 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -236,11 +236,29 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { return tq.MakeCacheRecord(reply, tr.resolver.Info), nil } +// ForceReconnect forces the resolver to re-establish the connection to the server. +func (tr *TCPResolver) ForceReconnect(ctx context.Context) { + tr.Lock() + defer tr.Unlock() + + // Do nothing if no connection is available. + if tr.resolverConn == nil { + return + } + + // Set the abandoned to force a new connection on next request. + // This will leave the previous connection and handler running until all requests are handled. + tr.resolverConn.abandoned.Set() + + log.Tracer(ctx).Tracef("resolver: marked %s for reconnecting", tr.resolver) +} + +// shutdown cleanly shuts down the resolver connection. +// Must only be called once. func (trc *tcpResolverConn) shutdown() { // Set abandoned status and close connection to the DNS server. - if trc.abandoned.SetToIf(false, true) { - _ = trc.conn.Close() - } + trc.abandoned.Set() + _ = trc.conn.Close() // Close all response channels for in-flight queries. for _, tq := range trc.inFlightQueries { @@ -320,7 +338,7 @@ func (trc *tcpResolverConn) handler(workerCtx context.Context) error { // If we are ready to recycle and we have no in-flight queries, we can // shutdown the connection and create a new one for the next query. - if readyToRecycle { + if readyToRecycle || trc.abandoned.IsSet() { if len(trc.inFlightQueries) == 0 { log.Debugf("resolver: recycling connection to %s", trc.resolverInfo.DescriptiveName()) return nil diff --git a/resolver/resolver.go b/resolver/resolver.go index 1e965688..1d17f1ef 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -212,6 +212,7 @@ type ResolverConn interface { //nolint:golint // TODO ReportFailure() IsFailing() bool ResetFailure() + ForceReconnect(ctx context.Context) } // BasicResolverConn implements ResolverConn for standard dns clients. diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 1da8cbd0..198f7543 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -570,3 +570,18 @@ func IsResolverAddress(ip net.IP, port uint16) bool { return false } + +// ForceResolverReconnect forces all resolvers to reconnect. +func ForceResolverReconnect(ctx context.Context) { + resolversLock.RLock() + defer resolversLock.RUnlock() + + ctx, tracer := log.AddTracer(ctx) + defer tracer.Submit() + + tracer.Trace("resolver: forcing all active resolvers to reconnect") + for _, r := range globalResolvers { + r.Conn.ForceReconnect(ctx) + } + tracer.Info("resolver: all active resolvers were forced to reconnect") +}