diff --git a/firewall/interception.go b/firewall/interception.go index a2146e97..04f7f4af 100644 --- a/firewall/interception.go +++ b/firewall/interception.go @@ -55,12 +55,12 @@ func interceptionStop() error { return interception.Stop() } -func handlePacket(pkt packet.Packet) { +func handlePacket(ctx context.Context, pkt packet.Packet) { if fastTrackedPermit(pkt) { return } - traceCtx, tracer := log.AddTracer(context.Background()) + traceCtx, tracer := log.AddTracer(ctx) if tracer != nil { pkt.SetCtx(traceCtx) tracer.Tracef("filter: handling packet: %s", pkt) @@ -318,7 +318,10 @@ func packetHandler(ctx context.Context) error { case <-ctx.Done(): return nil case pkt := <-interception.Packets: - handlePacket(pkt) + interceptionModule.StartWorker("initial packet handler", func(ctx context.Context) error { + handlePacket(ctx, pkt) + return nil + }) } } } diff --git a/intel/geoip/location.go b/intel/geoip/location.go index f62d754a..ebe9928e 100644 --- a/intel/geoip/location.go +++ b/intel/geoip/location.go @@ -49,50 +49,57 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) { // 100: same network/datacenter // Weighting: - // coordinate distance: 0-50 - // continent match: 15 - // country match: 10 - // AS owner match: 15 - // AS network match: 10 + // continent match: 25 + // country match: 20 + // AS owner match: 25 + // AS network match: 20 + // coordinate distance: 0-10 - // coordinate distance: 0-50 + // continent match: 25 + if l.Continent.Code == to.Continent.Code { + proximity += 25 + // country match: 20 + if l.Country.ISOCode == to.Country.ISOCode { + proximity += 20 + } + } + + // AS owner match: 25 + if l.AutonomousSystemOrganization == to.AutonomousSystemOrganization { + proximity += 25 + // AS network match: 20 + if l.AutonomousSystemNumber == to.AutonomousSystemNumber { + proximity += 20 + } + } + + // coordinate distance: 0-10 fromCoords := haversine.Coord{Lat: l.Coordinates.Latitude, Lon: l.Coordinates.Longitude} toCoords := haversine.Coord{Lat: to.Coordinates.Latitude, Lon: to.Coordinates.Longitude} _, km := haversine.Distance(fromCoords, toCoords) - // proximity distance by accuracy - // get worst accuracy rating + // adjust accuracy value accuracy := l.Coordinates.AccuracyRadius - if to.Coordinates.AccuracyRadius > accuracy { + switch { + case l.Coordinates.Latitude == 0 && l.Coordinates.Longitude == 0: + fallthrough + case to.Coordinates.Latitude == 0 && to.Coordinates.Longitude == 0: + // If we don't have any on any side coordinates, set accuracy to worst + // effective value. + accuracy = 1000 + case to.Coordinates.AccuracyRadius > accuracy: + // If the destination accuracy is worse, use that one. accuracy = to.Coordinates.AccuracyRadius } if km <= 10 && accuracy <= 100 { - proximity += 50 + proximity += 10 } else { - distanceIn50Percent := ((earthCircumferenceInKm - km) / earthCircumferenceInKm) * 50 + distanceInPercent := (earthCircumferenceInKm - km) * 100 / earthCircumferenceInKm // apply penalty for locations with low accuracy (targeting accuracy radius >100) accuracyModifier := 1 - float64(accuracy)/1000 - proximity += int(distanceIn50Percent * accuracyModifier) - } - - // continent match: 15 - if l.Continent.Code == to.Continent.Code { - proximity += 15 - // country match: 10 - if l.Country.ISOCode == to.Country.ISOCode { - proximity += 10 - } - } - - // AS owner match: 15 - if l.AutonomousSystemOrganization == to.AutonomousSystemOrganization { - proximity += 15 - // AS network match: 10 - if l.AutonomousSystemNumber == to.AutonomousSystemNumber { - proximity += 10 - } + proximity += int(distanceInPercent * 0.10 * accuracyModifier) } return //nolint:nakedret diff --git a/network/proc/findpid.go b/network/proc/findpid.go index 0610e361..37e9b0ef 100644 --- a/network/proc/findpid.go +++ b/network/proc/findpid.go @@ -5,10 +5,7 @@ package proc import ( "fmt" "os" - "sort" - "strconv" - "sync" - "syscall" + "time" "github.com/safing/portmaster/network/socket" @@ -16,88 +13,89 @@ import ( ) var ( - // pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs. - pidsByUserLock sync.Mutex - pidsByUser = make(map[int][]int) + baseWaitTime = 3 * time.Millisecond + lookupRetries = 3 ) // GetPID returns the already existing pid of the given socket info or searches for it. // This also acts as a getter for socket.*Info.PID, as locking for that occurs here. func GetPID(socketInfo socket.Info) (pid int) { - pidsByUserLock.Lock() - defer pidsByUserLock.Unlock() + // Get currently assigned PID to the socket info. + currentPid := socketInfo.GetPID() - if socketInfo.GetPID() != socket.UnidentifiedProcessID { - return socketInfo.GetPID() + // If the current PID already is valid (ie. not unidentified), return it immediately. + if currentPid != socket.UnidentifiedProcessID { + return currentPid } - pid = findPID(socketInfo.GetUID(), socketInfo.GetInode()) + // Find PID for the given UID and inode. + pid = findPID(socketInfo.GetUIDandInode()) + + // Set the newly found PID on the socket info. socketInfo.SetPID(pid) + + // Return found PID. return pid } // findPID returns the pid of the given uid and socket inode. -func findPID(uid, inode int) (pid int) { //nolint:gocognit // TODO +func findPID(uid, inode int) (pid int) { + socketName := fmt.Sprintf("socket:[%d]", inode) - pidsUpdated := false + for i := 0; i <= lookupRetries; i++ { + var pidsUpdated bool - // get pids of user, update if missing - pids, ok := pidsByUser[uid] - if !ok { - // log.Trace("proc: no processes of user, updating table") - updatePids() - pidsUpdated = true - pids, ok = pidsByUser[uid] - } - if ok { - // if user has pids, start checking them first - var checkedUserPids []int - for _, possiblePID := range pids { - if findSocketFromPid(possiblePID, inode) { - return possiblePID - } - checkedUserPids = append(checkedUserPids, possiblePID) - } - // if we fail on the first run and have not updated, update and check the ones we haven't tried so far. - if !pidsUpdated { - // log.Trace("proc: socket not found in any process of user, updating table") - // update + // Get all pids for the given uid. + pids, ok := getPidsByUser(uid) + if !ok { + // If we cannot find the uid in the map, update it. updatePids() - // sort for faster search - for i, j := 0, len(checkedUserPids)-1; i < j; i, j = i+1, j-1 { - checkedUserPids[i], checkedUserPids[j] = checkedUserPids[j], checkedUserPids[i] + pidsUpdated = true + pids, ok = getPidsByUser(uid) + } + + // If we have found PIDs, search them. + if ok { + // Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to + // be searched for. + for i := len(pids) - 1; i >= 0; i-- { + if findSocketFromPid(pids[i], socketName) { + return pids[i] + } } - len := len(checkedUserPids) - // check unchecked pids - for _, possiblePID := range pids { - // only check if not already checked - if sort.SearchInts(checkedUserPids, possiblePID) == len { - if findSocketFromPid(possiblePID, inode) { - return possiblePID + } + + // If we still cannot find our socket, and haven't yet updated the PID map, + // do this and then check again immediately. + if !pidsUpdated { + updatePids() + pids, ok = getPidsByUser(uid) + if ok { + // Look through the PIDs in reverse order, because higher/newer PIDs will be more likely to + // be searched for. + for i := len(pids) - 1; i >= 0; i-- { + if findSocketFromPid(pids[i], socketName) { + return pids[i] } } } } - } - // check all other pids - // log.Trace("proc: socket not found in any process of user, checking all pids") - // TODO: find best order for pidsByUser for best performance - for possibleUID, pids := range pidsByUser { - if possibleUID != uid { - for _, possiblePID := range pids { - if findSocketFromPid(possiblePID, inode) { - return possiblePID - } - } + // We have updated the PID map, but still cannot find anything. + // So, there is nothing we can other than wait a little for the kernel to + // populate the information. + + // Wait after each try, except for the last iteration + if i < lookupRetries { + // Wait in back-off fashion - with 3ms baseWaitTime: 3, 6, 9 - 18ms in total. + time.Sleep(time.Duration(i+1) * baseWaitTime) } } return socket.UnidentifiedProcessID } -func findSocketFromPid(pid, inode int) bool { - socketName := fmt.Sprintf("socket:[%d]", inode) +func findSocketFromPid(pid int, socketName string) bool { entries := readDirNames(fmt.Sprintf("/proc/%d/fd", pid)) if len(entries) == 0 { return false @@ -119,49 +117,6 @@ func findSocketFromPid(pid, inode int) bool { return false } -func updatePids() { - pidsByUser = make(map[int][]int) - - entries := readDirNames("/proc") - if len(entries) == 0 { - return - } - -entryLoop: - for _, entry := range entries { - pid, err := strconv.ParseInt(entry, 10, 32) - if err != nil { - continue entryLoop - } - - statData, err := os.Stat(fmt.Sprintf("/proc/%d", pid)) - if err != nil { - log.Warningf("proc: could not stat /proc/%d: %s", pid, err) - continue entryLoop - } - sys, ok := statData.Sys().(*syscall.Stat_t) - if !ok { - log.Warningf("proc: unable to parse /proc/%d: wrong type", pid) - continue entryLoop - } - - pids, ok := pidsByUser[int(sys.Uid)] - if ok { - pidsByUser[int(sys.Uid)] = append(pids, int(pid)) - } else { - pidsByUser[int(sys.Uid)] = []int{int(pid)} - } - - } - - for _, slice := range pidsByUser { - for i, j := 0, len(slice)-1; i < j; i, j = i+1, j-1 { - slice[i], slice[j] = slice[j], slice[i] - } - } - -} - // readDirNames only reads the directory names. Using ioutil.ReadDir() would call `lstat` on every // resulting directory name, which we don't need. This function will be called a lot, so we should // refrain from unnecessary work. diff --git a/network/proc/pids_by_user.go b/network/proc/pids_by_user.go new file mode 100644 index 00000000..0c155fd8 --- /dev/null +++ b/network/proc/pids_by_user.go @@ -0,0 +1,77 @@ +// +build linux + +package proc + +import ( + "fmt" + "os" + "strconv" + "sync" + "syscall" + + "github.com/safing/portbase/log" + "github.com/safing/portbase/utils" +) + +var ( + // pidsByUserLock is also used for locking the socketInfo.PID on all socket.*Info structs. + pidsByUser = make(map[int][]int) + pidsByUserLock sync.RWMutex + fetchPidsByUser utils.OnceAgain +) + +// getPidsByUser returns the cached PIDs for the given UID. +func getPidsByUser(uid int) (pids []int, ok bool) { + pidsByUserLock.RLock() + defer pidsByUserLock.RUnlock() + + pids, ok = pidsByUser[uid] + return +} + +// updatePids fetches and creates a new pidsByUser map using utils.OnceAgain. +func updatePids() { + fetchPidsByUser.Do(func() { + newPidsByUser := make(map[int][]int) + pidCnt := 0 + + entries := readDirNames("/proc") + if len(entries) == 0 { + log.Warning("proc: found no PIDs in /proc") + return + } + + entryLoop: + for _, entry := range entries { + pid, err := strconv.ParseInt(entry, 10, 32) + if err != nil { + continue entryLoop + } + + statData, err := os.Stat(fmt.Sprintf("/proc/%d", pid)) + if err != nil { + log.Warningf("proc: could not stat /proc/%d: %s", pid, err) + continue entryLoop + } + sys, ok := statData.Sys().(*syscall.Stat_t) + if !ok { + log.Warningf("proc: unable to parse /proc/%d: wrong type", pid) + continue entryLoop + } + + pids, ok := newPidsByUser[int(sys.Uid)] + if ok { + newPidsByUser[int(sys.Uid)] = append(pids, int(pid)) + } else { + newPidsByUser[int(sys.Uid)] = []int{int(pid)} + } + pidCnt++ + } + + // log.Tracef("proc: updated PID table with %d entries", pidCnt) + + pidsByUserLock.Lock() + defer pidsByUserLock.Unlock() + pidsByUser = newPidsByUser + }) +} diff --git a/network/proc/tables.go b/network/proc/tables.go index 45055d36..94f3198b 100644 --- a/network/proc/tables.go +++ b/network/proc/tables.go @@ -35,7 +35,7 @@ Cache every step! */ -// Network Related Constants +// Network Related Constants. const ( TCP4 uint8 = iota UDP4 diff --git a/network/socket/socket.go b/network/socket/socket.go index 2bab6277..24d03518 100644 --- a/network/socket/socket.go +++ b/network/socket/socket.go @@ -1,6 +1,9 @@ package socket -import "net" +import ( + "net" + "sync" +) const ( // UnidentifiedProcessID is originally defined in the process pkg, but duplicated here because of import loops. @@ -9,6 +12,8 @@ const ( // ConnectionInfo holds socket information returned by the system. type ConnectionInfo struct { + sync.Mutex + Local Address Remote Address PID int @@ -18,6 +23,8 @@ type ConnectionInfo struct { // BindInfo holds socket information returned by the system. type BindInfo struct { + sync.Mutex + Local Address PID int UID int @@ -35,32 +42,72 @@ type Info interface { GetPID() int SetPID(int) GetUID() int - GetInode() int + GetUIDandInode() (int, int) } // GetPID returns the PID. -func (i *ConnectionInfo) GetPID() int { return i.PID } +func (i *ConnectionInfo) GetPID() int { + i.Lock() + defer i.Unlock() + + return i.PID +} // SetPID sets the PID to the given value. -func (i *ConnectionInfo) SetPID(pid int) { i.PID = pid } +func (i *ConnectionInfo) SetPID(pid int) { + i.Lock() + defer i.Unlock() + + i.PID = pid +} // GetUID returns the UID. -func (i *ConnectionInfo) GetUID() int { return i.UID } +func (i *ConnectionInfo) GetUID() int { + i.Lock() + defer i.Unlock() -// GetInode returns the Inode. -func (i *ConnectionInfo) GetInode() int { return i.Inode } + return i.UID +} + +// GetUIDandInode returns the UID and Inode. +func (i *ConnectionInfo) GetUIDandInode() (int, int) { + i.Lock() + defer i.Unlock() + + return i.UID, i.Inode +} // GetPID returns the PID. -func (i *BindInfo) GetPID() int { return i.PID } +func (i *BindInfo) GetPID() int { + i.Lock() + defer i.Unlock() + + return i.PID +} // SetPID sets the PID to the given value. -func (i *BindInfo) SetPID(pid int) { i.PID = pid } +func (i *BindInfo) SetPID(pid int) { + i.Lock() + defer i.Unlock() + + i.PID = pid +} // GetUID returns the UID. -func (i *BindInfo) GetUID() int { return i.UID } +func (i *BindInfo) GetUID() int { + i.Lock() + defer i.Unlock() -// GetInode returns the Inode. -func (i *BindInfo) GetInode() int { return i.Inode } + return i.UID +} + +// GetUIDandInode returns the UID and Inode. +func (i *BindInfo) GetUIDandInode() (int, int) { + i.Lock() + defer i.Unlock() + + return i.UID, i.Inode +} // compile time checks var _ Info = new(ConnectionInfo) diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index 76c54cb8..c0703caa 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -18,6 +18,7 @@ import ( const ( tcpWriteTimeout = 1 * time.Second ignoreQueriesAfter = 10 * time.Minute + heartbeatTimeout = 15 * time.Second ) // TCPResolver is a resolver using just a single tcp connection with pipelining. @@ -29,7 +30,7 @@ type TCPResolver struct { clientStarted *abool.AtomicBool clientHeartbeat chan struct{} - clientCancel func() + stopClient func() connInstanceID *uint32 queries chan *dns.Msg inFlightQueries map[uint16]*InFlightQuery @@ -75,9 +76,9 @@ func NewTCPResolver(resolver *Resolver) *TCPResolver { }, clientStarted: abool.New(), clientHeartbeat: make(chan struct{}), - clientCancel: func() {}, + stopClient: func() {}, connInstanceID: &instanceID, - queries: make(chan *dns.Msg, 100), + queries: make(chan *dns.Msg, 1000), inFlightQueries: make(map[uint16]*InFlightQuery), } } @@ -181,15 +182,15 @@ func (tr *TCPResolver) checkClientStatus() { // Get client cancel function before waiting in order to not immediately // cancel a new client. tr.Lock() - cancelClient := tr.clientCancel + stopClient := tr.stopClient tr.Unlock() // Check if the client is alive with the heartbeat, if not shut it down. select { case tr.clientHeartbeat <- struct{}{}: - case <-time.After(defaultRequestTimeout): + case <-time.After(heartbeatTimeout): log.Warningf("resolver: heartbeat failed for %s dns client, stopping", tr.resolver.GetName()) - cancelClient() + stopClient() } } @@ -214,16 +215,16 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { mgr.tr.clientStarted.Set() // Create additional cancel function for this worker. - workerCtx, cancelWorker := context.WithCancel(workerCtx) + clientCtx, stopClient := context.WithCancel(workerCtx) mgr.tr.Lock() - mgr.tr.clientCancel = cancelWorker + mgr.tr.stopClient = stopClient mgr.tr.Unlock() // connection lifecycle loop for { // check if we are shutting down select { - case <-workerCtx.Done(): + case <-clientCtx.Done(): return nil default: } @@ -234,7 +235,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { } // wait for work before creating connection - proceed := mgr.waitForWork(workerCtx) + proceed := mgr.waitForWork(clientCtx) if !proceed { return nil } @@ -250,7 +251,7 @@ func (mgr *tcpResolverConnMgr) run(workerCtx context.Context) error { netenv.ReportSuccessfulConnection() // handle queries - proceed = mgr.queryHandler(workerCtx, conn, connClosing, connCtx, cancelConnCtx) + proceed = mgr.queryHandler(clientCtx, conn, connClosing, connCtx, cancelConnCtx) if !proceed { return nil } @@ -276,7 +277,7 @@ func (mgr *tcpResolverConnMgr) shutdown() { } } -func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed bool) { +func (mgr *tcpResolverConnMgr) waitForWork(clientCtx context.Context) (proceed bool) { // wait until there is something to do mgr.tr.Lock() waiting := len(mgr.tr.inFlightQueries) @@ -308,7 +309,7 @@ func (mgr *tcpResolverConnMgr) waitForWork(workerCtx context.Context) (proceed b // wait for first query select { - case <-workerCtx.Done(): + case <-clientCtx.Done(): return false case msg := <-mgr.tr.queries: // re-insert query, we will handle it later @@ -362,7 +363,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() ( ) // start reader - module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(workerCtx context.Context) error { + module.StartServiceWorker("dns client reader", 10*time.Millisecond, func(clientCtx context.Context) error { return mgr.msgReader(conn, connClosing, cancelConnCtx) }) @@ -370,7 +371,7 @@ func (mgr *tcpResolverConnMgr) establishConnection() ( } func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context _is_ the first parameter. - workerCtx context.Context, + clientCtx context.Context, conn *dns.Conn, connClosing *abool.AtomicBool, connCtx context.Context, @@ -394,7 +395,7 @@ func (mgr *tcpResolverConnMgr) queryHandler( //nolint:golint // context.Context case <-mgr.tr.clientHeartbeat: // respond to alive checks - case <-workerCtx.Done(): + case <-clientCtx.Done(): // module shutdown return false