mirror of
https://github.com/safing/portbase
synced 2025-09-01 18:19:57 +00:00
Merge pull request #136 from safing/feature/patch-set-2
Container and modules improvements
This commit is contained in:
commit
f61528737b
14 changed files with 206 additions and 88 deletions
|
@ -287,7 +287,7 @@ func checkAuth(w http.ResponseWriter, r *http.Request, authRequired bool) (token
|
||||||
|
|
||||||
// Return authentication failure message if authentication is required.
|
// Return authentication failure message if authentication is required.
|
||||||
if authRequired {
|
if authRequired {
|
||||||
log.Tracer(r.Context()).Warningf("api: denying api access to %s", r.RemoteAddr)
|
log.Tracer(r.Context()).Warningf("api: denying api access from %s", r.RemoteAddr)
|
||||||
http.Error(w, err.Error(), http.StatusForbidden)
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
|
@ -272,7 +272,7 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// Wait for the owning module to be ready.
|
// Wait for the owning module to be ready.
|
||||||
if !moduleIsReady(e.BelongsTo) {
|
if !moduleIsReady(e.BelongsTo) {
|
||||||
http.Error(w, "The API endpoint is not ready yet. Please try again later.", http.StatusServiceUnavailable)
|
http.Error(w, "The API endpoint is not ready yet or the its module is not enabled. Please try again later.", http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,11 +44,21 @@ func (c *Container) Append(data []byte) {
|
||||||
c.compartments = append(c.compartments, data)
|
c.compartments = append(c.compartments, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrependNumber prepends a number (varint encoded).
|
||||||
|
func (c *Container) PrependNumber(n uint64) {
|
||||||
|
c.Prepend(varint.Pack64(n))
|
||||||
|
}
|
||||||
|
|
||||||
// AppendNumber appends a number (varint encoded).
|
// AppendNumber appends a number (varint encoded).
|
||||||
func (c *Container) AppendNumber(n uint64) {
|
func (c *Container) AppendNumber(n uint64) {
|
||||||
c.compartments = append(c.compartments, varint.Pack64(n))
|
c.compartments = append(c.compartments, varint.Pack64(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrependInt prepends an int (varint encoded).
|
||||||
|
func (c *Container) PrependInt(n int) {
|
||||||
|
c.Prepend(varint.Pack64(uint64(n)))
|
||||||
|
}
|
||||||
|
|
||||||
// AppendInt appends an int (varint encoded).
|
// AppendInt appends an int (varint encoded).
|
||||||
func (c *Container) AppendInt(n int) {
|
func (c *Container) AppendInt(n int) {
|
||||||
c.compartments = append(c.compartments, varint.Pack64(uint64(n)))
|
c.compartments = append(c.compartments, varint.Pack64(uint64(n)))
|
||||||
|
@ -60,6 +70,12 @@ func (c *Container) AppendAsBlock(data []byte) {
|
||||||
c.Append(data)
|
c.Append(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrependAsBlock prepends the length of the data and the data itself. Data will NOT be copied.
|
||||||
|
func (c *Container) PrependAsBlock(data []byte) {
|
||||||
|
c.Prepend(data)
|
||||||
|
c.PrependNumber(uint64(len(data)))
|
||||||
|
}
|
||||||
|
|
||||||
// AppendContainer appends another Container. Data will NOT be copied.
|
// AppendContainer appends another Container. Data will NOT be copied.
|
||||||
func (c *Container) AppendContainer(data *Container) {
|
func (c *Container) AppendContainer(data *Container) {
|
||||||
c.compartments = append(c.compartments, data.compartments...)
|
c.compartments = append(c.compartments, data.compartments...)
|
||||||
|
@ -71,6 +87,16 @@ func (c *Container) AppendContainerAsBlock(data *Container) {
|
||||||
c.compartments = append(c.compartments, data.compartments...)
|
c.compartments = append(c.compartments, data.compartments...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HoldsData returns true if the Container holds any data.
|
||||||
|
func (c *Container) HoldsData() bool {
|
||||||
|
for i := c.offset; i < len(c.compartments); i++ {
|
||||||
|
if len(c.compartments[i]) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Length returns the full length of all bytes held by the container.
|
// Length returns the full length of all bytes held by the container.
|
||||||
func (c *Container) Length() (length int) {
|
func (c *Container) Length() (length int) {
|
||||||
for i := c.offset; i < len(c.compartments); i++ {
|
for i := c.offset; i < len(c.compartments); i++ {
|
||||||
|
@ -109,6 +135,14 @@ func (c *Container) Get(n int) ([]byte, error) {
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAll returns all data. Data MAY be copied and IS consumed.
|
||||||
|
func (c *Container) GetAll() []byte {
|
||||||
|
// TODO: Improve.
|
||||||
|
buf := c.gather(c.Length())
|
||||||
|
c.skip(len(buf))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
// GetAsContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS consumed.
|
// GetAsContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS consumed.
|
||||||
func (c *Container) GetAsContainer(n int) (*Container, error) {
|
func (c *Container) GetAsContainer(n int) (*Container, error) {
|
||||||
new := c.gatherAsContainer(n)
|
new := c.gatherAsContainer(n)
|
||||||
|
@ -198,6 +232,9 @@ func (c *Container) checkOffset() {
|
||||||
|
|
||||||
// Error Handling
|
// Error Handling
|
||||||
|
|
||||||
|
/*
|
||||||
|
DEPRECATING... like.... NOW.
|
||||||
|
|
||||||
// SetError sets an error.
|
// SetError sets an error.
|
||||||
func (c *Container) SetError(err error) {
|
func (c *Container) SetError(err error) {
|
||||||
c.err = err
|
c.err = err
|
||||||
|
@ -227,6 +264,7 @@ func (c *Container) Error() error {
|
||||||
func (c *Container) ErrString() string {
|
func (c *Container) ErrString() string {
|
||||||
return c.err.Error()
|
return c.err.Error()
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
// Block Handling
|
// Block Handling
|
||||||
|
|
||||||
|
@ -236,11 +274,17 @@ func (c *Container) PrependLength() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Container) gather(n int) []byte {
|
func (c *Container) gather(n int) []byte {
|
||||||
// check if first slice holds enough data
|
// Check requested length.
|
||||||
|
if n <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the first slice holds enough data.
|
||||||
if len(c.compartments[c.offset]) >= n {
|
if len(c.compartments[c.offset]) >= n {
|
||||||
return c.compartments[c.offset][:n]
|
return c.compartments[c.offset][:n]
|
||||||
}
|
}
|
||||||
// start gathering data
|
|
||||||
|
// Start gathering data.
|
||||||
slice := make([]byte, n)
|
slice := make([]byte, n)
|
||||||
copySlice := slice
|
copySlice := slice
|
||||||
n = 0
|
n = 0
|
||||||
|
@ -257,6 +301,13 @@ func (c *Container) gather(n int) []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Container) gatherAsContainer(n int) (new *Container) {
|
func (c *Container) gatherAsContainer(n int) (new *Container) {
|
||||||
|
// Check requested length.
|
||||||
|
if n < 0 {
|
||||||
|
return nil
|
||||||
|
} else if n == 0 {
|
||||||
|
return &Container{}
|
||||||
|
}
|
||||||
|
|
||||||
new = &Container{}
|
new = &Container{}
|
||||||
for i := c.offset; i < len(c.compartments); i++ {
|
for i := c.offset; i < len(c.compartments); i++ {
|
||||||
if n >= len(c.compartments[i]) {
|
if n >= len(c.compartments[i]) {
|
||||||
|
@ -345,7 +396,7 @@ func (c *Container) GetNextN32() (uint32, error) {
|
||||||
|
|
||||||
// GetNextN64 parses and returns a varint of type uint64.
|
// GetNextN64 parses and returns a varint of type uint64.
|
||||||
func (c *Container) GetNextN64() (uint64, error) {
|
func (c *Container) GetNextN64() (uint64, error) {
|
||||||
buf := c.gather(9)
|
buf := c.gather(10)
|
||||||
num, n, err := varint.Unpack64(buf)
|
num, n, err := varint.Unpack64(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|
|
@ -2,7 +2,6 @@ package container
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/safing/portbase/utils"
|
"github.com/safing/portbase/utils"
|
||||||
|
@ -82,38 +81,6 @@ func compareMany(t *testing.T, reference []byte, other ...[]byte) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContainerErrorHandling(t *testing.T) {
|
|
||||||
|
|
||||||
c1 := New(nil)
|
|
||||||
|
|
||||||
if c1.HasError() {
|
|
||||||
t.Error("should not have error")
|
|
||||||
}
|
|
||||||
|
|
||||||
c1.SetError(errors.New("test error"))
|
|
||||||
|
|
||||||
if !c1.HasError() {
|
|
||||||
t.Error("should have error")
|
|
||||||
}
|
|
||||||
|
|
||||||
c2 := New(append([]byte{0}, []byte("test error")...))
|
|
||||||
|
|
||||||
if c2.HasError() {
|
|
||||||
t.Error("should not have error")
|
|
||||||
}
|
|
||||||
|
|
||||||
c2.CheckError()
|
|
||||||
|
|
||||||
if !c2.HasError() {
|
|
||||||
t.Error("should have error")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c2.Error().Error() != "test error" {
|
|
||||||
t.Errorf("error message mismatch, was %s", c2.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDataFetching(t *testing.T) {
|
func TestDataFetching(t *testing.T) {
|
||||||
c1 := New(utils.DuplicateBytes(testData))
|
c1 := New(utils.DuplicateBytes(testData))
|
||||||
data := c1.GetMax(1)
|
data := c1.GetMax(1)
|
||||||
|
|
|
@ -530,6 +530,10 @@ func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME:
|
||||||
|
// Flush the cache before we query the database.
|
||||||
|
// i.FlushCache()
|
||||||
|
|
||||||
return db.Query(q, i.options.Local, i.options.Internal)
|
return db.Query(q, i.options.Local, i.options.Internal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
// "github.com/pkg/bson"
|
// "github.com/pkg/bson"
|
||||||
|
|
||||||
"github.com/safing/portbase/formats/varint"
|
"github.com/safing/portbase/formats/varint"
|
||||||
|
"github.com/safing/portbase/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// define types
|
// define types
|
||||||
|
@ -64,7 +65,7 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
|
||||||
case JSON:
|
case JSON:
|
||||||
err := json.Unmarshal(data, t)
|
err := json.Unmarshal(data, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("dsd: failed to unpack json data: %s", data)
|
return nil, fmt.Errorf("dsd: failed to unpack json: %s, data: %s", err, utils.SafeFirst16Bytes(data))
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
case BSON:
|
case BSON:
|
||||||
|
@ -81,11 +82,11 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
|
||||||
}
|
}
|
||||||
_, err := genCodeStruct.GenCodeUnmarshal(data)
|
_, err := genCodeStruct.GenCodeUnmarshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("dsd: failed to unpack gencode data: %s", err)
|
return nil, fmt.Errorf("dsd: failed to unpack gencode: %s, data: %s", err, utils.SafeFirst16Bytes(data))
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("dsd: tried to load unknown type %d, data: %v", format, data)
|
return nil, fmt.Errorf("dsd: tried to load unknown type %d, data: %s", format, utils.SafeFirst16Bytes(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,3 +20,29 @@ func GetNextBlock(data []byte) ([]byte, int, error) {
|
||||||
}
|
}
|
||||||
return data[n:totalLength], totalLength, nil
|
return data[n:totalLength], totalLength, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EncodedSize returns the size required to varint-encode an uint.
|
||||||
|
func EncodedSize(n uint64) (size int) {
|
||||||
|
switch {
|
||||||
|
case n < 1<<7: // < 128
|
||||||
|
return 1
|
||||||
|
case n < 1<<14: // < 16384
|
||||||
|
return 2
|
||||||
|
case n < 1<<21: // < 2097152
|
||||||
|
return 3
|
||||||
|
case n < 1<<28: // < 268435456
|
||||||
|
return 4
|
||||||
|
case n < 1<<35: // < 34359738368
|
||||||
|
return 5
|
||||||
|
case n < 1<<42: // < 4398046511104
|
||||||
|
return 6
|
||||||
|
case n < 1<<49: // < 562949953421312
|
||||||
|
return 7
|
||||||
|
case n < 1<<56: // < 72057594037927936
|
||||||
|
return 8
|
||||||
|
case n < 1<<63: // < 9223372036854775808
|
||||||
|
return 9
|
||||||
|
default:
|
||||||
|
return 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,9 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrBufTooSmall is returned when there is not enough data for parsing a varint.
|
||||||
|
var ErrBufTooSmall = errors.New("varint: buf too small")
|
||||||
|
|
||||||
// Pack8 packs a uint8 into a VarInt.
|
// Pack8 packs a uint8 into a VarInt.
|
||||||
func Pack8(n uint8) []byte {
|
func Pack8(n uint8) []byte {
|
||||||
if n < 128 {
|
if n < 128 {
|
||||||
|
@ -37,13 +40,13 @@ func Pack64(n uint64) []byte {
|
||||||
// Unpack8 unpacks a VarInt into a uint8. It returns the extracted int, how many bytes were used and an error.
|
// Unpack8 unpacks a VarInt into a uint8. It returns the extracted int, how many bytes were used and an error.
|
||||||
func Unpack8(blob []byte) (uint8, int, error) {
|
func Unpack8(blob []byte) (uint8, int, error) {
|
||||||
if len(blob) < 1 {
|
if len(blob) < 1 {
|
||||||
return 0, 0, errors.New("varint: buf has zero length")
|
return 0, 0, ErrBufTooSmall
|
||||||
}
|
}
|
||||||
if blob[0] < 128 {
|
if blob[0] < 128 {
|
||||||
return blob[0], 1, nil
|
return blob[0], 1, nil
|
||||||
}
|
}
|
||||||
if len(blob) < 2 {
|
if len(blob) < 2 {
|
||||||
return 0, 0, errors.New("varint: buf too small")
|
return 0, 0, ErrBufTooSmall
|
||||||
}
|
}
|
||||||
if blob[1] != 0x01 {
|
if blob[1] != 0x01 {
|
||||||
return 0, 0, errors.New("varint: encoded integer greater than 255 (uint8)")
|
return 0, 0, errors.New("varint: encoded integer greater than 255 (uint8)")
|
||||||
|
@ -55,7 +58,7 @@ func Unpack8(blob []byte) (uint8, int, error) {
|
||||||
func Unpack16(blob []byte) (uint16, int, error) {
|
func Unpack16(blob []byte) (uint16, int, error) {
|
||||||
n, r := binary.Uvarint(blob)
|
n, r := binary.Uvarint(blob)
|
||||||
if r == 0 {
|
if r == 0 {
|
||||||
return 0, 0, errors.New("varint: buf too small")
|
return 0, 0, ErrBufTooSmall
|
||||||
}
|
}
|
||||||
if r < 0 {
|
if r < 0 {
|
||||||
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")
|
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")
|
||||||
|
@ -70,7 +73,7 @@ func Unpack16(blob []byte) (uint16, int, error) {
|
||||||
func Unpack32(blob []byte) (uint32, int, error) {
|
func Unpack32(blob []byte) (uint32, int, error) {
|
||||||
n, r := binary.Uvarint(blob)
|
n, r := binary.Uvarint(blob)
|
||||||
if r == 0 {
|
if r == 0 {
|
||||||
return 0, 0, errors.New("varint: buf too small")
|
return 0, 0, ErrBufTooSmall
|
||||||
}
|
}
|
||||||
if r < 0 {
|
if r < 0 {
|
||||||
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")
|
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")
|
||||||
|
@ -85,7 +88,7 @@ func Unpack32(blob []byte) (uint32, int, error) {
|
||||||
func Unpack64(blob []byte) (uint64, int, error) {
|
func Unpack64(blob []byte) (uint64, int, error) {
|
||||||
n, r := binary.Uvarint(blob)
|
n, r := binary.Uvarint(blob)
|
||||||
if r == 0 {
|
if r == 0 {
|
||||||
return 0, 0, errors.New("varint: buf too small")
|
return 0, 0, ErrBufTooSmall
|
||||||
}
|
}
|
||||||
if r < 0 {
|
if r < 0 {
|
||||||
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")
|
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")
|
||||||
|
|
|
@ -130,7 +130,6 @@ func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err
|
||||||
// start for module
|
// start for module
|
||||||
// hint: only microTasks global var is important for scheduling, others can be set here
|
// hint: only microTasks global var is important for scheduling, others can be set here
|
||||||
atomic.AddInt32(m.microTaskCnt, 1)
|
atomic.AddInt32(m.microTaskCnt, 1)
|
||||||
m.waitGroup.Add(1)
|
|
||||||
|
|
||||||
// set up recovery
|
// set up recovery
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -145,7 +144,7 @@ func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err
|
||||||
|
|
||||||
// finish for module
|
// finish for module
|
||||||
atomic.AddInt32(m.microTaskCnt, -1)
|
atomic.AddInt32(m.microTaskCnt, -1)
|
||||||
m.waitGroup.Done()
|
m.checkIfStopComplete()
|
||||||
|
|
||||||
// finish and possibly trigger next task
|
// finish and possibly trigger next task
|
||||||
atomic.AddInt32(microTasks, -1)
|
atomic.AddInt32(microTasks, -1)
|
||||||
|
|
|
@ -52,15 +52,16 @@ type Module struct { //nolint:maligned // not worth the effort
|
||||||
// start
|
// start
|
||||||
startComplete chan struct{}
|
startComplete chan struct{}
|
||||||
// stop
|
// stop
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
cancelCtx func()
|
cancelCtx func()
|
||||||
stopFlag *abool.AtomicBool
|
stopFlag *abool.AtomicBool
|
||||||
|
stopComplete chan struct{}
|
||||||
|
|
||||||
// workers/tasks
|
// workers/tasks
|
||||||
workerCnt *int32
|
ctrlFuncRunning *abool.AtomicBool
|
||||||
taskCnt *int32
|
workerCnt *int32
|
||||||
microTaskCnt *int32
|
taskCnt *int32
|
||||||
waitGroup sync.WaitGroup
|
microTaskCnt *int32
|
||||||
|
|
||||||
// events
|
// events
|
||||||
eventHooks map[string]*eventHooks
|
eventHooks map[string]*eventHooks
|
||||||
|
@ -205,11 +206,29 @@ func (m *Module) start(reports chan *report) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Module) checkIfStopComplete() {
|
||||||
|
if m.stopFlag.IsSet() &&
|
||||||
|
m.ctrlFuncRunning.IsNotSet() &&
|
||||||
|
atomic.LoadInt32(m.workerCnt) == 0 &&
|
||||||
|
atomic.LoadInt32(m.taskCnt) == 0 &&
|
||||||
|
atomic.LoadInt32(m.microTaskCnt) == 0 {
|
||||||
|
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
|
||||||
|
if m.stopComplete != nil {
|
||||||
|
close(m.stopComplete)
|
||||||
|
m.stopComplete = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Module) stop(reports chan *report) {
|
func (m *Module) stop(reports chan *report) {
|
||||||
// check and set intermediate status
|
|
||||||
m.Lock()
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
|
||||||
|
// check and set intermediate status
|
||||||
if m.status != StatusOnline {
|
if m.status != StatusOnline {
|
||||||
m.Unlock()
|
|
||||||
go func() {
|
go func() {
|
||||||
reports <- &report{
|
reports <- &report{
|
||||||
module: m,
|
module: m,
|
||||||
|
@ -218,47 +237,46 @@ func (m *Module) stop(reports chan *report) {
|
||||||
}()
|
}()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.status = StatusStopping
|
|
||||||
|
|
||||||
// reset start management
|
// Reset start/stop signal channels.
|
||||||
m.startComplete = make(chan struct{})
|
m.startComplete = make(chan struct{})
|
||||||
// init stop management
|
m.stopComplete = make(chan struct{})
|
||||||
m.cancelCtx()
|
|
||||||
|
// Make a copy of the stop channel.
|
||||||
|
stopComplete := m.stopComplete
|
||||||
|
|
||||||
|
// Set status and cancel context.
|
||||||
|
m.status = StatusStopping
|
||||||
m.stopFlag.Set()
|
m.stopFlag.Set()
|
||||||
|
m.cancelCtx()
|
||||||
|
|
||||||
m.Unlock()
|
go m.stopAllTasks(reports, stopComplete)
|
||||||
|
|
||||||
go m.stopAllTasks(reports)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Module) stopAllTasks(reports chan *report) {
|
func (m *Module) stopAllTasks(reports chan *report, stopComplete chan struct{}) {
|
||||||
// start shutdown function
|
// start shutdown function
|
||||||
stopFnFinished := abool.NewBool(false)
|
|
||||||
var stopFnError error
|
var stopFnError error
|
||||||
|
stopFuncRunning := abool.New()
|
||||||
if m.stopFn != nil {
|
if m.stopFn != nil {
|
||||||
m.waitGroup.Add(1)
|
stopFuncRunning.Set()
|
||||||
go func() {
|
go func() {
|
||||||
stopFnError = m.runCtrlFn("stop module", m.stopFn)
|
stopFnError = m.runCtrlFn("stop module", m.stopFn)
|
||||||
stopFnFinished.Set()
|
stopFuncRunning.UnSet()
|
||||||
m.waitGroup.Done()
|
m.checkIfStopComplete()
|
||||||
}()
|
}()
|
||||||
|
} else {
|
||||||
|
m.checkIfStopComplete()
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for workers and stop fn
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
m.waitGroup.Wait()
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// wait for results
|
// wait for results
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-stopComplete:
|
||||||
case <-time.After(moduleStopTimeout):
|
// case <-time.After(moduleStopTimeout):
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
log.Warningf(
|
log.Warningf(
|
||||||
"%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v workers=%d tasks=%d microtasks=%d, continuing shutdown...",
|
"%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v/%v workers=%d tasks=%d microtasks=%d, continuing shutdown...",
|
||||||
m.Name,
|
m.Name,
|
||||||
stopFnFinished.IsSet(),
|
stopFuncRunning.IsSet(), m.ctrlFuncRunning.IsSet(),
|
||||||
atomic.LoadInt32(m.workerCnt),
|
atomic.LoadInt32(m.workerCnt),
|
||||||
atomic.LoadInt32(m.taskCnt),
|
atomic.LoadInt32(m.taskCnt),
|
||||||
atomic.LoadInt32(m.microTaskCnt),
|
atomic.LoadInt32(m.microTaskCnt),
|
||||||
|
@ -267,7 +285,7 @@ func (m *Module) stopAllTasks(reports chan *report) {
|
||||||
|
|
||||||
// collect error
|
// collect error
|
||||||
var err error
|
var err error
|
||||||
if stopFnFinished.IsSet() && stopFnError != nil {
|
if stopFuncRunning.IsNotSet() && stopFnError != nil {
|
||||||
err = stopFnError
|
err = stopFnError
|
||||||
}
|
}
|
||||||
// set status
|
// set status
|
||||||
|
@ -328,10 +346,10 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ...
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
cancelCtx: cancelCtx,
|
cancelCtx: cancelCtx,
|
||||||
stopFlag: abool.NewBool(false),
|
stopFlag: abool.NewBool(false),
|
||||||
|
ctrlFuncRunning: abool.NewBool(false),
|
||||||
workerCnt: &workerCnt,
|
workerCnt: &workerCnt,
|
||||||
taskCnt: &taskCnt,
|
taskCnt: &taskCnt,
|
||||||
microTaskCnt: µTaskCnt,
|
microTaskCnt: µTaskCnt,
|
||||||
waitGroup: sync.WaitGroup{},
|
|
||||||
eventHooks: make(map[string]*eventHooks),
|
eventHooks: make(map[string]*eventHooks),
|
||||||
depNames: dependencies,
|
depNames: dependencies,
|
||||||
}
|
}
|
||||||
|
|
|
@ -330,7 +330,6 @@ func (t *Task) executeWithLocking() {
|
||||||
// start for module
|
// start for module
|
||||||
// hint: only queueWg global var is important for scheduling, others can be set here
|
// hint: only queueWg global var is important for scheduling, others can be set here
|
||||||
atomic.AddInt32(t.module.taskCnt, 1)
|
atomic.AddInt32(t.module.taskCnt, 1)
|
||||||
t.module.waitGroup.Add(1)
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// recover from panic
|
// recover from panic
|
||||||
|
@ -343,7 +342,7 @@ func (t *Task) executeWithLocking() {
|
||||||
|
|
||||||
// finish for module
|
// finish for module
|
||||||
atomic.AddInt32(t.module.taskCnt, -1)
|
atomic.AddInt32(t.module.taskCnt, -1)
|
||||||
t.module.waitGroup.Done()
|
t.module.checkIfStopComplete()
|
||||||
|
|
||||||
t.lock.Lock()
|
t.lock.Lock()
|
||||||
|
|
||||||
|
|
|
@ -39,10 +39,9 @@ func (m *Module) RunWorker(name string, fn func(context.Context) error) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
atomic.AddInt32(m.workerCnt, 1)
|
atomic.AddInt32(m.workerCnt, 1)
|
||||||
m.waitGroup.Add(1)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
atomic.AddInt32(m.workerCnt, -1)
|
atomic.AddInt32(m.workerCnt, -1)
|
||||||
m.waitGroup.Done()
|
m.checkIfStopComplete()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return m.runWorker(name, fn)
|
return m.runWorker(name, fn)
|
||||||
|
@ -60,10 +59,9 @@ func (m *Module) StartServiceWorker(name string, backoffDuration time.Duration,
|
||||||
|
|
||||||
func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) {
|
func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) {
|
||||||
atomic.AddInt32(m.workerCnt, 1)
|
atomic.AddInt32(m.workerCnt, 1)
|
||||||
m.waitGroup.Add(1)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
atomic.AddInt32(m.workerCnt, -1)
|
atomic.AddInt32(m.workerCnt, -1)
|
||||||
m.waitGroup.Done()
|
m.checkIfStopComplete()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if backoffDuration == 0 {
|
if backoffDuration == 0 {
|
||||||
|
@ -143,6 +141,10 @@ func (m *Module) runCtrlFn(name string, fn func() error) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.ctrlFuncRunning.SetToIf(false, true) {
|
||||||
|
defer m.ctrlFuncRunning.SetToIf(true, false)
|
||||||
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// recover from panic
|
// recover from panic
|
||||||
panicVal := recover()
|
panicVal := recover()
|
||||||
|
|
21
utils/safe.go
Normal file
21
utils/safe.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SafeFirst16Bytes(data []byte) string {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "<empty>"
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimPrefix(
|
||||||
|
strings.SplitN(hex.Dump(data), "\n", 2)[0],
|
||||||
|
"00000000 ",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SafeFirst16Chars(s string) string {
|
||||||
|
return SafeFirst16Bytes([]byte(s))
|
||||||
|
}
|
27
utils/safe_test.go
Normal file
27
utils/safe_test.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSafeFirst16(t *testing.T) {
|
||||||
|
assert.Equal(t,
|
||||||
|
"47 6f 20 69 73 20 61 6e 20 6f 70 65 6e 20 73 6f |Go is an open so|",
|
||||||
|
SafeFirst16Bytes([]byte("Go is an open source programming language.")),
|
||||||
|
)
|
||||||
|
assert.Equal(t,
|
||||||
|
"47 6f 20 69 73 20 61 6e 20 6f 70 65 6e 20 73 6f |Go is an open so|",
|
||||||
|
SafeFirst16Chars("Go is an open source programming language."),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t,
|
||||||
|
"<empty>",
|
||||||
|
SafeFirst16Bytes(nil),
|
||||||
|
)
|
||||||
|
assert.Equal(t,
|
||||||
|
"<empty>",
|
||||||
|
SafeFirst16Chars(""),
|
||||||
|
)
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue