diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 1c0eae4d..9caf3b35 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -5,7 +5,7 @@ package nameserver import ( "context" "net" - "time" + "runtime" "github.com/miekg/dns" @@ -23,8 +23,17 @@ var ( localhostIPs []dns.RR ) +var ( + listenAddress = "127.0.0.1:53" + localhostIP = net.IPv4(127, 0, 0, 1) +) + func init() { modules.Register("nameserver", prep, start, nil, "intel") + + if runtime.GOOS == "windows" { + listenAddress = "0.0.0.0:53" + } } func prep() error { @@ -44,7 +53,7 @@ func prep() error { } func start() error { - server := &dns.Server{Addr: "0.0.0.0:53", Net: "udp"} + server := &dns.Server{Addr: listenAddress, Net: "udp"} dns.HandleFunc(".", handleRequest) go run(server) return nil @@ -55,8 +64,7 @@ func run(server *dns.Server) { err := server.ListenAndServe() if err != nil { log.Errorf("nameserver: server failed: %s", err) - log.Info("nameserver: restarting server in 10 seconds") - time.Sleep(10 * time.Second) + checkForConflictingService(err) } } } @@ -95,16 +103,20 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", fqdn, qtype) return } - localAddr, ok := w.RemoteAddr().(*net.UDPAddr) - if !ok { - log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", fqdn, qtype) - return - } + if !remoteAddr.IP.Equal(localhostIP) { + // if request is not coming from 127.0.0.1, check if it's really local - // ignore external request - if !remoteAddr.IP.Equal(localAddr.IP) { - log.Warningf("nameserver: external request for %s%s, ignoring", fqdn, qtype) - return + localAddr, ok := w.RemoteAddr().(*net.UDPAddr) + if !ok { + log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", fqdn, qtype) + return + } + + // ignore external request + if !remoteAddr.IP.Equal(localAddr.IP) { + log.Warningf("nameserver: external request for %s%s, ignoring", fqdn, qtype) + return + } } // check if valid domain name @@ -121,9 +133,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain // get connection - // start = time.Now() comm, err := network.GetCommunicationByDNSRequest(ctx, remoteAddr.IP, uint16(remoteAddr.Port), fqdn) - // log.Tracef("nameserver: took %s to get comms (and maybe process)", time.Since(start)) if err != nil { log.ErrorTracef(ctx, "nameserver: could not identify process of %s:%d, returning nxdomain: %s", remoteAddr.IP, remoteAddr.Port, err) nxDomain(w, query) @@ -141,9 +151,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { } // check profile before we even get intel and rr - // start = time.Now() firewall.DecideOnCommunicationBeforeIntel(comm, fqdn) - // log.Tracef("nameserver: took %s to make decision", time.Since(start)) if comm.GetVerdict() == network.VerdictBlock || comm.GetVerdict() == network.VerdictDrop { log.InfoTracef(ctx, "nameserver: %s denied before intel, returning nxdomain", comm) @@ -152,9 +160,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) { } // get intel and RRs - // start = time.Now() domainIntel, rrCache := intel.GetIntelAndRRs(ctx, fqdn, qtype, comm.Process().ProfileSet().SecurityLevel()) - // log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start)) if rrCache == nil { // TODO: analyze nxdomain requests, malware could be trying DGA-domains log.WarningTracef(ctx, "nameserver: %s requested %s%s, is nxdomain", comm.Process(), fqdn, qtype) diff --git a/nameserver/takeover.go b/nameserver/takeover.go new file mode 100644 index 00000000..d2ef510c --- /dev/null +++ b/nameserver/takeover.go @@ -0,0 +1,59 @@ +package nameserver + +import ( + "fmt" + "net" + "os" + "time" + + "github.com/Safing/portbase/log" + "github.com/Safing/portbase/notifications" + "github.com/Safing/portmaster/network/packet" + "github.com/Safing/portmaster/process" +) + +func checkForConflictingService(err error) { + pid, err := takeover() + if err != nil || pid == 0 { + log.Info("nameserver: restarting server in 10 seconds") + time.Sleep(10 * time.Second) + return + } + + log.Infof("nameserver: stopped conflicting name service with pid %d", pid) + + // notify user + (¬ifications.Notification{ + ID: "nameserver-stopped-conflicting-service", + Message: fmt.Sprintf("Portmaster stopped a conflicting name service (pid %d) to gain required system integration.", pid), + }).Init().Save() + + // wait for a short duration for the other service to shut down + time.Sleep(100 * time.Millisecond) +} + +func takeover() (int, error) { + pid, _, err := process.GetPidByEndpoints(net.IPv4(127, 0, 0, 1), 53, net.IPv4(127, 0, 0, 1), 65535, packet.UDP) + if err != nil { + // there may be nothing listening on :53 + log.Tracef("nameserver: expected conflicting name service, but could not find anything listenting on :53") + return 0, nil + } + + proc, err := os.FindProcess(pid) + if err != nil { + // huh. gone already? I guess we'll wait then... + return 0, err + } + + err = proc.Signal(os.Interrupt) + if err != nil { + err = proc.Kill() + if err != nil { + log.Errorf("nameserver: failed to stop conflicting service (pid %d): %s", pid, err) + return 0, err + } + } + + return pid, nil +}