Implement review suggestions

This commit is contained in:
Daniel 2020-05-19 16:57:55 +02:00
parent 65a3456165
commit e65ae8b55d
14 changed files with 130 additions and 100 deletions

View file

@ -169,7 +169,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, query *dns.Msg) er
} }
// start tracer // start tracer
ctx, tracer := log.AddTracer(context.Background()) ctx, tracer := log.AddTracer(ctx)
defer tracer.Submit() defer tracer.Submit()
tracer.Tracef("nameserver: handling new request for %s%s from %s:%d, getting connection", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port) tracer.Tracef("nameserver: handling new request for %s%s from %s:%d, getting connection", q.FQDN, q.QType, remoteAddr.IP, remoteAddr.Port)

View file

@ -37,7 +37,7 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
process *process.Process process *process.Process
// remote endpoint // remote endpoint
Entity *intel.Entity // needs locking, instance is never shared Entity *intel.Entity
Verdict Verdict Verdict Verdict
Reason string Reason string

View file

@ -8,13 +8,12 @@ import (
"github.com/safing/portmaster/network/socket" "github.com/safing/portmaster/network/socket"
) )
const (
unidentifiedProcessID = -1
)
var ( var (
ipHelper *IPHelper ipHelper *IPHelper
lock sync.RWMutex
// lock locks access to the whole DLL.
// TODO: It's unproven if we can access the iphlpapi.dll concurrently, especially as we might be encountering various versions of the DLL. In the future, we could possibly investigate and improve performance here.
lock sync.RWMutex
) )
// GetTCP4Table returns the system table for IPv4 TCP activity. // GetTCP4Table returns the system table for IPv4 TCP activity.

View file

@ -214,7 +214,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
binds = append(binds, &socket.BindInfo{ binds = append(binds, &socket.BindInfo{
Local: socket.Address{ Local: socket.Address{
IP: convertIPv4(row.localAddr), IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8), Port: convertPort(row.localPort),
}, },
PID: int(row.owningPid), PID: int(row.owningPid),
}) })
@ -222,11 +222,11 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
connections = append(connections, &socket.ConnectionInfo{ connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{ Local: socket.Address{
IP: convertIPv4(row.localAddr), IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8), Port: convertPort(row.localPort),
}, },
Remote: socket.Address{ Remote: socket.Address{
IP: convertIPv4(row.remoteAddr), IP: convertIPv4(row.remoteAddr),
Port: uint16(row.remotePort>>8 | row.remotePort<<8), Port: convertPort(row.remotePort),
}, },
PID: int(row.owningPid), PID: int(row.owningPid),
}) })
@ -243,7 +243,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
binds = append(binds, &socket.BindInfo{ binds = append(binds, &socket.BindInfo{
Local: socket.Address{ Local: socket.Address{
IP: net.IP(row.localAddr[:]), IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8), Port: convertPort(row.localPort),
}, },
PID: int(row.owningPid), PID: int(row.owningPid),
}) })
@ -251,11 +251,11 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
connections = append(connections, &socket.ConnectionInfo{ connections = append(connections, &socket.ConnectionInfo{
Local: socket.Address{ Local: socket.Address{
IP: net.IP(row.localAddr[:]), IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8), Port: convertPort(row.localPort),
}, },
Remote: socket.Address{ Remote: socket.Address{
IP: net.IP(row.remoteAddr[:]), IP: net.IP(row.remoteAddr[:]),
Port: uint16(row.remotePort>>8 | row.remotePort<<8), Port: convertPort(row.remotePort),
}, },
PID: int(row.owningPid), PID: int(row.owningPid),
}) })
@ -271,7 +271,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
binds = append(binds, &socket.BindInfo{ binds = append(binds, &socket.BindInfo{
Local: socket.Address{ Local: socket.Address{
IP: convertIPv4(row.localAddr), IP: convertIPv4(row.localAddr),
Port: uint16(row.localPort>>8 | row.localPort<<8), Port: convertPort(row.localPort),
}, },
PID: int(row.owningPid), PID: int(row.owningPid),
}) })
@ -286,7 +286,7 @@ func (ipHelper *IPHelper) getTable(ipVersion, protocol uint8) (connections []*so
binds = append(binds, &socket.BindInfo{ binds = append(binds, &socket.BindInfo{
Local: socket.Address{ Local: socket.Address{
IP: net.IP(row.localAddr[:]), IP: net.IP(row.localAddr[:]),
Port: uint16(row.localPort>>8 | row.localPort<<8), Port: convertPort(row.localPort),
}, },
PID: int(row.owningPid), PID: int(row.owningPid),
}) })
@ -303,3 +303,8 @@ func convertIPv4(input uint32) net.IP {
binary.LittleEndian.PutUint32(addressBuf, input) binary.LittleEndian.PutUint32(addressBuf, input)
return net.IP(addressBuf) return net.IP(addressBuf)
} }
// convertPort converts ports received from iphlpapi.dll
func convertPort(input uint32) uint16 {
return uint16(input>>8 | input<<8)
}

