mirror of
https://github.com/safing/portmaster
synced 2025-09-01 10:09:11 +00:00
Fix linter errors from netquery implementation
This commit is contained in:
parent
1889c68d27
commit
90d30c14a5
16 changed files with 163 additions and 132 deletions
|
@ -7,12 +7,12 @@ import (
|
||||||
|
|
||||||
"github.com/safing/portbase/modules"
|
"github.com/safing/portbase/modules"
|
||||||
"github.com/safing/portbase/modules/subsystems"
|
"github.com/safing/portbase/modules/subsystems"
|
||||||
"github.com/safing/portmaster/updates"
|
|
||||||
_ "github.com/safing/portmaster/broadcasts"
|
_ "github.com/safing/portmaster/broadcasts"
|
||||||
_ "github.com/safing/portmaster/netenv"
|
_ "github.com/safing/portmaster/netenv"
|
||||||
_ "github.com/safing/portmaster/netquery"
|
_ "github.com/safing/portmaster/netquery"
|
||||||
_ "github.com/safing/portmaster/status"
|
_ "github.com/safing/portmaster/status"
|
||||||
_ "github.com/safing/portmaster/ui"
|
_ "github.com/safing/portmaster/ui"
|
||||||
|
"github.com/safing/portmaster/updates"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/safing/portmaster/netquery/orm"
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ChartHandler handles requests for connection charts.
|
||||||
type ChartHandler struct {
|
type ChartHandler struct {
|
||||||
Database *Database
|
Database *Database
|
||||||
}
|
}
|
||||||
|
@ -55,14 +56,14 @@ func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
enc.SetEscapeHTML(false)
|
enc.SetEscapeHTML(false)
|
||||||
enc.SetIndent("", " ")
|
enc.SetIndent("", " ")
|
||||||
|
|
||||||
enc.Encode(map[string]interface{}{
|
_ = enc.Encode(map[string]interface{}{ //nolint:errchkjson
|
||||||
"results": result,
|
"results": result,
|
||||||
"query": query,
|
"query": query,
|
||||||
"params": paramMap,
|
"params": paramMap,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) {
|
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) { //nolint:dupl
|
||||||
var body io.Reader
|
var body io.Reader
|
||||||
|
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
|
|
|
@ -9,13 +9,14 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"zombiezen.com/go/sqlite"
|
||||||
|
"zombiezen.com/go/sqlite/sqlitex"
|
||||||
|
|
||||||
"github.com/safing/portbase/log"
|
"github.com/safing/portbase/log"
|
||||||
"github.com/safing/portmaster/netquery/orm"
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
"github.com/safing/portmaster/network"
|
"github.com/safing/portmaster/network"
|
||||||
"github.com/safing/portmaster/network/netutils"
|
"github.com/safing/portmaster/network/netutils"
|
||||||
"github.com/safing/portmaster/network/packet"
|
"github.com/safing/portmaster/network/packet"
|
||||||
"zombiezen.com/go/sqlite"
|
|
||||||
"zombiezen.com/go/sqlite/sqlitex"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// InMemory is the "file path" to open a new in-memory database.
|
// InMemory is the "file path" to open a new in-memory database.
|
||||||
|
@ -36,7 +37,7 @@ var ConnectionTypeToString = map[network.ConnectionType]string{
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// Database represents a SQLite3 backed connection database.
|
// Database represents a SQLite3 backed connection database.
|
||||||
// It's use is tailored for persistance and querying of network.Connection.
|
// It's use is tailored for persistence and querying of network.Connection.
|
||||||
// Access to the underlying SQLite database is synchronized.
|
// Access to the underlying SQLite database is synchronized.
|
||||||
//
|
//
|
||||||
// TODO(ppacher): somehow I'm receiving SIGBUS or SIGSEGV when no doing
|
// TODO(ppacher): somehow I'm receiving SIGBUS or SIGSEGV when no doing
|
||||||
|
@ -57,7 +58,7 @@ type (
|
||||||
//
|
//
|
||||||
// Use ConvertConnection from this package to convert a network.Connection to this
|
// Use ConvertConnection from this package to convert a network.Connection to this
|
||||||
// representation.
|
// representation.
|
||||||
Conn struct {
|
Conn struct { //nolint:maligned
|
||||||
// ID is a device-unique identifier for the connection. It is built
|
// ID is a device-unique identifier for the connection. It is built
|
||||||
// from network.Connection by hashing the connection ID and the start
|
// from network.Connection by hashing the connection ID and the start
|
||||||
// time. We cannot just use the network.Connection.ID because it is only unique
|
// time. We cannot just use the network.Connection.ID because it is only unique
|
||||||
|
@ -93,11 +94,11 @@ type (
|
||||||
ProfileRevision int `sqlite:"profile_revision"`
|
ProfileRevision int `sqlite:"profile_revision"`
|
||||||
ExitNode *string `sqlite:"exit_node"`
|
ExitNode *string `sqlite:"exit_node"`
|
||||||
|
|
||||||
// FIXME(ppacher): support "NOT" in search query to get rid of the following helper fields
|
// TODO(ppacher): support "NOT" in search query to get rid of the following helper fields
|
||||||
SPNUsed bool `sqlite:"spn_used"` // could use "exit_node IS NOT NULL" or "exit IS NULL"
|
SPNUsed bool `sqlite:"spn_used"` // could use "exit_node IS NOT NULL" or "exit IS NULL"
|
||||||
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
|
Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL"
|
||||||
|
|
||||||
// FIXME(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here
|
// TODO(ppacher): we need to profile here for "suggestion" support. It would be better to keep a table of profiles in sqlite and use joins here
|
||||||
ProfileName string `sqlite:"profile_name"`
|
ProfileName string `sqlite:"profile_name"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -153,9 +154,9 @@ func NewInMemory() (*Database, error) {
|
||||||
// to bring db up-to-date with the built-in schema.
|
// to bring db up-to-date with the built-in schema.
|
||||||
// TODO(ppacher): right now this only applies the current schema and ignores
|
// TODO(ppacher): right now this only applies the current schema and ignores
|
||||||
// any data-migrations. Once the history module is implemented this should
|
// any data-migrations. Once the history module is implemented this should
|
||||||
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration
|
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration.
|
||||||
func (db *Database) ApplyMigrations() error {
|
func (db *Database) ApplyMigrations() error {
|
||||||
// get the create-table SQL statement from the infered schema
|
// get the create-table SQL statement from the inferred schema
|
||||||
sql := db.Schema.CreateStatement(false)
|
sql := db.Schema.CreateStatement(false)
|
||||||
|
|
||||||
// execute the SQL
|
// execute the SQL
|
||||||
|
@ -234,7 +235,7 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro
|
||||||
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database
|
// dumpTo is a simple helper method that dumps all rows stored in the SQLite database
|
||||||
// as JSON to w.
|
// as JSON to w.
|
||||||
// Any error aborts dumping rows and is returned.
|
// Any error aborts dumping rows and is returned.
|
||||||
func (db *Database) dumpTo(ctx context.Context, w io.Writer) error {
|
func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:unused
|
||||||
db.l.Lock()
|
db.l.Lock()
|
||||||
defer db.l.Unlock()
|
defer db.l.Unlock()
|
||||||
|
|
||||||
|
|
|
@ -116,7 +116,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn Conn) error {
|
func (mng *Manager) pushConnUpdate(_ context.Context, meta record.Meta, conn Conn) error {
|
||||||
blob, err := json.Marshal(conn)
|
blob, err := json.Marshal(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal connection: %w", err)
|
return fmt.Errorf("failed to marshal connection: %w", err)
|
||||||
|
@ -173,17 +173,19 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
|
||||||
c.Type = "dns"
|
c.Type = "dns"
|
||||||
case network.IPConnection:
|
case network.IPConnection:
|
||||||
c.Type = "ip"
|
c.Type = "ip"
|
||||||
|
case network.Undefined:
|
||||||
|
c.Type = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
switch conn.Verdict {
|
switch conn.Verdict {
|
||||||
case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
|
case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
|
||||||
accepted := true
|
accepted := true
|
||||||
c.Allowed = &accepted
|
c.Allowed = &accepted
|
||||||
case network.VerdictUndecided, network.VerdictUndeterminable:
|
case network.VerdictBlock, network.VerdictDrop:
|
||||||
c.Allowed = nil
|
|
||||||
default:
|
|
||||||
allowed := false
|
allowed := false
|
||||||
c.Allowed = &allowed
|
c.Allowed = &allowed
|
||||||
|
case network.VerdictUndecided, network.VerdictUndeterminable, network.VerdictFailed:
|
||||||
|
c.Allowed = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.Ended > 0 {
|
if conn.Ended > 0 {
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
"github.com/safing/portmaster/network"
|
"github.com/safing/portmaster/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Module struct {
|
type module struct {
|
||||||
*modules.Module
|
*modules.Module
|
||||||
|
|
||||||
db *database.Interface
|
db *database.Interface
|
||||||
|
@ -25,19 +25,19 @@ type Module struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mod := new(Module)
|
m := new(module)
|
||||||
mod.Module = modules.Register(
|
m.Module = modules.Register(
|
||||||
"netquery",
|
"netquery",
|
||||||
mod.Prepare,
|
m.prepare,
|
||||||
mod.Start,
|
m.start,
|
||||||
mod.Stop,
|
m.stop,
|
||||||
"api",
|
"api",
|
||||||
"network",
|
"network",
|
||||||
"database",
|
"database",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Module) Prepare() error {
|
func (m *module) prepare() error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
m.db = database.NewInterface(&database.Options{
|
m.db = database.NewInterface(&database.Options{
|
||||||
|
@ -66,7 +66,6 @@ func (m *Module) Prepare() error {
|
||||||
Database: m.sqlStore,
|
Database: m.sqlStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME(ppacher): use appropriate permissions for this
|
|
||||||
if err := api.RegisterEndpoint(api.Endpoint{
|
if err := api.RegisterEndpoint(api.Endpoint{
|
||||||
Path: "netquery/query",
|
Path: "netquery/query",
|
||||||
MimeType: "application/json",
|
MimeType: "application/json",
|
||||||
|
@ -96,13 +95,15 @@ func (m *Module) Prepare() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mod *Module) Start() error {
|
func (m *module) start() error {
|
||||||
mod.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error {
|
m.StartServiceWorker("netquery-feeder", time.Second, func(ctx context.Context) error {
|
||||||
sub, err := mod.db.Subscribe(query.New("network:"))
|
sub, err := m.db.Subscribe(query.New("network:"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to subscribe to network tree: %w", err)
|
return fmt.Errorf("failed to subscribe to network tree: %w", err)
|
||||||
}
|
}
|
||||||
defer sub.Cancel()
|
defer func() {
|
||||||
|
_ = sub.Cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -120,24 +121,24 @@ func (mod *Module) Start() error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
mod.feed <- conn
|
m.feed <- conn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
mod.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error {
|
m.StartServiceWorker("netquery-persister", time.Second, func(ctx context.Context) error {
|
||||||
mod.mng.HandleFeed(ctx, mod.feed)
|
m.mng.HandleFeed(ctx, m.feed)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
mod.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error {
|
m.StartServiceWorker("netquery-row-cleaner", time.Second, func(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
case <-time.After(10 * time.Second):
|
case <-time.After(10 * time.Second):
|
||||||
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
|
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
|
||||||
count, err := mod.sqlStore.Cleanup(ctx, threshold)
|
count, err := m.sqlStore.Cleanup(ctx, threshold)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -147,19 +148,21 @@ func (mod *Module) Start() error {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// for debugging, we provide a simple direct SQL query interface using
|
// For debugging, provide a simple direct SQL query interface using
|
||||||
// the runtime database
|
// the runtime database.
|
||||||
// FIXME: Expose only in dev mode.
|
// Only expose in development mode.
|
||||||
_, err := NewRuntimeQueryRunner(mod.sqlStore, "netquery/query/", runtime.DefaultRegistry)
|
if config.GetAsBool(config.CfgDevModeKey, false)() {
|
||||||
if err != nil {
|
_, err := NewRuntimeQueryRunner(m.sqlStore, "netquery/query/", runtime.DefaultRegistry)
|
||||||
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set up runtime SQL query runner: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mod *Module) Stop() error {
|
func (m *module) stop() error {
|
||||||
close(mod.feed)
|
close(m.feed)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -30,7 +29,7 @@ var (
|
||||||
// TEXT or REAL.
|
// TEXT or REAL.
|
||||||
// This package provides support for time.Time being stored as TEXT (using a
|
// This package provides support for time.Time being stored as TEXT (using a
|
||||||
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
|
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
|
||||||
// unixepoch and unixnano-epoch where the nano variant is not offically supported by
|
// unixepoch and unixnano-epoch where the nano variant is not officially supported by
|
||||||
// SQLITE).
|
// SQLITE).
|
||||||
SqliteTimeFormat = "2006-01-02 15:04:05"
|
SqliteTimeFormat = "2006-01-02 15:04:05"
|
||||||
)
|
)
|
||||||
|
@ -54,6 +53,7 @@ type (
|
||||||
// DecodeFunc is called for each non-basic type during decoding.
|
// DecodeFunc is called for each non-basic type during decoding.
|
||||||
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
|
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
|
||||||
|
|
||||||
|
// DecodeConfig holds decoding functions.
|
||||||
DecodeConfig struct {
|
DecodeConfig struct {
|
||||||
DecodeHooks []DecodeFunc
|
DecodeHooks []DecodeFunc
|
||||||
}
|
}
|
||||||
|
@ -170,7 +170,8 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte
|
||||||
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
|
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
|
||||||
}
|
}
|
||||||
|
|
||||||
//log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
|
// Debugging:
|
||||||
|
// log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
|
||||||
|
|
||||||
// convert it to the target type if conversion is possible
|
// convert it to the target type if conversion is possible
|
||||||
newValue := reflect.ValueOf(columnValue)
|
newValue := reflect.ValueOf(columnValue)
|
||||||
|
@ -189,8 +190,7 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte
|
||||||
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
|
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
|
||||||
// decide between Unix or UnixNano epoch timestamps.
|
// decide between Unix or UnixNano epoch timestamps.
|
||||||
//
|
//
|
||||||
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing
|
// TODO(ppacher): update comment about loc parameter and TEXT storage class parsing.
|
||||||
//
|
|
||||||
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||||
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
|
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
|
||||||
// if we have the column definition available we
|
// if we have the column definition available we
|
||||||
|
@ -203,11 +203,11 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||||
|
|
||||||
// we only care about "time.Time" here
|
// we only care about "time.Time" here
|
||||||
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
|
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
|
||||||
log.Printf("not decoding %s %v", outType, colDef)
|
// log.Printf("not decoding %s %v", outType, colDef)
|
||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch stmt.ColumnType(colIdx) {
|
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // Only selecting specific types.
|
||||||
case sqlite.TypeInteger:
|
case sqlite.TypeInteger:
|
||||||
// stored as unix-epoch, if unixnano is set in the struct field tag
|
// stored as unix-epoch, if unixnano is set in the struct field tag
|
||||||
// we parse it with nano-second resolution
|
// we parse it with nano-second resolution
|
||||||
|
@ -242,7 +242,7 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeIntoMap(ctx context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
|
func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
|
||||||
if *mp == nil {
|
if *mp == nil {
|
||||||
*mp = make(map[string]interface{})
|
*mp = make(map[string]interface{})
|
||||||
}
|
}
|
||||||
|
@ -292,7 +292,7 @@ func decodeBasic() DecodeFunc {
|
||||||
if colDef != nil {
|
if colDef != nil {
|
||||||
valueKind = normalizeKind(colDef.GoType.Kind())
|
valueKind = normalizeKind(colDef.GoType.Kind())
|
||||||
|
|
||||||
// if we have a column defintion we try to convert the value to
|
// if we have a column definition we try to convert the value to
|
||||||
// the actual Go-type that was used in the model.
|
// the actual Go-type that was used in the model.
|
||||||
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
|
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
|
||||||
// or that type aliases like (type myInt int) are decoded into myInt instead of int
|
// or that type aliases like (type myInt int) are decoded into myInt instead of int
|
||||||
|
@ -314,7 +314,7 @@ func decodeBasic() DecodeFunc {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("decoding %s into kind %s", colName, valueKind)
|
// log.Printf("decoding %s into kind %s", colName, valueKind)
|
||||||
|
|
||||||
if colType == sqlite.TypeNull {
|
if colType == sqlite.TypeNull {
|
||||||
if colDef != nil && colDef.Nullable {
|
if colDef != nil && colDef.Nullable {
|
||||||
|
@ -330,7 +330,7 @@ func decodeBasic() DecodeFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch valueKind {
|
switch valueKind { //nolint:exhaustive
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if colType != sqlite.TypeText {
|
if colType != sqlite.TypeText {
|
||||||
return nil, false, errInvalidType
|
return nil, false, errInvalidType
|
||||||
|
@ -455,7 +455,7 @@ func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.S
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getKind returns the kind of value but normalized Int, Uint and Float varaints
|
// getKind returns the kind of value but normalized Int, Uint and Float variants.
|
||||||
// to their base type.
|
// to their base type.
|
||||||
func getKind(val reflect.Value) reflect.Kind {
|
func getKind(val reflect.Value) reflect.Kind {
|
||||||
kind := val.Kind()
|
kind := val.Kind()
|
||||||
|
@ -475,6 +475,7 @@ func normalizeKind(kind reflect.Kind) reflect.Kind {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultDecodeConfig holds the default decoding configuration.
|
||||||
var DefaultDecodeConfig = DecodeConfig{
|
var DefaultDecodeConfig = DecodeConfig{
|
||||||
DecodeHooks: []DecodeFunc{
|
DecodeHooks: []DecodeFunc{
|
||||||
DatetimeDecoder(time.UTC),
|
DatetimeDecoder(time.UTC),
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -21,14 +20,14 @@ type testStmt struct {
|
||||||
|
|
||||||
func (ts testStmt) ColumnCount() int { return len(ts.columns) }
|
func (ts testStmt) ColumnCount() int { return len(ts.columns) }
|
||||||
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
|
func (ts testStmt) ColumnName(i int) string { return ts.columns[i] }
|
||||||
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) }
|
func (ts testStmt) ColumnBool(i int) bool { return ts.values[i].(bool) } //nolint:forcetypeassert
|
||||||
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) }
|
func (ts testStmt) ColumnText(i int) string { return ts.values[i].(string) } //nolint:forcetypeassert
|
||||||
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) }
|
func (ts testStmt) ColumnFloat(i int) float64 { return ts.values[i].(float64) } //nolint:forcetypeassert
|
||||||
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) }
|
func (ts testStmt) ColumnInt(i int) int { return ts.values[i].(int) } //nolint:forcetypeassert
|
||||||
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) }
|
func (ts testStmt) ColumnReader(i int) *bytes.Reader { return bytes.NewReader(ts.values[i].([]byte)) } //nolint:forcetypeassert
|
||||||
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
|
func (ts testStmt) ColumnType(i int) sqlite.ColumnType { return ts.types[i] }
|
||||||
|
|
||||||
// compile time check
|
// Compile time check.
|
||||||
var _ Stmt = new(testStmt)
|
var _ Stmt = new(testStmt)
|
||||||
|
|
||||||
type exampleFieldTypes struct {
|
type exampleFieldTypes struct {
|
||||||
|
@ -98,10 +97,11 @@ func (etn *exampleTimeNano) Equal(other interface{}) bool {
|
||||||
return etn.T.Equal(oetn.T)
|
return etn.T.Equal(oetn.T)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Decoder(t *testing.T) {
|
func TestDecoder(t *testing.T) { //nolint:maintidx,tparallel
|
||||||
ctx := context.TODO()
|
t.Parallel()
|
||||||
|
|
||||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
ctx := context.TODO()
|
||||||
|
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
Desc string
|
Desc string
|
||||||
|
@ -433,8 +433,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
nil,
|
nil,
|
||||||
&exampleInterface{},
|
&exampleInterface{},
|
||||||
func() interface{} {
|
func() interface{} {
|
||||||
var x interface{}
|
var x interface{} = "value2"
|
||||||
x = "value2"
|
|
||||||
|
|
||||||
return &exampleInterface{
|
return &exampleInterface{
|
||||||
I: "value1",
|
I: "value1",
|
||||||
|
@ -546,12 +545,10 @@ func Test_Decoder(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx := range cases {
|
for idx := range cases { //nolint:paralleltest
|
||||||
c := cases[idx]
|
c := cases[idx]
|
||||||
t.Run(c.Desc, func(t *testing.T) {
|
t.Run(c.Desc, func(t *testing.T) {
|
||||||
//t.Parallel()
|
// log.Println(c.Desc)
|
||||||
|
|
||||||
log.Println(c.Desc)
|
|
||||||
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
|
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
|
||||||
if fn, ok := c.Expected.(func() interface{}); ok {
|
if fn, ok := c.Expected.(func() interface{}); ok {
|
||||||
c.Expected = fn()
|
c.Expected = fn()
|
||||||
|
|
|
@ -10,8 +10,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
// EncodeFunc is called for each non-basic type during encoding.
|
||||||
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
|
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
|
||||||
|
|
||||||
|
// EncodeConfig holds encoding functions.
|
||||||
EncodeConfig struct {
|
EncodeConfig struct {
|
||||||
EncodeHooks []EncodeFunc
|
EncodeHooks []EncodeFunc
|
||||||
}
|
}
|
||||||
|
@ -69,6 +71,7 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EncodeValue encodes the given value.
|
||||||
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
|
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
|
||||||
fieldValue := reflect.ValueOf(val)
|
fieldValue := reflect.ValueOf(val)
|
||||||
fieldType := reflect.TypeOf(val)
|
fieldType := reflect.TypeOf(val)
|
||||||
|
@ -115,7 +118,7 @@ func encodeBasic() EncodeFunc {
|
||||||
val = val.Elem()
|
val = val.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
switch normalizeKind(kind) {
|
switch normalizeKind(kind) { //nolint:exhaustive
|
||||||
case reflect.String,
|
case reflect.String,
|
||||||
reflect.Float64,
|
reflect.Float64,
|
||||||
reflect.Bool,
|
reflect.Bool,
|
||||||
|
@ -138,6 +141,7 @@ func encodeBasic() EncodeFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DatetimeEncoder returns a new datetime encoder for the given time zone.
|
||||||
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||||
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
||||||
// if fieldType holds a pointer we need to dereference the value
|
// if fieldType holds a pointer we need to dereference the value
|
||||||
|
@ -149,7 +153,8 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||||
|
|
||||||
// we only care about "time.Time" here
|
// we only care about "time.Time" here
|
||||||
var t time.Time
|
var t time.Time
|
||||||
if ft == "time.Time" {
|
switch {
|
||||||
|
case ft == "time.Time":
|
||||||
// handle the zero time as a NULL.
|
// handle the zero time as a NULL.
|
||||||
if !val.IsValid() || val.IsZero() {
|
if !val.IsValid() || val.IsZero() {
|
||||||
return nil, true, nil
|
return nil, true, nil
|
||||||
|
@ -162,19 +167,19 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||||
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
|
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if valType.Kind() == reflect.String && colDef.IsTime {
|
case valType.Kind() == reflect.String && colDef.IsTime:
|
||||||
var err error
|
var err error
|
||||||
t, err = time.Parse(time.RFC3339, val.String())
|
t, err = time.Parse(time.RFC3339, val.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
|
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
default:
|
||||||
// we don't care ...
|
// we don't care ...
|
||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch colDef.Type {
|
switch colDef.Type { //nolint:exhaustive
|
||||||
case sqlite.TypeInteger:
|
case sqlite.TypeInteger:
|
||||||
if colDef.UnixNano {
|
if colDef.UnixNano {
|
||||||
return t.UnixNano(), true, nil
|
return t.UnixNano(), true, nil
|
||||||
|
@ -194,7 +199,7 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||||
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
||||||
if valType == nil {
|
if valType == nil {
|
||||||
if !colDef.Nullable {
|
if !colDef.Nullable {
|
||||||
switch colDef.Type {
|
switch colDef.Type { //nolint:exhaustive
|
||||||
case sqlite.TypeBlob:
|
case sqlite.TypeBlob:
|
||||||
return []byte{}, true, nil
|
return []byte{}, true, nil
|
||||||
case sqlite.TypeFloat:
|
case sqlite.TypeFloat:
|
||||||
|
@ -225,6 +230,7 @@ func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value,
|
||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultEncodeConfig holds the default encoding configuration.
|
||||||
var DefaultEncodeConfig = EncodeConfig{
|
var DefaultEncodeConfig = EncodeConfig{
|
||||||
EncodeHooks: []EncodeFunc{
|
EncodeHooks: []EncodeFunc{
|
||||||
DatetimeEncoder(time.UTC),
|
DatetimeEncoder(time.UTC),
|
||||||
|
|
|
@ -9,9 +9,11 @@ import (
|
||||||
"zombiezen.com/go/sqlite"
|
"zombiezen.com/go/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_EncodeAsMap(t *testing.T) {
|
func TestEncodeAsMap(t *testing.T) { //nolint:tparallel
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
Desc string
|
Desc string
|
||||||
|
@ -114,11 +116,9 @@ func Test_EncodeAsMap(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx := range cases {
|
for idx := range cases { //nolint:paralleltest
|
||||||
c := cases[idx]
|
c := cases[idx]
|
||||||
t.Run(c.Desc, func(t *testing.T) {
|
t.Run(c.Desc, func(t *testing.T) {
|
||||||
// t.Parallel()
|
|
||||||
|
|
||||||
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig)
|
res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, c.Expected, res)
|
assert.Equal(t, c.Expected, res)
|
||||||
|
@ -126,9 +126,11 @@ func Test_EncodeAsMap(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_EncodeValue(t *testing.T) {
|
func TestEncodeValue(t *testing.T) { //nolint:tparallel
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
refTime := time.Date(2022, time.February, 15, 9, 51, 0, 0, time.UTC)
|
||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
Desc string
|
Desc string
|
||||||
|
@ -247,11 +249,9 @@ func Test_EncodeValue(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx := range cases {
|
for idx := range cases { //nolint:paralleltest
|
||||||
c := cases[idx]
|
c := cases[idx]
|
||||||
t.Run(c.Desc, func(t *testing.T) {
|
t.Run(c.Desc, func(t *testing.T) {
|
||||||
//t.Parallel()
|
|
||||||
|
|
||||||
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
|
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, c.Output, res)
|
assert.Equal(t, c.Output, res)
|
||||||
|
|
|
@ -57,6 +57,8 @@ func WithNamedArgs(args map[string]interface{}) QueryOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithSchema returns a query option that adds the given table
|
||||||
|
// schema to the query.
|
||||||
func WithSchema(tbl TableSchema) QueryOption {
|
func WithSchema(tbl TableSchema) QueryOption {
|
||||||
return func(opts *queryOpts) {
|
return func(opts *queryOpts) {
|
||||||
opts.Schema = tbl
|
opts.Schema = tbl
|
||||||
|
@ -139,9 +141,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
|
||||||
valElemType = valType.Elem()
|
valElemType = valType.Elem()
|
||||||
|
|
||||||
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
opts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||||
var currentField reflect.Value
|
currentField := reflect.New(valElemType)
|
||||||
|
|
||||||
currentField = reflect.New(valElemType)
|
|
||||||
|
|
||||||
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -10,10 +10,9 @@ import (
|
||||||
"zombiezen.com/go/sqlite"
|
"zombiezen.com/go/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var errSkipStructField = errors.New("struct field should be skipped")
|
||||||
errSkipStructField = errors.New("struct field should be skipped")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
// Struct Tags.
|
||||||
var (
|
var (
|
||||||
TagUnixNano = "unixnano"
|
TagUnixNano = "unixnano"
|
||||||
TagPrimaryKey = "primary"
|
TagPrimaryKey = "primary"
|
||||||
|
@ -36,12 +35,14 @@ var sqlTypeMap = map[sqlite.ColumnType]string{
|
||||||
}
|
}
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
// TableSchema defines a SQL table schema.
|
||||||
TableSchema struct {
|
TableSchema struct {
|
||||||
Name string
|
Name string
|
||||||
Columns []ColumnDef
|
Columns []ColumnDef
|
||||||
}
|
}
|
||||||
|
|
||||||
ColumnDef struct {
|
// ColumnDef defines a SQL column.
|
||||||
|
ColumnDef struct { //nolint:maligned
|
||||||
Name string
|
Name string
|
||||||
Nullable bool
|
Nullable bool
|
||||||
Type sqlite.ColumnType
|
Type sqlite.ColumnType
|
||||||
|
@ -54,6 +55,7 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetColumnDef returns the column definition with the given name.
|
||||||
func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
|
func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
|
||||||
for _, def := range ts.Columns {
|
for _, def := range ts.Columns {
|
||||||
if def.Name == name {
|
if def.Name == name {
|
||||||
|
@ -63,6 +65,7 @@ func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateStatement build the CREATE SQL statement for the table.
|
||||||
func (ts TableSchema) CreateStatement(ifNotExists bool) string {
|
func (ts TableSchema) CreateStatement(ifNotExists bool) string {
|
||||||
sql := "CREATE TABLE"
|
sql := "CREATE TABLE"
|
||||||
if ifNotExists {
|
if ifNotExists {
|
||||||
|
@ -81,6 +84,7 @@ func (ts TableSchema) CreateStatement(ifNotExists bool) string {
|
||||||
return sql
|
return sql
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AsSQL builds the SQL column definition.
|
||||||
func (def ColumnDef) AsSQL() string {
|
func (def ColumnDef) AsSQL() string {
|
||||||
sql := def.Name + " "
|
sql := def.Name + " "
|
||||||
|
|
||||||
|
@ -103,6 +107,7 @@ func (def ColumnDef) AsSQL() string {
|
||||||
return sql
|
return sql
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateTableSchema generates a table schema from the given struct.
|
||||||
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
|
func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) {
|
||||||
ts := &TableSchema{
|
ts := &TableSchema{
|
||||||
Name: name,
|
Name: name,
|
||||||
|
@ -149,7 +154,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
|
||||||
def.GoType = ft
|
def.GoType = ft
|
||||||
kind := normalizeKind(ft.Kind())
|
kind := normalizeKind(ft.Kind())
|
||||||
|
|
||||||
switch kind {
|
switch kind { //nolint:exhaustive
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
def.Type = sqlite.TypeInteger
|
def.Type = sqlite.TypeInteger
|
||||||
|
|
||||||
|
@ -190,7 +195,7 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error {
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
for _, k := range parts[1:] {
|
for _, k := range parts[1:] {
|
||||||
switch k {
|
switch k {
|
||||||
// column modifieres
|
// column modifiers
|
||||||
case TagPrimaryKey:
|
case TagPrimaryKey:
|
||||||
def.PrimaryKey = true
|
def.PrimaryKey = true
|
||||||
case TagAutoIncrement:
|
case TagAutoIncrement:
|
||||||
|
|
|
@ -6,7 +6,9 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_SchemaBuilder(t *testing.T) {
|
func TestSchemaBuilder(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
Name string
|
Name string
|
||||||
Model interface{}
|
Model interface{}
|
||||||
|
|
|
@ -5,15 +5,19 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/safing/portmaster/netquery/orm"
|
|
||||||
"zombiezen.com/go/sqlite"
|
"zombiezen.com/go/sqlite"
|
||||||
|
|
||||||
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Collection of Query and Matcher types.
|
||||||
|
// NOTE: whenever adding support for new operators make sure
|
||||||
|
// to update UnmarshalJSON as well.
|
||||||
|
//nolint:golint
|
||||||
type (
|
type (
|
||||||
Query map[string][]Matcher
|
Query map[string][]Matcher
|
||||||
|
|
||||||
|
@ -43,8 +47,6 @@ type (
|
||||||
Distinct bool `json:"distinct"`
|
Distinct bool `json:"distinct"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: whenever adding support for new operators make sure
|
|
||||||
// to update UnmarshalJSON as well.
|
|
||||||
Select struct {
|
Select struct {
|
||||||
Field string `json:"field"`
|
Field string `json:"field"`
|
||||||
Count *Count `json:"$count,omitempty"`
|
Count *Count `json:"$count,omitempty"`
|
||||||
|
@ -91,6 +93,7 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals a Query from json.
|
||||||
func (query *Query) UnmarshalJSON(blob []byte) error {
|
func (query *Query) UnmarshalJSON(blob []byte) error {
|
||||||
if *query == nil {
|
if *query == nil {
|
||||||
*query = make(Query)
|
*query = make(Query)
|
||||||
|
@ -202,13 +205,14 @@ func parseMatcher(raw json.RawMessage) (*Matcher, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Validate(); err != nil {
|
if err := m.Validate(); err != nil {
|
||||||
return nil, fmt.Errorf("invalid query matcher: %s", err)
|
return nil, fmt.Errorf("invalid query matcher: %w", err)
|
||||||
}
|
}
|
||||||
log.Printf("parsed matcher %s: %+v", string(raw), m)
|
|
||||||
return &m, nil
|
|
||||||
|
|
||||||
|
// log.Printf("parsed matcher %s: %+v", string(raw), m)
|
||||||
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate validates the matcher.
|
||||||
func (match Matcher) Validate() error {
|
func (match Matcher) Validate() error {
|
||||||
found := 0
|
found := 0
|
||||||
|
|
||||||
|
@ -239,9 +243,9 @@ func (match Matcher) Validate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (text TextSearch) toSQLConditionClause(ctx context.Context, schema *orm.TableSchema, suffix string, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||||
var (
|
var (
|
||||||
queryParts []string
|
queryParts = make([]string, 0, len(text.Fields))
|
||||||
params = make(map[string]interface{})
|
params = make(map[string]interface{})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -379,7 +383,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T
|
||||||
// merge parameters up into the superior parameter map
|
// merge parameters up into the superior parameter map
|
||||||
for key, val := range params {
|
for key, val := range params {
|
||||||
if _, ok := paramMap[key]; ok {
|
if _, ok := paramMap[key]; ok {
|
||||||
// is is soley a developer mistake when implementing a matcher so no forgiving ...
|
// This is solely a developer mistake when implementing a matcher so no forgiving ...
|
||||||
panic("sqlite parameter collision")
|
panic("sqlite parameter collision")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -399,6 +403,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.T
|
||||||
return whereClause, paramMap, errs.ErrorOrNil()
|
return whereClause, paramMap, errs.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals a Selects from json.
|
||||||
func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
||||||
if len(blob) == 0 {
|
if len(blob) == 0 {
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
|
@ -438,6 +443,7 @@ func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals a Select from json.
|
||||||
func (sel *Select) UnmarshalJSON(blob []byte) error {
|
func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||||
if len(blob) == 0 {
|
if len(blob) == 0 {
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
|
@ -481,6 +487,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals a OrderBys from json.
|
||||||
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
||||||
if len(blob) == 0 {
|
if len(blob) == 0 {
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
|
@ -523,6 +530,7 @@ func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals a OrderBy from json.
|
||||||
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
|
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
|
||||||
if len(blob) == 0 {
|
if len(blob) == 0 {
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
|
|
|
@ -17,9 +17,7 @@ import (
|
||||||
"github.com/safing/portmaster/netquery/orm"
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
|
||||||
charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
|
||||||
|
@ -109,7 +107,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) {
|
func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { //nolint:dupl
|
||||||
var body io.Reader
|
var body io.Reader
|
||||||
|
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
|
@ -230,11 +228,11 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case s.Count != nil:
|
case s.Count != nil:
|
||||||
var as = s.Count.As
|
as := s.Count.As
|
||||||
if as == "" {
|
if as == "" {
|
||||||
as = fmt.Sprintf("%s_count", colName)
|
as = fmt.Sprintf("%s_count", colName)
|
||||||
}
|
}
|
||||||
var distinct = ""
|
distinct := ""
|
||||||
if s.Count.Distinct {
|
if s.Count.Distinct {
|
||||||
distinct = "DISTINCT "
|
distinct = "DISTINCT "
|
||||||
}
|
}
|
||||||
|
@ -278,8 +276,7 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var groupBys = make([]string, len(req.GroupBy))
|
groupBys := make([]string, len(req.GroupBy))
|
||||||
|
|
||||||
for idx, name := range req.GroupBy {
|
for idx, name := range req.GroupBy {
|
||||||
colName, err := req.validateColumnName(schema, name)
|
colName, err := req.validateColumnName(schema, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -288,7 +285,6 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
|
||||||
|
|
||||||
groupBys[idx] = colName
|
groupBys[idx] = colName
|
||||||
}
|
}
|
||||||
|
|
||||||
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
|
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
|
||||||
|
|
||||||
// if there are no explicitly selected fields we default to the
|
// if there are no explicitly selected fields we default to the
|
||||||
|
@ -301,7 +297,7 @@ func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *QueryRequestPayload) generateSelectClause() string {
|
func (req *QueryRequestPayload) generateSelectClause() string {
|
||||||
var selectClause = "*"
|
selectClause := "*"
|
||||||
if len(req.selectedFields) > 0 {
|
if len(req.selectedFields) > 0 {
|
||||||
selectClause = strings.Join(req.selectedFields, ", ")
|
selectClause = strings.Join(req.selectedFields, ", ")
|
||||||
}
|
}
|
||||||
|
@ -314,7 +310,7 @@ func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var orderBys = make([]string, len(req.OrderBy))
|
orderBys := make([]string, len(req.OrderBy))
|
||||||
for idx, sort := range req.OrderBy {
|
for idx, sort := range req.OrderBy {
|
||||||
colName, err := req.validateColumnName(schema, sort.Field)
|
colName, err := req.validateColumnName(schema, sort.Field)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -352,5 +348,5 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
|
||||||
return "", fmt.Errorf("column name %q not allowed", field)
|
return "", fmt.Errorf("column name %q not allowed", field)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compile time check
|
// Compile time check.
|
||||||
var _ http.Handler = new(QueryHandler)
|
var _ http.Handler = new(QueryHandler)
|
||||||
|
|
|
@ -7,13 +7,16 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/safing/portmaster/netquery/orm"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_UnmarshalQuery(t *testing.T) {
|
func TestUnmarshalQuery(t *testing.T) { //nolint:tparallel
|
||||||
var cases = []struct {
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
Name string
|
Name string
|
||||||
Input string
|
Input string
|
||||||
Expected Query
|
Expected Query
|
||||||
|
@ -88,7 +91,8 @@ func Test_UnmarshalQuery(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range cases {
|
for _, testCase := range cases { //nolint:paralleltest
|
||||||
|
c := testCase
|
||||||
t.Run(c.Name, func(t *testing.T) {
|
t.Run(c.Name, func(t *testing.T) {
|
||||||
var q Query
|
var q Query
|
||||||
err := json.Unmarshal([]byte(c.Input), &q)
|
err := json.Unmarshal([]byte(c.Input), &q)
|
||||||
|
@ -105,10 +109,11 @@ func Test_UnmarshalQuery(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_QueryBuilder(t *testing.T) {
|
func TestQueryBuilder(t *testing.T) { //nolint:tparallel
|
||||||
now := time.Now()
|
t.Parallel()
|
||||||
|
|
||||||
var cases = []struct {
|
now := time.Now()
|
||||||
|
cases := []struct {
|
||||||
N string
|
N string
|
||||||
Q Query
|
Q Query
|
||||||
R string
|
R string
|
||||||
|
@ -186,7 +191,7 @@ func Test_QueryBuilder(t *testing.T) {
|
||||||
},
|
},
|
||||||
"",
|
"",
|
||||||
nil,
|
nil,
|
||||||
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"),
|
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), //nolint:golint
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Complex example",
|
"Complex example",
|
||||||
|
@ -225,19 +230,20 @@ func Test_QueryBuilder(t *testing.T) {
|
||||||
tbl, err := orm.GenerateTableSchema("connections", Conn{})
|
tbl, err := orm.GenerateTableSchema("connections", Conn{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for idx, c := range cases {
|
for idx, testCase := range cases { //nolint:paralleltest
|
||||||
|
cID := idx
|
||||||
|
c := testCase
|
||||||
t.Run(c.N, func(t *testing.T) {
|
t.Run(c.N, func(t *testing.T) {
|
||||||
//t.Parallel()
|
|
||||||
str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
|
str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
|
||||||
|
|
||||||
if c.E != nil {
|
if c.E != nil {
|
||||||
if assert.Error(t, err) {
|
if assert.Error(t, err) {
|
||||||
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", idx)
|
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", cID)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err, "test case %d", idx)
|
assert.NoError(t, err, "test case %d", cID)
|
||||||
assert.Equal(t, c.P, params, "test case %d", idx)
|
assert.Equal(t, c.P, params, "test case %d", cID)
|
||||||
assert.Equal(t, c.R, str, "test case %d", idx)
|
assert.Equal(t, c.R, str, "test case %d", cID)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
cleanerTickDuration = 5 * time.Second
|
// DeleteConnsAfterEndedThreshold defines the amount of time after which
|
||||||
|
// ended connections should be removed from the internal connection state.
|
||||||
DeleteConnsAfterEndedThreshold = 10 * time.Minute
|
DeleteConnsAfterEndedThreshold = 10 * time.Minute
|
||||||
|
|
||||||
|
cleanerTickDuration = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
func connectionCleaner(ctx context.Context) error {
|
func connectionCleaner(ctx context.Context) error {
|
||||||
|
|
Loading…
Add table
Reference in a new issue