mirror of
https://github.com/safing/portmaster
synced 2025-09-11 07:24:36 +00:00
Add query and chart support with multiple fixes to ORM package
This commit is contained in:
parent
0d2ec9df75
commit
25aceaf103
11 changed files with 535 additions and 117 deletions
118
netquery/chart_handler.go
Normal file
118
netquery/chart_handler.go
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
package netquery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/safing/portmaster/netquery/orm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChartHandler struct {
|
||||||
|
Database *Database
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *ChartHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
|
requestPayload, err := ch.parseRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
query, paramMap, err := requestPayload.generateSQL(req.Context(), ch.Database.Schema)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(resp, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// actually execute the query against the database and collect the result
|
||||||
|
var result []map[string]interface{}
|
||||||
|
if err := ch.Database.Execute(
|
||||||
|
req.Context(),
|
||||||
|
query,
|
||||||
|
orm.WithNamedArgs(paramMap),
|
||||||
|
orm.WithResult(&result),
|
||||||
|
orm.WithSchema(*ch.Database.Schema),
|
||||||
|
); err != nil {
|
||||||
|
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the HTTP status code
|
||||||
|
resp.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
// prepare the result encoder.
|
||||||
|
enc := json.NewEncoder(resp)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
|
||||||
|
enc.Encode(map[string]interface{}{
|
||||||
|
"results": result,
|
||||||
|
"query": query,
|
||||||
|
"params": paramMap,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *ChartHandler) parseRequest(req *http.Request) (*QueryActiveConnectionChartPayload, error) {
|
||||||
|
var body io.Reader
|
||||||
|
|
||||||
|
switch req.Method {
|
||||||
|
case http.MethodPost, http.MethodPut:
|
||||||
|
body = req.Body
|
||||||
|
case http.MethodGet:
|
||||||
|
body = strings.NewReader(req.URL.Query().Get("q"))
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid HTTP method")
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestPayload QueryActiveConnectionChartPayload
|
||||||
|
blob, err := ioutil.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read body" + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
body = bytes.NewReader(blob)
|
||||||
|
|
||||||
|
dec := json.NewDecoder(body)
|
||||||
|
dec.DisallowUnknownFields()
|
||||||
|
|
||||||
|
if err := json.Unmarshal(blob, &requestPayload); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return nil, fmt.Errorf("invalid query: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &requestPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *QueryActiveConnectionChartPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||||
|
template := `
|
||||||
|
WITH RECURSIVE epoch(x) AS (
|
||||||
|
SELECT strftime('%%s')-600
|
||||||
|
UNION ALL
|
||||||
|
SELECT x+1 FROM epoch WHERE x+1 < strftime('%%s')+0
|
||||||
|
)
|
||||||
|
SELECT x as timestamp, COUNT(*) AS value FROM epoch
|
||||||
|
JOIN connections
|
||||||
|
ON strftime('%%s', connections.started)+0 <= timestamp+0 AND (connections.ended IS NULL OR strftime('%%s', connections.ended)+0 > timestamp+0)
|
||||||
|
%s
|
||||||
|
GROUP BY round(timestamp/10, 0)*10;`
|
||||||
|
|
||||||
|
clause, params, err := req.Query.toSQLWhereClause(ctx, "", schema, orm.DefaultEncodeConfig)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if clause == "" {
|
||||||
|
return fmt.Sprintf(template, ""), map[string]interface{}{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(template, "WHERE ( "+clause+")"), params, nil
|
||||||
|
}
|
|
@ -63,32 +63,35 @@ type (
|
||||||
// time. We cannot just use the network.Connection.ID because it is only unique
|
// time. We cannot just use the network.Connection.ID because it is only unique
|
||||||
// as long as the connection is still active and might be, although unlikely,
|
// as long as the connection is still active and might be, although unlikely,
|
||||||
// reused afterwards.
|
// reused afterwards.
|
||||||
ID string `sqlite:"id,primary"`
|
ID string `sqlite:"id,primary"`
|
||||||
ProfileID string `sqlite:"profile"`
|
ProfileID string `sqlite:"profile"`
|
||||||
Path string `sqlite:"path"`
|
ProfileSource string `sqlite:"profileSource"`
|
||||||
Type string `sqlite:"type,varchar(8)"`
|
Path string `sqlite:"path"`
|
||||||
External bool `sqlite:"external"`
|
Type string `sqlite:"type,varchar(8)"`
|
||||||
IPVersion packet.IPVersion `sqlite:"ip_version"`
|
External bool `sqlite:"external"`
|
||||||
IPProtocol packet.IPProtocol `sqlite:"ip_protocol"`
|
IPVersion packet.IPVersion `sqlite:"ip_version"`
|
||||||
LocalIP string `sqlite:"local_ip"`
|
IPProtocol packet.IPProtocol `sqlite:"ip_protocol"`
|
||||||
LocalPort uint16 `sqlite:"local_port"`
|
LocalIP string `sqlite:"local_ip"`
|
||||||
RemoteIP string `sqlite:"remote_ip"`
|
LocalPort uint16 `sqlite:"local_port"`
|
||||||
RemotePort uint16 `sqlite:"remote_port"`
|
RemoteIP string `sqlite:"remote_ip"`
|
||||||
Domain string `sqlite:"domain"`
|
RemotePort uint16 `sqlite:"remote_port"`
|
||||||
Country string `sqlite:"country,varchar(2)"`
|
Domain string `sqlite:"domain"`
|
||||||
ASN uint `sqlite:"asn"`
|
Country string `sqlite:"country,varchar(2)"`
|
||||||
ASOwner string `sqlite:"as_owner"`
|
ASN uint `sqlite:"asn"`
|
||||||
Latitude float64 `sqlite:"latitude"`
|
ASOwner string `sqlite:"as_owner"`
|
||||||
Longitude float64 `sqlite:"longitude"`
|
Latitude float64 `sqlite:"latitude"`
|
||||||
Scope netutils.IPScope `sqlite:"scope"`
|
Longitude float64 `sqlite:"longitude"`
|
||||||
Verdict network.Verdict `sqlite:"verdict"`
|
Scope netutils.IPScope `sqlite:"scope"`
|
||||||
Started time.Time `sqlite:"started,text,time"`
|
Verdict network.Verdict `sqlite:"verdict"`
|
||||||
Ended *time.Time `sqlite:"ended,text,time"`
|
Started time.Time `sqlite:"started,text,time"`
|
||||||
Tunneled bool `sqlite:"tunneled"`
|
Ended *time.Time `sqlite:"ended,text,time"`
|
||||||
Encrypted bool `sqlite:"encrypted"`
|
Tunneled bool `sqlite:"tunneled"`
|
||||||
Internal bool `sqlite:"internal"`
|
Encrypted bool `sqlite:"encrypted"`
|
||||||
Inbound bool `sqlite:"inbound"`
|
Internal bool `sqlite:"internal"`
|
||||||
ExtraData json.RawMessage `sqlite:"extra_data"`
|
Direction string `sqlite:"direction"`
|
||||||
|
ExtraData json.RawMessage `sqlite:"extra_data"`
|
||||||
|
Allowed *bool `sqlite:"allowed"`
|
||||||
|
ProfileRevision int `sqlite:"profile_revision"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -190,11 +193,11 @@ func (db *Database) CountRows(ctx context.Context) (int, error) {
|
||||||
// probably not worth the cylces...
|
// probably not worth the cylces...
|
||||||
func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) {
|
func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) {
|
||||||
where := `WHERE ended IS NOT NULL
|
where := `WHERE ended IS NOT NULL
|
||||||
AND datetime(ended) < :threshold`
|
AND datetime(ended) < datetime(:threshold)`
|
||||||
sql := "DELETE FROM connections " + where + ";"
|
sql := "DELETE FROM connections " + where + ";"
|
||||||
|
|
||||||
args := orm.WithNamedArgs(map[string]interface{}{
|
args := orm.WithNamedArgs(map[string]interface{}{
|
||||||
":threshold": threshold,
|
":threshold": threshold.UTC().Format(orm.SqliteTimeFormat),
|
||||||
})
|
})
|
||||||
|
|
||||||
var result []struct {
|
var result []struct {
|
||||||
|
@ -232,7 +235,7 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error {
|
||||||
if err := sqlitex.ExecuteTransient(db.conn, "SELECT * FROM connections", &sqlitex.ExecOptions{
|
if err := sqlitex.ExecuteTransient(db.conn, "SELECT * FROM connections", &sqlitex.ExecOptions{
|
||||||
ResultFunc: func(stmt *sqlite.Stmt) error {
|
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||||
var c Conn
|
var c Conn
|
||||||
if err := orm.DecodeStmt(ctx, stmt, &c, orm.DefaultDecodeConfig); err != nil {
|
if err := orm.DecodeStmt(ctx, db.Schema, stmt, &c, orm.DefaultDecodeConfig); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -155,23 +155,42 @@ func (mng *Manager) pushConnUpdate(ctx context.Context, meta record.Meta, conn C
|
||||||
func convertConnection(conn *network.Connection) (*Conn, error) {
|
func convertConnection(conn *network.Connection) (*Conn, error) {
|
||||||
conn.Lock()
|
conn.Lock()
|
||||||
defer conn.Unlock()
|
defer conn.Unlock()
|
||||||
|
direction := "outbound"
|
||||||
|
if conn.Inbound {
|
||||||
|
direction = "inbound"
|
||||||
|
}
|
||||||
|
|
||||||
c := Conn{
|
c := Conn{
|
||||||
ID: genConnID(conn),
|
ID: genConnID(conn),
|
||||||
External: conn.External,
|
External: conn.External,
|
||||||
IPVersion: conn.IPVersion,
|
IPVersion: conn.IPVersion,
|
||||||
IPProtocol: conn.IPProtocol,
|
IPProtocol: conn.IPProtocol,
|
||||||
LocalIP: conn.LocalIP.String(),
|
LocalIP: conn.LocalIP.String(),
|
||||||
LocalPort: conn.LocalPort,
|
LocalPort: conn.LocalPort,
|
||||||
Verdict: conn.Verdict,
|
Verdict: conn.Verdict,
|
||||||
Started: time.Unix(conn.Started, 0),
|
Started: time.Unix(conn.Started, 0),
|
||||||
Tunneled: conn.Tunneled,
|
Tunneled: conn.Tunneled,
|
||||||
Encrypted: conn.Encrypted,
|
Encrypted: conn.Encrypted,
|
||||||
Internal: conn.Internal,
|
Internal: conn.Internal,
|
||||||
Inbound: conn.Inbound,
|
Direction: direction,
|
||||||
Type: ConnectionTypeToString[conn.Type],
|
Type: ConnectionTypeToString[conn.Type],
|
||||||
ProfileID: conn.ProcessContext.ProfileName,
|
ProfileID: conn.ProcessContext.Profile,
|
||||||
Path: conn.ProcessContext.BinaryPath,
|
ProfileSource: conn.ProcessContext.Source,
|
||||||
|
Path: conn.ProcessContext.BinaryPath,
|
||||||
|
ProfileRevision: int(conn.ProfileRevisionCounter),
|
||||||
|
}
|
||||||
|
|
||||||
|
switch conn.Type {
|
||||||
|
case network.DNSRequest:
|
||||||
|
c.Type = "dns"
|
||||||
|
case network.IPConnection:
|
||||||
|
c.Type = "ip"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch conn.Verdict {
|
||||||
|
case network.VerdictAccept, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
|
||||||
|
accepted := true
|
||||||
|
c.Allowed = &accepted
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.Ended > 0 {
|
if conn.Ended > 0 {
|
||||||
|
@ -181,6 +200,10 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
|
||||||
|
|
||||||
extraData := map[string]interface{}{}
|
extraData := map[string]interface{}{}
|
||||||
|
|
||||||
|
if conn.TunnelContext != nil {
|
||||||
|
extraData["tunnel"] = conn.TunnelContext
|
||||||
|
}
|
||||||
|
|
||||||
if conn.Entity != nil {
|
if conn.Entity != nil {
|
||||||
extraData["cname"] = conn.Entity.CNAME
|
extraData["cname"] = conn.Entity.CNAME
|
||||||
extraData["blockedByLists"] = conn.Entity.BlockedByLists
|
extraData["blockedByLists"] = conn.Entity.BlockedByLists
|
||||||
|
|
|
@ -63,6 +63,10 @@ func (m *Module) Prepare() error {
|
||||||
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
|
IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chartHandler := &ChartHandler{
|
||||||
|
Database: m.sqlStore,
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME(ppacher): use appropriate permissions for this
|
// FIXME(ppacher): use appropriate permissions for this
|
||||||
if err := api.RegisterEndpoint(api.Endpoint{
|
if err := api.RegisterEndpoint(api.Endpoint{
|
||||||
Path: "netquery/query",
|
Path: "netquery/query",
|
||||||
|
@ -77,6 +81,19 @@ func (m *Module) Prepare() error {
|
||||||
return fmt.Errorf("failed to register API endpoint: %w", err)
|
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := api.RegisterEndpoint(api.Endpoint{
|
||||||
|
Path: "netquery/charts/connection-active",
|
||||||
|
MimeType: "application/json",
|
||||||
|
Read: api.PermitAnyone,
|
||||||
|
Write: api.PermitAnyone,
|
||||||
|
BelongsTo: m.Module,
|
||||||
|
HandlerFunc: chartHandler.ServeHTTP,
|
||||||
|
Name: "Query In-Memory Database",
|
||||||
|
Description: "Query the in-memory sqlite database",
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to register API endpoint: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,11 +137,12 @@ func (mod *Module) Start() error {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
case <-time.After(10 * time.Second):
|
case <-time.After(10 * time.Second):
|
||||||
count, err := mod.sqlStore.Cleanup(ctx, time.Now().Add(-network.DeleteConnsAfterEndedThreshold))
|
threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold)
|
||||||
|
count, err := mod.sqlStore.Cleanup(ctx, threshold)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("netquery: successfully removed %d old rows", count)
|
log.Infof("netquery: successfully removed %d old rows that ended before %s", count, threshold)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -135,7 +153,7 @@ func (mod *Module) Start() error {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
count, err := mod.sqlStore.CountRows(ctx)
|
count, err := mod.sqlStore.CountRows(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
log.Errorf("netquery: failed to count number of rows in memory: %s", err)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -51,7 +52,7 @@ type (
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeFunc is called for each non-basic type during decoding.
|
// DecodeFunc is called for each non-basic type during decoding.
|
||||||
DecodeFunc func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error)
|
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
|
||||||
|
|
||||||
DecodeConfig struct {
|
DecodeConfig struct {
|
||||||
DecodeHooks []DecodeFunc
|
DecodeHooks []DecodeFunc
|
||||||
|
@ -63,7 +64,7 @@ type (
|
||||||
// be specified to provide support for special types.
|
// be specified to provide support for special types.
|
||||||
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
|
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
|
||||||
//
|
//
|
||||||
func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeConfig) error {
|
func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result interface{}, cfg DecodeConfig) error {
|
||||||
// make sure we got something to decode into ...
|
// make sure we got something to decode into ...
|
||||||
if result == nil {
|
if result == nil {
|
||||||
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
|
||||||
|
@ -71,7 +72,7 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
|
||||||
|
|
||||||
// fast path for decoding into a map
|
// fast path for decoding into a map
|
||||||
if mp, ok := result.(*map[string]interface{}); ok {
|
if mp, ok := result.(*map[string]interface{}); ok {
|
||||||
return decodeIntoMap(ctx, stmt, mp)
|
return decodeIntoMap(ctx, schema, stmt, mp, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure we got a pointer in result
|
// make sure we got a pointer in result
|
||||||
|
@ -147,10 +148,13 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
|
||||||
value = storage.Elem()
|
value = storage.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
colDef := schema.GetColumnDef(colName)
|
||||||
|
|
||||||
// execute all decode hooks but make sure we use decodeBasic() as the
|
// execute all decode hooks but make sure we use decodeBasic() as the
|
||||||
// last one.
|
// last one.
|
||||||
columnValue, err := runDecodeHooks(
|
columnValue, err := runDecodeHooks(
|
||||||
i,
|
i,
|
||||||
|
colDef,
|
||||||
stmt,
|
stmt,
|
||||||
fieldType,
|
fieldType,
|
||||||
value,
|
value,
|
||||||
|
@ -188,10 +192,19 @@ func DecodeStmt(ctx context.Context, stmt Stmt, result interface{}, cfg DecodeCo
|
||||||
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing
|
// FIXME(ppacher): update comment about loc parameter and TEXT storage class parsing
|
||||||
//
|
//
|
||||||
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||||
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) {
|
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
|
||||||
|
// if we have the column definition available we
|
||||||
|
// use the target go type from there.
|
||||||
|
outType := outval.Type()
|
||||||
|
|
||||||
|
if colDef != nil {
|
||||||
|
outType = colDef.GoType
|
||||||
|
}
|
||||||
|
|
||||||
// we only care about "time.Time" here
|
// we only care about "time.Time" here
|
||||||
if outval.Type().String() != "time.Time" {
|
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
|
||||||
return nil, nil
|
log.Printf("not decoding %s %v", outType, colDef)
|
||||||
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch stmt.ColumnType(colIdx) {
|
switch stmt.ColumnType(colIdx) {
|
||||||
|
@ -201,39 +214,61 @@ func DatetimeDecoder(loc *time.Location) DecodeFunc {
|
||||||
// TODO(ppacher): actually split the tag value at "," and search
|
// TODO(ppacher): actually split the tag value at "," and search
|
||||||
// the slice for "unixnano"
|
// the slice for "unixnano"
|
||||||
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
|
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
|
||||||
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), nil
|
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), nil
|
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), true, nil
|
||||||
|
|
||||||
case sqlite.TypeText:
|
case sqlite.TypeText:
|
||||||
// stored ISO8601 but does not have any timezone information
|
// stored ISO8601 but does not have any timezone information
|
||||||
// assigned so we always treat it as loc here.
|
// assigned so we always treat it as loc here.
|
||||||
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
|
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
return nil, false, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, true, nil
|
||||||
|
|
||||||
case sqlite.TypeFloat:
|
case sqlite.TypeFloat:
|
||||||
// stored as Julian day numbers
|
// stored as Julian day numbers
|
||||||
return nil, fmt.Errorf("REAL storage type not support for time.Time")
|
return nil, false, fmt.Errorf("REAL storage type not support for time.Time")
|
||||||
|
|
||||||
|
case sqlite.TypeNull:
|
||||||
|
return nil, true, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported storage type for time.Time: %s", outval.Type())
|
return nil, false, fmt.Errorf("unsupported storage type for time.Time: %s", stmt.ColumnType(colIdx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) error {
|
func decodeIntoMap(ctx context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
|
||||||
if *mp == nil {
|
if *mp == nil {
|
||||||
*mp = make(map[string]interface{})
|
*mp = make(map[string]interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < stmt.ColumnCount(); i++ {
|
for i := 0; i < stmt.ColumnCount(); i++ {
|
||||||
var x interface{}
|
var x interface{}
|
||||||
val, err := decodeBasic()(i, stmt, reflect.StructField{}, reflect.ValueOf(&x).Elem())
|
|
||||||
|
colDef := schema.GetColumnDef(stmt.ColumnName(i))
|
||||||
|
|
||||||
|
outVal := reflect.ValueOf(&x).Elem()
|
||||||
|
fieldType := reflect.StructField{}
|
||||||
|
if colDef != nil {
|
||||||
|
outVal = reflect.New(colDef.GoType).Elem()
|
||||||
|
fieldType = reflect.StructField{
|
||||||
|
Type: colDef.GoType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := runDecodeHooks(
|
||||||
|
i,
|
||||||
|
colDef,
|
||||||
|
stmt,
|
||||||
|
fieldType,
|
||||||
|
outVal,
|
||||||
|
append(cfg.DecodeHooks, decodeBasic()),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
|
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
|
||||||
}
|
}
|
||||||
|
@ -245,56 +280,99 @@ func decodeIntoMap(ctx context.Context, stmt Stmt, mp *map[string]interface{}) e
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeBasic() DecodeFunc {
|
func decodeBasic() DecodeFunc {
|
||||||
return func(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, error) {
|
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (result interface{}, handled bool, err error) {
|
||||||
valueKind := getKind(outval)
|
valueKind := getKind(outval)
|
||||||
colType := stmt.ColumnType(colIdx)
|
colType := stmt.ColumnType(colIdx)
|
||||||
colName := stmt.ColumnName(colIdx)
|
colName := stmt.ColumnName(colIdx)
|
||||||
|
|
||||||
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
|
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
|
||||||
|
|
||||||
|
// if we have the column definition available we
|
||||||
|
// use the target go type from there.
|
||||||
|
if colDef != nil {
|
||||||
|
valueKind = normalizeKind(colDef.GoType.Kind())
|
||||||
|
|
||||||
|
// if we have a column defintion we try to convert the value to
|
||||||
|
// the actual Go-type that was used in the model.
|
||||||
|
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
|
||||||
|
// or that type aliases like (type myInt int) are decoded into myInt instead of int
|
||||||
|
defer func() {
|
||||||
|
if handled {
|
||||||
|
t := reflect.New(colDef.GoType).Elem()
|
||||||
|
|
||||||
|
if result == nil || reflect.ValueOf(result).IsZero() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.ValueOf(result).Type().ConvertibleTo(colDef.GoType) {
|
||||||
|
result = reflect.ValueOf(result).Convert(colDef.GoType).Interface()
|
||||||
|
}
|
||||||
|
t.Set(reflect.ValueOf(result))
|
||||||
|
|
||||||
|
result = t.Interface()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("decoding %s into kind %s", colName, valueKind)
|
||||||
|
|
||||||
|
if colType == sqlite.TypeNull {
|
||||||
|
if colDef != nil && colDef.Nullable {
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if colDef != nil && !colDef.Nullable {
|
||||||
|
return reflect.New(colDef.GoType).Elem().Interface(), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if outval.Kind() == reflect.Ptr {
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch valueKind {
|
switch valueKind {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if colType != sqlite.TypeText {
|
if colType != sqlite.TypeText {
|
||||||
return nil, errInvalidType
|
return nil, false, errInvalidType
|
||||||
}
|
}
|
||||||
return stmt.ColumnText(colIdx), nil
|
return stmt.ColumnText(colIdx), true, nil
|
||||||
|
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
|
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
|
||||||
// with INTEGER affinity.
|
// with INTEGER affinity.
|
||||||
if colType != sqlite.TypeInteger {
|
if colType != sqlite.TypeInteger {
|
||||||
return nil, errInvalidType
|
return nil, false, errInvalidType
|
||||||
}
|
}
|
||||||
return stmt.ColumnBool(colIdx), nil
|
return stmt.ColumnBool(colIdx), true, nil
|
||||||
|
|
||||||
case reflect.Float64:
|
case reflect.Float64:
|
||||||
if colType != sqlite.TypeFloat {
|
if colType != sqlite.TypeFloat {
|
||||||
return nil, errInvalidType
|
return nil, false, errInvalidType
|
||||||
}
|
}
|
||||||
return stmt.ColumnFloat(colIdx), nil
|
return stmt.ColumnFloat(colIdx), true, nil
|
||||||
|
|
||||||
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
|
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
|
||||||
if colType != sqlite.TypeInteger {
|
if colType != sqlite.TypeInteger {
|
||||||
return nil, errInvalidType
|
return nil, false, errInvalidType
|
||||||
}
|
}
|
||||||
|
|
||||||
return stmt.ColumnInt(colIdx), nil
|
return stmt.ColumnInt(colIdx), true, nil
|
||||||
|
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if outval.Type().Elem().Kind() != reflect.Uint8 {
|
if outval.Type().Elem().Kind() != reflect.Uint8 {
|
||||||
return nil, fmt.Errorf("slices other than []byte for BLOB are not supported")
|
return nil, false, fmt.Errorf("slices other than []byte for BLOB are not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
if colType != sqlite.TypeBlob {
|
if colType != sqlite.TypeBlob {
|
||||||
return nil, errInvalidType
|
return nil, false, errInvalidType
|
||||||
}
|
}
|
||||||
|
|
||||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return columnValue, nil
|
return columnValue, true, nil
|
||||||
|
|
||||||
case reflect.Interface:
|
case reflect.Interface:
|
||||||
var (
|
var (
|
||||||
|
@ -306,7 +384,7 @@ func decodeBasic() DecodeFunc {
|
||||||
t = reflect.TypeOf([]byte{})
|
t = reflect.TypeOf([]byte{})
|
||||||
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
|
||||||
}
|
}
|
||||||
x = columnValue
|
x = columnValue
|
||||||
|
|
||||||
|
@ -327,20 +405,20 @@ func decodeBasic() DecodeFunc {
|
||||||
x = nil
|
x = nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported column type %s", colType)
|
return nil, false, fmt.Errorf("unsupported column type %s", colType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if t == nil {
|
if t == nil {
|
||||||
return nil, nil
|
return nil, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
target := reflect.New(t).Elem()
|
target := reflect.New(t).Elem()
|
||||||
target.Set(reflect.ValueOf(x))
|
target.Set(reflect.ValueOf(x))
|
||||||
|
|
||||||
return target.Interface(), nil
|
return target.Interface(), true, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("cannot decode into %s", valueKind)
|
return nil, false, fmt.Errorf("cannot decode into %s", valueKind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -362,14 +440,14 @@ func sqlColumnName(fieldType reflect.StructField) string {
|
||||||
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
|
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
|
||||||
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
|
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
|
||||||
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
|
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
|
||||||
func runDecodeHooks(colIdx int, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
|
func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
|
||||||
for _, fn := range hooks {
|
for _, fn := range hooks {
|
||||||
res, err := fn(colIdx, stmt, fieldDef, outval)
|
res, end, err := fn(colIdx, colDef, stmt, fieldDef, outval)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if res != nil {
|
if end {
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -82,6 +84,8 @@ func (ett *exampleTimeTypes) Equal(other interface{}) bool {
|
||||||
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
|
return ett.T.Equal(oett.T) && (ett.TP != nil && oett.TP != nil && ett.TP.Equal(*oett.TP)) || (ett.TP == nil && oett.TP == nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type myInt int
|
||||||
|
|
||||||
type exampleTimeNano struct {
|
type exampleTimeNano struct {
|
||||||
T time.Time `sqlite:",unixnano"`
|
T time.Time `sqlite:",unixnano"`
|
||||||
}
|
}
|
||||||
|
@ -100,10 +104,11 @@ func Test_Decoder(t *testing.T) {
|
||||||
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
refTime := time.Date(2022, time.February, 15, 9, 51, 00, 00, time.UTC)
|
||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
Desc string
|
Desc string
|
||||||
Stmt testStmt
|
Stmt testStmt
|
||||||
Result interface{}
|
ColumnDef []ColumnDef
|
||||||
Expected interface{}
|
Result interface{}
|
||||||
|
Expected interface{}
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"Decoding into nil is not allowed",
|
"Decoding into nil is not allowed",
|
||||||
|
@ -114,6 +119,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Decoding into basic types",
|
"Decoding into basic types",
|
||||||
|
@ -132,6 +138,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleFieldTypes{},
|
&exampleFieldTypes{},
|
||||||
&exampleFieldTypes{
|
&exampleFieldTypes{
|
||||||
S: "string value",
|
S: "string value",
|
||||||
|
@ -157,6 +164,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
1.2,
|
1.2,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleFieldTypes{},
|
&exampleFieldTypes{},
|
||||||
&exampleFieldTypes{
|
&exampleFieldTypes{
|
||||||
S: "string value",
|
S: "string value",
|
||||||
|
@ -178,6 +186,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleFieldTypes{},
|
&exampleFieldTypes{},
|
||||||
&exampleFieldTypes{
|
&exampleFieldTypes{
|
||||||
F: 1.2,
|
F: 1.2,
|
||||||
|
@ -201,6 +210,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&examplePointerTypes{},
|
&examplePointerTypes{},
|
||||||
func() interface{} {
|
func() interface{} {
|
||||||
s := "string value"
|
s := "string value"
|
||||||
|
@ -231,6 +241,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&examplePointerTypes{},
|
&examplePointerTypes{},
|
||||||
func() interface{} {
|
func() interface{} {
|
||||||
s := "string value"
|
s := "string value"
|
||||||
|
@ -255,6 +266,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
1,
|
1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleStructTags{},
|
&exampleStructTags{},
|
||||||
&exampleStructTags{
|
&exampleStructTags{
|
||||||
S: "string value",
|
S: "string value",
|
||||||
|
@ -280,6 +292,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
1,
|
1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleIntConv{},
|
&exampleIntConv{},
|
||||||
&exampleIntConv{
|
&exampleIntConv{
|
||||||
1, 1, 1, 1, 1,
|
1, 1, 1, 1, 1,
|
||||||
|
@ -301,6 +314,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
1.0,
|
1.0,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleFieldTypes{},
|
&exampleFieldTypes{},
|
||||||
&exampleFieldTypes{
|
&exampleFieldTypes{
|
||||||
F: 1.0,
|
F: 1.0,
|
||||||
|
@ -322,6 +336,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
1.0,
|
1.0,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&examplePointerTypes{},
|
&examplePointerTypes{},
|
||||||
func() interface{} {
|
func() interface{} {
|
||||||
f := 1.0
|
f := 1.0
|
||||||
|
@ -340,6 +355,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
([]byte)("hello world"),
|
([]byte)("hello world"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleBlobTypes{},
|
&exampleBlobTypes{},
|
||||||
&exampleBlobTypes{
|
&exampleBlobTypes{
|
||||||
B: ([]byte)("hello world"),
|
B: ([]byte)("hello world"),
|
||||||
|
@ -356,6 +372,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
([]byte)("hello world"),
|
([]byte)("hello world"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleJSONRawTypes{},
|
&exampleJSONRawTypes{},
|
||||||
&exampleJSONRawTypes{
|
&exampleJSONRawTypes{
|
||||||
B: (json.RawMessage)("hello world"),
|
B: (json.RawMessage)("hello world"),
|
||||||
|
@ -374,6 +391,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
int(refTime.Unix()),
|
int(refTime.Unix()),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleTimeTypes{},
|
&exampleTimeTypes{},
|
||||||
&exampleTimeTypes{
|
&exampleTimeTypes{
|
||||||
T: refTime,
|
T: refTime,
|
||||||
|
@ -393,6 +411,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
int(refTime.UnixNano()),
|
int(refTime.UnixNano()),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleTimeNano{},
|
&exampleTimeNano{},
|
||||||
&exampleTimeNano{
|
&exampleTimeNano{
|
||||||
T: refTime,
|
T: refTime,
|
||||||
|
@ -411,6 +430,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
"value2",
|
"value2",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
&exampleInterface{},
|
&exampleInterface{},
|
||||||
func() interface{} {
|
func() interface{} {
|
||||||
var x interface{}
|
var x interface{}
|
||||||
|
@ -439,6 +459,7 @@ func Test_Decoder(t *testing.T) {
|
||||||
[]byte("blob value"),
|
[]byte("blob value"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
new(map[string]interface{}),
|
new(map[string]interface{}),
|
||||||
&map[string]interface{}{
|
&map[string]interface{}{
|
||||||
"I": 1,
|
"I": 1,
|
||||||
|
@ -447,14 +468,91 @@ func Test_Decoder(t *testing.T) {
|
||||||
"B": []byte("blob value"),
|
"B": []byte("blob value"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"Decoding using type-hints",
|
||||||
|
testStmt{
|
||||||
|
columns: []string{"B", "T"},
|
||||||
|
types: []sqlite.ColumnType{
|
||||||
|
sqlite.TypeInteger,
|
||||||
|
sqlite.TypeText,
|
||||||
|
},
|
||||||
|
values: []interface{}{
|
||||||
|
true,
|
||||||
|
refTime.Format(SqliteTimeFormat),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[]ColumnDef{
|
||||||
|
{
|
||||||
|
Name: "B",
|
||||||
|
Type: sqlite.TypeInteger,
|
||||||
|
GoType: reflect.TypeOf(true),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "T",
|
||||||
|
Type: sqlite.TypeText,
|
||||||
|
GoType: reflect.TypeOf(time.Time{}),
|
||||||
|
IsTime: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
new(map[string]interface{}),
|
||||||
|
&map[string]interface{}{
|
||||||
|
"B": true,
|
||||||
|
"T": refTime,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Decoding into type aliases",
|
||||||
|
testStmt{
|
||||||
|
columns: []string{"B"},
|
||||||
|
types: []sqlite.ColumnType{
|
||||||
|
sqlite.TypeBlob,
|
||||||
|
},
|
||||||
|
values: []interface{}{
|
||||||
|
[]byte(`{"foo": "bar}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[]ColumnDef{
|
||||||
|
{
|
||||||
|
Name: "B",
|
||||||
|
Type: sqlite.TypeBlob,
|
||||||
|
GoType: reflect.TypeOf(json.RawMessage(`{"foo": "bar}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
new(map[string]interface{}),
|
||||||
|
&map[string]interface{}{
|
||||||
|
"B": json.RawMessage(`{"foo": "bar}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Decoding into type aliases #2",
|
||||||
|
testStmt{
|
||||||
|
columns: []string{"I"},
|
||||||
|
types: []sqlite.ColumnType{sqlite.TypeInteger},
|
||||||
|
values: []interface{}{
|
||||||
|
10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[]ColumnDef{
|
||||||
|
{
|
||||||
|
Name: "I",
|
||||||
|
Type: sqlite.TypeInteger,
|
||||||
|
GoType: reflect.TypeOf(myInt(0)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
new(map[string]interface{}),
|
||||||
|
&map[string]interface{}{
|
||||||
|
"I": myInt(10),
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx := range cases {
|
for idx := range cases {
|
||||||
c := cases[idx]
|
c := cases[idx]
|
||||||
t.Run(c.Desc, func(t *testing.T) {
|
t.Run(c.Desc, func(t *testing.T) {
|
||||||
t.Parallel()
|
//t.Parallel()
|
||||||
|
|
||||||
err := DecodeStmt(ctx, c.Stmt, c.Result, DefaultDecodeConfig)
|
log.Println(c.Desc)
|
||||||
|
err := DecodeStmt(ctx, &TableSchema{Columns: c.ColumnDef}, c.Stmt, c.Result, DefaultDecodeConfig)
|
||||||
if fn, ok := c.Expected.(func() interface{}); ok {
|
if fn, ok := c.Expected.(func() interface{}); ok {
|
||||||
c.Expected = fn()
|
c.Expected = fn()
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ type (
|
||||||
NamedArgs map[string]interface{}
|
NamedArgs map[string]interface{}
|
||||||
Result interface{}
|
Result interface{}
|
||||||
DecodeConfig DecodeConfig
|
DecodeConfig DecodeConfig
|
||||||
|
Schema TableSchema
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,6 +57,12 @@ func WithNamedArgs(args map[string]interface{}) QueryOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithSchema(tbl TableSchema) QueryOption {
|
||||||
|
return func(opts *queryOpts) {
|
||||||
|
opts.Schema = tbl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithResult sets the result receiver. result is expected to
|
// WithResult sets the result receiver. result is expected to
|
||||||
// be a pointer to a slice of struct or map types.
|
// be a pointer to a slice of struct or map types.
|
||||||
//
|
//
|
||||||
|
@ -136,7 +143,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q
|
||||||
|
|
||||||
currentField = reflect.New(valElemType)
|
currentField = reflect.New(valElemType)
|
||||||
|
|
||||||
if err := DecodeStmt(ctx, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,7 @@ type (
|
||||||
Name string
|
Name string
|
||||||
Nullable bool
|
Nullable bool
|
||||||
Type sqlite.ColumnType
|
Type sqlite.ColumnType
|
||||||
|
GoType reflect.Type
|
||||||
Length int
|
Length int
|
||||||
PrimaryKey bool
|
PrimaryKey bool
|
||||||
AutoIncrement bool
|
AutoIncrement bool
|
||||||
|
@ -145,6 +146,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) {
|
||||||
ft = fieldType.Type.Elem()
|
ft = fieldType.Type.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def.GoType = ft
|
||||||
kind := normalizeKind(ft.Kind())
|
kind := normalizeKind(ft.Kind())
|
||||||
|
|
||||||
switch kind {
|
switch kind {
|
||||||
|
|
|
@ -16,6 +16,12 @@ import (
|
||||||
type (
|
type (
|
||||||
Query map[string][]Matcher
|
Query map[string][]Matcher
|
||||||
|
|
||||||
|
MatchType interface {
|
||||||
|
Operator() string
|
||||||
|
}
|
||||||
|
|
||||||
|
Equal interface{}
|
||||||
|
|
||||||
Matcher struct {
|
Matcher struct {
|
||||||
Equal interface{} `json:"$eq,omitempty"`
|
Equal interface{} `json:"$eq,omitempty"`
|
||||||
NotEqual interface{} `json:"$ne,omitempty"`
|
NotEqual interface{} `json:"$ne,omitempty"`
|
||||||
|
@ -27,12 +33,22 @@ type (
|
||||||
Count struct {
|
Count struct {
|
||||||
As string `json:"as"`
|
As string `json:"as"`
|
||||||
Field string `json:"field"`
|
Field string `json:"field"`
|
||||||
Distinct bool `json:"distict"`
|
Distinct bool `json:"distinct"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Sum struct {
|
||||||
|
Condition Query `json:"condition"`
|
||||||
|
As string `json:"as"`
|
||||||
|
Distinct bool `json:"distinct"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: whenever adding support for new operators make sure
|
||||||
|
// to update UnmarshalJSON as well.
|
||||||
Select struct {
|
Select struct {
|
||||||
Field string `json:"field"`
|
Field string `json:"field"`
|
||||||
Count *Count `json:"$count"`
|
Count *Count `json:"$count,omitempty"`
|
||||||
|
Sum *Sum `json:"$sum,omitempty"`
|
||||||
|
Distinct *string `json:"$distinct"`
|
||||||
}
|
}
|
||||||
|
|
||||||
Selects []Select
|
Selects []Select
|
||||||
|
@ -45,6 +61,11 @@ type (
|
||||||
|
|
||||||
selectedFields []string
|
selectedFields []string
|
||||||
whitelistedFields []string
|
whitelistedFields []string
|
||||||
|
paramMap map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
QueryActiveConnectionChartPayload struct {
|
||||||
|
Query Query `json:"query"`
|
||||||
}
|
}
|
||||||
|
|
||||||
OrderBy struct {
|
OrderBy struct {
|
||||||
|
@ -179,15 +200,15 @@ func (match Matcher) Validate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
func (match Matcher) toSQLConditionClause(ctx context.Context, suffix string, conjunction string, colDef orm.ColumnDef, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||||
var (
|
var (
|
||||||
queryParts []string
|
queryParts []string
|
||||||
params = make(map[string]interface{})
|
params = make(map[string]interface{})
|
||||||
errs = new(multierror.Error)
|
errs = new(multierror.Error)
|
||||||
key = fmt.Sprintf("%s%d", colDef.Name, idx)
|
key = fmt.Sprintf("%s%s", colDef.Name, suffix)
|
||||||
)
|
)
|
||||||
|
|
||||||
add := func(operator, suffix string, values ...interface{}) {
|
add := func(operator, suffix string, list bool, values ...interface{}) {
|
||||||
var placeholder []string
|
var placeholder []string
|
||||||
|
|
||||||
for idx, value := range values {
|
for idx, value := range values {
|
||||||
|
@ -204,7 +225,7 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
|
||||||
params[uniqKey] = encodedValue
|
params[uniqKey] = encodedValue
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(placeholder) == 1 {
|
if len(placeholder) == 1 && !list {
|
||||||
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0]))
|
queryParts = append(queryParts, fmt.Sprintf("%s %s %s", colDef.Name, operator, placeholder[0]))
|
||||||
} else {
|
} else {
|
||||||
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", ")))
|
queryParts = append(queryParts, fmt.Sprintf("%s %s ( %s )", colDef.Name, operator, strings.Join(placeholder, ", ")))
|
||||||
|
@ -212,23 +233,23 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
|
||||||
}
|
}
|
||||||
|
|
||||||
if match.Equal != nil {
|
if match.Equal != nil {
|
||||||
add("=", "eq", match.Equal)
|
add("=", "eq", false, match.Equal)
|
||||||
}
|
}
|
||||||
|
|
||||||
if match.NotEqual != nil {
|
if match.NotEqual != nil {
|
||||||
add("!=", "ne", match.NotEqual)
|
add("!=", "ne", false, match.NotEqual)
|
||||||
}
|
}
|
||||||
|
|
||||||
if match.In != nil {
|
if match.In != nil {
|
||||||
add("IN", "in", match.In...)
|
add("IN", "in", true, match.In...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if match.NotIn != nil {
|
if match.NotIn != nil {
|
||||||
add("NOT IN", "notin", match.NotIn...)
|
add("NOT IN", "notin", true, match.NotIn...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if match.Like != "" {
|
if match.Like != "" {
|
||||||
add("LIKE", "like", match.Like)
|
add("LIKE", "like", false, match.Like)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(queryParts) == 0 {
|
if len(queryParts) == 0 {
|
||||||
|
@ -244,7 +265,7 @@ func (match Matcher) toSQLConditionClause(ctx context.Context, idx int, conjunct
|
||||||
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
|
return "( " + strings.Join(queryParts, " "+conjunction+" ") + " )", params, errs.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
func (query Query) toSQLWhereClause(ctx context.Context, suffix string, m *orm.TableSchema, encoderConfig orm.EncodeConfig) (string, map[string]interface{}, error) {
|
||||||
if len(query) == 0 {
|
if len(query) == 0 {
|
||||||
return "", nil, nil
|
return "", nil, nil
|
||||||
}
|
}
|
||||||
|
@ -279,7 +300,7 @@ func (query Query) toSQLWhereClause(ctx context.Context, m *orm.TableSchema, enc
|
||||||
|
|
||||||
queryParts := make([]string, len(values))
|
queryParts := make([]string, len(values))
|
||||||
for idx, val := range values {
|
for idx, val := range values {
|
||||||
matcherQuery, params, err := val.toSQLConditionClause(ctx, idx, "AND", colDef, encoderConfig)
|
matcherQuery, params, err := val.toSQLConditionClause(ctx, fmt.Sprintf("%s%d", suffix, idx), "AND", colDef, encoderConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs.Errors = append(errs.Errors,
|
errs.Errors = append(errs.Errors,
|
||||||
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
|
fmt.Errorf("invalid matcher at index %d for column %s: %w", idx, colDef.Name, err),
|
||||||
|
@ -359,8 +380,10 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||||
// directly
|
// directly
|
||||||
if blob[0] == '{' {
|
if blob[0] == '{' {
|
||||||
var res struct {
|
var res struct {
|
||||||
Field string `json:"field"`
|
Field string `json:"field"`
|
||||||
Count *Count `json:"$count"`
|
Count *Count `json:"$count"`
|
||||||
|
Sum *Sum `json:"$sum"`
|
||||||
|
Distinct *string `json:"$distinct"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(blob, &res); err != nil {
|
if err := json.Unmarshal(blob, &res); err != nil {
|
||||||
|
@ -369,6 +392,8 @@ func (sel *Select) UnmarshalJSON(blob []byte) error {
|
||||||
|
|
||||||
sel.Count = res.Count
|
sel.Count = res.Count
|
||||||
sel.Field = res.Field
|
sel.Field = res.Field
|
||||||
|
sel.Distinct = res.Distinct
|
||||||
|
sel.Sum = res.Sum
|
||||||
|
|
||||||
if sel.Count != nil && sel.Count.As != "" {
|
if sel.Count != nil && sel.Count.As != "" {
|
||||||
if !charOnlyRegexp.MatchString(sel.Count.As) {
|
if !charOnlyRegexp.MatchString(sel.Count.As) {
|
||||||
|
|
|
@ -58,6 +58,7 @@ func (qh *QueryHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
query,
|
query,
|
||||||
orm.WithNamedArgs(paramMap),
|
orm.WithNamedArgs(paramMap),
|
||||||
orm.WithResult(&result),
|
orm.WithResult(&result),
|
||||||
|
orm.WithSchema(*qh.Database.Schema),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
http.Error(resp, "Failed to execute query: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
|
||||||
|
@ -139,13 +140,14 @@ func (qh *QueryHandler) parseRequest(req *http.Request) (*QueryRequestPayload, e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.TableSchema) (string, map[string]interface{}, error) {
|
||||||
if err := req.prepareSelectedFields(schema); err != nil {
|
if err := req.prepareSelectedFields(ctx, schema); err != nil {
|
||||||
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
|
return "", nil, fmt.Errorf("perparing selected fields: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// build the SQL where clause from the payload query
|
// build the SQL where clause from the payload query
|
||||||
whereClause, paramMap, err := req.Query.toSQLWhereClause(
|
whereClause, paramMap, err := req.Query.toSQLWhereClause(
|
||||||
ctx,
|
ctx,
|
||||||
|
"",
|
||||||
schema,
|
schema,
|
||||||
orm.DefaultEncodeConfig,
|
orm.DefaultEncodeConfig,
|
||||||
)
|
)
|
||||||
|
@ -153,6 +155,14 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||||
return "", nil, fmt.Errorf("ganerating where clause: %w", err)
|
return "", nil, fmt.Errorf("ganerating where clause: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.paramMap == nil {
|
||||||
|
req.paramMap = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, val := range paramMap {
|
||||||
|
req.paramMap[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
// build the actual SQL query statement
|
// build the actual SQL query statement
|
||||||
// FIXME(ppacher): add support for group-by and sort-by
|
// FIXME(ppacher): add support for group-by and sort-by
|
||||||
|
|
||||||
|
@ -173,20 +183,26 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab
|
||||||
}
|
}
|
||||||
query += " " + groupByClause + " " + orderByClause
|
query += " " + groupByClause + " " + orderByClause
|
||||||
|
|
||||||
return query, paramMap, nil
|
return query, req.paramMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) error {
|
func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error {
|
||||||
for _, s := range req.Select {
|
for idx, s := range req.Select {
|
||||||
var field string
|
var field string
|
||||||
if s.Count != nil {
|
switch {
|
||||||
|
case s.Count != nil:
|
||||||
field = s.Count.Field
|
field = s.Count.Field
|
||||||
} else {
|
case s.Distinct != nil:
|
||||||
|
field = *s.Distinct
|
||||||
|
case s.Sum != nil:
|
||||||
|
// field is not used in case of $sum
|
||||||
|
field = "*"
|
||||||
|
default:
|
||||||
field = s.Field
|
field = s.Field
|
||||||
}
|
}
|
||||||
|
|
||||||
colName := "*"
|
colName := "*"
|
||||||
if field != "*" || s.Count == nil {
|
if field != "*" || (s.Count == nil && s.Sum == nil) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
colName, err = req.validateColumnName(schema, field)
|
colName, err = req.validateColumnName(schema, field)
|
||||||
|
@ -195,7 +211,8 @@ func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Count != nil {
|
switch {
|
||||||
|
case s.Count != nil:
|
||||||
var as = s.Count.As
|
var as = s.Count.As
|
||||||
if as == "" {
|
if as == "" {
|
||||||
as = fmt.Sprintf("%s_count", colName)
|
as = fmt.Sprintf("%s_count", colName)
|
||||||
|
@ -204,9 +221,34 @@ func (req *QueryRequestPayload) prepareSelectedFields(schema *orm.TableSchema) e
|
||||||
if s.Count.Distinct {
|
if s.Count.Distinct {
|
||||||
distinct = "DISTINCT "
|
distinct = "DISTINCT "
|
||||||
}
|
}
|
||||||
req.selectedFields = append(req.selectedFields, fmt.Sprintf("COUNT(%s%s) as %s", distinct, colName, as))
|
req.selectedFields = append(
|
||||||
|
req.selectedFields,
|
||||||
|
fmt.Sprintf("COUNT(%s%s) AS %s", distinct, colName, as),
|
||||||
|
)
|
||||||
req.whitelistedFields = append(req.whitelistedFields, as)
|
req.whitelistedFields = append(req.whitelistedFields, as)
|
||||||
} else {
|
|
||||||
|
case s.Sum != nil:
|
||||||
|
if s.Sum.As == "" {
|
||||||
|
return fmt.Errorf("missing 'as' for $sum")
|
||||||
|
}
|
||||||
|
|
||||||
|
clause, params, err := s.Sum.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("in $sum: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.paramMap = params
|
||||||
|
req.selectedFields = append(
|
||||||
|
req.selectedFields,
|
||||||
|
fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As),
|
||||||
|
)
|
||||||
|
req.whitelistedFields = append(req.whitelistedFields, s.Sum.As)
|
||||||
|
|
||||||
|
case s.Distinct != nil:
|
||||||
|
req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName))
|
||||||
|
req.whitelistedFields = append(req.whitelistedFields, colName)
|
||||||
|
|
||||||
|
default:
|
||||||
req.selectedFields = append(req.selectedFields, colName)
|
req.selectedFields = append(req.selectedFields, colName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -251,6 +293,10 @@ func (req *QueryRequestPayload) generateSelectClause() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
|
func (req *QueryRequestPayload) generateOrderByClause(schema *orm.TableSchema) (string, error) {
|
||||||
|
if len(req.OrderBy) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
var orderBys = make([]string, len(req.OrderBy))
|
var orderBys = make([]string, len(req.OrderBy))
|
||||||
for idx, sort := range req.OrderBy {
|
for idx, sort := range req.OrderBy {
|
||||||
colName, err := req.validateColumnName(schema, sort.Field)
|
colName, err := req.validateColumnName(schema, sort.Field)
|
||||||
|
@ -286,7 +332,7 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("column name %s not allowed", field)
|
return "", fmt.Errorf("column name %q not allowed", field)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compile time check
|
// compile time check
|
||||||
|
|
|
@ -228,7 +228,7 @@ func Test_QueryBuilder(t *testing.T) {
|
||||||
for idx, c := range cases {
|
for idx, c := range cases {
|
||||||
t.Run(c.N, func(t *testing.T) {
|
t.Run(c.N, func(t *testing.T) {
|
||||||
//t.Parallel()
|
//t.Parallel()
|
||||||
str, params, err := c.Q.toSQLWhereClause(context.TODO(), tbl, orm.DefaultEncodeConfig)
|
str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig)
|
||||||
|
|
||||||
if c.E != nil {
|
if c.E != nil {
|
||||||
if assert.Error(t, err) {
|
if assert.Error(t, err) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue