Merge branch 'develop' into feature/ui-revamp

This commit is contained in:
Patrick Pacher 2020-10-20 13:00:28 +02:00
commit 997787e3f2
No known key found for this signature in database
GPG key ID: E8CD2DA160925A6D
7 changed files with 253 additions and 163 deletions

View file

@ -55,12 +55,12 @@ func interceptionStop() error {
return interception.Stop()
}
func handlePacket(pkt packet.Packet) {
func handlePacket(ctx context.Context, pkt packet.Packet) {
if fastTrackedPermit(pkt) {
return
}
traceCtx, tracer := log.AddTracer(context.Background())
traceCtx, tracer := log.AddTracer(ctx)
if tracer != nil {
pkt.SetCtx(traceCtx)
tracer.Tracef("filter: handling packet: %s", pkt)
@ -318,7 +318,10 @@ func packetHandler(ctx context.Context) error {
case <-ctx.Done():
return nil
case pkt := <-interception.Packets:
handlePacket(pkt)
interceptionModule.StartWorker("initial packet handler", func(ctx context.Context) error {
handlePacket(ctx, pkt)
return nil
})
}
}
}

View file

@ -49,50 +49,57 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) {
// 100: same network/datacenter
// Weighting:
// coordinate distance: 0-50
// continent match: 15
// country match: 10
// AS owner match: 15
// AS network match: 10
// continent match: 25
// country match: 20
// AS owner match: 25
// AS network match: 20
// coordinate distance: 0-10
// coordinate distance: 0-50
// continent match: 25
if l.Continent.Code == to.Continent.Code {
proximity += 25
// country match: 20
if l.Country.ISOCode == to.Country.ISOCode {
proximity += 20
}
}
// AS owner match: 25
if l.AutonomousSystemOrganization == to.AutonomousSystemOrganization {
proximity += 25
// AS network match: 20
if l.AutonomousSystemNumber == to.AutonomousSystemNumber {
proximity += 20
}
}
// coordinate distance: 0-10
fromCoords := haversine.Coord{Lat: l.Coordinates.Latitude, Lon: l.Coordinates.Longitude}
toCoords := haversine.Coord{Lat: to.Coordinates.Latitude, Lon: to.Coordinates.Longitude}
_, km := haversine.Distance(fromCoords, toCoords)
// proximity distance by accuracy
// get worst accuracy rating
// adjust accuracy value
accuracy := l.Coordinates.AccuracyRadius
if to.Coordinates.AccuracyRadius > accuracy {
switch {
case l.Coordinates.Latitude == 0 && l.Coordinates.Longitude == 0:
fallthrough
case to.Coordinates.Latitude == 0 && to.Coordinates.Longitude == 0:
// If we don't have any on any side coordinates, set accuracy to worst
// effective value.
accuracy = 1000
case to.Coordinates.AccuracyRadius > accuracy:
// If the destination accuracy is worse, use that one.
accuracy = to.Coordinates.AccuracyRadius
}
if km <= 10 && accuracy <= 100 {
proximity += 50
proximity += 10
} else {
distanceIn50Percent := ((earthCircumferenceInKm - km) / earthCircumferenceInKm) * 50
distanceInPercent := (earthCircumferenceInKm - km) * 100 / earthCircumferenceInKm
// apply penalty for locations with low accuracy (targeting accuracy radius >100)
accuracyModifier := 1 - float64(accuracy)/1000
proximity += int(distanceIn50Percent * accuracyModifier)
}
// continent match: 15
if l.Continent.Code == to.Continent.Code {
proximity += 15
// country match: 10
if l.Country.ISOCode == to.Country.ISOCode {
proximity += 10
}
}
// AS owner match: 15
if l.AutonomousSystemOrganization == to.AutonomousSystemOrganization {
proximity += 15
// AS network match: 10
if l.AutonomousSystemNumber == to.AutonomousSystemNumber {
proximity += 10
}
proximity += int(distanceInPercent * 0.10 * accuracyModifier)
}
return //nolint:nakedret

View file

@ -5,10 +5,7 @@ package proc
import (
"fmt"
"os"
"sort"
"strconv"
"sync"
"syscall"
"time"
"github.com/safing/portmaster/network/socket"
@ -16,88 +13,89 @@ import (
)
var (
// pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs.
pidsByUserLock sync.Mutex
pidsByUser = make(map[int][]int)
baseWaitTime = 3 * time.Millisecond
lookupRetries = 3
)
// GetPID returns the already existing pid of the given socket info or searches for it.
// This also acts as a getter for socket.*Info.PID, as locking for that occurs here.
func GetPID(socketInfo socket.Info) (pid int) {
pidsByUserLock.Lock()
defer pidsByUserLock.Unlock()
// Get currently assigned PID to the socket info.
currentPid := socketInfo.GetPID()
if socketInfo.GetPID() != socket.UnidentifiedProcessID {
return socketInfo.GetPID()
// If the current PID already is valid (ie. not unidentified), return it immediately.
if currentPid != socket.UnidentifiedProcessID {
return currentPid
}
pid = findPID(socketInfo.GetUID(), socketInfo.GetInode())
// Find PID for the given UID and inode.
pid = findPID(socketInfo.GetUIDandInode())
// Set the newly found PID on the socket info.
socketInfo.SetPID(pid)
// Return found PID.
return pid
}
// findPID returns the pid of the given uid and socket inode.
func findPID(uid, inode int) (pid int) { //nolint:gocognit // TODO
func findPID(uid, inode int) (pid int) {
socketName := fmt.Sprintf("socket:[%d]", inode)
pidsUpdated := false
for i := 0; i <= lookupRetries; i++ {
var pidsUpdated bool
// get pids of user, update if missing
pids, ok := pidsByUser[uid]
if !ok {
// log.Trace("proc: no processes of user, updating table")
updatePids()
pidsUpdated = true
pids, ok = pidsByUser[uid]
}
if ok {
// if user has pids, start checking them first
var checkedUserPids []int
for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) {
return possiblePID
}
checkedUserPids = append(checkedUserPids, possiblePID)
}
// if we fail on the first run and have not updated, update and check the ones we haven't tried so far.
if !pidsUpdated {
// log.Trace("proc: socket not found in any process of user, updating table")
// update
// Get all pids for the given uid.
pids, ok := getPidsByUser(uid)
if !ok {
// If we cannot find the uid in the map, update it.
updatePids()
// sort for faster search
for i, j := 0, len(checkedUserPids)-1; i < j; i, j = i+1, j-1 {
checkedUserPids[i], checkedUserPids[j] = checkedUserPids[j], checkedUserPids[i]
pidsUpdated = true
pids, ok = getPidsByUser(uid)
}
// If we have found PIDs, search them.
if ok {
// Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to
// be searched for.
for i := len(pids) - 1; i >= 0; i-- {
if findSocketFromPid(pids[i], socketName) {
return pids[i]
}
}
len := len(checkedUserPids)
// check unchecked pids
for _, possiblePID := range pids {
// only check if not already checked
if sort.SearchInts(checkedUserPids, possiblePID) == len {
if findSocketFromPid(possiblePID, inode) {
return possiblePID
}
// If we still cannot find our socket, and haven't yet updated the PID map,
// do this and then check again immediately.
if !pidsUpdated {
updatePids()
pids, ok = getPidsByUser(uid)
if ok {
// Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to
// be searched for.
for i := len(pids) - 1; i >= 0; i-- {
if findSocketFromPid(pids[i], socketName) {
return pids[i]
}
}
}
}
}
// check all other pids
// log.Trace("proc: socket not found in any process of user, checking all pids")
// TODO: find best order for pidsByUser for best performance
for possibleUID, pids := range pidsByUser {
if possibleUID != uid {
for _, possiblePID := range pids {
if findSocketFromPid(possiblePID, inode) {
return possiblePID
}
}
// We have updated the PID map, but still cannot find anything.
// So, there is nothing we can other than wait a little for the kernel to
// populate the information.
// Wait after each try, except for the last iteration
if i < lookupRetries {
// Wait in back-off fashion - with 3ms baseWaitTime: 3, 6, 9 - 18ms in total.
time.Sleep(time.Duration(i+1) * baseWaitTime)
}
}
return socket.UnidentifiedProcessID
}
func findSocketFromPid(pid, inode int) bool {
socketName := fmt.Sprintf("socket:[%d]", inode)
func findSocketFromPid(pid int, socketName string) bool {
entries := readDirNames(fmt.Sprintf("/proc/%d/fd", pid))
if len(entries) == 0 {
return false
@ -119,49 +117,6 @@ func findSocketFromPid(pid, inode int) bool {
return false
}
func updatePids() {
pidsByUser = make(map[int][]int)
entries := readDirNames("/proc")
if len(entries) == 0 {
return
}
entryLoop:
for _, entry := range entries {
pid, err := strconv.ParseInt(entry, 10, 32)
if err != nil {
continue entryLoop
}
statData, err := os.Stat(fmt.Sprintf("/proc/%d", pid))
if err != nil {
log.Warningf("proc: could not stat /proc/%d: %s", pid, err)
continue entryLoop
}
sys, ok := statData.Sys().(*syscall.Stat_t)
if !ok {
log.Warningf("proc: unable to parse /proc/%d: wrong type", pid)
continue entryLoop
}
pids, ok := pidsByUser[int(sys.Uid)]
if ok {
pidsByUser[int(sys.Uid)] = append(pids, int(pid))
} else {
pidsByUser[int(sys.Uid)] = []int{int(pid)}
}
}
for _, slice := range pidsByUser {
for i, j := 0, len(slice)-1; i < j; i, j = i+1, j-1 {
slice[i], slice[j] = slice[j], slice[i]
}
}
}
// readDirNames only reads the directory names. Using ioutil.ReadDir() would call `lstat` on every
// resulting directory name, which we don't need. This function will be called a lot, so we should
// refrain from unnecessary work.

View file

@ -0,0 +1,77 @@
// +build linux
package proc
import (
"fmt"
"os"
"strconv"
"sync"
"syscall"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
)
var (
// pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs.
pidsByUser = make(map[int][]int)
pidsByUserLock sync.RWMutex
fetchPidsByUser utils.OnceAgain
)
// getPidsByUser returns the cached PIDs for the given UID.
func getPidsByUser(uid int) (pids []int, ok bool) {
pidsByUserLock.RLock()
defer pidsByUserLock.RUnlock()
pids, ok = pidsByUser[uid]
return
}
// updatePids fetches and creates a new pidsByUser map using utils.OnceAgain.
func updatePids() {
fetchPidsByUser.Do(func() {
newPidsByUser := make(map[int][]int)
pidCnt := 0
entries := readDirNames("/proc")
if len(entries) == 0 {
log.Warning("proc: found no PIDs in /proc")
return
}
entryLoop:
for _, entry := range entries {
pid, err := strconv.ParseInt(entry, 10, 32)
if err != nil {
continue entryLoop
}
statData, err := os.Stat(fmt.Sprintf("/proc/%d", pid))
if err != nil {
log.Warningf("proc: could not stat /proc/%d: %s", pid, err)
continue entryLoop
}
sys, ok := statData.Sys().(*syscall.Stat_t)
if !ok {
log.Warningf("proc: unable to parse /proc/%d: wrong type", pid)
continue entryLoop
}
pids, ok := newPidsByUser[int(sys.Uid)]
if ok {
newPidsByUser[int(sys.Uid)] = append(pids, int(pid))
} else {
newPidsByUser[int(sys.Uid)] = []int{int(pid)}
}
pidCnt++
}
// log.Tracef("proc: updated PID table with %d entries", pidCnt)
pidsByUserLock.Lock()
defer pidsByUserLock.Unlock()
pidsByUser = newPidsByUser
})
}

View file

@ -35,7 +35,7 @@ Cache every step!
*/
// Network Related Constants
// Network Related Constants.
const (
TCP4 uint8 = iota
UDP4

View file

@ -1,6 +1,9 @@
package socket
import "net"
import (
"net"
"sync"
)
const (
// UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops.
@ -9,6 +12,8 @@ const (
// ConnectionInfo holds socket information returned by the system.
type ConnectionInfo struct {
sync.Mutex
Local Address
Remote Address
PID int
@ -18,6 +23,8 @@ type ConnectionInfo struct {
// BindInfo holds socket information returned by the system.
type BindInfo struct {
sync.Mutex
Local Address
PID int
UID int
@ -35,32 +42,72 @@ type Info interface {
GetPID() int
SetPID(int)
GetUID() int
GetInode() int
GetUIDandInode() (int, int)
}
// GetPID returns the PID.
func (i *ConnectionInfo) GetPID() int { return i.PID }
func (i *ConnectionInfo) GetPID() int {
i.Lock()
defer i.Unlock()
return i.PID
}
// SetPID sets the PID to the given value.
func (i *ConnectionInfo) SetPID(pid int) { i.PID = pid }
func (i *ConnectionInfo) SetPID(pid int) {
i.Lock()
defer i.Unlock()
i.PID = pid
}
// GetUID returns the UID.
func (i *ConnectionInfo) GetUID() int { return i.UID }
func (i *ConnectionInfo) GetUID() int {
i.Lock()
defer i.Unlock()
// GetInode returns the Inode.
func (i *ConnectionInfo) GetInode() int { return i.Inode }
return i.UID
}
// GetUIDandInode returns the UID and Inode.
func (i *ConnectionInfo) GetUIDandInode() (int, int) {
i.Lock()
defer i.Unlock()
return i.UID, i.Inode
}
// GetPID returns the PID.
func (i *BindInfo) GetPID() int { return i.PID }
func (i *BindInfo) GetPID() int {
i.Lock()
defer i.Unlock()
return i.PID
}
// SetPID sets the PID to the given value.
func (i *BindInfo) SetPID(pid int) { i.PID = pid }
func (i *BindInfo) SetPID(pid int) {
i.Lock()
defer i.Unlock()
i.PID = pid
}
// GetUID returns the UID.
func (i *BindInfo) GetUID() int { return i.UID }
func (i *BindInfo) GetUID() int {
i.Lock()
defer i.Unlock()
// GetInode returns the Inode.
func (i *BindInfo) GetInode() int { return i.Inode }
return i.UID
}
// GetUIDandInode returns the UID and Inode.
func (i *BindInfo) GetUIDandInode() (int, int) {
i.Lock()
defer i.Unlock()
return i.UID, i.Inode
}
// compile time checks
var _ Info = new(ConnectionInfo)

View file

@ -18,6 +18,7 @@ import (
const (
tcpWriteTimeout = 1 * time.Second
ignoreQueriesAfter = 10 * time.Minute
heartbeatTimeout = 15 * time.Second
)
// TCPResolver is a resolver using just a single tcp connection with pipelining.
@ -29,7 +30,7 @@ type TCPResolver struct {
clientStarted *abool.AtomicBool
clientHeartbeat chan struct{}
clientCancel func()
stopClient func()
connInstanceID *uint32
queries chan *dns.Msg
inFlightQueries map[uint16]*InFlightQuery
@ -75,9 +76,9 @@ func NewTCPResolver(resolver *Resolver) *TCPResolver {
},
clientStarted: abool.New(),
clientHeartbeat: make(chan struct{}),
clientCancel: func() {},
stopClient: func() {},
connInstanceID: &instanceID,
queries: make(chan *dns.Msg, 100),
queries: make(chan *dns.Msg, 1000),
inFlightQueries: make(map[uint16]*InFlightQuery),
}
}
@ -181,15 +182,15 @@ func (tr *TCPResolver) checkClientStatus() {
// Get client cancel function before waiting in order to not immediately
// cancel a new client.
tr.Lock()
cancelClient := tr.clientCancel
stopClient := tr.stopClient
tr.Unlock()
// Check if the client is alive with the heartbeat, if not shut it down.
select {
case tr.clientHeartbeat <- struct{}{}:
case <-time.After(defaultRequestTimeout):
case <-time.After(heartbeatTimeout):
log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.GetName())
cancelClient()
stopClient()
}
}
@ -214,16 +215,16 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
mgr.tr.clientStarted.Set()
// Create additional cancel function for this worker.
workerCtx, cancelWorker := context.WithCancel(workerCtx)
clientCtx, stopClient := context.WithCancel(workerCtx)
mgr.tr.Lock()
mgr.tr.clientCancel = cancelWorker
mgr.tr.stopClient = stopClient
mgr.tr.Unlock()
// connection lifecycle loop
for {
// check if we are shutting down
select {
case <-workerCtx.Done():
case <-clientCtx.Done():
return nil
default:
}
@ -234,7 +235,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
}
// wait for work before creating connection
proceed := mgr.waitForWork(workerCtx)
proceed := mgr.waitForWork(clientCtx)
if !proceed {
return nil
}
@ -250,7 +251,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error {
netenv.ReportSuccessfulConnection()
// handle queries
proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx)
proceed = mgr.queryHandler(clientCtx, conn, connClosing, connCtx, cancelConnCtx)
if !proceed {
return nil
}
@ -276,7 +277,7 @@ func (mgr *tcpResolverConnMgr) shutdown() {
}
}
func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) {
func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed bool) {
// wait until there is something to do
mgr.tr.Lock()
waiting := len(mgr.tr.inFlightQueries)
@ -308,7 +309,7 @@ func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed b
// wait for first query
select {
case <-workerCtx.Done():
case <-clientCtx.Done():
return false
case msg := <-mgr.tr.queries:
// re-insert query, we will handle it later
@ -362,7 +363,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() (
)
// start reader
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error {
module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(clientCtx context.Context) error {
return mgr.msgReader(conn, connClosing, cancelConnCtx)
})
@ -370,7 +371,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() (
}
func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter.
workerCtx context.Context,
clientCtx context.Context,
conn *dns.Conn,
connClosing *abool.AtomicBool,
connCtx context.Context,
@ -394,7 +395,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context
case <-mgr.tr.clientHeartbeat:
// respond to alive checks
case <-workerCtx.Done():
case <-clientCtx.Done():
// module shutdown
return false