From d6ef9a62f2c8de886f36ba6b1e23929e22e6d2be Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 12 Mar 2019 22:56:23 +0100 Subject: [PATCH] Improve and clean up modules package to also consider dependencies in prepping phase --- modules/modules.go | 87 ++++++++++++++++++-- modules/modules_test.go | 87 ++++++++++++-------- modules/start.go | 170 +++++++++++++++++++++------------------- modules/stop.go | 119 ++++++++++++++-------------- 4 files changed, 283 insertions(+), 180 deletions(-) diff --git a/modules/modules.go b/modules/modules.go index f3ed836..b9c6435 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -4,15 +4,15 @@ package modules import ( "errors" + "fmt" "sync" "github.com/tevino/abool" ) var ( - modulesLock sync.Mutex - modules = make(map[string]*Module) - modulesOrder []*Module + modulesLock sync.Mutex + modules = make(map[string]*Module) // ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag. ErrCleanExit = errors.New("clean exit requested") @@ -21,14 +21,18 @@ var ( // Module represents a module. type Module struct { Name string - Active *abool.AtomicBool + Prepped *abool.AtomicBool + Started *abool.AtomicBool + Stopped *abool.AtomicBool inTransition *abool.AtomicBool prep func() error start func() error stop func() error - dependencies []string + depNames []string + depModules []*Module + depReverse []*Module } func dummyAction() error { @@ -39,12 +43,14 @@ func dummyAction() error { func Register(name string, prep, start, stop func() error, dependencies ...string) *Module { newModule := &Module{ Name: name, - Active: abool.NewBool(false), + Prepped: abool.NewBool(false), + Started: abool.NewBool(false), + Stopped: abool.NewBool(false), inTransition: abool.NewBool(false), prep: prep, start: start, stop: stop, - dependencies: dependencies, + depNames: dependencies, } // replace nil arguments with dummy action @@ -60,7 +66,72 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin modulesLock.Lock() defer modulesLock.Unlock() - modulesOrder = append(modulesOrder, newModule) modules[name] = newModule return newModule } + +func initDependencies() error { + for _, m := range modules { + for _, depName := range m.depNames { + + // get dependency + depModule, ok := modules[depName] + if !ok { + return fmt.Errorf("modules: module %s declares dependency \"%s\", but this module has not been registered", m.Name, depName) + } + + // link together + m.depModules = append(m.depModules, depModule) + depModule.depReverse = append(depModule.depReverse, m) + + } + } + + return nil +} + +// ReadyToPrep returns whether all dependencies are ready for this module to prep. +func (m *Module) ReadyToPrep() bool { + if m.inTransition.IsSet() || m.Prepped.IsSet() { + return false + } + + for _, dep := range m.depModules { + if !dep.Prepped.IsSet() { + return false + } + } + + return true +} + +// ReadyToStart returns whether all dependencies are ready for this module to start. +func (m *Module) ReadyToStart() bool { + if m.inTransition.IsSet() || m.Started.IsSet() { + return false + } + + for _, dep := range m.depModules { + if !dep.Started.IsSet() { + return false + } + } + + return true +} + +// ReadyToStop returns whether all dependencies are ready for this module to stop. +func (m *Module) ReadyToStop() bool { + if !m.Started.IsSet() || m.inTransition.IsSet() || m.Stopped.IsSet() { + return false + } + + for _, revDep := range m.depReverse { + // not ready if a reverse dependency was started, but not yet stopped + if revDep.Started.IsSet() && !revDep.Stopped.IsSet() { + return false + } + } + + return true +} diff --git a/modules/modules_test.go b/modules/modules_test.go index 2a53b49..f18c506 100644 --- a/modules/modules_test.go +++ b/modules/modules_test.go @@ -11,16 +11,23 @@ import ( ) var ( + orderLock sync.Mutex startOrder string shutdownOrder string ) -func testPrep() error { - return nil +func testPrep(name string) func() error { + return func() error { + // fmt.Printf("prep %s\n", name) + return nil + } } func testStart(name string) func() error { return func() error { + orderLock.Lock() + defer orderLock.Unlock() + // fmt.Printf("start %s\n", name) startOrder = fmt.Sprintf("%s>%s", startOrder, name) return nil } @@ -28,6 +35,9 @@ func testStart(name string) func() error { func testStop(name string) func() error { return func() error { + orderLock.Lock() + defer orderLock.Unlock() + // fmt.Printf("stop %s\n", name) shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name) return nil } @@ -43,12 +53,21 @@ func testCleanExit() error { func TestOrdering(t *testing.T) { - Register("database", testPrep, testStart("database"), testStop("database")) - Register("stats", testPrep, testStart("stats"), testStop("stats"), "database") - Register("service", testPrep, testStart("service"), testStop("service"), "database") - Register("analytics", testPrep, testStart("analytics"), testStop("analytics"), "stats", "database") + Register("database", testPrep("database"), testStart("database"), testStop("database")) + Register("stats", testPrep("stats"), testStart("stats"), testStop("stats"), "database") + Register("service", testPrep("service"), testStart("service"), testStop("service"), "database") + Register("analytics", testPrep("analytics"), testStart("analytics"), testStop("analytics"), "stats", "database") - Start() + err := Start() + if err != nil { + t.Error(err) + } + + if startOrder != ">database>service>stats>analytics" && + startOrder != ">database>stats>service>analytics" && + startOrder != ">database>stats>analytics>service" { + t.Errorf("start order mismatch, was %s", startOrder) + } var wg sync.WaitGroup wg.Add(1) @@ -60,13 +79,11 @@ func TestOrdering(t *testing.T) { } wg.Done() }() - Shutdown() - - if startOrder != ">database>service>stats>analytics" && - startOrder != ">database>stats>service>analytics" && - startOrder != ">database>stats>analytics>service" { - t.Errorf("start order mismatch, was %s", startOrder) + err = Shutdown() + if err != nil { + t.Error(err) } + if shutdownOrder != ">analytics>service>stats>database" && shutdownOrder != ">analytics>stats>service>database" && shutdownOrder != ">service>analytics>stats>database" { @@ -74,12 +91,31 @@ func TestOrdering(t *testing.T) { } wg.Wait() + + printAndRemoveModules() +} + +func printAndRemoveModules() { + modulesLock.Lock() + defer modulesLock.Unlock() + + fmt.Printf("All %d modules:\n", len(modules)) + for _, m := range modules { + fmt.Printf("module %s: %+v\n", m.Name, m) + } + + modules = make(map[string]*Module) } func resetModules() { for _, module := range modules { - module.Active.UnSet() + module.Prepped.UnSet() + module.Started.UnSet() + module.Stopped.UnSet() module.inTransition.UnSet() + + module.depModules = make([]*Module, 0) + module.depModules = make([]*Module, 0) } } @@ -87,7 +123,6 @@ func TestErrors(t *testing.T) { // reset modules modules = make(map[string]*Module) - modulesOrder = make([]*Module, 0) startComplete.UnSet() startCompleteSignal = make(chan struct{}) @@ -100,7 +135,6 @@ func TestErrors(t *testing.T) { // reset modules modules = make(map[string]*Module) - modulesOrder = make([]*Module, 0) startComplete.UnSet() startCompleteSignal = make(chan struct{}) @@ -113,18 +147,11 @@ func TestErrors(t *testing.T) { // reset modules modules = make(map[string]*Module) - modulesOrder = make([]*Module, 0) startComplete.UnSet() startCompleteSignal = make(chan struct{}) // test invalid dependency - Register("database", testPrep, testStart("database"), testStop("database"), "invalid") - // go func() { - // time.Sleep(1 * time.Second) - // fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") - // pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - // os.Exit(1) - // }() + Register("database", nil, testStart("database"), testStop("database"), "invalid") err = Start() if err == nil { t.Error("should fail") @@ -132,13 +159,12 @@ func TestErrors(t *testing.T) { // reset modules modules = make(map[string]*Module) - modulesOrder = make([]*Module, 0) startComplete.UnSet() startCompleteSignal = make(chan struct{}) // test dependency loop - Register("database", testPrep, testStart("database"), testStop("database"), "helper") - Register("helper", testPrep, testStart("helper"), testStop("helper"), "database") + Register("database", nil, testStart("database"), testStop("database"), "helper") + Register("helper", nil, testStart("helper"), testStop("helper"), "database") err = Start() if err == nil { t.Error("should fail") @@ -146,12 +172,11 @@ func TestErrors(t *testing.T) { // reset modules modules = make(map[string]*Module) - modulesOrder = make([]*Module, 0) startComplete.UnSet() startCompleteSignal = make(chan struct{}) // test failing module start - Register("startfail", testPrep, testFail, testStop("startfail")) + Register("startfail", nil, testFail, testStop("startfail")) err = Start() if err == nil { t.Error("should fail") @@ -159,12 +184,11 @@ func TestErrors(t *testing.T) { // reset modules modules = make(map[string]*Module) - modulesOrder = make([]*Module, 0) startComplete.UnSet() startCompleteSignal = make(chan struct{}) // test failing module stop - Register("stopfail", testPrep, testStart("stopfail"), testFail) + Register("stopfail", nil, testStart("stopfail"), testFail) err = Start() if err != nil { t.Error("should not fail") @@ -176,7 +200,6 @@ func TestErrors(t *testing.T) { // reset modules modules = make(map[string]*Module) - modulesOrder = make([]*Module, 0) startComplete.UnSet() startCompleteSignal = make(chan struct{}) diff --git a/modules/start.go b/modules/start.go index f059979..01836a3 100644 --- a/modules/start.go +++ b/modules/start.go @@ -3,7 +3,6 @@ package modules import ( "fmt" "os" - "sync" "github.com/Safing/portbase/log" "github.com/tevino/abool" @@ -36,8 +35,14 @@ func Start() error { modulesLock.Lock() defer modulesLock.Unlock() + // inter-link modules + err := initDependencies() + if err != nil { + return err + } + // parse flags - err := parseFlags() + err = parseFlags() if err != nil { fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err) return err @@ -76,93 +81,100 @@ func Start() error { return nil } +type report struct { + module *Module + err error +} + func prepareModules() error { - for _, module := range modulesOrder { - err := module.prep() - if err != nil { - if err == ErrCleanExit { - return ErrCleanExit - } - return fmt.Errorf("failed to prep module %s: %s", module.Name, err) - } - } - return nil -} + var rep *report + reports := make(chan *report) + execCnt := 0 + reportCnt := 0 -func checkStartStatus() (readyToStart []*Module, done bool, err error) { - active := 0 - modulesInProgress := false - - // go through all modules -moduleLoop: - for _, module := range modules { - switch { - case module.Active.IsSet(): - active++ - case module.inTransition.IsSet(): - modulesInProgress = true - default: - for _, depName := range module.dependencies { - depModule, ok := modules[depName] - if !ok { - return nil, false, fmt.Errorf("modules: module %s declares dependency \"%s\", but this module has not been registered", module.Name, depName) - } - if !depModule.Active.IsSet() { - continue moduleLoop - } - } - - readyToStart = append(readyToStart, module) - } - } - - // detect dependency loop - if active < len(modules) && !modulesInProgress && len(readyToStart) == 0 { - return nil, false, fmt.Errorf("modules: dependency loop detected, cannot continue") - } - - if active == len(modules) { - return nil, true, nil - } - return readyToStart, false, nil -} - -func startModules() error { - var modulesStarting sync.WaitGroup - - reports := make(chan error, 10) for { - readyToStart, done, err := checkStartStatus() - if err != nil { - return err + // find modules to exec + for _, m := range modules { + if m.ReadyToPrep() { + execCnt++ + m.inTransition.Set() + + execM := m + go func() { + reports <- &report{ + module: execM, + err: execM.prep(), + } + }() + } } - if done { + // check for dep loop + if execCnt == reportCnt { + return fmt.Errorf("modules: dependency loop detected, cannot continue") + } + + // wait for reports + rep = <-reports + rep.module.inTransition.UnSet() + if rep.err != nil { + if rep.err == ErrCleanExit { + return rep.err + } + return fmt.Errorf("failed to prep module %s: %s", rep.module.Name, rep.err) + } + reportCnt++ + rep.module.Prepped.Set() + + // exit if done + if reportCnt == len(modules) { return nil } - for _, module := range readyToStart { - modulesStarting.Add(1) - module.inTransition.Set() - nextModule := module // workaround go vet alert - go func() { - startErr := nextModule.start() - if startErr != nil { - reports <- fmt.Errorf("modules: could not start module %s: %s", nextModule.Name, startErr) - } else { - log.Infof("modules: started %s", nextModule.Name) - nextModule.Active.Set() - nextModule.inTransition.UnSet() - reports <- nil - } - modulesStarting.Done() - }() + } +} + +func startModules() error { + var rep *report + reports := make(chan *report) + execCnt := 0 + reportCnt := 0 + + for { + // find modules to exec + for _, m := range modules { + if m.ReadyToStart() { + execCnt++ + m.inTransition.Set() + + execM := m + go func() { + reports <- &report{ + module: execM, + err: execM.start(), + } + }() + } } - err = <-reports - if err != nil { - modulesStarting.Wait() - return err + // check for dep loop + if execCnt == reportCnt { + return fmt.Errorf("modules: dependency loop detected, cannot continue") + } + + // wait for reports + rep = <-reports + rep.module.inTransition.UnSet() + if rep.err != nil { + return fmt.Errorf("modules: could not start module %s: %s", rep.module.Name, rep.err) + } + reportCnt++ + rep.module.Started.Set() + log.Infof("modules: started %s", rep.module.Name) + + // exit if done + if reportCnt == len(modules) { + return nil } } diff --git a/modules/stop.go b/modules/stop.go index 109ecdb..d4b302d 100644 --- a/modules/stop.go +++ b/modules/stop.go @@ -19,38 +19,6 @@ func ShuttingDown() <-chan struct{} { return shutdownSignal } -func checkStopStatus() (readyToStop []*Module, done bool) { - active := 0 - - // collect all active modules - activeModules := make(map[string]*Module) - for _, module := range modules { - if module.Active.IsSet() { - active++ - activeModules[module.Name] = module - } - } - if active == 0 { - return nil, true - } - - // remove modules that others depend on - for _, module := range activeModules { - for _, depName := range module.dependencies { - delete(activeModules, depName) - } - } - - // make list out of map, minus modules in transition - for _, module := range activeModules { - if !module.inTransition.IsSet() { - readyToStop = append(readyToStop, module) - } - } - - return readyToStop, false -} - // Shutdown stops all modules in the correct order. func Shutdown() error { @@ -69,38 +37,67 @@ func Shutdown() error { log.Warning("modules: aborting, shutting down...") } - reports := make(chan error, 10) - for { - readyToStop, done := checkStopStatus() - - if done { - break - } - - for _, module := range readyToStop { - module.inTransition.Set() - nextModule := module // workaround go vet alert - go func() { - err := nextModule.stop() - if err != nil { - reports <- fmt.Errorf("modules: could not stop module %s: %s", nextModule.Name, err) - } else { - reports <- nil - } - nextModule.Active.UnSet() - nextModule.inTransition.UnSet() - }() - } - - err := <-reports - if err != nil { - log.Error(err.Error()) - return err - } - + err := stopModules() + if err != nil { + log.Error(err.Error()) + return err } log.Info("modules: shutdown complete") log.Shutdown() return nil } + +func stopModules() error { + var rep *report + reports := make(chan *report) + execCnt := 0 + reportCnt := 0 + + // get number of started modules + startedCnt := 0 + for _, m := range modules { + if m.Started.IsSet() { + startedCnt++ + } + } + + for { + // find modules to exec + for _, m := range modules { + if m.ReadyToStop() { + execCnt++ + m.inTransition.Set() + + execM := m + go func() { + reports <- &report{ + module: execM, + err: execM.stop(), + } + }() + } + } + + // check for dep loop + if execCnt == reportCnt { + return fmt.Errorf("modules: dependency loop detected, cannot continue") + } + + // wait for reports + rep = <-reports + rep.module.inTransition.UnSet() + if rep.err != nil { + return fmt.Errorf("modules: could not stop module %s: %s", rep.module.Name, rep.err) + } + reportCnt++ + rep.module.Stopped.Set() + log.Infof("modules: stopped %s", rep.module.Name) + + // exit if done + if reportCnt == startedCnt { + return nil + } + + } +}