Clean up network/* packages, revamp online status detection

This commit is contained in:
Daniel 2019-10-25 13:33:36 +02:00
parent c72f956fe8
commit fdb5f6fcf7
27 changed files with 738 additions and 268 deletions

View file

@ -7,6 +7,7 @@ import (
"github.com/safing/portmaster/network/netutils"
)
// GetAssignedAddresses returns the assigned IPv4 and IPv6 addresses of the host.
func GetAssignedAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
@ -25,6 +26,7 @@ func GetAssignedAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) {
return
}
// GetAssignedGlobalAddresses returns the assigned global IPv4 and IPv6 addresses of the host.
func GetAssignedGlobalAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) {
allv4, allv6, err := GetAssignedAddresses()
if err != nil {

View file

@ -156,7 +156,7 @@ func getNameserversFromDbus() ([]Nameserver, error) {
return nameservers, nil
}
func getConnectivityStateFromDbus() (uint8, error) {
func getConnectivityStateFromDbus() (OnlineStatus, error) {
var err error
dbusConnLock.Lock()
@ -187,18 +187,18 @@ func getConnectivityStateFromDbus() (uint8, error) {
switch connectivityState {
case 0:
return UNKNOWN, nil
return StatusUnknown, nil
case 1:
return OFFLINE, nil
return StatusOffline, nil
case 2:
return PORTAL, nil
return StatusPortal, nil
case 3:
return LIMITED, nil
return StatusLimited, nil
case 4:
return ONLINE, nil
return StatusOnline, nil
}
return UNKNOWN, nil
return StatusUnknown, nil
}
func getNetworkManagerProperty(conn *dbus.Conn, objectPath dbus.ObjectPath, property string) (dbus.Variant, error) {

View file

@ -8,5 +8,5 @@ func getNameserversFromDbus() ([]Nameserver, error) {
}
func getConnectivityStateFromDbus() (uint8, error) {
return UNKNOWN, nil
return StatusUnknown, nil
}

View file

@ -1,8 +1,16 @@
package environment
import "testing"
import (
"os"
"testing"
)
func TestDbus(t *testing.T) {
if _, err := os.Stat("/var/run/dbus/system_bus_socket"); os.IsNotExist(err) {
t.Logf("skipping dbus tests, as dbus does not seem to be installed: %s", err)
return
}
nameservers, err := getNameserversFromDbus()
if err != nil {
t.Errorf("getNameserversFromDbus failed: %s", err)

View file

@ -0,0 +1,21 @@
package environment
import "net"
var (
localAddrFactory func(network string) net.Addr
)
// SetLocalAddrFactory supplies the environment package with a function to get permitted local addresses for connections.
func SetLocalAddrFactory(laf func(network string) net.Addr) {
if localAddrFactory == nil {
localAddrFactory = laf
}
}
func getLocalAddr(network string) net.Addr {
if localAddrFactory != nil {
return localAddrFactory(network)
}
return nil
}

View file

@ -1,15 +1,9 @@
package environment
import (
"bytes"
"crypto/sha1"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
)
// TODO: find a good way to identify a network
@ -22,25 +16,11 @@ import (
// this info might already be included in the interfaces api provided by golang!
const (
UNKNOWN uint8 = iota
OFFLINE
LIMITED // local network only
PORTAL // there seems to be an internet connection, but we are being intercepted
ONLINE
)
const (
connectivityRecheck = 2 * time.Second
interfacesRecheck = 2 * time.Second
gatewaysRecheck = 2 * time.Second
nameserversRecheck = 2 * time.Second
gatewaysRecheck = 2 * time.Second
nameserversRecheck = 2 * time.Second
)
var (
connectivity uint8
connectivityLock sync.Mutex
connectivityExpires = time.Now()
// interfaces = make(map[*net.IP]net.Flags)
// interfacesLock sync.Mutex
// interfacesExpires = time.Now()
@ -52,114 +32,10 @@ var (
nameservers = make([]Nameserver, 0)
nameserversLock sync.Mutex
nameserversExpires = time.Now()
lastNetworkChange *int64
lastNetworkChecksum []byte
)
// Nameserver describes a system assigned namserver.
type Nameserver struct {
IP net.IP
Search []string
}
func init() {
lnc := int64(0)
lastNetworkChange = &lnc
go func() {
time.Sleep(1 * time.Second)
Connectivity()
}()
go monitorNetworkChanges()
}
// Connectivity returns the current state of connectivity to the network/Internet
func Connectivity() uint8 {
// locking
connectivityLock.Lock()
defer connectivityLock.Unlock()
// cache
if connectivityExpires.After(time.Now()) {
return connectivity
}
// logic
// TODO: implement more methods
status, err := getConnectivityStateFromDbus()
if err != nil {
log.Warningf("environment: could not get connectivity: %s", err)
setConnectivity(UNKNOWN)
return UNKNOWN
}
setConnectivity(status)
return status
}
func setConnectivity(status uint8) {
if connectivity != status {
connectivity = status
connectivityExpires = time.Now().Add(connectivityRecheck)
var connectivityName string
switch connectivity {
case UNKNOWN:
connectivityName = "unknown"
case OFFLINE:
connectivityName = "offline"
case LIMITED:
connectivityName = "limited"
case PORTAL:
connectivityName = "portal"
case ONLINE:
connectivityName = "online"
default:
connectivityName = "invalid"
}
log.Infof("environment: connectivity changed to %s", connectivityName)
}
}
// ConnectionSucceeded should be called when a module was able to successfully connect to the internet (do not call too often)
func ConnectionSucceeded() {
connectivityLock.Lock()
defer connectivityLock.Unlock()
setConnectivity(ONLINE)
}
func monitorNetworkChanges() {
// TODO: make more elegant solution
for {
time.Sleep(2 * time.Second)
hasher := sha1.New()
interfaces, err := net.Interfaces()
if err != nil {
log.Warningf("environment: failed to get interfaces: %s", err)
continue
}
for _, iface := range interfaces {
io.WriteString(hasher, iface.Name)
// log.Tracef("adding: %s", iface.Name)
io.WriteString(hasher, iface.Flags.String())
// log.Tracef("adding: %s", iface.Flags.String())
addrs, err := iface.Addrs()
if err != nil {
log.Warningf("environment: failed to get addrs from interface %s: %s", iface.Name, err)
continue
}
for _, addr := range addrs {
io.WriteString(hasher, addr.String())
// log.Tracef("adding: %s", addr.String())
}
}
newChecksum := hasher.Sum(nil)
if !bytes.Equal(lastNetworkChecksum, newChecksum) {
if len(lastNetworkChecksum) == 0 {
lastNetworkChecksum = newChecksum
continue
}
lastNetworkChecksum = newChecksum
atomic.StoreInt64(lastNetworkChange, time.Now().Unix())
log.Info("environment: network changed")
triggerNetworkChanged()
}
}
}

View file

@ -57,7 +57,7 @@ func Gateways() []*net.IP {
continue
}
if len(decoded) != 4 {
log.Warningf("environment: decoded gateway %s from /proc/net/route has wrong length")
log.Warningf("environment: decoded gateway %s from /proc/net/route has wrong length", decoded)
continue
}
gate := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0])
@ -90,7 +90,7 @@ func Gateways() []*net.IP {
continue
}
if len(decoded) != 16 {
log.Warningf("environment: decoded gateway %s from /proc/net/ipv6_route has wrong length")
log.Warningf("environment: decoded gateway %s from /proc/net/ipv6_route has wrong length", decoded)
continue
}
gate := net.IP(decoded)
@ -134,7 +134,6 @@ func Nameservers() []Nameserver {
resolvconfNameservers, err := getNameserversFromResolvconf()
if err != nil {
log.Warningf("environment: could not get nameservers from resolvconf: %s", err)
resolvconfNameservers = make([]Nameserver, 0)
} else {
nameservers = addNameservers(nameservers, resolvconfNameservers)
}
@ -178,7 +177,7 @@ func getNameserversFromResolvconf() ([]Nameserver, error) {
}
// build array
var nameservers []Nameserver
nameservers := make([]Nameserver, 0, len(servers))
for _, server := range servers {
nameservers = append(nameservers, Nameserver{
IP: server,

View file

@ -6,9 +6,6 @@ import "testing"
func TestEnvironment(t *testing.T) {
connectivityTest := Connectivity()
t.Logf("connectivity: %v", connectivityTest)
nameserversTest, err := getNameserversFromResolvconf()
if err != nil {
t.Errorf("failed to get namerservers from resolvconf: %s", err)

View file

@ -1,23 +0,0 @@
package environment
import (
"sync"
)
var (
networkChangedEventCh = make(chan struct{}, 0)
networkChangedEventLock sync.Mutex
)
func triggerNetworkChanged() {
networkChangedEventLock.Lock()
defer networkChangedEventLock.Unlock()
close(networkChangedEventCh)
networkChangedEventCh = make(chan struct{}, 0)
}
func NetworkChanged() <-chan struct{} {
networkChangedEventLock.Lock()
defer networkChangedEventLock.Unlock()
return networkChangedEventCh
}

View file

@ -1,27 +0,0 @@
package environment
import (
"sync"
"sync/atomic"
)
type EnvironmentInterface struct {
lastNetworkChange int64
lock sync.Mutex
}
func NewInterface() *EnvironmentInterface {
return &EnvironmentInterface{
lastNetworkChange: 0,
}
}
func (env *EnvironmentInterface) NetworkChanged() bool {
env.lock.Lock()
defer env.lock.Unlock()
lnc := atomic.LoadInt64(lastNetworkChange)
if lnc > env.lastNetworkChange {
return true
}
return false
}

View file

@ -1,7 +1,6 @@
package environment
import (
"errors"
"fmt"
"log"
"net"
@ -14,20 +13,21 @@ import (
"golang.org/x/net/ipv4"
)
// TODO: reference forking
// TODO: Create IPv6 version of GetApproximateInternetLocation
// GetApproximateInternetLocation returns the IP-address of the nearest ping-answering internet node
//nolint:gocognit // TODO
func GetApproximateInternetLocation() (net.IP, error) {
// TODO: first check if we have a public IP
// net.InterfaceAddrs()
// Traceroute example
var dst net.IPAddr
dst.IP = net.IPv4(8, 8, 8, 8)
dst := net.IPAddr{
IP: net.IPv4(1, 1, 1, 1),
}
c, err := net.ListenPacket("ip4:1", "0.0.0.0") // ICMP for IPv4
c, err := net.ListenPacket("ip4:icmp", "0.0.0.0") // ICMP for IPv4
if err != nil {
return nil, err
}
@ -42,9 +42,8 @@ func GetApproximateInternetLocation() (net.IP, error) {
wm := icmp.Message{
Type: ipv4.ICMPTypeEcho, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
// TODO: think of something better and not suspicious
Data: []byte("HELLO-R-U-THERE"),
ID: os.Getpid() & 0xffff,
Data: []byte{0},
},
}
rb := make([]byte, 1500)
@ -96,7 +95,7 @@ next:
case ipv4.ICMPTypeTimeExceeded:
ip := net.ParseIP(peer.String())
if ip == nil {
return nil, errors.New(fmt.Sprintf("failed to parse IP: %s", peer.String()))
return nil, fmt.Errorf("failed to parse IP: %s", peer.String())
}
if !netutils.IPIsLAN(ip) {
return ip, nil

View file

@ -0,0 +1,42 @@
package environment
import (
"errors"
"github.com/safing/portbase/modules"
)
const (
networkChangedEvent = "network changed"
onlineStatusChangedEvent = "online status changed"
)
var (
module *modules.Module
)
func InitSubModule(m *modules.Module) {
module = m
module.RegisterEvent(networkChangedEvent)
module.RegisterEvent(onlineStatusChangedEvent)
}
func StartSubModule() error {
if module == nil {
return errors.New("not initialized")
}
module.StartServiceWorker(
"monitor network changes",
0,
monitorNetworkChanges,
)
module.StartServiceWorker(
"monitor online status",
0,
monitorOnlineStatus,
)
return nil
}

View file

@ -0,0 +1,91 @@
package environment
import (
"bytes"
"context"
"crypto/sha1" //nolint:gosec // not used for security
"io"
"net"
"time"
"github.com/safing/portbase/log"
)
var (
networkChangeCheckTrigger = make(chan struct{}, 1)
)
func triggerNetworkChangeCheck() {
select {
case networkChangeCheckTrigger <- struct{}{}:
default:
}
}
func monitorNetworkChanges(ctx context.Context) error {
var lastNetworkChecksum []byte
serviceLoop:
for {
trigger := false
// wait for trigger
if GetOnlineStatus() == StatusOnline {
select {
case <-ctx.Done():
return nil
case <-networkChangeCheckTrigger:
case <-time.After(1 * time.Minute):
trigger = true
}
} else {
select {
case <-ctx.Done():
return nil
case <-networkChangeCheckTrigger:
case <-time.After(1 * time.Second):
trigger = true
}
}
// check network for changes
// create hashsum of current network config
hasher := sha1.New() //nolint:gosec // not used for security
interfaces, err := net.Interfaces()
if err != nil {
log.Warningf("environment: failed to get interfaces: %s", err)
continue
}
for _, iface := range interfaces {
_, _ = io.WriteString(hasher, iface.Name)
// log.Tracef("adding: %s", iface.Name)
_, _ = io.WriteString(hasher, iface.Flags.String())
// log.Tracef("adding: %s", iface.Flags.String())
addrs, err := iface.Addrs()
if err != nil {
log.Warningf("environment: failed to get addrs from interface %s: %s", iface.Name, err)
continue
}
for _, addr := range addrs {
_, _ = io.WriteString(hasher, addr.String())
// log.Tracef("adding: %s", addr.String())
}
}
newChecksum := hasher.Sum(nil)
// compare checksum with last
if !bytes.Equal(lastNetworkChecksum, newChecksum) {
if len(lastNetworkChecksum) == 0 {
lastNetworkChecksum = newChecksum
continue serviceLoop
}
lastNetworkChecksum = newChecksum
if trigger {
triggerOnlineStatusInvestigation()
}
module.TriggerEvent(networkChangedEvent, nil)
}
}
}

View file

@ -0,0 +1,19 @@
Intel:
- First ever request: use first resolver as selected
- If resolver fails:
- stop all requesting
- get network status
- if failed: do nothing, return offline error
- check list front to back, use first resolver that resolves one.one.one.one correctly
NetEnv:
- check for intercepted HTTP Request requests
- if fails on:
- connection establishment: OFFLINE
-
- check for intercepted HTTPS Request requests
- check for intercepted DNS requests

View file

@ -0,0 +1,353 @@
package environment
import (
"context"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/network/netutils"
"github.com/tevino/abool"
)
// OnlineStatus represent a state of connectivity to the Internet.
type OnlineStatus uint8
// Online Status Values
const (
StatusUnknown OnlineStatus = 0
StatusOffline OnlineStatus = 1
StatusLimited OnlineStatus = 2 // local network only
StatusPortal OnlineStatus = 3 // there seems to be an internet connection, but we are being intercepted, possibly by a captive portal
StatusSemiOnline OnlineStatus = 4 // we seem to online, but without full connectivity
StatusOnline OnlineStatus = 5
)
// Online Status and Resolver
const (
HTTPTestURL = "http://detectportal.firefox.com/success.txt"
HTTPExpectedContent = "success"
HTTPSTestURL = "https://one.one.one.one/"
ResolverTestFqdn = "one.one.one.one."
ResolverTestRRType = dns.TypeA
ResolverTestExpectedResponse = "1.1.1.1"
)
var (
parsedHTTPTestURL *url.URL
parsedHTTPSTestURL *url.URL
)
func init() {
var err error
parsedHTTPTestURL, err = url.Parse(HTTPTestURL)
if err != nil {
panic(err)
}
parsedHTTPSTestURL, err = url.Parse(HTTPSTestURL)
if err != nil {
panic(err)
}
}
// IsOnlineStatusTestDomain checks whether the given fqdn is used for testing online status.
func IsOnlineStatusTestDomain(domain string) bool {
switch domain {
case "detectportal.firefox.com.":
return true
case "one.one.one.one.":
return true
}
return false
}
// GetResolverTestingRequestData returns request information that should be used to test DNS resolvers for availability and basic correct behaviour.
func GetResolverTestingRequestData() (fqdn string, rrType uint16, expectedResponse string) {
return ResolverTestFqdn, ResolverTestRRType, ResolverTestExpectedResponse
}
func (os OnlineStatus) String() string {
switch os {
default:
return "Unknown"
case StatusOffline:
return "Offline"
case StatusLimited:
return "Limited"
case StatusPortal:
return "Portal"
case StatusSemiOnline:
return "SemiOnline"
case StatusOnline:
return "Online"
}
}
var (
onlineStatus *int32
onlineStatusQuickCheck = abool.NewBool(false)
onlineStatusInvestigationTrigger = make(chan struct{}, 1)
onlineStatusInvestigationInProgress = abool.NewBool(false)
onlineStatusInvestigationWg sync.WaitGroup
captivePortalURL string
captivePortalLock sync.Mutex
)
func init() {
var onlineStatusValue int32
onlineStatus = &onlineStatusValue
}
// Online returns true if online status is either SemiOnline or Online.
func Online() bool {
return onlineStatusQuickCheck.IsSet()
}
// GetOnlineStatus returns the current online stats.
func GetOnlineStatus() OnlineStatus {
return OnlineStatus(atomic.LoadInt32(onlineStatus))
}
// CheckAndGetOnlineStatus triggers a new online status check and returns the result
func CheckAndGetOnlineStatus() OnlineStatus {
// trigger new investigation
triggerOnlineStatusInvestigation()
// wait for completion
onlineStatusInvestigationWg.Wait()
// return current status
return GetOnlineStatus()
}
func updateOnlineStatus(status OnlineStatus, portalURL, comment string) {
changed := false
// status
currentStatus := atomic.LoadInt32(onlineStatus)
if status != OnlineStatus(currentStatus) && atomic.CompareAndSwapInt32(onlineStatus, currentStatus, int32(status)) {
// status changed!
onlineStatusQuickCheck.SetTo(
status == StatusOnline || status == StatusSemiOnline,
)
changed = true
}
// captive portal
captivePortalLock.Lock()
defer captivePortalLock.Unlock()
if portalURL != captivePortalURL {
captivePortalURL = portalURL
changed = true
}
// trigger event
if changed {
module.TriggerEvent(onlineStatusChangedEvent, nil)
if status == StatusPortal {
log.Infof(`network: setting online status to %s at "%s" (%s)`, status, captivePortalURL, comment)
} else {
log.Infof("network: setting online status to %s (%s)", status, comment)
}
triggerNetworkChangeCheck()
}
}
// GetCaptivePortalURL returns the current captive portal url as a string.
func GetCaptivePortalURL() string {
captivePortalLock.Lock()
defer captivePortalLock.Unlock()
return captivePortalURL
}
// ReportSuccessfulConnection hints the online status monitoring system that a connection attempt was successful.
func ReportSuccessfulConnection() {
if !onlineStatusQuickCheck.IsSet() {
triggerOnlineStatusInvestigation()
}
}
// ReportFailedConnection hints the online status monitoring system that a connection attempt has failed. This function has extremely low overhead and may be called as much as wanted.
func ReportFailedConnection() {
if onlineStatusQuickCheck.IsSet() {
triggerOnlineStatusInvestigation()
}
}
func triggerOnlineStatusInvestigation() {
if onlineStatusInvestigationInProgress.SetToIf(false, true) {
onlineStatusInvestigationWg.Add(1)
}
select {
case onlineStatusInvestigationTrigger <- struct{}{}:
default:
}
}
func monitorOnlineStatus(ctx context.Context) error {
for {
// wait for trigger
if GetOnlineStatus() == StatusOnline {
select {
case <-ctx.Done():
return nil
case <-onlineStatusInvestigationTrigger:
case <-time.After(1 * time.Minute):
}
} else {
select {
case <-ctx.Done():
return nil
case <-onlineStatusInvestigationTrigger:
case <-time.After(1 * time.Second):
}
}
// enable waiting
if onlineStatusInvestigationInProgress.SetToIf(false, true) {
onlineStatusInvestigationWg.Add(1)
}
checkOnlineStatus(ctx)
// finished!
onlineStatusInvestigationWg.Done()
onlineStatusInvestigationInProgress.UnSet()
}
}
func checkOnlineStatus(ctx context.Context) {
// TODO: implement more methods
/*status, err := getConnectivityStateFromDbus()
if err != nil {
log.Warningf("environment: could not get connectivity: %s", err)
setConnectivity(StatusUnknown)
return StatusUnknown
}*/
// 1) check for addresses
ipv4, ipv6, err := GetAssignedAddresses()
if err != nil {
log.Warningf("network: failed to get assigned network addresses: %s", err)
} else {
var lan bool
for _, ip := range ipv4 {
switch netutils.ClassifyIP(ip) {
case netutils.SiteLocal:
lan = true
case netutils.Global:
// we _are_ the Internet ;)
updateOnlineStatus(StatusOnline, "", "global IPv4 interface detected")
return
}
}
for _, ip := range ipv6 {
switch netutils.ClassifyIP(ip) {
case netutils.SiteLocal, netutils.Global:
// IPv6 global addresses are also used in local networks
lan = true
}
}
if !lan {
updateOnlineStatus(StatusOffline, "", "no local or global interfaces detected")
return
}
}
// 2) try a http request
// TODO: find (array of) alternatives to detectportal.firefox.com
// TODO: find something about usage terms of detectportal.firefox.com
client := &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
LocalAddr: getLocalAddr("tcp"),
DualStack: true,
}).DialContext,
DisableKeepAlives: true,
DisableCompression: true,
WriteBufferSize: 1024,
ReadBufferSize: 1024,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: 5 * time.Second,
}
request := (&http.Request{
Method: "GET",
URL: parsedHTTPTestURL,
Close: true,
}).WithContext(ctx)
response, err := client.Do(request)
if err != nil {
updateOnlineStatus(StatusLimited, "", "http request failed")
return
}
defer response.Body.Close()
// check location
portalURL, err := response.Location()
if err == nil {
updateOnlineStatus(StatusPortal, portalURL.String(), "http request succeeded with redirect")
return
}
// read the body
data, err := ioutil.ReadAll(response.Body)
if err != nil {
log.Warningf("network: failed to read http body of captive portal testing response: %s", err)
// assume we are online nonetheless
updateOnlineStatus(StatusOnline, "", "http request succeeded, albeit failing later")
return
}
// check body contents
if strings.TrimSpace(string(data)) == HTTPExpectedContent {
updateOnlineStatus(StatusOnline, "", "http request succeeded")
} else {
// something is interfering with the website content
// this might be a weird captive portal, just direct the user there
updateOnlineStatus(StatusPortal, "detectportal.firefox.com", "http request succeeded, response content not as expected")
}
// 3) try a https request
request = (&http.Request{
Method: "HEAD",
URL: parsedHTTPSTestURL,
Close: true,
}).WithContext(ctx)
// only test if we can get the headers
response, err = client.Do(request)
if err != nil {
// if we fail, something is really weird
updateOnlineStatus(StatusSemiOnline, "", "http request failed")
return
}
defer response.Body.Close()
// finally
updateOnlineStatus(StatusOnline, "", "all checks successful")
}

