diff --git a/api/main.go b/api/main.go index ea74793..830fe79 100644 --- a/api/main.go +++ b/api/main.go @@ -13,7 +13,7 @@ var ( ) func init() { - modules.Register("api", prep, start, nil, "core") + modules.Register("api", prep, start, stop, "base", "database", "config") } func prep() error { diff --git a/config/main.go b/config/main.go index 2c5003e..3b0d7d5 100644 --- a/config/main.go +++ b/config/main.go @@ -22,7 +22,7 @@ func SetDataRoot(root *utils.DirStructure) { } func init() { - modules.Register("config", prep, start, nil, "core") + modules.Register("config", prep, start, nil, "base", "database") } func prep() error { diff --git a/crypto/random/rng.go b/crypto/random/rng.go index 464b363..ef461a5 100644 --- a/crypto/random/rng.go +++ b/crypto/random/rng.go @@ -23,7 +23,7 @@ var ( ) func init() { - modules.Register("random", prep, Start, stop) + modules.Register("random", prep, Start, stop, "base") config.Register(&config.Option{ Name: "RNG Cipher", diff --git a/database/dbmodule/db.go b/database/dbmodule/db.go index 765af96..780ae7e 100644 --- a/database/dbmodule/db.go +++ b/database/dbmodule/db.go @@ -2,51 +2,46 @@ package dbmodule import ( "errors" - "flag" - "sync" "github.com/safing/portbase/database" "github.com/safing/portbase/modules" + "github.com/safing/portbase/utils" ) var ( - databaseDir string - shutdownSignal = make(chan struct{}) - maintenanceWg sync.WaitGroup + databasePath string + databaseStructureRoot *utils.DirStructure + + module *modules.Module ) -// SetDatabaseLocation sets the location of the database. Must be called before modules.Start and will be overridden by command line options. Intended for unit tests. -func SetDatabaseLocation(location string) { - databaseDir = location +func init() { + module = modules.Register("database", prep, start, stop, "base") } -func init() { - flag.StringVar(&databaseDir, "db", "", "set database directory") - - modules.Register("database", prep, start, stop) +// SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure. +func SetDatabaseLocation(dirPath string, dirStructureRoot *utils.DirStructure) { + databasePath = dirPath + databaseStructureRoot = dirStructureRoot } func prep() error { - if databaseDir == "" { - return errors.New("no database location specified, set with `-db=/path/to/db`") - } - ok := database.SetLocation(databaseDir) - if !ok { - return errors.New("database location already set") + if databasePath == "" && databaseStructureRoot == nil { + return errors.New("no database location specified") } return nil } func start() error { - err := database.Initialize() - if err == nil { - startMaintainer() + err := database.Initialize(databasePath, databaseStructureRoot) + if err != nil { + return err } - return err + + startMaintainer() + return nil } func stop() error { - close(shutdownSignal) - maintenanceWg.Wait() return database.Shutdown() } diff --git a/database/dbmodule/maintenance.go b/database/dbmodule/maintenance.go index 1635a20..82c6f98 100644 --- a/database/dbmodule/maintenance.go +++ b/database/dbmodule/maintenance.go @@ -13,7 +13,7 @@ var ( ) func startMaintainer() { - maintenanceWg.Add(1) + module.AddWorkers(1) go maintenanceWorker() } @@ -37,8 +37,8 @@ func maintenanceWorker() { if err != nil { log.Errorf("database: thorough maintenance error: %s", err) } - case <-shutdownSignal: - maintenanceWg.Done() + case <-module.ShuttingDown(): + module.FinishWorker() return } } diff --git a/database/main.go b/database/main.go index 3a51e8e..d49962c 100644 --- a/database/main.go +++ b/database/main.go @@ -23,7 +23,7 @@ var ( databasesStructure *utils.DirStructure ) -// Initialize initialized the database at the specified location. Supply either a path or dir structure. +// Initialize initializes the database at the specified location. Supply either a path or dir structure. func Initialize(dirPath string, dirStructureRoot *utils.DirStructure) error { if initialized.SetToIf(false, true) { diff --git a/modules/flags.go b/modules/flags.go index 2450119..d9ad1e8 100644 --- a/modules/flags.go +++ b/modules/flags.go @@ -3,6 +3,7 @@ package modules import "flag" var ( + // HelpFlag triggers printing flag.Usage. It's exported for custom help handling. HelpFlag bool ) diff --git a/modules/modules.go b/modules/modules.go index 76f45f7..5d55784 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -1,10 +1,14 @@ package modules import ( + "context" "errors" "fmt" "sync" + "sync/atomic" + "time" + "github.com/safing/portbase/log" "github.com/tevino/abool" ) @@ -18,33 +22,104 @@ var ( // Module represents a module. type Module struct { - Name string + Name string + + // lifecycle mgmt Prepped *abool.AtomicBool Started *abool.AtomicBool Stopped *abool.AtomicBool inTransition *abool.AtomicBool + // lifecycle callback functions prep func() error start func() error stop func() error + // shutdown mgmt + Ctx context.Context + cancelCtx func() + shutdownFlag *abool.AtomicBool + workerGroup sync.WaitGroup + workerCnt *int32 + + // dependency mgmt depNames []string depModules []*Module depReverse []*Module } +// AddWorkers adds workers to the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup. +func (m *Module) AddWorkers(n uint) { + if !m.ShutdownInProgress() { + if atomic.AddInt32(m.workerCnt, int32(n)) > 0 { + // only add to workgroup if cnt is positive (try to compensate wrong usage) + m.workerGroup.Add(int(n)) + } + } +} + +// FinishWorker removes a worker from the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup. +func (m *Module) FinishWorker() { + // check worker cnt + if atomic.AddInt32(m.workerCnt, -1) < 0 { + log.Warningf("modules: %s module tried to finish more workers than added, this may lead to undefined behavior when shutting down", m.Name) + return + } + // also mark worker done in workgroup + m.workerGroup.Done() +} + +// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead. +func (m *Module) ShutdownInProgress() bool { + return m.shutdownFlag.IsSet() +} + +// ShuttingDown lets you listen for the shutdown signal. +func (m *Module) ShuttingDown() <-chan struct{} { + return m.Ctx.Done() +} + +func (m *Module) shutdown() error { + // signal shutdown + m.shutdownFlag.Set() + m.cancelCtx() + + // wait for workers + done := make(chan struct{}) + go func() { + m.workerGroup.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(3 * time.Second): + return errors.New("timed out while waiting for module workers to finish") + } + + // call shutdown function + return m.stop() +} + func dummyAction() error { return nil } -// Register registers a new module. +// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished. func Register(name string, prep, start, stop func() error, dependencies ...string) *Module { + ctx, cancelCtx := context.WithCancel(context.Background()) + var workerCnt int32 + newModule := &Module{ Name: name, Prepped: abool.NewBool(false), Started: abool.NewBool(false), Stopped: abool.NewBool(false), inTransition: abool.NewBool(false), + Ctx: ctx, + cancelCtx: cancelCtx, + shutdownFlag: abool.NewBool(false), + workerGroup: sync.WaitGroup{}, + workerCnt: &workerCnt, prep: prep, start: start, stop: stop, diff --git a/modules/modules_test.go b/modules/modules_test.go index 966587b..5a85378 100644 --- a/modules/modules_test.go +++ b/modules/modules_test.go @@ -202,11 +202,11 @@ func TestErrors(t *testing.T) { startCompleteSignal = make(chan struct{}) // test help flag - helpFlag = true + HelpFlag = true err = Start() if err == nil { t.Error("should fail") } - helpFlag = false + HelpFlag = false } diff --git a/modules/stop.go b/modules/stop.go index 9bb90b0..ad1a79e 100644 --- a/modules/stop.go +++ b/modules/stop.go @@ -74,7 +74,7 @@ func stopModules() error { go func() { reports <- &report{ module: execM, - err: execM.stop(), + err: execM.shutdown(), } }() } diff --git a/notifications/module.go b/notifications/module.go index 21e5997..63902e2 100644 --- a/notifications/module.go +++ b/notifications/module.go @@ -12,7 +12,7 @@ var ( ) func init() { - modules.Register("notifications", nil, start, nil, "core") + modules.Register("notifications", nil, start, nil, "base", "database") } func start() error {