diff --git a/netquery/database.go b/netquery/database.go index cf5450a0..25f9e571 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -44,6 +44,8 @@ type ( // are actually supposed to do. // Database struct { + Schema *orm.TableSchema + l sync.Mutex conn *sqlite.Conn } @@ -62,6 +64,8 @@ type ( // as long as the connection is still active and might be, although unlikely, // reused afterwards. ID string `sqlite:"id,primary"` + ProfileID string `sqlite:"profile"` + Path string `sqlite:"path"` Type string `sqlite:"type,varchar(8)"` External bool `sqlite:"external"` IPVersion packet.IPVersion `sqlite:"ip_version"` @@ -78,8 +82,8 @@ type ( Longitude float64 `sqlite:"longitude"` Scope netutils.IPScope `sqlite:"scope"` Verdict network.Verdict `sqlite:"verdict"` - Started time.Time `sqlite:"started,text"` - Ended *time.Time `sqlite:"ended,text"` + Started time.Time `sqlite:"started,text,time"` + Ended *time.Time `sqlite:"ended,text,time"` Tunneled bool `sqlite:"tunneled"` Encrypted bool `sqlite:"encrypted"` Internal bool `sqlite:"internal"` @@ -107,7 +111,15 @@ func New(path string) (*Database, error) { return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) } - return &Database{conn: c}, nil + schema, err := orm.GenerateTableSchema("connections", Conn{}) + if err != nil { + return nil, err + } + + return &Database{ + Schema: schema, + conn: c, + }, nil } // NewInMemory is like New but creates a new in-memory database and @@ -133,13 +145,8 @@ func NewInMemory() (*Database, error) { // any data-migrations. Once the history module is implemented this should // become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration func (db *Database) ApplyMigrations() error { - schema, err := orm.GenerateTableSchema("connections", Conn{}) - if err != nil { - return fmt.Errorf("failed to generate table schema for conncetions: %w", err) - } - // get the create-table SQL statement from the infered schema - sql := schema.CreateStatement(false) + sql := db.Schema.CreateStatement(false) // execute the SQL if err := sqlitex.ExecuteTransient(db.conn, sql, nil); err != nil { @@ -284,7 +291,7 @@ func (db *Database) Save(ctx context.Context, conn Conn) error { return nil }, }); err != nil { - log.Errorf("netquery: failed to execute: %s", err) + log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values) return err } diff --git a/netquery/manager.go b/netquery/manager.go index 647e82ed..16a78021 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -96,7 +96,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect model, err := convertConnection(conn) if err != nil { - log.Errorf("netquery: failed to convert connection %s to sqlite model: %w", conn.ID, err) + log.Errorf("netquery: failed to convert connection %s to sqlite model: %s", conn.ID, err) continue } @@ -104,7 +104,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect log.Infof("netquery: persisting create/update to connection %s", conn.ID) if err := mng.store.Save(ctx, *model); err != nil { - log.Errorf("netquery: failed to save connection %s in sqlite database: %w", conn.ID, err) + log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err) continue } @@ -116,7 +116,7 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect // push an update for the connection if err := mng.pushConnUpdate(ctx, *cloned, *model); err != nil { - log.Errorf("netquery: failed to push update for conn %s via database system: %w", conn.ID, err) + log.Errorf("netquery: failed to push update for conn %s via database system: %s", conn.ID, err) } count++ @@ -170,6 +170,8 @@ func convertConnection(conn *network.Connection) (*Conn, error) { Internal: conn.Internal, Inbound: conn.Inbound, Type: ConnectionTypeToString[conn.Type], + ProfileID: conn.ProcessContext.ProfileName, + Path: conn.ProcessContext.BinaryPath, } if conn.Ended > 0 { diff --git a/netquery/module_api.go b/netquery/module_api.go index 6bc9cb4f..a344ec3b 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/safing/portbase/api" + "github.com/safing/portbase/config" "github.com/safing/portbase/database" "github.com/safing/portbase/database/query" "github.com/safing/portbase/log" @@ -29,6 +31,7 @@ func init() { mod.Prepare, mod.Start, mod.Stop, + "api", "network", "database", ) @@ -55,6 +58,25 @@ func (m *Module) Prepare() error { m.feed = make(chan *network.Connection, 1000) + queryHander := &QueryHandler{ + Database: m.sqlStore, + IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), + } + + // FIXME(ppacher): use appropriate permissions for this + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "netquery/query", + MimeType: "application/json", + Read: api.PermitAnyone, + Write: api.PermitAnyone, + BelongsTo: m.Module, + HandlerFunc: queryHander.ServeHTTP, + Name: "Query In-Memory Database", + Description: "Query the in-memory sqlite database", + }); err != nil { + return fmt.Errorf("failed to register API endpoint: %w", err) + } + return nil } @@ -100,7 +122,7 @@ func (mod *Module) Start() error { case <-time.After(10 * time.Second): count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold)) if err != nil { - log.Errorf("netquery: failed to count number of rows in memory: %w", err) + log.Errorf("netquery: failed to count number of rows in memory: %s", err) } else { log.Infof("netquery: successfully removed %d old rows", count) } @@ -116,7 +138,7 @@ func (mod *Module) Start() error { case <-time.After(5 * time.Second): count, err := mod.sqlStore.CountRows(ctx) if err != nil { - log.Errorf("netquery: failed to count number of rows in memory: %w", err) + log.Errorf("netquery: failed to count number of rows in memory: %s", err) } else { log.Infof("netquery: currently holding %d rows in memory", count) } diff --git a/netquery/orm/decoder.go b/netquery/orm/decoder.go index 5e0d1f7b..76359755 100644 --- a/netquery/orm/decoder.go +++ b/netquery/orm/decoder.go @@ -31,7 +31,7 @@ var ( // preconfigured timezone; UTC by default) or as INTEGER (the user can choose between // unixepoch and unixnano-epoch where the nano variant is not offically supported by // SQLITE). - sqliteTimeFormat = "2006-01-02 15:04:05" + SqliteTimeFormat = "2006-01-02 15:04:05" ) type ( @@ -209,7 +209,7 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc { case sqlite.TypeText: // stored ISO8601 but does not have any timezone information // assigned so we always treat it as loc here. - t, err := time.ParseInLocation(sqliteTimeFormat, stmt.ColumnText(colIdx), loc) + t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc) if err != nil { return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err) } diff --git a/netquery/orm/encoder.go b/netquery/orm/encoder.go index fc4e772c..0b60f6ca 100644 --- a/netquery/orm/encoder.go +++ b/netquery/orm/encoder.go @@ -103,6 +103,13 @@ func encodeBasic() EncodeFunc { kind = valType.Kind() if val.IsNil() { + if !col.Nullable { + // we need to set the zero value here since the column + // is not marked as nullable + //return reflect.New(valType).Elem().Interface(), true, nil + panic("nil pointer for not-null field") + } + return nil, true, nil } @@ -133,7 +140,7 @@ func encodeBasic() EncodeFunc { } func DatetimeEncoder(loc *time.Location) EncodeFunc { - return func(colDev *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) { + return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) { // if fieldType holds a pointer we need to dereference the value ft := valType.String() if valType.Kind() == reflect.Ptr { @@ -142,44 +149,71 @@ func DatetimeEncoder(loc *time.Location) EncodeFunc { } // we only care about "time.Time" here - if ft != "time.Time" { + var t time.Time + if ft == "time.Time" { + // handle the zero time as a NULL. + if !val.IsValid() || val.IsZero() { + return nil, true, nil + } + + var ok bool + valInterface := val.Interface() + t, ok = valInterface.(time.Time) + if !ok { + return nil, false, fmt.Errorf("cannot convert reflect value to time.Time") + } + + } else if valType.Kind() == reflect.String && colDef.IsTime { + var err error + t, err = time.Parse(time.RFC3339, val.String()) + if err != nil { + return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err) + } + + } else { + // we don't care ... return nil, false, nil } - // handle the zero time as a NULL. - if !val.IsValid() || val.IsZero() { - return nil, true, nil - } - - valInterface := val.Interface() - t, ok := valInterface.(time.Time) - if !ok { - return nil, false, fmt.Errorf("cannot convert reflect value to time.Time") - } - - switch colDev.Type { + switch colDef.Type { case sqlite.TypeInteger: - if colDev.UnixNano { + if colDef.UnixNano { return t.UnixNano(), true, nil } return t.Unix(), true, nil + case sqlite.TypeText: - str := t.In(loc).Format(sqliteTimeFormat) + str := t.In(loc).Format(SqliteTimeFormat) return str, true, nil } - return nil, false, fmt.Errorf("cannot store time.Time in %s", colDev.Type) + return nil, false, fmt.Errorf("cannot store time.Time in %s", colDef.Type) } } -func runEncodeHooks(colDev *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) { +func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) { if valType == nil { + if !colDef.Nullable { + switch colDef.Type { + case sqlite.TypeBlob: + return []byte{}, true, nil + case sqlite.TypeFloat: + return 0.0, true, nil + case sqlite.TypeText: + return "", true, nil + case sqlite.TypeInteger: + return 0, true, nil + default: + return nil, false, fmt.Errorf("unsupported sqlite data type: %s", colDef.Type) + } + } + return nil, true, nil } for _, fn := range hooks { - res, end, err := fn(colDev, valType, val) + res, end, err := fn(colDef, valType, val) if err != nil { return res, false, err } diff --git a/netquery/orm/encoder_test.go b/netquery/orm/encoder_test.go index 8b802a78..aff28580 100644 --- a/netquery/orm/encoder_test.go +++ b/netquery/orm/encoder_test.go @@ -89,7 +89,7 @@ func Test_EncodeAsMap(t *testing.T) { }, map[string]interface{}{ "TinInt": refTime.UnixNano(), - "TinString": refTime.Format(sqliteTimeFormat), + "TinString": refTime.Format(SqliteTimeFormat), }, }, { @@ -107,7 +107,7 @@ func Test_EncodeAsMap(t *testing.T) { }, map[string]interface{}{ "TinInt": refTime.UnixNano(), - "TinString": refTime.Format(sqliteTimeFormat), + "TinString": refTime.Format(SqliteTimeFormat), "Tnil1": nil, "Tnil2": nil, }, @@ -143,7 +143,7 @@ func Test_EncodeValue(t *testing.T) { Type: sqlite.TypeText, }, refTime, - refTime.Format(sqliteTimeFormat), + refTime.Format(SqliteTimeFormat), }, { "Special value time.Time as unix-epoch", @@ -189,13 +189,14 @@ func Test_EncodeValue(t *testing.T) { Type: sqlite.TypeText, }, &refTime, - refTime.Format(sqliteTimeFormat), + refTime.Format(SqliteTimeFormat), }, { "Special value untyped nil", ColumnDef{ - IsTime: true, - Type: sqlite.TypeText, + Nullable: true, + IsTime: true, + Type: sqlite.TypeText, }, nil, nil, @@ -209,12 +210,47 @@ func Test_EncodeValue(t *testing.T) { (*time.Time)(nil), nil, }, + { + "Time formated as string", + ColumnDef{ + IsTime: true, + Type: sqlite.TypeText, + }, + refTime.In(time.Local).Format(time.RFC3339), + refTime.Format(SqliteTimeFormat), + }, + { + "Nullable integer", + ColumnDef{ + Type: sqlite.TypeInteger, + Nullable: true, + }, + nil, + nil, + }, + { + "Not-Null integer", + ColumnDef{ + Name: "test", + Type: sqlite.TypeInteger, + }, + nil, + 0, + }, + { + "Not-Null string", + ColumnDef{ + Type: sqlite.TypeText, + }, + nil, + "", + }, } for idx := range cases { c := cases[idx] t.Run(c.Desc, func(t *testing.T) { - // t.Parallel() + //t.Parallel() res, err := EncodeValue(ctx, &c.Column, c.Input, DefaultEncodeConfig) assert.NoError(t, err) diff --git a/netquery/orm/schema_builder.go b/netquery/orm/schema_builder.go index 9289b06d..68783c3f 100644 --- a/netquery/orm/schema_builder.go +++ b/netquery/orm/schema_builder.go @@ -53,6 +53,15 @@ type ( } ) +func (ts TableSchema) GetColumnDef(name string) *ColumnDef { + for _, def := range ts.Columns { + if def.Name == name { + return &def + } + } + return nil +} + func (ts TableSchema) CreateStatement(ifNotExists bool) string { sql := "CREATE TABLE" if ifNotExists { diff --git a/netquery/query.go b/netquery/query.go new file mode 100644 index 00000000..2fdbf38f --- /dev/null +++ b/netquery/query.go @@ -0,0 +1,464 @@ +package netquery + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "sort" + "strings" + + "github.com/hashicorp/go-multierror" + "github.com/safing/portmaster/netquery/orm" +) + +type ( + Query map[string][]Matcher + + Matcher struct { + Equal interface{} `json:"$eq,omitempty"` + NotEqual interface{} `json:"$ne,omitempty"` + In []interface{} `json:"$in,omitempty"` + NotIn []interface{} `json:"$notIn,omitempty"` + Like string `json:"$like,omitempty"` + } + + Count struct { + As string `json:"as"` + Field string `json:"field"` + Distinct bool `json:"distict"` + } + + Select struct { + Field string `json:"field"` + Count *Count `json:"$count"` + } + + Selects []Select + + QueryRequestPayload struct { + Select Selects `json:"select"` + Query Query `json:"query"` + OrderBy []OrderBy `json:"orderBy"` + GroupBy []string `json:"groupBy"` + + selectedFields []string + whitelistedFields []string + } + + OrderBy struct { + Field string `json:"field"` + Desc bool `json:"desc"` + } + + OrderBys []OrderBy +) + +func (query *Query) UnmarshalJSON(blob []byte) error { + if *query == nil { + *query = make(Query) + } + + var model map[string]json.RawMessage + + if err := json.Unmarshal(blob, &model); err != nil { + return err + } + + for columnName, rawColumnQuery := range model { + if len(rawColumnQuery) == 0 { + continue + } + + switch rawColumnQuery[0] { + case '{': + m, err := parseMatcher(rawColumnQuery) + if err != nil { + return err + } + + (*query)[columnName] = []Matcher{*m} + + case '[': + var rawMatchers []json.RawMessage + if err := json.Unmarshal(rawColumnQuery, &rawMatchers); err != nil { + return err + } + + (*query)[columnName] = make([]Matcher, len(rawMatchers)) + for idx, val := range rawMatchers { + // this should not happen + if len(val) == 0 { + continue + } + + // if val starts with a { we have a matcher definition + if val[0] == '{' { + m, err := parseMatcher(val) + if err != nil { + return err + } + (*query)[columnName][idx] = *m + + continue + } else if val[0] == '[' { + return fmt.Errorf("invalid token [ in query for column %s", columnName) + } + + // val is a dedicated JSON primitive and not an object or array + // so we treat that as an EQUAL condition. + var x interface{} + if err := json.Unmarshal(val, &x); err != nil { + return err + } + + (*query)[columnName][idx] = Matcher{ + Equal: x, + } + } + + default: + // value is a JSON primitive and not an object or array + // so we treat that as an EQUAL condition. + var x interface{} + if err := json.Unmarshal(rawColumnQuery, &x); err != nil { + return err + } + + (*query)[columnName] = []Matcher{ + {Equal: x}, + } + } + } + + return nil +} + +func parseMatcher(raw json.RawMessage) (*Matcher, error) { + var m Matcher + if err := json.Unmarshal(raw, &m); err != nil { + return nil, err + } + + if err := m.Validate(); err != nil { + return nil, fmt.Errorf("invalid query matcher: %s", err) + } + log.Printf("parsed matcher %s: %+v", string(raw), m) + return &m, nil + +} + +func (match Matcher) Validate() error { + found := 0 + + if match.Equal != nil { + found++ + } + + if match.NotEqual != nil { + found++ + } + + if match.In != nil { + found++ + } + + if match.NotIn != nil { + found++ + } + + if match.Like != "" { + found++ + } + + if found == 0 { + return fmt.Errorf("no conditions specified") + } + + return nil +} + +func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { + var ( + queryParts []string + params = make(map[string]interface{}) + errs = new(multierror.Error) + key = fmt.Sprintf("%s%d", colDef.Name, idx) + ) + + add := func(operator, suffix string, values ...interface{}) { + var placeholder []string + + for idx, value := range values { + encodedValue, err := orm.EncodeValue(ctx, &colDef, value, encoderConfig) + if err != nil { + errs.Errors = append(errs.Errors, + fmt.Errorf("failed to encode %v for column %s: %w", value, colDef.Name, err), + ) + return + } + + uniqKey := fmt.Sprintf(":%s%s%d", key, suffix, idx) + placeholder = append(placeholder, uniqKey) + params[uniqKey] = encodedValue + } + + if len(placeholder) == 1 { + queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0])) + } else { + queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", "))) + } + } + + if match.Equal != nil { + add("=", "eq", match.Equal) + } + + if match.NotEqual != nil { + add("!=", "ne", match.NotEqual) + } + + if match.In != nil { + add("IN", "in", match.In...) + } + + if match.NotIn != nil { + add("NOT IN", "notin", match.NotIn...) + } + + if match.Like != "" { + add("LIKE", "like", match.Like) + } + + if len(queryParts) == 0 { + // this is an empty matcher without a single condition. + // we convert that to a no-op TRUE value + return "( 1 = 1 )", nil, errs.ErrorOrNil() + } + + if len(queryParts) == 1 { + return queryParts[0], params, errs.ErrorOrNil() + } + + return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil() +} + +func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { + if len(query) == 0 { + return "", nil, nil + } + + // create a lookup map to validate column names + lm := make(map[string]orm.ColumnDef, len(m.Columns)) + for _, col := range m.Columns { + lm[col.Name] = col + } + + paramMap := make(map[string]interface{}) + columnStmts := make([]string, 0, len(query)) + + // get all keys and sort them so we get a stable output + queryKeys := make([]string, 0, len(query)) + for column := range query { + queryKeys = append(queryKeys, column) + } + sort.Strings(queryKeys) + + // actually create the WHERE clause parts for each + // column in query. + errs := new(multierror.Error) + for _, column := range queryKeys { + values := query[column] + colDef, ok := lm[column] + if !ok { + errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column)) + + continue + } + + queryParts := make([]string, len(values)) + for idx, val := range values { + matcherQuery, params, err := val.toSQLConditionClause(ctx, idx, "AND", colDef, encoderConfig) + if err != nil { + errs.Errors = append(errs.Errors, + fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err), + ) + + continue + } + + // merge parameters up into the superior parameter map + for key, val := range params { + if _, ok := paramMap[key]; ok { + // is is soley a developer mistake when implementing a matcher so no forgiving ... + panic("sqlite parameter collision") + } + + paramMap[key] = val + } + + queryParts[idx] = matcherQuery + } + + columnStmts = append(columnStmts, + fmt.Sprintf("( %s )", strings.Join(queryParts, " OR ")), + ) + } + + whereClause := strings.Join(columnStmts, " AND ") + + return whereClause, paramMap, errs.ErrorOrNil() +} + +func (sel *Selects) UnmarshalJSON(blob []byte) error { + if len(blob) == 0 { + return io.ErrUnexpectedEOF + } + + // if we are looking at a slice directly decode into + // a []Select + if blob[0] == '[' { + var result []Select + if err := json.Unmarshal(blob, &result); err != nil { + return err + } + + (*sel) = result + + return nil + } + + // if it's an object decode into a single select + if blob[0] == '{' { + var result Select + if err := json.Unmarshal(blob, &result); err != nil { + return err + } + + *sel = []Select{result} + + return nil + } + + // otherwise this is just the field name + var field string + if err := json.Unmarshal(blob, &field); err != nil { + return err + } + + return nil +} + +func (sel *Select) UnmarshalJSON(blob []byte) error { + if len(blob) == 0 { + return io.ErrUnexpectedEOF + } + + // if we have an object at hand decode the select + // directly + if blob[0] == '{' { + var res struct { + Field string `json:"field"` + Count *Count `json:"$count"` + } + + if err := json.Unmarshal(blob, &res); err != nil { + return err + } + + sel.Count = res.Count + sel.Field = res.Field + + if sel.Count != nil && sel.Count.As != "" { + if !charOnlyRegexp.MatchString(sel.Count.As) { + return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+") + } + } + + return nil + } + + var x string + if err := json.Unmarshal(blob, &x); err != nil { + return err + } + + sel.Field = x + + return nil +} + +func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error { + if len(blob) == 0 { + return io.ErrUnexpectedEOF + } + + if blob[0] == '[' { + var result []OrderBy + if err := json.Unmarshal(blob, &result); err != nil { + return err + } + + *orderBys = result + + return nil + } + + if blob[0] == '{' { + var result OrderBy + if err := json.Unmarshal(blob, &result); err != nil { + return err + } + + *orderBys = []OrderBy{result} + + return nil + } + + var field string + if err := json.Unmarshal(blob, &field); err != nil { + return err + } + + *orderBys = []OrderBy{ + { + Field: field, + Desc: false, + }, + } + + return nil +} + +func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error { + if len(blob) == 0 { + return io.ErrUnexpectedEOF + } + + if blob[0] == '{' { + var res struct { + Field string `json:"field"` + Desc bool `json:"desc"` + } + + if err := json.Unmarshal(blob, &res); err != nil { + return err + } + + orderBy.Desc = res.Desc + orderBy.Field = res.Field + + return nil + } + + var field string + if err := json.Unmarshal(blob, &field); err != nil { + return err + } + + orderBy.Field = field + orderBy.Desc = false + + return nil +} diff --git a/netquery/query_handler.go b/netquery/query_handler.go new file mode 100644 index 00000000..a1e4353f --- /dev/null +++ b/netquery/query_handler.go @@ -0,0 +1,293 @@ +package netquery + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "regexp" + "strings" + "time" + + "github.com/safing/portbase/log" + "github.com/safing/portmaster/netquery/orm" +) + +var ( + charOnlyRegexp = regexp.MustCompile("[a-zA-Z]+") +) + +type ( + + // QueryHandler implements http.Handler and allows to perform SQL + // query and aggregate functions on Database. + QueryHandler struct { + IsDevMode func() bool + Database *Database + } +) + +func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + start := time.Now() + requestPayload, err := qh.parseRequest(req) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + queryParsed := time.Since(start) + + query, paramMap, err := requestPayload.generateSQL(req.Context(), qh.Database.Schema) + if err != nil { + http.Error(resp, err.Error(), http.StatusBadRequest) + + return + } + + sqlQueryBuilt := time.Since(start) + + // actually execute the query against the database and collect the result + var result []map[string]interface{} + if err := qh.Database.Execute( + req.Context(), + query, + orm.WithNamedArgs(paramMap), + orm.WithResult(&result), + ); err != nil { + http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError) + + return + } + sqlQueryFinished := time.Since(start) + + // send the HTTP status code + resp.WriteHeader(http.StatusOK) + + // prepare the result encoder. + enc := json.NewEncoder(resp) + enc.SetEscapeHTML(false) + enc.SetIndent("", " ") + + // prepare the result body that, in dev mode, contains + // some diagnostics data about the query + var resultBody map[string]interface{} + if qh.IsDevMode() { + resultBody = map[string]interface{}{ + "sql_prep_stmt": query, + "sql_params": paramMap, + "query": requestPayload.Query, + "orderBy": requestPayload.OrderBy, + "groupBy": requestPayload.GroupBy, + "selects": requestPayload.Select, + "times": map[string]interface{}{ + "start_time": start, + "query_parsed_after": queryParsed.String(), + "query_built_after": sqlQueryBuilt.String(), + "query_executed_after": sqlQueryFinished.String(), + }, + } + } else { + resultBody = make(map[string]interface{}) + } + resultBody["results"] = result + + // and finally stream the response + if err := enc.Encode(resultBody); err != nil { + // we failed to encode the JSON body to resp so we likely either already sent a + // few bytes or the pipe was already closed. In either case, trying to send the + // error using http.Error() is non-sense. We just log it out here and that's all + // we can do. + log.Errorf("failed to encode JSON response: %s", err) + + return + } +} + +func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { + var body io.Reader + + switch req.Method { + case http.MethodPost, http.MethodPut: + body = req.Body + case http.MethodGet: + body = strings.NewReader(req.URL.Query().Get("q")) + default: + return nil, fmt.Errorf("invalid HTTP method") + } + + var requestPayload QueryRequestPayload + blob, err := ioutil.ReadAll(body) + if err != nil { + return nil, fmt.Errorf("failed to read body" + err.Error()) + } + + body = bytes.NewReader(blob) + + dec := json.NewDecoder(body) + dec.DisallowUnknownFields() + + if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("invalid query: %w", err) + } + + return &requestPayload, nil +} + +func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { + if err := req.prepareSelectedFields(schema); err != nil { + return "", nil, fmt.Errorf("perparing selected fields: %w", err) + } + + // build the SQL where clause from the payload query + whereClause, paramMap, err := req.Query.toSQLWhereClause( + ctx, + schema, + orm.DefaultEncodeConfig, + ) + if err != nil { + return "", nil, fmt.Errorf("ganerating where clause: %w", err) + } + + // build the actual SQL query statement + // FIXME(ppacher): add support for group-by and sort-by + + groupByClause, err := req.generateGroupByClause(schema) + if err != nil { + return "", nil, fmt.Errorf("generating group-by clause: %w", err) + } + + orderByClause, err := req.generateOrderByClause(schema) + if err != nil { + return "", nil, fmt.Errorf("generating order-by clause: %w", err) + } + + selectClause := req.generateSelectClause() + query := `SELECT ` + selectClause + ` FROM connections` + if whereClause != "" { + query += " WHERE " + whereClause + } + query += " " + groupByClause + " " + orderByClause + + return query, paramMap, nil +} + +func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) error { + for _, s := range req.Select { + var field string + if s.Count != nil { + field = s.Count.Field + } else { + field = s.Field + } + + colName := "*" + if field != "*" || s.Count == nil { + var err error + + colName, err = req.validateColumnName(schema, field) + if err != nil { + return err + } + } + + if s.Count != nil { + var as = s.Count.As + if as == "" { + as = fmt.Sprintf("%s_count", colName) + } + var distinct = "" + if s.Count.Distinct { + distinct = "DISTINCT " + } + req.selectedFields = append(req.selectedFields, fmt.Sprintf("COUNT(%s%s) as %s", distinct, colName, as)) + req.whitelistedFields = append(req.whitelistedFields, as) + } else { + req.selectedFields = append(req.selectedFields, colName) + } + } + + return nil +} + +func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) { + if len(req.GroupBy) == 0 { + return "", nil + } + + var groupBys = make([]string, len(req.GroupBy)) + + for idx, name := range req.GroupBy { + colName, err := req.validateColumnName(schema, name) + if err != nil { + return "", err + } + + groupBys[idx] = colName + } + + groupByClause := "GROUP BY " + strings.Join(groupBys, ", ") + + // if there are no explicitly selected fields we default to the + // group-by columns as that's what's expected most of the time anyway... + if len(req.selectedFields) == 0 { + req.selectedFields = append(req.selectedFields, groupBys...) + } + + return groupByClause, nil +} + +func (req *QueryRequestPayload) generateSelectClause() string { + var selectClause = "*" + if len(req.selectedFields) > 0 { + selectClause = strings.Join(req.selectedFields, ", ") + } + + return selectClause +} + +func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) { + var orderBys = make([]string, len(req.OrderBy)) + for idx, sort := range req.OrderBy { + colName, err := req.validateColumnName(schema, sort.Field) + if err != nil { + return "", err + } + + if sort.Desc { + orderBys[idx] = fmt.Sprintf("%s DESC", colName) + } else { + orderBys[idx] = fmt.Sprintf("%s ASC", colName) + } + } + + return "ORDER BY " + strings.Join(orderBys, ", "), nil +} + +func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) { + colDef := schema.GetColumnDef(field) + if colDef != nil { + return colDef.Name, nil + } + + for _, selected := range req.whitelistedFields { + if field == selected { + return field, nil + } + } + + for _, selected := range req.selectedFields { + if field == selected { + return field, nil + } + } + + return "", fmt.Errorf("column name %s not allowed", field) +} + +// compile time check +var _ http.Handler = new(QueryHandler) diff --git a/netquery/query_test.go b/netquery/query_test.go new file mode 100644 index 00000000..f8e8b3e4 --- /dev/null +++ b/netquery/query_test.go @@ -0,0 +1,244 @@ +package netquery + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/safing/portmaster/netquery/orm" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_UnmarshalQuery(t *testing.T) { + var cases = []struct { + Name string + Input string + Expected Query + Error error + }{ + { + "Parse a simple query", + `{ "domain": ["example.com", "example.at"] }`, + Query{ + "domain": []Matcher{ + { + Equal: "example.com", + }, + { + Equal: "example.at", + }, + }, + }, + nil, + }, + { + "Parse a more complex query", + ` + { + "domain": [ + { + "$in": [ + "example.at", + "example.com" + ] + }, + { + "$like": "microsoft.%" + } + ], + "path": [ + "/bin/ping", + { + "$notin": [ + "/sbin/ping", + "/usr/sbin/ping" + ] + } + ] + } + `, + Query{ + "domain": []Matcher{ + { + In: []interface{}{ + "example.at", + "example.com", + }, + }, + { + Like: "microsoft.%", + }, + }, + "path": []Matcher{ + { + Equal: "/bin/ping", + }, + { + NotIn: []interface{}{ + "/sbin/ping", + "/usr/sbin/ping", + }, + }, + }, + }, + nil, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + var q Query + err := json.Unmarshal([]byte(c.Input), &q) + + if c.Error != nil { + if assert.Error(t, err) { + assert.Equal(t, c.Error.Error(), err.Error()) + } + } else { + assert.NoError(t, err) + assert.Equal(t, c.Expected, q) + } + }) + } +} + +func Test_QueryBuilder(t *testing.T) { + now := time.Now() + + var cases = []struct { + N string + Q Query + R string + P map[string]interface{} + E error + }{ + { + "No filter", + nil, + "", + nil, + nil, + }, + { + "Simple, one-column filter", + Query{"domain": []Matcher{ + { + Equal: "example.com", + }, + { + Equal: "example.at", + }, + }}, + "( domain = :domain0eq0 OR domain = :domain1eq0 )", + map[string]interface{}{ + ":domain0eq0": "example.com", + ":domain1eq0": "example.at", + }, + nil, + }, + { + "Two column filter", + Query{ + "domain": []Matcher{ + { + Equal: "example.com", + }, + }, + "path": []Matcher{ + { + Equal: "/bin/curl", + }, + { + Equal: "/bin/ping", + }, + }, + }, + "( domain = :domain0eq0 ) AND ( path = :path0eq0 OR path = :path1eq0 )", + map[string]interface{}{ + ":domain0eq0": "example.com", + ":path0eq0": "/bin/curl", + ":path1eq0": "/bin/ping", + }, + nil, + }, + { + "Time based filter", + Query{ + "started": []Matcher{ + { + Equal: now.Format(time.RFC3339), + }, + }, + }, + "( started = :started0eq0 )", + map[string]interface{}{ + ":started0eq0": now.In(time.UTC).Format(orm.SqliteTimeFormat), + }, + nil, + }, + { + "Invalid column access", + Query{ + "forbiddenField": []Matcher{{}}, + }, + "", + nil, + fmt.Errorf("1 error occurred:\n\t* column forbiddenField is not allowed\n\n"), + }, + { + "Complex example", + Query{ + "domain": []Matcher{ + { + In: []interface{}{"example.at", "example.com"}, + }, + { + Like: "microsoft.%", + }, + }, + "path": []Matcher{ + { + NotIn: []interface{}{ + "/bin/ping", + "/sbin/ping", + "/usr/bin/ping", + }, + }, + }, + }, + "( domain IN ( :domain0in0, :domain0in1 ) OR domain LIKE :domain1like0 ) AND ( path NOT IN ( :path0notin0, :path0notin1, :path0notin2 ) )", + map[string]interface{}{ + ":domain0in0": "example.at", + ":domain0in1": "example.com", + ":domain1like0": "microsoft.%", + ":path0notin0": "/bin/ping", + ":path0notin1": "/sbin/ping", + ":path0notin2": "/usr/bin/ping", + }, + nil, + }, + } + + tbl, err := orm.GenerateTableSchema("connections", Conn{}) + require.NoError(t, err) + + for idx, c := range cases { + t.Run(c.N, func(t *testing.T) { + //t.Parallel() + str, params, err := c.Q.toSQLWhereClause(context.TODO(), tbl, orm.DefaultEncodeConfig) + + if c.E != nil { + if assert.Error(t, err) { + assert.Equal(t, c.E.Error(), err.Error(), "test case %d", idx) + } + } else { + assert.NoError(t, err, "test case %d", idx) + assert.Equal(t, c.P, params, "test case %d", idx) + assert.Equal(t, c.R, str, "test case %d", idx) + } + }) + } +}