safing-portmaster/service/netquery/query_request.go

320 lines
7.7 KiB
Go

package netquery
import (
"context"
"fmt"
"strings"
"golang.org/x/exp/slices"
"github.com/safing/portmaster/service/netquery/orm"
)
type (
// QueryRequestPayload describes the payload of a netquery query.
QueryRequestPayload struct {
Select Selects `json:"select"`
Query Query `json:"query"`
OrderBy OrderBys `json:"orderBy"`
GroupBy []string `json:"groupBy"`
TextSearch *TextSearch `json:"textSearch"`
// A list of databases to query. If left empty,
// both, the LiveDatabase and the HistoryDatabase are queried
Databases []DatabaseName `json:"databases"`
Pagination
selectedFields []string
whitelistedFields []string
paramMap map[string]interface{}
}
// BatchQueryRequestPayload describes the payload of a batch netquery
// query. The map key is used in the response to identify the results
// for each query of the batch request.
BatchQueryRequestPayload map[string]QueryRequestPayload
)
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
if err := req.prepareSelectedFields(ctx, schema); err != nil {
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
}
// build the SQL where clause from the payload query
whereClause, paramMap, err := req.Query.toSQLWhereClause(
ctx,
"",
schema,
orm.DefaultEncodeConfig,
)
if err != nil {
return "", nil, fmt.Errorf("generating where clause: %w", err)
}
req.mergeParams(paramMap)
if req.TextSearch != nil {
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
if err != nil {
return "", nil, fmt.Errorf("generating text-search clause: %w", err)
}
if textClause != "" {
if whereClause != "" {
whereClause += " AND "
}
whereClause += textClause
req.mergeParams(textParams)
}
}
groupByClause, err := req.generateGroupByClause(schema)
if err != nil {
return "", nil, fmt.Errorf("generating group-by clause: %w", err)
}
orderByClause, err := req.generateOrderByClause(schema)
if err != nil {
return "", nil, fmt.Errorf("generating order-by clause: %w", err)
}
selectClause := req.generateSelectClause()
if whereClause != "" {
whereClause = "WHERE " + whereClause
}
// if no database is specified we default to LiveDatabase only.
if len(req.Databases) == 0 {
req.Databases = []DatabaseName{LiveDatabase}
}
sources := make([]string, len(req.Databases))
for idx, db := range req.Databases {
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
}
source := strings.Join(sources, " UNION ")
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause()
return strings.TrimSpace(query), req.paramMap, nil
}
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
for idx, s := range req.Select {
var field string
switch {
case s.Count != nil:
field = s.Count.Field
case s.Distinct != nil:
field = *s.Distinct
case s.Sum != nil:
if s.Sum.Field != "" {
field = s.Sum.Field
} else {
field = "*"
}
case s.Min != nil:
if s.Min.Field != "" {
field = s.Min.Field
} else {
field = "*"
}
case s.FieldSelect != nil:
field = s.FieldSelect.Field
default:
field = s.Field
}
colName := "*"
if field != "*" || (s.Count == nil && s.Sum == nil) {
var err error
colName, err = req.validateColumnName(schema, field)
if err != nil {
return err
}
}
switch {
case s.FieldSelect != nil:
as := s.FieldSelect.As
if as == "" {
as = s.FieldSelect.Field
}
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("%s AS %s", s.FieldSelect.Field, as),
)
req.whitelistedFields = append(req.whitelistedFields, as)
case s.Count != nil:
as := s.Count.As
if as == "" {
as = fmt.Sprintf("%s_count", colName)
}
distinct := ""
if s.Count.Distinct {
distinct = "DISTINCT "
}
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
)
req.whitelistedFields = append(req.whitelistedFields, as)
case s.Sum != nil:
if s.Sum.As == "" {
return fmt.Errorf("missing 'as' for $sum")
}
var (
clause string
params map[string]any
)
if s.Sum.Field != "" {
clause = s.Sum.Field
} else {
var err error
clause, params, err = s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $sum: %w", err)
}
}
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
case s.Min != nil:
if s.Min.As == "" {
return fmt.Errorf("missing 'as' for $min")
}
var (
clause string
params map[string]any
)
if s.Min.Field != "" {
clause = field
} else {
var err error
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $min: %w", err)
}
}
req.mergeParams(params)
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
case s.Distinct != nil:
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
req.whitelistedFields = append(req.whitelistedFields, colName)
default:
req.selectedFields = append(req.selectedFields, colName)
}
}
return nil
}
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
if req.paramMap == nil {
req.paramMap = make(map[string]any)
}
for key, value := range params {
req.paramMap[key] = value
}
}
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
if len(req.GroupBy) == 0 {
return "", nil
}
groupBys := make([]string, len(req.GroupBy))
for idx, name := range req.GroupBy {
colName, err := req.validateColumnName(schema, name)
if err != nil {
return "", err
}
groupBys[idx] = colName
}
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
// if there are no explicitly selected fields we default to the
// group-by columns as that's what's expected most of the time anyway...
if len(req.selectedFields) == 0 {
req.selectedFields = append(req.selectedFields, groupBys...)
}
return groupByClause, nil
}
func (req *QueryRequestPayload) generateSelectClause() string {
selectClause := "*"
if len(req.selectedFields) > 0 {
selectClause = strings.Join(req.selectedFields, ", ")
}
return selectClause
}
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
if len(req.OrderBy) == 0 {
return "", nil
}
orderBys := make([]string, len(req.OrderBy))
for idx, sort := range req.OrderBy {
colName, err := req.validateColumnName(schema, sort.Field)
if err != nil {
return "", err
}
if sort.Desc {
orderBys[idx] = fmt.Sprintf("%s DESC", colName)
} else {
orderBys[idx] = fmt.Sprintf("%s ASC", colName)
}
}
return "ORDER BY " + strings.Join(orderBys, ", "), nil
}
func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) {
colDef := schema.GetColumnDef(field)
if colDef != nil {
return colDef.Name, nil
}
if slices.Contains(req.whitelistedFields, field) {
return field, nil
}
if slices.Contains(req.selectedFields, field) {
return field, nil
}
return "", fmt.Errorf("column name %q not allowed", field)
}