diff --git a/network/environment/addresses.go b/network/environment/addresses.go index 9a9f8689..53ac900a 100644 --- a/network/environment/addresses.go +++ b/network/environment/addresses.go @@ -7,6 +7,7 @@ import ( "github.com/safing/portmaster/network/netutils" ) +// GetAssignedAddresses returns the assigned IPv4 and IPv6 addresses of the host. func GetAssignedAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) { addrs, err := net.InterfaceAddrs() if err != nil { @@ -25,6 +26,7 @@ func GetAssignedAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) { return } +// GetAssignedGlobalAddresses returns the assigned global IPv4 and IPv6 addresses of the host. func GetAssignedGlobalAddresses() (ipv4 []net.IP, ipv6 []net.IP, err error) { allv4, allv6, err := GetAssignedAddresses() if err != nil { diff --git a/network/environment/dbus_linux.go b/network/environment/dbus_linux.go index 8260d343..d967e82c 100644 --- a/network/environment/dbus_linux.go +++ b/network/environment/dbus_linux.go @@ -156,7 +156,7 @@ func getNameserversFromDbus() ([]Nameserver, error) { return nameservers, nil } -func getConnectivityStateFromDbus() (uint8, error) { +func getConnectivityStateFromDbus() (OnlineStatus, error) { var err error dbusConnLock.Lock() @@ -187,18 +187,18 @@ func getConnectivityStateFromDbus() (uint8, error) { switch connectivityState { case 0: - return UNKNOWN, nil + return StatusUnknown, nil case 1: - return OFFLINE, nil + return StatusOffline, nil case 2: - return PORTAL, nil + return StatusPortal, nil case 3: - return LIMITED, nil + return StatusLimited, nil case 4: - return ONLINE, nil + return StatusOnline, nil } - return UNKNOWN, nil + return StatusUnknown, nil } func getNetworkManagerProperty(conn *dbus.Conn, objectPath dbus.ObjectPath, property string) (dbus.Variant, error) { diff --git a/network/environment/dbus_linux_mock.go b/network/environment/dbus_linux_mock.go index d03f8248..c5f326ef 100644 --- a/network/environment/dbus_linux_mock.go +++ b/network/environment/dbus_linux_mock.go @@ -8,5 +8,5 @@ func getNameserversFromDbus() ([]Nameserver, error) { } func getConnectivityStateFromDbus() (uint8, error) { - return UNKNOWN, nil + return StatusUnknown, nil } diff --git a/network/environment/dbus_linux_test.go b/network/environment/dbus_linux_test.go index a274297c..cd2dbdb4 100644 --- a/network/environment/dbus_linux_test.go +++ b/network/environment/dbus_linux_test.go @@ -1,8 +1,16 @@ package environment -import "testing" +import ( + "os" + "testing" +) func TestDbus(t *testing.T) { + if _, err := os.Stat("/var/run/dbus/system_bus_socket"); os.IsNotExist(err) { + t.Logf("skipping dbus tests, as dbus does not seem to be installed: %s", err) + return + } + nameservers, err := getNameserversFromDbus() if err != nil { t.Errorf("getNameserversFromDbus failed: %s", err) diff --git a/network/environment/dialing.go b/network/environment/dialing.go new file mode 100644 index 00000000..14293414 --- /dev/null +++ b/network/environment/dialing.go @@ -0,0 +1,21 @@ +package environment + +import "net" + +var ( + localAddrFactory func(network string) net.Addr +) + +// SetLocalAddrFactory supplies the environment package with a function to get permitted local addresses for connections. +func SetLocalAddrFactory(laf func(network string) net.Addr) { + if localAddrFactory == nil { + localAddrFactory = laf + } +} + +func getLocalAddr(network string) net.Addr { + if localAddrFactory != nil { + return localAddrFactory(network) + } + return nil +} diff --git a/network/environment/environment.go b/network/environment/environment.go index 701ee8c0..2828297f 100644 --- a/network/environment/environment.go +++ b/network/environment/environment.go @@ -1,15 +1,9 @@ package environment import ( - "bytes" - "crypto/sha1" - "io" "net" "sync" - "sync/atomic" "time" - - "github.com/safing/portbase/log" ) // TODO: find a good way to identify a network @@ -22,25 +16,11 @@ import ( // this info might already be included in the interfaces api provided by golang! const ( - UNKNOWN uint8 = iota - OFFLINE - LIMITED // local network only - PORTAL // there seems to be an internet connection, but we are being intercepted - ONLINE -) - -const ( - connectivityRecheck = 2 * time.Second - interfacesRecheck = 2 * time.Second - gatewaysRecheck = 2 * time.Second - nameserversRecheck = 2 * time.Second + gatewaysRecheck = 2 * time.Second + nameserversRecheck = 2 * time.Second ) var ( - connectivity uint8 - connectivityLock sync.Mutex - connectivityExpires = time.Now() - // interfaces = make(map[*net.IP]net.Flags) // interfacesLock sync.Mutex // interfacesExpires = time.Now() @@ -52,114 +32,10 @@ var ( nameservers = make([]Nameserver, 0) nameserversLock sync.Mutex nameserversExpires = time.Now() - - lastNetworkChange *int64 - lastNetworkChecksum []byte ) +// Nameserver describes a system assigned namserver. type Nameserver struct { IP net.IP Search []string } - -func init() { - lnc := int64(0) - lastNetworkChange = &lnc - go func() { - time.Sleep(1 * time.Second) - Connectivity() - }() - - go monitorNetworkChanges() -} - -// Connectivity returns the current state of connectivity to the network/Internet -func Connectivity() uint8 { - // locking - connectivityLock.Lock() - defer connectivityLock.Unlock() - // cache - if connectivityExpires.After(time.Now()) { - return connectivity - } - // logic - // TODO: implement more methods - status, err := getConnectivityStateFromDbus() - if err != nil { - log.Warningf("environment: could not get connectivity: %s", err) - setConnectivity(UNKNOWN) - return UNKNOWN - } - setConnectivity(status) - return status -} - -func setConnectivity(status uint8) { - if connectivity != status { - connectivity = status - connectivityExpires = time.Now().Add(connectivityRecheck) - - var connectivityName string - switch connectivity { - case UNKNOWN: - connectivityName = "unknown" - case OFFLINE: - connectivityName = "offline" - case LIMITED: - connectivityName = "limited" - case PORTAL: - connectivityName = "portal" - case ONLINE: - connectivityName = "online" - default: - connectivityName = "invalid" - } - log.Infof("environment: connectivity changed to %s", connectivityName) - } -} - -// ConnectionSucceeded should be called when a module was able to successfully connect to the internet (do not call too often) -func ConnectionSucceeded() { - connectivityLock.Lock() - defer connectivityLock.Unlock() - setConnectivity(ONLINE) -} - -func monitorNetworkChanges() { - // TODO: make more elegant solution - for { - time.Sleep(2 * time.Second) - hasher := sha1.New() - interfaces, err := net.Interfaces() - if err != nil { - log.Warningf("environment: failed to get interfaces: %s", err) - continue - } - for _, iface := range interfaces { - io.WriteString(hasher, iface.Name) - // log.Tracef("adding: %s", iface.Name) - io.WriteString(hasher, iface.Flags.String()) - // log.Tracef("adding: %s", iface.Flags.String()) - addrs, err := iface.Addrs() - if err != nil { - log.Warningf("environment: failed to get addrs from interface %s: %s", iface.Name, err) - continue - } - for _, addr := range addrs { - io.WriteString(hasher, addr.String()) - // log.Tracef("adding: %s", addr.String()) - } - } - newChecksum := hasher.Sum(nil) - if !bytes.Equal(lastNetworkChecksum, newChecksum) { - if len(lastNetworkChecksum) == 0 { - lastNetworkChecksum = newChecksum - continue - } - lastNetworkChecksum = newChecksum - atomic.StoreInt64(lastNetworkChange, time.Now().Unix()) - log.Info("environment: network changed") - triggerNetworkChanged() - } - } -} diff --git a/network/environment/environment_linux.go b/network/environment/environment_linux.go index 3c2d3d14..6aa89653 100644 --- a/network/environment/environment_linux.go +++ b/network/environment/environment_linux.go @@ -57,7 +57,7 @@ func Gateways() []*net.IP { continue } if len(decoded) != 4 { - log.Warningf("environment: decoded gateway %s from /proc/net/route has wrong length") + log.Warningf("environment: decoded gateway %s from /proc/net/route has wrong length", decoded) continue } gate := net.IPv4(decoded[3], decoded[2], decoded[1], decoded[0]) @@ -90,7 +90,7 @@ func Gateways() []*net.IP { continue } if len(decoded) != 16 { - log.Warningf("environment: decoded gateway %s from /proc/net/ipv6_route has wrong length") + log.Warningf("environment: decoded gateway %s from /proc/net/ipv6_route has wrong length", decoded) continue } gate := net.IP(decoded) @@ -134,7 +134,6 @@ func Nameservers() []Nameserver { resolvconfNameservers, err := getNameserversFromResolvconf() if err != nil { log.Warningf("environment: could not get nameservers from resolvconf: %s", err) - resolvconfNameservers = make([]Nameserver, 0) } else { nameservers = addNameservers(nameservers, resolvconfNameservers) } @@ -178,7 +177,7 @@ func getNameserversFromResolvconf() ([]Nameserver, error) { } // build array - var nameservers []Nameserver + nameservers := make([]Nameserver, 0, len(servers)) for _, server := range servers { nameservers = append(nameservers, Nameserver{ IP: server, diff --git a/network/environment/environment_test.go b/network/environment/environment_test.go index c4db1685..da918e9b 100644 --- a/network/environment/environment_test.go +++ b/network/environment/environment_test.go @@ -6,9 +6,6 @@ import "testing" func TestEnvironment(t *testing.T) { - connectivityTest := Connectivity() - t.Logf("connectivity: %v", connectivityTest) - nameserversTest, err := getNameserversFromResolvconf() if err != nil { t.Errorf("failed to get namerservers from resolvconf: %s", err) diff --git a/network/environment/events.go b/network/environment/events.go deleted file mode 100644 index 35488ea4..00000000 --- a/network/environment/events.go +++ /dev/null @@ -1,23 +0,0 @@ -package environment - -import ( - "sync" -) - -var ( - networkChangedEventCh = make(chan struct{}, 0) - networkChangedEventLock sync.Mutex -) - -func triggerNetworkChanged() { - networkChangedEventLock.Lock() - defer networkChangedEventLock.Unlock() - close(networkChangedEventCh) - networkChangedEventCh = make(chan struct{}, 0) -} - -func NetworkChanged() <-chan struct{} { - networkChangedEventLock.Lock() - defer networkChangedEventLock.Unlock() - return networkChangedEventCh -} diff --git a/network/environment/interface.go b/network/environment/interface.go deleted file mode 100644 index 338eee87..00000000 --- a/network/environment/interface.go +++ /dev/null @@ -1,27 +0,0 @@ -package environment - -import ( - "sync" - "sync/atomic" -) - -type EnvironmentInterface struct { - lastNetworkChange int64 - lock sync.Mutex -} - -func NewInterface() *EnvironmentInterface { - return &EnvironmentInterface{ - lastNetworkChange: 0, - } -} - -func (env *EnvironmentInterface) NetworkChanged() bool { - env.lock.Lock() - defer env.lock.Unlock() - lnc := atomic.LoadInt64(lastNetworkChange) - if lnc > env.lastNetworkChange { - return true - } - return false -} diff --git a/network/environment/location.go b/network/environment/location.go index 806009f3..959768c2 100644 --- a/network/environment/location.go +++ b/network/environment/location.go @@ -1,7 +1,6 @@ package environment import ( - "errors" "fmt" "log" "net" @@ -14,20 +13,21 @@ import ( "golang.org/x/net/ipv4" ) -// TODO: reference forking // TODO: Create IPv6 version of GetApproximateInternetLocation // GetApproximateInternetLocation returns the IP-address of the nearest ping-answering internet node +//nolint:gocognit // TODO func GetApproximateInternetLocation() (net.IP, error) { // TODO: first check if we have a public IP // net.InterfaceAddrs() // Traceroute example - var dst net.IPAddr - dst.IP = net.IPv4(8, 8, 8, 8) + dst := net.IPAddr{ + IP: net.IPv4(1, 1, 1, 1), + } - c, err := net.ListenPacket("ip4:1", "0.0.0.0") // ICMP for IPv4 + c, err := net.ListenPacket("ip4:icmp", "0.0.0.0") // ICMP for IPv4 if err != nil { return nil, err } @@ -42,9 +42,8 @@ func GetApproximateInternetLocation() (net.IP, error) { wm := icmp.Message{ Type: ipv4.ICMPTypeEcho, Code: 0, Body: &icmp.Echo{ - ID: os.Getpid() & 0xffff, - // TODO: think of something better and not suspicious - Data: []byte("HELLO-R-U-THERE"), + ID: os.Getpid() & 0xffff, + Data: []byte{0}, }, } rb := make([]byte, 1500) @@ -96,7 +95,7 @@ next: case ipv4.ICMPTypeTimeExceeded: ip := net.ParseIP(peer.String()) if ip == nil { - return nil, errors.New(fmt.Sprintf("failed to parse IP: %s", peer.String())) + return nil, fmt.Errorf("failed to parse IP: %s", peer.String()) } if !netutils.IPIsLAN(ip) { return ip, nil diff --git a/network/environment/main.go b/network/environment/main.go new file mode 100644 index 00000000..31ee8d61 --- /dev/null +++ b/network/environment/main.go @@ -0,0 +1,42 @@ +package environment + +import ( + "errors" + + "github.com/safing/portbase/modules" +) + +const ( + networkChangedEvent = "network changed" + onlineStatusChangedEvent = "online status changed" +) + +var ( + module *modules.Module +) + +func InitSubModule(m *modules.Module) { + module = m + module.RegisterEvent(networkChangedEvent) + module.RegisterEvent(onlineStatusChangedEvent) +} + +func StartSubModule() error { + if module == nil { + return errors.New("not initialized") + } + + module.StartServiceWorker( + "monitor network changes", + 0, + monitorNetworkChanges, + ) + + module.StartServiceWorker( + "monitor online status", + 0, + monitorOnlineStatus, + ) + + return nil +} diff --git a/network/environment/network-change.go b/network/environment/network-change.go new file mode 100644 index 00000000..808c66fb --- /dev/null +++ b/network/environment/network-change.go @@ -0,0 +1,91 @@ +package environment + +import ( + "bytes" + "context" + "crypto/sha1" //nolint:gosec // not used for security + "io" + "net" + "time" + + "github.com/safing/portbase/log" +) + +var ( + networkChangeCheckTrigger = make(chan struct{}, 1) +) + +func triggerNetworkChangeCheck() { + select { + case networkChangeCheckTrigger <- struct{}{}: + default: + } +} + +func monitorNetworkChanges(ctx context.Context) error { + var lastNetworkChecksum []byte + +serviceLoop: + for { + trigger := false + + // wait for trigger + if GetOnlineStatus() == StatusOnline { + select { + case <-ctx.Done(): + return nil + case <-networkChangeCheckTrigger: + case <-time.After(1 * time.Minute): + trigger = true + } + } else { + select { + case <-ctx.Done(): + return nil + case <-networkChangeCheckTrigger: + case <-time.After(1 * time.Second): + trigger = true + } + } + + // check network for changes + // create hashsum of current network config + hasher := sha1.New() //nolint:gosec // not used for security + interfaces, err := net.Interfaces() + if err != nil { + log.Warningf("environment: failed to get interfaces: %s", err) + continue + } + for _, iface := range interfaces { + _, _ = io.WriteString(hasher, iface.Name) + // log.Tracef("adding: %s", iface.Name) + _, _ = io.WriteString(hasher, iface.Flags.String()) + // log.Tracef("adding: %s", iface.Flags.String()) + addrs, err := iface.Addrs() + if err != nil { + log.Warningf("environment: failed to get addrs from interface %s: %s", iface.Name, err) + continue + } + for _, addr := range addrs { + _, _ = io.WriteString(hasher, addr.String()) + // log.Tracef("adding: %s", addr.String()) + } + } + newChecksum := hasher.Sum(nil) + + // compare checksum with last + if !bytes.Equal(lastNetworkChecksum, newChecksum) { + if len(lastNetworkChecksum) == 0 { + lastNetworkChecksum = newChecksum + continue serviceLoop + } + lastNetworkChecksum = newChecksum + + if trigger { + triggerOnlineStatusInvestigation() + } + module.TriggerEvent(networkChangedEvent, nil) + } + + } +} diff --git a/network/environment/notes.md b/network/environment/notes.md new file mode 100644 index 00000000..b909ed9b --- /dev/null +++ b/network/environment/notes.md @@ -0,0 +1,19 @@ + + +Intel: +- First ever request: use first resolver as selected +- If resolver fails: + - stop all requesting + - get network status + - if failed: do nothing, return offline error + - check list front to back, use first resolver that resolves one.one.one.one correctly + +NetEnv: +- check for intercepted HTTP Request requests +- if fails on: + - connection establishment: OFFLINE + - +- check for intercepted HTTPS Request requests + + +- check for intercepted DNS requests diff --git a/network/environment/online-status.go b/network/environment/online-status.go new file mode 100644 index 00000000..e1393f7a --- /dev/null +++ b/network/environment/online-status.go @@ -0,0 +1,353 @@ +package environment + +import ( + "context" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + + "github.com/safing/portbase/log" + "github.com/safing/portmaster/network/netutils" + + "github.com/tevino/abool" +) + +// OnlineStatus represent a state of connectivity to the Internet. +type OnlineStatus uint8 + +// Online Status Values +const ( + StatusUnknown OnlineStatus = 0 + StatusOffline OnlineStatus = 1 + StatusLimited OnlineStatus = 2 // local network only + StatusPortal OnlineStatus = 3 // there seems to be an internet connection, but we are being intercepted, possibly by a captive portal + StatusSemiOnline OnlineStatus = 4 // we seem to online, but without full connectivity + StatusOnline OnlineStatus = 5 +) + +// Online Status and Resolver +const ( + HTTPTestURL = "http://detectportal.firefox.com/success.txt" + HTTPExpectedContent = "success" + HTTPSTestURL = "https://one.one.one.one/" + + ResolverTestFqdn = "one.one.one.one." + ResolverTestRRType = dns.TypeA + ResolverTestExpectedResponse = "1.1.1.1" +) + +var ( + parsedHTTPTestURL *url.URL + parsedHTTPSTestURL *url.URL +) + +func init() { + var err error + + parsedHTTPTestURL, err = url.Parse(HTTPTestURL) + if err != nil { + panic(err) + } + + parsedHTTPSTestURL, err = url.Parse(HTTPSTestURL) + if err != nil { + panic(err) + } +} + +// IsOnlineStatusTestDomain checks whether the given fqdn is used for testing online status. +func IsOnlineStatusTestDomain(domain string) bool { + switch domain { + case "detectportal.firefox.com.": + return true + case "one.one.one.one.": + return true + } + + return false +} + +// GetResolverTestingRequestData returns request information that should be used to test DNS resolvers for availability and basic correct behaviour. +func GetResolverTestingRequestData() (fqdn string, rrType uint16, expectedResponse string) { + return ResolverTestFqdn, ResolverTestRRType, ResolverTestExpectedResponse +} + +func (os OnlineStatus) String() string { + switch os { + default: + return "Unknown" + case StatusOffline: + return "Offline" + case StatusLimited: + return "Limited" + case StatusPortal: + return "Portal" + case StatusSemiOnline: + return "SemiOnline" + case StatusOnline: + return "Online" + } +} + +var ( + onlineStatus *int32 + onlineStatusQuickCheck = abool.NewBool(false) + + onlineStatusInvestigationTrigger = make(chan struct{}, 1) + onlineStatusInvestigationInProgress = abool.NewBool(false) + onlineStatusInvestigationWg sync.WaitGroup + + captivePortalURL string + captivePortalLock sync.Mutex +) + +func init() { + var onlineStatusValue int32 + onlineStatus = &onlineStatusValue +} + +// Online returns true if online status is either SemiOnline or Online. +func Online() bool { + return onlineStatusQuickCheck.IsSet() +} + +// GetOnlineStatus returns the current online stats. +func GetOnlineStatus() OnlineStatus { + return OnlineStatus(atomic.LoadInt32(onlineStatus)) +} + +// CheckAndGetOnlineStatus triggers a new online status check and returns the result +func CheckAndGetOnlineStatus() OnlineStatus { + // trigger new investigation + triggerOnlineStatusInvestigation() + // wait for completion + onlineStatusInvestigationWg.Wait() + // return current status + return GetOnlineStatus() +} + +func updateOnlineStatus(status OnlineStatus, portalURL, comment string) { + changed := false + + // status + currentStatus := atomic.LoadInt32(onlineStatus) + if status != OnlineStatus(currentStatus) && atomic.CompareAndSwapInt32(onlineStatus, currentStatus, int32(status)) { + // status changed! + onlineStatusQuickCheck.SetTo( + status == StatusOnline || status == StatusSemiOnline, + ) + changed = true + } + + // captive portal + captivePortalLock.Lock() + defer captivePortalLock.Unlock() + if portalURL != captivePortalURL { + captivePortalURL = portalURL + changed = true + } + + // trigger event + if changed { + module.TriggerEvent(onlineStatusChangedEvent, nil) + if status == StatusPortal { + log.Infof(`network: setting online status to %s at "%s" (%s)`, status, captivePortalURL, comment) + } else { + log.Infof("network: setting online status to %s (%s)", status, comment) + } + triggerNetworkChangeCheck() + } +} + +// GetCaptivePortalURL returns the current captive portal url as a string. +func GetCaptivePortalURL() string { + captivePortalLock.Lock() + defer captivePortalLock.Unlock() + + return captivePortalURL +} + +// ReportSuccessfulConnection hints the online status monitoring system that a connection attempt was successful. +func ReportSuccessfulConnection() { + if !onlineStatusQuickCheck.IsSet() { + triggerOnlineStatusInvestigation() + } +} + +// ReportFailedConnection hints the online status monitoring system that a connection attempt has failed. This function has extremely low overhead and may be called as much as wanted. +func ReportFailedConnection() { + if onlineStatusQuickCheck.IsSet() { + triggerOnlineStatusInvestigation() + } +} + +func triggerOnlineStatusInvestigation() { + if onlineStatusInvestigationInProgress.SetToIf(false, true) { + onlineStatusInvestigationWg.Add(1) + } + + select { + case onlineStatusInvestigationTrigger <- struct{}{}: + default: + } +} + +func monitorOnlineStatus(ctx context.Context) error { + for { + // wait for trigger + if GetOnlineStatus() == StatusOnline { + select { + case <-ctx.Done(): + return nil + case <-onlineStatusInvestigationTrigger: + case <-time.After(1 * time.Minute): + } + } else { + select { + case <-ctx.Done(): + return nil + case <-onlineStatusInvestigationTrigger: + case <-time.After(1 * time.Second): + } + } + + // enable waiting + if onlineStatusInvestigationInProgress.SetToIf(false, true) { + onlineStatusInvestigationWg.Add(1) + } + + checkOnlineStatus(ctx) + + // finished! + onlineStatusInvestigationWg.Done() + onlineStatusInvestigationInProgress.UnSet() + } +} + +func checkOnlineStatus(ctx context.Context) { + // TODO: implement more methods + /*status, err := getConnectivityStateFromDbus() + if err != nil { + log.Warningf("environment: could not get connectivity: %s", err) + setConnectivity(StatusUnknown) + return StatusUnknown + }*/ + + // 1) check for addresses + + ipv4, ipv6, err := GetAssignedAddresses() + if err != nil { + log.Warningf("network: failed to get assigned network addresses: %s", err) + } else { + var lan bool + for _, ip := range ipv4 { + switch netutils.ClassifyIP(ip) { + case netutils.SiteLocal: + lan = true + case netutils.Global: + // we _are_ the Internet ;) + updateOnlineStatus(StatusOnline, "", "global IPv4 interface detected") + return + } + } + for _, ip := range ipv6 { + switch netutils.ClassifyIP(ip) { + case netutils.SiteLocal, netutils.Global: + // IPv6 global addresses are also used in local networks + lan = true + } + } + if !lan { + updateOnlineStatus(StatusOffline, "", "no local or global interfaces detected") + return + } + } + + // 2) try a http request + + // TODO: find (array of) alternatives to detectportal.firefox.com + // TODO: find something about usage terms of detectportal.firefox.com + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 5 * time.Second, + LocalAddr: getLocalAddr("tcp"), + DualStack: true, + }).DialContext, + DisableKeepAlives: true, + DisableCompression: true, + WriteBufferSize: 1024, + ReadBufferSize: 1024, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Timeout: 5 * time.Second, + } + + request := (&http.Request{ + Method: "GET", + URL: parsedHTTPTestURL, + Close: true, + }).WithContext(ctx) + + response, err := client.Do(request) + if err != nil { + updateOnlineStatus(StatusLimited, "", "http request failed") + return + } + defer response.Body.Close() + + // check location + portalURL, err := response.Location() + if err == nil { + updateOnlineStatus(StatusPortal, portalURL.String(), "http request succeeded with redirect") + return + } + + // read the body + data, err := ioutil.ReadAll(response.Body) + if err != nil { + log.Warningf("network: failed to read http body of captive portal testing response: %s", err) + // assume we are online nonetheless + updateOnlineStatus(StatusOnline, "", "http request succeeded, albeit failing later") + return + } + + // check body contents + if strings.TrimSpace(string(data)) == HTTPExpectedContent { + updateOnlineStatus(StatusOnline, "", "http request succeeded") + } else { + // something is interfering with the website content + // this might be a weird captive portal, just direct the user there + updateOnlineStatus(StatusPortal, "detectportal.firefox.com", "http request succeeded, response content not as expected") + } + + // 3) try a https request + + request = (&http.Request{ + Method: "HEAD", + URL: parsedHTTPSTestURL, + Close: true, + }).WithContext(ctx) + + // only test if we can get the headers + response, err = client.Do(request) + if err != nil { + // if we fail, something is really weird + updateOnlineStatus(StatusSemiOnline, "", "http request failed") + return + } + defer response.Body.Close() + + // finally + updateOnlineStatus(StatusOnline, "", "all checks successful") +} diff --git a/network/environment/online-status_test.go b/network/environment/online-status_test.go new file mode 100644 index 00000000..afb72d7c --- /dev/null +++ b/network/environment/online-status_test.go @@ -0,0 +1,12 @@ +package environment + +import ( + "context" + "testing" +) + +func TestCheckOnlineStatus(t *testing.T) { + checkOnlineStatus(context.Background()) + t.Logf("online status: %s", GetOnlineStatus()) + t.Logf("captive portal: %s", GetCaptivePortalURL()) +} diff --git a/network/geoip/database.go b/network/geoip/database.go index bff7f217..96bf62d1 100644 --- a/network/geoip/database.go +++ b/network/geoip/database.go @@ -4,46 +4,52 @@ import ( "fmt" "sync" + "github.com/tevino/abool" + maxminddb "github.com/oschwald/maxminddb-golang" "github.com/safing/portbase/log" + "github.com/safing/portbase/updater" "github.com/safing/portmaster/updates" ) var ( + dbCityFile *updater.File + dbASNFile *updater.File + dbFileLock sync.Mutex + dbCity *maxminddb.Reader dbASN *maxminddb.Reader + dbLock sync.Mutex - dbLock sync.Mutex - dbInUse = false // only activate if used for first time - dbDoReload = true // if database should be reloaded + dbInUse = abool.NewBool(false) // only activate if used for first time + dbDoReload = abool.NewBool(true) // if database should be reloaded ) +// ReloadDatabases reloads the geoip database, if they are in use. func ReloadDatabases() error { - dbLock.Lock() - defer dbLock.Unlock() - // don't do anything if the database isn't actually used - if !dbInUse { + if !dbInUse.IsSet() { return nil } - dbDoReload = true + dbFileLock.Lock() + defer dbFileLock.Unlock() + dbLock.Lock() + defer dbLock.Unlock() + + dbDoReload.Set() return doReload() } func prepDatabaseForUse() error { - dbInUse = true + dbInUse.Set() return doReload() } func doReload() error { // reload if needed - if dbDoReload { - defer func() { - dbDoReload = false - }() - + if dbDoReload.SetToIf(true, false) { closeDBs() return openDBs() } @@ -53,7 +59,7 @@ func doReload() error { func openDBs() error { var err error - file, err := updates.GetFile("intel/geoip-city.mmdb") + file, err := updates.GetFile("intel/geoip/geoip-city.mmdb") if err != nil { return fmt.Errorf("could not get GeoIP City database file: %s", err) } @@ -61,7 +67,8 @@ func openDBs() error { if err != nil { return err } - file, err = updates.GetFile("intel/geoip-asn.mmdb") + + file, err = updates.GetFile("intel/geoip/geoip-asn.mmdb") if err != nil { return fmt.Errorf("could not get GeoIP ASN database file: %s", err) } @@ -73,8 +80,8 @@ func openDBs() error { } func handleError(err error) { - log.Warningf("network/geoip: lookup failed, reloading databases...") - dbDoReload = true + log.Errorf("network/geoip: lookup failed, reloading databases: %s", err) + dbDoReload.Set() } func closeDBs() { diff --git a/network/geoip/location.go b/network/geoip/location.go index a23efe19..8c341ed4 100644 --- a/network/geoip/location.go +++ b/network/geoip/location.go @@ -8,7 +8,7 @@ import ( ) const ( - earthCircumferenceKm float64 = 40100 // earth circumference in km + earthCircumferenceInKm float64 = 40100 // earth circumference in km ) // Location holds information regarding the geographical and network location of an IP address @@ -42,7 +42,7 @@ type Location struct { // Conclusion: // - Ignore location data completely if accuracy_radius > 500 -// EstimateNetworkProximity aims to calculate a distance value between 0 and 100. +// EstimateNetworkProximity aims to calculate the distance between two network locations. Returns a proximity value between 0 (far away) and 100 (nearby). func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) { // Distance Value: // 0: other side of the Internet @@ -50,12 +50,10 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) { // Weighting: // coordinate distance: 0-50 - // continent match: 10 + // continent match: 15 // country match: 10 // AS owner match: 15 - // AS network match: 15 - // - // We prioritize AS information over country information, as it is more accurate and we expect better privacy if we already are in the destination AS. + // AS network match: 10 // coordinate distance: 0-50 fromCoords := haversine.Coord{Lat: l.Coordinates.Latitude, Lon: l.Coordinates.Longitude} @@ -69,19 +67,19 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) { accuracy = to.Coordinates.AccuracyRadius } - if km <= 10 && accuracy <= 200 { + if km <= 10 && accuracy <= 100 { proximity += 50 } else { - distanceIn50Percent := ((earthCircumferenceKm - km) / earthCircumferenceKm) * 50 + distanceIn50Percent := ((earthCircumferenceInKm - km) / earthCircumferenceInKm) * 50 - // apply penalty for values high values (targeting >100) + // apply penalty for locations with low accuracy (targeting accuracy radius >100) accuracyModifier := 1 - float64(accuracy)/1000 proximity += int(distanceIn50Percent * accuracyModifier) } - // continent match: 10 + // continent match: 15 if l.Continent.Code == to.Continent.Code { - proximity += 10 + proximity += 15 // country match: 10 if l.Country.ISOCode == to.Country.ISOCode { proximity += 10 @@ -91,16 +89,16 @@ func (l *Location) EstimateNetworkProximity(to *Location) (proximity int) { // AS owner match: 15 if l.AutonomousSystemOrganization == to.AutonomousSystemOrganization { proximity += 15 - // AS network match: 15 + // AS network match: 10 if l.AutonomousSystemNumber == to.AutonomousSystemNumber { - proximity += 15 + proximity += 10 } } - return - + return //nolint:nakedreturn } +// PrimitiveNetworkProximity calculates the numerical distance between two IP addresses. Returns a proximity value between 0 (far away) and 100 (nearby). func PrimitiveNetworkProximity(from net.IP, to net.IP, ipVersion uint8) int { var diff float64 @@ -128,7 +126,7 @@ func PrimitiveNetworkProximity(from net.IP, to net.IP, ipVersion uint8) int { switch ipVersion { case 4: - diff = diff / 256 + diff /= 256 return int((1 - diff/16777216) * 100) case 6: return int((1 - diff/18446744073709552000) * 100) diff --git a/network/geoip/lookup_test.go b/network/geoip/lookup_test.go index 26cfc069..956a0923 100644 --- a/network/geoip/lookup_test.go +++ b/network/geoip/lookup_test.go @@ -3,9 +3,16 @@ package geoip import ( "net" "testing" + + "github.com/safing/portmaster/updates" ) func TestLocationLookup(t *testing.T) { + err := updates.InitForTesting() + if err != nil { + t.Fatal(err) + } + ip1 := net.ParseIP("81.2.69.142") loc1, err := GetLocation(ip1) if err != nil { diff --git a/network/geoip/module.go b/network/geoip/module.go new file mode 100644 index 00000000..05b33a91 --- /dev/null +++ b/network/geoip/module.go @@ -0,0 +1,58 @@ +package geoip + +import ( + "context" + "fmt" + "time" + + "github.com/safing/portbase/modules" +) + +var ( + module *modules.Module +) + +func init() { + module = modules.Register("geoip", nil, start, nil, "updates") +} + +func start() error { + err := prepDatabaseForUse() + if err != nil { + return fmt.Errorf("goeip: failed to load databases: %s", err) + } + + module.RegisterEventHook( + "updates", + "resource update", + "upgrade databases", + upgradeDatabases, + ) + + // TODO: replace with update subscription + module.NewTask("update databases", func(ctx context.Context, task *modules.Task) { + + dbFileLock.Lock() + defer dbFileLock.Unlock() + + }).Repeat(10 * time.Minute).MaxDelay(1 * time.Hour) + + return nil +} + +func upgradeDatabases(_ context.Context, _ interface{}) error { + dbFileLock.Lock() + reload := false + if dbCityFile != nil && dbCityFile.UpgradeAvailable() { + reload = true + } + if dbASNFile != nil && dbASNFile.UpgradeAvailable() { + reload = true + } + dbFileLock.Unlock() + + if reload { + return ReloadDatabases() + } + return nil +} diff --git a/network/link.go b/network/link.go index a1cbd04f..dcc52db3 100644 --- a/network/link.go +++ b/network/link.go @@ -98,7 +98,7 @@ func (link *Link) HandlePacket(pkt packet.Packet) { link.pktQueue <- pkt return } - log.Criticalf("network: link %s does not have a firewallHandler, dropping packet", link) + log.Warningf("network: link %s does not have a firewallHandler, dropping packet", link) pkt.Drop() } @@ -175,14 +175,18 @@ func (link *Link) packetHandler() { if pkt == nil { return } + // get handler link.Lock() - fwH := link.firewallHandler + handler := link.firewallHandler link.Unlock() - if fwH != nil { - fwH(pkt, link) + // execute handler or verdict + if handler != nil { + handler(pkt, link) } else { link.ApplyVerdict(pkt) } + // submit trace logs + log.Tracer(pkt.Ctx()).Submit() } } @@ -311,10 +315,10 @@ func GetOrCreateLinkByPacket(pkt packet.Packet) (*Link, bool) { // CreateLinkFromPacket creates a new Link based on Packet. func CreateLinkFromPacket(pkt packet.Packet) *Link { link := &Link{ - ID: pkt.GetLinkID(), - Verdict: VerdictUndecided, - Started: time.Now().Unix(), - RemoteAddress: pkt.FmtRemoteAddress(), + ID: pkt.GetLinkID(), + Verdict: VerdictUndecided, + Started: time.Now().Unix(), + RemoteAddress: pkt.FmtRemoteAddress(), saveWhenFinished: true, } return link diff --git a/network/module.go b/network/module.go index 1bd8ba9f..ab4d46f3 100644 --- a/network/module.go +++ b/network/module.go @@ -2,13 +2,25 @@ package network import ( "github.com/safing/portbase/modules" + "github.com/safing/portmaster/network/environment" +) + +var ( + module *modules.Module ) func init() { - modules.Register("network", nil, start, nil, "core") + module = modules.Register("network", nil, start, nil, "core") + environment.InitSubModule(module) } func start() error { + err := registerAsDatabase() + if err != nil { + return err + } + go cleaner() - return registerAsDatabase() + + return environment.StartSubModule() } diff --git a/network/netutils/cleandns.go b/network/netutils/cleandns.go index 33cec6cc..f9487664 100644 --- a/network/netutils/cleandns.go +++ b/network/netutils/cleandns.go @@ -5,8 +5,7 @@ import ( ) var ( - // cleanDomainRegex = regexp.MustCompile("^(((?!-))(xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\\.[a-z]{2,}\\.)$") - cleanDomainRegex = regexp.MustCompile("^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\\.[a-z]{2,}\\.)$") + cleanDomainRegex = regexp.MustCompile(`^((xn--)?[a-z0-9-_]{0,61}[a-z0-9]{1,1}\.)*(xn--)?([a-z0-9-]{1,61}|[a-z0-9-]{1,30}\.[a-z]{2,}\.)$`) ) // IsValidFqdn returns whether the given string is a valid fqdn. diff --git a/network/netutils/ip.go b/network/netutils/ip.go index 651a5d95..586dc185 100644 --- a/network/netutils/ip.go +++ b/network/netutils/ip.go @@ -10,7 +10,7 @@ const ( Global LocalMulticast GlobalMulticast - Invalid + Invalid int8 = -1 ) // ClassifyIP returns the classification for the given IP address. @@ -77,9 +77,7 @@ func IPIsLocalhost(ip net.IP) bool { // IPIsLAN returns true if the given IP is a site-local or link-local address. func IPIsLAN(ip net.IP) bool { switch ClassifyIP(ip) { - case SiteLocal: - return true - case LinkLocal: + case SiteLocal, LinkLocal: return true default: return false diff --git a/network/netutils/tcpassembly.go b/network/netutils/tcpassembly.go index 0d5d8762..43fa8bc9 100644 --- a/network/netutils/tcpassembly.go +++ b/network/netutils/tcpassembly.go @@ -7,32 +7,34 @@ import ( "github.com/google/gopacket/tcpassembly" ) +// SimpleStreamAssemblerManager is a simple manager for github.com/google/gopacket/tcpassembly type SimpleStreamAssemblerManager struct { InitLock sync.Mutex lastAssembler *SimpleStreamAssembler } +// New returns a new stream assembler. func (m *SimpleStreamAssemblerManager) New(net, transport gopacket.Flow) tcpassembly.Stream { assembler := new(SimpleStreamAssembler) m.lastAssembler = assembler return assembler } +// GetLastAssembler returns the newest created stream assembler. func (m *SimpleStreamAssemblerManager) GetLastAssembler() *SimpleStreamAssembler { - // defer func() { - // m.lastAssembler = nil - // }() return m.lastAssembler } +// SimpleStreamAssembler is a simple assembler for github.com/google/gopacket/tcpassembly type SimpleStreamAssembler struct { Cumulated []byte CumulatedLen int Complete bool } +// NewSimpleStreamAssembler returns a new SimpleStreamAssembler. func NewSimpleStreamAssembler() *SimpleStreamAssembler { - return new(SimpleStreamAssembler) + return &SimpleStreamAssembler{} } // Reassembled implements tcpassembly.Stream's Reassembled function. diff --git a/network/packet/const.go b/network/packet/const.go index 09955577..6107d2e7 100644 --- a/network/packet/const.go +++ b/network/packet/const.go @@ -5,13 +5,17 @@ import ( "fmt" ) +// Basic Types type ( - IPVersion uint8 + // IPVersion represents an IP version. + IPVersion uint8 + // IPProtocol represents an IP protocol. IPProtocol uint8 - Verdict uint8 - Endpoint bool + // Verdict describes the decision on a packet. + Verdict uint8 ) +// Basic Constants const ( IPv4 = IPVersion(4) IPv6 = IPVersion(6) @@ -19,18 +23,15 @@ const ( InBound = true OutBound = false - Local = true - Remote = false - - // convenience + ICMP = IPProtocol(1) IGMP = IPProtocol(2) - RAW = IPProtocol(255) TCP = IPProtocol(6) UDP = IPProtocol(17) - ICMP = IPProtocol(1) ICMPv6 = IPProtocol(58) + RAW = IPProtocol(255) ) +// Verdicts const ( DROP Verdict = iota BLOCK @@ -42,10 +43,11 @@ const ( ) var ( + // ErrFailedToLoadPayload is returned by GetPayload if it failed for an unspecified reason, or is not implemented on the current system. ErrFailedToLoadPayload = errors.New("could not load packet payload") ) -// Returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16 +// ByteSize returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16 func (v IPVersion) ByteSize() int { switch v { case IPv4: @@ -56,6 +58,7 @@ func (v IPVersion) ByteSize() int { return 0 } +// String returns the string representation of the IP version: "IPv4" or "IPv6". func (v IPVersion) String() string { switch v { case IPv4: @@ -66,6 +69,7 @@ func (v IPVersion) String() string { return fmt.Sprintf("", uint8(v)) } +// String returns the string representation (abbreviation) of the protocol. func (p IPProtocol) String() string { switch p { case RAW: @@ -84,12 +88,24 @@ func (p IPProtocol) String() string { return fmt.Sprintf("", uint8(p)) } +// String returns the string representation of the verdict. func (v Verdict) String() string { switch v { case DROP: return "DROP" + case BLOCK: + return "BLOCK" case ACCEPT: return "ACCEPT" + case STOLEN: + return "STOLEN" + case QUEUE: + return "QUEUE" + case REPEAT: + return "REPEAT" + case STOP: + return "STOP" + default: + return fmt.Sprintf("", uint8(v)) } - return fmt.Sprintf("", uint8(v)) } diff --git a/network/packet/packetinfo.go b/network/packet/packetinfo.go index 74698156..a98fc8a5 100644 --- a/network/packet/packetinfo.go +++ b/network/packet/packetinfo.go @@ -10,9 +10,9 @@ type Info struct { InTunnel bool Version IPVersion - Src, Dst net.IP Protocol IPProtocol SrcPort, DstPort uint16 + Src, Dst net.IP } // LocalIP returns the local IP of the packet.