Merge pull request #136 from safing/feature/patch-set-2

Container and modules improvements
This commit is contained in:
Daniel 2021-09-27 14:12:37 +02:00 committed by GitHub
commit f61528737b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 206 additions and 88 deletions

View file

@ -287,7 +287,7 @@ func checkAuth(w http.ResponseWriter, r *http.Request, authRequired bool) (token
// Return authentication failure message if authentication is required.
if authRequired {
log.Tracer(r.Context()).Warningf("api: denying api access to %s", r.RemoteAddr)
log.Tracer(r.Context()).Warningf("api: denying api access from %s", r.RemoteAddr)
http.Error(w, err.Error(), http.StatusForbidden)
return nil, true
}

View file

@ -272,7 +272,7 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Wait for the owning module to be ready.
if !moduleIsReady(e.BelongsTo) {
http.Error(w, "The API endpoint is not ready yet. Please try again later.", http.StatusServiceUnavailable)
http.Error(w, "The API endpoint is not ready yet or the its module is not enabled. Please try again later.", http.StatusServiceUnavailable)
return
}

View file

@ -44,11 +44,21 @@ func (c *Container) Append(data []byte) {
c.compartments = append(c.compartments, data)
}
// PrependNumber prepends a number (varint encoded).
func (c *Container) PrependNumber(n uint64) {
c.Prepend(varint.Pack64(n))
}
// AppendNumber appends a number (varint encoded).
func (c *Container) AppendNumber(n uint64) {
c.compartments = append(c.compartments, varint.Pack64(n))
}
// PrependInt prepends an int (varint encoded).
func (c *Container) PrependInt(n int) {
c.Prepend(varint.Pack64(uint64(n)))
}
// AppendInt appends an int (varint encoded).
func (c *Container) AppendInt(n int) {
c.compartments = append(c.compartments, varint.Pack64(uint64(n)))
@ -60,6 +70,12 @@ func (c *Container) AppendAsBlock(data []byte) {
c.Append(data)
}
// PrependAsBlock prepends the length of the data and the data itself. Data will NOT be copied.
func (c *Container) PrependAsBlock(data []byte) {
c.Prepend(data)
c.PrependNumber(uint64(len(data)))
}
// AppendContainer appends another Container. Data will NOT be copied.
func (c *Container) AppendContainer(data *Container) {
c.compartments = append(c.compartments, data.compartments...)
@ -71,6 +87,16 @@ func (c *Container) AppendContainerAsBlock(data *Container) {
c.compartments = append(c.compartments, data.compartments...)
}
// HoldsData returns true if the Container holds any data.
func (c *Container) HoldsData() bool {
for i := c.offset; i < len(c.compartments); i++ {
if len(c.compartments[i]) > 0 {
return true
}
}
return false
}
// Length returns the full length of all bytes held by the container.
func (c *Container) Length() (length int) {
for i := c.offset; i < len(c.compartments); i++ {
@ -109,6 +135,14 @@ func (c *Container) Get(n int) ([]byte, error) {
return buf, nil
}
// GetAll returns all data. Data MAY be copied and IS consumed.
func (c *Container) GetAll() []byte {
// TODO: Improve.
buf := c.gather(c.Length())
c.skip(len(buf))
return buf
}
// GetAsContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS consumed.
func (c *Container) GetAsContainer(n int) (*Container, error) {
new := c.gatherAsContainer(n)
@ -198,6 +232,9 @@ func (c *Container) checkOffset() {
// Error Handling
/*
DEPRECATING... like.... NOW.
// SetError sets an error.
func (c *Container) SetError(err error) {
c.err = err
@ -227,6 +264,7 @@ func (c *Container) Error() error {
func (c *Container) ErrString() string {
return c.err.Error()
}
*/
// Block Handling
@ -236,11 +274,17 @@ func (c *Container) PrependLength() {
}
func (c *Container) gather(n int) []byte {
// check if first slice holds enough data
// Check requested length.
if n <= 0 {
return nil
}
// Check if the first slice holds enough data.
if len(c.compartments[c.offset]) >= n {
return c.compartments[c.offset][:n]
}
// start gathering data
// Start gathering data.
slice := make([]byte, n)
copySlice := slice
n = 0
@ -257,6 +301,13 @@ func (c *Container) gather(n int) []byte {
}
func (c *Container) gatherAsContainer(n int) (new *Container) {
// Check requested length.
if n < 0 {
return nil
} else if n == 0 {
return &Container{}
}
new = &Container{}
for i := c.offset; i < len(c.compartments); i++ {
if n >= len(c.compartments[i]) {
@ -345,7 +396,7 @@ func (c *Container) GetNextN32() (uint32, error) {
// GetNextN64 parses and returns a varint of type uint64.
func (c *Container) GetNextN64() (uint64, error) {
buf := c.gather(9)
buf := c.gather(10)
num, n, err := varint.Unpack64(buf)
if err != nil {
return 0, err

View file

@ -2,7 +2,6 @@ package container
import (
"bytes"
"errors"
"testing"
"github.com/safing/portbase/utils"
@ -82,38 +81,6 @@ func compareMany(t *testing.T, reference []byte, other ...[]byte) {
}
}
func TestContainerErrorHandling(t *testing.T) {
c1 := New(nil)
if c1.HasError() {
t.Error("should not have error")
}
c1.SetError(errors.New("test error"))
if !c1.HasError() {
t.Error("should have error")
}
c2 := New(append([]byte{0}, []byte("test error")...))
if c2.HasError() {
t.Error("should not have error")
}
c2.CheckError()
if !c2.HasError() {
t.Error("should have error")
}
if c2.Error().Error() != "test error" {
t.Errorf("error message mismatch, was %s", c2.Error())
}
}
func TestDataFetching(t *testing.T) {
c1 := New(utils.DuplicateBytes(testData))
data := c1.GetMax(1)

View file

@ -530,6 +530,10 @@ func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) {
return nil, err
}
// FIXME:
// Flush the cache before we query the database.
// i.FlushCache()
return db.Query(q, i.options.Local, i.options.Internal)
}

View file

@ -11,6 +11,7 @@ import (
// "github.com/pkg/bson"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/utils"
)
// define types
@ -64,7 +65,7 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
case JSON:
err := json.Unmarshal(data, t)
if err != nil {
return nil, fmt.Errorf("dsd: failed to unpack json data: %s", data)
return nil, fmt.Errorf("dsd: failed to unpack json: %s, data: %s", err, utils.SafeFirst16Bytes(data))
}
return t, nil
case BSON:
@ -81,11 +82,11 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
}
_, err := genCodeStruct.GenCodeUnmarshal(data)
if err != nil {
return nil, fmt.Errorf("dsd: failed to unpack gencode data: %s", err)
return nil, fmt.Errorf("dsd: failed to unpack gencode: %s, data: %s", err, utils.SafeFirst16Bytes(data))
}
return t, nil
default:
return nil, fmt.Errorf("dsd: tried to load unknown type %d, data: %v", format, data)
return nil, fmt.Errorf("dsd: tried to load unknown type %d, data: %s", format, utils.SafeFirst16Bytes(data))
}
}

View file

@ -20,3 +20,29 @@ func GetNextBlock(data []byte) ([]byte, int, error) {
}
return data[n:totalLength], totalLength, nil
}
// EncodedSize returns the size required to varint-encode an uint.
func EncodedSize(n uint64) (size int) {
switch {
case n < 1<<7: // < 128
return 1
case n < 1<<14: // < 16384
return 2
case n < 1<<21: // < 2097152
return 3
case n < 1<<28: // < 268435456
return 4
case n < 1<<35: // < 34359738368
return 5
case n < 1<<42: // < 4398046511104
return 6
case n < 1<<49: // < 562949953421312
return 7
case n < 1<<56: // < 72057594037927936
return 8
case n < 1<<63: // < 9223372036854775808
return 9
default:
return 10
}
}

View file

@ -5,6 +5,9 @@ import (
"errors"
)
// ErrBufTooSmall is returned when there is not enough data for parsing a varint.
var ErrBufTooSmall = errors.New("varint: buf too small")
// Pack8 packs a uint8 into a VarInt.
func Pack8(n uint8) []byte {
if n < 128 {
@ -37,13 +40,13 @@ func Pack64(n uint64) []byte {
// Unpack8 unpacks a VarInt into a uint8. It returns the extracted int, how many bytes were used and an error.
func Unpack8(blob []byte) (uint8, int, error) {
if len(blob) < 1 {
return 0, 0, errors.New("varint: buf has zero length")
return 0, 0, ErrBufTooSmall
}
if blob[0] < 128 {
return blob[0], 1, nil
}
if len(blob) < 2 {
return 0, 0, errors.New("varint: buf too small")
return 0, 0, ErrBufTooSmall
}
if blob[1] != 0x01 {
return 0, 0, errors.New("varint: encoded integer greater than 255 (uint8)")
@ -55,7 +58,7 @@ func Unpack8(blob []byte) (uint8, int, error) {
func Unpack16(blob []byte) (uint16, int, error) {
n, r := binary.Uvarint(blob)
if r == 0 {
return 0, 0, errors.New("varint: buf too small")
return 0, 0, ErrBufTooSmall
}
if r < 0 {
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")
@ -70,7 +73,7 @@ func Unpack16(blob []byte) (uint16, int, error) {
func Unpack32(blob []byte) (uint32, int, error) {
n, r := binary.Uvarint(blob)
if r == 0 {
return 0, 0, errors.New("varint: buf too small")
return 0, 0, ErrBufTooSmall
}
if r < 0 {
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")
@ -85,7 +88,7 @@ func Unpack32(blob []byte) (uint32, int, error) {
func Unpack64(blob []byte) (uint64, int, error) {
n, r := binary.Uvarint(blob)
if r == 0 {
return 0, 0, errors.New("varint: buf too small")
return 0, 0, ErrBufTooSmall
}
if r < 0 {
return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)")

View file

@ -130,7 +130,6 @@ func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err
// start for module
// hint: only microTasks global var is important for scheduling, others can be set here
atomic.AddInt32(m.microTaskCnt, 1)
m.waitGroup.Add(1)
// set up recovery
defer func() {
@ -145,7 +144,7 @@ func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err
// finish for module
atomic.AddInt32(m.microTaskCnt, -1)
m.waitGroup.Done()
m.checkIfStopComplete()
// finish and possibly trigger next task
atomic.AddInt32(microTasks, -1)

View file

@ -52,15 +52,16 @@ type Module struct { //nolint:maligned // not worth the effort
// start
startComplete chan struct{}
// stop
Ctx context.Context
cancelCtx func()
stopFlag *abool.AtomicBool
Ctx context.Context
cancelCtx func()
stopFlag *abool.AtomicBool
stopComplete chan struct{}
// workers/tasks
workerCnt *int32
taskCnt *int32
microTaskCnt *int32
waitGroup sync.WaitGroup
ctrlFuncRunning *abool.AtomicBool
workerCnt *int32
taskCnt *int32
microTaskCnt *int32
// events
eventHooks map[string]*eventHooks
@ -205,11 +206,29 @@ func (m *Module) start(reports chan *report) {
}()
}
func (m *Module) checkIfStopComplete() {
if m.stopFlag.IsSet() &&
m.ctrlFuncRunning.IsNotSet() &&
atomic.LoadInt32(m.workerCnt) == 0 &&
atomic.LoadInt32(m.taskCnt) == 0 &&
atomic.LoadInt32(m.microTaskCnt) == 0 {
m.Lock()
defer m.Unlock()
if m.stopComplete != nil {
close(m.stopComplete)
m.stopComplete = nil
}
}
}
func (m *Module) stop(reports chan *report) {
// check and set intermediate status
m.Lock()
defer m.Unlock()
// check and set intermediate status
if m.status != StatusOnline {
m.Unlock()
go func() {
reports <- &report{
module: m,
@ -218,47 +237,46 @@ func (m *Module) stop(reports chan *report) {
}()
return
}
m.status = StatusStopping
// reset start management
// Reset start/stop signal channels.
m.startComplete = make(chan struct{})
// init stop management
m.cancelCtx()
m.stopComplete = make(chan struct{})
// Make a copy of the stop channel.
stopComplete := m.stopComplete
// Set status and cancel context.
m.status = StatusStopping
m.stopFlag.Set()
m.cancelCtx()
m.Unlock()
go m.stopAllTasks(reports)
go m.stopAllTasks(reports, stopComplete)
}
func (m *Module) stopAllTasks(reports chan *report) {
func (m *Module) stopAllTasks(reports chan *report, stopComplete chan struct{}) {
// start shutdown function
stopFnFinished := abool.NewBool(false)
var stopFnError error
stopFuncRunning := abool.New()
if m.stopFn != nil {
m.waitGroup.Add(1)
stopFuncRunning.Set()
go func() {
stopFnError = m.runCtrlFn("stop module", m.stopFn)
stopFnFinished.Set()
m.waitGroup.Done()
stopFuncRunning.UnSet()
m.checkIfStopComplete()
}()
} else {
m.checkIfStopComplete()
}
// wait for workers and stop fn
done := make(chan struct{})
go func() {
m.waitGroup.Wait()
close(done)
}()
// wait for results
select {
case <-done:
case <-time.After(moduleStopTimeout):
case <-stopComplete:
// case <-time.After(moduleStopTimeout):
case <-time.After(3 * time.Second):
log.Warningf(
"%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v workers=%d tasks=%d microtasks=%d, continuing shutdown...",
"%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v/%v workers=%d tasks=%d microtasks=%d, continuing shutdown...",
m.Name,
stopFnFinished.IsSet(),
stopFuncRunning.IsSet(), m.ctrlFuncRunning.IsSet(),
atomic.LoadInt32(m.workerCnt),
atomic.LoadInt32(m.taskCnt),
atomic.LoadInt32(m.microTaskCnt),
@ -267,7 +285,7 @@ func (m *Module) stopAllTasks(reports chan *report) {
// collect error
var err error
if stopFnFinished.IsSet() && stopFnError != nil {
if stopFuncRunning.IsNotSet() && stopFnError != nil {
err = stopFnError
}
// set status
@ -328,10 +346,10 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ...
Ctx: ctx,
cancelCtx: cancelCtx,
stopFlag: abool.NewBool(false),
ctrlFuncRunning: abool.NewBool(false),
workerCnt: &workerCnt,
taskCnt: &taskCnt,
microTaskCnt: &microTaskCnt,
waitGroup: sync.WaitGroup{},
eventHooks: make(map[string]*eventHooks),
depNames: dependencies,
}

View file

@ -330,7 +330,6 @@ func (t *Task) executeWithLocking() {
// start for module
// hint: only queueWg global var is important for scheduling, others can be set here
atomic.AddInt32(t.module.taskCnt, 1)
t.module.waitGroup.Add(1)
defer func() {
// recover from panic
@ -343,7 +342,7 @@ func (t *Task) executeWithLocking() {
// finish for module
atomic.AddInt32(t.module.taskCnt, -1)
t.module.waitGroup.Done()
t.module.checkIfStopComplete()
t.lock.Lock()

View file

@ -39,10 +39,9 @@ func (m *Module) RunWorker(name string, fn func(context.Context) error) error {
}
atomic.AddInt32(m.workerCnt, 1)
m.waitGroup.Add(1)
defer func() {
atomic.AddInt32(m.workerCnt, -1)
m.waitGroup.Done()
m.checkIfStopComplete()
}()
return m.runWorker(name, fn)
@ -60,10 +59,9 @@ func (m *Module) StartServiceWorker(name string, backoffDuration time.Duration,
func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) {
atomic.AddInt32(m.workerCnt, 1)
m.waitGroup.Add(1)
defer func() {
atomic.AddInt32(m.workerCnt, -1)
m.waitGroup.Done()
m.checkIfStopComplete()
}()
if backoffDuration == 0 {
@ -143,6 +141,10 @@ func (m *Module) runCtrlFn(name string, fn func() error) (err error) {
return
}
if m.ctrlFuncRunning.SetToIf(false, true) {
defer m.ctrlFuncRunning.SetToIf(true, false)
}
defer func() {
// recover from panic
panicVal := recover()

21
utils/safe.go Normal file
View file

@ -0,0 +1,21 @@
package utils
import (
"encoding/hex"
"strings"
)
func SafeFirst16Bytes(data []byte) string {
if len(data) == 0 {
return "<empty>"
}
return strings.TrimPrefix(
strings.SplitN(hex.Dump(data), "\n", 2)[0],
"00000000 ",
)
}
func SafeFirst16Chars(s string) string {
return SafeFirst16Bytes([]byte(s))
}

27
utils/safe_test.go Normal file
View file

@ -0,0 +1,27 @@
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSafeFirst16(t *testing.T) {
assert.Equal(t,
"47 6f 20 69 73 20 61 6e 20 6f 70 65 6e 20 73 6f |Go is an open so|",
SafeFirst16Bytes([]byte("Go is an open source programming language.")),
)
assert.Equal(t,
"47 6f 20 69 73 20 61 6e 20 6f 70 65 6e 20 73 6f |Go is an open so|",
SafeFirst16Chars("Go is an open source programming language."),
)
assert.Equal(t,
"<empty>",
SafeFirst16Bytes(nil),
)
assert.Equal(t,
"<empty>",
SafeFirst16Chars(""),
)
}