diff --git a/firewall/api.go b/firewall/api.go index 267ad5a5..5e6f81bf 100644 --- a/firewall/api.go +++ b/firewall/api.go @@ -40,6 +40,7 @@ var ( dataRoot *utils.DirStructure apiPortSet bool + apiIP net.IP apiPort uint16 ) @@ -50,7 +51,7 @@ func prepAPIAuth() error { func startAPIAuth() { var err error - _, apiPort, err = parseHostPort(apiListenAddress()) + apiIP, apiPort, err = parseHostPort(apiListenAddress()) if err != nil { log.Warningf("filter: failed to parse API address for improved api auth mechanism: %s", err) return diff --git a/firewall/interception.go b/firewall/interception.go index fd8a6887..02615f0d 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -2,12 +2,13 @@ package firewall import ( "context" + "errors" "net" "os" "sync/atomic" "time" - "github.com/safing/portmaster/netenv" + "github.com/tevino/abool" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" @@ -26,6 +27,10 @@ import ( var ( interceptionModule *modules.Module + nameserverIPMatcher func(ip net.IP) bool + nameserverIPMatcherSet = abool.New() + nameserverIPMatcherReady = abool.New() + packetsAccepted = new(uint64) packetsBlocked = new(uint64) packetsDropped = new(uint64) @@ -59,6 +64,18 @@ func interceptionStop() error { return interception.Stop() } +// SetNameserverIPMatcher sets a function that is used to match the internal +// nameserver IP(s). Can only bet set once. +func SetNameserverIPMatcher(fn func(ip net.IP) bool) error { + if !nameserverIPMatcherSet.SetToIf(false, true) { + return errors.New("nameserver IP matcher already set") + } + + nameserverIPMatcher = fn + nameserverIPMatcherReady.Set() + return nil +} + func handlePacket(ctx context.Context, pkt packet.Packet) { if fastTrackedPermit(pkt) { return @@ -90,12 +107,22 @@ func handlePacket(ctx context.Context, pkt packet.Packet) { func fastTrackedPermit(pkt packet.Packet) (handled bool) { meta := pkt.Info() - // Check for blocked IP + // Check if connection was already blocked. if meta.Dst.Equal(blockedIPv4) || meta.Dst.Equal(blockedIPv6) { _ = pkt.PermanentBlock() return true } + // Some programs do a network self-check where they connects to the same + // IP/Port to test network capabilities. + // Eg. dig: https://gitlab.isc.org/isc-projects/bind9/-/issues/1140 + if meta.SrcPort == meta.DstPort && + meta.Src.Equal(meta.Dst) { + log.Debugf("filter: fast-track network self-check: %s", pkt) + _ = pkt.PermanentAccept() + return true + } + switch meta.Protocol { case packet.ICMP: // Always permit ICMP. @@ -135,39 +162,42 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) { case apiPort: // Always allow direct access to the Portmaster API. + // Portmaster API is TCP only. + if meta.Protocol != packet.TCP { + return false + } + // Check if the api port is even set. if !apiPortSet { return false } - // Portmaster API must be TCP - if meta.Protocol != packet.TCP { - return false - } - - fallthrough - case 53: - // Always allow direct local access to own services. - // DNS is both UDP and TCP. - - // Only allow to own IPs. - dstIsMe, err := netenv.IsMyIP(meta.Dst) - if err != nil { - log.Warningf("filter: failed to check if IP %s is local: %s", meta.Dst, err) - } - if !dstIsMe { + // Must be destined for the API IP. + if !meta.Dst.Equal(apiIP) { return false } // Log and permit. - switch meta.DstPort { - case 53: - log.Debugf("filter: fast-track accepting local dns: %s", pkt) - case apiPort: - log.Debugf("filter: fast-track accepting api connection: %s", pkt) - default: + log.Debugf("filter: fast-track accepting api connection: %s", pkt) + _ = pkt.PermanentAccept() + return true + + case 53: + // Always allow direct access to the Portmaster Nameserver. + // DNS is both UDP and TCP. + + // Check if a nameserver IP matcher is set. + if !nameserverIPMatcherReady.IsSet() { return false } + + // Check if packet is destined for a nameserver IP. + if !nameserverIPMatcher(meta.Dst) { + return false + } + + // Log and permit. + log.Debugf("filter: fast-track accepting local dns: %s", pkt) _ = pkt.PermanentAccept() return true } @@ -191,8 +221,12 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { return } - // reroute dns requests to nameserver - if conn.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) { + // Redirect rogue dns requests to the Portmaster. + if pkt.IsOutbound() && + pkt.Info().DstPort == 53 && + conn.Process().Pid != os.Getpid() && + nameserverIPMatcherReady.IsSet() && + !nameserverIPMatcher(pkt.Info().Dst) { conn.Verdict = network.VerdictRerouteToNameserver conn.Reason.Msg = "redirecting rogue dns query" conn.Internal = true diff --git a/intel/entity.go b/intel/entity.go index 16d83707..c98126f0 100644 --- a/intel/entity.go +++ b/intel/entity.go @@ -292,6 +292,7 @@ func (e *Entity) getDomainLists(ctx context.Context) { return } + var err error e.loadDomainListOnce.Do(func() { var domainsToInspect = []string{domain} @@ -314,10 +315,10 @@ func (e *Entity) getDomainLists(ctx context.Context) { for _, d := range domains { log.Tracer(ctx).Tracef("intel: loading domain list for %s", d) - list, err := filterlists.LookupDomain(d) + var list []string + list, err = filterlists.LookupDomain(d) if err != nil { log.Tracer(ctx).Errorf("intel: failed to get domain blocklists for %s: %s", d, err) - e.loadDomainListOnce = sync.Once{} return } @@ -325,6 +326,10 @@ func (e *Entity) getDomainLists(ctx context.Context) { } e.domainListLoaded = true }) + + if err != nil { + e.loadDomainListOnce = sync.Once{} + } } func splitDomain(domain string) []string { diff --git a/intel/filterlists/lookup.go b/intel/filterlists/lookup.go index 827aeab9..a975c281 100644 --- a/intel/filterlists/lookup.go +++ b/intel/filterlists/lookup.go @@ -2,9 +2,7 @@ package filterlists import ( "errors" - "fmt" "net" - "strings" "github.com/safing/portbase/database" "github.com/safing/portbase/log" @@ -55,15 +53,16 @@ func LookupCountry(country string) ([]string, error) { // LookupDomain returns a list of sources that mark the domain // as blocked. If domain is not stored in the cache database -// a nil slice is returned. +// a nil slice is returned. The caller is responsible for making +// sure that the given domain is valid and canonical. func LookupDomain(domain string) ([]string, error) { - // make sure we only fully qualified domains - // ending in a dot. - domain = strings.ToLower(domain) - if domain[len(domain)-1] != '.' { - domain += "." + switch domain { + case "", ".": + // Return no lists for empty domains and the root zone. + return nil, nil + default: + return lookupBlockLists("domain", domain) } - return lookupBlockLists("domain", domain) } // LookupASNString returns a list of sources that mark the ASN @@ -89,7 +88,7 @@ func LookupIP(ip net.IP) ([]string, error) { func LookupIPString(ipStr string) ([]string, error) { ip := net.ParseIP(ipStr) if ip == nil { - return nil, fmt.Errorf("invalid IP") + return nil, errors.New("invalid IP") } return LookupIP(ip) diff --git a/nameserver/config.go b/nameserver/config.go index 12e7b546..f6e9153a 100644 --- a/nameserver/config.go +++ b/nameserver/config.go @@ -2,6 +2,7 @@ package nameserver import ( "flag" + "runtime" "github.com/safing/portbase/config" "github.com/safing/portbase/log" @@ -15,9 +16,16 @@ const ( var ( nameserverAddressFlag string nameserverAddressConfig config.StringOption + + defaultNameserverAddress = "localhost:53" ) func init() { + // On Windows, packets are redirected to the same interface. + if runtime.GOOS == "windows" { + defaultNameserverAddress = "0.0.0.0:53" + } + flag.StringVar(&nameserverAddressFlag, "nameserver-address", "", "override nameserver listen address") } @@ -45,7 +53,7 @@ func registerConfig() error { ExpertiseLevel: config.ExpertiseLevelDeveloper, ReleaseLevel: config.ReleaseLevelStable, DefaultValue: getDefaultNameserverAddress(), - ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$", + ValidationRegex: "^(localhost|[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}|\\[[:0-9A-Fa-f]+\\]):[0-9]{1,5}$", RequiresRestart: true, Annotations: config.Annotations{ config.DisplayOrderAnnotation: 514, diff --git a/nameserver/module.go b/nameserver/module.go new file mode 100644 index 00000000..cc4da029 --- /dev/null +++ b/nameserver/module.go @@ -0,0 +1,169 @@ +package nameserver + +import ( + "context" + "fmt" + "net" + "strconv" + + "github.com/safing/portbase/log" + "github.com/safing/portbase/modules" + "github.com/safing/portbase/modules/subsystems" + "github.com/safing/portmaster/firewall" + "github.com/safing/portmaster/netenv" + + "github.com/miekg/dns" +) + +var ( + module *modules.Module + stopListener func() error +) + +func init() { + module = modules.Register("nameserver", prep, start, stop, "core", "resolver") + subsystems.Register( + "dns", + "Secure DNS", + "DNS resolver with scoping and DNS-over-TLS", + module, + "config:dns/", + nil, + ) +} + +func prep() error { + return registerConfig() +} + +func start() error { + logFlagOverrides() + + ip1, ip2, port, err := getListenAddresses(nameserverAddressConfig()) + if err != nil { + return fmt.Errorf("failed to parse nameserver listen address: %w", err) + } + + // Start listener(s). + if ip2 == nil { + // Start a single listener. + dnsServer := startListener(ip1, port) + stopListener = dnsServer.Shutdown + + // Set nameserver matcher in firewall to fast-track dns queries. + if ip1.Equal(net.IPv4zero) || ip1.Equal(net.IPv6zero) { + // Fast track dns queries destined for any of the local IPs. + return firewall.SetNameserverIPMatcher(func(ip net.IP) bool { + dstIsMe, err := netenv.IsMyIP(ip) + if err != nil { + log.Warningf("nameserver: failed to check if IP %s is local: %s", ip, err) + } + return dstIsMe + }) + } else { + return firewall.SetNameserverIPMatcher(func(ip net.IP) bool { + return ip.Equal(ip1) + }) + } + + } else { + // Dual listener. + dnsServer1 := startListener(ip1, port) + dnsServer2 := startListener(ip2, port) + stopListener = func() error { + // Shutdown both listeners. + err1 := dnsServer1.Shutdown() + err2 := dnsServer2.Shutdown() + // Return first error. + if err1 != nil { + return err1 + } + return err2 + } + + // Fast track dns queries destined for one of the listener IPs. + return firewall.SetNameserverIPMatcher(func(ip net.IP) bool { + return ip.Equal(ip1) || ip.Equal(ip2) + }) + } +} + +func startListener(ip net.IP, port uint16) *dns.Server { + // Create DNS server. + dnsServer := &dns.Server{ + Addr: net.JoinHostPort( + ip.String(), + strconv.Itoa(int(port)), + ), + Net: "udp", + } + dns.HandleFunc(".", handleRequestAsWorker) + + // Start DNS server as service worker. + log.Infof("nameserver: starting to listen on %s", dnsServer.Addr) + module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error { + err := dnsServer.ListenAndServe() + if err != nil { + // check if we are shutting down + if module.IsStopping() { + return nil + } + // is something blocking our port? + checkErr := checkForConflictingService(ip, port) + if checkErr != nil { + return checkErr + } + } + return err + }) + + return dnsServer +} + +func stop() error { + if stopListener != nil { + return stopListener() + } + return nil +} + +func getListenAddresses(listenAddress string) (ip1, ip2 net.IP, port uint16, err error) { + // Split host and port. + ipString, portString, err := net.SplitHostPort(listenAddress) + if err != nil { + return nil, nil, 0, fmt.Errorf( + "failed to parse address %s: %w", + listenAddress, + err, + ) + } + + // Parse the IP address. If the want to listen on localhost, we need to + // listen separately for IPv4 and IPv6. + if ipString == "localhost" { + ip1 = net.IPv4(127, 0, 0, 17) + ip2 = net.IPv6loopback + } else { + ip1 = net.ParseIP(ipString) + if ip1 == nil { + return nil, nil, 0, fmt.Errorf( + "failed to parse IP %s from %s", + ipString, + listenAddress, + ) + } + } + + // Parse the port. + port64, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, nil, 0, fmt.Errorf( + "failed to parse port %s from %s: %w", + portString, + listenAddress, + err, + ) + } + + return ip1, ip2, uint16(port64), nil +} diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index ee902fc4..7d988c8f 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -6,76 +6,18 @@ import ( "net" "strings" - "github.com/safing/portmaster/network/packet" - - "github.com/safing/portbase/modules/subsystems" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" "github.com/safing/portmaster/firewall" "github.com/safing/portmaster/nameserver/nsutil" "github.com/safing/portmaster/netenv" "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" + "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/resolver" "github.com/miekg/dns" ) -var ( - module *modules.Module - dnsServer *dns.Server - - defaultNameserverAddress = "0.0.0.0:53" -) - -func init() { - module = modules.Register("nameserver", prep, start, stop, "core", "resolver") - subsystems.Register( - "dns", - "Secure DNS", - "DNS resolver with scoping and DNS-over-TLS", - module, - "config:dns/", - nil, - ) -} - -func prep() error { - return registerConfig() -} - -func start() error { - logFlagOverrides() - dnsServer = &dns.Server{Addr: nameserverAddressConfig(), Net: "udp"} - dns.HandleFunc(".", handleRequestAsWorker) - - module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error { - err := dnsServer.ListenAndServe() - if err != nil { - // check if we are shutting down - if module.IsStopping() { - return nil - } - // is something blocking our port? - checkErr := checkForConflictingService() - if checkErr != nil { - return checkErr - } - } - return err - }) - - return nil -} - -func stop() error { - if dnsServer != nil { - return dnsServer.Shutdown() - } - return nil -} - func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) { err := module.RunWorker("dns request", func(ctx context.Context) error { return handleRequest(ctx, w, query) diff --git a/nameserver/takeover.go b/nameserver/takeover.go index 7609c6a0..636d29a4 100644 --- a/nameserver/takeover.go +++ b/nameserver/takeover.go @@ -4,7 +4,7 @@ import ( "fmt" "net" "os" - "time" + "strconv" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" @@ -14,39 +14,46 @@ import ( ) var ( - otherResolverIPs = []net.IP{ + commonResolverIPs = []net.IP{ + net.IPv4zero, net.IPv4(127, 0, 0, 1), // default net.IPv4(127, 0, 0, 53), // some resolvers on Linux + net.IPv6zero, + net.IPv6loopback, } ) -func checkForConflictingService() error { - var pid int - var err error +func checkForConflictingService(ip net.IP, port uint16) error { + // Evaluate which IPs to check. + var ipsToCheck []net.IP + if ip.Equal(net.IPv4zero) || ip.Equal(net.IPv6zero) { + ipsToCheck = commonResolverIPs + } else { + ipsToCheck = []net.IP{ip} + } - // check multiple IPs for other resolvers - for _, resolverIP := range otherResolverIPs { - pid, err = takeover(resolverIP) - if err == nil && pid != 0 { + // Check if there is another resolver when need to take over. + var killed int + for _, resolverIP := range ipsToCheck { + pid, err := takeover(resolverIP, port) + switch { + case err != nil: + // Log the error and let the worker try again. + log.Infof("nameserver: could not stop conflicting service: %s", err) + return nil + case pid != 0: + // Conflicting service identified and killed! + killed = pid break } } - // handle returns - if err != nil { - log.Infof("nameserver: could not stop conflicting service: %s", err) - // leave original service-worker error intact - return nil - } - if pid == 0 { - // no conflicting service identified + + // Check if something was killed. + if killed == 0 { return nil } - // we killed something! - - // wait for a short duration for the other service to shut down - time.Sleep(10 * time.Millisecond) - + // Notify the user that we killed something. notifications.Notify(¬ifications.Notification{ EventID: "namserver:stopped-conflicting-service", Type: notifications.Info, @@ -54,15 +61,15 @@ func checkForConflictingService() error { Category: "Secure DNS", Message: fmt.Sprintf( "The Portmaster stopped a conflicting name service (pid %d) to gain required system integration.", - pid, + killed, ), }) - // restart via service-worker logic - return fmt.Errorf("%w: stopped conflicting name service with pid %d", modules.ErrRestartNow, pid) + // Restart nameserver via service-worker logic. + return fmt.Errorf("%w: stopped conflicting name service with pid %d", modules.ErrRestartNow, killed) } -func takeover(resolverIP net.IP) (int, error) { +func takeover(resolverIP net.IP, resolverPort uint16) (int, error) { pid, _, err := state.Lookup(&packet.Info{ Inbound: true, Version: 0, // auto-detect @@ -70,13 +77,18 @@ func takeover(resolverIP net.IP) (int, error) { Src: nil, // do not record direction SrcPort: 0, // do not record direction Dst: resolverIP, - DstPort: 53, + DstPort: resolverPort, }) if err != nil { // there may be nothing listening on :53 return 0, nil } + // Just don't, uh, kill ourselves... + if pid == os.Getpid() { + return 0, nil + } + proc, err := os.FindProcess(pid) if err != nil { // huh. gone already? I guess we'll wait then... @@ -92,5 +104,14 @@ func takeover(resolverIP net.IP) (int, error) { } } + log.Warningf( + "nameserver: killed conflicting service with PID %d over %s", + pid, + net.JoinHostPort( + resolverIP.String(), + strconv.Itoa(int(resolverPort)), + ), + ) + return pid, nil }