mirror of
https://github.com/safing/portmaster
synced 2025-09-02 18:49:14 +00:00
netquery: split up query payload into a dedicated file
This commit is contained in:
parent
d5c4495991
commit
ba72c204d3
3 changed files with 306 additions and 292 deletions
|
@ -86,23 +86,6 @@ type (
|
||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
QueryRequestPayload struct {
|
|
||||||
Select Selects `json:"select"`
|
|
||||||
Query Query `json:"query"`
|
|
||||||
OrderBy OrderBys `json:"orderBy"`
|
|
||||||
GroupBy []string `json:"groupBy"`
|
|
||||||
TextSearch *TextSearch `json:"textSearch"`
|
|
||||||
// A list of databases to query. If left empty,
|
|
||||||
// both, the LiveDatabase and the HistoryDatabase are queried
|
|
||||||
Databases []DatabaseName `json:"databases"`
|
|
||||||
|
|
||||||
Pagination
|
|
||||||
|
|
||||||
selectedFields []string
|
|
||||||
whitelistedFields []string
|
|
||||||
paramMap map[string]interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
QueryActiveConnectionChartPayload struct {
|
QueryActiveConnectionChartPayload struct {
|
||||||
Query Query `json:"query"`
|
Query Query `json:"query"`
|
||||||
TextSearch *TextSearch `json:"textSearch"`
|
TextSearch *TextSearch `json:"textSearch"`
|
||||||
|
|
|
@ -2,7 +2,6 @@ package netquery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -12,8 +11,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
|
|
||||||
"github.com/safing/portbase/log"
|
"github.com/safing/portbase/log"
|
||||||
"github.com/safing/portmaster/netquery/orm"
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
)
|
)
|
||||||
|
@ -32,7 +29,7 @@ type (
|
||||||
|
|
||||||
func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
requestPayload, err := qh.parseRequest(req)
|
requestPayload, err := parseQueryRequestPayload(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
@ -108,7 +105,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { //nolint:dupl
|
func parseQueryRequestPayload(req *http.Request) (*QueryRequestPayload, error) { //nolint:dupl
|
||||||
var body io.Reader
|
var body io.Reader
|
||||||
|
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
|
@ -138,275 +135,5 @@ func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, e
|
||||||
return &requestPayload, nil
|
return &requestPayload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
|
||||||
if err := req.prepareSelectedFields(ctx, schema); err != nil {
|
|
||||||
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// build the SQL where clause from the payload query
|
|
||||||
whereClause, paramMap, err := req.Query.toSQLWhereClause(
|
|
||||||
ctx,
|
|
||||||
"",
|
|
||||||
schema,
|
|
||||||
orm.DefaultEncodeConfig,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("generating where clause: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.mergeParams(paramMap)
|
|
||||||
|
|
||||||
if req.TextSearch != nil {
|
|
||||||
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("generating text-search clause: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if textClause != "" {
|
|
||||||
if whereClause != "" {
|
|
||||||
whereClause += " AND "
|
|
||||||
}
|
|
||||||
|
|
||||||
whereClause += textClause
|
|
||||||
|
|
||||||
req.mergeParams(textParams)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
groupByClause, err := req.generateGroupByClause(schema)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("generating group-by clause: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
orderByClause, err := req.generateOrderByClause(schema)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("generating order-by clause: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
selectClause := req.generateSelectClause()
|
|
||||||
|
|
||||||
if whereClause != "" {
|
|
||||||
whereClause = "WHERE " + whereClause
|
|
||||||
}
|
|
||||||
|
|
||||||
// if no database is specified we default to LiveDatabase only.
|
|
||||||
if len(req.Databases) == 0 {
|
|
||||||
req.Databases = []DatabaseName{LiveDatabase}
|
|
||||||
}
|
|
||||||
|
|
||||||
sources := make([]string, len(req.Databases))
|
|
||||||
for idx, db := range req.Databases {
|
|
||||||
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
|
|
||||||
}
|
|
||||||
|
|
||||||
source := strings.Join(sources, " UNION ")
|
|
||||||
|
|
||||||
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
|
|
||||||
|
|
||||||
query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause()
|
|
||||||
|
|
||||||
return strings.TrimSpace(query), req.paramMap, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
|
|
||||||
for idx, s := range req.Select {
|
|
||||||
var field string
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case s.Count != nil:
|
|
||||||
field = s.Count.Field
|
|
||||||
case s.Distinct != nil:
|
|
||||||
field = *s.Distinct
|
|
||||||
case s.Sum != nil:
|
|
||||||
if s.Sum.Field != "" {
|
|
||||||
field = s.Sum.Field
|
|
||||||
} else {
|
|
||||||
field = "*"
|
|
||||||
}
|
|
||||||
case s.Min != nil:
|
|
||||||
if s.Min.Field != "" {
|
|
||||||
field = s.Min.Field
|
|
||||||
} else {
|
|
||||||
field = "*"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
field = s.Field
|
|
||||||
}
|
|
||||||
|
|
||||||
colName := "*"
|
|
||||||
if field != "*" || (s.Count == nil && s.Sum == nil) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
colName, err = req.validateColumnName(schema, field)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case s.Count != nil:
|
|
||||||
as := s.Count.As
|
|
||||||
if as == "" {
|
|
||||||
as = fmt.Sprintf("%s_count", colName)
|
|
||||||
}
|
|
||||||
distinct := ""
|
|
||||||
if s.Count.Distinct {
|
|
||||||
distinct = "DISTINCT "
|
|
||||||
}
|
|
||||||
req.selectedFields = append(
|
|
||||||
req.selectedFields,
|
|
||||||
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
|
|
||||||
)
|
|
||||||
req.whitelistedFields = append(req.whitelistedFields, as)
|
|
||||||
|
|
||||||
case s.Sum != nil:
|
|
||||||
if s.Sum.As == "" {
|
|
||||||
return fmt.Errorf("missing 'as' for $sum")
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
clause string
|
|
||||||
params map[string]any
|
|
||||||
)
|
|
||||||
|
|
||||||
if s.Sum.Field != "" {
|
|
||||||
clause = s.Sum.Field
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
clause, params, err = s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("in $sum: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.mergeParams(params)
|
|
||||||
req.selectedFields = append(
|
|
||||||
req.selectedFields,
|
|
||||||
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
|
|
||||||
)
|
|
||||||
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
|
|
||||||
|
|
||||||
case s.Min != nil:
|
|
||||||
if s.Min.As == "" {
|
|
||||||
return fmt.Errorf("missing 'as' for $min")
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
clause string
|
|
||||||
params map[string]any
|
|
||||||
)
|
|
||||||
|
|
||||||
if s.Min.Field != "" {
|
|
||||||
clause = field
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("in $min: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.mergeParams(params)
|
|
||||||
req.selectedFields = append(
|
|
||||||
req.selectedFields,
|
|
||||||
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
|
|
||||||
)
|
|
||||||
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
|
|
||||||
|
|
||||||
case s.Distinct != nil:
|
|
||||||
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
|
|
||||||
req.whitelistedFields = append(req.whitelistedFields, colName)
|
|
||||||
|
|
||||||
default:
|
|
||||||
req.selectedFields = append(req.selectedFields, colName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
|
|
||||||
if req.paramMap == nil {
|
|
||||||
req.paramMap = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value := range params {
|
|
||||||
req.paramMap[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
|
|
||||||
if len(req.GroupBy) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
groupBys := make([]string, len(req.GroupBy))
|
|
||||||
for idx, name := range req.GroupBy {
|
|
||||||
colName, err := req.validateColumnName(schema, name)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
groupBys[idx] = colName
|
|
||||||
}
|
|
||||||
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
|
|
||||||
|
|
||||||
// if there are no explicitly selected fields we default to the
|
|
||||||
// group-by columns as that's what's expected most of the time anyway...
|
|
||||||
if len(req.selectedFields) == 0 {
|
|
||||||
req.selectedFields = append(req.selectedFields, groupBys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return groupByClause, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *QueryRequestPayload) generateSelectClause() string {
|
|
||||||
selectClause := "*"
|
|
||||||
if len(req.selectedFields) > 0 {
|
|
||||||
selectClause = strings.Join(req.selectedFields, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
return selectClause
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
|
|
||||||
if len(req.OrderBy) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
orderBys := make([]string, len(req.OrderBy))
|
|
||||||
for idx, sort := range req.OrderBy {
|
|
||||||
colName, err := req.validateColumnName(schema, sort.Field)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if sort.Desc {
|
|
||||||
orderBys[idx] = fmt.Sprintf("%s DESC", colName)
|
|
||||||
} else {
|
|
||||||
orderBys[idx] = fmt.Sprintf("%s ASC", colName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "ORDER BY " + strings.Join(orderBys, ", "), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) {
|
|
||||||
colDef := schema.GetColumnDef(field)
|
|
||||||
if colDef != nil {
|
|
||||||
return colDef.Name, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if slices.Contains(req.whitelistedFields, field) {
|
|
||||||
return field, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if slices.Contains(req.selectedFields, field) {
|
|
||||||
return field, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("column name %q not allowed", field)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile time check.
|
// Compile time check.
|
||||||
var _ http.Handler = new(QueryHandler)
|
var _ http.Handler = new(QueryHandler)
|
||||||
|
|
304
netquery/query_request.go
Normal file
304
netquery/query_request.go
Normal file
|
@ -0,0 +1,304 @@
|
||||||
|
package netquery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
QueryRequestPayload struct {
|
||||||
|
Select Selects `json:"select"`
|
||||||
|
Query Query `json:"query"`
|
||||||
|
OrderBy OrderBys `json:"orderBy"`
|
||||||
|
GroupBy []string `json:"groupBy"`
|
||||||
|
TextSearch *TextSearch `json:"textSearch"`
|
||||||
|
// A list of databases to query. If left empty,
|
||||||
|
// both, the LiveDatabase and the HistoryDatabase are queried
|
||||||
|
Databases []DatabaseName `json:"databases"`
|
||||||
|
|
||||||
|
Pagination
|
||||||
|
|
||||||
|
selectedFields []string
|
||||||
|
whitelistedFields []string
|
||||||
|
paramMap map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchQueryRequestPayload describes the payload of a batch netquery
|
||||||
|
// query. The map key is used in the response to identify the results
|
||||||
|
// for each query of the batch request.
|
||||||
|
BatchQueryRequestPayload map[string]QueryRequestPayload
|
||||||
|
)
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||||
|
if err := req.prepareSelectedFields(ctx, schema); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the SQL where clause from the payload query
|
||||||
|
whereClause, paramMap, err := req.Query.toSQLWhereClause(
|
||||||
|
ctx,
|
||||||
|
"",
|
||||||
|
schema,
|
||||||
|
orm.DefaultEncodeConfig,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("generating where clause: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.mergeParams(paramMap)
|
||||||
|
|
||||||
|
if req.TextSearch != nil {
|
||||||
|
textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("generating text-search clause: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if textClause != "" {
|
||||||
|
if whereClause != "" {
|
||||||
|
whereClause += " AND "
|
||||||
|
}
|
||||||
|
|
||||||
|
whereClause += textClause
|
||||||
|
|
||||||
|
req.mergeParams(textParams)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groupByClause, err := req.generateGroupByClause(schema)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("generating group-by clause: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
orderByClause, err := req.generateOrderByClause(schema)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("generating order-by clause: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
selectClause := req.generateSelectClause()
|
||||||
|
|
||||||
|
if whereClause != "" {
|
||||||
|
whereClause = "WHERE " + whereClause
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no database is specified we default to LiveDatabase only.
|
||||||
|
if len(req.Databases) == 0 {
|
||||||
|
req.Databases = []DatabaseName{LiveDatabase}
|
||||||
|
}
|
||||||
|
|
||||||
|
sources := make([]string, len(req.Databases))
|
||||||
|
for idx, db := range req.Databases {
|
||||||
|
sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause)
|
||||||
|
}
|
||||||
|
|
||||||
|
source := strings.Join(sources, " UNION ")
|
||||||
|
|
||||||
|
query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) `
|
||||||
|
|
||||||
|
query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause()
|
||||||
|
|
||||||
|
return strings.TrimSpace(query), req.paramMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
|
||||||
|
for idx, s := range req.Select {
|
||||||
|
var field string
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case s.Count != nil:
|
||||||
|
field = s.Count.Field
|
||||||
|
case s.Distinct != nil:
|
||||||
|
field = *s.Distinct
|
||||||
|
case s.Sum != nil:
|
||||||
|
if s.Sum.Field != "" {
|
||||||
|
field = s.Sum.Field
|
||||||
|
} else {
|
||||||
|
field = "*"
|
||||||
|
}
|
||||||
|
case s.Min != nil:
|
||||||
|
if s.Min.Field != "" {
|
||||||
|
field = s.Min.Field
|
||||||
|
} else {
|
||||||
|
field = "*"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
field = s.Field
|
||||||
|
}
|
||||||
|
|
||||||
|
colName := "*"
|
||||||
|
if field != "*" || (s.Count == nil && s.Sum == nil) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
colName, err = req.validateColumnName(schema, field)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case s.Count != nil:
|
||||||
|
as := s.Count.As
|
||||||
|
if as == "" {
|
||||||
|
as = fmt.Sprintf("%s_count", colName)
|
||||||
|
}
|
||||||
|
distinct := ""
|
||||||
|
if s.Count.Distinct {
|
||||||
|
distinct = "DISTINCT "
|
||||||
|
}
|
||||||
|
req.selectedFields = append(
|
||||||
|
req.selectedFields,
|
||||||
|
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
|
||||||
|
)
|
||||||
|
req.whitelistedFields = append(req.whitelistedFields, as)
|
||||||
|
|
||||||
|
case s.Sum != nil:
|
||||||
|
if s.Sum.As == "" {
|
||||||
|
return fmt.Errorf("missing 'as' for $sum")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
clause string
|
||||||
|
params map[string]any
|
||||||
|
)
|
||||||
|
|
||||||
|
if s.Sum.Field != "" {
|
||||||
|
clause = s.Sum.Field
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
clause, params, err = s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("in $sum: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req.mergeParams(params)
|
||||||
|
req.selectedFields = append(
|
||||||
|
req.selectedFields,
|
||||||
|
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
|
||||||
|
)
|
||||||
|
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
|
||||||
|
|
||||||
|
case s.Min != nil:
|
||||||
|
if s.Min.As == "" {
|
||||||
|
return fmt.Errorf("missing 'as' for $min")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
clause string
|
||||||
|
params map[string]any
|
||||||
|
)
|
||||||
|
|
||||||
|
if s.Min.Field != "" {
|
||||||
|
clause = field
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("in $min: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req.mergeParams(params)
|
||||||
|
req.selectedFields = append(
|
||||||
|
req.selectedFields,
|
||||||
|
fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As),
|
||||||
|
)
|
||||||
|
req.whitelistedFields = append(req.whitelistedFields, s.Min.As)
|
||||||
|
|
||||||
|
case s.Distinct != nil:
|
||||||
|
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
|
||||||
|
req.whitelistedFields = append(req.whitelistedFields, colName)
|
||||||
|
|
||||||
|
default:
|
||||||
|
req.selectedFields = append(req.selectedFields, colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) mergeParams(params map[string]any) {
|
||||||
|
if req.paramMap == nil {
|
||||||
|
req.paramMap = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range params {
|
||||||
|
req.paramMap[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) {
|
||||||
|
if len(req.GroupBy) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
groupBys := make([]string, len(req.GroupBy))
|
||||||
|
for idx, name := range req.GroupBy {
|
||||||
|
colName, err := req.validateColumnName(schema, name)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
groupBys[idx] = colName
|
||||||
|
}
|
||||||
|
groupByClause := "GROUP BY " + strings.Join(groupBys, ", ")
|
||||||
|
|
||||||
|
// if there are no explicitly selected fields we default to the
|
||||||
|
// group-by columns as that's what's expected most of the time anyway...
|
||||||
|
if len(req.selectedFields) == 0 {
|
||||||
|
req.selectedFields = append(req.selectedFields, groupBys...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupByClause, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) generateSelectClause() string {
|
||||||
|
selectClause := "*"
|
||||||
|
if len(req.selectedFields) > 0 {
|
||||||
|
selectClause = strings.Join(req.selectedFields, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return selectClause
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
|
||||||
|
if len(req.OrderBy) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
orderBys := make([]string, len(req.OrderBy))
|
||||||
|
for idx, sort := range req.OrderBy {
|
||||||
|
colName, err := req.validateColumnName(schema, sort.Field)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if sort.Desc {
|
||||||
|
orderBys[idx] = fmt.Sprintf("%s DESC", colName)
|
||||||
|
} else {
|
||||||
|
orderBys[idx] = fmt.Sprintf("%s ASC", colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "ORDER BY " + strings.Join(orderBys, ", "), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) {
|
||||||
|
colDef := schema.GetColumnDef(field)
|
||||||
|
if colDef != nil {
|
||||||
|
return colDef.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(req.whitelistedFields, field) {
|
||||||
|
return field, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(req.selectedFields, field) {
|
||||||
|
return field, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("column name %q not allowed", field)
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue