Add support for new query API

This commit is contained in:
Patrick Pacher 2022-04-26 14:59:27 +02:00
parent e21eb16a6b
commit d098f1c137
No known key found for this signature in database
GPG key ID: E8CD2DA160925A6D
10 changed files with 1154 additions and 43 deletions

View file

@ -44,6 +44,8 @@ type (
// are actually supposed to do. // are actually supposed to do.
// //
Database struct { Database struct {
Schema *orm.TableSchema
l sync.Mutex l sync.Mutex
conn *sqlite.Conn conn *sqlite.Conn
} }
@ -62,6 +64,8 @@ type (
// as long as the connection is still active and might be, although unlikely, // as long as the connection is still active and might be, although unlikely,
// reused afterwards. // reused afterwards.
ID string `sqlite:"id,primary"` ID string `sqlite:"id,primary"`
ProfileID string `sqlite:"profile"`
Path string `sqlite:"path"`
Type string `sqlite:"type,varchar(8)"` Type string `sqlite:"type,varchar(8)"`
External bool `sqlite:"external"` External bool `sqlite:"external"`
IPVersion packet.IPVersion `sqlite:"ip_version"` IPVersion packet.IPVersion `sqlite:"ip_version"`
@ -78,8 +82,8 @@ type (
Longitude float64 `sqlite:"longitude"` Longitude float64 `sqlite:"longitude"`
Scope netutils.IPScope `sqlite:"scope"` Scope netutils.IPScope `sqlite:"scope"`
Verdict network.Verdict `sqlite:"verdict"` Verdict network.Verdict `sqlite:"verdict"`
Started time.Time `sqlite:"started,text"` Started time.Time `sqlite:"started,text,time"`
Ended *time.Time `sqlite:"ended,text"` Ended *time.Time `sqlite:"ended,text,time"`
Tunneled bool `sqlite:"tunneled"` Tunneled bool `sqlite:"tunneled"`
Encrypted bool `sqlite:"encrypted"` Encrypted bool `sqlite:"encrypted"`
Internal bool `sqlite:"internal"` Internal bool `sqlite:"internal"`
@ -107,7 +111,15 @@ func New(path string) (*Database, error) {
return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err)
} }
return &Database{conn: c}, nil schema, err := orm.GenerateTableSchema("connections", Conn{})
if err != nil {
return nil, err
}
return &Database{
Schema: schema,
conn: c,
}, nil
} }
// NewInMemory is like New but creates a new in-memory database and // NewInMemory is like New but creates a new in-memory database and
@ -133,13 +145,8 @@ func NewInMemory() (*Database, error) {
// any data-migrations. Once the history module is implemented this should // any data-migrations. Once the history module is implemented this should
// become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration // become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration
func (db *Database) ApplyMigrations() error { func (db *Database) ApplyMigrations() error {
schema, err := orm.GenerateTableSchema("connections", Conn{})
if err != nil {
return fmt.Errorf("failed to generate table schema for conncetions: %w", err)
}
// get the create-table SQL statement from the infered schema // get the create-table SQL statement from the infered schema
sql := schema.CreateStatement(false) sql := db.Schema.CreateStatement(false)
// execute the SQL // execute the SQL
if err := sqlitex.ExecuteTransient(db.conn, sql, nil); err != nil { if err := sqlitex.ExecuteTransient(db.conn, sql, nil); err != nil {
@ -284,7 +291,7 @@ func (db *Database) Save(ctx context.Context, conn Conn) error {
return nil return nil
}, },
}); err != nil { }); err != nil {
log.Errorf("netquery: failed to execute: %s", err) log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values)
return err return err
} }

View file

