mirror of
https://github.com/safing/portmaster
synced 2025-04-19 10:29:11 +00:00
675 lines
16 KiB
Go
675 lines
16 KiB
Go
package netquery
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
"golang.org/x/exp/slices"
|
|
"zombiezen.com/go/sqlite"
|
|
|
|
"github.com/safing/portmaster/service/netquery/orm"
|
|
)
|
|
|
|
// DatabaseName is a database name constant.
|
|
type DatabaseName string
|
|
|
|
// Databases.
|
|
const (
|
|
LiveDatabase = DatabaseName("main")
|
|
HistoryDatabase = DatabaseName("history")
|
|
)
|
|
|
|
// Collection of Query and Matcher types.
|
|
// NOTE: whenever adding support for new operators make sure
|
|
// to update UnmarshalJSON as well.
|
|
//
|
|
//nolint:golint
|
|
type (
|
|
Query map[string][]Matcher
|
|
|
|
MatchType interface {
|
|
Operator() string
|
|
}
|
|
|
|
Equal interface{}
|
|
|
|
Matcher struct {
|
|
Equal interface{} `json:"$eq,omitempty"`
|
|
NotEqual interface{} `json:"$ne,omitempty"`
|
|
In []interface{} `json:"$in,omitempty"`
|
|
NotIn []interface{} `json:"$notIn,omitempty"`
|
|
Like string `json:"$like,omitempty"`
|
|
Greater *float64 `json:"$gt,omitempty"`
|
|
GreaterOrEqual *float64 `json:"$ge,omitempty"`
|
|
Less *float64 `json:"$lt,omitempty"`
|
|
LessOrEqual *float64 `json:"$le,omitempty"`
|
|
}
|
|
|
|
Count struct {
|
|
As string `json:"as"`
|
|
Field string `json:"field"`
|
|
Distinct bool `json:"distinct"`
|
|
}
|
|
|
|
Sum struct {
|
|
Condition Query `json:"condition"`
|
|
Field string `json:"field"`
|
|
As string `json:"as"`
|
|
Distinct bool `json:"distinct"`
|
|
}
|
|
|
|
Min struct {
|
|
Condition *Query `json:"condition,omitempty"`
|
|
Field string `json:"field"`
|
|
As string `json:"as"`
|
|
Distinct bool `json:"distinct"`
|
|
}
|
|
|
|
FieldSelect struct {
|
|
Field string `json:"field"`
|
|
As string `json:"as"`
|
|
}
|
|
|
|
Select struct {
|
|
Field string `json:"field"`
|
|
FieldSelect *FieldSelect `json:"$field"`
|
|
Count *Count `json:"$count,omitempty"`
|
|
Sum *Sum `json:"$sum,omitempty"`
|
|
Min *Min `json:"$min,omitempty"`
|
|
Distinct *string `json:"$distinct,omitempty"`
|
|
}
|
|
|
|
Selects []Select
|
|
|
|
TextSearch struct {
|
|
Fields []string `json:"fields"`
|
|
Value string `json:"value"`
|
|
}
|
|
|
|
QueryActiveConnectionChartPayload struct {
|
|
Query Query `json:"query"`
|
|
TextSearch *TextSearch `json:"textSearch"`
|
|
}
|
|
|
|
OrderBy struct {
|
|
Field string `json:"field"`
|
|
Desc bool `json:"desc"`
|
|
}
|
|
|
|
OrderBys []OrderBy
|
|
|
|
Pagination struct {
|
|
PageSize int `json:"pageSize"`
|
|
Page int `json:"page"`
|
|
}
|
|
)
|
|
|
|
// UnmarshalJSON unmarshals a Query from json.
|
|
func (query *Query) UnmarshalJSON(blob []byte) error {
|
|
if *query == nil {
|
|
*query = make(Query)
|
|
}
|
|
|
|
var model map[string]json.RawMessage
|
|
|
|
if err := json.Unmarshal(blob, &model); err != nil {
|
|
return err
|
|
}
|
|
|
|
for columnName, rawColumnQuery := range model {
|
|
if len(rawColumnQuery) == 0 {
|
|
continue
|
|
}
|
|
|
|
switch rawColumnQuery[0] {
|
|
case '{':
|
|
m, err := parseMatcher(rawColumnQuery)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
(*query)[columnName] = []Matcher{*m}
|
|
|
|
case '[':
|
|
var rawMatchers []json.RawMessage
|
|
if err := json.Unmarshal(rawColumnQuery, &rawMatchers); err != nil {
|
|
return err
|
|
}
|
|
|
|
(*query)[columnName] = make([]Matcher, len(rawMatchers))
|
|
for idx, val := range rawMatchers {
|
|
// this should not happen
|
|
if len(val) == 0 {
|
|
continue
|
|
}
|
|
|
|
// if val starts with a { we have a matcher definition
|
|
if val[0] == '{' {
|
|
m, err := parseMatcher(val)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
(*query)[columnName][idx] = *m
|
|
|
|
continue
|
|
} else if val[0] == '[' {
|
|
return fmt.Errorf("invalid token [ in query for column %s", columnName)
|
|
}
|
|
|
|
// val is a dedicated JSON primitive and not an object or array
|
|
// so we treat that as an EQUAL condition.
|
|
var x interface{}
|
|
if err := json.Unmarshal(val, &x); err != nil {
|
|
return err
|
|
}
|
|
|
|
(*query)[columnName][idx] = Matcher{
|
|
Equal: x,
|
|
}
|
|
}
|
|
|
|
default:
|
|
// value is a JSON primitive and not an object or array
|
|
// so we treat that as an EQUAL condition.
|
|
var x interface{}
|
|
if err := json.Unmarshal(rawColumnQuery, &x); err != nil {
|
|
return err
|
|
}
|
|
|
|
(*query)[columnName] = []Matcher{
|
|
{Equal: x},
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// TODO(ppacher): right now we only support LIMIT and OFFSET for pagination but that
|
|
// has an issue that loading the same page twice might yield different results due to
|
|
// new records shifting the result slice. To overcome this, return a "PageToken" to the
|
|
// user that includes the time the initial query was created so paginated queries can
|
|
// ensure new records don't end up in the result set.
|
|
func (page *Pagination) toSQLLimitOffsetClause() string {
|
|
limit := page.PageSize
|
|
|
|
// default and cap the limit to at most 100 items
|
|
// per page to avoid out-of-memory conditions when loading
|
|
// thousands of results at once.
|
|
if limit <= 0 || limit > 100 {
|
|
limit = 100
|
|
}
|
|
|
|
sql := fmt.Sprintf("LIMIT %d", limit)
|
|
|
|
if page.Page > 0 {
|
|
sql += fmt.Sprintf(" OFFSET %d", page.Page*limit)
|
|
}
|
|
|
|
return sql
|
|
}
|
|
|
|
func parseMatcher(raw json.RawMessage) (*Matcher, error) {
|
|
var m Matcher
|
|
if err := json.Unmarshal(raw, &m); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := m.Validate(); err != nil {
|
|
return nil, fmt.Errorf("invalid query matcher: %w", err)
|
|
}
|
|
|
|
// log.Printf("parsed matcher %s: %+v", string(raw), m)
|
|
return &m, nil
|
|
}
|
|
|
|
// Validate validates the matcher.
|
|
func (match Matcher) Validate() error {
|
|
found := 0
|
|
|
|
if match.Equal != nil {
|
|
found++
|
|
}
|
|
|
|
if match.NotEqual != nil {
|
|
found++
|
|
}
|
|
|
|
if match.In != nil {
|
|
found++
|
|
}
|
|
|
|
if match.NotIn != nil {
|
|
found++
|
|
}
|
|
|
|
if match.Like != "" {
|
|
found++
|
|
}
|
|
|
|
if match.Greater != nil {
|
|
found++
|
|
}
|
|
|
|
if match.GreaterOrEqual != nil {
|
|
found++
|
|
}
|
|
|
|
if match.Less != nil {
|
|
found++
|
|
}
|
|
|
|
if match.LessOrEqual != nil {
|
|
found++
|
|
}
|
|
|
|
if found == 0 {
|
|
return fmt.Errorf("no conditions specified")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (text TextSearch) toSQLConditionClause(_ context.Context, schema *orm.TableSchema, suffix string, _ orm.EncodeConfig) (string, map[string]interface{}, error) {
|
|
var (
|
|
queryParts = make([]string, 0, len(text.Fields))
|
|
params = make(map[string]interface{})
|
|
)
|
|
|
|
key := fmt.Sprintf(":t%s", suffix)
|
|
params[key] = fmt.Sprintf("%%%s%%", text.Value)
|
|
|
|
for _, field := range text.Fields {
|
|
colDef := schema.GetColumnDef(field)
|
|
if colDef == nil {
|
|
return "", nil, fmt.Errorf("column %s is not allowed in text-search", colDef.Name)
|
|
}
|
|
if colDef.Type != sqlite.TypeText {
|
|
return "", nil, fmt.Errorf("type of column %s cannot be used in text-search", colDef.Name)
|
|
}
|
|
|
|
queryParts = append(queryParts, fmt.Sprintf("%s LIKE %s", colDef.Name, key))
|
|
}
|
|
|
|
if len(queryParts) == 0 {
|
|
return "", nil, nil
|
|
}
|
|
|
|
return "( " + strings.Join(queryParts, " OR ") + " )", params, nil
|
|
}
|
|
|
|
func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
|
var (
|
|
queryParts []string
|
|
params = make(map[string]interface{})
|
|
errs = new(multierror.Error)
|
|
key = fmt.Sprintf("%s%s", colDef.Name, suffix)
|
|
)
|
|
|
|
add := func(operator, suffix string, list bool, values ...interface{}) {
|
|
var placeholder []string
|
|
|
|
for idx, value := range values {
|
|
var (
|
|
encodedValue any
|
|
err error
|
|
)
|
|
|
|
kind := orm.NormalizeKind(reflect.TypeOf(value).Kind())
|
|
isNumber := slices.Contains([]reflect.Kind{
|
|
reflect.Uint,
|
|
reflect.Int,
|
|
reflect.Float64,
|
|
}, kind)
|
|
|
|
// if we query a time-field that is queried as a number, don't do any encoding
|
|
// here as the orm.DateTimeEncoder would convert the number to a string.
|
|
if colDef.IsTime && colDef.Type == sqlite.TypeText && isNumber {
|
|
encodedValue = value
|
|
} else {
|
|
encodedValue, err = orm.EncodeValue(ctx, &colDef, value, encoderConfig)
|
|
if err != nil {
|
|
errs.Errors = append(errs.Errors,
|
|
fmt.Errorf("failed to encode %v for column %s: %w", value, colDef.Name, err),
|
|
)
|
|
return
|
|
}
|
|
}
|
|
|
|
uniqKey := fmt.Sprintf(":%s%s%d", key, suffix, idx)
|
|
placeholder = append(placeholder, uniqKey)
|
|
params[uniqKey] = encodedValue
|
|
}
|
|
|
|
nameStmt := colDef.Name
|
|
|
|
if len(values) > 0 {
|
|
// NOTE(ppacher): for now we assume that the type of each element of values
|
|
// is the same. We also can be sure that there is always at least one value.
|
|
//
|
|
// TODO(ppacher): if we start supporting values of different types here
|
|
// we need to revisit the whole behavior as we might need to do more boolean
|
|
// expression nesting to support that.
|
|
kind := orm.NormalizeKind(reflect.TypeOf(values[0]).Kind())
|
|
isNumber := slices.Contains([]reflect.Kind{
|
|
reflect.Uint,
|
|
reflect.Int,
|
|
reflect.Float64,
|
|
}, kind)
|
|
|
|
// if this is a time column that is stored in "text" format and the provided
|
|
// value is a number type, we need to wrap the property in a strftime() method
|
|
// call.
|
|
if colDef.IsTime && colDef.Type == sqlite.TypeText && isNumber {
|
|
nameStmt = fmt.Sprintf("strftime('%%s', %s)+0", nameStmt)
|
|
}
|
|
}
|
|
|
|
if len(placeholder) == 1 && !list {
|
|
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", nameStmt, operator, placeholder[0]))
|
|
} else {
|
|
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", nameStmt, operator, strings.Join(placeholder, ", ")))
|
|
}
|
|
}
|
|
|
|
if match.Equal != nil {
|
|
add("=", "eq", false, match.Equal)
|
|
}
|
|
|
|
if match.NotEqual != nil {
|
|
add("!=", "ne", false, match.NotEqual)
|
|
}
|
|
|
|
if match.In != nil {
|
|
add("IN", "in", true, match.In...)
|
|
}
|
|
|
|
if match.NotIn != nil {
|
|
add("NOT IN", "notin", true, match.NotIn...)
|
|
}
|
|
|
|
if match.Like != "" {
|
|
add("LIKE", "like", false, match.Like)
|
|
}
|
|
|
|
if match.Greater != nil {
|
|
add(">", "gt", false, *match.Greater)
|
|
}
|
|
|
|
if match.GreaterOrEqual != nil {
|
|
add(">=", "ge", false, *match.GreaterOrEqual)
|
|
}
|
|
|
|
if match.Less != nil {
|
|
add("<", "lt", false, *match.Less)
|
|
}
|
|
|
|
if match.LessOrEqual != nil {
|
|
add("<=", "le", false, *match.LessOrEqual)
|
|
}
|
|
|
|
if len(queryParts) == 0 {
|
|
// this is an empty matcher without a single condition.
|
|
// we convert that to a no-op TRUE value
|
|
return "( 1 = 1 )", nil, errs.ErrorOrNil()
|
|
}
|
|
|
|
if len(queryParts) == 1 {
|
|
return queryParts[0], params, errs.ErrorOrNil()
|
|
}
|
|
|
|
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
|
|
}
|
|
|
|
func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
|
if len(query) == 0 {
|
|
return "", nil, nil
|
|
}
|
|
|
|
// create a lookup map to validate column names
|
|
lm := make(map[string]orm.ColumnDef, len(m.Columns))
|
|
for _, col := range m.Columns {
|
|
lm[col.Name] = col
|
|
}
|
|
|
|
paramMap := make(map[string]interface{})
|
|
columnStmts := make([]string, 0, len(query))
|
|
|
|
// get all keys and sort them so we get a stable output
|
|
queryKeys := make([]string, 0, len(query))
|
|
for column := range query {
|
|
queryKeys = append(queryKeys, column)
|
|
}
|
|
sort.Strings(queryKeys)
|
|
|
|
// actually create the WHERE clause parts for each
|
|
// column in query.
|
|
errs := new(multierror.Error)
|
|
for _, column := range queryKeys {
|
|
values := query[column]
|
|
colDef, ok := lm[column]
|
|
if !ok {
|
|
|
|
errs.Errors = append(errs.Errors, fmt.Errorf("column %s is not allowed", column))
|
|
|
|
continue
|
|
}
|
|
|
|
queryParts := make([]string, len(values))
|
|
for idx, val := range values {
|
|
matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig)
|
|
if err != nil {
|
|
errs.Errors = append(errs.Errors,
|
|
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
|
|
)
|
|
|
|
continue
|
|
}
|
|
|
|
// merge parameters up into the superior parameter map
|
|
for key, val := range params {
|
|
if _, ok := paramMap[key]; ok {
|
|
// This is solely a developer mistake when implementing a matcher so no forgiving ...
|
|
panic("sqlite parameter collision")
|
|
}
|
|
|
|
paramMap[key] = val
|
|
}
|
|
|
|
queryParts[idx] = matcherQuery
|
|
}
|
|
|
|
columnStmts = append(columnStmts,
|
|
fmt.Sprintf("( %s )", strings.Join(queryParts, " OR ")),
|
|
)
|
|
}
|
|
|
|
whereClause := strings.Join(columnStmts, " AND ")
|
|
|
|
return whereClause, paramMap, errs.ErrorOrNil()
|
|
}
|
|
|
|
// UnmarshalJSON unmarshals a Selects from json.
|
|
func (sel *Selects) UnmarshalJSON(blob []byte) error {
|
|
if len(blob) == 0 {
|
|
return io.ErrUnexpectedEOF
|
|
}
|
|
|
|
// if we are looking at a slice directly decode into
|
|
// a []Select
|
|
if blob[0] == '[' {
|
|
var result []Select
|
|
if err := json.Unmarshal(blob, &result); err != nil {
|
|
return err
|
|
}
|
|
|
|
(*sel) = result
|
|
|
|
return nil
|
|
}
|
|
|
|
// if it's an object decode into a single select
|
|
if blob[0] == '{' {
|
|
var result Select
|
|
if err := json.Unmarshal(blob, &result); err != nil {
|
|
return err
|
|
}
|
|
|
|
*sel = []Select{result}
|
|
|
|
return nil
|
|
}
|
|
|
|
// otherwise this is just the field name
|
|
var field string
|
|
if err := json.Unmarshal(blob, &field); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UnmarshalJSON unmarshals a Select from json.
|
|
func (sel *Select) UnmarshalJSON(blob []byte) error {
|
|
if len(blob) == 0 {
|
|
return io.ErrUnexpectedEOF
|
|
}
|
|
|
|
// if we have an object at hand decode the select
|
|
// directly
|
|
if blob[0] == '{' {
|
|
var res struct {
|
|
Field string `json:"field"`
|
|
Count *Count `json:"$count"`
|
|
Sum *Sum `json:"$sum"`
|
|
Min *Min `json:"$min"`
|
|
Distinct *string `json:"$distinct"`
|
|
FieldSelect *FieldSelect `json:"$field"`
|
|
}
|
|
|
|
if err := json.Unmarshal(blob, &res); err != nil {
|
|
return err
|
|
}
|
|
|
|
sel.Count = res.Count
|
|
sel.Field = res.Field
|
|
sel.FieldSelect = res.FieldSelect
|
|
sel.Distinct = res.Distinct
|
|
sel.Sum = res.Sum
|
|
sel.Min = res.Min
|
|
|
|
if sel.Count != nil && sel.Count.As != "" {
|
|
if !charOnlyRegexp.MatchString(sel.Count.As) {
|
|
return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+")
|
|
}
|
|
}
|
|
if sel.Sum != nil && sel.Sum.As != "" {
|
|
if !charOnlyRegexp.MatchString(sel.Sum.As) {
|
|
return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+")
|
|
}
|
|
}
|
|
if sel.Min != nil && sel.Min.As != "" {
|
|
if !charOnlyRegexp.MatchString(sel.Min.As) {
|
|
return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+")
|
|
}
|
|
}
|
|
if sel.FieldSelect != nil && sel.FieldSelect.As != "" {
|
|
if !charOnlyRegexp.MatchString(sel.FieldSelect.As) {
|
|
return fmt.Errorf("invalid characters in $field.as, value must match [a-zA-Z]+")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
var x string
|
|
if err := json.Unmarshal(blob, &x); err != nil {
|
|
return err
|
|
}
|
|
|
|
sel.Field = x
|
|
|
|
return nil
|
|
}
|
|
|
|
// UnmarshalJSON unmarshals a OrderBys from json.
|
|
func (orderBys *OrderBys) UnmarshalJSON(blob []byte) error {
|
|
if len(blob) == 0 {
|
|
return io.ErrUnexpectedEOF
|
|
}
|
|
|
|
if blob[0] == '[' {
|
|
var result []OrderBy
|
|
if err := json.Unmarshal(blob, &result); err != nil {
|
|
return err
|
|
}
|
|
|
|
*orderBys = result
|
|
|
|
return nil
|
|
}
|
|
|
|
if blob[0] == '{' {
|
|
var result OrderBy
|
|
if err := json.Unmarshal(blob, &result); err != nil {
|
|
return err
|
|
}
|
|
|
|
*orderBys = []OrderBy{result}
|
|
|
|
return nil
|
|
}
|
|
|
|
var field string
|
|
if err := json.Unmarshal(blob, &field); err != nil {
|
|
return err
|
|
}
|
|
|
|
*orderBys = []OrderBy{
|
|
{
|
|
Field: field,
|
|
Desc: false,
|
|
},
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UnmarshalJSON unmarshals a OrderBy from json.
|
|
func (orderBy *OrderBy) UnmarshalJSON(blob []byte) error {
|
|
if len(blob) == 0 {
|
|
return io.ErrUnexpectedEOF
|
|
}
|
|
|
|
if blob[0] == '{' {
|
|
var res struct {
|
|
Field string `json:"field"`
|
|
Desc bool `json:"desc"`
|
|
}
|
|
|
|
if err := json.Unmarshal(blob, &res); err != nil {
|
|
return err
|
|
}
|
|
|
|
orderBy.Desc = res.Desc
|
|
orderBy.Field = res.Field
|
|
|
|
return nil
|
|
}
|
|
|
|
var field string
|
|
if err := json.Unmarshal(blob, &field); err != nil {
|
|
return err
|
|
}
|
|
|
|
orderBy.Field = field
|
|
orderBy.Desc = false
|
|
|
|
return nil
|
|
}
|