Improve device location system with more safeguards

This commit is contained in:
Daniel 2022-06-09 14:36:03 +02:00
parent b392a1e8ff
commit 86d4f64d42
2 changed files with 80 additions and 73 deletions

View file

@ -55,7 +55,7 @@ func registerAPIEndpoints() error {
Read: api.PermitUser, Read: api.PermitUser,
BelongsTo: module, BelongsTo: module,
StructFunc: func(ar *api.Request) (i interface{}, err error) { StructFunc: func(ar *api.Request) (i interface{}, err error) {
return getLocationFromTraceroute() return getLocationFromTraceroute(&DeviceLocations{})
}, },
Name: "Get Approximate Internet Location via Traceroute", Name: "Get Approximate Internet Location via Traceroute",
Description: "Returns an approximation of where the device is on the Internet using a the traceroute technique.", Description: "Returns an approximation of where the device is on the Internet using a the traceroute technique.",

View file

@ -47,16 +47,16 @@ type DeviceLocations struct {
} }
// Best returns the best (most accurate) device location. // Best returns the best (most accurate) device location.
func (dl *DeviceLocations) Best() *DeviceLocation { func (dls *DeviceLocations) Best() *DeviceLocation {
if len(dl.All) > 0 { if len(dls.All) > 0 {
return dl.All[0] return dls.All[0]
} }
return nil return nil
} }
// BestV4 returns the best (most accurate) IPv4 device location. // BestV4 returns the best (most accurate) IPv4 device location.
func (dl *DeviceLocations) BestV4() *DeviceLocation { func (dls *DeviceLocations) BestV4() *DeviceLocation {
for _, loc := range dl.All { for _, loc := range dls.All {
if loc.IPVersion == packet.IPv4 { if loc.IPVersion == packet.IPv4 {
return loc return loc
} }
@ -65,8 +65,8 @@ func (dl *DeviceLocations) BestV4() *DeviceLocation {
} }
// BestV6 returns the best (most accurate) IPv6 device location. // BestV6 returns the best (most accurate) IPv6 device location.
func (dl *DeviceLocations) BestV6() *DeviceLocation { func (dls *DeviceLocations) BestV6() *DeviceLocation {
for _, loc := range dl.All { for _, loc := range dls.All {
if loc.IPVersion == packet.IPv6 { if loc.IPVersion == packet.IPv6 {
return loc return loc
} }
@ -74,11 +74,8 @@ func (dl *DeviceLocations) BestV6() *DeviceLocation {
return nil return nil
} }
func copyDeviceLocations() *DeviceLocations { // Copy creates a copy of the locations, but not the individual entries.
locationsLock.Lock() func (dls *DeviceLocations) Copy() *DeviceLocations {
defer locationsLock.Unlock()
// Create a copy of the locations, but not the entries.
cp := &DeviceLocations{ cp := &DeviceLocations{
All: make([]*DeviceLocation, len(locations.All)), All: make([]*DeviceLocation, len(locations.All)),
} }
@ -87,6 +84,32 @@ func copyDeviceLocations() *DeviceLocations {
return cp return cp
} }
// AddLocation adds a location.
func (dls *DeviceLocations) AddLocation(dl *DeviceLocation) {
if dls == nil {
return
}
// Add to locations, if better.
var exists bool
for i, existing := range dls.All {
if (dl.IP == nil && existing.IP == nil) || dl.IP.Equal(existing.IP) {
exists = true
if dl.IsMoreAccurateThan(existing) {
// Replace
dls.All[i] = dl
break
}
}
}
if !exists {
dls.All = append(dls.All, dl)
}
// Sort locations.
sort.Sort(sortLocationsByAccuracy(dls.All))
}
// DeviceLocation represents a single IP and metadata. It must not be changed // DeviceLocation represents a single IP and metadata. It must not be changed
// once created. // once created.
type DeviceLocation struct { type DeviceLocation struct {
@ -147,6 +170,12 @@ func (dl *DeviceLocation) String() string {
return "<none>" return "<none>"
case dl.Location == nil: case dl.Location == nil:
return dl.IP.String() return dl.IP.String()
case dl.Source == SourceTimezone:
return fmt.Sprintf(
"TZ(%.0f/%.0f)",
dl.Location.Coordinates.Latitude,
dl.Location.Coordinates.Longitude,
)
default: default:
return fmt.Sprintf("%s (AS%d in %s)", dl.IP, dl.Location.AutonomousSystemNumber, dl.Location.Country.ISOCode) return fmt.Sprintf("%s (AS%d in %s)", dl.IP, dl.Location.AutonomousSystemNumber, dl.Location.Country.ISOCode)
} }
@ -193,6 +222,14 @@ func (a sortLocationsByAccuracy) Less(i, j int) bool { return !a[j].IsMoreAccura
// SetInternetLocation provides the location management system with a possible Internet location. // SetInternetLocation provides the location management system with a possible Internet location.
func SetInternetLocation(ip net.IP, source DeviceLocationSource) (dl *DeviceLocation, ok bool) { func SetInternetLocation(ip net.IP, source DeviceLocationSource) (dl *DeviceLocation, ok bool) {
locationsLock.Lock()
defer locationsLock.Unlock()
return locations.AddIP(ip, source)
}
// AddIP adds a new location based on the given IP.
func (dls *DeviceLocations) AddIP(ip net.IP, source DeviceLocationSource) (dl *DeviceLocation, ok bool) {
// Check if IP is global. // Check if IP is global.
if netutils.GetIPScope(ip) != netutils.Global { if netutils.GetIPScope(ip) != netutils.Global {
return nil, false return nil, false
@ -222,38 +259,10 @@ func SetInternetLocation(ip net.IP, source DeviceLocationSource) (dl *DeviceLoca
} }
loc.Location = geoLoc loc.Location = geoLoc
addLocation(loc) dls.AddLocation(loc)
return loc, true return loc, true
} }
func addLocation(dl *DeviceLocation) {
if dl == nil {
return
}
locationsLock.Lock()
defer locationsLock.Unlock()
// Add to locations, if better.
var exists bool
for i, existing := range locations.All {
if (dl.IP == nil && existing.IP == nil) || dl.IP.Equal(existing.IP) {
exists = true
if dl.IsMoreAccurateThan(existing) {
// Replace
locations.All[i] = dl
break
}
}
}
if !exists {
locations.All = append(locations.All, dl)
}
// Sort locations.
sort.Sort(sortLocationsByAccuracy(locations.All))
}
// GetApproximateInternetLocation returns the approximate Internet location. // GetApproximateInternetLocation returns the approximate Internet location.
// Deprecated: Please use GetInternetLocation instead. // Deprecated: Please use GetInternetLocation instead.
func GetApproximateInternetLocation() (net.IP, error) { func GetApproximateInternetLocation() (net.IP, error) {
@ -271,30 +280,21 @@ func GetInternetLocation() (deviceLocations *DeviceLocations, ok bool) {
// Check if the network changed, if not, return cache. // Check if the network changed, if not, return cache.
if !locationNetworkChangedFlag.IsSet() { if !locationNetworkChangedFlag.IsSet() {
return copyDeviceLocations(), true locationsLock.Lock()
defer locationsLock.Unlock()
return locations.Copy(), true
} }
locationNetworkChangedFlag.Refresh() locationNetworkChangedFlag.Refresh()
// Reset locations. // Create new location list.
func() { dls := &DeviceLocations{}
locationsLock.Lock()
defer locationsLock.Unlock()
locations = &DeviceLocations{}
}()
// Get all assigned addresses.
v4s, v6s, err := GetAssignedAddresses()
if err != nil {
log.Warningf("netenv: failed to get assigned addresses for device location: %s", err)
return nil, false
}
// Check interfaces for global addresses. // Check interfaces for global addresses.
v4ok, v6ok := getLocationFromInterfaces() v4ok, v6ok := getLocationFromInterfaces(dls)
// Try other methods for missing locations. // Try other methods for missing locations.
if len(v4s) > 0 && !v4ok { if !v4ok {
_, err = getLocationFromTraceroute() _, err := getLocationFromTraceroute(dls)
if err != nil { if err != nil {
log.Warningf("netenv: failed to get IPv4 device location from traceroute: %s", err) log.Warningf("netenv: failed to get IPv4 device location from traceroute: %s", err)
} else { } else {
@ -303,35 +303,43 @@ func GetInternetLocation() (deviceLocations *DeviceLocations, ok bool) {
// Get location from timezone as final fallback. // Get location from timezone as final fallback.
if !v4ok { if !v4ok {
getLocationFromTimezone(packet.IPv4) getLocationFromTimezone(dls, packet.IPv4)
} }
} }
if len(v6s) > 0 && !v6ok { if !v6ok && IPv6Enabled() {
// TODO: Find more ways to get IPv6 device location // TODO: Find more ways to get IPv6 device location
// Get location from timezone as final fallback. // Get location from timezone as final fallback.
getLocationFromTimezone(packet.IPv6) getLocationFromTimezone(dls, packet.IPv6)
} }
// As a last guard, make sure there is at least one location in the list.
if len(dls.All) == 0 {
getLocationFromTimezone(dls, packet.IPv4)
}
// Set new locations.
locationsLock.Lock()
defer locationsLock.Unlock()
locations = dls
// Return gathered locations. // Return gathered locations.
cp := copyDeviceLocations() return locations.Copy(), true
return cp, true
} }
func getLocationFromInterfaces() (v4ok, v6ok bool) { func getLocationFromInterfaces(dls *DeviceLocations) (v4ok, v6ok bool) {
globalIPv4, globalIPv6, err := GetAssignedGlobalAddresses() globalIPv4, globalIPv6, err := GetAssignedGlobalAddresses()
if err != nil { if err != nil {
log.Warningf("netenv: location: failed to get assigned global addresses: %s", err) log.Warningf("netenv: location: failed to get assigned global addresses: %s", err)
return false, false return false, false
} }
for _, ip := range globalIPv4 { for _, ip := range globalIPv4 {
if _, ok := SetInternetLocation(ip, SourceInterface); ok { if _, ok := dls.AddIP(ip, SourceInterface); ok {
v4ok = true v4ok = true
} }
} }
for _, ip := range globalIPv6 { for _, ip := range globalIPv6 {
if _, ok := SetInternetLocation(ip, SourceInterface); ok { if _, ok := dls.AddIP(ip, SourceInterface); ok {
v6ok = true v6ok = true
} }
} }
@ -349,7 +357,7 @@ func getLocationFromUPnP() (ok bool) {
} }
*/ */
func getLocationFromTraceroute() (dl *DeviceLocation, err error) { func getLocationFromTraceroute(dls *DeviceLocations) (dl *DeviceLocation, err error) {
// Create connection. // Create connection.
conn, err := net.ListenPacket("ip4:icmp", "") conn, err := net.ListenPacket("ip4:icmp", "")
if err != nil { if err != nil {
@ -470,7 +478,7 @@ nextHop:
// We have received a valid time exceeded error. // We have received a valid time exceeded error.
// If message came from a global unicast, us it! // If message came from a global unicast, us it!
if netutils.GetIPScope(remoteIP) == netutils.Global { if netutils.GetIPScope(remoteIP) == netutils.Global {
dl, ok := SetInternetLocation(remoteIP, SourceTraceroute) dl, ok := dls.AddIP(remoteIP, SourceTraceroute)
if !ok { if !ok {
return nil, errors.New("invalid IP address") return nil, errors.New("invalid IP address")
} }
@ -516,7 +524,7 @@ func recvICMP(currentHop int, icmpPacketsViaFirewall chan packet.Packet) (
} }
} }
func getLocationFromTimezone(ipVersion packet.IPVersion) (ok bool) { //nolint:unparam // This is documentation. func getLocationFromTimezone(dls *DeviceLocations, ipVersion packet.IPVersion) {
// Create base struct. // Create base struct.
tzLoc := &DeviceLocation{ tzLoc := &DeviceLocation{
IPVersion: ipVersion, IPVersion: ipVersion,
@ -531,6 +539,5 @@ func getLocationFromTimezone(ipVersion packet.IPVersion) (ok bool) { //nolint:un
tzLoc.Location.Coordinates.Latitude = 48 tzLoc.Location.Coordinates.Latitude = 48
tzLoc.Location.Coordinates.Longitude = float64(offsetSeconds) / 43200 * 180 tzLoc.Location.Coordinates.Longitude = float64(offsetSeconds) / 43200 * 180
addLocation(tzLoc) dls.AddLocation(tzLoc)
return true
} }