Improve and clean up modules package to also consider dependencies in prepping phase

This commit is contained in:
Daniel 2019-03-12 22:56:23 +01:00
parent 3188134203
commit d6ef9a62f2
4 changed files with 283 additions and 180 deletions

View file

@ -4,15 +4,15 @@ package modules
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"github.com/tevino/abool" "github.com/tevino/abool"
) )
var ( var (
modulesLock sync.Mutex modulesLock sync.Mutex
modules = make(map[string]*Module) modules = make(map[string]*Module)
modulesOrder []*Module
// ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag. // ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag.
ErrCleanExit = errors.New("clean exit requested") ErrCleanExit = errors.New("clean exit requested")
@ -21,14 +21,18 @@ var (
// Module represents a module. // Module represents a module.
type Module struct { type Module struct {
Name string Name string
Active *abool.AtomicBool Prepped *abool.AtomicBool
Started *abool.AtomicBool
Stopped *abool.AtomicBool
inTransition *abool.AtomicBool inTransition *abool.AtomicBool
prep func() error prep func() error
start func() error start func() error
stop func() error stop func() error
dependencies []string depNames []string
depModules []*Module
depReverse []*Module
} }
func dummyAction() error { func dummyAction() error {
@ -39,12 +43,14 @@ func dummyAction() error {
func Register(name string, prep, start, stop func() error, dependencies ...string) *Module { func Register(name string, prep, start, stop func() error, dependencies ...string) *Module {
newModule := &Module{ newModule := &Module{
Name: name, Name: name,
Active: abool.NewBool(false), Prepped: abool.NewBool(false),
Started: abool.NewBool(false),
Stopped: abool.NewBool(false),
inTransition: abool.NewBool(false), inTransition: abool.NewBool(false),
prep: prep, prep: prep,
start: start, start: start,
stop: stop, stop: stop,
dependencies: dependencies, depNames: dependencies,
} }
// replace nil arguments with dummy action // replace nil arguments with dummy action
@ -60,7 +66,72 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin
modulesLock.Lock() modulesLock.Lock()
defer modulesLock.Unlock() defer modulesLock.Unlock()
modulesOrder = append(modulesOrder, newModule)
modules[name] = newModule modules[name] = newModule
return newModule return newModule
} }
func initDependencies() error {
for _, m := range modules {
for _, depName := range m.depNames {
// get dependency
depModule, ok := modules[depName]
if !ok {
return fmt.Errorf("modules: module %s declares dependency \"%s\", but this module has not been registered", m.Name, depName)
}
// link together
m.depModules = append(m.depModules, depModule)
depModule.depReverse = append(depModule.depReverse, m)
}
}
return nil
}
// ReadyToPrep returns whether all dependencies are ready for this module to prep.
func (m *Module) ReadyToPrep() bool {
if m.inTransition.IsSet() || m.Prepped.IsSet() {
return false
}
for _, dep := range m.depModules {
if !dep.Prepped.IsSet() {
return false
}
}
return true
}
// ReadyToStart returns whether all dependencies are ready for this module to start.
func (m *Module) ReadyToStart() bool {
if m.inTransition.IsSet() || m.Started.IsSet() {
return false
}
for _, dep := range m.depModules {
if !dep.Started.IsSet() {
return false
}
}
return true
}
// ReadyToStop returns whether all dependencies are ready for this module to stop.
func (m *Module) ReadyToStop() bool {
if !m.Started.IsSet() || m.inTransition.IsSet() || m.Stopped.IsSet() {
return false
}
for _, revDep := range m.depReverse {
// not ready if a reverse dependency was started, but not yet stopped
if revDep.Started.IsSet() && !revDep.Stopped.IsSet() {
return false
}
}
return true
}

View file

@ -11,16 +11,23 @@ import (
) )
var ( var (
orderLock sync.Mutex
startOrder string startOrder string
shutdownOrder string shutdownOrder string
) )
func testPrep() error { func testPrep(name string) func() error {
return nil return func() error {
// fmt.Printf("prep %s\n", name)
return nil
}
} }
func testStart(name string) func() error { func testStart(name string) func() error {
return func() error { return func() error {
orderLock.Lock()
defer orderLock.Unlock()
// fmt.Printf("start %s\n", name)
startOrder = fmt.Sprintf("%s>%s", startOrder, name) startOrder = fmt.Sprintf("%s>%s", startOrder, name)
return nil return nil
} }
@ -28,6 +35,9 @@ func testStart(name string) func() error {
func testStop(name string) func() error { func testStop(name string) func() error {
return func() error { return func() error {
orderLock.Lock()
defer orderLock.Unlock()
// fmt.Printf("stop %s\n", name)
shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name) shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name)
return nil return nil
} }
@ -43,12 +53,21 @@ func testCleanExit() error {
func TestOrdering(t *testing.T) { func TestOrdering(t *testing.T) {
Register("database", testPrep, testStart("database"), testStop("database")) Register("database", testPrep("database"), testStart("database"), testStop("database"))
Register("stats", testPrep, testStart("stats"), testStop("stats"), "database") Register("stats", testPrep("stats"), testStart("stats"), testStop("stats"), "database")
Register("service", testPrep, testStart("service"), testStop("service"), "database") Register("service", testPrep("service"), testStart("service"), testStop("service"), "database")
Register("analytics", testPrep, testStart("analytics"), testStop("analytics"), "stats", "database") Register("analytics", testPrep("analytics"), testStart("analytics"), testStop("analytics"), "stats", "database")
Start() err := Start()
if err != nil {
t.Error(err)
}
if startOrder != ">database>service>stats>analytics" &&
startOrder != ">database>stats>service>analytics" &&
startOrder != ">database>stats>analytics>service" {
t.Errorf("start order mismatch, was %s", startOrder)
}
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -60,13 +79,11 @@ func TestOrdering(t *testing.T) {
} }
wg.Done() wg.Done()
}() }()
Shutdown() err = Shutdown()
if err != nil {
if startOrder != ">database>service>stats>analytics" && t.Error(err)
startOrder != ">database>stats>service>analytics" &&
startOrder != ">database>stats>analytics>service" {
t.Errorf("start order mismatch, was %s", startOrder)
} }
if shutdownOrder != ">analytics>service>stats>database" && if shutdownOrder != ">analytics>service>stats>database" &&
shutdownOrder != ">analytics>stats>service>database" && shutdownOrder != ">analytics>stats>service>database" &&
shutdownOrder != ">service>analytics>stats>database" { shutdownOrder != ">service>analytics>stats>database" {
@ -74,12 +91,31 @@ func TestOrdering(t *testing.T) {
} }
wg.Wait() wg.Wait()
printAndRemoveModules()
}
func printAndRemoveModules() {
modulesLock.Lock()
defer modulesLock.Unlock()
fmt.Printf("All %d modules:\n", len(modules))
for _, m := range modules {
fmt.Printf("module %s: %+v\n", m.Name, m)
}
modules = make(map[string]*Module)
} }
func resetModules() { func resetModules() {
for _, module := range modules { for _, module := range modules {
module.Active.UnSet() module.Prepped.UnSet()
module.Started.UnSet()
module.Stopped.UnSet()
module.inTransition.UnSet() module.inTransition.UnSet()
module.depModules = make([]*Module, 0)
module.depModules = make([]*Module, 0)
} }
} }
@ -87,7 +123,6 @@ func TestErrors(t *testing.T) {
// reset modules // reset modules
modules = make(map[string]*Module) modules = make(map[string]*Module)
modulesOrder = make([]*Module, 0)
startComplete.UnSet() startComplete.UnSet()
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
@ -100,7 +135,6 @@ func TestErrors(t *testing.T) {
// reset modules // reset modules
modules = make(map[string]*Module) modules = make(map[string]*Module)
modulesOrder = make([]*Module, 0)
startComplete.UnSet() startComplete.UnSet()
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
@ -113,18 +147,11 @@ func TestErrors(t *testing.T) {
// reset modules // reset modules
modules = make(map[string]*Module) modules = make(map[string]*Module)
modulesOrder = make([]*Module, 0)
startComplete.UnSet() startComplete.UnSet()
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test invalid dependency // test invalid dependency
Register("database", testPrep, testStart("database"), testStop("database"), "invalid") Register("database", nil, testStart("database"), testStop("database"), "invalid")
// go func() {
// time.Sleep(1 * time.Second)
// fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====")
// pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
// os.Exit(1)
// }()
err = Start() err = Start()
if err == nil { if err == nil {
t.Error("should fail") t.Error("should fail")
@ -132,13 +159,12 @@ func TestErrors(t *testing.T) {
// reset modules // reset modules
modules = make(map[string]*Module) modules = make(map[string]*Module)
modulesOrder = make([]*Module, 0)
startComplete.UnSet() startComplete.UnSet()
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test dependency loop // test dependency loop
Register("database", testPrep, testStart("database"), testStop("database"), "helper") Register("database", nil, testStart("database"), testStop("database"), "helper")
Register("helper", testPrep, testStart("helper"), testStop("helper"), "database") Register("helper", nil, testStart("helper"), testStop("helper"), "database")
err = Start() err = Start()
if err == nil { if err == nil {
t.Error("should fail") t.Error("should fail")
@ -146,12 +172,11 @@ func TestErrors(t *testing.T) {
// reset modules // reset modules
modules = make(map[string]*Module) modules = make(map[string]*Module)
modulesOrder = make([]*Module, 0)
startComplete.UnSet() startComplete.UnSet()
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test failing module start // test failing module start
Register("startfail", testPrep, testFail, testStop("startfail")) Register("startfail", nil, testFail, testStop("startfail"))
err = Start() err = Start()
if err == nil { if err == nil {
t.Error("should fail") t.Error("should fail")
@ -159,12 +184,11 @@ func TestErrors(t *testing.T) {
// reset modules // reset modules
modules = make(map[string]*Module) modules = make(map[string]*Module)
modulesOrder = make([]*Module, 0)
startComplete.UnSet() startComplete.UnSet()
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test failing module stop // test failing module stop
Register("stopfail", testPrep, testStart("stopfail"), testFail) Register("stopfail", nil, testStart("stopfail"), testFail)
err = Start() err = Start()
if err != nil { if err != nil {
t.Error("should not fail") t.Error("should not fail")
@ -176,7 +200,6 @@ func TestErrors(t *testing.T) {
// reset modules // reset modules
modules = make(map[string]*Module) modules = make(map[string]*Module)
modulesOrder = make([]*Module, 0)
startComplete.UnSet() startComplete.UnSet()
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})

View file

@ -3,7 +3,6 @@ package modules
import ( import (
"fmt" "fmt"
"os" "os"
"sync"
"github.com/Safing/portbase/log" "github.com/Safing/portbase/log"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -36,8 +35,14 @@ func Start() error {
modulesLock.Lock() modulesLock.Lock()
defer modulesLock.Unlock() defer modulesLock.Unlock()
// inter-link modules
err := initDependencies()
if err != nil {
return err
}
// parse flags // parse flags
err := parseFlags() err = parseFlags()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err) fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err)
return err return err
@ -76,93 +81,100 @@ func Start() error {
return nil return nil
} }
type report struct {
module *Module
err error
}
func prepareModules() error { func prepareModules() error {
for _, module := range modulesOrder { var rep *report
err := module.prep() reports := make(chan *report)
if err != nil { execCnt := 0
if err == ErrCleanExit { reportCnt := 0
return ErrCleanExit
}
return fmt.Errorf("failed to prep module %s: %s", module.Name, err)
}
}
return nil
}
func checkStartStatus() (readyToStart []*Module, done bool, err error) {
active := 0
modulesInProgress := false
// go through all modules
moduleLoop:
for _, module := range modules {
switch {
case module.Active.IsSet():
active++
case module.inTransition.IsSet():
modulesInProgress = true
default:
for _, depName := range module.dependencies {
depModule, ok := modules[depName]
if !ok {
return nil, false, fmt.Errorf("modules: module %s declares dependency \"%s\", but this module has not been registered", module.Name, depName)
}
if !depModule.Active.IsSet() {
continue moduleLoop
}
}
readyToStart = append(readyToStart, module)
}
}
// detect dependency loop
if active < len(modules) && !modulesInProgress && len(readyToStart) == 0 {
return nil, false, fmt.Errorf("modules: dependency loop detected, cannot continue")
}
if active == len(modules) {
return nil, true, nil
}
return readyToStart, false, nil
}
func startModules() error {
var modulesStarting sync.WaitGroup
reports := make(chan error, 10)
for { for {
readyToStart, done, err := checkStartStatus() // find modules to exec
if err != nil { for _, m := range modules {
return err if m.ReadyToPrep() {
execCnt++
m.inTransition.Set()
execM := m
go func() {
reports <- &report{
module: execM,
err: execM.prep(),
}
}()
}
} }
if done { // check for dep loop
if execCnt == reportCnt {
return fmt.Errorf("modules: dependency loop detected, cannot continue")
}
// wait for reports
rep = <-reports
rep.module.inTransition.UnSet()
if rep.err != nil {
if rep.err == ErrCleanExit {
return rep.err
}
return fmt.Errorf("failed to prep module %s: %s", rep.module.Name, rep.err)
}
reportCnt++
rep.module.Prepped.Set()
// exit if done
if reportCnt == len(modules) {
return nil return nil
} }
for _, module := range readyToStart { }
modulesStarting.Add(1) }
module.inTransition.Set()
nextModule := module // workaround go vet alert func startModules() error {
go func() { var rep *report
startErr := nextModule.start() reports := make(chan *report)
if startErr != nil { execCnt := 0
reports <- fmt.Errorf("modules: could not start module %s: %s", nextModule.Name, startErr) reportCnt := 0
} else {
log.Infof("modules: started %s", nextModule.Name) for {
nextModule.Active.Set() // find modules to exec
nextModule.inTransition.UnSet() for _, m := range modules {
reports <- nil if m.ReadyToStart() {
} execCnt++
modulesStarting.Done() m.inTransition.Set()
}()
execM := m
go func() {
reports <- &report{
module: execM,
err: execM.start(),
}
}()
}
} }
err = <-reports // check for dep loop
if err != nil { if execCnt == reportCnt {
modulesStarting.Wait() return fmt.Errorf("modules: dependency loop detected, cannot continue")
return err }
// wait for reports
rep = <-reports
rep.module.inTransition.UnSet()
if rep.err != nil {
return fmt.Errorf("modules: could not start module %s: %s", rep.module.Name, rep.err)
}
reportCnt++
rep.module.Started.Set()
log.Infof("modules: started %s", rep.module.Name)
// exit if done
if reportCnt == len(modules) {
return nil
} }
} }

View file

@ -19,38 +19,6 @@ func ShuttingDown() <-chan struct{} {
return shutdownSignal return shutdownSignal
} }
func checkStopStatus() (readyToStop []*Module, done bool) {
active := 0
// collect all active modules
activeModules := make(map[string]*Module)
for _, module := range modules {
if module.Active.IsSet() {
active++
activeModules[module.Name] = module
}
}
if active == 0 {
return nil, true
}
// remove modules that others depend on
for _, module := range activeModules {
for _, depName := range module.dependencies {
delete(activeModules, depName)
}
}
// make list out of map, minus modules in transition
for _, module := range activeModules {
if !module.inTransition.IsSet() {
readyToStop = append(readyToStop, module)
}
}
return readyToStop, false
}
// Shutdown stops all modules in the correct order. // Shutdown stops all modules in the correct order.
func Shutdown() error { func Shutdown() error {
@ -69,38 +37,67 @@ func Shutdown() error {
log.Warning("modules: aborting, shutting down...") log.Warning("modules: aborting, shutting down...")
} }
reports := make(chan error, 10) err := stopModules()
for { if err != nil {
readyToStop, done := checkStopStatus() log.Error(err.Error())
return err
if done {
break
}
for _, module := range readyToStop {
module.inTransition.Set()
nextModule := module // workaround go vet alert
go func() {
err := nextModule.stop()
if err != nil {
reports <- fmt.Errorf("modules: could not stop module %s: %s", nextModule.Name, err)
} else {
reports <- nil
}
nextModule.Active.UnSet()
nextModule.inTransition.UnSet()
}()
}
err := <-reports
if err != nil {
log.Error(err.Error())
return err
}
} }
log.Info("modules: shutdown complete") log.Info("modules: shutdown complete")
log.Shutdown() log.Shutdown()
return nil return nil
} }
func stopModules() error {
var rep *report
reports := make(chan *report)
execCnt := 0
reportCnt := 0
// get number of started modules
startedCnt := 0
for _, m := range modules {
if m.Started.IsSet() {
startedCnt++
}
}
for {
// find modules to exec
for _, m := range modules {
if m.ReadyToStop() {
execCnt++
m.inTransition.Set()
execM := m
go func() {
reports <- &report{
module: execM,
err: execM.stop(),
}
}()
}
}
// check for dep loop
if execCnt == reportCnt {
return fmt.Errorf("modules: dependency loop detected, cannot continue")
}
// wait for reports
rep = <-reports
rep.module.inTransition.UnSet()
if rep.err != nil {
return fmt.Errorf("modules: could not stop module %s: %s", rep.module.Name, rep.err)
}
reportCnt++
rep.module.Stopped.Set()
log.Infof("modules: stopped %s", rep.module.Name)
// exit if done
if reportCnt == startedCnt {
return nil
}
}
}