mirror of
https://github.com/safing/portmaster
synced 2025-09-02 10:39:22 +00:00
Add support for new query API
This commit is contained in:
parent
e21eb16a6b
commit
d098f1c137
10 changed files with 1154 additions and 43 deletions
|
@ -44,6 +44,8 @@ type (
|
||||||
// are actually supposed to do.
|
// are actually supposed to do.
|
||||||
//
|
//
|
||||||
Database struct {
|
Database struct {
|
||||||
|
Schema *orm.TableSchema
|
||||||
|
|
||||||
l sync.Mutex
|
l sync.Mutex
|
||||||
conn *sqlite.Conn
|
conn *sqlite.Conn
|
||||||
}
|
}
|
||||||
|
@ -62,6 +64,8 @@ type (
|
||||||
// as long as the connection is still active and might be, although unlikely,
|
// as long as the connection is still active and might be, although unlikely,
|
||||||
// reused afterwards.
|
// reused afterwards.
|
||||||
ID string `sqlite:"id,primary"`
|
ID string `sqlite:"id,primary"`
|
||||||
|
ProfileID string `sqlite:"profile"`
|
||||||
|
Path string `sqlite:"path"`
|
||||||
Type string `sqlite:"type,varchar(8)"`
|
Type string `sqlite:"type,varchar(8)"`
|
||||||
External bool `sqlite:"external"`
|
External bool `sqlite:"external"`
|
||||||
IPVersion packet.IPVersion `sqlite:"ip_version"`
|
IPVersion packet.IPVersion `sqlite:"ip_version"`
|
||||||
|
@ -78,8 +82,8 @@ type (
|
||||||
Longitude float64 `sqlite:"longitude"`
|
Longitude float64 `sqlite:"longitude"`
|
||||||
Scope netutils.IPScope `sqlite:"scope"`
|
Scope netutils.IPScope `sqlite:"scope"`
|
||||||
Verdict network.Verdict `sqlite:"verdict"`
|
Verdict network.Verdict `sqlite:"verdict"`
|
||||||
Started time.Time `sqlite:"started,text"`
|
Started time.Time `sqlite:"started,text,time"`
|
||||||
Ended *time.Time `sqlite:"ended,text"`
|
Ended *time.Time `sqlite:"ended,text,time"`
|
||||||
Tunneled bool `sqlite:"tunneled"`
|
Tunneled bool `sqlite:"tunneled"`
|
||||||
Encrypted bool `sqlite:"encrypted"`
|
Encrypted bool `sqlite:"encrypted"`
|
||||||
Internal bool `sqlite:"internal"`
|
Internal bool `sqlite:"internal"`
|
||||||
|
@ -107,7 +111,15 @@ func New(path string) (*Database, error) {
|
||||||
return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err)
|
return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Database{conn: c}, nil
|
schema, err := orm.GenerateTableSchema("connections", Conn{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Database{
|
||||||
|
Schema: schema,
|
||||||
|
conn: c,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInMemory is like New but creates a new in-memory database and
|
// NewInMemory is like New but creates a new in-memory database and
|
||||||
|
@ -133,13 +145,8 @@ func NewInMemory() (*Database, error) {
|
||||||
// any data-migrations. Once the history module is implemented this should
|
// any data-migrations. Once the history module is implemented this should
|
||||||
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration
|
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration
|
||||||
func (db *Database) ApplyMigrations() error {
|
func (db *Database) ApplyMigrations() error {
|
||||||
schema, err := orm.GenerateTableSchema("connections", Conn{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate table schema for conncetions: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the create-table SQL statement from the infered schema
|
// get the create-table SQL statement from the infered schema
|
||||||
sql := schema.CreateStatement(false)
|
sql := db.Schema.CreateStatement(false)
|
||||||
|
|
||||||
// execute the SQL
|
// execute the SQL
|
||||||
if err := sqlitex.ExecuteTransient(db.conn, sql, nil); err != nil {
|
if err := sqlitex.ExecuteTransient(db.conn, sql, nil); err != nil {
|
||||||
|
@ -284,7 +291,7 @@ func (db *Database) Save(ctx context.Context, conn Conn) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Errorf("netquery: failed to execute: %s", err)
|
log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
|
||||||
|
|
||||||
model, err := convertConnection(conn)
|
model, err := convertConnection(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("netquery: failed to convert connection %s to sqlite model: %w", conn.ID, err)
|
log.Errorf("netquery: failed to convert connection %s to sqlite model: %s", conn.ID, err)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
|
||||||
log.Infof("netquery: persisting create/update to connection %s", conn.ID)
|
log.Infof("netquery: persisting create/update to connection %s", conn.ID)
|
||||||
|
|
||||||
if err := mng.store.Save(ctx, *model); err != nil {
|
if err := mng.store.Save(ctx, *model); err != nil {
|
||||||
log.Errorf("netquery: failed to save connection %s in sqlite database: %w", conn.ID, err)
|
log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -116,7 +116,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
|
||||||
|
|
||||||
// push an update for the connection
|
// push an update for the connection
|
||||||
if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil {
|
if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil {
|
||||||
log.Errorf("netquery: failed to push update for conn %s via database system: %w", conn.ID, err)
|
log.Errorf("netquery: failed to push update for conn %s via database system: %s", conn.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
count++
|
count++
|
||||||
|
@ -170,6 +170,8 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
|
||||||
Internal: conn.Internal,
|
Internal: conn.Internal,
|
||||||
Inbound: conn.Inbound,
|
Inbound: conn.Inbound,
|
||||||
Type: ConnectionTypeToString[conn.Type],
|
Type: ConnectionTypeToString[conn.Type],
|
||||||
|
ProfileID: conn.ProcessContext.ProfileName,
|
||||||
|
Path: conn.ProcessContext.BinaryPath,
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.Ended > 0 {
|
if conn.Ended > 0 {
|
||||||
|
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/safing/portbase/api"
|
||||||
|
"github.com/safing/portbase/config"
|
||||||
"github.com/safing/portbase/database"
|
"github.com/safing/portbase/database"
|
||||||
"github.com/safing/portbase/database/query"
|
"github.com/safing/portbase/database/query"
|
||||||
"github.com/safing/portbase/log"
|
"github.com/safing/portbase/log"
|
||||||
|
@ -29,6 +31,7 @@ func init() {
|
||||||
mod.Prepare,
|
mod.Prepare,
|
||||||
mod.Start,
|
mod.Start,
|
||||||
mod.Stop,
|
mod.Stop,
|
||||||
|
"api",
|
||||||
"network",
|
"network",
|
||||||
"database",
|
"database",
|
||||||
)
|
)
|
||||||
|
@ -55,6 +58,25 @@ func (m *Module) Prepare() error {
|
||||||
|
|
||||||
m.feed = make(chan *network.Connection, 1000)
|
m.feed = make(chan *network.Connection, 1000)
|
||||||
|
|
||||||
|
queryHander := &QueryHandler{
|
||||||
|
Database: m.sqlStore,
|
||||||
|
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME(ppacher): use appropriate permissions for this
|
||||||
|
if err := api.RegisterEndpoint(api.Endpoint{
|
||||||
|
Path: "netquery/query",
|
||||||
|
MimeType: "application/json",
|
||||||
|
Read: api.PermitAnyone,
|
||||||
|
Write: api.PermitAnyone,
|
||||||
|
BelongsTo: m.Module,
|
||||||
|
HandlerFunc: queryHander.ServeHTTP,
|
||||||
|
Name: "Query In-Memory Database",
|
||||||
|
Description: "Query the in-memory sqlite database",
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,7 +122,7 @@ func (mod *Module) Start() error {
|
||||||
case <-time.After(10 * time.Second):
|
case <-time.After(10 * time.Second):
|
||||||
count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold))
|
count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("netquery: failed to count number of rows in memory: %w", err)
|
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("netquery: successfully removed %d old rows", count)
|
log.Infof("netquery: successfully removed %d old rows", count)
|
||||||
}
|
}
|
||||||
|
@ -116,7 +138,7 @@ func (mod *Module) Start() error {
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
count, err := mod.sqlStore.CountRows(ctx)
|
count, err := mod.sqlStore.CountRows(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("netquery: failed to count number of rows in memory: %w", err)
|
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("netquery: currently holding %d rows in memory", count)
|
log.Infof("netquery: currently holding %d rows in memory", count)
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ var (
|
||||||
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
|
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
|
||||||
// unixepoch and unixnano-epoch where the nano variant is not offically supported by
|
// unixepoch and unixnano-epoch where the nano variant is not offically supported by
|
||||||
// SQLITE).
|
// SQLITE).
|
||||||
sqliteTimeFormat = "2006-01-02 15:04:05"
|
SqliteTimeFormat = "2006-01-02 15:04:05"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -209,7 +209,7 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||||
case sqlite.TypeText:
|
case sqlite.TypeText:
|
||||||
// stored ISO8601 but does not have any timezone information
|
// stored ISO8601 but does not have any timezone information
|
||||||
// assigned so we always treat it as loc here.
|
// assigned so we always treat it as loc here.
|
||||||
t, err := time.ParseInLocation(sqliteTimeFormat, stmt.ColumnText(colIdx), loc)
|
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,6 +103,13 @@ func encodeBasic() EncodeFunc {
|
||||||
kind = valType.Kind()
|
kind = valType.Kind()
|
||||||
|
|
||||||
if val.IsNil() {
|
if val.IsNil() {
|
||||||
|
if !col.Nullable {
|
||||||
|
// we need to set the zero value here since the column
|
||||||
|
// is not marked as nullable
|
||||||
|
//return reflect.New(valType).Elem().Interface(), true, nil
|
||||||
|
panic("nil pointer for not-null field")
|
||||||
|
}
|
||||||
|
|
||||||
return nil, true, nil
|
return nil, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,7 +140,7 @@ func encodeBasic() EncodeFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||||
return func(colDev *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
||||||
// if fieldType holds a pointer we need to dereference the value
|
// if fieldType holds a pointer we need to dereference the value
|
||||||
ft := valType.String()
|
ft := valType.String()
|
||||||
if valType.Kind() == reflect.Ptr {
|
if valType.Kind() == reflect.Ptr {
|
||||||
|
@ -142,44 +149,71 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
// we only care about "time.Time" here
|
// we only care about "time.Time" here
|
||||||
if ft != "time.Time" {
|
var t time.Time
|
||||||
return nil, false, nil
|
if ft == "time.Time" {
|
||||||
}
|
|
||||||
|
|
||||||
// handle the zero time as a NULL.
|
// handle the zero time as a NULL.
|
||||||
if !val.IsValid() || val.IsZero() {
|
if !val.IsValid() || val.IsZero() {
|
||||||
return nil, true, nil
|
return nil, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ok bool
|
||||||
valInterface := val.Interface()
|
valInterface := val.Interface()
|
||||||
t, ok := valInterface.(time.Time)
|
t, ok = valInterface.(time.Time)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
|
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch colDev.Type {
|
} else if valType.Kind() == reflect.String && colDef.IsTime {
|
||||||
|
var err error
|
||||||
|
t, err = time.Parse(time.RFC3339, val.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// we don't care ...
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch colDef.Type {
|
||||||
case sqlite.TypeInteger:
|
case sqlite.TypeInteger:
|
||||||
if colDev.UnixNano {
|
if colDef.UnixNano {
|
||||||
return t.UnixNano(), true, nil
|
return t.UnixNano(), true, nil
|
||||||
}
|
}
|
||||||
return t.Unix(), true, nil
|
return t.Unix(), true, nil
|
||||||
|
|
||||||
case sqlite.TypeText:
|
case sqlite.TypeText:
|
||||||
str := t.In(loc).Format(sqliteTimeFormat)
|
str := t.In(loc).Format(SqliteTimeFormat)
|
||||||
|
|
||||||
return str, true, nil
|
return str, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDev.Type)
|
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDef.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runEncodeHooks(colDev *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
||||||
if valType == nil {
|
if valType == nil {
|
||||||
|
if !colDef.Nullable {
|
||||||
|
switch colDef.Type {
|
||||||
|
case sqlite.TypeBlob:
|
||||||
|
return []byte{}, true, nil
|
||||||
|
case sqlite.TypeFloat:
|
||||||
|
return 0.0, true, nil
|
||||||
|
case sqlite.TypeText:
|
||||||
|
return "", true, nil
|
||||||
|
case sqlite.TypeInteger:
|
||||||
|
return 0, true, nil
|
||||||
|
default:
|
||||||
|
return nil, false, fmt.Errorf("unsupported sqlite data type: %s", colDef.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil, true, nil
|
return nil, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, fn := range hooks {
|
for _, fn := range hooks {
|
||||||
res, end, err := fn(colDev, valType, val)
|
res, end, err := fn(colDef, valType, val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, false, err
|
return res, false, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,7 @@ func Test_EncodeAsMap(t *testing.T) {
|
||||||
},
|
},
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"TinInt": refTime.UnixNano(),
|
"TinInt": refTime.UnixNano(),
|
||||||
"TinString": refTime.Format(sqliteTimeFormat),
|
"TinString": refTime.Format(SqliteTimeFormat),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -107,7 +107,7 @@ func Test_EncodeAsMap(t *testing.T) {
|
||||||
},
|
},
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"TinInt": refTime.UnixNano(),
|
"TinInt": refTime.UnixNano(),
|
||||||
"TinString": refTime.Format(sqliteTimeFormat),
|
"TinString": refTime.Format(SqliteTimeFormat),
|
||||||
"Tnil1": nil,
|
"Tnil1": nil,
|
||||||
"Tnil2": nil,
|
"Tnil2": nil,
|
||||||
},
|
},
|
||||||
|
@ -143,7 +143,7 @@ func Test_EncodeValue(t *testing.T) {
|
||||||
Type: sqlite.TypeText,
|
Type: sqlite.TypeText,
|
||||||
},
|
},
|
||||||
refTime,
|
refTime,
|
||||||
refTime.Format(sqliteTimeFormat),
|
refTime.Format(SqliteTimeFormat),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Special value time.Time as unix-epoch",
|
"Special value time.Time as unix-epoch",
|
||||||
|
@ -189,11 +189,12 @@ func Test_EncodeValue(t *testing.T) {
|
||||||
Type: sqlite.TypeText,
|
Type: sqlite.TypeText,
|
||||||
},
|
},
|
||||||
&refTime,
|
&refTime,
|
||||||
refTime.Format(sqliteTimeFormat),
|
refTime.Format(SqliteTimeFormat),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Special value untyped nil",
|
"Special value untyped nil",
|
||||||
ColumnDef{
|
ColumnDef{
|
||||||
|
Nullable: true,
|
||||||
IsTime: true,
|
IsTime: true,
|
||||||
Type: sqlite.TypeText,
|
Type: sqlite.TypeText,
|
||||||
},
|
},
|
||||||
|
@ -209,12 +210,47 @@ func Test_EncodeValue(t *testing.T) {
|
||||||
(*time.Time)(nil),
|
(*time.Time)(nil),
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"Time formated as string",
|
||||||
|
ColumnDef{
|
||||||
|
IsTime: true,
|
||||||
|
Type: sqlite.TypeText,
|
||||||
|
},
|
||||||
|
refTime.In(time.Local).Format(time.RFC3339),
|
||||||
|
refTime.Format(SqliteTimeFormat),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Nullable integer",
|
||||||
|
ColumnDef{
|
||||||
|
Type: sqlite.TypeInteger,
|
||||||
|
Nullable: true,
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Not-Null integer",
|
||||||
|
ColumnDef{
|
||||||
|
Name: "test",
|
||||||
|
Type: sqlite.TypeInteger,
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Not-Null string",
|
||||||
|
ColumnDef{
|
||||||
|
Type: sqlite.TypeText,
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
"",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx := range cases {
|
for idx := range cases {
|
||||||
c := cases[idx]
|
c := cases[idx]
|
||||||
t.Run(c.Desc, func(t *testing.T) {
|
t.Run(c.Desc, func(t *testing.T) {
|
||||||
// t.Parallel()
|
//t.Parallel()
|
||||||
|
|
||||||
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
|
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -53,6 +53,15 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
|
||||||
|
for _, def := range ts.Columns {
|
||||||
|
if def.Name == name {
|
||||||
|
return &def
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ts TableSchema) CreateStatement(ifNotExists bool) string {
|
func (ts TableSchema) CreateStatement(ifNotExists bool) string {
|
||||||
sql := "CREATE TABLE"
|
sql := "CREATE TABLE"
|
||||||
if ifNotExists {
|
if ifNotExists {
|
||||||
|
|
464
netquery/query.go
Normal file
464
netquery/query.go
Normal file
|
@ -0,0 +1,464 @@
|
||||||
|
package netquery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Query map[string][]Matcher
|
||||||
|
|
||||||
|
Matcher struct {
|
||||||
|
Equal interface{} `json:"$eq,omitempty"`
|
||||||
|
NotEqual interface{} `json:"$ne,omitempty"`
|
||||||
|
In []interface{} `json:"$in,omitempty"`
|
||||||
|
NotIn []interface{} `json:"$notIn,omitempty"`
|
||||||
|
Like string `json:"$like,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
Count struct {
|
||||||
|
As string `json:"as"`
|
||||||
|
Field string `json:"field"`
|
||||||
|
Distinct bool `json:"distict"`
|
||||||
|
}
|
||||||
|
|
||||||
|
Select struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Count *Count `json:"$count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
Selects []Select
|
||||||
|
|
||||||
|
QueryRequestPayload struct {
|
||||||
|
Select Selects `json:"select"`
|
||||||
|
Query Query `json:"query"`
|
||||||
|
OrderBy []OrderBy `json:"orderBy"`
|
||||||
|
GroupBy []string `json:"groupBy"`
|
||||||
|
|
||||||
|
selectedFields []string
|
||||||
|
whitelistedFields []string
|
||||||
|
}
|
||||||
|
|
||||||
|
OrderBy struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Desc bool `json:"desc"`
|
||||||
|
}
|
||||||
|
|
||||||
|
OrderBys []OrderBy
|
||||||
|
)
|
||||||
|
|
||||||
|
func (query *Query) UnmarshalJSON(blob []byte) error {
|
||||||
|
if *query == nil {
|
||||||
|
*query = make(Query)
|
||||||
|
}
|
||||||
|
|
||||||
|
var model map[string]json.RawMessage
|
||||||
|
|
||||||
|
if err := json.Unmarshal(blob, &model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for columnName, rawColumnQuery := range model {
|
||||||
|
if len(rawColumnQuery) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch rawColumnQuery[0] {
|
||||||
|
case '{':
|
||||||
|
m, err := parseMatcher(rawColumnQuery)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
(*query)[columnName] = []Matcher{*m}
|
||||||
|
|
||||||
|
case '[':
|
||||||
|
var rawMatchers []json.RawMessage
|
||||||
|
if err := json.Unmarshal(rawColumnQuery, &rawMatchers); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
(*query)[columnName] = make([]Matcher, len(rawMatchers))
|
||||||
|
for idx, val := range rawMatchers {
|
||||||
|
// this should not happen
|
||||||
|
if len(val) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if val starts with a { we have a matcher definition
|
||||||
|
if val[0] == '{' {
|
||||||
|
m, err := parseMatcher(val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
(*query)[columnName][idx] = *m
|
||||||
|
|
||||||
|
continue
|
||||||
|
} else if val[0] == '[' {
|
||||||
|
return fmt.Errorf("invalid token [ in query for column %s", columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// val is a dedicated JSON primitive and not an object or array
|
||||||
|
// so we treat that as an EQUAL condition.
|
||||||
|
var x interface{}
|
||||||
|
if err := json.Unmarshal(val, &x); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
(*query)[columnName][idx] = Matcher{
|
||||||
|
Equal: x,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
// value is a JSON primitive and not an object or array
|
||||||
|
// so we treat that as an EQUAL condition.
|
||||||
|
var x interface{}
|
||||||
|
if err := json.Unmarshal(rawColumnQuery, &x); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
(*query)[columnName] = []Matcher{
|
||||||
|
{Equal: x},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMatcher(raw json.RawMessage) (*Matcher, error) {
|
||||||
|
var m Matcher
|
||||||
|
if err := json.Unmarshal(raw, &m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid query matcher: %s", err)
|
||||||
|
}
|
||||||
|
log.Printf("parsed matcher %s: %+v", string(raw), m)
|
||||||
|
return &m, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (match Matcher) Validate() error {
|
||||||
|
found := 0
|
||||||
|
|
||||||
|
if match.Equal != nil {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.NotEqual != nil {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.In != nil {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.NotIn != nil {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Like != "" {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
|
||||||
|
if found == 0 {
|
||||||
|
return fmt.Errorf("no conditions specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||||
|
var (
|
||||||
|
queryParts []string
|
||||||
|
params = make(map[string]interface{})
|
||||||
|
errs = new(multierror.Error)
|
||||||
|
key = fmt.Sprintf("%s%d", colDef.Name, idx)
|
||||||
|
)
|
||||||
|
|
||||||
|
add := func(operator, suffix string, values ...interface{}) {
|
||||||
|
var placeholder []string
|
||||||
|
|
||||||
|
for idx, value := range values {
|
||||||
|
encodedValue, err := orm.EncodeValue(ctx, &colDef, value, encoderConfig)
|
||||||
|
if err != nil {
|
||||||
|
errs.Errors = append(errs.Errors,
|
||||||
|
fmt.Errorf("failed to encode %v for column %s: %w", value, colDef.Name, err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uniqKey := fmt.Sprintf(":%s%s%d", key, suffix, idx)
|
||||||
|
placeholder = append(placeholder, uniqKey)
|
||||||
|
params[uniqKey] = encodedValue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(placeholder) == 1 {
|
||||||
|
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0]))
|
||||||
|
} else {
|
||||||
|
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", ")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Equal != nil {
|
||||||
|
add("=", "eq", match.Equal)
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.NotEqual != nil {
|
||||||
|
add("!=", "ne", match.NotEqual)
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.In != nil {
|
||||||
|
add("IN", "in", match.In...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.NotIn != nil {
|
||||||
|
add("NOT IN", "notin", match.NotIn...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Like != "" {
|
||||||
|
add("LIKE", "like", match.Like)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(queryParts) == 0 {
|
||||||
|
// this is an empty matcher without a single condition.
|
||||||
|
// we convert that to a no-op TRUE value
|
||||||
|
return "( 1 = 1 )", nil, errs.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(queryParts) == 1 {
|
||||||
|
return queryParts[0], params, errs.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||||
|
if len(query) == 0 {
|
||||||
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a lookup map to validate column names
|
||||||
|
lm := make(map[string]orm.ColumnDef, len(m.Columns))
|
||||||
|
for _, col := range m.Columns {
|
||||||
|
lm[col.Name] = col
|
||||||
|
}
|
||||||
|
|
||||||
|
paramMap := make(map[string]interface{})
|
||||||
|
columnStmts := make([]string, 0, len(query))
|
||||||
|
|
||||||
|
// get all keys and sort them so we get a stable output
|
||||||
|
queryKeys := make([]string, 0, len(query))
|
||||||
|
for column := range query {
|
||||||
|
queryKeys = append(queryKeys, column)
|
||||||
|
}
|
||||||
|
sort.Strings(queryKeys)
|
||||||
|
|
||||||
|
// actually create the WHERE clause parts for each
|
||||||
|
// column in query.
|
||||||
|
errs := new(multierror.Error)
|
||||||
|
for _, column := range queryKeys {
|
||||||
|
values := query[column]
|
||||||
|
colDef, ok := lm[column]
|
||||||
|
if !ok {
|
||||||
|
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
queryParts := make([]string, len(values))
|
||||||
|
for idx, val := range values {
|
||||||
|
matcherQuery, params, err := val.toSQLConditionClause(ctx, idx, "AND", colDef, encoderConfig)
|
||||||
|
if err != nil {
|
||||||
|
errs.Errors = append(errs.Errors,
|
||||||
|
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
|
||||||
|
)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge parameters up into the superior parameter map
|
||||||
|
for key, val := range params {
|
||||||
|
if _, ok := paramMap[key]; ok {
|
||||||
|
// is is soley a developer mistake when implementing a matcher so no forgiving ...
|
||||||
|
panic("sqlite parameter collision")
|
||||||
|
}
|
||||||
|
|
||||||
|
paramMap[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
queryParts[idx] = matcherQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
columnStmts = append(columnStmts,
|
||||||
|
fmt.Sprintf("( %s )", strings.Join(queryParts, " OR ")),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause := strings.Join(columnStmts, " AND ")
|
||||||
|
|
||||||
|
return whereClause, paramMap, errs.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
||||||
|
if len(blob) == 0 {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we are looking at a slice directly decode into
|
||||||
|
// a []Select
|
||||||
|
if blob[0] == '[' {
|
||||||
|
var result []Select
|
||||||
|
if err := json.Unmarshal(blob, &result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
(*sel) = result
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if it's an object decode into a single select
|
||||||
|
if blob[0] == '{' {
|
||||||
|
var result Select
|
||||||
|
if err := json.Unmarshal(blob, &result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*sel = []Select{result}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise this is just the field name
|
||||||
|
var field string
|
||||||
|
if err := json.Unmarshal(blob, &field); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||||
|
if len(blob) == 0 {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we have an object at hand decode the select
|
||||||
|
// directly
|
||||||
|
if blob[0] == '{' {
|
||||||
|
var res struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Count *Count `json:"$count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(blob, &res); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sel.Count = res.Count
|
||||||
|
sel.Field = res.Field
|
||||||
|
|
||||||
|
if sel.Count != nil && sel.Count.As != "" {
|
||||||
|
if !charOnlyRegexp.MatchString(sel.Count.As) {
|
||||||
|
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var x string
|
||||||
|
if err := json.Unmarshal(blob, &x); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sel.Field = x
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
||||||
|
if len(blob) == 0 {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if blob[0] == '[' {
|
||||||
|
var result []OrderBy
|
||||||
|
if err := json.Unmarshal(blob, &result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*orderBys = result
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if blob[0] == '{' {
|
||||||
|
var result OrderBy
|
||||||
|
if err := json.Unmarshal(blob, &result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*orderBys = []OrderBy{result}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var field string
|
||||||
|
if err := json.Unmarshal(blob, &field); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*orderBys = []OrderBy{
|
||||||
|
{
|
||||||
|
Field: field,
|
||||||
|
Desc: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
|
||||||
|
if len(blob) == 0 {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if blob[0] == '{' {
|
||||||
|
var res struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Desc bool `json:"desc"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(blob, &res); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
orderBy.Desc = res.Desc
|
||||||
|
orderBy.Field = res.Field
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var field string
|
||||||
|
if err := json.Unmarshal(blob, &field); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
orderBy.Field = field
|
||||||
|
orderBy.Desc = false
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
293
netquery/query_handler.go
Normal file
293
netquery/query_handler.go
Normal file
|
@ -0,0 +1,293 @@
|
||||||
|
package netquery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/safing/portbase/log"
|
||||||
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
|
||||||
|
// QueryHandler implements http.Handler and allows to perform SQL
|
||||||
|
// query and aggregate functions on Database.
|
||||||
|
QueryHandler struct {
|
||||||
|
IsDevMode func() bool
|
||||||
|
Database *Database
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
requestPayload, err := qh.parseRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
queryParsed := time.Since(start)
|
||||||
|
|
||||||
|
query, paramMap, err := requestPayload.generateSQL(req.Context(), qh.Database.Schema)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlQueryBuilt := time.Since(start)
|
||||||
|
|
||||||
|
// actually execute the query against the database and collect the result
|
||||||
|
var result []map[string]interface{}
|
||||||
|
if err := qh.Database.Execute(
|
||||||
|
req.Context(),
|
||||||
|
query,
|
||||||
|
orm.WithNamedArgs(paramMap),
|
||||||
|
orm.WithResult(&result),
|
||||||
|
); err != nil {
|
||||||
|
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sqlQueryFinished := time.Since(start)
|
||||||
|
|
||||||
|
// send the HTTP status code
|
||||||
|
resp.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
// prepare the result encoder.
|
||||||
|
enc := json.NewEncoder(resp)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
|
||||||
|
// prepare the result body that, in dev mode, contains
|
||||||
|
// some diagnostics data about the query
|
||||||
|
var resultBody map[string]interface{}
|
||||||
|
if qh.IsDevMode() {
|
||||||
|
resultBody = map[string]interface{}{
|
||||||
|
"sql_prep_stmt": query,
|
||||||
|
"sql_params": paramMap,
|
||||||
|
"query": requestPayload.Query,
|
||||||
|
"orderBy": requestPayload.OrderBy,
|
||||||
|
"groupBy": requestPayload.GroupBy,
|
||||||
|
"selects": requestPayload.Select,
|
||||||
|
"times": map[string]interface{}{
|
||||||
|
"start_time": start,
|
||||||
|
"query_parsed_after": queryParsed.String(),
|
||||||
|
"query_built_after": sqlQueryBuilt.String(),
|
||||||
|
"query_executed_after": sqlQueryFinished.String(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resultBody = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
resultBody["results"] = result
|
||||||
|
|
||||||
|
// and finally stream the response
|
||||||
|
if err := enc.Encode(resultBody); err != nil {
|
||||||
|
// we failed to encode the JSON body to resp so we likely either already sent a
|
||||||
|
// few bytes or the pipe was already closed. In either case, trying to send the
|
||||||
|
// error using http.Error() is non-sense. We just log it out here and that's all
|
||||||
|
// we can do.
|
||||||
|
log.Errorf("failed to encode JSON response: %s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) {
|
||||||
|
var body io.Reader
|
||||||
|
|
||||||
|
switch req.Method {
|
||||||
|
case http.MethodPost, http.MethodPut:
|
||||||
|
body = req.Body
|
||||||
|
case http.MethodGet:
|
||||||
|
body = strings.NewReader(req.URL.Query().Get("q"))
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid HTTP method")
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestPayload QueryRequestPayload
|
||||||
|
blob, err := ioutil.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read body" + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
body = bytes.NewReader(blob)
|
||||||
|
|
||||||
|
dec := json.NewDecoder(body)
|
||||||
|
dec.DisallowUnknownFields()
|
||||||
|
|
||||||
|
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return nil, fmt.Errorf("invalid query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &requestPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||||
|
if err := req.prepareSelectedFields(schema); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the SQL where clause from the payload query
|
||||||
|
whereClause, paramMap, err := req.Query.toSQLWhereClause(
|
||||||
|
ctx,
|
||||||
|
schema,
|
||||||
|
orm.DefaultEncodeConfig,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("ganerating where clause: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the actual SQL query statement
|
||||||
|
// FIXME(ppacher): add support for group-by and sort-by
|
||||||
|
|
||||||
|
groupByClause, err := req.generateGroupByClause(schema)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("generating group-by clause: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
orderByClause, err := req.generateOrderByClause(schema)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("generating order-by clause: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
selectClause := req.generateSelectClause()
|
||||||
|
query := `SELECT ` + selectClause + ` FROM connections`
|
||||||
|
if whereClause != "" {
|
||||||
|
query += " WHERE " + whereClause
|
||||||
|
}
|
||||||
|
query += " " + groupByClause + " " + orderByClause
|
||||||
|
|
||||||
|
return query, paramMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) error {
|
||||||
|
for _, s := range req.Select {
|
||||||
|
var field string
|
||||||
|
if s.Count != nil {
|
||||||
|
field = s.Count.Field
|
||||||
|
} else {
|
||||||
|
field = s.Field
|
||||||
|
}
|
||||||
|
|
||||||
|
colName := "*"
|
||||||
|
if field != "*" || s.Count == nil {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
colName, err = req.validateColumnName(schema, field)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Count != nil {
|
||||||
|
var as = s.Count.As
|
||||||
|
if as == "" {
|
||||||
|
as = fmt.Sprintf("%s_count", colName)
|
||||||
|
}
|
||||||
|
var distinct = ""
|
||||||
|
if s.Count.Distinct {
|
||||||
|
distinct = "DISTINCT "
|
||||||
|
}
|
||||||
|
req.selectedFields = append(req.selectedFields, fmt.Sprintf("COUNT(%s%s) as %s", distinct, colName, as))
|
||||||
|
req.whitelistedFields = append(req.whitelistedFields, as)
|
||||||
|
} else {
|
||||||
|
req.selectedFields = append(req.selectedFields, colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
|
||||||
|
if len(req.GroupBy) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var groupBys = make([]string, len(req.GroupBy))
|
||||||
|
|
||||||
|
for idx, name := range req.GroupBy {
|
||||||
|
colName, err := req.validateColumnName(schema, name)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
groupBys[idx] = colName
|
||||||
|
}
|
||||||
|
|
||||||
|
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
|
||||||
|
|
||||||
|
// if there are no explicitly selected fields we default to the
|
||||||
|
// group-by columns as that's what's expected most of the time anyway...
|
||||||
|
if len(req.selectedFields) == 0 {
|
||||||
|
req.selectedFields = append(req.selectedFields, groupBys...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupByClause, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) generateSelectClause() string {
|
||||||
|
var selectClause = "*"
|
||||||
|
if len(req.selectedFields) > 0 {
|
||||||
|
selectClause = strings.Join(req.selectedFields, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return selectClause
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
|
||||||
|
var orderBys = make([]string, len(req.OrderBy))
|
||||||
|
for idx, sort := range req.OrderBy {
|
||||||
|
colName, err := req.validateColumnName(schema, sort.Field)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if sort.Desc {
|
||||||
|
orderBys[idx] = fmt.Sprintf("%s DESC", colName)
|
||||||
|
} else {
|
||||||
|
orderBys[idx] = fmt.Sprintf("%s ASC", colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "ORDER BY " + strings.Join(orderBys, ", "), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) {
|
||||||
|
colDef := schema.GetColumnDef(field)
|
||||||
|
if colDef != nil {
|
||||||
|
return colDef.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, selected := range req.whitelistedFields {
|
||||||
|
if field == selected {
|
||||||
|
return field, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, selected := range req.selectedFields {
|
||||||
|
if field == selected {
|
||||||
|
return field, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("column name %s not allowed", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// compile time check
|
||||||
|
var _ http.Handler = new(QueryHandler)
|
244
netquery/query_test.go
Normal file
244
netquery/query_test.go
Normal file
|
@ -0,0 +1,244 @@
|
||||||
|
package netquery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_UnmarshalQuery(t *testing.T) {
|
||||||
|
var cases = []struct {
|
||||||
|
Name string
|
||||||
|
Input string
|
||||||
|
Expected Query
|
||||||
|
Error error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"Parse a simple query",
|
||||||
|
`{ "domain": ["example.com", "example.at"] }`,
|
||||||
|
Query{
|
||||||
|
"domain": []Matcher{
|
||||||
|
{
|
||||||
|
Equal: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Equal: "example.at",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Parse a more complex query",
|
||||||
|
`
|
||||||
|
{
|
||||||
|
"domain": [
|
||||||
|
{
|
||||||
|
"$in": [
|
||||||
|
"example.at",
|
||||||
|
"example.com"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$like": "microsoft.%"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"path": [
|
||||||
|
"/bin/ping",
|
||||||
|
{
|
||||||
|
"$notin": [
|
||||||
|
"/sbin/ping",
|
||||||
|
"/usr/sbin/ping"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
Query{
|
||||||
|
"domain": []Matcher{
|
||||||
|
{
|
||||||
|
In: []interface{}{
|
||||||
|
"example.at",
|
||||||
|
"example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Like: "microsoft.%",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"path": []Matcher{
|
||||||
|
{
|
||||||
|
Equal: "/bin/ping",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NotIn: []interface{}{
|
||||||
|
"/sbin/ping",
|
||||||
|
"/usr/sbin/ping",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.Name, func(t *testing.T) {
|
||||||
|
var q Query
|
||||||
|
err := json.Unmarshal([]byte(c.Input), &q)
|
||||||
|
|
||||||
|
if c.Error != nil {
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Equal(t, c.Error.Error(), err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, c.Expected, q)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_QueryBuilder(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
var cases = []struct {
|
||||||
|
N string
|
||||||
|
Q Query
|
||||||
|
R string
|
||||||
|
P map[string]interface{}
|
||||||
|
E error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"No filter",
|
||||||
|
nil,
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Simple, one-column filter",
|
||||||
|
Query{"domain": []Matcher{
|
||||||
|
{
|
||||||
|
Equal: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Equal: "example.at",
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
"( domain = :domain0eq0 OR domain = :domain1eq0 )",
|
||||||
|
map[string]interface{}{
|
||||||
|
":domain0eq0": "example.com",
|
||||||
|
":domain1eq0": "example.at",
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Two column filter",
|
||||||
|
Query{
|
||||||
|
"domain": []Matcher{
|
||||||
|
{
|
||||||
|
Equal: "example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"path": []Matcher{
|
||||||
|
{
|
||||||
|
Equal: "/bin/curl",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Equal: "/bin/ping",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"( domain = :domain0eq0 ) AND ( path = :path0eq0 OR path = :path1eq0 )",
|
||||||
|
map[string]interface{}{
|
||||||
|
":domain0eq0": "example.com",
|
||||||
|
":path0eq0": "/bin/curl",
|
||||||
|
":path1eq0": "/bin/ping",
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Time based filter",
|
||||||
|
Query{
|
||||||
|
"started": []Matcher{
|
||||||
|
{
|
||||||
|
Equal: now.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"( started = :started0eq0 )",
|
||||||
|
map[string]interface{}{
|
||||||
|
":started0eq0": now.In(time.UTC).Format(orm.SqliteTimeFormat),
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Invalid column access",
|
||||||
|
Query{
|
||||||
|
"forbiddenField": []Matcher{{}},
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Complex example",
|
||||||
|
Query{
|
||||||
|
"domain": []Matcher{
|
||||||
|
{
|
||||||
|
In: []interface{}{"example.at", "example.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Like: "microsoft.%",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"path": []Matcher{
|
||||||
|
{
|
||||||
|
NotIn: []interface{}{
|
||||||
|
"/bin/ping",
|
||||||
|
"/sbin/ping",
|
||||||
|
"/usr/bin/ping",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"( domain IN ( :domain0in0, :domain0in1 ) OR domain LIKE :domain1like0 ) AND ( path NOT IN ( :path0notin0, :path0notin1, :path0notin2 ) )",
|
||||||
|
map[string]interface{}{
|
||||||
|
":domain0in0": "example.at",
|
||||||
|
":domain0in1": "example.com",
|
||||||
|
":domain1like0": "microsoft.%",
|
||||||
|
":path0notin0": "/bin/ping",
|
||||||
|
":path0notin1": "/sbin/ping",
|
||||||
|
":path0notin2": "/usr/bin/ping",
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tbl, err := orm.GenerateTableSchema("connections", Conn{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for idx, c := range cases {
|
||||||
|
t.Run(c.N, func(t *testing.T) {
|
||||||
|
//t.Parallel()
|
||||||
|
str, params, err := c.Q.toSQLWhereClause(context.TODO(), tbl, orm.DefaultEncodeConfig)
|
||||||
|
|
||||||
|
if c.E != nil {
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", idx)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err, "test case %d", idx)
|
||||||
|
assert.Equal(t, c.P, params, "test case %d", idx)
|
||||||
|
assert.Equal(t, c.R, str, "test case %d", idx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue