From 55b0ae89446d21f0e7455788f3a108fdc651b768 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 17:15:22 +0200 Subject: [PATCH 01/36] Revamp process attribution of network connections --- firewall/api.go | 12 +- firewall/master.go | 7 +- nameserver/nameserver.go | 16 +- nameserver/takeover.go | 4 +- network/clean.go | 21 +- network/connection.go | 48 ++- network/database.go | 7 + network/iphelper/get.go | 72 ++++ network/iphelper/iphelper.go | 63 +++ {process => network}/iphelper/tables.go | 145 +++---- network/iphelper/tables_test.go | 54 +++ {process => network}/iphelper/test/main.go | 0 .../proc/findpid.go | 8 +- network/proc/tables.go | 218 +++++++++++ network/proc/tables_test.go | 60 +++ network/socket/socket.go | 26 ++ network/state/exists.go | 103 +++++ network/state/lookup.go | 189 +++++++++ network/state/system_linux.go | 37 ++ network/state/system_windows.go | 21 + network/state/tables.go | 66 ++++ network/state/udp.go | 118 ++++++ process/find.go | 149 ++----- process/getpid_linux.go | 13 - process/getpid_windows.go | 13 - process/iphelper/get.go | 260 ------------ process/iphelper/iphelper.go | 79 ---- process/proc/gather.go | 83 ---- process/proc/get.go | 66 ---- process/proc/processfinder_test.go | 18 - process/proc/sockets.go | 370 ------------------ process/proc/sockets_test.go | 40 -- updates/main.go | 5 +- updates/upgrader.go | 39 +- 34 files changed, 1234 insertions(+), 1196 deletions(-) create mode 100644 network/iphelper/get.go create mode 100644 network/iphelper/iphelper.go rename {process => network}/iphelper/tables.go (69%) create mode 100644 network/iphelper/tables_test.go rename {process => network}/iphelper/test/main.go (100%) rename process/proc/processfinder.go => network/proc/findpid.go (95%) create mode 100644 network/proc/tables.go create mode 100644 network/proc/tables_test.go create mode 100644 network/socket/socket.go create mode 100644 network/state/exists.go create mode 100644 network/state/lookup.go create mode 100644 network/state/system_linux.go create mode 100644 network/state/system_windows.go create mode 100644 network/state/tables.go create mode 100644 network/state/udp.go delete mode 100644 process/getpid_linux.go delete mode 100644 process/getpid_windows.go delete mode 100644 process/iphelper/get.go delete mode 100644 process/iphelper/iphelper.go delete mode 100644 process/proc/gather.go delete mode 100644 process/proc/get.go delete mode 100644 process/proc/processfinder_test.go delete mode 100644 process/proc/sockets.go delete mode 100644 process/proc/sockets_test.go diff --git a/firewall/api.go b/firewall/api.go index b73729c8..d0e03f24 100644 --- a/firewall/api.go +++ b/firewall/api.go @@ -60,7 +60,17 @@ func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err er var procsChecked []string // get process - proc, err := process.GetProcessByEndpoints(r.Context(), remoteIP, remotePort, localIP, localPort, packet.TCP) // switch reverse/local to get remote process + proc, _, err := process.GetProcessByEndpoints( + r.Context(), + packet.IPv4, + packet.TCP, + // switch reverse/local to get remote process + remoteIP, + remotePort, + localIP, + localPort, + false, + ) if err != nil { return false, fmt.Errorf("failed to get process: %s", err) } diff --git a/firewall/master.go b/firewall/master.go index a1b30203..69020dad 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -10,6 +10,7 @@ import ( "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/state" "github.com/safing/portmaster/process" "github.com/safing/portmaster/profile" "github.com/safing/portmaster/profile/endpoints" @@ -90,12 +91,14 @@ func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { pktInfo := pkt.Info() if conn.Process().Pid >= 0 && pktInfo.Src.Equal(pktInfo.Dst) { // get PID - otherPid, _, err := process.GetPidByEndpoints( + otherPid, _, err := state.Lookup( + pktInfo.Version, + pktInfo.Protocol, pktInfo.RemoteIP(), pktInfo.RemotePort(), pktInfo.LocalIP(), pktInfo.LocalPort(), - pktInfo.Protocol, + pktInfo.Direction, ) if err != nil { log.Warningf("filter: failed to find local peer process PID: %s", err) diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index c68e62a2..1e955c4a 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -6,6 +6,8 @@ import ( "net" "strings" + "github.com/safing/portmaster/network/packet" + "github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/log" @@ -167,13 +169,14 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er } // start tracer - ctx, tracer := log.AddTracer(ctx) - tracer.Tracef("nameserver: handling new request for %s%s from %s:%d", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) + ctx, tracer := log.AddTracer(context.Background()) + defer tracer.Submit() + tracer.Tracef("nameserver: handling new request for %s%s from %s:%d, getting connection", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) // TODO: if there are 3 request for the same domain/type in a row, delete all caches of that domain // get connection - conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, remoteAddr.IP, uint16(remoteAddr.Port)) + conn := network.NewConnectionFromDNSRequest(ctx, q.FQDN, nil, packet.IPv4, remoteAddr.IP, uint16(remoteAddr.Port)) // once we decided on the connection we might need to save it to the database // so we defer that check right now. @@ -191,7 +194,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er return default: - log.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn) + tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn) } }() @@ -242,7 +245,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er tracer.Infof("nameserver: %s handing over to reason-responder: %s", q.FQDN, conn.Reason) reply := responder.ReplyWithDNS(query, conn.Reason, conn.ReasonContext) if err := w.WriteMsg(reply); err != nil { - log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) + tracer.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) } else { tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process()) } @@ -269,6 +272,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er return nil } + tracer.Trace("nameserver: deciding on resolved dns") rrCache = firewall.DecideOnResolvedDNS(conn, q, rrCache) if rrCache == nil { sendResponse(w, query, conn.Verdict, conn.Reason, conn.ReasonContext) @@ -283,7 +287,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er m.Extra = rrCache.Extra if err := w.WriteMsg(m); err != nil { - log.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) + tracer.Warningf("nameserver: failed to return response %s%s to %s: %s", q.FQDN, q.QType, conn.Process(), err) } else { tracer.Debugf("nameserver: returning response %s%s to %s", q.FQDN, q.QType, conn.Process()) } diff --git a/nameserver/takeover.go b/nameserver/takeover.go index e55aa46c..d5ede695 100644 --- a/nameserver/takeover.go +++ b/nameserver/takeover.go @@ -10,7 +10,7 @@ import ( "github.com/safing/portbase/modules" "github.com/safing/portbase/notifications" "github.com/safing/portmaster/network/packet" - "github.com/safing/portmaster/process" + "github.com/safing/portmaster/network/state" ) var ( @@ -58,7 +58,7 @@ func checkForConflictingService() error { } func takeover(resolverIP net.IP) (int, error) { - pid, _, err := process.GetPidByEndpoints(resolverIP, 53, resolverIP, 65535, packet.UDP) + pid, _, err := state.Lookup(0, packet.UDP, resolverIP, 53, nil, 0, false) if err != nil { // there may be nothing listening on :53 return 0, nil diff --git a/network/clean.go b/network/clean.go index 3b1bbac9..efeadce9 100644 --- a/network/clean.go +++ b/network/clean.go @@ -4,6 +4,8 @@ import ( "context" "time" + "github.com/safing/portmaster/network/state" + "github.com/safing/portbase/log" "github.com/safing/portmaster/process" ) @@ -22,8 +24,12 @@ func connectionCleaner(ctx context.Context) error { ticker.Stop() return nil case <-ticker.C: + // clean connections and processes activePIDs := cleanConnections() process.CleanProcessStorage(activePIDs) + + // clean udp connection states + state.CleanUDPStates(ctx) } } } @@ -33,12 +39,9 @@ func cleanConnections() (activePIDs map[int]struct{}) { name := "clean connections" // TODO: change to new fn _ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error { - activeIDs := make(map[string]struct{}) - for _, cID := range process.GetActiveConnectionIDs() { - activeIDs[cID] = struct{}{} - } - now := time.Now().Unix() + now := time.Now().UTC() + nowUnix := now.Unix() deleteOlderThan := time.Now().Add(-deleteConnsAfterEndedThreshold).Unix() // lock both together because we cannot fully guarantee in which map a connection lands @@ -49,20 +52,20 @@ func cleanConnections() (activePIDs map[int]struct{}) { defer dnsConnsLock.Unlock() // network connections - for key, conn := range conns { + for _, conn := range conns { conn.Lock() // delete inactive connections switch { case conn.Ended == 0: // Step 1: check if still active - _, ok := activeIDs[key] - if ok { + exists := state.Exists(conn.IPVersion, conn.IPProtocol, conn.LocalIP, conn.LocalPort, conn.Entity.IP, conn.Entity.Port, now) + if exists { activePIDs[conn.process.Pid] = struct{}{} } else { // Step 2: mark end activePIDs[conn.process.Pid] = struct{}{} - conn.Ended = now + conn.Ended = nowUnix conn.Save() } case conn.Ended < deleteOlderThan: diff --git a/network/connection.go b/network/connection.go index bb5529f1..1b99842b 100644 --- a/network/connection.go +++ b/network/connection.go @@ -25,11 +25,19 @@ type Connection struct { //nolint:maligned // TODO: fix alignment record.Base sync.Mutex - ID string - Scope string - Inbound bool - Entity *intel.Entity // needs locking, instance is never shared - process *process.Process + ID string + Scope string + IPVersion packet.IPVersion + Inbound bool + + // local endpoint + IPProtocol packet.IPProtocol + LocalIP net.IP + LocalPort uint16 + process *process.Process + + // remote endpoint + Entity *intel.Entity // needs locking, instance is never shared Verdict Verdict Reason string @@ -55,9 +63,18 @@ type Connection struct { //nolint:maligned // TODO: fix alignment } // NewConnectionFromDNSRequest returns a new connection based on the given dns request. -func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, localIP net.IP, localPort uint16) *Connection { +func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, ipVersion packet.IPVersion, localIP net.IP, localPort uint16) *Connection { // get Process - proc, err := process.GetProcessByEndpoints(ctx, localIP, localPort, dnsAddress, dnsPort, packet.UDP) + proc, _, err := process.GetProcessByEndpoints( + ctx, + ipVersion, + packet.UDP, + localIP, + localPort, + dnsAddress, // this might not be correct, but it does not matter, as matching only occurs on the local address + dnsPort, + false, // inbound, irrevelant + ) if err != nil { log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err) proc = process.GetUnidentifiedProcess(ctx) @@ -147,11 +164,18 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { } return &Connection{ - ID: pkt.GetConnectionID(), - Scope: scope, - Inbound: inbound, - Entity: entity, - process: proc, + ID: pkt.GetConnectionID(), + Scope: scope, + IPVersion: pkt.Info().Version, + Inbound: inbound, + // local endpoint + IPProtocol: pkt.Info().Protocol, + LocalIP: pkt.Info().LocalIP(), + LocalPort: pkt.Info().LocalPort(), + process: proc, + // remote endpoint + Entity: entity, + // meta Started: time.Now().Unix(), } } diff --git a/network/database.go b/network/database.go index ee42a5b1..5910e92d 100644 --- a/network/database.go +++ b/network/database.go @@ -57,6 +57,13 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { return conn, nil } } + // case "system": + // if len(splitted) >= 2 { + // switch splitted[1] { + // case "": + // process.Get + // } + // } } return nil, storage.ErrNotFound diff --git a/network/iphelper/get.go b/network/iphelper/get.go new file mode 100644 index 00000000..80f3352f --- /dev/null +++ b/network/iphelper/get.go @@ -0,0 +1,72 @@ +// +build windows + +package iphelper + +import ( + "sync" + + "github.com/safing/portmaster/network/socket" +) + +const ( + unidentifiedProcessID = -1 +) + +var ( + ipHelper *IPHelper + lock sync.RWMutex +) + +// GetTCP4Table returns the system table for IPv4 TCP activity. +func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { + lock.Lock() + defer lock.Unlock() + + err = checkIPHelper() + if err != nil { + return nil, nil, err + } + + return ipHelper.getTable(IPv4, TCP) +} + +// GetTCP6Table returns the system table for IPv6 TCP activity. +func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { + lock.Lock() + defer lock.Unlock() + + err = checkIPHelper() + if err != nil { + return nil, nil, err + } + + return ipHelper.getTable(IPv6, TCP) +} + +// GetUDP4Table returns the system table for IPv4 UDP activity. +func GetUDP4Table() (binds []*socket.BindInfo, err error) { + lock.Lock() + defer lock.Unlock() + + err = checkIPHelper() + if err != nil { + return nil, err + } + + _, binds, err = ipHelper.getTable(IPv4, UDP) + return +} + +// GetUDP6Table returns the system table for IPv6 UDP activity. +func GetUDP6Table() (binds []*socket.BindInfo, err error) { + lock.Lock() + defer lock.Unlock() + + err = checkIPHelper() + if err != nil { + return nil, err + } + + _, binds, err = ipHelper.getTable(IPv6, UDP) + return +} diff --git a/network/iphelper/iphelper.go b/network/iphelper/iphelper.go new file mode 100644 index 00000000..5498879a --- /dev/null +++ b/network/iphelper/iphelper.go @@ -0,0 +1,63 @@ +// +build windows + +package iphelper + +import ( + "errors" + "fmt" + + "github.com/tevino/abool" + "golang.org/x/sys/windows" +) + +var ( + errInvalid = errors.New("IPHelper not initialzed or broken") +) + +// IPHelper represents a subset of the Windows iphlpapi.dll. +type IPHelper struct { + dll *windows.LazyDLL + + getExtendedTCPTable *windows.LazyProc + getExtendedUDPTable *windows.LazyProc + + valid *abool.AtomicBool +} + +func checkIPHelper() (err error) { + if ipHelper == nil { + ipHelper, err = New() + return err + } + return nil +} + +// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded). +func New() (*IPHelper, error) { + + new := &IPHelper{} + new.valid = abool.NewBool(false) + var err error + + // load dll + new.dll = windows.NewLazySystemDLL("iphlpapi.dll") + err = new.dll.Load() + if err != nil { + return nil, err + } + + // load functions + new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable") + err = new.getExtendedTCPTable.Find() + if err != nil { + return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err) + } + new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable") + err = new.getExtendedUDPTable.Find() + if err != nil { + return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err) + } + + new.valid.Set() + return new, nil +} diff --git a/process/iphelper/tables.go b/network/iphelper/tables.go similarity index 69% rename from process/iphelper/tables.go rename to network/iphelper/tables.go index 8ffecfd7..b2ea8286 100644 --- a/process/iphelper/tables.go +++ b/network/iphelper/tables.go @@ -3,12 +3,15 @@ package iphelper import ( + "encoding/binary" "errors" "fmt" "net" "sync" "unsafe" + "github.com/safing/portmaster/network/socket" + "golang.org/x/sys/windows" ) @@ -22,19 +25,6 @@ const ( winErrInvalidParameter = uintptr(windows.ERROR_INVALID_PARAMETER) ) -// ConnectionEntry describes a connection state table entry. -type ConnectionEntry struct { - localIP net.IP - remoteIP net.IP - localPort uint16 - remotePort uint16 - pid int -} - -func (entry *ConnectionEntry) String() string { - return fmt.Sprintf("PID=%d %s:%d <> %s:%d", entry.pid, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort) -} - type iphelperTCPTable struct { // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa366921(v=vs.85).aspx numEntries uint32 @@ -148,9 +138,9 @@ func increaseBufSize() int { return bufSize } -// GetTables returns the current connection state table of Windows of the given protocol and IP version. -func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connections []*ConnectionEntry, listeners []*ConnectionEntry, err error) { //nolint:gocognit,gocycle // TODO - // docs: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365928(v=vs.85).aspx +// getTable returns the current connection state table of Windows of the given protocol and IP version. +func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { //nolint:gocognit,gocycle // TODO + // docs: https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable if !ipHelper.valid.IsSet() { return nil, nil, errInvalid @@ -220,26 +210,27 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection table := tcpTable.table[:tcpTable.numEntries] for _, row := range table { - new := &ConnectionEntry{} - - // PID - new.pid = int(row.owningPid) - - // local - if row.localAddr != 0 { - new.localIP = convertIPv4(row.localAddr) - } - new.localPort = uint16(row.localPort>>8 | row.localPort<<8) - - // remote if row.state == iphelperTCPStateListen { - listeners = append(listeners, new) + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: convertIPv4(row.localAddr), + Port: uint16(row.localPort>>8 | row.localPort<<8), + }, + PID: int(row.owningPid), + }) } else { - new.remoteIP = convertIPv4(row.remoteAddr) - new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8) - connections = append(connections, new) + connections = append(connections, &socket.ConnectionInfo{ + Local: socket.Address{ + IP: convertIPv4(row.localAddr), + Port: uint16(row.localPort>>8 | row.localPort<<8), + }, + Remote: socket.Address{ + IP: convertIPv4(row.remoteAddr), + Port: uint16(row.remotePort>>8 | row.remotePort<<8), + }, + PID: int(row.owningPid), + }) } - } case protocol == TCP && ipVersion == IPv6: @@ -248,27 +239,27 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection table := tcpTable.table[:tcpTable.numEntries] for _, row := range table { - new := &ConnectionEntry{} - - // PID - new.pid = int(row.owningPid) - - // local - new.localIP = net.IP(row.localAddr[:]) - new.localPort = uint16(row.localPort>>8 | row.localPort<<8) - - // remote if row.state == iphelperTCPStateListen { - if new.localIP.Equal(net.IPv6zero) { - new.localIP = nil - } - listeners = append(listeners, new) + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: net.IP(row.localAddr[:]), + Port: uint16(row.localPort>>8 | row.localPort<<8), + }, + PID: int(row.owningPid), + }) } else { - new.remoteIP = net.IP(row.remoteAddr[:]) - new.remotePort = uint16(row.remotePort>>8 | row.remotePort<<8) - connections = append(connections, new) + connections = append(connections, &socket.ConnectionInfo{ + Local: socket.Address{ + IP: net.IP(row.localAddr[:]), + Port: uint16(row.localPort>>8 | row.localPort<<8), + }, + Remote: socket.Address{ + IP: net.IP(row.remoteAddr[:]), + Port: uint16(row.remotePort>>8 | row.remotePort<<8), + }, + PID: int(row.owningPid), + }) } - } case protocol == UDP && ipVersion == IPv4: @@ -277,19 +268,13 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection table := udpTable.table[:udpTable.numEntries] for _, row := range table { - new := &ConnectionEntry{} - - // PID - new.pid = int(row.owningPid) - - // local - new.localPort = uint16(row.localPort>>8 | row.localPort<<8) - if row.localAddr == 0 { - listeners = append(listeners, new) - } else { - new.localIP = convertIPv4(row.localAddr) - connections = append(connections, new) - } + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: convertIPv4(row.localAddr), + Port: uint16(row.localPort>>8 | row.localPort<<8), + }, + PID: int(row.owningPid), + }) } case protocol == UDP && ipVersion == IPv6: @@ -298,32 +283,22 @@ func (ipHelper *IPHelper) GetTables(protocol uint8, ipVersion uint8) (connection table := udpTable.table[:udpTable.numEntries] for _, row := range table { - new := &ConnectionEntry{} - - // PID - new.pid = int(row.owningPid) - - // local - new.localIP = net.IP(row.localAddr[:]) - new.localPort = uint16(row.localPort>>8 | row.localPort<<8) - if new.localIP.Equal(net.IPv6zero) { - new.localIP = nil - listeners = append(listeners, new) - } else { - connections = append(connections, new) - } + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: net.IP(row.localAddr[:]), + Port: uint16(row.localPort>>8 | row.localPort<<8), + }, + PID: int(row.owningPid), + }) } } - return connections, listeners, nil + return connections, binds, nil } func convertIPv4(input uint32) net.IP { - return net.IPv4( - uint8(input&0xFF), - uint8(input>>8&0xFF), - uint8(input>>16&0xFF), - uint8(input>>24&0xFF), - ) + addressBuf := make([]byte, 4) + binary.BigEndian.PutUint32(addressBuf, input) + return net.IP(addressBuf) } diff --git a/network/iphelper/tables_test.go b/network/iphelper/tables_test.go new file mode 100644 index 00000000..e996219e --- /dev/null +++ b/network/iphelper/tables_test.go @@ -0,0 +1,54 @@ +// +build windows + +package iphelper + +import ( + "fmt" + "testing" +) + +func TestSockets(t *testing.T) { + connections, listeners, err := GetTCP4Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nTCP 4 connections:") + for _, connection := range connections { + fmt.Printf("%+v\n", connection) + } + fmt.Println("\nTCP 4 listeners:") + for _, listener := range listeners { + fmt.Printf("%+v\n", listener) + } + + connections, listeners, err = GetTCP6Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nTCP 6 connections:") + for _, connection := range connections { + fmt.Printf("%+v\n", connection) + } + fmt.Println("\nTCP 6 listeners:") + for _, listener := range listeners { + fmt.Printf("%+v\n", listener) + } + + binds, err := GetUDP4Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nUDP 4 binds:") + for _, bind := range binds { + fmt.Printf("%+v\n", bind) + } + + binds, err = GetUDP6Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nUDP 6 binds:") + for _, bind := range binds { + fmt.Printf("%+v\n", bind) + } +} diff --git a/process/iphelper/test/main.go b/network/iphelper/test/main.go similarity index 100% rename from process/iphelper/test/main.go rename to network/iphelper/test/main.go diff --git a/process/proc/processfinder.go b/network/proc/findpid.go similarity index 95% rename from process/proc/processfinder.go rename to network/proc/findpid.go index 5e6ed7cc..ce54984e 100644 --- a/process/proc/processfinder.go +++ b/network/proc/findpid.go @@ -13,13 +13,17 @@ import ( "github.com/safing/portbase/log" ) +const ( + unidentifiedProcessID = -1 +) + var ( pidsByUserLock sync.Mutex pidsByUser = make(map[int][]int) ) -// GetPidOfInode returns the pid of the given uid and socket inode. -func GetPidOfInode(uid, inode int) (int, bool) { //nolint:gocognit // TODO +// FindPID returns the pid of the given uid and socket inode. +func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO pidsByUserLock.Lock() defer pidsByUserLock.Unlock() diff --git a/network/proc/tables.go b/network/proc/tables.go new file mode 100644 index 00000000..b5a652a1 --- /dev/null +++ b/network/proc/tables.go @@ -0,0 +1,218 @@ +// +build linux + +package proc + +import ( + "bufio" + "encoding/hex" + "net" + "os" + "strconv" + "strings" + "unicode" + + "github.com/safing/portmaster/network/socket" + + "github.com/safing/portbase/log" +) + +/* + +1. find socket inode + - by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP? + - /proc/net/{tcp|udp}[6] + +2. get list of processes of uid + +3. find socket inode in process fds + - if not found, refresh map of uid->pids + - if not found, check ALL pids: maybe euid != uid + +4. gather process info + +Cache every step! + +*/ + +// Network Related Constants +const ( + TCP4 uint8 = iota + UDP4 + TCP6 + UDP6 + ICMP4 + ICMP6 + + TCP4Data = "/proc/net/tcp" + UDP4Data = "/proc/net/udp" + TCP6Data = "/proc/net/tcp6" + UDP6Data = "/proc/net/udp6" + ICMP4Data = "/proc/net/icmp" + ICMP6Data = "/proc/net/icmp6" + + UnfetchedProcessID = -2 + + tcpListenStateHex = "0A" +) + +// GetTCP4Table returns the system table for IPv4 TCP activity. +func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { + return getTableFromSource(TCP4, TCP4Data, convertIPv4) +} + +// GetTCP6Table returns the system table for IPv6 TCP activity. +func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { + return getTableFromSource(TCP6, TCP6Data, convertIPv6) +} + +// GetUDP4Table returns the system table for IPv4 UDP activity. +func GetUDP4Table() (binds []*socket.BindInfo, err error) { + _, binds, err = getTableFromSource(UDP4, UDP4Data, convertIPv4) + return +} + +// GetUDP6Table returns the system table for IPv6 UDP activity. +func GetUDP6Table() (binds []*socket.BindInfo, err error) { + _, binds, err = getTableFromSource(UDP6, UDP6Data, convertIPv6) + return +} + +func getTableFromSource(stack uint8, procFile string, ipConverter func(string) net.IP) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { + + // open file + socketData, err := os.Open(procFile) + if err != nil { + return nil, nil, err + } + defer socketData.Close() + + // file scanner + scanner := bufio.NewScanner(socketData) + scanner.Split(bufio.ScanLines) + + // parse + scanner.Scan() // skip first line + for scanner.Scan() { + line := strings.FieldsFunc(scanner.Text(), procDelimiter) + if len(line) < 14 { + // log.Tracef("process: too short: %s", line) + continue + } + + localIP := ipConverter(line[1]) + if localIP == nil { + continue + } + + localPort, err := strconv.ParseUint(line[2], 16, 16) + if err != nil { + log.Warningf("process: could not parse port: %s", err) + continue + } + + uid, err := strconv.ParseInt(line[11], 10, 32) + // log.Tracef("uid: %s", line[11]) + if err != nil { + log.Warningf("process: could not parse uid %s: %s", line[11], err) + continue + } + + inode, err := strconv.ParseInt(line[13], 10, 32) + // log.Tracef("inode: %s", line[13]) + if err != nil { + log.Warningf("process: could not parse inode %s: %s", line[13], err) + continue + } + + switch stack { + case UDP4, UDP6: + + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: localIP, + Port: uint16(localPort), + }, + PID: UnfetchedProcessID, + UID: int(uid), + Inode: int(inode), + }) + + case TCP4, TCP6: + + if line[5] == tcpListenStateHex { + // listener + + binds = append(binds, &socket.BindInfo{ + Local: socket.Address{ + IP: localIP, + Port: uint16(localPort), + }, + PID: UnfetchedProcessID, + UID: int(uid), + Inode: int(inode), + }) + } else { + // connection + + remoteIP := ipConverter(line[3]) + if remoteIP == nil { + continue + } + + remotePort, err := strconv.ParseUint(line[4], 16, 16) + if err != nil { + log.Warningf("process: could not parse port: %s", err) + continue + } + + connections = append(connections, &socket.ConnectionInfo{ + Local: socket.Address{ + IP: localIP, + Port: uint16(localPort), + }, + Remote: socket.Address{ + IP: remoteIP, + Port: uint16(remotePort), + }, + PID: UnfetchedProcessID, + UID: int(uid), + Inode: int(inode), + }) + } + } + } + + return connections, binds, nil +} + +func procDelimiter(c rune) bool { + return unicode.IsSpace(c) || c == ':' +} + +func convertIPv4(data string) net.IP { + decoded, err := hex.DecodeString(data) + if err != nil { + log.Warningf("process: could not parse IPv4 %s: %s", data, err) + return nil + } + if len(decoded) != 4 { + log.Warningf("process: decoded IPv4 %s has wrong length", decoded) + return nil + } + ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) + return ip +} + +func convertIPv6(data string) net.IP { + decoded, err := hex.DecodeString(data) + if err != nil { + log.Warningf("process: could not parse IPv6 %s: %s", data, err) + return nil + } + if len(decoded) != 16 { + log.Warningf("process: decoded IPv6 %s has wrong length", decoded) + return nil + } + ip := net.IP(decoded) + return ip +} diff --git a/network/proc/tables_test.go b/network/proc/tables_test.go new file mode 100644 index 00000000..9dc7c1eb --- /dev/null +++ b/network/proc/tables_test.go @@ -0,0 +1,60 @@ +// +build linux + +package proc + +import ( + "fmt" + "testing" +) + +func TestSockets(t *testing.T) { + connections, listeners, err := GetTCP4Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nTCP 4 connections:") + for _, connection := range connections { + pid, _ := FindPID(connection.UID, connection.Inode) + fmt.Printf("%d: %+v\n", pid, connection) + } + fmt.Println("\nTCP 4 listeners:") + for _, listener := range listeners { + pid, _ := FindPID(listener.UID, listener.Inode) + fmt.Printf("%d: %+v\n", pid, listener) + } + + connections, listeners, err = GetTCP6Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nTCP 6 connections:") + for _, connection := range connections { + pid, _ := FindPID(connection.UID, connection.Inode) + fmt.Printf("%d: %+v\n", pid, connection) + } + fmt.Println("\nTCP 6 listeners:") + for _, listener := range listeners { + pid, _ := FindPID(listener.UID, listener.Inode) + fmt.Printf("%d: %+v\n", pid, listener) + } + + binds, err := GetUDP4Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nUDP 4 binds:") + for _, bind := range binds { + pid, _ := FindPID(bind.UID, bind.Inode) + fmt.Printf("%d: %+v\n", pid, bind) + } + + binds, err = GetUDP6Table() + if err != nil { + t.Fatal(err) + } + fmt.Println("\nUDP 6 binds:") + for _, bind := range binds { + pid, _ := FindPID(bind.UID, bind.Inode) + fmt.Printf("%d: %+v\n", pid, bind) + } +} diff --git a/network/socket/socket.go b/network/socket/socket.go new file mode 100644 index 00000000..a599eddf --- /dev/null +++ b/network/socket/socket.go @@ -0,0 +1,26 @@ +package socket + +import "net" + +// ConnectionInfo holds socket information returned by the system. +type ConnectionInfo struct { + Local Address + Remote Address + PID int + UID int + Inode int +} + +// BindInfo holds socket information returned by the system. +type BindInfo struct { + Local Address + PID int + UID int + Inode int +} + +// Address is an IP + Port pair. +type Address struct { + IP net.IP + Port uint16 +} diff --git a/network/state/exists.go b/network/state/exists.go new file mode 100644 index 00000000..9af6979f --- /dev/null +++ b/network/state/exists.go @@ -0,0 +1,103 @@ +package state + +import ( + "net" + "time" + + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/socket" +) + +const ( + UDPConnectionTTL = 10 * time.Minute +) + +func Exists( + ipVersion packet.IPVersion, + protocol packet.IPProtocol, + localIP net.IP, + localPort uint16, + remoteIP net.IP, + remotePort uint16, + now time.Time, +) (exists bool) { + + switch { + case ipVersion == packet.IPv4 && protocol == packet.TCP: + tcp4Lock.Lock() + defer tcp4Lock.Unlock() + return existsTCP(tcp4Connections, localIP, localPort, remoteIP, remotePort) + + case ipVersion == packet.IPv6 && protocol == packet.TCP: + tcp6Lock.Lock() + defer tcp6Lock.Unlock() + return existsTCP(tcp6Connections, localIP, localPort, remoteIP, remotePort) + + case ipVersion == packet.IPv4 && protocol == packet.UDP: + udp4Lock.Lock() + defer udp4Lock.Unlock() + return existsUDP(udp4Binds, udp4States, localIP, localPort, remoteIP, remotePort, now) + + case ipVersion == packet.IPv6 && protocol == packet.UDP: + udp6Lock.Lock() + defer udp6Lock.Unlock() + return existsUDP(udp6Binds, udp6States, localIP, localPort, remoteIP, remotePort, now) + + default: + return false + } +} + +func existsTCP( + connections []*socket.ConnectionInfo, + localIP net.IP, + localPort uint16, + remoteIP net.IP, + remotePort uint16, +) (exists bool) { + + // search connections + for _, socketInfo := range connections { + if localPort == socketInfo.Local.Port && + remotePort == socketInfo.Remote.Port && + remoteIP.Equal(socketInfo.Remote.IP) && + localIP.Equal(socketInfo.Local.IP) { + return true + } + } + + return false +} + +func existsUDP( + binds []*socket.BindInfo, + udpStates map[string]map[string]*udpState, + localIP net.IP, + localPort uint16, + remoteIP net.IP, + remotePort uint16, + now time.Time, +) (exists bool) { + + connThreshhold := now.Add(-UDPConnectionTTL) + + // search binds + for _, socketInfo := range binds { + if localPort == socketInfo.Local.Port && + (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { + + udpConnState, ok := getUDPConnState(socketInfo, udpStates, remoteIP, remotePort) + switch { + case !ok: + return false + case udpConnState.lastSeen.After(connThreshhold): + return true + default: + return false + } + + } + } + + return false +} diff --git a/network/state/lookup.go b/network/state/lookup.go new file mode 100644 index 00000000..ade151a5 --- /dev/null +++ b/network/state/lookup.go @@ -0,0 +1,189 @@ +package state + +import ( + "errors" + "net" + "sync" + "time" + + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/socket" +) + +// - TCP +// - Outbound: Match listeners (in!), then connections (out!) +// - Inbound: Match listeners (in!), then connections (out!) +// - Clean via connections +// - UDP +// - Any connection: match specific local address or zero IP +// - In or out: save direction of first packet: +// - map[]map[]{direction, lastSeen} +// - only clean if is removed by OS +// - limit to 256 entries? +// - clean after 72hrs? +// - switch direction to outbound if outbound packet is seen? +// - IP: Unidentified Process + +const ( + UnidentifiedProcessID = -1 +) + +// Errors +var ( + ErrConnectionNotFound = errors.New("could not find connection in system state tables") + ErrPIDNotFound = errors.New("could not find pid for socket inode") +) + +var ( + tcp4Lock sync.Mutex + tcp6Lock sync.Mutex + udp4Lock sync.Mutex + udp6Lock sync.Mutex + + waitTime = 3 * time.Millisecond +) + +func LookupWithPacket(pkt packet.Packet) (pid int, inbound bool, err error) { + meta := pkt.Info() + return Lookup( + meta.Version, + meta.Protocol, + meta.LocalIP(), + meta.LocalPort(), + meta.RemoteIP(), + meta.RemotePort(), + meta.Direction, + ) +} + +func Lookup( + ipVersion packet.IPVersion, + protocol packet.IPProtocol, + localIP net.IP, + localPort uint16, + remoteIP net.IP, + remotePort uint16, + pktInbound bool, +) ( + pid int, + inbound bool, + err error, +) { + + // auto-detect version + if ipVersion == 0 { + if ip := localIP.To4(); ip != nil { + ipVersion = packet.IPv4 + } else { + ipVersion = packet.IPv6 + } + } + + switch { + case ipVersion == packet.IPv4 && protocol == packet.TCP: + tcp4Lock.Lock() + defer tcp4Lock.Unlock() + return searchTCP(tcp4Connections, tcp4Listeners, updateTCP4Tables, localIP, localPort) + + case ipVersion == packet.IPv6 && protocol == packet.TCP: + tcp6Lock.Lock() + defer tcp6Lock.Unlock() + return searchTCP(tcp6Connections, tcp6Listeners, updateTCP6Tables, localIP, localPort) + + case ipVersion == packet.IPv4 && protocol == packet.UDP: + udp4Lock.Lock() + defer udp4Lock.Unlock() + return searchUDP(udp4Binds, udp4States, updateUDP4Table, localIP, localPort, remoteIP, remotePort, pktInbound) + + case ipVersion == packet.IPv6 && protocol == packet.UDP: + udp6Lock.Lock() + defer udp6Lock.Unlock() + return searchUDP(udp6Binds, udp6States, updateUDP6Table, localIP, localPort, remoteIP, remotePort, pktInbound) + + default: + return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") + } +} + +func searchTCP( + connections []*socket.ConnectionInfo, + listeners []*socket.BindInfo, + updateTables func() ([]*socket.ConnectionInfo, []*socket.BindInfo), + localIP net.IP, + localPort uint16, +) ( + pid int, + inbound bool, + err error, +) { + + // search until we find something + for i := 0; i < 5; i++ { + // always search listeners first + for _, socketInfo := range listeners { + if localPort == socketInfo.Local.Port && + (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { + return checkBindPID(socketInfo, true) + } + } + + // search connections + for _, socketInfo := range connections { + if localPort == socketInfo.Local.Port && + localIP.Equal(socketInfo.Local.IP) { + return checkConnectionPID(socketInfo, false) + } + } + + // we found nothing, we could have been too fast, give the kernel some time to think + time.Sleep(waitTime) + + // refetch lists + connections, listeners = updateTables() + } + + return UnidentifiedProcessID, false, ErrConnectionNotFound +} + +func searchUDP( + binds []*socket.BindInfo, + udpStates map[string]map[string]*udpState, + updateTable func() []*socket.BindInfo, + localIP net.IP, + localPort uint16, + remoteIP net.IP, + remotePort uint16, + pktInbound bool, +) ( + pid int, + inbound bool, + err error, +) { + + // search until we find something + for i := 0; i < 5; i++ { + // search binds + for _, socketInfo := range binds { + if localPort == socketInfo.Local.Port && + (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { + + // do not check direction if remoteIP/Port is not given + if remotePort == 0 { + return checkBindPID(socketInfo, pktInbound) + } + + // get direction and return + connInbound := getUDPDirection(socketInfo, udpStates, remoteIP, remotePort, pktInbound) + return checkBindPID(socketInfo, connInbound) + } + } + + // we found nothing, we could have been too fast, give the kernel some time to think + time.Sleep(waitTime) + + // refetch lists + binds = updateTable() + } + + return UnidentifiedProcessID, pktInbound, ErrConnectionNotFound +} diff --git a/network/state/system_linux.go b/network/state/system_linux.go new file mode 100644 index 00000000..a08fd86b --- /dev/null +++ b/network/state/system_linux.go @@ -0,0 +1,37 @@ +package state + +import ( + "github.com/safing/portmaster/network/proc" + "github.com/safing/portmaster/network/socket" +) + +var ( + getTCP4Table = proc.GetTCP4Table + getTCP6Table = proc.GetTCP6Table + getUDP4Table = proc.GetUDP4Table + getUDP6Table = proc.GetUDP6Table +) + +func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { + if socketInfo.PID == proc.UnfetchedProcessID { + pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode) + if ok { + socketInfo.PID = pid + } else { + socketInfo.PID = UnidentifiedProcessID + } + } + return socketInfo.PID, connInbound, nil +} + +func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { + if socketInfo.PID == proc.UnfetchedProcessID { + pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode) + if ok { + socketInfo.PID = pid + } else { + socketInfo.PID = UnidentifiedProcessID + } + } + return socketInfo.PID, connInbound, nil +} diff --git a/network/state/system_windows.go b/network/state/system_windows.go new file mode 100644 index 00000000..a03ea5f6 --- /dev/null +++ b/network/state/system_windows.go @@ -0,0 +1,21 @@ +package state + +import ( + "github.com/safing/portmaster/network/iphelper" + "github.com/safing/portmaster/network/socket" +) + +var ( + getTCP4Table = iphelper.GetTCP4Table + getTCP6Table = iphelper.GetTCP6Table + getUDP4Table = iphelper.GetUDP4Table + getUDP6Table = iphelper.GetUDP6Table +) + +func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { + return socketInfo.PID, connInbound, nil +} + +func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { + return socketInfo.PID, connInbound, nil +} diff --git a/network/state/tables.go b/network/state/tables.go new file mode 100644 index 00000000..59095a16 --- /dev/null +++ b/network/state/tables.go @@ -0,0 +1,66 @@ +package state + +import ( + "github.com/safing/portbase/log" + "github.com/safing/portmaster/network/socket" +) + +var ( + tcp4Connections []*socket.ConnectionInfo + tcp4Listeners []*socket.BindInfo + + tcp6Connections []*socket.ConnectionInfo + tcp6Listeners []*socket.BindInfo + + udp4Binds []*socket.BindInfo + + udp6Binds []*socket.BindInfo +) + +func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { + // FIXME: repeatable once + + connections, listeners, err := getTCP4Table() + if err != nil { + log.Warningf("state: failed to get TCP4 socket table: %s", err) + return + } + + tcp4Connections = connections + tcp4Listeners = listeners + return tcp4Connections, tcp4Listeners +} + +func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { + connections, listeners, err := getTCP6Table() + if err != nil { + log.Warningf("state: failed to get TCP6 socket table: %s", err) + return + } + + tcp6Connections = connections + tcp6Listeners = listeners + return tcp6Connections, tcp6Listeners +} + +func updateUDP4Table() (binds []*socket.BindInfo) { + binds, err := getUDP4Table() + if err != nil { + log.Warningf("state: failed to get UDP4 socket table: %s", err) + return + } + + udp4Binds = binds + return udp4Binds +} + +func updateUDP6Table() (binds []*socket.BindInfo) { + binds, err := getUDP6Table() + if err != nil { + log.Warningf("state: failed to get UDP6 socket table: %s", err) + return + } + + udp6Binds = binds + return udp6Binds +} diff --git a/network/state/udp.go b/network/state/udp.go new file mode 100644 index 00000000..f24ac237 --- /dev/null +++ b/network/state/udp.go @@ -0,0 +1,118 @@ +package state + +import ( + "context" + "net" + "time" + + "github.com/safing/portmaster/network/socket" +) + +type udpState struct { + inbound bool + lastSeen time.Time +} + +const ( + UpdConnStateTTL = 72 * time.Hour + UdpConnStateShortenedTTL = 3 * time.Hour + AggressiveCleaningThreshold = 256 +) + +var ( + udp4States = make(map[string]map[string]*udpState) // locked with udp4Lock + udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock +) + +func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16) (udpConnState *udpState, ok bool) { + bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] + if ok { + udpConnState, ok = bindMap[makeUDPStateKey(remoteIP, remotePort)] + return + } + + return nil, false +} + +func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16, pktInbound bool) (connDirection bool) { + localKey := makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port) + + bindMap, ok := udpStates[localKey] + if !ok { + bindMap = make(map[string]*udpState) + udpStates[localKey] = bindMap + } + + remoteKey := makeUDPStateKey(remoteIP, remotePort) + udpConnState, ok := bindMap[remoteKey] + if !ok { + bindMap[remoteKey] = &udpState{ + inbound: pktInbound, + lastSeen: time.Now().UTC(), + } + return pktInbound + } + + udpConnState.lastSeen = time.Now().UTC() + return udpConnState.inbound +} + +func CleanUDPStates(ctx context.Context) { + now := time.Now().UTC() + + udp4Lock.Lock() + updateUDP4Table() + cleanStates(ctx, udp4Binds, udp4States, now) + udp4Lock.Unlock() + + udp6Lock.Lock() + updateUDP6Table() + cleanStates(ctx, udp6Binds, udp6States, now) + udp6Lock.Unlock() +} + +func cleanStates( + ctx context.Context, + binds []*socket.BindInfo, + udpStates map[string]map[string]*udpState, + now time.Time, +) { + // compute thresholds + threshold := now.Add(-UpdConnStateTTL) + shortThreshhold := now.Add(-UdpConnStateShortenedTTL) + + // make list of all active keys + bindKeys := make(map[string]struct{}) + for _, socketInfo := range binds { + bindKeys[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] = struct{}{} + } + + // clean the udp state storage + for localKey, bindMap := range udpStates { + _, active := bindKeys[localKey] + if active { + // clean old entries + for remoteKey, udpConnState := range bindMap { + if udpConnState.lastSeen.Before(threshold) { + delete(bindMap, remoteKey) + } + } + // if there are too many clean more aggressively + if len(bindMap) > AggressiveCleaningThreshold { + for remoteKey, udpConnState := range bindMap { + if udpConnState.lastSeen.Before(shortThreshhold) { + delete(bindMap, remoteKey) + } + } + } + } else { + // delete the whole thing + delete(udpStates, localKey) + } + } +} + +func makeUDPStateKey(ip net.IP, port uint16) string { + // This could potentially go wrong, but as all IPs are created by the same source, everything should be fine. + return string(ip) + string(port) +} diff --git a/process/find.go b/process/find.go index 30f93f2d..aa0a4071 100644 --- a/process/find.go +++ b/process/find.go @@ -5,137 +5,65 @@ import ( "errors" "net" + "github.com/safing/portmaster/network/state" + "github.com/safing/portbase/log" "github.com/safing/portmaster/network/packet" ) // Errors var ( - ErrConnectionNotFound = errors.New("could not find connection in system state tables") - ErrProcessNotFound = errors.New("could not find process in system state tables") + ErrProcessNotFound = errors.New("could not find process in system state tables") ) -// GetPidByPacket returns the pid of the owner of the packet. -func GetPidByPacket(pkt packet.Packet) (pid int, direction bool, err error) { - - var localIP net.IP - var localPort uint16 - var remoteIP net.IP - var remotePort uint16 - if pkt.IsInbound() { - localIP = pkt.Info().Dst - remoteIP = pkt.Info().Src - } else { - localIP = pkt.Info().Src - remoteIP = pkt.Info().Dst - } - if pkt.HasPorts() { - if pkt.IsInbound() { - localPort = pkt.Info().DstPort - remotePort = pkt.Info().SrcPort - } else { - localPort = pkt.Info().SrcPort - remotePort = pkt.Info().DstPort - } - } - - switch { - case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv4: - return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv4: - return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.Info().Protocol == packet.TCP && pkt.Info().Version == packet.IPv6: - return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - case pkt.Info().Protocol == packet.UDP && pkt.Info().Version == packet.IPv6: - return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, pkt.IsInbound()) - default: - return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") - } - -} - // GetProcessByPacket returns the process that owns the given packet. -func GetProcessByPacket(pkt packet.Packet) (process *Process, direction bool, err error) { - if !enableProcessDetection() { - log.Tracer(pkt.Ctx()).Tracef("process: process detection disabled") - return GetUnidentifiedProcess(pkt.Ctx()), pkt.Info().Direction, nil - } - - log.Tracer(pkt.Ctx()).Tracef("process: getting process and profile by packet") - - var pid int - pid, direction, err = GetPidByPacket(pkt) - if err != nil { - log.Tracer(pkt.Ctx()).Errorf("process: failed to find PID of connection: %s", err) - return nil, direction, err - } - if pid < 0 { - log.Tracer(pkt.Ctx()).Errorf("process: %s", ErrConnectionNotFound.Error()) - return nil, direction, ErrConnectionNotFound - } - - process, err = GetOrFindPrimaryProcess(pkt.Ctx(), pid) - if err != nil { - log.Tracer(pkt.Ctx()).Errorf("process: failed to find (primary) process with PID: %s", err) - return nil, direction, err - } - - err = process.GetProfile(pkt.Ctx()) - if err != nil { - log.Tracer(pkt.Ctx()).Errorf("process: failed to get profile for process %s: %s", process, err) - } - - return process, direction, nil - -} - -// GetPidByEndpoints returns the pid of the owner of the described link. -func GetPidByEndpoints(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (pid int, direction bool, err error) { - - ipVersion := packet.IPv4 - if v4 := localIP.To4(); v4 == nil { - ipVersion = packet.IPv6 - } - - switch { - case protocol == packet.TCP && ipVersion == packet.IPv4: - return getTCP4PacketInfo(localIP, localPort, remoteIP, remotePort, false) - case protocol == packet.UDP && ipVersion == packet.IPv4: - return getUDP4PacketInfo(localIP, localPort, remoteIP, remotePort, false) - case protocol == packet.TCP && ipVersion == packet.IPv6: - return getTCP6PacketInfo(localIP, localPort, remoteIP, remotePort, false) - case protocol == packet.UDP && ipVersion == packet.IPv6: - return getUDP6PacketInfo(localIP, localPort, remoteIP, remotePort, false) - default: - return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") - } - +func GetProcessByPacket(pkt packet.Packet) (process *Process, inbound bool, err error) { + meta := pkt.Info() + return GetProcessByEndpoints( + pkt.Ctx(), + meta.Version, + meta.Protocol, + meta.LocalIP(), + meta.LocalPort(), + meta.RemoteIP(), + meta.RemotePort(), + meta.Direction, + ) } // GetProcessByEndpoints returns the process that owns the described link. -func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, protocol packet.IPProtocol) (process *Process, err error) { +func GetProcessByEndpoints( + ctx context.Context, + ipVersion packet.IPVersion, + protocol packet.IPProtocol, + localIP net.IP, + localPort uint16, + remoteIP net.IP, + remotePort uint16, + pktInbound bool, +) ( + process *Process, + connInbound bool, + err error, +) { + if !enableProcessDetection() { log.Tracer(ctx).Tracef("process: process detection disabled") - return GetUnidentifiedProcess(ctx), nil + return GetUnidentifiedProcess(ctx), pktInbound, nil } - log.Tracer(ctx).Tracef("process: getting process and profile by endpoints") - + log.Tracer(ctx).Tracef("process: getting pid from system network state") var pid int - pid, _, err = GetPidByEndpoints(localIP, localPort, remoteIP, remotePort, protocol) + pid, connInbound, err = state.Lookup(ipVersion, protocol, localIP, localPort, remoteIP, remotePort, pktInbound) if err != nil { log.Tracer(ctx).Errorf("process: failed to find PID of connection: %s", err) - return nil, err - } - if pid < 0 { - log.Tracer(ctx).Errorf("process: %s", ErrConnectionNotFound.Error()) - return nil, ErrConnectionNotFound + return nil, connInbound, err } process, err = GetOrFindPrimaryProcess(ctx, pid) if err != nil { log.Tracer(ctx).Errorf("process: failed to find (primary) process with PID: %s", err) - return nil, err + return nil, connInbound, err } err = process.GetProfile(ctx) @@ -143,10 +71,5 @@ func GetProcessByEndpoints(ctx context.Context, localIP net.IP, localPort uint16 log.Tracer(ctx).Errorf("process: failed to get profile for process %s: %s", process, err) } - return process, nil -} - -// GetActiveConnectionIDs returns a list of all active connection IDs. -func GetActiveConnectionIDs() []string { - return getActiveConnectionIDs() + return process, connInbound, nil } diff --git a/process/getpid_linux.go b/process/getpid_linux.go deleted file mode 100644 index 1788f3e9..00000000 --- a/process/getpid_linux.go +++ /dev/null @@ -1,13 +0,0 @@ -package process - -import ( - "github.com/safing/portmaster/process/proc" -) - -var ( - getTCP4PacketInfo = proc.GetTCP4PacketInfo - getTCP6PacketInfo = proc.GetTCP6PacketInfo - getUDP4PacketInfo = proc.GetUDP4PacketInfo - getUDP6PacketInfo = proc.GetUDP6PacketInfo - getActiveConnectionIDs = proc.GetActiveConnectionIDs -) diff --git a/process/getpid_windows.go b/process/getpid_windows.go deleted file mode 100644 index 98b200ea..00000000 --- a/process/getpid_windows.go +++ /dev/null @@ -1,13 +0,0 @@ -package process - -import ( - "github.com/safing/portmaster/process/iphelper" -) - -var ( - getTCP4PacketInfo = iphelper.GetTCP4PacketInfo - getTCP6PacketInfo = iphelper.GetTCP6PacketInfo - getUDP4PacketInfo = iphelper.GetUDP4PacketInfo - getUDP6PacketInfo = iphelper.GetUDP6PacketInfo - getActiveConnectionIDs = iphelper.GetActiveConnectionIDs -) diff --git a/process/iphelper/get.go b/process/iphelper/get.go deleted file mode 100644 index 6487ea06..00000000 --- a/process/iphelper/get.go +++ /dev/null @@ -1,260 +0,0 @@ -// +build windows - -package iphelper - -import ( - "fmt" - "net" - "sync" - "time" -) - -const ( - unidentifiedProcessID = -1 -) - -var ( - tcp4Connections []*ConnectionEntry - tcp4Listeners []*ConnectionEntry - tcp6Connections []*ConnectionEntry - tcp6Listeners []*ConnectionEntry - - udp4Connections []*ConnectionEntry - udp4Listeners []*ConnectionEntry - udp6Connections []*ConnectionEntry - udp6Listeners []*ConnectionEntry - - ipHelper *IPHelper - lock sync.RWMutex - - waitTime = 15 * time.Millisecond -) - -func checkIPHelper() (err error) { - if ipHelper == nil { - ipHelper, err = New() - return err - } - return nil -} - -// GetTCP4PacketInfo returns the pid of the given IPv4/TCP connection. -func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - - // search - pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - for i := 0; i < 3; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - tcp4Connections, tcp4Listeners, err = ipHelper.GetTables(TCP, IPv4) - } - lock.Unlock() - if err != nil { - return unidentifiedProcessID, pktDirection, err - } - - // search - pid, _ = search(tcp4Connections, tcp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - time.Sleep(waitTime) - } - - return unidentifiedProcessID, pktDirection, nil -} - -// GetTCP6PacketInfo returns the pid of the given IPv6/TCP connection. -func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - - // search - pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - for i := 0; i < 3; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - tcp6Connections, tcp6Listeners, err = ipHelper.GetTables(TCP, IPv6) - } - lock.Unlock() - if err != nil { - return unidentifiedProcessID, pktDirection, err - } - - // search - pid, _ = search(tcp6Connections, tcp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - time.Sleep(waitTime) - } - - return unidentifiedProcessID, pktDirection, nil -} - -// GetUDP4PacketInfo returns the pid of the given IPv4/UDP connection. -func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - - // search - pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - for i := 0; i < 3; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - udp4Connections, udp4Listeners, err = ipHelper.GetTables(UDP, IPv4) - } - lock.Unlock() - if err != nil { - return unidentifiedProcessID, pktDirection, err - } - - // search - pid, _ = search(udp4Connections, udp4Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - time.Sleep(waitTime) - } - - return unidentifiedProcessID, pktDirection, nil -} - -// GetUDP6PacketInfo returns the pid of the given IPv6/UDP connection. -func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - - // search - pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - for i := 0; i < 3; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - - // if unable to find, refresh - lock.Lock() - err = checkIPHelper() - if err == nil { - udp6Connections, udp6Listeners, err = ipHelper.GetTables(UDP, IPv6) - } - lock.Unlock() - if err != nil { - return unidentifiedProcessID, pktDirection, err - } - - // search - pid, _ = search(udp6Connections, udp6Listeners, localIP, remoteIP, localPort, remotePort, pktDirection) - if pid >= 0 { - return pid, pktDirection, nil - } - - time.Sleep(waitTime) - } - - return unidentifiedProcessID, pktDirection, nil -} - -func search(connections, listeners []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16, pktDirection bool) (pid int, direction bool) { //nolint:unparam // TODO: use direction, it may not be used because results caused problems, investigate. - lock.RLock() - defer lock.RUnlock() - - if pktDirection { - // inbound - pid = searchListeners(listeners, localIP, localPort) - if pid >= 0 { - return pid, true - } - pid = searchConnections(connections, localIP, remoteIP, localPort, remotePort) - if pid >= 0 { - return pid, false - } - } else { - // outbound - pid = searchConnections(connections, localIP, remoteIP, localPort, remotePort) - if pid >= 0 { - return pid, false - } - pid = searchListeners(listeners, localIP, localPort) - if pid >= 0 { - return pid, true - } - } - - return unidentifiedProcessID, pktDirection -} - -func searchConnections(list []*ConnectionEntry, localIP, remoteIP net.IP, localPort, remotePort uint16) (pid int) { - - for _, entry := range list { - if localPort == entry.localPort && - remotePort == entry.remotePort && - remoteIP.Equal(entry.remoteIP) && - localIP.Equal(entry.localIP) { - return entry.pid - } - } - - return unidentifiedProcessID -} - -func searchListeners(list []*ConnectionEntry, localIP net.IP, localPort uint16) (pid int) { - - for _, entry := range list { - if localPort == entry.localPort && - (entry.localIP == nil || // nil IP means zero IP, see tables.go - localIP.Equal(entry.localIP)) { - return entry.pid - } - } - - return unidentifiedProcessID -} - -// GetActiveConnectionIDs returns all currently active connection IDs. -func GetActiveConnectionIDs() (connections []string) { - lock.Lock() - defer lock.Unlock() - - for _, entry := range tcp4Connections { - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", TCP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)) - } - for _, entry := range tcp6Connections { - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", TCP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)) - } - for _, entry := range udp4Connections { - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", UDP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)) - } - for _, entry := range udp6Connections { - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", UDP, entry.localIP, entry.localPort, entry.remoteIP, entry.remotePort)) - } - - return -} diff --git a/process/iphelper/iphelper.go b/process/iphelper/iphelper.go deleted file mode 100644 index a7c259da..00000000 --- a/process/iphelper/iphelper.go +++ /dev/null @@ -1,79 +0,0 @@ -// +build windows - -package iphelper - -import ( - "errors" - "fmt" - - "github.com/tevino/abool" - "golang.org/x/sys/windows" -) - -var ( - errInvalid = errors.New("IPHelper not initialzed or broken") -) - -// IPHelper represents a subset of the Windows iphlpapi.dll. -type IPHelper struct { - dll *windows.LazyDLL - - getExtendedTCPTable *windows.LazyProc - getExtendedUDPTable *windows.LazyProc - // getOwnerModuleFromTcpEntry *windows.LazyProc - // getOwnerModuleFromTcp6Entry *windows.LazyProc - // getOwnerModuleFromUdpEntry *windows.LazyProc - // getOwnerModuleFromUdp6Entry *windows.LazyProc - - valid *abool.AtomicBool -} - -// New returns a new IPHelper API (with an instance of iphlpapi.dll loaded). -func New() (*IPHelper, error) { - - new := &IPHelper{} - new.valid = abool.NewBool(false) - var err error - - // load dll - new.dll = windows.NewLazySystemDLL("iphlpapi.dll") - err = new.dll.Load() - if err != nil { - return nil, err - } - - // load functions - new.getExtendedTCPTable = new.dll.NewProc("GetExtendedTcpTable") - err = new.getExtendedTCPTable.Find() - if err != nil { - return nil, fmt.Errorf("could find proc GetExtendedTcpTable: %s", err) - } - new.getExtendedUDPTable = new.dll.NewProc("GetExtendedUdpTable") - err = new.getExtendedUDPTable.Find() - if err != nil { - return nil, fmt.Errorf("could find proc GetExtendedUdpTable: %s", err) - } - // new.getOwnerModuleFromTcpEntry = new.dll.NewProc("GetOwnerModuleFromTcpEntry") - // err = new.getOwnerModuleFromTcpEntry.Find() - // if err != nil { - // return nil, fmt.Errorf("could find proc GetOwnerModuleFromTcpEntry: %s", err) - // } - // new.getOwnerModuleFromTcp6Entry = new.dll.NewProc("GetOwnerModuleFromTcp6Entry") - // err = new.getOwnerModuleFromTcp6Entry.Find() - // if err != nil { - // return nil, fmt.Errorf("could find proc GetOwnerModuleFromTcp6Entry: %s", err) - // } - // new.getOwnerModuleFromUdpEntry = new.dll.NewProc("GetOwnerModuleFromUdpEntry") - // err = new.getOwnerModuleFromUdpEntry.Find() - // if err != nil { - // return nil, fmt.Errorf("could find proc GetOwnerModuleFromUdpEntry: %s", err) - // } - // new.getOwnerModuleFromUdp6Entry = new.dll.NewProc("GetOwnerModuleFromUdp6Entry") - // err = new.getOwnerModuleFromUdp6Entry.Find() - // if err != nil { - // return nil, fmt.Errorf("could find proc GetOwnerModuleFromUdp6Entry: %s", err) - // } - - new.valid.Set() - return new, nil -} diff --git a/process/proc/gather.go b/process/proc/gather.go deleted file mode 100644 index 1413b3c9..00000000 --- a/process/proc/gather.go +++ /dev/null @@ -1,83 +0,0 @@ -// +build linux - -package proc - -import ( - "net" - "time" -) - -// PID querying return codes -const ( - Success uint8 = iota - NoSocket - NoProcess -) - -var ( - waitTime = 15 * time.Millisecond -) - -// GetPidOfConnection returns the PID of the given connection. -func GetPidOfConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { - uid, inode, ok := getConnectionSocket(localIP, localPort, protocol) - if !ok { - uid, inode, ok = getListeningSocket(localIP, localPort, protocol) - for i := 0; i < 3 && !ok; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - time.Sleep(waitTime) - uid, inode, ok = getConnectionSocket(localIP, localPort, protocol) - if !ok { - uid, inode, ok = getListeningSocket(localIP, localPort, protocol) - } - } - if !ok { - return unidentifiedProcessID, NoSocket - } - } - - pid, ok = GetPidOfInode(uid, inode) - for i := 0; i < 3 && !ok; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - time.Sleep(waitTime) - pid, ok = GetPidOfInode(uid, inode) - } - if !ok { - return unidentifiedProcessID, NoProcess - } - - return -} - -// GetPidOfIncomingConnection returns the PID of the given incoming connection. -func GetPidOfIncomingConnection(localIP net.IP, localPort uint16, protocol uint8) (pid int, status uint8) { - uid, inode, ok := getListeningSocket(localIP, localPort, protocol) - if !ok { - // for TCP4 and UDP4, also try TCP6 and UDP6, as linux sometimes treats them as a single dual socket, and shows the IPv6 version. - switch protocol { - case TCP4: - uid, inode, ok = getListeningSocket(localIP, localPort, TCP6) - case UDP4: - uid, inode, ok = getListeningSocket(localIP, localPort, UDP6) - } - - if !ok { - return unidentifiedProcessID, NoSocket - } - } - - pid, ok = GetPidOfInode(uid, inode) - for i := 0; i < 3 && !ok; i++ { - // give kernel some time, then try again - // log.Tracef("process: giving kernel some time to think") - time.Sleep(waitTime) - pid, ok = GetPidOfInode(uid, inode) - } - if !ok { - return unidentifiedProcessID, NoProcess - } - - return -} diff --git a/process/proc/get.go b/process/proc/get.go deleted file mode 100644 index 52974b3e..00000000 --- a/process/proc/get.go +++ /dev/null @@ -1,66 +0,0 @@ -// +build linux - -package proc - -import ( - "errors" - "net" -) - -const ( - unidentifiedProcessID = -1 -) - -// GetTCP4PacketInfo searches the network state tables for a TCP4 connection -func GetTCP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(TCP4, localIP, localPort, pktDirection) -} - -// GetTCP6PacketInfo searches the network state tables for a TCP6 connection -func GetTCP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(TCP6, localIP, localPort, pktDirection) -} - -// GetUDP4PacketInfo searches the network state tables for a UDP4 connection -func GetUDP4PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(UDP4, localIP, localPort, pktDirection) -} - -// GetUDP6PacketInfo searches the network state tables for a UDP6 connection -func GetUDP6PacketInfo(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, pktDirection bool) (pid int, direction bool, err error) { - return search(UDP6, localIP, localPort, pktDirection) -} - -func search(protocol uint8, localIP net.IP, localPort uint16, pktDirection bool) (pid int, direction bool, err error) { - - var status uint8 - if pktDirection { - pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol) - if pid >= 0 { - return pid, true, nil - } - // pid, status = GetPidOfConnection(localIP, localPort, protocol) - // if pid >= 0 { - // return pid, false, nil - // } - } else { - pid, status = GetPidOfConnection(localIP, localPort, protocol) - if pid >= 0 { - return pid, false, nil - } - // pid, status = GetPidOfIncomingConnection(localIP, localPort, protocol) - // if pid >= 0 { - // return pid, true, nil - // } - } - - switch status { - case NoSocket: - return unidentifiedProcessID, direction, errors.New("could not find socket") - case NoProcess: - return unidentifiedProcessID, direction, errors.New("could not find PID") - default: - return unidentifiedProcessID, direction, nil - } - -} diff --git a/process/proc/processfinder_test.go b/process/proc/processfinder_test.go deleted file mode 100644 index 16d3d181..00000000 --- a/process/proc/processfinder_test.go +++ /dev/null @@ -1,18 +0,0 @@ -// +build linux - -package proc - -import ( - "log" - "testing" -) - -func TestProcessFinder(t *testing.T) { - - updatePids() - log.Printf("pidsByUser: %v", pidsByUser) - - pid, _ := GetPidOfInode(1000, 112588) - log.Printf("pid: %d", pid) - -} diff --git a/process/proc/sockets.go b/process/proc/sockets.go deleted file mode 100644 index bcdd91d4..00000000 --- a/process/proc/sockets.go +++ /dev/null @@ -1,370 +0,0 @@ -// +build linux - -package proc - -import ( - "bufio" - "encoding/hex" - "fmt" - "net" - "os" - "strconv" - "strings" - "sync" - "unicode" - - "github.com/safing/portbase/log" -) - -/* - -1. find socket inode - - by incoming (listenting sockets) or outgoing (local port + external IP + port) - also local IP? - - /proc/net/{tcp|udp}[6] - -2. get list of processes of uid - -3. find socket inode in process fds - - if not found, refresh map of uid->pids - - if not found, check ALL pids: maybe euid != uid - -4. gather process info - -Cache every step! - -*/ - -// Network Related Constants -const ( - TCP4 uint8 = iota - UDP4 - TCP6 - UDP6 - ICMP4 - ICMP6 - - TCP4Data = "/proc/net/tcp" - UDP4Data = "/proc/net/udp" - TCP6Data = "/proc/net/tcp6" - UDP6Data = "/proc/net/udp6" - ICMP4Data = "/proc/net/icmp" - ICMP6Data = "/proc/net/icmp6" -) - -var ( - // connectionSocketsLock sync.Mutex - // connectionTCP4 = make(map[string][]int) - // connectionUDP4 = make(map[string][]int) - // connectionTCP6 = make(map[string][]int) - // connectionUDP6 = make(map[string][]int) - - listeningSocketsLock sync.Mutex - addressListeningTCP4 = make(map[string][]int) - addressListeningUDP4 = make(map[string][]int) - addressListeningTCP6 = make(map[string][]int) - addressListeningUDP6 = make(map[string][]int) - globalListeningTCP4 = make(map[uint16][]int) - globalListeningUDP4 = make(map[uint16][]int) - globalListeningTCP6 = make(map[uint16][]int) - globalListeningUDP6 = make(map[uint16][]int) -) - -func getConnectionSocket(localIP net.IP, localPort uint16, protocol uint8) (int, int, bool) { - // listeningSocketsLock.Lock() - // defer listeningSocketsLock.Unlock() - - var procFile string - var localIPHex string - switch protocol { - case TCP4: - procFile = TCP4Data - localIPBytes := []byte(localIP.To4()) - localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]})) - case UDP4: - procFile = UDP4Data - localIPBytes := []byte(localIP.To4()) - localIPHex = strings.ToUpper(hex.EncodeToString([]byte{localIPBytes[3], localIPBytes[2], localIPBytes[1], localIPBytes[0]})) - case TCP6: - procFile = TCP6Data - localIPHex = hex.EncodeToString([]byte(localIP)) - case UDP6: - procFile = UDP6Data - localIPHex = hex.EncodeToString([]byte(localIP)) - } - - localPortHex := fmt.Sprintf("%04X", localPort) - - // log.Tracef("process/proc: searching for PID of: %s:%d (%s:%s)", localIP, localPort, localIPHex, localPortHex) - - // open file - socketData, err := os.Open(procFile) - if err != nil { - log.Warningf("process/proc: could not read %s: %s", procFile, err) - return unidentifiedProcessID, unidentifiedProcessID, false - } - defer socketData.Close() - - // file scanner - scanner := bufio.NewScanner(socketData) - scanner.Split(bufio.ScanLines) - - // parse - scanner.Scan() // skip first line - for scanner.Scan() { - line := strings.FieldsFunc(scanner.Text(), procDelimiter) - // log.Tracef("line: %s", line) - if len(line) < 14 { - // log.Tracef("process: too short: %s", line) - continue - } - - if line[1] != localIPHex { - continue - } - if line[2] != localPortHex { - continue - } - - ok := true - - uid, err := strconv.ParseInt(line[11], 10, 32) - if err != nil { - log.Warningf("process: could not parse uid %s: %s", line[11], err) - uid = -1 - ok = false - } - - inode, err := strconv.ParseInt(line[13], 10, 32) - if err != nil { - log.Warningf("process: could not parse inode %s: %s", line[13], err) - inode = -1 - ok = false - } - - // log.Tracef("process/proc: identified process of %s:%d: socket=%d uid=%d", localIP, localPort, int(inode), int(uid)) - return int(uid), int(inode), ok - - } - - return unidentifiedProcessID, unidentifiedProcessID, false - -} - -func getListeningSocket(localIP net.IP, localPort uint16, protocol uint8) (uid, inode int, ok bool) { - listeningSocketsLock.Lock() - defer listeningSocketsLock.Unlock() - - var addressListening map[string][]int - var globalListening map[uint16][]int - switch protocol { - case TCP4: - addressListening = addressListeningTCP4 - globalListening = globalListeningTCP4 - case UDP4: - addressListening = addressListeningUDP4 - globalListening = globalListeningUDP4 - case TCP6: - addressListening = addressListeningTCP6 - globalListening = globalListeningTCP6 - case UDP6: - addressListening = addressListeningUDP6 - globalListening = globalListeningUDP6 - } - - data, ok := addressListening[fmt.Sprintf("%s:%d", localIP, localPort)] - if !ok { - data, ok = globalListening[localPort] - } - if ok { - return data[0], data[1], true - } - updateListeners(protocol) - data, ok = addressListening[fmt.Sprintf("%s:%d", localIP, localPort)] - if !ok { - data, ok = globalListening[localPort] - } - if ok { - return data[0], data[1], true - } - - return unidentifiedProcessID, unidentifiedProcessID, false -} - -func procDelimiter(c rune) bool { - return unicode.IsSpace(c) || c == ':' -} - -func convertIPv4(data string) net.IP { - decoded, err := hex.DecodeString(data) - if err != nil { - log.Warningf("process: could not parse IPv4 %s: %s", data, err) - return nil - } - if len(decoded) != 4 { - log.Warningf("process: decoded IPv4 %s has wrong length", decoded) - return nil - } - ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) - return ip -} - -func convertIPv6(data string) net.IP { - decoded, err := hex.DecodeString(data) - if err != nil { - log.Warningf("process: could not parse IPv6 %s: %s", data, err) - return nil - } - if len(decoded) != 16 { - log.Warningf("process: decoded IPv6 %s has wrong length", decoded) - return nil - } - ip := net.IP(decoded) - return ip -} - -func updateListeners(protocol uint8) { - switch protocol { - case TCP4: - addressListeningTCP4, globalListeningTCP4 = getListenerMaps(TCP4Data, "00000000", "0A", convertIPv4) - case UDP4: - addressListeningUDP4, globalListeningUDP4 = getListenerMaps(UDP4Data, "00000000", "07", convertIPv4) - case TCP6: - addressListeningTCP6, globalListeningTCP6 = getListenerMaps(TCP6Data, "00000000000000000000000000000000", "0A", convertIPv6) - case UDP6: - addressListeningUDP6, globalListeningUDP6 = getListenerMaps(UDP6Data, "00000000000000000000000000000000", "07", convertIPv6) - } -} - -func getListenerMaps(procFile, zeroIP, socketStatusListening string, ipConverter func(string) net.IP) (map[string][]int, map[uint16][]int) { - addressListening := make(map[string][]int) - globalListening := make(map[uint16][]int) - - // open file - socketData, err := os.Open(procFile) - if err != nil { - log.Warningf("process: could not read %s: %s", procFile, err) - return addressListening, globalListening - } - defer socketData.Close() - - // file scanner - scanner := bufio.NewScanner(socketData) - scanner.Split(bufio.ScanLines) - - // parse - scanner.Scan() // skip first line - for scanner.Scan() { - line := strings.FieldsFunc(scanner.Text(), procDelimiter) - if len(line) < 14 { - // log.Tracef("process: too short: %s", line) - continue - } - if line[5] != socketStatusListening { - // skip if not listening - // log.Tracef("process: not listening %s: %s", line, line[5]) - continue - } - - port, err := strconv.ParseUint(line[2], 16, 16) - // log.Tracef("port: %s", line[2]) - if err != nil { - log.Warningf("process: could not parse port %s: %s", line[2], err) - continue - } - - uid, err := strconv.ParseInt(line[11], 10, 32) - // log.Tracef("uid: %s", line[11]) - if err != nil { - log.Warningf("process: could not parse uid %s: %s", line[11], err) - continue - } - - inode, err := strconv.ParseInt(line[13], 10, 32) - // log.Tracef("inode: %s", line[13]) - if err != nil { - log.Warningf("process: could not parse inode %s: %s", line[13], err) - continue - } - - if line[1] == zeroIP { - globalListening[uint16(port)] = []int{int(uid), int(inode)} - } else { - address := ipConverter(line[1]) - if address != nil { - addressListening[fmt.Sprintf("%s:%d", address, port)] = []int{int(uid), int(inode)} - } - } - - } - - return addressListening, globalListening -} - -// GetActiveConnectionIDs returns all connection IDs that are still marked as active by the OS. -func GetActiveConnectionIDs() []string { - var connections []string - - connections = append(connections, getConnectionIDsFromSource(TCP4Data, 6, convertIPv4)...) - connections = append(connections, getConnectionIDsFromSource(UDP4Data, 17, convertIPv4)...) - connections = append(connections, getConnectionIDsFromSource(TCP6Data, 6, convertIPv6)...) - connections = append(connections, getConnectionIDsFromSource(UDP6Data, 17, convertIPv6)...) - - return connections -} - -func getConnectionIDsFromSource(source string, protocol uint16, ipConverter func(string) net.IP) []string { - var connections []string - - // open file - socketData, err := os.Open(source) - if err != nil { - log.Warningf("process: could not read %s: %s", source, err) - return connections - } - defer socketData.Close() - - // file scanner - scanner := bufio.NewScanner(socketData) - scanner.Split(bufio.ScanLines) - - // parse - scanner.Scan() // skip first line - for scanner.Scan() { - line := strings.FieldsFunc(scanner.Text(), procDelimiter) - if len(line) < 14 { - // log.Tracef("process: too short: %s", line) - continue - } - - // skip listeners and closed connections - if line[5] == "0A" || line[5] == "07" { - continue - } - - localIP := ipConverter(line[1]) - if localIP == nil { - continue - } - - localPort, err := strconv.ParseUint(line[2], 16, 16) - if err != nil { - log.Warningf("process: could not parse port: %s", err) - continue - } - - remoteIP := ipConverter(line[3]) - if remoteIP == nil { - continue - } - - remotePort, err := strconv.ParseUint(line[4], 16, 16) - if err != nil { - log.Warningf("process: could not parse port: %s", err) - continue - } - - connections = append(connections, fmt.Sprintf("%d-%s-%d-%s-%d", protocol, localIP, localPort, remoteIP, remotePort)) - } - - return connections -} diff --git a/process/proc/sockets_test.go b/process/proc/sockets_test.go deleted file mode 100644 index 44e8fd34..00000000 --- a/process/proc/sockets_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// +build linux - -package proc - -import ( - "net" - "testing" -) - -func TestSockets(t *testing.T) { - - updateListeners(TCP4) - updateListeners(UDP4) - updateListeners(TCP6) - updateListeners(UDP6) - t.Logf("addressListeningTCP4: %v", addressListeningTCP4) - t.Logf("globalListeningTCP4: %v", globalListeningTCP4) - t.Logf("addressListeningUDP4: %v", addressListeningUDP4) - t.Logf("globalListeningUDP4: %v", globalListeningUDP4) - t.Logf("addressListeningTCP6: %v", addressListeningTCP6) - t.Logf("globalListeningTCP6: %v", globalListeningTCP6) - t.Logf("addressListeningUDP6: %v", addressListeningUDP6) - t.Logf("globalListeningUDP6: %v", globalListeningUDP6) - - getListeningSocket(net.IPv4zero, 53, TCP4) - getListeningSocket(net.IPv4zero, 53, UDP4) - getListeningSocket(net.IPv6zero, 53, TCP6) - getListeningSocket(net.IPv6zero, 53, UDP6) - - // spotify: 192.168.0.102:5353 192.121.140.65:80 - localIP := net.IPv4(192, 168, 127, 10) - uid, inode, ok := getConnectionSocket(localIP, 46634, TCP4) - t.Logf("getConnectionSocket: %d %d %v", uid, inode, ok) - - activeConnectionIDs := GetActiveConnectionIDs() - for _, connID := range activeConnectionIDs { - t.Logf("active: %s", connID) - } - -} diff --git a/updates/main.go b/updates/main.go index ea15e075..84aa132e 100644 --- a/updates/main.go +++ b/updates/main.go @@ -152,7 +152,7 @@ func start() error { err = registry.LoadIndexes() if err != nil { - return err + log.Warningf("updates: failed to load indexes: %s", err) } err = registry.ScanStorage("") @@ -235,8 +235,7 @@ func checkForUpdates(ctx context.Context) (err error) { }() if err = registry.UpdateIndexes(); err != nil { - err = fmt.Errorf("failed to update indexes: %w", err) - return + log.Warningf("updates: failed to update indexes: %s", err) } err = registry.DownloadUpdates(ctx) diff --git a/updates/upgrader.go b/updates/upgrader.go index 1d754f7a..b109c913 100644 --- a/updates/upgrader.go +++ b/updates/upgrader.go @@ -113,15 +113,9 @@ func upgradePortmasterControl() error { return nil } - // check if registry tmp dir is ok - err := registry.TmpDir().Ensure() - if err != nil { - return fmt.Errorf("failed to prep updates tmp dir: %s", err) - } - // update portmaster-control in data root rootControlPath := filepath.Join(filepath.Dir(registry.StorageDir().Path), filename) - err = upgradeFile(rootControlPath, pmCtrlUpdate) + err := upgradeFile(rootControlPath, pmCtrlUpdate) if err != nil { return err } @@ -130,11 +124,11 @@ func upgradePortmasterControl() error { // upgrade parent process, if it's portmaster-control parent, err := processInfo.NewProcess(int32(os.Getppid())) if err != nil { - return fmt.Errorf("could not get parent process for upgrade checks: %s", err) + return fmt.Errorf("could not get parent process for upgrade checks: %w", err) } parentName, err := parent.Name() if err != nil { - return fmt.Errorf("could not get parent process name for upgrade checks: %s", err) + return fmt.Errorf("could not get parent process name for upgrade checks: %w", err) } if parentName != filename { log.Tracef("updates: parent process does not seem to be portmaster-control, name is %s", parentName) @@ -142,7 +136,7 @@ func upgradePortmasterControl() error { } parentPath, err := parent.Exe() if err != nil { - return fmt.Errorf("could not get parent process path for upgrade: %s", err) + return fmt.Errorf("could not get parent process path for upgrade: %w", err) } err = upgradeFile(parentPath, pmCtrlUpdate) if err != nil { @@ -190,7 +184,7 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error { // ensure tmp dir is here err = registry.TmpDir().Ensure() if err != nil { - return fmt.Errorf("unable to check updates tmp dir for moving file that needs upgrade: %s", err) + return fmt.Errorf("could not prepare tmp directory for moving file that needs upgrade: %w", err) } // maybe we're on windows and it's in use, try moving @@ -204,17 +198,17 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error { ), )) if err != nil { - return fmt.Errorf("unable to move file that needs upgrade: %s", err) + return fmt.Errorf("unable to move file that needs upgrade: %w", err) } } } // copy upgrade - err = copyFile(file.Path(), fileToUpgrade) + err = CopyFile(file.Path(), fileToUpgrade) if err != nil { // try again time.Sleep(1 * time.Second) - err = copyFile(file.Path(), fileToUpgrade) + err = CopyFile(file.Path(), fileToUpgrade) if err != nil { return err } @@ -224,23 +218,30 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error { if !onWindows { info, err := os.Stat(fileToUpgrade) if err != nil { - return fmt.Errorf("failed to get file info on %s: %s", fileToUpgrade, err) + return fmt.Errorf("failed to get file info on %s: %w", fileToUpgrade, err) } if info.Mode() != 0755 { err := os.Chmod(fileToUpgrade, 0755) if err != nil { - return fmt.Errorf("failed to set permissions on %s: %s", fileToUpgrade, err) + return fmt.Errorf("failed to set permissions on %s: %w", fileToUpgrade, err) } } } return nil } -func copyFile(srcPath, dstPath string) (err error) { +func CopyFile(srcPath, dstPath string) (err error) { + + // check tmp dir + err = registry.TmpDir().Ensure() + if err != nil { + return fmt.Errorf("could not prepare tmp directory for copying file: %w", err) + } + // open file for writing atomicDstFile, err := renameio.TempFile(registry.TmpDir().Path, dstPath) if err != nil { - return fmt.Errorf("could not create temp file for atomic copy: %s", err) + return fmt.Errorf("could not create temp file for atomic copy: %w", err) } defer atomicDstFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway @@ -260,7 +261,7 @@ func copyFile(srcPath, dstPath string) (err error) { // finalize file err = atomicDstFile.CloseAtomicallyReplace() if err != nil { - return fmt.Errorf("updates: failed to finalize copy to file %s: %s", dstPath, err) + return fmt.Errorf("updates: failed to finalize copy to file %s: %w", dstPath, err) } return nil From cb991e9f021093ec9cae0cc5595e8503f7706eee Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 17:16:08 +0200 Subject: [PATCH 02/36] Fix and improve IP address conversion on windows --- firewall/interception/windowskext/handler.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 99ad9b59..254c046b 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -114,18 +114,15 @@ func Handler(packets chan packet.Packet) { } func convertIPv4(input [4]uint32) net.IP { - return net.IPv4( - uint8(input[0]>>24&0xFF), - uint8(input[0]>>16&0xFF), - uint8(input[0]>>8&0xFF), - uint8(input[0]&0xFF), - ) + addressBuf := make([]byte, 4) + binary.BigEndian.PutUint32(addressBuf, input[0]) + return net.IP(addressBuf) } func convertIPv6(input [4]uint32) net.IP { addressBuf := make([]byte, 16) for i := 0; i < 4; i++ { - binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) + binary.BigEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) } return net.IP(addressBuf) } From 6e9c22d0b54685d31574f18e2c8996b5062101e2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 17:17:17 +0200 Subject: [PATCH 03/36] Stop whitelisting IGMP --- firewall/interception.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/firewall/interception.go b/firewall/interception.go index f99c94c0..a916776d 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -138,7 +138,7 @@ func handlePacket(pkt packet.Packet) { // pkt.RedirToNameserver() // } - // allow ICMP, IGMP and DHCP + // allow ICMP and DHCP // TODO: actually handle these switch meta.Protocol { case packet.ICMP: @@ -149,10 +149,6 @@ func handlePacket(pkt packet.Packet) { log.Debugf("accepting ICMPv6: %s", pkt) _ = pkt.PermanentAccept() return - case packet.IGMP: - log.Debugf("accepting IGMP: %s", pkt) - _ = pkt.PermanentAccept() - return case packet.UDP: if meta.DstPort == 67 || meta.DstPort == 68 { log.Debugf("accepting DHCP: %s", pkt) From c3ca0c4c84f766306d33ce6940af80fb47900490 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 17:17:48 +0200 Subject: [PATCH 04/36] Set dns redirects to Internal for better UX --- firewall/interception.go | 1 + 1 file changed, 1 insertion(+) diff --git a/firewall/interception.go b/firewall/interception.go index a916776d..0e3fc9d1 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -214,6 +214,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { // reroute dns requests to nameserver if conn.Process().Pid != os.Getpid() && pkt.IsOutbound() && pkt.Info().DstPort == 53 && !pkt.Info().Src.Equal(pkt.Info().Dst) { conn.Verdict = network.VerdictRerouteToNameserver + conn.Internal = true conn.StopFirewallHandler() issueVerdict(conn, pkt, 0, true) return From 635d5770d12797fca093c47e4e934893cd9be8aa Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 17:18:42 +0200 Subject: [PATCH 05/36] Change BlockInbound to only affect LAN and Internet --- firewall/master.go | 2 +- profile/config.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/firewall/master.go b/firewall/master.go index 69020dad..d7b926c8 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -157,7 +157,7 @@ func checkConnectionType(conn *network.Connection, _ packet.Packet) bool { // check conn type switch conn.Scope { - case network.IncomingHost, network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: + case network.IncomingLAN, network.IncomingInternet, network.IncomingInvalid: if p.BlockInbound() { if conn.Scope == network.IncomingHost { conn.Block("inbound connections blocked") diff --git a/profile/config.go b/profile/config.go index 4464a513..ad1434fd 100644 --- a/profile/config.go +++ b/profile/config.go @@ -326,7 +326,7 @@ Examples: err = config.Register(&config.Option{ Name: "Block Inbound Connections", Key: CfgOptionBlockInboundKey, - Description: "Connections initiated towards your device. This will usually only be the case if you are running a network service or are using peer to peer software.", + Description: "Connections initiated towards your device from the LAN or Internet. This will usually only be the case if you are running a network service or are using peer to peer software.", Order: cfgOptionBlockInboundOrder, OptType: config.OptTypeInt, ExternalOptType: "security level", From 87a55541b2b464c8459fceffc8dfb57d47fc706d Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 22:40:05 +0200 Subject: [PATCH 06/36] Add endpoint type network scope Also, update default service endpoint list configuration to allow localhost --- profile/config.go | 56 +++++--------- profile/endpoints/endpoint-scopes.go | 112 +++++++++++++++++++++++++++ profile/endpoints/endpoint.go | 4 + profile/endpoints/endpoints_test.go | 19 +++++ 4 files changed, 154 insertions(+), 37 deletions(-) create mode 100644 profile/endpoints/endpoint-scopes.go diff --git a/profile/config.go b/profile/config.go index ad1434fd..60a63070 100644 --- a/profile/config.go +++ b/profile/config.go @@ -121,17 +121,12 @@ func registerConfiguration() error { cfgOptionDisableAutoPermit = config.Concurrent.GetAsInt(CfgOptionDisableAutoPermitKey, int64(status.SecurityLevelsAll)) cfgIntOptions[CfgOptionDisableAutoPermitKey] = cfgOptionDisableAutoPermit - // Endpoint Filter List - err = config.Register(&config.Option{ - Name: "Endpoint Filter List", - Key: CfgOptionEndpointsKey, - Description: "Filter outgoing connections by matching the destination endpoint. Network Scope restrictions still apply.", - Help: `Format: + filterListHelp := `Format: Permission: "+": permit "-": block Host Matching: - IP, CIDR, Country Code, ASN, Filterlist, "*" for any + IP, CIDR, Country Code, ASN, Filterlist, Network Scope, "*" for any Domains: "example.com": exact match ".example.com": exact match + subdomains @@ -144,11 +139,20 @@ func registerConfiguration() error { Examples: + .example.com */HTTP - .example.com - + 192.168.0.1/24 + + 192.168.0.1 + + 192.168.1.1/24 + + Localhost,LAN + - AS123456789 - L:MAL - - AS0 + AT - - *`, + - *` + + // Endpoint Filter List + err = config.Register(&config.Option{ + Name: "Endpoint Filter List", + Key: CfgOptionEndpointsKey, + Description: "Filter outgoing connections by matching the destination endpoint. Network Scope restrictions still apply.", + Help: filterListHelp, Order: cfgOptionEndpointsOrder, OptType: config.OptTypeStringArray, DefaultValue: []string{}, @@ -163,35 +167,13 @@ Examples: // Service Endpoint Filter List err = config.Register(&config.Option{ - Name: "Service Endpoint Filter List", - Key: CfgOptionServiceEndpointsKey, - Description: "Filter incoming connections by matching the source endpoint. Network Scope restrictions and the inbound permission still apply. Also not that the implicit default action of this list is to always block.", - Help: `Format: - Permission: - "+": permit - "-": block - Host Matching: - IP, CIDR, Country Code, ASN, Filterlist, "*" for any - Domains: - "example.com": exact match - ".example.com": exact match + subdomains - "*xample.com": prefix wildcard - "example.*": suffix wildcard - "*example*": prefix and suffix wildcard - Protocol and Port Matching (optional): - / - -Examples: - + .example.com */HTTP - - .example.com - + 192.168.0.1/24 - - L:MAL - - AS0 - + AT - - *`, + Name: "Service Endpoint Filter List", + Key: CfgOptionServiceEndpointsKey, + Description: "Filter incoming connections by matching the source endpoint. Network Scope restrictions and the inbound permission still apply. Also not that the implicit default action of this list is to always block.", + Help: filterListHelp, Order: cfgOptionServiceEndpointsOrder, OptType: config.OptTypeStringArray, - DefaultValue: []string{}, + DefaultValue: []string{"+ Localhost"}, ExternalOptType: "endpoint list", ValidationRegex: `^(\+|\-) [A-z0-9\.:\-*/]+( [A-z0-9/]+)?$`, }) diff --git a/profile/endpoints/endpoint-scopes.go b/profile/endpoints/endpoint-scopes.go new file mode 100644 index 00000000..1c73aebe --- /dev/null +++ b/profile/endpoints/endpoint-scopes.go @@ -0,0 +1,112 @@ +package endpoints + +import ( + "strings" + + "github.com/safing/portmaster/network/netutils" + + "github.com/safing/portmaster/intel" +) + +const ( + scopeLocalhost = 1 + scopeLocalhostName = "Localhost" + scopeLocalhostMatcher = "localhost" + + scopeLAN = 2 + scopeLANName = "LAN" + scopeLANMatcher = "lan" + + scopeInternet = 4 + scopeInternetName = "Internet" + scopeInternetMatcher = "internet" +) + +// EndpointScope matches network scopes. +type EndpointScope struct { + EndpointBase + + scopes uint8 +} + +// Localhost +// LAN +// Internet + +// Matches checks whether the given entity matches this endpoint definition. +func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) { + if entity.IP == nil { + return Undeterminable, nil + } + + classification := netutils.ClassifyIP(entity.IP) + var scope uint8 + switch classification { + case netutils.HostLocal: + scope = scopeLocalhost + case netutils.LinkLocal: + scope = scopeLAN + case netutils.SiteLocal: + scope = scopeLAN + case netutils.Global: + scope = scopeInternet + case netutils.LocalMulticast: + scope = scopeLAN + case netutils.GlobalMulticast: + scope = scopeInternet + } + + if ep.scopes&scope > 0 { + return ep.match(ep, entity, ep.Scopes(), "scope matches") + } + return NoMatch, nil +} + +// Scopes returns the string representation of all scopes. +func (ep *EndpointScope) Scopes() string { + if ep.scopes == 3 || ep.scopes > 4 { + // single scope + switch ep.scopes { + case scopeLocalhost: + return scopeLocalhostName + case scopeLAN: + return scopeLANName + case scopeInternet: + return scopeInternetName + } + } + + // multiple scopes + var s []string + if ep.scopes&scopeLocalhost > 0 { + s = append(s, scopeLocalhostName) + } + if ep.scopes&scopeLAN > 0 { + s = append(s, scopeLANName) + } + if ep.scopes&scopeInternet > 0 { + s = append(s, scopeInternetName) + } + return strings.Join(s, ",") +} + +func (ep *EndpointScope) String() string { + return ep.renderPPP(ep.Scopes()) +} + +func parseTypeScope(fields []string) (Endpoint, error) { + ep := &EndpointScope{} + for _, val := range strings.Split(strings.ToLower(fields[1]), ",") { + switch val { + case scopeLocalhostMatcher: + ep.scopes &= scopeLocalhost + case scopeLANMatcher: + ep.scopes &= scopeLAN + case scopeInternetMatcher: + ep.scopes &= scopeInternet + default: + return nil, nil + } + } + return ep.parsePPP(ep, fields) +} diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index 76847ac7..4e73d1d4 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -231,6 +231,10 @@ func parseEndpoint(value string) (endpoint Endpoint, err error) { if endpoint, err = parseTypeASN(fields); endpoint != nil || err != nil { return } + // scopes + if endpoint, err = parseTypeScope(fields); endpoint != nil || err != nil { + return + } // lists if endpoint, err = parseTypeList(fields); endpoint != nil || err != nil { return diff --git a/profile/endpoints/endpoints_test.go b/profile/endpoints/endpoints_test.go index 0eb4e2e1..ad23d352 100644 --- a/profile/endpoints/endpoints_test.go +++ b/profile/endpoints/endpoints_test.go @@ -342,7 +342,26 @@ func TestEndpointMatching(t *testing.T) { IP: net.ParseIP("151.101.1.164"), // nytimes.com }).Init(), NoMatch) + // Scope + + ep, err = parseEndpoint("+ Localhost,LAN") + if err != nil { + t.Fatal(err) + } + + testEndpointMatch(t, ep, (&intel.Entity{ + IP: net.ParseIP("192.168.0.1"), + }).Init(), Permitted) + testEndpointMatch(t, ep, (&intel.Entity{ + IP: net.ParseIP("151.101.1.164"), // nytimes.com + }).Init(), NoMatch) + // Lists + + ep, err = parseEndpoint("+ L:A,B,C") + if err != nil { + t.Fatal(err) + } // TODO: write test for lists matcher } From dd837e40e2713c1b0d587cb537565326211082b3 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 22:41:19 +0200 Subject: [PATCH 07/36] Create exec dir for safe working dir for processes --- ui/module.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ui/module.go b/ui/module.go index 76a130ef..43bf667a 100644 --- a/ui/module.go +++ b/ui/module.go @@ -3,6 +3,8 @@ package ui import ( "context" + "github.com/safing/portbase/dataroot" + resources "github.com/cookieo9/resources-go" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" @@ -27,6 +29,11 @@ func prep() error { } func start() error { + err := dataroot.Root().ChildDir("exec", 0777).Ensure() + if err != nil { + log.Warningf("ui: failed to create safe exec dir: %s", err) + } + return module.RegisterEventHook("ui", eventReload, "reload assets", reloadUI) } From 53eb309e72fa1220f94a9801749b6a4f0b2ddc96 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 22:43:06 +0200 Subject: [PATCH 08/36] Add TLS resolver connection reusing and pooling Also, fix caching issues and add more tests --- ...{block_detection.go => block-detection.go} | 0 resolver/clients.go | 176 ++++++++++++++--- resolver/mdns.go | 29 ++- resolver/namerecord.go | 2 +- resolver/namerecord_test.go | 27 +++ resolver/pooling_test.go | 184 ++++++++++++++++++ resolver/resolve.go | 6 +- resolver/resolver.go | 53 ++++- resolver/resolvers.go | 49 +++-- resolver/reverse.go | 4 +- resolver/rrcache_test.go | 41 ++++ 11 files changed, 510 insertions(+), 61 deletions(-) rename resolver/{block_detection.go => block-detection.go} (100%) create mode 100644 resolver/namerecord_test.go create mode 100644 resolver/pooling_test.go create mode 100644 resolver/rrcache_test.go diff --git a/resolver/block_detection.go b/resolver/block-detection.go similarity index 100% rename from resolver/block_detection.go rename to resolver/block-detection.go diff --git a/resolver/clients.go b/resolver/clients.go index 6d1ad4b2..096f2af3 100644 --- a/resolver/clients.go +++ b/resolver/clients.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "crypto/tls" "net" "sync" @@ -9,6 +10,12 @@ import ( "github.com/miekg/dns" ) +const ( + defaultClientTTL = 5 * time.Minute + defaultRequestTimeout = 5 * time.Second + connectionEOLGracePeriod = 10 * time.Second +) + var ( localAddrFactory func(network string) net.Addr ) @@ -27,21 +34,72 @@ func getLocalAddr(network string) net.Addr { return nil } -type clientManager struct { - dnsClient *dns.Client - factory func() *dns.Client +type dnsClientManager struct { + lock sync.Mutex - lock sync.Mutex - refreshAfter time.Time - ttl time.Duration // force refresh of connection to reduce traceability + // set by creator + serverAddress string + ttl time.Duration // force refresh of connection to reduce traceability + factory func() *dns.Client + + // internal + pool []*dnsClient } -func newDNSClientManager(_ *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // new client for every request, as we need to randomize the port +type dnsClient struct { + mgr *dnsClientManager + + inUse bool + useUntil time.Time + dead bool + inPool bool + poolIndex int + + client *dns.Client + conn *dns.Conn +} + +// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). +func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { + if dc.conn == nil { + dc.conn, err = dc.client.Dial(dc.mgr.serverAddress) + if err != nil { + return nil, false, err + } + return dc.conn, true, nil + } + return dc.conn, false, nil +} + +func (dc *dnsClient) done() { + dc.mgr.lock.Lock() + defer dc.mgr.lock.Unlock() + + dc.inUse = false +} + +func (dc *dnsClient) destroy() { + dc.mgr.lock.Lock() + dc.inUse = true // block from being used + dc.dead = true // abort cleaning + if dc.inPool { + dc.inPool = false + dc.mgr.pool[dc.poolIndex] = nil + } + dc.mgr.lock.Unlock() + + if dc.conn != nil { + _ = dc.conn.Close() + } +} + +func newDNSClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: 0, // new client for every request, as we need to randomize the port factory: func() *dns.Client { return &dns.Client{ - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("udp"), }, @@ -50,25 +108,27 @@ func newDNSClientManager(_ *Resolver) *clientManager { } } -func newTCPClientManager(_ *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) +func newTCPClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: defaultClientTTL, factory: func() *dns.Client { return &dns.Client{ Net: "tcp", - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), - KeepAlive: 15 * time.Second, + KeepAlive: defaultClientTTL, }, } }, } } -func newTLSClientManager(resolver *Resolver) *clientManager { - return &clientManager{ - ttl: 0, // TODO: build a custom client that can reuse connections to some degree (performance / privacy tradeoff) +func newTLSClientManager(resolver *Resolver) *dnsClientManager { + return &dnsClientManager{ + serverAddress: resolver.ServerAddress, + ttl: defaultClientTTL, factory: func() *dns.Client { return &dns.Client{ Net: "tcp-tls", @@ -77,24 +137,90 @@ func newTLSClientManager(resolver *Resolver) *clientManager { ServerName: resolver.VerifyDomain, // TODO: use portbase rng }, - Timeout: 5 * time.Second, + Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), - KeepAlive: 15 * time.Second, + KeepAlive: defaultClientTTL, }, } }, } } -func (cm *clientManager) getDNSClient() *dns.Client { +func (cm *dnsClientManager) getDNSClient() *dnsClient { cm.lock.Lock() defer cm.lock.Unlock() - if cm.dnsClient == nil || cm.ttl == 0 || time.Now().After(cm.refreshAfter) { - cm.dnsClient = cm.factory() - cm.refreshAfter = time.Now().Add(cm.ttl) + // return new immediately if a new client should be used for every request + if cm.ttl == 0 { + return &dnsClient{ + mgr: cm, + client: cm.factory(), + } } - return cm.dnsClient + // get first unused from pool + now := time.Now().UTC() + for _, dc := range cm.pool { + if dc != nil && !dc.inUse && now.Before(dc.useUntil) { + dc.inUse = true + return dc + } + } + + // no available in pool, create new + newClient := &dnsClient{ + mgr: cm, + inUse: true, + useUntil: now.Add(cm.ttl), + inPool: true, + client: cm.factory(), + } + newClient.startCleaner() + + // find free spot in pool + for poolIndex, dc := range cm.pool { + if dc == nil { + cm.pool[poolIndex] = newClient + newClient.poolIndex = poolIndex + return newClient + } + } + + // append to pool + cm.pool = append(cm.pool, newClient) + newClient.poolIndex = len(cm.pool) - 1 + // TODO: shrink pool again? + + return newClient +} + +// startCleaner waits for EOL of the client and then removes it from the pool. +func (dc *dnsClient) startCleaner() { + // While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone. + module.StartWorker("dns client cleanup", func(ctx context.Context) error { + select { + case <-time.After(dc.mgr.ttl + time.Second): + dc.mgr.lock.Lock() + cleanNow := dc.dead || !dc.inUse + dc.mgr.lock.Unlock() + + if cleanNow { + dc.destroy() + return nil + } + case <-ctx.Done(): + // give a short time before kill for graceful request completion + time.Sleep(100 * time.Millisecond) + } + + // wait for grace period to end, then kill + select { + case <-time.After(connectionEOLGracePeriod): + case <-ctx.Done(): + } + + dc.destroy() + return nil + }) } diff --git a/resolver/mdns.go b/resolver/mdns.go index a8ba1ee5..b8595d67 100644 --- a/resolver/mdns.go +++ b/resolver/mdns.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/safing/portmaster/network/netutils" + "github.com/miekg/dns" "github.com/safing/portbase/log" @@ -29,10 +31,11 @@ var ( questionsLock sync.Mutex mDNSResolver = &Resolver{ - Server: ServerSourceMDNS, - ServerType: ServerTypeDNS, - Source: ServerSourceMDNS, - Conn: &mDNSResolverConn{}, + Server: ServerSourceMDNS, + ServerType: ServerTypeDNS, + ServerIPScope: netutils.SiteLocal, + Source: ServerSourceMDNS, + Conn: &mDNSResolverConn{}, } ) @@ -189,15 +192,21 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { // get entry from database if saveFullRequest { + // get from database rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype)) + // if we have no cached entry, or it has been updated less more than two seconds ago, or if it expired: + // create new and do not append if err != nil || rrCache.updated < time.Now().Add(-2*time.Second).Unix() || rrCache.TTL < time.Now().Unix() { rrCache = &RRCache{ - Domain: question.Name, - Question: dns.Type(question.Qtype), + Domain: question.Name, + Question: dns.Type(question.Qtype), + Server: mDNSResolver.Server, + ServerScope: mDNSResolver.ServerIPScope, } } } + // add all entries to RRCache for _, entry := range message.Answer { if strings.HasSuffix(entry.Header().Name, ".local.") || domainInScope(entry.Header().Name, localReverseScopes) { if saveFullRequest { @@ -289,9 +298,11 @@ func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error { continue } rrCache = &RRCache{ - Domain: v.Header().Name, - Question: dns.Type(v.Header().Class), - Answer: []dns.RR{v}, + Domain: v.Header().Name, + Question: dns.Type(v.Header().Class), + Answer: []dns.RR{v}, + Server: mDNSResolver.Server, + ServerScope: mDNSResolver.ServerIPScope, } rrCache.Clean(60) err := rrCache.Save() diff --git a/resolver/namerecord.go b/resolver/namerecord.go index d94beaaa..1a594e8f 100644 --- a/resolver/namerecord.go +++ b/resolver/namerecord.go @@ -12,7 +12,7 @@ import ( var ( recordDatabase = database.NewInterface(&database.Options{ AlwaysSetRelativateExpiry: 2592000, // 30 days - CacheSize: 128, + CacheSize: 256, }) ) diff --git a/resolver/namerecord_test.go b/resolver/namerecord_test.go new file mode 100644 index 00000000..f0e21a37 --- /dev/null +++ b/resolver/namerecord_test.go @@ -0,0 +1,27 @@ +package resolver + +import "testing" + +func TestNameRecordStorage(t *testing.T) { + testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com." + testQuestion := "A" + + testNameRecord := &NameRecord{ + Domain: testDomain, + Question: testQuestion, + } + + err := testNameRecord.Save() + if err != nil { + t.Fatal(err) + } + + r, err := GetNameRecord(testDomain, testQuestion) + if err != nil { + t.Fatal(err) + } + + if r.Domain != testDomain || r.Question != testQuestion { + t.Fatal("mismatch") + } +} diff --git a/resolver/pooling_test.go b/resolver/pooling_test.go new file mode 100644 index 00000000..dc341f33 --- /dev/null +++ b/resolver/pooling_test.go @@ -0,0 +1,184 @@ +package resolver + +import ( + "sync" + "testing" + + "github.com/miekg/dns" +) + +var ( + domainFeed = make(chan string) +) + +func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) { + dnsClient := brc.clientManager.getDNSClient() + + // create query + dnsQuery := new(dns.Msg) + dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) + + // get connection + conn, new, err := dnsClient.getConn() + if err != nil { + t.Fatalf("failed to connect: %s", err) //nolint:staticcheck + } + + // query server + reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn) + if err != nil { + t.Fatal(err) //nolint:staticcheck + } + if reply == nil { + t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck + } + + t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl) + dnsClient.done() + wg.Done() +} + +func TestClientPooling(t *testing.T) { + // skip if short - this test depends on the Internet and might fail randomly + if testing.Short() { + t.Skip() + } + + go feedDomains() + + // create separate resolver for this test + resolver, _, err := createResolver("dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", "config") + if err != nil { + t.Fatal(err) + } + brc := resolver.Conn.(*BasicResolverConn) + + wg := &sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(10) + for i := 0; i < 10; i++ { + go testQuery(t, wg, brc, &Query{ + FQDN: <-domainFeed, + QType: dns.Type(dns.TypeA), + }) + } + wg.Wait() + if len(brc.clientManager.pool) != 10 { + t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool)) + } + } +} + +func feedDomains() { + for { + for _, domain := range poolingTestDomains { + domainFeed <- domain + } + } +} + +// Data + +var ( + poolingTestDomains = []string{ + "facebook.com.", + "google.com.", + "youtube.com.", + "twitter.com.", + "instagram.com.", + "linkedin.com.", + "microsoft.com.", + "apple.com.", + "wikipedia.org.", + "plus.google.com.", + "en.wikipedia.org.", + "googletagmanager.com.", + "youtu.be.", + "adobe.com.", + "vimeo.com.", + "pinterest.com.", + "itunes.apple.com.", + "play.google.com.", + "maps.google.com.", + "goo.gl.", + "wordpress.com.", + "blogspot.com.", + "bit.ly.", + "github.com.", + "player.vimeo.com.", + "amazon.com.", + "wordpress.org.", + "docs.google.com.", + "yahoo.com.", + "mozilla.org.", + "tumblr.com.", + "godaddy.com.", + "flickr.com.", + "parked-content.godaddy.com.", + "drive.google.com.", + "support.google.com.", + "apache.org.", + "gravatar.com.", + "europa.eu.", + "qq.com.", + "w3.org.", + "nytimes.com.", + "reddit.com.", + "macromedia.com.", + "get.adobe.com.", + "soundcloud.com.", + "sourceforge.net.", + "sites.google.com.", + "nih.gov.", + "amazonaws.com.", + "t.co.", + "support.microsoft.com.", + "forbes.com.", + "theguardian.com.", + "cnn.com.", + "github.io.", + "bbc.co.uk.", + "dropbox.com.", + "whatsapp.com.", + "medium.com.", + "creativecommons.org.", + "www.ncbi.nlm.nih.gov.", + "httpd.apache.org.", + "archive.org.", + "ec.europa.eu.", + "php.net.", + "apps.apple.com.", + "weebly.com.", + "support.apple.com.", + "weibo.com.", + "wixsite.com.", + "issuu.com.", + "who.int.", + "paypal.com.", + "m.facebook.com.", + "oracle.com.", + "msn.com.", + "gnu.org.", + "tinyurl.com.", + "reuters.com.", + "l.facebook.com.", + "cloudflare.com.", + "wsj.com.", + "washingtonpost.com.", + "domainmarket.com.", + "imdb.com.", + "bbc.com.", + "bing.com.", + "accounts.google.com.", + "vk.com.", + "api.whatsapp.com.", + "opera.com.", + "cdc.gov.", + "slideshare.net.", + "wpa.qq.com.", + "harvard.edu.", + "mit.edu.", + "code.google.com.", + "wikimedia.org.", + } +) diff --git a/resolver/resolve.go b/resolver/resolve.go index f13d07c2..57c0c967 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -114,6 +114,7 @@ func Resolve(ctx context.Context, q *Query) (rrCache *RRCache, err error) { rrCache.MixAnswers() return rrCache, nil } + log.Tracer(ctx).Debugf("resolver: waited for another %s%s query, but cache missed!", q.FQDN, q.QType) // if cache is still empty or non-compliant, go ahead and just query } else { // we are the first! @@ -132,14 +133,14 @@ func checkCache(ctx context.Context, q *Query) *RRCache { if err != nil { if err != database.ErrNotFound { log.Tracer(ctx).Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) - log.Warningf("resolver: getting RRCache %s%s from database failed: %s", q.FQDN, q.QType.String(), err) } return nil } // get resolver that rrCache was resolved with - resolver := getResolverByIDWithLocking(rrCache.Server) + resolver := getActiveResolverByIDWithLocking(rrCache.Server) if resolver == nil { + log.Tracer(ctx).Debugf("resolver: ignoring RRCache %s%s because source server %s has been removed", q.FQDN, q.QType.String(), rrCache.Server) return nil } @@ -165,6 +166,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache { }) } + log.Tracer(ctx).Tracef("resolver: using cached RR (expires in %s)", time.Until(time.Unix(rrCache.TTL, 0))) return rrCache } diff --git a/resolver/resolver.go b/resolver/resolver.go index 65155fab..244b0c57 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "errors" "net" "sync" "time" @@ -92,7 +93,7 @@ type BasicResolverConn struct { sync.Mutex // for lastFail resolver *Resolver - clientManager *clientManager + clientManager *dnsClientManager lastFail time.Time } @@ -126,18 +127,41 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er // start var reply *dns.Msg + var ttl time.Duration var err error - for i := 0; i < 3; i++ { + var conn *dns.Conn + var new bool + var i int - // log query time - // qStart := time.Now() - reply, _, err = brc.clientManager.getDNSClient().Exchange(dnsQuery, resolver.ServerAddress) - // log.Tracef("resolver: query to %s took %s", resolver.Server, time.Now().Sub(qStart)) + for ; i < 5; i++ { + + // first get connection + dc := brc.clientManager.getDNSClient() + conn, new, err = dc.getConn() + if err != nil { + log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err) + // remove client from pool + dc.destroy() + // try again + continue + } + if new { + log.Tracer(ctx).Tracef("resolver: created new connection to %s", resolver.ServerAddress) + } else { + log.Tracer(ctx).Tracef("resolver: reusing connection to %s", resolver.ServerAddress) + } + + // query server + reply, ttl, err = dc.client.ExchangeWithConn(dnsQuery, conn) + log.Tracer(ctx).Tracef("resolver: query took %s", ttl) // error handling if err != nil { log.Tracer(ctx).Tracef("resolver: query to %s encountered error: %s", resolver.Server, err) + // remove client from pool + dc.destroy() + // TODO: handle special cases // 1. connect: network is unreachable // 2. timeout @@ -148,13 +172,23 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er // temporary error if nerr, ok := err.(net.Error); ok && nerr.Timeout() { log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server) + // try again continue } // permanent error break + } else if reply == nil { + // remove client from pool + dc.destroy() + + log.Errorf("resolver: successful query for %s%s to %s, but reply was nil", q.FQDN, q.QType, resolver.Server) + return nil, errors.New("internal error") } + // make client available again + dc.done() + if resolver.IsBlockedUpstream(reply) { return nil, &BlockedUpstreamError{resolver.GetName()} } @@ -166,12 +200,15 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er if err != nil { return nil, err // TODO: mark as failed + } else if reply == nil { + log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), i+1) + return nil, errors.New("internal error") } // hint network environment at successful connection netenv.ReportSuccessfulConnection() - new := &RRCache{ + newRecord := &RRCache{ Domain: q.FQDN, Question: q.QType, Answer: reply.Answer, @@ -182,5 +219,5 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er } // TODO: check if reply.Answer is valid - return new, nil + return newRecord, nil } diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 00ad0d0e..5d65fed3 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -25,7 +25,7 @@ var ( globalResolvers []*Resolver // all (global) resolvers localResolvers []*Resolver // all resolvers that are in site-local or link-local IP ranges localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope - allResolvers map[string]*Resolver // lookup map of all resolvers + activeResolvers map[string]*Resolver // lookup map of all resolvers resolversLock sync.RWMutex dupReqMap = make(map[string]*sync.WaitGroup) @@ -41,11 +41,11 @@ func indexOfScope(domain string, list []*Scope) int { return -1 } -func getResolverByIDWithLocking(server string) *Resolver { - resolversLock.Lock() - defer resolversLock.Unlock() +func getActiveResolverByIDWithLocking(server string) *Resolver { + resolversLock.RLock() + defer resolversLock.RUnlock() - resolver, ok := allResolvers[server] + resolver, ok := activeResolvers[server] if ok { return resolver } @@ -62,7 +62,7 @@ func formatIPAndPort(ip net.IP, port uint16) string { return address } -func clientManagerFactory(serverType string) func(*Resolver) *clientManager { +func clientManagerFactory(serverType string) func(*Resolver) *dnsClientManager { switch serverType { case ServerTypeDNS: return newDNSClientManager @@ -152,8 +152,8 @@ func configureSearchDomains(resolver *Resolver, searches []string) { } } -func getConfiguredResolvers() (resolvers []*Resolver) { - for _, server := range configuredNameServers() { +func getConfiguredResolvers(list []string) (resolvers []*Resolver) { + for _, server := range list { resolver, skip, err := createResolver(server, "config") if err != nil { // TODO(ppacher): module error @@ -199,19 +199,40 @@ func loadResolvers() { defer resolversLock.Unlock() newResolvers := append( - getConfiguredResolvers(), + getConfiguredResolvers(configuredNameServers()), getSystemResolvers()..., ) - // save resolvers - globalResolvers = newResolvers - if len(globalResolvers) == 0 { - log.Criticalf("resolver: no (valid) dns servers found in configuration and system") - // TODO(module error) + if len(newResolvers) == 0 { + msg := "no (valid) dns servers found in (user) configuration or system, falling back to defaults" + log.Warningf("resolver: %s", msg) + module.Warning("no-valid-user-resolvers", msg) + + // load defaults directly, overriding config system + newResolvers = getConfiguredResolvers(defaultNameServers) + if len(newResolvers) == 0 { + msg = "no (valid) dns servers found in configuration or system" + log.Criticalf("resolver: %s", msg) + module.Error("no-valid-default-resolvers", msg) + return + } } + // save resolvers + globalResolvers = newResolvers + + // assing resolvers to scopes setLocalAndScopeResolvers(globalResolvers) + // set active resolvers (for cache validation) + // reset + activeResolvers = make(map[string]*Resolver) + // add + for _, resolver := range newResolvers { + activeResolvers[resolver.Server] = resolver + } + activeResolvers[mDNSResolver.Server] = mDNSResolver + // log global resolvers if len(globalResolvers) > 0 { log.Trace("resolver: loaded global resolvers:") diff --git a/resolver/reverse.go b/resolver/reverse.go index 0487cf44..c236818b 100644 --- a/resolver/reverse.go +++ b/resolver/reverse.go @@ -65,12 +65,12 @@ func ResolveIPAndValidate(ctx context.Context, ip string, securityLevel uint8) ( for _, rr := range rrCache.Answer { switch v := rr.(type) { case *dns.A: - log.Infof("A: %s", v.A.String()) + // log.Debugf("A: %s", v.A.String()) if ip == v.A.String() { return ptrName, nil } case *dns.AAAA: - log.Infof("AAAA: %s", v.AAAA.String()) + // log.Debugf("AAAA: %s", v.AAAA.String()) if ip == v.AAAA.String() { return ptrName, nil } diff --git a/resolver/rrcache_test.go b/resolver/rrcache_test.go new file mode 100644 index 00000000..8aaa3094 --- /dev/null +++ b/resolver/rrcache_test.go @@ -0,0 +1,41 @@ +package resolver + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestCaching(t *testing.T) { + testDomain := "Mk35mMqOWEHXSMk11MYcbjLOjTE8PQvDiAVUxf4BvwtgR.example.com." + testQuestion := "A" + + testNameRecord := &NameRecord{ + Domain: testDomain, + Question: testQuestion, + } + + err := testNameRecord.Save() + if err != nil { + t.Fatal(err) + } + + rrCache, err := GetRRCache(testDomain, dns.Type(dns.TypeA)) + if err != nil { + t.Fatal(err) + } + + err = rrCache.Save() + if err != nil { + t.Fatal(err) + } + + rrCache2, err := GetRRCache(testDomain, dns.Type(dns.TypeA)) + if err != nil { + t.Fatal(err) + } + + if rrCache2.Domain != rrCache.Domain { + t.Fatal("something very is wrong") + } +} From 652518e5273ec8ba97415a5f4dcaf3687e63dd2f Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 22:43:15 +0200 Subject: [PATCH 09/36] Save failed processes --- process/process.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/process/process.go b/process/process.go index 8ef1ad73..16d04cb3 100644 --- a/process/process.go +++ b/process/process.go @@ -49,7 +49,8 @@ type Process struct { FirstSeen int64 LastSeen int64 - Virtual bool // This process is either merged into another process or is not needed. + Virtual bool // This process is either merged into another process or is not needed. + Error string // Cache errors } // Profile returns the assigned layered profile. @@ -94,6 +95,7 @@ func GetOrFindPrimaryProcess(ctx context.Context, pid int) (*Process, error) { parentProcess, err := loadProcess(ctx, process.ParentPid) if err != nil { log.Tracer(ctx).Tracef("process: could not get parent of %d: %d: %s", process.Pid, process.ParentPid, err) + saveFailedProcess(process.ParentPid, err.Error()) return process, nil } @@ -226,13 +228,7 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { pInfo, err := processInfo.NewProcess(int32(pid)) if err != nil { - // TODO: remove this workaround as soon as NewProcess really returns an error on windows when the process does not exist - // Issue: https://github.com/shirou/gopsutil/issues/729 - _, err = pInfo.Name() - if err != nil { - // process does not exists - return nil, err - } + return nil, err } // UID @@ -375,3 +371,14 @@ func loadProcess(ctx context.Context, pid int) (*Process, error) { new.Save() return new, nil } + +func saveFailedProcess(pid int, err string) { + failed := &Process{ + Pid: pid, + FirstSeen: time.Now().Unix(), + Virtual: true, // not needed + Error: err, + } + + failed.Save() +} From 75d7a91843b0a7113c10d55359fcb7aaaa801e36 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 22:44:09 +0200 Subject: [PATCH 10/36] Remove intermediate fstree folder from log dirs The logs dir will need to be supplied in a special way anyway --- pmctl/logs.go | 6 +++--- pmctl/run.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pmctl/logs.go b/pmctl/logs.go index 7ee1576d..03d05678 100644 --- a/pmctl/logs.go +++ b/pmctl/logs.go @@ -74,7 +74,7 @@ func finalizeLogFile(logFile *os.File, logFilePath string) { func initControlLogFile() *os.File { // check logging dir - logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + logFileBasePath := filepath.Join(logsRoot.Path, "control") err := logsRoot.EnsureAbsPath(logFileBasePath) if err != nil { log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) @@ -93,7 +93,7 @@ func logControlError(cErr error) { } // check logging dir - logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + logFileBasePath := filepath.Join(logsRoot.Path, "control") err := logsRoot.EnsureAbsPath(logFileBasePath) if err != nil { log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) @@ -114,7 +114,7 @@ func logControlError(cErr error) { //nolint:deadcode,unused // TODO func logControlStack() { // check logging dir - logFileBasePath := filepath.Join(logsRoot.Path, "fstree", "control") + logFileBasePath := filepath.Join(logsRoot.Path, "control") err := logsRoot.EnsureAbsPath(logFileBasePath) if err != nil { log.Printf("failed to check/create log file folder %s: %s\n", logFileBasePath, err) diff --git a/pmctl/run.go b/pmctl/run.go index 53133697..432bfa3d 100644 --- a/pmctl/run.go +++ b/pmctl/run.go @@ -227,7 +227,7 @@ func execute(opts *Options, args []string) (cont bool, err error) { // log files var logFile, errorFile *os.File - logFileBasePath := filepath.Join(logsRoot.Path, "fstree", opts.ShortIdentifier) + logFileBasePath := filepath.Join(logsRoot.Path, opts.ShortIdentifier) err = logsRoot.EnsureAbsPath(logFileBasePath) if err != nil { log.Printf("failed to check/create log file dir %s: %s\n", logFileBasePath, err) From ca8b36cbc743f9e05510038d44514061dadbfcdc Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2020 22:44:24 +0200 Subject: [PATCH 11/36] Fix FQDN validation and add tests --- network/netutils/cleandns.go | 48 +++++++++++++++++++++++++++++-- network/netutils/cleandns_test.go | 43 +++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 network/netutils/cleandns_test.go diff --git a/network/netutils/cleandns.go b/network/netutils/cleandns.go index f9487664..7df738d0 100644 --- a/network/netutils/cleandns.go +++ b/network/netutils/cleandns.go @@ -2,13 +2,57 @@ package netutils import ( "regexp" + + "github.com/miekg/dns" ) var ( - cleanDomainRegex = regexp.MustCompile(`^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\.[a-z]{2,}\.)$`) + cleanDomainRegex = regexp.MustCompile( + `^` + // match beginning + `(` + // start subdomain group + `(xn--)?` + // idn prefix + `[a-z0-9_-]{1,63}` + // main chunk + `\.` + // ending with a dot + `)*` + // end subdomain group, allow any number of subdomains + `(xn--)?` + // TLD idn prefix + `[a-z0-9_-]{2,63}` + // TLD main chunk with at least two characters + `\.` + // ending with a dot + `$`, // match end + ) ) // IsValidFqdn returns whether the given string is a valid fqdn. func IsValidFqdn(fqdn string) bool { - return cleanDomainRegex.MatchString(fqdn) + // root zone + if fqdn == "." { + return true + } + + // check max length + if len(fqdn) > 256 { + return false + } + + // check with regex + if !cleanDomainRegex.MatchString(fqdn) { + return false + } + + // check with miegk/dns + + // IsFqdn checks if a domain name is fully qualified. + if !dns.IsFqdn(fqdn) { + return false + } + + // IsDomainName checks if s is a valid domain name, it returns the number of + // labels and true, when a domain name is valid. Note that non fully qualified + // domain name is considered valid, in this case the last label is counted in + // the number of labels. When false is returned the number of labels is not + // defined. Also note that this function is extremely liberal; almost any + // string is a valid domain name as the DNS is 8 bit protocol. It checks if each + // label fits in 63 characters and that the entire name will fit into the 255 + // octet wire format limit. + _, ok := dns.IsDomainName(fqdn) + return ok } diff --git a/network/netutils/cleandns_test.go b/network/netutils/cleandns_test.go new file mode 100644 index 00000000..4f0dacb0 --- /dev/null +++ b/network/netutils/cleandns_test.go @@ -0,0 +1,43 @@ +package netutils + +import "testing" + +func testDomainValidity(t *testing.T, domain string, isValid bool) { + if IsValidFqdn(domain) != isValid { + t.Errorf("domain %s failed check: was valid=%v, expected valid=%v", domain, IsValidFqdn(domain), isValid) + } +} + +func TestDNSValidation(t *testing.T) { + // valid + testDomainValidity(t, ".", true) + testDomainValidity(t, "at.", true) + testDomainValidity(t, "orf.at.", true) + testDomainValidity(t, "www.orf.at.", true) + testDomainValidity(t, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.u.v.x.y.z.example.org.", true) + testDomainValidity(t, "a_a.com.", true) + testDomainValidity(t, "a-a.com.", true) + testDomainValidity(t, "a_a.com.", true) + testDomainValidity(t, "a-a.com.", true) + testDomainValidity(t, "xn--a.com.", true) + testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasd.com.", true) + + // maybe valid + testDomainValidity(t, "-.com.", true) + testDomainValidity(t, "_.com.", true) + testDomainValidity(t, "a_.com.", true) + testDomainValidity(t, "a-.com.", true) + testDomainValidity(t, "_a.com.", true) + testDomainValidity(t, "-a.com.", true) + + // invalid + testDomainValidity(t, ".com.", false) + testDomainValidity(t, ".com.", false) + testDomainValidity(t, "xn--asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false) + testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false) + testDomainValidity(t, "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf.com.", false) + testDomainValidity(t, "asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.asdf.as.com.", false) + + // real world examples + testDomainValidity(t, "iuqerfsodp9ifjaposdfjhgosurijfaewrwergwea.com.", true) +} From 886d30278f3f682ba4022552721382ef4eb0ea61 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 16 May 2020 22:43:42 +0200 Subject: [PATCH 12/36] Fix IPv4 parsing from windows state tables --- firewall/interception/windowskext/handler.go | 1 + network/iphelper/tables.go | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 254c046b..857a5bcb 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -113,6 +113,7 @@ func Handler(packets chan packet.Packet) { } } +// convertIPv4 as needed for data from the kernel func convertIPv4(input [4]uint32) net.IP { addressBuf := make([]byte, 4) binary.BigEndian.PutUint32(addressBuf, input[0]) diff --git a/network/iphelper/tables.go b/network/iphelper/tables.go index b2ea8286..2dcaf5c1 100644 --- a/network/iphelper/tables.go +++ b/network/iphelper/tables.go @@ -297,8 +297,9 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so return connections, binds, nil } +// convertIPv4 as needed for iphlpapi.dll func convertIPv4(input uint32) net.IP { addressBuf := make([]byte, 4) - binary.BigEndian.PutUint32(addressBuf, input) + binary.LittleEndian.PutUint32(addressBuf, input) return net.IP(addressBuf) } From e473c0e228988cf9f072f2f9b749d7fa82996c4b Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 16 May 2020 22:44:52 +0200 Subject: [PATCH 13/36] Remove iphelper test package You can get the same thing by doing this in the iphelper pkg dir: GOOS=windows go test -c --- network/iphelper/test/main.go | 62 ----------------------------------- 1 file changed, 62 deletions(-) delete mode 100644 network/iphelper/test/main.go diff --git a/network/iphelper/test/main.go b/network/iphelper/test/main.go deleted file mode 100644 index 5234fbb4..00000000 --- a/network/iphelper/test/main.go +++ /dev/null @@ -1,62 +0,0 @@ -// +build windows - -package main - -import ( - "fmt" - - "github.com/safing/portmaster/process/iphelper" -) - -func main() { - iph, err := iphelper.New() - if err != nil { - panic(err) - } - - fmt.Printf("TCP4\n") - conns, lConns, err := iph.GetTables(iphelper.TCP, iphelper.IPv4) - if err != nil { - panic(err) - } - fmt.Printf("Connections:\n") - for _, conn := range conns { - fmt.Printf("%s\n", conn) - } - fmt.Printf("Listeners:\n") - for _, conn := range lConns { - fmt.Printf("%s\n", conn) - } - - fmt.Printf("\nTCP6\n") - conns, lConns, err = iph.GetTables(iphelper.TCP, iphelper.IPv6) - if err != nil { - panic(err) - } - fmt.Printf("Connections:\n") - for _, conn := range conns { - fmt.Printf("%s\n", conn) - } - fmt.Printf("Listeners:\n") - for _, conn := range lConns { - fmt.Printf("%s\n", conn) - } - - fmt.Printf("\nUDP4\n") - _, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv4) - if err != nil { - panic(err) - } - for _, conn := range lConns { - fmt.Printf("%s\n", conn) - } - - fmt.Printf("\nUDP6\n") - _, lConns, err = iph.GetTables(iphelper.UDP, iphelper.IPv6) - if err != nil { - panic(err) - } - for _, conn := range lConns { - fmt.Printf("%s\n", conn) - } -} From 89317b8848365a12582759d504078a00cd19c7cb Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 18 May 2020 14:12:41 +0200 Subject: [PATCH 14/36] Disable time-triggered online check --- netenv/online-status.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/netenv/online-status.go b/netenv/online-status.go index b7542b3b..16a6c50d 100644 --- a/netenv/online-status.go +++ b/netenv/online-status.go @@ -201,20 +201,11 @@ func triggerOnlineStatusInvestigation() { func monitorOnlineStatus(ctx context.Context) error { for { - timeout := 5 * time.Minute - /* - if GetOnlineStatus() != StatusOnline { - timeout = time.Second - log.Debugf("checking online status again in %s because current status is %s", timeout, GetOnlineStatus()) - } - */ // wait for trigger select { case <-ctx.Done(): return nil case <-onlineStatusInvestigationTrigger: - - case <-time.After(timeout): } // enable waiting From 85c7fd4af7d810bc0872117a5f24d30d5a7c469c Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 18 May 2020 14:29:04 +0200 Subject: [PATCH 15/36] Improve udp connection attribution for broadcast and multicast packets --- network/state/lookup.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/network/state/lookup.go b/network/state/lookup.go index ade151a5..22baf62d 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/socket" ) @@ -160,12 +161,20 @@ func searchUDP( err error, ) { + isInboundMulticast := pktInbound && netutils.ClassifyIP(localIP) == netutils.LocalMulticast + // TODO: Currently broadcast/multicast scopes are not checked, so we might + // attribute an incoming broadcast/multicast packet to the wrong process if + // there are multiple processes listening on the same local port, but + // binding to different addresses. This highly unusual for clients. + // search until we find something for i := 0; i < 5; i++ { // search binds for _, socketInfo := range binds { if localPort == socketInfo.Local.Port && - (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { + (socketInfo.Local.IP[0] == 0 || // zero IP + isInboundMulticast || // inbound broadcast, multicast + localIP.Equal(socketInfo.Local.IP)) { // do not check direction if remoteIP/Port is not given if remotePort == 0 { From 11d3e15de41b6a31997dcc7e0dfcf55629ae073c Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 18 May 2020 15:23:10 +0200 Subject: [PATCH 16/36] Only enable BlockP2P in Extreme level by default --- profile/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profile/config.go b/profile/config.go index 60a63070..58894e86 100644 --- a/profile/config.go +++ b/profile/config.go @@ -295,7 +295,7 @@ Examples: Order: cfgOptionBlockP2POrder, OptType: config.OptTypeInt, ExternalOptType: "security level", - DefaultValue: status.SecurityLevelsAll, + DefaultValue: status.SecurityLevelExtreme, ValidationRegex: "^(4|6|7)$", }) if err != nil { From bdcf499f227f9a3b92b3d03423155af3fd764004 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 18 May 2020 15:25:08 +0200 Subject: [PATCH 17/36] Fix domain endpoint reason message --- profile/endpoints/endpoint-domain.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profile/endpoints/endpoint-domain.go b/profile/endpoints/endpoint-domain.go index fbd0dcf9..3cc3450f 100644 --- a/profile/endpoints/endpoint-domain.go +++ b/profile/endpoints/endpoint-domain.go @@ -31,7 +31,7 @@ type EndpointDomain struct { } func (ep *EndpointDomain) check(entity *intel.Entity, domain string) (EPResult, Reason) { - result, reason := ep.match(ep, entity, ep.Domain, "domain matches") + result, reason := ep.match(ep, entity, ep.OriginalValue, "domain matches") switch ep.MatchType { case domainMatchTypeExact: From 3adf52d19c3d20bb46b9666a98bae7f273be2437 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 18 May 2020 15:29:28 +0200 Subject: [PATCH 18/36] Lower priority of async dns queries They make take longer if there are network problems --- resolver/resolve.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resolver/resolve.go b/resolver/resolve.go index 57c0c967..385054a9 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -160,7 +160,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache { log.Tracer(ctx).Trace("resolver: serving from cache, requesting new") // resolve async - module.StartMediumPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error { + module.StartLowPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error { _, _ = resolveAndCache(ctx, q) return nil }) From 0036d2567286901e4d254c7c7923331df4b1e3f0 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 18 May 2020 15:33:05 +0200 Subject: [PATCH 19/36] Demote error and warning logging when process of packet could not be found --- network/connection.go | 4 ++-- process/find.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/network/connection.go b/network/connection.go index 1b99842b..0a6fb592 100644 --- a/network/connection.go +++ b/network/connection.go @@ -76,7 +76,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri false, // inbound, irrevelant ) if err != nil { - log.Warningf("network: failed to find process of dns request for %s: %s", fqdn, err) + log.Debugf("network: failed to find process of dns request for %s: %s", fqdn, err) proc = process.GetUnidentifiedProcess(ctx) } @@ -99,7 +99,7 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { // get Process proc, inbound, err := process.GetProcessByPacket(pkt) if err != nil { - log.Warningf("network: failed to find process of packet %s: %s", pkt, err) + log.Debugf("network: failed to find process of packet %s: %s", pkt, err) proc = process.GetUnidentifiedProcess(pkt.Ctx()) } diff --git a/process/find.go b/process/find.go index aa0a4071..936d9214 100644 --- a/process/find.go +++ b/process/find.go @@ -56,13 +56,13 @@ func GetProcessByEndpoints( var pid int pid, connInbound, err = state.Lookup(ipVersion, protocol, localIP, localPort, remoteIP, remotePort, pktInbound) if err != nil { - log.Tracer(ctx).Errorf("process: failed to find PID of connection: %s", err) + log.Tracer(ctx).Debugf("process: failed to find PID of connection: %s", err) return nil, connInbound, err } process, err = GetOrFindPrimaryProcess(ctx, pid) if err != nil { - log.Tracer(ctx).Errorf("process: failed to find (primary) process with PID: %s", err) + log.Tracer(ctx).Debugf("process: failed to find (primary) process with PID: %s", err) return nil, connInbound, err } From 7649859ba69441d9d5d7b4423956d333ac2bb925 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 18 May 2020 17:08:32 +0200 Subject: [PATCH 20/36] Switch connection state lookups to use the packet.Info struct Also, rename the Direction attribute on packet.Info to Inbound --- firewall/api.go | 19 ++--- firewall/interception/windowskext/handler.go | 8 +- firewall/master.go | 18 ++--- nameserver/takeover.go | 10 ++- network/connection.go | 20 ++--- network/packet/packet.go | 24 +++--- network/packet/packetinfo.go | 12 +-- network/state/lookup.go | 78 +++++++------------- network/state/udp.go | 9 ++- process/find.go | 36 +-------- 10 files changed, 95 insertions(+), 139 deletions(-) diff --git a/firewall/api.go b/firewall/api.go index d0e03f24..b435e7db 100644 --- a/firewall/api.go +++ b/firewall/api.go @@ -60,16 +60,17 @@ func apiAuthenticator(s *http.Server, r *http.Request) (grantAccess bool, err er var procsChecked []string // get process - proc, _, err := process.GetProcessByEndpoints( + proc, _, err := process.GetProcessByConnection( r.Context(), - packet.IPv4, - packet.TCP, - // switch reverse/local to get remote process - remoteIP, - remotePort, - localIP, - localPort, - false, + &packet.Info{ + Inbound: false, // outbound as we are looking for the process of the source address + Version: packet.IPv4, + Protocol: packet.TCP, + Src: remoteIP, // source as in the process we are looking for + SrcPort: remotePort, // source as in the process we are looking for + Dst: localIP, + DstPort: localPort, + }, ) if err != nil { return false, fmt.Errorf("failed to get process: %s", err) diff --git a/firewall/interception/windowskext/handler.go b/firewall/interception/windowskext/handler.go index 857a5bcb..97238324 100644 --- a/firewall/interception/windowskext/handler.go +++ b/firewall/interception/windowskext/handler.go @@ -62,7 +62,7 @@ func Handler(packets chan packet.Packet) { } info := new.Info() - info.Direction = packetInfo.direction > 0 + info.Inbound = packetInfo.direction > 0 info.InTunnel = false info.Protocol = packet.IPProtocol(packetInfo.protocol) @@ -76,7 +76,7 @@ func Handler(packets chan packet.Packet) { // IPs if info.Version == packet.IPv4 { // IPv4 - if info.Direction { + if info.Inbound { // Inbound info.Src = convertIPv4(packetInfo.remoteIP) info.Dst = convertIPv4(packetInfo.localIP) @@ -87,7 +87,7 @@ func Handler(packets chan packet.Packet) { } } else { // IPv6 - if info.Direction { + if info.Inbound { // Inbound info.Src = convertIPv6(packetInfo.remoteIP) info.Dst = convertIPv6(packetInfo.localIP) @@ -99,7 +99,7 @@ func Handler(packets chan packet.Packet) { } // Ports - if info.Direction { + if info.Inbound { // Inbound info.SrcPort = packetInfo.remotePort info.DstPort = packetInfo.localPort diff --git a/firewall/master.go b/firewall/master.go index d7b926c8..a512b202 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -91,15 +91,15 @@ func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { pktInfo := pkt.Info() if conn.Process().Pid >= 0 && pktInfo.Src.Equal(pktInfo.Dst) { // get PID - otherPid, _, err := state.Lookup( - pktInfo.Version, - pktInfo.Protocol, - pktInfo.RemoteIP(), - pktInfo.RemotePort(), - pktInfo.LocalIP(), - pktInfo.LocalPort(), - pktInfo.Direction, - ) + otherPid, _, err := state.Lookup(&packet.Info{ + Inbound: !pktInfo.Inbound, // we want to know the process on the other end + Version: pktInfo.Version, + Protocol: pktInfo.Protocol, + Src: pktInfo.Src, + SrcPort: pktInfo.SrcPort, + Dst: pktInfo.Dst, + DstPort: pktInfo.DstPort, + }) if err != nil { log.Warningf("filter: failed to find local peer process PID: %s", err) } else { diff --git a/nameserver/takeover.go b/nameserver/takeover.go index d5ede695..ecbea5cf 100644 --- a/nameserver/takeover.go +++ b/nameserver/takeover.go @@ -58,7 +58,15 @@ func checkForConflictingService() error { } func takeover(resolverIP net.IP) (int, error) { - pid, _, err := state.Lookup(0, packet.UDP, resolverIP, 53, nil, 0, false) + pid, _, err := state.Lookup(&packet.Info{ + Inbound: true, + Version: 0, // auto-detect + Protocol: packet.UDP, + Src: nil, // do not record direction + SrcPort: 0, // do not record direction + Dst: resolverIP, + DstPort: 53, + }) if err != nil { // there may be nothing listening on :53 return 0, nil diff --git a/network/connection.go b/network/connection.go index 0a6fb592..030c84b6 100644 --- a/network/connection.go +++ b/network/connection.go @@ -65,15 +65,17 @@ type Connection struct { //nolint:maligned // TODO: fix alignment // NewConnectionFromDNSRequest returns a new connection based on the given dns request. func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []string, ipVersion packet.IPVersion, localIP net.IP, localPort uint16) *Connection { // get Process - proc, _, err := process.GetProcessByEndpoints( + proc, _, err := process.GetProcessByConnection( ctx, - ipVersion, - packet.UDP, - localIP, - localPort, - dnsAddress, // this might not be correct, but it does not matter, as matching only occurs on the local address - dnsPort, - false, // inbound, irrevelant + &packet.Info{ + Inbound: false, // outbound as we are looking for the process of the source address + Version: ipVersion, + Protocol: packet.UDP, + Src: localIP, // source as in the process we are looking for + SrcPort: localPort, // source as in the process we are looking for + Dst: nil, // do not record direction + DstPort: 0, // do not record direction + }, ) if err != nil { log.Debugf("network: failed to find process of dns request for %s: %s", fqdn, err) @@ -97,7 +99,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri // NewConnectionFromFirstPacket returns a new connection based on the given packet. func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { // get Process - proc, inbound, err := process.GetProcessByPacket(pkt) + proc, inbound, err := process.GetProcessByConnection(pkt.Ctx(), pkt.Info()) if err != nil { log.Debugf("network: failed to find process of packet %s: %s", pkt, err) proc = process.GetUnidentifiedProcess(pkt.Ctx()) diff --git a/network/packet/packet.go b/network/packet/packet.go index 942dd215..8076ed69 100644 --- a/network/packet/packet.go +++ b/network/packet/packet.go @@ -36,22 +36,22 @@ func (pkt *Base) SetPacketInfo(packetInfo Info) { // SetInbound sets a the packet direction to inbound. This must only used when initializing the packet structure. func (pkt *Base) SetInbound() { - pkt.info.Direction = true + pkt.info.Inbound = true } // SetOutbound sets a the packet direction to outbound. This must only used when initializing the packet structure. func (pkt *Base) SetOutbound() { - pkt.info.Direction = false + pkt.info.Inbound = false } // IsInbound checks if the packet is inbound. func (pkt *Base) IsInbound() bool { - return pkt.info.Direction + return pkt.info.Inbound } // IsOutbound checks if the packet is outbound. func (pkt *Base) IsOutbound() bool { - return !pkt.info.Direction + return !pkt.info.Inbound } // HasPorts checks if the packet has a protocol that uses ports. @@ -80,13 +80,13 @@ func (pkt *Base) GetConnectionID() string { func (pkt *Base) createConnectionID() { if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { - if pkt.info.Direction { + if pkt.info.Inbound { pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) } else { pkt.connID = fmt.Sprintf("%d-%s-%d-%s-%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) } } else { - if pkt.info.Direction { + if pkt.info.Inbound { pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) } else { pkt.connID = fmt.Sprintf("%d-%s-%s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) @@ -105,7 +105,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I if pkt.info.Protocol != protocol { return false } - if pkt.info.Direction != remote { + if pkt.info.Inbound != remote { if !network.Contains(pkt.info.Src) { return false } @@ -131,7 +131,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I // Remote Src Dst // func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool { - if pkt.info.Direction != endpoint { + if pkt.info.Inbound != endpoint { if network.Contains(pkt.info.Src) { return true } @@ -152,12 +152,12 @@ func (pkt *Base) String() string { // FmtPacket returns the most important information about the packet as a string func (pkt *Base) FmtPacket() string { if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { - if pkt.info.Direction { + if pkt.info.Inbound { return fmt.Sprintf("IN %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Dst, pkt.info.DstPort, pkt.info.Src, pkt.info.SrcPort) } return fmt.Sprintf("OUT %s %s:%d <-> %s:%d", pkt.info.Protocol, pkt.info.Src, pkt.info.SrcPort, pkt.info.Dst, pkt.info.DstPort) } - if pkt.info.Direction { + if pkt.info.Inbound { return fmt.Sprintf("IN %s %s <-> %s", pkt.info.Protocol, pkt.info.Dst, pkt.info.Src) } return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) @@ -170,7 +170,7 @@ func (pkt *Base) FmtProtocol() string { // FmtRemoteIP returns the remote IP address as a string func (pkt *Base) FmtRemoteIP() string { - if pkt.info.Direction { + if pkt.info.Inbound { return pkt.info.Src.String() } return pkt.info.Dst.String() @@ -179,7 +179,7 @@ func (pkt *Base) FmtRemoteIP() string { // FmtRemotePort returns the remote port as a string func (pkt *Base) FmtRemotePort() string { if pkt.info.SrcPort != 0 { - if pkt.info.Direction { + if pkt.info.Inbound { return fmt.Sprintf("%d", pkt.info.SrcPort) } return fmt.Sprintf("%d", pkt.info.DstPort) diff --git a/network/packet/packetinfo.go b/network/packet/packetinfo.go index a98fc8a5..3e68e8af 100644 --- a/network/packet/packetinfo.go +++ b/network/packet/packetinfo.go @@ -6,8 +6,8 @@ import ( // Info holds IP and TCP/UDP header information type Info struct { - Direction bool - InTunnel bool + Inbound bool + InTunnel bool Version IPVersion Protocol IPProtocol @@ -17,7 +17,7 @@ type Info struct { // LocalIP returns the local IP of the packet. func (pi *Info) LocalIP() net.IP { - if pi.Direction { + if pi.Inbound { return pi.Dst } return pi.Src @@ -25,7 +25,7 @@ func (pi *Info) LocalIP() net.IP { // RemoteIP returns the remote IP of the packet. func (pi *Info) RemoteIP() net.IP { - if pi.Direction { + if pi.Inbound { return pi.Src } return pi.Dst @@ -33,7 +33,7 @@ func (pi *Info) RemoteIP() net.IP { // LocalPort returns the local port of the packet. func (pi *Info) LocalPort() uint16 { - if pi.Direction { + if pi.Inbound { return pi.DstPort } return pi.SrcPort @@ -41,7 +41,7 @@ func (pi *Info) LocalPort() uint16 { // RemotePort returns the remote port of the packet. func (pi *Info) RemotePort() uint16 { - if pi.Direction { + if pi.Inbound { return pi.SrcPort } return pi.DstPort diff --git a/network/state/lookup.go b/network/state/lookup.go index 22baf62d..aff461f0 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -2,7 +2,6 @@ package state import ( "errors" - "net" "sync" "time" @@ -44,62 +43,36 @@ var ( waitTime = 3 * time.Millisecond ) -func LookupWithPacket(pkt packet.Packet) (pid int, inbound bool, err error) { - meta := pkt.Info() - return Lookup( - meta.Version, - meta.Protocol, - meta.LocalIP(), - meta.LocalPort(), - meta.RemoteIP(), - meta.RemotePort(), - meta.Direction, - ) -} - -func Lookup( - ipVersion packet.IPVersion, - protocol packet.IPProtocol, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, - pktInbound bool, -) ( - pid int, - inbound bool, - err error, -) { - +func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) { // auto-detect version - if ipVersion == 0 { - if ip := localIP.To4(); ip != nil { - ipVersion = packet.IPv4 + if pktInfo.Version == 0 { + if ip := pktInfo.LocalIP().To4(); ip != nil { + pktInfo.Version = packet.IPv4 } else { - ipVersion = packet.IPv6 + pktInfo.Version = packet.IPv6 } } switch { - case ipVersion == packet.IPv4 && protocol == packet.TCP: + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: tcp4Lock.Lock() defer tcp4Lock.Unlock() - return searchTCP(tcp4Connections, tcp4Listeners, updateTCP4Tables, localIP, localPort) + return searchTCP(tcp4Connections, tcp4Listeners, updateTCP4Tables, pktInfo) - case ipVersion == packet.IPv6 && protocol == packet.TCP: + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP: tcp6Lock.Lock() defer tcp6Lock.Unlock() - return searchTCP(tcp6Connections, tcp6Listeners, updateTCP6Tables, localIP, localPort) + return searchTCP(tcp6Connections, tcp6Listeners, updateTCP6Tables, pktInfo) - case ipVersion == packet.IPv4 && protocol == packet.UDP: + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP: udp4Lock.Lock() defer udp4Lock.Unlock() - return searchUDP(udp4Binds, udp4States, updateUDP4Table, localIP, localPort, remoteIP, remotePort, pktInbound) + return searchUDP(udp4Binds, udp4States, updateUDP4Table, pktInfo) - case ipVersion == packet.IPv6 && protocol == packet.UDP: + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP: udp6Lock.Lock() defer udp6Lock.Unlock() - return searchUDP(udp6Binds, udp6States, updateUDP6Table, localIP, localPort, remoteIP, remotePort, pktInbound) + return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo) default: return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") @@ -110,14 +83,16 @@ func searchTCP( connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, updateTables func() ([]*socket.ConnectionInfo, []*socket.BindInfo), - localIP net.IP, - localPort uint16, + pktInfo *packet.Info, ) ( pid int, inbound bool, err error, ) { + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + // search until we find something for i := 0; i < 5; i++ { // always search listeners first @@ -150,18 +125,17 @@ func searchUDP( binds []*socket.BindInfo, udpStates map[string]map[string]*udpState, updateTable func() []*socket.BindInfo, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, - pktInbound bool, + pktInfo *packet.Info, ) ( pid int, inbound bool, err error, ) { - isInboundMulticast := pktInbound && netutils.ClassifyIP(localIP) == netutils.LocalMulticast + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + + isInboundMulticast := pktInfo.Inbound && netutils.ClassifyIP(localIP) == netutils.LocalMulticast // TODO: Currently broadcast/multicast scopes are not checked, so we might // attribute an incoming broadcast/multicast packet to the wrong process if // there are multiple processes listening on the same local port, but @@ -177,12 +151,12 @@ func searchUDP( localIP.Equal(socketInfo.Local.IP)) { // do not check direction if remoteIP/Port is not given - if remotePort == 0 { - return checkBindPID(socketInfo, pktInbound) + if pktInfo.RemotePort() == 0 { + return checkBindPID(socketInfo, pktInfo.Inbound) } // get direction and return - connInbound := getUDPDirection(socketInfo, udpStates, remoteIP, remotePort, pktInbound) + connInbound := getUDPDirection(socketInfo, udpStates, pktInfo) return checkBindPID(socketInfo, connInbound) } } @@ -194,5 +168,5 @@ func searchUDP( binds = updateTable() } - return UnidentifiedProcessID, pktInbound, ErrConnectionNotFound + return UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound } diff --git a/network/state/udp.go b/network/state/udp.go index f24ac237..ccfb0815 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -5,6 +5,7 @@ import ( "net" "time" + "github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/socket" ) @@ -34,7 +35,7 @@ func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[strin return nil, false } -func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16, pktInbound bool) (connDirection bool) { +func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) { localKey := makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port) bindMap, ok := udpStates[localKey] @@ -43,14 +44,14 @@ func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[strin udpStates[localKey] = bindMap } - remoteKey := makeUDPStateKey(remoteIP, remotePort) + remoteKey := makeUDPStateKey(pktInfo.RemoteIP(), pktInfo.RemotePort()) udpConnState, ok := bindMap[remoteKey] if !ok { bindMap[remoteKey] = &udpState{ - inbound: pktInbound, + inbound: pktInfo.Inbound, lastSeen: time.Now().UTC(), } - return pktInbound + return pktInfo.Inbound } udpConnState.lastSeen = time.Now().UTC() diff --git a/process/find.go b/process/find.go index 936d9214..a7e214cf 100644 --- a/process/find.go +++ b/process/find.go @@ -3,7 +3,6 @@ package process import ( "context" "errors" - "net" "github.com/safing/portmaster/network/state" @@ -16,45 +15,16 @@ var ( ErrProcessNotFound = errors.New("could not find process in system state tables") ) -// GetProcessByPacket returns the process that owns the given packet. -func GetProcessByPacket(pkt packet.Packet) (process *Process, inbound bool, err error) { - meta := pkt.Info() - return GetProcessByEndpoints( - pkt.Ctx(), - meta.Version, - meta.Protocol, - meta.LocalIP(), - meta.LocalPort(), - meta.RemoteIP(), - meta.RemotePort(), - meta.Direction, - ) -} - // GetProcessByEndpoints returns the process that owns the described link. -func GetProcessByEndpoints( - ctx context.Context, - ipVersion packet.IPVersion, - protocol packet.IPProtocol, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, - pktInbound bool, -) ( - process *Process, - connInbound bool, - err error, -) { - +func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process *Process, connInbound bool, err error) { if !enableProcessDetection() { log.Tracer(ctx).Tracef("process: process detection disabled") - return GetUnidentifiedProcess(ctx), pktInbound, nil + return GetUnidentifiedProcess(ctx), pktInfo.Inbound, nil } log.Tracer(ctx).Tracef("process: getting pid from system network state") var pid int - pid, connInbound, err = state.Lookup(ipVersion, protocol, localIP, localPort, remoteIP, remotePort, pktInbound) + pid, connInbound, err = state.Lookup(pktInfo) if err != nil { log.Tracer(ctx).Debugf("process: failed to find PID of connection: %s", err) return nil, connInbound, err From d11080d997246c822fb8aee658a54cf39c6e4fd2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 19 May 2020 09:07:01 +0200 Subject: [PATCH 21/36] Update dependencies --- Gopkg.lock | 107 +++++++++++++++++++++++++++++++++-------------------- Gopkg.toml | 4 ++ 2 files changed, 70 insertions(+), 41 deletions(-) diff --git a/Gopkg.lock b/Gopkg.lock index 7d1613bb..437e76c0 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -2,12 +2,12 @@ [[projects]] - digest = "1:f82b8ac36058904227087141017bb82f4b0fc58272990a4cdae3e2d6d222644e" + digest = "1:6146fda730c18186631e91e818d995e759e7cbe27644d6871ccd469f6865c686" name = "github.com/StackExchange/wmi" packages = ["."] pruneopts = "" - revision = "5d049714c4a64225c3c79a7cf7d02f7fb5b96338" - version = "1.0.0" + revision = "cbe66965904dbe8a6cd589e2298e5d8b986bd7dd" + version = "1.1.0" [[projects]] digest = "1:e010d6b45ee6c721df761eae89961c634ceb55feff166a48d15504729309f267" @@ -18,12 +18,12 @@ version = "v1.1.1" [[projects]] - digest = "1:3c753679736345f50125ae993e0a2614da126859921ea7faeecda6d217501ce2" + digest = "1:21caed545a1c7ef7a2627bbb45989f689872ff6d5087d49c31340ce74c36de59" name = "github.com/agext/levenshtein" packages = ["."] pruneopts = "" - revision = "0ded9c86537917af2ff89bc9c78de6bd58477894" - version = "v1.2.2" + revision = "52c14c47d03211d8ac1834e94601635e07c5a6ef" + version = "v1.2.3" [[projects]] branch = "v2.1" @@ -34,12 +34,12 @@ revision = "d27c04069d0d5dfe11c202dacbf745ae8d1ab181" [[projects]] - digest = "1:166e24c91c2732657d2f791d3ee3897e7d85ece7cbb62ad991250e6b51fc1d4c" + digest = "1:f384a8b6f89c502229e9013aa4f89ce5b5b56f09f9a4d601d7f1f026d3564fbf" name = "github.com/coreos/go-iptables" packages = ["iptables"] pruneopts = "" - revision = "78b5fff24e6df8886ef8eca9411f683a884349a5" - version = "v0.4.1" + revision = "f901d6c2a4f2a4df092b98c33366dfba1f93d7a0" + version = "v0.4.5" [[projects]] digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77" @@ -61,12 +61,12 @@ version = "v1.2.4" [[projects]] - digest = "1:cc1255e2fef3819bfab3540277001e602892dd431ef9ab5499bcdbc425923d64" + digest = "1:f63933986e63230fc32512ed00bc18ea4dbb0f57b5da18561314928fd20c2ff0" name = "github.com/godbus/dbus" packages = ["."] pruneopts = "" - revision = "2ff6f7ffd60f0f2410b3105864bdd12c7894f844" - version = "v5.0.1" + revision = "37bf87eef99d69c4f1d3528bd66e3a87dc201472" + version = "v5.0.3" [[projects]] digest = "1:e85e59c4152d8576341daf54f40d96c404c264e04941a4a36b97a0f427eb9e5e" @@ -113,20 +113,20 @@ revision = "2bc1f35cddc0cc527b4bc3dce8578fc2a6c11384" [[projects]] - digest = "1:0b6694f306890ddbb69c96a16776510bd24e07436fae3f9b0a4e5b650f1e6fb7" + branch = "master" + digest = "1:c140772b00f0c26cf6627aee32f62d9f9d89dffcda648861266c482c36a5344a" name = "github.com/miekg/dns" packages = ["."] pruneopts = "" - revision = "b13675009d59c97f3721247d9efa8914e1866a5b" - version = "v1.1.15" + revision = "b7703d0fa022e159d01efa2de82e6173d5ec04c8" [[projects]] - digest = "1:3819cd861b7abd7d12dc1ea52ecb998ad1171826a76ecf0aefa09545781091f9" + digest = "1:b962a528cbecf7662bee4d84a600f7a0a6a130368666d7d461757ba4d1341906" name = "github.com/oschwald/maxminddb-golang" packages = ["."] pruneopts = "" - revision = "2905694a1b00c5574f1418a7dbf8a22a7d247559" - version = "v1.3.1" + revision = "6a033e62c03b7dab4c37f7c9eb2ebb3b10e8f13a" + version = "v1.6.0" [[projects]] digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" @@ -145,7 +145,7 @@ version = "v1.2.0" [[projects]] - digest = "1:8bf42eb2ded52ed2678b0716dbfbf30628765bc12b13222c4d5669ba4c1310e4" + digest = "1:16f319cf21ddf49f27b3a2093d68316840dc25ec5c2a0a431a4a4fc01ea707e2" name = "github.com/shirou/gopsutil" packages = [ "cpu", @@ -155,32 +155,24 @@ "process", ] pruneopts = "" - revision = "4c8b404ee5c53b04b04f34b1744a26bf5d2910de" - version = "v2.19.6" + revision = "a81cf97fce2300934e6c625b9917103346c26ba3" + version = "v2.20.4" [[projects]] - branch = "master" - digest = "1:99c6a6dab47067c9b898e8c8b13d130c6ab4ffbcc4b7cc6236c2cd0b1e344f5b" - name = "github.com/shirou/w32" - packages = ["."] - pruneopts = "" - revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b" - -[[projects]] - digest = "1:0c63b3c7ad6d825a898f28cb854252a3b29d37700c68a117a977263f5ec94efe" + digest = "1:bff75d4f1a2d2c4b8f4b46ff5ac230b80b5fa49276f615900cba09fe4c97e66e" name = "github.com/spf13/cobra" packages = ["."] pruneopts = "" - revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5" - version = "v0.0.5" + revision = "a684a6d7f5e37385d954dd3b5a14fc6912c6ab9d" + version = "v1.0.0" [[projects]] - digest = "1:cbaf13cdbfef0e4734ed8a7504f57fe893d471d62a35b982bf6fb3f036449a66" + digest = "1:688428eeb1ca80d92599eb3254bdf91b51d7e232fead3a73844c1f201a281e51" name = "github.com/spf13/pflag" packages = ["."] pruneopts = "" - revision = "298182f68c66c05229eb03ac171abe6e309ee79a" - version = "v1.0.3" + revision = "2e9d26c8c37aae03e3f9d4e90b7116f5accb7cab" + version = "v1.0.5" [[projects]] digest = "1:cc4eb6813da8d08694e557fcafae8fcc24f47f61a0717f952da130ca9a486dfc" @@ -208,18 +200,18 @@ [[projects]] branch = "master" - digest = "1:086760278d762dbb0e9a26e09b57f04c89178c86467d8d94fae47d64c222f328" + digest = "1:bf61fa9b53be5ce096004599b957e5957b28a5e421b724250aa06ecb7ee6dc57" name = "golang.org/x/crypto" packages = [ "ed25519", "ed25519/internal/edwards25519", ] pruneopts = "" - revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285" + revision = "4b2356b1ed79e6be3deca3737a3db3d132d2847a" [[projects]] branch = "master" - digest = "1:31cd6e3c114e17c5f0c9e8b0bcaa3025ab3c221ce36323c7ce1acaa753d0d0aa" + digest = "1:ea84836e35d7a66c9b8944796295912509c80c921244bc4e098c5417219895f2" name = "golang.org/x/net" packages = [ "bpf", @@ -232,7 +224,7 @@ "publicsuffix", ] pruneopts = "" - revision = "da137c7871d730100384dbcf36e6f8fa493aef5b" + revision = "7e3656a0809f6f95abd88ac65313578f80b00df2" [[projects]] branch = "master" @@ -244,9 +236,10 @@ [[projects]] branch = "master" - digest = "1:2579a16d8afda9c9a475808c13324f5e572852e8927905ffa15bb14e71baba4f" + digest = "1:acb3b56e190190ac9497faf5f0c30c5da4d3e8278d6b7a7042f2aa3332ff7022" name = "golang.org/x/sys" packages = [ + "internal/unsafeheader", "unix", "windows", "windows/registry", @@ -256,7 +249,7 @@ "windows/svc/mgr", ] pruneopts = "" - revision = "04f50cda93cbb67f2afa353c52f342100e80e625" + revision = "bc7a7d42d5c30f4d0fe808715c002826ce2c624e" [[projects]] digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" @@ -283,6 +276,38 @@ revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" version = "v0.3.2" +[[projects]] + branch = "master" + digest = "1:3416c611e00178b07c8fc347ba96398e4d6709fe7d3fab17f0b0fa6f933b4bd1" + name = "golang.org/x/tools" + packages = [ + "go/ast/astutil", + "go/gcexportdata", + "go/internal/gcimporter", + "go/internal/packagesdriver", + "go/packages", + "go/types/typeutil", + "internal/event", + "internal/event/core", + "internal/event/keys", + "internal/event/label", + "internal/gocommand", + "internal/packagesinternal", + ] + pruneopts = "" + revision = "b8469989bc69e50ec6dc4e4513fc3ff9ce48b8af" + +[[projects]] + branch = "master" + digest = "1:9d4ac09a835404ae9306c6e1493cf800ecbb0f3f828f4333b3e055de4c962eea" + name = "golang.org/x/xerrors" + packages = [ + ".", + "internal", + ] + pruneopts = "" + revision = "9bdfabe68543c54f90421aeb9a60ef8061b5b544" + [[projects]] digest = "1:2efc9662a6a1ff28c65c84fc2f9030f13d3afecdb2ecad445f3b0c80e75fc281" name = "gopkg.in/yaml.v2" diff --git a/Gopkg.toml b/Gopkg.toml index 764c45b2..f449d026 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -25,3 +25,7 @@ # unused-packages = true ignored = ["github.com/safing/portbase/*"] + +[[constraint]] + name = "github.com/miekg/dns" + branch = "master" # switch back to semver releases when https://github.com/miekg/dns/pull/1110 is released From ad93b199682b9879ef595b431d90fff8650fce44 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 19 May 2020 09:22:50 +0200 Subject: [PATCH 22/36] Switch Exists function of network state pkg to use packet.Info --- network/clean.go | 21 ++++++++++++----- network/state/exists.go | 50 +++++++++++++++++------------------------ 2 files changed, 35 insertions(+), 36 deletions(-) diff --git a/network/clean.go b/network/clean.go index efeadce9..3d3a8c23 100644 --- a/network/clean.go +++ b/network/clean.go @@ -4,6 +4,8 @@ import ( "context" "time" + "github.com/safing/portmaster/network/packet" + "github.com/safing/portmaster/network/state" "github.com/safing/portbase/log" @@ -42,7 +44,7 @@ func cleanConnections() (activePIDs map[int]struct{}) { now := time.Now().UTC() nowUnix := now.Unix() - deleteOlderThan := time.Now().Add(-deleteConnsAfterEndedThreshold).Unix() + deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix() // lock both together because we cannot fully guarantee in which map a connection lands // of course every connection should land in the correct map, but this increases resilience @@ -59,12 +61,19 @@ func cleanConnections() (activePIDs map[int]struct{}) { switch { case conn.Ended == 0: // Step 1: check if still active - exists := state.Exists(conn.IPVersion, conn.IPProtocol, conn.LocalIP, conn.LocalPort, conn.Entity.IP, conn.Entity.Port, now) - if exists { - activePIDs[conn.process.Pid] = struct{}{} - } else { + exists := state.Exists(&packet.Info{ + Inbound: false, // src == local + Version: conn.IPVersion, + Protocol: conn.IPProtocol, + Src: conn.LocalIP, + SrcPort: conn.LocalPort, + Dst: conn.Entity.IP, + DstPort: conn.Entity.Port, + }, now) + activePIDs[conn.process.Pid] = struct{}{} + + if !exists { // Step 2: mark end - activePIDs[conn.process.Pid] = struct{}{} conn.Ended = nowUnix conn.Save() } diff --git a/network/state/exists.go b/network/state/exists.go index 9af6979f..68dd8288 100644 --- a/network/state/exists.go +++ b/network/state/exists.go @@ -1,7 +1,6 @@ package state import ( - "net" "time" "github.com/safing/portmaster/network/packet" @@ -12,49 +11,38 @@ const ( UDPConnectionTTL = 10 * time.Minute ) -func Exists( - ipVersion packet.IPVersion, - protocol packet.IPProtocol, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, - now time.Time, -) (exists bool) { - +func Exists(pktInfo *packet.Info, now time.Time) (exists bool) { switch { - case ipVersion == packet.IPv4 && protocol == packet.TCP: + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: tcp4Lock.Lock() defer tcp4Lock.Unlock() - return existsTCP(tcp4Connections, localIP, localPort, remoteIP, remotePort) + return existsTCP(tcp4Connections, pktInfo) - case ipVersion == packet.IPv6 && protocol == packet.TCP: + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.TCP: tcp6Lock.Lock() defer tcp6Lock.Unlock() - return existsTCP(tcp6Connections, localIP, localPort, remoteIP, remotePort) + return existsTCP(tcp6Connections, pktInfo) - case ipVersion == packet.IPv4 && protocol == packet.UDP: + case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.UDP: udp4Lock.Lock() defer udp4Lock.Unlock() - return existsUDP(udp4Binds, udp4States, localIP, localPort, remoteIP, remotePort, now) + return existsUDP(udp4Binds, udp4States, pktInfo, now) - case ipVersion == packet.IPv6 && protocol == packet.UDP: + case pktInfo.Version == packet.IPv6 && pktInfo.Protocol == packet.UDP: udp6Lock.Lock() defer udp6Lock.Unlock() - return existsUDP(udp6Binds, udp6States, localIP, localPort, remoteIP, remotePort, now) + return existsUDP(udp6Binds, udp6States, pktInfo, now) default: return false } } -func existsTCP( - connections []*socket.ConnectionInfo, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, -) (exists bool) { +func existsTCP(connections []*socket.ConnectionInfo, pktInfo *packet.Info) (exists bool) { + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + remoteIP := pktInfo.RemoteIP() + remotePort := pktInfo.RemotePort() // search connections for _, socketInfo := range connections { @@ -72,13 +60,15 @@ func existsTCP( func existsUDP( binds []*socket.BindInfo, udpStates map[string]map[string]*udpState, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, + pktInfo *packet.Info, now time.Time, ) (exists bool) { + localIP := pktInfo.LocalIP() + localPort := pktInfo.LocalPort() + remoteIP := pktInfo.RemoteIP() + remotePort := pktInfo.RemotePort() + connThreshhold := now.Add(-UDPConnectionTTL) // search binds From 3f9876fc098054cf7000294740bbacf6b7509fde Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 19 May 2020 09:39:03 +0200 Subject: [PATCH 23/36] Expose network system state table to api --- network/database.go | 16 ++++++++------ network/state/info.go | 50 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 7 deletions(-) create mode 100644 network/state/info.go diff --git a/network/database.go b/network/database.go index 5910e92d..460307c0 100644 --- a/network/database.go +++ b/network/database.go @@ -5,6 +5,8 @@ import ( "strings" "sync" + "github.com/safing/portmaster/network/state" + "github.com/safing/portbase/database" "github.com/safing/portbase/database/iterator" "github.com/safing/portbase/database/query" @@ -57,13 +59,13 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { return conn, nil } } - // case "system": - // if len(splitted) >= 2 { - // switch splitted[1] { - // case "": - // process.Get - // } - // } + case "system": + if len(splitted) >= 2 { + switch splitted[1] { + case "state": + return state.GetStateInfo(), nil + } + } } return nil, storage.ErrNotFound diff --git a/network/state/info.go b/network/state/info.go new file mode 100644 index 00000000..292a8ec7 --- /dev/null +++ b/network/state/info.go @@ -0,0 +1,50 @@ +package state + +import ( + "sync" + + "github.com/safing/portbase/database/record" + + "github.com/safing/portmaster/network/socket" +) + +type StateInfo struct { + record.Base + sync.Mutex + + TCP4Connections []*socket.ConnectionInfo + TCP4Listeners []*socket.BindInfo + TCP6Connections []*socket.ConnectionInfo + TCP6Listeners []*socket.BindInfo + UDP4Binds []*socket.BindInfo + UDP6Binds []*socket.BindInfo +} + +func GetStateInfo() *StateInfo { + info := &StateInfo{} + + tcp4Lock.Lock() + updateTCP4Tables() + info.TCP4Connections = tcp4Connections + info.TCP4Listeners = tcp4Listeners + tcp4Lock.Unlock() + + tcp6Lock.Lock() + updateTCP6Tables() + info.TCP6Connections = tcp6Connections + info.TCP6Listeners = tcp6Listeners + tcp6Lock.Unlock() + + udp4Lock.Lock() + updateUDP4Table() + info.UDP4Binds = udp4Binds + udp4Lock.Unlock() + + udp6Lock.Lock() + updateUDP6Table() + info.UDP6Binds = udp6Binds + udp6Lock.Unlock() + + info.UpdateMeta() + return info +} From c146a61704c2ce4671761a367411906b702876cb Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 19 May 2020 09:54:50 +0200 Subject: [PATCH 24/36] Improve waiting when searching the system state table --- network/state/lookup.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/network/state/lookup.go b/network/state/lookup.go index aff461f0..da8c2eec 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -40,7 +40,7 @@ var ( udp4Lock sync.Mutex udp6Lock sync.Mutex - waitTime = 3 * time.Millisecond + baseWaitTime = 3 * time.Millisecond ) func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) { @@ -94,7 +94,7 @@ func searchTCP( localPort := pktInfo.LocalPort() // search until we find something - for i := 0; i < 5; i++ { + for i := 0; i < 7; i++ { // always search listeners first for _, socketInfo := range listeners { if localPort == socketInfo.Local.Port && @@ -112,7 +112,8 @@ func searchTCP( } // we found nothing, we could have been too fast, give the kernel some time to think - time.Sleep(waitTime) + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) // refetch lists connections, listeners = updateTables() @@ -162,7 +163,8 @@ func searchUDP( } // we found nothing, we could have been too fast, give the kernel some time to think - time.Sleep(waitTime) + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) // refetch lists binds = updateTable() From 65a34561650286ae1b937845b55e537877fb3453 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 19 May 2020 16:57:13 +0200 Subject: [PATCH 25/36] Improve block reason in dns response --- intel/block_reason.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/intel/block_reason.go b/intel/block_reason.go index 26bd0a2a..52ad159f 100644 --- a/intel/block_reason.go +++ b/intel/block_reason.go @@ -71,8 +71,8 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns. for _, lm := range br { blockedBy, err := dns.NewRR(fmt.Sprintf( - "%s-blockedBy. 0 IN TXT %q", - strings.TrimRight(lm.Entity, "."), + `%s 0 IN TXT "was blocked by filter lists %s"`, + lm.Entity, strings.Join(lm.ActiveLists, ","), )) if err == nil { @@ -83,8 +83,8 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns. if len(lm.InactiveLists) > 0 { wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf( - "%s-wouldBeBlockedBy. 0 IN TXT %q", - strings.TrimRight(lm.Entity, "."), + `%s 0 IN TXT "would be blocked by filter lists %s"`, + lm.Entity, strings.Join(lm.InactiveLists, ","), )) if err == nil { From e65ae8b55d441ae21cdba614a8b1f2bb65909553 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 19 May 2020 16:57:55 +0200 Subject: [PATCH 26/36] Implement review suggestions --- nameserver/nameserver.go | 2 +- network/connection.go | 2 +- network/iphelper/get.go | 9 ++--- network/iphelper/tables.go | 21 +++++++---- network/proc/findpid.go | 16 ++++---- network/proc/tables.go | 71 ++++++++++++++++++++++------------- network/proc/tables_test.go | 12 +++--- network/socket/socket.go | 5 +++ network/state/exists.go | 9 ++++- network/state/lookup.go | 10 ++--- network/state/system_linux.go | 14 +------ network/state/tables.go | 22 ++++++----- network/state/udp.go | 29 +++++++------- ui/module.go | 8 ++++ 14 files changed, 130 insertions(+), 100 deletions(-) diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 1e955c4a..a06fbf61 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -169,7 +169,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er } // start tracer - ctx, tracer := log.AddTracer(context.Background()) + ctx, tracer := log.AddTracer(ctx) defer tracer.Submit() tracer.Tracef("nameserver: handling new request for %s%s from %s:%d, getting connection", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) diff --git a/network/connection.go b/network/connection.go index 030c84b6..2fa5f9df 100644 --- a/network/connection.go +++ b/network/connection.go @@ -37,7 +37,7 @@ type Connection struct { //nolint:maligned // TODO: fix alignment process *process.Process // remote endpoint - Entity *intel.Entity // needs locking, instance is never shared + Entity *intel.Entity Verdict Verdict Reason string diff --git a/network/iphelper/get.go b/network/iphelper/get.go index 80f3352f..e92f929c 100644 --- a/network/iphelper/get.go +++ b/network/iphelper/get.go @@ -8,13 +8,12 @@ import ( "github.com/safing/portmaster/network/socket" ) -const ( - unidentifiedProcessID = -1 -) - var ( ipHelper *IPHelper - lock sync.RWMutex + + // lock locks access to the whole DLL. + // TODO: It's unproven if we can access the iphlpapi.dll concurrently, especially as we might be encountering various versions of the DLL. In the future, we could possibly investigate and improve performance here. + lock sync.RWMutex ) // GetTCP4Table returns the system table for IPv4 TCP activity. diff --git a/network/iphelper/tables.go b/network/iphelper/tables.go index 2dcaf5c1..8c1fd8a7 100644 --- a/network/iphelper/tables.go +++ b/network/iphelper/tables.go @@ -214,7 +214,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so binds = append(binds, &socket.BindInfo{ Local: socket.Address{ IP: convertIPv4(row.localAddr), - Port: uint16(row.localPort>>8 | row.localPort<<8), + Port: convertPort(row.localPort), }, PID: int(row.owningPid), }) @@ -222,11 +222,11 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so connections = append(connections, &socket.ConnectionInfo{ Local: socket.Address{ IP: convertIPv4(row.localAddr), - Port: uint16(row.localPort>>8 | row.localPort<<8), + Port: convertPort(row.localPort), }, Remote: socket.Address{ IP: convertIPv4(row.remoteAddr), - Port: uint16(row.remotePort>>8 | row.remotePort<<8), + Port: convertPort(row.remotePort), }, PID: int(row.owningPid), }) @@ -243,7 +243,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so binds = append(binds, &socket.BindInfo{ Local: socket.Address{ IP: net.IP(row.localAddr[:]), - Port: uint16(row.localPort>>8 | row.localPort<<8), + Port: convertPort(row.localPort), }, PID: int(row.owningPid), }) @@ -251,11 +251,11 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so connections = append(connections, &socket.ConnectionInfo{ Local: socket.Address{ IP: net.IP(row.localAddr[:]), - Port: uint16(row.localPort>>8 | row.localPort<<8), + Port: convertPort(row.localPort), }, Remote: socket.Address{ IP: net.IP(row.remoteAddr[:]), - Port: uint16(row.remotePort>>8 | row.remotePort<<8), + Port: convertPort(row.remotePort), }, PID: int(row.owningPid), }) @@ -271,7 +271,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so binds = append(binds, &socket.BindInfo{ Local: socket.Address{ IP: convertIPv4(row.localAddr), - Port: uint16(row.localPort>>8 | row.localPort<<8), + Port: convertPort(row.localPort), }, PID: int(row.owningPid), }) @@ -286,7 +286,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so binds = append(binds, &socket.BindInfo{ Local: socket.Address{ IP: net.IP(row.localAddr[:]), - Port: uint16(row.localPort>>8 | row.localPort<<8), + Port: convertPort(row.localPort), }, PID: int(row.owningPid), }) @@ -303,3 +303,8 @@ func convertIPv4(input uint32) net.IP { binary.LittleEndian.PutUint32(addressBuf, input) return net.IP(addressBuf) } + +// convertPort converts ports received from iphlpapi.dll +func convertPort(input uint32) uint16 { + return uint16(input>>8 | input<<8) +} diff --git a/network/proc/findpid.go b/network/proc/findpid.go index ce54984e..6808960e 100644 --- a/network/proc/findpid.go +++ b/network/proc/findpid.go @@ -10,11 +10,9 @@ import ( "sync" "syscall" - "github.com/safing/portbase/log" -) + "github.com/safing/portmaster/network/socket" -const ( - unidentifiedProcessID = -1 + "github.com/safing/portbase/log" ) var ( @@ -23,7 +21,7 @@ var ( ) // FindPID returns the pid of the given uid and socket inode. -func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO +func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO pidsByUserLock.Lock() defer pidsByUserLock.Unlock() @@ -42,7 +40,7 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO var checkedUserPids []int for _, possiblePID := range pids { if findSocketFromPid(possiblePID, inode) { - return possiblePID, true + return possiblePID } checkedUserPids = append(checkedUserPids, possiblePID) } @@ -61,7 +59,7 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO // only check if not already checked if sort.SearchInts(checkedUserPids, possiblePID) == len { if findSocketFromPid(possiblePID, inode) { - return possiblePID, true + return possiblePID } } } @@ -75,13 +73,13 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO if possibleUID != uid { for _, possiblePID := range pids { if findSocketFromPid(possiblePID, inode) { - return possiblePID, true + return possiblePID } } } } - return unidentifiedProcessID, false + return socket.UnidentifiedProcessID } func findSocketFromPid(pid, inode int) bool { diff --git a/network/proc/tables.go b/network/proc/tables.go index b5a652a1..bf4a3eb0 100644 --- a/network/proc/tables.go +++ b/network/proc/tables.go @@ -5,6 +5,7 @@ package proc import ( "bufio" "encoding/hex" + "fmt" "net" "os" "strconv" @@ -43,12 +44,10 @@ const ( ICMP4 ICMP6 - TCP4Data = "/proc/net/tcp" - UDP4Data = "/proc/net/udp" - TCP6Data = "/proc/net/tcp6" - UDP6Data = "/proc/net/udp6" - ICMP4Data = "/proc/net/icmp" - ICMP6Data = "/proc/net/icmp6" + tcp4ProcFile = "/proc/net/tcp" + tcp6ProcFile = "/proc/net/tcp6" + udp4ProcFile = "/proc/net/udp" + udp6ProcFile = "/proc/net/udp6" UnfetchedProcessID = -2 @@ -57,27 +56,47 @@ const ( // GetTCP4Table returns the system table for IPv4 TCP activity. func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { - return getTableFromSource(TCP4, TCP4Data, convertIPv4) + return getTableFromSource(TCP4, tcp4ProcFile) } // GetTCP6Table returns the system table for IPv6 TCP activity. func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { - return getTableFromSource(TCP6, TCP6Data, convertIPv6) + return getTableFromSource(TCP6, tcp6ProcFile) } // GetUDP4Table returns the system table for IPv4 UDP activity. func GetUDP4Table() (binds []*socket.BindInfo, err error) { - _, binds, err = getTableFromSource(UDP4, UDP4Data, convertIPv4) + _, binds, err = getTableFromSource(UDP4, udp4ProcFile) return } // GetUDP6Table returns the system table for IPv6 UDP activity. func GetUDP6Table() (binds []*socket.BindInfo, err error) { - _, binds, err = getTableFromSource(UDP6, UDP6Data, convertIPv6) + _, binds, err = getTableFromSource(UDP6, udp6ProcFile) return } -func getTableFromSource(stack uint8, procFile string, ipConverter func(string) net.IP) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { +const ( + // hint: we split fields by multiple delimiters, see procDelimiter + fieldIndexLocalIP = 1 + fieldIndexLocalPort = 2 + fieldIndexRemoteIP = 3 + fieldIndexRemotePort = 4 + fieldIndexUID = 11 + fieldIndexInode = 13 +) + +func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { + + var ipConverter func(string) net.IP + switch stack { + case TCP4, UDP4: + ipConverter = convertIPv4 + case TCP6, UDP6: + ipConverter = convertIPv6 + default: + return nil, nil, fmt.Errorf("unsupported table stack: %d", stack) + } // open file socketData, err := os.Open(procFile) @@ -91,36 +110,36 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n scanner.Split(bufio.ScanLines) // parse - scanner.Scan() // skip first line + scanner.Scan() // skip first row for scanner.Scan() { - line := strings.FieldsFunc(scanner.Text(), procDelimiter) - if len(line) < 14 { - // log.Tracef("process: too short: %s", line) + fields := strings.FieldsFunc(scanner.Text(), procDelimiter) + if len(fields) < 14 { + // log.Tracef("process: too short: %s", fields) continue } - localIP := ipConverter(line[1]) + localIP := ipConverter(fields[fieldIndexLocalIP]) if localIP == nil { continue } - localPort, err := strconv.ParseUint(line[2], 16, 16) + localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16) if err != nil { log.Warningf("process: could not parse port: %s", err) continue } - uid, err := strconv.ParseInt(line[11], 10, 32) - // log.Tracef("uid: %s", line[11]) + uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32) + // log.Tracef("uid: %s", fields[fieldIndexUID]) if err != nil { - log.Warningf("process: could not parse uid %s: %s", line[11], err) + log.Warningf("process: could not parse uid %s: %s", fields[11], err) continue } - inode, err := strconv.ParseInt(line[13], 10, 32) - // log.Tracef("inode: %s", line[13]) + inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32) + // log.Tracef("inode: %s", fields[fieldIndexInode]) if err != nil { - log.Warningf("process: could not parse inode %s: %s", line[13], err) + log.Warningf("process: could not parse inode %s: %s", fields[13], err) continue } @@ -139,7 +158,7 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n case TCP4, TCP6: - if line[5] == tcpListenStateHex { + if fields[5] == tcpListenStateHex { // listener binds = append(binds, &socket.BindInfo{ @@ -154,12 +173,12 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n } else { // connection - remoteIP := ipConverter(line[3]) + remoteIP := ipConverter(fields[fieldIndexRemoteIP]) if remoteIP == nil { continue } - remotePort, err := strconv.ParseUint(line[4], 16, 16) + remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16) if err != nil { log.Warningf("process: could not parse port: %s", err) continue diff --git a/network/proc/tables_test.go b/network/proc/tables_test.go index 9dc7c1eb..eed12ce8 100644 --- a/network/proc/tables_test.go +++ b/network/proc/tables_test.go @@ -14,12 +14,12 @@ func TestSockets(t *testing.T) { } fmt.Println("\nTCP 4 connections:") for _, connection := range connections { - pid, _ := FindPID(connection.UID, connection.Inode) + pid := FindPID(connection.UID, connection.Inode) fmt.Printf("%d: %+v\n", pid, connection) } fmt.Println("\nTCP 4 listeners:") for _, listener := range listeners { - pid, _ := FindPID(listener.UID, listener.Inode) + pid := FindPID(listener.UID, listener.Inode) fmt.Printf("%d: %+v\n", pid, listener) } @@ -29,12 +29,12 @@ func TestSockets(t *testing.T) { } fmt.Println("\nTCP 6 connections:") for _, connection := range connections { - pid, _ := FindPID(connection.UID, connection.Inode) + pid := FindPID(connection.UID, connection.Inode) fmt.Printf("%d: %+v\n", pid, connection) } fmt.Println("\nTCP 6 listeners:") for _, listener := range listeners { - pid, _ := FindPID(listener.UID, listener.Inode) + pid := FindPID(listener.UID, listener.Inode) fmt.Printf("%d: %+v\n", pid, listener) } @@ -44,7 +44,7 @@ func TestSockets(t *testing.T) { } fmt.Println("\nUDP 4 binds:") for _, bind := range binds { - pid, _ := FindPID(bind.UID, bind.Inode) + pid := FindPID(bind.UID, bind.Inode) fmt.Printf("%d: %+v\n", pid, bind) } @@ -54,7 +54,7 @@ func TestSockets(t *testing.T) { } fmt.Println("\nUDP 6 binds:") for _, bind := range binds { - pid, _ := FindPID(bind.UID, bind.Inode) + pid := FindPID(bind.UID, bind.Inode) fmt.Printf("%d: %+v\n", pid, bind) } } diff --git a/network/socket/socket.go b/network/socket/socket.go index a599eddf..e8dfe1d9 100644 --- a/network/socket/socket.go +++ b/network/socket/socket.go @@ -2,6 +2,11 @@ package socket import "net" +const ( + // UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops. + UnidentifiedProcessID = -1 +) + // ConnectionInfo holds socket information returned by the system. type ConnectionInfo struct { Local Address diff --git a/network/state/exists.go b/network/state/exists.go index 68dd8288..f64e4a64 100644 --- a/network/state/exists.go +++ b/network/state/exists.go @@ -11,7 +11,11 @@ const ( UDPConnectionTTL = 10 * time.Minute ) +// Exists checks if the given connection is present in the system state tables. func Exists(pktInfo *packet.Info, now time.Time) (exists bool) { + + // TODO: create lookup maps before running a flurry of Exists() checks. + switch { case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: tcp4Lock.Lock() @@ -76,7 +80,10 @@ func existsUDP( if localPort == socketInfo.Local.Port && (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { - udpConnState, ok := getUDPConnState(socketInfo, udpStates, remoteIP, remotePort) + udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{ + IP: remoteIP, + Port: remotePort, + }) switch { case !ok: return false diff --git a/network/state/lookup.go b/network/state/lookup.go index da8c2eec..8346072e 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -24,10 +24,6 @@ import ( // - switch direction to outbound if outbound packet is seen? // - IP: Unidentified Process -const ( - UnidentifiedProcessID = -1 -) - // Errors var ( ErrConnectionNotFound = errors.New("could not find connection in system state tables") @@ -75,7 +71,7 @@ func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) { return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo) default: - return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") + return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") } } @@ -119,7 +115,7 @@ func searchTCP( connections, listeners = updateTables() } - return UnidentifiedProcessID, false, ErrConnectionNotFound + return socket.UnidentifiedProcessID, false, ErrConnectionNotFound } func searchUDP( @@ -170,5 +166,5 @@ func searchUDP( binds = updateTable() } - return UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound + return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound } diff --git a/network/state/system_linux.go b/network/state/system_linux.go index a08fd86b..b902c58c 100644 --- a/network/state/system_linux.go +++ b/network/state/system_linux.go @@ -14,24 +14,14 @@ var ( func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { if socketInfo.PID == proc.UnfetchedProcessID { - pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode) - if ok { - socketInfo.PID = pid - } else { - socketInfo.PID = UnidentifiedProcessID - } + socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode) } return socketInfo.PID, connInbound, nil } func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { if socketInfo.PID == proc.UnfetchedProcessID { - pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode) - if ok { - socketInfo.PID = pid - } else { - socketInfo.PID = UnidentifiedProcessID - } + socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode) } return socketInfo.PID, connInbound, nil } diff --git a/network/state/tables.go b/network/state/tables.go index 59095a16..2f236cc6 100644 --- a/network/state/tables.go +++ b/network/state/tables.go @@ -18,9 +18,8 @@ var ( ) func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { - // FIXME: repeatable once - - connections, listeners, err := getTCP4Table() + var err error + connections, listeners, err = getTCP4Table() if err != nil { log.Warningf("state: failed to get TCP4 socket table: %s", err) return @@ -28,11 +27,12 @@ func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*sock tcp4Connections = connections tcp4Listeners = listeners - return tcp4Connections, tcp4Listeners + return } func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { - connections, listeners, err := getTCP6Table() + var err error + connections, listeners, err = getTCP6Table() if err != nil { log.Warningf("state: failed to get TCP6 socket table: %s", err) return @@ -40,27 +40,29 @@ func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*sock tcp6Connections = connections tcp6Listeners = listeners - return tcp6Connections, tcp6Listeners + return } func updateUDP4Table() (binds []*socket.BindInfo) { - binds, err := getUDP4Table() + var err error + binds, err = getUDP4Table() if err != nil { log.Warningf("state: failed to get UDP4 socket table: %s", err) return } udp4Binds = binds - return udp4Binds + return } func updateUDP6Table() (binds []*socket.BindInfo) { - binds, err := getUDP6Table() + var err error + binds, err = getUDP6Table() if err != nil { log.Warningf("state: failed to get UDP6 socket table: %s", err) return } udp6Binds = binds - return udp6Binds + return } diff --git a/network/state/udp.go b/network/state/udp.go index ccfb0815..46966961 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -2,7 +2,6 @@ package state import ( "context" - "net" "time" "github.com/safing/portmaster/network/packet" @@ -15,7 +14,7 @@ type udpState struct { } const ( - UpdConnStateTTL = 72 * time.Hour + UdpConnStateTTL = 72 * time.Hour UdpConnStateShortenedTTL = 3 * time.Hour AggressiveCleaningThreshold = 256 ) @@ -25,10 +24,10 @@ var ( udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock ) -func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16) (udpConnState *udpState, ok bool) { - bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] +func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) { + bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)] if ok { - udpConnState, ok = bindMap[makeUDPStateKey(remoteIP, remotePort)] + udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)] return } @@ -36,7 +35,7 @@ func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[strin } func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) { - localKey := makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port) + localKey := makeUDPStateKey(socketInfo.Local) bindMap, ok := udpStates[localKey] if !ok { @@ -44,7 +43,10 @@ func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[strin udpStates[localKey] = bindMap } - remoteKey := makeUDPStateKey(pktInfo.RemoteIP(), pktInfo.RemotePort()) + remoteKey := makeUDPStateKey(socket.Address{ + IP: pktInfo.RemoteIP(), + Port: pktInfo.RemotePort(), + }) udpConnState, ok := bindMap[remoteKey] if !ok { bindMap[remoteKey] = &udpState{ @@ -79,19 +81,18 @@ func cleanStates( now time.Time, ) { // compute thresholds - threshold := now.Add(-UpdConnStateTTL) + threshold := now.Add(-UdpConnStateTTL) shortThreshhold := now.Add(-UdpConnStateShortenedTTL) - // make list of all active keys + // make lookup map of all active keys bindKeys := make(map[string]struct{}) for _, socketInfo := range binds { - bindKeys[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] = struct{}{} + bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{} } // clean the udp state storage for localKey, bindMap := range udpStates { - _, active := bindKeys[localKey] - if active { + if _, active := bindKeys[localKey]; active { // clean old entries for remoteKey, udpConnState := range bindMap { if udpConnState.lastSeen.Before(threshold) { @@ -113,7 +114,7 @@ func cleanStates( } } -func makeUDPStateKey(ip net.IP, port uint16) string { +func makeUDPStateKey(address socket.Address) string { // This could potentially go wrong, but as all IPs are created by the same source, everything should be fine. - return string(ip) + string(port) + return string(address.IP) + string(address.Port) } diff --git a/ui/module.go b/ui/module.go index 43bf667a..8fdfed9d 100644 --- a/ui/module.go +++ b/ui/module.go @@ -29,6 +29,14 @@ func prep() error { } func start() error { + // Create a dummy directory to which processes change their working directory + // to. Currently this includes the App and the Notifier. The aim is protect + // all other directories and increase compatibility should any process want + // to read or write something to the current working directory. This can also + // be useful in the future to dump data to for debugging. The permission used + // may seem dangerous, but proper permission on the parent directory provide + // (some) protection. + // Processes must _never_ read from this directory. err := dataroot.Root().ChildDir("exec", 0777).Ensure() if err != nil { log.Warningf("ui: failed to create safe exec dir: %s", err) From f1765a7abbc70b5e027c73fcd2f1e9650c3f7b03 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 14:53:03 +0200 Subject: [PATCH 27/36] Fix linter errors --- network/database.go | 3 ++- network/module.go | 5 ----- network/state/exists.go | 1 + network/state/info.go | 8 +++++--- network/state/lookup.go | 1 + network/state/udp.go | 21 +++++++++++++-------- process/find.go | 2 +- updates/upgrader.go | 1 + 8 files changed, 24 insertions(+), 18 deletions(-) diff --git a/network/database.go b/network/database.go index 460307c0..a44a379c 100644 --- a/network/database.go +++ b/network/database.go @@ -63,7 +63,8 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { if len(splitted) >= 2 { switch splitted[1] { case "state": - return state.GetStateInfo(), nil + return state.GetInfo(), nil + default: } } } diff --git a/network/module.go b/network/module.go index 70f2fd24..a1ee69f5 100644 --- a/network/module.go +++ b/network/module.go @@ -1,17 +1,12 @@ package network import ( - "net" - "github.com/safing/portbase/modules" ) var ( module *modules.Module - dnsAddress = net.IPv4(127, 0, 0, 1) - dnsPort uint16 = 53 - defaultFirewallHandler FirewallHandler ) diff --git a/network/state/exists.go b/network/state/exists.go index f64e4a64..7b308608 100644 --- a/network/state/exists.go +++ b/network/state/exists.go @@ -8,6 +8,7 @@ import ( ) const ( + // UDPConnectionTTL defines the duration after which unseen UDP connections are regarded as ended. UDPConnectionTTL = 10 * time.Minute ) diff --git a/network/state/info.go b/network/state/info.go index 292a8ec7..5d4b0d4d 100644 --- a/network/state/info.go +++ b/network/state/info.go @@ -8,7 +8,8 @@ import ( "github.com/safing/portmaster/network/socket" ) -type StateInfo struct { +// Info holds network state information as provided by the system. +type Info struct { record.Base sync.Mutex @@ -20,8 +21,9 @@ type StateInfo struct { UDP6Binds []*socket.BindInfo } -func GetStateInfo() *StateInfo { - info := &StateInfo{} +// GetInfo returns all system state tables. The returned data must not be modified. +func GetInfo() *Info { + info := &Info{} tcp4Lock.Lock() updateTCP4Tables() diff --git a/network/state/lookup.go b/network/state/lookup.go index 8346072e..5aadf7fa 100644 --- a/network/state/lookup.go +++ b/network/state/lookup.go @@ -39,6 +39,7 @@ var ( baseWaitTime = 3 * time.Millisecond ) +// Lookup looks for the given connection in the system state tables and returns the PID of the associated process and whether the connection is inbound. func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) { // auto-detect version if pktInfo.Version == 0 { diff --git a/network/state/udp.go b/network/state/udp.go index 46966961..f49b1d04 100644 --- a/network/state/udp.go +++ b/network/state/udp.go @@ -14,8 +14,13 @@ type udpState struct { } const ( - UdpConnStateTTL = 72 * time.Hour - UdpConnStateShortenedTTL = 3 * time.Hour + // UDPConnStateTTL is the maximum time a udp connection state is held. + UDPConnStateTTL = 72 * time.Hour + + // UDPConnStateShortenedTTL is a shortened maximum time a udp connection state is held, if there more entries than defined by AggressiveCleaningThreshold. + UDPConnStateShortenedTTL = 3 * time.Hour + + // AggressiveCleaningThreshold defines the soft limit of udp connection state held per udp socket. AggressiveCleaningThreshold = 256 ) @@ -60,29 +65,29 @@ func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[strin return udpConnState.inbound } -func CleanUDPStates(ctx context.Context) { +// CleanUDPStates cleans the udp connection states which save connection directions. +func CleanUDPStates(_ context.Context) { now := time.Now().UTC() udp4Lock.Lock() updateUDP4Table() - cleanStates(ctx, udp4Binds, udp4States, now) + cleanStates(udp4Binds, udp4States, now) udp4Lock.Unlock() udp6Lock.Lock() updateUDP6Table() - cleanStates(ctx, udp6Binds, udp6States, now) + cleanStates(udp6Binds, udp6States, now) udp6Lock.Unlock() } func cleanStates( - ctx context.Context, binds []*socket.BindInfo, udpStates map[string]map[string]*udpState, now time.Time, ) { // compute thresholds - threshold := now.Add(-UdpConnStateTTL) - shortThreshhold := now.Add(-UdpConnStateShortenedTTL) + threshold := now.Add(-UDPConnStateTTL) + shortThreshhold := now.Add(-UDPConnStateShortenedTTL) // make lookup map of all active keys bindKeys := make(map[string]struct{}) diff --git a/process/find.go b/process/find.go index a7e214cf..50070949 100644 --- a/process/find.go +++ b/process/find.go @@ -15,7 +15,7 @@ var ( ErrProcessNotFound = errors.New("could not find process in system state tables") ) -// GetProcessByEndpoints returns the process that owns the described link. +// GetProcessByConnection returns the process that owns the described connection. func GetProcessByConnection(ctx context.Context, pktInfo *packet.Info) (process *Process, connInbound bool, err error) { if !enableProcessDetection() { log.Tracer(ctx).Tracef("process: process detection disabled") diff --git a/updates/upgrader.go b/updates/upgrader.go index b109c913..47f8ea36 100644 --- a/updates/upgrader.go +++ b/updates/upgrader.go @@ -230,6 +230,7 @@ func upgradeFile(fileToUpgrade string, file *updater.File) error { return nil } +// CopyFile atomically copies a file using the update registry's tmp dir. func CopyFile(srcPath, dstPath string) (err error) { // check tmp dir From c48f8e5782d45a8c4b64a53cb887c4dae885b575 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 14:53:14 +0200 Subject: [PATCH 28/36] Fix endpoint scope --- profile/endpoints/endpoint-scopes.go | 28 +++++++++++----------------- profile/endpoints/endpoint.go | 2 +- profile/endpoints/endpoint_test.go | 6 ++++++ profile/endpoints/endpoints_test.go | 2 +- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/profile/endpoints/endpoint-scopes.go b/profile/endpoints/endpoint-scopes.go index 1c73aebe..ea22126d 100644 --- a/profile/endpoints/endpoint-scopes.go +++ b/profile/endpoints/endpoint-scopes.go @@ -29,10 +29,6 @@ type EndpointScope struct { scopes uint8 } -// Localhost -// LAN -// Internet - // Matches checks whether the given entity matches this endpoint definition. func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) { if entity.IP == nil { @@ -64,16 +60,14 @@ func (ep *EndpointScope) Matches(entity *intel.Entity) (EPResult, Reason) { // Scopes returns the string representation of all scopes. func (ep *EndpointScope) Scopes() string { - if ep.scopes == 3 || ep.scopes > 4 { - // single scope - switch ep.scopes { - case scopeLocalhost: - return scopeLocalhostName - case scopeLAN: - return scopeLANName - case scopeInternet: - return scopeInternetName - } + // single scope + switch ep.scopes { + case scopeLocalhost: + return scopeLocalhostName + case scopeLAN: + return scopeLANName + case scopeInternet: + return scopeInternetName } // multiple scopes @@ -99,11 +93,11 @@ func parseTypeScope(fields []string) (Endpoint, error) { for _, val := range strings.Split(strings.ToLower(fields[1]), ",") { switch val { case scopeLocalhostMatcher: - ep.scopes &= scopeLocalhost + ep.scopes ^= scopeLocalhost case scopeLANMatcher: - ep.scopes &= scopeLAN + ep.scopes ^= scopeLAN case scopeInternetMatcher: - ep.scopes &= scopeInternet + ep.scopes ^= scopeInternet default: return nil, nil } diff --git a/profile/endpoints/endpoint.go b/profile/endpoints/endpoint.go index 4e73d1d4..2e0a4e85 100644 --- a/profile/endpoints/endpoint.go +++ b/profile/endpoints/endpoint.go @@ -201,7 +201,7 @@ func invalidDefinitionError(fields []string, msg string) error { return fmt.Errorf(`invalid endpoint definition: "%s" - %s`, strings.Join(fields, " "), msg) } -func parseEndpoint(value string) (endpoint Endpoint, err error) { +func parseEndpoint(value string) (endpoint Endpoint, err error) { //nolint:gocognit fields := strings.Fields(value) if len(fields) < 2 { return nil, fmt.Errorf(`invalid endpoint definition: "%s"`, value) diff --git a/profile/endpoints/endpoint_test.go b/profile/endpoints/endpoint_test.go index d8aabee8..21ef057e 100644 --- a/profile/endpoints/endpoint_test.go +++ b/profile/endpoints/endpoint_test.go @@ -43,6 +43,12 @@ func TestEndpointParsing(t *testing.T) { testParsing(t, "+ AS1234") testParsing(t, "+ AS12345") + // network scope + testParsing(t, "+ Localhost") + testParsing(t, "+ LAN") + testParsing(t, "+ Internet") + testParsing(t, "+ Localhost,LAN,Internet") + // protocol and ports testParsing(t, "+ * TCP/1-1024") testParsing(t, "+ * */DNS") diff --git a/profile/endpoints/endpoints_test.go b/profile/endpoints/endpoints_test.go index ad23d352..dbc3119d 100644 --- a/profile/endpoints/endpoints_test.go +++ b/profile/endpoints/endpoints_test.go @@ -358,7 +358,7 @@ func TestEndpointMatching(t *testing.T) { // Lists - ep, err = parseEndpoint("+ L:A,B,C") + _, err = parseEndpoint("+ L:A,B,C") if err != nil { t.Fatal(err) } From c71dfaab3816506b053dbc47c367fd18d52c2749 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 14:57:17 +0200 Subject: [PATCH 29/36] Fix resolver failing --- resolver/mdns.go | 6 ++-- resolver/resolve.go | 2 +- resolver/resolver.go | 71 +++++++++++++++++++++++++++++++------------- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/resolver/mdns.go b/resolver/mdns.go index b8595d67..5fe889dd 100644 --- a/resolver/mdns.go +++ b/resolver/mdns.go @@ -45,10 +45,10 @@ func (mrc *mDNSResolverConn) Query(ctx context.Context, q *Query) (*RRCache, err return queryMulticastDNS(ctx, q) } -func (mrc *mDNSResolverConn) MarkFailed() {} +func (mrc *mDNSResolverConn) ReportFailure() {} -func (mrc *mDNSResolverConn) LastFail() time.Time { - return time.Time{} +func (mrc *mDNSResolverConn) IsFailing() bool { + return false } type savedQuestion struct { diff --git a/resolver/resolve.go b/resolver/resolve.go index 385054a9..ee76d5fa 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -233,7 +233,7 @@ resolveLoop: for i = 0; i < 2; i++ { for _, resolver := range resolvers { // check if resolver failed recently (on first run) - if i == 0 && resolver.Conn.LastFail().After(lastFailBoundary) { + if i == 0 && resolver.Conn.IsFailing() { log.Tracer(ctx).Tracef("resolver: skipping resolver %s, because it failed recently", resolver) continue } diff --git a/resolver/resolver.go b/resolver/resolver.go index 244b0c57..d44bba8b 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -24,6 +24,11 @@ const ( ServerSourceMDNS = "mdns" ) +var ( + // FailThreshold is amount of errors a resolvers must experience in order to be regarded as failed. + FailThreshold = 5 +) + // Resolver holds information about an active resolver. type Resolver struct { // Server config url (and ID) @@ -84,8 +89,8 @@ func (resolver *Resolver) String() string { // ResolverConn is an interface to implement different types of query backends. type ResolverConn interface { //nolint:go-lint // TODO Query(ctx context.Context, q *Query) (*RRCache, error) - MarkFailed() - LastFail() time.Time + ReportFailure() + IsFailing() bool } // BasicResolverConn implements ResolverConn for standard dns clients. @@ -94,11 +99,13 @@ type BasicResolverConn struct { resolver *Resolver clientManager *dnsClientManager - lastFail time.Time + + lastFail time.Time + fails int } -// MarkFailed marks the resolver as failed. -func (brc *BasicResolverConn) MarkFailed() { +// ReportFailure reports that an error occurred with this resolver. +func (brc *BasicResolverConn) ReportFailure() { if !netenv.Online() { // don't mark failed if we are offline return @@ -106,14 +113,26 @@ func (brc *BasicResolverConn) MarkFailed() { brc.Lock() defer brc.Unlock() - brc.lastFail = time.Now() + now := time.Now().UTC() + failDuration := time.Duration(nameserverRetryRate()) * time.Second + + // reset fail counter if currently not failing + if now.Add(-failDuration).After(brc.lastFail) { + brc.fails = 0 + } + + // update + brc.lastFail = now + brc.fails++ } -// LastFail returns the internal lastfail value while locking the Resolver. -func (brc *BasicResolverConn) LastFail() time.Time { +// IsFailing returns if this resolver is currently failing. +func (brc *BasicResolverConn) IsFailing() bool { brc.Lock() defer brc.Unlock() - return brc.lastFail + + failDuration := time.Duration(nameserverRetryRate()) * time.Second + return brc.fails >= FailThreshold && time.Now().UTC().Add(-failDuration).Before(brc.lastFail) } // Query executes the given query against the resolver. @@ -131,9 +150,9 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er var err error var conn *dns.Conn var new bool - var i int + var tries int - for ; i < 5; i++ { + for ; tries < 3; tries++ { // first get connection dc := brc.clientManager.getDNSClient() @@ -142,13 +161,22 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er log.Tracer(ctx).Tracef("resolver: failed to connect to %s: %s", resolver.Server, err) // remove client from pool dc.destroy() + // report that resolver had an error + brc.ReportFailure() + // hint network environment at failed connection + netenv.ReportFailedConnection() + + // TODO: handle special cases + // 1. connect: network is unreachable + // 2. timeout + // try again continue } if new { - log.Tracer(ctx).Tracef("resolver: created new connection to %s", resolver.ServerAddress) + log.Tracer(ctx).Tracef("resolver: created new connection to %s (%s)", resolver.Name, resolver.ServerAddress) } else { - log.Tracer(ctx).Tracef("resolver: reusing connection to %s", resolver.ServerAddress) + log.Tracer(ctx).Tracef("resolver: reusing connection to %s (%s)", resolver.Name, resolver.ServerAddress) } // query server @@ -162,13 +190,6 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er // remove client from pool dc.destroy() - // TODO: handle special cases - // 1. connect: network is unreachable - // 2. timeout - - // hint network environment at failed connection - netenv.ReportFailedConnection() - // temporary error if nerr, ok := err.(net.Error); ok && nerr.Timeout() { log.Tracer(ctx).Tracef("resolver: retrying to resolve %s%s with %s, error is temporary", q.FQDN, q.QType, resolver.Server) @@ -176,6 +197,14 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er continue } + // report failed if dns (nothing happens at getConn()) + if resolver.ServerType == ServerTypeDNS { + // report that resolver had an error + brc.ReportFailure() + // hint network environment at failed connection + netenv.ReportFailedConnection() + } + // permanent error break } else if reply == nil { @@ -201,7 +230,7 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er return nil, err // TODO: mark as failed } else if reply == nil { - log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), i+1) + log.Errorf("resolver: queried %s for %s%s (%d tries), but reply was nil", q.FQDN, q.QType, resolver.GetName(), tries+1) return nil, errors.New("internal error") } From 36c60a1e3387551b93eba8293838b01c77f43819 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 14:57:33 +0200 Subject: [PATCH 30/36] Reload resolver on config change --- resolver/main.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/resolver/main.go b/resolver/main.go index 9c71f5db..05d20fe3 100644 --- a/resolver/main.go +++ b/resolver/main.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "strings" "time" "github.com/safing/portbase/log" @@ -30,6 +31,7 @@ func start() error { // load resolvers from config and environment loadResolvers() + // reload after network change err := module.RegisterEventHook( "netenv", "network changed", @@ -44,6 +46,27 @@ func start() error { return err } + // reload after config change + prevNameservers := strings.Join(configuredNameServers(), " ") + err = module.RegisterEventHook( + "config", + "config change", + "update nameservers", + func(_ context.Context, _ interface{}) error { + newNameservers := strings.Join(configuredNameServers(), " ") + if newNameservers != prevNameservers { + prevNameservers = newNameservers + + loadResolvers() + log.Debug("resolver: reloaded nameservers due to config change") + } + return nil + }, + ) + if err != nil { + return err + } + module.StartServiceWorker( "mdns handler", 5*time.Second, From c8223f1a630473fe9be6ad30c0d101f3b112fae7 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 14:57:47 +0200 Subject: [PATCH 31/36] Switch resolver pooling to use sync.Pool --- resolver/clients.go | 92 ++++++++++++---------------------------- resolver/pooling_test.go | 17 +++++--- resolver/resolve.go | 9 +--- resolver/resolver.go | 4 +- 4 files changed, 41 insertions(+), 81 deletions(-) diff --git a/resolver/clients.go b/resolver/clients.go index 096f2af3..e3456759 100644 --- a/resolver/clients.go +++ b/resolver/clients.go @@ -12,8 +12,9 @@ import ( const ( defaultClientTTL = 5 * time.Minute - defaultRequestTimeout = 5 * time.Second - connectionEOLGracePeriod = 10 * time.Second + defaultRequestTimeout = 3 * time.Second // dns query + defaultConnectTimeout = 2 * time.Second // tcp/tls + connectionEOLGracePeriod = 7 * time.Second ) var ( @@ -43,23 +44,17 @@ type dnsClientManager struct { factory func() *dns.Client // internal - pool []*dnsClient + pool sync.Pool } type dnsClient struct { - mgr *dnsClientManager - - inUse bool - useUntil time.Time - dead bool - inPool bool - poolIndex int - - client *dns.Client - conn *dns.Conn + mgr *dnsClientManager + client *dns.Client + conn *dns.Conn + useUntil time.Time } -// conn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). +// getConn returns the *dns.Conn and if it's new. This function may only be called between clientManager.getDNSClient() and dnsClient.done(). func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { if dc.conn == nil { dc.conn, err = dc.client.Dial(dc.mgr.serverAddress) @@ -71,23 +66,11 @@ func (dc *dnsClient) getConn() (c *dns.Conn, new bool, err error) { return dc.conn, false, nil } -func (dc *dnsClient) done() { - dc.mgr.lock.Lock() - defer dc.mgr.lock.Unlock() - - dc.inUse = false +func (dc *dnsClient) addToPool() { + dc.mgr.pool.Put(dc) } func (dc *dnsClient) destroy() { - dc.mgr.lock.Lock() - dc.inUse = true // block from being used - dc.dead = true // abort cleaning - if dc.inPool { - dc.inPool = false - dc.mgr.pool[dc.poolIndex] = nil - } - dc.mgr.lock.Unlock() - if dc.conn != nil { _ = dc.conn.Close() } @@ -118,6 +101,7 @@ func newTCPClientManager(resolver *Resolver) *dnsClientManager { Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), + Timeout: defaultConnectTimeout, KeepAlive: defaultClientTTL, }, } @@ -140,6 +124,7 @@ func newTLSClientManager(resolver *Resolver) *dnsClientManager { Timeout: defaultRequestTimeout, Dialer: &net.Dialer{ LocalAddr: getLocalAddr("tcp"), + Timeout: defaultConnectTimeout, KeepAlive: defaultClientTTL, }, } @@ -159,11 +144,18 @@ func (cm *dnsClientManager) getDNSClient() *dnsClient { } } - // get first unused from pool + // get cached client from pool now := time.Now().UTC() - for _, dc := range cm.pool { - if dc != nil && !dc.inUse && now.Before(dc.useUntil) { - dc.inUse = true + +poolLoop: + for { + dc, ok := cm.pool.Get().(*dnsClient) + switch { + case !ok || dc == nil: // cache empty (probably, pool may always return nil!) + break poolLoop // create new + case now.After(dc.useUntil): + continue // get next + default: return dc } } @@ -171,27 +163,11 @@ func (cm *dnsClientManager) getDNSClient() *dnsClient { // no available in pool, create new newClient := &dnsClient{ mgr: cm, - inUse: true, - useUntil: now.Add(cm.ttl), - inPool: true, client: cm.factory(), + useUntil: now.Add(cm.ttl), } newClient.startCleaner() - // find free spot in pool - for poolIndex, dc := range cm.pool { - if dc == nil { - cm.pool[poolIndex] = newClient - newClient.poolIndex = poolIndex - return newClient - } - } - - // append to pool - cm.pool = append(cm.pool, newClient) - newClient.poolIndex = len(cm.pool) - 1 - // TODO: shrink pool again? - return newClient } @@ -200,26 +176,12 @@ func (dc *dnsClient) startCleaner() { // While a single worker to clean all connections may be slightly more performant, this approach focuses on least as possible locking and is simpler, thus less error prone. module.StartWorker("dns client cleanup", func(ctx context.Context) error { select { - case <-time.After(dc.mgr.ttl + time.Second): - dc.mgr.lock.Lock() - cleanNow := dc.dead || !dc.inUse - dc.mgr.lock.Unlock() - - if cleanNow { - dc.destroy() - return nil - } + case <-time.After(dc.mgr.ttl + connectionEOLGracePeriod): + // destroy case <-ctx.Done(): // give a short time before kill for graceful request completion time.Sleep(100 * time.Millisecond) } - - // wait for grace period to end, then kill - select { - case <-time.After(connectionEOLGracePeriod): - case <-ctx.Done(): - } - dc.destroy() return nil }) diff --git a/resolver/pooling_test.go b/resolver/pooling_test.go index dc341f33..3c03c14c 100644 --- a/resolver/pooling_test.go +++ b/resolver/pooling_test.go @@ -2,6 +2,7 @@ package resolver import ( "sync" + "sync/atomic" "testing" "github.com/miekg/dns" @@ -11,7 +12,7 @@ var ( domainFeed = make(chan string) ) -func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Query) { +func testQuery(t *testing.T, wg *sync.WaitGroup, newCnt *uint32, brc *BasicResolverConn, q *Query) { dnsClient := brc.clientManager.getDNSClient() // create query @@ -23,6 +24,9 @@ func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Quer if err != nil { t.Fatalf("failed to connect: %s", err) //nolint:staticcheck } + if new { + atomic.AddUint32(newCnt, 1) + } // query server reply, ttl, err := dnsClient.client.ExchangeWithConn(dnsQuery, conn) @@ -33,8 +37,8 @@ func testQuery(t *testing.T, wg *sync.WaitGroup, brc *BasicResolverConn, q *Quer t.Fatalf("resolved %s, but reply was empty!", q.FQDN) //nolint:staticcheck } - t.Logf("resolved %s with resolver %d (new=%v) in %s", q.FQDN, dnsClient.poolIndex, new, ttl) - dnsClient.done() + t.Logf("resolved %s [new resolver = %v] in %s", q.FQDN, new, ttl) + dnsClient.addToPool() wg.Done() } @@ -54,17 +58,18 @@ func TestClientPooling(t *testing.T) { brc := resolver.Conn.(*BasicResolverConn) wg := &sync.WaitGroup{} + var newCnt uint32 for i := 0; i < 10; i++ { wg.Add(10) for i := 0; i < 10; i++ { - go testQuery(t, wg, brc, &Query{ + go testQuery(t, wg, &newCnt, brc, &Query{ //nolint:staticcheck FQDN: <-domainFeed, QType: dns.Type(dns.TypeA), }) } wg.Wait() - if len(brc.clientManager.pool) != 10 { - t.Fatalf("unexpected pool size: %d", len(brc.clientManager.pool)) + if newCnt > uint32(10+i) { + t.Fatalf("unexpected pool size: %d (limit is %d)", newCnt, 10+i) } } } diff --git a/resolver/resolve.go b/resolver/resolve.go index ee76d5fa..e5c406ae 100644 --- a/resolver/resolve.go +++ b/resolver/resolve.go @@ -14,8 +14,6 @@ import ( ) var ( - mtAsyncResolve = "async resolve" - // basic errors // ErrNotFound is a basic error that will match all "not found" errors @@ -160,7 +158,7 @@ func checkCache(ctx context.Context, q *Query) *RRCache { log.Tracer(ctx).Trace("resolver: serving from cache, requesting new") // resolve async - module.StartLowPriorityMicroTask(&mtAsyncResolve, func(ctx context.Context) error { + module.StartWorker("resolve async", func(ctx context.Context) error { _, _ = resolveAndCache(ctx, q) return nil }) @@ -220,11 +218,6 @@ func resolveAndCache(ctx context.Context, q *Query) (rrCache *RRCache, err error return nil, ErrNoCompliance } - // prep - lastFailBoundary := time.Now().Add( - -time.Duration(nameserverRetryRate()) * time.Second, - ) - // start resolving var i int diff --git a/resolver/resolver.go b/resolver/resolver.go index d44bba8b..8921e2db 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -215,8 +215,8 @@ func (brc *BasicResolverConn) Query(ctx context.Context, q *Query) (*RRCache, er return nil, errors.New("internal error") } - // make client available again - dc.done() + // make client available (again) + dc.addToPool() if resolver.IsBlockedUpstream(reply) { return nil, &BlockedUpstreamError{resolver.GetName()} From e464ee136cc2e16f46a64905f582abdeb00dc177 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 15:14:48 +0200 Subject: [PATCH 32/36] Fix superfluous decision re-evaluations --- network/connection.go | 3 ++- profile/profile-layered.go | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/network/connection.go b/network/connection.go index 2fa5f9df..4c5acbcf 100644 --- a/network/connection.go +++ b/network/connection.go @@ -178,7 +178,8 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection { // remote endpoint Entity: entity, // meta - Started: time.Now().Unix(), + Started: time.Now().Unix(), + profileRevisionCounter: proc.Profile().RevisionCnt(), } } diff --git a/profile/profile-layered.go b/profile/profile-layered.go index 45311662..ab0335a2 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -126,6 +126,18 @@ func (lp *LayeredProfile) getValidityFlag() *abool.AtomicBool { return lp.validityFlag } +// RevisionCnt returns the current profile revision counter. +func (lp *LayeredProfile) RevisionCnt() (revisionCounter uint64) { + if lp == nil { + return 0 + } + + lp.lock.Lock() + defer lp.lock.Unlock() + + return lp.revisionCounter +} + // Update checks for updated profiles and replaces any outdated profiles. func (lp *LayeredProfile) Update() (revisionCounter uint64) { lp.lock.Lock() From 1c5474bdcd75bbbd637ed16be6a221dfa5a817cf Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 15:15:07 +0200 Subject: [PATCH 33/36] Change dns requests to be workers instead of microtasks --- nameserver/nameserver.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index a06fbf61..578d4450 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -24,9 +24,8 @@ import ( ) var ( - module *modules.Module - dnsServer *dns.Server - mtDNSRequest = "dns request" + module *modules.Module + dnsServer *dns.Server listenAddress = "0.0.0.0:53" ipv4Localhost = net.IPv4(127, 0, 0, 1) @@ -63,7 +62,7 @@ func prep() error { func start() error { dnsServer = &dns.Server{Addr: listenAddress, Net: "udp"} - dns.HandleFunc(".", handleRequestAsMicroTask) + dns.HandleFunc(".", handleRequestAsWorker) module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error { err := dnsServer.ListenAndServe() @@ -97,8 +96,8 @@ func returnServerFailure(w dns.ResponseWriter, query *dns.Msg) { _ = w.WriteMsg(m) } -func handleRequestAsMicroTask(w dns.ResponseWriter, query *dns.Msg) { - err := module.RunMicroTask(&mtDNSRequest, func(ctx context.Context) error { +func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) { + err := module.RunWorker("dns request", func(ctx context.Context) error { return handleRequest(ctx, w, query) }) if err != nil { From 467153569125d235bc63a4d6ef577df59d4932be Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 15:35:22 +0200 Subject: [PATCH 34/36] Improve logging --- firewall/master.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/firewall/master.go b/firewall/master.go index a512b202..84389ab6 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -36,7 +36,7 @@ import ( func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { // update profiles and check if communication needs reevaluation if conn.UpdateAndCheck() { - log.Infof("filter: re-evaluating verdict on %s", conn) + log.Tracer(pkt.Ctx()).Infof("filter: re-evaluating verdict on %s", conn) conn.Verdict = network.VerdictUndecided if conn.Entity != nil { @@ -71,10 +71,10 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { // checkPortmasterConnection allows all connection that originate from // portmaster itself. -func checkPortmasterConnection(conn *network.Connection, _ packet.Packet) bool { +func checkPortmasterConnection(conn *network.Connection, pkt packet.Packet) bool { // grant self if conn.Process().Pid == os.Getpid() { - log.Infof("filter: granting own connection %s", conn) + log.Tracer(pkt.Ctx()).Infof("filter: granting own connection %s", conn) conn.Verdict = network.VerdictAccept conn.Internal = true return true @@ -101,12 +101,12 @@ func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { DstPort: pktInfo.DstPort, }) if err != nil { - log.Warningf("filter: failed to find local peer process PID: %s", err) + log.Tracer(pkt.Ctx()).Warningf("filter: failed to find local peer process PID: %s", err) } else { // get primary process otherProcess, err := process.GetOrFindPrimaryProcess(pkt.Ctx(), otherPid) if err != nil { - log.Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err) + log.Tracer(pkt.Ctx()).Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err) } else if otherProcess.Pid == conn.Process().Pid { conn.Accept("connection to self") conn.Internal = true @@ -233,7 +233,7 @@ func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool { return false } -func checkFilterLists(conn *network.Connection, _ packet.Packet) bool { +func checkFilterLists(conn *network.Connection, pkt packet.Packet) bool { // apply privacy filter lists p := conn.Process().Profile() @@ -245,7 +245,7 @@ func checkFilterLists(conn *network.Connection, _ packet.Packet) bool { case endpoints.NoMatch: // nothing to do default: - log.Debugf("filter: filter lists returned unsupported verdict: %s", result) + log.Tracer(pkt.Ctx()).Debugf("filter: filter lists returned unsupported verdict: %s", result) } return false } From 26fd447700130fa8f4b97ff82d3582190d605f7f Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 15:36:11 +0200 Subject: [PATCH 35/36] Switch default action / asking to release level experimental There are, well, many problems... --- firewall/config.go | 4 ++-- profile/config.go | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/firewall/config.go b/firewall/config.go index 4a6a3abf..19d1b9c4 100644 --- a/firewall/config.go +++ b/firewall/config.go @@ -48,7 +48,7 @@ func registerConfig() error { Order: CfgOptionAskWithSystemNotificationsOrder, OptType: config.OptTypeBool, ExpertiseLevel: config.ExpertiseLevelUser, - ReleaseLevel: config.ReleaseLevelStable, + ReleaseLevel: config.ReleaseLevelExperimental, DefaultValue: true, }) if err != nil { @@ -62,7 +62,7 @@ func registerConfig() error { Order: CfgOptionAskTimeoutOrder, OptType: config.OptTypeInt, ExpertiseLevel: config.ExpertiseLevelUser, - ReleaseLevel: config.ReleaseLevelStable, + ReleaseLevel: config.ReleaseLevelExperimental, DefaultValue: 60, }) if err != nil { diff --git a/profile/config.go b/profile/config.go index 58894e86..494bdc7a 100644 --- a/profile/config.go +++ b/profile/config.go @@ -94,6 +94,7 @@ func registerConfiguration() error { Description: `The default filter action when nothing else permits or blocks a connection.`, Order: cfgOptionDefaultActionOrder, OptType: config.OptTypeString, + ReleaseLevel: config.ReleaseLevelExperimental, DefaultValue: "permit", ExternalOptType: "string list", ValidationRegex: "^(permit|ask|block)$", From 46411951f617fbceb7b1fea19aed203dc2133e09 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 20 May 2020 16:43:54 +0200 Subject: [PATCH 36/36] Further improve logging and messages --- firewall/interception.go | 2 +- firewall/master.go | 43 ++++++++++++++++++++-------------------- intel/block_reason.go | 6 +++--- nameserver/nameserver.go | 2 +- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/firewall/interception.go b/firewall/interception.go index 0e3fc9d1..c87a8c57 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -230,7 +230,7 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) { } log.Tracer(pkt.Ctx()).Trace("filter: starting decision process") - DecideOnConnection(conn, pkt) + DecideOnConnection(pkt.Ctx(), conn, pkt) conn.Inspecting = false // TODO: enable inspecting again switch { diff --git a/firewall/master.go b/firewall/master.go index 84389ab6..7f194960 100644 --- a/firewall/master.go +++ b/firewall/master.go @@ -1,6 +1,7 @@ package firewall import ( + "context" "fmt" "os" "path/filepath" @@ -33,10 +34,10 @@ import ( // DecideOnConnection makes a decision about a connection. // When called, the connection and profile is already locked. -func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { +func DecideOnConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) { // update profiles and check if communication needs reevaluation if conn.UpdateAndCheck() { - log.Tracer(pkt.Ctx()).Infof("filter: re-evaluating verdict on %s", conn) + log.Tracer(ctx).Infof("filter: re-evaluating verdict on %s", conn) conn.Verdict = network.VerdictUndecided if conn.Entity != nil { @@ -44,7 +45,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { } } - var deciders = []func(*network.Connection, packet.Packet) bool{ + var deciders = []func(context.Context, *network.Connection, packet.Packet) bool{ checkPortmasterConnection, checkSelfCommunication, checkProfileExists, @@ -60,7 +61,7 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { } for _, decider := range deciders { - if decider(conn, pkt) { + if decider(ctx, conn, pkt) { return } } @@ -71,10 +72,10 @@ func DecideOnConnection(conn *network.Connection, pkt packet.Packet) { // checkPortmasterConnection allows all connection that originate from // portmaster itself. -func checkPortmasterConnection(conn *network.Connection, pkt packet.Packet) bool { +func checkPortmasterConnection(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool { // grant self if conn.Process().Pid == os.Getpid() { - log.Tracer(pkt.Ctx()).Infof("filter: granting own connection %s", conn) + log.Tracer(ctx).Infof("filter: granting own connection %s", conn) conn.Verdict = network.VerdictAccept conn.Internal = true return true @@ -84,7 +85,7 @@ func checkPortmasterConnection(conn *network.Connection, pkt packet.Packet) bool } // checkSelfCommunication checks if the process is communicating with itself. -func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { +func checkSelfCommunication(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool { // check if process is communicating with itself if pkt != nil { // TODO: evaluate the case where different IPs in the 127/8 net are used. @@ -101,12 +102,12 @@ func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { DstPort: pktInfo.DstPort, }) if err != nil { - log.Tracer(pkt.Ctx()).Warningf("filter: failed to find local peer process PID: %s", err) + log.Tracer(ctx).Warningf("filter: failed to find local peer process PID: %s", err) } else { // get primary process - otherProcess, err := process.GetOrFindPrimaryProcess(pkt.Ctx(), otherPid) + otherProcess, err := process.GetOrFindPrimaryProcess(ctx, otherPid) if err != nil { - log.Tracer(pkt.Ctx()).Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err) + log.Tracer(ctx).Warningf("filter: failed to find load local peer process with PID %d: %s", otherPid, err) } else if otherProcess.Pid == conn.Process().Pid { conn.Accept("connection to self") conn.Internal = true @@ -119,7 +120,7 @@ func checkSelfCommunication(conn *network.Connection, pkt packet.Packet) bool { return false } -func checkProfileExists(conn *network.Connection, _ packet.Packet) bool { +func checkProfileExists(_ context.Context, conn *network.Connection, _ packet.Packet) bool { if conn.Process().Profile() == nil { conn.Block("unknown process or profile") return true @@ -127,7 +128,7 @@ func checkProfileExists(conn *network.Connection, _ packet.Packet) bool { return false } -func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool { +func checkEndpointLists(_ context.Context, conn *network.Connection, _ packet.Packet) bool { var result endpoints.EPResult var reason endpoints.Reason @@ -152,7 +153,7 @@ func checkEndpointLists(conn *network.Connection, _ packet.Packet) bool { return false } -func checkConnectionType(conn *network.Connection, _ packet.Packet) bool { +func checkConnectionType(ctx context.Context, conn *network.Connection, _ packet.Packet) bool { p := conn.Process().Profile() // check conn type @@ -177,7 +178,7 @@ func checkConnectionType(conn *network.Connection, _ packet.Packet) bool { return false } -func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool { +func checkConnectionScope(_ context.Context, conn *network.Connection, _ packet.Packet) bool { p := conn.Process().Profile() // check scopes @@ -216,7 +217,7 @@ func checkConnectionScope(conn *network.Connection, _ packet.Packet) bool { return false } -func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool { +func checkBypassPrevention(_ context.Context, conn *network.Connection, _ packet.Packet) bool { if conn.Process().Profile().PreventBypassing() { // check for bypass protection result, reason, reasonCtx := PreventBypassing(conn) @@ -233,7 +234,7 @@ func checkBypassPrevention(conn *network.Connection, _ packet.Packet) bool { return false } -func checkFilterLists(conn *network.Connection, pkt packet.Packet) bool { +func checkFilterLists(ctx context.Context, conn *network.Connection, pkt packet.Packet) bool { // apply privacy filter lists p := conn.Process().Profile() @@ -245,12 +246,12 @@ func checkFilterLists(conn *network.Connection, pkt packet.Packet) bool { case endpoints.NoMatch: // nothing to do default: - log.Tracer(pkt.Ctx()).Debugf("filter: filter lists returned unsupported verdict: %s", result) + log.Tracer(ctx).Debugf("filter: filter lists returned unsupported verdict: %s", result) } return false } -func checkInbound(conn *network.Connection, _ packet.Packet) bool { +func checkInbound(_ context.Context, conn *network.Connection, _ packet.Packet) bool { // implicit default=block for inbound if conn.Inbound { conn.Drop("endpoint is not whitelisted (incoming is always default=block)") @@ -259,7 +260,7 @@ func checkInbound(conn *network.Connection, _ packet.Packet) bool { return false } -func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool { +func checkDefaultPermit(_ context.Context, conn *network.Connection, _ packet.Packet) bool { // check default action p := conn.Process().Profile() if p.DefaultAction() == profile.DefaultActionPermit { @@ -269,7 +270,7 @@ func checkDefaultPermit(conn *network.Connection, _ packet.Packet) bool { return false } -func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool { +func checkAutoPermitRelated(_ context.Context, conn *network.Connection, _ packet.Packet) bool { p := conn.Process().Profile() if !p.DisableAutoPermit() { related, reason := checkRelation(conn) @@ -281,7 +282,7 @@ func checkAutoPermitRelated(conn *network.Connection, _ packet.Packet) bool { return false } -func checkDefaultAction(conn *network.Connection, pkt packet.Packet) bool { +func checkDefaultAction(_ context.Context, conn *network.Connection, pkt packet.Packet) bool { p := conn.Process().Profile() if p.DefaultAction() == profile.DefaultActionAsk { prompt(conn, pkt) diff --git a/intel/block_reason.go b/intel/block_reason.go index 52ad159f..ad140f4f 100644 --- a/intel/block_reason.go +++ b/intel/block_reason.go @@ -71,9 +71,9 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns. for _, lm := range br { blockedBy, err := dns.NewRR(fmt.Sprintf( - `%s 0 IN TXT "was blocked by filter lists %s"`, + `%s 0 IN TXT "blocked by filter lists %s"`, lm.Entity, - strings.Join(lm.ActiveLists, ","), + strings.Join(lm.ActiveLists, ", "), )) if err == nil { rrs = append(rrs, blockedBy) @@ -85,7 +85,7 @@ func (br ListBlockReason) GetExtraRR(_ *dns.Msg, _ string, _ interface{}) []dns. wouldBeBlockedBy, err := dns.NewRR(fmt.Sprintf( `%s 0 IN TXT "would be blocked by filter lists %s"`, lm.Entity, - strings.Join(lm.InactiveLists, ","), + strings.Join(lm.InactiveLists, ", "), )) if err == nil { rrs = append(rrs, wouldBeBlockedBy) diff --git a/nameserver/nameserver.go b/nameserver/nameserver.go index 578d4450..97d51e8f 100644 --- a/nameserver/nameserver.go +++ b/nameserver/nameserver.go @@ -222,7 +222,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er } // check profile before we even get intel and rr - firewall.DecideOnConnection(conn, nil) + firewall.DecideOnConnection(ctx, conn, nil) switch conn.Verdict { case network.VerdictBlock: