Merge pull request #229 from safing/feature/switch-to-localhost-nameserver

Switch to localhost nameserver
This commit is contained in:
Daniel 2021-01-19 15:13:10 +01:00 committed by GitHub
commit 3f8c99517f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 305 additions and 126 deletions

View file

@ -40,6 +40,7 @@ var (
dataRoot *utils.DirStructure dataRoot *utils.DirStructure
apiPortSet bool apiPortSet bool
apiIP net.IP
apiPort uint16 apiPort uint16
) )
@ -50,7 +51,7 @@ func prepAPIAuth() error {
func startAPIAuth() { func startAPIAuth() {
var err error var err error
_, apiPort, err = parseHostPort(apiListenAddress()) apiIP, apiPort, err = parseHostPort(apiListenAddress())
if err != nil { if err != nil {
log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err) log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err)
return return

View file

@ -2,12 +2,13 @@ package firewall
import ( import (
"context" "context"
"errors"
"net" "net"
"os" "os"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/safing/portmaster/netenv" "github.com/tevino/abool"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
@ -26,6 +27,10 @@ import (
var ( var (
interceptionModule *modules.Module interceptionModule *modules.Module
nameserverIPMatcher func(ip net.IP) bool
nameserverIPMatcherSet = abool.New()
nameserverIPMatcherReady = abool.New()
packetsAccepted = new(uint64) packetsAccepted = new(uint64)
packetsBlocked = new(uint64) packetsBlocked = new(uint64)
packetsDropped = new(uint64) packetsDropped = new(uint64)
@ -59,6 +64,18 @@ func interceptionStop() error {
return interception.Stop() return interception.Stop()
} }
// SetNameserverIPMatcher sets a function that is used to match the internal
// nameserver IP(s). Can only bet set once.
func SetNameserverIPMatcher(fn func(ip net.IP) bool) error {
if !nameserverIPMatcherSet.SetToIf(false, true) {
return errors.New("nameserver IP matcher already set")
}
nameserverIPMatcher = fn
nameserverIPMatcherReady.Set()
return nil
}
func handlePacket(ctx context.Context, pkt packet.Packet) { func handlePacket(ctx context.Context, pkt packet.Packet) {
if fastTrackedPermit(pkt) { if fastTrackedPermit(pkt) {
return return
@ -90,12 +107,22 @@ func handlePacket(ctx context.Context, pkt packet.Packet) {
func fastTrackedPermit(pkt packet.Packet) (handled bool) { func fastTrackedPermit(pkt packet.Packet) (handled bool) {
meta := pkt.Info() meta := pkt.Info()
// Check for blocked IP // Check if connection was already blocked.
if meta.Dst.Equal(blockedIPv4) || meta.Dst.Equal(blockedIPv6) { if meta.Dst.Equal(blockedIPv4) || meta.Dst.Equal(blockedIPv6) {
_ = pkt.PermanentBlock() _ = pkt.PermanentBlock()
return true return true
} }
// Some programs do a network self-check where they connects to the same
// IP/Port to test network capabilities.
// Eg. dig: https://gitlab.isc.org/isc-projects/bind9/-/issues/1140
if meta.SrcPort == meta.DstPort &&
meta.Src.Equal(meta.Dst) {
log.Debugf("filter: fast-track network self-check: %s", pkt)
_ = pkt.PermanentAccept()
return true
}
switch meta.Protocol { switch meta.Protocol {
case packet.ICMP: case packet.ICMP:
// Always permit ICMP. // Always permit ICMP.
@ -135,39 +162,42 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
case apiPort: case apiPort:
// Always allow direct access to the Portmaster API. // Always allow direct access to the Portmaster API.
// Portmaster API is TCP only.
if meta.Protocol != packet.TCP {
return false
}
// Check if the api port is even set. // Check if the api port is even set.
if !apiPortSet { if !apiPortSet {
return false return false
} }
// Portmaster API must be TCP // Must be destined for the API IP.
if meta.Protocol != packet.TCP { if !meta.Dst.Equal(apiIP) {
return false
}
fallthrough
case 53:
// Always allow direct local access to own services.
// DNS is both UDP and TCP.
// Only allow to own IPs.
dstIsMe, err := netenv.IsMyIP(meta.Dst)
if err != nil {
log.Warningf("filter: failed to check if IP %s is local: %s", meta.Dst, err)
}
if !dstIsMe {
return false return false
} }
// Log and permit. // Log and permit.
switch meta.DstPort { log.Debugf("filter: fast-track accepting api connection: %s", pkt)
case 53: _ = pkt.PermanentAccept()
log.Debugf("filter: fast-track accepting local dns: %s", pkt) return true
case apiPort:
log.Debugf("filter: fast-track accepting api connection: %s", pkt) case 53:
default: // Always allow direct access to the Portmaster Nameserver.
// DNS is both UDP and TCP.
// Check if a nameserver IP matcher is set.
if !nameserverIPMatcherReady.IsSet() {
return false return false
} }
// Check if packet is destined for a nameserver IP.
if !nameserverIPMatcher(meta.Dst) {
return false
}
// Log and permit.
log.Debugf("filter: fast-track accepting local dns: %s", pkt)
_ = pkt.PermanentAccept() _ = pkt.PermanentAccept()
return true return true
} }
@ -191,8 +221,12 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
return return
} }
// reroute dns requests to nameserver // Redirect rogue dns requests to the Portmaster.
if conn.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) { if pkt.IsOutbound() &&
pkt.Info().DstPort == 53 &&
conn.Process().Pid != os.Getpid() &&
nameserverIPMatcherReady.IsSet() &&
!nameserverIPMatcher(pkt.Info().Dst) {
conn.Verdict = network.VerdictRerouteToNameserver conn.Verdict = network.VerdictRerouteToNameserver
conn.Reason.Msg = "redirecting rogue dns query" conn.Reason.Msg = "redirecting rogue dns query"
conn.Internal = true conn.Internal = true

View file

@ -292,6 +292,7 @@ func (e *Entity) getDomainLists(ctx context.Context) {
return return
} }
var err error
e.loadDomainListOnce.Do(func() { e.loadDomainListOnce.Do(func() {
var domainsToInspect = []string{domain} var domainsToInspect = []string{domain}
@ -314,10 +315,10 @@ func (e *Entity) getDomainLists(ctx context.Context) {
for _, d := range domains { for _, d := range domains {
log.Tracer(ctx).Tracef("intel: loading domain list for %s", d) log.Tracer(ctx).Tracef("intel: loading domain list for %s", d)
list, err := filterlists.LookupDomain(d) var list []string
list, err = filterlists.LookupDomain(d)
if err != nil { if err != nil {
log.Tracer(ctx).Errorf("intel: failed to get domain blocklists for %s: %s", d, err) log.Tracer(ctx).Errorf("intel: failed to get domain blocklists for %s: %s", d, err)
e.loadDomainListOnce = sync.Once{}
return return
} }
@ -325,6 +326,10 @@ func (e *Entity) getDomainLists(ctx context.Context) {
} }
e.domainListLoaded = true e.domainListLoaded = true
}) })
if err != nil {
e.loadDomainListOnce = sync.Once{}
}
} }
func splitDomain(domain string) []string { func splitDomain(domain string) []string {

View file

@ -2,9 +2,7 @@ package filterlists
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"strings"
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -55,15 +53,16 @@ func LookupCountry(country string) ([]string, error) {
// LookupDomain returns a list of sources that mark the domain // LookupDomain returns a list of sources that mark the domain
// as blocked. If domain is not stored in the cache database // as blocked. If domain is not stored in the cache database
// a nil slice is returned. // a nil slice is returned. The caller is responsible for making
// sure that the given domain is valid and canonical.
func LookupDomain(domain string) ([]string, error) { func LookupDomain(domain string) ([]string, error) {
// make sure we only fully qualified domains switch domain {
// ending in a dot. case "", ".":
domain = strings.ToLower(domain) // Return no lists for empty domains and the root zone.
if domain[len(domain)-1] != '.' { return nil, nil
domain += "." default:
return lookupBlockLists("domain", domain)
} }
return lookupBlockLists("domain", domain)
} }
// LookupASNString returns a list of sources that mark the ASN // LookupASNString returns a list of sources that mark the ASN
@ -89,7 +88,7 @@ func LookupIP(ip net.IP) ([]string, error) {
func LookupIPString(ipStr string) ([]string, error) { func LookupIPString(ipStr string) ([]string, error) {
ip := net.ParseIP(ipStr) ip := net.ParseIP(ipStr)
if ip == nil { if ip == nil {
return nil, fmt.Errorf("invalid IP") return nil, errors.New("invalid IP")
} }
return LookupIP(ip) return LookupIP(ip)

View file

@ -2,6 +2,7 @@ package nameserver
import ( import (
"flag" "flag"
"runtime"
"github.com/safing/portbase/config" "github.com/safing/portbase/config"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -15,9 +16,16 @@ const (
var ( var (
nameserverAddressFlag string nameserverAddressFlag string
nameserverAddressConfig config.StringOption nameserverAddressConfig config.StringOption
defaultNameserverAddress = "localhost:53"
) )
func init() { func init() {
// On Windows, packets are redirected to the same interface.
if runtime.GOOS == "windows" {
defaultNameserverAddress = "0.0.0.0:53"
}
flag.StringVar(&nameserverAddressFlag, "nameserver-address", "", "override nameserver listen address") flag.StringVar(&nameserverAddressFlag, "nameserver-address", "", "override nameserver listen address")
} }
@ -45,7 +53,7 @@ func registerConfig() error {
ExpertiseLevel: config.ExpertiseLevelDeveloper, ExpertiseLevel: config.ExpertiseLevelDeveloper,
ReleaseLevel: config.ReleaseLevelStable, ReleaseLevel: config.ReleaseLevelStable,
DefaultValue: getDefaultNameserverAddress(), DefaultValue: getDefaultNameserverAddress(),
ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$", ValidationRegex: "^(localhost|[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}|\\[[:0-9A-Fa-f]+\\]):[0-9]{1,5}$",
RequiresRestart: true, RequiresRestart: true,
Annotations: config.Annotations{ Annotations: config.Annotations{
config.DisplayOrderAnnotation: 514, config.DisplayOrderAnnotation: 514,

169
nameserver/module.go Normal file
View file

@ -0,0 +1,169 @@
package nameserver
import (
"context"
"fmt"
"net"
"strconv"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portmaster/firewall"
"github.com/safing/portmaster/netenv"
"github.com/miekg/dns"
)
var (
module *modules.Module
stopListener func() error
)
func init() {
module = modules.Register("nameserver", prep, start, stop, "core", "resolver")
subsystems.Register(
"dns",
"Secure DNS",
"DNS resolver with scoping and DNS-over-TLS",
module,
"config:dns/",
nil,
)
}
func prep() error {
return registerConfig()
}
func start() error {
logFlagOverrides()
ip1, ip2, port, err := getListenAddresses(nameserverAddressConfig())
if err != nil {
return fmt.Errorf("failed to parse nameserver listen address: %w", err)
}
// Start listener(s).
if ip2 == nil {
// Start a single listener.
dnsServer := startListener(ip1, port)
stopListener = dnsServer.Shutdown
// Set nameserver matcher in firewall to fast-track dns queries.
if ip1.Equal(net.IPv4zero) || ip1.Equal(net.IPv6zero) {
// Fast track dns queries destined for any of the local IPs.
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
dstIsMe, err := netenv.IsMyIP(ip)
if err != nil {
log.Warningf("nameserver: failed to check if IP %s is local: %s", ip, err)
}
return dstIsMe
})
} else {
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
return ip.Equal(ip1)
})
}
} else {
// Dual listener.
dnsServer1 := startListener(ip1, port)
dnsServer2 := startListener(ip2, port)
stopListener = func() error {
// Shutdown both listeners.
err1 := dnsServer1.Shutdown()
err2 := dnsServer2.Shutdown()
// Return first error.
if err1 != nil {
return err1
}
return err2
}
// Fast track dns queries destined for one of the listener IPs.
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
return ip.Equal(ip1) || ip.Equal(ip2)
})
}
}
func startListener(ip net.IP, port uint16) *dns.Server {
// Create DNS server.
dnsServer := &dns.Server{
Addr: net.JoinHostPort(
ip.String(),
strconv.Itoa(int(port)),
),
Net: "udp",
}
dns.HandleFunc(".", handleRequestAsWorker)
// Start DNS server as service worker.
log.Infof("nameserver: starting to listen on %s", dnsServer.Addr)
module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error {
err := dnsServer.ListenAndServe()
if err != nil {
// check if we are shutting down
if module.IsStopping() {
return nil
}
// is something blocking our port?
checkErr := checkForConflictingService(ip, port)
if checkErr != nil {
return checkErr
}
}
return err
})
return dnsServer
}
func stop() error {
if stopListener != nil {
return stopListener()
}
return nil
}
func getListenAddresses(listenAddress string) (ip1, ip2 net.IP, port uint16, err error) {
// Split host and port.
ipString, portString, err := net.SplitHostPort(listenAddress)
if err != nil {
return nil, nil, 0, fmt.Errorf(
"failed to parse address %s: %w",
listenAddress,
err,
)
}
// Parse the IP address. If the want to listen on localhost, we need to
// listen separately for IPv4 and IPv6.
if ipString == "localhost" {
ip1 = net.IPv4(127, 0, 0, 17)
ip2 = net.IPv6loopback
} else {
ip1 = net.ParseIP(ipString)
if ip1 == nil {
return nil, nil, 0, fmt.Errorf(
"failed to parse IP %s from %s",
ipString,
listenAddress,
)
}
}
// Parse the port.
port64, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, nil, 0, fmt.Errorf(
"failed to parse port %s from %s: %w",
portString,
listenAddress,
err,
)
}
return ip1, ip2, uint16(port64), nil
}

View file

@ -6,76 +6,18 @@ import (
"net" "net"
"strings" "strings"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/firewall" "github.com/safing/portmaster/firewall"
"github.com/safing/portmaster/nameserver/nsutil" "github.com/safing/portmaster/nameserver/nsutil"
"github.com/safing/portmaster/netenv" "github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/resolver" "github.com/safing/portmaster/resolver"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
var (
module *modules.Module
dnsServer *dns.Server
defaultNameserverAddress = "0.0.0.0:53"
)
func init() {
module = modules.Register("nameserver", prep, start, stop, "core", "resolver")
subsystems.Register(
"dns",
"Secure DNS",
"DNS resolver with scoping and DNS-over-TLS",
module,
"config:dns/",
nil,
)
}
func prep() error {
return registerConfig()
}
func start() error {
logFlagOverrides()
dnsServer = &dns.Server{Addr: nameserverAddressConfig(), Net: "udp"}
dns.HandleFunc(".", handleRequestAsWorker)
module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error {
err := dnsServer.ListenAndServe()
if err != nil {
// check if we are shutting down
if module.IsStopping() {
return nil
}
// is something blocking our port?
checkErr := checkForConflictingService()
if checkErr != nil {
return checkErr
}
}
return err
})
return nil
}
func stop() error {
if dnsServer != nil {
return dnsServer.Shutdown()
}
return nil
}
func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) { func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) {
err := module.RunWorker("dns request", func(ctx context.Context) error { err := module.RunWorker("dns request", func(ctx context.Context) error {
return handleRequest(ctx, w, query) return handleRequest(ctx, w, query)

View file

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"time" "strconv"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
@ -14,39 +14,46 @@ import (
) )
var ( var (
otherResolverIPs = []net.IP{ commonResolverIPs = []net.IP{
net.IPv4zero,
net.IPv4(127, 0, 0, 1), // default net.IPv4(127, 0, 0, 1), // default
net.IPv4(127, 0, 0, 53), // some resolvers on Linux net.IPv4(127, 0, 0, 53), // some resolvers on Linux
net.IPv6zero,
net.IPv6loopback,
} }
) )
func checkForConflictingService() error { func checkForConflictingService(ip net.IP, port uint16) error {
var pid int // Evaluate which IPs to check.
var err error var ipsToCheck []net.IP
if ip.Equal(net.IPv4zero) || ip.Equal(net.IPv6zero) {
ipsToCheck = commonResolverIPs
} else {
ipsToCheck = []net.IP{ip}
}
// check multiple IPs for other resolvers // Check if there is another resolver when need to take over.
for _, resolverIP := range otherResolverIPs { var killed int
pid, err = takeover(resolverIP) for _, resolverIP := range ipsToCheck {
if err == nil && pid != 0 { pid, err := takeover(resolverIP, port)
switch {
case err != nil:
// Log the error and let the worker try again.
log.Infof("nameserver: could not stop conflicting service: %s", err)
return nil
case pid != 0:
// Conflicting service identified and killed!
killed = pid
break break
} }
} }
// handle returns
if err != nil { // Check if something was killed.
log.Infof("nameserver: could not stop conflicting service: %s", err) if killed == 0 {
// leave original service-worker error intact
return nil
}
if pid == 0 {
// no conflicting service identified
return nil return nil
} }
// we killed something! // Notify the user that we killed something.
// wait for a short duration for the other service to shut down
time.Sleep(10 * time.Millisecond)
notifications.Notify(&notifications.Notification{ notifications.Notify(&notifications.Notification{
EventID: "namserver:stopped-conflicting-service", EventID: "namserver:stopped-conflicting-service",
Type: notifications.Info, Type: notifications.Info,
@ -54,15 +61,15 @@ func checkForConflictingService() error {
Category: "Secure DNS", Category: "Secure DNS",
Message: fmt.Sprintf( Message: fmt.Sprintf(
"The Portmaster stopped a conflicting name service (pid %d) to gain required system integration.", "The Portmaster stopped a conflicting name service (pid %d) to gain required system integration.",
pid, killed,
), ),
}) })
// restart via service-worker logic // Restart nameserver via service-worker logic.
return fmt.Errorf("%w: stopped conflicting name service with pid %d", modules.ErrRestartNow, pid) return fmt.Errorf("%w: stopped conflicting name service with pid %d", modules.ErrRestartNow, killed)
} }
func takeover(resolverIP net.IP) (int, error) { func takeover(resolverIP net.IP, resolverPort uint16) (int, error) {
pid, _, err := state.Lookup(&packet.Info{ pid, _, err := state.Lookup(&packet.Info{
Inbound: true, Inbound: true,
Version: 0, // auto-detect Version: 0, // auto-detect
@ -70,13 +77,18 @@ func takeover(resolverIP net.IP) (int, error) {
Src: nil, // do not record direction Src: nil, // do not record direction
SrcPort: 0, // do not record direction SrcPort: 0, // do not record direction
Dst: resolverIP, Dst: resolverIP,
DstPort: 53, DstPort: resolverPort,
}) })
if err != nil { if err != nil {
// there may be nothing listening on :53 // there may be nothing listening on :53
return 0, nil return 0, nil
} }
// Just don't, uh, kill ourselves...
if pid == os.Getpid() {
return 0, nil
}
proc, err := os.FindProcess(pid) proc, err := os.FindProcess(pid)
if err != nil { if err != nil {
// huh. gone already? I guess we'll wait then... // huh. gone already? I guess we'll wait then...
@ -92,5 +104,14 @@ func takeover(resolverIP net.IP) (int, error) {
} }
} }
log.Warningf(
"nameserver: killed conflicting service with PID %d over %s",
pid,
net.JoinHostPort(
resolverIP.String(),
strconv.Itoa(int(resolverPort)),
),
)
return pid, nil return pid, nil
} }