@ -96,7 +96,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
model, err := convertConnection(conn) model, err := convertConnection(conn)
if err != nil { if err != nil {
log.Errorf("netquery: failed to convert connection %s to sqlite model: %w", conn.ID, err) log.Errorf("netquery: failed to convert connection %s to sqlite model: %s", conn.ID, err)
continue continue
} }
@ -104,7 +104,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
log.Infof("netquery: persisting create/update to connection %s", conn.ID) log.Infof("netquery: persisting create/update to connection %s", conn.ID)
if err := mng.store.Save(ctx, *model); err != nil { if err := mng.store.Save(ctx, *model); err != nil {
log.Errorf("netquery: failed to save connection %s in sqlite database: %w", conn.ID, err) log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err)
continue continue
} }
@ -116,7 +116,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect
// push an update for the connection // push an update for the connection
if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil { if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil {
log.Errorf("netquery: failed to push update for conn %s via database system: %w", conn.ID, err) log.Errorf("netquery: failed to push update for conn %s via database system: %s", conn.ID, err)
} }
count++ count++
@ -170,6 +170,8 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
Internal: conn.Internal, Internal: conn.Internal,
Inbound: conn.Inbound, Inbound: conn.Inbound,
Type: ConnectionTypeToString[conn.Type], Type: ConnectionTypeToString[conn.Type],
ProfileID: conn.ProcessContext.ProfileName,
Path: conn.ProcessContext.BinaryPath,
} }
if conn.Ended > 0 { if conn.Ended > 0 {

View file

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/safing/portbase/api"
"github.com/safing/portbase/config"
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/database/query" "github.com/safing/portbase/database/query"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
@ -29,6 +31,7 @@ func init() {
mod.Prepare, mod.Prepare,
mod.Start, mod.Start,
mod.Stop, mod.Stop,
"api",
"network", "network",
"database", "database",
) )
@ -55,6 +58,25 @@ func (m *Module) Prepare() error {
m.feed = make(chan *network.Connection, 1000) m.feed = make(chan *network.Connection, 1000)
queryHander := &QueryHandler{
Database: m.sqlStore,
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
}
// FIXME(ppacher): use appropriate permissions for this
if err := api.RegisterEndpoint(api.Endpoint{
Path: "netquery/query",
MimeType: "application/json",
Read: api.PermitAnyone,
Write: api.PermitAnyone,
BelongsTo: m.Module,
HandlerFunc: queryHander.ServeHTTP,
Name: "Query In-Memory Database",
Description: "Query the in-memory sqlite database",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
return nil return nil
} }
@ -100,7 +122,7 @@ func (mod *Module) Start() error {
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold)) count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold))
if err != nil { if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %w", err) log.Errorf("netquery: failed to count number of rows in memory: %s", err)
} else { } else {
log.Infof("netquery: successfully removed %d old rows", count) log.Infof("netquery: successfully removed %d old rows", count)
} }
@ -116,7 +138,7 @@ func (mod *Module) Start() error {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
count, err := mod.sqlStore.CountRows(ctx) count, err := mod.sqlStore.CountRows(ctx)
if err != nil { if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %w", err) log.Errorf("netquery: failed to count number of rows in memory: %s", err)
} else { } else {
log.Infof("netquery: currently holding %d rows in memory", count) log.Infof("netquery: currently holding %d rows in memory", count)
} }

View file

@ -31,7 +31,7 @@ var (
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between // preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
// unixepoch and unixnano-epoch where the nano variant is not offically supported by // unixepoch and unixnano-epoch where the nano variant is not offically supported by
// SQLITE). // SQLITE).
sqliteTimeFormat = "2006-01-02 15:04:05" SqliteTimeFormat = "2006-01-02 15:04:05"
) )
type ( type (
@ -209,7 +209,7 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
case sqlite.TypeText: case sqlite.TypeText:
// stored ISO8601 but does not have any timezone information // stored ISO8601 but does not have any timezone information
// assigned so we always treat it as loc here. // assigned so we always treat it as loc here.
t, err := time.ParseInLocation(sqliteTimeFormat, stmt.ColumnText(colIdx), loc) t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err) return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
} }

View file

