diff --git a/network/proc/findpid.go b/network/proc/findpid.go index 6808960e..3b123f4f 100644 --- a/network/proc/findpid.go +++ b/network/proc/findpid.go @@ -16,21 +16,48 @@ import ( ) var ( + // pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs. pidsByUserLock sync.Mutex pidsByUser = make(map[int][]int) ) -// FindPID returns the pid of the given uid and socket inode. -func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO +// FindConnectionPID returns the pid of the given socket info. +func FindConnectionPID(socketInfo *socket.ConnectionInfo) (pid int) { pidsByUserLock.Lock() defer pidsByUserLock.Unlock() + if socketInfo.PID != socket.UnidentifiedProcessID { + return socket.UnidentifiedProcessID + } + + pid = findPID(socketInfo.UID, socketInfo.Inode) + socketInfo.PID = pid + return pid +} + +// FindBindPID returns the pid of the given socket info. +func FindBindPID(socketInfo *socket.BindInfo) (pid int) { + pidsByUserLock.Lock() + defer pidsByUserLock.Unlock() + + if socketInfo.PID != socket.UnidentifiedProcessID { + return socket.UnidentifiedProcessID + } + + pid = findPID(socketInfo.UID, socketInfo.Inode) + socketInfo.PID = pid + return pid +} + +// findPID returns the pid of the given uid and socket inode. +func findPID(uid, inode int) (pid int) { //nolint:gocognit // TODO + pidsUpdated := false // get pids of user, update if missing pids, ok := pidsByUser[uid] if !ok { - // log.Trace("process: no processes of user, updating table") + // log.Trace("proc: no processes of user, updating table") updatePids() pidsUpdated = true pids, ok = pidsByUser[uid] @@ -46,7 +73,7 @@ func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO } // if we fail on the first run and have not updated, update and check the ones we haven't tried so far. if !pidsUpdated { - // log.Trace("process: socket not found in any process of user, updating table") + // log.Trace("proc: socket not found in any process of user, updating table") // update updatePids() // sort for faster search @@ -67,7 +94,7 @@ func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO } // check all other pids - // log.Trace("process: socket not found in any process of user, checking all pids") + // log.Trace("proc: socket not found in any process of user, checking all pids") // TODO: find best order for pidsByUser for best performance for possibleUID, pids := range pidsByUser { if possibleUID != uid { @@ -93,7 +120,7 @@ func findSocketFromPid(pid, inode int) bool { link, err := os.Readlink(fmt.Sprintf("/proc/%d/fd/%s", pid, entry)) if err != nil { if !os.IsNotExist(err) { - log.Warningf("process: failed to read link /proc/%d/fd/%s: %s", pid, entry, err) + log.Warningf("proc: failed to read link /proc/%d/fd/%s: %s", pid, entry, err) } continue } @@ -122,12 +149,12 @@ entryLoop: statData, err := os.Stat(fmt.Sprintf("/proc/%d", pid)) if err != nil { - log.Warningf("process: could not stat /proc/%d: %s", pid, err) + log.Warningf("proc: could not stat /proc/%d: %s", pid, err) continue entryLoop } sys, ok := statData.Sys().(*syscall.Stat_t) if !ok { - log.Warningf("process: unable to parse /proc/%d: wrong type", pid) + log.Warningf("proc: unable to parse /proc/%d: wrong type", pid) continue entryLoop } @@ -152,14 +179,14 @@ func readDirNames(dir string) (names []string) { file, err := os.Open(dir) if err != nil { if !os.IsNotExist(err) { - log.Warningf("process: could not open directory %s: %s", dir, err) + log.Warningf("proc: could not open directory %s: %s", dir, err) } return } defer file.Close() names, err = file.Readdirnames(0) if err != nil { - log.Warningf("process: could not get entries from directory %s: %s", dir, err) + log.Warningf("proc: could not get entries from directory %s: %s", dir, err) return []string{} } return diff --git a/network/proc/tables.go b/network/proc/tables.go index bf4a3eb0..45055d36 100644 --- a/network/proc/tables.go +++ b/network/proc/tables.go @@ -49,8 +49,6 @@ const ( udp4ProcFile = "/proc/net/udp" udp6ProcFile = "/proc/net/udp6" - UnfetchedProcessID = -2 - tcpListenStateHex = "0A" ) @@ -114,7 +112,7 @@ func getTableFromSource(stack uint8, procFile string) (connections []*socket.Con for scanner.Scan() { fields := strings.FieldsFunc(scanner.Text(), procDelimiter) if len(fields) < 14 { - // log.Tracef("process: too short: %s", fields) + // log.Tracef("proc: too short: %s", fields) continue } @@ -125,21 +123,21 @@ func getTableFromSource(stack uint8, procFile string) (connections []*socket.Con localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16) if err != nil { - log.Warningf("process: could not parse port: %s", err) + log.Warningf("proc: could not parse port: %s", err) continue } uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32) // log.Tracef("uid: %s", fields[fieldIndexUID]) if err != nil { - log.Warningf("process: could not parse uid %s: %s", fields[11], err) + log.Warningf("proc: could not parse uid %s: %s", fields[11], err) continue } inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32) // log.Tracef("inode: %s", fields[fieldIndexInode]) if err != nil { - log.Warningf("process: could not parse inode %s: %s", fields[13], err) + log.Warningf("proc: could not parse inode %s: %s", fields[13], err) continue } @@ -151,7 +149,7 @@ func getTableFromSource(stack uint8, procFile string) (connections []*socket.Con IP: localIP, Port: uint16(localPort), }, - PID: UnfetchedProcessID, + PID: socket.UnidentifiedProcessID, UID: int(uid), Inode: int(inode), }) @@ -166,7 +164,7 @@ func getTableFromSource(stack uint8, procFile string) (connections []*socket.Con IP: localIP, Port: uint16(localPort), }, - PID: UnfetchedProcessID, + PID: socket.UnidentifiedProcessID, UID: int(uid), Inode: int(inode), }) @@ -180,7 +178,7 @@ func getTableFromSource(stack uint8, procFile string) (connections []*socket.Con remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16) if err != nil { - log.Warningf("process: could not parse port: %s", err) + log.Warningf("proc: could not parse port: %s", err) continue } @@ -193,7 +191,7 @@ func getTableFromSource(stack uint8, procFile string) (connections []*socket.Con IP: remoteIP, Port: uint16(remotePort), }, - PID: UnfetchedProcessID, + PID: socket.UnidentifiedProcessID, UID: int(uid), Inode: int(inode), }) @@ -211,11 +209,11 @@ func procDelimiter(c rune) bool { func convertIPv4(data string) net.IP { decoded, err := hex.DecodeString(data) if err != nil { - log.Warningf("process: could not parse IPv4 %s: %s", data, err) + log.Warningf("proc: could not parse IPv4 %s: %s", data, err) return nil } if len(decoded) != 4 { - log.Warningf("process: decoded IPv4 %s has wrong length", decoded) + log.Warningf("proc: decoded IPv4 %s has wrong length", decoded) return nil } ip := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) @@ -225,11 +223,11 @@ func convertIPv4(data string) net.IP { func convertIPv6(data string) net.IP { decoded, err := hex.DecodeString(data) if err != nil { - log.Warningf("process: could not parse IPv6 %s: %s", data, err) + log.Warningf("proc: could not parse IPv6 %s: %s", data, err) return nil } if len(decoded) != 16 { - log.Warningf("process: decoded IPv6 %s has wrong length", decoded) + log.Warningf("proc: decoded IPv6 %s has wrong length", decoded) return nil } ip := net.IP(decoded) diff --git a/network/proc/tables_test.go b/network/proc/tables_test.go index eed12ce8..bcce8498 100644 --- a/network/proc/tables_test.go +++ b/network/proc/tables_test.go @@ -14,12 +14,12 @@ func TestSockets(t *testing.T) { } fmt.Println("\nTCP 4 connections:") for _, connection := range connections { - pid := FindPID(connection.UID, connection.Inode) + pid := FindConnectionPID(connection) fmt.Printf("%d: %+v\n", pid, connection) } fmt.Println("\nTCP 4 listeners:") for _, listener := range listeners { - pid := FindPID(listener.UID, listener.Inode) + pid := FindBindPID(listener) fmt.Printf("%d: %+v\n", pid, listener) } @@ -29,12 +29,12 @@ func TestSockets(t *testing.T) { } fmt.Println("\nTCP 6 connections:") for _, connection := range connections { - pid := FindPID(connection.UID, connection.Inode) + pid := FindConnectionPID(connection) fmt.Printf("%d: %+v\n", pid, connection) } fmt.Println("\nTCP 6 listeners:") for _, listener := range listeners { - pid := FindPID(listener.UID, listener.Inode) + pid := FindBindPID(listener) fmt.Printf("%d: %+v\n", pid, listener) } @@ -44,7 +44,7 @@ func TestSockets(t *testing.T) { } fmt.Println("\nUDP 4 binds:") for _, bind := range binds { - pid := FindPID(bind.UID, bind.Inode) + pid := FindBindPID(bind) fmt.Printf("%d: %+v\n", pid, bind) } @@ -54,7 +54,7 @@ func TestSockets(t *testing.T) { } fmt.Println("\nUDP 6 binds:") for _, bind := range binds { - pid := FindPID(bind.UID, bind.Inode) + pid := FindBindPID(bind) fmt.Printf("%d: %+v\n", pid, bind) } } diff --git a/network/state/system_linux.go b/network/state/system_linux.go index b902c58c..c4b67777 100644 --- a/network/state/system_linux.go +++ b/network/state/system_linux.go @@ -1,6 +1,8 @@ package state import ( + "time" + "github.com/safing/portmaster/network/proc" "github.com/safing/portmaster/network/socket" ) @@ -13,15 +15,41 @@ var ( ) func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { - if socketInfo.PID == proc.UnfetchedProcessID { - socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode) + for i := 0; i <= lookupRetries; i++ { + // look for PID + pid = proc.FindConnectionPID(socketInfo) + if pid != socket.UnidentifiedProcessID { + // if we found a PID, return + break + } + + // every time, except for the last iteration + if i < lookupRetries { + // we found no PID, we could have been too fast, give the kernel some time to think + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) + } } - return socketInfo.PID, connInbound, nil + + return pid, connInbound, nil } func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { - if socketInfo.PID == proc.UnfetchedProcessID { - socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode) + for i := 0; i <= lookupRetries; i++ { + // look for PID + pid = proc.FindBindPID(socketInfo) + if pid != socket.UnidentifiedProcessID { + // if we found a PID, return + break + } + + // every time, except for the last iteration + if i < lookupRetries { + // we found no PID, we could have been too fast, give the kernel some time to think + // back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total + time.Sleep(time.Duration(i+1) * baseWaitTime) + } } - return socketInfo.PID, connInbound, nil + + return pid, connInbound, nil }