diff --git a/.golangci.yml b/.golangci.yml index b4c851be..6c348ac6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -31,6 +31,7 @@ linters: - whitespace - wrapcheck - wsl + - nolintlint linters-settings: revive: diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index b7f1d1ec..f3c5e0fd 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -19,7 +19,7 @@ import ( //nolint:gci,nolintlint func main() { // set information - info.Set("Portmaster", "0.8.13", "AGPLv3", true) + info.Set("Portmaster", "0.9.0", "AGPLv3", true) // Configure metrics. _ = metrics.SetNamespace("portmaster") diff --git a/cmds/portmaster-start/main.go b/cmds/portmaster-start/main.go index 6afc831d..d9f53816 100644 --- a/cmds/portmaster-start/main.go +++ b/cmds/portmaster-start/main.go @@ -77,7 +77,7 @@ func main() { cobra.OnInitialize(initCobra) // set meta info - info.Set("Portmaster Start", "0.8.8", "AGPLv3", false) + info.Set("Portmaster Start", "0.9.0", "AGPLv3", false) // catch interrupt for clean shutdown signalCh := make(chan os.Signal, 2) diff --git a/cmds/portmaster-start/run.go b/cmds/portmaster-start/run.go index 8881b2d6..97dd38f7 100644 --- a/cmds/portmaster-start/run.go +++ b/cmds/portmaster-start/run.go @@ -121,6 +121,21 @@ func getExecArgs(opts *Options, cmdArgs []string) []string { if stdinSignals { args = append(args, "--input-signals") } + + if opts.Identifier == "app/portmaster-app.zip" { + // see https://www.freedesktop.org/software/systemd/man/pam_systemd.html#type= + if xdgSessionType := os.Getenv("XDG_SESSION_TYPE"); xdgSessionType == "wayland" { + // we're running the Portmaster UI App under Wayland so make sure we add some arguments + // required by Electron + args = append(args, + []string{ + "--enable-features=UseOzonePlatform", + "--ozone-platform=wayland", + }..., + ) + } + } + args = append(args, cmdArgs...) return args } diff --git a/cmds/portmaster-start/service_windows.go b/cmds/portmaster-start/service_windows.go index 077dffbf..f43b632a 100644 --- a/cmds/portmaster-start/service_windows.go +++ b/cmds/portmaster-start/service_windows.go @@ -14,7 +14,6 @@ import ( "github.com/spf13/cobra" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/debug" - "golang.org/x/sys/windows/svc/eventlog" ) var ( @@ -102,13 +101,6 @@ func runService(_ *cobra.Command, opts *Options, cmdArgs []string) error { svcRun = debug.Run } - // open eventlog - elog, err := eventlog.Open(serviceName) - if err != nil { - return fmt.Errorf("failed to open eventlog: %s", err) - } - defer elog.Close() - runWg.Add(2) finishWg.Add(1) @@ -134,7 +126,6 @@ func runService(_ *cobra.Command, opts *Options, cmdArgs []string) error { err = getShutdownError() if err != nil { log.Printf("%s service experienced an error: %s\n", serviceName, err) - _ = elog.Error(1, fmt.Sprintf("%s experienced an error: %s", serviceName, err)) } return err diff --git a/core/core.go b/core/core.go index 42645f0f..14be809f 100644 --- a/core/core.go +++ b/core/core.go @@ -7,12 +7,12 @@ import ( "github.com/safing/portbase/modules" "github.com/safing/portbase/modules/subsystems" - "github.com/safing/portmaster/updates" _ "github.com/safing/portmaster/broadcasts" _ "github.com/safing/portmaster/netenv" _ "github.com/safing/portmaster/netquery" _ "github.com/safing/portmaster/status" _ "github.com/safing/portmaster/ui" + "github.com/safing/portmaster/updates" ) const ( diff --git a/go.mod b/go.mod index 95a6a374..45638ee2 100644 --- a/go.mod +++ b/go.mod @@ -15,24 +15,17 @@ require ( github.com/miekg/dns v1.1.50 github.com/oschwald/maxminddb-golang v1.9.0 github.com/safing/portbase v0.14.5 - github.com/safing/spn v0.4.12 + github.com/safing/spn v0.4.13 github.com/shirou/gopsutil v3.21.11+incompatible github.com/spf13/cobra v1.5.0 - github.com/stretchr/testify v1.7.1 + github.com/stretchr/testify v1.8.0 github.com/tannerryan/ring v1.1.2 github.com/tevino/abool v1.2.0 github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 golang.org/x/net v0.0.0-20220708220712-1185a9018129 golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f - golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e - zombiezen.com/go/sqlite v0.10.0 - github.com/stretchr/testify v1.8.0 - github.com/tannerryan/ring v1.1.2 - github.com/tevino/abool v1.2.0 - github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 - golang.org/x/net v0.0.0-20220622184535-263ec571b305 - golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f - golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 + golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 + zombiezen.com/go/sqlite v0.10.1 ) require ( @@ -86,7 +79,7 @@ require ( golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/tools v0.1.11 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.16.17 // indirect modernc.org/mathutil v1.4.1 // indirect modernc.org/memory v1.1.1 // indirect diff --git a/go.sum b/go.sum index 168ff329..0b6989af 100644 --- a/go.sum +++ b/go.sum @@ -869,6 +869,8 @@ github.com/safing/spn v0.4.7/go.mod h1:NoSG9K0OK9hrPC76yqWFS6RtvbqZdIc/KGOsC4T3h github.com/safing/spn v0.4.11/go.mod h1:nro/I6b2JnafeeqoMsQRqf6TaQeL9uLLZkUREtxLVDE= github.com/safing/spn v0.4.12 h1:Tw7TUZEZR4yZy7L+ICRCketDk5L5x0s0pvrSUHFaKs4= github.com/safing/spn v0.4.12/go.mod h1:AUNgBrRwCcspC98ljptDnrPuHLn/BHSG+rSprV/5Wlc= +github.com/safing/spn v0.4.13 h1:5NXWUl/2EWyotrQhW3tD+3DYw7hEqQk0n0lHa+w4eFo= +github.com/safing/spn v0.4.13/go.mod h1:rBeimIc1FHQOhX7lTh/LaFGRotmnwZIDWUSsPyeIDog= github.com/sagikazarmark/crypt v0.3.0/go.mod h1:uD/D+6UF4SrIR1uGEv7bBNkNqLGqUr43MRiaGWX1Nig= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= @@ -1230,6 +1232,8 @@ golang.org/x/net v0.0.0-20220412020605-290c469a71a5/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220621193019-9d032be2e588/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220708220712-1185a9018129 h1:vucSRfWwTsoXro7P+3Cjlr6flUMtzCwzlvkxEQtHHB0= +golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1391,6 +1395,8 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -1748,5 +1754,5 @@ rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/structured-merge-diff/v4 v4.0.2/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw= sigs.k8s.io/structured-merge-diff/v4 v4.1.2/go.mod h1:j/nl6xW8vLS49O8YvXW1ocPhZawJtm+Yrr7PPRQ0Vg4= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= -zombiezen.com/go/sqlite v0.10.0 h1:hegW0Y8c/fSJ2VjbjBeiKJaQOISNr4EUTx1VZx94Q9Y= -zombiezen.com/go/sqlite v0.10.0/go.mod h1:tOd9u3peffVYnXOedepSJmX92n/mbqf594wcJ+29jf8= +zombiezen.com/go/sqlite v0.10.1 h1:PSgVSHeIVOGKbX7ZIQNXGKn3wcqM6JBnT4yS1OLjWbM= +zombiezen.com/go/sqlite v0.10.1/go.mod h1:tOd9u3peffVYnXOedepSJmX92n/mbqf594wcJ+29jf8= diff --git a/netquery/chart_handler.go b/netquery/chart_handler.go index aaff4892..04db0c8a 100644 --- a/netquery/chart_handler.go +++ b/netquery/chart_handler.go @@ -14,6 +14,7 @@ import ( "github.com/safing/portmaster/netquery/orm" ) +// ChartHandler handles requests for connection charts. type ChartHandler struct { Database *Database } @@ -55,14 +56,14 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { enc.SetEscapeHTML(false) enc.SetIndent("", " ") - enc.Encode(map[string]interface{}{ + _ = enc.Encode(map[string]interface{}{ //nolint:errchkjson "results": result, "query": query, "params": paramMap, }) } -func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { +func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl var body io.Reader switch req.Method { diff --git a/netquery/database.go b/netquery/database.go index 57aa6f6c..966f4090 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -9,13 +9,14 @@ import ( "sync" "time" + "zombiezen.com/go/sqlite" + "zombiezen.com/go/sqlite/sqlitex" + "github.com/safing/portbase/log" "github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" - "zombiezen.com/go/sqlite" - "zombiezen.com/go/sqlite/sqlitex" ) // InMemory is the "file path" to open a new in-memory database. @@ -36,7 +37,7 @@ var ConnectionTypeToString = map[network.ConnectionType]string{ type ( // Database represents a SQLite3 backed connection database. - // It's use is tailored for persistance and querying of network.Connection. + // It's use is tailored for persistence and querying of network.Connection. // Access to the underlying SQLite database is synchronized. // // TODO(ppacher): somehow I'm receiving SIGBUS or SIGSEGV when no doing @@ -57,7 +58,7 @@ type ( // // Use ConvertConnection from this package to convert a network.Connection to this // representation. - Conn struct { + Conn struct { //nolint:maligned // ID is a device-unique identifier for the connection. It is built // from network.Connection by hashing the connection ID and the start // time. We cannot just use the network.Connection.ID because it is only unique @@ -93,11 +94,11 @@ type ( ProfileRevision int `sqlite:"profile_revision"` ExitNode *string `sqlite:"exit_node"` - // FIXME(ppacher): support "NOT" in search query to get rid of the following helper fields + // TODO(ppacher): support "NOT" in search query to get rid of the following helper fields SPNUsed bool `sqlite:"spn_used"` // could use "exit_node IS NOT NULL" or "exit IS NULL" Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL" - // FIXME(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here + // TODO(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here ProfileName string `sqlite:"profile_name"` } ) @@ -153,9 +154,9 @@ func NewInMemory() (*Database, error) { // to bring db up-to-date with the built-in schema. // TODO(ppacher): right now this only applies the current schema and ignores // any data-migrations. Once the history module is implemented this should -// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration +// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration. func (db *Database) ApplyMigrations() error { - // get the create-table SQL statement from the infered schema + // get the create-table SQL statement from the inferred schema sql := db.Schema.CreateStatement(false) // execute the SQL @@ -234,7 +235,7 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro // dumpTo is a simple helper method that dumps all rows stored in the SQLite database // as JSON to w. // Any error aborts dumping rows and is returned. -func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { +func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:unused db.l.Lock() defer db.l.Unlock() diff --git a/netquery/manager.go b/netquery/manager.go index b6649483..cc6e5056 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -116,7 +116,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect } } -func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn Conn) error { +func (mng *Manager) pushConnUpdate(_ context.Context, meta record.Meta, conn Conn) error { blob, err := json.Marshal(conn) if err != nil { return fmt.Errorf("failed to marshal connection: %w", err) @@ -173,17 +173,19 @@ func convertConnection(conn *network.Connection) (*Conn, error) { c.Type = "dns" case network.IPConnection: c.Type = "ip" + case network.Undefined: + c.Type = "" } switch conn.Verdict { case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel: accepted := true c.Allowed = &accepted - case network.VerdictUndecided, network.VerdictUndeterminable: - c.Allowed = nil - default: + case network.VerdictBlock, network.VerdictDrop: allowed := false c.Allowed = &allowed + case network.VerdictUndecided, network.VerdictUndeterminable, network.VerdictFailed: + c.Allowed = nil } if conn.Ended > 0 { diff --git a/netquery/module_api.go b/netquery/module_api.go index e5127b90..cb7d08e9 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -15,7 +15,7 @@ import ( "github.com/safing/portmaster/network" ) -type Module struct { +type module struct { *modules.Module db *database.Interface @@ -25,19 +25,19 @@ type Module struct { } func init() { - mod := new(Module) - mod.Module = modules.Register( + m := new(module) + m.Module = modules.Register( "netquery", - mod.Prepare, - mod.Start, - mod.Stop, + m.prepare, + m.start, + m.stop, "api", "network", "database", ) } -func (m *Module) Prepare() error { +func (m *module) prepare() error { var err error m.db = database.NewInterface(&database.Options{ @@ -66,7 +66,6 @@ func (m *Module) Prepare() error { Database: m.sqlStore, } - // FIXME(ppacher): use appropriate permissions for this if err := api.RegisterEndpoint(api.Endpoint{ Path: "netquery/query", MimeType: "application/json", @@ -96,13 +95,15 @@ func (m *Module) Prepare() error { return nil } -func (mod *Module) Start() error { - mod.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error { - sub, err := mod.db.Subscribe(query.New("network:")) +func (m *module) start() error { + m.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error { + sub, err := m.db.Subscribe(query.New("network:")) if err != nil { return fmt.Errorf("failed to subscribe to network tree: %w", err) } - defer sub.Cancel() + defer func() { + _ = sub.Cancel() + }() for { select { @@ -120,24 +121,24 @@ func (mod *Module) Start() error { continue } - mod.feed <- conn + m.feed <- conn } } }) - mod.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error { - mod.mng.HandleFeed(ctx, mod.feed) + m.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error { + m.mng.HandleFeed(ctx, m.feed) return nil }) - mod.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error { + m.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error { for { select { case <-ctx.Done(): return nil case <-time.After(10 * time.Second): threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold) - count, err := mod.sqlStore.Cleanup(ctx, threshold) + count, err := m.sqlStore.Cleanup(ctx, threshold) if err != nil { log.Errorf("netquery: failed to count number of rows in memory: %s", err) } else { @@ -147,19 +148,21 @@ func (mod *Module) Start() error { } }) - // for debugging, we provide a simple direct SQL query interface using - // the runtime database - // FIXME: Expose only in dev mode. - _, err := NewRuntimeQueryRunner(mod.sqlStore, "netquery/query/", runtime.DefaultRegistry) - if err != nil { - return fmt.Errorf("failed to set up runtime SQL query runner: %w", err) + // For debugging, provide a simple direct SQL query interface using + // the runtime database. + // Only expose in development mode. + if config.GetAsBool(config.CfgDevModeKey, false)() { + _, err := NewRuntimeQueryRunner(m.sqlStore, "netquery/query/", runtime.DefaultRegistry) + if err != nil { + return fmt.Errorf("failed to set up runtime SQL query runner: %w", err) + } } return nil } -func (mod *Module) Stop() error { - close(mod.feed) +func (m *module) stop() error { + close(m.feed) return nil } diff --git a/netquery/orm/decoder.go b/netquery/orm/decoder.go index 6cc16f97..0387f9de 100644 --- a/netquery/orm/decoder.go +++ b/netquery/orm/decoder.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "log" "reflect" "strings" "time" @@ -30,7 +29,7 @@ var ( // TEXT or REAL. // This package provides support for time.Time being stored as TEXT (using a // preconfigured timezone; UTC by default) or as INTEGER (the user can choose between - // unixepoch and unixnano-epoch where the nano variant is not offically supported by + // unixepoch and unixnano-epoch where the nano variant is not officially supported by // SQLITE). SqliteTimeFormat = "2006-01-02 15:04:05" ) @@ -54,6 +53,7 @@ type ( // DecodeFunc is called for each non-basic type during decoding. DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) + // DecodeConfig holds decoding functions. DecodeConfig struct { DecodeHooks []DecodeFunc } @@ -170,7 +170,8 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte return fmt.Errorf("cannot decode column %d (type=%s)", i, colType) } - //log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue) + // Debugging: + // log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue) // convert it to the target type if conversion is possible newValue := reflect.ValueOf(columnValue) @@ -189,8 +190,7 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte // time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to // decide between Unix or UnixNano epoch timestamps. // -// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing -// +// TODO(ppacher): update comment about loc parameter and TEXT storage class parsing. func DatetimeDecoder(loc *time.Location) DecodeFunc { return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) { // if we have the column definition available we @@ -203,11 +203,11 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc { // we only care about "time.Time" here if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) { - log.Printf("not decoding %s %v", outType, colDef) + // log.Printf("not decoding %s %v", outType, colDef) return nil, false, nil } - switch stmt.ColumnType(colIdx) { + switch stmt.ColumnType(colIdx) { //nolint:exhaustive // Only selecting specific types. case sqlite.TypeInteger: // stored as unix-epoch, if unixnano is set in the struct field tag // we parse it with nano-second resolution @@ -242,7 +242,7 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc { } } -func decodeIntoMap(ctx context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error { +func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error { if *mp == nil { *mp = make(map[string]interface{}) } @@ -292,7 +292,7 @@ func decodeBasic() DecodeFunc { if colDef != nil { valueKind = normalizeKind(colDef.GoType.Kind()) - // if we have a column defintion we try to convert the value to + // if we have a column definition we try to convert the value to // the actual Go-type that was used in the model. // this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage // or that type aliases like (type myInt int) are decoded into myInt instead of int @@ -314,7 +314,7 @@ func decodeBasic() DecodeFunc { }() } - log.Printf("decoding %s into kind %s", colName, valueKind) + // log.Printf("decoding %s into kind %s", colName, valueKind) if colType == sqlite.TypeNull { if colDef != nil && colDef.Nullable { @@ -330,7 +330,7 @@ func decodeBasic() DecodeFunc { } } - switch valueKind { + switch valueKind { //nolint:exhaustive case reflect.String: if colType != sqlite.TypeText { return nil, false, errInvalidType @@ -455,7 +455,7 @@ func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.S return nil, nil } -// getKind returns the kind of value but normalized Int, Uint and Float varaints +// getKind returns the kind of value but normalized Int, Uint and Float variants. // to their base type. func getKind(val reflect.Value) reflect.Kind { kind := val.Kind() @@ -475,6 +475,7 @@ func normalizeKind(kind reflect.Kind) reflect.Kind { } } +// DefaultDecodeConfig holds the default decoding configuration. var DefaultDecodeConfig = DecodeConfig{ DecodeHooks: []DecodeFunc{ DatetimeDecoder(time.UTC), diff --git a/netquery/orm/decoder_test.go b/netquery/orm/decoder_test.go index 5abd324e..62616938 100644 --- a/netquery/orm/decoder_test.go +++ b/netquery/orm/decoder_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "log" "reflect" "testing" "time" @@ -21,14 +20,14 @@ type testStmt struct { func (ts testStmt) ColumnCount() int { return len(ts.columns) } func (ts testStmt) ColumnName(i int) string { return ts.columns[i] } -func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } -func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } -func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } -func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } -func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } +func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } //nolint:forcetypeassert +func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } //nolint:forcetypeassert +func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } //nolint:forcetypeassert +func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } //nolint:forcetypeassert +func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } //nolint:forcetypeassert func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] } -// compile time check +// Compile time check. var _ Stmt = new(testStmt) type exampleFieldTypes struct { @@ -98,10 +97,11 @@ func (etn *exampleTimeNano) Equal(other interface{}) bool { return etn.T.Equal(oetn.T) } -func Test_Decoder(t *testing.T) { - ctx := context.TODO() +func TestDecoder(t *testing.T) { //nolint:maintidx,tparallel + t.Parallel() - refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) + ctx := context.TODO() + refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC) cases := []struct { Desc string @@ -433,8 +433,7 @@ func Test_Decoder(t *testing.T) { nil, &exampleInterface{}, func() interface{} { - var x interface{} - x = "value2" + var x interface{} = "value2" return &exampleInterface{ I: "value1", @@ -546,12 +545,10 @@ func Test_Decoder(t *testing.T) { }, } - for idx := range cases { + for idx := range cases { //nolint:paralleltest c := cases[idx] t.Run(c.Desc, func(t *testing.T) { - //t.Parallel() - - log.Println(c.Desc) + // log.Println(c.Desc) err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig) if fn, ok := c.Expected.(func() interface{}); ok { c.Expected = fn() diff --git a/netquery/orm/encoder.go b/netquery/orm/encoder.go index 0bcea756..7961f088 100644 --- a/netquery/orm/encoder.go +++ b/netquery/orm/encoder.go @@ -10,8 +10,10 @@ import ( ) type ( + // EncodeFunc is called for each non-basic type during encoding. EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) + // EncodeConfig holds encoding functions. EncodeConfig struct { EncodeHooks []EncodeFunc } @@ -69,6 +71,7 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode return res, nil } +// EncodeValue encodes the given value. func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) { fieldValue := reflect.ValueOf(val) fieldType := reflect.TypeOf(val) @@ -115,7 +118,7 @@ func encodeBasic() EncodeFunc { val = val.Elem() } - switch normalizeKind(kind) { + switch normalizeKind(kind) { //nolint:exhaustive case reflect.String, reflect.Float64, reflect.Bool, @@ -138,6 +141,7 @@ func encodeBasic() EncodeFunc { } } +// DatetimeEncoder returns a new datetime encoder for the given time zone. func DatetimeEncoder(loc *time.Location) EncodeFunc { return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) { // if fieldType holds a pointer we need to dereference the value @@ -149,7 +153,8 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc { // we only care about "time.Time" here var t time.Time - if ft == "time.Time" { + switch { + case ft == "time.Time": // handle the zero time as a NULL. if !val.IsValid() || val.IsZero() { return nil, true, nil @@ -162,19 +167,19 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc { return nil, false, fmt.Errorf("cannot convert reflect value to time.Time") } - } else if valType.Kind() == reflect.String && colDef.IsTime { + case valType.Kind() == reflect.String && colDef.IsTime: var err error t, err = time.Parse(time.RFC3339, val.String()) if err != nil { return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err) } - } else { + default: // we don't care ... return nil, false, nil } - switch colDef.Type { + switch colDef.Type { //nolint:exhaustive case sqlite.TypeInteger: if colDef.UnixNano { return t.UnixNano(), true, nil @@ -194,7 +199,7 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc { func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) { if valType == nil { if !colDef.Nullable { - switch colDef.Type { + switch colDef.Type { //nolint:exhaustive case sqlite.TypeBlob: return []byte{}, true, nil case sqlite.TypeFloat: @@ -225,6 +230,7 @@ func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, return nil, false, nil } +// DefaultEncodeConfig holds the default encoding configuration. var DefaultEncodeConfig = EncodeConfig{ EncodeHooks: []EncodeFunc{ DatetimeEncoder(time.UTC), diff --git a/netquery/orm/encoder_test.go b/netquery/orm/encoder_test.go index aff28580..e5142962 100644 --- a/netquery/orm/encoder_test.go +++ b/netquery/orm/encoder_test.go @@ -9,9 +9,11 @@ import ( "zombiezen.com/go/sqlite" ) -func Test_EncodeAsMap(t *testing.T) { +func TestEncodeAsMap(t *testing.T) { //nolint:tparallel + t.Parallel() + ctx := context.TODO() - refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) + refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC) cases := []struct { Desc string @@ -114,11 +116,9 @@ func Test_EncodeAsMap(t *testing.T) { }, } - for idx := range cases { + for idx := range cases { //nolint:paralleltest c := cases[idx] t.Run(c.Desc, func(t *testing.T) { - // t.Parallel() - res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig) assert.NoError(t, err) assert.Equal(t, c.Expected, res) @@ -126,9 +126,11 @@ func Test_EncodeAsMap(t *testing.T) { } } -func Test_EncodeValue(t *testing.T) { +func TestEncodeValue(t *testing.T) { //nolint:tparallel + t.Parallel() + ctx := context.TODO() - refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) + refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC) cases := []struct { Desc string @@ -247,11 +249,9 @@ func Test_EncodeValue(t *testing.T) { }, } - for idx := range cases { + for idx := range cases { //nolint:paralleltest c := cases[idx] t.Run(c.Desc, func(t *testing.T) { - //t.Parallel() - res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig) assert.NoError(t, err) assert.Equal(t, c.Output, res) diff --git a/netquery/orm/query_runner.go b/netquery/orm/query_runner.go index 88eceefd..2d5f01a3 100644 --- a/netquery/orm/query_runner.go +++ b/netquery/orm/query_runner.go @@ -57,6 +57,8 @@ func WithNamedArgs(args map[string]interface{}) QueryOption { } } +// WithSchema returns a query option that adds the given table +// schema to the query. func WithSchema(tbl TableSchema) QueryOption { return func(opts *queryOpts) { opts.Schema = tbl @@ -139,9 +141,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q valElemType = valType.Elem() opts.ResultFunc = func(stmt *sqlite.Stmt) error { - var currentField reflect.Value - - currentField = reflect.New(valElemType) + currentField := reflect.New(valElemType) if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil { return err diff --git a/netquery/orm/schema_builder.go b/netquery/orm/schema_builder.go index 5d533be5..508b7b18 100644 --- a/netquery/orm/schema_builder.go +++ b/netquery/orm/schema_builder.go @@ -10,10 +10,9 @@ import ( "zombiezen.com/go/sqlite" ) -var ( - errSkipStructField = errors.New("struct field should be skipped") -) +var errSkipStructField = errors.New("struct field should be skipped") +// Struct Tags. var ( TagUnixNano = "unixnano" TagPrimaryKey = "primary" @@ -36,12 +35,14 @@ var sqlTypeMap = map[sqlite.ColumnType]string{ } type ( + // TableSchema defines a SQL table schema. TableSchema struct { Name string Columns []ColumnDef } - ColumnDef struct { + // ColumnDef defines a SQL column. + ColumnDef struct { //nolint:maligned Name string Nullable bool Type sqlite.ColumnType @@ -54,6 +55,7 @@ type ( } ) +// GetColumnDef returns the column definition with the given name. func (ts TableSchema) GetColumnDef(name string) *ColumnDef { for _, def := range ts.Columns { if def.Name == name { @@ -63,6 +65,7 @@ func (ts TableSchema) GetColumnDef(name string) *ColumnDef { return nil } +// CreateStatement build the CREATE SQL statement for the table. func (ts TableSchema) CreateStatement(ifNotExists bool) string { sql := "CREATE TABLE" if ifNotExists { @@ -81,6 +84,7 @@ func (ts TableSchema) CreateStatement(ifNotExists bool) string { return sql } +// AsSQL builds the SQL column definition. func (def ColumnDef) AsSQL() string { sql := def.Name + " " @@ -103,6 +107,7 @@ func (def ColumnDef) AsSQL() string { return sql } +// GenerateTableSchema generates a table schema from the given struct. func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) { ts := &TableSchema{ Name: name, @@ -149,7 +154,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) { def.GoType = ft kind := normalizeKind(ft.Kind()) - switch kind { + switch kind { //nolint:exhaustive case reflect.Int: def.Type = sqlite.TypeInteger @@ -190,7 +195,7 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error { if len(parts) > 1 { for _, k := range parts[1:] { switch k { - // column modifieres + // column modifiers case TagPrimaryKey: def.PrimaryKey = true case TagAutoIncrement: diff --git a/netquery/orm/schema_builder_test.go b/netquery/orm/schema_builder_test.go index 7012076d..734da981 100644 --- a/netquery/orm/schema_builder_test.go +++ b/netquery/orm/schema_builder_test.go @@ -6,7 +6,9 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_SchemaBuilder(t *testing.T) { +func TestSchemaBuilder(t *testing.T) { + t.Parallel() + cases := []struct { Name string Model interface{} diff --git a/netquery/query.go b/netquery/query.go index 58c6cdee..7ccda769 100644 --- a/netquery/query.go +++ b/netquery/query.go @@ -5,15 +5,19 @@ import ( "encoding/json" "fmt" "io" - "log" "sort" "strings" "github.com/hashicorp/go-multierror" - "github.com/safing/portmaster/netquery/orm" "zombiezen.com/go/sqlite" + + "github.com/safing/portmaster/netquery/orm" ) +// Collection of Query and Matcher types. +// NOTE: whenever adding support for new operators make sure +// to update UnmarshalJSON as well. +//nolint:golint type ( Query map[string][]Matcher @@ -43,8 +47,6 @@ type ( Distinct bool `json:"distinct"` } - // NOTE: whenever adding support for new operators make sure - // to update UnmarshalJSON as well. Select struct { Field string `json:"field"` Count *Count `json:"$count,omitempty"` @@ -91,6 +93,7 @@ type ( } ) +// UnmarshalJSON unmarshals a Query from json. func (query *Query) UnmarshalJSON(blob []byte) error { if *query == nil { *query = make(Query) @@ -202,13 +205,14 @@ func parseMatcher(raw json.RawMessage) (*Matcher, error) { } if err := m.Validate(); err != nil { - return nil, fmt.Errorf("invalid query matcher: %s", err) + return nil, fmt.Errorf("invalid query matcher: %w", err) } - log.Printf("parsed matcher %s: %+v", string(raw), m) - return &m, nil + // log.Printf("parsed matcher %s: %+v", string(raw), m) + return &m, nil } +// Validate validates the matcher. func (match Matcher) Validate() error { found := 0 @@ -239,9 +243,9 @@ func (match Matcher) Validate() error { return nil } -func (text TextSearch) toSQLConditionClause(ctx context.Context, schema *orm.TableSchema, suffix string, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { +func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) { var ( - queryParts []string + queryParts = make([]string, 0, len(text.Fields)) params = make(map[string]interface{}) ) @@ -379,7 +383,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T // merge parameters up into the superior parameter map for key, val := range params { if _, ok := paramMap[key]; ok { - // is is soley a developer mistake when implementing a matcher so no forgiving ... + // This is solely a developer mistake when implementing a matcher so no forgiving ... panic("sqlite parameter collision") } @@ -399,6 +403,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T return whereClause, paramMap, errs.ErrorOrNil() } +// UnmarshalJSON unmarshals a Selects from json. func (sel *Selects) UnmarshalJSON(blob []byte) error { if len(blob) == 0 { return io.ErrUnexpectedEOF @@ -438,6 +443,7 @@ func (sel *Selects) UnmarshalJSON(blob []byte) error { return nil } +// UnmarshalJSON unmarshals a Select from json. func (sel *Select) UnmarshalJSON(blob []byte) error { if len(blob) == 0 { return io.ErrUnexpectedEOF @@ -481,6 +487,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { return nil } +// UnmarshalJSON unmarshals a OrderBys from json. func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error { if len(blob) == 0 { return io.ErrUnexpectedEOF @@ -523,6 +530,7 @@ func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error { return nil } +// UnmarshalJSON unmarshals a OrderBy from json. func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error { if len(blob) == 0 { return io.ErrUnexpectedEOF diff --git a/netquery/query_handler.go b/netquery/query_handler.go index 1b2b5411..baa1df1d 100644 --- a/netquery/query_handler.go +++ b/netquery/query_handler.go @@ -17,9 +17,7 @@ import ( "github.com/safing/portmaster/netquery/orm" ) -var ( - charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+") -) +var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+") type ( @@ -109,7 +107,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { } } -func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { +func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { //nolint:dupl var body io.Reader switch req.Method { @@ -230,11 +228,11 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem switch { case s.Count != nil: - var as = s.Count.As + as := s.Count.As if as == "" { as = fmt.Sprintf("%s_count", colName) } - var distinct = "" + distinct := "" if s.Count.Distinct { distinct = "DISTINCT " } @@ -278,8 +276,7 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) ( return "", nil } - var groupBys = make([]string, len(req.GroupBy)) - + groupBys := make([]string, len(req.GroupBy)) for idx, name := range req.GroupBy { colName, err := req.validateColumnName(schema, name) if err != nil { @@ -288,7 +285,6 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) ( groupBys[idx] = colName } - groupByClause := "GROUP BY " + strings.Join(groupBys, ", ") // if there are no explicitly selected fields we default to the @@ -301,7 +297,7 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) ( } func (req *QueryRequestPayload) generateSelectClause() string { - var selectClause = "*" + selectClause := "*" if len(req.selectedFields) > 0 { selectClause = strings.Join(req.selectedFields, ", ") } @@ -314,7 +310,7 @@ func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) ( return "", nil } - var orderBys = make([]string, len(req.OrderBy)) + orderBys := make([]string, len(req.OrderBy)) for idx, sort := range req.OrderBy { colName, err := req.validateColumnName(schema, sort.Field) if err != nil { @@ -352,5 +348,5 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel return "", fmt.Errorf("column name %q not allowed", field) } -// compile time check +// Compile time check. var _ http.Handler = new(QueryHandler) diff --git a/netquery/query_test.go b/netquery/query_test.go index 7d9a0393..afd65b4f 100644 --- a/netquery/query_test.go +++ b/netquery/query_test.go @@ -7,13 +7,16 @@ import ( "testing" "time" - "github.com/safing/portmaster/netquery/orm" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/safing/portmaster/netquery/orm" ) -func Test_UnmarshalQuery(t *testing.T) { - var cases = []struct { +func TestUnmarshalQuery(t *testing.T) { //nolint:tparallel + t.Parallel() + + cases := []struct { Name string Input string Expected Query @@ -88,7 +91,8 @@ func Test_UnmarshalQuery(t *testing.T) { }, } - for _, c := range cases { + for _, testCase := range cases { //nolint:paralleltest + c := testCase t.Run(c.Name, func(t *testing.T) { var q Query err := json.Unmarshal([]byte(c.Input), &q) @@ -105,10 +109,11 @@ func Test_UnmarshalQuery(t *testing.T) { } } -func Test_QueryBuilder(t *testing.T) { - now := time.Now() +func TestQueryBuilder(t *testing.T) { //nolint:tparallel + t.Parallel() - var cases = []struct { + now := time.Now() + cases := []struct { N string Q Query R string @@ -186,7 +191,7 @@ func Test_QueryBuilder(t *testing.T) { }, "", nil, - fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), + fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), //nolint:golint }, { "Complex example", @@ -225,19 +230,20 @@ func Test_QueryBuilder(t *testing.T) { tbl, err := orm.GenerateTableSchema("connections", Conn{}) require.NoError(t, err) - for idx, c := range cases { + for idx, testCase := range cases { //nolint:paralleltest + cID := idx + c := testCase t.Run(c.N, func(t *testing.T) { - //t.Parallel() str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig) if c.E != nil { if assert.Error(t, err) { - assert.Equal(t, c.E.Error(), err.Error(), "test case %d", idx) + assert.Equal(t, c.E.Error(), err.Error(), "test case %d", cID) } } else { - assert.NoError(t, err, "test case %d", idx) - assert.Equal(t, c.P, params, "test case %d", idx) - assert.Equal(t, c.R, str, "test case %d", idx) + assert.NoError(t, err, "test case %d", cID) + assert.Equal(t, c.P, params, "test case %d", cID) + assert.Equal(t, c.R, str, "test case %d", cID) } }) } diff --git a/network/clean.go b/network/clean.go index 6ff1eb38..3d8b23cc 100644 --- a/network/clean.go +++ b/network/clean.go @@ -11,8 +11,11 @@ import ( ) const ( - cleanerTickDuration = 5 * time.Second + // DeleteConnsAfterEndedThreshold defines the amount of time after which + // ended connections should be removed from the internal connection state. DeleteConnsAfterEndedThreshold = 10 * time.Minute + + cleanerTickDuration = 5 * time.Second ) func connectionCleaner(ctx context.Context) error { diff --git a/resolver/config.go b/resolver/config.go index ff17cb99..19f299d3 100644 --- a/resolver/config.go +++ b/resolver/config.go @@ -112,7 +112,7 @@ The format is: "protocol://ip:port?parameter=value¶meter=value" ExpertiseLevel: config.ExpertiseLevelUser, ReleaseLevel: config.ReleaseLevelStable, DefaultValue: defaultNameServers, - ValidationRegex: fmt.Sprintf("^(%s|%s|%s)://.*", ServerTypeDoT, ServerTypeDNS, ServerTypeTCP), + ValidationRegex: fmt.Sprintf("^(%s|%s|%s|%s|%s|%s)://.*", ServerTypeDoT, ServerTypeDoH, ServerTypeDNS, ServerTypeTCP, HTTPSProtocol, TLSProtocol), ValidationFunc: validateNameservers, Annotations: config.Annotations{ config.DisplayHintAnnotation: config.DisplayHintOrdered, @@ -123,32 +123,31 @@ The format is: "protocol://ip:port?parameter=value¶meter=value" Name: "Cloudflare (with Malware Filter)", Action: config.QuickReplace, Value: []string{ - "dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", - "dot://1.0.0.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip", + "dot://cloudflare-dns.com?ip=1.1.1.2&name=Cloudflare&blockedif=zeroip", + "dot://cloudflare-dns.com?ip=1.0.0.2&name=Cloudflare&blockedif=zeroip", }, }, { Name: "Quad9", Action: config.QuickReplace, Value: []string{ - "dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", - "dot://149.112.112.112:853?verify=dns.quad9.net&name=Quad9&blockedif=empty", + "dot://dns.quad9.net?ip=9.9.9.9&name=Quad9&blockedif=empty", + "dot://dns.quad9.net?ip=149.112.112.112&name=Quad9&blockedif=empty", }, }, { Name: "AdGuard", Action: config.QuickReplace, Value: []string{ - "dot://94.140.14.14:853?verify=dns.adguard.com&name=AdGuard&blockedif=zeroip", - "dot://94.140.15.15:853?verify=dns.adguard.com&name=AdGuard&blockedif=zeroip", + "dot://dns.adguard.com?ip=94.140.14.14&name=AdGuard&blockedif=zeroip", + "dot://dns.adguard.com?ip=94.140.15.15&name=AdGuard&blockedif=zeroip", }, }, { Name: "Foundation for Applied Privacy", Action: config.QuickReplace, Value: []string{ - "dot://94.130.106.88:853?verify=dot1.applied-privacy.net&name=AppliedPrivacy", - "dot://94.130.106.88:443?verify=dot1.applied-privacy.net&name=AppliedPrivacy", + "dot://dot1.applied-privacy.net?ip=94.130.106.88&name=AppliedPrivacy", }, }, }, diff --git a/resolver/resolver-https.go b/resolver/resolver-https.go new file mode 100644 index 00000000..9776c609 --- /dev/null +++ b/resolver/resolver-https.go @@ -0,0 +1,120 @@ +package resolver + +import ( + "context" + "crypto/tls" + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" + + "github.com/miekg/dns" +) + +// HTTPSResolver is a resolver using just a single tcp connection with pipelining. +type HTTPSResolver struct { + BasicResolverConn + Client *http.Client +} + +// HTTPSQuery holds the query information for a hTTPSResolverConn. +type HTTPSQuery struct { + Query *Query + Response chan *dns.Msg +} + +// MakeCacheRecord creates an RRCache record from a reply. +func (tq *HTTPSQuery) MakeCacheRecord(reply *dns.Msg, resolverInfo *ResolverInfo) *RRCache { + return &RRCache{ + Domain: tq.Query.FQDN, + Question: tq.Query.QType, + RCode: reply.Rcode, + Answer: reply.Answer, + Ns: reply.Ns, + Extra: reply.Extra, + Resolver: resolverInfo.Copy(), + } +} + +// NewHTTPSResolver returns a new HTTPSResolver. +func NewHTTPSResolver(resolver *Resolver) *HTTPSResolver { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: resolver.Info.Domain, + // TODO: use portbase rng + }, + IdleConnTimeout: 3 * time.Minute, + } + + client := &http.Client{Transport: tr} + newResolver := &HTTPSResolver{ + BasicResolverConn: BasicResolverConn{ + resolver: resolver, + }, + Client: client, + } + newResolver.BasicResolverConn.init() + return newResolver +} + +// Query executes the given query against the resolver. +func (hr *HTTPSResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { + dnsQuery := new(dns.Msg) + dnsQuery.SetQuestion(q.FQDN, uint16(q.QType)) + + // Pack query and convert to base64 string + buf, err := dnsQuery.Pack() + if err != nil { + return nil, err + } + b64dns := base64.RawStdEncoding.EncodeToString(buf) + + // Build and execute http reuqest + url := &url.URL{ + Scheme: "https", + Host: hr.resolver.ServerAddress, + Path: hr.resolver.Path, + ForceQuery: true, + RawQuery: fmt.Sprintf("dns=%s", b64dns), + } + + request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) + if err != nil { + return nil, err + } + + resp, err := hr.Client.Do(request) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // Try to read the result + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + reply := new(dns.Msg) + + err = reply.Unpack(body) + if err != nil { + return nil, err + } + + newRecord := &RRCache{ + Domain: q.FQDN, + Question: q.QType, + RCode: reply.Rcode, + Answer: reply.Answer, + Ns: reply.Ns, + Extra: reply.Extra, + Resolver: hr.resolver.Info.Copy(), + } + + // TODO: check if reply.Answer is valid + return newRecord, nil +} diff --git a/resolver/resolver-tcp.go b/resolver/resolver-tcp.go index cb37ec94..746d6c01 100644 --- a/resolver/resolver-tcp.go +++ b/resolver/resolver-tcp.go @@ -99,7 +99,7 @@ func (tr *TCPResolver) UseTLS() *TCPResolver { tr.dnsClient.Net = "tcp-tls" tr.dnsClient.TLSConfig = &tls.Config{ MinVersion: tls.VersionTLS12, - ServerName: tr.resolver.VerifyDomain, + ServerName: tr.resolver.Info.Domain, // TODO: use portbase rng } return tr diff --git a/resolver/resolver.go b/resolver/resolver.go index 1f12c8f2..764b24fa 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -30,6 +30,12 @@ const ( ServerSourceEnv = "env" ) +// DNS Resolver alias +const ( + HTTPSProtocol = "https" + TLSProtocol = "tls" +) + // FailThreshold is amount of errors a resolvers must experience in order to be regarded as failed. var FailThreshold = 20 @@ -61,9 +67,9 @@ type Resolver struct { UpstreamBlockDetection string // Special Options - VerifyDomain string - Search []string - SearchOnly bool + Search []string + SearchOnly bool + Path string // logic interface Conn ResolverConn `json:"-"` @@ -87,6 +93,9 @@ type ResolverInfo struct { //nolint:golint,maligned // TODO // IP is the IP address of the resolver IP net.IP + // Domain of the dns server if it has one + Domain string + // IPScope is the network scope of the IP address. IPScope netutils.IPScope @@ -107,6 +116,20 @@ func (info *ResolverInfo) ID() string { info.id = ServerTypeMDNS case ServerTypeEnv: info.id = ServerTypeEnv + case ServerTypeDoH: + info.id = fmt.Sprintf( + "https://%s:%d#%s", + info.Domain, + info.Port, + info.Source, + ) + case ServerTypeDoT: + info.id = fmt.Sprintf( + "dot://%s:%d#%s", + info.Domain, + info.Port, + info.Source, + ) default: info.id = fmt.Sprintf( "%s://%s:%d#%s", @@ -135,6 +158,12 @@ func (info *ResolverInfo) DescriptiveName() string { info.Name, info.ID(), ) + case info.Domain != "": + return fmt.Sprintf( + "%s (%s)", + info.Domain, + info.ID(), + ) default: return fmt.Sprintf( "%s (%s)", @@ -155,6 +184,7 @@ func (info *ResolverInfo) Copy() *ResolverInfo { Type: info.Type, Source: info.Source, IP: info.IP, + Domain: info.Domain, IPScope: info.IPScope, Port: info.Port, id: info.id, diff --git a/resolver/resolvers.go b/resolver/resolvers.go index 4360d836..bc6ecbf6 100644 --- a/resolver/resolvers.go +++ b/resolver/resolvers.go @@ -12,6 +12,7 @@ import ( "golang.org/x/net/publicsuffix" + "github.com/miekg/dns" "github.com/safing/portbase/log" "github.com/safing/portbase/utils" "github.com/safing/portmaster/netenv" @@ -29,9 +30,11 @@ type Scope struct { const ( parameterName = "name" parameterVerify = "verify" + parameterIP = "ip" parameterBlockedIf = "blockedif" parameterSearch = "search" parameterSearchOnly = "search-only" + parameterPath = "path" ) var ( @@ -41,7 +44,9 @@ var ( localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope activeResolvers map[string]*Resolver // lookup map of all resolvers currentResolverConfig []string // current active resolver config, to detect changes - resolversLock sync.RWMutex + resolverInitDomains map[string]struct{} // a set with all domains of the dns resolvers + + resolversLock sync.RWMutex ) func indexOfScope(domain string, list []*Scope) int { @@ -80,6 +85,8 @@ func resolverConnFactory(resolver *Resolver) ResolverConn { return NewTCPResolver(resolver) case ServerTypeDoT: return NewTCPResolver(resolver).UseTLS() + case ServerTypeDoH: + return NewHTTPSResolver(resolver) case ServerTypeDNS: return NewPlainResolver(resolver) default: @@ -93,93 +100,64 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { return nil, false, err } + if resolverInitDomains == nil { + resolverInitDomains = make(map[string]struct{}) + } + switch u.Scheme { - case ServerTypeDNS, ServerTypeDoT, ServerTypeTCP: + case ServerTypeDNS, ServerTypeDoT, ServerTypeDoH, ServerTypeTCP: + case HTTPSProtocol: + u.Scheme = ServerTypeDoH + case TLSProtocol: + u.Scheme = ServerTypeDoT default: return nil, false, fmt.Errorf("DNS resolver scheme %q invalid", u.Scheme) } - ip := net.ParseIP(u.Hostname()) - if ip == nil { - return nil, false, fmt.Errorf("resolver IP %q invalid", u.Hostname()) - } - - // Add default port for scheme if it is missing. - var port uint16 - hostPort := u.Port() - switch { - case hostPort != "": - parsedPort, err := strconv.ParseUint(hostPort, 10, 16) - if err != nil { - return nil, false, fmt.Errorf("resolver port %q invalid", u.Port()) - } - port = uint16(parsedPort) - case u.Scheme == ServerTypeDNS, u.Scheme == ServerTypeTCP: - port = 53 - case u.Scheme == ServerTypeDoH: - port = 443 - case u.Scheme == ServerTypeDoT: - port = 853 - default: - return nil, false, fmt.Errorf("missing port in %q", u.Host) - } - - scope := netutils.GetIPScope(ip) - // Skip localhost resolvers from the OS, but not if configured. - if scope.IsLocalhost() && source == ServerSourceOperatingSystem { - return nil, true, nil // skip - } - - // Get parameters and check if keys exist. query := u.Query() - for key := range query { - switch key { - case parameterName, - parameterVerify, - parameterBlockedIf, - parameterSearch, - parameterSearchOnly: - // Known key, continue. - default: - // Unknown key, abort. - return nil, false, fmt.Errorf(`unknown parameter "%s"`, key) - } - } - // Check domain verification config. - verifyDomain := query.Get(parameterVerify) - if verifyDomain != "" && u.Scheme != ServerTypeDoT { - return nil, false, fmt.Errorf("domain verification only supported in DOT") - } - if verifyDomain == "" && u.Scheme == ServerTypeDoT { - return nil, false, fmt.Errorf("DOT must have a verify query parameter set") - } - - // Check block detection type. - blockType := query.Get(parameterBlockedIf) - if blockType == "" { - blockType = BlockDetectionZeroIP - } - switch blockType { - case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP: - default: - return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)") - } - - // Build resolver. + // Create Resolver object newResolver := &Resolver{ ConfigURL: resolverURL, Info: &ResolverInfo{ Name: query.Get(parameterName), Type: u.Scheme, Source: source, - IP: ip, - IPScope: scope, - Port: port, + IP: nil, + Domain: "", + IPScope: netutils.Global, + Port: 0, }, - ServerAddress: net.JoinHostPort(ip.String(), strconv.Itoa(int(port))), - VerifyDomain: verifyDomain, - UpstreamBlockDetection: blockType, + ServerAddress: "", + Path: u.Path, // Used for DoH + UpstreamBlockDetection: "", + } + + // Get parameters and check if keys exist. + err = checkAndSetResolverParamters(u, newResolver) + if err != nil { + return nil, false, err + } + + // Check block detection type. + newResolver.UpstreamBlockDetection = query.Get(parameterBlockedIf) + if newResolver.UpstreamBlockDetection == "" { + newResolver.UpstreamBlockDetection = BlockDetectionZeroIP + } + + switch newResolver.UpstreamBlockDetection { + case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP: + default: + return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)") + } + + // Get ip scope if we have ip + if newResolver.Info.IP != nil { + newResolver.Info.IPScope = netutils.GetIPScope(newResolver.Info.IP) + // Skip localhost resolvers from the OS, but not if configured. + if newResolver.Info.IPScope.IsLocalhost() && source == ServerSourceOperatingSystem { + return nil, true, nil // skip + } } // Parse search domains. @@ -206,6 +184,108 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) { return newResolver, false, nil } +func checkAndSetResolverParamters(u *url.URL, resolver *Resolver) error { + // Check if we are using domain name and if it's in a valid scheme + ip := net.ParseIP(u.Hostname()) + hostnameIsDomaion := (ip == nil) + if ip == nil && u.Scheme != ServerTypeDoH && u.Scheme != ServerTypeDoT { + return fmt.Errorf("resolver IP %q is invalid", u.Hostname()) + } + + // Add default port for scheme if it is missing. + port, err := parsePortFromURL(u) + if err != nil { + return err + } + resolver.Info.Port = port + resolver.Info.IP = ip + + query := u.Query() + + for key := range query { + switch key { + case parameterName, + parameterVerify, + parameterIP, + parameterBlockedIf, + parameterSearch, + parameterSearchOnly, + parameterPath: + // Known key, continue. + default: + // Unknown key, abort. + return fmt.Errorf(`unknown parameter "%q"`, key) + } + } + + resolver.Info.Domain = query.Get(parameterVerify) + paramterServerIP := query.Get(parameterIP) + + if u.Scheme == ServerTypeDoT || u.Scheme == ServerTypeDoH { + // Check if IP and Domain are set correctly + switch { + case hostnameIsDomaion && resolver.Info.Domain != "": + return fmt.Errorf("cannot set the domain name via both the hostname in the URL and the verify parameter") + case !hostnameIsDomaion && resolver.Info.Domain == "": + return fmt.Errorf("verify parameter must be set when using ip as domain") + case !hostnameIsDomaion && paramterServerIP != "": + return fmt.Errorf("cannot set the IP address via both the hostname in the URL and the ip parameter") + } + + // Parse and set IP and Domain to the resolver + switch { + case hostnameIsDomaion && paramterServerIP != "": // domain and ip as parameter + resolver.Info.IP = net.ParseIP(paramterServerIP) + resolver.ServerAddress = net.JoinHostPort(paramterServerIP, strconv.Itoa(int(resolver.Info.Port))) + resolver.Info.Domain = u.Hostname() + case !hostnameIsDomaion && resolver.Info.Domain != "": // ip and domain as parameter + resolver.ServerAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(resolver.Info.Port))) + case hostnameIsDomaion && resolver.Info.Domain == "" && paramterServerIP == "": // only domain + resolver.Info.Domain = u.Hostname() + resolver.ServerAddress = net.JoinHostPort(resolver.Info.Domain, strconv.Itoa(int(port))) + } + + if ip == nil { + resolverInitDomains[dns.Fqdn(resolver.Info.Domain)] = struct{}{} + } + + } else { + if resolver.Info.Domain != "" { + return fmt.Errorf("domain verification is only supported by DoT and DoH servers") + } + resolver.ServerAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(resolver.Info.Port))) + } + + return nil +} + +func parsePortFromURL(url *url.URL) (uint16, error) { + var port uint16 + hostPort := url.Port() + if hostPort != "" { + // There is a port in the url + parsedPort, err := strconv.ParseUint(hostPort, 10, 16) + if err != nil { + return 0, fmt.Errorf("invalid port %q", url.Port()) + } + port = uint16(parsedPort) + } else { + // set the default port for the protocol + switch { + case url.Scheme == ServerTypeDNS, url.Scheme == ServerTypeTCP: + port = 53 + case url.Scheme == ServerTypeDoH: + port = 443 + case url.Scheme == ServerTypeDoT: + port = 853 + default: + return 0, fmt.Errorf("cannot determine port for %q", url.Scheme) + } + } + + return port, nil +} + func configureSearchDomains(resolver *Resolver, searches []string, hardfail bool) error { resolver.Search = make([]string, 0, len(searches)) diff --git a/resolver/scopes.go b/resolver/scopes.go index 883bc0ad..07769fd4 100644 --- a/resolver/scopes.go +++ b/resolver/scopes.go @@ -220,6 +220,13 @@ addNextResolver: } } + // the domains from the configured resolvers should not be resolved with the same resolvers + if resolver.Info.Source == ServerSourceConfigured && resolver.Info.IP == nil { + if _, ok := resolverInitDomains[q.FQDN]; ok { + continue addNextResolver + } + } + // add compliant and unique resolvers to selected resolvers selected = append(selected, resolver) }