From ba72c204d3ea9f53b4793b10149b9a18abd6a369 Mon Sep 17 00:00:00 2001 From: Patrick Pacher Date: Thu, 14 Sep 2023 08:39:15 +0200 Subject: [PATCH] netquery: split up query payload into a dedicated file --- netquery/query.go | 17 --- netquery/query_handler.go | 277 +--------------------------------- netquery/query_request.go | 304 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 306 insertions(+), 292 deletions(-) create mode 100644 netquery/query_request.go diff --git a/netquery/query.go b/netquery/query.go index edbf3309..a4502849 100644 --- a/netquery/query.go +++ b/netquery/query.go @@ -86,23 +86,6 @@ type ( Value string `json:"value"` } - QueryRequestPayload struct { - Select Selects `json:"select"` - Query Query `json:"query"` - OrderBy OrderBys `json:"orderBy"` - GroupBy []string `json:"groupBy"` - TextSearch *TextSearch `json:"textSearch"` - // A list of databases to query. If left empty, - // both, the LiveDatabase and the HistoryDatabase are queried - Databases []DatabaseName `json:"databases"` - - Pagination - - selectedFields []string - whitelistedFields []string - paramMap map[string]interface{} - } - QueryActiveConnectionChartPayload struct { Query Query `json:"query"` TextSearch *TextSearch `json:"textSearch"` diff --git a/netquery/query_handler.go b/netquery/query_handler.go index bca2eac3..bf2933d5 100644 --- a/netquery/query_handler.go +++ b/netquery/query_handler.go @@ -2,7 +2,6 @@ package netquery import ( "bytes" - "context" "encoding/json" "errors" "fmt" @@ -12,8 +11,6 @@ import ( "strings" "time" - "golang.org/x/exp/slices" - "github.com/safing/portbase/log" "github.com/safing/portmaster/netquery/orm" ) @@ -32,7 +29,7 @@ type ( func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { start := time.Now() - requestPayload, err := qh.parseRequest(req) + requestPayload, err := parseQueryRequestPayload(req) if err != nil { http.Error(resp, err.Error(), http.StatusBadRequest) @@ -108,7 +105,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) { } } -func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, error) { //nolint:dupl +func parseQueryRequestPayload(req *http.Request) (*QueryRequestPayload, error) { //nolint:dupl var body io.Reader switch req.Method { @@ -138,275 +135,5 @@ func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, e return &requestPayload, nil } -func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { - if err := req.prepareSelectedFields(ctx, schema); err != nil { - return "", nil, fmt.Errorf("perparing selected fields: %w", err) - } - - // build the SQL where clause from the payload query - whereClause, paramMap, err := req.Query.toSQLWhereClause( - ctx, - "", - schema, - orm.DefaultEncodeConfig, - ) - if err != nil { - return "", nil, fmt.Errorf("generating where clause: %w", err) - } - - req.mergeParams(paramMap) - - if req.TextSearch != nil { - textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig) - if err != nil { - return "", nil, fmt.Errorf("generating text-search clause: %w", err) - } - - if textClause != "" { - if whereClause != "" { - whereClause += " AND " - } - - whereClause += textClause - - req.mergeParams(textParams) - } - } - - groupByClause, err := req.generateGroupByClause(schema) - if err != nil { - return "", nil, fmt.Errorf("generating group-by clause: %w", err) - } - - orderByClause, err := req.generateOrderByClause(schema) - if err != nil { - return "", nil, fmt.Errorf("generating order-by clause: %w", err) - } - - selectClause := req.generateSelectClause() - - if whereClause != "" { - whereClause = "WHERE " + whereClause - } - - // if no database is specified we default to LiveDatabase only. - if len(req.Databases) == 0 { - req.Databases = []DatabaseName{LiveDatabase} - } - - sources := make([]string, len(req.Databases)) - for idx, db := range req.Databases { - sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause) - } - - source := strings.Join(sources, " UNION ") - - query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) ` - - query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause() - - return strings.TrimSpace(query), req.paramMap, nil -} - -func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error { - for idx, s := range req.Select { - var field string - - switch { - case s.Count != nil: - field = s.Count.Field - case s.Distinct != nil: - field = *s.Distinct - case s.Sum != nil: - if s.Sum.Field != "" { - field = s.Sum.Field - } else { - field = "*" - } - case s.Min != nil: - if s.Min.Field != "" { - field = s.Min.Field - } else { - field = "*" - } - default: - field = s.Field - } - - colName := "*" - if field != "*" || (s.Count == nil && s.Sum == nil) { - var err error - - colName, err = req.validateColumnName(schema, field) - if err != nil { - return err - } - } - - switch { - case s.Count != nil: - as := s.Count.As - if as == "" { - as = fmt.Sprintf("%s_count", colName) - } - distinct := "" - if s.Count.Distinct { - distinct = "DISTINCT " - } - req.selectedFields = append( - req.selectedFields, - fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as), - ) - req.whitelistedFields = append(req.whitelistedFields, as) - - case s.Sum != nil: - if s.Sum.As == "" { - return fmt.Errorf("missing 'as' for $sum") - } - - var ( - clause string - params map[string]any - ) - - if s.Sum.Field != "" { - clause = s.Sum.Field - } else { - var err error - clause, params, err = s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig) - if err != nil { - return fmt.Errorf("in $sum: %w", err) - } - } - - req.mergeParams(params) - req.selectedFields = append( - req.selectedFields, - fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As), - ) - req.whitelistedFields = append(req.whitelistedFields, s.Sum.As) - - case s.Min != nil: - if s.Min.As == "" { - return fmt.Errorf("missing 'as' for $min") - } - - var ( - clause string - params map[string]any - ) - - if s.Min.Field != "" { - clause = field - } else { - var err error - clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig) - if err != nil { - return fmt.Errorf("in $min: %w", err) - } - } - - req.mergeParams(params) - req.selectedFields = append( - req.selectedFields, - fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As), - ) - req.whitelistedFields = append(req.whitelistedFields, s.Min.As) - - case s.Distinct != nil: - req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName)) - req.whitelistedFields = append(req.whitelistedFields, colName) - - default: - req.selectedFields = append(req.selectedFields, colName) - } - } - - return nil -} - -func (req *QueryRequestPayload) mergeParams(params map[string]any) { - if req.paramMap == nil { - req.paramMap = make(map[string]any) - } - - for key, value := range params { - req.paramMap[key] = value - } -} - -func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) { - if len(req.GroupBy) == 0 { - return "", nil - } - - groupBys := make([]string, len(req.GroupBy)) - for idx, name := range req.GroupBy { - colName, err := req.validateColumnName(schema, name) - if err != nil { - return "", err - } - - groupBys[idx] = colName - } - groupByClause := "GROUP BY " + strings.Join(groupBys, ", ") - - // if there are no explicitly selected fields we default to the - // group-by columns as that's what's expected most of the time anyway... - if len(req.selectedFields) == 0 { - req.selectedFields = append(req.selectedFields, groupBys...) - } - - return groupByClause, nil -} - -func (req *QueryRequestPayload) generateSelectClause() string { - selectClause := "*" - if len(req.selectedFields) > 0 { - selectClause = strings.Join(req.selectedFields, ", ") - } - - return selectClause -} - -func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) { - if len(req.OrderBy) == 0 { - return "", nil - } - - orderBys := make([]string, len(req.OrderBy)) - for idx, sort := range req.OrderBy { - colName, err := req.validateColumnName(schema, sort.Field) - if err != nil { - return "", err - } - - if sort.Desc { - orderBys[idx] = fmt.Sprintf("%s DESC", colName) - } else { - orderBys[idx] = fmt.Sprintf("%s ASC", colName) - } - } - - return "ORDER BY " + strings.Join(orderBys, ", "), nil -} - -func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) { - colDef := schema.GetColumnDef(field) - if colDef != nil { - return colDef.Name, nil - } - - if slices.Contains(req.whitelistedFields, field) { - return field, nil - } - - if slices.Contains(req.selectedFields, field) { - return field, nil - } - - return "", fmt.Errorf("column name %q not allowed", field) -} - // Compile time check. var _ http.Handler = new(QueryHandler) diff --git a/netquery/query_request.go b/netquery/query_request.go new file mode 100644 index 00000000..e294d6ea --- /dev/null +++ b/netquery/query_request.go @@ -0,0 +1,304 @@ +package netquery + +import ( + "context" + "fmt" + "strings" + + "github.com/safing/portmaster/netquery/orm" + "golang.org/x/exp/slices" +) + +type ( + QueryRequestPayload struct { + Select Selects `json:"select"` + Query Query `json:"query"` + OrderBy OrderBys `json:"orderBy"` + GroupBy []string `json:"groupBy"` + TextSearch *TextSearch `json:"textSearch"` + // A list of databases to query. If left empty, + // both, the LiveDatabase and the HistoryDatabase are queried + Databases []DatabaseName `json:"databases"` + + Pagination + + selectedFields []string + whitelistedFields []string + paramMap map[string]interface{} + } + + // BatchQueryRequestPayload describes the payload of a batch netquery + // query. The map key is used in the response to identify the results + // for each query of the batch request. + BatchQueryRequestPayload map[string]QueryRequestPayload +) + +func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { + if err := req.prepareSelectedFields(ctx, schema); err != nil { + return "", nil, fmt.Errorf("perparing selected fields: %w", err) + } + + // build the SQL where clause from the payload query + whereClause, paramMap, err := req.Query.toSQLWhereClause( + ctx, + "", + schema, + orm.DefaultEncodeConfig, + ) + if err != nil { + return "", nil, fmt.Errorf("generating where clause: %w", err) + } + + req.mergeParams(paramMap) + + if req.TextSearch != nil { + textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig) + if err != nil { + return "", nil, fmt.Errorf("generating text-search clause: %w", err) + } + + if textClause != "" { + if whereClause != "" { + whereClause += " AND " + } + + whereClause += textClause + + req.mergeParams(textParams) + } + } + + groupByClause, err := req.generateGroupByClause(schema) + if err != nil { + return "", nil, fmt.Errorf("generating group-by clause: %w", err) + } + + orderByClause, err := req.generateOrderByClause(schema) + if err != nil { + return "", nil, fmt.Errorf("generating order-by clause: %w", err) + } + + selectClause := req.generateSelectClause() + + if whereClause != "" { + whereClause = "WHERE " + whereClause + } + + // if no database is specified we default to LiveDatabase only. + if len(req.Databases) == 0 { + req.Databases = []DatabaseName{LiveDatabase} + } + + sources := make([]string, len(req.Databases)) + for idx, db := range req.Databases { + sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause) + } + + source := strings.Join(sources, " UNION ") + + query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) ` + + query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause() + + return strings.TrimSpace(query), req.paramMap, nil +} + +func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error { + for idx, s := range req.Select { + var field string + + switch { + case s.Count != nil: + field = s.Count.Field + case s.Distinct != nil: + field = *s.Distinct + case s.Sum != nil: + if s.Sum.Field != "" { + field = s.Sum.Field + } else { + field = "*" + } + case s.Min != nil: + if s.Min.Field != "" { + field = s.Min.Field + } else { + field = "*" + } + default: + field = s.Field + } + + colName := "*" + if field != "*" || (s.Count == nil && s.Sum == nil) { + var err error + + colName, err = req.validateColumnName(schema, field) + if err != nil { + return err + } + } + + switch { + case s.Count != nil: + as := s.Count.As + if as == "" { + as = fmt.Sprintf("%s_count", colName) + } + distinct := "" + if s.Count.Distinct { + distinct = "DISTINCT " + } + req.selectedFields = append( + req.selectedFields, + fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as), + ) + req.whitelistedFields = append(req.whitelistedFields, as) + + case s.Sum != nil: + if s.Sum.As == "" { + return fmt.Errorf("missing 'as' for $sum") + } + + var ( + clause string + params map[string]any + ) + + if s.Sum.Field != "" { + clause = s.Sum.Field + } else { + var err error + clause, params, err = s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig) + if err != nil { + return fmt.Errorf("in $sum: %w", err) + } + } + + req.mergeParams(params) + req.selectedFields = append( + req.selectedFields, + fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As), + ) + req.whitelistedFields = append(req.whitelistedFields, s.Sum.As) + + case s.Min != nil: + if s.Min.As == "" { + return fmt.Errorf("missing 'as' for $min") + } + + var ( + clause string + params map[string]any + ) + + if s.Min.Field != "" { + clause = field + } else { + var err error + clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig) + if err != nil { + return fmt.Errorf("in $min: %w", err) + } + } + + req.mergeParams(params) + req.selectedFields = append( + req.selectedFields, + fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As), + ) + req.whitelistedFields = append(req.whitelistedFields, s.Min.As) + + case s.Distinct != nil: + req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName)) + req.whitelistedFields = append(req.whitelistedFields, colName) + + default: + req.selectedFields = append(req.selectedFields, colName) + } + } + + return nil +} + +func (req *QueryRequestPayload) mergeParams(params map[string]any) { + if req.paramMap == nil { + req.paramMap = make(map[string]any) + } + + for key, value := range params { + req.paramMap[key] = value + } +} + +func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) { + if len(req.GroupBy) == 0 { + return "", nil + } + + groupBys := make([]string, len(req.GroupBy)) + for idx, name := range req.GroupBy { + colName, err := req.validateColumnName(schema, name) + if err != nil { + return "", err + } + + groupBys[idx] = colName + } + groupByClause := "GROUP BY " + strings.Join(groupBys, ", ") + + // if there are no explicitly selected fields we default to the + // group-by columns as that's what's expected most of the time anyway... + if len(req.selectedFields) == 0 { + req.selectedFields = append(req.selectedFields, groupBys...) + } + + return groupByClause, nil +} + +func (req *QueryRequestPayload) generateSelectClause() string { + selectClause := "*" + if len(req.selectedFields) > 0 { + selectClause = strings.Join(req.selectedFields, ", ") + } + + return selectClause +} + +func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) { + if len(req.OrderBy) == 0 { + return "", nil + } + + orderBys := make([]string, len(req.OrderBy)) + for idx, sort := range req.OrderBy { + colName, err := req.validateColumnName(schema, sort.Field) + if err != nil { + return "", err + } + + if sort.Desc { + orderBys[idx] = fmt.Sprintf("%s DESC", colName) + } else { + orderBys[idx] = fmt.Sprintf("%s ASC", colName) + } + } + + return "ORDER BY " + strings.Join(orderBys, ", "), nil +} + +func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, field string) (string, error) { + colDef := schema.GetColumnDef(field) + if colDef != nil { + return colDef.Name, nil + } + + if slices.Contains(req.whitelistedFields, field) { + return field, nil + } + + if slices.Contains(req.selectedFields, field) { + return field, nil + } + + return "", fmt.Errorf("column name %q not allowed", field) +}