diff --git a/nameserver/nsutil/nsutil.go b/nameserver/nsutil/nsutil.go index 98ac6971..65674672 100644 --- a/nameserver/nsutil/nsutil.go +++ b/nameserver/nsutil/nsutil.go @@ -2,6 +2,7 @@ package nsutil import ( "context" + "encoding/json" "errors" "fmt" "strings" @@ -45,6 +46,11 @@ func (rf ResponderFunc) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns return rf(ctx, request) } +// MarshalJSON disables JSON marshaling for ResponderFunc. +func (rf ResponderFunc) MarshalJSON() ([]byte, error) { + return json.Marshal(nil) +} + // BlockIP is a ResponderFunc than replies with either 0.0.0.17 or ::17 for // each A or AAAA question respectively. If there is no A or AAAA question, it // defaults to replying with NXDomain. diff --git a/network/database.go b/network/database.go index 6d9a337d..0937e3fb 100644 --- a/network/database.go +++ b/network/database.go @@ -1,6 +1,7 @@ package network import ( + "context" "fmt" "strconv" "strings" @@ -115,13 +116,17 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { // Query returns a an iterator for the supplied query. func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { it := iterator.New() - go s.processQuery(q, it) - // TODO: check local and internal + + module.StartWorker("connection query", func(_ context.Context) error { + s.processQuery(q, it) + return nil + }) return it, nil } func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { + var matches bool pid, scope, _, ok := parseDBKey(q.DatabaseKeyPrefix()) if !ok { it.Finish(nil) @@ -131,33 +136,42 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { if pid == process.UndefinedProcessID { // processes for _, proc := range process.All() { - proc.Lock() - if q.Matches(proc) { + func() { + proc.Lock() + defer proc.Unlock() + matches = q.Matches(proc) + }() + if matches { it.Next <- proc } - proc.Unlock() } } if scope == "" || scope == "dns" { // dns scopes only for _, dnsConn := range dnsConns.clone() { - dnsConn.Lock() - if q.Matches(dnsConn) { + func() { + dnsConn.Lock() + defer dnsConn.Unlock() + matches = q.Matches(dnsConn) + }() + if matches { it.Next <- dnsConn } - dnsConn.Unlock() } } if scope == "" || scope == "ip" { // connections for _, conn := range conns.clone() { - conn.Lock() - if q.Matches(conn) { + func() { + conn.Lock() + defer conn.Unlock() + matches = q.Matches(conn) + }() + if matches { it.Next <- conn } - conn.Unlock() } } diff --git a/network/module.go b/network/module.go index a2f70ec7..081d6725 100644 --- a/network/module.go +++ b/network/module.go @@ -21,16 +21,16 @@ func SetDefaultFirewallHandler(handler FirewallHandler) { } } +func prep() error { + return registerAPIEndpoints() +} + func start() error { err := registerAsDatabase() if err != nil { return err } - if err := registerAPIEndpoints(); err != nil { - return err - } - if err := registerMetrics(); err != nil { return err }