mirror of
https://github.com/safing/portmaster
synced 2025-09-14 08:49:40 +00:00
Add query and chart support with multiple fixes to ORM package
This commit is contained in:
parent
0d2ec9df75
commit
25aceaf103
11 changed files with 535 additions and 117 deletions
118
netquery/chart_handler.go
Normal file
118
netquery/chart_handler.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package netquery
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portmaster/netquery/orm"
|
||||
)
|
||||
|
||||
type ChartHandler struct {
|
||||
Database *Database
|
||||
}
|
||||
|
||||
func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||
requestPayload, err := ch.parseRequest(req)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
|
||||
if err != nil {
|
||||
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// actually execute the query against the database and collect the result
|
||||
var result []map[string]interface{}
|
||||
if err := ch.Database.Execute(
|
||||
req.Context(),
|
||||
query,
|
||||
orm.WithNamedArgs(paramMap),
|
||||
orm.WithResult(&result),
|
||||
orm.WithSchema(*ch.Database.Schema),
|
||||
); err != nil {
|
||||
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// send the HTTP status code
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
|
||||
// prepare the result encoder.
|
||||
enc := json.NewEncoder(resp)
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
|
||||
enc.Encode(map[string]interface{}{
|
||||
"results": result,
|
||||
"query": query,
|
||||
"params": paramMap,
|
||||
})
|
||||
}
|
||||
|
||||
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) {
|
||||
var body io.Reader
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodPost, http.MethodPut:
|
||||
body = req.Body
|
||||
case http.MethodGet:
|
||||
body = strings.NewReader(req.URL.Query().Get("q"))
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid HTTP method")
|
||||
}
|
||||
|
||||
var requestPayload QueryActiveConnectionChartPayload
|
||||
blob, err := ioutil.ReadAll(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body" + err.Error())
|
||||
}
|
||||
|
||||
body = bytes.NewReader(blob)
|
||||
|
||||
dec := json.NewDecoder(body)
|
||||
dec.DisallowUnknownFields()
|
||||
|
||||
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("invalid query: %w", err)
|
||||
}
|
||||
|
||||
return &requestPayload, nil
|
||||
}
|
||||
|
||||
func (req *QueryActiveConnectionChartPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||
template := `
|
||||
WITH RECURSIVE epoch(x) AS (
|
||||
SELECT strftime('%%s')-600
|
||||
UNION ALL
|
||||
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
|
||||
)
|
||||
SELECT x as timestamp, COUNT(*) AS value FROM epoch
|
||||
JOIN connections
|
||||
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 > timestamp+0)
|
||||
%s
|
||||
GROUP BY round(timestamp/10, 0)*10;`
|
||||
|
||||
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if clause == "" {
|
||||
return fmt.Sprintf(template, ""), map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf(template, "WHERE ( "+clause+")"), params, nil
|
||||
}
|
|
@ -65,6 +65,7 @@ type (
|
|||
// reused afterwards.
|
||||
ID string `sqlite:"id,primary"`
|
||||
ProfileID string `sqlite:"profile"`
|
||||
ProfileSource string `sqlite:"profileSource"`
|
||||
Path string `sqlite:"path"`
|
||||
Type string `sqlite:"type,varchar(8)"`
|
||||
External bool `sqlite:"external"`
|
||||
|
@ -87,8 +88,10 @@ type (
|
|||
Tunneled bool `sqlite:"tunneled"`
|
||||
Encrypted bool `sqlite:"encrypted"`
|
||||
Internal bool `sqlite:"internal"`
|
||||
Inbound bool `sqlite:"inbound"`
|
||||
Direction string `sqlite:"direction"`
|
||||
ExtraData json.RawMessage `sqlite:"extra_data"`
|
||||
Allowed *bool `sqlite:"allowed"`
|
||||
ProfileRevision int `sqlite:"profile_revision"`
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -190,11 +193,11 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
|
|||
// probably not worth the cylces...
|
||||
func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) {
|
||||
where := `WHERE ended IS NOT NULL
|
||||
AND datetime(ended) < :threshold`
|
||||
AND datetime(ended) < datetime(:threshold)`
|
||||
sql := "DELETE FROM connections " + where + ";"
|
||||
|
||||
args := orm.WithNamedArgs(map[string]interface{}{
|
||||
":threshold": threshold,
|
||||
":threshold": threshold.UTC().Format(orm.SqliteTimeFormat),
|
||||
})
|
||||
|
||||
var result []struct {
|
||||
|
@ -232,7 +235,7 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error {
|
|||
if err := sqlitex.ExecuteTransient(db.conn, "SELECT * FROM connections", &sqlitex.ExecOptions{
|
||||
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||
var c Conn
|
||||
if err := orm.DecodeStmt(ctx, stmt, &c, orm.DefaultDecodeConfig); err != nil {
|
||||
if err := orm.DecodeStmt(ctx, db.Schema, stmt, &c, orm.DefaultDecodeConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -155,6 +155,10 @@ func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn C
|
|||
func convertConnection(conn *network.Connection) (*Conn, error) {
|
||||
conn.Lock()
|
||||
defer conn.Unlock()
|
||||
direction := "outbound"
|
||||
if conn.Inbound {
|
||||
direction = "inbound"
|
||||
}
|
||||
|
||||
c := Conn{
|
||||
ID: genConnID(conn),
|
||||
|
@ -168,10 +172,25 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
|
|||
Tunneled: conn.Tunneled,
|
||||
Encrypted: conn.Encrypted,
|
||||
Internal: conn.Internal,
|
||||
Inbound: conn.Inbound,
|
||||
Direction: direction,
|
||||
Type: ConnectionTypeToString[conn.Type],
|
||||
ProfileID: conn.ProcessContext.ProfileName,
|
||||
ProfileID: conn.ProcessContext.Profile,
|
||||
ProfileSource: conn.ProcessContext.Source,
|
||||
Path: conn.ProcessContext.BinaryPath,
|
||||
ProfileRevision: int(conn.ProfileRevisionCounter),
|
||||
}
|
||||
|
||||
switch conn.Type {
|
||||
case network.DNSRequest:
|
||||
c.Type = "dns"
|
||||
case network.IPConnection:
|
||||
c.Type = "ip"
|
||||
}
|
||||
|
||||
switch conn.Verdict {
|
||||
case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
|
||||
accepted := true
|
||||
c.Allowed = &accepted
|
||||
}
|
||||
|
||||
if conn.Ended > 0 {
|
||||
|
@ -181,6 +200,10 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
|
|||
|
||||
extraData := map[string]interface{}{}
|
||||
|
||||
if conn.TunnelContext != nil {
|
||||
extraData["tunnel"] = conn.TunnelContext
|
||||
}
|
||||
|
||||
if conn.Entity != nil {
|
||||
extraData["cname"] = conn.Entity.CNAME
|
||||
extraData["blockedByLists"] = conn.Entity.BlockedByLists
|
||||
|
|
|
@ -63,6 +63,10 @@ func (m *Module) Prepare() error {
|
|||
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
|
||||
}
|
||||
|
||||
chartHandler := &ChartHandler{
|
||||
Database: m.sqlStore,
|
||||
}
|
||||
|
||||
// FIXME(ppacher): use appropriate permissions for this
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "netquery/query",
|
||||
|
@ -77,6 +81,19 @@ func (m *Module) Prepare() error {
|
|||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
if err := api.RegisterEndpoint(api.Endpoint{
|
||||
Path: "netquery/charts/connection-active",
|
||||
MimeType: "application/json",
|
||||
Read: api.PermitAnyone,
|
||||
Write: api.PermitAnyone,
|
||||
BelongsTo: m.Module,
|
||||
HandlerFunc: chartHandler.ServeHTTP,
|
||||
Name: "Query In-Memory Database",
|
||||
Description: "Query the in-memory sqlite database",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -120,11 +137,12 @@ func (mod *Module) Start() error {
|
|||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(10 * time.Second):
|
||||
count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold))
|
||||
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
|
||||
count, err := mod.sqlStore.Cleanup(ctx, threshold)
|
||||
if err != nil {
|
||||
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
||||
} else {
|
||||
log.Infof("netquery: successfully removed %d old rows", count)
|
||||
log.Infof("netquery: successfully removed %d old rows that ended before %s", count, threshold)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -135,7 +153,7 @@ func (mod *Module) Start() error {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-time.After(1 * time.Second):
|
||||
count, err := mod.sqlStore.CountRows(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -51,7 +52,7 @@ type (
|
|||
}
|
||||
|
||||
// DecodeFunc is called for each non-basic type during decoding.
|
||||
DecodeFunc func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error)
|
||||
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
|
||||
|
||||
DecodeConfig struct {
|
||||
DecodeHooks []DecodeFunc
|
||||
|
@ -63,7 +64,7 @@ type (
|
|||
// be specified to provide support for special types.
|
||||
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
|
||||
//
|
||||
func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeConfig) error {
|
||||
func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result interface{}, cfg DecodeConfig) error {
|
||||
// make sure we got something to decode into ...
|
||||
if result == nil {
|
||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||
|
@ -71,7 +72,7 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
|
|||
|
||||
// fast path for decoding into a map
|
||||
if mp, ok := result.(*map[string]interface{}); ok {
|
||||
return decodeIntoMap(ctx, stmt, mp)
|
||||
return decodeIntoMap(ctx, schema, stmt, mp, cfg)
|
||||
}
|
||||
|
||||
// make sure we got a pointer in result
|
||||
|
@ -147,10 +148,13 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
|
|||
value = storage.Elem()
|
||||
}
|
||||
|
||||
colDef := schema.GetColumnDef(colName)
|
||||
|
||||
// execute all decode hooks but make sure we use decodeBasic() as the
|
||||
// last one.
|
||||
columnValue, err := runDecodeHooks(
|
||||
i,
|
||||
colDef,
|
||||
stmt,
|
||||
fieldType,
|
||||
value,
|
||||
|
@ -188,10 +192,19 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
|
|||
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing
|
||||
//
|
||||
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) {
|
||||
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
|
||||
// if we have the column definition available we
|
||||
// use the target go type from there.
|
||||
outType := outval.Type()
|
||||
|
||||
if colDef != nil {
|
||||
outType = colDef.GoType
|
||||
}
|
||||
|
||||
// we only care about "time.Time" here
|
||||
if outval.Type().String() != "time.Time" {
|
||||
return nil, nil
|
||||
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
|
||||
log.Printf("not decoding %s %v", outType, colDef)
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch stmt.ColumnType(colIdx) {
|
||||
|
@ -201,39 +214,61 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
|||
// TODO(ppacher): actually split the tag value at "," and search
|
||||
// the slice for "unixnano"
|
||||
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
|
||||
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), nil
|
||||
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), true, nil
|
||||
}
|
||||
|
||||
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), nil
|
||||
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), true, nil
|
||||
|
||||
case sqlite.TypeText:
|
||||
// stored ISO8601 but does not have any timezone information
|
||||
// assigned so we always treat it as loc here.
|
||||
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
||||
return nil, false, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
return t, true, nil
|
||||
|
||||
case sqlite.TypeFloat:
|
||||
// stored as Julian day numbers
|
||||
return nil, fmt.Errorf("REAL storage type not support for time.Time")
|
||||
return nil, false, fmt.Errorf("REAL storage type not support for time.Time")
|
||||
|
||||
case sqlite.TypeNull:
|
||||
return nil, true, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported storage type for time.Time: %s", outval.Type())
|
||||
return nil, false, fmt.Errorf("unsupported storage type for time.Time: %s", stmt.ColumnType(colIdx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) error {
|
||||
func decodeIntoMap(ctx context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
|
||||
if *mp == nil {
|
||||
*mp = make(map[string]interface{})
|
||||
}
|
||||
|
||||
for i := 0; i < stmt.ColumnCount(); i++ {
|
||||
var x interface{}
|
||||
val, err := decodeBasic()(i, stmt, reflect.StructField{}, reflect.ValueOf(&x).Elem())
|
||||
|
||||
colDef := schema.GetColumnDef(stmt.ColumnName(i))
|
||||
|
||||
outVal := reflect.ValueOf(&x).Elem()
|
||||
fieldType := reflect.StructField{}
|
||||
if colDef != nil {
|
||||
outVal = reflect.New(colDef.GoType).Elem()
|
||||
fieldType = reflect.StructField{
|
||||
Type: colDef.GoType,
|
||||
}
|
||||
}
|
||||
|
||||
val, err := runDecodeHooks(
|
||||
i,
|
||||
colDef,
|
||||
stmt,
|
||||
fieldType,
|
||||
outVal,
|
||||
append(cfg.DecodeHooks, decodeBasic()),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
|
||||
}
|
||||
|
@ -245,56 +280,99 @@ func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) e
|
|||
}
|
||||
|
||||
func decodeBasic() DecodeFunc {
|
||||
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) {
|
||||
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (result interface{}, handled bool, err error) {
|
||||
valueKind := getKind(outval)
|
||||
colType := stmt.ColumnType(colIdx)
|
||||
colName := stmt.ColumnName(colIdx)
|
||||
|
||||
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
|
||||
|
||||
// if we have the column definition available we
|
||||
// use the target go type from there.
|
||||
if colDef != nil {
|
||||
valueKind = normalizeKind(colDef.GoType.Kind())
|
||||
|
||||
// if we have a column defintion we try to convert the value to
|
||||
// the actual Go-type that was used in the model.
|
||||
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
|
||||
// or that type aliases like (type myInt int) are decoded into myInt instead of int
|
||||
defer func() {
|
||||
if handled {
|
||||
t := reflect.New(colDef.GoType).Elem()
|
||||
|
||||
if result == nil || reflect.ValueOf(result).IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.ValueOf(result).Type().ConvertibleTo(colDef.GoType) {
|
||||
result = reflect.ValueOf(result).Convert(colDef.GoType).Interface()
|
||||
}
|
||||
t.Set(reflect.ValueOf(result))
|
||||
|
||||
result = t.Interface()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
log.Printf("decoding %s into kind %s", colName, valueKind)
|
||||
|
||||
if colType == sqlite.TypeNull {
|
||||
if colDef != nil && colDef.Nullable {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
if colDef != nil && !colDef.Nullable {
|
||||
return reflect.New(colDef.GoType).Elem().Interface(), true, nil
|
||||
}
|
||||
|
||||
if outval.Kind() == reflect.Ptr {
|
||||
return nil, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
switch valueKind {
|
||||
case reflect.String:
|
||||
if colType != sqlite.TypeText {
|
||||
return nil, errInvalidType
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnText(colIdx), nil
|
||||
return stmt.ColumnText(colIdx), true, nil
|
||||
|
||||
case reflect.Bool:
|
||||
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
|
||||
// with INTEGER affinity.
|
||||
if colType != sqlite.TypeInteger {
|
||||
return nil, errInvalidType
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnBool(colIdx), nil
|
||||
return stmt.ColumnBool(colIdx), true, nil
|
||||
|
||||
case reflect.Float64:
|
||||
if colType != sqlite.TypeFloat {
|
||||
return nil, errInvalidType
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
return stmt.ColumnFloat(colIdx), nil
|
||||
return stmt.ColumnFloat(colIdx), true, nil
|
||||
|
||||
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
|
||||
if colType != sqlite.TypeInteger {
|
||||
return nil, errInvalidType
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
|
||||
return stmt.ColumnInt(colIdx), nil
|
||||
return stmt.ColumnInt(colIdx), true, nil
|
||||
|
||||
case reflect.Slice:
|
||||
if outval.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, fmt.Errorf("slices other than []byte for BLOB are not supported")
|
||||
return nil, false, fmt.Errorf("slices other than []byte for BLOB are not supported")
|
||||
}
|
||||
|
||||
if colType != sqlite.TypeBlob {
|
||||
return nil, errInvalidType
|
||||
return nil, false, errInvalidType
|
||||
}
|
||||
|
||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
}
|
||||
|
||||
return columnValue, nil
|
||||
return columnValue, true, nil
|
||||
|
||||
case reflect.Interface:
|
||||
var (
|
||||
|
@ -306,7 +384,7 @@ func decodeBasic() DecodeFunc {
|
|||
t = reflect.TypeOf([]byte{})
|
||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||
}
|
||||
x = columnValue
|
||||
|
||||
|
@ -327,20 +405,20 @@ func decodeBasic() DecodeFunc {
|
|||
x = nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported column type %s", colType)
|
||||
return nil, false, fmt.Errorf("unsupported column type %s", colType)
|
||||
}
|
||||
|
||||
if t == nil {
|
||||
return nil, nil
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
target := reflect.New(t).Elem()
|
||||
target.Set(reflect.ValueOf(x))
|
||||
|
||||
return target.Interface(), nil
|
||||
return target.Interface(), true, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("cannot decode into %s", valueKind)
|
||||
return nil, false, fmt.Errorf("cannot decode into %s", valueKind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -362,14 +440,14 @@ func sqlColumnName(fieldType reflect.StructField) string {
|
|||
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
|
||||
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
|
||||
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
|
||||
func runDecodeHooks(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
|
||||
func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
|
||||
for _, fn := range hooks {
|
||||
res, err := fn(colIdx, stmt, fieldDef, outval)
|
||||
res, end, err := fn(colIdx, colDef, stmt, fieldDef, outval)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res != nil {
|
||||
if end {
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -82,6 +84,8 @@ func (ett *exampleTimeTypes) Equal(other interface{}) bool {
|
|||
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
|
||||
}
|
||||
|
||||
type myInt int
|
||||
|
||||
type exampleTimeNano struct {
|
||||
T time.Time `sqlite:",unixnano"`
|
||||
}
|
||||
|
@ -102,6 +106,7 @@ func Test_Decoder(t *testing.T) {
|
|||
cases := []struct {
|
||||
Desc string
|
||||
Stmt testStmt
|
||||
ColumnDef []ColumnDef
|
||||
Result interface{}
|
||||
Expected interface{}
|
||||
}{
|
||||
|
@ -114,6 +119,7 @@ func Test_Decoder(t *testing.T) {
|
|||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Decoding into basic types",
|
||||
|
@ -132,6 +138,7 @@ func Test_Decoder(t *testing.T) {
|
|||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
S: "string value",
|
||||
|
@ -157,6 +164,7 @@ func Test_Decoder(t *testing.T) {
|
|||
1.2,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
S: "string value",
|
||||
|
@ -178,6 +186,7 @@ func Test_Decoder(t *testing.T) {
|
|||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
F: 1.2,
|
||||
|
@ -201,6 +210,7 @@ func Test_Decoder(t *testing.T) {
|
|||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
s := "string value"
|
||||
|
@ -231,6 +241,7 @@ func Test_Decoder(t *testing.T) {
|
|||
true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
s := "string value"
|
||||
|
@ -255,6 +266,7 @@ func Test_Decoder(t *testing.T) {
|
|||
1,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleStructTags{},
|
||||
&exampleStructTags{
|
||||
S: "string value",
|
||||
|
@ -280,6 +292,7 @@ func Test_Decoder(t *testing.T) {
|
|||
1,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleIntConv{},
|
||||
&exampleIntConv{
|
||||
1, 1, 1, 1, 1,
|
||||
|
@ -301,6 +314,7 @@ func Test_Decoder(t *testing.T) {
|
|||
1.0,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleFieldTypes{},
|
||||
&exampleFieldTypes{
|
||||
F: 1.0,
|
||||
|
@ -322,6 +336,7 @@ func Test_Decoder(t *testing.T) {
|
|||
1.0,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&examplePointerTypes{},
|
||||
func() interface{} {
|
||||
f := 1.0
|
||||
|
@ -340,6 +355,7 @@ func Test_Decoder(t *testing.T) {
|
|||
([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleBlobTypes{},
|
||||
&exampleBlobTypes{
|
||||
B: ([]byte)("hello world"),
|
||||
|
@ -356,6 +372,7 @@ func Test_Decoder(t *testing.T) {
|
|||
([]byte)("hello world"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleJSONRawTypes{},
|
||||
&exampleJSONRawTypes{
|
||||
B: (json.RawMessage)("hello world"),
|
||||
|
@ -374,6 +391,7 @@ func Test_Decoder(t *testing.T) {
|
|||
int(refTime.Unix()),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleTimeTypes{},
|
||||
&exampleTimeTypes{
|
||||
T: refTime,
|
||||
|
@ -393,6 +411,7 @@ func Test_Decoder(t *testing.T) {
|
|||
int(refTime.UnixNano()),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleTimeNano{},
|
||||
&exampleTimeNano{
|
||||
T: refTime,
|
||||
|
@ -411,6 +430,7 @@ func Test_Decoder(t *testing.T) {
|
|||
"value2",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
&exampleInterface{},
|
||||
func() interface{} {
|
||||
var x interface{}
|
||||
|
@ -439,6 +459,7 @@ func Test_Decoder(t *testing.T) {
|
|||
[]byte("blob value"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"I": 1,
|
||||
|
@ -447,14 +468,91 @@ func Test_Decoder(t *testing.T) {
|
|||
"B": []byte("blob value"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding using type-hints",
|
||||
testStmt{
|
||||
columns: []string{"B", "T"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeInteger,
|
||||
sqlite.TypeText,
|
||||
},
|
||||
values: []interface{}{
|
||||
true,
|
||||
refTime.Format(SqliteTimeFormat),
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "B",
|
||||
Type: sqlite.TypeInteger,
|
||||
GoType: reflect.TypeOf(true),
|
||||
},
|
||||
{
|
||||
Name: "T",
|
||||
Type: sqlite.TypeText,
|
||||
GoType: reflect.TypeOf(time.Time{}),
|
||||
IsTime: true,
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"B": true,
|
||||
"T": refTime,
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into type aliases",
|
||||
testStmt{
|
||||
columns: []string{"B"},
|
||||
types: []sqlite.ColumnType{
|
||||
sqlite.TypeBlob,
|
||||
},
|
||||
values: []interface{}{
|
||||
[]byte(`{"foo": "bar}`),
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "B",
|
||||
Type: sqlite.TypeBlob,
|
||||
GoType: reflect.TypeOf(json.RawMessage(`{"foo": "bar}`)),
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"B": json.RawMessage(`{"foo": "bar}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
"Decoding into type aliases #2",
|
||||
testStmt{
|
||||
columns: []string{"I"},
|
||||
types: []sqlite.ColumnType{sqlite.TypeInteger},
|
||||
values: []interface{}{
|
||||
10,
|
||||
},
|
||||
},
|
||||
[]ColumnDef{
|
||||
{
|
||||
Name: "I",
|
||||
Type: sqlite.TypeInteger,
|
||||
GoType: reflect.TypeOf(myInt(0)),
|
||||
},
|
||||
},
|
||||
new(map[string]interface{}),
|
||||
&map[string]interface{}{
|
||||
"I": myInt(10),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for idx := range cases {
|
||||
c := cases[idx]
|
||||
t.Run(c.Desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
//t.Parallel()
|
||||
|
||||
err := DecodeStmt(ctx, c.Stmt, c.Result, DefaultDecodeConfig)
|
||||
log.Println(c.Desc)
|
||||
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
|
||||
if fn, ok := c.Expected.(func() interface{}); ok {
|
||||
c.Expected = fn()
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ type (
|
|||
NamedArgs map[string]interface{}
|
||||
Result interface{}
|
||||
DecodeConfig DecodeConfig
|
||||
Schema TableSchema
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -56,6 +57,12 @@ func WithNamedArgs(args map[string]interface{}) QueryOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithSchema(tbl TableSchema) QueryOption {
|
||||
return func(opts *queryOpts) {
|
||||
opts.Schema = tbl
|
||||
}
|
||||
}
|
||||
|
||||
// WithResult sets the result receiver. result is expected to
|
||||
// be a pointer to a slice of struct or map types.
|
||||
//
|
||||
|
@ -136,7 +143,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
|
|||
|
||||
currentField = reflect.New(valElemType)
|
||||
|
||||
if err := DecodeStmt(ctx, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ type (
|
|||
Name string
|
||||
Nullable bool
|
||||
Type sqlite.ColumnType
|
||||
GoType reflect.Type
|
||||
Length int
|
||||
PrimaryKey bool
|
||||
AutoIncrement bool
|
||||
|
@ -145,6 +146,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
|
|||
ft = fieldType.Type.Elem()
|
||||
}
|
||||
|
||||
def.GoType = ft
|
||||
kind := normalizeKind(ft.Kind())
|
||||
|
||||
switch kind {
|
||||
|
|
|
@ -16,6 +16,12 @@ import (
|
|||
type (
|
||||
Query map[string][]Matcher
|
||||
|
||||
MatchType interface {
|
||||
Operator() string
|
||||
}
|
||||
|
||||
Equal interface{}
|
||||
|
||||
Matcher struct {
|
||||
Equal interface{} `json:"$eq,omitempty"`
|
||||
NotEqual interface{} `json:"$ne,omitempty"`
|
||||
|
@ -27,12 +33,22 @@ type (
|
|||
Count struct {
|
||||
As string `json:"as"`
|
||||
Field string `json:"field"`
|
||||
Distinct bool `json:"distict"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
Sum struct {
|
||||
Condition Query `json:"condition"`
|
||||
As string `json:"as"`
|
||||
Distinct bool `json:"distinct"`
|
||||
}
|
||||
|
||||
// NOTE: whenever adding support for new operators make sure
|
||||
// to update UnmarshalJSON as well.
|
||||
Select struct {
|
||||
Field string `json:"field"`
|
||||
Count *Count `json:"$count"`
|
||||
Count *Count `json:"$count,omitempty"`
|
||||
Sum *Sum `json:"$sum,omitempty"`
|
||||
Distinct *string `json:"$distinct"`
|
||||
}
|
||||
|
||||
Selects []Select
|
||||
|
@ -45,6 +61,11 @@ type (
|
|||
|
||||
selectedFields []string
|
||||
whitelistedFields []string
|
||||
paramMap map[string]interface{}
|
||||
}
|
||||
|
||||
QueryActiveConnectionChartPayload struct {
|
||||
Query Query `json:"query"`
|
||||
}
|
||||
|
||||
OrderBy struct {
|
||||
|
@ -179,15 +200,15 @@ func (match Matcher) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
var (
|
||||
queryParts []string
|
||||
params = make(map[string]interface{})
|
||||
errs = new(multierror.Error)
|
||||
key = fmt.Sprintf("%s%d", colDef.Name, idx)
|
||||
key = fmt.Sprintf("%s%s", colDef.Name, suffix)
|
||||
)
|
||||
|
||||
add := func(operator, suffix string, values ...interface{}) {
|
||||
add := func(operator, suffix string, list bool, values ...interface{}) {
|
||||
var placeholder []string
|
||||
|
||||
for idx, value := range values {
|
||||
|
@ -204,7 +225,7 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
|
|||
params[uniqKey] = encodedValue
|
||||
}
|
||||
|
||||
if len(placeholder) == 1 {
|
||||
if len(placeholder) == 1 && !list {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0]))
|
||||
} else {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", ")))
|
||||
|
@ -212,23 +233,23 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
|
|||
}
|
||||
|
||||
if match.Equal != nil {
|
||||
add("=", "eq", match.Equal)
|
||||
add("=", "eq", false, match.Equal)
|
||||
}
|
||||
|
||||
if match.NotEqual != nil {
|
||||
add("!=", "ne", match.NotEqual)
|
||||
add("!=", "ne", false, match.NotEqual)
|
||||
}
|
||||
|
||||
if match.In != nil {
|
||||
add("IN", "in", match.In...)
|
||||
add("IN", "in", true, match.In...)
|
||||
}
|
||||
|
||||
if match.NotIn != nil {
|
||||
add("NOT IN", "notin", match.NotIn...)
|
||||
add("NOT IN", "notin", true, match.NotIn...)
|
||||
}
|
||||
|
||||
if match.Like != "" {
|
||||
add("LIKE", "like", match.Like)
|
||||
add("LIKE", "like", false, match.Like)
|
||||
}
|
||||
|
||||
if len(queryParts) == 0 {
|
||||
|
@ -244,7 +265,7 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
|
|||
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||
if len(query) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
@ -279,7 +300,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, enc
|
|||
|
||||
queryParts := make([]string, len(values))
|
||||
for idx, val := range values {
|
||||
matcherQuery, params, err := val.toSQLConditionClause(ctx, idx, "AND", colDef, encoderConfig)
|
||||
matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig)
|
||||
if err != nil {
|
||||
errs.Errors = append(errs.Errors,
|
||||
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
|
||||
|
@ -361,6 +382,8 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
|
|||
var res struct {
|
||||
Field string `json:"field"`
|
||||
Count *Count `json:"$count"`
|
||||
Sum *Sum `json:"$sum"`
|
||||
Distinct *string `json:"$distinct"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(blob, &res); err != nil {
|
||||
|
@ -369,6 +392,8 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
|
|||
|
||||
sel.Count = res.Count
|
||||
sel.Field = res.Field
|
||||
sel.Distinct = res.Distinct
|
||||
sel.Sum = res.Sum
|
||||
|
||||
if sel.Count != nil && sel.Count.As != "" {
|
||||
if !charOnlyRegexp.MatchString(sel.Count.As) {
|
||||
|
|
|
@ -58,6 +58,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
|||
query,
|
||||
orm.WithNamedArgs(paramMap),
|
||||
orm.WithResult(&result),
|
||||
orm.WithSchema(*qh.Database.Schema),
|
||||
); err != nil {
|
||||
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
|
@ -139,13 +140,14 @@ func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, e
|
|||
}
|
||||
|
||||
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||
if err := req.prepareSelectedFields(schema); err != nil {
|
||||
if err := req.prepareSelectedFields(ctx, schema); err != nil {
|
||||
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
|
||||
}
|
||||
|
||||
// build the SQL where clause from the payload query
|
||||
whereClause, paramMap, err := req.Query.toSQLWhereClause(
|
||||
ctx,
|
||||
"",
|
||||
schema,
|
||||
orm.DefaultEncodeConfig,
|
||||
)
|
||||
|
@ -153,6 +155,14 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
|||
return "", nil, fmt.Errorf("ganerating where clause: %w", err)
|
||||
}
|
||||
|
||||
if req.paramMap == nil {
|
||||
req.paramMap = make(map[string]interface{})
|
||||
}
|
||||
|
||||
for key, val := range paramMap {
|
||||
req.paramMap[key] = val
|
||||
}
|
||||
|
||||
// build the actual SQL query statement
|
||||
// FIXME(ppacher): add support for group-by and sort-by
|
||||
|
||||
|
@ -173,20 +183,26 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
|||
}
|
||||
query += " " + groupByClause + " " + orderByClause
|
||||
|
||||
return query, paramMap, nil
|
||||
return query, req.paramMap, nil
|
||||
}
|
||||
|
||||
func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) error {
|
||||
for _, s := range req.Select {
|
||||
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
|
||||
for idx, s := range req.Select {
|
||||
var field string
|
||||
if s.Count != nil {
|
||||
switch {
|
||||
case s.Count != nil:
|
||||
field = s.Count.Field
|
||||
} else {
|
||||
case s.Distinct != nil:
|
||||
field = *s.Distinct
|
||||
case s.Sum != nil:
|
||||
// field is not used in case of $sum
|
||||
field = "*"
|
||||
default:
|
||||
field = s.Field
|
||||
}
|
||||
|
||||
colName := "*"
|
||||
if field != "*" || s.Count == nil {
|
||||
if field != "*" || (s.Count == nil && s.Sum == nil) {
|
||||
var err error
|
||||
|
||||
colName, err = req.validateColumnName(schema, field)
|
||||
|
@ -195,7 +211,8 @@ func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) e
|
|||
}
|
||||
}
|
||||
|
||||
if s.Count != nil {
|
||||
switch {
|
||||
case s.Count != nil:
|
||||
var as = s.Count.As
|
||||
if as == "" {
|
||||
as = fmt.Sprintf("%s_count", colName)
|
||||
|
@ -204,9 +221,34 @@ func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) e
|
|||
if s.Count.Distinct {
|
||||
distinct = "DISTINCT "
|
||||
}
|
||||
req.selectedFields = append(req.selectedFields, fmt.Sprintf("COUNT(%s%s) as %s", distinct, colName, as))
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, as)
|
||||
} else {
|
||||
|
||||
case s.Sum != nil:
|
||||
if s.Sum.As == "" {
|
||||
return fmt.Errorf("missing 'as' for $sum")
|
||||
}
|
||||
|
||||
clause, params, err := s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in $sum: %w", err)
|
||||
}
|
||||
|
||||
req.paramMap = params
|
||||
req.selectedFields = append(
|
||||
req.selectedFields,
|
||||
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
|
||||
)
|
||||
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
|
||||
|
||||
case s.Distinct != nil:
|
||||
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
|
||||
req.whitelistedFields = append(req.whitelistedFields, colName)
|
||||
|
||||
default:
|
||||
req.selectedFields = append(req.selectedFields, colName)
|
||||
}
|
||||
}
|
||||
|
@ -251,6 +293,10 @@ func (req *QueryRequestPayload) generateSelectClause() string {
|
|||
}
|
||||
|
||||
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
|
||||
if len(req.OrderBy) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var orderBys = make([]string, len(req.OrderBy))
|
||||
for idx, sort := range req.OrderBy {
|
||||
colName, err := req.validateColumnName(schema, sort.Field)
|
||||
|
@ -286,7 +332,7 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
|
|||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("column name %s not allowed", field)
|
||||
return "", fmt.Errorf("column name %q not allowed", field)
|
||||
}
|
||||
|
||||
// compile time check
|
||||
|
|
|
@ -228,7 +228,7 @@ func Test_QueryBuilder(t *testing.T) {
|
|||
for idx, c := range cases {
|
||||
t.Run(c.N, func(t *testing.T) {
|
||||
//t.Parallel()
|
||||
str, params, err := c.Q.toSQLWhereClause(context.TODO(), tbl, orm.DefaultEncodeConfig)
|
||||
str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
|
||||
|
||||
if c.E != nil {
|
||||
if assert.Error(t, err) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue