Add query and chart support with multiple fixes to ORM package

This commit is contained in:
Patrick Pacher 2022-05-03 16:11:12 +02:00
parent 0d2ec9df75
commit 25aceaf103
No known key found for this signature in database
GPG key ID: E8CD2DA160925A6D
11 changed files with 535 additions and 117 deletions

118
netquery/chart_handler.go Normal file
View file

@ -0,0 +1,118 @@
package netquery
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/safing/portmaster/netquery/orm"
)
type ChartHandler struct {
Database *Database
}
func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
requestPayload, err := ch.parseRequest(req)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
if err != nil {
http.Error(resp, err.Error(), http.StatusBadRequest)
return
}
// actually execute the query against the database and collect the result
var result []map[string]interface{}
if err := ch.Database.Execute(
req.Context(),
query,
orm.WithNamedArgs(paramMap),
orm.WithResult(&result),
orm.WithSchema(*ch.Database.Schema),
); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
return
}
// send the HTTP status code
resp.WriteHeader(http.StatusOK)
// prepare the result encoder.
enc := json.NewEncoder(resp)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
enc.Encode(map[string]interface{}{
"results": result,
"query": query,
"params": paramMap,
})
}
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) {
var body io.Reader
switch req.Method {
case http.MethodPost, http.MethodPut:
body = req.Body
case http.MethodGet:
body = strings.NewReader(req.URL.Query().Get("q"))
default:
return nil, fmt.Errorf("invalid HTTP method")
}
var requestPayload QueryActiveConnectionChartPayload
blob, err := ioutil.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("failed to read body" + err.Error())
}
body = bytes.NewReader(blob)
dec := json.NewDecoder(body)
dec.DisallowUnknownFields()
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("invalid query: %w", err)
}
return &requestPayload, nil
}
func (req *QueryActiveConnectionChartPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
template := `
WITH RECURSIVE epoch(x) AS (
SELECT strftime('%%s')-600
UNION ALL
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
)
SELECT x as timestamp, COUNT(*) AS value FROM epoch
JOIN connections
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 > timestamp+0)
%s
GROUP BY round(timestamp/10, 0)*10;`
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
if err != nil {
return "", nil, err
}
if clause == "" {
return fmt.Sprintf(template, ""), map[string]interface{}{}, nil
}
return fmt.Sprintf(template, "WHERE ( "+clause+")"), params, nil
}

View file

@ -63,32 +63,35 @@ type (
// time. We cannot just use the network.Connection.ID because it is only unique // time. We cannot just use the network.Connection.ID because it is only unique
// as long as the connection is still active and might be, although unlikely, // as long as the connection is still active and might be, although unlikely,
// reused afterwards. // reused afterwards.
ID string `sqlite:"id,primary"` ID string `sqlite:"id,primary"`
ProfileID string `sqlite:"profile"` ProfileID string `sqlite:"profile"`
Path string `sqlite:"path"` ProfileSource string `sqlite:"profileSource"`
Type string `sqlite:"type,varchar(8)"` Path string `sqlite:"path"`
External bool `sqlite:"external"` Type string `sqlite:"type,varchar(8)"`
IPVersion packet.IPVersion `sqlite:"ip_version"` External bool `sqlite:"external"`
IPProtocol packet.IPProtocol `sqlite:"ip_protocol"` IPVersion packet.IPVersion `sqlite:"ip_version"`
LocalIP string `sqlite:"local_ip"` IPProtocol packet.IPProtocol `sqlite:"ip_protocol"`
LocalPort uint16 `sqlite:"local_port"` LocalIP string `sqlite:"local_ip"`
RemoteIP string `sqlite:"remote_ip"` LocalPort uint16 `sqlite:"local_port"`
RemotePort uint16 `sqlite:"remote_port"` RemoteIP string `sqlite:"remote_ip"`
Domain string `sqlite:"domain"` RemotePort uint16 `sqlite:"remote_port"`
Country string `sqlite:"country,varchar(2)"` Domain string `sqlite:"domain"`
ASN uint `sqlite:"asn"` Country string `sqlite:"country,varchar(2)"`
ASOwner string `sqlite:"as_owner"` ASN uint `sqlite:"asn"`
Latitude float64 `sqlite:"latitude"` ASOwner string `sqlite:"as_owner"`
Longitude float64 `sqlite:"longitude"` Latitude float64 `sqlite:"latitude"`
Scope netutils.IPScope `sqlite:"scope"` Longitude float64 `sqlite:"longitude"`
Verdict network.Verdict `sqlite:"verdict"` Scope netutils.IPScope `sqlite:"scope"`
Started time.Time `sqlite:"started,text,time"` Verdict network.Verdict `sqlite:"verdict"`
Ended *time.Time `sqlite:"ended,text,time"` Started time.Time `sqlite:"started,text,time"`
Tunneled bool `sqlite:"tunneled"` Ended *time.Time `sqlite:"ended,text,time"`
Encrypted bool `sqlite:"encrypted"` Tunneled bool `sqlite:"tunneled"`
Internal bool `sqlite:"internal"` Encrypted bool `sqlite:"encrypted"`
Inbound bool `sqlite:"inbound"` Internal bool `sqlite:"internal"`
ExtraData json.RawMessage `sqlite:"extra_data"` Direction string `sqlite:"direction"`
ExtraData json.RawMessage `sqlite:"extra_data"`
Allowed *bool `sqlite:"allowed"`
ProfileRevision int `sqlite:"profile_revision"`
} }
) )
@ -190,11 +193,11 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
// probably not worth the cylces... // probably not worth the cylces...
func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) { func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) {
where := `WHERE ended IS NOT NULL where := `WHERE ended IS NOT NULL
AND datetime(ended) < :threshold` AND datetime(ended) < datetime(:threshold)`
sql := "DELETE FROM connections " + where + ";" sql := "DELETE FROM connections " + where + ";"
args := orm.WithNamedArgs(map[string]interface{}{ args := orm.WithNamedArgs(map[string]interface{}{
":threshold": threshold, ":threshold": threshold.UTC().Format(orm.SqliteTimeFormat),
}) })
var result []struct { var result []struct {
@ -232,7 +235,7 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error {
if err := sqlitex.ExecuteTransient(db.conn, "SELECT * FROM connections", &sqlitex.ExecOptions{ if err := sqlitex.ExecuteTransient(db.conn, "SELECT * FROM connections", &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error { ResultFunc: func(stmt *sqlite.Stmt) error {
var c Conn var c Conn
if err := orm.DecodeStmt(ctx, stmt, &c, orm.DefaultDecodeConfig); err != nil { if err := orm.DecodeStmt(ctx, db.Schema, stmt, &c, orm.DefaultDecodeConfig); err != nil {
return err return err
} }

View file

@ -155,23 +155,42 @@ func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn C
func convertConnection(conn *network.Connection) (*Conn, error) { func convertConnection(conn *network.Connection) (*Conn, error) {
conn.Lock() conn.Lock()
defer conn.Unlock() defer conn.Unlock()
direction := "outbound"
if conn.Inbound {
direction = "inbound"
}
c := Conn{ c := Conn{
ID: genConnID(conn), ID: genConnID(conn),
External: conn.External, External: conn.External,
IPVersion: conn.IPVersion, IPVersion: conn.IPVersion,
IPProtocol: conn.IPProtocol, IPProtocol: conn.IPProtocol,
LocalIP: conn.LocalIP.String(), LocalIP: conn.LocalIP.String(),
LocalPort: conn.LocalPort, LocalPort: conn.LocalPort,
Verdict: conn.Verdict, Verdict: conn.Verdict,
Started: time.Unix(conn.Started, 0), Started: time.Unix(conn.Started, 0),
Tunneled: conn.Tunneled, Tunneled: conn.Tunneled,
Encrypted: conn.Encrypted, Encrypted: conn.Encrypted,
Internal: conn.Internal, Internal: conn.Internal,
Inbound: conn.Inbound, Direction: direction,
Type: ConnectionTypeToString[conn.Type], Type: ConnectionTypeToString[conn.Type],
ProfileID: conn.ProcessContext.ProfileName, ProfileID: conn.ProcessContext.Profile,
Path: conn.ProcessContext.BinaryPath, ProfileSource: conn.ProcessContext.Source,
Path: conn.ProcessContext.BinaryPath,
ProfileRevision: int(conn.ProfileRevisionCounter),
}
switch conn.Type {
case network.DNSRequest:
c.Type = "dns"
case network.IPConnection:
c.Type = "ip"
}
switch conn.Verdict {
case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
accepted := true
c.Allowed = &accepted
} }
if conn.Ended > 0 { if conn.Ended > 0 {
@ -181,6 +200,10 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
extraData := map[string]interface{}{} extraData := map[string]interface{}{}
if conn.TunnelContext != nil {
extraData["tunnel"] = conn.TunnelContext
}
if conn.Entity != nil { if conn.Entity != nil {
extraData["cname"] = conn.Entity.CNAME extraData["cname"] = conn.Entity.CNAME
extraData["blockedByLists"] = conn.Entity.BlockedByLists extraData["blockedByLists"] = conn.Entity.BlockedByLists

View file

@ -63,6 +63,10 @@ func (m *Module) Prepare() error {
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
} }
chartHandler := &ChartHandler{
Database: m.sqlStore,
}
// FIXME(ppacher): use appropriate permissions for this // FIXME(ppacher): use appropriate permissions for this
if err := api.RegisterEndpoint(api.Endpoint{ if err := api.RegisterEndpoint(api.Endpoint{
Path: "netquery/query", Path: "netquery/query",
@ -77,6 +81,19 @@ func (m *Module) Prepare() error {
return fmt.Errorf("failed to register API endpoint: %w", err) return fmt.Errorf("failed to register API endpoint: %w", err)
} }
if err := api.RegisterEndpoint(api.Endpoint{
Path: "netquery/charts/connection-active",
MimeType: "application/json",
Read: api.PermitAnyone,
Write: api.PermitAnyone,
BelongsTo: m.Module,
HandlerFunc: chartHandler.ServeHTTP,
Name: "Query In-Memory Database",
Description: "Query the in-memory sqlite database",
}); err != nil {
return fmt.Errorf("failed to register API endpoint: %w", err)
}
return nil return nil
} }
@ -120,11 +137,12 @@ func (mod *Module) Start() error {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold)) threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
count, err := mod.sqlStore.Cleanup(ctx, threshold)
if err != nil { if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %s", err) log.Errorf("netquery: failed to count number of rows in memory: %s", err)
} else { } else {
log.Infof("netquery: successfully removed %d old rows", count) log.Infof("netquery: successfully removed %d old rows that ended before %s", count, threshold)
} }
} }
} }
@ -135,7 +153,7 @@ func (mod *Module) Start() error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case <-time.After(5 * time.Second): case <-time.After(1 * time.Second):
count, err := mod.sqlStore.CountRows(ctx) count, err := mod.sqlStore.CountRows(ctx)
if err != nil { if err != nil {
log.Errorf("netquery: failed to count number of rows in memory: %s", err) log.Errorf("netquery: failed to count number of rows in memory: %s", err)

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"reflect" "reflect"
"strings" "strings"
"time" "time"
@ -51,7 +52,7 @@ type (
} }
// DecodeFunc is called for each non-basic type during decoding. // DecodeFunc is called for each non-basic type during decoding.
DecodeFunc func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
DecodeConfig struct { DecodeConfig struct {
DecodeHooks []DecodeFunc DecodeHooks []DecodeFunc
@ -63,7 +64,7 @@ type (
// be specified to provide support for special types. // be specified to provide support for special types.
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion. // See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
// //
func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeConfig) error { func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result interface{}, cfg DecodeConfig) error {
// make sure we got something to decode into ... // make sure we got something to decode into ...
if result == nil { if result == nil {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result) return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
@ -71,7 +72,7 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
// fast path for decoding into a map // fast path for decoding into a map
if mp, ok := result.(*map[string]interface{}); ok { if mp, ok := result.(*map[string]interface{}); ok {
return decodeIntoMap(ctx, stmt, mp) return decodeIntoMap(ctx, schema, stmt, mp, cfg)
} }
// make sure we got a pointer in result // make sure we got a pointer in result
@ -147,10 +148,13 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
value = storage.Elem() value = storage.Elem()
} }
colDef := schema.GetColumnDef(colName)
// execute all decode hooks but make sure we use decodeBasic() as the // execute all decode hooks but make sure we use decodeBasic() as the
// last one. // last one.
columnValue, err := runDecodeHooks( columnValue, err := runDecodeHooks(
i, i,
colDef,
stmt, stmt,
fieldType, fieldType,
value, value,
@ -188,10 +192,19 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing // FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing
// //
func DatetimeDecoder(loc *time.Location) DecodeFunc { func DatetimeDecoder(loc *time.Location) DecodeFunc {
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) { return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
// if we have the column definition available we
// use the target go type from there.
outType := outval.Type()
if colDef != nil {
outType = colDef.GoType
}
// we only care about "time.Time" here // we only care about "time.Time" here
if outval.Type().String() != "time.Time" { if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
return nil, nil log.Printf("not decoding %s %v", outType, colDef)
return nil, false, nil
} }
switch stmt.ColumnType(colIdx) { switch stmt.ColumnType(colIdx) {
@ -201,39 +214,61 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
// TODO(ppacher): actually split the tag value at "," and search // TODO(ppacher): actually split the tag value at "," and search
// the slice for "unixnano" // the slice for "unixnano"
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") { if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), nil return time.Unix(0, int64(stmt.ColumnInt(colIdx))), true, nil
} }
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), nil return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), true, nil
case sqlite.TypeText: case sqlite.TypeText:
// stored ISO8601 but does not have any timezone information // stored ISO8601 but does not have any timezone information
// assigned so we always treat it as loc here. // assigned so we always treat it as loc here.
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc) t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err) return nil, false, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
} }
return t, nil return t, true, nil
case sqlite.TypeFloat: case sqlite.TypeFloat:
// stored as Julian day numbers // stored as Julian day numbers
return nil, fmt.Errorf("REAL storage type not support for time.Time") return nil, false, fmt.Errorf("REAL storage type not support for time.Time")
case sqlite.TypeNull:
return nil, true, nil
default: default:
return nil, fmt.Errorf("unsupported storage type for time.Time: %s", outval.Type()) return nil, false, fmt.Errorf("unsupported storage type for time.Time: %s", stmt.ColumnType(colIdx))
} }
} }
} }
func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) error { func decodeIntoMap(ctx context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
if *mp == nil { if *mp == nil {
*mp = make(map[string]interface{}) *mp = make(map[string]interface{})
} }
for i := 0; i < stmt.ColumnCount(); i++ { for i := 0; i < stmt.ColumnCount(); i++ {
var x interface{} var x interface{}
val, err := decodeBasic()(i, stmt, reflect.StructField{}, reflect.ValueOf(&x).Elem())
colDef := schema.GetColumnDef(stmt.ColumnName(i))
outVal := reflect.ValueOf(&x).Elem()
fieldType := reflect.StructField{}
if colDef != nil {
outVal = reflect.New(colDef.GoType).Elem()
fieldType = reflect.StructField{
Type: colDef.GoType,
}
}
val, err := runDecodeHooks(
i,
colDef,
stmt,
fieldType,
outVal,
append(cfg.DecodeHooks, decodeBasic()),
)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err) return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
} }
@ -245,56 +280,99 @@ func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) e
} }
func decodeBasic() DecodeFunc { func decodeBasic() DecodeFunc {
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) { return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (result interface{}, handled bool, err error) {
valueKind := getKind(outval) valueKind := getKind(outval)
colType := stmt.ColumnType(colIdx) colType := stmt.ColumnType(colIdx)
colName := stmt.ColumnName(colIdx) colName := stmt.ColumnName(colIdx)
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type()) errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
// if we have the column definition available we
// use the target go type from there.
if colDef != nil {
valueKind = normalizeKind(colDef.GoType.Kind())
// if we have a column defintion we try to convert the value to
// the actual Go-type that was used in the model.
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
// or that type aliases like (type myInt int) are decoded into myInt instead of int
defer func() {
if handled {
t := reflect.New(colDef.GoType).Elem()
if result == nil || reflect.ValueOf(result).IsZero() {
return
}
if reflect.ValueOf(result).Type().ConvertibleTo(colDef.GoType) {
result = reflect.ValueOf(result).Convert(colDef.GoType).Interface()
}
t.Set(reflect.ValueOf(result))
result = t.Interface()
}
}()
}
log.Printf("decoding %s into kind %s", colName, valueKind)
if colType == sqlite.TypeNull {
if colDef != nil && colDef.Nullable {
return nil, true, nil
}
if colDef != nil && !colDef.Nullable {
return reflect.New(colDef.GoType).Elem().Interface(), true, nil
}
if outval.Kind() == reflect.Ptr {
return nil, true, nil
}
}
switch valueKind { switch valueKind {
case reflect.String: case reflect.String:
if colType != sqlite.TypeText { if colType != sqlite.TypeText {
return nil, errInvalidType return nil, false, errInvalidType
} }
return stmt.ColumnText(colIdx), nil return stmt.ColumnText(colIdx), true, nil
case reflect.Bool: case reflect.Bool:
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column // sqlite does not have a BOOL type, it rather stores a 1/0 in a column
// with INTEGER affinity. // with INTEGER affinity.
if colType != sqlite.TypeInteger { if colType != sqlite.TypeInteger {
return nil, errInvalidType return nil, false, errInvalidType
} }
return stmt.ColumnBool(colIdx), nil return stmt.ColumnBool(colIdx), true, nil
case reflect.Float64: case reflect.Float64:
if colType != sqlite.TypeFloat { if colType != sqlite.TypeFloat {
return nil, errInvalidType return nil, false, errInvalidType
} }
return stmt.ColumnFloat(colIdx), nil return stmt.ColumnFloat(colIdx), true, nil
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ... case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
if colType != sqlite.TypeInteger { if colType != sqlite.TypeInteger {
return nil, errInvalidType return nil, false, errInvalidType
} }
return stmt.ColumnInt(colIdx), nil return stmt.ColumnInt(colIdx), true, nil
case reflect.Slice: case reflect.Slice:
if outval.Type().Elem().Kind() != reflect.Uint8 { if outval.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("slices other than []byte for BLOB are not supported") return nil, false, fmt.Errorf("slices other than []byte for BLOB are not supported")
} }
if colType != sqlite.TypeBlob { if colType != sqlite.TypeBlob {
return nil, errInvalidType return nil, false, errInvalidType
} }
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx)) columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err) return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
} }
return columnValue, nil return columnValue, true, nil
case reflect.Interface: case reflect.Interface:
var ( var (
@ -306,7 +384,7 @@ func decodeBasic() DecodeFunc {
t = reflect.TypeOf([]byte{}) t = reflect.TypeOf([]byte{})
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx)) columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err) return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
} }
x = columnValue x = columnValue
@ -327,20 +405,20 @@ func decodeBasic() DecodeFunc {
x = nil x = nil
default: default:
return nil, fmt.Errorf("unsupported column type %s", colType) return nil, false, fmt.Errorf("unsupported column type %s", colType)
} }
if t == nil { if t == nil {
return nil, nil return nil, true, nil
} }
target := reflect.New(t).Elem() target := reflect.New(t).Elem()
target.Set(reflect.ValueOf(x)) target.Set(reflect.ValueOf(x))
return target.Interface(), nil return target.Interface(), true, nil
default: default:
return nil, fmt.Errorf("cannot decode into %s", valueKind) return nil, false, fmt.Errorf("cannot decode into %s", valueKind)
} }
} }
} }
@ -362,14 +440,14 @@ func sqlColumnName(fieldType reflect.StructField) string {
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks. // runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is // The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
// returned by a decode hook runDecodeHooks stops the error is returned to the caller. // returned by a decode hook runDecodeHooks stops the error is returned to the caller.
func runDecodeHooks(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) { func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
for _, fn := range hooks { for _, fn := range hooks {
res, err := fn(colIdx, stmt, fieldDef, outval) res, end, err := fn(colIdx, colDef, stmt, fieldDef, outval)
if err != nil { if err != nil {
return res, err return res, err
} }
if res != nil { if end {
return res, nil return res, nil
} }
} }

View file

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"log"
"reflect"
"testing" "testing"
"time" "time"
@ -82,6 +84,8 @@ func (ett *exampleTimeTypes) Equal(other interface{}) bool {
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil) return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
} }
type myInt int
type exampleTimeNano struct { type exampleTimeNano struct {
T time.Time `sqlite:",unixnano"` T time.Time `sqlite:",unixnano"`
} }
@ -100,10 +104,11 @@ func Test_Decoder(t *testing.T) {
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC) refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
cases := []struct { cases := []struct {
Desc string Desc string
Stmt testStmt Stmt testStmt
Result interface{} ColumnDef []ColumnDef
Expected interface{} Result interface{}
Expected interface{}
}{ }{
{ {
"Decoding into nil is not allowed", "Decoding into nil is not allowed",
@ -114,6 +119,7 @@ func Test_Decoder(t *testing.T) {
}, },
nil, nil,
nil, nil,
nil,
}, },
{ {
"Decoding into basic types", "Decoding into basic types",
@ -132,6 +138,7 @@ func Test_Decoder(t *testing.T) {
true, true,
}, },
}, },
nil,
&exampleFieldTypes{}, &exampleFieldTypes{},
&exampleFieldTypes{ &exampleFieldTypes{
S: "string value", S: "string value",
@ -157,6 +164,7 @@ func Test_Decoder(t *testing.T) {
1.2, 1.2,
}, },
}, },
nil,
&exampleFieldTypes{}, &exampleFieldTypes{},
&exampleFieldTypes{ &exampleFieldTypes{
S: "string value", S: "string value",
@ -178,6 +186,7 @@ func Test_Decoder(t *testing.T) {
true, true,
}, },
}, },
nil,
&exampleFieldTypes{}, &exampleFieldTypes{},
&exampleFieldTypes{ &exampleFieldTypes{
F: 1.2, F: 1.2,
@ -201,6 +210,7 @@ func Test_Decoder(t *testing.T) {
true, true,
}, },
}, },
nil,
&examplePointerTypes{}, &examplePointerTypes{},
func() interface{} { func() interface{} {
s := "string value" s := "string value"
@ -231,6 +241,7 @@ func Test_Decoder(t *testing.T) {
true, true,
}, },
}, },
nil,
&examplePointerTypes{}, &examplePointerTypes{},
func() interface{} { func() interface{} {
s := "string value" s := "string value"
@ -255,6 +266,7 @@ func Test_Decoder(t *testing.T) {
1, 1,
}, },
}, },
nil,
&exampleStructTags{}, &exampleStructTags{},
&exampleStructTags{ &exampleStructTags{
S: "string value", S: "string value",
@ -280,6 +292,7 @@ func Test_Decoder(t *testing.T) {
1, 1,
}, },
}, },
nil,
&exampleIntConv{}, &exampleIntConv{},
&exampleIntConv{ &exampleIntConv{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
@ -301,6 +314,7 @@ func Test_Decoder(t *testing.T) {
1.0, 1.0,
}, },
}, },
nil,
&exampleFieldTypes{}, &exampleFieldTypes{},
&exampleFieldTypes{ &exampleFieldTypes{
F: 1.0, F: 1.0,
@ -322,6 +336,7 @@ func Test_Decoder(t *testing.T) {
1.0, 1.0,
}, },
}, },
nil,
&examplePointerTypes{}, &examplePointerTypes{},
func() interface{} { func() interface{} {
f := 1.0 f := 1.0
@ -340,6 +355,7 @@ func Test_Decoder(t *testing.T) {
([]byte)("hello world"), ([]byte)("hello world"),
}, },
}, },
nil,
&exampleBlobTypes{}, &exampleBlobTypes{},
&exampleBlobTypes{ &exampleBlobTypes{
B: ([]byte)("hello world"), B: ([]byte)("hello world"),
@ -356,6 +372,7 @@ func Test_Decoder(t *testing.T) {
([]byte)("hello world"), ([]byte)("hello world"),
}, },
}, },
nil,
&exampleJSONRawTypes{}, &exampleJSONRawTypes{},
&exampleJSONRawTypes{ &exampleJSONRawTypes{
B: (json.RawMessage)("hello world"), B: (json.RawMessage)("hello world"),
@ -374,6 +391,7 @@ func Test_Decoder(t *testing.T) {
int(refTime.Unix()), int(refTime.Unix()),
}, },
}, },
nil,
&exampleTimeTypes{}, &exampleTimeTypes{},
&exampleTimeTypes{ &exampleTimeTypes{
T: refTime, T: refTime,
@ -393,6 +411,7 @@ func Test_Decoder(t *testing.T) {
int(refTime.UnixNano()), int(refTime.UnixNano()),
}, },
}, },
nil,
&exampleTimeNano{}, &exampleTimeNano{},
&exampleTimeNano{ &exampleTimeNano{
T: refTime, T: refTime,
@ -411,6 +430,7 @@ func Test_Decoder(t *testing.T) {
"value2", "value2",
}, },
}, },
nil,
&exampleInterface{}, &exampleInterface{},
func() interface{} { func() interface{} {
var x interface{} var x interface{}
@ -439,6 +459,7 @@ func Test_Decoder(t *testing.T) {
[]byte("blob value"), []byte("blob value"),
}, },
}, },
nil,
new(map[string]interface{}), new(map[string]interface{}),
&map[string]interface{}{ &map[string]interface{}{
"I": 1, "I": 1,
@ -447,14 +468,91 @@ func Test_Decoder(t *testing.T) {
"B": []byte("blob value"), "B": []byte("blob value"),
}, },
}, },
{
"Decoding using type-hints",
testStmt{
columns: []string{"B", "T"},
types: []sqlite.ColumnType{
sqlite.TypeInteger,
sqlite.TypeText,
},
values: []interface{}{
true,
refTime.Format(SqliteTimeFormat),
},
},
[]ColumnDef{
{
Name: "B",
Type: sqlite.TypeInteger,
GoType: reflect.TypeOf(true),
},
{
Name: "T",
Type: sqlite.TypeText,
GoType: reflect.TypeOf(time.Time{}),
IsTime: true,
},
},
new(map[string]interface{}),
&map[string]interface{}{
"B": true,
"T": refTime,
},
},
{
"Decoding into type aliases",
testStmt{
columns: []string{"B"},
types: []sqlite.ColumnType{
sqlite.TypeBlob,
},
values: []interface{}{
[]byte(`{"foo": "bar}`),
},
},
[]ColumnDef{
{
Name: "B",
Type: sqlite.TypeBlob,
GoType: reflect.TypeOf(json.RawMessage(`{"foo": "bar}`)),
},
},
new(map[string]interface{}),
&map[string]interface{}{
"B": json.RawMessage(`{"foo": "bar}`),
},
},
{
"Decoding into type aliases #2",
testStmt{
columns: []string{"I"},
types: []sqlite.ColumnType{sqlite.TypeInteger},
values: []interface{}{
10,
},
},
[]ColumnDef{
{
Name: "I",
Type: sqlite.TypeInteger,
GoType: reflect.TypeOf(myInt(0)),
},
},
new(map[string]interface{}),
&map[string]interface{}{
"I": myInt(10),
},
},
} }
for idx := range cases { for idx := range cases {
c := cases[idx] c := cases[idx]
t.Run(c.Desc, func(t *testing.T) { t.Run(c.Desc, func(t *testing.T) {
t.Parallel() //t.Parallel()
err := DecodeStmt(ctx, c.Stmt, c.Result, DefaultDecodeConfig) log.Println(c.Desc)
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
if fn, ok := c.Expected.(func() interface{}); ok { if fn, ok := c.Expected.(func() interface{}); ok {
c.Expected = fn() c.Expected = fn()
} }

View file

@ -20,6 +20,7 @@ type (
NamedArgs map[string]interface{} NamedArgs map[string]interface{}
Result interface{} Result interface{}
DecodeConfig DecodeConfig DecodeConfig DecodeConfig
Schema TableSchema
} }
) )
@ -56,6 +57,12 @@ func WithNamedArgs(args map[string]interface{}) QueryOption {
} }
} }
func WithSchema(tbl TableSchema) QueryOption {
return func(opts *queryOpts) {
opts.Schema = tbl
}
}
// WithResult sets the result receiver. result is expected to // WithResult sets the result receiver. result is expected to
// be a pointer to a slice of struct or map types. // be a pointer to a slice of struct or map types.
// //
@ -136,7 +143,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
currentField = reflect.New(valElemType) currentField = reflect.New(valElemType)
if err := DecodeStmt(ctx, stmt, currentField.Interface(), args.DecodeConfig); err != nil { if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
return err return err
} }

