Implement further feedback on the TCP Resolver

This commit is contained in:
Daniel 2020-07-14 12:54:05 +02:00
parent 73da96fe98
commit 383c019d0c

View file

@ -104,13 +104,7 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
}
tr.Lock()
// check for existing query
for i := 0; i < 10; i++ { // don't try forever
_, exists := tr.inFlightQueries[msg.Id]
if !exists {
break // we are unique, yay!
}
msg.Id = dns.Id() // regenerate ID
}
tr.ensureUniqueID(msg)
// add query to in flight registry
tr.inFlightQueries[msg.Id] = inFlight
tr.Unlock()
@ -121,6 +115,27 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
return inFlight
}
// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked.
func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
// try a random ID 10000 times
for i := 0; i < 10000; i++ { // don't try forever
_, exists := tr.inFlightQueries[msg.Id]
if !exists {
return // we are unique, yay!
}
msg.Id = dns.Id() // regenerate ID
}
// go through the complete space
var id uint16
for ; id <= (1<<16)-1; id++ { // don't try forever
_, exists := tr.inFlightQueries[id]
if !exists {
msg.Id = id
return // we are unique, yay!
}
}
}
// Query executes the given query against the resolver.
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
// submit to client
@ -148,12 +163,6 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
type tcpResolverConnMgr struct {
tr *TCPResolver
workerCtx context.Context
conn *dns.Conn
connCtx context.Context
cancelConnCtx func()
connTimer *time.Timer
connClosing *abool.AtomicBool
responses chan *dns.Msg
failCnt int
}
@ -162,8 +171,6 @@ func (tr *TCPResolver) startClient() {
if tr.clientStarted.SetToIf(false, true) {
mgr := &tcpResolverConnMgr{
tr: tr,
connTimer: time.NewTimer(tr.clientTTL),
connClosing: abool.New(),
responses: make(chan *dns.Msg, 100),
}
module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run)
@ -171,8 +178,6 @@ func (tr *TCPResolver) startClient() {
}
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
mgr.workerCtx = workerCtx
// connection lifecycle loop
for {
// check if we are failing
@ -181,19 +186,16 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
return nil
}
// clean up anything that is left over
mgr.cleanupConnection()
// wait for work before creating connection
proceed := mgr.waitForWork()
proceed := mgr.waitForWork(workerCtx)
if !proceed {
mgr.shutdown()
return nil
}
// create connection
success := mgr.establishConnection()
if !success {
conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection(workerCtx)
if conn == nil {
mgr.failCnt++
continue
}
@ -202,7 +204,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
netenv.ReportSuccessfulConnection()
// handle queries
proceed = mgr.queryHandler()
proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx)
if !proceed {
mgr.shutdown()
return nil
@ -210,26 +212,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
}
}
func (mgr *tcpResolverConnMgr) cleanupConnection() {
// cleanup old connection
if mgr.conn != nil {
mgr.connClosing.Set() // silence connection errors
_ = mgr.conn.Close()
if mgr.cancelConnCtx != nil {
mgr.cancelConnCtx()
}
// delete old connection
mgr.conn = nil
// increase instance counter
atomic.AddUint32(mgr.tr.connInstanceID, 1)
}
}
func (mgr *tcpResolverConnMgr) shutdown() {
mgr.cleanupConnection()
// reply to all waiting queries
mgr.tr.Lock()
for id, inFlight := range mgr.tr.inFlightQueries {
@ -246,7 +229,7 @@ func (mgr *tcpResolverConnMgr) shutdown() {
}
}
func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) {
func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) {
// wait until there is something to do
mgr.tr.Lock()
waiting := len(mgr.tr.inFlightQueries)
@ -256,6 +239,7 @@ func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) {
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID)
mgr.tr.Lock()
defer mgr.tr.Unlock()
for id, inFlight := range mgr.tr.inFlightQueries {
if inFlight.Started.Before(ignoreBefore) {
// remove old queries
@ -272,29 +256,34 @@ func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) {
}
// in-flight queries that match the connection instance ID are not changed. They are already in the queue.
}
mgr.tr.Unlock()
} else {
return true
}
// wait for first query
select {
case <-mgr.workerCtx.Done():
case <-workerCtx.Done():
return false
case msg := <-mgr.tr.queries:
// re-insert query, we will handle it later
module.StartWorker("reinject triggering dns query", func(ctx context.Context) error {
select {
case mgr.tr.queries <- msg:
default:
case <-time.After(2 * time.Second):
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Name)
}
}
return nil
})
}
return true
}
func (mgr *tcpResolverConnMgr) establishConnection() (success bool) {
// create connection
mgr.connCtx, mgr.cancelConnCtx = context.WithCancel(mgr.workerCtx)
mgr.connClosing = abool.New()
func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) (
conn *dns.Conn,
connClosing *abool.AtomicBool,
connCtx context.Context,
cancelConnCtx context.CancelFunc,
) {
// refresh dialer to set an authenticated local address
// TODO: lock dnsClient (only manager should run at any time, so this should not be an issue)
mgr.tr.dnsClient.Dialer = &net.Dialer{
@ -303,42 +292,56 @@ func (mgr *tcpResolverConnMgr) establishConnection() (success bool) {
KeepAlive: defaultClientTTL,
}
// connect
c, err := mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
var err error
conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
if err != nil {
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.Name, mgr.tr.resolver.ServerAddress)
return false
return nil, nil, nil, nil
}
mgr.conn = c
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, mgr.conn.RemoteAddr())
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
connClosing = abool.New()
// reset timer
mgr.connTimer.Stop()
select {
case <-mgr.connTimer.C: // try to empty the timer
default:
}
mgr.connTimer.Reset(mgr.tr.clientTTL)
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
// start reader
module.StartServiceWorker("dns client reader", 10*time.Millisecond, mgr.msgReader)
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
return mgr.msgReader(conn, connClosing, cancelConnCtx)
})
return true
return conn, connClosing, connCtx, cancelConnCtx
}
func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
workerCtx context.Context,
conn *dns.Conn,
connClosing *abool.AtomicBool,
connCtx context.Context,
cancelConnCtx context.CancelFunc,
) (proceed bool) {
var readyToRecycle bool
ttlTimer := time.After(mgr.tr.clientTTL)
// clean up connection
defer func() {
connClosing.Set() // silence connection errors
cancelConnCtx()
_ = conn.Close()
// increase instance counter
atomic.AddUint32(mgr.tr.connInstanceID, 1)
}()
for {
select {
case <-mgr.workerCtx.Done():
case <-workerCtx.Done():
// module shutdown
return false
case <-mgr.connCtx.Done():
case <-connCtx.Done():
// connection error
return true
case <-mgr.connTimer.C:
case <-ttlTimer:
// connection TTL reached, rebuild connection
// but handle all in flight queries first
readyToRecycle = true
@ -351,18 +354,36 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit
case msg := <-mgr.tr.queries:
// write query
_ = mgr.conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
err := mgr.conn.WriteMsg(msg)
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
err := conn.WriteMsg(msg)
if err != nil {
if mgr.connClosing.SetToIf(false, true) {
mgr.cancelConnCtx()
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, mgr.conn.RemoteAddr(), err)
if connClosing.SetToIf(false, true) {
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
}
return true
}
case msg := <-mgr.responses:
if msg != nil { // nil messages only trigger the recycle check
if msg != nil {
mgr.handleQueryResponse(conn, msg)
}
if readyToRecycle {
// check to see if we can recycle the connection
mgr.tr.Lock()
activeQueries := len(mgr.tr.inFlightQueries)
mgr.tr.Unlock()
if activeQueries == 0 {
log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
return true
}
}
}
}
}
func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) {
// handle query from resolver
mgr.tr.Lock()
inFlight, ok := mgr.tr.inFlightQueries[msg.Id]
@ -371,14 +392,31 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit
}
mgr.tr.Unlock()
if ok {
if !ok {
log.Debugf(
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
mgr.tr.resolver.Name,
conn.RemoteAddr(),
msg.Id,
msg.Question,
)
return
}
select {
case inFlight.Response <- msg:
mgr.failCnt = 0 // reset fail counter
// responded!
return
default:
// save to cache, if enabled
if !inFlight.Query.NoCaching {
// no one is listening for that response.
}
// if caching is disabled we're done
if inFlight.Query.NoCaching {
return
}
// persist to database
rrCache := inFlight.MakeCacheRecord(msg)
rrCache.Clean(600)
@ -391,46 +429,19 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit
err,
)
}
}
}
} else {
log.Debugf(
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
mgr.tr.resolver.Name,
mgr.conn.RemoteAddr(),
msg.Id,
msg.Question,
)
}
}
if readyToRecycle {
// check to see if we can recycle the connection
mgr.tr.Lock()
activeQueries := len(mgr.tr.inFlightQueries)
mgr.tr.Unlock()
if activeQueries == 0 {
log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, mgr.conn.RemoteAddr())
return true
}
}
}
}
}
func (mgr *tcpResolverConnMgr) msgReader(workerCtx context.Context) error {
// copy values from manager
conn := mgr.conn
cancelConnCtx := mgr.cancelConnCtx
connClosing := mgr.connClosing
func (mgr *tcpResolverConnMgr) msgReader(
conn *dns.Conn,
connClosing *abool.AtomicBool,
cancelConnCtx context.CancelFunc,
) error {
defer cancelConnCtx()
for {
msg, err := conn.ReadMsg()
if err != nil {
if connClosing.SetToIf(false, true) {
cancelConnCtx()
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, mgr.conn.RemoteAddr(), err)
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
}
return nil
}