Implement review suggestions

This commit is contained in:
Daniel 2020-05-19 16:57:55 +02:00
parent 65a3456165
commit e65ae8b55d
14 changed files with 130 additions and 100 deletions

View file

@ -169,7 +169,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
}
// start tracer
ctx, tracer := log.AddTracer(context.Background())
ctx, tracer := log.AddTracer(ctx)
defer tracer.Submit()
tracer.Tracef("nameserver: handling new request for %s%s from %s:%d, getting connection", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port)

View file

@ -37,7 +37,7 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
process *process.Process
// remote endpoint
Entity *intel.Entity // needs locking, instance is never shared
Entity *intel.Entity
Verdict Verdict
Reason string

View file

@ -8,13 +8,12 @@ import (
"github.com/safing/portmaster/network/socket"
)
const (
unidentifiedProcessID = -1
)
var (
ipHelper *IPHelper
lock sync.RWMutex
// lock locks access to the whole DLL.
// TODO: It's unproven if we can access the iphlpapi.dll concurrently, especially as we might be encountering various versions of the DLL. In the future, we could possibly investigate and improve performance here.
lock sync.RWMutex
)
// GetTCP4Table returns the system table for IPv4 TCP activity.

View file

@ -214,7 +214,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
@ -222,11 +222,11 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8),
Port: convertPort(row.localPort),
},
Remote: socket.Address{
IP: convertIPv4(row.remoteAddr),
Port: uint16(row.remotePort>>8 | row.remotePort<<8),
Port: convertPort(row.remotePort),
},
PID: int(row.owningPid),
})
@ -243,7 +243,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
@ -251,11 +251,11 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8),
Port: convertPort(row.localPort),
},
Remote: socket.Address{
IP: net.IP(row.remoteAddr[:]),
Port: uint16(row.remotePort>>8 | row.remotePort<<8),
Port: convertPort(row.remotePort),
},
PID: int(row.owningPid),
})
@ -271,7 +271,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
@ -286,7 +286,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
binds = append(binds, &socket.BindInfo{
Local: socket.Address{
IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8),
Port: convertPort(row.localPort),
},
PID: int(row.owningPid),
})
@ -303,3 +303,8 @@ func convertIPv4(input uint32) net.IP {
binary.LittleEndian.PutUint32(addressBuf, input)
return net.IP(addressBuf)
}
// convertPort converts ports received from iphlpapi.dll
func convertPort(input uint32) uint16 {
return uint16(input>>8 | input<<8)
}

View file