View file

@ -45,6 +45,7 @@ type (
Name string Name string
Nullable bool Nullable bool
Type sqlite.ColumnType Type sqlite.ColumnType
GoType reflect.Type
Length int Length int
PrimaryKey bool PrimaryKey bool
AutoIncrement bool AutoIncrement bool
@ -145,6 +146,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
ft = fieldType.Type.Elem() ft = fieldType.Type.Elem()
} }
def.GoType = ft
kind := normalizeKind(ft.Kind()) kind := normalizeKind(ft.Kind())
switch kind { switch kind {

View file

@ -16,6 +16,12 @@ import (
type ( type (
Query map[string][]Matcher Query map[string][]Matcher
MatchType interface {
Operator() string
}
Equal interface{}
Matcher struct { Matcher struct {
Equal interface{} `json:"$eq,omitempty"` Equal interface{} `json:"$eq,omitempty"`
NotEqual interface{} `json:"$ne,omitempty"` NotEqual interface{} `json:"$ne,omitempty"`
@ -27,12 +33,22 @@ type (
Count struct { Count struct {
As string `json:"as"` As string `json:"as"`
Field string `json:"field"` Field string `json:"field"`
Distinct bool `json:"distict"` Distinct bool `json:"distinct"`
} }
Sum struct {
Condition Query `json:"condition"`
As string `json:"as"`
Distinct bool `json:"distinct"`
}
// NOTE: whenever adding support for new operators make sure
// to update UnmarshalJSON as well.
Select struct { Select struct {
Field string `json:"field"` Field string `json:"field"`
Count *Count `json:"$count"` Count *Count `json:"$count,omitempty"`
Sum *Sum `json:"$sum,omitempty"`
Distinct *string `json:"$distinct"`
} }
Selects []Select Selects []Select
@ -45,6 +61,11 @@ type (
selectedFields []string selectedFields []string
whitelistedFields []string whitelistedFields []string
paramMap map[string]interface{}
}
QueryActiveConnectionChartPayload struct {
Query Query `json:"query"`
} }
OrderBy struct { OrderBy struct {
@ -179,15 +200,15 @@ func (match Matcher) Validate() error {
return nil return nil
} }
func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
var ( var (
queryParts []string queryParts []string
params = make(map[string]interface{}) params = make(map[string]interface{})
errs = new(multierror.Error) errs = new(multierror.Error)
key = fmt.Sprintf("%s%d", colDef.Name, idx) key = fmt.Sprintf("%s%s", colDef.Name, suffix)
) )
add := func(operator, suffix string, values ...interface{}) { add := func(operator, suffix string, list bool, values ...interface{}) {
var placeholder []string var placeholder []string
for idx, value := range values { for idx, value := range values {
@ -204,7 +225,7 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
params[uniqKey] = encodedValue params[uniqKey] = encodedValue
} }
if len(placeholder) == 1 { if len(placeholder) == 1 && !list {
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0])) queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0]))
} else { } else {
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", "))) queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", ")))
@ -212,23 +233,23 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
} }
if match.Equal != nil { if match.Equal != nil {
add("=", "eq", match.Equal) add("=", "eq", false, match.Equal)
} }
if match.NotEqual != nil { if match.NotEqual != nil {
add("!=", "ne", match.NotEqual) add("!=", "ne", false, match.NotEqual)
} }
if match.In != nil { if match.In != nil {
add("IN", "in", match.In...) add("IN", "in", true, match.In...)
} }
if match.NotIn != nil { if match.NotIn != nil {
add("NOT IN", "notin", match.NotIn...) add("NOT IN", "notin", true, match.NotIn...)
} }
if match.Like != "" { if match.Like != "" {
add("LIKE", "like", match.Like) add("LIKE", "like", false, match.Like)
} }
if len(queryParts) == 0 { if len(queryParts) == 0 {
@ -244,7 +265,7 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil() return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
} }
func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) { func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
if len(query) == 0 { if len(query) == 0 {
return "", nil, nil return "", nil, nil
} }
@ -279,7 +300,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, enc
queryParts := make([]string, len(values)) queryParts := make([]string, len(values))
for idx, val := range values { for idx, val := range values {
matcherQuery, params, err := val.toSQLConditionClause(ctx, idx, "AND", colDef, encoderConfig) matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig)
if err != nil { if err != nil {
errs.Errors = append(errs.Errors, errs.Errors = append(errs.Errors,
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err), fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
@ -359,8 +380,10 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
// directly // directly
if blob[0] == '{' { if blob[0] == '{' {
var res struct { var res struct {
Field string `json:"field"` Field string `json:"field"`
Count *Count `json:"$count"` Count *Count `json:"$count"`
Sum *Sum `json:"$sum"`
Distinct *string `json:"$distinct"`
} }
if err := json.Unmarshal(blob, &res); err != nil { if err := json.Unmarshal(blob, &res); err != nil {
@ -369,6 +392,8 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
sel.Count = res.Count sel.Count = res.Count
sel.Field = res.Field sel.Field = res.Field
sel.Distinct = res.Distinct
sel.Sum = res.Sum
if sel.Count != nil && sel.Count.As != "" { if sel.Count != nil && sel.Count.As != "" {
if !charOnlyRegexp.MatchString(sel.Count.As) { if !charOnlyRegexp.MatchString(sel.Count.As) {

View file

@ -58,6 +58,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
query, query,
orm.WithNamedArgs(paramMap), orm.WithNamedArgs(paramMap),
orm.WithResult(&result), orm.WithResult(&result),
orm.WithSchema(*qh.Database.Schema),
); err != nil { ); err != nil {
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError) http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
@ -139,13 +140,14 @@ func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, e
} }
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) { func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
if err := req.prepareSelectedFields(schema); err != nil { if err := req.prepareSelectedFields(ctx, schema); err != nil {
return "", nil, fmt.Errorf("perparing selected fields: %w", err) return "", nil, fmt.Errorf("perparing selected fields: %w", err)
} }
// build the SQL where clause from the payload query // build the SQL where clause from the payload query
whereClause, paramMap, err := req.Query.toSQLWhereClause( whereClause, paramMap, err := req.Query.toSQLWhereClause(
ctx, ctx,
"",
schema, schema,
orm.DefaultEncodeConfig, orm.DefaultEncodeConfig,
) )
@ -153,6 +155,14 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
return "", nil, fmt.Errorf("ganerating where clause: %w", err) return "", nil, fmt.Errorf("ganerating where clause: %w", err)
} }
if req.paramMap == nil {
req.paramMap = make(map[string]interface{})
}
for key, val := range paramMap {
req.paramMap[key] = val
}
// build the actual SQL query statement // build the actual SQL query statement
// FIXME(ppacher): add support for group-by and sort-by // FIXME(ppacher): add support for group-by and sort-by
@ -173,20 +183,26 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
} }
query += " " + groupByClause + " " + orderByClause query += " " + groupByClause + " " + orderByClause
return query, paramMap, nil return query, req.paramMap, nil
} }
func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) error { func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
for _, s := range req.Select { for idx, s := range req.Select {
var field string var field string
if s.Count != nil { switch {
case s.Count != nil:
field = s.Count.Field field = s.Count.Field
} else { case s.Distinct != nil:
field = *s.Distinct
case s.Sum != nil:
// field is not used in case of $sum
field = "*"
default:
field = s.Field field = s.Field
} }
colName := "*" colName := "*"
if field != "*" || s.Count == nil { if field != "*" || (s.Count == nil && s.Sum == nil) {
var err error var err error
colName, err = req.validateColumnName(schema, field) colName, err = req.validateColumnName(schema, field)
@ -195,7 +211,8 @@ func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) e
} }
} }
if s.Count != nil { switch {
case s.Count != nil:
var as = s.Count.As var as = s.Count.As
if as == "" { if as == "" {
as = fmt.Sprintf("%s_count", colName) as = fmt.Sprintf("%s_count", colName)
@ -204,9 +221,34 @@ func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) e
if s.Count.Distinct { if s.Count.Distinct {
distinct = "DISTINCT " distinct = "DISTINCT "
} }
req.selectedFields = append(req.selectedFields, fmt.Sprintf("COUNT(%s%s) as %s", distinct, colName, as)) req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
)
req.whitelistedFields = append(req.whitelistedFields, as) req.whitelistedFields = append(req.whitelistedFields, as)
} else {
case s.Sum != nil:
if s.Sum.As == "" {
return fmt.Errorf("missing 'as' for $sum")
}
clause, params, err := s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
if err != nil {
return fmt.Errorf("in $sum: %w", err)
}
req.paramMap = params
req.selectedFields = append(
req.selectedFields,
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
)
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
case s.Distinct != nil:
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
req.whitelistedFields = append(req.whitelistedFields, colName)
default:
req.selectedFields = append(req.selectedFields, colName) req.selectedFields = append(req.selectedFields, colName)
} }
} }
@ -251,6 +293,10 @@ func (req *QueryRequestPayload) generateSelectClause() string {
} }
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) { func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
if len(req.OrderBy) == 0 {
return "", nil
}
var orderBys = make([]string, len(req.OrderBy)) var orderBys = make([]string, len(req.OrderBy))
for idx, sort := range req.OrderBy { for idx, sort := range req.OrderBy {
colName, err := req.validateColumnName(schema, sort.Field) colName, err := req.validateColumnName(schema, sort.Field)
@ -286,7 +332,7 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
} }
} }
return "", fmt.Errorf("column name %s not allowed", field) return "", fmt.Errorf("column name %q not allowed", field)
} }
// compile time check // compile time check

View file

@ -228,7 +228,7 @@ func Test_QueryBuilder(t *testing.T) {
for idx, c := range cases { for idx, c := range cases {
t.Run(c.N, func(t *testing.T) { t.Run(c.N, func(t *testing.T) {
//t.Parallel() //t.Parallel()
str, params, err := c.Q.toSQLWhereClause(context.TODO(), tbl, orm.DefaultEncodeConfig) str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
if c.E != nil { if c.E != nil {
if assert.Error(t, err) { if assert.Error(t, err) {