mirror of
https://github.com/safing/portmaster
synced 2025-09-01 18:19:12 +00:00
Merge branch 'develop' into feature/my-profile-endpoint
This commit is contained in:
commit
ab68a07459
29 changed files with 517 additions and 244 deletions
|
@ -31,6 +31,7 @@ linters:
|
|||
- whitespace
|
||||
- wrapcheck
|
||||
- wsl
|
||||
- nolintlint
|
||||
|
||||
linters-settings:
|
||||
revive:
|
||||
|
|
|
@ -19,7 +19,7 @@ import ( //nolint:gci,nolintlint
|
|||
|
||||
func main() {
|
||||
// set information
|
||||
info.Set("Portmaster", "0.8.13", "AGPLv3", true)
|
||||
info.Set("Portmaster", "0.9.0", "AGPLv3", true)
|
||||
|
||||
// Configure metrics.
|
||||
_ = metrics.SetNamespace("portmaster")
|
||||
|
|
|
@ -77,7 +77,7 @@ func main() {
|
|||
cobra.OnInitialize(initCobra)
|
||||
|
||||
// set meta info
|
||||
info.Set("Portmaster Start", "0.8.8", "AGPLv3", false)
|
||||
info.Set("Portmaster Start", "0.9.0", "AGPLv3", false)
|
||||
|
||||
// catch interrupt for clean shutdown
|
||||
signalCh := make(chan os.Signal, 2)
|
||||
|
|
|
@ -121,6 +121,21 @@ func getExecArgs(opts *Options, cmdArgs []string) []string {
|
|||
if stdinSignals {
|
||||
args = append(args, "--input-signals")
|
||||
}
|
||||
|
||||
if opts.Identifier == "app/portmaster-app.zip" {
|
||||
// see https://www.freedesktop.org/software/systemd/man/pam_systemd.html#type=
|
||||
if xdgSessionType := os.Getenv("XDG_SESSION_TYPE"); xdgSessionType == "wayland" {
|
||||
// we're running the Portmaster UI App under Wayland so make sure we add some arguments
|
||||
// required by Electron
|
||||
args = append(args,
|
||||
[]string{
|
||||
"--enable-features=UseOzonePlatform",
|
||||
"--ozone-platform=wayland",
|
||||
}...,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
args = append(args, cmdArgs...)
|
||||
return args
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sys/windows/svc"
|
||||
"golang.org/x/sys/windows/svc/debug"
|
||||
"golang.org/x/sys/windows/svc/eventlog"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -102,13 +101,6 @@ func runService(_ *cobra.Command, opts *Options, cmdArgs []string) error {
|
|||
svcRun = debug.Run
|
||||
}
|
||||
|
||||
// open eventlog
|
||||
elog, err := eventlog.Open(serviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open eventlog: %s", err)
|
||||
}
|
||||
defer elog.Close()
|
||||
|
||||
runWg.Add(2)
|
||||
finishWg.Add(1)
|
||||
|
||||
|
@ -134,7 +126,6 @@ func runService(_ *cobra.Command, opts *Options, cmdArgs []string) error {
|
|||
err = getShutdownError()
|
||||
if err != nil {
|
||||
log.Printf("%s service experienced an error: %s\n", serviceName, err)
|
||||
_ = elog.Error(1, fmt.Sprintf("%s experienced an error: %s", serviceName, err))
|
||||
}
|
||||
|
||||
return err
|
||||
|
|
|
@ -7,12 +7,12 @@ import (
|
|||
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
"github.com/safing/portmaster/updates"
|
||||
_ "github.com/safing/portmaster/broadcasts"
|
||||
_ "github.com/safing/portmaster/netenv"
|
||||
_ "github.com/safing/portmaster/netquery"
|
||||
_ "github.com/safing/portmaster/status"
|
||||
_ "github.com/safing/portmaster/ui"
|
||||
"github.com/safing/portmaster/updates"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
17
go.mod
17
go.mod
|
@ -15,24 +15,17 @@ require (
|
|||
github.com/miekg/dns v1.1.50
|
||||
github.com/oschwald/maxminddb-golang v1.9.0
|
||||
github.com/safing/portbase v0.14.5
|
||||
github.com/safing/spn v0.4.12
|
||||
github.com/safing/spn v0.4.13
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/spf13/cobra v1.5.0
|
||||
github.com/stretchr/testify v1.7.1
|
||||
github.com/stretchr/testify v1.8.0
|
||||
github.com/tannerryan/ring v1.1.2
|
||||
github.com/tevino/abool v1.2.0
|
||||
github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26
|
||||
golang.org/x/net v0.0.0-20220708220712-1185a9018129
|
||||
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f
|
||||
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e
|
||||
zombiezen.com/go/sqlite v0.10.0
|
||||
github.com/stretchr/testify v1.8.0
|
||||
github.com/tannerryan/ring v1.1.2
|
||||
github.com/tevino/abool v1.2.0
|
||||
github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26
|
||||
golang.org/x/net v0.0.0-20220622184535-263ec571b305
|
||||
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f
|
||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8
|
||||
zombiezen.com/go/sqlite v0.10.1
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -86,7 +79,7 @@ require (
|
|||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
|
||||
golang.org/x/tools v0.1.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.16.17 // indirect
|
||||
modernc.org/mathutil v1.4.1 // indirect
|
||||
modernc.org/memory v1.1.1 // indirect
|
||||
|
|
10
go.sum
10
go.sum
|
@ -869,6 +869,8 @@ github.com/safing/spn v0.4.7/go.mod h1:NoSG9K0OK9hrPC76yqWFS6RtvbqZdIc/KGOsC4T3h
|
|||
github.com/safing/spn v0.4.11/go.mod h1:nro/I6b2JnafeeqoMsQRqf6TaQeL9uLLZkUREtxLVDE=
|
||||
github.com/safing/spn v0.4.12 h1:Tw7TUZEZR4yZy7L+ICRCketDk5L5x0s0pvrSUHFaKs4=
|
||||
github.com/safing/spn v0.4.12/go.mod h1:AUNgBrRwCcspC98ljptDnrPuHLn/BHSG+rSprV/5Wlc=
|
||||
github.com/safing/spn v0.4.13 h1:5NXWUl/2EWyotrQhW3tD+3DYw7hEqQk0n0lHa+w4eFo=
|
||||
github.com/safing/spn v0.4.13/go.mod h1:rBeimIc1FHQOhX7lTh/LaFGRotmnwZIDWUSsPyeIDog=
|
||||
github.com/sagikazarmark/crypt v0.3.0/go.mod h1:uD/D+6UF4SrIR1uGEv7bBNkNqLGqUr43MRiaGWX1Nig=
|
||||
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
|
@ -1230,6 +1232,8 @@ golang.org/x/net v0.0.0-20220412020605-290c469a71a5/go.mod h1:CfG3xpIq0wQ8r1q4Su
|
|||
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220621193019-9d032be2e588/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220708220712-1185a9018129 h1:vucSRfWwTsoXro7P+3Cjlr6flUMtzCwzlvkxEQtHHB0=
|
||||
golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
@ -1391,6 +1395,8 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
|
@ -1748,5 +1754,5 @@ rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
|||
sigs.k8s.io/structured-merge-diff/v4 v4.0.2/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw=
|
||||
sigs.k8s.io/structured-merge-diff/v4 v4.1.2/go.mod h1:j/nl6xW8vLS49O8YvXW1ocPhZawJtm+Yrr7PPRQ0Vg4=
|
||||
sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc=
|
||||
zombiezen.com/go/sqlite v0.10.0 h1:hegW0Y8c/fSJ2VjbjBeiKJaQOISNr4EUTx1VZx94Q9Y=
|
||||
zombiezen.com/go/sqlite v0.10.0/go.mod h1:tOd9u3peffVYnXOedepSJmX92n/mbqf594wcJ+29jf8=
|
||||
zombiezen.com/go/sqlite v0.10.1 h1:PSgVSHeIVOGKbX7ZIQNXGKn3wcqM6JBnT4yS1OLjWbM=
|
||||
zombiezen.com/go/sqlite v0.10.1/go.mod h1:tOd9u3peffVYnXOedepSJmX92n/mbqf594wcJ+29jf8=
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/safing/portmaster/netquery/orm"
|
||||
)
|
||||
|
||||
// ChartHandler handles requests for connection charts.
|
||||
type ChartHandler struct {
|
||||
Database *Database
|
||||
}
|
||||
|
@ -55,14 +56,14 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
|||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
enc.Encode(map[string]interface{}{
|
||||
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
|
||||
"results": result,
|
||||
"query": query,
|
||||
"params": paramMap,
|
||||
})
|
||||
}
|
||||
|
||||
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) {
|
||||
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
|
||||
var body io.Reader
|
||||
|
||||
switch req.Method {
|
||||
|
|
|
@ -9,13 +9,14 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
"github.com/safing/portmaster/network"
|
||||
"github.com/safing/portmaster/network/netutils"
|
||||
"github.com/safing/portmaster/network/packet"
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
// InMemory is the "file path" to open a new in-memory database.
|
||||
|
@ -36,7 +37,7 @@ var ConnectionTypeToString = map[network.ConnectionType]string{
|
|||
|
||||
type (
|
||||
// Database represents a SQLite3 backed connection database.
|
||||
// It's use is tailored for persistance and querying of network.Connection.
|
||||
// It's use is tailored for persistence and querying of network.Connection.
|
||||
// Access to the underlying SQLite database is synchronized.
|
||||
//
|
||||
// TODO(ppacher): somehow I'm receiving SIGBUS or SIGSEGV when no doing
|
||||
|
@ -57,7 +58,7 @@ type (
|
|||
//
|
||||
// Use ConvertConnection from this package to convert a network.Connection to this
|
||||
// representation.
|
||||
Conn struct {
|
||||
Conn struct { //nolint:maligned
|
||||
// ID is a device-unique identifier for the connection. It is built
|
||||
// from network.Connection by hashing the connection ID and the start
|
||||
// time. We cannot just use the network.Connection.ID because it is only unique
|
||||
|
@ -93,11 +94,11 @@ type (
|
|||
ProfileRevision int `sqlite:"profile_revision"`
|
||||
ExitNode *string `sqlite:"exit_node"`
|
||||
|
||||
// FIXME(ppacher): support "NOT" in search query to get rid of the following helper fields
|
||||
// TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
|
||||
SPNUsed bool `sqlite:"spn_used"` // could use "exit_node IS NOT NULL" or "exit IS NULL"
|
||||
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
|
||||
|
||||
// FIXME(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here
|
||||
// TODO(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here
|
||||
ProfileName string `sqlite:"profile_name"`
|
||||
}
|
||||
)
|
||||
|
@ -153,9 +154,9 @@ func NewInMemory() (*Database, error) {
|
|||
// to bring db up-to-date with the built-in schema.
|
||||
// TODO(ppacher): right now this only applies the current schema and ignores
|
||||
// any data-migrations. Once the history module is implemented this should
|
||||
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration
|
||||
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration.
|
||||
func (db *Database) ApplyMigrations() error {
|
||||
// get the create-table SQL statement from the infered schema
|
||||
// get the create-table SQL statement from the inferred schema
|
||||
sql := db.Schema.CreateStatement(false)
|
||||
|
||||
// execute the SQL
|
||||
|
@ -234,7 +235,7 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro
|
|||
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database
|
||||
// as JSON to w.
|
||||
// Any error aborts dumping rows and is returned.
|
||||
func (db *Database) dumpTo(ctx context.Context, w io.Writer) error {
|
||||
func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:unused
|
||||
db.l.Lock()
|
||||
defer db.l.Unlock()
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
|
|||
}
|
||||
}
|
||||
|
||||
func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn Conn) error {
|
||||
func (mng *Manager) pushConnUpdate(_ context.Context, meta record.Meta, conn Conn) error {
|
||||
blob, err := json.Marshal(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal connection: %w", err)
|
||||
|
@ -173,17 +173,19 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
|
|||
c.Type = "dns"
|
||||
case network.IPConnection:
|
||||
c.Type = "ip"
|
||||
case network.Undefined:
|
||||
c.Type = ""
|
||||
}
|
||||
|
||||
switch conn.Verdict {
|
||||
case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
|
||||
accepted := true
|
||||
c.Allowed = &accepted
|
||||
case network.VerdictUndecided, network.VerdictUndeterminable:
|
||||
c.Allowed = nil
|
||||
default:
|
||||
case network.VerdictBlock, network.VerdictDrop:
|
||||
allowed := false
|
||||
c.Allowed = &allowed
|
||||
case network.VerdictUndecided, network.VerdictUndeterminable, network.VerdictFailed:
|
||||
c.Allowed = nil
|
||||
}
|
||||
|
||||
if conn.Ended > 0 {
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
"github.com/safing/portmaster/network"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
type module struct {
|
||||
*modules.Module
|
||||
|
||||
db *database.Interface
|
||||
|
@ -25,19 +25,19 @@ type Module struct {
|
|||
}
|
||||
|
||||
func init() {
|
||||
mod := new(Module)
|
||||
mod.Module = modules.Register(
|
||||
m := new(module)
|
||||
m.Module = modules.Register(
|
||||
"netquery",
|
||||
mod.Prepare,
|
||||
mod.Start,
|
||||
mod.Stop,
|
||||
m.prepare,
|
||||
m.start,
|
||||
m.stop,
|
||||
"api",
|
||||
"network",
|
||||
"database",
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Module) Prepare() error {
|
||||
func (m *module) prepare() error {
|
||||
var err error
|
||||
|
||||
m.db = database.NewInterface(&database.Options{
|
||||
|
@ -66,7 +66,6 @@ func (m *Module) Prepare() error {
|
|||
Database: m.sqlStore,
|
||||
}
|
||||
|
||||
// FIXME(ppacher): use appropriate permissions for this
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "netquery/query",
|
||||
MimeType: "application/json",
|
||||
|
@ -96,13 +95,15 @@ func (m *Module) Prepare() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (mod *Module) Start() error {
|
||||
mod.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error {
|
||||
sub, err := mod.db.Subscribe(query.New("network:"))
|
||||
func (m *module) start() error {
|
||||
m.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error {
|
||||
sub, err := m.db.Subscribe(query.New("network:"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to network tree: %w", err)
|
||||
}
|
||||
defer sub.Cancel()
|
||||
defer func() {
|
||||
_ = sub.Cancel()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
|
@ -120,24 +121,24 @@ func (mod *Module) Start() error {
|
|||
continue
|
||||
}
|
||||
|
||||
mod.feed <- conn
|
||||
m.feed <- conn
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mod.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error {
|
||||
mod.mng.HandleFeed(ctx, mod.feed)
|
||||
m.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error {
|
||||
m.mng.HandleFeed(ctx, m.feed)
|
||||
return nil
|
||||
})
|
||||
|
||||
mod.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error {
|
||||
m.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(10 * time.Second):
|
||||
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
|
||||
count, err := mod.sqlStore.Cleanup(ctx, threshold)
|
||||
count, err := m.sqlStore.Cleanup(ctx, threshold)
|
||||
if err != nil {
|
||||
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
||||
} else {
|
||||
|
@ -147,19 +148,21 @@ func (mod *Module) Start() error {
|
|||
}
|
||||
})
|
||||
|
||||
// for debugging, we provide a simple direct SQL query interface using
|
||||
// the runtime database
|
||||
// FIXME: Expose only in dev mode.
|
||||
_, err := NewRuntimeQueryRunner(mod.sqlStore, "netquery/query/", runtime.DefaultRegistry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
|
||||
// For debugging, provide a simple direct SQL query interface using
|
||||
// the runtime database.
|
||||
// Only expose in development mode.
|
||||
if config.GetAsBool(config.CfgDevModeKey, false)() {
|
||||
_, err := NewRuntimeQueryRunner(m.sqlStore, "netquery/query/", runtime.DefaultRegistry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mod *Module) Stop() error {
|
||||
close(mod.feed)
|
||||
func (m *module) stop() error {
|
||||
close(m.feed)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -30,7 +29,7 @@ var (
|
|||
// TEXT or REAL.
|
||||
// This package provides support for time.Time being stored as TEXT (using a
|
||||
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
|
||||
// unixepoch and unixnano-epoch where the nano variant is not offically supported by
|
||||
// unixepoch and unixnano-epoch where the nano variant is not officially supported by
|
||||
// SQLITE).
|
||||
SqliteTimeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
@ -54,6 +53,7 @@ type (
|
|||
// DecodeFunc is called for each non-basic type during decoding.
|
||||
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
|
||||
|
||||
// DecodeConfig holds decoding functions.
|
||||
DecodeConfig struct {
|
||||
DecodeHooks []DecodeFunc
|
||||
}
|
||||
|
@ -170,7 +170,8 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte
|
|||
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
|
||||
}
|
||||
|
||||
//log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
|
||||
// Debugging:
|
||||
// log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
|
||||
|
||||
// convert it to the target type if conversion is possible
|
||||
newValue := reflect.ValueOf(columnValue)
|
||||
|
@ -189,8 +190,7 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte
|
|||
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
|
||||
// decide between Unix or UnixNano epoch timestamps.
|
||||
//
|
||||
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing
|
||||
//
|
||||
// TODO(ppacher): update comment about loc parameter and TEXT storage class parsing.
|
||||
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
|
||||
// if we have the column definition available we
|
||||
|
@ -203,11 +203,11 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
|||
|
||||
// we only care about "time.Time" here
|
||||
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
|
||||
log.Printf("not decoding %s %v", outType, colDef)
|
||||
// log.Printf("not decoding %s %v", outType, colDef)
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch stmt.ColumnType(colIdx) {
|
||||
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // Only selecting specific types.
|
||||
case sqlite.TypeInteger:
|
||||
// stored as unix-epoch, if unixnano is set in the struct field tag
|
||||
// we parse it with nano-second resolution
|
||||
|
@ -242,7 +242,7 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func decodeIntoMap(ctx context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
|
||||
func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
|
||||
if *mp == nil {
|
||||
*mp = make(map[string]interface{})
|
||||
}
|
||||
|
@ -292,7 +292,7 @@ func decodeBasic() DecodeFunc {
|
|||
if colDef != nil {
|
||||
valueKind = normalizeKind(colDef.GoType.Kind())
|
||||
|
||||
// if we have a column defintion we try to convert the value to
|
||||
// if we have a column definition we try to convert the value to
|
||||
// the actual Go-type that was used in the model.
|
||||
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
|
||||
// or that type aliases like (type myInt int) are decoded into myInt instead of int
|
||||
|
@ -314,7 +314,7 @@ func decodeBasic() DecodeFunc {
|
|||
}()
|
||||
}
|
||||
|
||||
log.Printf("decoding %s into kind %s", colName, valueKind)
|
||||
// log.Printf("decoding %s into kind %s", colName, valueKind)
|
||||
|
||||
if colType == sqlite.TypeNull {
|
||||
if colDef != nil && colDef.Nullable {
|
||||
|
@ -330,7 +330,7 @@ func decodeBasic() DecodeFunc {
|
|||
}
|
||||
}
|
||||
|
||||
switch valueKind {
|
||||
switch valueKind { //nolint:exhaustive
|
||||
case reflect.String:
|
||||
if colType != sqlite.TypeText {
|
||||
return nil, false, errInvalidType
|
||||
|
@ -455,7 +455,7 @@ func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.S
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// getKind returns the kind of value but normalized Int, Uint and Float varaints
|
||||
// getKind returns the kind of value but normalized Int, Uint and Float variants.
|
||||
// to their base type.
|
||||
func getKind(val reflect.Value) reflect.Kind {
|
||||
kind := val.Kind()
|
||||
|
@ -475,6 +475,7 @@ func normalizeKind(kind reflect.Kind) reflect.Kind {
|
|||
}
|
||||
}
|
||||
|
||||
// DefaultDecodeConfig holds the default decoding configuration.
|
||||
var DefaultDecodeConfig = DecodeConfig{
|
||||
DecodeHooks: []DecodeFunc{
|
||||
DatetimeDecoder(time.UTC),
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -21,14 +20,14 @@ type testStmt struct {
|
|||
|
||||
func (ts testStmt) ColumnCount() int { return len(ts.columns) }
|
||||
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
|
||||
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) }
|
||||
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) }
|
||||
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) }
|
||||
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) }
|
||||
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) }
|
||||
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } //nolint:forcetypeassert
|
||||
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
|
||||
|
||||
// compile time check
|
||||
// Compile time check.
|
||||
var _ Stmt = new(testStmt)
|
||||
|
||||
type exampleFieldTypes struct {
|
||||
|
@ -98,10 +97,11 @@ func (etn *exampleTimeNano) Equal(other interface{}) bool {
|
|||
return etn.T.Equal(oetn.T)
|
||||
}
|
||||
|
||||
func Test_Decoder(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
func TestDecoder(t *testing.T) { //nolint:maintidx,tparallel
|
||||
t.Parallel()
|
||||
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
|
@ -433,8 +433,7 @@ func Test_Decoder(t *testing.T) {
|
|||
nil,
|
||||
&exampleInterface{},
|
||||
func() interface{} {
|
||||
var x interface{}
|
||||
x = "value2"
|
||||
var x interface{} = "value2"
|
||||
|
||||
return &exampleInterface{
|
||||
I: "value1",
|
||||
|
@ -546,12 +545,10 @@ func Test_Decoder(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
//t.Parallel()
|
||||
|
||||
log.Println(c.Desc)
|
||||
// log.Println(c.Desc)
|
||||
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
|
||||
if fn, ok := c.Expected.(func() interface{}); ok {
|
||||
c.Expected = fn()
|
||||
|
|
|
@ -10,8 +10,10 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
// EncodeFunc is called for each non-basic type during encoding.
|
||||
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
|
||||
|
||||
// EncodeConfig holds encoding functions.
|
||||
EncodeConfig struct {
|
||||
EncodeHooks []EncodeFunc
|
||||
}
|
||||
|
@ -69,6 +71,7 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode
|
|||
return res, nil
|
||||
}
|
||||
|
||||
// EncodeValue encodes the given value.
|
||||
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
|
||||
fieldValue := reflect.ValueOf(val)
|
||||
fieldType := reflect.TypeOf(val)
|
||||
|
@ -115,7 +118,7 @@ func encodeBasic() EncodeFunc {
|
|||
val = val.Elem()
|
||||
}
|
||||
|
||||
switch normalizeKind(kind) {
|
||||
switch normalizeKind(kind) { //nolint:exhaustive
|
||||
case reflect.String,
|
||||
reflect.Float64,
|
||||
reflect.Bool,
|
||||
|
@ -138,6 +141,7 @@ func encodeBasic() EncodeFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// DatetimeEncoder returns a new datetime encoder for the given time zone.
|
||||
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
||||
// if fieldType holds a pointer we need to dereference the value
|
||||
|
@ -149,7 +153,8 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
|||
|
||||
// we only care about "time.Time" here
|
||||
var t time.Time
|
||||
if ft == "time.Time" {
|
||||
switch {
|
||||
case ft == "time.Time":
|
||||
// handle the zero time as a NULL.
|
||||
if !val.IsValid() || val.IsZero() {
|
||||
return nil, true, nil
|
||||
|
@ -162,19 +167,19 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
|||
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
|
||||
}
|
||||
|
||||
} else if valType.Kind() == reflect.String && colDef.IsTime {
|
||||
case valType.Kind() == reflect.String && colDef.IsTime:
|
||||
var err error
|
||||
t, err = time.Parse(time.RFC3339, val.String())
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
|
||||
}
|
||||
|
||||
} else {
|
||||
default:
|
||||
// we don't care ...
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch colDef.Type {
|
||||
switch colDef.Type { //nolint:exhaustive
|
||||
case sqlite.TypeInteger:
|
||||
if colDef.UnixNano {
|
||||
return t.UnixNano(), true, nil
|
||||
|
@ -194,7 +199,7 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
|||
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
||||
if valType == nil {
|
||||
if !colDef.Nullable {
|
||||
switch colDef.Type {
|
||||
switch colDef.Type { //nolint:exhaustive
|
||||
case sqlite.TypeBlob:
|
||||
return []byte{}, true, nil
|
||||
case sqlite.TypeFloat:
|
||||
|
@ -225,6 +230,7 @@ func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value,
|
|||
return nil, false, nil
|
||||
}
|
||||
|
||||
// DefaultEncodeConfig holds the default encoding configuration.
|
||||
var DefaultEncodeConfig = EncodeConfig{
|
||||
EncodeHooks: []EncodeFunc{
|
||||
DatetimeEncoder(time.UTC),
|
||||
|
|
|
@ -9,9 +9,11 @@ import (
|
|||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
func Test_EncodeAsMap(t *testing.T) {
|
||||
func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
|
@ -114,11 +116,9 @@ func Test_EncodeAsMap(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
// t.Parallel()
|
||||
|
||||
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Expected, res)
|
||||
|
@ -126,9 +126,11 @@ func Test_EncodeAsMap(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_EncodeValue(t *testing.T) {
|
||||
func TestEncodeValue(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.TODO()
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
||||
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||
|
||||
cases := []struct {
|
||||
Desc string
|
||||
|
@ -247,11 +249,9 @@ func Test_EncodeValue(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
for idx := range cases { //nolint:paralleltest
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
//t.Parallel()
|
||||
|
||||
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.Output, res)
|
||||
|
|
|
@ -57,6 +57,8 @@ func WithNamedArgs(args map[string]interface{}) QueryOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithSchema returns a query option that adds the given table
|
||||
// schema to the query.
|
||||
func WithSchema(tbl TableSchema) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Schema = tbl
|
||||
|
@ -139,9 +141,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
|
|||
valElemType = valType.Elem()
|
||||
|
||||
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
var currentField reflect.Value
|
||||
|
||||
currentField = reflect.New(valElemType)
|
||||
currentField := reflect.New(valElemType)
|
||||
|
||||
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||
return err
|
||||
|
|
|
@ -10,10 +10,9 @@ import (
|
|||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
var (
|
||||
errSkipStructField = errors.New("struct field should be skipped")
|
||||
)
|
||||
var errSkipStructField = errors.New("struct field should be skipped")
|
||||
|
||||
// Struct Tags.
|
||||
var (
|
||||
TagUnixNano = "unixnano"
|
||||
TagPrimaryKey = "primary"
|
||||
|
@ -36,12 +35,14 @@ var sqlTypeMap = map[sqlite.ColumnType]string{
|
|||
}
|
||||
|
||||
type (
|
||||
// TableSchema defines a SQL table schema.
|
||||
TableSchema struct {
|
||||
Name string
|
||||
Columns []ColumnDef
|
||||
}
|
||||
|
||||
ColumnDef struct {
|
||||
// ColumnDef defines a SQL column.
|
||||
ColumnDef struct { //nolint:maligned
|
||||
Name string
|
||||
Nullable bool
|
||||
Type sqlite.ColumnType
|
||||
|
@ -54,6 +55,7 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
// GetColumnDef returns the column definition with the given name.
|
||||
func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
|
||||
for _, def := range ts.Columns {
|
||||
if def.Name == name {
|
||||
|
@ -63,6 +65,7 @@ func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
|
|||
return nil
|
||||
}
|
||||
|
||||
// CreateStatement build the CREATE SQL statement for the table.
|
||||
func (ts TableSchema) CreateStatement(ifNotExists bool) string {
|
||||
sql := "CREATE TABLE"
|
||||
if ifNotExists {
|
||||
|
@ -81,6 +84,7 @@ func (ts TableSchema) CreateStatement(ifNotExists bool) string {
|
|||
return sql
|
||||
}
|
||||
|
||||
// AsSQL builds the SQL column definition.
|
||||
func (def ColumnDef) AsSQL() string {
|
||||
sql := def.Name + " "
|
||||
|
||||
|
@ -103,6 +107,7 @@ func (def ColumnDef) AsSQL() string {
|
|||
return sql
|
||||
}
|
||||
|
||||
// GenerateTableSchema generates a table schema from the given struct.
|
||||
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
|
||||
ts := &TableSchema{
|
||||
Name: name,
|
||||
|
@ -149,7 +154,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
|
|||
def.GoType = ft
|
||||
kind := normalizeKind(ft.Kind())
|
||||
|
||||
switch kind {
|
||||
switch kind { //nolint:exhaustive
|
||||
case reflect.Int:
|
||||
def.Type = sqlite.TypeInteger
|
||||
|
||||
|
@ -190,7 +195,7 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
|
|||
if len(parts) > 1 {
|
||||
for _, k := range parts[1:] {
|
||||
switch k {
|
||||
// column modifieres
|
||||
// column modifiers
|
||||
case TagPrimaryKey:
|
||||
def.PrimaryKey = true
|
||||
case TagAutoIncrement:
|
||||
|
|
|
@ -6,7 +6,9 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_SchemaBuilder(t *testing.T) {
|
||||
func TestSchemaBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
Model interface{}
|
||||
|
|
|
@ -5,15 +5,19 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
"zombiezen.com/go/sqlite"
|
||||
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
)
|
||||
|
||||
// Collection of Query and Matcher types.
|
||||
// NOTE: whenever adding support for new operators make sure
|
||||
// to update UnmarshalJSON as well.
|
||||
//nolint:golint
|
||||
type (
|
||||
Query map[string][]Matcher
|
||||
|
||||
|
@ -43,8 +47,6 @@ type (
|
|||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
// NOTE: whenever adding support for new operators make sure
|
||||
// to update UnmarshalJSON as well.
|
||||
Select struct {
|
||||
Field string `json:"field"`
|
||||
Count *Count `json:"$count,omitempty"`
|
||||
|
@ -91,6 +93,7 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
// UnmarshalJSON unmarshals a Query from json.
|
||||
func (query *Query) UnmarshalJSON(blob []byte) error {
|
||||
if *query == nil {
|
||||
*query = make(Query)
|
||||
|
@ -202,13 +205,14 @@ func parseMatcher(raw json.RawMessage) (*Matcher, error) {
|
|||
}
|
||||
|
||||
if err := m.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid query matcher: %s", err)
|
||||
return nil, fmt.Errorf("invalid query matcher: %w", err)
|
||||
}
|
||||
log.Printf("parsed matcher %s: %+v", string(raw), m)
|
||||
return &m, nil
|
||||
|
||||
// log.Printf("parsed matcher %s: %+v", string(raw), m)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// Validate validates the matcher.
|
||||
func (match Matcher) Validate() error {
|
||||
found := 0
|
||||
|
||||
|
@ -239,9 +243,9 @@ func (match Matcher) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (text TextSearch) toSQLConditionClause(ctx context.Context, schema *orm.TableSchema, suffix string, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
var (
|
||||
queryParts []string
|
||||
queryParts = make([]string, 0, len(text.Fields))
|
||||
params = make(map[string]interface{})
|
||||
)
|
||||
|
||||
|
@ -379,7 +383,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T
|
|||
// merge parameters up into the superior parameter map
|
||||
for key, val := range params {
|
||||
if _, ok := paramMap[key]; ok {
|
||||
// is is soley a developer mistake when implementing a matcher so no forgiving ...
|
||||
// This is solely a developer mistake when implementing a matcher so no forgiving ...
|
||||
panic("sqlite parameter collision")
|
||||
}
|
||||
|
||||
|
@ -399,6 +403,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T
|
|||
return whereClause, paramMap, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a Selects from json.
|
||||
func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
|
@ -438,6 +443,7 @@ func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a Select from json.
|
||||
func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
|
@ -481,6 +487,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a OrderBys from json.
|
||||
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
|
@ -523,6 +530,7 @@ func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals a OrderBy from json.
|
||||
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
|
||||
if len(blob) == 0 {
|
||||
return io.ErrUnexpectedEOF
|
||||
|
|
|
@ -17,9 +17,7 @@ import (
|
|||
"github.com/safing/portmaster/netquery/orm"
|
||||
)
|
||||
|
||||
var (
|
||||
charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
|
||||
)
|
||||
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
|
||||
|
||||
type (
|
||||
|
||||
|
@ -109,7 +107,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) {
|
||||
func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { //nolint:dupl
|
||||
var body io.Reader
|
||||
|
||||
switch req.Method {
|
||||
|
@ -230,11 +228,11 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
|
|||
|
||||
switch {
|
||||
case s.Count != nil:
|
||||
var as = s.Count.As
|
||||
as := s.Count.As
|
||||
if as == "" {
|
||||
as = fmt.Sprintf("%s_count", colName)
|
||||
}
|
||||
var distinct = ""
|
||||
distinct := ""
|
||||
if s.Count.Distinct {
|
||||
distinct = "DISTINCT "
|
||||
}
|
||||
|
@ -278,8 +276,7 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
|
|||
return "", nil
|
||||
}
|
||||
|
||||
var groupBys = make([]string, len(req.GroupBy))
|
||||
|
||||
groupBys := make([]string, len(req.GroupBy))
|
||||
for idx, name := range req.GroupBy {
|
||||
colName, err := req.validateColumnName(schema, name)
|
||||
if err != nil {
|
||||
|
@ -288,7 +285,6 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
|
|||
|
||||
groupBys[idx] = colName
|
||||
}
|
||||
|
||||
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
|
||||
|
||||
// if there are no explicitly selected fields we default to the
|
||||
|
@ -301,7 +297,7 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
|
|||
}
|
||||
|
||||
func (req *QueryRequestPayload) generateSelectClause() string {
|
||||
var selectClause = "*"
|
||||
selectClause := "*"
|
||||
if len(req.selectedFields) > 0 {
|
||||
selectClause = strings.Join(req.selectedFields, ", ")
|
||||
}
|
||||
|
@ -314,7 +310,7 @@ func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (
|
|||
return "", nil
|
||||
}
|
||||
|
||||
var orderBys = make([]string, len(req.OrderBy))
|
||||
orderBys := make([]string, len(req.OrderBy))
|
||||
for idx, sort := range req.OrderBy {
|
||||
colName, err := req.validateColumnName(schema, sort.Field)
|
||||
if err != nil {
|
||||
|
@ -352,5 +348,5 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
|
|||
return "", fmt.Errorf("column name %q not allowed", field)
|
||||
}
|
||||
|
||||
// compile time check
|
||||
// Compile time check.
|
||||
var _ http.Handler = new(QueryHandler)
|
||||
|
|
|
@ -7,13 +7,16 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
)
|
||||
|
||||
func Test_UnmarshalQuery(t *testing.T) {
|
||||
var cases = []struct {
|
||||
func TestUnmarshalQuery(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
Input string
|
||||
Expected Query
|
||||
|
@ -88,7 +91,8 @@ func Test_UnmarshalQuery(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
for _, testCase := range cases { //nolint:paralleltest
|
||||
c := testCase
|
||||
t.Run(c.Name, func(t *testing.T) {
|
||||
var q Query
|
||||
err := json.Unmarshal([]byte(c.Input), &q)
|
||||
|
@ -105,10 +109,11 @@ func Test_UnmarshalQuery(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_QueryBuilder(t *testing.T) {
|
||||
now := time.Now()
|
||||
func TestQueryBuilder(t *testing.T) { //nolint:tparallel
|
||||
t.Parallel()
|
||||
|
||||
var cases = []struct {
|
||||
now := time.Now()
|
||||
cases := []struct {
|
||||
N string
|
||||
Q Query
|
||||
R string
|
||||
|
@ -186,7 +191,7 @@ func Test_QueryBuilder(t *testing.T) {
|
|||
},
|
||||
"",
|
||||
nil,
|
||||
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"),
|
||||
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), //nolint:golint
|
||||
},
|
||||
{
|
||||
"Complex example",
|
||||
|
@ -225,19 +230,20 @@ func Test_QueryBuilder(t *testing.T) {
|
|||
tbl, err := orm.GenerateTableSchema("connections", Conn{})
|
||||
require.NoError(t, err)
|
||||
|
||||
for idx, c := range cases {
|
||||
for idx, testCase := range cases { //nolint:paralleltest
|
||||
cID := idx
|
||||
c := testCase
|
||||
t.Run(c.N, func(t *testing.T) {
|
||||
//t.Parallel()
|
||||
str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
|
||||
|
||||
if c.E != nil {
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", idx)
|
||||
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", cID)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err, "test case %d", idx)
|
||||
assert.Equal(t, c.P, params, "test case %d", idx)
|
||||
assert.Equal(t, c.R, str, "test case %d", idx)
|
||||
assert.NoError(t, err, "test case %d", cID)
|
||||
assert.Equal(t, c.P, params, "test case %d", cID)
|
||||
assert.Equal(t, c.R, str, "test case %d", cID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -11,8 +11,11 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
cleanerTickDuration = 5 * time.Second
|
||||
// DeleteConnsAfterEndedThreshold defines the amount of time after which
|
||||
// ended connections should be removed from the internal connection state.
|
||||
DeleteConnsAfterEndedThreshold = 10 * time.Minute
|
||||
|
||||
cleanerTickDuration = 5 * time.Second
|
||||
)
|
||||
|
||||
func connectionCleaner(ctx context.Context) error {
|
||||
|
|
|
@ -112,7 +112,7 @@ The format is: "protocol://ip:port?parameter=value¶meter=value"
|
|||
ExpertiseLevel: config.ExpertiseLevelUser,
|
||||
ReleaseLevel: config.ReleaseLevelStable,
|
||||
DefaultValue: defaultNameServers,
|
||||
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)://.*", ServerTypeDoT, ServerTypeDNS, ServerTypeTCP),
|
||||
ValidationRegex: fmt.Sprintf("^(%s|%s|%s|%s|%s|%s)://.*", ServerTypeDoT, ServerTypeDoH, ServerTypeDNS, ServerTypeTCP, HTTPSProtocol, TLSProtocol),
|
||||
ValidationFunc: validateNameservers,
|
||||
Annotations: config.Annotations{
|
||||
config.DisplayHintAnnotation: config.DisplayHintOrdered,
|
||||
|
@ -123,32 +123,31 @@ The format is: "protocol://ip:port?parameter=value¶meter=value"
|
|||
Name: "Cloudflare (with Malware Filter)",
|
||||
Action: config.QuickReplace,
|
||||
Value: []string{
|
||||
"dot://1.1.1.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip",
|
||||
"dot://1.0.0.2:853?verify=cloudflare-dns.com&name=Cloudflare&blockedif=zeroip",
|
||||
"dot://cloudflare-dns.com?ip=1.1.1.2&name=Cloudflare&blockedif=zeroip",
|
||||
"dot://cloudflare-dns.com?ip=1.0.0.2&name=Cloudflare&blockedif=zeroip",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Quad9",
|
||||
Action: config.QuickReplace,
|
||||
Value: []string{
|
||||
"dot://9.9.9.9:853?verify=dns.quad9.net&name=Quad9&blockedif=empty",
|
||||
"dot://149.112.112.112:853?verify=dns.quad9.net&name=Quad9&blockedif=empty",
|
||||
"dot://dns.quad9.net?ip=9.9.9.9&name=Quad9&blockedif=empty",
|
||||
"dot://dns.quad9.net?ip=149.112.112.112&name=Quad9&blockedif=empty",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "AdGuard",
|
||||
Action: config.QuickReplace,
|
||||
Value: []string{
|
||||
"dot://94.140.14.14:853?verify=dns.adguard.com&name=AdGuard&blockedif=zeroip",
|
||||
"dot://94.140.15.15:853?verify=dns.adguard.com&name=AdGuard&blockedif=zeroip",
|
||||
"dot://dns.adguard.com?ip=94.140.14.14&name=AdGuard&blockedif=zeroip",
|
||||
"dot://dns.adguard.com?ip=94.140.15.15&name=AdGuard&blockedif=zeroip",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Foundation for Applied Privacy",
|
||||
Action: config.QuickReplace,
|
||||
Value: []string{
|
||||
"dot://94.130.106.88:853?verify=dot1.applied-privacy.net&name=AppliedPrivacy",
|
||||
"dot://94.130.106.88:443?verify=dot1.applied-privacy.net&name=AppliedPrivacy",
|
||||
"dot://dot1.applied-privacy.net?ip=94.130.106.88&name=AppliedPrivacy",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
120
resolver/resolver-https.go
Normal file
120
resolver/resolver-https.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// HTTPSResolver is a resolver using just a single tcp connection with pipelining.
|
||||
type HTTPSResolver struct {
|
||||
BasicResolverConn
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
// HTTPSQuery holds the query information for a hTTPSResolverConn.
|
||||
type HTTPSQuery struct {
|
||||
Query *Query
|
||||
Response chan *dns.Msg
|
||||
}
|
||||
|
||||
// MakeCacheRecord creates an RRCache record from a reply.
|
||||
func (tq *HTTPSQuery) MakeCacheRecord(reply *dns.Msg, resolverInfo *ResolverInfo) *RRCache {
|
||||
return &RRCache{
|
||||
Domain: tq.Query.FQDN,
|
||||
Question: tq.Query.QType,
|
||||
RCode: reply.Rcode,
|
||||
Answer: reply.Answer,
|
||||
Ns: reply.Ns,
|
||||
Extra: reply.Extra,
|
||||
Resolver: resolverInfo.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPSResolver returns a new HTTPSResolver.
|
||||
func NewHTTPSResolver(resolver *Resolver) *HTTPSResolver {
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: resolver.Info.Domain,
|
||||
// TODO: use portbase rng
|
||||
},
|
||||
IdleConnTimeout: 3 * time.Minute,
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: tr}
|
||||
newResolver := &HTTPSResolver{
|
||||
BasicResolverConn: BasicResolverConn{
|
||||
resolver: resolver,
|
||||
},
|
||||
Client: client,
|
||||
}
|
||||
newResolver.BasicResolverConn.init()
|
||||
return newResolver
|
||||
}
|
||||
|
||||
// Query executes the given query against the resolver.
|
||||
func (hr *HTTPSResolver) Query(ctx context.Context, q *Query) (*RRCache, error) {
|
||||
dnsQuery := new(dns.Msg)
|
||||
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
|
||||
|
||||
// Pack query and convert to base64 string
|
||||
buf, err := dnsQuery.Pack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b64dns := base64.RawStdEncoding.EncodeToString(buf)
|
||||
|
||||
// Build and execute http reuqest
|
||||
url := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: hr.resolver.ServerAddress,
|
||||
Path: hr.resolver.Path,
|
||||
ForceQuery: true,
|
||||
RawQuery: fmt.Sprintf("dns=%s", b64dns),
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := hr.Client.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Try to read the result
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reply := new(dns.Msg)
|
||||
|
||||
err = reply.Unpack(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRecord := &RRCache{
|
||||
Domain: q.FQDN,
|
||||
Question: q.QType,
|
||||
RCode: reply.Rcode,
|
||||
Answer: reply.Answer,
|
||||
Ns: reply.Ns,
|
||||
Extra: reply.Extra,
|
||||
Resolver: hr.resolver.Info.Copy(),
|
||||
}
|
||||
|
||||
// TODO: check if reply.Answer is valid
|
||||
return newRecord, nil
|
||||
}
|
|
@ -99,7 +99,7 @@ func (tr *TCPResolver) UseTLS() *TCPResolver {
|
|||
tr.dnsClient.Net = "tcp-tls"
|
||||
tr.dnsClient.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: tr.resolver.VerifyDomain,
|
||||
ServerName: tr.resolver.Info.Domain,
|
||||
// TODO: use portbase rng
|
||||
}
|
||||
return tr
|
||||
|
|
|
@ -30,6 +30,12 @@ const (
|
|||
ServerSourceEnv = "env"
|
||||
)
|
||||
|
||||
// DNS Resolver alias
|
||||
const (
|
||||
HTTPSProtocol = "https"
|
||||
TLSProtocol = "tls"
|
||||
)
|
||||
|
||||
// FailThreshold is amount of errors a resolvers must experience in order to be regarded as failed.
|
||||
var FailThreshold = 20
|
||||
|
||||
|
@ -61,9 +67,9 @@ type Resolver struct {
|
|||
UpstreamBlockDetection string
|
||||
|
||||
// Special Options
|
||||
VerifyDomain string
|
||||
Search []string
|
||||
SearchOnly bool
|
||||
Search []string
|
||||
SearchOnly bool
|
||||
Path string
|
||||
|
||||
// logic interface
|
||||
Conn ResolverConn `json:"-"`
|
||||
|
@ -87,6 +93,9 @@ type ResolverInfo struct { //nolint:golint,maligned // TODO
|
|||
// IP is the IP address of the resolver
|
||||
IP net.IP
|
||||
|
||||
// Domain of the dns server if it has one
|
||||
Domain string
|
||||
|
||||
// IPScope is the network scope of the IP address.
|
||||
IPScope netutils.IPScope
|
||||
|
||||
|
@ -107,6 +116,20 @@ func (info *ResolverInfo) ID() string {
|
|||
info.id = ServerTypeMDNS
|
||||
case ServerTypeEnv:
|
||||
info.id = ServerTypeEnv
|
||||
case ServerTypeDoH:
|
||||
info.id = fmt.Sprintf(
|
||||
"https://%s:%d#%s",
|
||||
info.Domain,
|
||||
info.Port,
|
||||
info.Source,
|
||||
)
|
||||
case ServerTypeDoT:
|
||||
info.id = fmt.Sprintf(
|
||||
"dot://%s:%d#%s",
|
||||
info.Domain,
|
||||
info.Port,
|
||||
info.Source,
|
||||
)
|
||||
default:
|
||||
info.id = fmt.Sprintf(
|
||||
"%s://%s:%d#%s",
|
||||
|
@ -135,6 +158,12 @@ func (info *ResolverInfo) DescriptiveName() string {
|
|||
info.Name,
|
||||
info.ID(),
|
||||
)
|
||||
case info.Domain != "":
|
||||
return fmt.Sprintf(
|
||||
"%s (%s)",
|
||||
info.Domain,
|
||||
info.ID(),
|
||||
)
|
||||
default:
|
||||
return fmt.Sprintf(
|
||||
"%s (%s)",
|
||||
|
@ -155,6 +184,7 @@ func (info *ResolverInfo) Copy() *ResolverInfo {
|
|||
Type: info.Type,
|
||||
Source: info.Source,
|
||||
IP: info.IP,
|
||||
Domain: info.Domain,
|
||||
IPScope: info.IPScope,
|
||||
Port: info.Port,
|
||||
id: info.id,
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
"golang.org/x/net/publicsuffix"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/utils"
|
||||
"github.com/safing/portmaster/netenv"
|
||||
|
@ -29,9 +30,11 @@ type Scope struct {
|
|||
const (
|
||||
parameterName = "name"
|
||||
parameterVerify = "verify"
|
||||
parameterIP = "ip"
|
||||
parameterBlockedIf = "blockedif"
|
||||
parameterSearch = "search"
|
||||
parameterSearchOnly = "search-only"
|
||||
parameterPath = "path"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -41,7 +44,9 @@ var (
|
|||
localScopes []*Scope // list of scopes with a list of local resolvers that can resolve the scope
|
||||
activeResolvers map[string]*Resolver // lookup map of all resolvers
|
||||
currentResolverConfig []string // current active resolver config, to detect changes
|
||||
resolversLock sync.RWMutex
|
||||
resolverInitDomains map[string]struct{} // a set with all domains of the dns resolvers
|
||||
|
||||
resolversLock sync.RWMutex
|
||||
)
|
||||
|
||||
func indexOfScope(domain string, list []*Scope) int {
|
||||
|
@ -80,6 +85,8 @@ func resolverConnFactory(resolver *Resolver) ResolverConn {
|
|||
return NewTCPResolver(resolver)
|
||||
case ServerTypeDoT:
|
||||
return NewTCPResolver(resolver).UseTLS()
|
||||
case ServerTypeDoH:
|
||||
return NewHTTPSResolver(resolver)
|
||||
case ServerTypeDNS:
|
||||
return NewPlainResolver(resolver)
|
||||
default:
|
||||
|
@ -93,93 +100,64 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
|
|||
return nil, false, err
|
||||
}
|
||||
|
||||
if resolverInitDomains == nil {
|
||||
resolverInitDomains = make(map[string]struct{})
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case ServerTypeDNS, ServerTypeDoT, ServerTypeTCP:
|
||||
case ServerTypeDNS, ServerTypeDoT, ServerTypeDoH, ServerTypeTCP:
|
||||
case HTTPSProtocol:
|
||||
u.Scheme = ServerTypeDoH
|
||||
case TLSProtocol:
|
||||
u.Scheme = ServerTypeDoT
|
||||
default:
|
||||
return nil, false, fmt.Errorf("DNS resolver scheme %q invalid", u.Scheme)
|
||||
}
|
||||
|
||||
ip := net.ParseIP(u.Hostname())
|
||||
if ip == nil {
|
||||
return nil, false, fmt.Errorf("resolver IP %q invalid", u.Hostname())
|
||||
}
|
||||
|
||||
// Add default port for scheme if it is missing.
|
||||
var port uint16
|
||||
hostPort := u.Port()
|
||||
switch {
|
||||
case hostPort != "":
|
||||
parsedPort, err := strconv.ParseUint(hostPort, 10, 16)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("resolver port %q invalid", u.Port())
|
||||
}
|
||||
port = uint16(parsedPort)
|
||||
case u.Scheme == ServerTypeDNS, u.Scheme == ServerTypeTCP:
|
||||
port = 53
|
||||
case u.Scheme == ServerTypeDoH:
|
||||
port = 443
|
||||
case u.Scheme == ServerTypeDoT:
|
||||
port = 853
|
||||
default:
|
||||
return nil, false, fmt.Errorf("missing port in %q", u.Host)
|
||||
}
|
||||
|
||||
scope := netutils.GetIPScope(ip)
|
||||
// Skip localhost resolvers from the OS, but not if configured.
|
||||
if scope.IsLocalhost() && source == ServerSourceOperatingSystem {
|
||||
return nil, true, nil // skip
|
||||
}
|
||||
|
||||
// Get parameters and check if keys exist.
|
||||
query := u.Query()
|
||||
for key := range query {
|
||||
switch key {
|
||||
case parameterName,
|
||||
parameterVerify,
|
||||
parameterBlockedIf,
|
||||
parameterSearch,
|
||||
parameterSearchOnly:
|
||||
// Known key, continue.
|
||||
default:
|
||||
// Unknown key, abort.
|
||||
return nil, false, fmt.Errorf(`unknown parameter "%s"`, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Check domain verification config.
|
||||
verifyDomain := query.Get(parameterVerify)
|
||||
if verifyDomain != "" && u.Scheme != ServerTypeDoT {
|
||||
return nil, false, fmt.Errorf("domain verification only supported in DOT")
|
||||
}
|
||||
if verifyDomain == "" && u.Scheme == ServerTypeDoT {
|
||||
return nil, false, fmt.Errorf("DOT must have a verify query parameter set")
|
||||
}
|
||||
|
||||
// Check block detection type.
|
||||
blockType := query.Get(parameterBlockedIf)
|
||||
if blockType == "" {
|
||||
blockType = BlockDetectionZeroIP
|
||||
}
|
||||
switch blockType {
|
||||
case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP:
|
||||
default:
|
||||
return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)")
|
||||
}
|
||||
|
||||
// Build resolver.
|
||||
// Create Resolver object
|
||||
newResolver := &Resolver{
|
||||
ConfigURL: resolverURL,
|
||||
Info: &ResolverInfo{
|
||||
Name: query.Get(parameterName),
|
||||
Type: u.Scheme,
|
||||
Source: source,
|
||||
IP: ip,
|
||||
IPScope: scope,
|
||||
Port: port,
|
||||
IP: nil,
|
||||
Domain: "",
|
||||
IPScope: netutils.Global,
|
||||
Port: 0,
|
||||
},
|
||||
ServerAddress: net.JoinHostPort(ip.String(), strconv.Itoa(int(port))),
|
||||
VerifyDomain: verifyDomain,
|
||||
UpstreamBlockDetection: blockType,
|
||||
ServerAddress: "",
|
||||
Path: u.Path, // Used for DoH
|
||||
UpstreamBlockDetection: "",
|
||||
}
|
||||
|
||||
// Get parameters and check if keys exist.
|
||||
err = checkAndSetResolverParamters(u, newResolver)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// Check block detection type.
|
||||
newResolver.UpstreamBlockDetection = query.Get(parameterBlockedIf)
|
||||
if newResolver.UpstreamBlockDetection == "" {
|
||||
newResolver.UpstreamBlockDetection = BlockDetectionZeroIP
|
||||
}
|
||||
|
||||
switch newResolver.UpstreamBlockDetection {
|
||||
case BlockDetectionDisabled, BlockDetectionEmptyAnswer, BlockDetectionRefused, BlockDetectionZeroIP:
|
||||
default:
|
||||
return nil, false, fmt.Errorf("invalid value for upstream block detection (blockedif=)")
|
||||
}
|
||||
|
||||
// Get ip scope if we have ip
|
||||
if newResolver.Info.IP != nil {
|
||||
newResolver.Info.IPScope = netutils.GetIPScope(newResolver.Info.IP)
|
||||
// Skip localhost resolvers from the OS, but not if configured.
|
||||
if newResolver.Info.IPScope.IsLocalhost() && source == ServerSourceOperatingSystem {
|
||||
return nil, true, nil // skip
|
||||
}
|
||||
}
|
||||
|
||||
// Parse search domains.
|
||||
|
@ -206,6 +184,108 @@ func createResolver(resolverURL, source string) (*Resolver, bool, error) {
|
|||
return newResolver, false, nil
|
||||
}
|
||||
|
||||
func checkAndSetResolverParamters(u *url.URL, resolver *Resolver) error {
|
||||
// Check if we are using domain name and if it's in a valid scheme
|
||||
ip := net.ParseIP(u.Hostname())
|
||||
hostnameIsDomaion := (ip == nil)
|
||||
if ip == nil && u.Scheme != ServerTypeDoH && u.Scheme != ServerTypeDoT {
|
||||
return fmt.Errorf("resolver IP %q is invalid", u.Hostname())
|
||||
}
|
||||
|
||||
// Add default port for scheme if it is missing.
|
||||
port, err := parsePortFromURL(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.Info.Port = port
|
||||
resolver.Info.IP = ip
|
||||
|
||||
query := u.Query()
|
||||
|
||||
for key := range query {
|
||||
switch key {
|
||||
case parameterName,
|
||||
parameterVerify,
|
||||
parameterIP,
|
||||
parameterBlockedIf,
|
||||
parameterSearch,
|
||||
parameterSearchOnly,
|
||||
parameterPath:
|
||||
// Known key, continue.
|
||||
default:
|
||||
// Unknown key, abort.
|
||||
return fmt.Errorf(`unknown parameter "%q"`, key)
|
||||
}
|
||||
}
|
||||
|
||||
resolver.Info.Domain = query.Get(parameterVerify)
|
||||
paramterServerIP := query.Get(parameterIP)
|
||||
|
||||
if u.Scheme == ServerTypeDoT || u.Scheme == ServerTypeDoH {
|
||||
// Check if IP and Domain are set correctly
|
||||
switch {
|
||||
case hostnameIsDomaion && resolver.Info.Domain != "":
|
||||
return fmt.Errorf("cannot set the domain name via both the hostname in the URL and the verify parameter")
|
||||
case !hostnameIsDomaion && resolver.Info.Domain == "":
|
||||
return fmt.Errorf("verify parameter must be set when using ip as domain")
|
||||
case !hostnameIsDomaion && paramterServerIP != "":
|
||||
return fmt.Errorf("cannot set the IP address via both the hostname in the URL and the ip parameter")
|
||||
}
|
||||
|
||||
// Parse and set IP and Domain to the resolver
|
||||
switch {
|
||||
case hostnameIsDomaion && paramterServerIP != "": // domain and ip as parameter
|
||||
resolver.Info.IP = net.ParseIP(paramterServerIP)
|
||||
resolver.ServerAddress = net.JoinHostPort(paramterServerIP, strconv.Itoa(int(resolver.Info.Port)))
|
||||
resolver.Info.Domain = u.Hostname()
|
||||
case !hostnameIsDomaion && resolver.Info.Domain != "": // ip and domain as parameter
|
||||
resolver.ServerAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(resolver.Info.Port)))
|
||||
case hostnameIsDomaion && resolver.Info.Domain == "" && paramterServerIP == "": // only domain
|
||||
resolver.Info.Domain = u.Hostname()
|
||||
resolver.ServerAddress = net.JoinHostPort(resolver.Info.Domain, strconv.Itoa(int(port)))
|
||||
}
|
||||
|
||||
if ip == nil {
|
||||
resolverInitDomains[dns.Fqdn(resolver.Info.Domain)] = struct{}{}
|
||||
}
|
||||
|
||||
} else {
|
||||
if resolver.Info.Domain != "" {
|
||||
return fmt.Errorf("domain verification is only supported by DoT and DoH servers")
|
||||
}
|
||||
resolver.ServerAddress = net.JoinHostPort(ip.String(), strconv.Itoa(int(resolver.Info.Port)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePortFromURL(url *url.URL) (uint16, error) {
|
||||
var port uint16
|
||||
hostPort := url.Port()
|
||||
if hostPort != "" {
|
||||
// There is a port in the url
|
||||
parsedPort, err := strconv.ParseUint(hostPort, 10, 16)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid port %q", url.Port())
|
||||
}
|
||||
port = uint16(parsedPort)
|
||||
} else {
|
||||
// set the default port for the protocol
|
||||
switch {
|
||||
case url.Scheme == ServerTypeDNS, url.Scheme == ServerTypeTCP:
|
||||
port = 53
|
||||
case url.Scheme == ServerTypeDoH:
|
||||
port = 443
|
||||
case url.Scheme == ServerTypeDoT:
|
||||
port = 853
|
||||
default:
|
||||
return 0, fmt.Errorf("cannot determine port for %q", url.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func configureSearchDomains(resolver *Resolver, searches []string, hardfail bool) error {
|
||||
resolver.Search = make([]string, 0, len(searches))
|
||||
|
||||
|
|
|
@ -220,6 +220,13 @@ addNextResolver:
|
|||
}
|
||||
}
|
||||
|
||||
// the domains from the configured resolvers should not be resolved with the same resolvers
|
||||
if resolver.Info.Source == ServerSourceConfigured && resolver.Info.IP == nil {
|
||||
if _, ok := resolverInitDomains[q.FQDN]; ok {
|
||||
continue addNextResolver
|
||||
}
|
||||
}
|
||||
|
||||
// add compliant and unique resolvers to selected resolvers
|
||||
selected = append(selected, resolver)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue