diff --git a/modules/modules.go b/modules/modules.go index ebae2d0..ef20171 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -57,10 +57,11 @@ type Module struct { //nolint:maligned // start startComplete chan struct{} // stop - Ctx context.Context - cancelCtx func() - stopFlag *abool.AtomicBool - stopComplete chan struct{} + Ctx context.Context + cancelCtx func() + stopFlag *abool.AtomicBool + stopCompleted *abool.AtomicBool + stopComplete chan struct{} // workers/tasks ctrlFuncRunning *abool.AtomicBool @@ -255,12 +256,10 @@ func (m *Module) checkIfStopComplete() { atomic.LoadInt32(m.taskCnt) == 0 && atomic.LoadInt32(m.microTaskCnt) == 0 { - m.Lock() - defer m.Unlock() - - if m.stopComplete != nil { + if m.stopCompleted.SetToIf(false, true) { + m.Lock() + defer m.Unlock() close(m.stopComplete) - m.stopComplete = nil } } } @@ -283,60 +282,56 @@ func (m *Module) stop(reports chan *report) { // Reset start/stop signal channels. m.startComplete = make(chan struct{}) m.stopComplete = make(chan struct{}) + m.stopCompleted.SetTo(false) - // Make a copy of the stop channel. - stopComplete := m.stopComplete - - // Set status and cancel context. + // Set status. m.status = StatusStopping - m.stopFlag.Set() - m.cancelCtx() - go m.stopAllTasks(reports, stopComplete) + go m.stopAllTasks(reports) } -func (m *Module) stopAllTasks(reports chan *report, stopComplete chan struct{}) { - // start shutdown function - var stopFnError error - stopFuncRunning := abool.New() - if m.stopFn != nil { - stopFuncRunning.Set() - go func() { - stopFnError = m.runCtrlFn("stop module", m.stopFn) - stopFuncRunning.UnSet() - m.checkIfStopComplete() - }() - } else { - m.checkIfStopComplete() - } +func (m *Module) stopAllTasks(reports chan *report) { + // Manually set the control function flag in order to stop completion by race + // condition before stop function has even started. + m.ctrlFuncRunning.Set() + + // Set stop flag for everyone checking this flag before we activate any stop trigger. + m.stopFlag.Set() + + // Cancel the context to notify all workers and tasks. + m.cancelCtx() + + // Start stop function. + stopFnError := m.startCtrlFn("stop module", m.stopFn) // wait for results select { - case <-stopComplete: - // case <-time.After(moduleStopTimeout): + case <-m.stopComplete: + // Complete! case <-time.After(moduleStopTimeout): log.Warningf( - "%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v/%v workers=%d tasks=%d microtasks=%d, continuing shutdown...", + "%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v workers=%d tasks=%d microtasks=%d, continuing shutdown...", m.Name, - stopFuncRunning.IsSet(), m.ctrlFuncRunning.IsSet(), + m.ctrlFuncRunning.IsSet(), atomic.LoadInt32(m.workerCnt), atomic.LoadInt32(m.taskCnt), atomic.LoadInt32(m.microTaskCnt), ) } - // collect error + // Check for stop fn status. var err error - if stopFuncRunning.IsNotSet() && stopFnError != nil { - err = stopFnError - } - // set status - if err != nil { - m.Error( - fmt.Sprintf("%s:stop-failed", m.Name), - fmt.Sprintf("Stopping module %s failed", m.Name), - fmt.Sprintf("Failed to stop module: %s", err.Error()), - ) + select { + case err = <-stopFnError: + if err != nil { + // Set error as module error. + m.Error( + fmt.Sprintf("%s:stop-failed", m.Name), + fmt.Sprintf("Stopping module %s failed", m.Name), + fmt.Sprintf("Failed to stop module: %s", err.Error()), + ) + } + default: } // Always set to offline in order to let other modules shutdown in order. @@ -384,7 +379,7 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ... Name: name, enabled: abool.NewBool(false), enabledAsDependency: abool.NewBool(false), - sleepMode: abool.NewBool(true), + sleepMode: abool.NewBool(true), // Change (for init) is triggered below. sleepWaitingChannel: make(chan time.Time), prepFn: prep, startFn: start, @@ -393,6 +388,7 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ... Ctx: ctx, cancelCtx: cancelCtx, stopFlag: abool.NewBool(false), + stopCompleted: abool.NewBool(true), ctrlFuncRunning: abool.NewBool(false), workerCnt: &workerCnt, taskCnt: &taskCnt, @@ -401,7 +397,7 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ... depNames: dependencies, } - // Sleep mode is disabled by default + // Sleep mode is disabled by default. newModule.Sleep(false) return newModule diff --git a/modules/worker.go b/modules/worker.go index 249be06..61b01d3 100644 --- a/modules/worker.go +++ b/modules/worker.go @@ -135,10 +135,7 @@ func (m *Module) runWorker(name string, fn func(context.Context) error) (err err } func (m *Module) runCtrlFnWithTimeout(name string, timeout time.Duration, fn func() error) error { - stopFnError := make(chan error) - go func() { - stopFnError <- m.runCtrlFn(name, fn) - }() + stopFnError := m.startCtrlFn(name, fn) // wait for results select { @@ -149,26 +146,44 @@ func (m *Module) runCtrlFnWithTimeout(name string, timeout time.Duration, fn fun } } -func (m *Module) runCtrlFn(name string, fn func() error) (err error) { +func (m *Module) startCtrlFn(name string, fn func() error) chan error { + ctrlFnError := make(chan error, 1) + + // If no function is given, still act as if it was run. if fn == nil { - return + // Signal finish. + m.ctrlFuncRunning.UnSet() + m.checkIfStopComplete() + + // Report nil error and return. + ctrlFnError <- nil + return ctrlFnError } - if m.ctrlFuncRunning.SetToIf(false, true) { - defer m.ctrlFuncRunning.SetToIf(true, false) - } + // Signal that a control function is running. + m.ctrlFuncRunning.Set() - defer func() { - // recover from panic - panicVal := recover() - if panicVal != nil { - me := m.NewPanicError(name, "module-control", panicVal) - me.Report() - err = me - } + // Start control function in goroutine. + go func() { + // Recover from panic and reset control function signal. + defer func() { + // recover from panic + panicVal := recover() + if panicVal != nil { + me := m.NewPanicError(name, "module-control", panicVal) + me.Report() + ctrlFnError <- fmt.Errorf("panic: %s", panicVal) + } + + // Signal finish. + m.ctrlFuncRunning.UnSet() + m.checkIfStopComplete() + }() + + // Run control function and report error. + err := fn() + ctrlFnError <- err }() - // run - err = fn() - return + return ctrlFnError }