safing-portmaster/netquery/query.go

564 lines
12 KiB
Go

package netquery
import (
"context"
"encoding/json"
"fmt"
"io"
"sort"
"strings"
"github.com/hashicorp/go-multierror"
"zombiezen.com/go/sqlite"
"github.com/safing/portmaster/netquery/orm"
)
// Collection of Query and Matcher types.
// NOTE: whenever adding support for new operators make sure
// to update UnmarshalJSON as well.
//nolint:golint
type (
Query map[string][]Matcher
MatchType interface {
Operator() string
}
Equal interface{}
Matcher struct {
Equal interface{} `json:"$eq,omitempty"`
NotEqual interface{} `json:"$ne,omitempty"`
In []interface{} `json:"$in,omitempty"`
NotIn []interface{} `json:"$notIn,omitempty"`
Like string `json:"$like,omitempty"`
}
Count struct {
As string `json:"as"`
Field string `json:"field"`
Distinct bool `json:"distinct"`
}
Sum struct {
Condition Query `json:"condition"`
As string `json:"as"`
Distinct bool `json:"distinct"`
}
Select struct {
Field string `json:"field"`
Count *Count `json:"$count,omitempty"`
Sum *Sum `json:"$sum,omitempty"`
Distinct *string `json:"$distinct"`
}
Selects []Select
TextSearch struct {
Fields []string `json:"fields"`
Value string `json:"value"`
}
QueryRequestPayload struct {
Select Selects `json:"select"`
Query Query `json:"query"`
OrderBy OrderBys `json:"orderBy"`
GroupBy []string `json:"groupBy"`
TextSearch *TextSearch `json:"textSearch"`
Pagination
selectedFields []string
whitelistedFields []string
paramMap map[string]interface{}
}
QueryActiveConnectionChartPayload struct {
Query Query `json:"query"`
TextSearch *TextSearch `json:"textSearch"`
}
OrderBy struct {
Field string `json:"field"`
Desc bool `json:"desc"`
}
OrderBys []OrderBy
Pagination struct {
PageSize int `json:"pageSize"`
Page int `json:"page"`
}
)
// UnmarshalJSON unmarshals a Query from json.
func (query *Query) UnmarshalJSON(blob []byte) error {
if *query == nil {
*query = make(Query)
}
var model map[string]json.RawMessage
if err := json.Unmarshal(blob, &model); err != nil {
return err
}
for columnName, rawColumnQuery := range model {
if len(rawColumnQuery) == 0 {
continue
}
switch rawColumnQuery[0] {
case '{':
m, err := parseMatcher(rawColumnQuery)
if err != nil {
return err
}
(*query)[columnName] = []Matcher{*m}
case '[':
var rawMatchers []json.RawMessage
if err := json.Unmarshal(rawColumnQuery, &rawMatchers); err != nil {
return err
}
(*query)[columnName] = make([]Matcher, len(rawMatchers))
for idx, val := range rawMatchers {
// this should not happen
if len(val) == 0 {
continue
}
// if val starts with a { we have a matcher definition
if val[0] == '{' {
m, err := parseMatcher(val)
if err != nil {
return err
}
(*query)[columnName][idx] = *m
continue
} else if val[0] == '[' {
return fmt.Errorf("invalid token [ in query for column %s", columnName)
}
// val is a dedicated JSON primitive and not an object or array
// so we treat that as an EQUAL condition.
var x interface{}
if err := json.Unmarshal(val, &x); err != nil {
return err
}
(*query)[columnName][idx] = Matcher{
Equal: x,
}
}
default:
// value is a JSON primitive and not an object or array
// so we treat that as an EQUAL condition.
var x interface{}
if err := json.Unmarshal(rawColumnQuery, &x); err != nil {
return err
}
(*query)[columnName] = []Matcher{
{Equal: x},
}
}
}
return nil
}
// TODO(ppacher): right now we only support LIMIT and OFFSET for pagination but that
// has an issue that loading the same page twice might yield different results due to
// new records shifting the result slice. To overcome this, return a "PageToken" to the
// user that includes the time the initial query was created so paginated queries can
// ensure new records don't end up in the result set.
func (page *Pagination) toSQLLimitOffsetClause() string {
limit := page.PageSize
// default and cap the limit to at most 100 items
// per page to avoid out-of-memory conditions when loading
// thousands of results at once.
if limit <= 0 || limit > 100 {
limit = 100
}
sql := fmt.Sprintf("LIMIT %d", limit)
if page.Page > 0 {
sql += fmt.Sprintf(" OFFSET %d", page.Page*limit)
}
return sql
}
func parseMatcher(raw json.RawMessage) (*Matcher, error) {
var m Matcher
if err := json.Unmarshal(raw, &m); err != nil {
return nil, err
}
if err := m.Validate(); err != nil {
return nil, fmt.Errorf("invalid query matcher: %w", err)
}
// log.Printf("parsed matcher %s: %+v", string(raw), m)
return &m, nil
}
// Validate validates the matcher.
func (match Matcher) Validate() error {
found := 0
if match.Equal != nil {
found++
}
if match.NotEqual != nil {
found++
}
if match.In != nil {
found++
}
if match.NotIn != nil {
found++
}
if match.Like != "" {
found++
}
if found == 0 {
return fmt.Errorf("no conditions specified")
}
return nil
}
func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) {
var (
queryParts = make([]string, 0, len(text.Fields))
params = make(map[string]interface{})
)
key := fmt.Sprintf(":t%s", suffix)
params[key] = fmt.Sprintf("%%%s%%", text.Value)
for _, field := range text.Fields {
colDef := schema.GetColumnDef(field)
if colDef == nil {
return "", nil, fmt.Errorf("column %s is not allowed in text-search", colDef.Name)
}
if colDef.Type != sqlite.TypeText {
return "", nil, fmt.Errorf("type of column %s cannot be used in text-search", colDef.Name)
}
queryParts = append(queryParts, fmt.Sprintf("%s LIKE %s", colDef.Name, key))
}
if len(queryParts) == 0 {
return "", nil, nil
}
return "( " + strings.Join(queryParts, " OR ") + " )", params, nil
}
func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
var (
queryParts []string
params = make(map[string]interface{})
errs = new(multierror.Error)
key = fmt.Sprintf("%s%s", colDef.Name, suffix)
)
add := func(operator, suffix string, list bool, values ...interface{}) {
var placeholder []string
for idx, value := range values {
encodedValue, err := orm.EncodeValue(ctx, &colDef, value, encoderConfig)
if err != nil {
errs.Errors = append(errs.Errors,
fmt.Errorf("failed to encode %v for column %s: %w", value, colDef.Name, err),
)
return
}
uniqKey := fmt.Sprintf(":%s%s%d", key, suffix, idx)
placeholder = append(placeholder, uniqKey)
params[uniqKey] = encodedValue
}
if len(placeholder) == 1 && !list {
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0]))
} else {
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", ")))
}
}
if match.Equal != nil {
add("=", "eq", false, match.Equal)
}
if match.NotEqual != nil {
add("!=", "ne", false, match.NotEqual)
}
if match.In != nil {
add("IN", "in", true, match.In...)
}
if match.NotIn != nil {
add("NOT IN", "notin", true, match.NotIn...)
}
if match.Like != "" {
add("LIKE", "like", false, match.Like)
}
if len(queryParts) == 0 {
// this is an empty matcher without a single condition.
// we convert that to a no-op TRUE value
return "( 1 = 1 )", nil, errs.ErrorOrNil()
}
if len(queryParts) == 1 {
return queryParts[0], params, errs.ErrorOrNil()
}
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
}
func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
if len(query) == 0 {
return "", nil, nil
}
// create a lookup map to validate column names
lm := make(map[string]orm.ColumnDef, len(m.Columns))
for _, col := range m.Columns {
lm[col.Name] = col
}
paramMap := make(map[string]interface{})
columnStmts := make([]string, 0, len(query))
// get all keys and sort them so we get a stable output
queryKeys := make([]string, 0, len(query))
for column := range query {
queryKeys = append(queryKeys, column)
}
sort.Strings(queryKeys)
// actually create the WHERE clause parts for each
// column in query.
errs := new(multierror.Error)
for _, column := range queryKeys {
values := query[column]
colDef, ok := lm[column]
if !ok {
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
continue
}
queryParts := make([]string, len(values))
for idx, val := range values {
matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig)
if err != nil {
errs.Errors = append(errs.Errors,
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
)
continue
}
// merge parameters up into the superior parameter map
for key, val := range params {
if _, ok := paramMap[key]; ok {
// This is solely a developer mistake when implementing a matcher so no forgiving ...
panic("sqlite parameter collision")
}
paramMap[key] = val
}
queryParts[idx] = matcherQuery
}
columnStmts = append(columnStmts,
fmt.Sprintf("( %s )", strings.Join(queryParts, " OR ")),
)
}
whereClause := strings.Join(columnStmts, " AND ")
return whereClause, paramMap, errs.ErrorOrNil()
}
// UnmarshalJSON unmarshals a Selects from json.
func (sel *Selects) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
// if we are looking at a slice directly decode into
// a []Select
if blob[0] == '[' {
var result []Select
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
(*sel) = result
return nil
}
// if it's an object decode into a single select
if blob[0] == '{' {
var result Select
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*sel = []Select{result}
return nil
}
// otherwise this is just the field name
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
return nil
}
// UnmarshalJSON unmarshals a Select from json.
func (sel *Select) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
// if we have an object at hand decode the select
// directly
if blob[0] == '{' {
var res struct {
Field string `json:"field"`
Count *Count `json:"$count"`
Sum *Sum `json:"$sum"`
Distinct *string `json:"$distinct"`
}
if err := json.Unmarshal(blob, &res); err != nil {
return err
}
sel.Count = res.Count
sel.Field = res.Field
sel.Distinct = res.Distinct
sel.Sum = res.Sum
if sel.Count != nil && sel.Count.As != "" {
if !charOnlyRegexp.MatchString(sel.Count.As) {
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
}
}
return nil
}
var x string
if err := json.Unmarshal(blob, &x); err != nil {
return err
}
sel.Field = x
return nil
}
// UnmarshalJSON unmarshals a OrderBys from json.
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
if blob[0] == '[' {
var result []OrderBy
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*orderBys = result
return nil
}
if blob[0] == '{' {
var result OrderBy
if err := json.Unmarshal(blob, &result); err != nil {
return err
}
*orderBys = []OrderBy{result}
return nil
}
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
*orderBys = []OrderBy{
{
Field: field,
Desc: false,
},
}
return nil
}
// UnmarshalJSON unmarshals a OrderBy from json.
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
if len(blob) == 0 {
return io.ErrUnexpectedEOF
}
if blob[0] == '{' {
var res struct {
Field string `json:"field"`
Desc bool `json:"desc"`
}
if err := json.Unmarshal(blob, &res); err != nil {
return err
}
orderBy.Desc = res.Desc
orderBy.Field = res.Field
return nil
}
var field string
if err := json.Unmarshal(blob, &field); err != nil {
return err
}
orderBy.Field = field
orderBy.Desc = false
return nil
}