diff --git a/database/accessor/accessor-json-bytes.go b/database/accessor/accessor-json-bytes.go index 08330d7..3115bf6 100644 --- a/database/accessor/accessor-json-bytes.go +++ b/database/accessor/accessor-json-bytes.go @@ -36,6 +36,10 @@ func (ja *JSONBytesAccessor) Set(key string, value interface{}) error { if result.Type != gjson.True && result.Type != gjson.False { return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) } + case []string: + if !result.IsArray() { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) + } } } @@ -47,6 +51,15 @@ func (ja *JSONBytesAccessor) Set(key string, value interface{}) error { return nil } +// Get returns the value found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) Get(key string) (value interface{}, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() { + return nil, false + } + return result.Value(), true +} + // GetString returns the string found by the given json key and whether it could be successfully extracted. func (ja *JSONBytesAccessor) GetString(key string) (value string, ok bool) { result := gjson.GetBytes(*ja.json, key) @@ -56,6 +69,24 @@ func (ja *JSONBytesAccessor) GetString(key string) (value string, ok bool) { return result.String(), true } +// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetStringArray(key string) (value []string, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() && !result.IsArray() { + return nil, false + } + slice := result.Array() + new := make([]string, len(slice)) + for i, res := range slice { + if res.Type == gjson.String { + new[i] = res.String() + } else { + return nil, false + } + } + return new, true +} + // GetInt returns the int found by the given json key and whether it could be successfully extracted. func (ja *JSONBytesAccessor) GetInt(key string) (value int64, ok bool) { result := gjson.GetBytes(*ja.json, key) diff --git a/database/accessor/accessor-json-string.go b/database/accessor/accessor-json-string.go index 1170418..c5c215d 100644 --- a/database/accessor/accessor-json-string.go +++ b/database/accessor/accessor-json-string.go @@ -36,6 +36,10 @@ func (ja *JSONAccessor) Set(key string, value interface{}) error { if result.Type != gjson.True && result.Type != gjson.False { return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) } + case []string: + if !result.IsArray() { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value) + } } } @@ -47,6 +51,15 @@ func (ja *JSONAccessor) Set(key string, value interface{}) error { return nil } +// Get returns the value found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) Get(key string) (value interface{}, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() { + return nil, false + } + return result.Value(), true +} + // GetString returns the string found by the given json key and whether it could be successfully extracted. func (ja *JSONAccessor) GetString(key string) (value string, ok bool) { result := gjson.Get(*ja.json, key) @@ -56,6 +69,24 @@ func (ja *JSONAccessor) GetString(key string) (value string, ok bool) { return result.String(), true } +// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetStringArray(key string) (value []string, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() && !result.IsArray() { + return nil, false + } + slice := result.Array() + new := make([]string, len(slice)) + for i, res := range slice { + if res.Type == gjson.String { + new[i] = res.String() + } else { + return nil, false + } + } + return new, true +} + // GetInt returns the int found by the given json key and whether it could be successfully extracted. func (ja *JSONAccessor) GetInt(key string) (value int64, ok bool) { result := gjson.Get(*ja.json, key) diff --git a/database/accessor/accessor-struct.go b/database/accessor/accessor-struct.go index b56c9b7..6f245aa 100644 --- a/database/accessor/accessor-struct.go +++ b/database/accessor/accessor-struct.go @@ -86,6 +86,15 @@ func (sa *StructAccessor) Set(key string, value interface{}) error { return nil } +// Get returns the value found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) Get(key string) (value interface{}, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() || !field.CanInterface() { + return nil, false + } + return field.Interface(), true +} + // GetString returns the string found by the given json key and whether it could be successfully extracted. func (sa *StructAccessor) GetString(key string) (value string, ok bool) { field := sa.object.FieldByName(key) @@ -95,6 +104,20 @@ func (sa *StructAccessor) GetString(key string) (value string, ok bool) { return field.String(), true } +// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetStringArray(key string) (value []string, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() || field.Kind() != reflect.Slice || !field.CanInterface() { + return nil, false + } + v := field.Interface() + slice, ok := v.([]string) + if !ok { + return nil, false + } + return slice, true +} + // GetInt returns the int found by the given json key and whether it could be successfully extracted. func (sa *StructAccessor) GetInt(key string) (value int64, ok bool) { field := sa.object.FieldByName(key) diff --git a/database/accessor/accessor.go b/database/accessor/accessor.go index aedad26..9fde0e0 100644 --- a/database/accessor/accessor.go +++ b/database/accessor/accessor.go @@ -6,7 +6,9 @@ const ( // Accessor provides an interface to supply the query matcher a method to retrieve values from an object. type Accessor interface { + Get(key string) (value interface{}, ok bool) GetString(key string) (value string, ok bool) + GetStringArray(key string) (value []string, ok bool) GetInt(key string) (value int64, ok bool) GetFloat(key string) (value float64, ok bool) GetBool(key string) (value bool, ok bool) diff --git a/database/accessor/accessor_test.go b/database/accessor/accessor_test.go index d2a3a20..aab5137 100644 --- a/database/accessor/accessor_test.go +++ b/database/accessor/accessor_test.go @@ -3,10 +3,13 @@ package accessor import ( "encoding/json" "testing" + + "github.com/Safing/portbase/utils" ) type TestStruct struct { S string + A []string I int I8 int8 I16 int16 @@ -25,6 +28,7 @@ type TestStruct struct { var ( testStruct = &TestStruct{ S: "banana", + A: []string{"black", "white"}, I: 42, I8: 42, I16: 42, @@ -56,6 +60,19 @@ func testGetString(t *testing.T, acc Accessor, key string, shouldSucceed bool, e } } +func testGetStringArray(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue []string) { + v, ok := acc.GetStringArray(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get []string with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get []string with key %s, it returned %v", acc.Type(), key, v) + } + if !utils.StringSliceEqual(v, expectedValue) { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + func testGetInt(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue int64) { v, ok := acc.GetInt(key) switch { @@ -127,6 +144,7 @@ func TestAccessor(t *testing.T) { // get for _, acc := range accs { testGetString(t, acc, "S", true, "banana") + testGetStringArray(t, acc, "A", true, []string{"black", "white"}) testGetInt(t, acc, "I", true, 42) testGetInt(t, acc, "I8", true, 42) testGetInt(t, acc, "I16", true, 42) @@ -145,6 +163,7 @@ func TestAccessor(t *testing.T) { // set for _, acc := range accs { testSet(t, acc, "S", true, "coconut") + testSet(t, acc, "A", true, []string{"green", "blue"}) testSet(t, acc, "I", true, uint32(44)) testSet(t, acc, "I8", true, uint64(44)) testSet(t, acc, "I16", true, uint8(44)) @@ -163,6 +182,7 @@ func TestAccessor(t *testing.T) { // get again to check if new values were set for _, acc := range accs { testGetString(t, acc, "S", true, "coconut") + testGetStringArray(t, acc, "A", true, []string{"green", "blue"}) testGetInt(t, acc, "I", true, 44) testGetInt(t, acc, "I8", true, 44) testGetInt(t, acc, "I16", true, 44) @@ -185,6 +205,12 @@ func TestAccessor(t *testing.T) { testSet(t, acc, "S", false, 1) testSet(t, acc, "S", false, 1.1) + testSet(t, acc, "A", false, "1") + testSet(t, acc, "A", false, true) + testSet(t, acc, "A", false, false) + testSet(t, acc, "A", false, 1) + testSet(t, acc, "A", false, 1.1) + testSet(t, acc, "I", false, "1") testSet(t, acc, "I8", false, "1") testSet(t, acc, "I16", false, "1") @@ -207,6 +233,7 @@ func TestAccessor(t *testing.T) { // get again to check if values werent changed when an error occurred for _, acc := range accs { testGetString(t, acc, "S", true, "coconut") + testGetStringArray(t, acc, "A", true, []string{"green", "blue"}) testGetInt(t, acc, "I", true, 44) testGetInt(t, acc, "I8", true, 44) testGetInt(t, acc, "I16", true, 44) @@ -225,6 +252,7 @@ func TestAccessor(t *testing.T) { // test existence for _, acc := range accs { testExists(t, acc, "S", true) + testExists(t, acc, "A", true) testExists(t, acc, "I", true) testExists(t, acc, "I8", true) testExists(t, acc, "I16", true) diff --git a/database/controller.go b/database/controller.go index 02e8d63..d981fc3 100644 --- a/database/controller.go +++ b/database/controller.go @@ -157,6 +157,20 @@ func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iter return it, nil } +// PushUpdate pushes a record update to subscribers. +func (c *Controller) PushUpdate(r record.Record) { + if c != nil { + for _, sub := range c.subscriptions { + if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) { + select { + case sub.Feed <- r: + default: + } + } + } + } +} + func (c *Controller) readUnlockerAfterQuery(it *iterator.Iterator) { <- it.Done c.readLock.RUnlock() diff --git a/database/controllers.go b/database/controllers.go index 5b9d002..38d1591 100644 --- a/database/controllers.go +++ b/database/controllers.go @@ -30,7 +30,7 @@ func getController(name string) (*Controller, error) { // get db registration registeredDB, err := getDatabase(name) if err != nil { - return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err) + return nil, fmt.Errorf(`could not start database %s: %s`, name, err) } // get location @@ -56,13 +56,13 @@ func getController(name string) (*Controller, error) { } // InjectDatabase injects an already running database into the system. -func InjectDatabase(name string, storageInt storage.Interface) error { +func InjectDatabase(name string, storageInt storage.Interface) (*Controller, error) { controllersLock.Lock() defer controllersLock.Unlock() _, ok := controllers[name] if ok { - return errors.New(`database "%s" already loaded`) + return nil, errors.New(`database "%s" already loaded`) } registryLock.Lock() @@ -71,17 +71,17 @@ func InjectDatabase(name string, storageInt storage.Interface) error { // check if database is registered registeredDB, ok := registry[name] if !ok { - return fmt.Errorf(`database "%s" not registered`, name) + return nil, fmt.Errorf(`database "%s" not registered`, name) } if registeredDB.StorageType != "injected" { - return fmt.Errorf(`database not of type "injected"`) + return nil, fmt.Errorf(`database not of type "injected"`) } controller, err := newController(storageInt) if err != nil { - return fmt.Errorf(`could not create controller for database %s: %s`, name, err) + return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err) } controllers[name] = controller - return nil + return controller, nil } diff --git a/database/database_test.go b/database/database_test.go index e997760..9aa5e7d 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -98,8 +98,8 @@ func testDatabase(t *testing.T, storageType string) { for _ = range it.Next { cnt++ } - if it.Error != nil { - t.Fatal(it.Error) + if it.Err != nil { + t.Fatal(it.Err) } if cnt != 2 { t.Fatal("expected two records") diff --git a/database/iterator/iterator.go b/database/iterator/iterator.go index 7a1dff4..70451ea 100644 --- a/database/iterator/iterator.go +++ b/database/iterator/iterator.go @@ -6,9 +6,9 @@ import ( // Iterator defines the iterator structure. type Iterator struct { - Next chan record.Record - Done chan struct{} - Error error + Next chan record.Record + Done chan struct{} + Err error } // New creates a new Iterator. @@ -18,3 +18,9 @@ func New() *Iterator { Done: make(chan struct{}), } } + +func (it *Iterator) Finish(err error) { + close(it.Next) + close(it.Done) + it.Err = err +} diff --git a/database/location.go b/database/location.go index 0c095b6..80250cc 100644 --- a/database/location.go +++ b/database/location.go @@ -39,6 +39,11 @@ func ensureDirectory(dirPath string) error { return nil } +// GetDatabaseRoot returns the root directory of the database. +func GetDatabaseRoot() string { + return rootDir +} + // getLocation returns the storage location for the given name and type. func getLocation(name, storageType string) (string, error) { location := path.Join(rootDir, databasesSubDir, name, storageType) diff --git a/database/maintenance.go b/database/maintenance.go index d1399bd..05d8182 100644 --- a/database/maintenance.go +++ b/database/maintenance.go @@ -64,7 +64,7 @@ func MaintainRecordStates() error { toExpire = append(toExpire, r) } } - if it.Error != nil { + if it.Err != nil { return err } diff --git a/database/record/wrapper.go b/database/record/wrapper.go index 63bdb66..03c46bc 100644 --- a/database/record/wrapper.go +++ b/database/record/wrapper.go @@ -42,10 +42,11 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) { return nil, fmt.Errorf("could not unmarshal meta section: %s", err) } - format, _, err := varint.Unpack8(data[offset:]) + format, n, err := varint.Unpack8(data[offset:]) if err != nil { return nil, fmt.Errorf("could not get dsd format: %s", err) } + offset += n return &Wrapper{ Base{ @@ -60,12 +61,7 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) { } // NewWrapper returns a new record wrapper for the given data. -func NewWrapper(key string, meta *Meta, data []byte) (*Wrapper, error) { - format, _, err := varint.Unpack8(data) - if err != nil { - return nil, fmt.Errorf("could not get dsd format: %s", err) - } - +func NewWrapper(key string, meta *Meta, format uint8, data []byte) (*Wrapper, error) { dbName, dbKey := ParseKey(key) return &Wrapper{ @@ -81,7 +77,7 @@ func NewWrapper(key string, meta *Meta, data []byte) (*Wrapper, error) { } // Marshal marshals the object, without the database key or metadata -func (w *Wrapper) Marshal(r Record, storageType uint8) ([]byte, error) { +func (w *Wrapper) Marshal(r Record, format uint8) ([]byte, error) { if w.Meta() == nil { return nil, errors.New("missing meta") } @@ -90,10 +86,15 @@ func (w *Wrapper) Marshal(r Record, storageType uint8) ([]byte, error) { return nil, nil } - if storageType != dsd.AUTO && storageType != w.Format { + if format != dsd.AUTO && format != w.Format { return nil, errors.New("could not dump model, wrapped object format mismatch") } - return w.Data, nil + + data := make([]byte, len(w.Data)+1) + data[0] = w.Format + copy(data[1:], w.Data) + + return data, nil } // MarshalRecord packs the object, including metadata, into a byte array for saving in a database. @@ -146,7 +147,7 @@ func Unwrap(wrapped, new Record) error { return fmt.Errorf("cannot unwrap %T", wrapped) } - _, err := dsd.Load(wrapper.Data, new) + _, err := dsd.LoadAsFormat(wrapper.Data, wrapper.Format, new) if err != nil { return fmt.Errorf("failed to unwrap %T: %s", new, err) } @@ -159,9 +160,8 @@ func Unwrap(wrapped, new Record) error { // GetAccessor returns an accessor for this record, if available. func (w *Wrapper) GetAccessor(self Record) accessor.Accessor { - if len(w.Data) > 1 && w.Data[0] == JSON { - jsonData := w.Data[1:] - return accessor.NewJSONBytesAccessor(&jsonData) + if w.Format == JSON && len(w.Data) > 0 { + return accessor.NewJSONBytesAccessor(&w.Data) } return nil } diff --git a/database/record/wrapper_test.go b/database/record/wrapper_test.go index e2988f0..0e14156 100644 --- a/database/record/wrapper_test.go +++ b/database/record/wrapper_test.go @@ -16,10 +16,11 @@ func TestWrapper(t *testing.T) { _ = m // create test data - testData := []byte(`J{"a": "b"}`) + testData := []byte(`{"a": "b"}`) + encodedTestData := []byte(`J{"a": "b"}`) // test wrapper - wrapper, err := NewWrapper("test:a", &Meta{}, testData) + wrapper, err := NewWrapper("test:a", &Meta{}, JSON, testData) if err != nil { t.Fatal(err) } @@ -34,7 +35,7 @@ func TestWrapper(t *testing.T) { if err != nil { t.Fatal(err) } - if !bytes.Equal(testData, encoded) { + if !bytes.Equal(encodedTestData, encoded) { t.Error("marshal mismatch") } diff --git a/database/registry.go b/database/registry.go index 1322854..7195bc9 100644 --- a/database/registry.go +++ b/database/registry.go @@ -116,7 +116,7 @@ func loadRegistry() error { // parse new := make(map[string]*Database) - err = json.Unmarshal(data, new) + err = json.Unmarshal(data, &new) if err != nil { return err } diff --git a/database/storage/badger/badger.go b/database/storage/badger/badger.go index 0ff6abe..3df5a37 100644 --- a/database/storage/badger/badger.go +++ b/database/storage/badger/badger.go @@ -163,7 +163,7 @@ func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, loc }) if err != nil { - queryIter.Error = err + queryIter.Err = err } close(queryIter.Next) close(queryIter.Done) diff --git a/database/storage/badger/badger_test.go b/database/storage/badger/badger_test.go index f89ee90..9a66172 100644 --- a/database/storage/badger/badger_test.go +++ b/database/storage/badger/badger_test.go @@ -101,7 +101,7 @@ func TestBadger(t *testing.T) { for _ = range it.Next { cnt++ } - if it.Error != nil { + if it.Err != nil { t.Fatal(err) } if cnt != 1 { diff --git a/database/storage/injectbase.go b/database/storage/injectbase.go new file mode 100644 index 0000000..1d3a734 --- /dev/null +++ b/database/storage/injectbase.go @@ -0,0 +1,61 @@ +package storage + +import ( + "errors" + + "github.com/Safing/portbase/database/iterator" + "github.com/Safing/portbase/database/query" + "github.com/Safing/portbase/database/record" +) + +var ( + errNotImplemented = errors.New("not implemented") +) + +// InjectBase is a dummy base structure to reduce boilerplate code for injected storage interfaces. +type InjectBase struct{} + +// Get returns a database record. +func (i *InjectBase) Get(key string) (record.Record, error) { + return nil, errNotImplemented +} + +// Put stores a record in the database. +func (i *InjectBase) Put(m record.Record) error { + return errNotImplemented +} + +// Delete deletes a record from the database. +func (i *InjectBase) Delete(key string) error { + return errNotImplemented +} + +// Query returns a an iterator for the supplied query. +func (i *InjectBase) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + return nil, errNotImplemented +} + +// ReadOnly returns whether the database is read only. +func (i *InjectBase) ReadOnly() bool { + return true +} + +// Injected returns whether the database is injected. +func (i *InjectBase) Injected() bool { + return true +} + +// Maintain runs a light maintenance operation on the database. +func (i *InjectBase) Maintain() error { + return nil +} + +// MaintainThorough runs a thorough maintenance operation on the database. +func (i *InjectBase) MaintainThorough() error { + return nil +} + +// Shutdown shuts down the database. +func (i *InjectBase) Shutdown() error { + return nil +}