Implement review suggestions

This commit is contained in:
Daniel 2020-10-19 13:26:53 +02:00
parent 8502d83e27
commit a560262818
3 changed files with 72 additions and 34 deletions

View file

@ -5,7 +5,6 @@ package proc
import ( import (
"fmt" "fmt"
"os" "os"
"sync"
"time" "time"
"github.com/safing/portmaster/network/socket" "github.com/safing/portmaster/network/socket"
@ -14,8 +13,6 @@ import (
) )
var ( var (
socketInfoLock sync.RWMutex
baseWaitTime = 3 * time.Millisecond baseWaitTime = 3 * time.Millisecond
lookupRetries = 3 lookupRetries = 3
) )
@ -24,9 +21,7 @@ var (
// This also acts as a getter for socket.*Info.PID, as locking for that occurs here. // This also acts as a getter for socket.*Info.PID, as locking for that occurs here.
func GetPID(socketInfo socket.Info) (pid int) { func GetPID(socketInfo socket.Info) (pid int) {
// Get currently assigned PID to the socket info. // Get currently assigned PID to the socket info.
socketInfoLock.RLock()
currentPid := socketInfo.GetPID() currentPid := socketInfo.GetPID()
socketInfoLock.RUnlock()
// If the current PID already is valid (ie. not unidentified), return it immediately. // If the current PID already is valid (ie. not unidentified), return it immediately.
if currentPid != socket.UnidentifiedProcessID { if currentPid != socket.UnidentifiedProcessID {
@ -34,12 +29,11 @@ func GetPID(socketInfo socket.Info) (pid int) {
} }
// Find PID for the given UID and inode. // Find PID for the given UID and inode.
pid = findPID(socketInfo.GetUID(), socketInfo.GetInode()) // uid, inode := socketInfo.GetUIDandInode()
pid = findPID(socketInfo.GetUIDandInode())
// Set the newly found PID on the socket info. // Set the newly found PID on the socket info.
socketInfoLock.Lock()
socketInfo.SetPID(pid) socketInfo.SetPID(pid)
socketInfoLock.Unlock()
// Return found PID. // Return found PID.
return pid return pid
@ -63,9 +57,11 @@ func findPID(uid, inode int) (pid int) {
// If we have found PIDs, search them. // If we have found PIDs, search them.
if ok { if ok {
for _, pid = range pids { // Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to
if findSocketFromPid(pid, socketName) { // be searched for.
return pid for i := len(pids) - 1; i > 0; i-- {
if findSocketFromPid(pids[i], socketName) {
return pids[i]
} }
} }
} }
@ -76,9 +72,11 @@ func findPID(uid, inode int) (pid int) {
updatePids() updatePids()
pids, ok = getPidsByUser(uid) pids, ok = getPidsByUser(uid)
if ok { if ok {
for _, pid = range pids { // Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to
if findSocketFromPid(pid, socketName) { // be searched for.
return pid for i := len(pids) - 1; i > 0; i-- {
if findSocketFromPid(pids[i], socketName) {
return pids[i]
} }
} }
} }

View file

@ -66,14 +66,7 @@ func updatePids() {
pidCnt++ pidCnt++
} }
// Reverse slice orders, because higher PIDs will be more likely to be searched for. // log.Tracef("proc: updated PID table with %d entries", pidCnt)
for _, slice := range newPidsByUser {
for i, j := 0, len(slice)-1; i < j; i, j = i+1, j-1 {
slice[i], slice[j] = slice[j], slice[i]
}
}
log.Tracef("proc: updated PID table with %d entries", pidCnt)
pidsByUserLock.Lock() pidsByUserLock.Lock()
defer pidsByUserLock.Unlock() defer pidsByUserLock.Unlock()

View file

@ -1,6 +1,9 @@
package socket package socket
import "net" import (
"net"
"sync"
)
const ( const (
// UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops. // UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops.
@ -9,6 +12,8 @@ const (
// ConnectionInfo holds socket information returned by the system. // ConnectionInfo holds socket information returned by the system.
type ConnectionInfo struct { type ConnectionInfo struct {
sync.Mutex
Local Address Local Address
Remote Address Remote Address
PID int PID int
@ -18,6 +23,8 @@ type ConnectionInfo struct {
// BindInfo holds socket information returned by the system. // BindInfo holds socket information returned by the system.
type BindInfo struct { type BindInfo struct {
sync.Mutex
Local Address Local Address
PID int PID int
UID int UID int
@ -35,29 +42,69 @@ type Info interface {
GetPID() int GetPID() int
SetPID(int) SetPID(int)
GetUID() int GetUID() int
GetInode() int GetUIDandInode() (int, int)
} }
// GetPID returns the PID. // GetPID returns the PID.
func (i *ConnectionInfo) GetPID() int { return i.PID } func (i *ConnectionInfo) GetPID() int {
i.Lock()
defer i.Unlock()
return i.PID
}
// SetPID sets the PID to the given value. // SetPID sets the PID to the given value.
func (i *ConnectionInfo) SetPID(pid int) { i.PID = pid } func (i *ConnectionInfo) SetPID(pid int) {
i.Lock()
defer i.Unlock()
i.PID = pid
}
// GetUID returns the UID. // GetUID returns the UID.
func (i *ConnectionInfo) GetUID() int { return i.UID } func (i *ConnectionInfo) GetUID() int {
i.Lock()
defer i.Unlock()
// GetInode returns the Inode. return i.UID
func (i *ConnectionInfo) GetInode() int { return i.Inode } }
// GetUIDandInode returns the UID and Inode.
func (i *ConnectionInfo) GetUIDandInode() (int, int) {
i.Lock()
defer i.Unlock()
return i.UID, i.Inode
}
// GetPID returns the PID. // GetPID returns the PID.
func (i *BindInfo) GetPID() int { return i.PID } func (i *BindInfo) GetPID() int {
i.Lock()
defer i.Unlock()
return i.PID
}
// SetPID sets the PID to the given value. // SetPID sets the PID to the given value.
func (i *BindInfo) SetPID(pid int) { i.PID = pid } func (i *BindInfo) SetPID(pid int) {
i.Lock()
defer i.Unlock()
i.PID = pid
}
// GetUID returns the UID. // GetUID returns the UID.
func (i *BindInfo) GetUID() int { return i.UID } func (i *BindInfo) GetUID() int {
i.Lock()
defer i.Unlock()
// GetInode returns the Inode. return i.UID
func (i *BindInfo) GetInode() int { return i.Inode } }
// GetUIDandInode returns the UID and Inode.
func (i *BindInfo) GetUIDandInode() (int, int) {
i.Lock()
defer i.Unlock()
return i.UID, i.Inode
}