Merge pull request #452 from safing/fix/patch-set-10

Fix dns request marshaling and improve network db
This commit is contained in:
Daniel 2021-11-18 09:28:36 +01:00 committed by GitHub
commit c50087630e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 15 deletions

View file

@ -2,6 +2,7 @@ package nsutil
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -45,6 +46,11 @@ func (rf ResponderFunc) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns
return rf(ctx, request) return rf(ctx, request)
} }
// MarshalJSON disables JSON marshaling for ResponderFunc.
func (rf ResponderFunc) MarshalJSON() ([]byte, error) {
return json.Marshal(nil)
}
// BlockIP is a ResponderFunc than replies with either 0.0.0.17 or ::17 for // BlockIP is a ResponderFunc than replies with either 0.0.0.17 or ::17 for
// each A or AAAA question respectively. If there is no A or AAAA question, it // each A or AAAA question respectively. If there is no A or AAAA question, it
// defaults to replying with NXDomain. // defaults to replying with NXDomain.

View file

@ -1,6 +1,7 @@
package network package network
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -115,13 +116,17 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
// Query returns a an iterator for the supplied query. // Query returns a an iterator for the supplied query.
func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
it := iterator.New() it := iterator.New()
go s.processQuery(q, it)
// TODO: check local and internal module.StartWorker("connection query", func(_ context.Context) error {
s.processQuery(q, it)
return nil
})
return it, nil return it, nil
} }
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
var matches bool
pid, scope, _, ok := parseDBKey(q.DatabaseKeyPrefix()) pid, scope, _, ok := parseDBKey(q.DatabaseKeyPrefix())
if !ok { if !ok {
it.Finish(nil) it.Finish(nil)
@ -131,33 +136,42 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
if pid == process.UndefinedProcessID { if pid == process.UndefinedProcessID {
// processes // processes
for _, proc := range process.All() { for _, proc := range process.All() {
func() {
proc.Lock() proc.Lock()
if q.Matches(proc) { defer proc.Unlock()
matches = q.Matches(proc)
}()
if matches {
it.Next <- proc it.Next <- proc
} }
proc.Unlock()
} }
} }
if scope == "" || scope == "dns" { if scope == "" || scope == "dns" {
// dns scopes only // dns scopes only
for _, dnsConn := range dnsConns.clone() { for _, dnsConn := range dnsConns.clone() {
func() {
dnsConn.Lock() dnsConn.Lock()
if q.Matches(dnsConn) { defer dnsConn.Unlock()
matches = q.Matches(dnsConn)
}()
if matches {
it.Next <- dnsConn it.Next <- dnsConn
} }
dnsConn.Unlock()
} }
} }
if scope == "" || scope == "ip" { if scope == "" || scope == "ip" {
// connections // connections
for _, conn := range conns.clone() { for _, conn := range conns.clone() {
func() {
conn.Lock() conn.Lock()
if q.Matches(conn) { defer conn.Unlock()
matches = q.Matches(conn)
}()
if matches {
it.Next <- conn it.Next <- conn
} }
conn.Unlock()
} }
} }

View file

@ -21,16 +21,16 @@ func SetDefaultFirewallHandler(handler FirewallHandler) {
} }
} }
func prep() error {
return registerAPIEndpoints()
}
func start() error { func start() error {
err := registerAsDatabase() err := registerAsDatabase()
if err != nil { if err != nil {
return err return err
} }
if err := registerAPIEndpoints(); err != nil {
return err
}
if err := registerMetrics(); err != nil { if err := registerMetrics(); err != nil {
return err return err
} }