Finish minimal feature set, start with tests

This commit is contained in:
Daniel 2018-09-10 19:01:28 +02:00
parent 3d60431376
commit 06a34f931e
34 changed files with 651 additions and 346 deletions

View file

@ -1,6 +1,8 @@
package record package accessor
import ( import (
"fmt"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@ -19,6 +21,24 @@ func NewJSONBytesAccessor(json *[]byte) *JSONBytesAccessor {
// Set sets the value identified by key. // Set sets the value identified by key.
func (ja *JSONBytesAccessor) Set(key string, value interface{}) error { func (ja *JSONBytesAccessor) Set(key string, value interface{}) error {
result := gjson.GetBytes(*ja.json, key)
if result.Exists() {
switch value.(type) {
case string:
if result.Type != gjson.String {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
if result.Type != gjson.Number {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case bool:
if result.Type != gjson.True && result.Type != gjson.False {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
}
}
new, err := sjson.SetBytes(*ja.json, key, value) new, err := sjson.SetBytes(*ja.json, key, value)
if err != nil { if err != nil {
return err return err

View file

@ -1,6 +1,8 @@
package record package accessor
import ( import (
"fmt"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@ -19,6 +21,24 @@ func NewJSONAccessor(json *string) *JSONAccessor {
// Set sets the value identified by key. // Set sets the value identified by key.
func (ja *JSONAccessor) Set(key string, value interface{}) error { func (ja *JSONAccessor) Set(key string, value interface{}) error {
result := gjson.Get(*ja.json, key)
if result.Exists() {
switch value.(type) {
case string:
if result.Type != gjson.String {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
if result.Type != gjson.Number {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case bool:
if result.Type != gjson.True && result.Type != gjson.False {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
}
}
new, err := sjson.Set(*ja.json, key, value) new, err := sjson.Set(*ja.json, key, value)
if err != nil { if err != nil {
return err return err

View file

@ -1,4 +1,4 @@
package record package accessor
import ( import (
"errors" "errors"

View file

@ -1,4 +1,4 @@
package record package accessor
const ( const (
emptyString = "" emptyString = ""

View file

@ -1,4 +1,4 @@
package record package accessor
import ( import (
"encoding/json" "encoding/json"
@ -95,6 +95,16 @@ func testGetBool(t *testing.T, acc Accessor, key string, shouldSucceed bool, exp
} }
} }
func testExists(t *testing.T, acc Accessor, key string, shouldSucceed bool) {
ok := acc.Exists(key)
switch {
case !ok && shouldSucceed:
t.Errorf("%s should report key %s as existing", acc.Type(), key)
case ok && !shouldSucceed:
t.Errorf("%s should report key %s as non-existing", acc.Type(), key)
}
}
func testSet(t *testing.T, acc Accessor, key string, shouldSucceed bool, valueToSet interface{}) { func testSet(t *testing.T, acc Accessor, key string, shouldSucceed bool, valueToSet interface{}) {
err := acc.Set(key, valueToSet) err := acc.Set(key, valueToSet)
switch { switch {
@ -150,7 +160,7 @@ func TestAccessor(t *testing.T) {
testSet(t, acc, "B", true, false) testSet(t, acc, "B", true, false)
} }
// get again // get again to check if new values were set
for _, acc := range accs { for _, acc := range accs {
testGetString(t, acc, "S", true, "coconut") testGetString(t, acc, "S", true, "coconut")
testGetInt(t, acc, "I", true, 44) testGetInt(t, acc, "I", true, 44)
@ -170,19 +180,69 @@ func TestAccessor(t *testing.T) {
// failures // failures
for _, acc := range accs { for _, acc := range accs {
testGetString(t, acc, "S", false, 1) testSet(t, acc, "S", false, true)
testGetInt(t, acc, "I", false, 44) testSet(t, acc, "S", false, false)
testGetInt(t, acc, "I8", false, 512) testSet(t, acc, "S", false, 1)
testGetInt(t, acc, "I16", false, 1000000) testSet(t, acc, "S", false, 1.1)
testGetInt(t, acc, "I32", false, 44)
testGetInt(t, acc, "I64", false, "44") testSet(t, acc, "I", false, "1")
testGetInt(t, acc, "UI", false, 44) testSet(t, acc, "I8", false, "1")
testGetInt(t, acc, "UI8", false, 44) testSet(t, acc, "I16", false, "1")
testGetInt(t, acc, "UI16", false, 44) testSet(t, acc, "I32", false, "1")
testGetInt(t, acc, "UI32", false, 44) testSet(t, acc, "I64", false, "1")
testGetInt(t, acc, "UI64", false, 44) testSet(t, acc, "UI", false, "1")
testGetFloat(t, acc, "F32", false, 44.44) testSet(t, acc, "UI8", false, "1")
testGetFloat(t, acc, "F64", false, 44.44) testSet(t, acc, "UI16", false, "1")
testGetBool(t, acc, "B", false, false) testSet(t, acc, "UI32", false, "1")
testSet(t, acc, "UI64", false, "1")
testSet(t, acc, "F32", false, "1.1")
testSet(t, acc, "F64", false, "1.1")
testSet(t, acc, "B", false, "false")
testSet(t, acc, "B", false, 1)
testSet(t, acc, "B", false, 1.1)
} }
// get again to check if values werent changed when an error occurred
for _, acc := range accs {
testGetString(t, acc, "S", true, "coconut")
testGetInt(t, acc, "I", true, 44)
testGetInt(t, acc, "I8", true, 44)
testGetInt(t, acc, "I16", true, 44)
testGetInt(t, acc, "I32", true, 44)
testGetInt(t, acc, "I64", true, 44)
testGetInt(t, acc, "UI", true, 44)
testGetInt(t, acc, "UI8", true, 44)
testGetInt(t, acc, "UI16", true, 44)
testGetInt(t, acc, "UI32", true, 44)
testGetInt(t, acc, "UI64", true, 44)
testGetFloat(t, acc, "F32", true, 44.44)
testGetFloat(t, acc, "F64", true, 44.44)
testGetBool(t, acc, "B", true, false)
}
// test existence
for _, acc := range accs {
testExists(t, acc, "S", true)
testExists(t, acc, "I", true)
testExists(t, acc, "I8", true)
testExists(t, acc, "I16", true)
testExists(t, acc, "I32", true)
testExists(t, acc, "I64", true)
testExists(t, acc, "UI", true)
testExists(t, acc, "UI8", true)
testExists(t, acc, "UI16", true)
testExists(t, acc, "UI32", true)
testExists(t, acc, "UI64", true)
testExists(t, acc, "F32", true)
testExists(t, acc, "F64", true)
testExists(t, acc, "B", true)
}
// test non-existence
for _, acc := range accs {
testExists(t, acc, "X", false)
}
} }

View file

@ -1,7 +1,6 @@
package database package database
import ( import (
"errors"
"sync" "sync"
"time" "time"
@ -29,8 +28,17 @@ func newController(storageInt storage.Interface) (*Controller, error) {
}, nil }, nil
} }
// ReadOnly returns whether the storage is read only.
func (c *Controller) ReadOnly() bool {
return c.storage.ReadOnly()
}
// Get return the record with the given key. // Get return the record with the given key.
func (c *Controller) Get(key string) (record.Record, error) { func (c *Controller) Get(key string) (record.Record, error) {
if shuttingDown.IsSet() {
return nil, ErrShuttingDown
}
r, err := c.storage.Get(key) r, err := c.storage.Get(key)
if err != nil { if err != nil {
return nil, err return nil, err
@ -48,101 +56,41 @@ func (c *Controller) Get(key string) (record.Record, error) {
// Put saves a record in the database. // Put saves a record in the database.
func (c *Controller) Put(r record.Record) error { func (c *Controller) Put(r record.Record) error {
if shuttingDown.IsSet() {
return ErrShuttingDown
}
if c.storage.ReadOnly() { if c.storage.ReadOnly() {
return ErrReadOnly return ErrReadOnly
} }
if r.Meta() == nil {
r.SetMeta(&record.Meta{})
}
r.Meta().Update()
return c.storage.Put(r) return c.storage.Put(r)
} }
// Delete a record from the database. // Query executes the given query on the database.
func (c *Controller) Delete(key string) error {
if c.storage.ReadOnly() {
return ErrReadOnly
}
r, err := c.Get(key)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
r.Meta().Deleted = time.Now().Unix()
return c.Put(r)
}
// Partial
// What happens if I mutate a value that does not yet exist? How would I know its type?
func (c *Controller) InsertPartial(key string, partialObject interface{}) error {
if c.storage.ReadOnly() {
return ErrReadOnly
}
return nil
}
func (c *Controller) InsertValue(key string, attribute string, value interface{}) error {
if c.storage.ReadOnly() {
return ErrReadOnly
}
r, err := c.Get(key)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
if r.IsWrapped() {
wrapper, ok := r.(*record.Wrapper)
if !ok {
return errors.New("record is malformed")
}
} else {
}
return nil
}
// Query
func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
return nil, nil if shuttingDown.IsSet() {
} return nil, ErrShuttingDown
// Meta
func (c *Controller) SetAbsoluteExpiry(key string, time int64) error {
if c.storage.ReadOnly() {
return ErrReadOnly
} }
return c.storage.Query(q, local, internal)
return nil
} }
func (c *Controller) SetRelativateExpiry(key string, duration int64) error { // Maintain runs the Maintain method no the storage.
if c.storage.ReadOnly() { func (c *Controller) Maintain() error {
return ErrReadOnly return c.storage.Maintain()
}
return nil
} }
func (c *Controller) MakeCrownJewel(key string) error { // MaintainThorough runs the MaintainThorough method no the storage.
if c.storage.ReadOnly() { func (c *Controller) MaintainThorough() error {
return ErrReadOnly return c.storage.MaintainThorough()
}
return nil
} }
func (c *Controller) MakeSecret(key string) error { // Shutdown shuts down the storage.
if c.storage.ReadOnly() { func (c *Controller) Shutdown() error {
return ErrReadOnly return c.storage.Shutdown()
}
return nil
} }

View file

@ -1,121 +0,0 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package database
import (
"errors"
)
// Errors
var (
ErrNotFound = errors.New("database entry could not be found")
ErrPermissionDenied = errors.New("access to database record denied")
ErrReadOnly = errors.New("database is read only")
)
func init() {
// if strings.HasSuffix(os.Args[0], ".test") {
// // testing setup
// log.Warning("===== DATABASE RUNNING IN TEST MODE =====")
// db = channelshim.NewChanneledDatastore(ds.NewMapDatastore())
// return
// }
// sfsDB, err := simplefs.NewDatastore(meta.DatabaseDir())
// if err != nil {
// fmt.Fprintf(os.Stderr, "FATAL ERROR: could not init simplefs database: %s\n", err)
// os.Exit(1)
// }
// ldb, err := leveldb.NewDatastore(path.Join(meta.DatabaseDir(), "leveldb"), &leveldb.Options{})
// if err != nil {
// fmt.Fprintf(os.Stderr, "FATAL ERROR: could not init simplefs database: %s\n", err)
// os.Exit(1)
// }
//
// mapDB := ds.NewMapDatastore()
//
// db = channelshim.NewChanneledDatastore(mount.New([]mount.Mount{
// mount.Mount{
// Prefix: ds.NewKey("/Run"),
// Datastore: mapDB,
// },
// mount.Mount{
// Prefix: ds.NewKey("/"),
// Datastore: ldb,
// },
// }))
}
// func Batch() (ds.Batch, error) {
// return db.Batch()
// }
// func Close() error {
// return db.Close()
// }
// func Get(key *ds.Key) (Model, error) {
// data, err := db.Get(*key)
// if err != nil {
// switch err {
// case ds.ErrNotFound:
// return nil, ErrNotFound
// default:
// return nil, err
// }
// }
// model, ok := data.(Model)
// if !ok {
// return nil, errors.New("database did not return model")
// }
// return model, nil
// }
// func Has(key ds.Key) (exists bool, err error) {
// return db.Has(key)
// }
//
// func Create(key ds.Key, model Model) (err error) {
// handleCreateSubscriptions(model)
// err = db.Put(key, model)
// if err != nil {
// log.Tracef("database: failed to create entry %s: %s", key, err)
// }
// return err
// }
//
// func Update(key ds.Key, model Model) (err error) {
// handleUpdateSubscriptions(model)
// err = db.Put(key, model)
// if err != nil {
// log.Tracef("database: failed to update entry %s: %s", key, err)
// }
// return err
// }
//
// func Delete(key ds.Key) (err error) {
// handleDeleteSubscriptions(&key)
// return db.Delete(key)
// }
//
// func Query(q dsq.Query) (dsq.Results, error) {
// return db.Query(q)
// }
//
// func RawGet(key ds.Key) (*dbutils.Wrapper, error) {
// data, err := db.Get(key)
// if err != nil {
// return nil, err
// }
// wrapped, ok := data.(*dbutils.Wrapper)
// if !ok {
// return nil, errors.New("returned data is not a wrapper")
// }
// return wrapped, nil
// }
//
// func RawPut(key ds.Key, value interface{}) error {
// return db.Put(key, value)
// }

69
database/database_test.go Normal file
View file

@ -0,0 +1,69 @@
package database
import (
"io/ioutil"
"os"
"sync"
"testing"
"github.com/Safing/portbase/database/record"
)
type TestRecord struct {
record.Base
lock sync.Mutex
S string
I int
I8 int8
I16 int16
I32 int32
I64 int64
UI uint
UI8 uint8
UI16 uint16
UI32 uint32
UI64 uint64
F32 float32
F64 float64
B bool
}
func (tr *TestRecord) Lock() {
}
func (tr *TestRecord) Unlock() {
}
func TestDatabase(t *testing.T) {
testDir, err := ioutil.TempDir("", "testing-")
if err != nil {
t.Fatal(err)
}
err = Initialize(testDir)
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(testDir) // clean up
err = RegisterDatabase(&RegisteredDatabase{
Name: "testing",
Description: "Unit Test Database",
StorageType: "badger",
PrimaryAPI: "",
})
if err != nil {
t.Fatal(err)
}
db := NewInterface(nil)
new := &TestRecord{}
new.SetKey("testing:A")
err = db.Put(new)
if err != nil {
t.Fatal(err)
}
}

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"path" "path"
"github.com/tevino/abool"
"github.com/Safing/portbase/database/storage" "github.com/Safing/portbase/database/storage"
"github.com/Safing/portbase/database/record" "github.com/Safing/portbase/database/record"
) )
@ -13,14 +14,16 @@ import (
var ( var (
databases = make(map[string]*Controller) databases = make(map[string]*Controller)
databasesLock sync.Mutex databasesLock sync.Mutex
shuttingDown = abool.NewBool(false)
) )
func splitKeyAndGetDatabase(key string) (dbKey string, db *Controller, err error) { func splitKeyAndGetDatabase(key string) (db *Controller, dbKey string, err error) {
var dbName string var dbName string
dbName, dbKey = record.ParseKey(key) dbName, dbKey = record.ParseKey(key)
db, err = getDatabase(dbName) db, err = getDatabase(dbName)
if err != nil { if err != nil {
return "", nil, err return nil, "", err
} }
return return
} }

View file

@ -1,29 +1,43 @@
package dbmodule package dbmodule
import ( import (
"github.com/Safing/portbase/database" "errors"
"flag"
"sync"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/modules"
) )
var ( var (
databaseDir string databaseDir string
shutdownSignal = make(chan struct{})
maintenanceWg sync.WaitGroup
) )
func init() { func init() {
flag.StringVar(&databaseDir, "db", "", "set database directory") flag.StringVar(&databaseDir, "db", "", "set database directory")
modules.Register("database", prep, start, stop) modules.Register("database", prep, start, stop)
} }
func prep() error { func prep() error {
if databaseDir == "" { if databaseDir == "" {
return errors.New("no database location specified, set with `-db=/path/to/db`") return errors.New("no database location specified, set with `-db=/path/to/db`")
} }
return nil
} }
func start() error { func start() error {
return database.Initialize(databaseDir) err := database.Initialize(databaseDir)
if err == nil {
go maintainer()
}
return err
} }
func stop() { func stop() error {
return database.Shutdown() close(shutdownSignal)
maintenanceWg.Wait()
return database.Shutdown()
} }

View file

@ -0,0 +1,32 @@
package dbmodule
import (
"time"
"github.com/Safing/portbase/database"
"github.com/Safing/portbase/log"
)
func maintainer() {
ticker := time.NewTicker(1 * time.Hour)
tickerThorough := time.NewTicker(10 * time.Minute)
maintenanceWg.Add(1)
for {
select {
case <- ticker.C:
err := database.Maintain()
if err != nil {
log.Errorf("database: maintenance error: %s", err)
}
case <- ticker.C:
err := database.MaintainThorough()
if err != nil {
log.Errorf("database: maintenance (thorough) error: %s", err)
}
case <-shutdownSignal:
maintenanceWg.Done()
return
}
}
}

15
database/errors.go Normal file
View file

@ -0,0 +1,15 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package database
import (
"errors"
)
// Errors
var (
ErrNotFound = errors.New("database entry could not be found")
ErrPermissionDenied = errors.New("access to database record denied")
ErrReadOnly = errors.New("database is read only")
ErrShuttingDown = errors.New("database system is shutting down")
)

View file

@ -1,9 +1,19 @@
package database package database
import ( import (
"errors"
"fmt"
"github.com/Safing/portbase/database/accessor"
"github.com/Safing/portbase/database/iterator"
"github.com/Safing/portbase/database/query"
"github.com/Safing/portbase/database/record" "github.com/Safing/portbase/database/record"
) )
const (
getDBFromKey = ""
)
// Interface provides a method to access the database with attached options. // Interface provides a method to access the database with attached options.
type Interface struct { type Interface struct {
options *Options options *Options
@ -30,7 +40,7 @@ func NewInterface(opts *Options) *Interface {
// Exists return whether a record with the given key exists. // Exists return whether a record with the given key exists.
func (i *Interface) Exists(key string) (bool, error) { func (i *Interface) Exists(key string) (bool, error) {
_, err := i.getRecord(key) _, _, err := i.getRecord(getDBFromKey, key, false, false)
if err != nil { if err != nil {
if err == ErrNotFound { if err == ErrNotFound {
return false, nil return false, nil
@ -42,28 +52,161 @@ func (i *Interface) Exists(key string) (bool, error) {
// Get return the record with the given key. // Get return the record with the given key.
func (i *Interface) Get(key string) (record.Record, error) { func (i *Interface) Get(key string) (record.Record, error) {
r, err := i.getRecord(key) r, _, err := i.getRecord(getDBFromKey, key, true, false)
if err != nil { return r, err
return nil, err
}
if !r.Meta().CheckPermission(i.options.Local, i.options.Internal) {
return nil, ErrPermissionDenied
}
return r, nil
} }
func (i *Interface) getRecord(key string) (record.Record, error) { func (i *Interface) getRecord(dbName string, dbKey string, check bool, mustBeWriteable bool) (r record.Record, db *Controller, err error) {
dbKey, db, err := splitKeyAndGetDatabase(key) if dbName == "" {
if err != nil { dbName, dbKey = record.ParseKey(dbKey)
return nil, err
} }
r, err := db.Get(dbKey) db, err = getDatabase(dbName)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return r, nil if mustBeWriteable && db.ReadOnly() {
return nil, nil, ErrReadOnly
}
r, err = db.Get(dbKey)
if err != nil {
return nil, nil, err
}
if check && !r.Meta().CheckPermission(i.options.Local, i.options.Internal) {
return nil, nil, ErrPermissionDenied
}
return r, db, nil
}
// InsertValue inserts a value into a record.
func (i *Interface) InsertValue(key string, attribute string, value interface{}) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
var acc accessor.Accessor
if r.IsWrapped() {
wrapper, ok := r.(*record.Wrapper)
if !ok {
return errors.New("record is malformed (reports to be wrapped but is not of type *record.Wrapper)")
}
acc = accessor.NewJSONBytesAccessor(&wrapper.Data)
} else {
acc = accessor.NewStructAccessor(r)
}
err = acc.Set(attribute, value)
if err != nil {
return fmt.Errorf("failed to set value with %s: %s", acc.Type(), err)
}
return db.Put(r)
}
// Put saves a record to the database.
func (i *Interface) Put(r record.Record) error {
_, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
if err != nil {
return err
}
return db.Put(r)
}
// PutNew saves a record to the database as a new record (ie. with a new creation timestamp).
func (i *Interface) PutNew(r record.Record) error {
_, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
if err != nil && err != ErrNotFound {
return err
}
r.SetMeta(&record.Meta{})
return db.Put(r)
}
// SetAbsoluteExpiry sets an absolute record expiry.
func (i *Interface) SetAbsoluteExpiry(key string, time int64) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
r.Meta().SetAbsoluteExpiry(time)
return db.Put(r)
}
// SetRelativateExpiry sets a relative (self-updating) record expiry.
func (i *Interface) SetRelativateExpiry(key string, duration int64) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
r.Meta().SetRelativateExpiry(duration)
return db.Put(r)
}
// MakeSecret marks the record as a secret, meaning interfacing processes, such as an UI, are denied access to the record.
func (i *Interface) MakeSecret(key string) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
r.Meta().MakeSecret()
return db.Put(r)
}
// MakeCrownJewel marks a record as a crown jewel, meaning it will only be accessible locally.
func (i *Interface) MakeCrownJewel(key string) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
r.Meta().MakeCrownJewel()
return db.Put(r)
}
// Delete deletes a record from the database.
func (i *Interface) Delete(key string) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
r.Meta().Delete()
return db.Put(r)
}
// Query executes the given query on the database.
func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) {
db, err := getDatabase(q.DatabaseName())
if err != nil {
return nil, err
}
return db.Query(q, i.options.Local, i.options.Internal)
} }

View file

@ -2,12 +2,57 @@ package database
import ( import (
"path" "path"
"os"
"fmt"
"errors"
) )
var ( var (
rootDir string rootDir string
) )
// Initialize initialized the database
func Initialize(location string) error {
if initialized.SetToIf(false, true) {
rootDir = location
err := checkRootDir()
if err != nil {
return fmt.Errorf("could not create/open database directory (%s): %s", rootDir, err)
}
err = loadRegistry()
if err != nil {
return fmt.Errorf("could not load database registry (%s): %s", path.Join(rootDir, registryFileName), err)
}
return nil
}
return errors.New("database already initialized")
}
func checkRootDir() error {
// open dir
dir, err := os.Open(rootDir)
if err != nil {
if err == os.ErrNotExist {
return os.MkdirAll(rootDir, 0700)
}
return err
}
defer dir.Close()
fileInfo, err := dir.Stat()
if err != nil {
return err
}
if fileInfo.Mode().Perm() != 0700 {
return dir.Chmod(0700)
}
return nil
}
// getLocation returns the storage location for the given name and type. // getLocation returns the storage location for the given name and type.
func getLocation(name, storageType string) (location string, err error) { func getLocation(name, storageType string) (location string, err error) {
return path.Join(rootDir, name, storageType), nil return path.Join(rootDir, name, storageType), nil

50
database/maintainence.go Normal file
View file

@ -0,0 +1,50 @@
package database
// Maintain runs the Maintain method on all storages.
func Maintain() (err error) {
controllers := duplicateControllers()
for _, c := range controllers {
err = c.Maintain()
if err != nil {
return
}
}
return
}
// MaintainThorough runs the MaintainThorough method on all storages.
func MaintainThorough() (err error) {
controllers := duplicateControllers()
for _, c := range controllers {
err = c.MaintainThorough()
if err != nil {
return
}
}
return
}
// Shutdown shuts down the whole database system.
func Shutdown() (err error) {
shuttingDown.Set()
controllers := duplicateControllers()
for _, c := range controllers {
err = c.Shutdown()
if err != nil {
return
}
}
return
}
func duplicateControllers() (controllers []*Controller) {
databasesLock.Lock()
defer databasesLock.Unlock()
for _, c := range databases {
controllers = append(controllers, c)
}
return
}

View file

@ -3,6 +3,8 @@ package query
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/Safing/portbase/database/accessor"
) )
// And combines multiple conditions with a logical _AND_ operator. // And combines multiple conditions with a logical _AND_ operator.
@ -16,9 +18,9 @@ type andCond struct {
conditions []Condition conditions []Condition
} }
func (c *andCond) complies(f Fetcher) bool { func (c *andCond) complies(acc accessor.Accessor) bool {
for _, cond := range c.conditions { for _, cond := range c.conditions {
if !cond.complies(f) { if !cond.complies(acc) {
return false return false
} }
} }

View file

@ -4,6 +4,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"github.com/Safing/portbase/database/accessor"
) )
type boolCondition struct { type boolCondition struct {
@ -42,8 +44,8 @@ func newBoolCondition(key string, operator uint8, value interface{}) *boolCondit
} }
} }
func (c *boolCondition) complies(f Fetcher) bool { func (c *boolCondition) complies(acc accessor.Accessor) bool {
comp, ok := f.GetBool(c.key) comp, ok := acc.GetBool(c.key)
if !ok { if !ok {
return false return false
} }

View file

@ -1,5 +1,9 @@
package query package query
import (
"github.com/Safing/portbase/database/accessor"
)
type errorCondition struct { type errorCondition struct {
err error err error
} }
@ -10,7 +14,7 @@ func newErrorCondition(err error) *errorCondition {
} }
} }
func (c *errorCondition) complies(f Fetcher) bool { func (c *errorCondition) complies(acc accessor.Accessor) bool {
return false return false
} }

View file

@ -3,6 +3,8 @@ package query
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/Safing/portbase/database/accessor"
) )
type existsCondition struct { type existsCondition struct {
@ -17,8 +19,8 @@ func newExistsCondition(key string, operator uint8) *existsCondition {
} }
} }
func (c *existsCondition) complies(f Fetcher) bool { func (c *existsCondition) complies(acc accessor.Accessor) bool {
return f.Exists(c.key) return acc.Exists(c.key)
} }
func (c *existsCondition) check() error { func (c *existsCondition) check() error {

View file

@ -4,6 +4,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"github.com/Safing/portbase/database/accessor"
) )
type floatCondition struct { type floatCondition struct {
@ -62,8 +64,8 @@ func newFloatCondition(key string, operator uint8, value interface{}) *floatCond
} }
} }
func (c *floatCondition) complies(f Fetcher) bool { func (c *floatCondition) complies(acc accessor.Accessor) bool {
comp, ok := f.GetFloat(c.key) comp, ok := acc.GetFloat(c.key)
if !ok { if !ok {
return false return false
} }

View file

@ -4,6 +4,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"github.com/Safing/portbase/database/accessor"
) )
type intCondition struct { type intCondition struct {
@ -58,8 +60,8 @@ func newIntCondition(key string, operator uint8, value interface{}) *intConditio
} }
} }
func (c *intCondition) complies(f Fetcher) bool { func (c *intCondition) complies(acc accessor.Accessor) bool {
comp, ok := f.GetInt(c.key) comp, ok := acc.GetInt(c.key)
if !ok { if !ok {
return false return false
} }

View file

@ -1,9 +1,13 @@
package query package query
import (
"github.com/Safing/portbase/database/accessor"
)
type noCond struct { type noCond struct {
} }
func (c *noCond) complies(f Fetcher) bool { func (c *noCond) complies(acc accessor.Accessor) bool {
return true return true
} }

View file

@ -3,6 +3,8 @@ package query
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/Safing/portbase/database/accessor"
) )
// Not negates the supplied condition. // Not negates the supplied condition.
@ -16,8 +18,8 @@ type notCond struct {
notC Condition notC Condition
} }
func (c *notCond) complies(f Fetcher) bool { func (c *notCond) complies(acc accessor.Accessor) bool {
return !c.notC.complies(f) return !c.notC.complies(acc)
} }
func (c *notCond) check() error { func (c *notCond) check() error {

View file

@ -3,6 +3,8 @@ package query
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/Safing/portbase/database/accessor"
) )
// Or combines multiple conditions with a logical _OR_ operator. // Or combines multiple conditions with a logical _OR_ operator.
@ -16,9 +18,9 @@ type orCond struct {
conditions []Condition conditions []Condition
} }
func (c *orCond) complies(f Fetcher) bool { func (c *orCond) complies(acc accessor.Accessor) bool {
for _, cond := range c.conditions { for _, cond := range c.conditions {
if cond.complies(f) { if cond.complies(acc) {
return true return true
} }
} }

View file

@ -4,6 +4,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"github.com/Safing/portbase/database/accessor"
) )
type regexCondition struct { type regexCondition struct {
@ -35,8 +37,8 @@ func newRegexCondition(key string, operator uint8, value interface{}) *regexCond
} }
} }
func (c *regexCondition) complies(f Fetcher) bool { func (c *regexCondition) complies(acc accessor.Accessor) bool {
comp, ok := f.GetString(c.key) comp, ok := acc.GetString(c.key)
if !ok { if !ok {
return false return false
} }

View file

@ -4,6 +4,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"github.com/Safing/portbase/database/accessor"
) )
type stringCondition struct { type stringCondition struct {
@ -28,8 +30,8 @@ func newStringCondition(key string, operator uint8, value interface{}) *stringCo
} }
} }
func (c *stringCondition) complies(f Fetcher) bool { func (c *stringCondition) complies(acc accessor.Accessor) bool {
comp, ok := f.GetString(c.key) comp, ok := acc.GetString(c.key)
if !ok { if !ok {
return false return false
} }

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/Safing/portbase/database/accessor"
"github.com/Safing/portbase/utils" "github.com/Safing/portbase/utils"
) )
@ -44,8 +45,8 @@ func newStringSliceCondition(key string, operator uint8, value interface{}) *str
} }
func (c *stringSliceCondition) complies(f Fetcher) bool { func (c *stringSliceCondition) complies(acc accessor.Accessor) bool {
comp, ok := f.GetString(c.key) comp, ok := acc.GetString(c.key)
if !ok { if !ok {
return false return false
} }

View file

@ -1,10 +1,14 @@
package query package query
import "fmt" import (
"fmt"
"github.com/Safing/portbase/database/accessor"
)
// Condition is an interface to provide a common api to all condition types. // Condition is an interface to provide a common api to all condition types.
type Condition interface { type Condition interface {
complies(f Fetcher) bool complies(acc accessor.Accessor) bool
check() error check() error
string() string string() string
} }

View file

@ -144,7 +144,6 @@ func testParseError(t *testing.T, queryText string, expectedErrorString string)
func TestParseErrors(t *testing.T) { func TestParseErrors(t *testing.T) {
// syntax // syntax
testParseError(t, `query`, `unexpected end at position 5`) testParseError(t, `query`, `unexpected end at position 5`)
testParseError(t, `query test`, `invalid prefix: test`)
testParseError(t, `query test: where`, `unexpected end at position 17`) testParseError(t, `query test: where`, `unexpected end at position 17`)
testParseError(t, `query test: where (`, `unexpected end at position 19`) testParseError(t, `query test: where (`, `unexpected end at position 19`)
testParseError(t, `query test: where )`, `unknown clause ")" at position 19`) testParseError(t, `query test: where )`, `unknown clause ")" at position 19`)

View file

@ -2,8 +2,10 @@ package query
import ( import (
"fmt" "fmt"
"regexp"
"strings" "strings"
"github.com/Safing/portbase/database/accessor"
"github.com/Safing/portbase/database/record"
) )
// Example: // Example:
@ -16,24 +18,23 @@ import (
// ) // )
// ) // )
var (
prefixExpr = regexp.MustCompile("^[a-z-]+:")
)
// Query contains a compiled query. // Query contains a compiled query.
type Query struct { type Query struct {
checked bool checked bool
prefix string dbName string
where Condition dbKeyPrefix string
orderBy string where Condition
limit int orderBy string
offset int limit int
offset int
} }
// New creates a new query with the supplied prefix. // New creates a new query with the supplied prefix.
func New(prefix string) *Query { func New(prefix string) *Query {
dbName, dbKeyPrefix := record.ParseKey(prefix)
return &Query{ return &Query{
prefix: prefix, dbName: dbName,
dbKeyPrefix: dbKeyPrefix,
} }
} }
@ -67,11 +68,6 @@ func (q *Query) Check() (*Query, error) {
return q, nil return q, nil
} }
// check prefix
if !prefixExpr.MatchString(q.prefix) {
return nil, fmt.Errorf("invalid prefix: %s", q.prefix)
}
// check condition // check condition
if q.where != nil { if q.where != nil {
err := q.where.check() err := q.where.check()
@ -101,8 +97,8 @@ func (q *Query) IsChecked() bool {
} }
// Matches checks whether the query matches the supplied data object. // Matches checks whether the query matches the supplied data object.
func (q *Query) Matches(f Fetcher) bool { func (q *Query) Matches(acc accessor.Accessor) bool {
return q.where.complies(f) return q.where.complies(acc)
} }
// Print returns the string representation of the query. // Print returns the string representation of the query.
@ -130,5 +126,15 @@ func (q *Query) Print() string {
offset = fmt.Sprintf(" offset %d", q.offset) offset = fmt.Sprintf(" offset %d", q.offset)
} }
return fmt.Sprintf("query %s%s%s%s%s", q.prefix, where, orderBy, limit, offset) return fmt.Sprintf("query %s:%s%s%s%s%s", q.dbName, q.dbKeyPrefix, where, orderBy, limit, offset)
}
// DatabaseName returns the name of the database.
func (q *Query) DatabaseName() string {
return q.dbName
}
// DatabaseKeyPrefix returns the key prefix for the database.
func (q *Query) DatabaseKeyPrefix() string {
return q.dbKeyPrefix
} }

View file

@ -2,6 +2,8 @@ package query
import ( import (
"testing" "testing"
"github.com/Safing/portbase/database/accessor"
) )
var ( var (
@ -44,12 +46,12 @@ var (
}` }`
) )
func testQuery(t *testing.T, f Fetcher, shouldMatch bool, condition Condition) { func testQuery(t *testing.T, acc accessor.Accessor, shouldMatch bool, condition Condition) {
q := New("test:").Where(condition).MustBeValid() q := New("test:").Where(condition).MustBeValid()
// fmt.Printf("%s\n", q.String()) // fmt.Printf("%s\n", q.String())
matched := q.Matches(f) matched := q.Matches(acc)
switch { switch {
case !matched && shouldMatch: case !matched && shouldMatch:
t.Errorf("should match: %s", q.Print()) t.Errorf("should match: %s", q.Print())
@ -63,7 +65,7 @@ func TestQuery(t *testing.T) {
// if !gjson.Valid(testJSON) { // if !gjson.Valid(testJSON) {
// t.Fatal("test json is invalid") // t.Fatal("test json is invalid")
// } // }
f := NewJSONFetcher(testJSON) f := accessor.NewJSONAccessor(&testJSON)
testQuery(t, f, true, Where("age", Equals, 100)) testQuery(t, f, true, Where("age", Equals, 100))
testQuery(t, f, true, Where("age", GreaterThan, uint8(99))) testQuery(t, f, true, Where("age", GreaterThan, uint8(99)))

View file

@ -73,6 +73,11 @@ func (m *Meta) Reset() {
m.Deleted = 0 m.Deleted = 0
} }
// Delete marks the record as deleted.
func (m *Meta) Delete() {
m.Deleted = time.Now().Unix()
}
// CheckValidity checks whether the database record is valid. // CheckValidity checks whether the database record is valid.
func (m *Meta) CheckValidity(now int64) (valid bool) { func (m *Meta) CheckValidity(now int64) (valid bool) {
switch { switch {

View file

@ -3,10 +3,10 @@ package database
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
"regexp"
"sync" "sync"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -40,6 +40,8 @@ var (
registry map[string]*RegisteredDatabase registry map[string]*RegisteredDatabase
registryLock sync.Mutex registryLock sync.Mutex
nameConstraint = regexp.MustCompile("^[A-Za-z0-9_-]{5,}$")
) )
// RegisterDatabase registers a new database. // RegisterDatabase registers a new database.
@ -48,6 +50,10 @@ func RegisterDatabase(new *RegisteredDatabase) error {
return errors.New("database not initialized") return errors.New("database not initialized")
} }
if !nameConstraint.MatchString(new.Name) {
return errors.New("database name must only contain alphanumeric and `_-` characters and must be at least 5 characters long")
}
registryLock.Lock() registryLock.Lock()
defer registryLock.Unlock() defer registryLock.Unlock()
@ -60,48 +66,6 @@ func RegisterDatabase(new *RegisteredDatabase) error {
return nil return nil
} }
// Initialize initialized the database
func Initialize(location string) error {
if initialized.SetToIf(false, true) {
rootDir = location
err := checkRootDir()
if err != nil {
return fmt.Errorf("could not create/open database directory (%s): %s", rootDir, err)
}
err = loadRegistry()
if err != nil {
return fmt.Errorf("could not load database registry (%s): %s", path.Join(rootDir, registryFileName), err)
}
return nil
}
return errors.New("database already initialized")
}
func checkRootDir() error {
// open dir
dir, err := os.Open(rootDir)
if err != nil {
if err == os.ErrNotExist {
return os.MkdirAll(rootDir, 0700)
}
return err
}
defer dir.Close()
fileInfo, err := dir.Stat()
if err != nil {
return err
}
if fileInfo.Mode().Perm() != 0700 {
return dir.Chmod(0700)
}
return nil
}
func loadRegistry() error { func loadRegistry() error {
registryLock.Lock() registryLock.Lock()
defer registryLock.Unlock() defer registryLock.Unlock()

View file

@ -11,7 +11,7 @@ type Interface interface {
Get(key string) (record.Record, error) Get(key string) (record.Record, error)
Put(m record.Record) error Put(m record.Record) error
Delete(key string) error Delete(key string) error
Query(q *query.Query) (*iterator.Iterator, error) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error)
ReadOnly() bool ReadOnly() bool
Maintain() error Maintain() error