Finish modules/tasks revamp

This commit is contained in:
Daniel 2019-09-12 09:37:08 +02:00
parent 4e99dd2153
commit 71dabc1f23
11 changed files with 542 additions and 296 deletions

18
modules/doc.go Normal file
View file

@ -0,0 +1,18 @@
/*
Package modules provides a full module and task management ecosystem to cleanly put all big and small moving parts of a service together.
Modules are started in a multi-stage process and may depend on other modules:
- Go's init(): register flags
- prep: check flags, register config variables
- start: start actual work, access config
- stop: gracefully shut down
Workers: A simple function that is run by the module while catching panics and reporting them. Ideal for long running (possibly) idle goroutines. Can be automatically restarted if execution ends with an error.
Tasks: Functions that take somewhere between a couple seconds and a couple minutes to execute and should be queued, scheduled or repeated.
MicroTasks: Functions that take less than a second to execute, but require lots of resources. Running such functions as MicroTasks will reduce concurrent execution and shall improve performance.
Ideally, _any_ execution by a module is done through these methods. This will not only ensure that all panics are caught, but will also give better insights into how your service performs.
*/
package modules

90
modules/error.go Normal file
View file

@ -0,0 +1,90 @@
package modules
import (
"fmt"
"runtime/debug"
)
var (
errorReportingChannel chan *ModuleError
)
// ModuleError wraps a panic, error or message into an error that can be reported.
type ModuleError struct {
Message string
ModuleName string
TaskName string
TaskType string // one of "worker", "task", "microtask" or custom
Severity string // one of "info", "error", "panic" or custom
PanicValue interface{}
StackTrace string
}
// NewInfoMessage creates a new, reportable, info message (including a stack trace).
func (m *Module) NewInfoMessage(message string) *ModuleError {
return &ModuleError{
Message: message,
ModuleName: m.Name,
Severity: "info",
StackTrace: string(debug.Stack()),
}
}
// NewErrorMessage creates a new, reportable, error message (including a stack trace).
func (m *Module) NewErrorMessage(taskName string, err error) *ModuleError {
return &ModuleError{
Message: err.Error(),
ModuleName: m.Name,
Severity: "error",
StackTrace: string(debug.Stack()),
}
}
// NewPanicError creates a new, reportable, panic error message (including a stack trace).
func (m *Module) NewPanicError(taskName, taskType string, panicValue interface{}) *ModuleError {
me := &ModuleError{
Message: fmt.Sprintf("panic: %s", panicValue),
ModuleName: m.Name,
TaskName: taskName,
TaskType: taskType,
Severity: "panic",
PanicValue: panicValue,
StackTrace: string(debug.Stack()),
}
me.Message = me.Error()
return me
}
// Error returns the string representation of the error.
func (me *ModuleError) Error() string {
return me.Message
}
// Report reports the error through the configured reporting channel.
func (me *ModuleError) Report() {
if errorReportingChannel != nil {
select {
case errorReportingChannel <- me:
default:
}
}
}
// IsPanic returns whether the given error is a wrapped panic by the modules package and additionally returns it, if true.
func IsPanic(err error) (bool, *ModuleError) {
switch val := err.(type) {
case *ModuleError:
return true, val
default:
return false, nil
}
}
// SetErrorReportingChannel sets the channel to report module errors through. By default only panics are reported, all other errors need to be manually wrapped into a *ModuleError and reported.
func SetErrorReportingChannel(reportingChannel chan *ModuleError) {
if errorReportingChannel == nil {
errorReportingChannel = reportingChannel
}
}

View file