@ -103,6 +103,13 @@ func encodeBasic() EncodeFunc {
kind = valType.Kind() kind = valType.Kind()
if val.IsNil() { if val.IsNil() {
if !col.Nullable {
// we need to set the zero value here since the column
// is not marked as nullable
//return reflect.New(valType).Elem().Interface(), true, nil
panic("nil pointer for not-null field")
}
return nil, true, nil return nil, true, nil
} }
@ -133,7 +140,7 @@ func encodeBasic() EncodeFunc {
} }
func DatetimeEncoder(loc *time.Location) EncodeFunc { func DatetimeEncoder(loc *time.Location) EncodeFunc {
return func(colDev *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) { return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
// if fieldType holds a pointer we need to dereference the value // if fieldType holds a pointer we need to dereference the value
ft := valType.String() ft := valType.String()
if valType.Kind() == reflect.Ptr { if valType.Kind() == reflect.Ptr {
@ -142,44 +149,71 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc {
} }
// we only care about "time.Time" here // we only care about "time.Time" here
if ft != "time.Time" { var t time.Time
return nil, false, nil if ft == "time.Time" {
}
// handle the zero time as a NULL. // handle the zero time as a NULL.
if !val.IsValid() || val.IsZero() { if !val.IsValid() || val.IsZero() {
return nil, true, nil return nil, true, nil
} }
var ok bool
valInterface := val.Interface() valInterface := val.Interface()
t, ok := valInterface.(time.Time) t, ok = valInterface.(time.Time)
if !ok { if !ok {
return nil, false, fmt.Errorf("cannot convert reflect value to time.Time") return nil, false, fmt.Errorf("cannot convert reflect value to time.Time")
} }
switch colDev.Type { } else if valType.Kind() == reflect.String && colDef.IsTime {
var err error
t, err = time.Parse(time.RFC3339, val.String())
if err != nil {
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
}
} else {
// we don't care ...
return nil, false, nil
}
switch colDef.Type {
case sqlite.TypeInteger: case sqlite.TypeInteger:
if colDev.UnixNano { if colDef.UnixNano {
return t.UnixNano(), true, nil return t.UnixNano(), true, nil
} }
return t.Unix(), true, nil return t.Unix(), true, nil
case sqlite.TypeText: case sqlite.TypeText:
str := t.In(loc).Format(sqliteTimeFormat) str := t.In(loc).Format(SqliteTimeFormat)
return str, true, nil return str, true, nil
} }
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDev.Type) return nil, false, fmt.Errorf("cannot store time.Time in %s", colDef.Type)
} }
} }
func runEncodeHooks(colDev *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) { func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
if valType == nil { if valType == nil {
if !colDef.Nullable {
switch colDef.Type {
case sqlite.TypeBlob:
return []byte{}, true, nil
case sqlite.TypeFloat:
return 0.0, true, nil
case sqlite.TypeText:
return "", true, nil
case sqlite.TypeInteger:
return 0, true, nil
default:
return nil, false, fmt.Errorf("unsupported sqlite data type: %s", colDef.Type)
}
}
return nil, true, nil return nil, true, nil
} }
for _, fn := range hooks { for _, fn := range hooks {
res, end, err := fn(colDev, valType, val) res, end, err := fn(colDef, valType, val)
if err != nil { if err != nil {
return res, false, err return res, false, err
} }

View file

@ -89,7 +89,7 @@ func Test_EncodeAsMap(t *testing.T) {
}, },
map[string]interface{}{ map[string]interface{}{
"TinInt": refTime.UnixNano(), "TinInt": refTime.UnixNano(),
"TinString": refTime.Format(sqliteTimeFormat), "TinString": refTime.Format(SqliteTimeFormat),
}, },
}, },
{ {
@ -107,7 +107,7 @@ func Test_EncodeAsMap(t *testing.T) {
}, },
map[string]interface{}{ map[string]interface{}{
"TinInt": refTime.UnixNano(), "TinInt": refTime.UnixNano(),
"TinString": refTime.Format(sqliteTimeFormat), "TinString": refTime.Format(SqliteTimeFormat),
"Tnil1": nil, "Tnil1": nil,
"Tnil2": nil, "Tnil2": nil,
}, },
@ -143,7 +143,7 @@ func Test_EncodeValue(t *testing.T) {
Type: sqlite.TypeText, Type: sqlite.TypeText,
}, },
refTime, refTime,
refTime.Format(sqliteTimeFormat), refTime.Format(SqliteTimeFormat),
}, },
{ {
"Special value time.Time as unix-epoch", "Special value time.Time as unix-epoch",
@ -189,11 +189,12 @@ func Test_EncodeValue(t *testing.T) {
Type: sqlite.TypeText, Type: sqlite.TypeText,
}, },
&refTime, &refTime,
refTime.Format(sqliteTimeFormat), refTime.Format(SqliteTimeFormat),
}, },
{ {
"Special value untyped nil", "Special value untyped nil",
ColumnDef{ ColumnDef{
Nullable: true,
IsTime: true, IsTime: true,
Type: sqlite.TypeText, Type: sqlite.TypeText,
}, },
@ -209,12 +210,47 @@ func Test_EncodeValue(t *testing.T) {
(*time.Time)(nil), (*time.Time)(nil),
nil, nil,
}, },
{
"Time formated as string",
ColumnDef{
IsTime: true,
Type: sqlite.TypeText,
},
refTime.In(time.Local).Format(time.RFC3339),
refTime.Format(SqliteTimeFormat),
},
{
"Nullable integer",
ColumnDef{
Type: sqlite.TypeInteger,
Nullable: true,
},
nil,
nil,
},
{
"Not-Null integer",
ColumnDef{
Name: "test",
Type: sqlite.TypeInteger,
},
nil,
0,
},
{
"Not-Null string",
ColumnDef{
Type: sqlite.TypeText,
},
nil,
"",
},
} }
for idx := range cases { for idx := range cases {
c := cases[idx] c := cases[idx]
t.Run(c.Desc, func(t *testing.T) { t.Run(c.Desc, func(t *testing.T) {
// t.Parallel() //t.Parallel()
res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig) res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig)
assert.NoError(t, err) assert.NoError(t, err)

View file

@ -53,6 +53,15 @@ type (
} }
) )
func (ts TableSchema) GetColumnDef(name string) *ColumnDef {
for _, def := range ts.Columns {
if def.Name == name {
return &def
}
}
return nil
}
func (ts TableSchema) CreateStatement(ifNotExists bool) string { func (ts TableSchema) CreateStatement(ifNotExists bool) string {
sql := "CREATE TABLE" sql := "CREATE TABLE"
if ifNotExists { if ifNotExists {

464
netquery/query.go Normal file
View file

@ -0,0 +1,464 @@
package netquery
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"sort"
"strings"
"github.com/hashicorp/go-multierror"
"github.com/safing/portmaster/netquery/orm"
)
type (
Query map[string][]Matcher
Matcher struct {
Equal interface{} `json:"$eq,omitempty"`
NotEqual interface{} `json:"$ne,omitempty"`
In []interface{} `json:"$in,omitempty"`
NotIn []interface{} `json:"$notIn,omitempty"`
Like string `json:"$like,omitempty"`
}
Count struct {
As string `json:"as"`
Field string `json:"field"`
Distinct bool `json:"distict"`
}
Select struct {
Field string `json:"field"`
Count *Count `json:"$count"`
}
Selects []Select
QueryRequestPayload struct {
Select Selects `json:"select"`
Query Query `json:"query"`
OrderBy []OrderBy `json:"orderBy"`
GroupBy []string `json:"groupBy"`
selectedFields []string
whitelistedFields []string
}
OrderBy struct {
Field string `json:"field"`
Desc bool `json:"desc"`
}
OrderBys []OrderBy
)
func (query *Query) UnmarshalJSON(blob []byte) error {
if *query == nil {
*query = make(Query)
}
var model map[string]json.RawMessage
if err := json.Unmarshal(blob, &model); err != nil {
return err
}
for columnName, rawColumnQuery := range model {
if len(rawColumnQuery) == 0 {
continue
}
switch rawColumnQuery[0] {
case '{':
m, err := parseMatcher(rawColumnQuery)
if err != nil {
return err
}
(*query)[columnName] = []Matcher{*m}
case '[':
var rawMatchers []json.RawMessage
if err := json.Unmarshal(rawColumnQuery, &rawMatchers); err != nil {
return err
}
(*query)[columnName] = make([]Matcher, len(rawMatchers))
for idx, val := range rawMatchers {
// this should not happen
if len(val) == 0 {
continue
}
// if val starts with a { we have a matcher definition
if val[0] == '{' {
m, err := parseMatcher(val)
if err != nil {
return err
}
(*query)[columnName][idx] = *m
continue
} else if val[0] == '[' {
return fmt.Errorf("invalid token [ in query for column %s", columnName)
}
// val is a dedicated JSON primitive and not an object or array
// so we treat that as an EQUAL condition.
var x interface{}
if err := json.Unmarshal(val, &x); err != nil {
return err
}
(*query)[columnName][idx] = Matcher{
Equal: x,
}
}
default:
// value is a JSON primitive and not an object or array
// so we treat that as an EQUAL condition.
var x interface{}
if err := json.Unmarshal(rawColumnQuery, &x); err != nil {
return err
}
(*query)[columnName] = []Matcher{
{Equal: x},
}
}
}
return nil
}
func parseMatcher(raw json.RawMessage) (*Matcher, error) {
var m Matcher
if err := json.Unmarshal(raw, &m); err != nil {
return nil, err
}
if err := m.Validate(); err != nil {
return nil, fmt.Errorf("invalid query matcher: %s", err)
}
log.Printf("parsed matcher %s: %+v", string(raw), m)
return &m, nil
}
func (match Matcher) Validate() error {
found := 0
if match.Equal != nil {
found++
}
if match.NotEqual != nil {
found++
}
if match.In != nil {
found++
}
if match.NotIn != nil {
found++
}
if match.Like != "" {
found++
}
if found == 0 {
return fmt.Errorf("no conditions specified")
}
return nil
}
func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
var (
queryParts []string
params = make(map[string]interface{})
errs = new(multierror.Error)
key = fmt.Sprintf("%s%d", colDef.Name, idx)
)
add := func(operator, suffix string, values ...interface{}) {
var placeholder []string
for idx, value := range values {
encodedValue, err := orm.EncodeValue(ctx, &colDef, value, encoderConfig)
if err != nil {
errs.Errors = append(errs.Errors,
fmt.Errorf("failed to encode %v for column %s: %w", value, colDef.Name, err),
)
return
}
uniqKey := fmt.Sprintf(":%s%s%d", key, suffix, idx)
placeholder = append(placeholder, uniqKey)
params[uniqKey] = encodedValue
}
if len(placeholder) == 1 {
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0]))
} else {
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", ")))
}
}
if match.Equal != nil {
add("=", "eq", match.Equal)
}
if match.NotEqual != nil {
add("!=", "ne", match.NotEqual)
}
if match.In != nil {
add("IN", "in", match.In...)
}
if match.NotIn != nil {
add("NOT IN", "notin", match.NotIn...)
}
if match.Like != "" {
add("LIKE", "like", match.Like)
}
if len(queryParts) == 0 {
// this is an empty matcher without a single condition.
// we convert that to a no-op TRUE value
return "( 1 = 1 )", nil, errs.ErrorOrNil()
}
if len(queryParts) == 1 {
return queryParts[0], params, errs.ErrorOrNil()
}
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
}
func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
if len(query) == 0 {
return "", nil, nil
}
// create a lookup map to validate column names
lm := make(map[string]orm.ColumnDef, len(m.Columns))
for _, col := range m.Columns {
lm[col.Name] = col
}
paramMap := make(map[string]interface{})
columnStmts := make([]string, 0, len(query))
// get all keys and sort them so we get a stable output
queryKeys := make([]string, 0, len(query))
for column := range query {
queryKeys = append(queryKeys, column)
}
sort.Strings(queryKeys)
// actually create the WHERE clause parts for each
// column in query.
errs := new(multierror.Error)
for _, column := range queryKeys {
values := query[column]
colDef, ok := lm[column]
if !ok {
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
continue
}
queryParts := make([]string, len(values))
for idx, val := range values {
matcherQuery, params, err := val.toSQLConditionClause(ctx, idx, "AND", colDef, encoderConfig)
if err != nil {
errs.Errors = append(errs.Errors,
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
)
continue
}
// merge parameters up into the superior parameter map
for key, val := range params {
if _, ok := paramMap[key]; ok {
// is is soley a developer mistake when implementing a matcher so no forgiving ...
panic("sqlite parameter collision")
}
paramMap[key] = val
}
queryParts[idx] = matcherQuery
}
columnStmts = append(columnStmts,
fmt.Sprintf("( %s )", strings.Join(queryParts, " OR ")),
)
}
whereClause := strings.Join(columnStmts, " AND ")
return whereClause, paramMap, errs.ErrorOrNil()
}
func (sel *Selects) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
// if we are looking at a slice directly decode into
// a []Select
if blob[0] == '[' {
var result []Select
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
(*sel) = result
return nil
}
// if it's an object decode into a single select
if blob[0] == '{' {
var result Select
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*sel = []Select{result}
return nil
}
// otherwise this is just the field name
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
return nil
}
func (sel *Select) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
// if we have an object at hand decode the select
// directly
if blob[0] == '{' {
var res struct {
Field string `json:"field"`
Count *Count `json:"$count"`
}
if err := json.Unmarshal(blob, &res); err != nil {
return err
}
sel.Count = res.Count
sel.Field = res.Field
if sel.Count != nil && sel.Count.As != "" {
if !charOnlyRegexp.MatchString(sel.Count.As) {
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
}
}
return nil
}
var x string
if err := json.Unmarshal(blob, &x); err != nil {
return err
}
sel.Field = x
return nil
}
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
if blob[0] == '[' {
var result []OrderBy
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*orderBys = result
return nil
}
if blob[0] == '{' {
var result OrderBy
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*orderBys = []OrderBy{result}
return nil
}
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
*orderBys = []OrderBy{
{
Field: field,
Desc: false,
},
}
return nil
}
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
if blob[0] == '{' {
var res struct {
Field string `json:"field"`
Desc bool `json:"desc"`
}
if err := json.Unmarshal(blob, &res); err != nil {
return err
}
orderBy.Desc = res.Desc
orderBy.Field = res.Field
return nil
}
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
orderBy.Field = field
orderBy.Desc = false
return nil
}

