mirror of
https://github.com/safing/portmaster
synced 2025-09-03 19:19:15 +00:00
Implement further feedback on the TCP Resolver
This commit is contained in:
parent
73da96fe98
commit
383c019d0c
1 changed files with 144 additions and 133 deletions
|
@ -104,13 +104,7 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
|
||||||
}
|
}
|
||||||
tr.Lock()
|
tr.Lock()
|
||||||
// check for existing query
|
// check for existing query
|
||||||
for i := 0; i < 10; i++ { // don't try forever
|
tr.ensureUniqueID(msg)
|
||||||
_, exists := tr.inFlightQueries[msg.Id]
|
|
||||||
if !exists {
|
|
||||||
break // we are unique, yay!
|
|
||||||
}
|
|
||||||
msg.Id = dns.Id() // regenerate ID
|
|
||||||
}
|
|
||||||
// add query to in flight registry
|
// add query to in flight registry
|
||||||
tr.inFlightQueries[msg.Id] = inFlight
|
tr.inFlightQueries[msg.Id] = inFlight
|
||||||
tr.Unlock()
|
tr.Unlock()
|
||||||
|
@ -121,6 +115,27 @@ func (tr *TCPResolver) submitQuery(_ context.Context, q *Query) *InFlightQuery {
|
||||||
return inFlight
|
return inFlight
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureUniqueID makes sure that ID assigned to msg is unique. TCPResolver must be locked.
|
||||||
|
func (tr *TCPResolver) ensureUniqueID(msg *dns.Msg) {
|
||||||
|
// try a random ID 10000 times
|
||||||
|
for i := 0; i < 10000; i++ { // don't try forever
|
||||||
|
_, exists := tr.inFlightQueries[msg.Id]
|
||||||
|
if !exists {
|
||||||
|
return // we are unique, yay!
|
||||||
|
}
|
||||||
|
msg.Id = dns.Id() // regenerate ID
|
||||||
|
}
|
||||||
|
// go through the complete space
|
||||||
|
var id uint16
|
||||||
|
for ; id <= (1<<16)-1; id++ { // don't try forever
|
||||||
|
_, exists := tr.inFlightQueries[id]
|
||||||
|
if !exists {
|
||||||
|
msg.Id = id
|
||||||
|
return // we are unique, yay!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Query executes the given query against the resolver.
|
// Query executes the given query against the resolver.
|
||||||
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
// submit to client
|
// submit to client
|
||||||
|
@ -148,12 +163,6 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||||
|
|
||||||
type tcpResolverConnMgr struct {
|
type tcpResolverConnMgr struct {
|
||||||
tr *TCPResolver
|
tr *TCPResolver
|
||||||
workerCtx context.Context
|
|
||||||
conn *dns.Conn
|
|
||||||
connCtx context.Context
|
|
||||||
cancelConnCtx func()
|
|
||||||
connTimer *time.Timer
|
|
||||||
connClosing *abool.AtomicBool
|
|
||||||
responses chan *dns.Msg
|
responses chan *dns.Msg
|
||||||
failCnt int
|
failCnt int
|
||||||
}
|
}
|
||||||
|
@ -162,8 +171,6 @@ func (tr *TCPResolver) startClient() {
|
||||||
if tr.clientStarted.SetToIf(false, true) {
|
if tr.clientStarted.SetToIf(false, true) {
|
||||||
mgr := &tcpResolverConnMgr{
|
mgr := &tcpResolverConnMgr{
|
||||||
tr: tr,
|
tr: tr,
|
||||||
connTimer: time.NewTimer(tr.clientTTL),
|
|
||||||
connClosing: abool.New(),
|
|
||||||
responses: make(chan *dns.Msg, 100),
|
responses: make(chan *dns.Msg, 100),
|
||||||
}
|
}
|
||||||
module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run)
|
module.StartServiceWorker("dns client", 10*time.Millisecond, mgr.run)
|
||||||
|
@ -171,8 +178,6 @@ func (tr *TCPResolver) startClient() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
||||||
mgr.workerCtx = workerCtx
|
|
||||||
|
|
||||||
// connection lifecycle loop
|
// connection lifecycle loop
|
||||||
for {
|
for {
|
||||||
// check if we are failing
|
// check if we are failing
|
||||||
|
@ -181,19 +186,16 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean up anything that is left over
|
|
||||||
mgr.cleanupConnection()
|
|
||||||
|
|
||||||
// wait for work before creating connection
|
// wait for work before creating connection
|
||||||
proceed := mgr.waitForWork()
|
proceed := mgr.waitForWork(workerCtx)
|
||||||
if !proceed {
|
if !proceed {
|
||||||
mgr.shutdown()
|
mgr.shutdown()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// create connection
|
// create connection
|
||||||
success := mgr.establishConnection()
|
conn, connClosing, connCtx, cancelConnCtx := mgr.establishConnection(workerCtx)
|
||||||
if !success {
|
if conn == nil {
|
||||||
mgr.failCnt++
|
mgr.failCnt++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -202,7 +204,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
||||||
netenv.ReportSuccessfulConnection()
|
netenv.ReportSuccessfulConnection()
|
||||||
|
|
||||||
// handle queries
|
// handle queries
|
||||||
proceed = mgr.queryHandler()
|
proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx)
|
||||||
if !proceed {
|
if !proceed {
|
||||||
mgr.shutdown()
|
mgr.shutdown()
|
||||||
return nil
|
return nil
|
||||||
|
@ -210,26 +212,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) cleanupConnection() {
|
|
||||||
// cleanup old connection
|
|
||||||
if mgr.conn != nil {
|
|
||||||
mgr.connClosing.Set() // silence connection errors
|
|
||||||
_ = mgr.conn.Close()
|
|
||||||
if mgr.cancelConnCtx != nil {
|
|
||||||
mgr.cancelConnCtx()
|
|
||||||
}
|
|
||||||
|
|
||||||
// delete old connection
|
|
||||||
mgr.conn = nil
|
|
||||||
|
|
||||||
// increase instance counter
|
|
||||||
atomic.AddUint32(mgr.tr.connInstanceID, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) shutdown() {
|
func (mgr *tcpResolverConnMgr) shutdown() {
|
||||||
mgr.cleanupConnection()
|
|
||||||
|
|
||||||
// reply to all waiting queries
|
// reply to all waiting queries
|
||||||
mgr.tr.Lock()
|
mgr.tr.Lock()
|
||||||
for id, inFlight := range mgr.tr.inFlightQueries {
|
for id, inFlight := range mgr.tr.inFlightQueries {
|
||||||
|
@ -246,7 +229,7 @@ func (mgr *tcpResolverConnMgr) shutdown() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) {
|
func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) {
|
||||||
// wait until there is something to do
|
// wait until there is something to do
|
||||||
mgr.tr.Lock()
|
mgr.tr.Lock()
|
||||||
waiting := len(mgr.tr.inFlightQueries)
|
waiting := len(mgr.tr.inFlightQueries)
|
||||||
|
@ -256,6 +239,7 @@ func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) {
|
||||||
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
|
ignoreBefore := time.Now().Add(-ignoreQueriesAfter)
|
||||||
currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID)
|
currentConnInstanceID := atomic.LoadUint32(mgr.tr.connInstanceID)
|
||||||
mgr.tr.Lock()
|
mgr.tr.Lock()
|
||||||
|
defer mgr.tr.Unlock()
|
||||||
for id, inFlight := range mgr.tr.inFlightQueries {
|
for id, inFlight := range mgr.tr.inFlightQueries {
|
||||||
if inFlight.Started.Before(ignoreBefore) {
|
if inFlight.Started.Before(ignoreBefore) {
|
||||||
// remove old queries
|
// remove old queries
|
||||||
|
@ -272,29 +256,34 @@ func (mgr *tcpResolverConnMgr) waitForWork() (proceed bool) {
|
||||||
}
|
}
|
||||||
// in-flight queries that match the connection instance ID are not changed. They are already in the queue.
|
// in-flight queries that match the connection instance ID are not changed. They are already in the queue.
|
||||||
}
|
}
|
||||||
mgr.tr.Unlock()
|
return true
|
||||||
} else {
|
}
|
||||||
|
|
||||||
// wait for first query
|
// wait for first query
|
||||||
select {
|
select {
|
||||||
case <-mgr.workerCtx.Done():
|
case <-workerCtx.Done():
|
||||||
return false
|
return false
|
||||||
case msg := <-mgr.tr.queries:
|
case msg := <-mgr.tr.queries:
|
||||||
// re-insert query, we will handle it later
|
// re-insert query, we will handle it later
|
||||||
|
module.StartWorker("reinject triggering dns query", func(ctx context.Context) error {
|
||||||
select {
|
select {
|
||||||
case mgr.tr.queries <- msg:
|
case mgr.tr.queries <- msg:
|
||||||
default:
|
case <-time.After(2 * time.Second):
|
||||||
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Name)
|
log.Warningf("resolver: failed to re-inject waking query to %s", mgr.tr.resolver.Name)
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) establishConnection() (success bool) {
|
func (mgr *tcpResolverConnMgr) establishConnection(workerCtx context.Context) (
|
||||||
// create connection
|
conn *dns.Conn,
|
||||||
mgr.connCtx, mgr.cancelConnCtx = context.WithCancel(mgr.workerCtx)
|
connClosing *abool.AtomicBool,
|
||||||
mgr.connClosing = abool.New()
|
connCtx context.Context,
|
||||||
|
cancelConnCtx context.CancelFunc,
|
||||||
|
) {
|
||||||
// refresh dialer to set an authenticated local address
|
// refresh dialer to set an authenticated local address
|
||||||
// TODO: lock dnsClient (only manager should run at any time, so this should not be an issue)
|
// TODO: lock dnsClient (only manager should run at any time, so this should not be an issue)
|
||||||
mgr.tr.dnsClient.Dialer = &net.Dialer{
|
mgr.tr.dnsClient.Dialer = &net.Dialer{
|
||||||
|
@ -303,42 +292,56 @@ func (mgr *tcpResolverConnMgr) establishConnection() (success bool) {
|
||||||
KeepAlive: defaultClientTTL,
|
KeepAlive: defaultClientTTL,
|
||||||
}
|
}
|
||||||
// connect
|
// connect
|
||||||
c, err := mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
|
var err error
|
||||||
|
conn, err = mgr.tr.dnsClient.Dial(mgr.tr.resolver.ServerAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.Name, mgr.tr.resolver.ServerAddress)
|
log.Debugf("resolver: failed to connect to %s (%s)", mgr.tr.resolver.Name, mgr.tr.resolver.ServerAddress)
|
||||||
return false
|
return nil, nil, nil, nil
|
||||||
}
|
}
|
||||||
mgr.conn = c
|
connCtx, cancelConnCtx = context.WithCancel(workerCtx)
|
||||||
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, mgr.conn.RemoteAddr())
|
connClosing = abool.New()
|
||||||
|
|
||||||
// reset timer
|
log.Debugf("resolver: connected to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
|
||||||
mgr.connTimer.Stop()
|
|
||||||
select {
|
|
||||||
case <-mgr.connTimer.C: // try to empty the timer
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
mgr.connTimer.Reset(mgr.tr.clientTTL)
|
|
||||||
|
|
||||||
// start reader
|
// start reader
|
||||||
module.StartServiceWorker("dns client reader", 10*time.Millisecond, mgr.msgReader)
|
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
|
||||||
|
return mgr.msgReader(conn, connClosing, cancelConnCtx)
|
||||||
|
})
|
||||||
|
|
||||||
return true
|
return conn, connClosing, connCtx, cancelConnCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit
|
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
|
||||||
|
workerCtx context.Context,
|
||||||
|
conn *dns.Conn,
|
||||||
|
connClosing *abool.AtomicBool,
|
||||||
|
connCtx context.Context,
|
||||||
|
cancelConnCtx context.CancelFunc,
|
||||||
|
) (proceed bool) {
|
||||||
var readyToRecycle bool
|
var readyToRecycle bool
|
||||||
|
ttlTimer := time.After(mgr.tr.clientTTL)
|
||||||
|
|
||||||
|
// clean up connection
|
||||||
|
defer func() {
|
||||||
|
connClosing.Set() // silence connection errors
|
||||||
|
cancelConnCtx()
|
||||||
|
_ = conn.Close()
|
||||||
|
|
||||||
|
// increase instance counter
|
||||||
|
atomic.AddUint32(mgr.tr.connInstanceID, 1)
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-mgr.workerCtx.Done():
|
case <-workerCtx.Done():
|
||||||
// module shutdown
|
// module shutdown
|
||||||
return false
|
return false
|
||||||
|
|
||||||
case <-mgr.connCtx.Done():
|
case <-connCtx.Done():
|
||||||
// connection error
|
// connection error
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case <-mgr.connTimer.C:
|
case <-ttlTimer:
|
||||||
// connection TTL reached, rebuild connection
|
// connection TTL reached, rebuild connection
|
||||||
// but handle all in flight queries first
|
// but handle all in flight queries first
|
||||||
readyToRecycle = true
|
readyToRecycle = true
|
||||||
|
@ -351,18 +354,36 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit
|
||||||
|
|
||||||
case msg := <-mgr.tr.queries:
|
case msg := <-mgr.tr.queries:
|
||||||
// write query
|
// write query
|
||||||
_ = mgr.conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
|
_ = conn.SetWriteDeadline(time.Now().Add(mgr.tr.dnsClient.WriteTimeout))
|
||||||
err := mgr.conn.WriteMsg(msg)
|
err := conn.WriteMsg(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if mgr.connClosing.SetToIf(false, true) {
|
if connClosing.SetToIf(false, true) {
|
||||||
mgr.cancelConnCtx()
|
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
|
||||||
log.Warningf("resolver: write error to %s (%s): %s", mgr.tr.resolver.Name, mgr.conn.RemoteAddr(), err)
|
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
case msg := <-mgr.responses:
|
case msg := <-mgr.responses:
|
||||||
if msg != nil { // nil messages only trigger the recycle check
|
if msg != nil {
|
||||||
|
mgr.handleQueryResponse(conn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if readyToRecycle {
|
||||||
|
// check to see if we can recycle the connection
|
||||||
|
mgr.tr.Lock()
|
||||||
|
activeQueries := len(mgr.tr.inFlightQueries)
|
||||||
|
mgr.tr.Unlock()
|
||||||
|
if activeQueries == 0 {
|
||||||
|
log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, conn.RemoteAddr())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *tcpResolverConnMgr) handleQueryResponse(conn *dns.Conn, msg *dns.Msg) {
|
||||||
// handle query from resolver
|
// handle query from resolver
|
||||||
mgr.tr.Lock()
|
mgr.tr.Lock()
|
||||||
inFlight, ok := mgr.tr.inFlightQueries[msg.Id]
|
inFlight, ok := mgr.tr.inFlightQueries[msg.Id]
|
||||||
|
@ -371,14 +392,31 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit
|
||||||
}
|
}
|
||||||
mgr.tr.Unlock()
|
mgr.tr.Unlock()
|
||||||
|
|
||||||
if ok {
|
if !ok {
|
||||||
|
log.Debugf(
|
||||||
|
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
|
||||||
|
mgr.tr.resolver.Name,
|
||||||
|
conn.RemoteAddr(),
|
||||||
|
msg.Id,
|
||||||
|
msg.Question,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case inFlight.Response <- msg:
|
case inFlight.Response <- msg:
|
||||||
mgr.failCnt = 0 // reset fail counter
|
mgr.failCnt = 0 // reset fail counter
|
||||||
// responded!
|
// responded!
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
// save to cache, if enabled
|
// no one is listening for that response.
|
||||||
if !inFlight.Query.NoCaching {
|
}
|
||||||
|
|
||||||
|
// if caching is disabled we're done
|
||||||
|
if inFlight.Query.NoCaching {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// persist to database
|
// persist to database
|
||||||
rrCache := inFlight.MakeCacheRecord(msg)
|
rrCache := inFlight.MakeCacheRecord(msg)
|
||||||
rrCache.Clean(600)
|
rrCache.Clean(600)
|
||||||
|
@ -391,46 +429,19 @@ func (mgr *tcpResolverConnMgr) queryHandler() (proceed bool) { //nolint:gocognit
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf(
|
|
||||||
"resolver: received possibly unsolicited reply from %s (%s): txid=%d q=%+v",
|
|
||||||
mgr.tr.resolver.Name,
|
|
||||||
mgr.conn.RemoteAddr(),
|
|
||||||
msg.Id,
|
|
||||||
msg.Question,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if readyToRecycle {
|
|
||||||
// check to see if we can recycle the connection
|
|
||||||
mgr.tr.Lock()
|
|
||||||
activeQueries := len(mgr.tr.inFlightQueries)
|
|
||||||
mgr.tr.Unlock()
|
|
||||||
if activeQueries == 0 {
|
|
||||||
log.Debugf("resolver: recycling conn to %s (%s)", mgr.tr.resolver.Name, mgr.conn.RemoteAddr())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *tcpResolverConnMgr) msgReader(workerCtx context.Context) error {
|
func (mgr *tcpResolverConnMgr) msgReader(
|
||||||
// copy values from manager
|
conn *dns.Conn,
|
||||||
conn := mgr.conn
|
connClosing *abool.AtomicBool,
|
||||||
cancelConnCtx := mgr.cancelConnCtx
|
cancelConnCtx context.CancelFunc,
|
||||||
connClosing := mgr.connClosing
|
) error {
|
||||||
|
defer cancelConnCtx()
|
||||||
for {
|
for {
|
||||||
msg, err := conn.ReadMsg()
|
msg, err := conn.ReadMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if connClosing.SetToIf(false, true) {
|
if connClosing.SetToIf(false, true) {
|
||||||
cancelConnCtx()
|
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, conn.RemoteAddr(), err)
|
||||||
log.Warningf("resolver: read error from %s (%s): %s", mgr.tr.resolver.Name, mgr.conn.RemoteAddr(), err)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue