mirror of
https://github.com/safing/portmaster
synced 2025-09-02 18:49:14 +00:00
Create interface for socket info structs, fix caching bug
From PR Review https://github.com/safing/portmaster/pull/72
This commit is contained in:
parent
3d0e01383f
commit
7c6c4552aa
5 changed files with 54 additions and 52 deletions
|
@ -21,31 +21,18 @@ var (
|
||||||
pidsByUser = make(map[int][]int)
|
pidsByUser = make(map[int][]int)
|
||||||
)
|
)
|
||||||
|
|
||||||
// FindConnectionPID returns the pid of the given socket info.
|
// GetPID returns the already existing pid of the given socket info or searches for it.
|
||||||
func FindConnectionPID(socketInfo *socket.ConnectionInfo) (pid int) {
|
// This also acts as a getter for socket.*Info.PID, as locking for that occurs here.
|
||||||
|
func GetPID(socketInfo socket.Info) (pid int) {
|
||||||
pidsByUserLock.Lock()
|
pidsByUserLock.Lock()
|
||||||
defer pidsByUserLock.Unlock()
|
defer pidsByUserLock.Unlock()
|
||||||
|
|
||||||
if socketInfo.PID != socket.UnidentifiedProcessID {
|
if socketInfo.GetPID() != socket.UnidentifiedProcessID {
|
||||||
return socket.UnidentifiedProcessID
|
return socketInfo.GetPID()
|
||||||
}
|
}
|
||||||
|
|
||||||
pid = findPID(socketInfo.UID, socketInfo.Inode)
|
pid = findPID(socketInfo.GetUID(), socketInfo.GetInode())
|
||||||
socketInfo.PID = pid
|
socketInfo.SetPID(pid)
|
||||||
return pid
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindBindPID returns the pid of the given socket info.
|
|
||||||
func FindBindPID(socketInfo *socket.BindInfo) (pid int) {
|
|
||||||
pidsByUserLock.Lock()
|
|
||||||
defer pidsByUserLock.Unlock()
|
|
||||||
|
|
||||||
if socketInfo.PID != socket.UnidentifiedProcessID {
|
|
||||||
return socket.UnidentifiedProcessID
|
|
||||||
}
|
|
||||||
|
|
||||||
pid = findPID(socketInfo.UID, socketInfo.Inode)
|
|
||||||
socketInfo.PID = pid
|
|
||||||
return pid
|
return pid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,6 +162,9 @@ entryLoop:
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readDirNames only reads the directory names. Using ioutil.ReadDir() would call `lstat` on every
|
||||||
|
// resulting directory name, which we don't need. This function will be called a lot, so we should
|
||||||
|
// refrain from unnecessary work.
|
||||||
func readDirNames(dir string) (names []string) {
|
func readDirNames(dir string) (names []string) {
|
||||||
file, err := os.Open(dir)
|
file, err := os.Open(dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -14,12 +14,12 @@ func TestSockets(t *testing.T) {
|
||||||
}
|
}
|
||||||
fmt.Println("\nTCP 4 connections:")
|
fmt.Println("\nTCP 4 connections:")
|
||||||
for _, connection := range connections {
|
for _, connection := range connections {
|
||||||
pid := FindConnectionPID(connection)
|
pid := GetPID(connection)
|
||||||
fmt.Printf("%d: %+v\n", pid, connection)
|
fmt.Printf("%d: %+v\n", pid, connection)
|
||||||
}
|
}
|
||||||
fmt.Println("\nTCP 4 listeners:")
|
fmt.Println("\nTCP 4 listeners:")
|
||||||
for _, listener := range listeners {
|
for _, listener := range listeners {
|
||||||
pid := FindBindPID(listener)
|
pid := GetPID(listener)
|
||||||
fmt.Printf("%d: %+v\n", pid, listener)
|
fmt.Printf("%d: %+v\n", pid, listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,12 +29,12 @@ func TestSockets(t *testing.T) {
|
||||||
}
|
}
|
||||||
fmt.Println("\nTCP 6 connections:")
|
fmt.Println("\nTCP 6 connections:")
|
||||||
for _, connection := range connections {
|
for _, connection := range connections {
|
||||||
pid := FindConnectionPID(connection)
|
pid := GetPID(connection)
|
||||||
fmt.Printf("%d: %+v\n", pid, connection)
|
fmt.Printf("%d: %+v\n", pid, connection)
|
||||||
}
|
}
|
||||||
fmt.Println("\nTCP 6 listeners:")
|
fmt.Println("\nTCP 6 listeners:")
|
||||||
for _, listener := range listeners {
|
for _, listener := range listeners {
|
||||||
pid := FindBindPID(listener)
|
pid := GetPID(listener)
|
||||||
fmt.Printf("%d: %+v\n", pid, listener)
|
fmt.Printf("%d: %+v\n", pid, listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ func TestSockets(t *testing.T) {
|
||||||
}
|
}
|
||||||
fmt.Println("\nUDP 4 binds:")
|
fmt.Println("\nUDP 4 binds:")
|
||||||
for _, bind := range binds {
|
for _, bind := range binds {
|
||||||
pid := FindBindPID(bind)
|
pid := GetPID(bind)
|
||||||
fmt.Printf("%d: %+v\n", pid, bind)
|
fmt.Printf("%d: %+v\n", pid, bind)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ func TestSockets(t *testing.T) {
|
||||||
}
|
}
|
||||||
fmt.Println("\nUDP 6 binds:")
|
fmt.Println("\nUDP 6 binds:")
|
||||||
for _, bind := range binds {
|
for _, bind := range binds {
|
||||||
pid := FindBindPID(bind)
|
pid := GetPID(bind)
|
||||||
fmt.Printf("%d: %+v\n", pid, bind)
|
fmt.Printf("%d: %+v\n", pid, bind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,3 +29,35 @@ type Address struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Port uint16
|
Port uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Info is a generic interface to both ConnectionInfo and BindInfo.
|
||||||
|
type Info interface {
|
||||||
|
GetPID() int
|
||||||
|
SetPID(int)
|
||||||
|
GetUID() int
|
||||||
|
GetInode() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPID returns the PID.
|
||||||
|
func (i *ConnectionInfo) GetPID() int { return i.PID }
|
||||||
|
|
||||||
|
// SetPID sets the PID to the given value.
|
||||||
|
func (i *ConnectionInfo) SetPID(pid int) { i.PID = pid }
|
||||||
|
|
||||||
|
// GetUID returns the UID.
|
||||||
|
func (i *ConnectionInfo) GetUID() int { return i.UID }
|
||||||
|
|
||||||
|
// GetInode returns the Inode.
|
||||||
|
func (i *ConnectionInfo) GetInode() int { return i.Inode }
|
||||||
|
|
||||||
|
// GetPID returns the PID.
|
||||||
|
func (i *BindInfo) GetPID() int { return i.PID }
|
||||||
|
|
||||||
|
// SetPID sets the PID to the given value.
|
||||||
|
func (i *BindInfo) SetPID(pid int) { i.PID = pid }
|
||||||
|
|
||||||
|
// GetUID returns the UID.
|
||||||
|
func (i *BindInfo) GetUID() int { return i.UID }
|
||||||
|
|
||||||
|
// GetInode returns the Inode.
|
||||||
|
func (i *BindInfo) GetInode() int { return i.Inode }
|
||||||
|
|
|
@ -81,7 +81,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) (
|
||||||
if localPort == socketInfo.Local.Port &&
|
if localPort == socketInfo.Local.Port &&
|
||||||
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
|
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
|
||||||
table.lock.RUnlock()
|
table.lock.RUnlock()
|
||||||
return checkBindPID(socketInfo, true)
|
return checkPID(socketInfo, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ func (table *tcpTable) lookup(pktInfo *packet.Info) (
|
||||||
if localPort == socketInfo.Local.Port &&
|
if localPort == socketInfo.Local.Port &&
|
||||||
localIP.Equal(socketInfo.Local.IP) {
|
localIP.Equal(socketInfo.Local.IP) {
|
||||||
table.lock.RUnlock()
|
table.lock.RUnlock()
|
||||||
return checkConnectionPID(socketInfo, false)
|
return checkPID(socketInfo, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,12 +138,12 @@ func (table *udpTable) lookup(pktInfo *packet.Info) (
|
||||||
|
|
||||||
// do not check direction if remoteIP/Port is not given
|
// do not check direction if remoteIP/Port is not given
|
||||||
if pktInfo.RemotePort() == 0 {
|
if pktInfo.RemotePort() == 0 {
|
||||||
return checkBindPID(socketInfo, pktInfo.Inbound)
|
return checkPID(socketInfo, pktInfo.Inbound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// get direction and return
|
// get direction and return
|
||||||
connInbound := table.getDirection(socketInfo, pktInfo)
|
connInbound := table.getDirection(socketInfo, pktInfo)
|
||||||
return checkBindPID(socketInfo, connInbound)
|
return checkPID(socketInfo, connInbound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,30 +14,10 @@ var (
|
||||||
getUDP6Table = proc.GetUDP6Table
|
getUDP6Table = proc.GetUDP6Table
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
|
func checkPID(socketInfo socket.Info, connInbound bool) (pid int, inbound bool, err error) {
|
||||||
for i := 0; i <= lookupRetries; i++ {
|
for i := 0; i <= lookupRetries; i++ {
|
||||||
// look for PID
|
// look for PID
|
||||||
pid = proc.FindConnectionPID(socketInfo)
|
pid = proc.GetPID(socketInfo)
|
||||||
if pid != socket.UnidentifiedProcessID {
|
|
||||||
// if we found a PID, return
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// every time, except for the last iteration
|
|
||||||
if i < lookupRetries {
|
|
||||||
// we found no PID, we could have been too fast, give the kernel some time to think
|
|
||||||
// back off timer: with 3ms baseWaitTime: 3, 6, 9, 12, 15, 18, 21ms - 84ms in total
|
|
||||||
time.Sleep(time.Duration(i+1) * baseWaitTime)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pid, connInbound, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
|
|
||||||
for i := 0; i <= lookupRetries; i++ {
|
|
||||||
// look for PID
|
|
||||||
pid = proc.FindBindPID(socketInfo)
|
|
||||||
if pid != socket.UnidentifiedProcessID {
|
if pid != socket.UnidentifiedProcessID {
|
||||||
// if we found a PID, return
|
// if we found a PID, return
|
||||||
break
|
break
|
||||||
|
|
Loading…
Add table
Reference in a new issue