@ -1,9 +1,11 @@
package modules package modules
import ( import (
"context"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/safing/portbase/log"
"github.com/tevino/abool" "github.com/tevino/abool"
) )
@ -12,162 +14,140 @@ import (
// (2) sometimes there seems to some kind of race condition stuff, the test hangs and does not complete // (2) sometimes there seems to some kind of race condition stuff, the test hangs and does not complete
var ( var (
closedChannel chan bool microTasks *int32
microTasksThreshhold *int32
microTaskFinished = make(chan struct{}, 1)
tasks *int32 mediumPriorityClearance = make(chan struct{})
lowPriorityClearance = make(chan struct{})
mediumPriorityClearance chan bool
lowPriorityClearance chan bool
veryLowPriorityClearance chan bool
tasksDone chan bool
tasksDoneFlag *abool.AtomicBool
tasksWaiting chan bool
tasksWaitingFlag *abool.AtomicBool
) )
// StartMicroTask starts a new MicroTask. It will start immediately. const (
func StartMicroTask() { mediumPriorityMaxDelay = 1 * time.Second
atomic.AddInt32(tasks, 1) lowPriorityMaxDelay = 3 * time.Second
tasksDoneFlag.UnSet() )
}
// EndMicroTask MUST be always called when a MicroTask was previously started.
func EndMicroTask() {
c := atomic.AddInt32(tasks, -1)
if c < 1 {
if tasksDoneFlag.SetToIf(false, true) {
tasksDone <- true
}
}
}
func newTaskIsWaiting() {
tasksWaiting <- true
}
// StartMediumPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
func StartMediumPriorityMicroTask() chan bool {
if shutdownSignalClosed.IsSet() {
return closedChannel
}
if tasksWaitingFlag.SetToIf(false, true) {
defer newTaskIsWaiting()
}
return mediumPriorityClearance
}
// StartLowPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
func StartLowPriorityMicroTask() chan bool {
if shutdownSignalClosed.IsSet() {
return closedChannel
}
if tasksWaitingFlag.SetToIf(false, true) {
defer newTaskIsWaiting()
}
return lowPriorityClearance
}
// StartVeryLowPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
func StartVeryLowPriorityMicroTask() chan bool {
if shutdownSignalClosed.IsSet() {
return closedChannel
}
if tasksWaitingFlag.SetToIf(false, true) {
defer newTaskIsWaiting()
}
return veryLowPriorityClearance
}
func start() error {
return nil
}
func stop() error {
close(shutdownSignal)
return nil
}
func init() { func init() {
var microTasksVal int32
microTasks = &microTasksVal
var microTasksThreshholdVal int32
microTasksThreshhold = &microTasksThreshholdVal
}
closedChannel = make(chan bool, 0) // SetMaxConcurrentMicroTasks sets the maximum number of microtasks that should be run concurrently.
close(closedChannel) func SetMaxConcurrentMicroTasks(n int) {
if n < 4 {
atomic.StoreInt32(microTasksThreshhold, 4)
} else {
atomic.StoreInt32(microTasksThreshhold, int32(n))
}
}
var t int32 // StartMicroTask starts a new MicroTask with high priority. It will start immediately. The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
tasks = &t func (m *Module) StartMicroTask(name *string, fn func(context.Context) error) error {
atomic.AddInt32(microTasks, 1)
return m.runMicroTask(name, fn)
}
mediumPriorityClearance = make(chan bool, 0) // StartMediumPriorityMicroTask starts a new MicroTask with medium priority. It will wait until given a go (max 3 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
lowPriorityClearance = make(chan bool, 0) func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Context) error) error {
veryLowPriorityClearance = make(chan bool, 0) // check if we can go immediately
select {
case <-mediumPriorityClearance:
default:
// wait for go or max delay
select {
case <-mediumPriorityClearance:
case <-time.After(mediumPriorityMaxDelay):
}
}
return m.runMicroTask(name, fn)
}
tasksDone = make(chan bool, 1) // StartLowPriorityMicroTask starts a new MicroTask with low priority. It will wait until given a go (max 15 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
tasksDoneFlag = abool.NewBool(true) func (m *Module) StartLowPriorityMicroTask(name *string, fn func(context.Context) error) error {
tasksWaiting = make(chan bool, 1) // check if we can go immediately
tasksWaitingFlag = abool.NewBool(false) select {
case <-lowPriorityClearance:
default:
// wait for go or max delay
select {
case <-lowPriorityClearance:
case <-time.After(lowPriorityMaxDelay):
}
}
return m.runMicroTask(name, fn)
}
timoutTimerDuration := 1 * time.Second func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err error) {
// timoutTimer := time.NewTimer(timoutTimerDuration) // start for module
// hint: only microTasks global var is important for scheduling, others can be set here
atomic.AddInt32(m.microTaskCnt, 1)
m.waitGroup.Add(1)
go func() { // set up recovery
microTaskManageLoop: defer func() {
for { // recover from panic
panicVal := recover()
if panicVal != nil {
me := m.NewPanicError(*name, "microtask", panicVal)
me.Report()
log.Errorf("%s: microtask %s panicked: %s\n%s", m.Name, *name, panicVal, me.StackTrace)
err = me
}
// wait for an event to start new tasks // finish for module
if !shutdownSignalClosed.IsSet() { atomic.AddInt32(m.microTaskCnt, -1)
m.waitGroup.Done()
// reset timer
// https://golang.org/pkg/time/#Timer.Reset
// if !timoutTimer.Stop() {
// <-timoutTimer.C
// }
// timoutTimer.Reset(timoutTimerDuration)
// wait for event to start a new task
select {
case <-tasksWaiting:
if !tasksDoneFlag.IsSet() {
continue microTaskManageLoop
}
case <-time.After(timoutTimerDuration):
case <-tasksDone:
case <-shutdownSignal:
}
} else {
// execute tasks until no tasks are waiting anymore
if !tasksWaitingFlag.IsSet() {
// wait until tasks are finished
if !tasksDoneFlag.IsSet() {
<-tasksDone
}
// signal module completion
// microTasksModule.StopComplete()
// exit
return
}
}
// start new task, if none is started, check if we are shutting down
select {
case mediumPriorityClearance <- true:
StartMicroTask()
default:
select {
case lowPriorityClearance <- true:
StartMicroTask()
default:
select {
case veryLowPriorityClearance <- true:
StartMicroTask()
default:
tasksWaitingFlag.UnSet()
}
}
}
// finish and possibly trigger next task
atomic.AddInt32(microTasks, -1)
select {
case microTaskFinished <- struct{}{}:
default:
} }
}() }()
// run
err = fn(m.Ctx)
return //nolint:nakedret // need to use named return val in order to change in defer
}
var (
microTaskSchedulerStarted = abool.NewBool(false)
)
func microTaskScheduler() {
// only ever start once
if !microTaskSchedulerStarted.SetToIf(false, true) {
return
}
microTaskManageLoop:
for {
if shutdownSignalClosed.IsSet() {
close(mediumPriorityClearance)
close(lowPriorityClearance)
return
}
if atomic.LoadInt32(microTasks) < atomic.LoadInt32(microTasksThreshhold) { // space left for firing task
select {
case mediumPriorityClearance <- struct{}{}:
default:
select {
case taskTimeslot <- struct{}{}:
continue microTaskManageLoop
case mediumPriorityClearance <- struct{}{}:
case lowPriorityClearance <- struct{}{}:
}
}
// increase task counter
atomic.AddInt32(microTasks, 1)
} else {
// wait for signal that a task was completed
<-microTaskFinished
}
}
} }

View file

@ -1,60 +1,97 @@
package modules package modules
import ( import (
"context"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
var (
mtTestName = "microtask test"
mtModule = initNewModule("microtask test module", nil, nil, nil)
)
func init() {
go microTaskScheduler()
}
// test waiting // test waiting
func TestMicroTaskWaiting(t *testing.T) { func TestMicroTaskWaiting(t *testing.T) {
// skip // skip
if testing.Short() { if testing.Short() {
t.Skip("skipping test in short mode.") t.Skip("skipping test in short mode, as it is not fully deterministic")
} }
// init // init
mtwWaitGroup := new(sync.WaitGroup) mtwWaitGroup := new(sync.WaitGroup)
mtwOutputChannel := make(chan string, 100) mtwOutputChannel := make(chan string, 100)
mtwExpectedOutput := "123456" mtwExpectedOutput := "1234567"
mtwSleepDuration := 10 * time.Millisecond mtwSleepDuration := 10 * time.Millisecond
// TEST // TEST
mtwWaitGroup.Add(3) mtwWaitGroup.Add(4)
// ensure we only execute one microtask at once
atomic.StoreInt32(microTasksThreshhold, 1)
// High Priority - slot 1-5 // High Priority - slot 1-5
go func() { go func() {
defer mtwWaitGroup.Done() defer mtwWaitGroup.Done()
StartMicroTask() // exec at slot 1
mtwOutputChannel <- "1" _ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
time.Sleep(mtwSleepDuration * 5) mtwOutputChannel <- "1" // slot 1
mtwOutputChannel <- "2" time.Sleep(mtwSleepDuration * 5)
EndMicroTask() mtwOutputChannel <- "2" // slot 5
return nil
})
}() }()
time.Sleep(mtwSleepDuration * 2) time.Sleep(mtwSleepDuration * 1)
// clear clearances
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
return nil
})
// Low Priority - slot 16
go func() {
defer mtwWaitGroup.Done()
// exec at slot 2
_ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "7" // slot 16
return nil
})
}()
time.Sleep(mtwSleepDuration * 1)
// High Priority - slot 10-15 // High Priority - slot 10-15
go func() { go func() {
defer mtwWaitGroup.Done() defer mtwWaitGroup.Done()
time.Sleep(mtwSleepDuration * 8) time.Sleep(mtwSleepDuration * 8)
StartMicroTask() // exec at slot 10
mtwOutputChannel <- "4" _ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
time.Sleep(mtwSleepDuration * 5) mtwOutputChannel <- "4" // slot 10
mtwOutputChannel <- "6" time.Sleep(mtwSleepDuration * 5)
EndMicroTask() mtwOutputChannel <- "6" // slot 15
return nil
})
}() }()
// Medium Priority - Waits at slot 3, should execute in slot 6-13 // Medium Priority - slot 6-13
go func() { go func() {
defer mtwWaitGroup.Done() defer mtwWaitGroup.Done()
<-StartMediumPriorityMicroTask() // exec at slot 3
mtwOutputChannel <- "3" _ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
time.Sleep(mtwSleepDuration * 7) mtwOutputChannel <- "3" // slot 6
mtwOutputChannel <- "5" time.Sleep(mtwSleepDuration * 7)
EndMicroTask() mtwOutputChannel <- "5" // slot 13
return nil
})
}() }()
// wait for test to finish // wait for test to finish
@ -67,6 +104,7 @@ func TestMicroTaskWaiting(t *testing.T) {
completeOutput += s completeOutput += s
} }
// check if test succeeded // check if test succeeded
t.Logf("microTask wait order: %s", completeOutput)
if completeOutput != mtwExpectedOutput { if completeOutput != mtwExpectedOutput {
t.Errorf("MicroTask waiting test failed, expected sequence %s, got %s", mtwExpectedOutput, completeOutput) t.Errorf("MicroTask waiting test failed, expected sequence %s, got %s", mtwExpectedOutput, completeOutput)
} }
@ -78,34 +116,27 @@ func TestMicroTaskWaiting(t *testing.T) {
// globals // globals
var mtoWaitGroup sync.WaitGroup var mtoWaitGroup sync.WaitGroup
var mtoOutputChannel chan string var mtoOutputChannel chan string
var mtoWaitCh chan bool var mtoWaitCh chan struct{}
// functions // functions
func mediumPrioTaskTester() { func mediumPrioTaskTester() {
defer mtoWaitGroup.Done() defer mtoWaitGroup.Done()
<-mtoWaitCh <-mtoWaitCh
<-StartMediumPriorityMicroTask() _ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtoOutputChannel <- "1" mtoOutputChannel <- "1"
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
EndMicroTask() return nil
})
} }
func lowPrioTaskTester() { func lowPrioTaskTester() {
defer mtoWaitGroup.Done() defer mtoWaitGroup.Done()
<-mtoWaitCh <-mtoWaitCh
<-StartLowPriorityMicroTask() _ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtoOutputChannel <- "2" mtoOutputChannel <- "2"
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
EndMicroTask() return nil
} })
func veryLowPrioTaskTester() {
defer mtoWaitGroup.Done()
<-mtoWaitCh
<-StartVeryLowPriorityMicroTask()
mtoOutputChannel <- "3"
time.Sleep(2 * time.Millisecond)
EndMicroTask()
} }
// test // test
@ -113,53 +144,51 @@ func TestMicroTaskOrdering(t *testing.T) {
// skip // skip
if testing.Short() { if testing.Short() {
t.Skip("skipping test in short mode.") t.Skip("skipping test in short mode, as it is not fully deterministic")
} }
// init // init
mtoOutputChannel = make(chan string, 100) mtoOutputChannel = make(chan string, 100)
mtoWaitCh = make(chan bool, 0) mtoWaitCh = make(chan struct{})
// TEST // TEST
mtoWaitGroup.Add(30) mtoWaitGroup.Add(20)
// ensure we only execute one microtask at once
atomic.StoreInt32(microTasksThreshhold, 1)
// kick off // kick off
go mediumPrioTaskTester() go mediumPrioTaskTester()
go mediumPrioTaskTester() go mediumPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go veryLowPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester()
go veryLowPrioTaskTester()
go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go mediumPrioTaskTester() go mediumPrioTaskTester()
go veryLowPrioTaskTester() go lowPrioTaskTester()
go mediumPrioTaskTester()
go mediumPrioTaskTester() go mediumPrioTaskTester()
go mediumPrioTaskTester() go mediumPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go mediumPrioTaskTester() go mediumPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go mediumPrioTaskTester() go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go mediumPrioTaskTester() go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go mediumPrioTaskTester() go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester() go lowPrioTaskTester()
go veryLowPrioTaskTester()
// wait for all goroutines to be ready // wait for all goroutines to be ready
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// sync all goroutines // sync all goroutines
close(mtoWaitCh) close(mtoWaitCh)
// trigger
select {
case microTaskFinished <- struct{}{}:
default:
}
// wait for test to finish // wait for test to finish
mtoWaitGroup.Wait() mtoWaitGroup.Wait()
@ -171,7 +200,8 @@ func TestMicroTaskOrdering(t *testing.T) {
completeOutput += s completeOutput += s
} }
// check if test succeeded // check if test succeeded
if !strings.Contains(completeOutput, "11111") || !strings.Contains(completeOutput, "22222") || !strings.Contains(completeOutput, "33333") { t.Logf("microTask exec order: %s", completeOutput)
if !strings.Contains(completeOutput, "11111") || !strings.Contains(completeOutput, "22222") {
t.Errorf("MicroTask ordering test failed, output was %s. This happens occasionally, please run the test multiple times to verify", completeOutput) t.Errorf("MicroTask ordering test failed, output was %s. This happens occasionally, please run the test multiple times to verify", completeOutput)
} }

View file

@ -39,8 +39,12 @@ type Module struct {
Ctx context.Context Ctx context.Context
cancelCtx func() cancelCtx func()
shutdownFlag *abool.AtomicBool shutdownFlag *abool.AtomicBool
workerGroup sync.WaitGroup
// workers/tasks
workerCnt *int32 workerCnt *int32
taskCnt *int32
microTaskCnt *int32
waitGroup sync.WaitGroup
// dependency mgmt // dependency mgmt
depNames []string depNames []string
@ -48,27 +52,6 @@ type Module struct {
depReverse []*Module depReverse []*Module
} }
// AddWorkers adds workers to the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
func (m *Module) AddWorkers(n uint) {
if !m.ShutdownInProgress() {
if atomic.AddInt32(m.workerCnt, int32(n)) > 0 {
// only add to workgroup if cnt is positive (try to compensate wrong usage)
m.workerGroup.Add(int(n))
}
}
}
// FinishWorker removes a worker from the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
func (m *Module) FinishWorker() {
// check worker cnt
if atomic.AddInt32(m.workerCnt, -1) < 0 {
log.Warningf("modules: %s module tried to finish more workers than added, this may lead to undefined behavior when shutting down", m.Name)
return
}
// also mark worker done in workgroup
m.workerGroup.Done()
}
// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead. // ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead.
func (m *Module) ShutdownInProgress() bool { func (m *Module) ShutdownInProgress() bool {
return m.shutdownFlag.IsSet() return m.shutdownFlag.IsSet()
@ -87,13 +70,19 @@ func (m *Module) shutdown() error {
// wait for workers // wait for workers
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
m.workerGroup.Wait() m.waitGroup.Wait()
close(done) close(done)
}() }()
select { select {
case <-done: case <-done:
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
return errors.New("timed out while waiting for module workers to finish") log.Warningf(
"%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...",
m.Name,
atomic.LoadInt32(m.workerCnt),
atomic.LoadInt32(m.taskCnt),
atomic.LoadInt32(m.microTaskCnt),
)
} }
// call shutdown function // call shutdown function
@ -106,8 +95,19 @@ func dummyAction() error {
// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished. // Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished.
func Register(name string, prep, start, stop func() error, dependencies ...string) *Module { func Register(name string, prep, start, stop func() error, dependencies ...string) *Module {
newModule := initNewModule(name, prep, start, stop, dependencies...)
modulesLock.Lock()
defer modulesLock.Unlock()
modules[name] = newModule
return newModule
}
func initNewModule(name string, prep, start, stop func() error, dependencies ...string) *Module {
ctx, cancelCtx := context.WithCancel(context.Background()) ctx, cancelCtx := context.WithCancel(context.Background())
var workerCnt int32 var workerCnt int32
var taskCnt int32
var microTaskCnt int32
newModule := &Module{ newModule := &Module{
Name: name, Name: name,
@ -118,8 +118,10 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin
Ctx: ctx, Ctx: ctx,
cancelCtx: cancelCtx, cancelCtx: cancelCtx,
shutdownFlag: abool.NewBool(false), shutdownFlag: abool.NewBool(false),
workerGroup: sync.WaitGroup{}, waitGroup: sync.WaitGroup{},
workerCnt: &workerCnt, workerCnt: &workerCnt,
taskCnt: &taskCnt,
microTaskCnt: &microTaskCnt,
prep: prep, prep: prep,
start: start, start: start,
stop: stop, stop: stop,
@ -137,9 +139,6 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin
newModule.stop = dummyAction newModule.stop = dummyAction
} }
modulesLock.Lock()
defer modulesLock.Unlock()
modules[name] = newModule
return newModule return newModule
} }

View file

@ -14,28 +14,28 @@ var (
shutdownOrder string shutdownOrder string
) )
func testPrep(name string) func() error { func testPrep(t *testing.T, name string) func() error {
return func() error { return func() error {
// fmt.Printf("prep %s\n", name) t.Logf("prep %s\n", name)
return nil return nil
} }
} }
func testStart(name string) func() error { func testStart(t *testing.T, name string) func() error {
return func() error { return func() error {
orderLock.Lock() orderLock.Lock()
defer orderLock.Unlock() defer orderLock.Unlock()
// fmt.Printf("start %s\n", name) t.Logf("start %s\n", name)
startOrder = fmt.Sprintf("%s>%s", startOrder, name) startOrder = fmt.Sprintf("%s>%s", startOrder, name)
return nil return nil
} }
} }
func testStop(name string) func() error { func testStop(t *testing.T, name string) func() error {
return func() error { return func() error {
orderLock.Lock() orderLock.Lock()
defer orderLock.Unlock() defer orderLock.Unlock()
// fmt.Printf("stop %s\n", name) t.Logf("stop %s\n", name)
shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name) shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name)
return nil return nil
} }
@ -49,12 +49,19 @@ func testCleanExit() error {
return ErrCleanExit return ErrCleanExit
} }
func TestOrdering(t *testing.T) { func TestModules(t *testing.T) {
t.Parallel() // Not really, just a workaround for running these tests last.
Register("database", testPrep("database"), testStart("database"), testStop("database")) t.Run("TestModuleOrder", testModuleOrder)
Register("stats", testPrep("stats"), testStart("stats"), testStop("stats"), "database") t.Run("TestModuleErrors", testModuleErrors)
Register("service", testPrep("service"), testStart("service"), testStop("service"), "database") }
Register("analytics", testPrep("analytics"), testStart("analytics"), testStop("analytics"), "stats", "database")
func testModuleOrder(t *testing.T) {
Register("database", testPrep(t, "database"), testStart(t, "database"), testStop(t, "database"))
Register("stats", testPrep(t, "stats"), testStart(t, "stats"), testStop(t, "stats"), "database")
Register("service", testPrep(t, "service"), testStart(t, "service"), testStop(t, "service"), "database")
Register("analytics", testPrep(t, "analytics"), testStart(t, "analytics"), testStop(t, "analytics"), "stats", "database")
err := Start() err := Start()
if err != nil { if err != nil {
@ -105,19 +112,7 @@ func printAndRemoveModules() {
modules = make(map[string]*Module) modules = make(map[string]*Module)
} }
func resetModules() { func testModuleErrors(t *testing.T) {
for _, module := range modules {
module.Prepped.UnSet()
module.Started.UnSet()
module.Stopped.UnSet()
module.inTransition.UnSet()
module.depModules = make([]*Module, 0)
module.depModules = make([]*Module, 0)
}
}
func TestErrors(t *testing.T) {
// reset modules // reset modules
modules = make(map[string]*Module) modules = make(map[string]*Module)
@ -125,7 +120,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test prep error // test prep error
Register("prepfail", testFail, testStart("prepfail"), testStop("prepfail")) Register("prepfail", testFail, testStart(t, "prepfail"), testStop(t, "prepfail"))
err := Start() err := Start()
if err == nil { if err == nil {
t.Error("should fail") t.Error("should fail")
@ -137,7 +132,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test prep clean exit // test prep clean exit
Register("prepcleanexit", testCleanExit, testStart("prepcleanexit"), testStop("prepcleanexit")) Register("prepcleanexit", testCleanExit, testStart(t, "prepcleanexit"), testStop(t, "prepcleanexit"))
err = Start() err = Start()
if err != ErrCleanExit { if err != ErrCleanExit {
t.Error("should fail with clean exit") t.Error("should fail with clean exit")
@ -149,7 +144,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test invalid dependency // test invalid dependency
Register("database", nil, testStart("database"), testStop("database"), "invalid") Register("database", nil, testStart(t, "database"), testStop(t, "database"), "invalid")
err = Start() err = Start()
if err == nil { if err == nil {
t.Error("should fail") t.Error("should fail")
@ -161,8 +156,8 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test dependency loop // test dependency loop
Register("database", nil, testStart("database"), testStop("database"), "helper") Register("database", nil, testStart(t, "database"), testStop(t, "database"), "helper")
Register("helper", nil, testStart("helper"), testStop("helper"), "database") Register("helper", nil, testStart(t, "helper"), testStop(t, "helper"), "database")
err = Start() err = Start()
if err == nil { if err == nil {
t.Error("should fail") t.Error("should fail")
@ -174,7 +169,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test failing module start // test failing module start
Register("startfail", nil, testFail, testStop("startfail")) Register("startfail", nil, testFail, testStop(t, "startfail"))
err = Start() err = Start()
if err == nil { if err == nil {
t.Error("should fail") t.Error("should fail")
@ -186,7 +181,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test failing module stop // test failing module stop
Register("stopfail", nil, testStart("stopfail"), testFail) Register("stopfail", nil, testStart(t, "stopfail"), testFail)
err = Start() err = Start()
if err != nil { if err != nil {
t.Error("should not fail") t.Error("should not fail")

View file

@ -3,6 +3,7 @@ package modules
import ( import (
"fmt" "fmt"
"os" "os"
"runtime"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -13,13 +14,6 @@ var (
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
) )
// markStartComplete marks the startup as completed.
func markStartComplete() {
if startComplete.SetToIf(false, true) {
close(startCompleteSignal)
}
}
// StartCompleted returns whether starting has completed. // StartCompleted returns whether starting has completed.
func StartCompleted() bool { func StartCompleted() bool {
return startComplete.IsSet() return startComplete.IsSet()
@ -35,6 +29,10 @@ func Start() error {
modulesLock.Lock() modulesLock.Lock()
defer modulesLock.Unlock() defer modulesLock.Unlock()
// start microtask scheduler
go microTaskScheduler()
SetMaxConcurrentMicroTasks(runtime.GOMAXPROCS(0) * 2)
// inter-link modules // inter-link modules
err := initDependencies() err := initDependencies()
if err != nil { if err != nil {

View file

@ -4,6 +4,7 @@ import (
"container/list" "container/list"
"context" "context"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -15,7 +16,7 @@ import (
type Task struct { type Task struct {
name string name string
module *Module module *Module
taskFn TaskFn taskFn func(context.Context, *Task)
queued bool queued bool
canceled bool canceled bool
@ -33,9 +34,6 @@ type Task struct {
lock sync.Mutex lock sync.Mutex
} }
// TaskFn is the function signature for creating Tasks.
type TaskFn func(ctx context.Context, task *Task)
var ( var (
taskQueue = list.New() taskQueue = list.New()
prioritizedTaskQueue = list.New() prioritizedTaskQueue = list.New()
@ -49,20 +47,22 @@ var (
queueIsFilled = make(chan struct{}, 1) // kick off queue handler queueIsFilled = make(chan struct{}, 1) // kick off queue handler
recalculateNextScheduledTask = make(chan struct{}, 1) recalculateNextScheduledTask = make(chan struct{}, 1)
taskTimeslot = make(chan struct{})
) )
const ( const (
maxTimeslotWait = 30 * time.Second
minRepeatDuration = 1 * time.Minute
maxExecutionWait = 1 * time.Minute maxExecutionWait = 1 * time.Minute
defaultMaxDelay = 5 * time.Minute defaultMaxDelay = 1 * time.Minute
minRepeatDuration = 1 * time.Second
) )
// NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed. // NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed.
func (m *Module) NewTask(name string, taskFn TaskFn) *Task { func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task {
return &Task{ return &Task{
name: name, name: name,
module: m, module: m,
taskFn: taskFn, taskFn: fn,
maxDelay: defaultMaxDelay, maxDelay: defaultMaxDelay,
} }
} }
@ -168,7 +168,7 @@ func (t *Task) Schedule(executeAt time.Time) *Task {
return t return t
} }
// Repeat sets the task to be executed in endless repeat at the specified interval. First execution will be after interval. Minimum repeat interval is one second. // Repeat sets the task to be executed in endless repeat at the specified interval. First execution will be after interval. Minimum repeat interval is one minute.
func (t *Task) Repeat(interval time.Duration) *Task { func (t *Task) Repeat(interval time.Duration) *Task {
// check minimum interval duration // check minimum interval duration
if interval < minRepeatDuration { if interval < minRepeatDuration {
@ -194,6 +194,12 @@ func (t *Task) Cancel() {
} }
func (t *Task) runWithLocking() { func (t *Task) runWithLocking() {
// wait for good timeslot regarding microtasks
select {
case <-taskTimeslot:
case <-time.After(maxTimeslotWait):
}
t.lock.Lock() t.lock.Lock()
// check state, return if already executing or inactive // check state, return if already executing or inactive
@ -240,8 +246,6 @@ func (t *Task) runWithLocking() {
t.lock.Unlock() t.lock.Unlock()
} }
// add to module workers
t.module.AddWorkers(1)
// add to queue workgroup // add to queue workgroup
queueWg.Add(1) queueWg.Add(1)
@ -257,15 +261,23 @@ func (t *Task) runWithLocking() {
} }
func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) { func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) {
// start for module
// hint: only queueWg global var is important for scheduling, others can be set here
atomic.AddInt32(t.module.taskCnt, 1)
t.module.waitGroup.Add(1)
defer func() { defer func() {
// log result if error // recover from panic
panicVal := recover() panicVal := recover()
if panicVal != nil { if panicVal != nil {
log.Errorf("%s: task %s panicked: %s", t.module.Name, t.name, panicVal) me := t.module.NewPanicError(t.name, "task", panicVal)
me.Report()
log.Errorf("%s: task %s panicked: %s\n%s", t.module.Name, t.name, panicVal, me.StackTrace)
} }
// mark task as completed // finish for module
t.module.FinishWorker() atomic.AddInt32(t.module.taskCnt, -1)
t.module.waitGroup.Done()
// reset // reset
t.lock.Lock() t.lock.Lock()
@ -282,6 +294,8 @@ func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) {
// notify that we finished // notify that we finished
cancelFunc() cancelFunc()
}() }()
// run
t.taskFn(ctx, t) t.taskFn(ctx, t)
} }
@ -335,7 +349,7 @@ func waitUntilNextScheduledTask() <-chan time.Time {
defer scheduleLock.Unlock() defer scheduleLock.Unlock()
if taskSchedule.Len() > 0 { if taskSchedule.Len() > 0 {
return time.After(taskSchedule.Front().Value.(*Task).executeAt.Sub(time.Now())) return time.After(time.Until(taskSchedule.Front().Value.(*Task).executeAt))
} }
return waitForever return waitForever
} }
@ -346,6 +360,7 @@ var (
) )
func taskQueueHandler() { func taskQueueHandler() {
// only ever start once
if !taskQueueHandlerStarted.SetToIf(false, true) { if !taskQueueHandlerStarted.SetToIf(false, true) {
return return
} }
@ -396,6 +411,7 @@ func taskQueueHandler() {
} }
func taskScheduleHandler() { func taskScheduleHandler() {
// only ever start once
if !taskScheduleHandlerStarted.SetToIf(false, true) { if !taskScheduleHandlerStarted.SetToIf(false, true) {
return return
} }

View file

@ -8,8 +8,6 @@ import (
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/tevino/abool"
) )
func init() { func init() {
@ -19,9 +17,16 @@ func init() {
go func() { go func() {
<-time.After(10 * time.Second) <-time.After(10 * time.Second)
fmt.Fprintln(os.Stderr, "taking too long") fmt.Fprintln(os.Stderr, "taking too long")
pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) _ = pprof.Lookup("goroutine").WriteTo(os.Stderr, 2)
os.Exit(1) os.Exit(1)
}() }()
// always trigger task timeslot for testing
go func() {
for {
taskTimeslot <- struct{}{}
}
}()
} }
// test waiting // test waiting
@ -30,18 +35,7 @@ func init() {
var qtWg sync.WaitGroup var qtWg sync.WaitGroup
var qtOutputChannel chan string var qtOutputChannel chan string
var qtSleepDuration time.Duration var qtSleepDuration time.Duration
var qtWorkerCnt int32 var qtModule = initNewModule("task test module", nil, nil, nil)
var qtModule = &Module{
Name: "task test module",
Prepped: abool.NewBool(false),
Started: abool.NewBool(false),
Stopped: abool.NewBool(false),
inTransition: abool.NewBool(false),
Ctx: context.Background(),
shutdownFlag: abool.NewBool(false),
workerGroup: sync.WaitGroup{},
workerCnt: &qtWorkerCnt,
}
// functions // functions
func queuedTaskTester(s string) { func queuedTaskTester(s string) {
@ -64,7 +58,7 @@ func prioritizedTaskTester(s string) {
func TestQueuedTask(t *testing.T) { func TestQueuedTask(t *testing.T) {
// skip // skip
if testing.Short() { if testing.Short() {
t.Skip("skipping test in short mode.") t.Skip("skipping test in short mode, as it is not fully deterministic")
} }
// init // init
@ -127,14 +121,14 @@ func TestScheduledTaskWaiting(t *testing.T) {
// skip // skip
if testing.Short() { if testing.Short() {
t.Skip("skipping test in short mode.") t.Skip("skipping test in short mode, as it is not fully deterministic")
} }
// init // init
expectedOutput := "0123456789" expectedOutput := "0123456789"
stSleepDuration = 10 * time.Millisecond stSleepDuration = 10 * time.Millisecond
stOutputChannel = make(chan string, 100) stOutputChannel = make(chan string, 100)
stWaitCh = make(chan bool, 0) stWaitCh = make(chan bool)
stWg.Add(10) stWg.Add(10)

63
modules/worker.go Normal file
View file

@ -0,0 +1,63 @@
package modules
import (
"context"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
)
var (
serviceBackoffDuration = 2 * time.Second
)
// StartWorker starts a generic worker that does not fit to be a Task or MicroTask, such as long running (and possibly mostly idle) sessions. You may declare a worker as a service, which will then be automatically restarted in case of an error.
func (m *Module) StartWorker(name string, service bool, fn func(context.Context) error) error {
atomic.AddInt32(m.workerCnt, 1)
m.waitGroup.Add(1)
defer func() {
atomic.AddInt32(m.workerCnt, -1)
m.waitGroup.Done()
}()
failCnt := 0
if service {
for {
if m.ShutdownInProgress() {
return nil
}
err := m.runWorker(name, fn)
if err != nil {
// log error and restart
failCnt++
sleepFor := time.Duration(failCnt) * serviceBackoffDuration
log.Errorf("module %s service-worker %s failed (%d): %s - restarting in %s", m.Name, name, failCnt, err, sleepFor)
time.Sleep(sleepFor)
} else {
// clean finish
return nil
}
}
} else {
return m.runWorker(name, fn)
}
}
func (m *Module) runWorker(name string, fn func(context.Context) error) (err error) {
defer func() {
// recover from panic
panicVal := recover()
if panicVal != nil {
me := m.NewPanicError(name, "worker", panicVal)
me.Report()
err = me
}
}()
// run
err = fn(m.Ctx)
return
}

63
modules/worker_test.go Normal file
View file

@ -0,0 +1,63 @@
package modules
import (
"context"
"errors"
"testing"
"time"
)
var (
wModule = initNewModule("worker test module", nil, nil, nil)
errTest = errors.New("test error")
)
func TestWorker(t *testing.T) {
// test basic functionality
err := wModule.StartWorker("test worker", false, func(ctx context.Context) error {
return nil
})
if err != nil {
t.Errorf("worker failed (should not): %s", err)
}
// test returning an error
err = wModule.StartWorker("test worker", false, func(ctx context.Context) error {
return errTest
})
if err != errTest {
t.Errorf("worker failed with unexpected error: %s", err)
}
// test service functionality
serviceBackoffDuration = 2 * time.Millisecond // speed up backoff
failCnt := 0
err = wModule.StartWorker("test worker", true, func(ctx context.Context) error {
failCnt++
t.Logf("service-worker test run #%d", failCnt)
if failCnt >= 3 {
return nil
}
return errTest
})
if err == errTest {
t.Errorf("service-worker failed with unexpected error: %s", err)
}
if failCnt != 3 {
t.Errorf("service-worker failed to restart")
}
// test panic recovery
err = wModule.StartWorker("test worker", false, func(ctx context.Context) error {
var a []byte
_ = a[0] //nolint // we want to runtime panic!
return nil
})
t.Logf("panic error message: %s", err)
panicked, mErr := IsPanic(err)
if !panicked {
t.Errorf("failed to return *ModuleError, got %+v", err)
} else {
t.Logf("panic stack trace:\n%s", mErr.StackTrace)
}
}