Upgrade nameserver to take over the DNS port when in use by another process

This commit is contained in:
Daniel 2019-05-22 16:04:41 +02:00
parent 7043f05144
commit b8374f044a
2 changed files with 84 additions and 19 deletions

View file

@ -5,7 +5,7 @@ package nameserver
import ( import (
"context" "context"
"net" "net"
"time" "runtime"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -23,8 +23,17 @@ var (
localhostIPs []dns.RR localhostIPs []dns.RR
) )
var (
listenAddress = "127.0.0.1:53"
localhostIP = net.IPv4(127, 0, 0, 1)
)
func init() { func init() {
modules.Register("nameserver", prep, start, nil, "intel") modules.Register("nameserver", prep, start, nil, "intel")
if runtime.GOOS == "windows" {
listenAddress = "0.0.0.0:53"
}
} }
func prep() error { func prep() error {
@ -44,7 +53,7 @@ func prep() error {
} }
func start() error { func start() error {
server := &dns.Server{Addr: "0.0.0.0:53", Net: "udp"} server := &dns.Server{Addr: listenAddress, Net: "udp"}
dns.HandleFunc(".", handleRequest) dns.HandleFunc(".", handleRequest)
go run(server) go run(server)
return nil return nil
@ -55,8 +64,7 @@ func run(server *dns.Server) {
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil { if err != nil {
log.Errorf("nameserver: server failed: %s", err) log.Errorf("nameserver: server failed: %s", err)
log.Info("nameserver: restarting server in 10 seconds") checkForConflictingService(err)
time.Sleep(10 * time.Second)
} }
} }
} }
@ -95,6 +103,9 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", fqdn, qtype) log.Warningf("nameserver: could not get remote address of request for %s%s, ignoring", fqdn, qtype)
return return
} }
if !remoteAddr.IP.Equal(localhostIP) {
// if request is not coming from 127.0.0.1, check if it's really local
localAddr, ok := w.RemoteAddr().(*net.UDPAddr) localAddr, ok := w.RemoteAddr().(*net.UDPAddr)
if !ok { if !ok {
log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", fqdn, qtype) log.Warningf("nameserver: could not get local address of request for %s%s, ignoring", fqdn, qtype)
@ -106,6 +117,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
log.Warningf("nameserver: external request for %s%s, ignoring", fqdn, qtype) log.Warningf("nameserver: external request for %s%s, ignoring", fqdn, qtype)
return return
} }
}
// check if valid domain name // check if valid domain name
if !netutils.IsValidFqdn(fqdn) { if !netutils.IsValidFqdn(fqdn) {
@ -121,9 +133,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
// TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain
// get connection // get connection
// start = time.Now()
comm, err := network.GetCommunicationByDNSRequest(ctx, remoteAddr.IP, uint16(remoteAddr.Port), fqdn) comm, err := network.GetCommunicationByDNSRequest(ctx, remoteAddr.IP, uint16(remoteAddr.Port), fqdn)
// log.Tracef("nameserver: took %s to get comms (and maybe process)", time.Since(start))
if err != nil { if err != nil {
log.ErrorTracef(ctx, "nameserver: could not identify process of %s:%d, returning nxdomain: %s", remoteAddr.IP, remoteAddr.Port, err) log.ErrorTracef(ctx, "nameserver: could not identify process of %s:%d, returning nxdomain: %s", remoteAddr.IP, remoteAddr.Port, err)
nxDomain(w, query) nxDomain(w, query)
@ -141,9 +151,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
} }
// check profile before we even get intel and rr // check profile before we even get intel and rr
// start = time.Now()
firewall.DecideOnCommunicationBeforeIntel(comm, fqdn) firewall.DecideOnCommunicationBeforeIntel(comm, fqdn)
// log.Tracef("nameserver: took %s to make decision", time.Since(start))
if comm.GetVerdict() == network.VerdictBlock || comm.GetVerdict() == network.VerdictDrop { if comm.GetVerdict() == network.VerdictBlock || comm.GetVerdict() == network.VerdictDrop {
log.InfoTracef(ctx, "nameserver: %s denied before intel, returning nxdomain", comm) log.InfoTracef(ctx, "nameserver: %s denied before intel, returning nxdomain", comm)
@ -152,9 +160,7 @@ func handleRequest(w dns.ResponseWriter, query *dns.Msg) {
} }
// get intel and RRs // get intel and RRs
// start = time.Now()
domainIntel, rrCache := intel.GetIntelAndRRs(ctx, fqdn, qtype, comm.Process().ProfileSet().SecurityLevel()) domainIntel, rrCache := intel.GetIntelAndRRs(ctx, fqdn, qtype, comm.Process().ProfileSet().SecurityLevel())
// log.Tracef("nameserver: took %s to get intel and RRs", time.Since(start))
if rrCache == nil { if rrCache == nil {
// TODO: analyze nxdomain requests, malware could be trying DGA-domains // TODO: analyze nxdomain requests, malware could be trying DGA-domains
log.WarningTracef(ctx, "nameserver: %s requested %s%s, is nxdomain", comm.Process(), fqdn, qtype) log.WarningTracef(ctx, "nameserver: %s requested %s%s, is nxdomain", comm.Process(), fqdn, qtype)

59
nameserver/takeover.go Normal file
View file

@ -0,0 +1,59 @@
package nameserver
import (
"fmt"
"net"
"os"
"time"
"github.com/Safing/portbase/log"
"github.com/Safing/portbase/notifications"
"github.com/Safing/portmaster/network/packet"
"github.com/Safing/portmaster/process"
)
func checkForConflictingService(err error) {
pid, err := takeover()
if err != nil || pid == 0 {
log.Info("nameserver: restarting server in 10 seconds")
time.Sleep(10 * time.Second)
return
}
log.Infof("nameserver: stopped conflicting name service with pid %d", pid)
// notify user
(&notifications.Notification{
ID: "nameserver-stopped-conflicting-service",
Message: fmt.Sprintf("Portmaster stopped a conflicting name service (pid %d) to gain required system integration.", pid),
}).Init().Save()
// wait for a short duration for the other service to shut down
time.Sleep(100 * time.Millisecond)
}
func takeover() (int, error) {
pid, _, err := process.GetPidByEndpoints(net.IPv4(127, 0, 0, 1), 53, net.IPv4(127, 0, 0, 1), 65535, packet.UDP)
if err != nil {
// there may be nothing listening on :53
log.Tracef("nameserver: expected conflicting name service, but could not find anything listenting on :53")
return 0, nil
}
proc, err := os.FindProcess(pid)
if err != nil {
// huh. gone already? I guess we'll wait then...
return 0, err
}
err = proc.Signal(os.Interrupt)
if err != nil {
err = proc.Kill()
if err != nil {
log.Errorf("nameserver: failed to stop conflicting service (pid %d): %s", pid, err)
return 0, err
}
}
return pid, nil
}