293
netquery/query_handler.go Normal file
View file

@ -0,0 +1,293 @@
package netquery
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"regexp"
"strings"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portmaster/netquery/orm"
)
var (
charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+")
)
type (
// QueryHandler implements http.Handler and allows to perform SQL
// query and aggregate functions on Database.
QueryHandler struct {
IsDevMode func() bool
Database *Database
}
)
func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
start := time.Now()
requestPayload, err := qh.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
queryParsed := time.Since(start)
query, paramMap, err := requestPayload.generateSQL(req.Context(), qh.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
sqlQueryBuilt := time.Since(start)
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := qh.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
sqlQueryFinished := time.Since(start)
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
// prepare the result body that, in dev mode, contains
// some diagnostics data about the query
var resultBody map[string]interface{}
if qh.IsDevMode() {
resultBody = map[string]interface{}{
"sql_prep_stmt": query,
"sql_params": paramMap,
"query": requestPayload.Query,
"orderBy": requestPayload.OrderBy,
"groupBy": requestPayload.GroupBy,
"selects": requestPayload.Select,
"times": map[string]interface{}{
"start_time": start,
"query_parsed_after": queryParsed.String(),
"query_built_after": sqlQueryBuilt.String(),
"query_executed_after": sqlQueryFinished.String(),
},
}
} else {
resultBody = make(map[string]interface{})
}
resultBody["results"] = result
// and finally stream the response
if err := enc.Encode(resultBody); err != nil {
// we failed to encode the JSON body to resp so we likely either already sent a
// few bytes or the pipe was already closed. In either case, trying to send the
// error using http.Error() is non-sense. We just log it out here and that's all
// we can do.
log.Errorf("failed to encode JSON response: %s", err)
return
}
}
func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) {
var body io.Reader
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
var requestPayload QueryRequestPayload
blob, err := ioutil.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
if err := req.prepareSelectedFields(schema); err != nil {
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
}
// build the SQL where clause from the payload query
whereClause, paramMap, err := req.Query.toSQLWhereClause(
ctx,
schema,
orm.DefaultEncodeConfig,
)
if err != nil {
return "", nil, fmt.Errorf("ganerating where clause: %w", err)
}
// build the actual SQL query statement
// FIXME(ppacher): add support for group-by and sort-by
groupByClause, err := req.generateGroupByClause(schema)
if err != nil {
return "", nil, fmt.Errorf("generating group-by clause: %w", err)
}
orderByClause, err := req.generateOrderByClause(schema)
if err != nil {
return "", nil, fmt.Errorf("generating order-by clause: %w", err)
}
selectClause := req.generateSelectClause()
query := `SELECT ` + selectClause + ` FROM connections`
if whereClause != "" {
query += " WHERE " + whereClause
}
query += " " + groupByClause + " " + orderByClause
return query, paramMap, nil
}
func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) error {
for _, s := range req.Select {
var field string
if s.Count != nil {
field = s.Count.Field
} else {
field = s.Field
}
colName := "*"
if field != "*" || s.Count == nil {
var err error
colName, err = req.validateColumnName(schema, field)
if err != nil {
return err
}
}
if s.Count != nil {
var as = s.Count.As
if as == "" {
as = fmt.Sprintf("%s_count", colName)
}
var distinct = ""
if s.Count.Distinct {
distinct = "DISTINCT "
}
req.selectedFields = append(req.selectedFields, fmt.Sprintf("COUNT(%s%s) as %s", distinct, colName, as))
req.whitelistedFields = append(req.whitelistedFields, as)
} else {
req.selectedFields = append(req.selectedFields, colName)
}
}
return nil
}
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
if len(req.GroupBy) == 0 {
return "", nil
}
var groupBys = make([]string, len(req.GroupBy))
for idx, name := range req.GroupBy {
colName, err := req.validateColumnName(schema, name)
if err != nil {
return "", err
}
groupBys[idx] = colName
}
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
// if there are no explicitly selected fields we default to the
// group-by columns as that's what's expected most of the time anyway...
if len(req.selectedFields) == 0 {
req.selectedFields = append(req.selectedFields, groupBys...)
}
return groupByClause, nil
}
func (req *QueryRequestPayload) generateSelectClause() string {
var selectClause = "*"
if len(req.selectedFields) > 0 {
selectClause = strings.Join(req.selectedFields, ", ")
}
return selectClause
}
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
var orderBys = make([]string, len(req.OrderBy))
for idx, sort := range req.OrderBy {
colName, err := req.validateColumnName(schema, sort.Field)
if err != nil {
return "", err
}
if sort.Desc {
orderBys[idx] = fmt.Sprintf("%s DESC", colName)
} else {
orderBys[idx] = fmt.Sprintf("%s ASC", colName)
}
}
return "ORDER BY " + strings.Join(orderBys, ", "), nil
}
func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) {
colDef := schema.GetColumnDef(field)
if colDef != nil {
return colDef.Name, nil
}
for _, selected := range req.whitelistedFields {
if field == selected {
return field, nil
}
}
for _, selected := range req.selectedFields {
if field == selected {
return field, nil
}
}
return "", fmt.Errorf("column name %s not allowed", field)
}
// compile time check
var _ http.Handler = new(QueryHandler)