View file

@ -10,11 +10,9 @@ import (
"sync" "sync"
"syscall" "syscall"
"github.com/safing/portbase/log" "github.com/safing/portmaster/network/socket"
)
const ( "github.com/safing/portbase/log"
unidentifiedProcessID = -1
) )
var ( var (
@ -23,7 +21,7 @@ var (
) )
// FindPID returns the pid of the given uid and socket inode. // FindPID returns the pid of the given uid and socket inode.
func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO func FindPID(uid, inode int) (pid int) { //nolint:gocognit // TODO
pidsByUserLock.Lock() pidsByUserLock.Lock()
defer pidsByUserLock.Unlock() defer pidsByUserLock.Unlock()
@ -42,7 +40,7 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO
var checkedUserPids []int var checkedUserPids []int
for _, possiblePID := range pids { for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) { if findSocketFromPid(possiblePID, inode) {
return possiblePID, true return possiblePID
} }
checkedUserPids = append(checkedUserPids, possiblePID) checkedUserPids = append(checkedUserPids, possiblePID)
} }
@ -61,7 +59,7 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO
// only check if not already checked // only check if not already checked
if sort.SearchInts(checkedUserPids, possiblePID) == len { if sort.SearchInts(checkedUserPids, possiblePID) == len {
if findSocketFromPid(possiblePID, inode) { if findSocketFromPid(possiblePID, inode) {
return possiblePID, true return possiblePID
} }
} }
} }
@ -75,13 +73,13 @@ func FindPID(uid, inode int) (pid int, ok bool) { //nolint:gocognit // TODO
if possibleUID != uid { if possibleUID != uid {
for _, possiblePID := range pids { for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) { if findSocketFromPid(possiblePID, inode) {
return possiblePID, true return possiblePID
} }
} }
} }
} }
return unidentifiedProcessID, false return socket.UnidentifiedProcessID
} }
func findSocketFromPid(pid, inode int) bool { func findSocketFromPid(pid, inode int) bool {

View file

@ -5,6 +5,7 @@ package proc
import ( import (
"bufio" "bufio"
"encoding/hex" "encoding/hex"
"fmt"
"net" "net"
"os" "os"
"strconv" "strconv"
@ -43,12 +44,10 @@ const (
ICMP4 ICMP4
ICMP6 ICMP6
TCP4Data = "/proc/net/tcp" tcp4ProcFile = "/proc/net/tcp"
UDP4Data = "/proc/net/udp" tcp6ProcFile = "/proc/net/tcp6"
TCP6Data = "/proc/net/tcp6" udp4ProcFile = "/proc/net/udp"
UDP6Data = "/proc/net/udp6" udp6ProcFile = "/proc/net/udp6"
ICMP4Data = "/proc/net/icmp"
ICMP6Data = "/proc/net/icmp6"
UnfetchedProcessID = -2 UnfetchedProcessID = -2
@ -57,27 +56,47 @@ const (
// GetTCP4Table returns the system table for IPv4 TCP activity. // GetTCP4Table returns the system table for IPv4 TCP activity.
func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { func GetTCP4Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP4, TCP4Data, convertIPv4) return getTableFromSource(TCP4, tcp4ProcFile)
} }
// GetTCP6Table returns the system table for IPv6 TCP activity. // GetTCP6Table returns the system table for IPv6 TCP activity.
func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) { func GetTCP6Table() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo, err error) {
return getTableFromSource(TCP6, TCP6Data, convertIPv6) return getTableFromSource(TCP6, tcp6ProcFile)
} }
// GetUDP4Table returns the system table for IPv4 UDP activity. // GetUDP4Table returns the system table for IPv4 UDP activity.
func GetUDP4Table() (binds []*socket.BindInfo, err error) { func GetUDP4Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP4, UDP4Data, convertIPv4) _, binds, err = getTableFromSource(UDP4, udp4ProcFile)
return return
} }
// GetUDP6Table returns the system table for IPv6 UDP activity. // GetUDP6Table returns the system table for IPv6 UDP activity.
func GetUDP6Table() (binds []*socket.BindInfo, err error) { func GetUDP6Table() (binds []*socket.BindInfo, err error) {
_, binds, err = getTableFromSource(UDP6, UDP6Data, convertIPv6) _, binds, err = getTableFromSource(UDP6, udp6ProcFile)
return return
} }
func getTableFromSource(stack uint8, procFile string, ipConverter func(string) net.IP) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { const (
// hint: we split fields by multiple delimiters, see procDelimiter
fieldIndexLocalIP = 1
fieldIndexLocalPort = 2
fieldIndexRemoteIP = 3
fieldIndexRemotePort = 4
fieldIndexUID = 11
fieldIndexInode = 13
)
func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
var ipConverter func(string) net.IP
switch stack {
case TCP4, UDP4:
ipConverter = convertIPv4
case TCP6, UDP6:
ipConverter = convertIPv6
default:
return nil, nil, fmt.Errorf("unsupported table stack: %d", stack)
}
// open file // open file
socketData, err := os.Open(procFile) socketData, err := os.Open(procFile)
@ -91,36 +110,36 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
// parse // parse
scanner.Scan() // skip first line scanner.Scan() // skip first row
for scanner.Scan() { for scanner.Scan() {
line := strings.FieldsFunc(scanner.Text(), procDelimiter) fields := strings.FieldsFunc(scanner.Text(), procDelimiter)
if len(line) < 14 { if len(fields) < 14 {
// log.Tracef("process: too short: %s", line) // log.Tracef("process: too short: %s", fields)
continue continue
} }
localIP := ipConverter(line[1]) localIP := ipConverter(fields[fieldIndexLocalIP])
if localIP == nil { if localIP == nil {
continue continue
} }
localPort, err := strconv.ParseUint(line[2], 16, 16) localPort, err := strconv.ParseUint(fields[fieldIndexLocalPort], 16, 16)
if err != nil { if err != nil {
log.Warningf("process: could not parse port: %s", err) log.Warningf("process: could not parse port: %s", err)
continue continue
} }
uid, err := strconv.ParseInt(line[11], 10, 32) uid, err := strconv.ParseInt(fields[fieldIndexUID], 10, 32)
// log.Tracef("uid: %s", line[11]) // log.Tracef("uid: %s", fields[fieldIndexUID])
if err != nil { if err != nil {
log.Warningf("process: could not parse uid %s: %s", line[11], err) log.Warningf("process: could not parse uid %s: %s", fields[11], err)
continue continue
} }
inode, err := strconv.ParseInt(line[13], 10, 32) inode, err := strconv.ParseInt(fields[fieldIndexInode], 10, 32)
// log.Tracef("inode: %s", line[13]) // log.Tracef("inode: %s", fields[fieldIndexInode])
if err != nil { if err != nil {
log.Warningf("process: could not parse inode %s: %s", line[13], err) log.Warningf("process: could not parse inode %s: %s", fields[13], err)
continue continue
} }
@ -139,7 +158,7 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n
case TCP4, TCP6: case TCP4, TCP6:
if line[5] == tcpListenStateHex { if fields[5] == tcpListenStateHex {
// listener // listener
binds = append(binds, &socket.BindInfo{ binds = append(binds, &socket.BindInfo{
@ -154,12 +173,12 @@ func getTableFromSource(stack uint8, procFile string, ipConverter func(string) n
} else { } else {
// connection // connection
remoteIP := ipConverter(line[3]) remoteIP := ipConverter(fields[fieldIndexRemoteIP])
if remoteIP == nil { if remoteIP == nil {
continue continue
} }
remotePort, err := strconv.ParseUint(line[4], 16, 16) remotePort, err := strconv.ParseUint(fields[fieldIndexRemotePort], 16, 16)
if err != nil { if err != nil {
log.Warningf("process: could not parse port: %s", err) log.Warningf("process: could not parse port: %s", err)
continue continue

View file

@ -14,12 +14,12 @@ func TestSockets(t *testing.T) {
} }
fmt.Println("\nTCP 4 connections:") fmt.Println("\nTCP 4 connections:")
for _, connection := range connections { for _, connection := range connections {
pid, _ := FindPID(connection.UID, connection.Inode) pid := FindPID(connection.UID, connection.Inode)
fmt.Printf("%d: %+v\n", pid, connection) fmt.Printf("%d: %+v\n", pid, connection)
} }
fmt.Println("\nTCP 4 listeners:") fmt.Println("\nTCP 4 listeners:")
for _, listener := range listeners { for _, listener := range listeners {
pid, _ := FindPID(listener.UID, listener.Inode) pid := FindPID(listener.UID, listener.Inode)
fmt.Printf("%d: %+v\n", pid, listener) fmt.Printf("%d: %+v\n", pid, listener)
} }
@ -29,12 +29,12 @@ func TestSockets(t *testing.T) {
} }
fmt.Println("\nTCP 6 connections:") fmt.Println("\nTCP 6 connections:")
for _, connection := range connections { for _, connection := range connections {
pid, _ := FindPID(connection.UID, connection.Inode) pid := FindPID(connection.UID, connection.Inode)
fmt.Printf("%d: %+v\n", pid, connection) fmt.Printf("%d: %+v\n", pid, connection)
} }
fmt.Println("\nTCP 6 listeners:") fmt.Println("\nTCP 6 listeners:")
for _, listener := range listeners { for _, listener := range listeners {
pid, _ := FindPID(listener.UID, listener.Inode) pid := FindPID(listener.UID, listener.Inode)
fmt.Printf("%d: %+v\n", pid, listener) fmt.Printf("%d: %+v\n", pid, listener)
} }
@ -44,7 +44,7 @@ func TestSockets(t *testing.T) {
} }
fmt.Println("\nUDP 4 binds:") fmt.Println("\nUDP 4 binds:")
for _, bind := range binds { for _, bind := range binds {
pid, _ := FindPID(bind.UID, bind.Inode) pid := FindPID(bind.UID, bind.Inode)
fmt.Printf("%d: %+v\n", pid, bind) fmt.Printf("%d: %+v\n", pid, bind)
} }
@ -54,7 +54,7 @@ func TestSockets(t *testing.T) {
} }
fmt.Println("\nUDP 6 binds:") fmt.Println("\nUDP 6 binds:")
for _, bind := range binds { for _, bind := range binds {
pid, _ := FindPID(bind.UID, bind.Inode) pid := FindPID(bind.UID, bind.Inode)
fmt.Printf("%d: %+v\n", pid, bind) fmt.Printf("%d: %+v\n", pid, bind)
} }
} }

View file

@ -2,6 +2,11 @@ package socket
import "net" import "net"
const (
// UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops.
UnidentifiedProcessID = -1
)
// ConnectionInfo holds socket information returned by the system. // ConnectionInfo holds socket information returned by the system.
type ConnectionInfo struct { type ConnectionInfo struct {
Local Address Local Address

View file

@ -11,7 +11,11 @@ const (
UDPConnectionTTL = 10 * time.Minute UDPConnectionTTL = 10 * time.Minute
) )
// Exists checks if the given connection is present in the system state tables.
func Exists(pktInfo *packet.Info, now time.Time) (exists bool) { func Exists(pktInfo *packet.Info, now time.Time) (exists bool) {
// TODO: create lookup maps before running a flurry of Exists() checks.
switch { switch {
case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP: case pktInfo.Version == packet.IPv4 && pktInfo.Protocol == packet.TCP:
tcp4Lock.Lock() tcp4Lock.Lock()
@ -76,7 +80,10 @@ func existsUDP(
if localPort == socketInfo.Local.Port && if localPort == socketInfo.Local.Port &&
(socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) { (socketInfo.Local.IP[0] == 0 || localIP.Equal(socketInfo.Local.IP)) {
udpConnState, ok := getUDPConnState(socketInfo, udpStates, remoteIP, remotePort) udpConnState, ok := getUDPConnState(socketInfo, udpStates, socket.Address{
IP: remoteIP,
Port: remotePort,
})
switch { switch {
case !ok: case !ok:
return false return false

View file

@ -24,10 +24,6 @@ import (
// - switch direction to outbound if outbound packet is seen? // - switch direction to outbound if outbound packet is seen?
// - IP: Unidentified Process // - IP: Unidentified Process
const (
UnidentifiedProcessID = -1
)
// Errors // Errors
var ( var (
ErrConnectionNotFound = errors.New("could not find connection in system state tables") ErrConnectionNotFound = errors.New("could not find connection in system state tables")
@ -75,7 +71,7 @@ func Lookup(pktInfo *packet.Info) (pid int, inbound bool, err error) {
return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo) return searchUDP(udp6Binds, udp6States, updateUDP6Table, pktInfo)
default: default:
return UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process") return socket.UnidentifiedProcessID, false, errors.New("unsupported protocol for finding process")
} }
} }
@ -119,7 +115,7 @@ func searchTCP(
connections, listeners = updateTables() connections, listeners = updateTables()
} }
return UnidentifiedProcessID, false, ErrConnectionNotFound return socket.UnidentifiedProcessID, false, ErrConnectionNotFound
} }
func searchUDP( func searchUDP(
@ -170,5 +166,5 @@ func searchUDP(
binds = updateTable() binds = updateTable()
} }
return UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound return socket.UnidentifiedProcessID, pktInfo.Inbound, ErrConnectionNotFound
} }

View file

@ -14,24 +14,14 @@ var (
func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) { func checkConnectionPID(socketInfo *socket.ConnectionInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID { if socketInfo.PID == proc.UnfetchedProcessID {
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode) socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
if ok {
socketInfo.PID = pid
} else {
socketInfo.PID = UnidentifiedProcessID
}
} }
return socketInfo.PID, connInbound, nil return socketInfo.PID, connInbound, nil
} }
func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) { func checkBindPID(socketInfo *socket.BindInfo, connInbound bool) (pid int, inbound bool, err error) {
if socketInfo.PID == proc.UnfetchedProcessID { if socketInfo.PID == proc.UnfetchedProcessID {
pid, ok := proc.FindPID(socketInfo.UID, socketInfo.Inode) socketInfo.PID = proc.FindPID(socketInfo.UID, socketInfo.Inode)
if ok {
socketInfo.PID = pid
} else {
socketInfo.PID = UnidentifiedProcessID
}
} }
return socketInfo.PID, connInbound, nil return socketInfo.PID, connInbound, nil
} }

View file

@ -18,9 +18,8 @@ var (
) )
func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
// FIXME: repeatable once var err error
connections, listeners, err = getTCP4Table()
connections, listeners, err := getTCP4Table()
if err != nil { if err != nil {
log.Warningf("state: failed to get TCP4 socket table: %s", err) log.Warningf("state: failed to get TCP4 socket table: %s", err)
return return
@ -28,11 +27,12 @@ func updateTCP4Tables() (connections []*socket.ConnectionInfo, listeners []*sock
tcp4Connections = connections tcp4Connections = connections
tcp4Listeners = listeners tcp4Listeners = listeners
return tcp4Connections, tcp4Listeners return
} }
func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) { func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*socket.BindInfo) {
connections, listeners, err := getTCP6Table() var err error
connections, listeners, err = getTCP6Table()
if err != nil { if err != nil {
log.Warningf("state: failed to get TCP6 socket table: %s", err) log.Warningf("state: failed to get TCP6 socket table: %s", err)
return return
@ -40,27 +40,29 @@ func updateTCP6Tables() (connections []*socket.ConnectionInfo, listeners []*sock
tcp6Connections = connections tcp6Connections = connections
tcp6Listeners = listeners tcp6Listeners = listeners
return tcp6Connections, tcp6Listeners return
} }
func updateUDP4Table() (binds []*socket.BindInfo) { func updateUDP4Table() (binds []*socket.BindInfo) {
binds, err := getUDP4Table() var err error
binds, err = getUDP4Table()
if err != nil { if err != nil {
log.Warningf("state: failed to get UDP4 socket table: %s", err) log.Warningf("state: failed to get UDP4 socket table: %s", err)
return return
} }
udp4Binds = binds udp4Binds = binds
return udp4Binds return
} }
func updateUDP6Table() (binds []*socket.BindInfo) { func updateUDP6Table() (binds []*socket.BindInfo) {
binds, err := getUDP6Table() var err error
binds, err = getUDP6Table()
if err != nil { if err != nil {
log.Warningf("state: failed to get UDP6 socket table: %s", err) log.Warningf("state: failed to get UDP6 socket table: %s", err)
return return
} }
udp6Binds = binds udp6Binds = binds
return udp6Binds return
} }

View file

@ -2,7 +2,6 @@ package state
import ( import (
"context" "context"
"net"
"time" "time"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
@ -15,7 +14,7 @@ type udpState struct {
} }
const ( const (
UpdConnStateTTL = 72 * time.Hour UdpConnStateTTL = 72 * time.Hour
UdpConnStateShortenedTTL = 3 * time.Hour UdpConnStateShortenedTTL = 3 * time.Hour
AggressiveCleaningThreshold = 256 AggressiveCleaningThreshold = 256
) )
@ -25,10 +24,10 @@ var (
udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock udp6States = make(map[string]map[string]*udpState) // locked with udp6Lock
) )
func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteIP net.IP, remotePort uint16) (udpConnState *udpState, ok bool) { func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, remoteAddress socket.Address) (udpConnState *udpState, ok bool) {
bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] bindMap, ok := udpStates[makeUDPStateKey(socketInfo.Local)]
if ok { if ok {
udpConnState, ok = bindMap[makeUDPStateKey(remoteIP, remotePort)] udpConnState, ok = bindMap[makeUDPStateKey(remoteAddress)]
return return
} }
@ -36,7 +35,7 @@ func getUDPConnState(socketInfo *socket.BindInfo, udpStates map[string]map[strin
} }
func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) { func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[string]*udpState, pktInfo *packet.Info) (connDirection bool) {
localKey := makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port) localKey := makeUDPStateKey(socketInfo.Local)
bindMap, ok := udpStates[localKey] bindMap, ok := udpStates[localKey]
if !ok { if !ok {
@ -44,7 +43,10 @@ func getUDPDirection(socketInfo *socket.BindInfo, udpStates map[string]map[strin
udpStates[localKey] = bindMap udpStates[localKey] = bindMap
} }
remoteKey := makeUDPStateKey(pktInfo.RemoteIP(), pktInfo.RemotePort()) remoteKey := makeUDPStateKey(socket.Address{
IP: pktInfo.RemoteIP(),
Port: pktInfo.RemotePort(),
})
udpConnState, ok := bindMap[remoteKey] udpConnState, ok := bindMap[remoteKey]
if !ok { if !ok {
bindMap[remoteKey] = &udpState{ bindMap[remoteKey] = &udpState{
@ -79,19 +81,18 @@ func cleanStates(
now time.Time, now time.Time,
) { ) {
// compute thresholds // compute thresholds
threshold := now.Add(-UpdConnStateTTL) threshold := now.Add(-UdpConnStateTTL)
shortThreshhold := now.Add(-UdpConnStateShortenedTTL) shortThreshhold := now.Add(-UdpConnStateShortenedTTL)
// make list of all active keys // make lookup map of all active keys
bindKeys := make(map[string]struct{}) bindKeys := make(map[string]struct{})
for _, socketInfo := range binds { for _, socketInfo := range binds {
bindKeys[makeUDPStateKey(socketInfo.Local.IP, socketInfo.Local.Port)] = struct{}{} bindKeys[makeUDPStateKey(socketInfo.Local)] = struct{}{}
} }
// clean the udp state storage // clean the udp state storage
for localKey, bindMap := range udpStates { for localKey, bindMap := range udpStates {
_, active := bindKeys[localKey] if _, active := bindKeys[localKey]; active {
if active {
// clean old entries // clean old entries
for remoteKey, udpConnState := range bindMap { for remoteKey, udpConnState := range bindMap {
if udpConnState.lastSeen.Before(threshold) { if udpConnState.lastSeen.Before(threshold) {
@ -113,7 +114,7 @@ func cleanStates(
} }
} }
func makeUDPStateKey(ip net.IP, port uint16) string { func makeUDPStateKey(address socket.Address) string {
// This could potentially go wrong, but as all IPs are created by the same source, everything should be fine. // This could potentially go wrong, but as all IPs are created by the same source, everything should be fine.
return string(ip) + string(port) return string(address.IP) + string(address.Port)
} }

View file

@ -29,6 +29,14 @@ func prep() error {
} }
func start() error { func start() error {
// Create a dummy directory to which processes change their working directory
// to. Currently this includes the App and the Notifier. The aim is protect
// all other directories and increase compatibility should any process want
// to read or write something to the current working directory. This can also
// be useful in the future to dump data to for debugging. The permission used
// may seem dangerous, but proper permission on the parent directory provide
// (some) protection.
// Processes must _never_ read from this directory.
err := dataroot.Root().ChildDir("exec", 0777).Ensure() err := dataroot.Root().ChildDir("exec", 0777).Ensure()
if err != nil { if err != nil {
log.Warningf("ui: failed to create safe exec dir: %s", err) log.Warningf("ui: failed to create safe exec dir: %s", err)