@ -10,11 +10,9 @@ import (
"sync"
"syscall"
"github.com/safing/portbase/log"
)
"github.com/safing/portmaster/network/socket"
const (
unidentifiedProcessID = -1
"github.com/safing/portbase/log"
)
var (
@ -23,7 +21,7 @@ var (
)
// FindPID returns the pid of the given uid and socket inode.
func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO
func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO
pidsByUserLock.Lock()
defer pidsByUserLock.Unlock()
@ -42,7 +40,7 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO
var checkedUserPids []int
for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
return possiblePID
}
checkedUserPids = append(checkedUserPids, possiblePID)
}
@ -61,7 +59,7 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO
// only check if not already checked
if sort.SearchInts(checkedUserPids, possiblePID) == len {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
return possiblePID
}
}
}
@ -75,13 +73,13 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO
if possibleUID != uid {
for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) {
return possiblePID, true
return possiblePID
}
}
}
}
return unidentifiedProcessID, false
return socket.UnidentifiedProcessID
}
func findSocketFromPid(pid, inode int) bool {

View file

@ -5,6 +5,7 @@ package proc
import (
"bufio"
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
@ -43,12 +44,10 @@ const (
ICMP4
ICMP6
TCP4Data = "/proc/net/tcp"
UDP4Data = "/proc/net/udp"
TCP6Data = "/proc/net/tcp6"
UDP6Data = "/proc/net/udp6"
ICMP4Data = "/proc/net/icmp"
ICMP6Data = "/proc/net/icmp6"
tcp4ProcFile = "/proc/net/tcp"
tcp6ProcFile = "/proc/net/tcp6"
udp4ProcFile = "/proc/net/udp"
udp6ProcFile = "/proc/net/udp6"
UnfetchedProcessID = -2
@ -57,27 +56,47 @@ const (
// GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP4, TCP4Data, convertIPv4)
return getTableFromSource(TCP4, tcp4ProcFile)
}
// GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP6, TCP6Data, convertIPv6)
return getTableFromSource(TCP6, tcp6ProcFile)
}
// GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP4, UDP4Data, convertIPv4)
_, binds, err = getTableFromSource(UDP4, udp4ProcFile)
return
}
// GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP6, UDP6Data, convertIPv6)
_, binds, err = getTableFromSource(UDP6, udp6ProcFile)
return
}
func getTableFromSource(stack uint8, procFile string, ipConverter func(string) net.IP) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
const (
// hint: we split fields by multiple delimiters, see procDelimiter
fieldIndexLocalIP = 1
fieldIndexLocalPort = 2
fieldIndexRemoteIP = 3
fieldIndexRemotePort = 4
fieldIndexUID = 11
fieldIndexInode = 13
)
func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
var ipConverter func(string) net.IP
switch stack {
case TCP4, UDP4:
ipConverter = convertIPv4
case TCP6, UDP6:
ipConverter = convertIPv6
default:
return nil, nil, fmt.Errorf("unsupported table stack: %d", stack)
}
// open file
socketData, err := os.Open(procFile)
@ -91,36 +110,36 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n
scanner.Split(bufio.ScanLines)
// parse
scanner.Scan() // skip first line
scanner.Scan() // skip first row
for scanner.Scan() {
line := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(line) < 14 {
// log.Tracef("process: too short: %s", line)
fields := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(fields) < 14 {
// log.Tracef("process: too short: %s", fields)
continue
}
localIP := ipConverter(line[1])
localIP := ipConverter(fields[fieldIndexLocalIP])
if localIP == nil {
continue
}
localPort, err := strconv.ParseUint(line[2], 16, 16)
localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16)
if err != nil {
log.Warningf("process: could not parse port: %s", err)
continue
}
uid, err := strconv.ParseInt(line[11], 10, 32)
// log.Tracef("uid: %s", line[11])
uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32)
// log.Tracef("uid: %s", fields[fieldIndexUID])
if err != nil {
log.Warningf("process: could not parse uid %s: %s", line[11], err)
log.Warningf("process: could not parse uid %s: %s", fields[11], err)
continue
}
inode, err := strconv.ParseInt(line[13], 10, 32)
// log.Tracef("inode: %s", line[13])
inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32)
// log.Tracef("inode: %s", fields[fieldIndexInode])
if err != nil {
log.Warningf("process: could not parse inode %s: %s", line[13], err)
log.Warningf("process: could not parse inode %s: %s", fields[13], err)
continue
}
@ -139,7 +158,7 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n
case TCP4, TCP6:
if line[5] == tcpListenStateHex {
if fields[5] == tcpListenStateHex {
// listener
binds = append(binds, &socket.BindInfo{
@ -154,12 +173,12 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n
} else {
// connection
remoteIP := ipConverter(line[3])
remoteIP := ipConverter(fields[fieldIndexRemoteIP])
if remoteIP == nil {
continue
}
remotePort, err := strconv.ParseUint(line[4], 16, 16)
remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16)
if err != nil {
log.Warningf("process: could not parse port: %s", err)
continue

View file

@ -14,12 +14,12 @@ func TestSockets(t *testing.T) {
}
fmt.Println("\nTCP 4 connections:")
for _, connection := range connections {
pid, _ := FindPID(connection.UID, connection.Inode)
pid := FindPID(connection.UID, connection.Inode)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners {
pid, _ := FindPID(listener.UID, listener.Inode)
pid := FindPID(listener.UID, listener.Inode)
fmt.Printf("%d: %+v\n", pid, listener)
}
@ -29,12 +29,12 @@ func TestSockets(t *testing.T) {
}
fmt.Println("\nTCP 6 connections:")
for _, connection := range connections {
pid, _ := FindPID(connection.UID, connection.Inode)
pid := FindPID(connection.UID, connection.Inode)
fmt.Printf("%d: %+v\n", pid, connection)
}
fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners {
pid, _ := FindPID(listener.UID, listener.Inode)
pid := FindPID(listener.UID, listener.Inode)
fmt.Printf("%d: %+v\n", pid, listener)
}
@ -44,7 +44,7 @@ func TestSockets(t *testing.T) {
}
fmt.Println("\nUDP 4 binds:")
for _, bind := range binds {
pid, _ := FindPID(bind.UID, bind.Inode)
pid := FindPID(bind.UID, bind.Inode)
fmt.Printf("%d: %+v\n", pid, bind)
}
@ -54,7 +54,7 @@ func TestSockets(t *testing.T) {
}
fmt.Println("\nUDP 6 binds:")
for _, bind := range binds {
pid, _ := FindPID(bind.UID, bind.Inode)
pid := FindPID(bind.UID, bind.Inode)
fmt.Printf("%d: %+v\n", pid, bind)
}
}

View file

@ -2,6 +2,11 @@ package socket
import "net"
const (
// UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops.
UnidentifiedProcessID = -1
)
// ConnectionInfo holds socket information returned by the system.
type ConnectionInfo struct {
Local Address

View file

@ -11,7 +11,11 @@ const (
UDPConnectionTTL = 10 * time.Minute
)
// Exists checks if the given connection is present in the system state tables.
func Exists(pktInfo *packet.Info, now time.Time) (exists bool) {
// TODO: create lookup maps before running a flurry of Exists() checks.
switch {
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
tcp4Lock.Lock()
@ -76,7 +80,10 @@ func existsUDP(
if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
udpConnState, ok := getUDPConnState(socketInfo, udpStates, remoteIP, remotePort)
udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{
IP: remoteIP,
Port: remotePort,
})
switch {
case !ok:
return false

View file

@ -24,10 +24,6 @@ import (
// - switch direction to outbound if outbound packet is seen?
// - IP: Unidentified Process
const (
UnidentifiedProcessID = -1
)
// Errors
var (
ErrConnectionNotFound = errors.New("could not find connection in system state tables")
@ -75,7 +71,7 @@ func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) {
return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo)
default:
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
}
}
@ -119,7 +115,7 @@ func searchTCP(
connections, listeners = updateTables()
}
return UnidentifiedProcessID, false, ErrConnectionNotFound
return socket.UnidentifiedProcessID, false, ErrConnectionNotFound
}
func searchUDP(
@ -170,5 +166,5 @@ func searchUDP(
binds = updateTable()
}
return UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
}

View file

@ -14,24 +14,14 @@ var (
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID {
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode)
if ok {
socketInfo.PID = pid
} else {
socketInfo.PID = UnidentifiedProcessID
}
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
}
return socketInfo.PID, connInbound, nil
}
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID {
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode)
if ok {
socketInfo.PID = pid
} else {
socketInfo.PID = UnidentifiedProcessID
}
socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
}
return socketInfo.PID, connInbound, nil
}

View file

@ -18,9 +18,8 @@ var (
)
func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
// FIXME: repeatable once
connections, listeners, err := getTCP4Table()
var err error
connections, listeners, err = getTCP4Table()
if err != nil {
log.Warningf("state: failed to get TCP4 socket table: %s", err)
return
@ -28,11 +27,12 @@ func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*sock
tcp4Connections = connections
tcp4Listeners = listeners
return tcp4Connections, tcp4Listeners
return
}
func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
connections, listeners, err := getTCP6Table()
var err error
connections, listeners, err = getTCP6Table()
if err != nil {
log.Warningf("state: failed to get TCP6 socket table: %s", err)
return
@ -40,27 +40,29 @@ func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*sock
tcp6Connections = connections
tcp6Listeners = listeners
return tcp6Connections, tcp6Listeners
return
}
func updateUDP4Table() (binds []*socket.BindInfo) {
binds, err := getUDP4Table()
var err error
binds, err = getUDP4Table()
if err != nil {
log.Warningf("state: failed to get UDP4 socket table: %s", err)
return
}
udp4Binds = binds
return udp4Binds
return
}
func updateUDP6Table() (binds []*socket.BindInfo) {
binds, err := getUDP6Table()
var err error
binds, err = getUDP6Table()
if err != nil {
log.Warningf("state: failed to get UDP6 socket table: %s", err)
return
}
udp6Binds = binds
return udp6Binds
return
}

View file

@ -2,7 +2,6 @@ package state
import (
"context"
"net"
"time"
"github.com/safing/portmaster/network/packet"
@ -15,7 +14,7 @@ type udpState struct {
}
const (
UpdConnStateTTL = 72 * time.Hour
UdpConnStateTTL = 72 * time.Hour
UdpConnStateShortenedTTL = 3 * time.Hour
AggressiveCleaningThreshold = 256
)
@ -25,10 +24,10 @@ var (
udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock
)
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16) (udpConnState *udpState, ok bool) {
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)]
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) {
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)]
if ok {
udpConnState, ok = bindMap[makeUDPStateKey(remoteIP, remotePort)]
udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)]
return
}
@ -36,7 +35,7 @@ func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[strin
}
func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) {
localKey := makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)
localKey := makeUDPStateKey(socketInfo.Local)
bindMap, ok := udpStates[localKey]
if !ok {
@ -44,7 +43,10 @@ func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[strin
udpStates[localKey] = bindMap
}
remoteKey := makeUDPStateKey(pktInfo.RemoteIP(), pktInfo.RemotePort())
remoteKey := makeUDPStateKey(socket.Address{
IP: pktInfo.RemoteIP(),
Port: pktInfo.RemotePort(),
})
udpConnState, ok := bindMap[remoteKey]
if !ok {
bindMap[remoteKey] = &udpState{
@ -79,19 +81,18 @@ func cleanStates(
now time.Time,
) {
// compute thresholds
threshold := now.Add(-UpdConnStateTTL)
threshold := now.Add(-UdpConnStateTTL)
shortThreshhold := now.Add(-UdpConnStateShortenedTTL)
// make list of all active keys
// make lookup map of all active keys
bindKeys := make(map[string]struct{})
for _, socketInfo := range binds {
bindKeys[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] = struct{}{}
bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{}
}
// clean the udp state storage
for localKey, bindMap := range udpStates {
_, active := bindKeys[localKey]
if active {
if _, active := bindKeys[localKey]; active {
// clean old entries
for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(threshold) {
@ -113,7 +114,7 @@ func cleanStates(
}
}
func makeUDPStateKey(ip net.IP, port uint16) string {
func makeUDPStateKey(address socket.Address) string {
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
return string(ip) + string(port)
return string(address.IP) + string(address.Port)
}

View file

@ -29,6 +29,14 @@ func prep() error {
}
func start() error {
// Create a dummy directory to which processes change their working directory
// to. Currently this includes the App and the Notifier. The aim is protect
// all other directories and increase compatibility should any process want
// to read or write something to the current working directory. This can also
// be useful in the future to dump data to for debugging. The permission used
// may seem dangerous, but proper permission on the parent directory provide
// (some) protection.
// Processes must _never_ read from this directory.
err := dataroot.Root().ChildDir("exec", 0777).Ensure()
if err != nil {
log.Warningf("ui: failed to create safe exec dir: %s", err)