244
netquery/query_test.go Normal file
View file

@ -0,0 +1,244 @@
package netquery
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/safing/portmaster/netquery/orm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_UnmarshalQuery(t *testing.T) {
var cases = []struct {
Name string
Input string
Expected Query
Error error
}{
{
"Parse a simple query",
`{ "domain": ["example.com", "example.at"] }`,
Query{
"domain": []Matcher{
{
Equal: "example.com",
},
{
Equal: "example.at",
},
},
},
nil,
},
{
"Parse a more complex query",
`
{
"domain": [
{
"$in": [
"example.at",
"example.com"
]
},
{
"$like": "microsoft.%"
}
],
"path": [
"/bin/ping",
{
"$notin": [
"/sbin/ping",
"/usr/sbin/ping"
]
}
]
}
`,
Query{
"domain": []Matcher{
{
In: []interface{}{
"example.at",
"example.com",
},
},
{
Like: "microsoft.%",
},
},
"path": []Matcher{
{
Equal: "/bin/ping",
},
{
NotIn: []interface{}{
"/sbin/ping",
"/usr/sbin/ping",
},
},
},
},
nil,
},
}
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
var q Query
err := json.Unmarshal([]byte(c.Input), &q)
if c.Error != nil {
if assert.Error(t, err) {
assert.Equal(t, c.Error.Error(), err.Error())
}
} else {
assert.NoError(t, err)
assert.Equal(t, c.Expected, q)
}
})
}
}
func Test_QueryBuilder(t *testing.T) {
now := time.Now()
var cases = []struct {
N string
Q Query
R string
P map[string]interface{}
E error
}{
{
"No filter",
nil,
"",
nil,
nil,
},
{
"Simple, one-column filter",
Query{"domain": []Matcher{
{
Equal: "example.com",
},
{
Equal: "example.at",
},
}},
"( domain = :domain0eq0 OR domain = :domain1eq0 )",
map[string]interface{}{
":domain0eq0": "example.com",
":domain1eq0": "example.at",
},
nil,
},
{
"Two column filter",
Query{
"domain": []Matcher{
{
Equal: "example.com",
},
},
"path": []Matcher{
{
Equal: "/bin/curl",
},
{
Equal: "/bin/ping",
},
},
},
"( domain = :domain0eq0 ) AND ( path = :path0eq0 OR path = :path1eq0 )",
map[string]interface{}{
":domain0eq0": "example.com",
":path0eq0": "/bin/curl",
":path1eq0": "/bin/ping",
},
nil,
},
{
"Time based filter",
Query{
"started": []Matcher{
{
Equal: now.Format(time.RFC3339),
},
},
},
"( started = :started0eq0 )",
map[string]interface{}{
":started0eq0": now.In(time.UTC).Format(orm.SqliteTimeFormat),
},
nil,
},
{
"Invalid column access",
Query{
"forbiddenField": []Matcher{{}},
},
"",
nil,
fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"),
},
{
"Complex example",
Query{
"domain": []Matcher{
{
In: []interface{}{"example.at", "example.com"},
},
{
Like: "microsoft.%",
},
},
"path": []Matcher{
{
NotIn: []interface{}{
"/bin/ping",
"/sbin/ping",
"/usr/bin/ping",
},
},
},
},
"( domain IN ( :domain0in0, :domain0in1 ) OR domain LIKE :domain1like0 ) AND ( path NOT IN ( :path0notin0, :path0notin1, :path0notin2 ) )",
map[string]interface{}{
":domain0in0": "example.at",
":domain0in1": "example.com",
":domain1like0": "microsoft.%",
":path0notin0": "/bin/ping",
":path0notin1": "/sbin/ping",
":path0notin2": "/usr/bin/ping",
},
nil,
},
}
tbl, err := orm.GenerateTableSchema("connections", Conn{})
require.NoError(t, err)
for idx, c := range cases {
t.Run(c.N, func(t *testing.T) {
//t.Parallel()
str, params, err := c.Q.toSQLWhereClause(context.TODO(), tbl, orm.DefaultEncodeConfig)
if c.E != nil {
if assert.Error(t, err) {
assert.Equal(t, c.E.Error(), err.Error(), "test case %d", idx)
}
} else {
assert.NoError(t, err, "test case %d", idx)
assert.Equal(t, c.P, params, "test case %d", idx)
assert.Equal(t, c.R, str, "test case %d", idx)
}
})
}
}