mirror of
https://github.com/safing/portmaster
synced 2025-04-21 11:29:09 +00:00
259 lines
6.1 KiB
Go
259 lines
6.1 KiB
Go
package orm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"time"
|
|
|
|
"golang.org/x/exp/slices"
|
|
"zombiezen.com/go/sqlite"
|
|
)
|
|
|
|
type (
|
|
// EncodeFunc is called for each non-basic type during encoding.
|
|
EncodeFunc func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error)
|
|
|
|
// EncodeConfig holds encoding functions.
|
|
EncodeConfig struct {
|
|
EncodeHooks []EncodeFunc
|
|
}
|
|
)
|
|
|
|
// ToParamMap returns a map that contains the sqlite compatible value of each struct field of
|
|
// r using the sqlite column name as a map key. It either uses the name of the
|
|
// exported struct field or the value of the "sqlite" tag.
|
|
func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) {
|
|
// make sure we work on a struct type
|
|
val := reflect.Indirect(reflect.ValueOf(r))
|
|
if val.Kind() != reflect.Struct {
|
|
return nil, fmt.Errorf("%w, got %T", errStructExpected, r)
|
|
}
|
|
|
|
res := make(map[string]interface{}, val.NumField())
|
|
|
|
for i := 0; i < val.NumField(); i++ {
|
|
fieldType := val.Type().Field(i)
|
|
field := val.Field(i)
|
|
|
|
// skip unexported fields
|
|
if !fieldType.IsExported() {
|
|
continue
|
|
}
|
|
|
|
colDef, err := getColumnDef(fieldType)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err)
|
|
}
|
|
|
|
if slices.Contains(skipFields, colDef.Name) {
|
|
continue
|
|
}
|
|
|
|
x, found, err := runEncodeHooks(
|
|
colDef,
|
|
fieldType.Type,
|
|
field,
|
|
append(
|
|
cfg.EncodeHooks,
|
|
encodeBasic(),
|
|
),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
|
|
}
|
|
|
|
if !found {
|
|
if reflect.Indirect(field).IsValid() {
|
|
x = reflect.Indirect(field).Interface()
|
|
}
|
|
}
|
|
|
|
res[keyPrefix+sqlColumnName(fieldType)] = x
|
|
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// EncodeValue encodes the given value.
|
|
func EncodeValue(ctx context.Context, colDef *ColumnDef, val interface{}, cfg EncodeConfig) (interface{}, error) {
|
|
fieldValue := reflect.ValueOf(val)
|
|
fieldType := reflect.TypeOf(val)
|
|
|
|
x, found, err := runEncodeHooks(
|
|
colDef,
|
|
fieldType,
|
|
fieldValue,
|
|
append(
|
|
cfg.EncodeHooks,
|
|
encodeBasic(),
|
|
),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to run encode hooks: %w", err)
|
|
}
|
|
|
|
if !found {
|
|
if reflect.Indirect(fieldValue).IsValid() {
|
|
x = reflect.Indirect(fieldValue).Interface()
|
|
}
|
|
}
|
|
|
|
return x, nil
|
|
}
|
|
|
|
func encodeBasic() EncodeFunc {
|
|
return func(col *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
|
kind := valType.Kind()
|
|
if kind == reflect.Ptr {
|
|
valType = valType.Elem()
|
|
kind = valType.Kind()
|
|
|
|
if val.IsNil() {
|
|
if !col.Nullable {
|
|
// we need to set the zero value here since the column
|
|
// is not marked as nullable
|
|
return reflect.New(valType).Elem().Interface(), true, nil
|
|
}
|
|
|
|
return nil, true, nil
|
|
}
|
|
|
|
val = val.Elem()
|
|
}
|
|
|
|
switch NormalizeKind(kind) { //nolint:exhaustive
|
|
case reflect.String,
|
|
reflect.Float64,
|
|
reflect.Bool,
|
|
reflect.Int,
|
|
reflect.Uint:
|
|
// sqlite package handles conversion of those types
|
|
// already
|
|
return val.Interface(), true, nil
|
|
|
|
case reflect.Slice:
|
|
if valType.Elem().Kind() == reflect.Uint8 {
|
|
// this is []byte
|
|
return val.Interface(), true, nil
|
|
}
|
|
fallthrough
|
|
|
|
default:
|
|
return nil, false, fmt.Errorf("cannot convert value of kind %s for use in SQLite", kind)
|
|
}
|
|
}
|
|
}
|
|
|
|
// DatetimeEncoder returns a new datetime encoder for the given time zone.
|
|
func DatetimeEncoder(loc *time.Location) EncodeFunc {
|
|
return func(colDef *ColumnDef, valType reflect.Type, val reflect.Value) (interface{}, bool, error) {
|
|
// if fieldType holds a pointer we need to dereference the value
|
|
ft := valType.String()
|
|
if valType.Kind() == reflect.Ptr {
|
|
ft = valType.Elem().String()
|
|
val = reflect.Indirect(val)
|
|
}
|
|
|
|
normalizedKind := NormalizeKind(valType.Kind())
|
|
|
|
// we only care about "time.Time" here
|
|
var t time.Time
|
|
switch {
|
|
case ft == "time.Time":
|
|
// handle the zero time as a NULL.
|
|
if !val.IsValid() || val.IsZero() {
|
|
return nil, true, nil
|
|
}
|
|
|
|
var ok bool
|
|
valInterface := val.Interface()
|
|
t, ok = valInterface.(time.Time)
|
|
if !ok {
|
|
return nil, false, errors.New("cannot convert reflect value to time.Time")
|
|
}
|
|
|
|
case valType.Kind() == reflect.String && colDef.IsTime:
|
|
var err error
|
|
t, err = time.Parse(time.RFC3339, val.String())
|
|
if err != nil {
|
|
return nil, false, fmt.Errorf("failed to parse time as RFC3339: %w", err)
|
|
}
|
|
|
|
case (normalizedKind == reflect.Int || normalizedKind == reflect.Uint || normalizedKind == reflect.Float64) && colDef.IsTime:
|
|
seconds := int64(0)
|
|
switch normalizedKind { //nolint:exhaustive // Previous switch case assures these types.
|
|
case reflect.Int:
|
|
seconds = val.Int()
|
|
case reflect.Uint:
|
|
seconds = int64(val.Uint())
|
|
case reflect.Float64:
|
|
seconds = int64(val.Float())
|
|
}
|
|
|
|
t = time.Unix(seconds, 0)
|
|
|
|
default:
|
|
// we don't care ...
|
|
return nil, false, nil
|
|
}
|
|
|
|
switch colDef.Type { //nolint:exhaustive
|
|
case sqlite.TypeInteger:
|
|
if colDef.UnixNano {
|
|
return t.UnixNano(), true, nil
|
|
}
|
|
return t.Unix(), true, nil
|
|
|
|
case sqlite.TypeText:
|
|
str := t.In(loc).Format(SqliteTimeFormat)
|
|
|
|
return str, true, nil
|
|
}
|
|
|
|
return nil, false, fmt.Errorf("cannot store time.Time in %s", colDef.Type)
|
|
}
|
|
}
|
|
|
|
func runEncodeHooks(colDef *ColumnDef, valType reflect.Type, val reflect.Value, hooks []EncodeFunc) (interface{}, bool, error) {
|
|
if valType == nil {
|
|
if !colDef.Nullable {
|
|
switch colDef.Type { //nolint:exhaustive
|
|
case sqlite.TypeBlob:
|
|
return []byte{}, true, nil
|
|
case sqlite.TypeFloat:
|
|
return 0.0, true, nil
|
|
case sqlite.TypeText:
|
|
return "", true, nil
|
|
case sqlite.TypeInteger:
|
|
return 0, true, nil
|
|
default:
|
|
return nil, false, fmt.Errorf("unsupported sqlite data type: %s", colDef.Type)
|
|
}
|
|
}
|
|
|
|
return nil, true, nil
|
|
}
|
|
|
|
for _, fn := range hooks {
|
|
res, end, err := fn(colDef, valType, val)
|
|
if err != nil {
|
|
return res, false, err
|
|
}
|
|
|
|
if end {
|
|
return res, true, nil
|
|
}
|
|
}
|
|
|
|
return nil, false, nil
|
|
}
|
|
|
|
// DefaultEncodeConfig holds the default encoding configuration.
|
|
var DefaultEncodeConfig = EncodeConfig{
|
|
EncodeHooks: []EncodeFunc{
|
|
DatetimeEncoder(time.UTC),
|
|
},
|
|
}
|