diff --git a/modules/modules.go b/modules/modules.go index f9ebb89..f3ed836 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -10,8 +10,6 @@ import ( ) var ( - startComplete = abool.NewBool(false) - modulesLock sync.Mutex modules = make(map[string]*Module) modulesOrder []*Module diff --git a/modules/modules_test.go b/modules/modules_test.go index 20e8236..2a53b49 100644 --- a/modules/modules_test.go +++ b/modules/modules_test.go @@ -89,6 +89,7 @@ func TestErrors(t *testing.T) { modules = make(map[string]*Module) modulesOrder = make([]*Module, 0) startComplete.UnSet() + startCompleteSignal = make(chan struct{}) // test prep error Register("prepfail", testFail, testStart("prepfail"), testStop("prepfail")) @@ -101,6 +102,7 @@ func TestErrors(t *testing.T) { modules = make(map[string]*Module) modulesOrder = make([]*Module, 0) startComplete.UnSet() + startCompleteSignal = make(chan struct{}) // test prep clean exit Register("prepcleanexit", testCleanExit, testStart("prepcleanexit"), testStop("prepcleanexit")) @@ -113,6 +115,7 @@ func TestErrors(t *testing.T) { modules = make(map[string]*Module) modulesOrder = make([]*Module, 0) startComplete.UnSet() + startCompleteSignal = make(chan struct{}) // test invalid dependency Register("database", testPrep, testStart("database"), testStop("database"), "invalid") @@ -131,6 +134,7 @@ func TestErrors(t *testing.T) { modules = make(map[string]*Module) modulesOrder = make([]*Module, 0) startComplete.UnSet() + startCompleteSignal = make(chan struct{}) // test dependency loop Register("database", testPrep, testStart("database"), testStop("database"), "helper") @@ -144,6 +148,7 @@ func TestErrors(t *testing.T) { modules = make(map[string]*Module) modulesOrder = make([]*Module, 0) startComplete.UnSet() + startCompleteSignal = make(chan struct{}) // test failing module start Register("startfail", testPrep, testFail, testStop("startfail")) @@ -156,6 +161,7 @@ func TestErrors(t *testing.T) { modules = make(map[string]*Module) modulesOrder = make([]*Module, 0) startComplete.UnSet() + startCompleteSignal = make(chan struct{}) // test failing module stop Register("stopfail", testPrep, testStart("stopfail"), testFail) @@ -172,6 +178,7 @@ func TestErrors(t *testing.T) { modules = make(map[string]*Module) modulesOrder = make([]*Module, 0) startComplete.UnSet() + startCompleteSignal = make(chan struct{}) // test help flag helpFlag = true diff --git a/modules/start.go b/modules/start.go index cc3290a..f059979 100644 --- a/modules/start.go +++ b/modules/start.go @@ -6,8 +6,31 @@ import ( "sync" "github.com/Safing/portbase/log" + "github.com/tevino/abool" ) +var ( + startComplete = abool.NewBool(false) + startCompleteSignal = make(chan struct{}) +) + +// markStartComplete marks the startup as completed. +func markStartComplete() { + if startComplete.SetToIf(false, true) { + close(startCompleteSignal) + } +} + +// StartCompleted returns whether starting has completed. +func StartCompleted() bool { + return startComplete.IsSet() +} + +// WaitForStartCompletion returns as soon as starting has completed. +func WaitForStartCompletion() <-chan struct{} { + return startCompleteSignal +} + // Start starts all modules in the correct order. In case of an error, it will automatically shutdown again. func Start() error { modulesLock.Lock() @@ -44,8 +67,12 @@ func Start() error { return err } - startComplete.Set() + // complete startup log.Infof("modules: started %d modules", len(modules)) + if startComplete.SetToIf(false, true) { + close(startCompleteSignal) + } + return nil }