View file

@ -0,0 +1,12 @@
package environment
import (
"context"
"testing"
)
func TestCheckOnlineStatus(t *testing.T) {
checkOnlineStatus(context.Background())
t.Logf("online status: %s", GetOnlineStatus())
t.Logf("captive portal: %s", GetCaptivePortalURL())
}

View file

@ -4,46 +4,52 @@ import (
"fmt"
"sync"
"github.com/tevino/abool"
maxminddb "github.com/oschwald/maxminddb-golang"
"github.com/safing/portbase/log"
"github.com/safing/portbase/updater"
"github.com/safing/portmaster/updates"
)
var (
dbCityFile *updater.File
dbASNFile *updater.File
dbFileLock sync.Mutex
dbCity *maxminddb.Reader
dbASN *maxminddb.Reader
dbLock sync.Mutex
dbLock sync.Mutex
dbInUse = false // only activate if used for first time
dbDoReload = true // if database should be reloaded
dbInUse = abool.NewBool(false) // only activate if used for first time
dbDoReload = abool.NewBool(true) // if database should be reloaded
)
// ReloadDatabases reloads the geoip database, if they are in use.
func ReloadDatabases() error {
dbLock.Lock()
defer dbLock.Unlock()
// don't do anything if the database isn't actually used
if !dbInUse {
if !dbInUse.IsSet() {
return nil
}
dbDoReload = true
dbFileLock.Lock()
defer dbFileLock.Unlock()
dbLock.Lock()
defer dbLock.Unlock()
dbDoReload.Set()
return doReload()
}
func prepDatabaseForUse() error {
dbInUse = true
dbInUse.Set()
return doReload()
}
func doReload() error {
// reload if needed
if dbDoReload {
defer func() {
dbDoReload = false
}()
if dbDoReload.SetToIf(true, false) {
closeDBs()
return openDBs()
}
@ -53,7 +59,7 @@ func doReload() error {
func openDBs() error {
var err error
file, err := updates.GetFile("intel/geoip-city.mmdb")
file, err := updates.GetFile("intel/geoip/geoip-city.mmdb")
if err != nil {
return fmt.Errorf("could not get GeoIP City database file: %s", err)
}
@ -61,7 +67,8 @@ func openDBs() error {
if err != nil {
return err
}
file, err = updates.GetFile("intel/geoip-asn.mmdb")
file, err = updates.GetFile("intel/geoip/geoip-asn.mmdb")
if err != nil {
return fmt.Errorf("could not get GeoIP ASN database file: %s", err)
}
@ -73,8 +80,8 @@ func openDBs() error {
}
func handleError(err error) {
log.Warningf("network/geoip: lookup failed, reloading databases...")
dbDoReload = true
log.Errorf("network/geoip: lookup failed, reloading databases: %s", err)
dbDoReload.Set()
}
func closeDBs() {

View file

@ -8,7 +8,7 @@ import (
)
const (
earthCircumferenceKm float64 = 40100 // earth circumference in km
earthCircumferenceInKm float64 = 40100 // earth circumference in km
)
// Location holds information regarding the geographical and network location of an IP address
@ -42,7 +42,7 @@ type Location struct {
// Conclusion:
// - Ignore location data completely if accuracy_radius > 500
// EstimateNetworkProximity aims to calculate a distance value between 0 and 100.
// EstimateNetworkProximity aims to calculate the distance between two network locations. Returns a proximity value between 0 (far away) and 100 (nearby).
func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) {
// Distance Value:
// 0: other side of the Internet
@ -50,12 +50,10 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) {
// Weighting:
// coordinate distance: 0-50
// continent match: 10
// continent match: 15
// country match: 10
// AS owner match: 15
// AS network match: 15
//
// We prioritize AS information over country information, as it is more accurate and we expect better privacy if we already are in the destination AS.
// AS network match: 10
// coordinate distance: 0-50
fromCoords := haversine.Coord{Lat: l.Coordinates.Latitude, Lon: l.Coordinates.Longitude}
@ -69,19 +67,19 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) {
accuracy = to.Coordinates.AccuracyRadius
}
if km <= 10 && accuracy <= 200 {
if km <= 10 && accuracy <= 100 {
proximity += 50
} else {
distanceIn50Percent := ((earthCircumferenceKm - km) / earthCircumferenceKm) * 50
distanceIn50Percent := ((earthCircumferenceInKm - km) / earthCircumferenceInKm) * 50
// apply penalty for values high values (targeting >100)
// apply penalty for locations with low accuracy (targeting accuracy radius >100)
accuracyModifier := 1 - float64(accuracy)/1000
proximity += int(distanceIn50Percent * accuracyModifier)
}
// continent match: 10
// continent match: 15
if l.Continent.Code == to.Continent.Code {
proximity += 10
proximity += 15
// country match: 10
if l.Country.ISOCode == to.Country.ISOCode {
proximity += 10
@ -91,16 +89,16 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) {
// AS owner match: 15
if l.AutonomousSystemOrganization == to.AutonomousSystemOrganization {
proximity += 15
// AS network match: 15
// AS network match: 10
if l.AutonomousSystemNumber == to.AutonomousSystemNumber {
proximity += 15
proximity += 10
}
}
return
return //nolint:nakedreturn
}
// PrimitiveNetworkProximity calculates the numerical distance between two IP addresses. Returns a proximity value between 0 (far away) and 100 (nearby).
func PrimitiveNetworkProximity(from net.IP, to net.IP, ipVersion uint8) int {
var diff float64
@ -128,7 +126,7 @@ func PrimitiveNetworkProximity(from net.IP, to net.IP, ipVersion uint8) int {
switch ipVersion {
case 4:
diff = diff / 256
diff /= 256
return int((1 - diff/16777216) * 100)
case 6:
return int((1 - diff/18446744073709552000) * 100)

View file

@ -3,9 +3,16 @@ package geoip
import (
"net"
"testing"
"github.com/safing/portmaster/updates"
)
func TestLocationLookup(t *testing.T) {
err := updates.InitForTesting()
if err != nil {
t.Fatal(err)
}
ip1 := net.ParseIP("81.2.69.142")
loc1, err := GetLocation(ip1)
if err != nil {

58
network/geoip/module.go Normal file
View file

@ -0,0 +1,58 @@
package geoip
import (
"context"
"fmt"
"time"
"github.com/safing/portbase/modules"
)
var (
module *modules.Module
)
func init() {
module = modules.Register("geoip", nil, start, nil, "updates")
}
func start() error {
err := prepDatabaseForUse()
if err != nil {
return fmt.Errorf("goeip: failed to load databases: %s", err)
}
module.RegisterEventHook(
"updates",
"resource update",
"upgrade databases",
upgradeDatabases,
)
// TODO: replace with update subscription
module.NewTask("update databases", func(ctx context.Context, task *modules.Task) {
dbFileLock.Lock()
defer dbFileLock.Unlock()
}).Repeat(10 * time.Minute).MaxDelay(1 * time.Hour)
return nil
}
func upgradeDatabases(_ context.Context, _ interface{}) error {
dbFileLock.Lock()
reload := false
if dbCityFile != nil && dbCityFile.UpgradeAvailable() {
reload = true
}
if dbASNFile != nil && dbASNFile.UpgradeAvailable() {
reload = true
}
dbFileLock.Unlock()
if reload {
return ReloadDatabases()
}
return nil
}

View file

@ -98,7 +98,7 @@ func (link *Link) HandlePacket(pkt packet.Packet) {
link.pktQueue <- pkt
return
}
log.Criticalf("network: link %s does not have a firewallHandler, dropping packet", link)
log.Warningf("network: link %s does not have a firewallHandler, dropping packet", link)
pkt.Drop()
}
@ -175,14 +175,18 @@ func (link *Link) packetHandler() {
if pkt == nil {
return
}
// get handler
link.Lock()
fwH := link.firewallHandler
handler := link.firewallHandler
link.Unlock()
if fwH != nil {
fwH(pkt, link)
// execute handler or verdict
if handler != nil {
handler(pkt, link)
} else {
link.ApplyVerdict(pkt)
}
// submit trace logs
log.Tracer(pkt.Ctx()).Submit()
}
}
@ -311,10 +315,10 @@ func GetOrCreateLinkByPacket(pkt packet.Packet) (*Link, bool) {
// CreateLinkFromPacket creates a new Link based on Packet.
func CreateLinkFromPacket(pkt packet.Packet) *Link {
link := &Link{
ID: pkt.GetLinkID(),
Verdict: VerdictUndecided,
Started: time.Now().Unix(),
RemoteAddress: pkt.FmtRemoteAddress(),
ID: pkt.GetLinkID(),
Verdict: VerdictUndecided,
Started: time.Now().Unix(),
RemoteAddress: pkt.FmtRemoteAddress(),
saveWhenFinished: true,
}
return link

View file

@ -2,13 +2,25 @@ package network
import (
"github.com/safing/portbase/modules"
"github.com/safing/portmaster/network/environment"
)
var (
module *modules.Module
)
func init() {
modules.Register("network", nil, start, nil, "core")
module = modules.Register("network", nil, start, nil, "core")
environment.InitSubModule(module)
}
func start() error {
err := registerAsDatabase()
if err != nil {
return err
}
go cleaner()
return registerAsDatabase()
return environment.StartSubModule()
}

View file

@ -5,8 +5,7 @@ import (
)
var (
// cleanDomainRegex = regexp.MustCompile("^(((?!-))(xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\\.[a-z]{2,}\\.)$")
cleanDomainRegex = regexp.MustCompile("^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\\.[a-z]{2,}\\.)$")
cleanDomainRegex = regexp.MustCompile(`^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\.[a-z]{2,}\.)$`)
)
// IsValidFqdn returns whether the given string is a valid fqdn.

View file

@ -10,7 +10,7 @@ const (
Global
LocalMulticast
GlobalMulticast
Invalid
Invalid int8 = -1
)
// ClassifyIP returns the classification for the given IP address.
@ -77,9 +77,7 @@ func IPIsLocalhost(ip net.IP) bool {
// IPIsLAN returns true if the given IP is a site-local or link-local address.
func IPIsLAN(ip net.IP) bool {
switch ClassifyIP(ip) {
case SiteLocal:
return true
case LinkLocal:
case SiteLocal, LinkLocal:
return true
default:
return false

View file

@ -7,32 +7,34 @@ import (
"github.com/google/gopacket/tcpassembly"
)
// SimpleStreamAssemblerManager is a simple manager for github.com/google/gopacket/tcpassembly
type SimpleStreamAssemblerManager struct {
InitLock sync.Mutex
lastAssembler *SimpleStreamAssembler
}
// New returns a new stream assembler.
func (m *SimpleStreamAssemblerManager) New(net, transport gopacket.Flow) tcpassembly.Stream {
assembler := new(SimpleStreamAssembler)
m.lastAssembler = assembler
return assembler
}
// GetLastAssembler returns the newest created stream assembler.
func (m *SimpleStreamAssemblerManager) GetLastAssembler() *SimpleStreamAssembler {
// defer func() {
// m.lastAssembler = nil
// }()
return m.lastAssembler
}
// SimpleStreamAssembler is a simple assembler for github.com/google/gopacket/tcpassembly
type SimpleStreamAssembler struct {
Cumulated []byte
CumulatedLen int
Complete bool
}
// NewSimpleStreamAssembler returns a new SimpleStreamAssembler.
func NewSimpleStreamAssembler() *SimpleStreamAssembler {
return new(SimpleStreamAssembler)
return &SimpleStreamAssembler{}
}
// Reassembled implements tcpassembly.Stream's Reassembled function.

View file

@ -5,13 +5,17 @@ import (
"fmt"
)
// Basic Types
type (
IPVersion uint8
// IPVersion represents an IP version.
IPVersion uint8
// IPProtocol represents an IP protocol.
IPProtocol uint8
Verdict uint8
Endpoint bool
// Verdict describes the decision on a packet.
Verdict uint8
)
// Basic Constants
const (
IPv4 = IPVersion(4)
IPv6 = IPVersion(6)
@ -19,18 +23,15 @@ const (
InBound = true
OutBound = false
Local = true
Remote = false
// convenience
ICMP = IPProtocol(1)
IGMP = IPProtocol(2)
RAW = IPProtocol(255)
TCP = IPProtocol(6)
UDP = IPProtocol(17)
ICMP = IPProtocol(1)
ICMPv6 = IPProtocol(58)
RAW = IPProtocol(255)
)
// Verdicts
const (
DROP Verdict = iota
BLOCK
@ -42,10 +43,11 @@ const (
)
var (
// ErrFailedToLoadPayload is returned by GetPayload if it failed for an unspecified reason, or is not implemented on the current system.
ErrFailedToLoadPayload = errors.New("could not load packet payload")
)
// Returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16
// ByteSize returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16
func (v IPVersion) ByteSize() int {
switch v {
case IPv4:
@ -56,6 +58,7 @@ func (v IPVersion) ByteSize() int {
return 0
}
// String returns the string representation of the IP version: "IPv4" or "IPv6".
func (v IPVersion) String() string {
switch v {
case IPv4:
@ -66,6 +69,7 @@ func (v IPVersion) String() string {
return fmt.Sprintf("<unknown ip version, %d>", uint8(v))
}
// String returns the string representation (abbreviation) of the protocol.
func (p IPProtocol) String() string {
switch p {
case RAW:
@ -84,12 +88,24 @@ func (p IPProtocol) String() string {
return fmt.Sprintf("<unknown protocol, %d>", uint8(p))
}
// String returns the string representation of the verdict.
func (v Verdict) String() string {
switch v {
case DROP:
return "DROP"
case BLOCK:
return "BLOCK"
case ACCEPT:
return "ACCEPT"
case STOLEN:
return "STOLEN"
case QUEUE:
return "QUEUE"
case REPEAT:
return "REPEAT"
case STOP:
return "STOP"
default:
return fmt.Sprintf("<unsupported verdict, %d>", uint8(v))
}
return fmt.Sprintf("<unsupported verdict, %d>", uint8(v))
}

View file

@ -10,9 +10,9 @@ type Info struct {
InTunnel bool
Version IPVersion
Src, Dst net.IP
Protocol IPProtocol
SrcPort, DstPort uint16
Src, Dst net.IP
}
// LocalIP returns the local IP of the packet.