Fix linter errors from netquery implementation

This commit is contained in:
Daniel 2022-07-22 14:25:16 +02:00
parent 1889c68d27
commit 90d30c14a5
16 changed files with 163 additions and 132 deletions

View file

@ -7,12 +7,12 @@ import (
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/modules/subsystems"
"github.com/safing/portmaster/updates"
_ "github.com/safing/portmaster/broadcasts" _ "github.com/safing/portmaster/broadcasts"
_ "github.com/safing/portmaster/netenv" _ "github.com/safing/portmaster/netenv"
_ "github.com/safing/portmaster/netquery" _ "github.com/safing/portmaster/netquery"
_ "github.com/safing/portmaster/status" _ "github.com/safing/portmaster/status"
_ "github.com/safing/portmaster/ui" _ "github.com/safing/portmaster/ui"
"github.com/safing/portmaster/updates"
) )
const ( const (

View file

@ -14,6 +14,7 @@ import (
"github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/netquery/orm"
) )
// ChartHandler handles requests for connection charts.
type ChartHandler struct { type ChartHandler struct {
Database *Database Database *Database
} }
@ -55,14 +56,14 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
enc.SetEscapeHTML(false) enc.SetEscapeHTML(false)
enc.SetIndent("", " ") enc.SetIndent("", " ")
enc.Encode(map[string]interface{}{ _ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
"results": result, "results": result,
"query": query, "query": query,
"params": paramMap, "params": paramMap,
}) })
} }
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
var body io.Reader var body io.Reader
switch req.Method { switch req.Method {

View file

@ -9,13 +9,14 @@ import (
"sync" "sync"
"time" "time"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/netquery/orm"
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
"github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/netutils"
"github.com/safing/portmaster/network/packet" "github.com/safing/portmaster/network/packet"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
) )
// InMemory is the "file path" to open a new in-memory database. // InMemory is the "file path" to open a new in-memory database.
@ -36,7 +37,7 @@ var ConnectionTypeToString = map[network.ConnectionType]string{
type ( type (
// Database represents a SQLite3 backed connection database. // Database represents a SQLite3 backed connection database.
// It's use is tailored for persistance and querying of network.Connection. // It's use is tailored for persistence and querying of network.Connection.
// Access to the underlying SQLite database is synchronized. // Access to the underlying SQLite database is synchronized.
// //
// TODO(ppacher): somehow I'm receiving SIGBUS or SIGSEGV when no doing // TODO(ppacher): somehow I'm receiving SIGBUS or SIGSEGV when no doing
@ -57,7 +58,7 @@ type (
// //
// Use ConvertConnection from this package to convert a network.Connection to this // Use ConvertConnection from this package to convert a network.Connection to this
// representation. // representation.
Conn struct { Conn struct { //nolint:maligned
// ID is a device-unique identifier for the connection. It is built // ID is a device-unique identifier for the connection. It is built
// from network.Connection by hashing the connection ID and the start // from network.Connection by hashing the connection ID and the start
// time. We cannot just use the network.Connection.ID because it is only unique // time. We cannot just use the network.Connection.ID because it is only unique
@ -93,11 +94,11 @@ type (
ProfileRevision int `sqlite:"profile_revision"` ProfileRevision int `sqlite:"profile_revision"`
ExitNode *string `sqlite:"exit_node"` ExitNode *string `sqlite:"exit_node"`
// FIXME(ppacher): support "NOT" in search query to get rid of the following helper fields // TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
SPNUsed bool `sqlite:"spn_used"` // could use "exit_node IS NOT NULL" or "exit IS NULL" SPNUsed bool `sqlite:"spn_used"` // could use "exit_node IS NOT NULL" or "exit IS NULL"
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL" Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
// FIXME(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here // TODO(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here
ProfileName string `sqlite:"profile_name"` ProfileName string `sqlite:"profile_name"`
} }
) )
@ -153,9 +154,9 @@ func NewInMemory() (*Database, error) {
// to bring db up-to-date with the built-in schema. // to bring db up-to-date with the built-in schema.
// TODO(ppacher): right now this only applies the current schema and ignores // TODO(ppacher): right now this only applies the current schema and ignores
// any data-migrations. Once the history module is implemented this should // any data-migrations. Once the history module is implemented this should
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration // become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration.
func (db *Database) ApplyMigrations() error { func (db *Database) ApplyMigrations() error {
// get the create-table SQL statement from the infered schema // get the create-table SQL statement from the inferred schema
sql := db.Schema.CreateStatement(false) sql := db.Schema.CreateStatement(false)
// execute the SQL // execute the SQL
@ -234,7 +235,7 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database // dumpTo is a simple helper method that dumps all rows stored in the SQLite database
// as JSON to w. // as JSON to w.
// Any error aborts dumping rows and is returned. // Any error aborts dumping rows and is returned.
func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:unused
db.l.Lock() db.l.Lock()
defer db.l.Unlock() defer db.l.Unlock()

View file

@ -116,7 +116,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
} }
} }
func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn Conn) error { func (mng *Manager) pushConnUpdate(_ context.Context, meta record.Meta, conn Conn) error {
blob, err := json.Marshal(conn) blob, err := json.Marshal(conn)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal connection: %w", err) return fmt.Errorf("failed to marshal connection: %w", err)
@ -173,17 +173,19 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
c.Type = "dns" c.Type = "dns"
case network.IPConnection: case network.IPConnection:
c.Type = "ip" c.Type = "ip"
case network.Undefined:
c.Type = ""
} }
switch conn.Verdict { switch conn.Verdict {
case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel: case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
accepted := true accepted := true
c.Allowed = &accepted c.Allowed = &accepted
case network.VerdictUndecided, network.VerdictUndeterminable: case network.VerdictBlock, network.VerdictDrop:
c.Allowed = nil
default:
allowed := false allowed := false
c.Allowed = &allowed c.Allowed = &allowed
case network.VerdictUndecided, network.VerdictUndeterminable, network.VerdictFailed:
c.Allowed = nil
} }
if conn.Ended > 0 { if conn.Ended > 0 {

View file

@ -15,7 +15,7 @@ import (
"github.com/safing/portmaster/network" "github.com/safing/portmaster/network"
) )
type Module struct { type module struct {
*modules.Module *modules.Module
db *database.Interface db *database.Interface
@ -25,19 +25,19 @@ type Module struct {
} }
func init() { func init() {
mod := new(Module) m := new(module)
mod.Module = modules.Register( m.Module = modules.Register(
"netquery", "netquery",
mod.Prepare, m.prepare,
mod.Start, m.start,
mod.Stop, m.stop,
"api", "api",
"network", "network",
"database", "database",
) )
} }
func (m *Module) Prepare() error { func (m *module) prepare() error {
var err error var err error
m.db = database.NewInterface(&database.Options{ m.db = database.NewInterface(&database.Options{
@ -66,7 +66,6 @@ func (m *Module) Prepare() error {
Database: m.sqlStore, Database: m.sqlStore,
} }
// FIXME(ppacher): use appropriate permissions for this
if err := api.RegisterEndpoint(api.Endpoint{ if err := api.RegisterEndpoint(api.Endpoint{
Path: "netquery/query", Path: "netquery/query",
MimeType: "application/json", MimeType: "application/json",
@ -96,13 +95,15 @@ func (m *Module) Prepare() error {
return nil return nil
} }
func (mod *Module) Start() error { func (m *module) start() error {
mod.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error { m.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error {
sub, err := mod.db.Subscribe(query.New("network:")) sub, err := m.db.Subscribe(query.New("network:"))
if err != nil { if err != nil {
return fmt.Errorf("failed to subscribe to network tree: %w", err) return fmt.Errorf("failed to subscribe to network tree: %w", err)
} }
defer sub.Cancel() defer func() {
_ = sub.Cancel()
}()
for { for {
select { select {
@ -120,24 +121,24 @@ func (mod *Module) Start() error {
continue continue
} }
mod.feed <- conn m.feed <- conn
} }
} }
}) })
mod.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error { m.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error {
mod.mng.HandleFeed(ctx, mod.feed) m.mng.HandleFeed(ctx, m.feed)
return nil return nil
}) })
mod.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error { m.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold) threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
count, err := mod.sqlStore.Cleanup(ctx, threshold) count, err := m.sqlStore.Cleanup(ctx, threshold)
if err != nil { if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %s", err) log.Errorf("netquery: failed to count number of rows in memory: %s", err)
} else { } else {
@ -147,19 +148,21 @@ func (mod *Module) Start() error {
} }
}) })
// for debugging, we provide a simple direct SQL query interface using // For debugging, provide a simple direct SQL query interface using
// the runtime database // the runtime database.
// FIXME: Expose only in dev mode. // Only expose in development mode.
_, err := NewRuntimeQueryRunner(mod.sqlStore, "netquery/query/", runtime.DefaultRegistry) if config.GetAsBool(config.CfgDevModeKey, false)() {
if err != nil { _, err := NewRuntimeQueryRunner(m.sqlStore, "netquery/query/", runtime.DefaultRegistry)
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err) if err != nil {
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
}
} }
return nil return nil
} }
func (mod *Module) Stop() error { func (m *module) stop() error {
close(mod.feed) close(m.feed)
return nil return nil
} }

View file

@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"reflect" "reflect"
"strings" "strings"
"time" "time"
@ -30,7 +29,7 @@ var (
// TEXT or REAL. // TEXT or REAL.
// This package provides support for time.Time being stored as TEXT (using a // This package provides support for time.Time being stored as TEXT (using a
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between // preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
// unixepoch and unixnano-epoch where the nano variant is not offically supported by // unixepoch and unixnano-epoch where the nano variant is not officially supported by
// SQLITE). // SQLITE).
SqliteTimeFormat = "2006-01-02 15:04:05" SqliteTimeFormat = "2006-01-02 15:04:05"
) )
@ -54,6 +53,7 @@ type (
// DecodeFunc is called for each non-basic type during decoding. // DecodeFunc is called for each non-basic type during decoding.
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
// DecodeConfig holds decoding functions.
DecodeConfig struct { DecodeConfig struct {
DecodeHooks []DecodeFunc DecodeHooks []DecodeFunc
} }
@ -170,7 +170,8 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType) return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
} }
//log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue) // Debugging:
// log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
// convert it to the target type if conversion is possible // convert it to the target type if conversion is possible
newValue := reflect.ValueOf(columnValue) newValue := reflect.ValueOf(columnValue)
@ -189,8 +190,7 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to // time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
// decide between Unix or UnixNano epoch timestamps. // decide between Unix or UnixNano epoch timestamps.
// //
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing // TODO(ppacher): update comment about loc parameter and TEXT storage class parsing.
//
func DatetimeDecoder(loc *time.Location) DecodeFunc { func DatetimeDecoder(loc *time.Location) DecodeFunc {
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) { return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
// if we have the column definition available we // if we have the column definition available we
@ -203,11 +203,11 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
// we only care about "time.Time" here // we only care about "time.Time" here
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) { if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
log.Printf("not decoding %s %v", outType, colDef) // log.Printf("not decoding %s %v", outType, colDef)
return nil, false, nil return nil, false, nil
} }
switch stmt.ColumnType(colIdx) { switch stmt.ColumnType(colIdx) { //nolint:exhaustive // Only selecting specific types.
case sqlite.TypeInteger: case sqlite.TypeInteger:
// stored as unix-epoch, if unixnano is set in the struct field tag // stored as unix-epoch, if unixnano is set in the struct field tag
// we parse it with nano-second resolution // we parse it with nano-second resolution
@ -242,7 +242,7 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
} }
} }
func decodeIntoMap(ctx context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error { func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
if *mp == nil { if *mp == nil {
*mp = make(map[string]interface{}) *mp = make(map[string]interface{})
} }
@ -292,7 +292,7 @@ func decodeBasic() DecodeFunc {
if colDef != nil { if colDef != nil {
valueKind = normalizeKind(colDef.GoType.Kind()) valueKind = normalizeKind(colDef.GoType.Kind())
// if we have a column defintion we try to convert the value to // if we have a column definition we try to convert the value to
// the actual Go-type that was used in the model. // the actual Go-type that was used in the model.
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage // this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
// or that type aliases like (type myInt int) are decoded into myInt instead of int // or that type aliases like (type myInt int) are decoded into myInt instead of int
@ -314,7 +314,7 @@ func decodeBasic() DecodeFunc {
}() }()
} }
log.Printf("decoding %s into kind %s", colName, valueKind) // log.Printf("decoding %s into kind %s", colName, valueKind)
if colType == sqlite.TypeNull { if colType == sqlite.TypeNull {
if colDef != nil && colDef.Nullable { if colDef != nil && colDef.Nullable {
@ -330,7 +330,7 @@ func decodeBasic() DecodeFunc {
} }
} }
switch valueKind { switch valueKind { //nolint:exhaustive
case reflect.String: case reflect.String:
if colType != sqlite.TypeText { if colType != sqlite.TypeText {
return nil, false, errInvalidType return nil, false, errInvalidType
@ -455,7 +455,7 @@ func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.S
return nil, nil return nil, nil
} }
// getKind returns the kind of value but normalized Int, Uint and Float varaints // getKind returns the kind of value but normalized Int, Uint and Float variants.
// to their base type. // to their base type.
func getKind(val reflect.Value) reflect.Kind { func getKind(val reflect.Value) reflect.Kind {
kind := val.Kind() kind := val.Kind()
@ -475,6 +475,7 @@ func normalizeKind(kind reflect.Kind) reflect.Kind {
} }
} }
// DefaultDecodeConfig holds the default decoding configuration.
var DefaultDecodeConfig = DecodeConfig{ var DefaultDecodeConfig = DecodeConfig{
DecodeHooks: []DecodeFunc{ DecodeHooks: []DecodeFunc{
DatetimeDecoder(time.UTC), DatetimeDecoder(time.UTC),

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"log"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -21,14 +20,14 @@ type testStmt struct {
func (ts testStmt) ColumnCount() int { return len(ts.columns) } func (ts testStmt) ColumnCount() int { return len(ts.columns) }
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] } func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } //nolint:forcetypeassert
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } //nolint:forcetypeassert
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } //nolint:forcetypeassert
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } //nolint:forcetypeassert
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } //nolint:forcetypeassert
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] } func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
// compile time check // Compile time check.
var _ Stmt = new(testStmt) var _ Stmt = new(testStmt)
type exampleFieldTypes struct { type exampleFieldTypes struct {
@ -98,10 +97,11 @@ func (etn *exampleTimeNano) Equal(other interface{}) bool {
return etn.T.Equal(oetn.T) return etn.T.Equal(oetn.T)
} }
func Test_Decoder(t *testing.T) { func TestDecoder(t *testing.T) { //nolint:maintidx,tparallel
ctx := context.TODO() t.Parallel()
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
cases := []struct { cases := []struct {
Desc string Desc string
@ -433,8 +433,7 @@ func Test_Decoder(t *testing.T) {
nil, nil,
&exampleInterface{}, &exampleInterface{},
func() interface{} { func() interface{} {
var x interface{} var x interface{} = "value2"
x = "value2"
return &exampleInterface{ return &exampleInterface{
I: "value1", I: "value1",
@ -546,12 +545,10 @@ func Test_Decoder(t *testing.T) {
}, },
} }
for idx := range cases { for idx := range cases { //nolint:paralleltest
c := cases[idx] c := cases[idx]
t.Run(c.Desc, func(t *testing.T) { t.Run(c.Desc, func(t *testing.T) {
//t.Parallel() // log.Println(c.Desc)
log.Println(c.Desc)
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig) err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
if fn, ok := c.Expected.(func() interface{}); ok { if fn, ok := c.Expected.(func() interface{}); ok {
c.Expected = fn() c.Expected = fn()

View file

@ -10,8 +10,10 @@ import (
) )
type ( type (
// EncodeFunc is called for each non-basic type during encoding.
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
// EncodeConfig holds encoding functions.
EncodeConfig struct { EncodeConfig struct {
EncodeHooks []EncodeFunc EncodeHooks []EncodeFunc
} }
@ -69,6 +71,7 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode
return res, nil return res, nil
} }
// EncodeValue encodes the given value.
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) { func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
fieldValue := reflect.ValueOf(val) fieldValue := reflect.ValueOf(val)
fieldType := reflect.TypeOf(val) fieldType := reflect.TypeOf(val)
@ -115,7 +118,7 @@ func encodeBasic() EncodeFunc {
val = val.Elem() val = val.Elem()
} }
switch normalizeKind(kind) { switch normalizeKind(kind) { //nolint:exhaustive
case reflect.String, case reflect.String,
reflect.Float64, reflect.Float64,
reflect.Bool, reflect.Bool,
@ -138,6 +141,7 @@ func encodeBasic() EncodeFunc {
} }
} }
// DatetimeEncoder returns a new datetime encoder for the given time zone.
func DatetimeEncoder(loc *time.Location) EncodeFunc { func DatetimeEncoder(loc *time.Location) EncodeFunc {
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) { return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
// if fieldType holds a pointer we need to dereference the value // if fieldType holds a pointer we need to dereference the value
@ -149,7 +153,8 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
// we only care about "time.Time" here // we only care about "time.Time" here
var t time.Time var t time.Time
if ft == "time.Time" { switch {
case ft == "time.Time":
// handle the zero time as a NULL. // handle the zero time as a NULL.
if !val.IsValid() || val.IsZero() { if !val.IsValid() || val.IsZero() {
return nil, true, nil return nil, true, nil
@ -162,19 +167,19 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time") return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
} }
} else if valType.Kind() == reflect.String && colDef.IsTime { case valType.Kind() == reflect.String && colDef.IsTime:
var err error var err error
t, err = time.Parse(time.RFC3339, val.String()) t, err = time.Parse(time.RFC3339, val.String())
if err != nil { if err != nil {
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err) return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
} }
} else { default:
// we don't care ... // we don't care ...
return nil, false, nil return nil, false, nil
} }
switch colDef.Type { switch colDef.Type { //nolint:exhaustive
case sqlite.TypeInteger: case sqlite.TypeInteger:
if colDef.UnixNano { if colDef.UnixNano {
return t.UnixNano(), true, nil return t.UnixNano(), true, nil
@ -194,7 +199,7 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) { func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
if valType == nil { if valType == nil {
if !colDef.Nullable { if !colDef.Nullable {
switch colDef.Type { switch colDef.Type { //nolint:exhaustive
case sqlite.TypeBlob: case sqlite.TypeBlob:
return []byte{}, true, nil return []byte{}, true, nil
case sqlite.TypeFloat: case sqlite.TypeFloat:
@ -225,6 +230,7 @@ func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value,
return nil, false, nil return nil, false, nil
} }
// DefaultEncodeConfig holds the default encoding configuration.
var DefaultEncodeConfig = EncodeConfig{ var DefaultEncodeConfig = EncodeConfig{
EncodeHooks: []EncodeFunc{ EncodeHooks: []EncodeFunc{
DatetimeEncoder(time.UTC), DatetimeEncoder(time.UTC),

View file

@ -9,9 +9,11 @@ import (
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
func Test_EncodeAsMap(t *testing.T) { func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
t.Parallel()
ctx := context.TODO() ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
cases := []struct { cases := []struct {
Desc string Desc string
@ -114,11 +116,9 @@ func Test_EncodeAsMap(t *testing.T) {
}, },
} }
for idx := range cases { for idx := range cases { //nolint:paralleltest
c := cases[idx] c := cases[idx]
t.Run(c.Desc, func(t *testing.T) { t.Run(c.Desc, func(t *testing.T) {
// t.Parallel()
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig) res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, c.Expected, res) assert.Equal(t, c.Expected, res)
@ -126,9 +126,11 @@ func Test_EncodeAsMap(t *testing.T) {
} }
} }
func Test_EncodeValue(t *testing.T) { func TestEncodeValue(t *testing.T) { //nolint:tparallel
t.Parallel()
ctx := context.TODO() ctx := context.TODO()
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
cases := []struct { cases := []struct {
Desc string Desc string
@ -247,11 +249,9 @@ func Test_EncodeValue(t *testing.T) {
}, },
} }
for idx := range cases { for idx := range cases { //nolint:paralleltest
c := cases[idx] c := cases[idx]
t.Run(c.Desc, func(t *testing.T) { t.Run(c.Desc, func(t *testing.T) {
//t.Parallel()
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig) res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, c.Output, res) assert.Equal(t, c.Output, res)

View file

@ -57,6 +57,8 @@ func WithNamedArgs(args map[string]interface{}) QueryOption {
} }
} }
// WithSchema returns a query option that adds the given table
// schema to the query.
func WithSchema(tbl TableSchema) QueryOption { func WithSchema(tbl TableSchema) QueryOption {
return func(opts *queryOpts) { return func(opts *queryOpts) {
opts.Schema = tbl opts.Schema = tbl
@ -139,9 +141,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
valElemType = valType.Elem() valElemType = valType.Elem()
opts.ResultFunc = func(stmt *sqlite.Stmt) error { opts.ResultFunc = func(stmt *sqlite.Stmt) error {
var currentField reflect.Value currentField := reflect.New(valElemType)
currentField = reflect.New(valElemType)
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil { if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
return err return err

View file

@ -10,10 +10,9 @@ import (
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
var ( var errSkipStructField = errors.New("struct field should be skipped")
errSkipStructField = errors.New("struct field should be skipped")
)
// Struct Tags.
var ( var (
TagUnixNano = "unixnano" TagUnixNano = "unixnano"
TagPrimaryKey = "primary" TagPrimaryKey = "primary"
@ -36,12 +35,14 @@ var sqlTypeMap = map[sqlite.ColumnType]string{
} }
type ( type (
// TableSchema defines a SQL table schema.
TableSchema struct { TableSchema struct {
Name string Name string
Columns []ColumnDef Columns []ColumnDef
} }
ColumnDef struct { // ColumnDef defines a SQL column.
ColumnDef struct { //nolint:maligned
Name string Name string
Nullable bool Nullable bool
Type sqlite.ColumnType Type sqlite.ColumnType
@ -54,6 +55,7 @@ type (
} }
) )
// GetColumnDef returns the column definition with the given name.
func (ts TableSchema) GetColumnDef(name string) *ColumnDef { func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
for _, def := range ts.Columns { for _, def := range ts.Columns {
if def.Name == name { if def.Name == name {
@ -63,6 +65,7 @@ func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
return nil return nil
} }
// CreateStatement build the CREATE SQL statement for the table.
func (ts TableSchema) CreateStatement(ifNotExists bool) string { func (ts TableSchema) CreateStatement(ifNotExists bool) string {
sql := "CREATE TABLE" sql := "CREATE TABLE"
if ifNotExists { if ifNotExists {
@ -81,6 +84,7 @@ func (ts TableSchema) CreateStatement(ifNotExists bool) string {
return sql return sql
} }
// AsSQL builds the SQL column definition.
func (def ColumnDef) AsSQL() string { func (def ColumnDef) AsSQL() string {
sql := def.Name + " " sql := def.Name + " "
@ -103,6 +107,7 @@ func (def ColumnDef) AsSQL() string {
return sql return sql
} }
// GenerateTableSchema generates a table schema from the given struct.
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) { func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
ts := &TableSchema{ ts := &TableSchema{
Name: name, Name: name,
@ -149,7 +154,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
def.GoType = ft def.GoType = ft
kind := normalizeKind(ft.Kind()) kind := normalizeKind(ft.Kind())
switch kind { switch kind { //nolint:exhaustive
case reflect.Int: case reflect.Int:
def.Type = sqlite.TypeInteger def.Type = sqlite.TypeInteger
@ -190,7 +195,7 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
if len(parts) > 1 { if len(parts) > 1 {
for _, k := range parts[1:] { for _, k := range parts[1:] {
switch k { switch k {
// column modifieres // column modifiers
case TagPrimaryKey: case TagPrimaryKey:
def.PrimaryKey = true def.PrimaryKey = true
case TagAutoIncrement: case TagAutoIncrement:

View file

@ -6,7 +6,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_SchemaBuilder(t *testing.T) { func TestSchemaBuilder(t *testing.T) {
t.Parallel()
cases := []struct { cases := []struct {
Name string Name string
Model interface{} Model interface{}

View file

@ -5,15 +5,19 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"sort" "sort"
"strings" "strings"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/safing/portmaster/netquery/orm"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
"github.com/safing/portmaster/netquery/orm"
) )
// Collection of Query and Matcher types.
// NOTE: whenever adding support for new operators make sure
// to update UnmarshalJSON as well.
//nolint:golint
type ( type (
Query map[string][]Matcher Query map[string][]Matcher
@ -43,8 +47,6 @@ type (
Distinct bool `json:"distinct"` Distinct bool `json:"distinct"`
} }
// NOTE: whenever adding support for new operators make sure
// to update UnmarshalJSON as well.
Select struct { Select struct {
Field string `json:"field"` Field string `json:"field"`
Count *Count `json:"$count,omitempty"` Count *Count `json:"$count,omitempty"`
@ -91,6 +93,7 @@ type (
} }
) )
// UnmarshalJSON unmarshals a Query from json.
func (query *Query) UnmarshalJSON(blob []byte) error { func (query *Query) UnmarshalJSON(blob []byte) error {
if *query == nil { if *query == nil {
*query = make(Query) *query = make(Query)
@ -202,13 +205,14 @@ func parseMatcher(raw json.RawMessage) (*Matcher, error) {
} }
if err := m.Validate(); err != nil { if err := m.Validate(); err != nil {
return nil, fmt.Errorf("invalid query matcher: %s", err) return nil, fmt.Errorf("invalid query matcher: %w", err)
} }
log.Printf("parsed matcher %s: %+v", string(raw), m)
return &m, nil
// log.Printf("parsed matcher %s: %+v", string(raw), m)
return &m, nil
} }
// Validate validates the matcher.
func (match Matcher) Validate() error { func (match Matcher) Validate() error {
found := 0 found := 0
@ -239,9 +243,9 @@ func (match Matcher) Validate() error {
return nil return nil
} }
func (text TextSearch) toSQLConditionClause(ctx context.Context, schema *orm.TableSchema, suffix string, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) {
var ( var (
queryParts []string queryParts = make([]string, 0, len(text.Fields))
params = make(map[string]interface{}) params = make(map[string]interface{})
) )
@ -379,7 +383,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T
// merge parameters up into the superior parameter map // merge parameters up into the superior parameter map
for key, val := range params { for key, val := range params {
if _, ok := paramMap[key]; ok { if _, ok := paramMap[key]; ok {
// is is soley a developer mistake when implementing a matcher so no forgiving ... // This is solely a developer mistake when implementing a matcher so no forgiving ...
panic("sqlite parameter collision") panic("sqlite parameter collision")
} }
@ -399,6 +403,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T
return whereClause, paramMap, errs.ErrorOrNil() return whereClause, paramMap, errs.ErrorOrNil()
} }
// UnmarshalJSON unmarshals a Selects from json.
func (sel *Selects) UnmarshalJSON(blob []byte) error { func (sel *Selects) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 { if len(blob) == 0 {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
@ -438,6 +443,7 @@ func (sel *Selects) UnmarshalJSON(blob []byte) error {
return nil return nil
} }
// UnmarshalJSON unmarshals a Select from json.
func (sel *Select) UnmarshalJSON(blob []byte) error { func (sel *Select) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 { if len(blob) == 0 {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
@ -481,6 +487,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
return nil return nil
} }
// UnmarshalJSON unmarshals a OrderBys from json.
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error { func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 { if len(blob) == 0 {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
@ -523,6 +530,7 @@ func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
return nil return nil
} }
// UnmarshalJSON unmarshals a OrderBy from json.
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error { func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 { if len(blob) == 0 {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF

View file

@ -17,9 +17,7 @@ import (
"github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/netquery/orm"
) )
var ( var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
)
type ( type (
@ -109,7 +107,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
} }
} }
func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { //nolint:dupl
var body io.Reader var body io.Reader
switch req.Method { switch req.Method {
@ -230,11 +228,11 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
switch { switch {
case s.Count != nil: case s.Count != nil:
var as = s.Count.As as := s.Count.As
if as == "" { if as == "" {
as = fmt.Sprintf("%s_count", colName) as = fmt.Sprintf("%s_count", colName)
} }
var distinct = "" distinct := ""
if s.Count.Distinct { if s.Count.Distinct {
distinct = "DISTINCT " distinct = "DISTINCT "
} }
@ -278,8 +276,7 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
return "", nil return "", nil
} }
var groupBys = make([]string, len(req.GroupBy)) groupBys := make([]string, len(req.GroupBy))
for idx, name := range req.GroupBy { for idx, name := range req.GroupBy {
colName, err := req.validateColumnName(schema, name) colName, err := req.validateColumnName(schema, name)
if err != nil { if err != nil {
@ -288,7 +285,6 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
groupBys[idx] = colName groupBys[idx] = colName
} }
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ") groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
// if there are no explicitly selected fields we default to the // if there are no explicitly selected fields we default to the
@ -301,7 +297,7 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
} }
func (req *QueryRequestPayload) generateSelectClause() string { func (req *QueryRequestPayload) generateSelectClause() string {
var selectClause = "*" selectClause := "*"
if len(req.selectedFields) > 0 { if len(req.selectedFields) > 0 {
selectClause = strings.Join(req.selectedFields, ", ") selectClause = strings.Join(req.selectedFields, ", ")
} }
@ -314,7 +310,7 @@ func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (
return "", nil return "", nil
} }
var orderBys = make([]string, len(req.OrderBy)) orderBys := make([]string, len(req.OrderBy))
for idx, sort := range req.OrderBy { for idx, sort := range req.OrderBy {
colName, err := req.validateColumnName(schema, sort.Field) colName, err := req.validateColumnName(schema, sort.Field)
if err != nil { if err != nil {
@ -352,5 +348,5 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
return "", fmt.Errorf("column name %q not allowed", field) return "", fmt.Errorf("column name %q not allowed", field)
} }
// compile time check // Compile time check.
var _ http.Handler = new(QueryHandler) var _ http.Handler = new(QueryHandler)

View file

@ -7,13 +7,16 @@ import (
"testing" "testing"
"time" "time"
"github.com/safing/portmaster/netquery/orm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/safing/portmaster/netquery/orm"
) )
func Test_UnmarshalQuery(t *testing.T) { func TestUnmarshalQuery(t *testing.T) { //nolint:tparallel
var cases = []struct { t.Parallel()
cases := []struct {
Name string Name string
Input string Input string
Expected Query Expected Query
@ -88,7 +91,8 @@ func Test_UnmarshalQuery(t *testing.T) {
}, },
} }
for _, c := range cases { for _, testCase := range cases { //nolint:paralleltest
c := testCase
t.Run(c.Name, func(t *testing.T) { t.Run(c.Name, func(t *testing.T) {
var q Query var q Query
err := json.Unmarshal([]byte(c.Input), &q) err := json.Unmarshal([]byte(c.Input), &q)
@ -105,10 +109,11 @@ func Test_UnmarshalQuery(t *testing.T) {
} }
} }
func Test_QueryBuilder(t *testing.T) { func TestQueryBuilder(t *testing.T) { //nolint:tparallel
now := time.Now() t.Parallel()
var cases = []struct { now := time.Now()
cases := []struct {
N string N string
Q Query Q Query
R string R string
@ -186,7 +191,7 @@ func Test_QueryBuilder(t *testing.T) {
}, },
"", "",
nil, nil,
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), //nolint:golint
}, },
{ {
"Complex example", "Complex example",
@ -225,19 +230,20 @@ func Test_QueryBuilder(t *testing.T) {
tbl, err := orm.GenerateTableSchema("connections", Conn{}) tbl, err := orm.GenerateTableSchema("connections", Conn{})
require.NoError(t, err) require.NoError(t, err)
for idx, c := range cases { for idx, testCase := range cases { //nolint:paralleltest
cID := idx
c := testCase
t.Run(c.N, func(t *testing.T) { t.Run(c.N, func(t *testing.T) {
//t.Parallel()
str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig) str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
if c.E != nil { if c.E != nil {
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", idx) assert.Equal(t, c.E.Error(), err.Error(), "test case %d", cID)
} }
} else { } else {
assert.NoError(t, err, "test case %d", idx) assert.NoError(t, err, "test case %d", cID)
assert.Equal(t, c.P, params, "test case %d", idx) assert.Equal(t, c.P, params, "test case %d", cID)
assert.Equal(t, c.R, str, "test case %d", idx) assert.Equal(t, c.R, str, "test case %d", cID)
} }
}) })
} }

View file

@ -11,8 +11,11 @@ import (
) )
const ( const (
cleanerTickDuration = 5 * time.Second // DeleteConnsAfterEndedThreshold defines the amount of time after which
// ended connections should be removed from the internal connection state.
DeleteConnsAfterEndedThreshold = 10 * time.Minute DeleteConnsAfterEndedThreshold = 10 * time.Minute
cleanerTickDuration = 5 * time.Second
) )
func connectionCleaner(ctx context.Context) error { func connectionCleaner(ctx context.Context) error {