diff --git a/firewall/interception.go b/firewall/interception.go index a2146e97..04f7f4af 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -55,12 +55,12 @@ func interceptionStop() error { return interception.Stop() } -func handlePacket(pkt packet.Packet) { +func handlePacket(ctx context.Context, pkt packet.Packet) { if fastTrackedPermit(pkt) { return } - traceCtx, tracer := log.AddTracer(context.Background()) + traceCtx, tracer := log.AddTracer(ctx) if tracer != nil { pkt.SetCtx(traceCtx) tracer.Tracef("filter: handling packet: %s", pkt) @@ -318,7 +318,10 @@ func packetHandler(ctx context.Context) error { case <-ctx.Done(): return nil case pkt := <-interception.Packets: - handlePacket(pkt) + interceptionModule.StartWorker("initial packet handler", func(ctx context.Context) error { + handlePacket(ctx, pkt) + return nil + }) } } } diff --git a/network/proc/findpid.go b/network/proc/findpid.go index 0610e361..6187f048 100644 --- a/network/proc/findpid.go +++ b/network/proc/findpid.go @@ -5,10 +5,8 @@ package proc import ( "fmt" "os" - "sort" - "strconv" "sync" - "syscall" + "time" "github.com/safing/portmaster/network/socket" @@ -16,88 +14,91 @@ import ( ) var ( - // pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs. - pidsByUserLock sync.Mutex - pidsByUser = make(map[int][]int) + socketInfoLock sync.RWMutex + + baseWaitTime = 3 * time.Millisecond + lookupRetries = 3 ) // GetPID returns the already existing pid of the given socket info or searches for it. // This also acts as a getter for socket.*Info.PID, as locking for that occurs here. func GetPID(socketInfo socket.Info) (pid int) { - pidsByUserLock.Lock() - defer pidsByUserLock.Unlock() + // Get currently assigned PID to the socket info. + socketInfoLock.RLock() + currentPid := socketInfo.GetPID() + socketInfoLock.RUnlock() - if socketInfo.GetPID() != socket.UnidentifiedProcessID { - return socketInfo.GetPID() + // If the current PID already is valid (ie. not unidentified), return it immediately. + if currentPid != socket.UnidentifiedProcessID { + return currentPid } + // Find PID for the given UID and inode. pid = findPID(socketInfo.GetUID(), socketInfo.GetInode()) + + // Set the newly found PID on the socket info. + socketInfoLock.Lock() socketInfo.SetPID(pid) + socketInfoLock.Unlock() + + // Return found PID. return pid } // findPID returns the pid of the given uid and socket inode. -func findPID(uid, inode int) (pid int) { //nolint:gocognit // TODO +func findPID(uid, inode int) (pid int) { + socketName := fmt.Sprintf("socket:[%d]", inode) - pidsUpdated := false + for i := 0; i <= lookupRetries; i++ { + var pidsUpdated bool - // get pids of user, update if missing - pids, ok := pidsByUser[uid] - if !ok { - // log.Trace("proc: no processes of user, updating table") - updatePids() - pidsUpdated = true - pids, ok = pidsByUser[uid] - } - if ok { - // if user has pids, start checking them first - var checkedUserPids []int - for _, possiblePID := range pids { - if findSocketFromPid(possiblePID, inode) { - return possiblePID - } - checkedUserPids = append(checkedUserPids, possiblePID) - } - // if we fail on the first run and have not updated, update and check the ones we haven't tried so far. - if !pidsUpdated { - // log.Trace("proc: socket not found in any process of user, updating table") - // update + // Get all pids for the given uid. + pids, ok := getPidsByUser(uid) + if !ok { + // If we cannot find the uid in the map, update it. updatePids() - // sort for faster search - for i, j := 0, len(checkedUserPids)-1; i < j; i, j = i+1, j-1 { - checkedUserPids[i], checkedUserPids[j] = checkedUserPids[j], checkedUserPids[i] + pidsUpdated = true + pids, ok = getPidsByUser(uid) + } + + // If we have found PIDs, search them. + if ok { + for _, pid = range pids { + if findSocketFromPid(pid, socketName) { + return pid + } } - len := len(checkedUserPids) - // check unchecked pids - for _, possiblePID := range pids { - // only check if not already checked - if sort.SearchInts(checkedUserPids, possiblePID) == len { - if findSocketFromPid(possiblePID, inode) { - return possiblePID + } + + // If we still cannot find our socket, and haven't yet updated the PID map, + // do this and then check again immediately. + if !pidsUpdated { + updatePids() + pids, ok = getPidsByUser(uid) + if ok { + for _, pid = range pids { + if findSocketFromPid(pid, socketName) { + return pid } } } } - } - // check all other pids - // log.Trace("proc: socket not found in any process of user, checking all pids") - // TODO: find best order for pidsByUser for best performance - for possibleUID, pids := range pidsByUser { - if possibleUID != uid { - for _, possiblePID := range pids { - if findSocketFromPid(possiblePID, inode) { - return possiblePID - } - } + // We have updated the PID map, but still cannot find anything. + // So, there is nothing we can other than wait a little for the kernel to + // populate the information. + + // Wait after each try, except for the last iteration + if i < lookupRetries { + // Wait in back-off fashion - with 3ms baseWaitTime: 3, 6, 9 - 18ms in total. + time.Sleep(time.Duration(i+1) * baseWaitTime) } } return socket.UnidentifiedProcessID } -func findSocketFromPid(pid, inode int) bool { - socketName := fmt.Sprintf("socket:[%d]", inode) +func findSocketFromPid(pid int, socketName string) bool { entries := readDirNames(fmt.Sprintf("/proc/%d/fd", pid)) if len(entries) == 0 { return false @@ -119,49 +120,6 @@ func findSocketFromPid(pid, inode int) bool { return false } -func updatePids() { - pidsByUser = make(map[int][]int) - - entries := readDirNames("/proc") - if len(entries) == 0 { - return - } - -entryLoop: - for _, entry := range entries { - pid, err := strconv.ParseInt(entry, 10, 32) - if err != nil { - continue entryLoop - } - - statData, err := os.Stat(fmt.Sprintf("/proc/%d", pid)) - if err != nil { - log.Warningf("proc: could not stat /proc/%d: %s", pid, err) - continue entryLoop - } - sys, ok := statData.Sys().(*syscall.Stat_t) - if !ok { - log.Warningf("proc: unable to parse /proc/%d: wrong type", pid) - continue entryLoop - } - - pids, ok := pidsByUser[int(sys.Uid)] - if ok { - pidsByUser[int(sys.Uid)] = append(pids, int(pid)) - } else { - pidsByUser[int(sys.Uid)] = []int{int(pid)} - } - - } - - for _, slice := range pidsByUser { - for i, j := 0, len(slice)-1; i < j; i, j = i+1, j-1 { - slice[i], slice[j] = slice[j], slice[i] - } - } - -} - // readDirNames only reads the directory names. Using ioutil.ReadDir() would call `lstat` on every // resulting directory name, which we don't need. This function will be called a lot, so we should // refrain from unnecessary work. diff --git a/network/proc/pids_by_user.go b/network/proc/pids_by_user.go new file mode 100644 index 00000000..6ebc001a --- /dev/null +++ b/network/proc/pids_by_user.go @@ -0,0 +1,82 @@ +package proc + +import ( + "fmt" + "os" + "strconv" + "sync" + "syscall" + + "github.com/safing/portbase/log" + "github.com/safing/portbase/utils" +) + +var ( + // pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs. + pidsByUser = make(map[int][]int) + pidsByUserLock sync.RWMutex + fetchPidsByUser utils.OnceAgain +) + +// getPidsByUser returns the cached PIDs for the given UID. +func getPidsByUser(uid int) (pids []int, ok bool) { + pidsByUserLock.RLock() + defer pidsByUserLock.RUnlock() + + pids, ok = pidsByUser[uid] + return +} + +// updatePids fetches and creates a new pidsByUser map using utils.OnceAgain. +func updatePids() { + fetchPidsByUser.Do(func() { + newPidsByUser := make(map[int][]int) + pidCnt := 0 + + entries := readDirNames("/proc") + if len(entries) == 0 { + log.Warning("proc: found no PIDs in /proc") + return + } + + entryLoop: + for _, entry := range entries { + pid, err := strconv.ParseInt(entry, 10, 32) + if err != nil { + continue entryLoop + } + + statData, err := os.Stat(fmt.Sprintf("/proc/%d", pid)) + if err != nil { + log.Warningf("proc: could not stat /proc/%d: %s", pid, err) + continue entryLoop + } + sys, ok := statData.Sys().(*syscall.Stat_t) + if !ok { + log.Warningf("proc: unable to parse /proc/%d: wrong type", pid) + continue entryLoop + } + + pids, ok := newPidsByUser[int(sys.Uid)] + if ok { + newPidsByUser[int(sys.Uid)] = append(pids, int(pid)) + } else { + newPidsByUser[int(sys.Uid)] = []int{int(pid)} + } + pidCnt++ + } + + // Reverse slice orders, because higher PIDs will be more likely to be searched for. + for _, slice := range newPidsByUser { + for i, j := 0, len(slice)-1; i < j; i, j = i+1, j-1 { + slice[i], slice[j] = slice[j], slice[i] + } + } + + log.Tracef("proc: updated PID table with %d entries", pidCnt) + + pidsByUserLock.Lock() + defer pidsByUserLock.Unlock() + pidsByUser = newPidsByUser + }) +} diff --git a/network/proc/tables.go b/network/proc/tables.go index 45055d36..94f3198b 100644 --- a/network/proc/tables.go +++ b/network/proc/tables.go @@ -35,7 +35,7 @@ Cache every step! */ -// Network Related Constants +// Network Related Constants. const ( TCP4 uint8 = iota UDP4