safing-portmaster/service/netquery/orm/decoder.go
2024-03-27 16:17:58 +01:00

483 lines
14 KiB
Go

package orm
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"reflect"
"strings"
"time"
"zombiezen.com/go/sqlite"
)
// Commonly used error messages when working with orm.
var (
errStructExpected = errors.New("encode: can only encode structs to maps")
errStructPointerExpected = errors.New("decode: result must be pointer to a struct type or map[string]interface{}")
errUnexpectedColumnType = errors.New("decode: unexpected column type")
)
// constants used when transforming data to and from sqlite.
var (
// sqliteTimeFromat defines the string representation that is
// expected by SQLite DATETIME functions.
// Note that SQLite itself does not include support for a DATETIME
// column type. Instead, dates and times are stored either as INTEGER,
// TEXT or REAL.
// This package provides support for time.Time being stored as TEXT (using a
// preconfigured timezone; UTC by default) or as INTEGER (the user can choose between
// unixepoch and unixnano-epoch where the nano variant is not officially supported by
// SQLITE).
SqliteTimeFormat = "2006-01-02 15:04:05"
)
type (
// Stmt describes the interface that must be implemented in order to
// be decodable to a struct type using DecodeStmt. This interface is implemented
// by *sqlite.Stmt.
Stmt interface {
ColumnCount() int
ColumnName(col int) string
ColumnType(col int) sqlite.ColumnType
ColumnText(col int) string
ColumnBool(col int) bool
ColumnFloat(col int) float64
ColumnInt(col int) int
ColumnReader(col int) *bytes.Reader
}
// DecodeFunc is called for each non-basic type during decoding.
DecodeFunc func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error)
// DecodeConfig holds decoding functions.
DecodeConfig struct {
DecodeHooks []DecodeFunc
}
)
// DecodeStmt decodes the current result row loaded in Stmt into the struct or map type result.
// Decoding hooks configured in cfg are executed before trying to decode basic types and may
// be specified to provide support for special types.
// See DatetimeDecoder() for an example of a DecodeHook that handles graceful time.Time conversion.
func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result interface{}, cfg DecodeConfig) error {
// make sure we got something to decode into ...
if result == nil {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// fast path for decoding into a map
if mp, ok := result.(*map[string]interface{}); ok {
return decodeIntoMap(ctx, schema, stmt, mp, cfg)
}
// make sure we got a pointer in result
if reflect.TypeOf(result).Kind() != reflect.Ptr {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// make sure it's a poiter to a struct type
t := reflect.ValueOf(result).Elem().Type()
if t.Kind() != reflect.Struct {
return fmt.Errorf("%w, got %T", errStructPointerExpected, result)
}
// if result is a nil pointer make sure to allocate some space
// for the resulting struct
resultValue := reflect.ValueOf(result)
if resultValue.IsNil() {
resultValue.Set(
reflect.New(t),
)
}
// we need access to the struct directly and not to the
// pointer.
target := reflect.Indirect(resultValue)
// create a lookup map from field name (or sqlite:"" tag)
// to the field name
lm := make(map[string]string)
for i := 0; i < target.NumField(); i++ {
fieldType := t.Field(i)
// skip unexported fields
if !fieldType.IsExported() {
continue
}
lm[sqlColumnName(fieldType)] = fieldType.Name
}
// iterate over all columns and assign them to the correct
// fields
for i := 0; i < stmt.ColumnCount(); i++ {
colName := stmt.ColumnName(i)
fieldName, ok := lm[colName]
if !ok {
// there's no target field for this column
// so we can skip it
continue
}
fieldType, _ := t.FieldByName(fieldName)
value := target.FieldByName(fieldName)
colType := stmt.ColumnType(i)
// if the column is reported as NULL we keep
// the field as it is.
// TODO(ppacher): should we set it to nil here?
if colType == sqlite.TypeNull {
continue
}
// if value is a nil pointer we need to allocate some memory
// first
if getKind(value) == reflect.Ptr && value.IsNil() {
storage := reflect.New(fieldType.Type.Elem())
value.Set(storage)
// make sure value actually points the
// dereferenced target storage
value = storage.Elem()
}
colDef := schema.GetColumnDef(colName)
// execute all decode hooks but make sure we use decodeBasic() as the
// last one.
columnValue, err := runDecodeHooks(
i,
colDef,
stmt,
fieldType,
value,
append(cfg.DecodeHooks, decodeBasic()),
)
if err != nil {
return err
}
// if we don't have a converted value now we try to
// decode basic types
if columnValue == nil {
return fmt.Errorf("cannot decode column %d (type=%s)", i, colType)
}
// Debugging:
// log.Printf("valueTypeName: %s fieldName = %s value-orig = %s value = %s (%v) newValue = %s", value.Type().String(), fieldName, target.FieldByName(fieldName).Type(), value.Type(), value, columnValue)
// convert it to the target type if conversion is possible
newValue := reflect.ValueOf(columnValue)
if newValue.Type().ConvertibleTo(value.Type()) {
newValue = newValue.Convert(value.Type())
}
// assign the new value to the struct field.
value.Set(newValue)
}
return nil
}
// DatetimeDecoder is capable of decoding sqlite INTEGER or TEXT storage classes into
// time.Time. For INTEGER storage classes, it supports 'unixnano' struct tag value to
// decide between Unix or UnixNano epoch timestamps.
//
// TODO(ppacher): update comment about loc parameter and TEXT storage class parsing.
func DatetimeDecoder(loc *time.Location) DecodeFunc {
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (interface{}, bool, error) {
// if we have the column definition available we
// use the target go type from there.
outType := outval.Type()
if colDef != nil {
outType = colDef.GoType
}
// we only care about "time.Time" here
if outType.String() != "time.Time" || (colDef != nil && !colDef.IsTime) {
// log.Printf("not decoding %s %v", outType, colDef)
return nil, false, nil
}
switch stmt.ColumnType(colIdx) { //nolint:exhaustive // Only selecting specific types.
case sqlite.TypeInteger:
// stored as unix-epoch, if unixnano is set in the struct field tag
// we parse it with nano-second resolution
// TODO(ppacher): actually split the tag value at "," and search
// the slice for "unixnano"
if strings.Contains(fieldDef.Tag.Get("sqlite"), ",unixnano") {
return time.Unix(0, int64(stmt.ColumnInt(colIdx))), true, nil
}
return time.Unix(int64(stmt.ColumnInt(colIdx)), 0), true, nil
case sqlite.TypeText:
// stored ISO8601 but does not have any timezone information
// assigned so we always treat it as loc here.
t, err := time.ParseInLocation(SqliteTimeFormat, stmt.ColumnText(colIdx), loc)
if err != nil {
return nil, false, fmt.Errorf("failed to parse %q in %s: %w", stmt.ColumnText(colIdx), fieldDef.Name, err)
}
return t, true, nil
case sqlite.TypeFloat:
// stored as Julian day numbers
return nil, false, errors.New("REAL storage type not support for time.Time")
case sqlite.TypeNull:
return nil, true, nil
default:
return nil, false, fmt.Errorf("unsupported storage type for time.Time: %s", stmt.ColumnType(colIdx))
}
}
}
func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[string]interface{}, cfg DecodeConfig) error {
if *mp == nil {
*mp = make(map[string]interface{})
}
for i := 0; i < stmt.ColumnCount(); i++ {
var x interface{}
colDef := schema.GetColumnDef(stmt.ColumnName(i))
outVal := reflect.ValueOf(&x).Elem()
fieldType := reflect.StructField{}
if colDef != nil {
outVal = reflect.New(colDef.GoType).Elem()
fieldType = reflect.StructField{
Type: colDef.GoType,
}
}
val, err := runDecodeHooks(
i,
colDef,
stmt,
fieldType,
outVal,
append(cfg.DecodeHooks, decodeBasic()),
)
if err != nil {
return fmt.Errorf("failed to decode column %s: %w", stmt.ColumnName(i), err)
}
(*mp)[stmt.ColumnName(i)] = val
}
return nil
}
func decodeBasic() DecodeFunc {
return func(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value) (result interface{}, handled bool, err error) {
valueKind := getKind(outval)
colType := stmt.ColumnType(colIdx)
colName := stmt.ColumnName(colIdx)
errInvalidType := fmt.Errorf("%w %s for column %s with field type %s", errUnexpectedColumnType, colType.String(), colName, outval.Type())
// if we have the column definition available we
// use the target go type from there.
if colDef != nil {
valueKind = NormalizeKind(colDef.GoType.Kind())
// if we have a column definition we try to convert the value to
// the actual Go-type that was used in the model.
// this is useful, for example, to ensure a []byte{} is always decoded into json.RawMessage
// or that type aliases like (type myInt int) are decoded into myInt instead of int
defer func() {
if handled {
t := reflect.New(colDef.GoType).Elem()
if result == nil || reflect.ValueOf(result).IsZero() {
return
}
if reflect.ValueOf(result).Type().ConvertibleTo(colDef.GoType) {
result = reflect.ValueOf(result).Convert(colDef.GoType).Interface()
}
t.Set(reflect.ValueOf(result))
result = t.Interface()
}
}()
}
// log.Printf("decoding %s into kind %s", colName, valueKind)
if colType == sqlite.TypeNull {
if colDef != nil && colDef.Nullable {
return nil, true, nil
}
if colDef != nil && !colDef.Nullable {
return reflect.New(colDef.GoType).Elem().Interface(), true, nil
}
if outval.Kind() == reflect.Ptr {
return nil, true, nil
}
}
switch valueKind { //nolint:exhaustive
case reflect.String:
if colType != sqlite.TypeText {
return nil, false, errInvalidType
}
return stmt.ColumnText(colIdx), true, nil
case reflect.Bool:
// sqlite does not have a BOOL type, it rather stores a 1/0 in a column
// with INTEGER affinity.
if colType != sqlite.TypeInteger {
return nil, false, errInvalidType
}
return stmt.ColumnBool(colIdx), true, nil
case reflect.Float64:
if colType != sqlite.TypeFloat {
return nil, false, errInvalidType
}
return stmt.ColumnFloat(colIdx), true, nil
case reflect.Int, reflect.Uint: // getKind() normalizes all ints to reflect.Int/Uint because sqlite doesn't really care ...
if colType != sqlite.TypeInteger {
return nil, false, errInvalidType
}
return stmt.ColumnInt(colIdx), true, nil
case reflect.Slice:
if outval.Type().Elem().Kind() != reflect.Uint8 {
return nil, false, errors.New("slices other than []byte for BLOB are not supported")
}
if colType != sqlite.TypeBlob {
return nil, false, errInvalidType
}
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
if err != nil {
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
}
return columnValue, true, nil
case reflect.Interface:
var (
t reflect.Type
x interface{}
)
switch colType {
case sqlite.TypeBlob:
t = reflect.TypeOf([]byte{})
columnValue, err := io.ReadAll(stmt.ColumnReader(colIdx))
if err != nil {
return nil, false, fmt.Errorf("failed to read blob for column %s: %w", fieldDef.Name, err)
}
x = columnValue
case sqlite.TypeFloat:
t = reflect.TypeOf(float64(0))
x = stmt.ColumnFloat(colIdx)
case sqlite.TypeInteger:
t = reflect.TypeOf(int(0))
x = stmt.ColumnInt(colIdx)
case sqlite.TypeText:
t = reflect.TypeOf(string(""))
x = stmt.ColumnText(colIdx)
case sqlite.TypeNull:
t = nil
x = nil
default:
return nil, false, fmt.Errorf("unsupported column type %s", colType)
}
if t == nil {
return nil, true, nil
}
target := reflect.New(t).Elem()
target.Set(reflect.ValueOf(x))
return target.Interface(), true, nil
default:
return nil, false, fmt.Errorf("cannot decode into %s", valueKind)
}
}
}
func sqlColumnName(fieldType reflect.StructField) string {
tagValue, hasTag := fieldType.Tag.Lookup("sqlite")
if !hasTag {
return fieldType.Name
}
parts := strings.Split(tagValue, ",")
if parts[0] != "" {
return parts[0]
}
return fieldType.Name
}
// runDecodeHooks tries to decode the column value of stmt at index colIdx into outval by running all decode hooks.
// The first hook that returns a non-nil interface wins, other hooks will not be executed. If an error is
// returned by a decode hook runDecodeHooks stops the error is returned to the caller.
func runDecodeHooks(colIdx int, colDef *ColumnDef, stmt Stmt, fieldDef reflect.StructField, outval reflect.Value, hooks []DecodeFunc) (interface{}, error) {
for _, fn := range hooks {
res, end, err := fn(colIdx, colDef, stmt, fieldDef, outval)
if err != nil {
return res, err
}
if end {
return res, nil
}
}
return nil, nil
}
// getKind returns the kind of value but normalized Int, Uint and Float variants.
// to their base type.
func getKind(val reflect.Value) reflect.Kind {
kind := val.Kind()
return NormalizeKind(kind)
}
// NormalizeKind returns a normalized kind of the given kind.
func NormalizeKind(kind reflect.Kind) reflect.Kind {
switch {
case kind >= reflect.Int && kind <= reflect.Int64:
return reflect.Int
case kind >= reflect.Uint && kind <= reflect.Uint64:
return reflect.Uint
case kind >= reflect.Float32 && kind <= reflect.Float64:
return reflect.Float64
default:
return kind
}
}
// DefaultDecodeConfig holds the default decoding configuration.
var DefaultDecodeConfig = DecodeConfig{
DecodeHooks: []DecodeFunc{
DatetimeDecoder(time.UTC),
},
}