diff --git a/config/integration/api.go b/config/integration/api.go new file mode 100644 index 0000000..e6aa1b2 --- /dev/null +++ b/config/integration/api.go @@ -0,0 +1,3 @@ +package integration + +// API diff --git a/config/integration/module.go b/config/integration/module.go new file mode 100644 index 0000000..c9bba1b --- /dev/null +++ b/config/integration/module.go @@ -0,0 +1,3 @@ +package integration + +// register as module diff --git a/config/integration/persistence.go b/config/integration/persistence.go new file mode 100644 index 0000000..09cd938 --- /dev/null +++ b/config/integration/persistence.go @@ -0,0 +1,4 @@ +package integration + +// persist config file +// create callback function in config to get updates diff --git a/config/registry.go b/config/registry.go index b67c981..a998987 100644 --- a/config/registry.go +++ b/config/registry.go @@ -7,7 +7,7 @@ import ( "sync" ) -// Variable Type IDs for frontend Identification. Values over 100 are free for custom use. +// Variable Type IDs for frontend Identification. Values from 100 are free for custom use. const ( OptTypeString uint8 = 1 OptTypeStringArray uint8 = 2 diff --git a/database/accessor/accessor-json-bytes.go b/database/accessor/accessor-json-bytes.go new file mode 100644 index 0000000..08330d7 --- /dev/null +++ b/database/accessor/accessor-json-bytes.go @@ -0,0 +1,101 @@ +package accessor + +import ( + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// JSONBytesAccessor is a json string with get functions. +type JSONBytesAccessor struct { + json *[]byte +} + +// NewJSONBytesAccessor adds the Accessor interface to a JSON bytes string. +func NewJSONBytesAccessor(json *[]byte) *JSONBytesAccessor { + return &JSONBytesAccessor{ + json: json, + } +} + +// Set sets the value identified by key. +func (ja *JSONBytesAccessor) Set(key string, value interface{}) error { + result := gjson.GetBytes(*ja.json, key) + if result.Exists() { + switch value.(type) { + case string: + if result.Type != gjson.String { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + if result.Type != gjson.Number { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) + } + case bool: + if result.Type != gjson.True && result.Type != gjson.False { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) + } + } + } + + new, err := sjson.SetBytes(*ja.json, key, value) + if err != nil { + return err + } + *ja.json = new + return nil +} + +// GetString returns the string found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetString(key string) (value string, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() || result.Type != gjson.String { + return emptyString, false + } + return result.String(), true +} + +// GetInt returns the int found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetInt(key string) (value int64, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() || result.Type != gjson.Number { + return 0, false + } + return result.Int(), true +} + +// GetFloat returns the float found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetFloat(key string) (value float64, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() || result.Type != gjson.Number { + return 0, false + } + return result.Float(), true +} + +// GetBool returns the bool found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetBool(key string) (value bool, ok bool) { + result := gjson.GetBytes(*ja.json, key) + switch { + case !result.Exists(): + return false, false + case result.Type == gjson.True: + return true, true + case result.Type == gjson.False: + return false, true + default: + return false, false + } +} + +// Exists returns the whether the given key exists. +func (ja *JSONBytesAccessor) Exists(key string) bool { + result := gjson.GetBytes(*ja.json, key) + return result.Exists() +} + +// Type returns the accessor type as a string. +func (ja *JSONBytesAccessor) Type() string { + return "JSONBytesAccessor" +} diff --git a/database/accessor/accessor-json-string.go b/database/accessor/accessor-json-string.go new file mode 100644 index 0000000..1170418 --- /dev/null +++ b/database/accessor/accessor-json-string.go @@ -0,0 +1,101 @@ +package accessor + +import ( + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// JSONAccessor is a json string with get functions. +type JSONAccessor struct { + json *string +} + +// NewJSONAccessor adds the Accessor interface to a JSON string. +func NewJSONAccessor(json *string) *JSONAccessor { + return &JSONAccessor{ + json: json, + } +} + +// Set sets the value identified by key. +func (ja *JSONAccessor) Set(key string, value interface{}) error { + result := gjson.Get(*ja.json, key) + if result.Exists() { + switch value.(type) { + case string: + if result.Type != gjson.String { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + if result.Type != gjson.Number { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) + } + case bool: + if result.Type != gjson.True && result.Type != gjson.False { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) + } + } + } + + new, err := sjson.Set(*ja.json, key, value) + if err != nil { + return err + } + *ja.json = new + return nil +} + +// GetString returns the string found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetString(key string) (value string, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() || result.Type != gjson.String { + return emptyString, false + } + return result.String(), true +} + +// GetInt returns the int found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetInt(key string) (value int64, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() || result.Type != gjson.Number { + return 0, false + } + return result.Int(), true +} + +// GetFloat returns the float found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetFloat(key string) (value float64, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() || result.Type != gjson.Number { + return 0, false + } + return result.Float(), true +} + +// GetBool returns the bool found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetBool(key string) (value bool, ok bool) { + result := gjson.Get(*ja.json, key) + switch { + case !result.Exists(): + return false, false + case result.Type == gjson.True: + return true, true + case result.Type == gjson.False: + return false, true + default: + return false, false + } +} + +// Exists returns the whether the given key exists. +func (ja *JSONAccessor) Exists(key string) bool { + result := gjson.Get(*ja.json, key) + return result.Exists() +} + +// Type returns the accessor type as a string. +func (ja *JSONAccessor) Type() string { + return "JSONAccessor" +} diff --git a/database/accessor/accessor-struct.go b/database/accessor/accessor-struct.go new file mode 100644 index 0000000..b56c9b7 --- /dev/null +++ b/database/accessor/accessor-struct.go @@ -0,0 +1,149 @@ +package accessor + +import ( + "errors" + "fmt" + "reflect" +) + +// StructAccessor is a json string with get functions. +type StructAccessor struct { + object reflect.Value +} + +// NewStructAccessor adds the Accessor interface to a JSON string. +func NewStructAccessor(object interface{}) *StructAccessor { + return &StructAccessor{ + object: reflect.ValueOf(object).Elem(), + } +} + +// Set sets the value identified by key. +func (sa *StructAccessor) Set(key string, value interface{}) error { + field := sa.object.FieldByName(key) + if !field.IsValid() { + return errors.New("struct field does not exist") + } + if !field.CanSet() { + return fmt.Errorf("field %s or struct is immutable", field.String()) + } + + newVal := reflect.ValueOf(value) + + // set directly if type matches + if newVal.Kind() == field.Kind() { + field.Set(newVal) + return nil + } + + // handle special cases + switch field.Kind() { + + // ints + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var newInt int64 + switch newVal.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + newInt = newVal.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + newInt = int64(newVal.Uint()) + default: + return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String()) + } + if field.OverflowInt(newInt) { + return fmt.Errorf("setting field %s (%s) to %d would overflow", key, field.Kind().String(), newInt) + } + field.SetInt(newInt) + + // uints + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var newUint uint64 + switch newVal.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + newUint = uint64(newVal.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + newUint = newVal.Uint() + default: + return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String()) + } + if field.OverflowUint(newUint) { + return fmt.Errorf("setting field %s (%s) to %d would overflow", key, field.Kind().String(), newUint) + } + field.SetUint(newUint) + + // floats + case reflect.Float32, reflect.Float64: + switch newVal.Kind() { + case reflect.Float32, reflect.Float64: + field.SetFloat(newVal.Float()) + default: + return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String()) + } + default: + return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String()) + } + + return nil +} + +// GetString returns the string found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetString(key string) (value string, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() || field.Kind() != reflect.String { + return "", false + } + return field.String(), true +} + +// GetInt returns the int found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetInt(key string) (value int64, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() { + return 0, false + } + switch field.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return field.Int(), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(field.Uint()), true + default: + return 0, false + } +} + +// GetFloat returns the float found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetFloat(key string) (value float64, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() { + return 0, false + } + switch field.Kind() { + case reflect.Float32, reflect.Float64: + return field.Float(), true + default: + return 0, false + } +} + +// GetBool returns the bool found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetBool(key string) (value bool, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() || field.Kind() != reflect.Bool { + return false, false + } + return field.Bool(), true +} + +// Exists returns the whether the given key exists. +func (sa *StructAccessor) Exists(key string) bool { + field := sa.object.FieldByName(key) + if field.IsValid() { + return true + } + return false +} + +// Type returns the accessor type as a string. +func (sa *StructAccessor) Type() string { + return "StructAccessor" +} diff --git a/database/accessor/accessor.go b/database/accessor/accessor.go new file mode 100644 index 0000000..aedad26 --- /dev/null +++ b/database/accessor/accessor.go @@ -0,0 +1,18 @@ +package accessor + +const ( + emptyString = "" +) + +// Accessor provides an interface to supply the query matcher a method to retrieve values from an object. +type Accessor interface { + GetString(key string) (value string, ok bool) + GetInt(key string) (value int64, ok bool) + GetFloat(key string) (value float64, ok bool) + GetBool(key string) (value bool, ok bool) + Exists(key string) bool + + Set(key string, value interface{}) error + + Type() string +} diff --git a/database/accessor/accessor_test.go b/database/accessor/accessor_test.go new file mode 100644 index 0000000..d2a3a20 --- /dev/null +++ b/database/accessor/accessor_test.go @@ -0,0 +1,248 @@ +package accessor + +import ( + "encoding/json" + "testing" +) + +type TestStruct struct { + S string + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + F32 float32 + F64 float64 + B bool +} + +var ( + testStruct = &TestStruct{ + S: "banana", + I: 42, + I8: 42, + I16: 42, + I32: 42, + I64: 42, + UI: 42, + UI8: 42, + UI16: 42, + UI32: 42, + UI64: 42, + F32: 42.42, + F64: 42.42, + B: true, + } + testJSONBytes, _ = json.Marshal(testStruct) + testJSON = string(testJSONBytes) +) + +func testGetString(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue string) { + v, ok := acc.GetString(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get string with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get string with key %s, it returned %v", acc.Type(), key, v) + } + if v != expectedValue { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testGetInt(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue int64) { + v, ok := acc.GetInt(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get int with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get int with key %s, it returned %v", acc.Type(), key, v) + } + if v != expectedValue { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testGetFloat(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue float64) { + v, ok := acc.GetFloat(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get float with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get float with key %s, it returned %v", acc.Type(), key, v) + } + if int64(v) != int64(expectedValue) { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testGetBool(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue bool) { + v, ok := acc.GetBool(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get bool with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get bool with key %s, it returned %v", acc.Type(), key, v) + } + if v != expectedValue { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testExists(t *testing.T, acc Accessor, key string, shouldSucceed bool) { + ok := acc.Exists(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s should report key %s as existing", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should report key %s as non-existing", acc.Type(), key) + } +} + +func testSet(t *testing.T, acc Accessor, key string, shouldSucceed bool, valueToSet interface{}) { + err := acc.Set(key, valueToSet) + switch { + case err != nil && shouldSucceed: + t.Errorf("%s failed to set %s to %+v: %s", acc.Type(), key, valueToSet, err) + case err == nil && !shouldSucceed: + t.Errorf("%s should have failed to set %s to %+v", acc.Type(), key, valueToSet) + } +} + +func TestAccessor(t *testing.T) { + + // Test interface compliance + accs := []Accessor{ + NewJSONAccessor(&testJSON), + NewJSONBytesAccessor(&testJSONBytes), + NewStructAccessor(testStruct), + } + + // get + for _, acc := range accs { + testGetString(t, acc, "S", true, "banana") + testGetInt(t, acc, "I", true, 42) + testGetInt(t, acc, "I8", true, 42) + testGetInt(t, acc, "I16", true, 42) + testGetInt(t, acc, "I32", true, 42) + testGetInt(t, acc, "I64", true, 42) + testGetInt(t, acc, "UI", true, 42) + testGetInt(t, acc, "UI8", true, 42) + testGetInt(t, acc, "UI16", true, 42) + testGetInt(t, acc, "UI32", true, 42) + testGetInt(t, acc, "UI64", true, 42) + testGetFloat(t, acc, "F32", true, 42.42) + testGetFloat(t, acc, "F64", true, 42.42) + testGetBool(t, acc, "B", true, true) + } + + // set + for _, acc := range accs { + testSet(t, acc, "S", true, "coconut") + testSet(t, acc, "I", true, uint32(44)) + testSet(t, acc, "I8", true, uint64(44)) + testSet(t, acc, "I16", true, uint8(44)) + testSet(t, acc, "I32", true, uint16(44)) + testSet(t, acc, "I64", true, 44) + testSet(t, acc, "UI", true, 44) + testSet(t, acc, "UI8", true, int64(44)) + testSet(t, acc, "UI16", true, int32(44)) + testSet(t, acc, "UI32", true, int8(44)) + testSet(t, acc, "UI64", true, int16(44)) + testSet(t, acc, "F32", true, 44.44) + testSet(t, acc, "F64", true, 44.44) + testSet(t, acc, "B", true, false) + } + + // get again to check if new values were set + for _, acc := range accs { + testGetString(t, acc, "S", true, "coconut") + testGetInt(t, acc, "I", true, 44) + testGetInt(t, acc, "I8", true, 44) + testGetInt(t, acc, "I16", true, 44) + testGetInt(t, acc, "I32", true, 44) + testGetInt(t, acc, "I64", true, 44) + testGetInt(t, acc, "UI", true, 44) + testGetInt(t, acc, "UI8", true, 44) + testGetInt(t, acc, "UI16", true, 44) + testGetInt(t, acc, "UI32", true, 44) + testGetInt(t, acc, "UI64", true, 44) + testGetFloat(t, acc, "F32", true, 44.44) + testGetFloat(t, acc, "F64", true, 44.44) + testGetBool(t, acc, "B", true, false) + } + + // failures + for _, acc := range accs { + testSet(t, acc, "S", false, true) + testSet(t, acc, "S", false, false) + testSet(t, acc, "S", false, 1) + testSet(t, acc, "S", false, 1.1) + + testSet(t, acc, "I", false, "1") + testSet(t, acc, "I8", false, "1") + testSet(t, acc, "I16", false, "1") + testSet(t, acc, "I32", false, "1") + testSet(t, acc, "I64", false, "1") + testSet(t, acc, "UI", false, "1") + testSet(t, acc, "UI8", false, "1") + testSet(t, acc, "UI16", false, "1") + testSet(t, acc, "UI32", false, "1") + testSet(t, acc, "UI64", false, "1") + + testSet(t, acc, "F32", false, "1.1") + testSet(t, acc, "F64", false, "1.1") + + testSet(t, acc, "B", false, "false") + testSet(t, acc, "B", false, 1) + testSet(t, acc, "B", false, 1.1) + } + + // get again to check if values werent changed when an error occurred + for _, acc := range accs { + testGetString(t, acc, "S", true, "coconut") + testGetInt(t, acc, "I", true, 44) + testGetInt(t, acc, "I8", true, 44) + testGetInt(t, acc, "I16", true, 44) + testGetInt(t, acc, "I32", true, 44) + testGetInt(t, acc, "I64", true, 44) + testGetInt(t, acc, "UI", true, 44) + testGetInt(t, acc, "UI8", true, 44) + testGetInt(t, acc, "UI16", true, 44) + testGetInt(t, acc, "UI32", true, 44) + testGetInt(t, acc, "UI64", true, 44) + testGetFloat(t, acc, "F32", true, 44.44) + testGetFloat(t, acc, "F64", true, 44.44) + testGetBool(t, acc, "B", true, false) + } + + // test existence + for _, acc := range accs { + testExists(t, acc, "S", true) + testExists(t, acc, "I", true) + testExists(t, acc, "I8", true) + testExists(t, acc, "I16", true) + testExists(t, acc, "I32", true) + testExists(t, acc, "I64", true) + testExists(t, acc, "UI", true) + testExists(t, acc, "UI8", true) + testExists(t, acc, "UI16", true) + testExists(t, acc, "UI32", true) + testExists(t, acc, "UI64", true) + testExists(t, acc, "F32", true) + testExists(t, acc, "F64", true) + testExists(t, acc, "B", true) + } + + // test non-existence + for _, acc := range accs { + testExists(t, acc, "X", false) + } + +} diff --git a/database/base.go b/database/base.go deleted file mode 100644 index a9bf424..0000000 --- a/database/base.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "errors" - "strings" - - "github.com/Safing/safing-core/database/dbutils" - - "github.com/ipfs/go-datastore" - uuid "github.com/satori/go.uuid" -) - -type Base struct { - dbKey *datastore.Key - meta *dbutils.Meta -} - -func (m *Base) SetKey(key *datastore.Key) { - m.dbKey = key -} - -func (m *Base) GetKey() *datastore.Key { - return m.dbKey -} - -func (m *Base) FmtKey() string { - return m.dbKey.String() -} - -func (m *Base) Meta() *dbutils.Meta { - return m.meta -} - -func (m *Base) CreateObject(namespace *datastore.Key, name string, model Model) error { - var newKey datastore.Key - if name == "" { - newKey = NewInstance(namespace.ChildString(getTypeName(model)), strings.Replace(uuid.NewV4().String(), "-", "", -1)) - } else { - newKey = NewInstance(namespace.ChildString(getTypeName(model)), name) - } - m.dbKey = &newKey - return Create(*m.dbKey, model) -} - -func (m *Base) SaveObject(model Model) error { - if m.dbKey == nil { - return errors.New("cannot save new object, use Create() instead") - } - return Update(*m.dbKey, model) -} - -func (m *Base) Delete() error { - if m.dbKey == nil { - return errors.New("cannot delete object unsaved object") - } - return Delete(*m.dbKey) -} - -func NewInstance(k datastore.Key, s string) datastore.Key { - return datastore.NewKey(k.String() + ":" + s) -} diff --git a/database/boilerplate_test.go b/database/boilerplate_test.go new file mode 100644 index 0000000..3fd5dac --- /dev/null +++ b/database/boilerplate_test.go @@ -0,0 +1,64 @@ +package database + +import ( + "fmt" + "sync" + + "github.com/Safing/portbase/database/record" +) + +type Example struct { + record.Base + sync.Mutex + + Name string + Score int +} + +var ( + exampleDB = NewInterface(nil) +) + +// GetExample gets an Example from the database. +func GetExample(key string) (*Example, error) { + r, err := exampleDB.Get(key) + if err != nil { + return nil, err + } + + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + new := &Example{} + err = record.Unwrap(r, new) + if err != nil { + return nil, err + } + return new, nil + } + + // or adjust type + new, ok := r.(*Example) + if !ok { + return nil, fmt.Errorf("record not of type *Example, but %T", r) + } + return new, nil +} + +func (e *Example) Save() error { + return exampleDB.Put(e) +} + +func (e *Example) SaveAs(key string) error { + e.SetKey(key) + return exampleDB.PutNew(e) +} + +func NewExample(key, name string, score int) *Example { + new := &Example{ + Name: name, + Score: score, + } + new.SetKey(key) + return new +} diff --git a/database/controller.go b/database/controller.go new file mode 100644 index 0000000..438072f --- /dev/null +++ b/database/controller.go @@ -0,0 +1,183 @@ +package database + +import ( + "sync" + + "github.com/tevino/abool" + + "github.com/Safing/portbase/database/iterator" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/database/storage" +) + +// A Controller takes care of all the extra database logic. +type Controller struct { + storage storage.Interface + + hooks []*RegisteredHook + subscriptions []*Subscription + + writeLock sync.RWMutex + readLock sync.RWMutex + migrating *abool.AtomicBool // TODO + hibernating *abool.AtomicBool // TODO +} + +// newController creates a new controller for a storage. +func newController(storageInt storage.Interface) (*Controller, error) { + return &Controller{ + storage: storageInt, + migrating: abool.NewBool(false), + hibernating: abool.NewBool(false), + }, nil +} + +// ReadOnly returns whether the storage is read only. +func (c *Controller) ReadOnly() bool { + return c.storage.ReadOnly() +} + +// Injected returns whether the storage is injected. +func (c *Controller) Injected() bool { + return c.storage.Injected() +} + +// Get return the record with the given key. +func (c *Controller) Get(key string) (record.Record, error) { + if shuttingDown.IsSet() { + return nil, ErrShuttingDown + } + + c.readLock.RLock() + defer c.readLock.RUnlock() + + // process hooks + for _, hook := range c.hooks { + if hook.h.UsesPreGet() && hook.q.MatchesKey(key) { + err := hook.h.PreGet(key) + if err != nil { + return nil, err + } + } + } + + r, err := c.storage.Get(key) + if err != nil { + // replace not found error + if err == storage.ErrNotFound { + return nil, ErrNotFound + } + return nil, err + } + + r.Lock() + defer r.Unlock() + + // process hooks + for _, hook := range c.hooks { + if hook.h.UsesPostGet() && hook.q.Matches(r) { + r, err = hook.h.PostGet(r) + if err != nil { + return nil, err + } + } + } + + if !r.Meta().CheckValidity() { + return nil, ErrNotFound + } + + return r, nil +} + +// Put saves a record in the database. +func (c *Controller) Put(r record.Record) (err error) { + if shuttingDown.IsSet() { + return ErrShuttingDown + } + + if c.ReadOnly() { + return ErrReadOnly + } + + r.Lock() + defer r.Unlock() + + // process hooks + for _, hook := range c.hooks { + if hook.h.UsesPrePut() && hook.q.Matches(r) { + r, err = hook.h.PrePut(r) + if err != nil { + return err + } + } + } + + if r.Meta() == nil { + r.SetMeta(&record.Meta{}) + } + r.Meta().Update() + + c.writeLock.RLock() + defer c.writeLock.RUnlock() + + err = c.storage.Put(r) + if err != nil { + return err + } + + // process subscriptions + for _, sub := range c.subscriptions { + if sub.q.Matches(r) { + select { + case sub.Feed <- r: + default: + } + } + } + + return nil +} + +// Query executes the given query on the database. +func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + if shuttingDown.IsSet() { + return nil, ErrShuttingDown + } + + c.readLock.RLock() + it, err := c.storage.Query(q, local, internal) + if err != nil { + c.readLock.RUnlock() + return nil, err + } + + go c.readUnlockerAfterQuery(it) + return it, nil +} + +func (c *Controller) readUnlockerAfterQuery(it *iterator.Iterator) { + <- it.Done + c.readLock.RUnlock() +} + +// Maintain runs the Maintain method no the storage. +func (c *Controller) Maintain() error { + c.writeLock.RLock() + defer c.writeLock.RUnlock() + return c.storage.Maintain() +} + +// MaintainThorough runs the MaintainThorough method no the storage. +func (c *Controller) MaintainThorough() error { + c.writeLock.RLock() + defer c.writeLock.RUnlock() + return c.storage.MaintainThorough() +} + +// Shutdown shuts down the storage. +func (c *Controller) Shutdown() error { + // TODO: should we wait for gets/puts/queries to complete? + return c.storage.Shutdown() +} diff --git a/database/controllers.go b/database/controllers.go new file mode 100644 index 0000000..5b9d002 --- /dev/null +++ b/database/controllers.go @@ -0,0 +1,87 @@ +package database + +import ( + "errors" + "sync" + "fmt" + + "github.com/Safing/portbase/database/storage" +) + +var ( + controllers = make(map[string]*Controller) + controllersLock sync.Mutex +) + +func getController(name string) (*Controller, error) { + if !initialized.IsSet() { + return nil, errors.New("database not initialized") + } + + controllersLock.Lock() + defer controllersLock.Unlock() + + // return database if already started + controller, ok := controllers[name] + if ok { + return controller, nil + } + + // get db registration + registeredDB, err := getDatabase(name) + if err != nil { + return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err) + } + + // get location + dbLocation, err := getLocation(name, registeredDB.StorageType) + if err != nil { + return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err) + } + + // start database + storageInt, err := storage.StartDatabase(name, registeredDB.StorageType, dbLocation) + if err != nil { + return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err) + } + + // create controller + controller, err = newController(storageInt) + if err != nil { + return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err) + } + + controllers[name] = controller + return controller, nil +} + +// InjectDatabase injects an already running database into the system. +func InjectDatabase(name string, storageInt storage.Interface) error { + controllersLock.Lock() + defer controllersLock.Unlock() + + _, ok := controllers[name] + if ok { + return errors.New(`database "%s" already loaded`) + } + + registryLock.Lock() + defer registryLock.Unlock() + + // check if database is registered + registeredDB, ok := registry[name] + if !ok { + return fmt.Errorf(`database "%s" not registered`, name) + } + if registeredDB.StorageType != "injected" { + return fmt.Errorf(`database not of type "injected"`) + } + + controller, err := newController(storageInt) + if err != nil { + return fmt.Errorf(`could not create controller for database %s: %s`, name, err) + } + + controllers[name] = controller + return nil +} diff --git a/database/database.go b/database/database.go index b1b7a5f..e8e4504 100644 --- a/database/database.go +++ b/database/database.go @@ -1,151 +1,32 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package database import ( "errors" - "fmt" - "os" - "path" - "strings" - - ds "github.com/ipfs/go-datastore" - dsq "github.com/ipfs/go-datastore/query" - mount "github.com/ipfs/go-datastore/syncmount" - - "github.com/Safing/safing-core/database/dbutils" - "github.com/Safing/safing-core/database/ds/channelshim" - "github.com/Safing/safing-core/database/ds/leveldb" - "github.com/Safing/safing-core/log" - "github.com/Safing/safing-core/meta" + "time" ) -// TODO: do not let other modules panic, even if database module crashes. -var db ds.Datastore - -var ErrNotFound = errors.New("database: entry could not be found") - -func init() { - if strings.HasSuffix(os.Args[0], ".test") { - // testing setup - log.Warning("===== DATABASE RUNNING IN TEST MODE =====") - db = channelshim.NewChanneledDatastore(ds.NewMapDatastore()) - return - } - - // sfsDB, err := simplefs.NewDatastore(meta.DatabaseDir()) - // if err != nil { - // fmt.Fprintf(os.Stderr, "FATAL ERROR: could not init simplefs database: %s\n", err) - // os.Exit(1) - // } - - ldb, err := leveldb.NewDatastore(path.Join(meta.DatabaseDir(), "leveldb"), &leveldb.Options{}) - if err != nil { - fmt.Fprintf(os.Stderr, "FATAL ERROR: could not init simplefs database: %s\n", err) - os.Exit(1) - } - - mapDB := ds.NewMapDatastore() - - db = channelshim.NewChanneledDatastore(mount.New([]mount.Mount{ - mount.Mount{ - Prefix: ds.NewKey("/Run"), - Datastore: mapDB, - }, - mount.Mount{ - Prefix: ds.NewKey("/"), - Datastore: ldb, - }, - })) - +// Database holds information about registered databases +type Database struct { + Name string + Description string + StorageType string + PrimaryAPI string + Registered time.Time + LastUpdated time.Time + LastLoaded time.Time } -// func Batch() (ds.Batch, error) { -// return db.Batch() -// } - -// func Close() error { -// return db.Close() -// } - -func Get(key *ds.Key) (Model, error) { - data, err := db.Get(*key) - if err != nil { - switch err { - case ds.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - } - } - model, ok := data.(Model) - if !ok { - return nil, errors.New("database did not return model") - } - return model, nil +// MigrateTo migrates the database to another storage type. +func (db *Database) MigrateTo(newStorageType string) error { + return errors.New("not implemented yet") // TODO } -func GetAndEnsureModel(namespace *ds.Key, name string, model Model) (Model, error) { - newKey := namespace.ChildString(getTypeName(model)).Instance(name) - - data, err := Get(&newKey) - if err != nil { - return nil, err - } - - newModel, err := EnsureModel(data, model) - if err != nil { - return nil, err - } - - newModel.SetKey(&newKey) - - return newModel, nil +// Loaded updates the LastLoaded timestamp. +func (db *Database) Loaded() { + db.LastLoaded = time.Now().Round(time.Second) } -func Has(key ds.Key) (exists bool, err error) { - return db.Has(key) -} - -func Create(key ds.Key, model Model) (err error) { - handleCreateSubscriptions(model) - err = db.Put(key, model) - if err != nil { - log.Tracef("database: failed to create entry %s: %s", key, err) - } - return err -} - -func Update(key ds.Key, model Model) (err error) { - handleUpdateSubscriptions(model) - err = db.Put(key, model) - if err != nil { - log.Tracef("database: failed to update entry %s: %s", key, err) - } - return err -} - -func Delete(key ds.Key) (err error) { - handleDeleteSubscriptions(&key) - return db.Delete(key) -} - -func Query(q dsq.Query) (dsq.Results, error) { - return db.Query(q) -} - -func RawGet(key ds.Key) (*dbutils.Wrapper, error) { - data, err := db.Get(key) - if err != nil { - return nil, err - } - wrapped, ok := data.(*dbutils.Wrapper) - if !ok { - return nil, errors.New("returned data is not a wrapper") - } - return wrapped, nil -} - -func RawPut(key ds.Key, value interface{}) error { - return db.Put(key, value) +// Updated updates the LastUpdated timestamp. +func (db *Database) Updated() { + db.LastUpdated = time.Now().Round(time.Second) } diff --git a/database/database_test.go b/database/database_test.go new file mode 100644 index 0000000..4cc94c3 --- /dev/null +++ b/database/database_test.go @@ -0,0 +1,162 @@ +package database + +import ( + "fmt" + "io/ioutil" + "log" + "os" + "reflect" + "runtime/pprof" + "testing" + "time" + + q "github.com/Safing/portbase/database/query" + _ "github.com/Safing/portbase/database/storage/badger" +) + +func makeKey(dbName, key string) string { + return fmt.Sprintf("%s:%s", dbName, key) +} + +func testDatabase(t *testing.T, storageType string) { + dbName := fmt.Sprintf("testing-%s", storageType) + _, err := Register(&Database{ + Name: dbName, + Description: fmt.Sprintf("Unit Test Database for %s", storageType), + StorageType: storageType, + PrimaryAPI: "", + }) + if err != nil { + t.Fatal(err) + } + + // hook + hook, err := RegisterHook(q.New(dbName).MustBeValid(), &HookBase{}) + if err != nil { + t.Fatal(err) + } + + // sub + sub, err := Subscribe(q.New(dbName).MustBeValid()) + if err != nil { + t.Fatal(err) + } + + // interface + db := NewInterface(nil) + + A := NewExample(makeKey(dbName, "A"), "Herbert", 411) + err = A.Save() + if err != nil { + t.Fatal(err) + } + + B := NewExample(makeKey(dbName, "B"), "Fritz", 347) + err = B.Save() + if err != nil { + t.Fatal(err) + } + + C := NewExample(makeKey(dbName, "C"), "Norbert", 217) + err = C.Save() + if err != nil { + t.Fatal(err) + } + + exists, err := db.Exists(makeKey(dbName, "A")) + if err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("record %s should exist!", makeKey(dbName, "A")) + } + + A1, err := GetExample(makeKey(dbName, "A")) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(A, A1) { + log.Fatalf("A and A1 mismatch, A1: %v", A1) + } + + query, err := q.New(dbName).Where( + q.And( + q.Where("Name", q.EndsWith, "bert"), + q.Where("Score", q.GreaterThan, 100), + ), + ).Check() + if err != nil { + t.Fatal(err) + } + + it, err := db.Query(query) + if err != nil { + t.Fatal(err) + } + + cnt := 0 + for _ = range it.Next { + cnt++ + } + if it.Error != nil { + t.Fatal(it.Error) + } + if cnt != 2 { + t.Fatal("expected two records") + } + + err = hook.Cancel() + if err != nil { + t.Fatal(err) + } + err = sub.Cancel() + if err != nil { + t.Fatal(err) + } + +} + +func TestDatabaseSystem(t *testing.T) { + + // panic after 10 seconds, to check for locks + go func() { + time.Sleep(10 * time.Second) + fmt.Println("===== TAKING TOO LONG - PRINTING STACK TRACES =====") + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + os.Exit(1) + }() + + testDir, err := ioutil.TempDir("", "testing-") + if err != nil { + t.Fatal(err) + } + + err = Initialize(testDir) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(testDir) // clean up + + testDatabase(t, "badger") + + err = MaintainRecordStates() + if err != nil { + t.Fatal(err) + } + + err = Maintain() + if err != nil { + t.Fatal(err) + } + + err = MaintainThorough() + if err != nil { + t.Fatal(err) + } + + err = Shutdown() + if err != nil { + t.Fatal(err) + } + +} diff --git a/database/dbmodule/db.go b/database/dbmodule/db.go new file mode 100644 index 0000000..1b7d1d8 --- /dev/null +++ b/database/dbmodule/db.go @@ -0,0 +1,43 @@ +package dbmodule + +import ( + "errors" + "flag" + "sync" + + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/modules" +) + +var ( + databaseDir string + shutdownSignal = make(chan struct{}) + maintenanceWg sync.WaitGroup +) + +func init() { + flag.StringVar(&databaseDir, "db", "", "set database directory") + + modules.Register("database", prep, start, stop) +} + +func prep() error { + if databaseDir == "" { + return errors.New("no database location specified, set with `-db=/path/to/db`") + } + return nil +} + +func start() error { + err := database.Initialize(databaseDir) + if err == nil { + go maintainer() + } + return err +} + +func stop() error { + close(shutdownSignal) + maintenanceWg.Wait() + return database.Shutdown() +} diff --git a/database/dbmodule/maintenance.go b/database/dbmodule/maintenance.go new file mode 100644 index 0000000..4bdd6fe --- /dev/null +++ b/database/dbmodule/maintenance.go @@ -0,0 +1,36 @@ +package dbmodule + +import ( + "time" + + "github.com/Safing/portbase/database" + "github.com/Safing/portbase/log" +) + +func maintainer() { + ticker := time.NewTicker(10 * time.Minute) + longTicker := time.NewTicker(1 * time.Hour) + maintenanceWg.Add(1) + + for { + select { + case <- ticker.C: + err := database.Maintain() + if err != nil { + log.Errorf("database: maintenance error: %s", err) + } + case <- longTicker.C: + err := database.MaintainRecordStates() + if err != nil { + log.Errorf("database: record states maintenance error: %s", err) + } + err = database.MaintainThorough() + if err != nil { + log.Errorf("database: thorough maintenance error: %s", err) + } + case <-shutdownSignal: + maintenanceWg.Done() + return + } + } +} diff --git a/database/dbutils/wrapper.go b/database/dbutils/wrapper.go deleted file mode 100644 index 98ef68c..0000000 --- a/database/dbutils/wrapper.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -/* -Package dbutils provides important function for datastore backends without creating an import loop. -*/ -package dbutils - -import ( - "errors" - "fmt" - - "github.com/ipfs/go-datastore" - - "github.com/Safing/safing-core/formats/dsd" - "github.com/Safing/safing-core/formats/varint" -) - -type Wrapper struct { - dbKey *datastore.Key - meta *Meta - Format uint8 - Data []byte -} - -func NewWrapper(key *datastore.Key, data []byte) (*Wrapper, error) { - // line crashes with: panic: runtime error: index out of range - format, _, err := varint.Unpack8(data) - if err != nil { - return nil, fmt.Errorf("database: could not get dsd format: %s", err) - } - - new := &Wrapper{ - Format: format, - Data: data, - } - new.SetKey(key) - - return new, nil -} - -func (w *Wrapper) SetKey(key *datastore.Key) { - w.dbKey = key -} - -func (w *Wrapper) GetKey() *datastore.Key { - return w.dbKey -} - -func (w *Wrapper) FmtKey() string { - return w.dbKey.String() -} - -func DumpModel(uncertain interface{}, storageType uint8) ([]byte, error) { - wrapped, ok := uncertain.(*Wrapper) - if ok { - if storageType != dsd.AUTO && storageType != wrapped.Format { - return nil, errors.New("could not dump model, wrapped object format mismatch") - } - return wrapped.Data, nil - } - - dumped, err := dsd.Dump(uncertain, storageType) - if err != nil { - return nil, err - } - return dumped, nil -} diff --git a/database/doc.go b/database/doc.go index f5c0a9c..d93ff5f 100644 --- a/database/doc.go +++ b/database/doc.go @@ -3,98 +3,63 @@ /* Package database provides a universal interface for interacting with the database. -The Lazy Database +A Lazy Database The database system can handle Go structs as well as serialized data by the dsd package. While data is in transit within the system, it does not know which form it currently has. Only when it reaches its destination, it must ensure that it is either of a certain type or dump it. -Internals +Record Interface -The database system uses the Model interface to transparently handle all types of structs that get saved in the database. Structs include Base struct to fulfill most parts of the Model interface. +The database system uses the Record interface to transparently handle all types of structs that get saved in the database. Structs include the Base struct to fulfill most parts of the Record interface. -Boilerplate Code +Boilerplate Code: -Receiving model, using as struct: + type Example struct { + record.Base + sync.Mutex - // At some point, declare a pointer to your model. - // This is only used to identify the model, so you can reuse it safely for this purpose - var cowModel *Cow // only use this as parameter for database.EnsureModel-like functions - - receivedModel := <- models // chan database.Model - cow, ok := database.SilentEnsureModel(receivedModel, cowModel).(*Cow) - if !ok { - panic("received model does not match expected model") - } - - // more verbose, in case you need better error handling - receivedModel := <- models // chan database.Model - genericModel, err := database.EnsureModel(receivedModel, cowModel) - if err != nil { - panic(err) - } - cow, ok := genericModel.(*Cow) - if !ok { - panic("received model does not match expected model") - } - -Receiving a model, dumping: - - // receivedModel <- chan database.Model - bytes, err := database.DumpModel(receivedModel, dsd.JSON) // or other dsd format - if err != nil { - panic(err) - } - -Model definition: - - // Cow makes moo. - type Cow struct { - database.Base - // Fields... - } - - var cowModel *Cow // only use this as parameter for database.EnsureModel-like functions - - func init() { - database.RegisterModel(cowModel, func() database.Model { return new(Cow) }) - } - - // this all you need, but you might find the following code helpful: - - var cowNamespace = datastore.NewKey("/Cow") - - // Create saves Cow with the provided name in the default namespace. - func (m *Cow) Create(name string) error { - return m.CreateObject(&cowNamespace, name, m) - } - - // CreateInNamespace saves Cow with the provided name in the provided namespace. - func (m *Cow) CreateInNamespace(namespace *datastore.Key, name string) error { - return m.CreateObject(namespace, name, m) - } - - // Save saves Cow. - func (m *Cow) Save() error { - return m.SaveObject(m) - } - - // GetCow fetches Cow with the provided name from the default namespace. - func GetCow(name string) (*Cow, error) { - return GetCowFromNamespace(&cowNamespace, name) - } - - // GetCowFromNamespace fetches Cow with the provided name from the provided namespace. - func GetCowFromNamespace(namespace *datastore.Key, name string) (*Cow, error) { - object, err := database.GetAndEnsureModel(namespace, name, cowModel) - if err != nil { - return nil, err + Name string + Score int } - model, ok := object.(*Cow) - if !ok { - return nil, database.NewMismatchError(object, cowModel) + + var ( + db = database.NewInterface(nil) + ) + + // GetExample gets an Example from the database. + func GetExample(key string) (*Example, error) { + r, err := db.Get(key) + if err != nil { + return nil, err + } + + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + new := &Example{} + err = record.Unwrap(r, new) + if err != nil { + return nil, err + } + return new, nil + } + + // or adjust type + new, ok := r.(*Example) + if !ok { + return nil, fmt.Errorf("record not of type *Example, but %T", r) + } + return new, nil + } + + func (e *Example) Save() error { + return db.Put(e) + } + + func (e *Example) SaveAs(key string) error { + e.SetKey(key) + return db.PutNew(e) } - return model, nil - } */ package database diff --git a/database/easyquery.go b/database/easyquery.go deleted file mode 100644 index ee4db82..0000000 --- a/database/easyquery.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "errors" - "fmt" - "strings" - - dsq "github.com/ipfs/go-datastore/query" -) - -type FilterMaxDepth struct { - MaxDepth int -} - -func (f FilterMaxDepth) Filter(entry dsq.Entry) bool { - return strings.Count(entry.Key, "/") <= f.MaxDepth -} - -type FilterKeyLength struct { - Length int -} - -func (f FilterKeyLength) Filter(entry dsq.Entry) bool { - return len(entry.Key) == f.Length -} - -func EasyQueryIterator(subscriptionKey string) (dsq.Results, error) { - query := dsq.Query{} - - namespaces := strings.Split(subscriptionKey, "/")[1:] - lastSpace := "" - if len(namespaces) != 0 { - lastSpace = namespaces[len(namespaces)-1] - } - - switch { - case lastSpace == "": - // get all children - query.Prefix = subscriptionKey - case strings.HasPrefix(lastSpace, "*"): - // get children to defined depth - query.Prefix = strings.Trim(subscriptionKey, "*") - query.Filters = []dsq.Filter{ - FilterMaxDepth{len(lastSpace) + len(namespaces) - 1}, - } - case strings.Contains(lastSpace, ":"): - query.Prefix = subscriptionKey - query.Filters = []dsq.Filter{ - FilterKeyLength{len(query.Prefix)}, - } - default: - // get only from this location and this type - query.Prefix = subscriptionKey + ":" - query.Filters = []dsq.Filter{ - FilterMaxDepth{len(namespaces)}, - } - } - - // log.Tracef("easyquery: %s has prefix %s", subscriptionKey, query.Prefix) - - results, err := db.Query(query) - if err != nil { - return nil, errors.New(fmt.Sprintf("easyquery: %s", err)) - } - - return results, nil -} - -func EasyQuery(subscriptionKey string) (*[]dsq.Entry, error) { - - results, err := EasyQueryIterator(subscriptionKey) - if err != nil { - return nil, err - } - - entries, err := results.Rest() - if err != nil { - return nil, errors.New(fmt.Sprintf("easyquery: %s", err)) - } - - return &entries, nil -} diff --git a/database/easyquery_test.go b/database/easyquery_test.go deleted file mode 100644 index 914c0b8..0000000 --- a/database/easyquery_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "testing" - - datastore "github.com/ipfs/go-datastore" -) - -func testQuery(t *testing.T, queryString string, expecting []string) { - - entries, err := EasyQuery(queryString) - if err != nil { - t.Errorf("error in query %s: %s", queryString, err) - } - - totalExcepted := len(expecting) - total := 0 - fail := false - - keys := datastore.EntryKeys(*entries) - -resultLoop: - for _, key := range keys { - total++ - for _, expectedName := range expecting { - if key.Name() == expectedName { - continue resultLoop - } - } - fail = true - break - } - - if !fail && total == totalExcepted { - return - } - - t.Errorf("Query %s got %s, expected %s", queryString, keys, expecting) - -} - -func TestEasyQuery(t *testing.T) { - - // setup test data - (&(TestingModel{})).CreateInNamespace("EasyQuery", "1") - (&(TestingModel{})).CreateInNamespace("EasyQuery", "2") - (&(TestingModel{})).CreateInNamespace("EasyQuery", "3") - (&(TestingModel{})).CreateInNamespace("EasyQuery/A", "4") - (&(TestingModel{})).CreateInNamespace("EasyQuery/A/B", "5") - (&(TestingModel{})).CreateInNamespace("EasyQuery/A/B/C", "6") - (&(TestingModel{})).CreateInNamespace("EasyQuery/A/B/C/D", "7") - - (&(TestingModel{})).CreateWithTypeName("EasyQuery", "ConfigModel", "X") - (&(TestingModel{})).CreateWithTypeName("EasyQuery", "ConfigModel", "Y") - (&(TestingModel{})).CreateWithTypeName("EasyQuery/A", "ConfigModel", "Z") - - testQuery(t, "/Tests/EasyQuery/TestingModel", []string{"1", "2", "3"}) - testQuery(t, "/Tests/EasyQuery/TestingModel:1", []string{"1"}) - - testQuery(t, "/Tests/EasyQuery/ConfigModel", []string{"X", "Y"}) - testQuery(t, "/Tests/EasyQuery/ConfigModel:Y", []string{"Y"}) - - testQuery(t, "/Tests/EasyQuery/A/", []string{"Z", "4", "5", "6", "7"}) - testQuery(t, "/Tests/EasyQuery/A/B/**", []string{"5", "6"}) - -} diff --git a/database/errors.go b/database/errors.go new file mode 100644 index 0000000..55d42e6 --- /dev/null +++ b/database/errors.go @@ -0,0 +1,15 @@ +// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. + +package database + +import ( + "errors" +) + +// Errors +var ( + ErrNotFound = errors.New("database entry could not be found") + ErrPermissionDenied = errors.New("access to database record denied") + ErrReadOnly = errors.New("database is read only") + ErrShuttingDown = errors.New("database system is shutting down") +) diff --git a/database/hook.go b/database/hook.go new file mode 100644 index 0000000..1edac94 --- /dev/null +++ b/database/hook.go @@ -0,0 +1,70 @@ +package database + +import ( + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" +) + +// Hook describes a hook +type Hook interface { + UsesPreGet() bool + PreGet(dbKey string) error + + UsesPostGet() bool + PostGet(r record.Record) (record.Record, error) + + UsesPrePut() bool + PrePut(r record.Record) (record.Record, error) +} + +// RegisteredHook is a registered database hook. +type RegisteredHook struct { + q *query.Query + h Hook +} + +// RegisterHook registeres a hook for records matching the given query in the database. +func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) { + _, err := q.Check() + if err != nil { + return nil, err + } + + c, err := getController(q.DatabaseName()) + if err != nil { + return nil, err + } + + c.readLock.Lock() + defer c.readLock.Unlock() + c.writeLock.Lock() + defer c.writeLock.Unlock() + + rh := &RegisteredHook{ + q: q, + h: hook, + } + c.hooks = append(c.hooks, rh) + return rh, nil +} + +// Cancel unhooks the hook. +func (h *RegisteredHook) Cancel() error { + c, err := getController(h.q.DatabaseName()) + if err != nil { + return err + } + + c.readLock.Lock() + defer c.readLock.Unlock() + c.writeLock.Lock() + defer c.writeLock.Unlock() + + for key, hook := range c.hooks { + if hook.q == h.q { + c.hooks = append(c.hooks[:key], c.hooks[key+1:]...) + return nil + } + } + return nil +} diff --git a/database/hookbase.go b/database/hookbase.go new file mode 100644 index 0000000..fd42748 --- /dev/null +++ b/database/hookbase.go @@ -0,0 +1,39 @@ +package database + +import ( + "github.com/Safing/portbase/database/record" +) + +// HookBase implements the Hook interface and provides dummy functions to reduce boilerplate. +type HookBase struct { +} + +// UsesPreGet implements the Hook interface and returns false. +func (b *HookBase) UsesPreGet() bool { + return false +} + +// UsesPostGet implements the Hook interface and returns false. +func (b *HookBase) UsesPostGet() bool { + return false +} + +// UsesPrePut implements the Hook interface and returns false. +func (b *HookBase) UsesPrePut() bool { + return false +} + +// PreGet implements the Hook interface. +func (b *HookBase) PreGet(dbKey string) error { + return nil +} + +// PostGet implements the Hook interface. +func (b *HookBase) PostGet(r record.Record) (record.Record, error) { + return r, nil +} + +// PrePut implements the Hook interface. +func (b *HookBase) PrePut(r record.Record) (record.Record, error) { + return r, nil +} diff --git a/database/interface.go b/database/interface.go new file mode 100644 index 0000000..de5910c --- /dev/null +++ b/database/interface.go @@ -0,0 +1,237 @@ +package database + +import ( + "errors" + "fmt" + + "github.com/Safing/portbase/database/accessor" + "github.com/Safing/portbase/database/iterator" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" +) + +const ( + getDBFromKey = "" +) + +// Interface provides a method to access the database with attached options. +type Interface struct { + options *Options +} + +// Options holds options that may be set for an Interface instance. +type Options struct { + Local bool + Internal bool + AlwaysMakeSecret bool + AlwaysMakeCrownjewel bool +} + +// Apply applies options to the record metadata. +func (o *Options) Apply(r record.Record) { + if r.Meta() == nil { + r.SetMeta(&record.Meta{}) + } + if o.AlwaysMakeSecret { + r.Meta().MakeSecret() + } + if o.AlwaysMakeCrownjewel { + r.Meta().MakeCrownJewel() + } +} + +// NewInterface returns a new Interface to the database. +func NewInterface(opts *Options) *Interface { + if opts == nil { + opts = &Options{} + } + + return &Interface{ + options: opts, + } +} + +// Exists return whether a record with the given key exists. +func (i *Interface) Exists(key string) (bool, error) { + _, _, err := i.getRecord(getDBFromKey, key, false, false) + if err != nil { + if err == ErrNotFound { + return false, nil + } + return false, err + } + return true, nil +} + +// Get return the record with the given key. +func (i *Interface) Get(key string) (record.Record, error) { + r, _, err := i.getRecord(getDBFromKey, key, true, false) + return r, err +} + +func (i *Interface) getRecord(dbName string, dbKey string, check bool, mustBeWriteable bool) (r record.Record, db *Controller, err error) { + if dbName == "" { + dbName, dbKey = record.ParseKey(dbKey) + } + + db, err = getController(dbName) + if err != nil { + return nil, nil, err + } + + if mustBeWriteable && db.ReadOnly() { + return nil, nil, ErrReadOnly + } + + r, err = db.Get(dbKey) + if err != nil { + if err == ErrNotFound { + return nil, db, err + } + return nil, nil, err + } + + if check && !r.Meta().CheckPermission(i.options.Local, i.options.Internal) { + return nil, nil, ErrPermissionDenied + } + + return r, db, nil +} + +// InsertValue inserts a value into a record. +func (i *Interface) InsertValue(key string, attribute string, value interface{}) error { + r, db, err := i.getRecord(getDBFromKey, key, true, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + var acc accessor.Accessor + if r.IsWrapped() { + wrapper, ok := r.(*record.Wrapper) + if !ok { + return errors.New("record is malformed (reports to be wrapped but is not of type *record.Wrapper)") + } + acc = accessor.NewJSONBytesAccessor(&wrapper.Data) + } else { + acc = accessor.NewStructAccessor(r) + } + + err = acc.Set(attribute, value) + if err != nil { + return fmt.Errorf("failed to set value with %s: %s", acc.Type(), err) + } + + i.options.Apply(r) + return db.Put(r) +} + +// Put saves a record to the database. +func (i *Interface) Put(r record.Record) error { + _, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true) + if err != nil && err != ErrNotFound { + return err + } + + i.options.Apply(r) + return db.Put(r) +} + +// PutNew saves a record to the database as a new record (ie. with new timestamps). +func (i *Interface) PutNew(r record.Record) error { + _, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true) + if err != nil && err != ErrNotFound { + return err + } + + i.options.Apply(r) + r.Meta().Reset() + return db.Put(r) +} + +// SetAbsoluteExpiry sets an absolute record expiry. +func (i *Interface) SetAbsoluteExpiry(key string, time int64) error { + r, db, err := i.getRecord(getDBFromKey, key, true, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().SetAbsoluteExpiry(time) + return db.Put(r) +} + +// SetRelativateExpiry sets a relative (self-updating) record expiry. +func (i *Interface) SetRelativateExpiry(key string, duration int64) error { + r, db, err := i.getRecord(getDBFromKey, key, true, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().SetRelativateExpiry(duration) + return db.Put(r) +} + +// MakeSecret marks the record as a secret, meaning interfacing processes, such as an UI, are denied access to the record. +func (i *Interface) MakeSecret(key string) error { + r, db, err := i.getRecord(getDBFromKey, key, true, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().MakeSecret() + return db.Put(r) +} + +// MakeCrownJewel marks a record as a crown jewel, meaning it will only be accessible locally. +func (i *Interface) MakeCrownJewel(key string) error { + r, db, err := i.getRecord(getDBFromKey, key, true, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().MakeCrownJewel() + return db.Put(r) +} + +// Delete deletes a record from the database. +func (i *Interface) Delete(key string) error { + r, db, err := i.getRecord(getDBFromKey, key, true, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().Delete() + return db.Put(r) +} + +// Query executes the given query on the database. +func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) { + db, err := getController(q.DatabaseName()) + if err != nil { + return nil, err + } + + return db.Query(q, i.options.Local, i.options.Internal) +} diff --git a/database/iterator/iterator.go b/database/iterator/iterator.go new file mode 100644 index 0000000..7a1dff4 --- /dev/null +++ b/database/iterator/iterator.go @@ -0,0 +1,20 @@ +package iterator + +import ( + "github.com/Safing/portbase/database/record" +) + +// Iterator defines the iterator structure. +type Iterator struct { + Next chan record.Record + Done chan struct{} + Error error +} + +// New creates a new Iterator. +func New() *Iterator { + return &Iterator{ + Next: make(chan record.Record, 10), + Done: make(chan struct{}), + } +} diff --git a/database/location.go b/database/location.go new file mode 100644 index 0000000..0c095b6 --- /dev/null +++ b/database/location.go @@ -0,0 +1,52 @@ +package database + +import ( + "errors" + "fmt" + "os" + "path" +) + +const ( + databasesSubDir = "databases" +) + +var ( + rootDir string +) + +func ensureDirectory(dirPath string) error { + // open dir + dir, err := os.Open(dirPath) + if err != nil { + if os.IsNotExist(err) { + return os.MkdirAll(dirPath, 0700) + } + return err + } + defer dir.Close() + + fileInfo, err := dir.Stat() + if err != nil { + return err + } + if !fileInfo.IsDir() { + return errors.New("path exists and is not a directory") + } + if fileInfo.Mode().Perm() != 0700 { + return dir.Chmod(0700) + } + return nil +} + +// getLocation returns the storage location for the given name and type. +func getLocation(name, storageType string) (string, error) { + location := path.Join(rootDir, databasesSubDir, name, storageType) + + // check location + err := ensureDirectory(location) + if err != nil { + return "", fmt.Errorf("location (%s) invalid: %s", location, err) + } + return location, nil +} diff --git a/database/main.go b/database/main.go new file mode 100644 index 0000000..7bb781d --- /dev/null +++ b/database/main.go @@ -0,0 +1,55 @@ +package database + +import ( + "errors" + "fmt" + "path" + + "github.com/tevino/abool" +) + +var ( + initialized = abool.NewBool(false) + + shuttingDown = abool.NewBool(false) + shutdownSignal = make(chan struct{}) +) + +// Initialize initialized the database +func Initialize(location string) error { + if initialized.SetToIf(false, true) { + rootDir = location + + err := ensureDirectory(rootDir) + if err != nil { + return fmt.Errorf("could not create/open database directory (%s): %s", rootDir, err) + } + + err = loadRegistry() + if err != nil { + return fmt.Errorf("could not load database registry (%s): %s", path.Join(rootDir, registryFileName), err) + } + + // start registry writer + go registryWriter() + + return nil + } + return errors.New("database already initialized") +} + +// Shutdown shuts down the whole database system. +func Shutdown() (err error) { + if shuttingDown.SetToIf(false, true) { + close(shutdownSignal) + } + + all := duplicateControllers() + for _, c := range all { + err = c.Shutdown() + if err != nil { + return + } + } + return +} diff --git a/database/maintenance.go b/database/maintenance.go new file mode 100644 index 0000000..d1399bd --- /dev/null +++ b/database/maintenance.go @@ -0,0 +1,92 @@ +package database + +import ( + "time" + + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" +) + +// Maintain runs the Maintain method on all storages. +func Maintain() (err error) { + controllers := duplicateControllers() + for _, c := range controllers { + err = c.Maintain() + if err != nil { + return + } + } + return +} + +// MaintainThorough runs the MaintainThorough method on all storages. +func MaintainThorough() (err error) { + all := duplicateControllers() + for _, c := range all { + err = c.MaintainThorough() + if err != nil { + return + } + } + return +} + +// MaintainRecordStates runs record state lifecycle maintenance on all storages. +func MaintainRecordStates() error { + all := duplicateControllers() + now := time.Now().Unix() + thirtyDaysAgo := time.Now().Add(-30*24*time.Hour).Unix() + + for _, c := range all { + + if c.ReadOnly() || c.Injected() { + continue + } + + q, err := query.New("").Check() + if err != nil { + return err + } + + it, err := c.Query(q, true, true) + if err != nil { + return err + } + + var toDelete []record.Record + var toExpire []record.Record + + for r := range it.Next { + switch { + case r.Meta().Deleted < thirtyDaysAgo: + toDelete = append(toDelete, r) + case r.Meta().Expires < now: + toExpire = append(toExpire, r) + } + } + if it.Error != nil { + return err + } + + for _, r := range toDelete { + c.storage.Delete(r.DatabaseKey()) + } + for _, r := range toExpire { + r.Meta().Delete() + return c.Put(r) + } + + } + return nil +} + +func duplicateControllers() (all []*Controller) { + controllersLock.Lock() + defer controllersLock.Unlock() + + for _, c := range controllers { + all = append(all, c) + } + + return +} diff --git a/database/model.go b/database/model.go deleted file mode 100644 index 4283ba2..0000000 --- a/database/model.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "fmt" - "strings" - "sync" - - "github.com/ipfs/go-datastore" - - "github.com/Safing/safing-core/database/dbutils" - "github.com/Safing/safing-core/formats/dsd" -) - -type Model interface { - SetKey(*datastore.Key) - GetKey() *datastore.Key - FmtKey() string - // Type() string - // DefaultNamespace() datastore.Key - // Create(string) error - // CreateInLocation(datastore.Key, string) error - // CreateObject(*datastore.Key, string, Model) error - // Save() error - // Delete() error - // CastError(interface{}, interface{}) error -} - -func getTypeName(model interface{}) string { - full := fmt.Sprintf("%T", model) - return full[strings.LastIndex(full, ".")+1:] -} - -func TypeAssertError(model Model, object interface{}) error { - return fmt.Errorf("database: could not assert %s to type %T (is type %T)", model.FmtKey(), model, object) -} - -// Model Registration - -var ( - registeredModels = make(map[string]func() Model) - registeredModelsLock sync.RWMutex -) - -func RegisterModel(model Model, constructor func() Model) { - registeredModelsLock.Lock() - defer registeredModelsLock.Unlock() - registeredModels[fmt.Sprintf("%T", model)] = constructor -} - -func NewModel(model Model) (Model, error) { - registeredModelsLock.RLock() - defer registeredModelsLock.RUnlock() - constructor, ok := registeredModels[fmt.Sprintf("%T", model)] - if !ok { - return nil, fmt.Errorf("database: cannot create new %T, not registered", model) - } - return constructor(), nil -} - -func EnsureModel(uncertain, model Model) (Model, error) { - wrappedObj, ok := uncertain.(*dbutils.Wrapper) - if !ok { - return uncertain, nil - } - newModel, err := NewModel(model) - if err != nil { - return nil, err - } - _, err = dsd.Load(wrappedObj.Data, &newModel) - if err != nil { - return nil, fmt.Errorf("database: failed to unwrap %T: %s", model, err) - } - newModel.SetKey(wrappedObj.GetKey()) - model = newModel - return newModel, nil -} - -func SilentEnsureModel(uncertain, model Model) Model { - obj, err := EnsureModel(uncertain, model) - if err != nil { - return nil - } - return obj -} - -func NewMismatchError(got, expected interface{}) error { - return fmt.Errorf("database: entry (%T) does not match expected model (%T)", got, expected) -} diff --git a/database/model_test.go b/database/model_test.go deleted file mode 100644 index 9d1f486..0000000 --- a/database/model_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "testing" - - datastore "github.com/ipfs/go-datastore" -) - -type TestingModel struct { - Base - Name string - Value string -} - -var testingModel *TestingModel - -func init() { - RegisterModel(testingModel, func() Model { return new(TestingModel) }) -} - -func (m *TestingModel) Create(name string) error { - return m.CreateObject(&Tests, name, m) -} - -func (m *TestingModel) CreateInNamespace(namespace string, name string) error { - testsNamescace := Tests.ChildString(namespace) - return m.CreateObject(&testsNamescace, name, m) -} - -func (m *TestingModel) CreateWithTypeName(namespace string, typeName string, name string) error { - customNamespace := Tests.ChildString(namespace).ChildString(typeName).Instance(name) - - m.dbKey = &customNamespace - handleCreateSubscriptions(m) - return Create(*m.dbKey, m) -} - -func (m *TestingModel) Save() error { - return m.SaveObject(m) -} - -func GetTestingModel(name string) (*TestingModel, error) { - return GetTestingModelFromNamespace(&Tests, name) -} - -func GetTestingModelFromNamespace(namespace *datastore.Key, name string) (*TestingModel, error) { - object, err := GetAndEnsureModel(namespace, name, testingModel) - if err != nil { - return nil, err - } - model, ok := object.(*TestingModel) - if !ok { - return nil, NewMismatchError(object, testingModel) - } - return model, nil -} - -func TestModel(t *testing.T) { - - // create - m := TestingModel{ - Name: "a", - Value: "b", - } - // newKey := datastore.NewKey("/Tests/TestingModel:test") - // m.dbKey = &newKey - // err := Put(*m.dbKey, m) - err := m.Create("") - if err != nil { - t.Errorf("database test: could not create object: %s", err) - } - - // get - o, err := GetTestingModel(m.dbKey.Name()) - if err != nil { - t.Errorf("database test: failed to get model: %s (%s)", err, m.dbKey.Name()) - } - - // check fetched object - if o.Name != "a" || o.Value != "b" { - t.Errorf("database test: values do not match: got Name=%s and Value=%s", o.Name, o.Value) - } - - // o, err := Get(*m.dbKey) - // if err != nil { - // t.Errorf("database: could not get object: %s", err) - // } - // n, ok := o.(*TestingModel) - // if !ok { - // t.Errorf("database: wrong type, got type %T from %s", o, m.dbKey.String()) - // } - - // save - o.Value = "c" - err = o.Save() - if err != nil { - t.Errorf("database test: could not save object: %s", err) - } - - // delete - err = o.Delete() - if err != nil { - t.Errorf("database test: could not delete object: %s", err) - } - -} diff --git a/database/namespaces.go b/database/namespaces.go deleted file mode 100644 index 4dca334..0000000 --- a/database/namespaces.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import datastore "github.com/ipfs/go-datastore" - -var ( - // Persistent data that is fetched or gathered, entries may be deleted - Cache = datastore.NewKey("/Cache") - DNSCache = Cache.ChildString("Dns") - IntelCache = Cache.ChildString("Intel") - FileInfoCache = Cache.ChildString("FileInfo") - ProfileCache = Cache.ChildString("Profile") - IPInfoCache = Cache.ChildString("IPInfo") - CertCache = Cache.ChildString("Cert") - CARevocationInfoCache = Cache.ChildString("CARevocationInfo") - - // Volatile, in-memory (recommended) namespace for storing runtime information, cleans itself - Run = datastore.NewKey("/Run") - Processes = Run.ChildString("Processes") - OrphanedConnection = Run.ChildString("OrphanedConnections") - OrphanedLink = Run.ChildString("OrphanedLinks") - Api = Run.ChildString("Api") - ApiSessions = Api.ChildString("ApiSessions") - - // Namespace for current device, will be mounted into /Devices/[device] - Me = datastore.NewKey("/Me") - - // Holds data of all Devices - Devices = datastore.NewKey("/Devices") - - // Holds persistent data - Data = datastore.NewKey("/Data") - Profiles = Data.ChildString("Profiles") - - // Holds data distributed by the System (coming from the Community and Devs) - Dist = datastore.NewKey("/Dist") - DistProfiles = Dist.ChildString("Profiles") - DistUpdates = Dist.ChildString("Updates") - - // Holds data issued by company - Company = datastore.NewKey("/Company") - CompanyProfiles = Company.ChildString("Profiles") - CompanyUpdates = Company.ChildString("Updates") - - // Server - // The Authority namespace is used by authoritative servers (Safing or Company) to store data (Intel, Profiles, ...) to be served to clients - Authority = datastore.NewKey("/Authority") - AthoritativeIntel = Authority.ChildString("Intel") - AthoritativeProfiles = Authority.ChildString("Profiles") - // The Staging namespace is the same as the Authority namespace, but for rolling out new things first to a selected list of clients for testing - AuthorityStaging = datastore.NewKey("/Staging") - AthoritativeStagingProfiles = AuthorityStaging.ChildString("Profiles") - - // Holds data of Apps - Apps = datastore.NewKey("/Apps") - - // Test & Invalid namespace - Tests = datastore.NewKey("/Tests") - Invalid = datastore.NewKey("/Invalid") -) diff --git a/database/queries.go b/database/queries.go deleted file mode 100644 index 3c5ea13..0000000 --- a/database/queries.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "time" - - "github.com/Safing/safing-core/formats/dsd" - "github.com/Safing/safing-core/log" - - dsq "github.com/ipfs/go-datastore/query" -) - -func init() { - // go dumper() -} - -func dumper() { - for { - time.Sleep(10 * time.Second) - result, err := db.Query(dsq.Query{Prefix: "/Run/Process"}) - if err != nil { - log.Warningf("Query failed: %s", err) - continue - } - log.Infof("Dumping all processes:") - for model, ok := result.NextSync(); ok; model, ok = result.NextSync() { - bytes, err := dsd.Dump(model, dsd.AUTO) - if err != nil { - log.Warningf("Error dumping: %s", err) - continue - } - log.Info(string(bytes)) - } - log.Infof("END") - } -} diff --git a/database/query/README.md b/database/query/README.md new file mode 100644 index 0000000..9311417 --- /dev/null +++ b/database/query/README.md @@ -0,0 +1,55 @@ +# Query + +## Control Flow + +- Grouping with `(` and `)` +- Chaining with `and` and `or` + - _NO_ mixing! Be explicit and use grouping. +- Negation with `not` + - in front of expression for group: `not (...)` + - inside expression for clause: `name not matches "^King "` + +## Selectors + +Supported by all feeders: +- root level field: `field` +- sub level field: `field.sub` +- array/slice/map access: `map.0` +- array/slice/map length: `map.#` + +Please note that some feeders may have other special characters. It is advised to only use alphanumeric characters for keys. + +## Operators + +| Name | Textual | Req. Type | Internal Type | Compared with | +|---|---|---|---| +| Equals | `==` | int | int64 | `==` | +| GreaterThan | `>` | int | int64 | `>` | +| GreaterThanOrEqual | `>=` | int | int64 | `>=` | +| LessThan | `<` | int | int64 | `<` | +| LessThanOrEqual | `<=` | int | int64 | `<=` | +| FloatEquals | `f==` | float | float64 | `==` | +| FloatGreaterThan | `f>` | float | float64 | `>` | +| FloatGreaterThanOrEqual | `f>=` | float | float64 | `>=` | +| FloatLessThan | `f<` | float | float64 | `<` | +| FloatLessThanOrEqual | `f<=` | float | float64 | `<=` | +| SameAs | `sameas`, `s==` | string | string | `==` | +| Contains | `contains`, `co` | string | string | `strings.Contains()` | +| StartsWith | `startswith`, `sw` | string | string | `strings.HasPrefix()` | +| EndsWith | `endswith`, `ew` | string | string | `strings.HasSuffix()` | +| In | `in` | string | string | for loop with `==` | +| Matches | `matches`, `re` | string | int64 | `regexp.Regexp.Matches()` | +| Is | `is` | bool* | bool | `==` | +| Exists | `exists`, `ex` | any | n/a | n/a | + +\*accepts strings: 1, t, T, TRUE, true, True, 0, f, F, FALSE + +## Escaping + +If you need to use a control character within a value (ie. not for controlling), escape it with `\`. +It is recommended to wrap a word into parenthesis instead of escaping control characters, when possible. + +| Location | Characters to be escaped | +|---|---| +| Within parenthesis (`"`) | `"`, `\` | +| Everywhere else | `(`, `)`, `"`, `\`, `\t`, `\r`, `\n`, ` ` (space) | diff --git a/database/query/condition-and.go b/database/query/condition-and.go new file mode 100644 index 0000000..74304b9 --- /dev/null +++ b/database/query/condition-and.go @@ -0,0 +1,46 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/Safing/portbase/database/accessor" +) + +// And combines multiple conditions with a logical _AND_ operator. +func And(conditions ...Condition) Condition { + return &andCond{ + conditions: conditions, + } +} + +type andCond struct { + conditions []Condition +} + +func (c *andCond) complies(acc accessor.Accessor) bool { + for _, cond := range c.conditions { + if !cond.complies(acc) { + return false + } + } + return true +} + +func (c *andCond) check() (err error) { + for _, cond := range c.conditions { + err = cond.check() + if err != nil { + return err + } + } + return nil +} + +func (c *andCond) string() string { + var all []string + for _, cond := range c.conditions { + all = append(all, cond.string()) + } + return fmt.Sprintf("(%s)", strings.Join(all, " and ")) +} diff --git a/database/query/condition-bool.go b/database/query/condition-bool.go new file mode 100644 index 0000000..834b592 --- /dev/null +++ b/database/query/condition-bool.go @@ -0,0 +1,70 @@ +package query + +import ( + "errors" + "fmt" + "strconv" + + "github.com/Safing/portbase/database/accessor" +) + +type boolCondition struct { + key string + operator uint8 + value bool +} + +func newBoolCondition(key string, operator uint8, value interface{}) *boolCondition { + + var parsedValue bool + + switch v := value.(type) { + case bool: + parsedValue = v + case string: + var err error + parsedValue, err = strconv.ParseBool(v) + if err != nil { + return &boolCondition{ + key: fmt.Sprintf("could not parse \"%s\" to bool: %s", v, err), + operator: errorPresent, + } + } + default: + return &boolCondition{ + key: fmt.Sprintf("incompatible value %v for int64", value), + operator: errorPresent, + } + } + + return &boolCondition{ + key: key, + operator: operator, + value: parsedValue, + } +} + +func (c *boolCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetBool(c.key) + if !ok { + return false + } + + switch c.operator { + case Is: + return comp == c.value + default: + return false + } +} + +func (c *boolCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *boolCondition) string() string { + return fmt.Sprintf("%s %s %t", escapeString(c.key), getOpName(c.operator), c.value) +} diff --git a/database/query/condition-error.go b/database/query/condition-error.go new file mode 100644 index 0000000..a46c36b --- /dev/null +++ b/database/query/condition-error.go @@ -0,0 +1,27 @@ +package query + +import ( + "github.com/Safing/portbase/database/accessor" +) + +type errorCondition struct { + err error +} + +func newErrorCondition(err error) *errorCondition { + return &errorCondition{ + err: err, + } +} + +func (c *errorCondition) complies(acc accessor.Accessor) bool { + return false +} + +func (c *errorCondition) check() error { + return c.err +} + +func (c *errorCondition) string() string { + return "[ERROR]" +} diff --git a/database/query/condition-exists.go b/database/query/condition-exists.go new file mode 100644 index 0000000..567360f --- /dev/null +++ b/database/query/condition-exists.go @@ -0,0 +1,35 @@ +package query + +import ( + "errors" + "fmt" + + "github.com/Safing/portbase/database/accessor" +) + +type existsCondition struct { + key string + operator uint8 +} + +func newExistsCondition(key string, operator uint8) *existsCondition { + return &existsCondition{ + key: key, + operator: operator, + } +} + +func (c *existsCondition) complies(acc accessor.Accessor) bool { + return acc.Exists(c.key) +} + +func (c *existsCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *existsCondition) string() string { + return fmt.Sprintf("%s %s", escapeString(c.key), getOpName(c.operator)) +} diff --git a/database/query/condition-float.go b/database/query/condition-float.go new file mode 100644 index 0000000..4416594 --- /dev/null +++ b/database/query/condition-float.go @@ -0,0 +1,98 @@ +package query + +import ( + "errors" + "fmt" + "strconv" + + "github.com/Safing/portbase/database/accessor" +) + +type floatCondition struct { + key string + operator uint8 + value float64 +} + +func newFloatCondition(key string, operator uint8, value interface{}) *floatCondition { + + var parsedValue float64 + + switch v := value.(type) { + case int: + parsedValue = float64(v) + case int8: + parsedValue = float64(v) + case int16: + parsedValue = float64(v) + case int32: + parsedValue = float64(v) + case int64: + parsedValue = float64(v) + case uint: + parsedValue = float64(v) + case uint8: + parsedValue = float64(v) + case uint16: + parsedValue = float64(v) + case uint32: + parsedValue = float64(v) + case float32: + parsedValue = float64(v) + case float64: + parsedValue = v + case string: + var err error + parsedValue, err = strconv.ParseFloat(v, 64) + if err != nil { + return &floatCondition{ + key: fmt.Sprintf("could not parse %s to float64: %s", v, err), + operator: errorPresent, + } + } + default: + return &floatCondition{ + key: fmt.Sprintf("incompatible value %v for float64", value), + operator: errorPresent, + } + } + + return &floatCondition{ + key: key, + operator: operator, + value: parsedValue, + } +} + +func (c *floatCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetFloat(c.key) + if !ok { + return false + } + + switch c.operator { + case FloatEquals: + return comp == c.value + case FloatGreaterThan: + return comp > c.value + case FloatGreaterThanOrEqual: + return comp >= c.value + case FloatLessThan: + return comp < c.value + case FloatLessThanOrEqual: + return comp <= c.value + default: + return false + } +} + +func (c *floatCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *floatCondition) string() string { + return fmt.Sprintf("%s %s %g", escapeString(c.key), getOpName(c.operator), c.value) +} diff --git a/database/query/condition-int.go b/database/query/condition-int.go new file mode 100644 index 0000000..dccac28 --- /dev/null +++ b/database/query/condition-int.go @@ -0,0 +1,94 @@ +package query + +import ( + "errors" + "fmt" + "strconv" + + "github.com/Safing/portbase/database/accessor" +) + +type intCondition struct { + key string + operator uint8 + value int64 +} + +func newIntCondition(key string, operator uint8, value interface{}) *intCondition { + + var parsedValue int64 + + switch v := value.(type) { + case int: + parsedValue = int64(v) + case int8: + parsedValue = int64(v) + case int16: + parsedValue = int64(v) + case int32: + parsedValue = int64(v) + case int64: + parsedValue = int64(v) + case uint: + parsedValue = int64(v) + case uint8: + parsedValue = int64(v) + case uint16: + parsedValue = int64(v) + case uint32: + parsedValue = int64(v) + case string: + var err error + parsedValue, err = strconv.ParseInt(v, 10, 64) + if err != nil { + return &intCondition{ + key: fmt.Sprintf("could not parse %s to int64: %s (hint: use \"sameas\" to compare strings)", v, err), + operator: errorPresent, + } + } + default: + return &intCondition{ + key: fmt.Sprintf("incompatible value %v for int64", value), + operator: errorPresent, + } + } + + return &intCondition{ + key: key, + operator: operator, + value: parsedValue, + } +} + +func (c *intCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetInt(c.key) + if !ok { + return false + } + + switch c.operator { + case Equals: + return comp == c.value + case GreaterThan: + return comp > c.value + case GreaterThanOrEqual: + return comp >= c.value + case LessThan: + return comp < c.value + case LessThanOrEqual: + return comp <= c.value + default: + return false + } +} + +func (c *intCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *intCondition) string() string { + return fmt.Sprintf("%s %s %d", escapeString(c.key), getOpName(c.operator), c.value) +} diff --git a/database/query/condition-not.go b/database/query/condition-not.go new file mode 100644 index 0000000..cac04a7 --- /dev/null +++ b/database/query/condition-not.go @@ -0,0 +1,36 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/Safing/portbase/database/accessor" +) + +// Not negates the supplied condition. +func Not(c Condition) Condition { + return ¬Cond{ + notC: c, + } +} + +type notCond struct { + notC Condition +} + +func (c *notCond) complies(acc accessor.Accessor) bool { + return !c.notC.complies(acc) +} + +func (c *notCond) check() error { + return c.notC.check() +} + +func (c *notCond) string() string { + next := c.notC.string() + if strings.HasPrefix(next, "(") { + return fmt.Sprintf("not %s", c.notC.string()) + } + splitted := strings.Split(next, " ") + return strings.Join(append([]string{splitted[0], "not"}, splitted[1:]...), " ") +} diff --git a/database/query/condition-or.go b/database/query/condition-or.go new file mode 100644 index 0000000..25fd37b --- /dev/null +++ b/database/query/condition-or.go @@ -0,0 +1,46 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/Safing/portbase/database/accessor" +) + +// Or combines multiple conditions with a logical _OR_ operator. +func Or(conditions ...Condition) Condition { + return &orCond{ + conditions: conditions, + } +} + +type orCond struct { + conditions []Condition +} + +func (c *orCond) complies(acc accessor.Accessor) bool { + for _, cond := range c.conditions { + if cond.complies(acc) { + return true + } + } + return false +} + +func (c *orCond) check() (err error) { + for _, cond := range c.conditions { + err = cond.check() + if err != nil { + return err + } + } + return nil +} + +func (c *orCond) string() string { + var all []string + for _, cond := range c.conditions { + all = append(all, cond.string()) + } + return fmt.Sprintf("(%s)", strings.Join(all, " or ")) +} diff --git a/database/query/condition-regex.go b/database/query/condition-regex.go new file mode 100644 index 0000000..e808fcd --- /dev/null +++ b/database/query/condition-regex.go @@ -0,0 +1,63 @@ +package query + +import ( + "errors" + "fmt" + "regexp" + + "github.com/Safing/portbase/database/accessor" +) + +type regexCondition struct { + key string + operator uint8 + regex *regexp.Regexp +} + +func newRegexCondition(key string, operator uint8, value interface{}) *regexCondition { + switch v := value.(type) { + case string: + r, err := regexp.Compile(v) + if err != nil { + return ®exCondition{ + key: fmt.Sprintf("could not compile regex \"%s\": %s", v, err), + operator: errorPresent, + } + } + return ®exCondition{ + key: key, + operator: operator, + regex: r, + } + default: + return ®exCondition{ + key: fmt.Sprintf("incompatible value %v for string", value), + operator: errorPresent, + } + } +} + +func (c *regexCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetString(c.key) + if !ok { + return false + } + + switch c.operator { + case Matches: + return c.regex.MatchString(comp) + default: + return false + } +} + +func (c *regexCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *regexCondition) string() string { + return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(c.regex.String())) +} diff --git a/database/query/condition-string.go b/database/query/condition-string.go new file mode 100644 index 0000000..ddbf1b1 --- /dev/null +++ b/database/query/condition-string.go @@ -0,0 +1,62 @@ +package query + +import ( + "errors" + "fmt" + "strings" + + "github.com/Safing/portbase/database/accessor" +) + +type stringCondition struct { + key string + operator uint8 + value string +} + +func newStringCondition(key string, operator uint8, value interface{}) *stringCondition { + switch v := value.(type) { + case string: + return &stringCondition{ + key: key, + operator: operator, + value: v, + } + default: + return &stringCondition{ + key: fmt.Sprintf("incompatible value %v for string", value), + operator: errorPresent, + } + } +} + +func (c *stringCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetString(c.key) + if !ok { + return false + } + + switch c.operator { + case SameAs: + return c.value == comp + case Contains: + return strings.Contains(comp, c.value) + case StartsWith: + return strings.HasPrefix(comp, c.value) + case EndsWith: + return strings.HasSuffix(comp, c.value) + default: + return false + } +} + +func (c *stringCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *stringCondition) string() string { + return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(c.value)) +} diff --git a/database/query/condition-stringslice.go b/database/query/condition-stringslice.go new file mode 100644 index 0000000..ffc6643 --- /dev/null +++ b/database/query/condition-stringslice.go @@ -0,0 +1,71 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/Safing/portbase/database/accessor" + "github.com/Safing/portbase/utils" +) + +type stringSliceCondition struct { + key string + operator uint8 + value []string +} + +func newStringSliceCondition(key string, operator uint8, value interface{}) *stringSliceCondition { + + switch v := value.(type) { + case string: + parsedValue := strings.Split(v, ",") + if len(parsedValue) < 2 { + return &stringSliceCondition{ + key: v, + operator: errorPresent, + } + } + return &stringSliceCondition{ + key: key, + operator: operator, + value: parsedValue, + } + case []string: + return &stringSliceCondition{ + key: key, + operator: operator, + value: v, + } + default: + return &stringSliceCondition{ + key: fmt.Sprintf("incompatible value %v for []string", value), + operator: errorPresent, + } + } + +} + +func (c *stringSliceCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetString(c.key) + if !ok { + return false + } + + switch c.operator { + case In: + return utils.StringInSlice(c.value, comp) + default: + return false + } +} + +func (c *stringSliceCondition) check() error { + if c.operator == errorPresent { + return fmt.Errorf("could not parse \"%s\" to []string", c.key) + } + return nil +} + +func (c *stringSliceCondition) string() string { + return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(strings.Join(c.value, ","))) +} diff --git a/database/query/condition.go b/database/query/condition.go new file mode 100644 index 0000000..52f4e89 --- /dev/null +++ b/database/query/condition.go @@ -0,0 +1,71 @@ +package query + +import ( + "fmt" + + "github.com/Safing/portbase/database/accessor" +) + +// Condition is an interface to provide a common api to all condition types. +type Condition interface { + complies(acc accessor.Accessor) bool + check() error + string() string +} + +// Operators +const ( + Equals uint8 = iota // int + GreaterThan // int + GreaterThanOrEqual // int + LessThan // int + LessThanOrEqual // int + FloatEquals // float + FloatGreaterThan // float + FloatGreaterThanOrEqual // float + FloatLessThan // float + FloatLessThanOrEqual // float + SameAs // string + Contains // string + StartsWith // string + EndsWith // string + In // stringSlice + Matches // regex + Is // bool: accepts 1, t, T, TRUE, true, True, 0, f, F, FALSE + Exists // any + + errorPresent uint8 = 255 +) + +// Where returns a condition to add to a query. +func Where(key string, operator uint8, value interface{}) Condition { + switch operator { + case Equals, + GreaterThan, + GreaterThanOrEqual, + LessThan, + LessThanOrEqual: + return newIntCondition(key, operator, value) + case FloatEquals, + FloatGreaterThan, + FloatGreaterThanOrEqual, + FloatLessThan, + FloatLessThanOrEqual: + return newFloatCondition(key, operator, value) + case SameAs, + Contains, + StartsWith, + EndsWith: + return newStringCondition(key, operator, value) + case In: + return newStringSliceCondition(key, operator, value) + case Matches: + return newRegexCondition(key, operator, value) + case Is: + return newBoolCondition(key, operator, value) + case Exists: + return newExistsCondition(key, operator) + default: + return newErrorCondition(fmt.Errorf("no operator with ID %d", operator)) + } +} diff --git a/database/query/condition_test.go b/database/query/condition_test.go new file mode 100644 index 0000000..eb871a7 --- /dev/null +++ b/database/query/condition_test.go @@ -0,0 +1,76 @@ +package query + +import "testing" + +func testSuccess(t *testing.T, c Condition) { + err := c.check() + if err != nil { + t.Errorf("failed: %s", err) + } +} + +func TestInterfaces(t *testing.T) { + testSuccess(t, newIntCondition("banana", Equals, uint(1))) + testSuccess(t, newIntCondition("banana", Equals, uint8(1))) + testSuccess(t, newIntCondition("banana", Equals, uint16(1))) + testSuccess(t, newIntCondition("banana", Equals, uint32(1))) + testSuccess(t, newIntCondition("banana", Equals, int(1))) + testSuccess(t, newIntCondition("banana", Equals, int8(1))) + testSuccess(t, newIntCondition("banana", Equals, int16(1))) + testSuccess(t, newIntCondition("banana", Equals, int32(1))) + testSuccess(t, newIntCondition("banana", Equals, int64(1))) + testSuccess(t, newIntCondition("banana", Equals, "1")) + + testSuccess(t, newFloatCondition("banana", FloatEquals, uint(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, uint8(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, uint16(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, uint32(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int8(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int16(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int32(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int64(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, float32(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, float64(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, "1.1")) + + testSuccess(t, newStringCondition("banana", SameAs, "coconut")) + testSuccess(t, newRegexCondition("banana", Matches, "coconut")) + testSuccess(t, newStringSliceCondition("banana", FloatEquals, []string{"banana", "coconut"})) + testSuccess(t, newStringSliceCondition("banana", FloatEquals, "banana,coconut")) +} + +func testCondError(t *testing.T, c Condition) { + err := c.check() + if err == nil { + t.Error("should fail") + } +} + +func TestConditionErrors(t *testing.T) { + // test invalid value types + testCondError(t, newBoolCondition("banana", Is, 1)) + testCondError(t, newFloatCondition("banana", FloatEquals, true)) + testCondError(t, newIntCondition("banana", Equals, true)) + testCondError(t, newStringCondition("banana", SameAs, 1)) + testCondError(t, newRegexCondition("banana", Matches, 1)) + testCondError(t, newStringSliceCondition("banana", Matches, 1)) + + // test error presence + testCondError(t, newBoolCondition("banana", errorPresent, true)) + testCondError(t, And(newBoolCondition("banana", errorPresent, true))) + testCondError(t, Or(newBoolCondition("banana", errorPresent, true))) + testCondError(t, newExistsCondition("banana", errorPresent)) + testCondError(t, newFloatCondition("banana", errorPresent, 1.1)) + testCondError(t, newIntCondition("banana", errorPresent, 1)) + testCondError(t, newStringCondition("banana", errorPresent, "coconut")) + testCondError(t, newRegexCondition("banana", errorPresent, "coconut")) +} + +func TestWhere(t *testing.T) { + c := Where("", 254, nil) + err := c.check() + if err == nil { + t.Error("should fail") + } +} diff --git a/database/query/operators.go b/database/query/operators.go new file mode 100644 index 0000000..bbd21ee --- /dev/null +++ b/database/query/operators.go @@ -0,0 +1,53 @@ +package query + +var ( + operatorNames = map[string]uint8{ + "==": Equals, + ">": GreaterThan, + ">=": GreaterThanOrEqual, + "<": LessThan, + "<=": LessThanOrEqual, + "f==": FloatEquals, + "f>": FloatGreaterThan, + "f>=": FloatGreaterThanOrEqual, + "f<": FloatLessThan, + "f<=": FloatLessThanOrEqual, + "sameas": SameAs, + "s==": SameAs, + "contains": Contains, + "co": Contains, + "startswith": StartsWith, + "sw": StartsWith, + "endswith": EndsWith, + "ew": EndsWith, + "in": In, + "matches": Matches, + "re": Matches, + "is": Is, + "exists": Exists, + "ex": Exists, + } + + primaryNames = make(map[uint8]string) +) + +func init() { + for opName, opID := range operatorNames { + name, ok := primaryNames[opID] + if ok { + if len(name) < len(opName) { + primaryNames[opID] = opName + } + } else { + primaryNames[opID] = opName + } + } +} + +func getOpName(operator uint8) string { + name, ok := primaryNames[operator] + if ok { + return name + } + return "[unknown]" +} diff --git a/database/query/operators_test.go b/database/query/operators_test.go new file mode 100644 index 0000000..3f4fe81 --- /dev/null +++ b/database/query/operators_test.go @@ -0,0 +1,9 @@ +package query + +import "testing" + +func TestGetOpName(t *testing.T) { + if getOpName(254) != "[unknown]" { + t.Error("unexpected output") + } +} diff --git a/database/query/parser.go b/database/query/parser.go new file mode 100644 index 0000000..d14b847 --- /dev/null +++ b/database/query/parser.go @@ -0,0 +1,349 @@ +package query + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +type snippet struct { + text string + globalPosition int +} + +// ParseQuery parses a plaintext query. Special characters (that must be escaped with a '\') are: `\()` and any whitespaces. +func ParseQuery(query string) (*Query, error) { + snippets, err := extractSnippets(query) + if err != nil { + return nil, err + } + snippetsPos := 0 + + getSnippet := func() (*snippet, error) { + // order is important, as parseAndOr will always consume one additional snippet. + snippetsPos++ + if snippetsPos > len(snippets) { + return nil, fmt.Errorf("unexpected end at position %d", len(query)) + } + return snippets[snippetsPos-1], nil + } + remainingSnippets := func() int { + return len(snippets) - snippetsPos + } + + // check for query word + queryWord, err := getSnippet() + if err != nil { + return nil, err + } + if queryWord.text != "query" { + return nil, errors.New("queries must start with \"query\"") + } + + // get prefix + prefix, err := getSnippet() + if err != nil { + return nil, err + } + q := New(prefix.text) + + for remainingSnippets() > 0 { + command, err := getSnippet() + if err != nil { + return nil, err + } + + switch command.text { + case "where": + if q.where != nil { + return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition) + } + + // parse conditions + condition, err := parseAndOr(getSnippet, remainingSnippets, true) + if err != nil { + return nil, err + } + // go one back, as parseAndOr had to check if its done + snippetsPos-- + + q.Where(condition) + case "orderby": + if q.orderBy != "" { + return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition) + } + + orderBySnippet, err := getSnippet() + if err != nil { + return nil, err + } + + q.OrderBy(orderBySnippet.text) + case "limit": + if q.limit != 0 { + return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition) + } + + limitSnippet, err := getSnippet() + if err != nil { + return nil, err + } + limit, err := strconv.ParseUint(limitSnippet.text, 10, 31) + if err != nil { + return nil, fmt.Errorf("could not parse integer (%s) at position %d", limitSnippet.text, limitSnippet.globalPosition) + } + + q.Limit(int(limit)) + case "offset": + if q.offset != 0 { + return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition) + } + + offsetSnippet, err := getSnippet() + if err != nil { + return nil, err + } + offset, err := strconv.ParseUint(offsetSnippet.text, 10, 31) + if err != nil { + return nil, fmt.Errorf("could not parse integer (%s) at position %d", offsetSnippet.text, offsetSnippet.globalPosition) + } + + q.Offset(int(offset)) + default: + return nil, fmt.Errorf("unknown clause \"%s\" at position %d", command.text, command.globalPosition) + } + } + + return q.Check() +} + +func extractSnippets(text string) (snippets []*snippet, err error) { + + skip := false + start := -1 + inParenthesis := false + var pos int + var char rune + + for pos, char = range text { + + // skip + if skip { + skip = false + continue + } + if char == '\\' { + skip = true + } + + // wait for parenthesis to be overs + if inParenthesis { + if char == '"' { + snippets = append(snippets, &snippet{ + text: prepToken(text[start+1 : pos]), + globalPosition: start + 1, + }) + start = -1 + inParenthesis = false + } + continue + } + + // handle segments + switch char { + case '\t', '\n', '\r', ' ', '(', ')': + if start >= 0 { + snippets = append(snippets, &snippet{ + text: prepToken(text[start:pos]), + globalPosition: start + 1, + }) + start = -1 + } + default: + if start == -1 { + start = pos + } + } + + // handle special segment characters + switch char { + case '(', ')': + snippets = append(snippets, &snippet{ + text: text[pos : pos+1], + globalPosition: pos + 1, + }) + case '"': + if start < pos { + return nil, fmt.Errorf("parenthesis ('\"') may not be used within words, please escape with '\\' (position: %d)", pos+1) + } + inParenthesis = true + } + + } + + // add last + if start >= 0 { + snippets = append(snippets, &snippet{ + text: prepToken(text[start : pos+1]), + globalPosition: start + 1, + }) + } + + return snippets, nil + +} + +func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() int, rootCondition bool) (Condition, error) { + var isOr = false + var typeSet = false + var wrapInNot = false + var expectingMore = true + var conditions []Condition + + for { + if !expectingMore && rootCondition && remainingSnippets() == 0 { + // advance snippetsPos by one, as it will be set back by 1 + getSnippet() + if len(conditions) == 1 { + return conditions[0], nil + } + if isOr { + return Or(conditions...), nil + } + return And(conditions...), nil + } + + firstSnippet, err := getSnippet() + if err != nil { + return nil, err + } + + if !expectingMore && rootCondition { + switch firstSnippet.text { + case "orderby", "limit", "offset": + if len(conditions) == 1 { + return conditions[0], nil + } + if isOr { + return Or(conditions...), nil + } + return And(conditions...), nil + } + } + + switch firstSnippet.text { + case "(": + condition, err := parseAndOr(getSnippet, remainingSnippets, false) + if err != nil { + return nil, err + } + if wrapInNot { + conditions = append(conditions, Not(condition)) + wrapInNot = false + } else { + conditions = append(conditions, condition) + } + expectingMore = true + case ")": + if len(conditions) == 1 { + return conditions[0], nil + } + if isOr { + return Or(conditions...), nil + } + return And(conditions...), nil + case "and": + if typeSet && isOr { + return nil, fmt.Errorf("you may not mix \"and\" and \"or\" (position: %d)", firstSnippet.globalPosition) + } + isOr = false + typeSet = true + expectingMore = true + case "or": + if typeSet && !isOr { + return nil, fmt.Errorf("you may not mix \"and\" and \"or\" (position: %d)", firstSnippet.globalPosition) + } + isOr = true + typeSet = true + expectingMore = true + case "not": + wrapInNot = true + expectingMore = true + default: + condition, err := parseCondition(firstSnippet, getSnippet) + if err != nil { + return nil, err + } + if wrapInNot { + conditions = append(conditions, Not(condition)) + wrapInNot = false + } else { + conditions = append(conditions, condition) + } + expectingMore = false + } + } +} + +func parseCondition(firstSnippet *snippet, getSnippet func() (*snippet, error)) (Condition, error) { + wrapInNot := false + + // get operator name + opName, err := getSnippet() + if err != nil { + return nil, err + } + // negate? + if opName.text == "not" { + wrapInNot = true + opName, err = getSnippet() + if err != nil { + return nil, err + } + } + + // get operator + operator, ok := operatorNames[opName.text] + if !ok { + return nil, fmt.Errorf("unknown operator at position %d", opName.globalPosition) + } + + // don't need a value for "exists" + if operator == Exists { + if wrapInNot { + return Not(Where(firstSnippet.text, operator, nil)), nil + } + return Where(firstSnippet.text, operator, nil), nil + } + + // get value + value, err := getSnippet() + if err != nil { + return nil, err + } + if wrapInNot { + return Not(Where(firstSnippet.text, operator, value.text)), nil + } + return Where(firstSnippet.text, operator, value.text), nil +} + +var ( + escapeReplacer = regexp.MustCompile("\\\\([^\\\\])") +) + +// prepToken removes surrounding parenthesis and escape characters. +func prepToken(text string) string { + return escapeReplacer.ReplaceAllString(strings.Trim(text, "\""), "$1") +} + +// escapeString correctly escapes a snippet for printing +func escapeString(token string) string { + // check if token contains characters that need to be escaped + if strings.ContainsAny(token, "()\"\\\t\r\n ") { + // put the token in parenthesis and only escape \ and " + return fmt.Sprintf("\"%s\"", strings.Replace(token, "\"", "\\\"", -1)) + } + return token +} diff --git a/database/query/parser_test.go b/database/query/parser_test.go new file mode 100644 index 0000000..7bd2bda --- /dev/null +++ b/database/query/parser_test.go @@ -0,0 +1,167 @@ +package query + +import ( + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestExtractSnippets(t *testing.T) { + text1 := `query test: where ( "bananas" > 100 and monkeys.# <= "12")or(coconuts < 10 "and" area > 50) or name sameas Julian or name matches ^King\ ` + result1 := []*snippet{ + &snippet{text: "query", globalPosition: 1}, + &snippet{text: "test:", globalPosition: 7}, + &snippet{text: "where", globalPosition: 13}, + &snippet{text: "(", globalPosition: 19}, + &snippet{text: "bananas", globalPosition: 21}, + &snippet{text: ">", globalPosition: 31}, + &snippet{text: "100", globalPosition: 33}, + &snippet{text: "and", globalPosition: 37}, + &snippet{text: "monkeys.#", globalPosition: 41}, + &snippet{text: "<=", globalPosition: 51}, + &snippet{text: "12", globalPosition: 54}, + &snippet{text: ")", globalPosition: 58}, + &snippet{text: "or", globalPosition: 59}, + &snippet{text: "(", globalPosition: 61}, + &snippet{text: "coconuts", globalPosition: 62}, + &snippet{text: "<", globalPosition: 71}, + &snippet{text: "10", globalPosition: 73}, + &snippet{text: "and", globalPosition: 76}, + &snippet{text: "area", globalPosition: 82}, + &snippet{text: ">", globalPosition: 87}, + &snippet{text: "50", globalPosition: 89}, + &snippet{text: ")", globalPosition: 91}, + &snippet{text: "or", globalPosition: 93}, + &snippet{text: "name", globalPosition: 96}, + &snippet{text: "sameas", globalPosition: 101}, + &snippet{text: "Julian", globalPosition: 108}, + &snippet{text: "or", globalPosition: 115}, + &snippet{text: "name", globalPosition: 118}, + &snippet{text: "matches", globalPosition: 123}, + &snippet{text: "^King ", globalPosition: 131}, + } + + snippets, err := extractSnippets(text1) + if err != nil { + t.Errorf("failed to extract snippets: %s", err) + } + + if !reflect.DeepEqual(result1, snippets) { + t.Errorf("unexpected results:") + for _, el := range snippets { + t.Errorf("%+v", el) + } + } + + // t.Error(spew.Sprintf("%v", treeElement)) +} + +func testParsing(t *testing.T, queryText string, expectedResult *Query) { + _, err := expectedResult.Check() + if err != nil { + t.Errorf("failed to create query: %s", err) + return + } + + q, err := ParseQuery(queryText) + if err != nil { + t.Errorf("failed to parse query: %s", err) + return + } + + if queryText != q.Print() { + t.Errorf("string match failed: %s", q.Print()) + return + } + if !reflect.DeepEqual(expectedResult, q) { + t.Error("deepqual match failed.") + t.Error("got:") + t.Error(spew.Sdump(q)) + t.Error("expected:") + t.Error(spew.Sdump(expectedResult)) + } +} + +func TestParseQuery(t *testing.T) { + text1 := `query test: where (bananas > 100 and monkeys.# <= 12) or not (coconuts < 10 and area not > 50) or name sameas Julian or name matches "^King " orderby name limit 10 offset 20` + result1 := New("test:").Where(Or( + And( + Where("bananas", GreaterThan, 100), + Where("monkeys.#", LessThanOrEqual, 12), + ), + Not(And( + Where("coconuts", LessThan, 10), + Not(Where("area", GreaterThan, 50)), + )), + Where("name", SameAs, "Julian"), + Where("name", Matches, "^King "), + )).OrderBy("name").Limit(10).Offset(20) + testParsing(t, text1, result1) + + testParsing(t, `query test: orderby name`, New("test:").OrderBy("name")) + testParsing(t, `query test: limit 10`, New("test:").Limit(10)) + testParsing(t, `query test: offset 10`, New("test:").Offset(10)) + testParsing(t, `query test: where banana matches ^ban`, New("test:").Where(Where("banana", Matches, "^ban"))) + testParsing(t, `query test: where banana exists`, New("test:").Where(Where("banana", Exists, nil))) + testParsing(t, `query test: where banana not exists`, New("test:").Where(Not(Where("banana", Exists, nil)))) + + // test all operators + testParsing(t, `query test: where banana == 1`, New("test:").Where(Where("banana", Equals, 1))) + testParsing(t, `query test: where banana > 1`, New("test:").Where(Where("banana", GreaterThan, 1))) + testParsing(t, `query test: where banana >= 1`, New("test:").Where(Where("banana", GreaterThanOrEqual, 1))) + testParsing(t, `query test: where banana < 1`, New("test:").Where(Where("banana", LessThan, 1))) + testParsing(t, `query test: where banana <= 1`, New("test:").Where(Where("banana", LessThanOrEqual, 1))) + testParsing(t, `query test: where banana f== 1.1`, New("test:").Where(Where("banana", FloatEquals, 1.1))) + testParsing(t, `query test: where banana f> 1.1`, New("test:").Where(Where("banana", FloatGreaterThan, 1.1))) + testParsing(t, `query test: where banana f>= 1.1`, New("test:").Where(Where("banana", FloatGreaterThanOrEqual, 1.1))) + testParsing(t, `query test: where banana f< 1.1`, New("test:").Where(Where("banana", FloatLessThan, 1.1))) + testParsing(t, `query test: where banana f<= 1.1`, New("test:").Where(Where("banana", FloatLessThanOrEqual, 1.1))) + testParsing(t, `query test: where banana sameas banana`, New("test:").Where(Where("banana", SameAs, "banana"))) + testParsing(t, `query test: where banana contains banana`, New("test:").Where(Where("banana", Contains, "banana"))) + testParsing(t, `query test: where banana startswith banana`, New("test:").Where(Where("banana", StartsWith, "banana"))) + testParsing(t, `query test: where banana endswith banana`, New("test:").Where(Where("banana", EndsWith, "banana"))) + testParsing(t, `query test: where banana in banana,coconut`, New("test:").Where(Where("banana", In, []string{"banana", "coconut"}))) + testParsing(t, `query test: where banana matches banana`, New("test:").Where(Where("banana", Matches, "banana"))) + testParsing(t, `query test: where banana is true`, New("test:").Where(Where("banana", Is, true))) + testParsing(t, `query test: where banana exists`, New("test:").Where(Where("banana", Exists, nil))) + + // special + testParsing(t, `query test: where banana not exists`, New("test:").Where(Not(Where("banana", Exists, nil)))) +} + +func testParseError(t *testing.T, queryText string, expectedErrorString string) { + _, err := ParseQuery(queryText) + if err == nil { + t.Errorf("should fail to parse: %s", queryText) + return + } + if err.Error() != expectedErrorString { + t.Errorf("unexpected error for query: %s\nwanted: %s\n got: %s", queryText, expectedErrorString, err) + } +} + +func TestParseErrors(t *testing.T) { + // syntax + testParseError(t, `query`, `unexpected end at position 5`) + testParseError(t, `query test: where`, `unexpected end at position 17`) + testParseError(t, `query test: where (`, `unexpected end at position 19`) + testParseError(t, `query test: where )`, `unknown clause ")" at position 19`) + testParseError(t, `query test: where not`, `unexpected end at position 21`) + testParseError(t, `query test: where banana`, `unexpected end at position 24`) + testParseError(t, `query test: where banana >`, `unexpected end at position 26`) + testParseError(t, `query test: where banana nope`, `unknown operator at position 26`) + testParseError(t, `query test: where banana exists or`, `unexpected end at position 34`) + testParseError(t, `query test: where banana exists and`, `unexpected end at position 35`) + testParseError(t, `query test: where banana exists and (`, `unexpected end at position 37`) + testParseError(t, `query test: where banana exists and banana is true or`, `you may not mix "and" and "or" (position: 52)`) + testParseError(t, `query test: where banana exists or banana is true and`, `you may not mix "and" and "or" (position: 51)`) + // testParseError(t, `query test: where banana exists and (`, ``) + + // value parsing error + testParseError(t, `query test: where banana == banana`, `could not parse banana to int64: strconv.ParseInt: parsing "banana": invalid syntax (hint: use "sameas" to compare strings)`) + testParseError(t, `query test: where banana f== banana`, `could not parse banana to float64: strconv.ParseFloat: parsing "banana": invalid syntax`) + testParseError(t, `query test: where banana in banana`, `could not parse "banana" to []string`) + testParseError(t, `query test: where banana matches [banana`, "could not compile regex \"[banana\": error parsing regexp: missing closing ]: `[banana`") + testParseError(t, `query test: where banana is great`, `could not parse "great" to bool: strconv.ParseBool: parsing "great": invalid syntax`) +} diff --git a/database/query/query.go b/database/query/query.go new file mode 100644 index 0000000..d9ecb7f --- /dev/null +++ b/database/query/query.go @@ -0,0 +1,173 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/Safing/portbase/database/accessor" + "github.com/Safing/portbase/database/record" +) + +// Example: +// q.New("core:/", +// q.Where("a", q.GreaterThan, 0), +// q.Where("b", q.Equals, 0), +// q.Or( +// q.Where("c", q.StartsWith, "x"), +// q.Where("d", q.Contains, "y") +// ) +// ) + +// Query contains a compiled query. +type Query struct { + checked bool + dbName string + dbKeyPrefix string + where Condition + orderBy string + limit int + offset int +} + +// New creates a new query with the supplied prefix. +func New(prefix string) *Query { + dbName, dbKeyPrefix := record.ParseKey(prefix) + return &Query{ + dbName: dbName, + dbKeyPrefix: dbKeyPrefix, + } +} + +// Where adds filtering. +func (q *Query) Where(condition Condition) *Query { + q.where = condition + return q +} + +// Limit limits the number of returned results. +func (q *Query) Limit(limit int) *Query { + q.limit = limit + return q +} + +// Offset sets the query offset. +func (q *Query) Offset(offset int) *Query { + q.offset = offset + return q +} + +// OrderBy orders the results by the given key. +func (q *Query) OrderBy(key string) *Query { + q.orderBy = key + return q +} + +// Check checks for errors in the query. +func (q *Query) Check() (*Query, error) { + if q.checked { + return q, nil + } + + // check condition + if q.where != nil { + err := q.where.check() + if err != nil { + return nil, err + } + } + + q.checked = true + return q, nil +} + +// MustBeValid checks for errors in the query and panics if there is an error. +func (q *Query) MustBeValid() *Query { + _, err := q.Check() + if err != nil { + panic(err) + } + return q +} + +// IsChecked returns whether they query was checked. +func (q *Query) IsChecked() bool { + return q.checked +} + +// MatchesKey checks whether the query matches the supplied database key (key without database prefix). +func (q *Query) MatchesKey(dbKey string) bool { + if !strings.HasPrefix(dbKey, q.dbKeyPrefix) { + return false + } + return true +} + +// MatchesRecord checks whether the query matches the supplied database record (value only). +func (q *Query) MatchesRecord(r record.Record) bool { + if q.where == nil { + return true + } + + acc := r.GetAccessor(r) + if acc == nil { + return false + } + return q.where.complies(acc) +} + +// MatchesAccessor checks whether the query matches the supplied accessor (value only). +func (q *Query) MatchesAccessor(acc accessor.Accessor) bool { + if q.where == nil { + return true + } + return q.where.complies(acc) +} + +// Matches checks whether the query matches the supplied database record. +func (q *Query) Matches(r record.Record) bool { + if q.MatchesKey(r.DatabaseKey()) { + return true + } + return q.MatchesRecord(r) +} + +// Print returns the string representation of the query. +func (q *Query) Print() string { + var where string + if q.where != nil { + where = q.where.string() + if where != "" { + if strings.HasPrefix(where, "(") { + where = where[1 : len(where)-1] + } + where = fmt.Sprintf(" where %s", where) + } + } + + var orderBy string + if q.orderBy != "" { + orderBy = fmt.Sprintf(" orderby %s", q.orderBy) + } + + var limit string + if q.limit > 0 { + limit = fmt.Sprintf(" limit %d", q.limit) + } + + var offset string + if q.offset > 0 { + offset = fmt.Sprintf(" offset %d", q.offset) + } + + return fmt.Sprintf("query %s:%s%s%s%s%s", q.dbName, q.dbKeyPrefix, where, orderBy, limit, offset) +} + +// DatabaseName returns the name of the database. +func (q *Query) DatabaseName() string { + return q.dbName +} + +// DatabaseKeyPrefix returns the key prefix for the database. +func (q *Query) DatabaseKeyPrefix() string { + return q.dbKeyPrefix +} diff --git a/database/query/query_test.go b/database/query/query_test.go new file mode 100644 index 0000000..4645a61 --- /dev/null +++ b/database/query/query_test.go @@ -0,0 +1,112 @@ +package query + +import ( + "testing" + + "github.com/Safing/portbase/database/record" +) + +var ( + // copied from https://github.com/tidwall/gjson/blob/master/gjson_test.go + testJSON = `{"age":100, "name":{"here":"B\\\"R"}, + "noop":{"what is a wren?":"a bird"}, + "happy":true,"immortal":false, + "items":[1,2,3,{"tags":[1,2,3],"points":[[1,2],[3,4]]},4,5,6,7], + "arr":["1",2,"3",{"hello":"world"},"4",5], + "vals":[1,2,3,{"sadf":sdf"asdf"}],"name":{"first":"tom","last":null}, + "created":"2014-05-16T08:28:06.989Z", + "loggy":{ + "programmers": [ + { + "firstName": "Brett", + "lastName": "McLaughlin", + "email": "aaaa", + "tag": "good" + }, + { + "firstName": "Jason", + "lastName": "Hunter", + "email": "bbbb", + "tag": "bad" + }, + { + "firstName": "Elliotte", + "lastName": "Harold", + "email": "cccc", + "tag":, "good" + }, + { + "firstName": 1002.3, + "age": 101 + } + ] + }, + "lastly":{"yay":"final"}, + "temperature": 120.413 +}` +) + +func testQuery(t *testing.T, r record.Record, shouldMatch bool, condition Condition) { + q := New("test:").Where(condition).MustBeValid() + + // fmt.Printf("%s\n", q.Print()) + + matched := q.Matches(r) + switch { + case !matched && shouldMatch: + t.Errorf("should match: %s", q.Print()) + case matched && !shouldMatch: + t.Errorf("should not match: %s", q.Print()) + } +} + +func TestQuery(t *testing.T) { + + // if !gjson.Valid(testJSON) { + // t.Fatal("test json is invalid") + // } + r, err := record.NewWrapper("", nil, append([]byte("J"), []byte(testJSON)...)) + if err != nil { + t.Fatal(err) + } + + testQuery(t, r, true, Where("age", Equals, 100)) + testQuery(t, r, true, Where("age", GreaterThan, uint8(99))) + testQuery(t, r, true, Where("age", GreaterThanOrEqual, 99)) + testQuery(t, r, true, Where("age", GreaterThanOrEqual, 100)) + testQuery(t, r, true, Where("age", LessThan, 101)) + testQuery(t, r, true, Where("age", LessThanOrEqual, "101")) + testQuery(t, r, true, Where("age", LessThanOrEqual, 100)) + + testQuery(t, r, true, Where("temperature", FloatEquals, 120.413)) + testQuery(t, r, true, Where("temperature", FloatGreaterThan, 120)) + testQuery(t, r, true, Where("temperature", FloatGreaterThanOrEqual, 120)) + testQuery(t, r, true, Where("temperature", FloatGreaterThanOrEqual, 120.413)) + testQuery(t, r, true, Where("temperature", FloatLessThan, 121)) + testQuery(t, r, true, Where("temperature", FloatLessThanOrEqual, "121")) + testQuery(t, r, true, Where("temperature", FloatLessThanOrEqual, "120.413")) + + testQuery(t, r, true, Where("lastly.yay", SameAs, "final")) + testQuery(t, r, true, Where("lastly.yay", Contains, "ina")) + testQuery(t, r, true, Where("lastly.yay", StartsWith, "fin")) + testQuery(t, r, true, Where("lastly.yay", EndsWith, "nal")) + testQuery(t, r, true, Where("lastly.yay", In, "draft,final")) + testQuery(t, r, true, Where("lastly.yay", In, "final,draft")) + + testQuery(t, r, true, Where("happy", Is, true)) + testQuery(t, r, true, Where("happy", Is, "true")) + testQuery(t, r, true, Where("happy", Is, "t")) + testQuery(t, r, true, Not(Where("happy", Is, "0"))) + testQuery(t, r, true, And( + Where("happy", Is, "1"), + Not(Or( + Where("happy", Is, false), + Where("happy", Is, "f"), + )), + )) + + testQuery(t, r, true, Where("happy", Exists, nil)) + + testQuery(t, r, true, Where("created", Matches, "^2014-[0-9]{2}-[0-9]{2}T")) + +} diff --git a/database/record/base.go b/database/record/base.go new file mode 100644 index 0000000..0c7597f --- /dev/null +++ b/database/record/base.go @@ -0,0 +1,106 @@ +package record + +import ( + "errors" + "fmt" + + "github.com/Safing/portbase/container" + "github.com/Safing/portbase/database/accessor" + "github.com/Safing/portbase/formats/dsd" +) + +// Base provides a quick way to comply with the Model interface. +type Base struct { + dbName string + dbKey string + meta *Meta +} + +// Key returns the key of the database record. +func (b *Base) Key() string { + return fmt.Sprintf("%s:%s", b.dbName, b.dbKey) +} + +// DatabaseName returns the name of the database. +func (b *Base) DatabaseName() string { + return b.dbName +} + +// DatabaseKey returns the database key of the database record. +func (b *Base) DatabaseKey() string { + return b.dbKey +} + +// SetKey sets the key on the database record, it should only be called after loading the record. Use MoveTo to save the record with another key. +func (b *Base) SetKey(key string) { + b.dbName, b.dbKey = ParseKey(key) +} + +// MoveTo sets a new key for the record and resets all metadata, except for the secret and crownjewel status. +func (b *Base) MoveTo(key string) { + b.SetKey(key) + b.meta.Reset() +} + +// Meta returns the metadata object for this record. +func (b *Base) Meta() *Meta { + return b.meta +} + +// SetMeta sets the metadata on the database record, it should only be called after loading the record. Use MoveTo to save the record with another key. +func (b *Base) SetMeta(meta *Meta) { + b.meta = meta +} + +// Marshal marshals the object, without the database key or metadata +func (b *Base) Marshal(self Record, format uint8) ([]byte, error) { + if b.Meta() == nil { + return nil, errors.New("missing meta") + } + + if b.Meta().Deleted > 0 { + return nil, nil + } + + dumped, err := dsd.Dump(self, format) + if err != nil { + return nil, err + } + return dumped, nil +} + +// MarshalRecord packs the object, including metadata, into a byte array for saving in a database. +func (b *Base) MarshalRecord(self Record) ([]byte, error) { + if b.Meta() == nil { + return nil, errors.New("missing meta") + } + + // version + c := container.New([]byte{1}) + + // meta + metaSection, err := b.meta.GenCodeMarshal(nil) + if err != nil { + return nil, err + } + c.AppendAsBlock(metaSection) + + // data + dataSection, err := b.Marshal(self, dsd.JSON) + if err != nil { + return nil, err + } + c.Append(dataSection) + + return c.CompileData(), nil +} + +// IsWrapped returns whether the record is a Wrapper. +func (b *Base) IsWrapped() bool { + return false +} + +// GetAccessor returns an accessor for this record, if available. +func (b *Base) GetAccessor(self Record) accessor.Accessor { + return accessor.NewStructAccessor(self) +} diff --git a/database/record/base_test.go b/database/record/base_test.go new file mode 100644 index 0000000..f207bb1 --- /dev/null +++ b/database/record/base_test.go @@ -0,0 +1,13 @@ +package record + +import "testing" + +func TestBaseRecord(t *testing.T) { + + // check model interface compliance + var m Record + b := &TestRecord{} + m = b + _ = m + +} diff --git a/database/record/formats.go b/database/record/formats.go new file mode 100644 index 0000000..d453337 --- /dev/null +++ b/database/record/formats.go @@ -0,0 +1,15 @@ +package record + +import ( + "github.com/Safing/portbase/formats/dsd" +) + +// Reimport DSD storage types +const ( + AUTO = dsd.AUTO + STRING = dsd.STRING // S + BYTES = dsd.BYTES // X + JSON = dsd.JSON // J + BSON = dsd.BSON // B + GenCode = dsd.GenCode // G (reserved) +) diff --git a/database/record/key.go b/database/record/key.go new file mode 100644 index 0000000..b02eecf --- /dev/null +++ b/database/record/key.go @@ -0,0 +1,14 @@ +package record + +import ( + "strings" +) + +// ParseKey splits a key into it's database name and key parts. +func ParseKey(key string) (dbName, dbKey string) { + splitted := strings.SplitN(key, ":", 2) + if len(splitted) == 2 { + return splitted[0], splitted[1] + } + return splitted[0], "" +} diff --git a/database/record/meta-bench_test.go b/database/record/meta-bench_test.go new file mode 100644 index 0000000..ca845c6 --- /dev/null +++ b/database/record/meta-bench_test.go @@ -0,0 +1,466 @@ +package record + +// Benchmark: +// BenchmarkAllocateBytes-8 2000000000 0.76 ns/op +// BenchmarkAllocateStruct1-8 2000000000 0.76 ns/op +// BenchmarkAllocateStruct2-8 2000000000 0.79 ns/op +// BenchmarkMetaSerializeContainer-8 1000000 1703 ns/op +// BenchmarkMetaUnserializeContainer-8 2000000 950 ns/op +// BenchmarkMetaSerializeVarInt-8 3000000 457 ns/op +// BenchmarkMetaUnserializeVarInt-8 20000000 62.9 ns/op +// BenchmarkMetaSerializeWithXDR2-8 1000000 2360 ns/op +// BenchmarkMetaUnserializeWithXDR2-8 500000 3189 ns/op +// BenchmarkMetaSerializeWithColfer-8 10000000 237 ns/op +// BenchmarkMetaUnserializeWithColfer-8 20000000 51.7 ns/op +// BenchmarkMetaSerializeWithCodegen-8 50000000 23.7 ns/op +// BenchmarkMetaUnserializeWithCodegen-8 100000000 18.9 ns/op +// BenchmarkMetaSerializeWithDSDJSON-8 1000000 2398 ns/op +// BenchmarkMetaUnserializeWithDSDJSON-8 300000 6264 ns/op + +import ( + "testing" + "time" + + "github.com/Safing/portbase/container" + "github.com/Safing/portbase/formats/dsd" + "github.com/Safing/portbase/formats/varint" + // Colfer + // "github.com/Safing/portbase/database/model/model" + // XDR + // xdr2 "github.com/davecgh/go-xdr/xdr2" +) + +var ( + testMeta = &Meta{ + Created: time.Now().Unix(), + Modified: time.Now().Unix(), + Expires: time.Now().Unix(), + Deleted: time.Now().Unix(), + secret: true, + cronjewel: true, + } +) + +func BenchmarkAllocateBytes(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = make([]byte, 33) + } +} + +func BenchmarkAllocateStruct1(b *testing.B) { + for i := 0; i < b.N; i++ { + var new Meta + _ = new + } +} + +func BenchmarkAllocateStruct2(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Meta{} + } +} + +func BenchmarkMetaSerializeContainer(b *testing.B) { + + // Start benchmark + for i := 0; i < b.N; i++ { + c := container.New() + c.AppendNumber(uint64(testMeta.Created)) + c.AppendNumber(uint64(testMeta.Modified)) + c.AppendNumber(uint64(testMeta.Expires)) + c.AppendNumber(uint64(testMeta.Deleted)) + switch { + case testMeta.secret && testMeta.cronjewel: + c.AppendNumber(3) + case testMeta.secret: + c.AppendNumber(1) + case testMeta.cronjewel: + c.AppendNumber(2) + default: + c.AppendNumber(0) + } + } + +} + +func BenchmarkMetaUnserializeContainer(b *testing.B) { + + // Setup + c := container.New() + c.AppendNumber(uint64(testMeta.Created)) + c.AppendNumber(uint64(testMeta.Modified)) + c.AppendNumber(uint64(testMeta.Expires)) + c.AppendNumber(uint64(testMeta.Deleted)) + switch { + case testMeta.secret && testMeta.cronjewel: + c.AppendNumber(3) + case testMeta.secret: + c.AppendNumber(1) + case testMeta.cronjewel: + c.AppendNumber(2) + default: + c.AppendNumber(0) + } + encodedData := c.CompileData() + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + var newMeta Meta + var err error + var num uint64 + c := container.New(encodedData) + num, err = c.GetNextN64() + newMeta.Created = int64(num) + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + num, err = c.GetNextN64() + newMeta.Modified = int64(num) + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + num, err = c.GetNextN64() + newMeta.Expires = int64(num) + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + num, err = c.GetNextN64() + newMeta.Deleted = int64(num) + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + + flags, err := c.GetNextN8() + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + + switch flags { + case 3: + newMeta.secret = true + newMeta.cronjewel = true + case 2: + newMeta.cronjewel = true + case 1: + newMeta.secret = true + case 0: + default: + b.Errorf("invalid flag value: %d", flags) + return + } + } + +} + +func BenchmarkMetaSerializeVarInt(b *testing.B) { + + // Start benchmark + for i := 0; i < b.N; i++ { + encoded := make([]byte, 33) + offset := 0 + data := varint.Pack64(uint64(testMeta.Created)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Modified)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Expires)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Deleted)) + for _, part := range data { + encoded[offset] = part + offset++ + } + + switch { + case testMeta.secret && testMeta.cronjewel: + encoded[offset] = 3 + case testMeta.secret: + encoded[offset] = 1 + case testMeta.cronjewel: + encoded[offset] = 2 + default: + encoded[offset] = 0 + } + offset++ + } + +} + +func BenchmarkMetaUnserializeVarInt(b *testing.B) { + + // Setup + encoded := make([]byte, 33) + offset := 0 + data := varint.Pack64(uint64(testMeta.Created)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Modified)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Expires)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Deleted)) + for _, part := range data { + encoded[offset] = part + offset++ + } + + switch { + case testMeta.secret && testMeta.cronjewel: + encoded[offset] = 3 + case testMeta.secret: + encoded[offset] = 1 + case testMeta.cronjewel: + encoded[offset] = 2 + default: + encoded[offset] = 0 + } + offset++ + encodedData := encoded[:offset] + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + var newMeta Meta + offset = 0 + + num, n, err := varint.Unpack64(encodedData) + if err != nil { + b.Error(err) + return + } + testMeta.Created = int64(num) + offset += n + + num, n, err = varint.Unpack64(encodedData[offset:]) + if err != nil { + b.Error(err) + return + } + testMeta.Modified = int64(num) + offset += n + + num, n, err = varint.Unpack64(encodedData[offset:]) + if err != nil { + b.Error(err) + return + } + testMeta.Expires = int64(num) + offset += n + + num, n, err = varint.Unpack64(encodedData[offset:]) + if err != nil { + b.Error(err) + return + } + testMeta.Deleted = int64(num) + offset += n + + switch encodedData[offset] { + case 3: + newMeta.secret = true + newMeta.cronjewel = true + case 2: + newMeta.cronjewel = true + case 1: + newMeta.secret = true + case 0: + default: + b.Errorf("invalid flag value: %d", encodedData[offset]) + return + } + } + +} + +// func BenchmarkMetaSerializeWithXDR2(b *testing.B) { +// +// // Setup +// var w bytes.Buffer +// +// // Reset timer for precise results +// b.ResetTimer() +// +// // Start benchmark +// for i := 0; i < b.N; i++ { +// w.Reset() +// _, err := xdr2.Marshal(&w, testMeta) +// if err != nil { +// b.Errorf("failed to serialize with xdr2: %s", err) +// return +// } +// } +// +// } + +// func BenchmarkMetaUnserializeWithXDR2(b *testing.B) { +// +// // Setup +// var w bytes.Buffer +// _, err := xdr2.Marshal(&w, testMeta) +// if err != nil { +// b.Errorf("failed to serialize with xdr2: %s", err) +// } +// encodedData := w.Bytes() +// +// // Reset timer for precise results +// b.ResetTimer() +// +// // Start benchmark +// for i := 0; i < b.N; i++ { +// var newMeta Meta +// _, err := xdr2.Unmarshal(bytes.NewReader(encodedData), &newMeta) +// if err != nil { +// b.Errorf("failed to unserialize with xdr2: %s", err) +// return +// } +// } +// +// } + +// func BenchmarkMetaSerializeWithColfer(b *testing.B) { +// +// testColf := &model.Course{ +// Created: time.Now().Unix(), +// Modified: time.Now().Unix(), +// Expires: time.Now().Unix(), +// Deleted: time.Now().Unix(), +// Secret: true, +// Cronjewel: true, +// } +// +// // Setup +// for i := 0; i < b.N; i++ { +// _, err := testColf.MarshalBinary() +// if err != nil { +// b.Errorf("failed to serialize with colfer: %s", err) +// return +// } +// } +// +// } + +// func BenchmarkMetaUnserializeWithColfer(b *testing.B) { +// +// testColf := &model.Course{ +// Created: time.Now().Unix(), +// Modified: time.Now().Unix(), +// Expires: time.Now().Unix(), +// Deleted: time.Now().Unix(), +// Secret: true, +// Cronjewel: true, +// } +// encodedData, err := testColf.MarshalBinary() +// if err != nil { +// b.Errorf("failed to serialize with colfer: %s", err) +// return +// } +// +// // Setup +// for i := 0; i < b.N; i++ { +// var testUnColf model.Course +// err := testUnColf.UnmarshalBinary(encodedData) +// if err != nil { +// b.Errorf("failed to unserialize with colfer: %s", err) +// return +// } +// } +// +// } + +func BenchmarkMetaSerializeWithCodegen(b *testing.B) { + + for i := 0; i < b.N; i++ { + _, err := testMeta.GenCodeMarshal(nil) + if err != nil { + b.Errorf("failed to serialize with codegen: %s", err) + return + } + } + +} + +func BenchmarkMetaUnserializeWithCodegen(b *testing.B) { + + // Setup + encodedData, err := testMeta.GenCodeMarshal(nil) + if err != nil { + b.Errorf("failed to serialize with codegen: %s", err) + return + } + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + var newMeta Meta + _, err := newMeta.GenCodeUnmarshal(encodedData) + if err != nil { + b.Errorf("failed to unserialize with codegen: %s", err) + return + } + } + +} + +func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) { + + for i := 0; i < b.N; i++ { + _, err := dsd.Dump(testMeta, dsd.JSON) + if err != nil { + b.Errorf("failed to serialize with DSD/JSON: %s", err) + return + } + } + +} + +func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) { + + // Setup + encodedData, err := dsd.Dump(testMeta, dsd.JSON) + if err != nil { + b.Errorf("failed to serialize with DSD/JSON: %s", err) + return + } + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + var newMeta Meta + _, err := dsd.Load(encodedData, &newMeta) + if err != nil { + b.Errorf("failed to unserialize with DSD/JSON: %s", err) + return + } + } + +} diff --git a/database/record/meta-gencode.go b/database/record/meta-gencode.go new file mode 100644 index 0000000..c0a2142 --- /dev/null +++ b/database/record/meta-gencode.go @@ -0,0 +1,162 @@ +package record + +import ( + "fmt" + "io" + "time" + "unsafe" +) + +var ( + _ = unsafe.Sizeof(0) + _ = io.ReadFull + _ = time.Now() +) + +// GenCodeSize returns the size of the gencode marshalled byte slice +func (d *Meta) GenCodeSize() (s int) { + s += 34 + return +} + +// GenCodeMarshal gencode marshalls Meta into the given byte array, or a new one if its too small. +func (d *Meta) GenCodeMarshal(buf []byte) ([]byte, error) { + size := d.GenCodeSize() + { + if cap(buf) >= size { + buf = buf[:size] + } else { + buf = make([]byte, size) + } + } + i := uint64(0) + + { + + buf[0+0] = byte(d.Created >> 0) + + buf[1+0] = byte(d.Created >> 8) + + buf[2+0] = byte(d.Created >> 16) + + buf[3+0] = byte(d.Created >> 24) + + buf[4+0] = byte(d.Created >> 32) + + buf[5+0] = byte(d.Created >> 40) + + buf[6+0] = byte(d.Created >> 48) + + buf[7+0] = byte(d.Created >> 56) + + } + { + + buf[0+8] = byte(d.Modified >> 0) + + buf[1+8] = byte(d.Modified >> 8) + + buf[2+8] = byte(d.Modified >> 16) + + buf[3+8] = byte(d.Modified >> 24) + + buf[4+8] = byte(d.Modified >> 32) + + buf[5+8] = byte(d.Modified >> 40) + + buf[6+8] = byte(d.Modified >> 48) + + buf[7+8] = byte(d.Modified >> 56) + + } + { + + buf[0+16] = byte(d.Expires >> 0) + + buf[1+16] = byte(d.Expires >> 8) + + buf[2+16] = byte(d.Expires >> 16) + + buf[3+16] = byte(d.Expires >> 24) + + buf[4+16] = byte(d.Expires >> 32) + + buf[5+16] = byte(d.Expires >> 40) + + buf[6+16] = byte(d.Expires >> 48) + + buf[7+16] = byte(d.Expires >> 56) + + } + { + + buf[0+24] = byte(d.Deleted >> 0) + + buf[1+24] = byte(d.Deleted >> 8) + + buf[2+24] = byte(d.Deleted >> 16) + + buf[3+24] = byte(d.Deleted >> 24) + + buf[4+24] = byte(d.Deleted >> 32) + + buf[5+24] = byte(d.Deleted >> 40) + + buf[6+24] = byte(d.Deleted >> 48) + + buf[7+24] = byte(d.Deleted >> 56) + + } + { + if d.secret { + buf[32] = 1 + } else { + buf[32] = 0 + } + } + { + if d.cronjewel { + buf[33] = 1 + } else { + buf[33] = 0 + } + } + return buf[:i+34], nil +} + +// GenCodeUnmarshal gencode unmarshalls Meta and returns the bytes read. +func (d *Meta) GenCodeUnmarshal(buf []byte) (uint64, error) { + if len(buf) < d.GenCodeSize() { + return 0, fmt.Errorf("insufficient data: got %d out of %d bytes", len(buf), d.GenCodeSize()) + } + + i := uint64(0) + + { + + d.Created = 0 | (int64(buf[0+0]) << 0) | (int64(buf[1+0]) << 8) | (int64(buf[2+0]) << 16) | (int64(buf[3+0]) << 24) | (int64(buf[4+0]) << 32) | (int64(buf[5+0]) << 40) | (int64(buf[6+0]) << 48) | (int64(buf[7+0]) << 56) + + } + { + + d.Modified = 0 | (int64(buf[0+8]) << 0) | (int64(buf[1+8]) << 8) | (int64(buf[2+8]) << 16) | (int64(buf[3+8]) << 24) | (int64(buf[4+8]) << 32) | (int64(buf[5+8]) << 40) | (int64(buf[6+8]) << 48) | (int64(buf[7+8]) << 56) + + } + { + + d.Expires = 0 | (int64(buf[0+16]) << 0) | (int64(buf[1+16]) << 8) | (int64(buf[2+16]) << 16) | (int64(buf[3+16]) << 24) | (int64(buf[4+16]) << 32) | (int64(buf[5+16]) << 40) | (int64(buf[6+16]) << 48) | (int64(buf[7+16]) << 56) + + } + { + + d.Deleted = 0 | (int64(buf[0+24]) << 0) | (int64(buf[1+24]) << 8) | (int64(buf[2+24]) << 16) | (int64(buf[3+24]) << 24) | (int64(buf[4+24]) << 32) | (int64(buf[5+24]) << 40) | (int64(buf[6+24]) << 48) | (int64(buf[7+24]) << 56) + + } + { + d.secret = buf[32] == 1 + } + { + d.cronjewel = buf[33] == 1 + } + return i + 34, nil +} diff --git a/database/record/meta-gencode_test.go b/database/record/meta-gencode_test.go new file mode 100644 index 0000000..7050e7d --- /dev/null +++ b/database/record/meta-gencode_test.go @@ -0,0 +1,35 @@ +package record + +import ( + "reflect" + "testing" + "time" +) + +var ( + genCodeTestMeta = &Meta{ + Created: time.Now().Unix(), + Modified: time.Now().Unix(), + Expires: time.Now().Unix(), + Deleted: time.Now().Unix(), + secret: true, + cronjewel: true, + } +) + +func TestGenCode(t *testing.T) { + encoded, err := genCodeTestMeta.GenCodeMarshal(nil) + if err != nil { + t.Fatal(err) + } + + new := &Meta{} + _, err = new.GenCodeUnmarshal(encoded) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(genCodeTestMeta, new) { + t.Errorf("objects are not equal, got: %v", new) + } +} diff --git a/database/record/meta.colf b/database/record/meta.colf new file mode 100644 index 0000000..0072e92 --- /dev/null +++ b/database/record/meta.colf @@ -0,0 +1,10 @@ +package record + +type course struct { + Created int64 + Modified int64 + Expires int64 + Deleted int64 + Secret bool + Cronjewel bool +} diff --git a/database/record/meta.gencode b/database/record/meta.gencode new file mode 100644 index 0000000..7592e2d --- /dev/null +++ b/database/record/meta.gencode @@ -0,0 +1,8 @@ +struct Meta { + Created int64 + Modified int64 + Expires int64 + Deleted int64 + Secret bool + Cronjewel bool +} diff --git a/database/record/meta.go b/database/record/meta.go new file mode 100644 index 0000000..a2c8bb4 --- /dev/null +++ b/database/record/meta.go @@ -0,0 +1,103 @@ +package record + +import "time" + +// Meta holds +type Meta struct { + Created int64 + Modified int64 + Expires int64 + Deleted int64 + secret bool // secrets must not be sent to the UI, only synced between nodes + cronjewel bool // crownjewels must never leave the instance, but may be read by the UI +} + +// SetAbsoluteExpiry sets an absolute expiry time (in seconds), that is not affected when the record is updated. +func (m *Meta) SetAbsoluteExpiry(seconds int64) { + m.Expires = seconds + m.Deleted = 0 +} + +// SetRelativateExpiry sets a relative expiry time (ie. TTL in seconds) that is automatically updated whenever the record is updated/saved. +func (m *Meta) SetRelativateExpiry(seconds int64) { + if seconds >= 0 { + m.Deleted = -seconds + } +} + +// GetAbsoluteExpiry returns the absolute expiry time. +func (m *Meta) GetAbsoluteExpiry() int64 { + return m.Expires +} + +// GetRelativeExpiry returns the current relative expiry time - ie. seconds until expiry. +func (m *Meta) GetRelativeExpiry() int64 { + if m.Deleted < 0 { + return -m.Deleted + } + + abs := m.Expires - time.Now().Unix() + if abs < 0 { + return 0 + } + return abs +} + +// MakeCrownJewel marks the database records as a crownjewel, meaning that it will not be sent/synced to other devices. +func (m *Meta) MakeCrownJewel() { + m.cronjewel = true +} + +// MakeSecret sets the database record as secret, meaning that it may only be used internally, and not by interfacing processes, such as the UI. +func (m *Meta) MakeSecret() { + m.secret = true +} + +// Update updates the internal meta states and should be called before writing the record to the database. +func (m *Meta) Update() { + now := time.Now().Unix() + m.Modified = now + if m.Created == 0 { + m.Created = now + } + if m.Deleted < 0 { + m.Expires = now - m.Deleted + } +} + +// Reset resets all metadata, except for the secret and crownjewel status. +func (m *Meta) Reset() { + m.Created = 0 + m.Modified = 0 + m.Expires = 0 + m.Deleted = 0 +} + +// Delete marks the record as deleted. +func (m *Meta) Delete() { + m.Deleted = time.Now().Unix() +} + +// CheckValidity checks whether the database record is valid. +func (m *Meta) CheckValidity() (valid bool) { + switch { + case m.Deleted > 0: + return false + case m.Expires > 0 && m.Expires < time.Now().Unix(): + return false + default: + return true + } +} + +// CheckPermission checks whether the database record may be accessed with the following scope. +func (m *Meta) CheckPermission(local, internal bool) (permitted bool) { + switch { + case !local && m.cronjewel: + return false + case !internal && m.secret: + return false + default: + return true + } +} diff --git a/database/record/record.go b/database/record/record.go new file mode 100644 index 0000000..3895292 --- /dev/null +++ b/database/record/record.go @@ -0,0 +1,26 @@ +package record + +import ( + "github.com/Safing/portbase/database/accessor" +) + +// Record provides an interface for uniformally handling database records. +type Record interface { + Key() string // test:config + DatabaseName() string // test + DatabaseKey() string // config + + SetKey(key string) // test:config + MoveTo(key string) // test:config + Meta() *Meta + SetMeta(meta *Meta) + + Marshal(self Record, format uint8) ([]byte, error) + MarshalRecord(self Record) ([]byte, error) + GetAccessor(self Record) accessor.Accessor + + Lock() + Unlock() + + IsWrapped() bool +} diff --git a/database/record/record_test.go b/database/record/record_test.go new file mode 100644 index 0000000..f5e315d --- /dev/null +++ b/database/record/record_test.go @@ -0,0 +1,16 @@ +package record + +import "sync" + +type TestRecord struct { + Base + lock sync.Mutex +} + +func (tm *TestRecord) Lock() { + tm.lock.Lock() +} + +func (tm *TestRecord) Unlock() { + tm.lock.Unlock() +} diff --git a/database/record/wrapper.go b/database/record/wrapper.go new file mode 100644 index 0000000..63bdb66 --- /dev/null +++ b/database/record/wrapper.go @@ -0,0 +1,167 @@ +package record + +import ( + "errors" + "fmt" + "sync" + + "github.com/Safing/portbase/container" + "github.com/Safing/portbase/database/accessor" + "github.com/Safing/portbase/formats/dsd" + "github.com/Safing/portbase/formats/varint" +) + +// Wrapper wraps raw data and implements the Record interface. +type Wrapper struct { + Base + sync.Mutex + + Format uint8 + Data []byte +} + +// NewRawWrapper returns a record wrapper for the given data, including metadata. This is normally only used by storage backends when loading records. +func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) { + version, offset, err := varint.Unpack8(data) + if err != nil { + return nil, err + } + if version != 1 { + return nil, fmt.Errorf("incompatible record version: %d", version) + } + + metaSection, n, err := varint.GetNextBlock(data[offset:]) + if err != nil { + return nil, fmt.Errorf("could not get meta section: %s", err) + } + offset += n + + newMeta := &Meta{} + _, err = newMeta.GenCodeUnmarshal(metaSection) + if err != nil { + return nil, fmt.Errorf("could not unmarshal meta section: %s", err) + } + + format, _, err := varint.Unpack8(data[offset:]) + if err != nil { + return nil, fmt.Errorf("could not get dsd format: %s", err) + } + + return &Wrapper{ + Base{ + database, + key, + newMeta, + }, + sync.Mutex{}, + format, + data[offset:], + }, nil +} + +// NewWrapper returns a new record wrapper for the given data. +func NewWrapper(key string, meta *Meta, data []byte) (*Wrapper, error) { + format, _, err := varint.Unpack8(data) + if err != nil { + return nil, fmt.Errorf("could not get dsd format: %s", err) + } + + dbName, dbKey := ParseKey(key) + + return &Wrapper{ + Base{ + dbName: dbName, + dbKey: dbKey, + meta: meta, + }, + sync.Mutex{}, + format, + data, + }, nil +} + +// Marshal marshals the object, without the database key or metadata +func (w *Wrapper) Marshal(r Record, storageType uint8) ([]byte, error) { + if w.Meta() == nil { + return nil, errors.New("missing meta") + } + + if w.Meta().Deleted > 0 { + return nil, nil + } + + if storageType != dsd.AUTO && storageType != w.Format { + return nil, errors.New("could not dump model, wrapped object format mismatch") + } + return w.Data, nil +} + +// MarshalRecord packs the object, including metadata, into a byte array for saving in a database. +func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) { + // Duplication necessary, as the version from Base would call Base.Marshal instead of Wrapper.Marshal + + if w.Meta() == nil { + return nil, errors.New("missing meta") + } + + // version + c := container.New([]byte{1}) + + // meta + metaSection, err := w.meta.GenCodeMarshal(nil) + if err != nil { + return nil, err + } + c.AppendAsBlock(metaSection) + + // data + dataSection, err := w.Marshal(r, dsd.JSON) + if err != nil { + return nil, err + } + c.Append(dataSection) + + return c.CompileData(), nil +} + +// // Lock locks the record. +// func (w *Wrapper) Lock() { +// w.lock.Lock() +// } +// +// // Unlock unlocks the record. +// func (w *Wrapper) Unlock() { +// w.lock.Unlock() +// } + +// IsWrapped returns whether the record is a Wrapper. +func (w *Wrapper) IsWrapped() bool { + return true +} + +// Unwrap unwraps data into a record. +func Unwrap(wrapped, new Record) error { + wrapper, ok := wrapped.(*Wrapper) + if !ok { + return fmt.Errorf("cannot unwrap %T", wrapped) + } + + _, err := dsd.Load(wrapper.Data, new) + if err != nil { + return fmt.Errorf("failed to unwrap %T: %s", new, err) + } + + new.SetKey(wrapped.Key()) + new.SetMeta(wrapped.Meta()) + + return nil +} + +// GetAccessor returns an accessor for this record, if available. +func (w *Wrapper) GetAccessor(self Record) accessor.Accessor { + if len(w.Data) > 1 && w.Data[0] == JSON { + jsonData := w.Data[1:] + return accessor.NewJSONBytesAccessor(&jsonData) + } + return nil +} diff --git a/database/record/wrapper_test.go b/database/record/wrapper_test.go new file mode 100644 index 0000000..e2988f0 --- /dev/null +++ b/database/record/wrapper_test.go @@ -0,0 +1,55 @@ +package record + +import ( + "bytes" + "testing" + + "github.com/Safing/portbase/formats/dsd" +) + +func TestWrapper(t *testing.T) { + + // check model interface compliance + var m Record + w := &Wrapper{} + m = w + _ = m + + // create test data + testData := []byte(`J{"a": "b"}`) + + // test wrapper + wrapper, err := NewWrapper("test:a", &Meta{}, testData) + if err != nil { + t.Fatal(err) + } + if wrapper.Format != dsd.JSON { + t.Error("format mismatch") + } + if !bytes.Equal(testData, wrapper.Data) { + t.Error("data mismatch") + } + + encoded, err := wrapper.Marshal(wrapper, dsd.JSON) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(testData, encoded) { + t.Error("marshal mismatch") + } + + wrapper.SetMeta(&Meta{}) + raw, err := wrapper.MarshalRecord(wrapper) + if err != nil { + t.Fatal(err) + } + + wrapper2, err := NewRawWrapper("test", "a", raw) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(testData, wrapper2.Data) { + t.Error("marshal mismatch") + } + +} diff --git a/database/registry.go b/database/registry.go new file mode 100644 index 0000000..1322854 --- /dev/null +++ b/database/registry.go @@ -0,0 +1,158 @@ +package database + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "path" + "regexp" + "sync" + "time" + + "github.com/tevino/abool" +) + +const ( + registryFileName = "databases.json" +) + +var ( + writeRegistrySoon = abool.NewBool(false) + + registry map[string]*Database + registryLock sync.Mutex + + nameConstraint = regexp.MustCompile("^[A-Za-z0-9_-]{4,}$") +) + +// Register registers a new database. +// If the database is already registered, only +// the description and the primary API will be +// updated and the effective object will be returned. +func Register(new *Database) (*Database, error) { + if !initialized.IsSet() { + return nil, errors.New("database not initialized") + } + + registryLock.Lock() + defer registryLock.Unlock() + + registeredDB, ok := registry[new.Name] + save := false + + if ok { + // update database + if registeredDB.Description != new.Description { + registeredDB.Description = new.Description + save = true + } + if registeredDB.PrimaryAPI != new.PrimaryAPI { + registeredDB.PrimaryAPI = new.PrimaryAPI + save = true + } + } else { + // register new database + if !nameConstraint.MatchString(new.Name) { + return nil, errors.New("database name must only contain alphanumeric and `_-` characters and must be at least 4 characters long") + } + + now := time.Now().Round(time.Second) + new.Registered = now + new.LastUpdated = now + new.LastLoaded = time.Time{} + + registry[new.Name] = new + save = true + } + + if save { + if ok { + registeredDB.Updated() + } + err := saveRegistry(false) + if err != nil { + return nil, err + } + } + + if ok { + return registeredDB, nil + } + return nil, nil +} + +func getDatabase(name string) (*Database, error) { + registryLock.Lock() + defer registryLock.Unlock() + + registeredDB, ok := registry[name] + if !ok { + return nil, fmt.Errorf(`database "%s" not registered`, name) + } + if time.Now().Add(-24 * time.Hour).After(registeredDB.LastLoaded) { + writeRegistrySoon.Set() + } + registeredDB.Loaded() + + return registeredDB, nil +} + +func loadRegistry() error { + registryLock.Lock() + defer registryLock.Unlock() + + // read file + filePath := path.Join(rootDir, registryFileName) + data, err := ioutil.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + registry = make(map[string]*Database) + return nil + } + return err + } + + // parse + new := make(map[string]*Database) + err = json.Unmarshal(data, new) + if err != nil { + return err + } + + // set + registry = new + return nil +} + +func saveRegistry(lock bool) error { + if lock { + registryLock.Lock() + defer registryLock.Unlock() + } + + // marshal + data, err := json.MarshalIndent(registry, "", "\t") + if err != nil { + return err + } + + // write file + filePath := path.Join(rootDir, registryFileName) + return ioutil.WriteFile(filePath, data, 0600) +} + +func registryWriter() { + for { + select { + case <-time.After(1 * time.Hour): + if writeRegistrySoon.SetToIf(true, false) { + saveRegistry(true) + } + case <-shutdownSignal: + saveRegistry(true) + return + } + } +} diff --git a/database/storage/badger/badger.go b/database/storage/badger/badger.go new file mode 100644 index 0000000..0ff6abe --- /dev/null +++ b/database/storage/badger/badger.go @@ -0,0 +1,199 @@ +package badger + +import ( + "errors" + "fmt" + "time" + + "github.com/dgraph-io/badger" + + "github.com/Safing/portbase/database/iterator" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/database/storage" +) + +// Badger database made pluggable for portbase. +type Badger struct { + name string + db *badger.DB +} + +func init() { + storage.Register("badger", NewBadger) +} + +// NewBadger opens/creates a badger database. +func NewBadger(name, location string) (storage.Interface, error) { + opts := badger.DefaultOptions + opts.Dir = location + opts.ValueDir = location + + db, err := badger.Open(opts) + if err != nil { + return nil, err + } + + return &Badger{ + name: name, + db: db, + }, nil +} + +// Get returns a database record. +func (b *Badger) Get(key string) (record.Record, error) { + var item *badger.Item + + err := b.db.View(func(txn *badger.Txn) error { + var err error + item, err = txn.Get([]byte(key)) + if err != nil { + if err == badger.ErrKeyNotFound { + return storage.ErrNotFound + } + return err + } + return nil + }) + if err != nil { + return nil, err + } + + // DO NOT check for this, as we got our own machanism for that. + // if item.IsDeletedOrExpired() { + // return nil, storage.ErrNotFound + // } + + data, err := item.ValueCopy(nil) + if err != nil { + return nil, err + } + + m, err := record.NewRawWrapper(b.name, string(item.Key()), data) + if err != nil { + return nil, err + } + return m, nil +} + +// Put stores a record in the database. +func (b *Badger) Put(r record.Record) error { + data, err := r.MarshalRecord(r) + if err != nil { + return err + } + + err = b.db.Update(func(txn *badger.Txn) error { + return txn.Set([]byte(r.DatabaseKey()), data) + }) + return err +} + +// Delete deletes a record from the database. +func (b *Badger) Delete(key string) error { + return b.db.Update(func(txn *badger.Txn) error { + err := txn.Delete([]byte(key)) + if err != nil && err != badger.ErrKeyNotFound { + return err + } + return nil + }) +} + +// Query returns a an iterator for the supplied query. +func (b *Badger) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + _, err := q.Check() + if err != nil { + return nil, fmt.Errorf("invalid query: %s", err) + } + + queryIter := iterator.New() + + go b.queryExecutor(queryIter, q, local, internal) + return queryIter, nil +} + +func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) { + err := b.db.View(func(txn *badger.Txn) error { + it := txn.NewIterator(badger.DefaultIteratorOptions) + defer it.Close() + prefix := []byte(q.DatabaseKeyPrefix()) + for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { + item := it.Item() + + data, err := item.Value() + if err != nil { + return err + } + + r, err := record.NewRawWrapper(b.name, string(item.Key()), data) + if err != nil { + return err + } + + if !r.Meta().CheckValidity() { + continue + } + if !r.Meta().CheckPermission(local, internal) { + continue + } + + if q.MatchesRecord(r) { + copiedData, err := item.ValueCopy(nil) + if err != nil { + return err + } + new, err := record.NewRawWrapper(b.name, r.DatabaseKey(), copiedData) + if err != nil { + return err + } + select { + case queryIter.Next <- new: + default: + select { + case queryIter.Next <- new: + case <-time.After(1 * time.Minute): + return errors.New("query timeout") + } + } + } + + } + return nil + }) + + if err != nil { + queryIter.Error = err + } + close(queryIter.Next) + close(queryIter.Done) +} + +// ReadOnly returns whether the database is read only. +func (b *Badger) ReadOnly() bool { + return false +} + +// Injected returns whether the database is injected. +func (b *Badger) Injected() bool { + return false +} + +// Maintain runs a light maintenance operation on the database. +func (b *Badger) Maintain() error { + b.db.RunValueLogGC(0.7) + return nil +} + +// MaintainThorough runs a thorough maintenance operation on the database. +func (b *Badger) MaintainThorough() (err error) { + for err == nil { + err = b.db.RunValueLogGC(0.7) + } + return nil +} + +// Shutdown shuts down the database. +func (b *Badger) Shutdown() error { + return b.db.Close() +} diff --git a/database/storage/badger/badger_test.go b/database/storage/badger/badger_test.go new file mode 100644 index 0000000..f89ee90 --- /dev/null +++ b/database/storage/badger/badger_test.go @@ -0,0 +1,138 @@ +package badger + +import ( + "io/ioutil" + "os" + "reflect" + "sync" + "testing" + + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" +) + +type TestRecord struct { + record.Base + lock sync.Mutex + S string + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + F32 float32 + F64 float64 + B bool +} + +func (tr *TestRecord) Lock() { +} + +func (tr *TestRecord) Unlock() { +} + +func TestBadger(t *testing.T) { + testDir, err := ioutil.TempDir("", "testing-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(testDir) // clean up + + // start + db, err := NewBadger("test", testDir) + if err != nil { + t.Fatal(err) + } + + a := &TestRecord{ + S: "banana", + I: 42, + I8: 42, + I16: 42, + I32: 42, + I64: 42, + UI: 42, + UI8: 42, + UI16: 42, + UI32: 42, + UI64: 42, + F32: 42.42, + F64: 42.42, + B: true, + } + a.SetMeta(&record.Meta{}) + a.Meta().Update() + a.SetKey("test:A") + + // put record + err = db.Put(a) + if err != nil { + t.Fatal(err) + } + + // get and compare + r1, err := db.Get("A") + if err != nil { + t.Fatal(err) + } + + a1 := &TestRecord{} + err = record.Unwrap(r1, a1) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(a, a1) { + t.Fatalf("mismatch, got %v", a1) + } + + // test query + q := query.New("").MustBeValid() + it, err := db.Query(q, true, true) + if err != nil { + t.Fatal(err) + } + cnt := 0 + for _ = range it.Next { + cnt++ + } + if it.Error != nil { + t.Fatal(err) + } + if cnt != 1 { + t.Fatalf("unexpected query result count: %d", cnt) + } + + // delete + err = db.Delete("A") + if err != nil { + t.Fatal(err) + } + + // check if its gone + _, err = db.Get("A") + if err == nil { + t.Fatal("should fail") + } + + // maintenance + err = db.Maintain() + if err != nil { + t.Fatal(err) + } + err = db.MaintainThorough() + if err != nil { + t.Fatal(err) + } + + // shutdown + err = db.Shutdown() + if err != nil { + t.Fatal(err) + } +} diff --git a/database/storage/errors.go b/database/storage/errors.go new file mode 100644 index 0000000..a280296 --- /dev/null +++ b/database/storage/errors.go @@ -0,0 +1,8 @@ +package storage + +import "errors" + +// Errors for storages +var ( + ErrNotFound = errors.New("storage entry could not be found") +) diff --git a/database/storage/interface.go b/database/storage/interface.go new file mode 100644 index 0000000..73b1a5f --- /dev/null +++ b/database/storage/interface.go @@ -0,0 +1,21 @@ +package storage + +import ( + "github.com/Safing/portbase/database/iterator" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" +) + +// Interface defines the database storage API. +type Interface interface { + Get(key string) (record.Record, error) + Put(m record.Record) error + Delete(key string) error + Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) + + ReadOnly() bool + Injected() bool + Maintain() error + MaintainThorough() error + Shutdown() error +} diff --git a/database/storage/sinkhole/sinkhole.go b/database/storage/sinkhole/sinkhole.go new file mode 100644 index 0000000..b1447a7 --- /dev/null +++ b/database/storage/sinkhole/sinkhole.go @@ -0,0 +1,71 @@ +package sinkhole + +import ( + "errors" + + "github.com/Safing/portbase/database/iterator" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" + "github.com/Safing/portbase/database/storage" +) + +// Sinkhole is a dummy storage. +type Sinkhole struct { + name string +} + +func init() { + storage.Register("sinkhole", NewSinkhole) +} + +// NewSinkhole creates a dummy database. +func NewSinkhole(name, location string) (storage.Interface, error) { + return &Sinkhole{ + name: name, + }, nil +} + +// Exists returns whether an entry with the given key exists. +func (s *Sinkhole) Exists(key string) (bool, error) { + return false, nil +} + +// Get returns a database record. +func (s *Sinkhole) Get(key string) (record.Record, error) { + return nil, storage.ErrNotFound +} + +// Put stores a record in the database. +func (s *Sinkhole) Put(m record.Record) error { + return nil +} + +// Delete deletes a record from the database. +func (s *Sinkhole) Delete(key string) error { + return nil +} + +// Query returns a an iterator for the supplied query. +func (s *Sinkhole) Query(q *query.Query) (*iterator.Iterator, error) { + return nil, errors.New("query not implemented by sinkhole") +} + +// ReadOnly returns whether the database is read only. +func (s *Sinkhole) ReadOnly() bool { + return false +} + +// Maintain runs a light maintenance operation on the database. +func (s *Sinkhole) Maintain() error { + return nil +} + +// MaintainThorough runs a thorough maintenance operation on the database. +func (s *Sinkhole) MaintainThorough() (err error) { + return nil +} + +// Shutdown shuts down the database. +func (s *Sinkhole) Shutdown() error { + return nil +} diff --git a/database/storage/storages.go b/database/storage/storages.go new file mode 100644 index 0000000..1fa7448 --- /dev/null +++ b/database/storage/storages.go @@ -0,0 +1,47 @@ +package storage + +import ( + "errors" + "fmt" + "sync" +) + +// A Factory creates a new database of it's type. +type Factory func(name, location string) (Interface, error) + +var ( + storages = make(map[string]Factory) + storagesLock sync.Mutex +) + +// Register registers a new storage type. +func Register(name string, factory Factory) error { + storagesLock.Lock() + defer storagesLock.Unlock() + + _, ok := storages[name] + if ok { + return errors.New("factory for this type already exists") + } + + storages[name] = factory + return nil +} + +// CreateDatabase starts a new database with the given name and storageType at location. +func CreateDatabase(name, storageType, location string) (Interface, error) { + return nil, nil +} + +// StartDatabase starts a new database with the given name and storageType at location. +func StartDatabase(name, storageType, location string) (Interface, error) { + storagesLock.Lock() + defer storagesLock.Unlock() + + factory, ok := storages[storageType] + if !ok { + return nil, fmt.Errorf("storage type %s not registered", storageType) + } + + return factory(name, location) +} diff --git a/database/subscription.go b/database/subscription.go new file mode 100644 index 0000000..d95ac94 --- /dev/null +++ b/database/subscription.go @@ -0,0 +1,59 @@ +package database + +import ( + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" +) + +// Subscription is a database subscription for updates. +type Subscription struct { + q *query.Query + Feed chan record.Record + Err error +} + +// Subscribe subscribes to updates matching the given query. +func Subscribe(q *query.Query) (*Subscription, error) { + _, err := q.Check() + if err != nil { + return nil, err + } + + c, err := getController(q.DatabaseName()) + if err != nil { + return nil, err + } + + c.readLock.Lock() + defer c.readLock.Unlock() + c.writeLock.Lock() + defer c.writeLock.Unlock() + + sub := &Subscription{ + q: q, + Feed: make(chan record.Record, 100), + } + c.subscriptions = append(c.subscriptions, sub) + return sub, nil +} + +// Cancel cancels the subscription. +func (s *Subscription) Cancel() error { + c, err := getController(s.q.DatabaseName()) + if err != nil { + return err + } + + c.readLock.Lock() + defer c.readLock.Unlock() + c.writeLock.Lock() + defer c.writeLock.Unlock() + + for key, sub := range c.subscriptions { + if sub.q == s.q { + c.subscriptions = append(c.subscriptions[:key], c.subscriptions[key+1:]...) + return nil + } + } + return nil +} diff --git a/database/subscriptions.go b/database/subscriptions.go deleted file mode 100644 index 684016b..0000000 --- a/database/subscriptions.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "fmt" - "strings" - "sync" - - "github.com/Safing/safing-core/modules" - "github.com/Safing/safing-core/taskmanager" - - "github.com/ipfs/go-datastore" - "github.com/tevino/abool" -) - -var subscriptionModule *modules.Module -var subscriptions []*Subscription -var subLock sync.Mutex - -var databaseUpdate chan Model -var databaseCreate chan Model -var databaseDelete chan *datastore.Key - -var workIsWaiting chan *struct{} -var workIsWaitingFlag *abool.AtomicBool -var forceProcessing chan *struct{} - -type Subscription struct { - typeAndLocation map[string]bool - exactObject map[string]bool - children map[string]uint8 - Created chan Model - Updated chan Model - Deleted chan *datastore.Key -} - -func NewSubscription() *Subscription { - subLock.Lock() - defer subLock.Unlock() - sub := &Subscription{ - typeAndLocation: make(map[string]bool), - exactObject: make(map[string]bool), - children: make(map[string]uint8), - Created: make(chan Model, 128), - Updated: make(chan Model, 128), - Deleted: make(chan *datastore.Key, 128), - } - subscriptions = append(subscriptions, sub) - return sub -} - -func (sub *Subscription) Subscribe(subKey string) { - subLock.Lock() - defer subLock.Unlock() - - namespaces := strings.Split(subKey, "/")[1:] - lastSpace := "" - if len(namespaces) != 0 { - lastSpace = namespaces[len(namespaces)-1] - } - - switch { - case lastSpace == "": - // save key without leading "/" - // save with depth 255 to get all - sub.children[strings.Trim(subKey, "/")] = 0xFF - case strings.HasPrefix(lastSpace, "*"): - // save key without leading or trailing "/" or "*" - // save full wanted depth - this makes comparison easier - sub.children[strings.Trim(subKey, "/*")] = uint8(len(lastSpace) + len(namespaces) - 1) - case strings.Contains(lastSpace, ":"): - sub.exactObject[subKey] = true - default: - sub.typeAndLocation[subKey] = true - } -} - -func (sub *Subscription) Unsubscribe(subKey string) { - subLock.Lock() - defer subLock.Unlock() - - namespaces := strings.Split(subKey, "/")[1:] - lastSpace := "" - if len(namespaces) != 0 { - lastSpace = namespaces[len(namespaces)-1] - } - - switch { - case lastSpace == "": - delete(sub.children, strings.Trim(subKey, "/")) - case strings.HasPrefix(lastSpace, "*"): - delete(sub.children, strings.Trim(subKey, "/*")) - case strings.Contains(lastSpace, ":"): - delete(sub.exactObject, subKey) - default: - delete(sub.typeAndLocation, subKey) - } -} - -func (sub *Subscription) Destroy() { - subLock.Lock() - defer subLock.Unlock() - - for k, v := range subscriptions { - if v.Created == sub.Created { - defer func() { - subscriptions = append(subscriptions[:k], subscriptions[k+1:]...) - }() - close(sub.Created) - close(sub.Updated) - close(sub.Deleted) - return - } - } -} - -func (sub *Subscription) Subscriptions() *[]string { - subStrings := make([]string, 0) - for subString := range sub.exactObject { - subStrings = append(subStrings, subString) - } - for subString := range sub.typeAndLocation { - subStrings = append(subStrings, subString) - } - for subString, depth := range sub.children { - if depth == 0xFF { - subStrings = append(subStrings, fmt.Sprintf("/%s/", subString)) - } else { - subStrings = append(subStrings, fmt.Sprintf("/%s/%s", subString, strings.Repeat("*", int(depth)-len(strings.Split(subString, "/"))))) - } - } - return &subStrings -} - -func (sub *Subscription) String() string { - return fmt.Sprintf("", strings.Join(*sub.Subscriptions(), " ")) -} - -func (sub *Subscription) send(key *datastore.Key, model Model, created bool) { - if model == nil { - sub.Deleted <- key - } else if created { - sub.Created <- model - } else { - sub.Updated <- model - } -} - -func process(key *datastore.Key, model Model, created bool) { - subLock.Lock() - defer subLock.Unlock() - - stringRep := key.String() - // "/Comedy/MontyPython/Actor:JohnCleese" - typeAndLocation := key.Path().String() - // "/Comedy/MontyPython/Actor" - namespaces := key.Namespaces() - // ["Comedy", "MontyPython", "Actor:JohnCleese"] - depth := uint8(len(namespaces)) - // 3 - -subscriptionLoop: - for _, sub := range subscriptions { - if _, ok := sub.exactObject[stringRep]; ok { - sub.send(key, model, created) - continue subscriptionLoop - } - if _, ok := sub.typeAndLocation[typeAndLocation]; ok { - sub.send(key, model, created) - continue subscriptionLoop - } - for i := 0; i < len(namespaces); i++ { - if subscribedDepth, ok := sub.children[strings.Join(namespaces[:i], "/")]; ok { - if subscribedDepth >= depth { - sub.send(key, model, created) - continue subscriptionLoop - } - } - } - } - -} - -func init() { - subscriptionModule = modules.Register("Database:Subscriptions", 128) - subscriptions = make([]*Subscription, 0) - subLock = sync.Mutex{} - - databaseUpdate = make(chan Model, 32) - databaseCreate = make(chan Model, 32) - databaseDelete = make(chan *datastore.Key, 32) - - workIsWaiting = make(chan *struct{}, 0) - workIsWaitingFlag = abool.NewBool(false) - forceProcessing = make(chan *struct{}, 0) - - go run() -} - -func run() { - for { - select { - case <-subscriptionModule.Stop: - subscriptionModule.StopComplete() - return - case <-workIsWaiting: - work() - } - } -} - -func work() { - defer workIsWaitingFlag.UnSet() - - // wait - select { - case <-taskmanager.StartMediumPriorityMicroTask(): - defer taskmanager.EndMicroTask() - case <-forceProcessing: - } - - // work - for { - select { - case model := <-databaseCreate: - process(model.GetKey(), model, true) - case model := <-databaseUpdate: - process(model.GetKey(), model, false) - case key := <-databaseDelete: - process(key, nil, false) - default: - return - } - } -} - -func handleCreateSubscriptions(model Model) { - select { - case databaseCreate <- model: - default: - forceProcessing <- nil - databaseCreate <- model - } - if workIsWaitingFlag.SetToIf(false, true) { - workIsWaiting <- nil - } -} - -func handleUpdateSubscriptions(model Model) { - select { - case databaseUpdate <- model: - default: - forceProcessing <- nil - databaseUpdate <- model - } - if workIsWaitingFlag.SetToIf(false, true) { - workIsWaiting <- nil - } -} - -func handleDeleteSubscriptions(key *datastore.Key) { - select { - case databaseDelete <- key: - default: - forceProcessing <- nil - databaseDelete <- key - } - if workIsWaitingFlag.SetToIf(false, true) { - workIsWaiting <- nil - } -} diff --git a/database/subscriptions_test.go b/database/subscriptions_test.go deleted file mode 100644 index 817a9df..0000000 --- a/database/subscriptions_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "strconv" - "strings" - "sync" - "testing" -) - -var subTestWg sync.WaitGroup - -func waitForSubs(t *testing.T, sub *Subscription, highest int) { - defer subTestWg.Done() - expecting := 1 - var subbedModel Model -forLoop: - for { - select { - case subbedModel = <-sub.Created: - case subbedModel = <-sub.Updated: - } - t.Logf("got model from subscription: %s", subbedModel.GetKey().String()) - if !strings.HasPrefix(subbedModel.GetKey().Name(), "sub") { - // not a model that we use for testing, other tests might be interfering - continue forLoop - } - number, err := strconv.Atoi(strings.TrimPrefix(subbedModel.GetKey().Name(), "sub")) - if err != nil || number != expecting { - t.Errorf("test subscription: got unexpected model %s, expected sub%d", subbedModel.GetKey().String(), expecting) - continue forLoop - } - if number == highest { - return - } - expecting++ - } -} - -func TestSubscriptions(t *testing.T) { - - // create subscription - sub := NewSubscription() - - // FIRST TEST - - subTestWg.Add(1) - go waitForSubs(t, sub, 3) - sub.Subscribe("/Tests/") - t.Log(sub.String()) - - (&(TestingModel{})).CreateInNamespace("", "sub1") - (&(TestingModel{})).CreateInNamespace("A", "sub2") - (&(TestingModel{})).CreateInNamespace("A/B/C/D/E", "sub3") - - subTestWg.Wait() - - // SECOND TEST - - subTestWg.Add(1) - go waitForSubs(t, sub, 3) - sub.Unsubscribe("/Tests/") - sub.Subscribe("/Tests/A/****") - t.Log(sub.String()) - - (&(TestingModel{})).CreateInNamespace("", "subX") - (&(TestingModel{})).CreateInNamespace("A", "sub1") - (&(TestingModel{})).CreateInNamespace("A/B/C/D", "sub2") - (&(TestingModel{})).CreateInNamespace("A/B/C/D/E", "subX") - (&(TestingModel{})).CreateInNamespace("A", "sub3") - - subTestWg.Wait() - - // THIRD TEST - - subTestWg.Add(1) - go waitForSubs(t, sub, 3) - sub.Unsubscribe("/Tests/A/****") - sub.Subscribe("/Tests/TestingModel:sub1") - sub.Subscribe("/Tests/TestingModel:sub1/TestingModel") - t.Log(sub.String()) - - (&(TestingModel{})).CreateInNamespace("", "sub1") - (&(TestingModel{})).CreateInNamespace("", "subX") - (&(TestingModel{})).CreateInNamespace("TestingModel:sub1", "sub2") - (&(TestingModel{})).CreateInNamespace("TestingModel:sub1/A", "subX") - (&(TestingModel{})).CreateInNamespace("TestingModel:sub1", "sub3") - - subTestWg.Wait() - - // FINAL STUFF - - model := &TestingModel{} - model.CreateInNamespace("Invalid", "subX") - model.Save() - - sub.Destroy() - - // time.Sleep(1 * time.Second) - // pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - -} diff --git a/database/utils/kvops/kvops.go b/database/utils/kvops/kvops.go new file mode 100644 index 0000000..bf8ee70 --- /dev/null +++ b/database/utils/kvops/kvops.go @@ -0,0 +1 @@ +package kvops diff --git a/database/wrapper.go b/database/wrapper.go deleted file mode 100644 index 74d3089..0000000 --- a/database/wrapper.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "github.com/ipfs/go-datastore" - - "github.com/Safing/safing-core/database/dbutils" -) - -func NewWrapper(key *datastore.Key, data []byte) (*dbutils.Wrapper, error) { - return dbutils.NewWrapper(key, data) -} - -func DumpModel(uncertain interface{}, storageType uint8) ([]byte, error) { - return dbutils.DumpModel(uncertain, storageType) -} diff --git a/database/wrapper_test.go b/database/wrapper_test.go deleted file mode 100644 index 27ce5f2..0000000 --- a/database/wrapper_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package database - -import ( - "testing" - - "github.com/Safing/safing-core/formats/dsd" -) - -func TestWrapper(t *testing.T) { - - // create Model - new := &TestingModel{ - Name: "a", - Value: "b", - } - newTwo := &TestingModel{ - Name: "c", - Value: "d", - } - - // dump - bytes, err := DumpModel(new, dsd.JSON) - if err != nil { - panic(err) - } - bytesTwo, err := DumpModel(newTwo, dsd.JSON) - if err != nil { - panic(err) - } - - // wrap - wrapped, err := NewWrapper(nil, bytes) - if err != nil { - panic(err) - } - wrappedTwo, err := NewWrapper(nil, bytesTwo) - if err != nil { - panic(err) - } - - // model definition for unwrapping - var model *TestingModel - - // unwrap - myModel, ok := SilentEnsureModel(wrapped, model).(*TestingModel) - if !ok { - panic("received model does not match expected model") - } - if myModel.Name != "a" || myModel.Value != "b" { - panic("model value mismatch") - } - - // verbose unwrap - genericModel, err := EnsureModel(wrappedTwo, model) - if err != nil { - panic(err) - } - myModelTwo, ok := genericModel.(*TestingModel) - if !ok { - panic("received model does not match expected model") - } - if myModelTwo.Name != "c" || myModelTwo.Value != "d" { - panic("model value mismatch") - } - -} diff --git a/formats/dsd/dsd.go b/formats/dsd/dsd.go index c21f75f..7b3fcd2 100644 --- a/formats/dsd/dsd.go +++ b/formats/dsd/dsd.go @@ -10,19 +10,19 @@ import ( "errors" "fmt" - "github.com/pkg/bson" + // "github.com/pkg/bson" "github.com/Safing/safing-core/formats/varint" ) // define types const ( - AUTO = 0 - STRING = 83 // S - BYTES = 88 // X - JSON = 74 // J - BSON = 66 // B - // MSGP + AUTO = 0 + STRING = 83 // S + BYTES = 88 // X + JSON = 74 // J + BSON = 66 // B + GenCode = 71 // G (reserved) ) // define errors @@ -56,12 +56,12 @@ func Load(data []byte, t interface{}) (interface{}, error) { return nil, err } return t, nil - case BSON: - err := bson.Unmarshal(data[read:], t) - if err != nil { - return nil, err - } - return t, nil + // case BSON: + // err := bson.Unmarshal(data[read:], t) + // if err != nil { + // return nil, err + // } + // return t, nil // case MSGP: // err := t.UnmarshalMsg(data[read:]) // if err != nil { @@ -101,11 +101,11 @@ func Dump(t interface{}, format uint8) ([]byte, error) { if err != nil { return nil, err } - case BSON: - data, err = bson.Marshal(t) - if err != nil { - return nil, err - } + // case BSON: + // data, err = bson.Marshal(t) + // if err != nil { + // return nil, err + // } // case MSGP: // data, err := t.MarshalMsg(nil) // if err != nil {