diff --git a/modules/flags.go b/modules/flags.go new file mode 100644 index 0000000..45328d9 --- /dev/null +++ b/modules/flags.go @@ -0,0 +1,24 @@ +package modules + +import "flag" + +var ( + helpFlag bool +) + +func init() { + flag.BoolVar(&helpFlag, "help", false, "print help") +} + +func parseFlags() error { + + // parse flags + flag.Parse() + + if helpFlag { + flag.Usage() + return ErrCleanExit + } + + return nil +} diff --git a/modules/logging.go b/modules/logging.go deleted file mode 100644 index 4215c41..0000000 --- a/modules/logging.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - -package modules - -var logger Logger -var loggerRegistered chan struct{} - -type Logger interface { - Tracef(things ...interface{}) - Trace(msg string) - Debugf(things ...interface{}) - Debug(msg string) - Infof(things ...interface{}) - Info(msg string) - Warningf(things ...interface{}) - Warning(msg string) - Errorf(things ...interface{}) - Error(msg string) - Criticalf(things ...interface{}) - Critical(msg string) -} - -func RegisterLogger(newLogger Logger) { - if logger == nil { - logger = newLogger - loggerRegistered <- struct{}{} - } -} - -func GetLogger() Logger { - return logger -} diff --git a/modules/modules.go b/modules/modules.go index f873699..df9c065 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -3,145 +3,49 @@ package modules import ( - "container/list" - "os" - "time" + "errors" + "sync" "github.com/tevino/abool" ) -var modules *list.List -var addModule chan *Module -var GlobalShutdown chan struct{} -var loggingActive bool +var ( + startComplete = abool.NewBool(false) + modulesLock sync.Mutex + modules = make(map[string]*Module) + modulesOrder []*Module + + // ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag. + ErrCleanExit = errors.New("clean exit requested") +) + +// Module represents a module. type Module struct { - Name string - Order uint8 + Name string + Active *abool.AtomicBool - Start chan struct{} - Active *abool.AtomicBool - startComplete chan struct{} + prep func() error + start func() error + starting bool + stop func() error - Stop chan struct{} - Stopped *abool.AtomicBool - stopComplete chan struct{} + dependencies []string } -func Register(name string, order uint8) *Module { +// Register registers a new module. +func Register(name string, prep, start, stop func() error, dependencies ...string) *Module { newModule := &Module{ - Name: name, - Order: order, - Start: make(chan struct{}), - Active: abool.NewBool(true), - startComplete: make(chan struct{}), - - Stop: make(chan struct{}), - Stopped: abool.NewBool(false), - stopComplete: make(chan struct{}), + Name: name, + Active: abool.NewBool(false), + prep: prep, + start: start, + stop: stop, + dependencies: dependencies, } - addModule <- newModule + modulesLock.Lock() + defer modulesLock.Unlock() + modulesOrder = append(modulesOrder, newModule) + modules[name] = newModule return newModule } - -func (module *Module) addToList() { - if loggingActive { - logger.Infof("Modules: starting %s", module.Name) - } - for e := modules.Back(); e != nil; e = e.Prev() { - if module.Order > e.Value.(*Module).Order { - modules.InsertAfter(module, e) - return - } - } - modules.PushFront(module) -} - -func (module *Module) stop() { - module.Active.UnSet() - defer module.Stopped.Set() - for { - select { - case module.Stop <- struct{}{}: - case <-module.stopComplete: - return - case <-time.After(1 * time.Second): - if loggingActive { - logger.Warningf("Modules: waiting for %s to stop...", module.Name) - } - } - } -} - -func (module *Module) StopComplete() { - if loggingActive { - logger.Warningf("Modules: stopped %s", module.Name) - } - module.stopComplete <- struct{}{} -} - -func (module *Module) start() { - module.Stopped.UnSet() - defer module.Active.Set() - for { - select { - case module.Start <- struct{}{}: - case <-module.startComplete: - return - } - } -} - -func (module *Module) StartComplete() { - if loggingActive { - logger.Infof("Modules: starting %s", module.Name) - } - module.startComplete <- struct{}{} -} - -func InitiateFullShutdown() { - close(GlobalShutdown) -} - -func fullStop() { - for e := modules.Back(); e != nil; e = e.Prev() { - if e.Value.(*Module).Active.IsSet() { - e.Value.(*Module).stop() - } - } -} - -func run() { - select { - case <-loggerRegistered: - logger.Info("Modules: starting") - loggingActive = true - case <-time.After(1 * time.Second): - } - - for { - select { - case <-GlobalShutdown: - if loggingActive { - logger.Warning("Modules: stopping") - } - fullStop() - os.Exit(0) - case m := <-addModule: - m.addToList() - // go m.start() - } - } -} - -func init() { - - modules = list.New() - addModule = make(chan *Module, 10) - GlobalShutdown = make(chan struct{}) - loggerRegistered = make(chan struct{}, 1) - loggingActive = false - - go run() - -} diff --git a/modules/modules_test.go b/modules/modules_test.go index 5198fda..9665e59 100644 --- a/modules/modules_test.go +++ b/modules/modules_test.go @@ -3,48 +3,182 @@ package modules import ( + "errors" "fmt" + "testing" + "sync" "time" ) -func newTestModule(name string, order uint8) { +var ( + startOrder string + shutdownOrder string +) - fmt.Printf("up %s\n", name) - module := Register("TestModule", order) +func testPrep() error { + return nil +} +func testStart(name string) func() error { + return func() error { + startOrder = fmt.Sprintf("%s>%s", startOrder, name) + return nil + } +} + +func testStop(name string) func() error { + return func() error { + shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name) + return nil + } +} + +func testFail() error { + return errors.New("test error") +} + +func testCleanExit() error { + return ErrCleanExit +} + +func TestOrdering(t *testing.T) { + + Register("database", testPrep, testStart("database"), testStop("database")) + Register("stats", testPrep, testStart("stats"), testStop("stats"), "database") + Register("service", testPrep, testStart("service"), testStop("service"), "database") + Register("analytics", testPrep, testStart("analytics"), testStop("analytics"), "stats", "database") + + Start() + + var wg sync.WaitGroup + wg.Add(1) go func() { - <-module.Stop - fmt.Printf("down %s\n", name) - module.StopComplete() + select { + case <-ShuttingDown(): + case <-time.After(1 * time.Second): + t.Error("did not receive shutdown signal") + } + wg.Done() }() + Shutdown() + if startOrder != ">database>service>stats>analytics" && + startOrder != ">database>stats>service>analytics" && + startOrder != ">database>stats>analytics>service" { + t.Errorf("start order mismatch, was %s", startOrder) + } + if shutdownOrder != ">analytics>service>stats>database" && + shutdownOrder != ">analytics>stats>service>database" && + shutdownOrder != ">service>analytics>stats>database" { + t.Errorf("shutdown order mismatch, was %s", shutdownOrder) + } + + wg.Wait() } -func Example() { +func resetModules() { + for _, module := range modules { + module.Active.UnSet() + module.starting = false + } +} - // wait for logger registration timeout - time.Sleep(1010 * time.Millisecond) +func TestErrors(t *testing.T) { - newTestModule("1", 1) - newTestModule("4", 4) - newTestModule("3", 3) - newTestModule("2", 2) - newTestModule("5", 5) + // reset modules + modules = make(map[string]*Module) + modulesOrder = make([]*Module, 0) + startComplete.UnSet() - InitiateFullShutdown() + // test prep error + Register("prepfail", testFail, testStart("prepfail"), testStop("prepfail")) + err := Start() + if err == nil { + t.Error("should fail") + } - time.Sleep(10 * time.Millisecond) + // reset modules + modules = make(map[string]*Module) + modulesOrder = make([]*Module, 0) + startComplete.UnSet() - // Output: - // up 1 - // up 4 - // up 3 - // up 2 - // up 5 - // down 5 - // down 4 - // down 3 - // down 2 - // down 1 + // test prep clean exit + Register("prepcleanexit", testCleanExit, testStart("prepcleanexit"), testStop("prepcleanexit")) + err = Start() + if err != ErrCleanExit { + t.Error("should fail with clean exit") + } + + // reset modules + modules = make(map[string]*Module) + modulesOrder = make([]*Module, 0) + startComplete.UnSet() + + // test invalid dependency + Register("database", testPrep, testStart("database"), testStop("database"), "invalid") + // go func() { + // time.Sleep(1 * time.Second) + // fmt.Println("===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") + // pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + // os.Exit(1) + // }() + err = Start() + if err == nil { + t.Error("should fail") + } + + // reset modules + modules = make(map[string]*Module) + modulesOrder = make([]*Module, 0) + startComplete.UnSet() + + // test dependency loop + Register("database", testPrep, testStart("database"), testStop("database"), "helper") + Register("helper", testPrep, testStart("helper"), testStop("helper"), "database") + err = Start() + if err == nil { + t.Error("should fail") + } + + // reset modules + modules = make(map[string]*Module) + modulesOrder = make([]*Module, 0) + startComplete.UnSet() + + // test failing module start + Register("startfail", testPrep, testFail, testStop("startfail")) + err = Start() + if err == nil { + t.Error("should fail") + } + + // reset modules + modules = make(map[string]*Module) + modulesOrder = make([]*Module, 0) + startComplete.UnSet() + + // test failing module stop + Register("stopfail", testPrep, testStart("stopfail"), testFail) + err = Start() + if err != nil { + t.Error("should not fail") + } + err = Shutdown() + if err == nil { + t.Error("should fail") + } + + // reset modules + modules = make(map[string]*Module) + modulesOrder = make([]*Module, 0) + startComplete.UnSet() + + // test help flag + helpFlag = true + err = Start() + if err == nil { + t.Error("should fail") + } + helpFlag = false } diff --git a/modules/start.go b/modules/start.go new file mode 100644 index 0000000..b9b5f8e --- /dev/null +++ b/modules/start.go @@ -0,0 +1,142 @@ +package modules + +import ( + "fmt" + "os" + "sync" + + "github.com/Safing/portbase/log" +) + +// Start starts all modules in the correct order. In case of an error, it will automatically shutdown again. +func Start() error { + modulesLock.Lock() + defer modulesLock.Unlock() + + // parse flags + err := parseFlags() + if err != nil { + fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err) + return err + } + + // prep modules + err = prepareModules() + if err != nil { + if err != ErrCleanExit { + fmt.Fprintf(os.Stderr, "CRITICAL ERROR: %s\n", err) + } + return err + } + + // start logging + err = log.Start() + if err != nil { + fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to start logging: %s\n", err) + return err + } + + // start modules + log.Info("modules: initiating...") + err = startModules() + if err != nil { + log.Critical(err.Error()) + Shutdown() + return err + } + + startComplete.Set() + log.Infof("modules: started %d modules", len(modules)) + return nil +} + +func prepareModules() error { + for _, module := range modulesOrder { + err := module.prep() + if err != nil { + if err == ErrCleanExit { + return ErrCleanExit + } + return fmt.Errorf("failed to prep module %s: %s", module.Name, err) + } + } + return nil +} + +func checkStartStatus() (readyToStart []*Module, done bool, err error) { + active := 0 + modulesInProgress := false + + // go through all modules +moduleLoop: + for _, module := range modules { + switch { + case module.Active.IsSet(): + active++ + case module.starting: + modulesInProgress = true + default: + for _, depName := range module.dependencies { + depModule, ok := modules[depName] + if !ok { + return nil, false, fmt.Errorf("modules: module %s declares dependency \"%s\", but this module has not been registered", module.Name, depName) + } + if !depModule.Active.IsSet() { + continue moduleLoop + } + } + + readyToStart = append(readyToStart, module) + } + } + + // detect dependency loop + if active < len(modules) && !modulesInProgress && len(readyToStart) == 0 { + return nil, false, fmt.Errorf("modules: dependency loop detected, cannot continue") + } + + if active == len(modules) { + return nil, true, nil + } + return readyToStart, false, nil +} + +func startModules() error { + var modulesStarting sync.WaitGroup + + reports := make(chan error, 0) + for { + readyToStart, done, err := checkStartStatus() + if err != nil { + return err + } + + if done { + return nil + } + + for _, module := range readyToStart { + modulesStarting.Add(1) + module.starting = true + nextModule := module // workaround go vet alert + go func() { + startErr := nextModule.start() + if startErr != nil { + reports <- fmt.Errorf("modules: could not start module %s: %s", nextModule.Name, err) + } else { + log.Debugf("modules: started %s", nextModule.Name) + nextModule.Active.Set() + reports <- nil + } + modulesStarting.Done() + }() + } + + err = <-reports + if err != nil { + modulesStarting.Wait() + return err + } + + } +} diff --git a/modules/stop.go b/modules/stop.go new file mode 100644 index 0000000..06b6fad --- /dev/null +++ b/modules/stop.go @@ -0,0 +1,96 @@ +package modules + +import ( + "fmt" + + "github.com/tevino/abool" + + "github.com/Safing/portbase/log" +) + +var ( + shutdownSignal = make(chan struct{}) + shutdownSignalClosed = abool.NewBool(false) +) + +// ShuttingDown returns a channel read on the global shutdown signal. +func ShuttingDown() <-chan struct{} { + return shutdownSignal +} + +func checkStopStatus() (readyToStop []*Module, done bool) { + active := 0 + + // collect all active modules + activeModules := make(map[string]*Module) + for _, module := range modules { + if module.Active.IsSet() { + active++ + activeModules[module.Name] = module + } + } + if active == 0 { + return nil, true + } + + // remove modules that others depend on + for _, module := range activeModules { + for _, depName := range module.dependencies { + delete(activeModules, depName) + } + } + + // make list out of map + for _, module := range activeModules { + readyToStop = append(readyToStop, module) + } + + return readyToStop, false +} + +// Shutdown stops all modules in the correct order. +func Shutdown() error { + + if startComplete.IsSet() { + log.Warning("modules: starting shutdown...") + modulesLock.Lock() + defer modulesLock.Unlock() + } else { + log.Warning("modules: aborting, shutting down...") + } + + if shutdownSignalClosed.SetToIf(false, true) { + close(shutdownSignal) + } + + reports := make(chan error, 0) + for { + readyToStop, done := checkStopStatus() + + if done { + log.Info("modules: shutdown complete") + return nil + } + + for _, module := range readyToStop { + module.starting = false + nextModule := module // workaround go vet alert + go func() { + err := nextModule.stop() + nextModule.Active.UnSet() + if err != nil { + reports <- fmt.Errorf("modules: could not stop module %s: %s", nextModule.Name, err) + } else { + reports <- nil + } + }() + } + + err := <-reports + if err != nil { + log.Error(err.Error()) + return err + } + + } +}