Fix tests and linters

This commit is contained in:
Daniel 2022-02-02 12:48:42 +01:00
parent f2fcad4d11
commit 60d8664e7b
171 changed files with 944 additions and 874 deletions

View file

@ -1,6 +1,6 @@
package main package main
import ( import ( //nolint:gci,nolintlint
"os" "os"
"github.com/safing/portbase/info" "github.com/safing/portbase/info"
@ -8,7 +8,7 @@ import (
"github.com/safing/portbase/run" "github.com/safing/portbase/run"
"github.com/safing/spn/conf" "github.com/safing/spn/conf"
// include packages here // Include packages here.
_ "github.com/safing/portbase/modules/subsystems" _ "github.com/safing/portbase/modules/subsystems"
_ "github.com/safing/portmaster/core" _ "github.com/safing/portmaster/core"
_ "github.com/safing/portmaster/firewall" _ "github.com/safing/portmaster/firewall"
@ -22,7 +22,7 @@ func main() {
info.Set("Portmaster", "0.7.21", "AGPLv3", true) info.Set("Portmaster", "0.7.21", "AGPLv3", true)
// Configure metrics. // Configure metrics.
metrics.SetNamespace("portmaster") _ = metrics.SetNamespace("portmaster")
// enable SPN client mode // enable SPN client mode
conf.EnableClient(true) conf.EnableClient(true)

View file

@ -1,4 +1,4 @@
// +build !windows // go:build !windows
package main package main
@ -8,5 +8,4 @@ func attachToParentConsole() (attached bool, err error) {
return true, nil return true, nil
} }
func hideWindow(cmd *exec.Cmd) { func hideWindow(cmd *exec.Cmd) {}
}

View file

@ -86,7 +86,7 @@ func createInstanceLock(lockFilePath string) error {
// create lock file // create lock file
// TODO: Investigate required permissions. // TODO: Investigate required permissions.
err = ioutil.WriteFile(lockFilePath, []byte(fmt.Sprintf("%d", os.Getpid())), 0666) //nolint:gosec err = ioutil.WriteFile(lockFilePath, []byte(fmt.Sprintf("%d", os.Getpid())), 0o0666) //nolint:gosec
if err != nil { if err != nil {
return err return err
} }

View file

@ -8,15 +8,16 @@ import (
"runtime" "runtime"
"time" "time"
"github.com/spf13/cobra"
"github.com/safing/portbase/container" "github.com/safing/portbase/container"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd" "github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/info" "github.com/safing/portbase/info"
"github.com/spf13/cobra"
) )
func initializeLogFile(logFilePath string, identifier string, version string) *os.File { func initializeLogFile(logFilePath string, identifier string, version string) *os.File {
logFile, err := os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE, 0444) logFile, err := os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE, 0o0444)
if err != nil { if err != nil {
log.Printf("failed to create log file %s: %s\n", logFilePath, err) log.Printf("failed to create log file %s: %s\n", logFilePath, err)
return nil return nil
@ -107,7 +108,9 @@ func logControlError(cErr error) {
if errorFile == nil { if errorFile == nil {
return return
} }
defer errorFile.Close() defer func() {
_ = errorFile.Close()
}()
fmt.Fprintln(errorFile, cErr.Error()) fmt.Fprintln(errorFile, cErr.Error())
} }

View file

@ -11,15 +11,14 @@ import (
"strings" "strings"
"syscall" "syscall"
"github.com/safing/portmaster/updates/helper" "github.com/spf13/cobra"
"github.com/safing/portbase/dataroot" "github.com/safing/portbase/dataroot"
"github.com/safing/portbase/info" "github.com/safing/portbase/info"
portlog "github.com/safing/portbase/log" portlog "github.com/safing/portbase/log"
"github.com/safing/portbase/updater" "github.com/safing/portbase/updater"
"github.com/safing/portbase/utils" "github.com/safing/portbase/utils"
"github.com/safing/portmaster/updates/helper"
"github.com/spf13/cobra"
) )
var ( var (
@ -29,7 +28,7 @@ var (
dataRoot *utils.DirStructure dataRoot *utils.DirStructure
logsRoot *utils.DirStructure logsRoot *utils.DirStructure
// create registry // Create registry.
registry = &updater.ResourceRegistry{ registry = &updater.ResourceRegistry{
Name: "updates", Name: "updates",
UpdateURLs: []string{ UpdateURLs: []string{
@ -153,14 +152,14 @@ func configureRegistry(mustLoadIndex bool) error {
// Remove left over quotes. // Remove left over quotes.
dataDir = strings.Trim(dataDir, `\"`) dataDir = strings.Trim(dataDir, `\"`)
// Initialize data root. // Initialize data root.
err := dataroot.Initialize(dataDir, 0755) err := dataroot.Initialize(dataDir, 0o0755)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize data root: %s", err) return fmt.Errorf("failed to initialize data root: %w", err)
} }
dataRoot = dataroot.Root() dataRoot = dataroot.Root()
// Initialize registry. // Initialize registry.
err = registry.Initialize(dataRoot.ChildDir("updates", 0755)) err = registry.Initialize(dataRoot.ChildDir("updates", 0o0755))
if err != nil { if err != nil {
return err return err
} }
@ -170,10 +169,10 @@ func configureRegistry(mustLoadIndex bool) error {
func ensureLoggingDir() error { func ensureLoggingDir() error {
// set up logs root // set up logs root
logsRoot = dataRoot.ChildDir("logs", 0777) logsRoot = dataRoot.ChildDir("logs", 0o0777)
err := logsRoot.Ensure() err := logsRoot.Ensure()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize logs root (%q): %s", logsRoot.Path, err) return fmt.Errorf("failed to initialize logs root (%q): %w", logsRoot.Path, err)
} }
// warn about CTRL-C on windows // warn about CTRL-C on windows

View file

@ -1,13 +1,15 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"strings" "strings"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/safing/portmaster/firewall/interception"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/safing/portmaster/firewall/interception"
) )
var recoverIPTablesCmd = &cobra.Command{ var recoverIPTablesCmd = &cobra.Command{
@ -19,8 +21,10 @@ var recoverIPTablesCmd = &cobra.Command{
// we don't get the errno of the actual error and need to parse the // we don't get the errno of the actual error and need to parse the
// output instead. Make sure it's always english by setting LC_ALL=C // output instead. Make sure it's always english by setting LC_ALL=C
currentLocale := os.Getenv("LC_ALL") currentLocale := os.Getenv("LC_ALL")
os.Setenv("LC_ALL", "C") _ = os.Setenv("LC_ALL", "C")
defer os.Setenv("LC_ALL", currentLocale) defer func() {
_ = os.Setenv("LC_ALL", currentLocale)
}()
err := interception.DeactivateNfqueueFirewall() err := interception.DeactivateNfqueueFirewall()
if err == nil { if err == nil {
@ -29,8 +33,8 @@ var recoverIPTablesCmd = &cobra.Command{
// we don't want to show ErrNotExists to the user // we don't want to show ErrNotExists to the user
// as that only means portmaster did the cleanup itself. // as that only means portmaster did the cleanup itself.
mr, ok := err.(*multierror.Error) var mr *multierror.Error
if !ok { if !errors.As(err, &mr) {
return err return err
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -12,9 +13,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/safing/portmaster/updates/helper"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/tevino/abool" "github.com/tevino/abool"
"github.com/safing/portmaster/updates/helper"
) )
const ( const (
@ -223,11 +225,11 @@ func fixExecPerm(path string) error {
return fmt.Errorf("failed to stat %s: %w", path, err) return fmt.Errorf("failed to stat %s: %w", path, err)
} }
if info.Mode() == 0755 { if info.Mode() == 0o0755 {
return nil return nil
} }
if err := os.Chmod(path, 0755); err != nil { if err := os.Chmod(path, 0o0755); err != nil { //nolint:gosec // Set execution rights.
return fmt.Errorf("failed to chmod %s: %w", path, err) return fmt.Errorf("failed to chmod %s: %w", path, err)
} }
@ -367,7 +369,7 @@ func execute(opts *Options, args []string) (cont bool, err error) {
case <-time.After(3 * time.Minute): // portmaster core prints stack if not able to shutdown in 3 minutes, give it one more ... case <-time.After(3 * time.Minute): // portmaster core prints stack if not able to shutdown in 3 minutes, give it one more ...
err = exc.Process.Kill() err = exc.Process.Kill()
if err != nil { if err != nil {
return false, fmt.Errorf("failed to kill %s: %s", opts.Identifier, err) return false, fmt.Errorf("failed to kill %s: %w", opts.Identifier, err)
} }
return false, fmt.Errorf("killed %s", opts.Identifier) return false, fmt.Errorf("killed %s", opts.Identifier)
} }
@ -402,7 +404,8 @@ func parseExitError(err error) (restart bool, errWithCtx error) {
return false, nil return false, nil
} }
if exErr, ok := err.(*exec.ExitError); ok { var exErr *exec.ExitError
if errors.As(err, &exErr) {
switch exErr.ProcessState.ExitCode() { switch exErr.ProcessState.ExitCode() {
case 0: case 0:
return false, fmt.Errorf("clean exit with error: %w", err) return false, fmt.Errorf("clean exit with error: %w", err)

View file

@ -4,8 +4,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/safing/portmaster/updates/helper"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/safing/portmaster/updates/helper"
) )
func init() { func init() {
@ -35,7 +36,7 @@ func show(opts *Options, cmdArgs []string) error {
helper.PlatformIdentifier(opts.Identifier), helper.PlatformIdentifier(opts.Identifier),
) )
if err != nil { if err != nil {
return fmt.Errorf("could not get component: %s", err) return fmt.Errorf("could not get component: %w", err)
} }
fmt.Printf("%s %s\n", file.Path(), strings.Join(args, " ")) fmt.Printf("%s %s\n", file.Path(), strings.Join(args, " "))

View file

@ -5,10 +5,16 @@ import (
) )
var ( var (
startupComplete = make(chan struct{}) // signal that the start procedure completed (is never closed, just signaled once) // startupComplete signals that the start procedure completed.
shuttingDown = make(chan struct{}) // signal that we are shutting down (will be closed, may not be closed directly, use initiateShutdown) // The channel is not closed, just signaled once.
//nolint:unused // false positive on linux, currently used by windows only startupComplete = make(chan struct{})
shutdownError error // protected by shutdownLock
// shuttingDown signals that we are shutting down.
// The channel will be closed, but may not be closed directly - only via initiateShutdown.
shuttingDown = make(chan struct{})
// shutdownError is protected by shutdownLock.
shutdownError error //nolint:unused,errname // Not what the linter thinks it is. Currently used on windows only.
shutdownLock sync.Mutex shutdownLock sync.Mutex
) )

View file

@ -5,9 +5,10 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/spf13/cobra"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/updates/helper" "github.com/safing/portmaster/updates/helper"
"github.com/spf13/cobra"
) )
var reset bool var reset bool

View file

@ -8,8 +8,9 @@ import (
"strings" "strings"
"text/tabwriter" "text/tabwriter"
"github.com/safing/portbase/info"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/safing/portbase/info"
) )
var ( var (
@ -64,9 +65,7 @@ var (
fmt.Fprintf(tw, " %s\t%s\n", identifier, res.SelectedVersion.VersionNumber) fmt.Fprintf(tw, " %s\t%s\n", identifier, res.SelectedVersion.VersionNumber)
} }
tw.Flush() return tw.Flush()
return nil
}, },
} }
) )

View file

@ -12,9 +12,7 @@ import (
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
) )
const ( const dnsResolver = "1.1.1.1:53"
dnsResolver = "1.1.1.1:53"
)
var ( var (
url string url string
@ -72,7 +70,9 @@ func makeHTTPRequest(i int) {
log.Errorf("http request #%d failed after %s: %s", i, time.Since(start).Round(time.Millisecond), err) log.Errorf("http request #%d failed after %s: %s", i, time.Since(start).Round(time.Millisecond), err)
return return
} }
defer resp.Body.Close() defer func() {
_ = resp.Body.Close()
}()
log.Infof("http response #%d after %s: %d", i, time.Since(start).Round(time.Millisecond), resp.StatusCode) log.Infof("http response #%d after %s: %d", i, time.Since(start).Round(time.Millisecond), resp.StatusCode)
} }

View file

@ -5,9 +5,10 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/spf13/cobra"
"github.com/safing/portbase/updater" "github.com/safing/portbase/updater"
"github.com/safing/portbase/utils" "github.com/safing/portbase/utils"
"github.com/spf13/cobra"
) )
var ( var (
@ -30,7 +31,7 @@ var rootCmd = &cobra.Command{
} }
registry = &updater.ResourceRegistry{} registry = &updater.ResourceRegistry{}
err = registry.Initialize(utils.NewDirStructure(absDistPath, 0755)) err = registry.Initialize(utils.NewDirStructure(absDistPath, 0o0755))
if err != nil { if err != nil {
return err return err
} }

View file

@ -8,9 +8,9 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/safing/portbase/updater"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/safing/portbase/updater"
) )
var ( var (
@ -44,13 +44,13 @@ func release(cmd *cobra.Command, args []string) error {
// Check if we want to reset instead. // Check if we want to reset instead.
if resetPreReleases { if resetPreReleases {
return removeFilesFromIndex(getChannelVersions(channel, preReleaseFrom, true)) return removeFilesFromIndex(getChannelVersions(preReleaseFrom, true))
} }
// Write new index. // Write new index.
err := writeIndex( err := writeIndex(
channel, channel,
getChannelVersions(channel, preReleaseFrom, false), getChannelVersions(preReleaseFrom, false),
) )
if err != nil { if err != nil {
return err return err
@ -95,7 +95,7 @@ func writeIndex(channel string, versions map[string]string) error {
} }
// Write new index to disk. // Write new index to disk.
err = ioutil.WriteFile(indexFilePath, versionData, 0644) //nolint:gosec // 0644 is intended err = ioutil.WriteFile(indexFilePath, versionData, 0o0644) //nolint:gosec // 0644 is intended
if err != nil { if err != nil {
return err return err
} }
@ -129,7 +129,7 @@ func removeFilesFromIndex(versions map[string]string) error {
return nil return nil
} }
func getChannelVersions(channel string, prereleaseFrom string, storagePath bool) map[string]string { func getChannelVersions(prereleaseFrom string, storagePath bool) map[string]string {
if prereleaseFrom != "" { if prereleaseFrom != "" {
registry.AddIndex(updater.Index{ registry.AddIndex(updater.Index{
Path: prereleaseFrom + ".json", Path: prereleaseFrom + ".json",

View file

@ -7,6 +7,7 @@ import (
"github.com/safing/portmaster/process" "github.com/safing/portmaster/process"
) )
// SubmitSystemIntegrationCheckPacket submit a packet for the system integrity check.
func SubmitSystemIntegrationCheckPacket(p packet.Packet) { func SubmitSystemIntegrationCheckPacket(p packet.Packet) {
select { select {
case systemIntegrationCheckPackets <- p: case systemIntegrationCheckPackets <- p:
@ -14,6 +15,7 @@ func SubmitSystemIntegrationCheckPacket(p packet.Packet) {
} }
} }
// SubmitDNSCheckDomain submits a subdomain for the dns check.
func SubmitDNSCheckDomain(subdomain string) (respondWith net.IP) { func SubmitDNSCheckDomain(subdomain string) (respondWith net.IP) {
// Submit queried domain. // Submit queried domain.
select { select {
@ -27,10 +29,12 @@ func SubmitDNSCheckDomain(subdomain string) (respondWith net.IP) {
return dnsCheckAnswer return dnsCheckAnswer
} }
// ReportSecureDNSBypassIssue reports a DNS bypassing issue for the given process.
func ReportSecureDNSBypassIssue(p *process.Process) { func ReportSecureDNSBypassIssue(p *process.Process) {
secureDNSBypassIssue.notify(p) secureDNSBypassIssue.notify(p)
} }
// ReportMultiPeerUDPTunnelIssue reports a multi-peer UDP tunnel for the given process.
func ReportMultiPeerUDPTunnelIssue(p *process.Process) { func ReportMultiPeerUDPTunnelIssue(p *process.Process) {
multiPeerUDPTunnelIssue.notify(p) multiPeerUDPTunnelIssue.notify(p)
} }

View file

@ -4,11 +4,12 @@ import (
"context" "context"
"time" "time"
"github.com/tevino/abool"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portmaster/netenv" "github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/resolver" "github.com/safing/portmaster/resolver"
"github.com/tevino/abool"
) )
var ( var (
@ -43,7 +44,7 @@ func prep() error {
func start() error { func start() error {
selfcheckTask = module.NewTask("compatibility self-check", selfcheckTaskFunc). selfcheckTask = module.NewTask("compatibility self-check", selfcheckTaskFunc).
Repeat(1 * time.Minute). Repeat(5 * time.Minute).
MaxDelay(selfcheckTaskRetryAfter). MaxDelay(selfcheckTaskRetryAfter).
Schedule(time.Now().Add(selfcheckTaskRetryAfter)) Schedule(time.Now().Add(selfcheckTaskRetryAfter))
@ -98,6 +99,9 @@ func selfcheckTaskFunc(ctx context.Context, task *modules.Task) error {
return nil return nil
} }
// SelfCheckIsFailing returns whether the self check is currently failing.
// This returns true after the first check fails, and does not wait for the
// failing threshold to be met.
func SelfCheckIsFailing() bool { func SelfCheckIsFailing() bool {
return selfCheckIsFailing.IsSet() return selfCheckIsFailing.IsSet()
} }

View file

@ -7,18 +7,17 @@ import (
"sync" "sync"
"time" "time"
"github.com/safing/portmaster/profile"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/notifications" "github.com/safing/portbase/notifications"
"github.com/safing/portmaster/process" "github.com/safing/portmaster/process"
"github.com/safing/portmaster/profile"
) )
type baseIssue struct { type baseIssue struct {
id string id string //nolint:structcheck // Inherited.
title string title string //nolint:structcheck // Inherited.
message string message string //nolint:structcheck // Inherited.
level notifications.Type level notifications.Type //nolint:structcheck // Inherited.
} }
type systemIssue baseIssue type systemIssue baseIssue

View file

@ -18,22 +18,24 @@ import (
var ( var (
selfcheckLock sync.Mutex selfcheckLock sync.Mutex
SystemIntegrationCheckDstIP = net.IPv4(127, 65, 67, 75) // SystemIntegrationCheckDstIP is the IP address to send a packet to for the
// system integration test.
SystemIntegrationCheckDstIP = net.IPv4(127, 65, 67, 75)
// SystemIntegrationCheckProtocol is the IP protocol to use for the system
// integration test.
SystemIntegrationCheckProtocol = packet.AnyHostInternalProtocol61 SystemIntegrationCheckProtocol = packet.AnyHostInternalProtocol61
systemIntegrationCheckDialNet = fmt.Sprintf("ip4:%d", uint8(SystemIntegrationCheckProtocol)) systemIntegrationCheckDialNet = fmt.Sprintf("ip4:%d", uint8(SystemIntegrationCheckProtocol))
systemIntegrationCheckDialIP = SystemIntegrationCheckDstIP.String() systemIntegrationCheckDialIP = SystemIntegrationCheckDstIP.String()
systemIntegrationCheckPackets = make(chan packet.Packet, 1) systemIntegrationCheckPackets = make(chan packet.Packet, 1)
systemIntegrationCheckWaitDuration = 3 * time.Second systemIntegrationCheckWaitDuration = 10 * time.Second
// DNSCheckInternalDomainScope is the domain scope to use for dns checks.
DNSCheckInternalDomainScope = ".self-check." + resolver.InternalSpecialUseDomain DNSCheckInternalDomainScope = ".self-check." + resolver.InternalSpecialUseDomain
dnsCheckReceivedDomain = make(chan string, 1) dnsCheckReceivedDomain = make(chan string, 1)
dnsCheckWaitDuration = 3 * time.Second dnsCheckWaitDuration = 10 * time.Second
dnsCheckAnswerLock sync.Mutex dnsCheckAnswerLock sync.Mutex
dnsCheckAnswer net.IP dnsCheckAnswer net.IP
DNSTestDomain = "one.one.one.one."
DNSTestExpectedIP = net.IPv4(1, 1, 1, 1)
) )
func selfcheck(ctx context.Context) (issue *systemIssue, err error) { func selfcheck(ctx context.Context) (issue *systemIssue, err error) {

View file

@ -3,14 +3,12 @@ package base
import ( import (
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
// database module // Dependencies.
_ "github.com/safing/portbase/database/dbmodule" _ "github.com/safing/portbase/database/dbmodule"
// module dependencies
_ "github.com/safing/portbase/database/storage/bbolt" _ "github.com/safing/portbase/database/storage/bbolt"
) )
// Default Values (changeable for testing) // Default Values (changeable for testing).
var ( var (
DefaultDatabaseStorageType = "bbolt" DefaultDatabaseStorageType = "bbolt"
) )

View file

@ -11,7 +11,7 @@ import (
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
) )
// Default Values (changeable for testing) // Default Values (changeable for testing).
var ( var (
DefaultAPIListenAddress = "127.0.0.1:817" DefaultAPIListenAddress = "127.0.0.1:817"
@ -56,7 +56,7 @@ func globalPrep() error {
} }
// initialize structure // initialize structure
err := dataroot.Initialize(dataDir, 0755) err := dataroot.Initialize(dataDir, 0o0755)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,19 +1,14 @@
package base package base
import ( import (
_ "github.com/safing/portbase/config"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/metrics" "github.com/safing/portbase/metrics"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
// module dependencies
_ "github.com/safing/portbase/config"
_ "github.com/safing/portbase/metrics"
_ "github.com/safing/portbase/rng" _ "github.com/safing/portbase/rng"
) )
var ( var module *modules.Module
module *modules.Module
)
func init() { func init() {
module = modules.Register("base", nil, start, nil, "database", "config", "rng", "metrics") module = modules.Register("base", nil, start, nil, "database", "config", "rng", "metrics")

View file

@ -8,9 +8,7 @@ import (
"runtime/pprof" "runtime/pprof"
) )
var ( var cpuProfile string
cpuProfile string
)
func init() { func init() {
flag.StringVar(&cpuProfile, "cpuprofile", "", "write cpu profile to `file`") flag.StringVar(&cpuProfile, "cpuprofile", "", "write cpu profile to `file`")
@ -25,10 +23,10 @@ func startProfiling() {
func cpuProfiler(ctx context.Context) error { func cpuProfiler(ctx context.Context) error {
f, err := os.Create(cpuProfile) f, err := os.Create(cpuProfile)
if err != nil { if err != nil {
return fmt.Errorf("could not create CPU profile: %s", err) return fmt.Errorf("could not create CPU profile: %w", err)
} }
if err := pprof.StartCPUProfile(f); err != nil { if err := pprof.StartCPUProfile(f); err != nil {
return fmt.Errorf("could not start CPU profile: %s", err) return fmt.Errorf("could not start CPU profile: %w", err)
} }
// wait for shutdown // wait for shutdown
@ -37,7 +35,7 @@ func cpuProfiler(ctx context.Context) error {
pprof.StopCPUProfile() pprof.StopCPUProfile()
err = f.Close() err = f.Close()
if err != nil { if err != nil {
return fmt.Errorf("failed to close CPU profile file: %s", err) return fmt.Errorf("failed to close CPU profile file: %w", err)
} }
return nil return nil
} }

View file

@ -7,12 +7,10 @@ import (
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/modules/subsystems"
"github.com/safing/portmaster/updates"
// module dependencies
_ "github.com/safing/portmaster/netenv" _ "github.com/safing/portmaster/netenv"
_ "github.com/safing/portmaster/status" _ "github.com/safing/portmaster/status"
_ "github.com/safing/portmaster/ui" _ "github.com/safing/portmaster/ui"
"github.com/safing/portmaster/updates"
) )
const ( const (
@ -65,7 +63,7 @@ func prep() error {
func start() error { func start() error {
if err := startPlatformSpecific(); err != nil { if err := startPlatformSpecific(); err != nil {
return fmt.Errorf("failed to start plattform-specific components: %s", err) return fmt.Errorf("failed to start plattform-specific components: %w", err)
} }
registerLogCleaner() registerLogCleaner()

View file

@ -1,4 +1,4 @@
// +build !windows // go:build !windows
package core package core

View file

@ -24,18 +24,14 @@ import (
"runtime/pprof" "runtime/pprof"
"testing" "testing"
_ "github.com/safing/portbase/database/storage/hashmap"
"github.com/safing/portbase/dataroot" "github.com/safing/portbase/dataroot"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portmaster/core/base" "github.com/safing/portmaster/core/base"
// module dependencies
_ "github.com/safing/portbase/database/storage/hashmap"
) )
var ( var printStackOnExit bool
printStackOnExit bool
)
func init() { func init() {
flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down") flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down")
@ -73,7 +69,7 @@ func TestMainWithHooks(m *testing.M, module *modules.Module, afterStartFn, befor
// tmp dir for data root (db & config) // tmp dir for data root (db & config)
tmpDir := filepath.Join(os.TempDir(), "portmaster-testing") tmpDir := filepath.Join(os.TempDir(), "portmaster-testing")
// initialize data dir // initialize data dir
err := dataroot.Initialize(tmpDir, 0755) err := dataroot.Initialize(tmpDir, 0o0755)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err)
os.Exit(1) os.Exit(1)

View file

@ -3,6 +3,8 @@ package dga
import "testing" import "testing"
func TestLmsScoreOfDomain(t *testing.T) { func TestLmsScoreOfDomain(t *testing.T) {
t.Parallel()
testDomain(t, "g.symcd.com.", 100, 100) testDomain(t, "g.symcd.com.", 100, 100)
testDomain(t, "www.google.com.", 100, 100) testDomain(t, "www.google.com.", 100, 100)
testDomain(t, "55ttt5.12abc3.test.com.", 68, 69) testDomain(t, "55ttt5.12abc3.test.com.", 68, 69)
@ -10,6 +12,8 @@ func TestLmsScoreOfDomain(t *testing.T) {
} }
func testDomain(t *testing.T, domain string, min, max float64) { func testDomain(t *testing.T, domain string, min, max float64) {
t.Helper()
score := LmsScoreOfDomain(domain) score := LmsScoreOfDomain(domain)
if score < min || score > max { if score < min || score > max {
t.Errorf("domain %s has scored %.2f, but should be between %.0f and %.0f", domain, score, min, max) t.Errorf("domain %s has scored %.2f, but should be between %.0f and %.0f", domain, score, min, max)

View file

@ -11,16 +11,14 @@ import (
"strings" "strings"
"time" "time"
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/updates"
"github.com/safing/portbase/api" "github.com/safing/portbase/api"
"github.com/safing/portbase/dataroot" "github.com/safing/portbase/dataroot"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/utils" "github.com/safing/portbase/utils"
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/process" "github.com/safing/portmaster/process"
"github.com/safing/portmaster/updates"
) )
const ( const (
@ -79,13 +77,13 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er
// get local IP/Port // get local IP/Port
localIP, localPort, err := parseHostPort(s.Addr) localIP, localPort, err := parseHostPort(s.Addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get local IP/Port: %s", err) return nil, fmt.Errorf("failed to get local IP/Port: %w", err)
} }
// get remote IP/Port // get remote IP/Port
remoteIP, remotePort, err := parseHostPort(r.RemoteAddr) remoteIP, remotePort, err := parseHostPort(r.RemoteAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get remote IP/Port: %s", err) return nil, fmt.Errorf("failed to get remote IP/Port: %w", err)
} }
// Check if the request is even local. // Check if the request is even local.
@ -151,11 +149,12 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo
// Go up up to two levels, if we don't match the path. // Go up up to two levels, if we don't match the path.
checkLevels := 2 checkLevels := 2
checkLevelsLoop:
for i := 0; i < checkLevels+1; i++ { for i := 0; i < checkLevels+1; i++ {
// Check for eligible path. // Check for eligible path.
switch proc.Pid { switch proc.Pid {
case process.UnidentifiedProcessID, process.SystemProcessID: case process.UnidentifiedProcessID, process.SystemProcessID:
break break checkLevelsLoop
default: // normal process default: // normal process
// Check if the requesting process is in database root / updates dir. // Check if the requesting process is in database root / updates dir.
if strings.HasPrefix(proc.Path, authenticatedPath) { if strings.HasPrefix(proc.Path, authenticatedPath) {

View file

@ -5,16 +5,13 @@ import (
"strings" "strings"
"github.com/safing/portmaster/compat" "github.com/safing/portmaster/compat"
"github.com/safing/portmaster/nameserver/nsutil" "github.com/safing/portmaster/nameserver/nsutil"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/profile/endpoints" "github.com/safing/portmaster/profile/endpoints"
) )
var ( var resolverFilterLists = []string{"17-DNS"}
resolverFilterLists = []string{"17-DNS"}
)
// PreventBypassing checks if the connection should be denied or permitted // PreventBypassing checks if the connection should be denied or permitted
// based on some bypass protection checks. // based on some bypass protection checks.
@ -27,7 +24,7 @@ func PreventBypassing(ctx context.Context, conn *network.Connection) (endpoints.
} }
// Block direct connections to known DNS resolvers. // Block direct connections to known DNS resolvers.
switch packet.IPProtocol(conn.Entity.Protocol) { switch packet.IPProtocol(conn.Entity.Protocol) { //nolint:exhaustive // Checking for specific values only.
case packet.ICMP, packet.ICMPv6: case packet.ICMP, packet.ICMPv6:
// Make an exception for ICMP, as these IPs are also often used for debugging. // Make an exception for ICMP, as these IPs are also often used for debugging.
default: default:

View file

@ -2,10 +2,12 @@ package firewall
import ( import (
"context" "context"
"errors"
"net" "net"
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
@ -22,7 +24,6 @@ func filterDNSSection(
resolverScope netutils.IPScope, resolverScope netutils.IPScope,
sysResolver bool, sysResolver bool,
) ([]dns.RR, []string, int, string) { ) ([]dns.RR, []string, int, string) {
// Will be filled 1:1 most of the time. // Will be filled 1:1 most of the time.
goodEntries := make([]dns.RR, 0, len(entries)) goodEntries := make([]dns.RR, 0, len(entries))
@ -275,7 +276,7 @@ func UpdateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw
} }
// Resolve all CNAMEs in the correct order and add the to the record. // Resolve all CNAMEs in the correct order and add the to the record.
var domain = q.FQDN domain := q.FQDN
for { for {
nextDomain, isCNAME := cnames[domain] nextDomain, isCNAME := cnames[domain]
if !isCNAME { if !isCNAME {
@ -294,7 +295,7 @@ func UpdateIPsAndCNAMEs(q *resolver.Query, rrCache *resolver.RRCache, conn *netw
ipString := ip.String() ipString := ip.String()
info, err := resolver.GetIPInfo(profileID, ipString) info, err := resolver.GetIPInfo(profileID, ipString)
if err != nil { if err != nil {
if err != database.ErrNotFound { if !errors.Is(err, database.ErrNotFound) {
log.Errorf("nameserver: failed to search for IP info record: %s", err) log.Errorf("nameserver: failed to search for IP info record: %s", err)
} }

View file

@ -2,13 +2,12 @@ package firewall
import ( import (
"github.com/safing/portbase/config" "github.com/safing/portbase/config"
"github.com/safing/portbase/modules/subsystems"
"github.com/safing/spn/captain"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
// module dependencies // Dependency.
_ "github.com/safing/portmaster/core" _ "github.com/safing/portmaster/core"
"github.com/safing/spn/captain"
) )
var ( var (

View file

@ -9,28 +9,25 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/safing/portmaster/compat"
"github.com/safing/spn/captain"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/safing/portmaster/netenv"
"golang.org/x/sync/singleflight"
"github.com/tevino/abool" "github.com/tevino/abool"
"golang.org/x/sync/singleflight"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portmaster/compat"
// Dependency.
_ "github.com/safing/portmaster/core/base"
"github.com/safing/portmaster/firewall/inspection" "github.com/safing/portmaster/firewall/inspection"
"github.com/safing/portmaster/firewall/interception" "github.com/safing/portmaster/firewall/interception"
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
"github.com/safing/spn/captain"
"github.com/safing/spn/crew" "github.com/safing/spn/crew"
"github.com/safing/spn/sluice" "github.com/safing/spn/sluice"
// module dependencies
_ "github.com/safing/portmaster/core/base"
) )
var ( var (
@ -141,14 +138,14 @@ func getConnection(pkt packet.Packet) (*network.Connection, error) {
return conn, nil return conn, nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get connection: %s", err) return nil, fmt.Errorf("failed to get connection: %w", err)
} }
if newConn == nil { if newConn == nil {
return nil, errors.New("connection getter returned nil") return nil, errors.New("connection getter returned nil")
} }
// Transform and log result. // Transform and log result.
conn := newConn.(*network.Connection) conn := newConn.(*network.Connection) //nolint:forcetypeassert // Can only be a *network.Connection.
sharedIndicator := "" sharedIndicator := ""
if shared { if shared {
sharedIndicator = " (shared)" sharedIndicator = " (shared)"
@ -188,7 +185,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
return true return true
} }
switch meta.Protocol { switch meta.Protocol { //nolint:exhaustive // Checking for specific values only.
case packet.ICMP, packet.ICMPv6: case packet.ICMP, packet.ICMPv6:
// Load packet data. // Load packet data.
err := pkt.LoadPacketData() err := pkt.LoadPacketData()
@ -243,7 +240,7 @@ func fastTrackedPermit(pkt packet.Packet) (handled bool) {
} }
// DHCP is only valid in local network scopes. // DHCP is only valid in local network scopes.
switch netutils.ClassifyIP(meta.Dst) { switch netutils.ClassifyIP(meta.Dst) { //nolint:exhaustive // Checking for specific values only.
case netutils.HostLocal, netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast: case netutils.HostLocal, netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
default: default:
return false return false
@ -430,7 +427,6 @@ func initialHandler(conn *network.Connection, pkt packet.Packet) {
conn.StopFirewallHandler() conn.StopFirewallHandler()
issueVerdict(conn, pkt, 0, true) issueVerdict(conn, pkt, 0, true)
} }
} }
func defaultHandler(conn *network.Connection, pkt packet.Packet) { func defaultHandler(conn *network.Connection, pkt packet.Packet) {
@ -494,6 +490,9 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V
case network.VerdictFailed: case network.VerdictFailed:
atomic.AddUint64(packetsFailed, 1) atomic.AddUint64(packetsFailed, 1)
err = pkt.Drop() err = pkt.Drop()
case network.VerdictUndecided, network.VerdictUndeterminable:
log.Warningf("filter: tried to apply verdict %s to pkt %s: dropping instead", verdict, pkt)
fallthrough
default: default:
atomic.AddUint64(packetsDropped, 1) atomic.AddUint64(packetsDropped, 1)
err = pkt.Drop() err = pkt.Drop()

View file

@ -25,7 +25,7 @@ func Start() error {
return nil return nil
} }
var inputPackets = Packets inputPackets := Packets
if packetMetricsDestination != "" { if packetMetricsDestination != "" {
go metrics.writeMetrics() go metrics.writeMetrics()
inputPackets = make(chan packet.Packet) inputPackets = make(chan packet.Packet)

View file

@ -58,7 +58,9 @@ func (pm *packetMetrics) writeMetrics() {
log.Errorf("Failed to create packet metrics file: %s", err) log.Errorf("Failed to create packet metrics file: %s", err)
return return
} }
defer f.Close() defer func() {
_ = f.Close()
}()
for { for {
select { select {

View file

@ -1,4 +1,4 @@
// +build linux // go:build linux
// Package nfq contains a nfqueue library experiment. // Package nfq contains a nfqueue library experiment.
package nfq package nfq
@ -10,15 +10,15 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/safing/portbase/log" "github.com/florianl/go-nfqueue"
pmpacket "github.com/safing/portmaster/network/packet"
"github.com/tevino/abool" "github.com/tevino/abool"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/florianl/go-nfqueue" "github.com/safing/portbase/log"
pmpacket "github.com/safing/portmaster/network/packet"
) )
// Queue wraps a nfqueue // Queue wraps a nfqueue.
type Queue struct { type Queue struct {
id uint16 id uint16
afFamily uint8 afFamily uint8
@ -32,7 +32,7 @@ type Queue struct {
} }
func (q *Queue) getNfq() *nfqueue.Nfqueue { func (q *Queue) getNfq() *nfqueue.Nfqueue {
return q.nf.Load().(*nfqueue.Nfqueue) return q.nf.Load().(*nfqueue.Nfqueue) //nolint:forcetypeassert // TODO: Check.
} }
// New opens a new nfQueue. // New opens a new nfQueue.
@ -112,7 +112,7 @@ func (q *Queue) open(ctx context.Context) error {
} }
if err := nf.RegisterWithErrorFunc(ctx, q.packetHandler(ctx), q.handleError); err != nil { if err := nf.RegisterWithErrorFunc(ctx, q.packetHandler(ctx), q.handleError); err != nil {
defer nf.Close() _ = nf.Close()
return err return err
} }
@ -124,7 +124,7 @@ func (q *Queue) open(ctx context.Context) error {
func (q *Queue) handleError(e error) int { func (q *Queue) handleError(e error) int {
// embedded interface is required to work-around some // embedded interface is required to work-around some
// dep-vendoring weirdness // dep-vendoring weirdness
if opError, ok := e.(interface { if opError, ok := e.(interface { //nolint:errorlint // TODO: Check if we can remove workaround.
Timeout() bool Timeout() bool
Temporary() bool Temporary() bool
}); ok { }); ok {
@ -153,7 +153,7 @@ func (q *Queue) handleError(e error) int {
// Close the existing socket // Close the existing socket
if nf := q.getNfq(); nf != nil { if nf := q.getNfq(); nf != nil {
nf.Close() _ = nf.Close()
} }
// Trigger a restart of the queue // Trigger a restart of the queue

View file

@ -1,4 +1,4 @@
// +build linux // go:build linux
package nfq package nfq
@ -8,9 +8,9 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/florianl/go-nfqueue"
"github.com/tevino/abool" "github.com/tevino/abool"
"github.com/florianl/go-nfqueue"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
pmpacket "github.com/safing/portmaster/network/packet" pmpacket "github.com/safing/portmaster/network/packet"
) )
@ -104,7 +104,7 @@ func (pkt *packet) setMark(mark int) error {
if err := pkt.queue.getNfq().SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil { if err := pkt.queue.getNfq().SetVerdictWithMark(pkt.pktID, nfqueue.NfAccept, mark); err != nil {
// embedded interface is required to work-around some // embedded interface is required to work-around some
// dep-vendoring weirdness // dep-vendoring weirdness
if opErr, ok := err.(interface { if opErr, ok := err.(interface { //nolint:errorlint // TODO: Check if we can remove workaround.
Timeout() bool Timeout() bool
Temporary() bool Temporary() bool
}); ok { }); ok {

View file

@ -44,7 +44,6 @@ type nfQueue interface {
} }
func init() { func init() {
v4chains = []string{ v4chains = []string{
"mangle C170", "mangle C170",
"mangle C171", "mangle C171",
@ -128,7 +127,6 @@ func init() {
// Reverse because we'd like to insert in a loop // Reverse because we'd like to insert in a loop
_ = sort.Reverse(sort.StringSlice(v4once)) // silence vet (sort is used just like in the docs) _ = sort.Reverse(sort.StringSlice(v4once)) // silence vet (sort is used just like in the docs)
_ = sort.Reverse(sort.StringSlice(v6once)) // silence vet (sort is used just like in the docs) _ = sort.Reverse(sort.StringSlice(v6once)) // silence vet (sort is used just like in the docs)
} }
func activateNfqueueFirewall() error { func activateNfqueueFirewall() error {
@ -241,7 +239,7 @@ func StartNfqueueInterception(packets chan<- packet.Packet) (err error) {
err = activateNfqueueFirewall() err = activateNfqueueFirewall()
if err != nil { if err != nil {
_ = Stop() _ = Stop()
return fmt.Errorf("could not initialize nfqueue: %s", err) return fmt.Errorf("could not initialize nfqueue: %w", err)
} }
out4Queue, err = nfq.New(17040, false) out4Queue, err = nfq.New(17040, false)
@ -288,7 +286,7 @@ func StopNfqueueInterception() error {
err := DeactivateNfqueueFirewall() err := DeactivateNfqueueFirewall()
if err != nil { if err != nil {
return fmt.Errorf("interception: error while deactivating nfqueue: %s", err) return fmt.Errorf("interception: error while deactivating nfqueue: %w", err)
} }
return nil return nil

View file

@ -6,11 +6,12 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/safing/portmaster/detection/dga" "github.com/agext/levenshtein"
"github.com/safing/portmaster/netenv"
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/detection/dga"
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
@ -18,8 +19,6 @@ import (
"github.com/safing/portmaster/process" "github.com/safing/portmaster/process"
"github.com/safing/portmaster/profile" "github.com/safing/portmaster/profile"
"github.com/safing/portmaster/profile/endpoints" "github.com/safing/portmaster/profile/endpoints"
"github.com/agext/levenshtein"
) )
// Call order: // Call order:
@ -215,6 +214,8 @@ func checkEndpointLists(ctx context.Context, conn *network.Connection, p *profil
case endpoints.Permitted: case endpoints.Permitted:
conn.AcceptWithContext(reason.String(), optionKey, reason.Context()) conn.AcceptWithContext(reason.String(), optionKey, reason.Context())
return true return true
case endpoints.NoMatch:
return false
} }
return false return false
@ -236,6 +237,8 @@ func checkEndpointListsForSystemResolverDNSRequests(ctx context.Context, conn *n
case endpoints.Permitted: case endpoints.Permitted:
conn.AcceptWithContext(reason.String(), profile.CfgOptionEndpointsKey, reason.Context()) conn.AcceptWithContext(reason.String(), profile.CfgOptionEndpointsKey, reason.Context())
return true return true
case endpoints.NoMatch:
return false
} }
} }
} }
@ -345,7 +348,9 @@ func checkConnectionScope(_ context.Context, conn *network.Connection, p *profil
conn.Block("Localhost access blocked", profile.CfgOptionBlockScopeLocalKey) // Block Outbound / Drop Inbound conn.Block("Localhost access blocked", profile.CfgOptionBlockScopeLocalKey) // Block Outbound / Drop Inbound
return true return true
} }
default: // netutils.Unknown and netutils.Invalid case netutils.Undefined, netutils.Invalid:
fallthrough
default:
conn.Deny("invalid IP", noReasonOptionKey) // Block Outbound / Drop Inbound conn.Deny("invalid IP", noReasonOptionKey) // Block Outbound / Drop Inbound
return true return true
} }
@ -358,14 +363,19 @@ func checkBypassPrevention(ctx context.Context, conn *network.Connection, p *pro
// check for bypass protection // check for bypass protection
result, reason, reasonCtx := PreventBypassing(ctx, conn) result, reason, reasonCtx := PreventBypassing(ctx, conn)
switch result { switch result {
case endpoints.Denied: case endpoints.Denied, endpoints.MatchError:
// Also block on MatchError to be on the safe side.
// PreventBypassing does not use any data that needs to be loaded, so it should not fail anyway.
conn.BlockWithContext("bypass prevention: "+reason, profile.CfgOptionPreventBypassingKey, reasonCtx) conn.BlockWithContext("bypass prevention: "+reason, profile.CfgOptionPreventBypassingKey, reasonCtx)
return true return true
case endpoints.Permitted: case endpoints.Permitted:
conn.AcceptWithContext("bypass prevention: "+reason, profile.CfgOptionPreventBypassingKey, reasonCtx) conn.AcceptWithContext("bypass prevention: "+reason, profile.CfgOptionPreventBypassingKey, reasonCtx)
return true return true
case endpoints.NoMatch:
return false
} }
} }
return false return false
} }
@ -378,6 +388,8 @@ func checkFilterLists(ctx context.Context, conn *network.Connection, p *profile.
return true return true
case endpoints.NoMatch: case endpoints.NoMatch:
// nothing to do // nothing to do
case endpoints.Permitted, endpoints.MatchError:
fallthrough
default: default:
log.Tracer(ctx).Debugf("filter: filter lists returned unsupported verdict: %s", result) log.Tracer(ctx).Debugf("filter: filter lists returned unsupported verdict: %s", result)
} }

View file

@ -1,16 +1,14 @@
package firewall package firewall
import ( import (
"fmt"
"net"
"strconv" "strconv"
"sync" "sync"
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
"fmt"
"net"
"github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/resolver" "github.com/safing/portmaster/resolver"
) )

View file

@ -16,7 +16,7 @@ import (
) )
const ( const (
// notification action IDs // notification action IDs.
allowDomainAll = "allow-domain-all" allowDomainAll = "allow-domain-all"
allowDomainDistinct = "allow-domain-distinct" allowDomainDistinct = "allow-domain-distinct"
blockDomainAll = "block-domain-all" blockDomainAll = "block-domain-all"

View file

@ -7,6 +7,7 @@ import (
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/nameserver/nsutil" "github.com/safing/portmaster/nameserver/nsutil"
) )

View file

@ -8,12 +8,13 @@ import (
"strings" "strings"
"sync" "sync"
"golang.org/x/net/publicsuffix"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/intel/filterlists" "github.com/safing/portmaster/intel/filterlists"
"github.com/safing/portmaster/intel/geoip" "github.com/safing/portmaster/intel/geoip"
"github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/status" "github.com/safing/portmaster/status"
"golang.org/x/net/publicsuffix"
) )
// Entity describes a remote endpoint in many different ways. // Entity describes a remote endpoint in many different ways.
@ -21,7 +22,7 @@ import (
// functions performs locking. The caller MUST ENSURE // functions performs locking. The caller MUST ENSURE
// proper locking and synchronization when accessing // proper locking and synchronization when accessing
// any properties of Entity. // any properties of Entity.
type Entity struct { type Entity struct { //nolint:maligned
sync.Mutex sync.Mutex
// lists exist for most entity information and // lists exist for most entity information and
@ -319,7 +320,7 @@ func (e *Entity) getDomainLists(ctx context.Context) {
log.Tracer(ctx).Tracef("intel: loading domain list for %s", domain) log.Tracer(ctx).Tracef("intel: loading domain list for %s", domain)
e.loadDomainListOnce.Do(func() { e.loadDomainListOnce.Do(func() {
var domainsToInspect = []string{domain} domainsToInspect := []string{domain}
if e.checkCNAMEs && len(e.CNAME) > 0 { if e.checkCNAMEs && len(e.CNAME) > 0 {
log.Tracer(ctx).Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME) log.Tracer(ctx).Tracef("intel: CNAME filtering enabled, checking %v too", e.CNAME)

View file

@ -6,9 +6,10 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/tannerryan/ring"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/tannerryan/ring"
) )
var defaultFilter = newScopedBloom() var defaultFilter = newScopedBloom()

View file

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/safing/portbase/database"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
) )

View file

@ -9,12 +9,13 @@ import (
"sync" "sync"
"time" "time"
"golang.org/x/sync/errgroup"
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/updater" "github.com/safing/portbase/updater"
"github.com/safing/portmaster/updates" "github.com/safing/portmaster/updates"
"golang.org/x/sync/errgroup"
) )
const ( const (
@ -46,13 +47,11 @@ var (
filterListsLoaded chan struct{} filterListsLoaded chan struct{}
) )
var ( var cache = database.NewInterface(&database.Options{
cache = database.NewInterface(&database.Options{ Local: true,
Local: true, Internal: true,
Internal: true, CacheSize: 2 ^ 8,
CacheSize: 2 ^ 8, })
})
)
// getFileFunc is the function used to get a file from // getFileFunc is the function used to get a file from
// the updater. It's basically updates.GetFile and used // the updater. It's basically updates.GetFile and used
@ -85,7 +84,9 @@ func processListFile(ctx context.Context, filter *scopedBloom, file *updater.Fil
if err != nil { if err != nil {
return err return err
} }
defer f.Close() defer func() {
_ = f.Close()
}()
values := make(chan *listEntry, 100) values := make(chan *listEntry, 100)
records := make(chan record.Record, 100) records := make(chan record.Record, 100)

View file

@ -4,6 +4,7 @@ import (
"compress/gzip" "compress/gzip"
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
@ -57,7 +58,7 @@ func decodeFile(ctx context.Context, r io.Reader, ch chan<- *listEntry) error {
entryCount++ entryCount++
length, readErr := binary.ReadUvarint(reader) length, readErr := binary.ReadUvarint(reader)
if readErr != nil { if readErr != nil {
if readErr == io.EOF { if errors.Is(readErr, io.EOF) {
return nil return nil
} }
return fmt.Errorf("failed to load varint entity length: %w", readErr) return fmt.Errorf("failed to load varint entity length: %w", readErr)
@ -66,7 +67,7 @@ func decodeFile(ctx context.Context, r io.Reader, ch chan<- *listEntry) error {
blob := make([]byte, length) blob := make([]byte, length)
_, readErr = io.ReadFull(reader, blob) _, readErr = io.ReadFull(reader, blob)
if readErr != nil { if readErr != nil {
if readErr == io.EOF { if errors.Is(readErr, io.EOF) {
// there shouldn't be an EOF here because // there shouldn't be an EOF here because
// we actually got a length above. Return // we actually got a length above. Return
// ErrUnexpectedEOF instead of just EOF. // ErrUnexpectedEOF instead of just EOF.

View file

@ -163,7 +163,7 @@ func getListIndexFromCache() (*ListIndexFile, error) {
} }
var ( var (
// listIndexUpdate must only be used by updateListIndex // listIndexUpdate must only be used by updateListIndex.
listIndexUpdate *updater.File listIndexUpdate *updater.File
listIndexUpdateLock sync.Mutex listIndexUpdateLock sync.Mutex
) )
@ -232,9 +232,8 @@ func updateListIndex() error {
// a slice of distinct source IDs. // a slice of distinct source IDs.
func ResolveListIDs(ids []string) ([]string, error) { func ResolveListIDs(ids []string) ([]string, error) {
index, err := getListIndexFromCache() index, err := getListIndexFromCache()
if err != nil { if err != nil {
if err == database.ErrNotFound { if errors.Is(err, database.ErrNotFound) {
if err := updateListIndex(); err != nil { if err := updateListIndex(); err != nil {
return nil, err return nil, err
} }

View file

@ -33,7 +33,7 @@ func lookupBlockLists(entity, value string) ([]string, error) {
// log.Debugf("intel/filterlists: searching for entries with %s", key) // log.Debugf("intel/filterlists: searching for entries with %s", key)
entry, err := getEntityRecordByKey(key) entry, err := getEntityRecordByKey(key)
if err != nil { if err != nil {
if err == database.ErrNotFound { if errors.Is(err, database.ErrNotFound) {
return nil, nil return nil, nil
} }
log.Errorf("intel/filterlists: failed to get entries for key %s: %s", key, err) log.Errorf("intel/filterlists: failed to get entries for key %s: %s", key, err)
@ -103,7 +103,6 @@ func LookupIPv4String(ipv4 string) ([]string, error) {
// LookupIPv4 is like LookupIPv4String but accepts a net.IP. // LookupIPv4 is like LookupIPv4String but accepts a net.IP.
func LookupIPv4(ipv4 net.IP) ([]string, error) { func LookupIPv4(ipv4 net.IP) ([]string, error) {
ip := ipv4.To4() ip := ipv4.To4()
if ip == nil { if ip == nil {
return nil, errors.New("invalid IPv4") return nil, errors.New("invalid IPv4")

View file

@ -4,22 +4,20 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/tevino/abool"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portmaster/netenv" "github.com/safing/portmaster/netenv"
"github.com/safing/portmaster/updates" "github.com/safing/portmaster/updates"
"github.com/tevino/abool"
) )
var ( var module *modules.Module
module *modules.Module
)
const ( const (
filterlistsDisabled = "filterlists:disabled" filterlistsDisabled = "filterlists:disabled"
filterlistsUpdateFailed = "filterlists:update-failed" filterlistsUpdateFailed = "filterlists:update-failed"
filterlistsStaleDataSurvived = "filterlists:staledata" filterlistsStaleDataSurvived = "filterlists:staledata"
filterlistsUpdateInProgress = "filterlists:update-in-progress"
) )
// booleans mainly used to decouple the module // booleans mainly used to decouple the module

View file

@ -24,17 +24,17 @@ func getEntityRecordByKey(key string) (*entityRecord, error) {
} }
if r.IsWrapped() { if r.IsWrapped() {
new := &entityRecord{} newER := &entityRecord{}
if err := record.Unwrap(r, new); err != nil { if err := record.Unwrap(r, newER); err != nil {
return nil, err return nil, err
} }
return new, nil return newER, nil
} }
e, ok := r.(*entityRecord) newER, ok := r.(*entityRecord)
if !ok { if !ok {
return nil, fmt.Errorf("record not of type *entityRecord, but %T", r) return nil, fmt.Errorf("record not of type *entityRecord, but %T", r)
} }
return e, nil return newER, nil
} }

View file

@ -2,17 +2,19 @@ package filterlists
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sort" "sort"
"time" "time"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/tevino/abool"
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/database/query" "github.com/safing/portbase/database/query"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/updater" "github.com/safing/portbase/updater"
"github.com/tevino/abool"
) )
var updateInProgress = abool.New() var updateInProgress = abool.New()
@ -21,7 +23,6 @@ var updateInProgress = abool.New()
// error state is correctly set or resolved. // error state is correctly set or resolved.
func tryListUpdate(ctx context.Context) error { func tryListUpdate(ctx context.Context) error {
err := performUpdate(ctx) err := performUpdate(ctx)
if err != nil { if err != nil {
// Check if the module already has a failure status set. If not, set a // Check if the module already has a failure status set. If not, set a
// generic one with the returned error. // generic one with the returned error.
@ -123,7 +124,7 @@ func performUpdate(ctx context.Context) error {
module.Warning( module.Warning(
filterlistsStaleDataSurvived, filterlistsStaleDataSurvived,
"Filter Lists May Overblock", "Filter Lists May Overblock",
fmt.Sprintf("The Portmaster failed to delete outdated filter list data. Filtering capabilities are fully available, but overblocking may occur. Error: %s", err.Error()), fmt.Sprintf("The Portmaster failed to delete outdated filter list data. Filtering capabilities are fully available, but overblocking may occur. Error: %s", err.Error()), //nolint:misspell // overblocking != overclocking
) )
return fmt.Errorf("failed to cleanup stale cache records: %w", err) return fmt.Errorf("failed to cleanup stale cache records: %w", err)
} }
@ -137,7 +138,7 @@ func performUpdate(ctx context.Context) error {
log.Infof("intel/filterlists: successfully migrated cache database to %s", highestVersion.Version()) log.Infof("intel/filterlists: successfully migrated cache database to %s", highestVersion.Version())
} }
// The list update suceeded, resolve any states. // The list update succeeded, resolve any states.
module.Resolve("") module.Resolve("")
return nil return nil
} }
@ -178,7 +179,7 @@ func getUpgradableFiles() ([]*updater.File, error) {
if intermediateFile == nil || intermediateFile.UpgradeAvailable() || !cacheDBInUse { if intermediateFile == nil || intermediateFile.UpgradeAvailable() || !cacheDBInUse {
var err error var err error
intermediateFile, err = getFile(intermediateListFilePath) intermediateFile, err = getFile(intermediateListFilePath)
if err != nil && err != updater.ErrNotFound { if err != nil && !errors.Is(err, updater.ErrNotFound) {
return nil, err return nil, err
} }
@ -191,7 +192,7 @@ func getUpgradableFiles() ([]*updater.File, error) {
if urgentFile == nil || urgentFile.UpgradeAvailable() || !cacheDBInUse { if urgentFile == nil || urgentFile.UpgradeAvailable() || !cacheDBInUse {
var err error var err error
urgentFile, err = getFile(urgentListFilePath) urgentFile, err = getFile(urgentListFilePath)
if err != nil && err != updater.ErrNotFound { if err != nil && !errors.Is(err, updater.ErrNotFound) {
return nil, err return nil, err
} }
@ -216,7 +217,7 @@ func resolveUpdateOrder(updateOrder []*updater.File) ([]*updater.File, error) {
var err error var err error
cacheDBVersion, err = getCacheDatabaseVersion() cacheDBVersion, err = getCacheDatabaseVersion()
if err != nil { if err != nil {
if err != database.ErrNotFound { if !errors.Is(err, database.ErrNotFound) {
log.Errorf("intel/filterlists: failed to get cache database version: %s", err) log.Errorf("intel/filterlists: failed to get cache database version: %s", err)
} }
cacheDBVersion, _ = version.NewSemver("v0.0.0") cacheDBVersion, _ = version.NewSemver("v0.0.0")
@ -247,12 +248,14 @@ func resolveUpdateOrder(updateOrder []*updater.File) ([]*updater.File, error) {
type byAscVersion []*updater.File type byAscVersion []*updater.File
func (fs byAscVersion) Len() int { return len(fs) } func (fs byAscVersion) Len() int { return len(fs) }
func (fs byAscVersion) Less(i, j int) bool { func (fs byAscVersion) Less(i, j int) bool {
vi, _ := version.NewSemver(fs[i].Version()) vi, _ := version.NewSemver(fs[i].Version())
vj, _ := version.NewSemver(fs[j].Version()) vj, _ := version.NewSemver(fs[j].Version())
return vi.LessThan(vj) return vi.LessThan(vj)
} }
func (fs byAscVersion) Swap(i, j int) { func (fs byAscVersion) Swap(i, j int) {
fi := fs[i] fi := fs[i]
fj := fs[j] fj := fs[j]

View file

@ -58,7 +58,7 @@ func (ub *updateBroadcaster) ReplaceDatabase(db *geoIPDB) {
defer ub.rw.Unlock() defer ub.rw.Unlock()
if ub.db != nil { if ub.db != nil {
ub.db.Close() _ = ub.db.Close()
} }
ub.db = db ub.db = db
ub.notifyWaiters() ub.notifyWaiters()
@ -101,7 +101,7 @@ type updateWorker struct {
// waiting nil is returned. // waiting nil is returned.
func (upd *updateWorker) GetReader(v6 bool, wait bool) *maxminddb.Reader { func (upd *updateWorker) GetReader(v6 bool, wait bool) *maxminddb.Reader {
// check which updateBroadcaster we need to use // check which updateBroadcaster we need to use
var ub *updateBroadcaster = &upd.v4 ub := &upd.v4
if v6 { if v6 {
ub = &upd.v6 ub = &upd.v6
} }

View file

@ -5,8 +5,9 @@ import (
"net" "net"
"strings" "strings"
"github.com/safing/portbase/utils"
"github.com/umahmood/haversine" "github.com/umahmood/haversine"
"github.com/safing/portbase/utils"
) )
const ( const (
@ -27,6 +28,7 @@ type Location struct {
AutonomousSystemOrganization string `maxminddb:"autonomous_system_organization"` AutonomousSystemOrganization string `maxminddb:"autonomous_system_organization"`
} }
// Coordinates holds geographic coordinates and their estimated accuracy.
type Coordinates struct { type Coordinates struct {
AccuracyRadius uint16 `maxminddb:"accuracy_radius"` AccuracyRadius uint16 `maxminddb:"accuracy_radius"`
Latitude float64 `maxminddb:"latitude"` Latitude float64 `maxminddb:"latitude"`
@ -199,6 +201,8 @@ var unknownASOrgNames = []string{
"undefined", // Programmatic unknown value. "undefined", // Programmatic unknown value.
} }
// ASOrgUnknown return whether the given AS Org string actually is meant to
// mean that the AS Org is unknown.
func ASOrgUnknown(asOrg string) bool { func ASOrgUnknown(asOrg string) bool {
return utils.StringInSlice( return utils.StringInSlice(
unknownASOrgNames, unknownASOrgNames,

View file

@ -6,6 +6,8 @@ import (
) )
func TestPrimitiveNetworkProximity(t *testing.T) { func TestPrimitiveNetworkProximity(t *testing.T) {
t.Parallel()
ip4_1 := net.ParseIP("1.1.1.1") ip4_1 := net.ParseIP("1.1.1.1")
ip4_2 := net.ParseIP("1.1.1.2") ip4_2 := net.ParseIP("1.1.1.2")
ip4_3 := net.ParseIP("255.255.255.0") ip4_3 := net.ParseIP("255.255.255.0")

View file

@ -6,6 +6,8 @@ import (
) )
func TestLocationLookup(t *testing.T) { func TestLocationLookup(t *testing.T) {
t.Parallel()
ip1 := net.ParseIP("81.2.69.142") ip1 := net.ParseIP("81.2.69.142")
loc1, err := GetLocation(ip1) loc1, err := GetLocation(ip1)
if err != nil { if err != nil {
@ -53,8 +55,8 @@ func TestLocationLookup(t *testing.T) {
dist3 := loc1.EstimateNetworkProximity(loc3) dist3 := loc1.EstimateNetworkProximity(loc3)
dist4 := loc1.EstimateNetworkProximity(loc4) dist4 := loc1.EstimateNetworkProximity(loc4)
t.Logf("proximity %s <> %s: %d", ip1, ip2, dist1) t.Logf("proximity %s <> %s: %.2f", ip1, ip2, dist1)
t.Logf("proximity %s <> %s: %d", ip2, ip3, dist2) t.Logf("proximity %s <> %s: %.2f", ip2, ip3, dist2)
t.Logf("proximity %s <> %s: %d", ip1, ip3, dist3) t.Logf("proximity %s <> %s: %.2f", ip1, ip3, dist3)
t.Logf("proximity %s <> %s: %d", ip1, ip4, dist4) t.Logf("proximity %s <> %s: %.2f", ip1, ip4, dist4)
} }

View file

@ -4,10 +4,8 @@ import (
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
) )
var ( // Module of this package. Export needed for testing of the endpoints package.
// Module of this package. Export needed for testing of the endpoints package. var Module *modules.Module
Module *modules.Module
)
func init() { func init() {
Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists") Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists")

View file

@ -4,9 +4,7 @@ import (
"context" "context"
) )
var ( var reverseResolver func(ctx context.Context, ip string, securityLevel uint8) (domain string, err error)
reverseResolver func(ctx context.Context, ip string, securityLevel uint8) (domain string, err error)
)
// SetReverseResolver allows the resolver module to register a function to allow reverse resolving IPs to domains. // SetReverseResolver allows the resolver module to register a function to allow reverse resolving IPs to domains.
func SetReverseResolver(fn func(ctx context.Context, ip string, securityLevel uint8) (domain string, err error)) { func SetReverseResolver(fn func(ctx context.Context, ip string, securityLevel uint8) (domain string, err error)) {

View file

@ -8,10 +8,8 @@ import (
"github.com/safing/portmaster/core" "github.com/safing/portmaster/core"
) )
// Config Keys // CfgDefaultNameserverAddressKey is the config key for the listen address..
const ( const CfgDefaultNameserverAddressKey = "dns/listenAddress"
CfgDefaultNameserverAddressKey = "dns/listenAddress"
)
var ( var (
defaultNameserverAddress = "localhost:53" defaultNameserverAddress = "localhost:53"

View file

@ -36,7 +36,7 @@ var (
failingQueriesNetworkChangedFlag = netenv.GetNetworkChangedFlag() failingQueriesNetworkChangedFlag = netenv.GetNetworkChangedFlag()
) )
func checkIfQueryIsFailing(q *resolver.Query) (failingErr error, failingUntil *time.Time) { func checkIfQueryIsFailing(q *resolver.Query) (failingUntil *time.Time, failingErr error) {
// If the network changed, reset the failed queries. // If the network changed, reset the failed queries.
if failingQueriesNetworkChangedFlag.IsSet() { if failingQueriesNetworkChangedFlag.IsSet() {
failingQueriesNetworkChangedFlag.Refresh() failingQueriesNetworkChangedFlag.Refresh()
@ -45,7 +45,7 @@ func checkIfQueryIsFailing(q *resolver.Query) (failingErr error, failingUntil *t
defer failingQueriesLock.Unlock() defer failingQueriesLock.Unlock()
// Compiler optimized map reset. // Compiler optimized map reset.
for key, _ := range failingQueries { for key := range failingQueries {
delete(failingQueries, key) delete(failingQueries, key)
} }
@ -72,7 +72,7 @@ func checkIfQueryIsFailing(q *resolver.Query) (failingErr error, failingUntil *t
} }
// Return failing error and until when it's valid. // Return failing error and until when it's valid.
return failing.Err, &failing.Until return &failing.Until, failing.Err
} }
func addFailingQuery(q *resolver.Query, err error) { func addFailingQuery(q *resolver.Query, err error) {

View file

@ -7,13 +7,13 @@ import (
"os" "os"
"strconv" "strconv"
"github.com/miekg/dns"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/modules/subsystems"
"github.com/safing/portmaster/firewall" "github.com/safing/portmaster/firewall"
"github.com/safing/portmaster/netenv" "github.com/safing/portmaster/netenv"
"github.com/miekg/dns"
) )
var ( var (
@ -69,32 +69,31 @@ func start() error {
} }
return dstIsMe return dstIsMe
}) })
} else {
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
return ip.Equal(ip1)
})
} }
} else {
// Dual listener.
dnsServer1 := startListener(ip1, port)
dnsServer2 := startListener(ip2, port)
stopListener = func() error {
// Shutdown both listeners.
err1 := dnsServer1.Shutdown()
err2 := dnsServer2.Shutdown()
// Return first error.
if err1 != nil {
return err1
}
return err2
}
// Fast track dns queries destined for one of the listener IPs.
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool { return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
return ip.Equal(ip1) || ip.Equal(ip2) return ip.Equal(ip1)
}) })
} }
// Dual listener.
dnsServer1 := startListener(ip1, port)
dnsServer2 := startListener(ip2, port)
stopListener = func() error {
// Shutdown both listeners.
err1 := dnsServer1.Shutdown()
err2 := dnsServer2.Shutdown()
// Return first error.
if err1 != nil {
return err1
}
return err2
}
// Fast track dns queries destined for one of the listener IPs.
return firewall.SetNameserverIPMatcher(func(ip net.IP) bool {
return ip.Equal(ip1) || ip.Equal(ip2)
})
} }
func startListener(ip net.IP, port uint16) *dns.Server { func startListener(ip net.IP, port uint16) *dns.Server {

View file

@ -8,6 +8,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/miekg/dns"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/firewall" "github.com/safing/portmaster/firewall"
"github.com/safing/portmaster/nameserver/nsutil" "github.com/safing/portmaster/nameserver/nsutil"
@ -15,8 +17,6 @@ import (
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/resolver" "github.com/safing/portmaster/resolver"
"github.com/miekg/dns"
) )
var hostname string var hostname string
@ -30,7 +30,7 @@ func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) {
} }
} }
func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:gocognit // TODO func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) error { //nolint:maintidx // TODO
// Record metrics. // Record metrics.
startTime := time.Now() startTime := time.Now()
defer requestsHistogram.UpdateDuration(startTime) defer requestsHistogram.UpdateDuration(startTime)
@ -113,7 +113,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
// will fail with a very high probability, it is beneficial to just kill the // will fail with a very high probability, it is beneficial to just kill the
// query for some time before doing any expensive work. // query for some time before doing any expensive work.
defer cleanFailingQueries(10, 3) defer cleanFailingQueries(10, 3)
failingErr, failingUntil := checkIfQueryIsFailing(q) failingUntil, failingErr := checkIfQueryIsFailing(q)
if failingErr != nil { if failingErr != nil {
remainingFailingDuration := time.Until(*failingUntil) remainingFailingDuration := time.Until(*failingUntil)
tracer.Debugf("nameserver: returning previous error for %s: %s", q.ID(), failingErr) tracer.Debugf("nameserver: returning previous error for %s: %s", q.ID(), failingErr)
@ -205,6 +205,8 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
network.SaveOpenDNSRequest(q, rrCache, conn) network.SaveOpenDNSRequest(q, rrCache, conn)
firewall.UpdateIPsAndCNAMEs(q, rrCache, conn) firewall.UpdateIPsAndCNAMEs(q, rrCache, conn)
case network.VerdictUndeterminable:
fallthrough
default: default:
tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn) tracer.Warningf("nameserver: unexpected verdict %s for connection %s, not saving", conn.Verdict, conn)
} }
@ -224,7 +226,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
} }
// Check if there is a Verdict to act upon. // Check if there is a Verdict to act upon.
switch conn.Verdict { switch conn.Verdict { //nolint:exhaustive // Only checking for specific values.
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed: case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed:
tracer.Infof( tracer.Infof(
"nameserver: returning %s response for %s to %s", "nameserver: returning %s response for %s to %s",
@ -289,7 +291,7 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
} }
// Check if there is a Verdict to act upon. // Check if there is a Verdict to act upon.
switch conn.Verdict { switch conn.Verdict { //nolint:exhaustive // Only checking for specific values.
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed: case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed:
tracer.Infof( tracer.Infof(
"nameserver: returning %s response for %s to %s", "nameserver: returning %s response for %s to %s",

View file

@ -9,13 +9,12 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
) )
var ( // ErrNilRR is returned when a parsed RR is nil.
// ErrNilRR is returned when a parsed RR is nil. var ErrNilRR = errors.New("is nil")
ErrNilRR = errors.New("is nil")
)
// Responder defines the interface that any block/deny reason interface // Responder defines the interface that any block/deny reason interface
// may implement to support sending custom DNS responses for a given reason. // may implement to support sending custom DNS responses for a given reason.

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/nameserver/nsutil" "github.com/safing/portmaster/nameserver/nsutil"
) )

View file

@ -13,15 +13,13 @@ import (
"github.com/safing/portmaster/network/state" "github.com/safing/portmaster/network/state"
) )
var ( var commonResolverIPs = []net.IP{
commonResolverIPs = []net.IP{ net.IPv4zero,
net.IPv4zero, net.IPv4(127, 0, 0, 1), // default
net.IPv4(127, 0, 0, 1), // default net.IPv4(127, 0, 0, 53), // some resolvers on Linux
net.IPv4(127, 0, 0, 53), // some resolvers on Linux net.IPv6zero,
net.IPv6zero, net.IPv6loopback,
net.IPv6loopback, }
}
)
func checkForConflictingService(ip net.IP, port uint16) error { func checkForConflictingService(ip net.IP, port uint16) error {
// Evaluate which IPs to check. // Evaluate which IPs to check.
@ -34,6 +32,7 @@ func checkForConflictingService(ip net.IP, port uint16) error {
// Check if there is another resolver when need to take over. // Check if there is another resolver when need to take over.
var killed int var killed int
ipsToCheckLoop:
for _, resolverIP := range ipsToCheck { for _, resolverIP := range ipsToCheck {
pid, err := takeover(resolverIP, port) pid, err := takeover(resolverIP, port)
switch { switch {
@ -44,7 +43,7 @@ func checkForConflictingService(ip net.IP, port uint16) error {
case pid != 0: case pid != 0:
// Conflicting service identified and killed! // Conflicting service identified and killed!
killed = pid killed = pid
break break ipsToCheckLoop
} }
} }
@ -92,7 +91,7 @@ func takeover(resolverIP net.IP, resolverPort uint16) (int, error) {
}, true) }, true)
if err != nil { if err != nil {
// there may be nothing listening on :53 // there may be nothing listening on :53
return 0, nil return 0, nil //nolint:nilerr // Treat lookup error as "not found".
} }
// Just don't, uh, kill ourselves... // Just don't, uh, kill ourselves...

View file

@ -55,7 +55,7 @@ var (
myNetworks []*net.IPNet myNetworks []*net.IPNet
myNetworksLock sync.Mutex myNetworksLock sync.Mutex
myNetworksNetworkChangedFlag = GetNetworkChangedFlag() myNetworksNetworkChangedFlag = GetNetworkChangedFlag()
myNetworksRefreshError error myNetworksRefreshError error //nolint:errname // Not what the linter thinks this is for.
myNetworksRefreshFailingUntil time.Time myNetworksRefreshFailingUntil time.Time
) )
@ -63,7 +63,7 @@ var (
// Broadcast or multicast addresses will never match, even if valid in in use. // Broadcast or multicast addresses will never match, even if valid in in use.
func IsMyIP(ip net.IP) (yes bool, err error) { func IsMyIP(ip net.IP) (yes bool, err error) {
// Check for IPs that don't need extra checks. // Check for IPs that don't need extra checks.
switch netutils.GetIPScope(ip) { switch netutils.GetIPScope(ip) { //nolint:exhaustive // Only looking for specific values.
case netutils.HostLocal: case netutils.HostLocal:
return true, nil return true, nil
case netutils.LocalMulticast, netutils.GlobalMulticast: case netutils.LocalMulticast, netutils.GlobalMulticast:
@ -90,7 +90,7 @@ func IsMyIP(ip net.IP) (yes bool, err error) {
// Check if there was a recent error on the previous refresh. // Check if there was a recent error on the previous refresh.
if myNetworksRefreshError != nil && time.Now().Before(myNetworksRefreshFailingUntil) { if myNetworksRefreshError != nil && time.Now().Before(myNetworksRefreshFailingUntil) {
return false, fmt.Errorf("failed to previously refresh interface addresses: %s", myNetworksRefreshError) return false, fmt.Errorf("failed to previously refresh interface addresses: %w", myNetworksRefreshError)
} }
// Refresh assigned networks. // Refresh assigned networks.
@ -101,7 +101,7 @@ func IsMyIP(ip net.IP) (yes bool, err error) {
// literally over thousand goroutines wanting to try this again. // literally over thousand goroutines wanting to try this again.
myNetworksRefreshError = err myNetworksRefreshError = err
myNetworksRefreshFailingUntil = time.Now().Add(1 * time.Second) myNetworksRefreshFailingUntil = time.Now().Add(1 * time.Second)
return false, fmt.Errorf("failed to refresh interface addresses: %s", err) return false, fmt.Errorf("failed to refresh interface addresses: %w", err)
} }
myNetworks = make([]*net.IPNet, 0, len(interfaceNetworks)) myNetworks = make([]*net.IPNet, 0, len(interfaceNetworks))
for _, ifNet := range interfaceNetworks { for _, ifNet := range interfaceNetworks {

View file

@ -5,6 +5,8 @@ import (
) )
func TestGetAssignedAddresses(t *testing.T) { func TestGetAssignedAddresses(t *testing.T) {
t.Parallel()
ipv4, ipv6, err := GetAssignedAddresses() ipv4, ipv6, err := GetAssignedAddresses()
t.Logf("all v4: %v", ipv4) t.Logf("all v4: %v", ipv4)
t.Logf("all v6: %v", ipv6) t.Logf("all v6: %v", ipv6)
@ -17,6 +19,8 @@ func TestGetAssignedAddresses(t *testing.T) {
} }
func TestGetAssignedGlobalAddresses(t *testing.T) { func TestGetAssignedGlobalAddresses(t *testing.T) {
t.Parallel()
ipv4, ipv6, err := GetAssignedGlobalAddresses() ipv4, ipv6, err := GetAssignedGlobalAddresses()
t.Logf("all global v4: %v", ipv4) t.Logf("all global v4: %v", ipv4)
t.Logf("all global v6: %v", ipv6) t.Logf("all global v6: %v", ipv6)

View file

@ -1,4 +1,4 @@
// +build !server // go:build !server
package netenv package netenv
@ -8,9 +8,9 @@ import (
"net" "net"
"sync" "sync"
"github.com/safing/portbase/log"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
"github.com/safing/portbase/log"
) )
var ( var (
@ -36,7 +36,7 @@ func getNameserversFromDbus() ([]Nameserver, error) { //nolint:gocognit // TODO
primaryConnectionVariant, err := getNetworkManagerProperty(dbusConn, dbus.ObjectPath("/org/freedesktop/NetworkManager"), "org.freedesktop.NetworkManager.PrimaryConnection") primaryConnectionVariant, err := getNetworkManagerProperty(dbusConn, dbus.ObjectPath("/org/freedesktop/NetworkManager"), "org.freedesktop.NetworkManager.PrimaryConnection")
if err != nil { if err != nil {
return nil, fmt.Errorf("dbus: failed to access NetworkManager.PrimaryConnection: %s", err) return nil, fmt.Errorf("dbus: failed to access NetworkManager.PrimaryConnection: %w", err)
} }
primaryConnection, ok := primaryConnectionVariant.Value().(dbus.ObjectPath) primaryConnection, ok := primaryConnectionVariant.Value().(dbus.ObjectPath)
if !ok { if !ok {
@ -45,7 +45,7 @@ func getNameserversFromDbus() ([]Nameserver, error) { //nolint:gocognit // TODO
activeConnectionsVariant, err := getNetworkManagerProperty(dbusConn, dbus.ObjectPath("/org/freedesktop/NetworkManager"), "org.freedesktop.NetworkManager.ActiveConnections") activeConnectionsVariant, err := getNetworkManagerProperty(dbusConn, dbus.ObjectPath("/org/freedesktop/NetworkManager"), "org.freedesktop.NetworkManager.ActiveConnections")
if err != nil { if err != nil {
return nil, fmt.Errorf("dbus: failed to access NetworkManager.ActiveConnections: %s", err) return nil, fmt.Errorf("dbus: failed to access NetworkManager.ActiveConnections: %w", err)
} }
activeConnections, ok := activeConnectionsVariant.Value().([]dbus.ObjectPath) activeConnections, ok := activeConnectionsVariant.Value().([]dbus.ObjectPath)
if !ok { if !ok {
@ -60,18 +60,18 @@ func getNameserversFromDbus() ([]Nameserver, error) { //nolint:gocognit // TODO
} }
for _, activeConnection := range sortedConnections { for _, activeConnection := range sortedConnections {
new, err := dbusGetInterfaceNameservers(dbusConn, activeConnection, 4) newNameservers, err := dbusGetInterfaceNameservers(dbusConn, activeConnection, 4)
if err != nil { if err != nil {
log.Warningf("failed to get nameserver: %s", err) log.Warningf("failed to get nameserver: %s", err)
} else { } else {
ns = append(ns, new...) ns = append(ns, newNameservers...)
} }
new, err = dbusGetInterfaceNameservers(dbusConn, activeConnection, 6) newNameservers, err = dbusGetInterfaceNameservers(dbusConn, activeConnection, 6)
if err != nil { if err != nil {
log.Warningf("failed to get nameserver: %s", err) log.Warningf("failed to get nameserver: %s", err)
} else { } else {
ns = append(ns, new...) ns = append(ns, newNameservers...)
} }
} }
@ -87,7 +87,7 @@ func dbusGetInterfaceNameservers(dbusConn *dbus.Conn, interfaceObject dbus.Objec
// Get Interface Configuration. // Get Interface Configuration.
ipConfigVariant, err := getNetworkManagerProperty(dbusConn, interfaceObject, ipConfigPropertyKey) ipConfigVariant, err := getNetworkManagerProperty(dbusConn, interfaceObject, ipConfigPropertyKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to access %s:%s: %s", interfaceObject, ipConfigPropertyKey, err) return nil, fmt.Errorf("failed to access %s:%s: %w", interfaceObject, ipConfigPropertyKey, err)
} }
ipConfig, ok := ipConfigVariant.Value().(dbus.ObjectPath) ipConfig, ok := ipConfigVariant.Value().(dbus.ObjectPath)
if !ok { if !ok {
@ -102,7 +102,7 @@ func dbusGetInterfaceNameservers(dbusConn *dbus.Conn, interfaceObject dbus.Objec
// Get Nameserver IPs // Get Nameserver IPs
nameserverIPsVariant, err := getNetworkManagerProperty(dbusConn, ipConfig, nameserversIPsPropertyKey) nameserverIPsVariant, err := getNetworkManagerProperty(dbusConn, ipConfig, nameserversIPsPropertyKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to access %s:%s: %s", ipConfig, nameserversIPsPropertyKey, err) return nil, fmt.Errorf("failed to access %s:%s: %w", ipConfig, nameserversIPsPropertyKey, err)
} }
var nameserverIPs []net.IP var nameserverIPs []net.IP
switch ipVersion { switch ipVersion {
@ -134,7 +134,7 @@ func dbusGetInterfaceNameservers(dbusConn *dbus.Conn, interfaceObject dbus.Objec
// Get Nameserver Domains // Get Nameserver Domains
nameserverDomainsVariant, err := getNetworkManagerProperty(dbusConn, ipConfig, nameserversDomainsPropertyKey) nameserverDomainsVariant, err := getNetworkManagerProperty(dbusConn, ipConfig, nameserversDomainsPropertyKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to access %s:%s: %s", ipConfig, nameserversDomainsPropertyKey, err) return nil, fmt.Errorf("failed to access %s:%s: %w", ipConfig, nameserversDomainsPropertyKey, err)
} }
nameserverDomains, ok := nameserverDomainsVariant.Value().([]string) nameserverDomains, ok := nameserverDomainsVariant.Value().([]string)
if !ok { if !ok {
@ -144,7 +144,7 @@ func dbusGetInterfaceNameservers(dbusConn *dbus.Conn, interfaceObject dbus.Objec
// Get Nameserver Searches // Get Nameserver Searches
nameserverSearchesVariant, err := getNetworkManagerProperty(dbusConn, ipConfig, nameserversSearchesPropertyKey) nameserverSearchesVariant, err := getNetworkManagerProperty(dbusConn, ipConfig, nameserversSearchesPropertyKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to access %s:%s: %s", ipConfig, nameserversSearchesPropertyKey, err) return nil, fmt.Errorf("failed to access %s:%s: %w", ipConfig, nameserversSearchesPropertyKey, err)
} }
nameserverSearches, ok := nameserverSearchesVariant.Value().([]string) nameserverSearches, ok := nameserverSearchesVariant.Value().([]string)
if !ok { if !ok {
@ -152,7 +152,7 @@ func dbusGetInterfaceNameservers(dbusConn *dbus.Conn, interfaceObject dbus.Objec
} }
ns := make([]Nameserver, 0, len(nameserverIPs)) ns := make([]Nameserver, 0, len(nameserverIPs))
searchDomains := append(nameserverDomains, nameserverSearches...) searchDomains := append(nameserverDomains, nameserverSearches...) //nolint:gocritic
for _, nameserverIP := range nameserverIPs { for _, nameserverIP := range nameserverIPs {
ns = append(ns, Nameserver{ ns = append(ns, Nameserver{
IP: nameserverIP, IP: nameserverIP,

View file

@ -6,6 +6,8 @@ import (
) )
func TestDbus(t *testing.T) { func TestDbus(t *testing.T) {
t.Parallel()
if testing.Short() { if testing.Short() {
t.Skip("skipping test in short mode because it fails in the CI") t.Skip("skipping test in short mode because it fails in the CI")
} }

View file

@ -2,9 +2,7 @@ package netenv
import "net" import "net"
var ( var localAddrFactory func(network string) net.Addr
localAddrFactory func(network string) net.Addr
)
// SetLocalAddrFactory supplies the environment package with a function to get permitted local addresses for connections. // SetLocalAddrFactory supplies the environment package with a function to get permitted local addresses for connections.
func SetLocalAddrFactory(laf func(network string) net.Addr) { func SetLocalAddrFactory(laf func(network string) net.Addr) {

View file

@ -43,7 +43,9 @@ func Gateways() []net.IP {
log.Warningf("environment: could not read /proc/net/route: %s", err) log.Warningf("environment: could not read /proc/net/route: %s", err)
return gateways return gateways
} }
defer route.Close() defer func() {
_ = route.Close()
}()
// file scanner // file scanner
scanner := bufio.NewScanner(route) scanner := bufio.NewScanner(route)
@ -76,7 +78,9 @@ func Gateways() []net.IP {
log.Warningf("environment: could not read /proc/net/ipv6_route: %s", err) log.Warningf("environment: could not read /proc/net/ipv6_route: %s", err)
return gateways return gateways
} }
defer v6route.Close() defer func() {
_ = v6route.Close()
}()
// file scanner // file scanner
scanner = bufio.NewScanner(v6route) scanner = bufio.NewScanner(v6route)
@ -149,7 +153,9 @@ func getNameserversFromResolvconf() ([]Nameserver, error) {
log.Warningf("environment: could not read /etc/resolv.conf: %s", err) log.Warningf("environment: could not read /etc/resolv.conf: %s", err)
return nil, err return nil, err
} }
defer resolvconf.Close() defer func() {
_ = resolvconf.Close()
}()
// file scanner // file scanner
scanner := bufio.NewScanner(resolvconf) scanner := bufio.NewScanner(resolvconf)
@ -186,7 +192,6 @@ func getNameserversFromResolvconf() ([]Nameserver, error) {
}) })
} }
return nameservers, nil return nameservers, nil
} }
func addNameservers(nameservers, newNameservers []Nameserver) []Nameserver { func addNameservers(nameservers, newNameservers []Nameserver) []Nameserver {

View file

@ -3,6 +3,8 @@ package netenv
import "testing" import "testing"
func TestLinuxEnvironment(t *testing.T) { func TestLinuxEnvironment(t *testing.T) {
t.Parallel()
nameserversTest, err := getNameserversFromResolvconf() nameserversTest, err := getNameserversFromResolvconf()
if err != nil { if err != nil {
t.Errorf("failed to get namerservers from resolvconf: %s", err) t.Errorf("failed to get namerservers from resolvconf: %s", err)

View file

@ -3,6 +3,8 @@ package netenv
import "testing" import "testing"
func TestEnvironment(t *testing.T) { func TestEnvironment(t *testing.T) {
t.Parallel()
nameserversTest := Nameservers() nameserversTest := Nameservers()
t.Logf("nameservers: %+v", nameserversTest) t.Logf("nameservers: %+v", nameserversTest)

View file

@ -9,10 +9,10 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/google/gopacket/layers"
"golang.org/x/net/icmp" "golang.org/x/net/icmp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"github.com/google/gopacket/layers"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portbase/rng" "github.com/safing/portbase/rng"
"github.com/safing/portmaster/intel/geoip" "github.com/safing/portmaster/intel/geoip"
@ -41,10 +41,12 @@ func prepLocation() (err error) {
return err return err
} }
// DeviceLocations holds multiple device locations.
type DeviceLocations struct { type DeviceLocations struct {
All []*DeviceLocation All []*DeviceLocation
} }
// Best returns the best (most accurate) device location.
func (dl *DeviceLocations) Best() *DeviceLocation { func (dl *DeviceLocations) Best() *DeviceLocation {
if len(dl.All) > 0 { if len(dl.All) > 0 {
return dl.All[0] return dl.All[0]
@ -52,6 +54,7 @@ func (dl *DeviceLocations) Best() *DeviceLocation {
return nil return nil
} }
// BestV4 returns the best (most accurate) IPv4 device location.
func (dl *DeviceLocations) BestV4() *DeviceLocation { func (dl *DeviceLocations) BestV4() *DeviceLocation {
for _, loc := range dl.All { for _, loc := range dl.All {
if loc.IPVersion == packet.IPv4 { if loc.IPVersion == packet.IPv4 {
@ -61,6 +64,7 @@ func (dl *DeviceLocations) BestV4() *DeviceLocation {
return nil return nil
} }
// BestV6 returns the best (most accurate) IPv6 device location.
func (dl *DeviceLocations) BestV6() *DeviceLocation { func (dl *DeviceLocations) BestV6() *DeviceLocation {
for _, loc := range dl.All { for _, loc := range dl.All {
if loc.IPVersion == packet.IPv6 { if loc.IPVersion == packet.IPv6 {
@ -129,6 +133,7 @@ func (dl *DeviceLocation) IsMoreAccurateThan(other *DeviceLocation) bool {
return false return false
} }
// LocationOrNil or returns the geoip location, or nil if not present.
func (dl *DeviceLocation) LocationOrNil() *geoip.Location { func (dl *DeviceLocation) LocationOrNil() *geoip.Location {
if dl == nil { if dl == nil {
return nil return nil
@ -147,8 +152,10 @@ func (dl *DeviceLocation) String() string {
} }
} }
// DeviceLocationSource is a location source.
type DeviceLocationSource string type DeviceLocationSource string
// Location Sources.
const ( const (
SourceInterface DeviceLocationSource = "interface" SourceInterface DeviceLocationSource = "interface"
SourcePeer DeviceLocationSource = "peer" SourcePeer DeviceLocationSource = "peer"
@ -158,6 +165,7 @@ const (
SourceOther DeviceLocationSource = "other" SourceOther DeviceLocationSource = "other"
) )
// Accuracy returns the location accuracy of the source.
func (dls DeviceLocationSource) Accuracy() int { func (dls DeviceLocationSource) Accuracy() int {
switch dls { switch dls {
case SourceInterface: case SourceInterface:
@ -183,6 +191,7 @@ func (a sortLocationsByAccuracy) Len() int { return len(a) }
func (a sortLocationsByAccuracy) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a sortLocationsByAccuracy) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a sortLocationsByAccuracy) Less(i, j int) bool { return !a[j].IsMoreAccurateThan(a[i]) } func (a sortLocationsByAccuracy) Less(i, j int) bool { return !a[j].IsMoreAccurateThan(a[i]) }
// SetInternetLocation provides the location management system with a possible Internet location.
func SetInternetLocation(ip net.IP, source DeviceLocationSource) (dl *DeviceLocation, ok bool) { func SetInternetLocation(ip net.IP, source DeviceLocationSource) (dl *DeviceLocation, ok bool) {
// Check if IP is global. // Check if IP is global.
if netutils.GetIPScope(ip) != netutils.Global { if netutils.GetIPScope(ip) != netutils.Global {
@ -206,9 +215,8 @@ func SetInternetLocation(ip net.IP, source DeviceLocationSource) (dl *DeviceLoca
if err != nil { if err != nil {
log.Warningf("netenv: failed to get geolocation data of %s (from %s): %s", ip, source, err) log.Warningf("netenv: failed to get geolocation data of %s (from %s): %s", ip, source, err)
return nil, false return nil, false
} else {
loc.Location = geoLoc
} }
loc.Location = geoLoc
addLocation(loc) addLocation(loc)
return loc, true return loc, true
@ -242,7 +250,8 @@ func addLocation(dl *DeviceLocation) {
sort.Sort(sortLocationsByAccuracy(locations.All)) sort.Sort(sortLocationsByAccuracy(locations.All))
} }
// DEPRECATED: Please use GetInternetLocation instead. // GetApproximateInternetLocation returns the approximate Internet location.
// Deprecated: Please use GetInternetLocation instead.
func GetApproximateInternetLocation() (net.IP, error) { func GetApproximateInternetLocation() (net.IP, error) {
loc, ok := GetInternetLocation() loc, ok := GetInternetLocation()
if !ok || loc.Best() == nil { if !ok || loc.Best() == nil {
@ -251,6 +260,7 @@ func GetApproximateInternetLocation() (net.IP, error) {
return loc.Best().IP, nil return loc.Best().IP, nil
} }
// GetInternetLocation returns the possible device locations.
func GetInternetLocation() (deviceLocations *DeviceLocations, ok bool) { func GetInternetLocation() (deviceLocations *DeviceLocations, ok bool) {
gettingLocationsLock.Lock() gettingLocationsLock.Lock()
defer gettingLocationsLock.Unlock() defer gettingLocationsLock.Unlock()
@ -323,7 +333,7 @@ func getLocationFromInterfaces() (v4ok, v6ok bool) {
func getLocationFromUPnP() (ok bool) { func getLocationFromUPnP() (ok bool) {
// Endoint: urn:schemas-upnp-org:service:WANIPConnection:1#GetExternalIPAddress // Endoint: urn:schemas-upnp-org:service:WANIPConnection:1#GetExternalIPAddress
// A first test showed that a router did offer that endpoint, but did not // A first test showed that a router did offer that endpoint, but did not
// return an IP addres. // return an IP address.
return false return false
} }
*/ */
@ -332,14 +342,14 @@ func getLocationFromTraceroute() (dl *DeviceLocation, err error) {
// Create connection. // Create connection.
conn, err := net.ListenPacket("ip4:icmp", "") conn, err := net.ListenPacket("ip4:icmp", "")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open icmp conn: %s", err) return nil, fmt.Errorf("failed to open icmp conn: %w", err)
} }
v4Conn := ipv4.NewPacketConn(conn) v4Conn := ipv4.NewPacketConn(conn)
// Generate a random ID for the ICMP packets. // Generate a random ID for the ICMP packets.
generatedID, err := rng.Number(0xFFFF) // uint16 generatedID, err := rng.Number(0xFFFF) // uint16
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate icmp msg ID: %s", err) return nil, fmt.Errorf("failed to generate icmp msg ID: %w", err)
} }
msgID := int(generatedID) msgID := int(generatedID)
var msgSeq int var msgSeq int
@ -368,28 +378,27 @@ nextHop:
for j := 1; j <= 2; j++ { // Try every hop twice. for j := 1; j <= 2; j++ { // Try every hop twice.
// Increase sequence number. // Increase sequence number.
msgSeq++ msgSeq++
pingMessage.Body.(*icmp.Echo).Seq = msgSeq pingMessage.Body.(*icmp.Echo).Seq = msgSeq //nolint:forcetypeassert // Can only be *icmp.Echo.
// Make packet data. // Make packet data.
pingPacket, err := pingMessage.Marshal(nil) pingPacket, err := pingMessage.Marshal(nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to build icmp packet: %s", err) return nil, fmt.Errorf("failed to build icmp packet: %w", err)
} }
// Set TTL on IP packet. // Set TTL on IP packet.
err = v4Conn.SetTTL(i) err = v4Conn.SetTTL(i)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set icmp packet TTL: %s", err) return nil, fmt.Errorf("failed to set icmp packet TTL: %w", err)
} }
// Send ICMP packet. // Send ICMP packet.
if _, err := conn.WriteTo(pingPacket, locationTestingIPv4Addr); err != nil { if _, err := conn.WriteTo(pingPacket, locationTestingIPv4Addr); err != nil {
if neterr, ok := err.(*net.OpError); ok { var opErr *net.OpError
if neterr.Err == syscall.ENOBUFS { if errors.As(err, &opErr) && errors.Is(opErr.Err, syscall.ENOBUFS) {
continue continue
}
} }
return nil, fmt.Errorf("failed to send icmp packet: %s", err) return nil, fmt.Errorf("failed to send icmp packet: %w", err)
} }
// Listen for replies of the ICMP packet. // Listen for replies of the ICMP packet.
@ -433,7 +442,7 @@ nextHop:
continue listen continue listen
} }
// Check if the ID and sequence match. // Check if the ID and sequence match.
if originalEcho.ID != int(msgID) { if originalEcho.ID != msgID {
continue listen continue listen
} }
if originalEcho.Seq < minSeq { if originalEcho.Seq < minSeq {
@ -469,8 +478,8 @@ nextHop:
} }
func recvICMP(currentHop int, icmpPacketsViaFirewall chan packet.Packet) ( func recvICMP(currentHop int, icmpPacketsViaFirewall chan packet.Packet) (
remoteIP net.IP, imcpPacket *layers.ICMPv4, ok bool) { remoteIP net.IP, imcpPacket *layers.ICMPv4, ok bool,
) {
for { for {
select { select {
case pkt := <-icmpPacketsViaFirewall: case pkt := <-icmpPacketsViaFirewall:
@ -496,7 +505,7 @@ func recvICMP(currentHop int, icmpPacketsViaFirewall chan packet.Packet) (
} }
} }
func getLocationFromTimezone(ipVersion packet.IPVersion) (ok bool) { func getLocationFromTimezone(ipVersion packet.IPVersion) (ok bool) { //nolint:unparam // This is documentation.
// Create base struct. // Create base struct.
tzLoc := &DeviceLocation{ tzLoc := &DeviceLocation{
IPVersion: ipVersion, IPVersion: ipVersion,

View file

@ -1,9 +1,9 @@
//+build !windows // go:build !windows
package netenv package netenv
import "net" import "net"
func newICMPListener(_ string) (net.PacketConn, error) { func newICMPListener(_ string) (net.PacketConn, error) { //nolint:unused,deadcode // TODO: clean with Windows code later.
return net.ListenPacket("ip4:icmp", "0.0.0.0") return net.ListenPacket("ip4:icmp", "0.0.0.0")
} }

View file

@ -5,15 +5,15 @@ import (
"testing" "testing"
) )
var ( var privileged bool
privileged bool
)
func init() { func init() {
flag.BoolVar(&privileged, "privileged", false, "run tests that require root/admin privileges") flag.BoolVar(&privileged, "privileged", false, "run tests that require root/admin privileges")
} }
func TestGetInternetLocation(t *testing.T) { func TestGetInternetLocation(t *testing.T) {
t.Parallel()
if testing.Short() { if testing.Short() {
t.Skip() t.Skip()
} }

View file

@ -4,16 +4,14 @@ import (
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
) )
// Event Names // Event Names.
const ( const (
ModuleName = "netenv" ModuleName = "netenv"
NetworkChangedEvent = "network changed" NetworkChangedEvent = "network changed"
OnlineStatusChangedEvent = "online status changed" OnlineStatusChangedEvent = "online status changed"
) )
var ( var module *modules.Module
module *modules.Module
)
func init() { func init() {
module = modules.Register(ModuleName, prep, start, nil) module = modules.Register(ModuleName, prep, start, nil)

View file

@ -3,7 +3,7 @@ package netenv
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/sha1" //nolint:gosec // not used for security "crypto/sha1"
"io" "io"
"net" "net"
"time" "time"
@ -17,6 +17,7 @@ var (
networkChangedBroadcastFlag = utils.NewBroadcastFlag() networkChangedBroadcastFlag = utils.NewBroadcastFlag()
) )
// GetNetworkChangedFlag returns a flag to be notified about a network change.
func GetNetworkChangedFlag() *utils.Flag { func GetNetworkChangedFlag() *utils.Flag {
return networkChangedBroadcastFlag.NewFlag() return networkChangedBroadcastFlag.NewFlag()
} }

View file

@ -2,6 +2,7 @@ package netenv
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -21,7 +22,7 @@ import (
// OnlineStatus represent a state of connectivity to the Internet. // OnlineStatus represent a state of connectivity to the Internet.
type OnlineStatus uint8 type OnlineStatus uint8
// Online Status Values // Online Status Values.
const ( const (
StatusUnknown OnlineStatus = 0 StatusUnknown OnlineStatus = 0
StatusOffline OnlineStatus = 1 StatusOffline OnlineStatus = 1
@ -31,7 +32,7 @@ const (
StatusOnline OnlineStatus = 5 StatusOnline OnlineStatus = 5
) )
// Online Status and Resolver // Online Status and Resolver.
var ( var (
PortalTestIP = net.IPv4(192, 0, 2, 1) PortalTestIP = net.IPv4(192, 0, 2, 1)
PortalTestURL = fmt.Sprintf("http://%s/", PortalTestIP) PortalTestURL = fmt.Sprintf("http://%s/", PortalTestIP)
@ -124,8 +125,6 @@ func IsConnectivityDomain(domain string) bool {
func (os OnlineStatus) String() string { func (os OnlineStatus) String() string {
switch os { switch os {
default:
return "Unknown"
case StatusOffline: case StatusOffline:
return "Offline" return "Offline"
case StatusLimited: case StatusLimited:
@ -136,6 +135,10 @@ func (os OnlineStatus) String() string {
return "SemiOnline" return "SemiOnline"
case StatusOnline: case StatusOnline:
return "Online" return "Online"
case StatusUnknown:
fallthrough
default:
return "Unknown"
} }
} }
@ -175,7 +178,7 @@ func GetOnlineStatus() OnlineStatus {
return OnlineStatus(atomic.LoadInt32(onlineStatus)) return OnlineStatus(atomic.LoadInt32(onlineStatus))
} }
// CheckAndGetOnlineStatus triggers a new online status check and returns the result // CheckAndGetOnlineStatus triggers a new online status check and returns the result.
func CheckAndGetOnlineStatus() OnlineStatus { func CheckAndGetOnlineStatus() OnlineStatus {
// trigger new investigation // trigger new investigation
triggerOnlineStatusInvestigation() triggerOnlineStatusInvestigation()
@ -225,7 +228,7 @@ func notifyOnlineStatus(status OnlineStatus) {
var eventID, title, message string var eventID, title, message string
// Check if status is worth notifying. // Check if status is worth notifying.
switch status { switch status { //nolint:exhaustive // Checking for selection only.
case StatusOffline: case StatusOffline:
eventID = "netenv:online-status:offline" eventID = "netenv:online-status:offline"
title = "Device is Offline" title = "Device is Offline"
@ -419,7 +422,7 @@ func checkOnlineStatus(ctx context.Context) {
} else { } else {
var lan bool var lan bool
for _, ip := range ipv4 { for _, ip := range ipv4 {
switch netutils.GetIPScope(ip) { switch netutils.GetIPScope(ip) { //nolint:exhaustive // Checking to specific values only.
case netutils.SiteLocal: case netutils.SiteLocal:
lan = true lan = true
case netutils.Global: case netutils.Global:
@ -429,7 +432,7 @@ func checkOnlineStatus(ctx context.Context) {
} }
} }
for _, ip := range ipv6 { for _, ip := range ipv6 {
switch netutils.GetIPScope(ip) { switch netutils.GetIPScope(ip) { //nolint:exhaustive // Checking to specific values only.
case netutils.SiteLocal, netutils.Global: case netutils.SiteLocal, netutils.Global:
// IPv6 global addresses are also used in local networks // IPv6 global addresses are also used in local networks
lan = true lan = true
@ -470,14 +473,16 @@ func checkOnlineStatus(ctx context.Context) {
response, err := client.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
nErr, ok := err.(net.Error) var netErr net.Error
if !ok || !nErr.Timeout() { if !errors.As(err, &netErr) || !netErr.Timeout() {
// Timeout is the expected error when there is no portal // Timeout is the expected error when there is no portal
log.Debugf("netenv: http portal test failed: %s", err) log.Debugf("netenv: http portal test failed: %s", err)
// TODO: discern between errors to detect StatusLimited // TODO: discern between errors to detect StatusLimited
} }
} else { } else {
defer response.Body.Close() defer func() {
_ = response.Body.Close()
}()
// Got a response, something is messing with the request // Got a response, something is messing with the request
// check location // check location

View file

@ -6,6 +6,8 @@ import (
) )
func TestCheckOnlineStatus(t *testing.T) { func TestCheckOnlineStatus(t *testing.T) {
t.Parallel()
checkOnlineStatus(context.Background()) checkOnlineStatus(context.Background())
t.Logf("online status: %s", GetOnlineStatus()) t.Logf("online status: %s", GetOnlineStatus())
t.Logf("captive portal: %+v", GetCaptivePortal()) t.Logf("captive portal: %+v", GetCaptivePortal())

View file

@ -89,6 +89,7 @@ func debugInfo(ar *api.Request) (data []byte, err error) {
return di.Bytes(), nil return di.Bytes(), nil
} }
// AddNetworkDebugData adds the network debug data of the given profile to the debug data.
func AddNetworkDebugData(di *debug.Info, profile, where string) { func AddNetworkDebugData(di *debug.Info, profile, where string) {
// Prepend where prefix to query if necessary. // Prepend where prefix to query if necessary.
if where != "" && !strings.HasPrefix(where, "where ") { if where != "" && !strings.HasPrefix(where, "where ") {
@ -99,7 +100,7 @@ func AddNetworkDebugData(di *debug.Info, profile, where string) {
q, err := query.ParseQuery("query network: " + where) q, err := query.ParseQuery("query network: " + where)
if err != nil { if err != nil {
di.AddSection( di.AddSection(
fmt.Sprintf("Network: Debug Failed"), "Network: Debug Failed",
debug.NoFlags, debug.NoFlags,
fmt.Sprintf("Failed to build query: %s", err), fmt.Sprintf("Failed to build query: %s", err),
) )
@ -110,7 +111,7 @@ func AddNetworkDebugData(di *debug.Info, profile, where string) {
it, err := dbController.Query(q, true, true) it, err := dbController.Query(q, true, true)
if err != nil { if err != nil {
di.AddSection( di.AddSection(
fmt.Sprintf("Network: Debug Failed"), "Network: Debug Failed",
debug.NoFlags, debug.NoFlags,
fmt.Sprintf("Failed to run query: %s", err), fmt.Sprintf("Failed to run query: %s", err),
) )
@ -118,9 +119,11 @@ func AddNetworkDebugData(di *debug.Info, profile, where string) {
} }
// Collect matching connections. // Collect matching connections.
var debugConns []*Connection var ( //nolint:prealloc // We don't know the size.
var accepted int debugConns []*Connection
var total int accepted int
total int
)
for maybeConn := range it.Next { for maybeConn := range it.Next {
// Switch to correct type. // Switch to correct type.
conn, ok := maybeConn.(*Connection) conn, ok := maybeConn.(*Connection)
@ -149,7 +152,7 @@ func AddNetworkDebugData(di *debug.Info, profile, where string) {
// Count. // Count.
total++ total++
switch conn.Verdict { switch conn.Verdict { //nolint:exhaustive
case VerdictAccept, case VerdictAccept,
VerdictRerouteToNameserver, VerdictRerouteToNameserver,
VerdictRerouteToTunnel: VerdictRerouteToTunnel:

View file

@ -9,12 +9,16 @@ import (
) )
func TestDebugInfoLineFormatting(t *testing.T) { func TestDebugInfoLineFormatting(t *testing.T) {
t.Parallel()
for _, conn := range connectionTestData { for _, conn := range connectionTestData {
fmt.Println(conn.debugInfoLine()) fmt.Println(conn.debugInfoLine())
} }
} }
func TestDebugInfoFormatting(t *testing.T) { func TestDebugInfoFormatting(t *testing.T) {
t.Parallel()
fmt.Println(buildNetworkDebugInfoData(connectionTestData)) fmt.Println(buildNetworkDebugInfoData(connectionTestData))
} }

View file

@ -4,11 +4,9 @@ import (
"context" "context"
"time" "time"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/state"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/network/packet"
"github.com/safing/portmaster/network/state"
"github.com/safing/portmaster/process" "github.com/safing/portmaster/process"
) )
@ -41,7 +39,6 @@ func cleanConnections() (activePIDs map[int]struct{}) {
name := "clean connections" // TODO: change to new fn name := "clean connections" // TODO: change to new fn
_ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error { _ = module.RunMediumPriorityMicroTask(&name, func(ctx context.Context) error {
now := time.Now().UTC() now := time.Now().UTC()
nowUnix := now.Unix() nowUnix := now.Unix()
deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix() deleteOlderThan := now.Add(-deleteConnsAfterEndedThreshold).Unix()

View file

@ -44,13 +44,14 @@ type ProcessContext struct {
Source string Source string
} }
// ConnectionType is a type of connection.
type ConnectionType int8 type ConnectionType int8
// Connection Types.
const ( const (
Undefined ConnectionType = iota Undefined ConnectionType = iota
IPConnection IPConnection
DNSRequest DNSRequest
// ProxyRequest
) )
// Connection describes a distinct physical network connection // Connection describes a distinct physical network connection
@ -280,6 +281,7 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
return dnsConn return dnsConn
} }
// NewConnectionFromExternalDNSRequest returns a connection for an external DNS request.
func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cnames []string, connID string, remoteIP net.IP) (*Connection, error) { func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cnames []string, connID string, remoteIP net.IP) (*Connection, error) {
remoteHost, err := process.GetNetworkHost(ctx, remoteIP) remoteHost, err := process.GetNetworkHost(ctx, remoteIP)
if err != nil { if err != nil {
@ -336,7 +338,6 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
var dnsContext *resolver.DNSRequestContext var dnsContext *resolver.DNSRequestContext
if inbound { if inbound {
switch entity.IPScope { switch entity.IPScope {
case netutils.HostLocal: case netutils.HostLocal:
scope = IncomingHost scope = IncomingHost
@ -345,12 +346,11 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
case netutils.Global, netutils.GlobalMulticast: case netutils.Global, netutils.GlobalMulticast:
scope = IncomingInternet scope = IncomingInternet
case netutils.Invalid: case netutils.Undefined, netutils.Invalid:
fallthrough fallthrough
default: default:
scope = IncomingInvalid scope = IncomingInvalid
} }
} else { } else {
// check if we can find a domain for that IP // check if we can find a domain for that IP
@ -379,7 +379,6 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
} }
if scope == "" { if scope == "" {
// outbound direct (possibly P2P) connection // outbound direct (possibly P2P) connection
switch entity.IPScope { switch entity.IPScope {
case netutils.HostLocal: case netutils.HostLocal:
@ -389,12 +388,11 @@ func NewConnectionFromFirstPacket(pkt packet.Packet) *Connection {
case netutils.Global, netutils.GlobalMulticast: case netutils.Global, netutils.GlobalMulticast:
scope = PeerInternet scope = PeerInternet
case netutils.Invalid: case netutils.Undefined, netutils.Invalid:
fallthrough fallthrough
default: default:
scope = PeerInvalid scope = PeerInvalid
} }
} }
} }
@ -547,10 +545,10 @@ func (conn *Connection) Save() {
if !conn.KeyIsSet() { if !conn.KeyIsSet() {
if conn.Type == DNSRequest { if conn.Type == DNSRequest {
conn.SetKey(makeKey(conn.process.Pid, "dns", conn.ID)) conn.SetKey(makeKey(conn.process.Pid, dbScopeDNS, conn.ID))
dnsConns.add(conn) dnsConns.add(conn)
} else { } else {
conn.SetKey(makeKey(conn.process.Pid, "ip", conn.ID)) conn.SetKey(makeKey(conn.process.Pid, dbScopeIP, conn.ID))
conns.add(conn) conns.add(conn)
} }
} }
@ -597,7 +595,7 @@ func (conn *Connection) StopFirewallHandler() {
conn.pktQueue <- nil conn.pktQueue <- nil
} }
// HandlePacket queues packet of Link for handling // HandlePacket queues packet of Link for handling.
func (conn *Connection) HandlePacket(pkt packet.Packet) { func (conn *Connection) HandlePacket(pkt packet.Packet) {
conn.Lock() conn.Lock()
defer conn.Unlock() defer conn.Unlock()
@ -611,7 +609,7 @@ func (conn *Connection) HandlePacket(pkt packet.Packet) {
} }
} }
// packetHandler sequentially handles queued packets // packetHandler sequentially handles queued packets.
func (conn *Connection) packetHandler() { func (conn *Connection) packetHandler() {
for pkt := range conn.pktQueue { for pkt := range conn.pktQueue {
if pkt == nil { if pkt == nil {
@ -649,8 +647,8 @@ func (conn *Connection) GetActiveInspectors() []bool {
} }
// SetActiveInspectors sets the list of active inspectors. // SetActiveInspectors sets the list of active inspectors.
func (conn *Connection) SetActiveInspectors(new []bool) { func (conn *Connection) SetActiveInspectors(newInspectors []bool) {
conn.activeInspectors = new conn.activeInspectors = newInspectors
} }
// GetInspectorData returns the list of inspector data. // GetInspectorData returns the list of inspector data.
@ -659,8 +657,8 @@ func (conn *Connection) GetInspectorData() map[uint8]interface{} {
} }
// SetInspectorData set the list of inspector data. // SetInspectorData set the list of inspector data.
func (conn *Connection) SetInspectorData(new map[uint8]interface{}) { func (conn *Connection) SetInspectorData(newInspectorData map[uint8]interface{}) {
conn.inspectorData = new conn.inspectorData = newInspectorData
} }
// String returns a string representation of conn. // String returns a string representation of conn.

View file

@ -48,7 +48,7 @@ func (cs *connectionStore) clone() map[string]*Connection {
return m return m
} }
func (cs *connectionStore) len() int { func (cs *connectionStore) len() int { //nolint:unused // TODO: Clean up if still unused.
cs.rw.RLock() cs.rw.RLock()
defer cs.rw.RUnlock() defer cs.rw.RUnlock()

View file

@ -14,6 +14,12 @@ import (
"github.com/safing/portmaster/process" "github.com/safing/portmaster/process"
) )
const (
dbScopeNone = ""
dbScopeDNS = "dns"
dbScopeIP = "ip"
)
var ( var (
dbController *database.Controller dbController *database.Controller
@ -43,7 +49,7 @@ func parseDBKey(key string) (pid int, scope, id string, ok bool) {
// Split into segments. // Split into segments.
segments := strings.Split(key, "/") segments := strings.Split(key, "/")
// Check for valid prefix. // Check for valid prefix.
if !strings.HasPrefix("tree", segments[0]) { if segments[0] != "tree" {
return 0, "", "", false return 0, "", "", false
} }
@ -57,7 +63,7 @@ func parseDBKey(key string) (pid int, scope, id string, ok bool) {
scope = segments[2] scope = segments[2]
// Sanity check. // Sanity check.
switch scope { switch scope {
case "dns", "ip", "": case dbScopeNone, dbScopeDNS, dbScopeIP:
// Parsed id matches possible values. // Parsed id matches possible values.
// The empty string is for matching a trailing slash for in query prefix. // The empty string is for matching a trailing slash for in query prefix.
// TODO: For queries, also prefixes of these values are valid. // TODO: For queries, also prefixes of these values are valid.
@ -96,15 +102,15 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
} }
switch scope { switch scope {
case "dns": case dbScopeDNS:
if r, ok := dnsConns.get(id); ok { if r, ok := dnsConns.get(id); ok {
return r, nil return r, nil
} }
case "ip": case dbScopeIP:
if r, ok := conns.get(id); ok { if r, ok := conns.get(id); ok {
return r, nil return r, nil
} }
case "": case dbScopeNone:
if proc, ok := process.GetProcessFromStorage(pid); ok { if proc, ok := process.GetProcessFromStorage(pid); ok {
return proc, nil return proc, nil
} }
@ -147,7 +153,7 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
} }
} }
if scope == "" || scope == "dns" { if scope == dbScopeNone || scope == dbScopeDNS {
// dns scopes only // dns scopes only
for _, dnsConn := range dnsConns.clone() { for _, dnsConn := range dnsConns.clone() {
func() { func() {
@ -161,7 +167,7 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
} }
} }
if scope == "" || scope == "ip" { if scope == dbScopeNone || scope == dbScopeIP {
// connections // connections
for _, conn := range conns.clone() { for _, conn := range conns.clone() {
func() { func() {

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/nameserver/nsutil" "github.com/safing/portmaster/nameserver/nsutil"
"github.com/safing/portmaster/process" "github.com/safing/portmaster/process"
@ -17,16 +18,15 @@ import (
var ( var (
openDNSRequests = make(map[string]*Connection) // key: <pid>/fqdn openDNSRequests = make(map[string]*Connection) // key: <pid>/fqdn
openDNSRequestsLock sync.Mutex openDNSRequestsLock sync.Mutex
// scope prefix
unidentifiedProcessScopePrefix = strconv.Itoa(process.UnidentifiedProcessID) + "/"
) )
const ( const (
// write open dns requests every // writeOpenDNSRequestsTickDuration defines the interval in which open dns
// requests are written.
writeOpenDNSRequestsTickDuration = 5 * time.Second writeOpenDNSRequestsTickDuration = 5 * time.Second
// duration after which DNS requests without a following connection are logged // openDNSRequestLimit defines the duration after which DNS requests without
// a following connection are logged.
openDNSRequestLimit = 3 * time.Second openDNSRequestLimit = 3 * time.Second
) )
@ -122,6 +122,9 @@ func (conn *Connection) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns
return nil // Do not respond to request. return nil // Do not respond to request.
case VerdictFailed: case VerdictFailed:
return nsutil.BlockIP().ReplyWithDNS(ctx, request) return nsutil.BlockIP().ReplyWithDNS(ctx, request)
case VerdictUndecided, VerdictUndeterminable,
VerdictAccept, VerdictRerouteToNameserver, VerdictRerouteToTunnel:
fallthrough
default: default:
reply := nsutil.ServerFailure().ReplyWithDNS(ctx, request) reply := nsutil.ServerFailure().ReplyWithDNS(ctx, request)
nsutil.AddMessagesToReply(ctx, reply, log.ErrorLevel, "INTERNAL ERROR: incorrect use of Connection DNS Responder") nsutil.AddMessagesToReply(ctx, reply, log.ErrorLevel, "INTERNAL ERROR: incorrect use of Connection DNS Responder")
@ -136,6 +139,10 @@ func (conn *Connection) GetExtraRRs(ctx context.Context, request *dns.Msg) []dns
switch conn.Verdict { switch conn.Verdict {
case VerdictFailed: case VerdictFailed:
level = log.ErrorLevel level = log.ErrorLevel
case VerdictUndecided, VerdictUndeterminable,
VerdictAccept, VerdictBlock, VerdictDrop,
VerdictRerouteToNameserver, VerdictRerouteToTunnel:
fallthrough
default: default:
level = log.InfoLevel level = log.InfoLevel
} }

View file

@ -119,7 +119,7 @@ func (conn *Connection) addToMetrics() {
} }
// Check the verdict. // Check the verdict.
switch conn.Verdict { switch conn.Verdict { //nolint:exhaustive // Not critical.
case VerdictBlock, VerdictDrop: case VerdictBlock, VerdictDrop:
blockedOutConnCounter.Inc() blockedOutConnCounter.Inc()
conn.addedToMetrics = true conn.addedToMetrics = true

View file

@ -19,7 +19,7 @@ func IPFromAddr(addr net.Addr) (net.IP, error) {
// Parse via string. // Parse via string.
host, _, err := net.SplitHostPort(addr.String()) host, _, err := net.SplitHostPort(addr.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to split host and port of %q: %s", addr, err) return nil, fmt.Errorf("failed to split host and port of %q: %w", addr, err)
} }
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip == nil { if ip == nil {

View file

@ -8,19 +8,17 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
var ( var cleanDomainRegex = regexp.MustCompile(
cleanDomainRegex = regexp.MustCompile( `^` + // match beginning
`^` + // match beginning `(` + // start subdomain group
`(` + // start subdomain group `(xn--)?` + // idn prefix
`(xn--)?` + // idn prefix `[a-z0-9_-]{1,63}` + // main chunk
`[a-z0-9_-]{1,63}` + // main chunk `\.` + // ending with a dot
`\.` + // ending with a dot `)*` + // end subdomain group, allow any number of subdomains
`)*` + // end subdomain group, allow any number of subdomains `(xn--)?` + // TLD idn prefix
`(xn--)?` + // TLD idn prefix `[a-z0-9_-]{2,63}` + // TLD main chunk with at least two characters
`[a-z0-9_-]{2,63}` + // TLD main chunk with at least two characters `\.` + // ending with a dot
`\.` + // ending with a dot `$`, // match end
`$`, // match end
)
) )
// IsValidFqdn returns whether the given string is a valid fqdn. // IsValidFqdn returns whether the given string is a valid fqdn.

View file

@ -3,12 +3,16 @@ package netutils
import "testing" import "testing"
func testDomainValidity(t *testing.T, domain string, isValid bool) { func testDomainValidity(t *testing.T, domain string, isValid bool) {
t.Helper()
if IsValidFqdn(domain) != isValid { if IsValidFqdn(domain) != isValid {
t.Errorf("domain %s failed check: was valid=%v, expected valid=%v", domain, IsValidFqdn(domain), isValid) t.Errorf("domain %s failed check: was valid=%v, expected valid=%v", domain, IsValidFqdn(domain), isValid)
} }
} }
func TestDNSValidation(t *testing.T) { func TestDNSValidation(t *testing.T) {
t.Parallel()
// valid // valid
testDomainValidity(t, ".", true) testDomainValidity(t, ".", true)
testDomainValidity(t, "at.", true) testDomainValidity(t, "at.", true)

View file

@ -93,7 +93,7 @@ func (scope IPScope) IsLocalhost() bool {
// IsLAN returns true if the scope is site-local or link-local. // IsLAN returns true if the scope is site-local or link-local.
func (scope IPScope) IsLAN() bool { func (scope IPScope) IsLAN() bool {
switch scope { switch scope { //nolint:exhaustive // Looking for something specific.
case SiteLocal, LinkLocal, LocalMulticast: case SiteLocal, LinkLocal, LocalMulticast:
return true return true
default: default:
@ -103,7 +103,7 @@ func (scope IPScope) IsLAN() bool {
// IsGlobal returns true if the scope is global. // IsGlobal returns true if the scope is global.
func (scope IPScope) IsGlobal() bool { func (scope IPScope) IsGlobal() bool {
switch scope { switch scope { //nolint:exhaustive // Looking for something specific.
case Global, GlobalMulticast: case Global, GlobalMulticast:
return true return true
default: default:

View file

@ -6,6 +6,8 @@ import (
) )
func TestIPScope(t *testing.T) { func TestIPScope(t *testing.T) {
t.Parallel()
testScope(t, net.IPv4(71, 87, 113, 211), Global) testScope(t, net.IPv4(71, 87, 113, 211), Global)
testScope(t, net.IPv4(127, 0, 0, 1), HostLocal) testScope(t, net.IPv4(127, 0, 0, 1), HostLocal)
testScope(t, net.IPv4(127, 255, 255, 1), HostLocal) testScope(t, net.IPv4(127, 255, 255, 1), HostLocal)
@ -17,6 +19,8 @@ func TestIPScope(t *testing.T) {
} }
func testScope(t *testing.T, ip net.IP, expectedScope IPScope) { func testScope(t *testing.T, ip net.IP, expectedScope IPScope) {
t.Helper()
c := GetIPScope(ip) c := GetIPScope(ip)
if c != expectedScope { if c != expectedScope {
t.Errorf("%s is %s, expected %s", ip, scopeName(c), scopeName(expectedScope)) t.Errorf("%s is %s, expected %s", ip, scopeName(c), scopeName(expectedScope))

View file

@ -7,7 +7,7 @@ import (
"github.com/google/gopacket/tcpassembly" "github.com/google/gopacket/tcpassembly"
) )
// SimpleStreamAssemblerManager is a simple manager for github.com/google/gopacket/tcpassembly // SimpleStreamAssemblerManager is a simple manager for github.com/google/gopacket/tcpassembly.
type SimpleStreamAssemblerManager struct { type SimpleStreamAssemblerManager struct {
InitLock sync.Mutex InitLock sync.Mutex
lastAssembler *SimpleStreamAssembler lastAssembler *SimpleStreamAssembler
@ -25,7 +25,7 @@ func (m *SimpleStreamAssemblerManager) GetLastAssembler() *SimpleStreamAssembler
return m.lastAssembler return m.lastAssembler
} }
// SimpleStreamAssembler is a simple assembler for github.com/google/gopacket/tcpassembly // SimpleStreamAssembler is a simple assembler for github.com/google/gopacket/tcpassembly.
type SimpleStreamAssembler struct { type SimpleStreamAssembler struct {
Cumulated []byte Cumulated []byte
CumulatedLen int CumulatedLen int

View file

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
) )
// Basic Types // Basic Types.
type ( type (
// IPVersion represents an IP version. // IPVersion represents an IP version.
IPVersion uint8 IPVersion uint8
@ -15,7 +15,7 @@ type (
Verdict uint8 Verdict uint8
) )
// Basic Constants // Basic Constants.
const ( const (
IPv4 = IPVersion(4) IPv4 = IPVersion(4)
IPv6 = IPVersion(6) IPv6 = IPVersion(6)
@ -34,7 +34,7 @@ const (
AnyHostInternalProtocol61 = IPProtocol(61) AnyHostInternalProtocol61 = IPProtocol(61)
) )
// Verdicts // Verdicts.
const ( const (
DROP Verdict = iota DROP Verdict = iota
BLOCK BLOCK
@ -45,12 +45,10 @@ const (
STOP STOP
) )
var ( // ErrFailedToLoadPayload is returned by GetPayload if it failed for an unspecified reason, or is not implemented on the current system.
// ErrFailedToLoadPayload is returned by GetPayload if it failed for an unspecified reason, or is not implemented on the current system. var ErrFailedToLoadPayload = errors.New("could not load packet payload")
ErrFailedToLoadPayload = errors.New("could not load packet payload")
)
// ByteSize returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16 // ByteSize returns the byte size of the ip (IPv4 = 4 bytes, IPv6 = 16).
func (v IPVersion) ByteSize() int { func (v IPVersion) ByteSize() int {
switch v { switch v {
case IPv4: case IPv4:
@ -89,8 +87,11 @@ func (p IPProtocol) String() string {
return "ICMPv6" return "ICMPv6"
case IGMP: case IGMP:
return "IGMP" return "IGMP"
case AnyHostInternalProtocol61:
fallthrough
default:
return fmt.Sprintf("<unknown protocol, %d>", uint8(p))
} }
return fmt.Sprintf("<unknown protocol, %d>", uint8(p))
} }
// String returns the string representation of the verdict. // String returns the string representation of the verdict.

View file

@ -71,8 +71,11 @@ func (pkt *Base) HasPorts() bool {
return true return true
case UDP, UDPLite: case UDP, UDPLite:
return true return true
case ICMP, ICMPv6, IGMP, RAW, AnyHostInternalProtocol61:
fallthrough
default:
return false
} }
return false
} }
// LoadPacketData loads packet data from the integration, if not yet done. // LoadPacketData loads packet data from the integration, if not yet done.
@ -125,7 +128,7 @@ func (pkt *Base) createConnectionID() {
// IN OUT // IN OUT
// Local Dst Src // Local Dst Src
// Remote Src Dst // Remote Src Dst
// //.
func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.IPNet, port uint16) bool { func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.IPNet, port uint16) bool {
if pkt.info.Protocol != protocol { if pkt.info.Protocol != protocol {
return false return false
@ -154,7 +157,7 @@ func (pkt *Base) MatchesAddress(remote bool, protocol IPProtocol, network *net.I
// IN OUT // IN OUT
// Local Dst Src // Local Dst Src
// Remote Src Dst // Remote Src Dst
// //.
func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool { func (pkt *Base) MatchesIP(endpoint bool, network *net.IPNet) bool {
if pkt.info.Inbound != endpoint { if pkt.info.Inbound != endpoint {
if network.Contains(pkt.info.Src) { if network.Contains(pkt.info.Src) {
@ -174,7 +177,7 @@ func (pkt *Base) String() string {
return pkt.FmtPacket() return pkt.FmtPacket()
} }
// FmtPacket returns the most important information about the packet as a string // FmtPacket returns the most important information about the packet as a string.
func (pkt *Base) FmtPacket() string { func (pkt *Base) FmtPacket() string {
if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP { if pkt.info.Protocol == TCP || pkt.info.Protocol == UDP {
if pkt.info.Inbound { if pkt.info.Inbound {
@ -188,12 +191,12 @@ func (pkt *Base) FmtPacket() string {
return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst) return fmt.Sprintf("OUT %s %s <-> %s", pkt.info.Protocol, pkt.info.Src, pkt.info.Dst)
} }
// FmtProtocol returns the protocol as a string // FmtProtocol returns the protocol as a string.
func (pkt *Base) FmtProtocol() string { func (pkt *Base) FmtProtocol() string {
return pkt.info.Protocol.String() return pkt.info.Protocol.String()
} }
// FmtRemoteIP returns the remote IP address as a string // FmtRemoteIP returns the remote IP address as a string.
func (pkt *Base) FmtRemoteIP() string { func (pkt *Base) FmtRemoteIP() string {
if pkt.info.Inbound { if pkt.info.Inbound {
return pkt.info.Src.String() return pkt.info.Src.String()
@ -201,7 +204,7 @@ func (pkt *Base) FmtRemoteIP() string {
return pkt.info.Dst.String() return pkt.info.Dst.String()
} }
// FmtRemotePort returns the remote port as a string // FmtRemotePort returns the remote port as a string.
func (pkt *Base) FmtRemotePort() string { func (pkt *Base) FmtRemotePort() string {
if pkt.info.SrcPort != 0 { if pkt.info.SrcPort != 0 {
if pkt.info.Inbound { if pkt.info.Inbound {
@ -212,14 +215,14 @@ func (pkt *Base) FmtRemotePort() string {
return "-" return "-"
} }
// FmtRemoteAddress returns the full remote address (protocol, IP, port) as a string // FmtRemoteAddress returns the full remote address (protocol, IP, port) as a string.
func (pkt *Base) FmtRemoteAddress() string { func (pkt *Base) FmtRemoteAddress() string {
return fmt.Sprintf("%s:%s:%s", pkt.info.Protocol.String(), pkt.FmtRemoteIP(), pkt.FmtRemotePort()) return fmt.Sprintf("%s:%s:%s", pkt.info.Protocol.String(), pkt.FmtRemoteIP(), pkt.FmtRemotePort())
} }
// Packet is an interface to a network packet to provide object behaviour the same across all systems // Packet is an interface to a network packet to provide object behaviour the same across all systems.
type Packet interface { type Packet interface {
// VERDICTS // Verdicts.
Accept() error Accept() error
Block() error Block() error
Drop() error Drop() error
@ -230,7 +233,7 @@ type Packet interface {
RerouteToTunnel() error RerouteToTunnel() error
FastTrackedByIntegration() bool FastTrackedByIntegration() bool
// INFO // Info.
SetCtx(context.Context) SetCtx(context.Context)
Ctx() context.Context Ctx() context.Context
Info() *Info Info() *Info
@ -242,17 +245,17 @@ type Packet interface {
HasPorts() bool HasPorts() bool
GetConnectionID() string GetConnectionID() string
// PAYLOAD // Payload.
LoadPacketData() error LoadPacketData() error
Layers() gopacket.Packet Layers() gopacket.Packet
Raw() []byte Raw() []byte
Payload() []byte Payload() []byte
// MATCHING // Matching.
MatchesAddress(bool, IPProtocol, *net.IPNet, uint16) bool MatchesAddress(bool, IPProtocol, *net.IPNet, uint16) bool
MatchesIP(bool, *net.IPNet) bool MatchesIP(bool, *net.IPNet) bool
// FORMATTING // Formatting.
String() string String() string
FmtPacket() string FmtPacket() string
FmtProtocol() string FmtProtocol() string

View file

@ -4,7 +4,7 @@ import (
"net" "net"
) )
// Info holds IP and TCP/UDP header information // Info holds IP and TCP/UDP header information.
type Info struct { type Info struct {
Inbound bool Inbound bool
InTunnel bool InTunnel bool

View file

@ -135,7 +135,7 @@ func Parse(packetData []byte, pktBase *Base) (err error) {
parseIPv6, parseIPv6,
parseTCP, parseTCP,
parseUDP, parseUDP,
//parseUDPLite, // we don't yet support udplite // parseUDPLite, // We don't yet support udplite.
parseICMPv4, parseICMPv4,
parseICMPv6, parseICMPv6,
parseIGMP, parseIGMP,

View file

@ -1,4 +1,4 @@
// +build linux // go:build linux
package proc package proc
@ -7,9 +7,8 @@ import (
"os" "os"
"time" "time"
"github.com/safing/portmaster/network/socket"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/network/socket"
) )
var ( var (
@ -128,7 +127,10 @@ func readDirNames(dir string) (names []string) {
} }
return return
} }
defer file.Close() defer func() {
_ = file.Close()
}()
names, err = file.Readdirnames(0) names, err = file.Readdirnames(0)
if err != nil { if err != nil {
log.Warningf("proc: could not get entries from directory %s: %s", dir, err) log.Warningf("proc: could not get entries from directory %s: %s", dir, err)

View file

@ -1,4 +1,4 @@
// +build linux // go:build linux
package proc package proc

View file

@ -1,4 +1,4 @@
// +build linux // go:build linux
package proc package proc
@ -12,9 +12,8 @@ import (
"strings" "strings"
"unicode" "unicode"
"github.com/safing/portmaster/network/socket"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/network/socket"
) )
/* /*
@ -85,7 +84,6 @@ const (
) )
func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) { func getTableFromSource(stack uint8, procFile string) (connections []*socket.ConnectionInfo, binds []*socket.BindInfo, err error) {
var ipConverter func(string) net.IP var ipConverter func(string) net.IP
switch stack { switch stack {
case TCP4, UDP4: case TCP4, UDP4:
@ -101,7 +99,9 @@ func getTableFromSource(stack uint8, procFile string) (connections []*socket.Con
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer socketData.Close() defer func() {
_ = socketData.Close()
}()
// file scanner // file scanner
scanner := bufio.NewScanner(socketData) scanner := bufio.NewScanner(socketData)

View file

@ -1,4 +1,4 @@
// +build linux // go:build linux
package proc package proc
@ -8,6 +8,8 @@ import (
) )
func TestSockets(t *testing.T) { func TestSockets(t *testing.T) {
t.Parallel()
connections, listeners, err := GetTCP4Table() connections, listeners, err := GetTCP4Table()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

Some files were not shown because too many files have changed in this diff Show more