Block any database interaction when shutting down

This commit is contained in:
Daniel 2019-03-12 22:57:29 +01:00
parent d6ef9a62f2
commit 52b56450c7
3 changed files with 84 additions and 45 deletions

View file

@ -50,13 +50,13 @@ func (c *Controller) Injected() bool {
// Get return the record with the given key. // Get return the record with the given key.
func (c *Controller) Get(key string) (record.Record, error) { func (c *Controller) Get(key string) (record.Record, error) {
c.readLock.RLock()
defer c.readLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return nil, ErrShuttingDown return nil, ErrShuttingDown
} }
c.readLock.RLock()
defer c.readLock.RUnlock()
// process hooks // process hooks
for _, hook := range c.hooks { for _, hook := range c.hooks {
if hook.h.UsesPreGet() && hook.q.MatchesKey(key) { if hook.h.UsesPreGet() && hook.q.MatchesKey(key) {
@ -98,6 +98,9 @@ func (c *Controller) Get(key string) (record.Record, error) {
// Put saves a record in the database. // Put saves a record in the database.
func (c *Controller) Put(r record.Record) (err error) { func (c *Controller) Put(r record.Record) (err error) {
c.writeLock.RLock()
defer c.writeLock.RUnlock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
return ErrShuttingDown return ErrShuttingDown
} }
@ -116,9 +119,6 @@ func (c *Controller) Put(r record.Record) (err error) {
} }
} }
c.writeLock.RLock()
defer c.writeLock.RUnlock()
err = c.storage.Put(r) err = c.storage.Put(r)
if err != nil { if err != nil {
return err return err
@ -139,11 +139,13 @@ func (c *Controller) Put(r record.Record) (err error) {
// Query executes the given query on the database. // Query executes the given query on the database.
func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
c.readLock.RLock()
if shuttingDown.IsSet() { if shuttingDown.IsSet() {
c.readLock.RUnlock()
return nil, ErrShuttingDown return nil, ErrShuttingDown
} }
c.readLock.RLock()
it, err := c.storage.Query(q, local, internal) it, err := c.storage.Query(q, local, internal)
if err != nil { if err != nil {
c.readLock.RUnlock() c.readLock.RUnlock()
@ -160,6 +162,10 @@ func (c *Controller) PushUpdate(r record.Record) {
c.readLock.RLock() c.readLock.RLock()
defer c.readLock.RUnlock() defer c.readLock.RUnlock()
if shuttingDown.IsSet() {
return
}
for _, sub := range c.subscriptions { for _, sub := range c.subscriptions {
if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) { if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) {
select { select {
@ -177,6 +183,10 @@ func (c *Controller) addSubscription(sub *Subscription) {
c.writeLock.Lock() c.writeLock.Lock()
defer c.writeLock.Unlock() defer c.writeLock.Unlock()
if shuttingDown.IsSet() {
return
}
c.subscriptions = append(c.subscriptions, sub) c.subscriptions = append(c.subscriptions, sub)
} }
@ -189,6 +199,11 @@ func (c *Controller) readUnlockerAfterQuery(it *iterator.Iterator) {
func (c *Controller) Maintain() error { func (c *Controller) Maintain() error {
c.writeLock.RLock() c.writeLock.RLock()
defer c.writeLock.RUnlock() defer c.writeLock.RUnlock()
if shuttingDown.IsSet() {
return nil
}
return c.storage.Maintain() return c.storage.Maintain()
} }
@ -196,11 +211,21 @@ func (c *Controller) Maintain() error {
func (c *Controller) MaintainThorough() error { func (c *Controller) MaintainThorough() error {
c.writeLock.RLock() c.writeLock.RLock()
defer c.writeLock.RUnlock() defer c.writeLock.RUnlock()
if shuttingDown.IsSet() {
return nil
}
return c.storage.MaintainThorough() return c.storage.MaintainThorough()
} }
// Shutdown shuts down the storage. // Shutdown shuts down the storage.
func (c *Controller) Shutdown() error { func (c *Controller) Shutdown() error {
// TODO: should we wait for gets/puts/queries to complete? // aquire full locks
c.readLock.Lock()
defer c.readLock.Unlock()
c.writeLock.Lock()
defer c.writeLock.Unlock()
return c.storage.Shutdown() return c.storage.Shutdown()
} }

View file

@ -2,29 +2,35 @@ package database
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"fmt"
"github.com/Safing/portbase/database/storage" "github.com/Safing/portbase/database/storage"
) )
var ( var (
controllers = make(map[string]*Controller) controllers = make(map[string]*Controller)
controllersLock sync.Mutex controllersLock sync.RWMutex
) )
func getController(name string) (*Controller, error) { func getController(name string) (*Controller, error) {
if !initialized.IsSet() { if !initialized.IsSet() {
return nil, errors.New("database not initialized") return nil, errors.New("database not initialized")
} }
// return database if already started
controllersLock.RLock()
controller, ok := controllers[name]
controllersLock.RUnlock()
if ok {
return controller, nil
}
controllersLock.Lock() controllersLock.Lock()
defer controllersLock.Unlock() defer controllersLock.Unlock()
// return database if already started if shuttingDown.IsSet() {
controller, ok := controllers[name] return nil, ErrShuttingDown
if ok {
return controller, nil
} }
// get db registration // get db registration
@ -39,20 +45,20 @@ func getController(name string) (*Controller, error) {
return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err) return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err)
} }
// start database // start database
storageInt, err := storage.StartDatabase(name, registeredDB.StorageType, dbLocation) storageInt, err := storage.StartDatabase(name, registeredDB.StorageType, dbLocation)
if err != nil { if err != nil {
return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err) return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err)
} }
// create controller // create controller
controller, err = newController(storageInt) controller, err = newController(storageInt)
if err != nil { if err != nil {
return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err) return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err)
} }
controllers[name] = controller controllers[name] = controller
return controller, nil return controller, nil
} }
// InjectDatabase injects an already running database into the system. // InjectDatabase injects an already running database into the system.
@ -60,27 +66,31 @@ func InjectDatabase(name string, storageInt storage.Interface) (*Controller, err
controllersLock.Lock() controllersLock.Lock()
defer controllersLock.Unlock() defer controllersLock.Unlock()
if shuttingDown.IsSet() {
return nil, ErrShuttingDown
}
_, ok := controllers[name] _, ok := controllers[name]
if ok { if ok {
return nil, errors.New(`database "%s" already loaded`) return nil, errors.New(`database "%s" already loaded`)
} }
registryLock.Lock() registryLock.Lock()
defer registryLock.Unlock() defer registryLock.Unlock()
// check if database is registered // check if database is registered
registeredDB, ok := registry[name] registeredDB, ok := registry[name]
if !ok { if !ok {
return nil, fmt.Errorf(`database "%s" not registered`, name) return nil, fmt.Errorf(`database "%s" not registered`, name)
} }
if registeredDB.StorageType != "injected" { if registeredDB.StorageType != "injected" {
return nil, fmt.Errorf(`database not of type "injected"`) return nil, fmt.Errorf(`database not of type "injected"`)
} }
controller, err := newController(storageInt) controller, err := newController(storageInt)
if err != nil { if err != nil {
return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err) return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err)
} }
controllers[name] = controller controllers[name] = controller
return controller, nil return controller, nil

View file

@ -50,10 +50,14 @@ func Initialize() error {
func Shutdown() (err error) { func Shutdown() (err error) {
if shuttingDown.SetToIf(false, true) { if shuttingDown.SetToIf(false, true) {
close(shutdownSignal) close(shutdownSignal)
} else {
return
} }
all := duplicateControllers() controllersLock.RLock()
for _, c := range all { defer controllersLock.RUnlock()
for _, c := range controllers {
err = c.Shutdown() err = c.Shutdown()
if err != nil { if err != nil {
return return