mirror of
https://github.com/safing/portmaster
synced 2025-04-19 10:29:11 +00:00
320 lines
7.7 KiB
Go
320 lines
7.7 KiB
Go
package netquery
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"golang.org/x/exp/slices"
|
|
|
|
"github.com/safing/portmaster/service/netquery/orm"
|
|
)
|
|
|
|
type (
|
|
// QueryRequestPayload describes the payload of a netquery query.
|
|
QueryRequestPayload struct {
|
|
Select Selects `json:"select"`
|
|
Query Query `json:"query"`
|
|
OrderBy OrderBys `json:"orderBy"`
|
|
GroupBy []string `json:"groupBy"`
|
|
TextSearch *TextSearch `json:"textSearch"`
|
|
// A list of databases to query. If left empty,
|
|
// both, the LiveDatabase and the HistoryDatabase are queried
|
|
Databases []DatabaseName `json:"databases"`
|
|
|
|
Pagination
|
|
|
|
selectedFields []string
|
|
whitelistedFields []string
|
|
paramMap map[string]interface{}
|
|
}
|
|
|
|
// BatchQueryRequestPayload describes the payload of a batch netquery
|
|
// query. The map key is used in the response to identify the results
|
|
// for each query of the batch request.
|
|
BatchQueryRequestPayload map[string]QueryRequestPayload
|
|
)
|
|
|
|
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
|
if err := req.prepareSelectedFields(ctx, schema); err != nil {
|
|
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
|
|
}
|
|
|
|
// build the SQL where clause from the payload query
|
|
whereClause, paramMap, err := req.Query.toSQLWhereClause(
|
|
ctx,
|
|
"",
|
|
schema,
|
|
orm.DefaultEncodeConfig,
|
|
)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("generating where clause: %w", err)
|
|
}
|
|
|
|
req.mergeParams(paramMap)
|
|
|
|
if req.TextSearch != nil {
|
|
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("generating text-search clause: %w", err)
|
|
}
|
|
|
|
if textClause != "" {
|
|
if whereClause != "" {
|
|
whereClause += " AND "
|
|
}
|
|
|
|
whereClause += textClause
|
|
|
|
req.mergeParams(textParams)
|
|
}
|
|
}
|
|
|
|
groupByClause, err := req.generateGroupByClause(schema)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("generating group-by clause: %w", err)
|
|
}
|
|
|
|
orderByClause, err := req.generateOrderByClause(schema)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("generating order-by clause: %w", err)
|
|
}
|
|
|
|
selectClause := req.generateSelectClause()
|
|
|
|
if whereClause != "" {
|
|
whereClause = "WHERE " + whereClause
|
|
}
|
|
|
|
// if no database is specified we default to LiveDatabase only.
|
|
if len(req.Databases) == 0 {
|
|
req.Databases = []DatabaseName{LiveDatabase}
|
|
}
|
|
|
|
sources := make([]string, len(req.Databases))
|
|
for idx, db := range req.Databases {
|
|
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
|
|
}
|
|
|
|
source := strings.Join(sources, " UNION ")
|
|
|
|
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
|
|
|
|
query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause()
|
|
|
|
return strings.TrimSpace(query), req.paramMap, nil
|
|
}
|
|
|
|
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
|
|
for idx, s := range req.Select {
|
|
var field string
|
|
|
|
switch {
|
|
case s.Count != nil:
|
|
field = s.Count.Field
|
|
case s.Distinct != nil:
|
|
field = *s.Distinct
|
|
case s.Sum != nil:
|
|
if s.Sum.Field != "" {
|
|
field = s.Sum.Field
|
|
} else {
|
|
field = "*"
|
|
}
|
|
case s.Min != nil:
|
|
if s.Min.Field != "" {
|
|
field = s.Min.Field
|
|
} else {
|
|
field = "*"
|
|
}
|
|
case s.FieldSelect != nil:
|
|
field = s.FieldSelect.Field
|
|
default:
|
|
field = s.Field
|
|
}
|
|
|
|
colName := "*"
|
|
if field != "*" || (s.Count == nil && s.Sum == nil) {
|
|
var err error
|
|
|
|
colName, err = req.validateColumnName(schema, field)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
switch {
|
|
case s.FieldSelect != nil:
|
|
as := s.FieldSelect.As
|
|
if as == "" {
|
|
as = s.FieldSelect.Field
|
|
}
|
|
|
|
req.selectedFields = append(
|
|
req.selectedFields,
|
|
fmt.Sprintf("%s AS %s", s.FieldSelect.Field, as),
|
|
)
|
|
req.whitelistedFields = append(req.whitelistedFields, as)
|
|
|
|
case s.Count != nil:
|
|
as := s.Count.As
|
|
if as == "" {
|
|
as = fmt.Sprintf("%s_count", colName)
|
|
}
|
|
distinct := ""
|
|
if s.Count.Distinct {
|
|
distinct = "DISTINCT "
|
|
}
|
|
req.selectedFields = append(
|
|
req.selectedFields,
|
|
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
|
|
)
|
|
req.whitelistedFields = append(req.whitelistedFields, as)
|
|
|
|
case s.Sum != nil:
|
|
if s.Sum.As == "" {
|
|
return fmt.Errorf("missing 'as' for $sum")
|
|
}
|
|
|
|
var (
|
|
clause string
|
|
params map[string]any
|
|
)
|
|
|
|
if s.Sum.Field != "" {
|
|
clause = s.Sum.Field
|
|
} else {
|
|
var err error
|
|
clause, params, err = s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("in $sum: %w", err)
|
|
}
|
|
}
|
|
|
|
req.mergeParams(params)
|
|
req.selectedFields = append(
|
|
req.selectedFields,
|
|
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
|
|
)
|
|
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
|
|
|
|
case s.Min != nil:
|
|
if s.Min.As == "" {
|
|
return fmt.Errorf("missing 'as' for $min")
|
|
}
|
|
|
|
var (
|
|
clause string
|
|
params map[string]any
|
|
)
|
|
|
|
if s.Min.Field != "" {
|
|
clause = field
|
|
} else {
|
|
var err error
|
|
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("in $min: %w", err)
|
|
}
|
|
}
|
|
|
|
req.mergeParams(params)
|
|
req.selectedFields = append(
|
|
req.selectedFields,
|
|
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
|
|
)
|
|
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
|
|
|
|
case s.Distinct != nil:
|
|
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
|
|
req.whitelistedFields = append(req.whitelistedFields, colName)
|
|
|
|
default:
|
|
req.selectedFields = append(req.selectedFields, colName)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
|
|
if req.paramMap == nil {
|
|
req.paramMap = make(map[string]any)
|
|
}
|
|
|
|
for key, value := range params {
|
|
req.paramMap[key] = value
|
|
}
|
|
}
|
|
|
|
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
|
|
if len(req.GroupBy) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
groupBys := make([]string, len(req.GroupBy))
|
|
for idx, name := range req.GroupBy {
|
|
colName, err := req.validateColumnName(schema, name)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
groupBys[idx] = colName
|
|
}
|
|
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
|
|
|
|
// if there are no explicitly selected fields we default to the
|
|
// group-by columns as that's what's expected most of the time anyway...
|
|
if len(req.selectedFields) == 0 {
|
|
req.selectedFields = append(req.selectedFields, groupBys...)
|
|
}
|
|
|
|
return groupByClause, nil
|
|
}
|
|
|
|
func (req *QueryRequestPayload) generateSelectClause() string {
|
|
selectClause := "*"
|
|
if len(req.selectedFields) > 0 {
|
|
selectClause = strings.Join(req.selectedFields, ", ")
|
|
}
|
|
|
|
return selectClause
|
|
}
|
|
|
|
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
|
|
if len(req.OrderBy) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
orderBys := make([]string, len(req.OrderBy))
|
|
for idx, sort := range req.OrderBy {
|
|
colName, err := req.validateColumnName(schema, sort.Field)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if sort.Desc {
|
|
orderBys[idx] = fmt.Sprintf("%s DESC", colName)
|
|
} else {
|
|
orderBys[idx] = fmt.Sprintf("%s ASC", colName)
|
|
}
|
|
}
|
|
|
|
return "ORDER BY " + strings.Join(orderBys, ", "), nil
|
|
}
|
|
|
|
func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) {
|
|
colDef := schema.GetColumnDef(field)
|
|
if colDef != nil {
|
|
return colDef.Name, nil
|
|
}
|
|
|
|
if slices.Contains(req.whitelistedFields, field) {
|
|
return field, nil
|
|
}
|
|
|
|
if slices.Contains(req.selectedFields, field) {
|
|
return field, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("column name %q not allowed", field)
|
|
}
|