mirror of
https://github.com/safing/portbase
synced 2025-04-23 10:49:09 +00:00
commit
f311c2864d
88 changed files with 1920 additions and 1759 deletions
.golangci.yml
api
config
crypto
hash
random
database
accessor
controller.gocontrollers.godatabase_test.godbmodule
dbutils
hook.gointerface.gomaintenance.goquery
registry.gostorage
utils/kvops
formats/dsd
log
flags.goformatting.goformatting_linux.goformatting_windows.goinput.gologging.gologging_test.gooutput.gotrace.gotrace_test.go
modules
doc.goerror.gomicrotasks.gomicrotasks_test.gomodules.gomodules_test.gostart.gotasks.gotasks_test.goworker.goworker_test.go
notifications
portbase.gotaskmanager
testutils
6
.golangci.yml
Normal file
6
.golangci.yml
Normal file
|
@ -0,0 +1,6 @@
|
|||
linters:
|
||||
enable-all: true
|
||||
disable:
|
||||
- lll
|
||||
- gochecknoinits
|
||||
- gochecknoglobals
|
|
@ -101,6 +101,7 @@ func authMiddleware(next http.Handler) http.Handler {
|
|||
Name: cookieName,
|
||||
Value: tokenString,
|
||||
HttpOnly: true,
|
||||
MaxAge: int(cookieTTL.Seconds()),
|
||||
})
|
||||
|
||||
// serve
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
|
@ -35,7 +35,6 @@ type Client struct {
|
|||
operations map[string]*Operation
|
||||
nextOpID uint64
|
||||
|
||||
wsConn *websocket.Conn
|
||||
lastError string
|
||||
}
|
||||
|
||||
|
@ -72,12 +71,12 @@ func (c *Client) Connect() error {
|
|||
func (c *Client) StayConnected() {
|
||||
log.Infof("client: connecting to Portmaster at %s", c.server)
|
||||
|
||||
c.Connect()
|
||||
_ = c.Connect()
|
||||
for {
|
||||
select {
|
||||
case <-time.After(backOffTimer):
|
||||
log.Infof("client: reconnecting...")
|
||||
c.Connect()
|
||||
_ = c.Connect()
|
||||
case <-c.shutdownSignal:
|
||||
return
|
||||
}
|
||||
|
|
|
@ -9,10 +9,12 @@ import (
|
|||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// Client errors
|
||||
var (
|
||||
ErrMalformedMessage = errors.New("malformed message")
|
||||
)
|
||||
|
||||
// Message is an API message.
|
||||
type Message struct {
|
||||
OpID string
|
||||
Type string
|
||||
|
@ -22,6 +24,7 @@ type Message struct {
|
|||
sent *abool.AtomicBool
|
||||
}
|
||||
|
||||
// ParseMessage parses the given raw data and returns a Message.
|
||||
func ParseMessage(data []byte) (*Message, error) {
|
||||
parts := bytes.SplitN(data, apiSeperatorBytes, 4)
|
||||
if len(parts) < 2 {
|
||||
|
@ -68,6 +71,7 @@ func ParseMessage(data []byte) (*Message, error) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
// Pack serializes a message into a []byte slice.
|
||||
func (m *Message) Pack() ([]byte, error) {
|
||||
c := container.New([]byte(m.OpID), apiSeperatorBytes, []byte(m.Type))
|
||||
|
||||
|
@ -90,28 +94,3 @@ func (m *Message) Pack() ([]byte, error) {
|
|||
|
||||
return c.CompileData(), nil
|
||||
}
|
||||
|
||||
func (m *Message) IsOk() bool {
|
||||
return m.Type == MsgOk
|
||||
}
|
||||
func (m *Message) IsDone() bool {
|
||||
return m.Type == MsgDone
|
||||
}
|
||||
func (m *Message) IsError() bool {
|
||||
return m.Type == MsgError
|
||||
}
|
||||
func (m *Message) IsUpdate() bool {
|
||||
return m.Type == MsgUpdate
|
||||
}
|
||||
func (m *Message) IsNew() bool {
|
||||
return m.Type == MsgNew
|
||||
}
|
||||
func (m *Message) IsDelete() bool {
|
||||
return m.Type == MsgDelete
|
||||
}
|
||||
func (m *Message) IsWarning() bool {
|
||||
return m.Type == MsgWarning
|
||||
}
|
||||
func (m *Message) GetMessage() string {
|
||||
return m.Key
|
||||
}
|
||||
|
|
|
@ -191,7 +191,7 @@ func (api *DatabaseAPI) writer() {
|
|||
select {
|
||||
// prioritize direct writes
|
||||
case data = <-api.sendQueue:
|
||||
if data == nil || len(data) == 0 {
|
||||
if len(data) == 0 {
|
||||
api.shutdown()
|
||||
return
|
||||
}
|
||||
|
@ -242,8 +242,9 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) {
|
|||
r, err := api.db.Get(key)
|
||||
if err == nil {
|
||||
data, err = r.Marshal(r, record.JSON)
|
||||
} else {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
}
|
||||
if err == nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil) //nolint:nilness // FIXME: possibly false positive (golangci-lint govet/nilness)
|
||||
return
|
||||
}
|
||||
api.send(opID, dbMsgTypeOk, r.Key(), data)
|
||||
|
@ -357,7 +358,7 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
|
|||
select {
|
||||
case <-api.shutdownSignal:
|
||||
// cancel sub and return
|
||||
sub.Cancel()
|
||||
_ = sub.Cancel()
|
||||
return
|
||||
case r := <-sub.Feed:
|
||||
// process sub feed
|
||||
|
@ -373,17 +374,19 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
|
|||
// TODO: use upd, new and delete msgTypes
|
||||
r.Lock()
|
||||
isDeleted := r.Meta().IsDeleted()
|
||||
new := r.Meta().Created == r.Meta().Modified
|
||||
r.Unlock()
|
||||
if isDeleted {
|
||||
switch {
|
||||
case isDeleted:
|
||||
api.send(opID, dbMsgTypeDel, r.Key(), nil)
|
||||
} else {
|
||||
case new:
|
||||
api.send(opID, dbMsgTypeNew, r.Key(), data)
|
||||
default:
|
||||
api.send(opID, dbMsgTypeUpd, r.Key(), data)
|
||||
}
|
||||
} else {
|
||||
} else if sub.Err != nil {
|
||||
// sub feed ended
|
||||
if sub.Err != nil {
|
||||
api.send(opID, dbMsgTypeError, sub.Err.Error(), nil)
|
||||
}
|
||||
api.send(opID, dbMsgTypeError, sub.Err.Error(), nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -489,10 +492,7 @@ func (api *DatabaseAPI) handleInsert(opID []byte, key string, data []byte) {
|
|||
return false
|
||||
}
|
||||
insertError = acc.Set(key.String(), value.Value())
|
||||
if insertError != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return insertError == nil
|
||||
})
|
||||
|
||||
if insertError != nil {
|
||||
|
|
|
@ -99,12 +99,12 @@ func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterato
|
|||
}
|
||||
}
|
||||
|
||||
go s.processQuery(q, it, opts)
|
||||
go s.processQuery(it, opts)
|
||||
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator, opts []*Option) {
|
||||
func (s *StorageInterface) processQuery(it *iterator.Iterator, opts []*Option) {
|
||||
|
||||
sort.Sort(sortableOptions(opts))
|
||||
|
||||
|
|
|
@ -9,12 +9,6 @@ import (
|
|||
var (
|
||||
validityFlag = abool.NewBool(true)
|
||||
validityFlagLock sync.RWMutex
|
||||
|
||||
tableLock sync.RWMutex
|
||||
|
||||
stringTable map[string]string
|
||||
intTable map[string]int
|
||||
boolTable map[string]bool
|
||||
)
|
||||
|
||||
type (
|
||||
|
|
|
@ -2,8 +2,8 @@ package config
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ var (
|
|||
// ErrInvalidOptionType is returned by SetConfigOption and SetDefaultConfigOption if given an unsupported option type.
|
||||
ErrInvalidOptionType = errors.New("invalid option value type")
|
||||
|
||||
changedSignal = make(chan struct{}, 0)
|
||||
changedSignal = make(chan struct{})
|
||||
)
|
||||
|
||||
// Changed signals if any config option was changed.
|
||||
|
@ -34,7 +34,7 @@ func Changed() <-chan struct{} {
|
|||
func triggerChange() {
|
||||
// must be locked!
|
||||
close(changedSignal)
|
||||
changedSignal = make(chan struct{}, 0)
|
||||
changedSignal = make(chan struct{})
|
||||
}
|
||||
|
||||
// setConfig sets the (prioritized) user defined config.
|
||||
|
@ -141,7 +141,7 @@ func setConfigOption(name string, value interface{}, push bool) error {
|
|||
|
||||
if err == nil {
|
||||
resetValidityFlag()
|
||||
go saveConfig()
|
||||
go saveConfig() //nolint:errcheck // error is logged
|
||||
triggerChange()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:goconst,errcheck
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
|
|
@ -39,16 +39,19 @@ func getTypeName(t uint8) string {
|
|||
|
||||
// Option describes a configuration option.
|
||||
type Option struct {
|
||||
Name string
|
||||
Key string // category/sub/key
|
||||
Description string
|
||||
Name string
|
||||
Key string // in path format: category/sub/key
|
||||
Description string
|
||||
|
||||
ExpertiseLevel uint8
|
||||
OptType uint8
|
||||
RequiresRestart bool
|
||||
DefaultValue interface{}
|
||||
|
||||
ExternalOptType string
|
||||
ValidationRegex string
|
||||
RequiresRestart bool
|
||||
compiledRegex *regexp.Regexp
|
||||
|
||||
compiledRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
// Export expors an option to a Record.
|
||||
|
|
|
@ -82,7 +82,7 @@ func MapToJSON(mapData map[string]interface{}) ([]byte, error) {
|
|||
for key, value := range mapData {
|
||||
new[key] = value
|
||||
}
|
||||
|
||||
|
||||
expand(new)
|
||||
return json.MarshalIndent(new, "", " ")
|
||||
}
|
||||
|
|
|
@ -1,145 +0,0 @@
|
|||
package hash
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"hash"
|
||||
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
type Algorithm uint8
|
||||
|
||||
const (
|
||||
SHA2_224 Algorithm = 1 + iota
|
||||
SHA2_256
|
||||
SHA2_512_224
|
||||
SHA2_512_256
|
||||
SHA2_384
|
||||
SHA2_512
|
||||
SHA3_224
|
||||
SHA3_256
|
||||
SHA3_384
|
||||
SHA3_512
|
||||
BLAKE2S_256
|
||||
BLAKE2B_256
|
||||
BLAKE2B_384
|
||||
BLAKE2B_512
|
||||
)
|
||||
|
||||
var (
|
||||
attributes = map[Algorithm][]uint8{
|
||||
// block size, output size, security strength - in bytes
|
||||
SHA2_224: []uint8{64, 28, 14},
|
||||
SHA2_256: []uint8{64, 32, 16},
|
||||
SHA2_512_224: []uint8{128, 28, 14},
|
||||
SHA2_512_256: []uint8{128, 32, 16},
|
||||
SHA2_384: []uint8{128, 48, 24},
|
||||
SHA2_512: []uint8{128, 64, 32},
|
||||
SHA3_224: []uint8{144, 28, 14},
|
||||
SHA3_256: []uint8{136, 32, 16},
|
||||
SHA3_384: []uint8{104, 48, 24},
|
||||
SHA3_512: []uint8{72, 64, 32},
|
||||
BLAKE2S_256: []uint8{64, 32, 16},
|
||||
BLAKE2B_256: []uint8{128, 32, 16},
|
||||
BLAKE2B_384: []uint8{128, 48, 24},
|
||||
BLAKE2B_512: []uint8{128, 64, 32},
|
||||
}
|
||||
|
||||
functions = map[Algorithm]func() hash.Hash{
|
||||
SHA2_224: sha256.New224,
|
||||
SHA2_256: sha256.New,
|
||||
SHA2_512_224: sha512.New512_224,
|
||||
SHA2_512_256: sha512.New512_256,
|
||||
SHA2_384: sha512.New384,
|
||||
SHA2_512: sha512.New,
|
||||
SHA3_224: sha3.New224,
|
||||
SHA3_256: sha3.New256,
|
||||
SHA3_384: sha3.New384,
|
||||
SHA3_512: sha3.New512,
|
||||
BLAKE2S_256: NewBlake2s256,
|
||||
BLAKE2B_256: NewBlake2b256,
|
||||
BLAKE2B_384: NewBlake2b384,
|
||||
BLAKE2B_512: NewBlake2b512,
|
||||
}
|
||||
|
||||
// just ordered by strength and establishment, no research conducted yet.
|
||||
orderedByRecommendation = []Algorithm{
|
||||
SHA3_512, // {72, 64, 32}
|
||||
SHA2_512, // {128, 64, 32}
|
||||
BLAKE2B_512, // {128, 64, 32}
|
||||
SHA3_384, // {104, 48, 24}
|
||||
SHA2_384, // {128, 48, 24}
|
||||
BLAKE2B_384, // {128, 48, 24}
|
||||
SHA3_256, // {136, 32, 16}
|
||||
SHA2_512_256, // {128, 32, 16}
|
||||
SHA2_256, // {64, 32, 16}
|
||||
BLAKE2B_256, // {128, 32, 16}
|
||||
BLAKE2S_256, // {64, 32, 16}
|
||||
SHA3_224, // {144, 28, 14}
|
||||
SHA2_512_224, // {128, 28, 14}
|
||||
SHA2_224, // {64, 28, 14}
|
||||
}
|
||||
|
||||
// names
|
||||
names = map[Algorithm]string{
|
||||
SHA2_224: "SHA2-224",
|
||||
SHA2_256: "SHA2-256",
|
||||
SHA2_512_224: "SHA2-512/224",
|
||||
SHA2_512_256: "SHA2-512/256",
|
||||
SHA2_384: "SHA2-384",
|
||||
SHA2_512: "SHA2-512",
|
||||
SHA3_224: "SHA3-224",
|
||||
SHA3_256: "SHA3-256",
|
||||
SHA3_384: "SHA3-384",
|
||||
SHA3_512: "SHA3-512",
|
||||
BLAKE2S_256: "Blake2s-256",
|
||||
BLAKE2B_256: "Blake2b-256",
|
||||
BLAKE2B_384: "Blake2b-384",
|
||||
BLAKE2B_512: "Blake2b-512",
|
||||
}
|
||||
)
|
||||
|
||||
func (a Algorithm) BlockSize() uint8 {
|
||||
att, ok := attributes[a]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return att[0]
|
||||
}
|
||||
|
||||
func (a Algorithm) Size() uint8 {
|
||||
att, ok := attributes[a]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return att[1]
|
||||
}
|
||||
|
||||
func (a Algorithm) SecurityStrength() uint8 {
|
||||
att, ok := attributes[a]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return att[2]
|
||||
}
|
||||
|
||||
func (a Algorithm) String() string {
|
||||
return a.Name()
|
||||
}
|
||||
|
||||
func (a Algorithm) Name() string {
|
||||
name, ok := names[a]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (a Algorithm) New() hash.Hash {
|
||||
fn, ok := functions[a]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return fn()
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
package hash
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAttributes(t *testing.T) {
|
||||
|
||||
for alg, att := range attributes {
|
||||
|
||||
name, ok := names[alg]
|
||||
if !ok {
|
||||
t.Errorf("hash test: name missing for Algorithm ID %d", alg)
|
||||
}
|
||||
_ = alg.String()
|
||||
|
||||
_, ok = functions[alg]
|
||||
if !ok {
|
||||
t.Errorf("hash test: function missing for Algorithm %s", name)
|
||||
}
|
||||
hash := alg.New()
|
||||
|
||||
if len(att) != 3 {
|
||||
t.Errorf("hash test: Algorithm %s does not have exactly 3 attributes", name)
|
||||
}
|
||||
|
||||
if hash.BlockSize() != int(alg.BlockSize()) {
|
||||
t.Errorf("hash test: block size mismatch at Algorithm %s", name)
|
||||
}
|
||||
if hash.Size() != int(alg.Size()) {
|
||||
t.Errorf("hash test: size mismatch at Algorithm %s", name)
|
||||
}
|
||||
if alg.Size()/2 != alg.SecurityStrength() {
|
||||
t.Errorf("hash test: possible strength error at Algorithm %s", name)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
noAlg := Algorithm(255)
|
||||
if noAlg.String() != "" {
|
||||
t.Error("hash test: invalid Algorithm error")
|
||||
}
|
||||
if noAlg.BlockSize() != 0 {
|
||||
t.Error("hash test: invalid Algorithm error")
|
||||
}
|
||||
if noAlg.Size() != 0 {
|
||||
t.Error("hash test: invalid Algorithm error")
|
||||
}
|
||||
if noAlg.SecurityStrength() != 0 {
|
||||
t.Error("hash test: invalid Algorithm error")
|
||||
}
|
||||
if noAlg.New() != nil {
|
||||
t.Error("hash test: invalid Algorithm error")
|
||||
}
|
||||
|
||||
}
|
|
@ -1,131 +0,0 @@
|
|||
package hash
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/safing/portbase/formats/varint"
|
||||
)
|
||||
|
||||
type Hash struct {
|
||||
Algorithm Algorithm
|
||||
Sum []byte
|
||||
}
|
||||
|
||||
func FromBytes(bytes []byte) (*Hash, int, error) {
|
||||
hash := &Hash{}
|
||||
alg, read, err := varint.Unpack8(bytes)
|
||||
hash.Algorithm = Algorithm(alg)
|
||||
if err != nil {
|
||||
return nil, 0, errors.New(fmt.Sprintf("hash: failed to parse: %s", err))
|
||||
}
|
||||
// TODO: check if length is correct
|
||||
hash.Sum = bytes[read:]
|
||||
return hash, 0, nil
|
||||
}
|
||||
|
||||
func (h *Hash) Bytes() []byte {
|
||||
return append(varint.Pack8(uint8(h.Algorithm)), h.Sum...)
|
||||
}
|
||||
|
||||
func FromSafe64(s string) (*Hash, error) {
|
||||
bytes, err := base64.RawURLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("hash: failed to parse: %s", err))
|
||||
}
|
||||
hash, _, err := FromBytes(bytes)
|
||||
return hash, err
|
||||
}
|
||||
|
||||
func (h *Hash) Safe64() string {
|
||||
return base64.RawURLEncoding.EncodeToString(h.Bytes())
|
||||
}
|
||||
|
||||
func FromHex(s string) (*Hash, error) {
|
||||
bytes, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("hash: failed to parse: %s", err))
|
||||
}
|
||||
hash, _, err := FromBytes(bytes)
|
||||
return hash, err
|
||||
}
|
||||
|
||||
func (h *Hash) Hex() string {
|
||||
return hex.EncodeToString(h.Bytes())
|
||||
}
|
||||
|
||||
func (h *Hash) Equal(other *Hash) bool {
|
||||
if h.Algorithm != other.Algorithm {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(h.Sum, other.Sum)
|
||||
}
|
||||
|
||||
func Sum(data []byte, alg Algorithm) *Hash {
|
||||
hasher := alg.New()
|
||||
hasher.Write(data)
|
||||
return &Hash{
|
||||
Algorithm: alg,
|
||||
Sum: hasher.Sum(nil),
|
||||
}
|
||||
}
|
||||
|
||||
func SumString(data string, alg Algorithm) *Hash {
|
||||
hasher := alg.New()
|
||||
io.WriteString(hasher, data)
|
||||
return &Hash{
|
||||
Algorithm: alg,
|
||||
Sum: hasher.Sum(nil),
|
||||
}
|
||||
}
|
||||
|
||||
func SumReader(reader io.Reader, alg Algorithm) (*Hash, error) {
|
||||
hasher := alg.New()
|
||||
_, err := io.Copy(hasher, reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Hash{
|
||||
Algorithm: alg,
|
||||
Sum: hasher.Sum(nil),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SumAndCompare(data []byte, other Hash) (bool, *Hash) {
|
||||
newHash := Sum(data, other.Algorithm)
|
||||
return other.Equal(newHash), newHash
|
||||
}
|
||||
|
||||
func SumReaderAndCompare(reader io.Reader, other Hash) (bool, *Hash, error) {
|
||||
newHash, err := SumReader(reader, other.Algorithm)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
return other.Equal(newHash), newHash, nil
|
||||
}
|
||||
|
||||
func RecommendedAlg(strengthInBits uint16) Algorithm {
|
||||
strengthInBytes := uint8(strengthInBits / 8)
|
||||
if strengthInBits%8 != 0 {
|
||||
strengthInBytes++
|
||||
}
|
||||
if strengthInBytes == 0 {
|
||||
strengthInBytes = uint8(0xFF)
|
||||
}
|
||||
chosenAlg := orderedByRecommendation[0]
|
||||
for _, alg := range orderedByRecommendation {
|
||||
strength := alg.SecurityStrength()
|
||||
if strength < strengthInBytes {
|
||||
break
|
||||
}
|
||||
chosenAlg = alg
|
||||
if strength == strengthInBytes {
|
||||
break
|
||||
}
|
||||
}
|
||||
return chosenAlg
|
||||
}
|
|
@ -1,82 +0,0 @@
|
|||
package hash
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
testEmpty = []byte("")
|
||||
testFox = []byte("The quick brown fox jumps over the lazy dog")
|
||||
)
|
||||
|
||||
func testAlgorithm(t *testing.T, alg Algorithm, emptyHex, foxHex string) {
|
||||
|
||||
var err error
|
||||
|
||||
// testEmpty
|
||||
hash := Sum(testEmpty, alg)
|
||||
if err != nil {
|
||||
t.Errorf("test Sum %s (empty): error occured: %s", alg.String(), err)
|
||||
}
|
||||
if hash.Hex()[2:] != emptyHex {
|
||||
t.Errorf("test Sum %s (empty): hex sum mismatch, expected %s, got %s", alg.String(), emptyHex, hash.Hex())
|
||||
}
|
||||
|
||||
// testFox
|
||||
hash = Sum(testFox, alg)
|
||||
if err != nil {
|
||||
t.Errorf("test Sum %s (fox): error occured: %s", alg.String(), err)
|
||||
}
|
||||
if hash.Hex()[2:] != foxHex {
|
||||
t.Errorf("test Sum %s (fox): hex sum mismatch, expected %s, got %s", alg.String(), foxHex, hash.Hex())
|
||||
}
|
||||
|
||||
// testEmpty
|
||||
hash = SumString(string(testEmpty), alg)
|
||||
if err != nil {
|
||||
t.Errorf("test SumString %s (empty): error occured: %s", alg.String(), err)
|
||||
}
|
||||
if hash.Hex()[2:] != emptyHex {
|
||||
t.Errorf("test SumString %s (empty): hex sum mismatch, expected %s, got %s", alg.String(), emptyHex, hash.Hex())
|
||||
}
|
||||
|
||||
// testFox
|
||||
hash = SumString(string(testFox), alg)
|
||||
if err != nil {
|
||||
t.Errorf("test SumString %s (fox): error occured: %s", alg.String(), err)
|
||||
}
|
||||
if hash.Hex()[2:] != foxHex {
|
||||
t.Errorf("test SumString %s (fox): hex sum mismatch, expected %s, got %s", alg.String(), foxHex, hash.Hex())
|
||||
}
|
||||
|
||||
// testEmpty
|
||||
hash, err = SumReader(bytes.NewReader(testEmpty), alg)
|
||||
if err != nil {
|
||||
t.Errorf("test SumReader %s (empty): error occured: %s", alg.String(), err)
|
||||
}
|
||||
if hash.Hex()[2:] != emptyHex {
|
||||
t.Errorf("test SumReader %s (empty): hex sum mismatch, expected %s, got %s", alg.String(), emptyHex, hash.Hex())
|
||||
}
|
||||
|
||||
// testFox
|
||||
hash, err = SumReader(bytes.NewReader(testFox), alg)
|
||||
if err != nil {
|
||||
t.Errorf("test SumReader %s (fox): error occured: %s", alg.String(), err)
|
||||
}
|
||||
if hash.Hex()[2:] != foxHex {
|
||||
t.Errorf("test SumReader %s (fox): hex sum mismatch, expected %s, got %s", alg.String(), foxHex, hash.Hex())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
testAlgorithm(t, SHA2_512,
|
||||
"cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
|
||||
"07e547d9586f6a73f73fbac0435ed76951218fb7d0c8d788a309d785436bbb642e93a252a954f23912547d1e8a3b5ed6e1bfd7097821233fa0538f3db854fee6",
|
||||
)
|
||||
testAlgorithm(t, SHA3_512,
|
||||
"a69f73cca23a9ac5c8b567dc185a756e97c982164fe25859e0d1dcc1475c80a615b2123af1f5f94c11e3e9402c3ac558f500199d95b6d3e301758586281dcd26",
|
||||
"01dedd5de4ef14642445ba5f5b97c15e47b9ad931326e4b0727cd94cefc44fff23f07bf543139939b49128caf436dc1bdee54fcb24023a08d9403f9b4bf0d450",
|
||||
)
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
package hash
|
||||
|
||||
import (
|
||||
"hash"
|
||||
|
||||
"golang.org/x/crypto/blake2b"
|
||||
"golang.org/x/crypto/blake2s"
|
||||
)
|
||||
|
||||
func NewBlake2s256() hash.Hash {
|
||||
h, _ := blake2s.New256(nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func NewBlake2b256() hash.Hash {
|
||||
h, _ := blake2b.New256(nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func NewBlake2b384() hash.Hash {
|
||||
h, _ := blake2b.New384(nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func NewBlake2b512() hash.Hash {
|
||||
h, _ := blake2b.New512(nil)
|
||||
return h
|
||||
}
|
|
@ -10,23 +10,10 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
rngFeeder = make(chan []byte, 0)
|
||||
rngFeeder = make(chan []byte)
|
||||
minFeedEntropy config.IntOption
|
||||
)
|
||||
|
||||
func init() {
|
||||
config.Register(&config.Option{
|
||||
Name: "Minimum Feed Entropy",
|
||||
Key: "random/min_feed_entropy",
|
||||
Description: "The minimum amount of entropy before a entropy source is feed to the RNG, in bits.",
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
OptType: config.OptTypeInt,
|
||||
DefaultValue: 256,
|
||||
ValidationRegex: "^[0-9]{3,5}$",
|
||||
})
|
||||
minFeedEntropy = config.Concurrent.GetAsInt("random/min_feed_entropy", 256)
|
||||
}
|
||||
|
||||
// The Feeder is used to feed entropy to the RNG.
|
||||
type Feeder struct {
|
||||
input chan *entropyData
|
||||
|
@ -43,7 +30,7 @@ type entropyData struct {
|
|||
// NewFeeder returns a new entropy Feeder.
|
||||
func NewFeeder() *Feeder {
|
||||
new := &Feeder{
|
||||
input: make(chan *entropyData, 0),
|
||||
input: make(chan *entropyData),
|
||||
needsEntropy: abool.NewBool(true),
|
||||
buffer: container.New(),
|
||||
}
|
||||
|
|
|
@ -25,28 +25,6 @@ var (
|
|||
type reader struct{}
|
||||
|
||||
func init() {
|
||||
config.Register(&config.Option{
|
||||
Name: "Reseed after x seconds",
|
||||
Key: "random/reseed_after_seconds",
|
||||
Description: "Number of seconds until reseed",
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
OptType: config.OptTypeInt,
|
||||
DefaultValue: 360, // ten minutes
|
||||
ValidationRegex: "^[1-9][0-9]{1,5}$",
|
||||
})
|
||||
reseedAfterSeconds = config.Concurrent.GetAsInt("random/reseed_after_seconds", 360)
|
||||
|
||||
config.Register(&config.Option{
|
||||
Name: "Reseed after x bytes",
|
||||
Key: "random/reseed_after_bytes",
|
||||
Description: "Number of fetched bytes until reseed",
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
OptType: config.OptTypeInt,
|
||||
DefaultValue: 1000000, // one megabyte
|
||||
ValidationRegex: "^[1-9][0-9]{2,9}$",
|
||||
})
|
||||
reseedAfterBytes = config.GetAsInt("random/reseed_after_bytes", 1000000)
|
||||
|
||||
Reader = reader{}
|
||||
}
|
||||
|
||||
|
@ -55,7 +33,7 @@ func checkEntropy() (err error) {
|
|||
return errors.New("RNG is not ready yet")
|
||||
}
|
||||
if rngBytesRead > reseedAfterBytes() ||
|
||||
int64(time.Now().Sub(rngLastFeed).Seconds()) > reseedAfterSeconds() {
|
||||
int64(time.Since(rngLastFeed).Seconds()) > reseedAfterSeconds() {
|
||||
select {
|
||||
case r := <-rngFeeder:
|
||||
rng.Reseed(r)
|
||||
|
|
|
@ -19,13 +19,15 @@ var (
|
|||
rngReady = false
|
||||
rngCipherOption config.StringOption
|
||||
|
||||
shutdownSignal = make(chan struct{}, 0)
|
||||
shutdownSignal = make(chan struct{})
|
||||
)
|
||||
|
||||
func init() {
|
||||
modules.Register("random", prep, Start, stop, "base")
|
||||
modules.Register("random", prep, Start, nil, "base")
|
||||
}
|
||||
|
||||
config.Register(&config.Option{
|
||||
func prep() error {
|
||||
err := config.Register(&config.Option{
|
||||
Name: "RNG Cipher",
|
||||
Key: "random/rng_cipher",
|
||||
Description: "Cipher to use for the Fortuna RNG. Requires restart to take effect.",
|
||||
|
@ -35,10 +37,53 @@ func init() {
|
|||
DefaultValue: "aes",
|
||||
ValidationRegex: "^(aes|serpent)$",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rngCipherOption = config.GetAsString("random/rng_cipher", "aes")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Minimum Feed Entropy",
|
||||
Key: "random/min_feed_entropy",
|
||||
Description: "The minimum amount of entropy before a entropy source is feed to the RNG, in bits.",
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
OptType: config.OptTypeInt,
|
||||
DefaultValue: 256,
|
||||
ValidationRegex: "^[0-9]{3,5}$",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
minFeedEntropy = config.Concurrent.GetAsInt("random/min_feed_entropy", 256)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Reseed after x seconds",
|
||||
Key: "random/reseed_after_seconds",
|
||||
Description: "Number of seconds until reseed",
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
OptType: config.OptTypeInt,
|
||||
DefaultValue: 360, // ten minutes
|
||||
ValidationRegex: "^[1-9][0-9]{1,5}$",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reseedAfterSeconds = config.Concurrent.GetAsInt("random/reseed_after_seconds", 360)
|
||||
|
||||
err = config.Register(&config.Option{
|
||||
Name: "Reseed after x bytes",
|
||||
Key: "random/reseed_after_bytes",
|
||||
Description: "Number of fetched bytes until reseed",
|
||||
ExpertiseLevel: config.ExpertiseLevelDeveloper,
|
||||
OptType: config.OptTypeInt,
|
||||
DefaultValue: 1000000, // one megabyte
|
||||
ValidationRegex: "^[1-9][0-9]{2,9}$",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reseedAfterBytes = config.GetAsInt("random/reseed_after_bytes", 1000000)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -73,7 +118,3 @@ func Start() (err error) {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,21 +7,34 @@ import (
|
|||
)
|
||||
|
||||
func init() {
|
||||
prep()
|
||||
Start()
|
||||
err := prep()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = Start()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRNG(t *testing.T) {
|
||||
key := make([]byte, 16)
|
||||
|
||||
config.SetConfigOption("random.rng_cipher", "aes")
|
||||
_, err := newCipher(key)
|
||||
err := config.SetConfigOption("random/rng_cipher", "aes")
|
||||
if err != nil {
|
||||
t.Errorf("failed to set random/rng_cipher config: %s", err)
|
||||
}
|
||||
_, err = newCipher(key)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create aes cipher: %s", err)
|
||||
}
|
||||
rng.Reseed(key)
|
||||
|
||||
config.SetConfigOption("random.rng_cipher", "serpent")
|
||||
err = config.SetConfigOption("random/rng_cipher", "serpent")
|
||||
if err != nil {
|
||||
t.Errorf("failed to set random/rng_cipher config: %s", err)
|
||||
}
|
||||
_, err = newCipher(key)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create serpent cipher: %s", err)
|
||||
|
|
|
@ -55,7 +55,10 @@ func main() {
|
|||
switch os.Args[1] {
|
||||
case "fortuna":
|
||||
|
||||
random.Start()
|
||||
err := random.Start()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for {
|
||||
b, err := random.Bytes(64)
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package accessor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
@ -23,23 +21,9 @@ func NewJSONBytesAccessor(json *[]byte) *JSONBytesAccessor {
|
|||
func (ja *JSONBytesAccessor) Set(key string, value interface{}) error {
|
||||
result := gjson.GetBytes(*ja.json, key)
|
||||
if result.Exists() {
|
||||
switch value.(type) {
|
||||
case string:
|
||||
if result.Type != gjson.String {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
if result.Type != gjson.Number {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
|
||||
}
|
||||
case bool:
|
||||
if result.Type != gjson.True && result.Type != gjson.False {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
|
||||
}
|
||||
case []string:
|
||||
if !result.IsArray() {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
|
||||
}
|
||||
err := checkJSONValueType(result, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,23 +23,9 @@ func NewJSONAccessor(json *string) *JSONAccessor {
|
|||
func (ja *JSONAccessor) Set(key string, value interface{}) error {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
if result.Exists() {
|
||||
switch value.(type) {
|
||||
case string:
|
||||
if result.Type != gjson.String {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
if result.Type != gjson.Number {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
|
||||
}
|
||||
case bool:
|
||||
if result.Type != gjson.True && result.Type != gjson.False {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
|
||||
}
|
||||
case []string:
|
||||
if !result.IsArray() {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
|
||||
}
|
||||
err := checkJSONValueType(result, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,6 +37,28 @@ func (ja *JSONAccessor) Set(key string, value interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func checkJSONValueType(jsonValue gjson.Result, key string, value interface{}) error {
|
||||
switch value.(type) {
|
||||
case string:
|
||||
if jsonValue.Type != gjson.String {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
if jsonValue.Type != gjson.Number {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case bool:
|
||||
if jsonValue.Type != gjson.True && jsonValue.Type != gjson.False {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
case []string:
|
||||
if !jsonValue.IsArray() {
|
||||
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the value found by the given json key and whether it could be successfully extracted.
|
||||
func (ja *JSONAccessor) Get(key string) (value interface{}, ok bool) {
|
||||
result := gjson.Get(*ja.json, key)
|
||||
|
|
|
@ -160,10 +160,7 @@ func (sa *StructAccessor) GetBool(key string) (value bool, ok bool) {
|
|||
// Exists returns the whether the given key exists.
|
||||
func (sa *StructAccessor) Exists(key string) bool {
|
||||
field := sa.object.FieldByName(key)
|
||||
if field.IsValid() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return field.IsValid()
|
||||
}
|
||||
|
||||
// Type returns the accessor type as a string.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:maligned,unparam
|
||||
package accessor
|
||||
|
||||
import (
|
||||
|
|
|
@ -30,12 +30,12 @@ type Controller struct {
|
|||
}
|
||||
|
||||
// newController creates a new controller for a storage.
|
||||
func newController(storageInt storage.Interface) (*Controller, error) {
|
||||
func newController(storageInt storage.Interface) *Controller {
|
||||
return &Controller{
|
||||
storage: storageInt,
|
||||
migrating: abool.NewBool(false),
|
||||
hibernating: abool.NewBool(false),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadOnly returns whether the storage is read only.
|
||||
|
@ -221,7 +221,7 @@ func (c *Controller) MaintainThorough() error {
|
|||
|
||||
// Shutdown shuts down the storage.
|
||||
func (c *Controller) Shutdown() error {
|
||||
// aquire full locks
|
||||
// acquire full locks
|
||||
c.readLock.Lock()
|
||||
defer c.readLock.Unlock()
|
||||
c.writeLock.Lock()
|
||||
|
|
|
@ -51,12 +51,7 @@ func getController(name string) (*Controller, error) {
|
|||
return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err)
|
||||
}
|
||||
|
||||
// create controller
|
||||
controller, err = newController(storageInt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err)
|
||||
}
|
||||
|
||||
controller = newController(storageInt)
|
||||
controllers[name] = controller
|
||||
return controller, nil
|
||||
}
|
||||
|
@ -87,11 +82,7 @@ func InjectDatabase(name string, storageInt storage.Interface) (*Controller, err
|
|||
return nil, fmt.Errorf(`database not of type "injected"`)
|
||||
}
|
||||
|
||||
controller, err := newController(storageInt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err)
|
||||
}
|
||||
|
||||
controller := newController(storageInt)
|
||||
controllers[name] = controller
|
||||
return controller, nil
|
||||
}
|
||||
|
|
|
@ -97,7 +97,7 @@ func testDatabase(t *testing.T, storageType string) {
|
|||
}
|
||||
|
||||
cnt := 0
|
||||
for _ = range it.Next {
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
|
@ -124,7 +124,7 @@ func TestDatabaseSystem(t *testing.T) {
|
|||
go func() {
|
||||
time.Sleep(10 * time.Second)
|
||||
fmt.Println("===== TAKING TOO LONG - PRINTING STACK TRACES =====")
|
||||
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ func start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
startMaintainer()
|
||||
registerMaintenanceTasks()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,45 +1,37 @@
|
|||
package dbmodule
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
var (
|
||||
maintenanceShortTickDuration = 10 * time.Minute
|
||||
maintenanceLongTickDuration = 1 * time.Hour
|
||||
)
|
||||
|
||||
func startMaintainer() {
|
||||
module.AddWorkers(1)
|
||||
go maintenanceWorker()
|
||||
func registerMaintenanceTasks() {
|
||||
module.NewTask("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute)
|
||||
module.NewTask("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour)
|
||||
module.NewTask("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour)
|
||||
}
|
||||
|
||||
func maintenanceWorker() {
|
||||
ticker := time.NewTicker(maintenanceShortTickDuration)
|
||||
longTicker := time.NewTicker(maintenanceLongTickDuration)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := database.Maintain()
|
||||
if err != nil {
|
||||
log.Errorf("database: maintenance error: %s", err)
|
||||
}
|
||||
case <-longTicker.C:
|
||||
err := database.MaintainRecordStates()
|
||||
if err != nil {
|
||||
log.Errorf("database: record states maintenance error: %s", err)
|
||||
}
|
||||
err = database.MaintainThorough()
|
||||
if err != nil {
|
||||
log.Errorf("database: thorough maintenance error: %s", err)
|
||||
}
|
||||
case <-module.ShuttingDown():
|
||||
module.FinishWorker()
|
||||
return
|
||||
}
|
||||
func maintainBasic(ctx context.Context, task *modules.Task) {
|
||||
err := database.Maintain()
|
||||
if err != nil {
|
||||
log.Errorf("database: maintenance error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func maintainThorough(ctx context.Context, task *modules.Task) {
|
||||
err := database.MaintainThorough()
|
||||
if err != nil {
|
||||
log.Errorf("database: thorough maintenance error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func maintainRecords(ctx context.Context, task *modules.Task) {
|
||||
err := database.MaintainRecordStates()
|
||||
if err != nil {
|
||||
log.Errorf("database: record states maintenance error: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
package dbutils
|
||||
|
||||
type Meta struct {
|
||||
Created int64 `json:"c,omitempty" bson:"c,omitempty"`
|
||||
Modified int64 `json:"m,omitempty" bson:"m,omitempty"`
|
||||
Expires int64 `json:"e,omitempty" bson:"e,omitempty"`
|
||||
Deleted int64 `json:"d,omitempty" bson:"d,omitempty"`
|
||||
Secret bool `json:"s,omitempty" bson:"s,omitempty"` // secrets must not be sent to clients, only synced between cores
|
||||
Cronjewel bool `json:"j,omitempty" bson:"j,omitempty"` // crownjewels must never leave the instance
|
||||
}
|
|
@ -23,7 +23,7 @@ type RegisteredHook struct {
|
|||
h Hook
|
||||
}
|
||||
|
||||
// RegisterHook registeres a hook for records matching the given query in the database.
|
||||
// RegisterHook registers a hook for records matching the given query in the database.
|
||||
func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) {
|
||||
_, err := q.Check()
|
||||
if err != nil {
|
||||
|
|
|
@ -80,7 +80,7 @@ func (i *Interface) checkCache(key string) (record.Record, bool) {
|
|||
|
||||
func (i *Interface) updateCache(r record.Record) {
|
||||
if i.cache != nil {
|
||||
i.cache.Set(r.Key(), r)
|
||||
_ = i.cache.Set(r.Key(), r)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -69,11 +69,17 @@ func MaintainRecordStates() error {
|
|||
}
|
||||
|
||||
for _, r := range toDelete {
|
||||
c.storage.Delete(r.DatabaseKey())
|
||||
err := c.storage.Delete(r.DatabaseKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, r := range toExpire {
|
||||
r.Meta().Delete()
|
||||
return c.Put(r)
|
||||
err := c.Put(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ func (c *andCond) check() (err error) {
|
|||
}
|
||||
|
||||
func (c *andCond) string() string {
|
||||
var all []string
|
||||
all := make([]string, 0, len(c.conditions))
|
||||
for _, cond := range c.conditions {
|
||||
all = append(all, cond.string())
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ func newIntCondition(key string, operator uint8, value interface{}) *intConditio
|
|||
case int32:
|
||||
parsedValue = int64(v)
|
||||
case int64:
|
||||
parsedValue = int64(v)
|
||||
parsedValue = v
|
||||
case uint:
|
||||
parsedValue = int64(v)
|
||||
case uint8:
|
||||
|
|
|
@ -38,7 +38,7 @@ func (c *orCond) check() (err error) {
|
|||
}
|
||||
|
||||
func (c *orCond) string() string {
|
||||
var all []string
|
||||
all := make([]string, 0, len(c.conditions))
|
||||
for _, cond := range c.conditions {
|
||||
all = append(all, cond.string())
|
||||
}
|
||||
|
|
|
@ -205,7 +205,7 @@ func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() in
|
|||
for {
|
||||
if !expectingMore && rootCondition && remainingSnippets() == 0 {
|
||||
// advance snippetsPos by one, as it will be set back by 1
|
||||
getSnippet()
|
||||
getSnippet() //nolint:errcheck
|
||||
if len(conditions) == 1 {
|
||||
return conditions[0], nil
|
||||
}
|
||||
|
@ -330,7 +330,7 @@ func parseCondition(firstSnippet *snippet, getSnippet func() (*snippet, error))
|
|||
}
|
||||
|
||||
var (
|
||||
escapeReplacer = regexp.MustCompile("\\\\([^\\\\])")
|
||||
escapeReplacer = regexp.MustCompile(`\\([^\\])`)
|
||||
)
|
||||
|
||||
// prepToken removes surrounding parenthesis and escape characters.
|
||||
|
|
|
@ -10,36 +10,36 @@ import (
|
|||
func TestExtractSnippets(t *testing.T) {
|
||||
text1 := `query test: where ( "bananas" > 100 and monkeys.# <= "12")or(coconuts < 10 "and" area > 50) or name sameas Julian or name matches ^King\ `
|
||||
result1 := []*snippet{
|
||||
&snippet{text: "query", globalPosition: 1},
|
||||
&snippet{text: "test:", globalPosition: 7},
|
||||
&snippet{text: "where", globalPosition: 13},
|
||||
&snippet{text: "(", globalPosition: 19},
|
||||
&snippet{text: "bananas", globalPosition: 21},
|
||||
&snippet{text: ">", globalPosition: 31},
|
||||
&snippet{text: "100", globalPosition: 33},
|
||||
&snippet{text: "and", globalPosition: 37},
|
||||
&snippet{text: "monkeys.#", globalPosition: 41},
|
||||
&snippet{text: "<=", globalPosition: 51},
|
||||
&snippet{text: "12", globalPosition: 54},
|
||||
&snippet{text: ")", globalPosition: 58},
|
||||
&snippet{text: "or", globalPosition: 59},
|
||||
&snippet{text: "(", globalPosition: 61},
|
||||
&snippet{text: "coconuts", globalPosition: 62},
|
||||
&snippet{text: "<", globalPosition: 71},
|
||||
&snippet{text: "10", globalPosition: 73},
|
||||
&snippet{text: "and", globalPosition: 76},
|
||||
&snippet{text: "area", globalPosition: 82},
|
||||
&snippet{text: ">", globalPosition: 87},
|
||||
&snippet{text: "50", globalPosition: 89},
|
||||
&snippet{text: ")", globalPosition: 91},
|
||||
&snippet{text: "or", globalPosition: 93},
|
||||
&snippet{text: "name", globalPosition: 96},
|
||||
&snippet{text: "sameas", globalPosition: 101},
|
||||
&snippet{text: "Julian", globalPosition: 108},
|
||||
&snippet{text: "or", globalPosition: 115},
|
||||
&snippet{text: "name", globalPosition: 118},
|
||||
&snippet{text: "matches", globalPosition: 123},
|
||||
&snippet{text: "^King ", globalPosition: 131},
|
||||
{text: "query", globalPosition: 1},
|
||||
{text: "test:", globalPosition: 7},
|
||||
{text: "where", globalPosition: 13},
|
||||
{text: "(", globalPosition: 19},
|
||||
{text: "bananas", globalPosition: 21},
|
||||
{text: ">", globalPosition: 31},
|
||||
{text: "100", globalPosition: 33},
|
||||
{text: "and", globalPosition: 37},
|
||||
{text: "monkeys.#", globalPosition: 41},
|
||||
{text: "<=", globalPosition: 51},
|
||||
{text: "12", globalPosition: 54},
|
||||
{text: ")", globalPosition: 58},
|
||||
{text: "or", globalPosition: 59},
|
||||
{text: "(", globalPosition: 61},
|
||||
{text: "coconuts", globalPosition: 62},
|
||||
{text: "<", globalPosition: 71},
|
||||
{text: "10", globalPosition: 73},
|
||||
{text: "and", globalPosition: 76},
|
||||
{text: "area", globalPosition: 82},
|
||||
{text: ">", globalPosition: 87},
|
||||
{text: "50", globalPosition: 89},
|
||||
{text: ")", globalPosition: 91},
|
||||
{text: "or", globalPosition: 93},
|
||||
{text: "name", globalPosition: 96},
|
||||
{text: "sameas", globalPosition: 101},
|
||||
{text: "Julian", globalPosition: 108},
|
||||
{text: "or", globalPosition: 115},
|
||||
{text: "name", globalPosition: 118},
|
||||
{text: "matches", globalPosition: 123},
|
||||
{text: "^King ", globalPosition: 131},
|
||||
}
|
||||
|
||||
snippets, err := extractSnippets(text1)
|
||||
|
|
|
@ -96,10 +96,7 @@ func (q *Query) IsChecked() bool {
|
|||
|
||||
// MatchesKey checks whether the query matches the supplied database key (key without database prefix).
|
||||
func (q *Query) MatchesKey(dbKey string) bool {
|
||||
if !strings.HasPrefix(dbKey, q.dbKeyPrefix) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return strings.HasPrefix(dbKey, q.dbKeyPrefix)
|
||||
}
|
||||
|
||||
// MatchesRecord checks whether the query matches the supplied database record (value only).
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:unparam
|
||||
package query
|
||||
|
||||
import (
|
||||
|
|
|
@ -148,10 +148,10 @@ func registryWriter() {
|
|||
select {
|
||||
case <-time.After(1 * time.Hour):
|
||||
if writeRegistrySoon.SetToIf(true, false) {
|
||||
saveRegistry(true)
|
||||
_ = saveRegistry(true)
|
||||
}
|
||||
case <-shutdownSignal:
|
||||
saveRegistry(true)
|
||||
_ = saveRegistry(true)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ type Badger struct {
|
|||
}
|
||||
|
||||
func init() {
|
||||
storage.Register("badger", NewBadger)
|
||||
_ = storage.Register("badger", NewBadger)
|
||||
}
|
||||
|
||||
// NewBadger opens/creates a badger database.
|
||||
|
@ -190,7 +190,7 @@ func (b *Badger) Injected() bool {
|
|||
|
||||
// Maintain runs a light maintenance operation on the database.
|
||||
func (b *Badger) Maintain() error {
|
||||
b.db.RunValueLogGC(0.7)
|
||||
_ = b.db.RunValueLogGC(0.7)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:unparam,maligned
|
||||
package badger
|
||||
|
||||
import (
|
||||
|
@ -13,7 +14,7 @@ import (
|
|||
|
||||
type TestRecord struct {
|
||||
record.Base
|
||||
lock sync.Mutex
|
||||
sync.Mutex
|
||||
S string
|
||||
I int
|
||||
I8 int8
|
||||
|
@ -30,12 +31,6 @@ type TestRecord struct {
|
|||
B bool
|
||||
}
|
||||
|
||||
func (tr *TestRecord) Lock() {
|
||||
}
|
||||
|
||||
func (tr *TestRecord) Unlock() {
|
||||
}
|
||||
|
||||
func TestBadger(t *testing.T) {
|
||||
testDir, err := ioutil.TempDir("", "testing-")
|
||||
if err != nil {
|
||||
|
@ -98,7 +93,7 @@ func TestBadger(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
cnt := 0
|
||||
for _ = range it.Next {
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
|
|
|
@ -26,7 +26,7 @@ type BBolt struct {
|
|||
}
|
||||
|
||||
func init() {
|
||||
storage.Register("bbolt", NewBBolt)
|
||||
_ = storage.Register("bbolt", NewBBolt)
|
||||
}
|
||||
|
||||
// NewBBolt opens/creates a bbolt database.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:unparam,maligned
|
||||
package bbolt
|
||||
|
||||
import (
|
||||
|
@ -13,7 +14,7 @@ import (
|
|||
|
||||
type TestRecord struct {
|
||||
record.Base
|
||||
lock sync.Mutex
|
||||
sync.Mutex
|
||||
S string
|
||||
I int
|
||||
I8 int8
|
||||
|
@ -30,12 +31,6 @@ type TestRecord struct {
|
|||
B bool
|
||||
}
|
||||
|
||||
func (tr *TestRecord) Lock() {
|
||||
}
|
||||
|
||||
func (tr *TestRecord) Unlock() {
|
||||
}
|
||||
|
||||
func TestBadger(t *testing.T) {
|
||||
testDir, err := ioutil.TempDir("", "testing-")
|
||||
if err != nil {
|
||||
|
@ -126,7 +121,7 @@ func TestBadger(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
cnt := 0
|
||||
for _ = range it.Next {
|
||||
for range it.Next {
|
||||
cnt++
|
||||
}
|
||||
if it.Err() != nil {
|
||||
|
|
|
@ -35,7 +35,7 @@ type FSTree struct {
|
|||
}
|
||||
|
||||
func init() {
|
||||
storage.Register("fstree", NewFSTree)
|
||||
_ = storage.Register("fstree", NewFSTree)
|
||||
}
|
||||
|
||||
// NewFSTree returns a (new) FSTree database.
|
||||
|
@ -160,15 +160,14 @@ func (fst *FSTree) Query(q *query.Query, local, internal bool) (*iterator.Iterat
|
|||
}
|
||||
fileInfo, err := os.Stat(walkPrefix)
|
||||
var walkRoot string
|
||||
if err == nil {
|
||||
if fileInfo.IsDir() {
|
||||
walkRoot = walkPrefix
|
||||
} else {
|
||||
walkRoot = filepath.Dir(walkPrefix)
|
||||
}
|
||||
} else if os.IsNotExist(err) {
|
||||
switch {
|
||||
case err == nil && fileInfo.IsDir():
|
||||
walkRoot = walkPrefix
|
||||
case err == nil:
|
||||
walkRoot = filepath.Dir(walkPrefix)
|
||||
} else {
|
||||
case os.IsNotExist(err):
|
||||
walkRoot = filepath.Dir(walkPrefix)
|
||||
default: // err != nil
|
||||
return nil, fmt.Errorf("fstree: could not stat query root %s: %s", walkPrefix, err)
|
||||
}
|
||||
|
||||
|
@ -279,7 +278,7 @@ func writeFile(filename string, data []byte, perm os.FileMode) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer t.Cleanup()
|
||||
defer t.Cleanup() //nolint:errcheck
|
||||
|
||||
// Set permissions before writing data, in case the data is sensitive.
|
||||
if !onWindows {
|
||||
|
|
|
@ -15,7 +15,7 @@ type Sinkhole struct {
|
|||
}
|
||||
|
||||
func init() {
|
||||
storage.Register("sinkhole", NewSinkhole)
|
||||
_ = storage.Register("sinkhole", NewSinkhole)
|
||||
}
|
||||
|
||||
// NewSinkhole creates a dummy database.
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
package kvops
|
|
@ -25,7 +25,6 @@ const (
|
|||
|
||||
// define errors
|
||||
var errNoMoreSpace = errors.New("dsd: no more space left after reading dsd type")
|
||||
var errUnknownType = errors.New("dsd: tried to unpack unknown type")
|
||||
var errNotImplemented = errors.New("dsd: this type is not yet implemented")
|
||||
|
||||
// Load loads an dsd structured data blob into the given interface.
|
||||
|
@ -58,7 +57,8 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
|
|||
return nil, fmt.Errorf("dsd: failed to unpack json data: %s", data)
|
||||
}
|
||||
return t, nil
|
||||
// case BSON:
|
||||
case BSON:
|
||||
return nil, errNotImplemented
|
||||
// err := bson.Unmarshal(data[read:], t)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
|
@ -92,7 +92,7 @@ func Dump(t interface{}, format uint8) ([]byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
f := varint.Pack8(uint8(format))
|
||||
f := varint.Pack8(format)
|
||||
var data []byte
|
||||
var err error
|
||||
switch format {
|
||||
|
@ -106,7 +106,8 @@ func Dump(t interface{}, format uint8) ([]byte, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// case BSON:
|
||||
case BSON:
|
||||
return nil, errNotImplemented
|
||||
// data, err = bson.Marshal(t)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:maligned,unparam,gocyclo
|
||||
package dsd
|
||||
|
||||
import (
|
||||
|
@ -97,8 +98,7 @@ func TestConversion(t *testing.T) {
|
|||
}
|
||||
|
||||
bString := "b"
|
||||
var bBytes byte
|
||||
bBytes = 0x02
|
||||
var bBytes byte = 0x02
|
||||
|
||||
complexSubject := ComplexTestStruct{
|
||||
-1,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint:nakedret,unconvert
|
||||
package dsd
|
||||
|
||||
import (
|
||||
|
|
|
@ -9,5 +9,5 @@ var (
|
|||
|
||||
func init() {
|
||||
flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]")
|
||||
flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,firewall=debug")
|
||||
flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,notifications=debug")
|
||||
}
|
||||
|
|
|
@ -7,9 +7,12 @@ import (
|
|||
|
||||
var counter uint16
|
||||
|
||||
const maxCount uint16 = 999
|
||||
const (
|
||||
maxCount uint16 = 999
|
||||
timeFormat string = "060102 15:04:05.000"
|
||||
)
|
||||
|
||||
func (s severity) String() string {
|
||||
func (s Severity) String() string {
|
||||
switch s {
|
||||
case TraceLevel:
|
||||
return "TRAC"
|
||||
|
@ -41,25 +44,25 @@ func formatLine(line *logLine, duplicates uint64, useColor bool) string {
|
|||
|
||||
var fLine string
|
||||
if line.line == 0 {
|
||||
fLine = fmt.Sprintf("%s%s ? %s %s %03d%s%s %s", colorStart, line.timestamp.Format("060102 15:04:05.000"), rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg)
|
||||
fLine = fmt.Sprintf("%s%s ? %s %s %03d%s%s %s", colorStart, line.timestamp.Format(timeFormat), rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg)
|
||||
} else {
|
||||
fLen := len(line.file)
|
||||
fPartStart := fLen - 10
|
||||
if fPartStart < 0 {
|
||||
fPartStart = 0
|
||||
}
|
||||
fLine = fmt.Sprintf("%s%s %s:%03d %s %s %03d%s%s %s", colorStart, line.timestamp.Format("060102 15:04:05.000"), line.file[fPartStart:], line.line, rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg)
|
||||
fLine = fmt.Sprintf("%s%s %s:%03d %s %s %03d%s%s %s", colorStart, line.timestamp.Format(timeFormat), line.file[fPartStart:], line.line, rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg)
|
||||
}
|
||||
|
||||
if line.trace != nil {
|
||||
if line.tracer != nil {
|
||||
// append full trace time
|
||||
if len(line.trace.actions) > 0 {
|
||||
fLine += fmt.Sprintf(" Σ=%s", line.timestamp.Sub(line.trace.actions[0].timestamp))
|
||||
if len(line.tracer.logs) > 0 {
|
||||
fLine += fmt.Sprintf(" Σ=%s", line.timestamp.Sub(line.tracer.logs[0].timestamp))
|
||||
}
|
||||
|
||||
// append all trace actions
|
||||
var d time.Duration
|
||||
for i, action := range line.trace.actions {
|
||||
for i, action := range line.tracer.logs {
|
||||
// set color
|
||||
if useColor {
|
||||
colorStart = action.level.color()
|
||||
|
@ -71,10 +74,10 @@ func formatLine(line *logLine, duplicates uint64, useColor bool) string {
|
|||
fPartStart = 0
|
||||
}
|
||||
// format
|
||||
if i == len(line.trace.actions)-1 { // last
|
||||
if i == len(line.tracer.logs)-1 { // last
|
||||
d = line.timestamp.Sub(action.timestamp)
|
||||
} else {
|
||||
d = line.trace.actions[i+1].timestamp.Sub(action.timestamp)
|
||||
d = line.tracer.logs[i+1].timestamp.Sub(action.timestamp)
|
||||
}
|
||||
fLine += fmt.Sprintf("\n%s%19s %s:%03d %s %s%s %s", colorStart, d, action.file[fPartStart:], action.line, rightArrow, action.level.String(), colorEnd, action.msg)
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ const (
|
|||
// colorWhite = "\033[37m"
|
||||
)
|
||||
|
||||
func (s severity) color() string {
|
||||
func (s Severity) color() string {
|
||||
switch s {
|
||||
case DebugLevel:
|
||||
return colorCyan
|
||||
|
|
|
@ -28,7 +28,7 @@ func init() {
|
|||
colorsSupported = osdetail.EnableColorSupport()
|
||||
}
|
||||
|
||||
func (s severity) color() string {
|
||||
func (s Severity) color() string {
|
||||
if colorsSupported {
|
||||
switch s {
|
||||
case DebugLevel:
|
||||
|
|
104
log/input.go
104
log/input.go
|
@ -8,33 +8,18 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func fastcheck(level severity) bool {
|
||||
if pkgLevelsActive.IsSet() {
|
||||
return true
|
||||
}
|
||||
if uint32(level) < atomic.LoadUint32(logLevel) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func log(level severity, msg string, trace *ContextTracer) {
|
||||
func log(level Severity, msg string, tracer *ContextTracer) {
|
||||
|
||||
if !started.IsSet() {
|
||||
// a bit resouce intense, but keeps logs before logging started.
|
||||
// a bit resource intense, but keeps logs before logging started.
|
||||
// FIXME: create option to disable logging
|
||||
go func() {
|
||||
<-startedSignal
|
||||
log(level, msg, trace)
|
||||
log(level, msg, tracer)
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// check if level is enabled
|
||||
if !pkgLevelsActive.IsSet() && uint32(level) < atomic.LoadUint32(logLevel) {
|
||||
return
|
||||
}
|
||||
|
||||
// get time
|
||||
now := time.Now()
|
||||
|
||||
|
@ -53,26 +38,33 @@ func log(level severity, msg string, trace *ContextTracer) {
|
|||
|
||||
// check if level is enabled for file or generally
|
||||
if pkgLevelsActive.IsSet() {
|
||||
fileOnly := strings.Split(file, "/")
|
||||
if len(fileOnly) < 2 {
|
||||
pathSegments := strings.Split(file, "/")
|
||||
if len(pathSegments) < 2 {
|
||||
// file too short for package levels
|
||||
return
|
||||
}
|
||||
sev, ok := pkgLevels[fileOnly[len(fileOnly)-2]]
|
||||
pkgLevelsLock.Lock()
|
||||
severity, ok := pkgLevels[pathSegments[len(pathSegments)-2]]
|
||||
pkgLevelsLock.Unlock()
|
||||
if ok {
|
||||
if level < sev {
|
||||
if level < severity {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// no package level set, check against global level
|
||||
if uint32(level) < atomic.LoadUint32(logLevel) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if uint32(level) < atomic.LoadUint32(logLevel) {
|
||||
// no package levels set, check against global level
|
||||
return
|
||||
}
|
||||
|
||||
// create log object
|
||||
log := &logLine{
|
||||
msg: msg,
|
||||
trace: trace,
|
||||
tracer: tracer,
|
||||
level: level,
|
||||
timestamp: now,
|
||||
file: file,
|
||||
|
@ -83,93 +75,107 @@ func log(level severity, msg string, trace *ContextTracer) {
|
|||
select {
|
||||
case logBuffer <- log:
|
||||
default:
|
||||
forceEmptyingOfBuffer <- true
|
||||
forceEmptyingOfBuffer <- struct{}{}
|
||||
logBuffer <- log
|
||||
}
|
||||
|
||||
// wake up writer if necessary
|
||||
if logsWaitingFlag.SetToIf(false, true) {
|
||||
logsWaiting <- true
|
||||
logsWaiting <- struct{}{}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Tracef(things ...interface{}) {
|
||||
if fastcheck(TraceLevel) {
|
||||
log(TraceLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
func fastcheck(level Severity) bool {
|
||||
if pkgLevelsActive.IsSet() {
|
||||
return true
|
||||
}
|
||||
if uint32(level) >= atomic.LoadUint32(logLevel) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Trace is used to log tiny steps. Log traces to context if you can!
|
||||
func Trace(msg string) {
|
||||
if fastcheck(TraceLevel) {
|
||||
log(TraceLevel, msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func Debugf(things ...interface{}) {
|
||||
if fastcheck(DebugLevel) {
|
||||
log(DebugLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
// Tracef is used to log tiny steps. Log traces to context if you can!
|
||||
func Tracef(format string, things ...interface{}) {
|
||||
if fastcheck(TraceLevel) {
|
||||
log(TraceLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Debug is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem.
|
||||
func Debug(msg string) {
|
||||
if fastcheck(DebugLevel) {
|
||||
log(DebugLevel, msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func Infof(things ...interface{}) {
|
||||
if fastcheck(InfoLevel) {
|
||||
log(InfoLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
// Debugf is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem.
|
||||
func Debugf(format string, things ...interface{}) {
|
||||
if fastcheck(DebugLevel) {
|
||||
log(DebugLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Info is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen.
|
||||
func Info(msg string) {
|
||||
if fastcheck(InfoLevel) {
|
||||
log(InfoLevel, msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func Warningf(things ...interface{}) {
|
||||
if fastcheck(WarningLevel) {
|
||||
log(WarningLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
// Infof is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen.
|
||||
func Infof(format string, things ...interface{}) {
|
||||
if fastcheck(InfoLevel) {
|
||||
log(InfoLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Warning is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet.
|
||||
func Warning(msg string) {
|
||||
if fastcheck(WarningLevel) {
|
||||
log(WarningLevel, msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func Errorf(things ...interface{}) {
|
||||
if fastcheck(ErrorLevel) {
|
||||
log(ErrorLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
// Warningf is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet.
|
||||
func Warningf(format string, things ...interface{}) {
|
||||
if fastcheck(WarningLevel) {
|
||||
log(WarningLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Error is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational. Maybe User/Admin should be informed.
|
||||
func Error(msg string) {
|
||||
if fastcheck(ErrorLevel) {
|
||||
log(ErrorLevel, msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func Criticalf(things ...interface{}) {
|
||||
if fastcheck(CriticalLevel) {
|
||||
log(CriticalLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
// Errorf is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational.
|
||||
func Errorf(format string, things ...interface{}) {
|
||||
if fastcheck(ErrorLevel) {
|
||||
log(ErrorLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Critical is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
|
||||
func Critical(msg string) {
|
||||
if fastcheck(CriticalLevel) {
|
||||
log(CriticalLevel, msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func Testf(things ...interface{}) {
|
||||
fmt.Printf(things[0].(string), things[1:]...)
|
||||
}
|
||||
|
||||
func Test(msg string) {
|
||||
fmt.Println(msg)
|
||||
// Criticalf is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
|
||||
func Criticalf(format string, things ...interface{}) {
|
||||
if fastcheck(CriticalLevel) {
|
||||
log(CriticalLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,12 +30,13 @@ import (
|
|||
- logging is configured by main module and is supplied access to configuration and taskmanager
|
||||
*/
|
||||
|
||||
type severity uint32
|
||||
// Severity describes a log level.
|
||||
type Severity uint32
|
||||
|
||||
type logLine struct {
|
||||
msg string
|
||||
trace *ContextTracer
|
||||
level severity
|
||||
tracer *ContextTracer
|
||||
level Severity
|
||||
timestamp time.Time
|
||||
file string
|
||||
line int
|
||||
|
@ -45,7 +46,7 @@ func (ll *logLine) Equal(ol *logLine) bool {
|
|||
switch {
|
||||
case ll.msg != ol.msg:
|
||||
return false
|
||||
case ll.trace != nil || ol.trace != nil:
|
||||
case ll.tracer != nil || ol.tracer != nil:
|
||||
return false
|
||||
case ll.file != ol.file:
|
||||
return false
|
||||
|
@ -57,55 +58,58 @@ func (ll *logLine) Equal(ol *logLine) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// Log Levels
|
||||
const (
|
||||
TraceLevel severity = 1
|
||||
DebugLevel severity = 2
|
||||
InfoLevel severity = 3
|
||||
WarningLevel severity = 4
|
||||
ErrorLevel severity = 5
|
||||
CriticalLevel severity = 6
|
||||
TraceLevel Severity = 1
|
||||
DebugLevel Severity = 2
|
||||
InfoLevel Severity = 3
|
||||
WarningLevel Severity = 4
|
||||
ErrorLevel Severity = 5
|
||||
CriticalLevel Severity = 6
|
||||
)
|
||||
|
||||
var (
|
||||
logBuffer chan *logLine
|
||||
forceEmptyingOfBuffer chan bool
|
||||
forceEmptyingOfBuffer chan struct{}
|
||||
|
||||
logLevelInt = uint32(3)
|
||||
logLevel = &logLevelInt
|
||||
|
||||
pkgLevelsActive = abool.NewBool(false)
|
||||
pkgLevels = make(map[string]severity)
|
||||
pkgLevels = make(map[string]Severity)
|
||||
pkgLevelsLock sync.Mutex
|
||||
|
||||
logsWaiting = make(chan bool, 1)
|
||||
logsWaiting = make(chan struct{}, 1)
|
||||
logsWaitingFlag = abool.NewBool(false)
|
||||
|
||||
shutdownSignal = make(chan struct{}, 0)
|
||||
shutdownSignal = make(chan struct{})
|
||||
shutdownWaitGroup sync.WaitGroup
|
||||
|
||||
initializing = abool.NewBool(false)
|
||||
started = abool.NewBool(false)
|
||||
startedSignal = make(chan struct{}, 0)
|
||||
|
||||
testErrors = abool.NewBool(false)
|
||||
startedSignal = make(chan struct{})
|
||||
)
|
||||
|
||||
func SetPkgLevels(levels map[string]severity) {
|
||||
// SetPkgLevels sets individual log levels for packages.
|
||||
func SetPkgLevels(levels map[string]Severity) {
|
||||
pkgLevelsLock.Lock()
|
||||
pkgLevels = levels
|
||||
pkgLevelsLock.Unlock()
|
||||
pkgLevelsActive.Set()
|
||||
}
|
||||
|
||||
// UnSetPkgLevels removes all individual log levels for packages.
|
||||
func UnSetPkgLevels() {
|
||||
pkgLevelsActive.UnSet()
|
||||
}
|
||||
|
||||
func SetLogLevel(level severity) {
|
||||
// SetLogLevel sets a new log level.
|
||||
func SetLogLevel(level Severity) {
|
||||
atomic.StoreUint32(logLevel, uint32(level))
|
||||
}
|
||||
|
||||
func ParseLevel(level string) severity {
|
||||
// ParseLevel returns the level severity of a log level name.
|
||||
func ParseLevel(level string) Severity {
|
||||
switch strings.ToLower(level) {
|
||||
case "trace":
|
||||
return 1
|
||||
|
@ -123,14 +127,15 @@ func ParseLevel(level string) severity {
|
|||
return 0
|
||||
}
|
||||
|
||||
// Start starts the logging system. Must be called in order to see logs.
|
||||
func Start() (err error) {
|
||||
|
||||
if !initializing.SetToIf(false, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
logBuffer = make(chan *logLine, 8192)
|
||||
forceEmptyingOfBuffer = make(chan bool, 4)
|
||||
logBuffer = make(chan *logLine, 1024)
|
||||
forceEmptyingOfBuffer = make(chan struct{}, 16)
|
||||
|
||||
initialLogLevel := ParseLevel(logLevelFlag)
|
||||
if initialLogLevel > 0 {
|
||||
|
@ -143,7 +148,7 @@ func Start() (err error) {
|
|||
// get and set file loglevels
|
||||
pkgLogLevels := pkgLogLevelsFlag
|
||||
if len(pkgLogLevels) > 0 {
|
||||
newPkgLevels := make(map[string]severity)
|
||||
newPkgLevels := make(map[string]Severity)
|
||||
for _, pair := range strings.Split(pkgLogLevels, ",") {
|
||||
splitted := strings.Split(pair, "=")
|
||||
if len(splitted) != 2 {
|
||||
|
@ -162,6 +167,9 @@ func Start() (err error) {
|
|||
SetPkgLevels(newPkgLevels)
|
||||
}
|
||||
|
||||
if !schedulingEnabled {
|
||||
close(writeTrigger)
|
||||
}
|
||||
startWriter()
|
||||
|
||||
started.Set()
|
||||
|
@ -170,7 +178,7 @@ func Start() (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
// Shutdown writes remaining log lines and then stops the logger.
|
||||
// Shutdown writes remaining log lines and then stops the log system.
|
||||
func Shutdown() {
|
||||
close(shutdownSignal)
|
||||
shutdownWaitGroup.Wait()
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// test waiting
|
||||
func TestLogging(t *testing.T) {
|
||||
|
||||
func init() {
|
||||
err := Start()
|
||||
if err != nil {
|
||||
t.Errorf("start failed: %s", err)
|
||||
panic(fmt.Sprintf("start failed: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
// test waiting
|
||||
func TestLogging(t *testing.T) {
|
||||
|
||||
// skip
|
||||
if testing.Short() {
|
||||
|
|
|
@ -3,10 +3,35 @@ package log
|
|||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/taskmanager"
|
||||
)
|
||||
|
||||
var (
|
||||
schedulingEnabled = false
|
||||
writeTrigger = make(chan struct{})
|
||||
)
|
||||
|
||||
// EnableScheduling enables external scheduling of the logger. This will require to manually trigger writes via TriggerWrite whenevery logs should be written. Please note that full buffers will also trigger writing. Must be called before Start() to have an effect.
|
||||
func EnableScheduling() {
|
||||
if !initializing.IsSet() {
|
||||
schedulingEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerWriter triggers log output writing.
|
||||
func TriggerWriter() {
|
||||
if started.IsSet() && schedulingEnabled {
|
||||
select {
|
||||
case writeTrigger <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerWriterChannel returns the channel to trigger log writing. Returned channel will close if EnableScheduling() is not called correctly.
|
||||
func TriggerWriterChannel() chan struct{} {
|
||||
return writeTrigger
|
||||
}
|
||||
|
||||
func writeLine(line *logLine, duplicates uint64) {
|
||||
fmt.Println(formatLine(line, duplicates, true))
|
||||
// TODO: implement file logging and setting console/file logging
|
||||
|
@ -15,7 +40,7 @@ func writeLine(line *logLine, duplicates uint64) {
|
|||
|
||||
func startWriter() {
|
||||
shutdownWaitGroup.Add(1)
|
||||
fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format("060102 15:04:05.000"), rightArrow, endColor()))
|
||||
fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor()))
|
||||
go writer()
|
||||
}
|
||||
|
||||
|
@ -23,13 +48,12 @@ func writer() {
|
|||
var line *logLine
|
||||
var lastLine *logLine
|
||||
var duplicates uint64
|
||||
startedTask := false
|
||||
defer shutdownWaitGroup.Done()
|
||||
|
||||
for {
|
||||
// reset
|
||||
line = nil
|
||||
lastLine = nil
|
||||
lastLine = nil //nolint:ineffassign // only ineffectual in first loop
|
||||
duplicates = 0
|
||||
|
||||
// wait until logs need to be processed
|
||||
|
@ -37,23 +61,17 @@ func writer() {
|
|||
case <-logsWaiting:
|
||||
logsWaitingFlag.UnSet()
|
||||
case <-shutdownSignal:
|
||||
finalizeWriting()
|
||||
return
|
||||
}
|
||||
|
||||
// wait for timeslot to log, or when buffer is full
|
||||
select {
|
||||
case <-taskmanager.StartVeryLowPriorityMicroTask():
|
||||
startedTask = true
|
||||
case <-writeTrigger:
|
||||
case <-forceEmptyingOfBuffer:
|
||||
case <-shutdownSignal:
|
||||
for {
|
||||
select {
|
||||
case line = <-logBuffer:
|
||||
writeLine(line, duplicates)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
fmt.Println(fmt.Sprintf("%s%s %s EOF%s", InfoLevel.color(), time.Now().Format("060102 15:04:05.000"), leftArrow, endColor()))
|
||||
return
|
||||
}
|
||||
}
|
||||
finalizeWriting()
|
||||
return
|
||||
}
|
||||
|
||||
// write all the logs!
|
||||
|
@ -77,22 +95,17 @@ func writer() {
|
|||
// deduplication
|
||||
if !line.Equal(lastLine) {
|
||||
// no duplicate
|
||||
writeLine(lastLine, duplicates)
|
||||
duplicates = 0
|
||||
} else {
|
||||
// duplicate
|
||||
duplicates++
|
||||
break dedupLoop
|
||||
}
|
||||
|
||||
// duplicate
|
||||
duplicates++
|
||||
}
|
||||
|
||||
// write actual line
|
||||
writeLine(line, duplicates)
|
||||
duplicates = 0
|
||||
default:
|
||||
if startedTask {
|
||||
taskmanager.EndMicroTask()
|
||||
startedTask = false
|
||||
}
|
||||
break writeLoop
|
||||
}
|
||||
}
|
||||
|
@ -101,7 +114,21 @@ func writer() {
|
|||
select {
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
case <-shutdownSignal:
|
||||
finalizeWriting()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func finalizeWriting() {
|
||||
for {
|
||||
select {
|
||||
case line := <-logBuffer:
|
||||
writeLine(line, 0)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
fmt.Println(fmt.Sprintf("%s%s %s EOF%s", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor()))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
314
log/trace.go
314
log/trace.go
|
@ -4,42 +4,73 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Key for context value
|
||||
// ContextTracerKey is the key used for the context key/value storage.
|
||||
type ContextTracerKey struct{}
|
||||
|
||||
// ContextTracer is attached to a context in order bind logs to a context.
|
||||
type ContextTracer struct {
|
||||
sync.Mutex
|
||||
actions []*Action
|
||||
}
|
||||
|
||||
type Action struct {
|
||||
timestamp time.Time
|
||||
level severity
|
||||
msg string
|
||||
file string
|
||||
line int
|
||||
logs []*logLine
|
||||
}
|
||||
|
||||
var (
|
||||
key = ContextTracerKey{}
|
||||
nilTracer *ContextTracer
|
||||
key = ContextTracerKey{}
|
||||
)
|
||||
|
||||
func AddTracer(ctx context.Context) context.Context {
|
||||
if ctx != nil && fastcheckLevel(TraceLevel) {
|
||||
// AddTracer adds a ContextTracer to the returned Context. Will return a nil ContextTracer if logging level is not set to trace. Will return a nil ContextTracer if one already exists. Will return a nil ContextTracer in case of an error. Will return a nil context if nil.
|
||||
func AddTracer(ctx context.Context) (context.Context, *ContextTracer) {
|
||||
if ctx != nil && fastcheck(TraceLevel) {
|
||||
// check pkg levels
|
||||
if pkgLevelsActive.IsSet() {
|
||||
// get file
|
||||
_, file, _, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
// cannot get file, ignore
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
pathSegments := strings.Split(file, "/")
|
||||
if len(pathSegments) < 2 {
|
||||
// file too short for package levels
|
||||
return ctx, nil
|
||||
}
|
||||
pkgLevelsLock.Lock()
|
||||
severity, ok := pkgLevels[pathSegments[len(pathSegments)-2]]
|
||||
pkgLevelsLock.Unlock()
|
||||
if ok {
|
||||
// check against package level
|
||||
if TraceLevel < severity {
|
||||
return ctx, nil
|
||||
}
|
||||
} else {
|
||||
// no package level set, check against global level
|
||||
if uint32(TraceLevel) < atomic.LoadUint32(logLevel) {
|
||||
return ctx, nil
|
||||
}
|
||||
}
|
||||
} else if uint32(TraceLevel) < atomic.LoadUint32(logLevel) {
|
||||
// no package levels set, check against global level
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// check for existing tracer
|
||||
_, ok := ctx.Value(key).(*ContextTracer)
|
||||
if !ok {
|
||||
return context.WithValue(ctx, key, &ContextTracer{})
|
||||
// add and return new tracer
|
||||
tracer := &ContextTracer{}
|
||||
return context.WithValue(ctx, key, tracer), tracer
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// Tracer returns the ContextTracer previously added to the given Context.
|
||||
func Tracer(ctx context.Context) *ContextTracer {
|
||||
if ctx != nil {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
|
@ -47,10 +78,59 @@ func Tracer(ctx context.Context) *ContextTracer {
|
|||
return tracer
|
||||
}
|
||||
}
|
||||
return nilTracer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ct *ContextTracer) logTrace(level severity, msg string) {
|
||||
// Submit collected logs on the context for further processing/outputting. Does nothing if called on a nil ContextTracer.
|
||||
func (tracer *ContextTracer) Submit() {
|
||||
if tracer != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !started.IsSet() {
|
||||
// a bit resource intense, but keeps logs before logging started.
|
||||
// FIXME: create option to disable logging
|
||||
go func() {
|
||||
<-startedSignal
|
||||
tracer.Submit()
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
if len(tracer.logs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// extract last line as main line
|
||||
mainLine := tracer.logs[len(tracer.logs)-1]
|
||||
tracer.logs = tracer.logs[:len(tracer.logs)-1]
|
||||
|
||||
// create log object
|
||||
log := &logLine{
|
||||
msg: mainLine.msg,
|
||||
tracer: tracer,
|
||||
level: mainLine.level,
|
||||
timestamp: mainLine.timestamp,
|
||||
file: mainLine.file,
|
||||
line: mainLine.line,
|
||||
}
|
||||
|
||||
// send log to processing
|
||||
select {
|
||||
case logBuffer <- log:
|
||||
default:
|
||||
forceEmptyingOfBuffer <- struct{}{}
|
||||
logBuffer <- log
|
||||
}
|
||||
|
||||
// wake up writer if necessary
|
||||
if logsWaitingFlag.SetToIf(false, true) {
|
||||
logsWaiting <- struct{}{}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (tracer *ContextTracer) log(level Severity, msg string) {
|
||||
// get file and line
|
||||
_, file, line, ok := runtime.Caller(2)
|
||||
if !ok {
|
||||
|
@ -64,9 +144,9 @@ func (ct *ContextTracer) logTrace(level severity, msg string) {
|
|||
}
|
||||
}
|
||||
|
||||
ct.Lock()
|
||||
defer ct.Unlock()
|
||||
ct.actions = append(ct.actions, &Action{
|
||||
tracer.Lock()
|
||||
defer tracer.Unlock()
|
||||
tracer.logs = append(tracer.logs, &logLine{
|
||||
timestamp: time.Now(),
|
||||
level: level,
|
||||
msg: msg,
|
||||
|
@ -75,146 +155,122 @@ func (ct *ContextTracer) logTrace(level severity, msg string) {
|
|||
})
|
||||
}
|
||||
|
||||
func (ct *ContextTracer) Tracef(things ...interface{}) (ok bool) {
|
||||
if ct != nil {
|
||||
if fastcheckLevel(TraceLevel) {
|
||||
ct.logTrace(TraceLevel, fmt.Sprintf(things[0].(string), things[1:]...))
|
||||
}
|
||||
return true
|
||||
// Trace is used to log tiny steps. Log traces to context if you can!
|
||||
func (tracer *ContextTracer) Trace(msg string) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(TraceLevel, msg)
|
||||
case fastcheck(TraceLevel):
|
||||
log(TraceLevel, msg, nil)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ct *ContextTracer) Trace(msg string) (ok bool) {
|
||||
if ct != nil {
|
||||
if fastcheckLevel(TraceLevel) {
|
||||
ct.logTrace(TraceLevel, msg)
|
||||
}
|
||||
return true
|
||||
// Tracef is used to log tiny steps. Log traces to context if you can!
|
||||
func (tracer *ContextTracer) Tracef(format string, things ...interface{}) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(TraceLevel, fmt.Sprintf(format, things...))
|
||||
case fastcheck(TraceLevel):
|
||||
log(TraceLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ct *ContextTracer) Warningf(things ...interface{}) (ok bool) {
|
||||
if ct != nil {
|
||||
if fastcheckLevel(TraceLevel) {
|
||||
ct.logTrace(WarningLevel, fmt.Sprintf(things[0].(string), things[1:]...))
|
||||
}
|
||||
return true
|
||||
// Debug is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem.
|
||||
func (tracer *ContextTracer) Debug(msg string) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(DebugLevel, msg)
|
||||
case fastcheck(DebugLevel):
|
||||
log(DebugLevel, msg, nil)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ct *ContextTracer) Warning(msg string) (ok bool) {
|
||||
if ct != nil {
|
||||
if fastcheckLevel(TraceLevel) {
|
||||
ct.logTrace(WarningLevel, msg)
|
||||
}
|
||||
return true
|
||||
// Debugf is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem.
|
||||
func (tracer *ContextTracer) Debugf(format string, things ...interface{}) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(DebugLevel, fmt.Sprintf(format, things...))
|
||||
case fastcheck(DebugLevel):
|
||||
log(DebugLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ct *ContextTracer) Errorf(things ...interface{}) (ok bool) {
|
||||
if ct != nil {
|
||||
if fastcheckLevel(TraceLevel) {
|
||||
ct.logTrace(ErrorLevel, fmt.Sprintf(things[0].(string), things[1:]...))
|
||||
}
|
||||
return true
|
||||
// Info is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen.
|
||||
func (tracer *ContextTracer) Info(msg string) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(InfoLevel, msg)
|
||||
case fastcheck(InfoLevel):
|
||||
log(InfoLevel, msg, nil)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ct *ContextTracer) Error(msg string) (ok bool) {
|
||||
if ct != nil {
|
||||
if fastcheckLevel(TraceLevel) {
|
||||
ct.logTrace(ErrorLevel, msg)
|
||||
}
|
||||
return true
|
||||
// Infof is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen.
|
||||
func (tracer *ContextTracer) Infof(format string, things ...interface{}) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(InfoLevel, fmt.Sprintf(format, things...))
|
||||
case fastcheck(InfoLevel):
|
||||
log(InfoLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func DebugTrace(ctx context.Context, msg string) (ok bool) {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
if ok && fastcheckLevel(TraceLevel) {
|
||||
log(DebugLevel, msg, tracer)
|
||||
return true
|
||||
// Warning is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet.
|
||||
func (tracer *ContextTracer) Warning(msg string) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(WarningLevel, msg)
|
||||
case fastcheck(WarningLevel):
|
||||
log(WarningLevel, msg, nil)
|
||||
}
|
||||
log(DebugLevel, msg, nil)
|
||||
return false
|
||||
}
|
||||
|
||||
func DebugTracef(ctx context.Context, things ...interface{}) (ok bool) {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
if ok && fastcheckLevel(TraceLevel) {
|
||||
log(DebugLevel, fmt.Sprintf(things[0].(string), things[1:]...), tracer)
|
||||
return true
|
||||
// Warningf is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet.
|
||||
func (tracer *ContextTracer) Warningf(format string, things ...interface{}) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(WarningLevel, fmt.Sprintf(format, things...))
|
||||
case fastcheck(WarningLevel):
|
||||
log(WarningLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
log(DebugLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
return false
|
||||
}
|
||||
|
||||
func InfoTrace(ctx context.Context, msg string) (ok bool) {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
if ok && fastcheckLevel(TraceLevel) {
|
||||
log(InfoLevel, msg, tracer)
|
||||
return true
|
||||
// Error is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational. Maybe User/Admin should be informed.
|
||||
func (tracer *ContextTracer) Error(msg string) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(ErrorLevel, msg)
|
||||
case fastcheck(ErrorLevel):
|
||||
log(ErrorLevel, msg, nil)
|
||||
}
|
||||
log(InfoLevel, msg, nil)
|
||||
return false
|
||||
}
|
||||
|
||||
func InfoTracef(ctx context.Context, things ...interface{}) (ok bool) {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
if ok && fastcheckLevel(TraceLevel) {
|
||||
log(InfoLevel, fmt.Sprintf(things[0].(string), things[1:]...), tracer)
|
||||
return true
|
||||
// Errorf is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational.
|
||||
func (tracer *ContextTracer) Errorf(format string, things ...interface{}) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(ErrorLevel, fmt.Sprintf(format, things...))
|
||||
case fastcheck(ErrorLevel):
|
||||
log(ErrorLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
log(InfoLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
return false
|
||||
}
|
||||
|
||||
func WarningTrace(ctx context.Context, msg string) (ok bool) {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
if ok && fastcheckLevel(TraceLevel) {
|
||||
log(WarningLevel, msg, tracer)
|
||||
return true
|
||||
// Critical is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
|
||||
func (tracer *ContextTracer) Critical(msg string) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(CriticalLevel, msg)
|
||||
case fastcheck(CriticalLevel):
|
||||
log(CriticalLevel, msg, nil)
|
||||
}
|
||||
log(WarningLevel, msg, nil)
|
||||
return false
|
||||
}
|
||||
|
||||
func WarningTracef(ctx context.Context, things ...interface{}) (ok bool) {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
if ok && fastcheckLevel(TraceLevel) {
|
||||
log(WarningLevel, fmt.Sprintf(things[0].(string), things[1:]...), tracer)
|
||||
return true
|
||||
// Criticalf is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
|
||||
func (tracer *ContextTracer) Criticalf(format string, things ...interface{}) {
|
||||
switch {
|
||||
case tracer != nil:
|
||||
tracer.log(CriticalLevel, fmt.Sprintf(format, things...))
|
||||
case fastcheck(CriticalLevel):
|
||||
log(CriticalLevel, fmt.Sprintf(format, things...), nil)
|
||||
}
|
||||
log(WarningLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
return false
|
||||
}
|
||||
|
||||
func ErrorTrace(ctx context.Context, msg string) (ok bool) {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
if ok && fastcheckLevel(TraceLevel) {
|
||||
log(ErrorLevel, msg, tracer)
|
||||
return true
|
||||
}
|
||||
log(ErrorLevel, msg, nil)
|
||||
return false
|
||||
}
|
||||
|
||||
func ErrorTracef(ctx context.Context, things ...interface{}) (ok bool) {
|
||||
tracer, ok := ctx.Value(key).(*ContextTracer)
|
||||
if ok && fastcheckLevel(TraceLevel) {
|
||||
log(ErrorLevel, fmt.Sprintf(things[0].(string), things[1:]...), tracer)
|
||||
return true
|
||||
}
|
||||
log(ErrorLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
|
||||
return false
|
||||
}
|
||||
|
||||
func fastcheckLevel(level severity) bool {
|
||||
return uint32(level) >= atomic.LoadUint32(logLevel)
|
||||
}
|
||||
|
|
|
@ -12,20 +12,22 @@ func TestContextTracer(t *testing.T) {
|
|||
t.Skip()
|
||||
}
|
||||
|
||||
ctx := AddTracer(context.Background())
|
||||
ctx, tracer := AddTracer(context.Background())
|
||||
_ = Tracer(ctx)
|
||||
|
||||
Tracer(ctx).Trace("api: request received, checking security")
|
||||
tracer.Trace("api: request received, checking security")
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
Tracer(ctx).Trace("login: logging in user")
|
||||
tracer.Trace("login: logging in user")
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
Tracer(ctx).Trace("database: fetching requested resources")
|
||||
tracer.Trace("database: fetching requested resources")
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
Tracer(ctx).Warning("database: partial failure")
|
||||
tracer.Warning("database: partial failure")
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
Tracer(ctx).Trace("renderer: rendering output")
|
||||
tracer.Trace("renderer: rendering output")
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
Tracer(ctx).Trace("api: returning request")
|
||||
tracer.Trace("api: returning request")
|
||||
|
||||
DebugTrace(ctx, "api: completed request")
|
||||
tracer.Trace("api: completed request")
|
||||
tracer.Submit()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
|
18
modules/doc.go
Normal file
18
modules/doc.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
/*
|
||||
Package modules provides a full module and task management ecosystem to cleanly put all big and small moving parts of a service together.
|
||||
|
||||
Modules are started in a multi-stage process and may depend on other modules:
|
||||
- Go's init(): register flags
|
||||
- prep: check flags, register config variables
|
||||
- start: start actual work, access config
|
||||
- stop: gracefully shut down
|
||||
|
||||
Workers: A simple function that is run by the module while catching panics and reporting them. Ideal for long running (possibly) idle goroutines. Can be automatically restarted if execution ends with an error.
|
||||
|
||||
Tasks: Functions that take somewhere between a couple seconds and a couple minutes to execute and should be queued, scheduled or repeated.
|
||||
|
||||
MicroTasks: Functions that take less than a second to execute, but require lots of resources. Running such functions as MicroTasks will reduce concurrent execution and shall improve performance.
|
||||
|
||||
Ideally, _any_ execution by a module is done through these methods. This will not only ensure that all panics are caught, but will also give better insights into how your service performs.
|
||||
*/
|
||||
package modules
|
90
modules/error.go
Normal file
90
modules/error.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
var (
|
||||
errorReportingChannel chan *ModuleError
|
||||
)
|
||||
|
||||
// ModuleError wraps a panic, error or message into an error that can be reported.
|
||||
type ModuleError struct {
|
||||
Message string
|
||||
|
||||
ModuleName string
|
||||
TaskName string
|
||||
TaskType string // one of "worker", "task", "microtask" or custom
|
||||
Severity string // one of "info", "error", "panic" or custom
|
||||
|
||||
PanicValue interface{}
|
||||
StackTrace string
|
||||
}
|
||||
|
||||
// NewInfoMessage creates a new, reportable, info message (including a stack trace).
|
||||
func (m *Module) NewInfoMessage(message string) *ModuleError {
|
||||
return &ModuleError{
|
||||
Message: message,
|
||||
ModuleName: m.Name,
|
||||
Severity: "info",
|
||||
StackTrace: string(debug.Stack()),
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorMessage creates a new, reportable, error message (including a stack trace).
|
||||
func (m *Module) NewErrorMessage(taskName string, err error) *ModuleError {
|
||||
return &ModuleError{
|
||||
Message: err.Error(),
|
||||
ModuleName: m.Name,
|
||||
Severity: "error",
|
||||
StackTrace: string(debug.Stack()),
|
||||
}
|
||||
}
|
||||
|
||||
// NewPanicError creates a new, reportable, panic error message (including a stack trace).
|
||||
func (m *Module) NewPanicError(taskName, taskType string, panicValue interface{}) *ModuleError {
|
||||
me := &ModuleError{
|
||||
Message: fmt.Sprintf("panic: %s", panicValue),
|
||||
ModuleName: m.Name,
|
||||
TaskName: taskName,
|
||||
TaskType: taskType,
|
||||
Severity: "panic",
|
||||
PanicValue: panicValue,
|
||||
StackTrace: string(debug.Stack()),
|
||||
}
|
||||
me.Message = me.Error()
|
||||
return me
|
||||
}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
func (me *ModuleError) Error() string {
|
||||
return me.Message
|
||||
}
|
||||
|
||||
// Report reports the error through the configured reporting channel.
|
||||
func (me *ModuleError) Report() {
|
||||
if errorReportingChannel != nil {
|
||||
select {
|
||||
case errorReportingChannel <- me:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsPanic returns whether the given error is a wrapped panic by the modules package and additionally returns it, if true.
|
||||
func IsPanic(err error) (bool, *ModuleError) {
|
||||
switch val := err.(type) {
|
||||
case *ModuleError:
|
||||
return true, val
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetErrorReportingChannel sets the channel to report module errors through. By default only panics are reported, all other errors need to be manually wrapped into a *ModuleError and reported.
|
||||
func SetErrorReportingChannel(reportingChannel chan *ModuleError) {
|
||||
if errorReportingChannel == nil {
|
||||
errorReportingChannel = reportingChannel
|
||||
}
|
||||
}
|
157
modules/microtasks.go
Normal file
157
modules/microtasks.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// TODO: getting some errors when in nanosecond precision for tests:
|
||||
// (1) panic: sync: WaitGroup is reused before previous Wait has returned - should theoretically not happen
|
||||
// (2) sometimes there seems to some kind of race condition stuff, the test hangs and does not complete
|
||||
|
||||
var (
|
||||
microTasks *int32
|
||||
microTasksThreshhold *int32
|
||||
microTaskFinished = make(chan struct{}, 1)
|
||||
|
||||
mediumPriorityClearance = make(chan struct{})
|
||||
lowPriorityClearance = make(chan struct{})
|
||||
|
||||
triggerLogWriting = log.TriggerWriterChannel()
|
||||
)
|
||||
|
||||
const (
|
||||
mediumPriorityMaxDelay = 1 * time.Second
|
||||
lowPriorityMaxDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
func init() {
|
||||
var microTasksVal int32
|
||||
microTasks = µTasksVal
|
||||
var microTasksThreshholdVal int32
|
||||
microTasksThreshhold = µTasksThreshholdVal
|
||||
}
|
||||
|
||||
// SetMaxConcurrentMicroTasks sets the maximum number of microtasks that should be run concurrently.
|
||||
func SetMaxConcurrentMicroTasks(n int) {
|
||||
if n < 4 {
|
||||
atomic.StoreInt32(microTasksThreshhold, 4)
|
||||
} else {
|
||||
atomic.StoreInt32(microTasksThreshhold, int32(n))
|
||||
}
|
||||
}
|
||||
|
||||
// StartMicroTask starts a new MicroTask with high priority. It will start immediately. The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
|
||||
func (m *Module) StartMicroTask(name *string, fn func(context.Context) error) error {
|
||||
atomic.AddInt32(microTasks, 1)
|
||||
return m.runMicroTask(name, fn)
|
||||
}
|
||||
|
||||
// StartMediumPriorityMicroTask starts a new MicroTask with medium priority. It will wait until given a go (max 3 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
|
||||
func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Context) error) error {
|
||||
// check if we can go immediately
|
||||
select {
|
||||
case <-mediumPriorityClearance:
|
||||
default:
|
||||
// wait for go or max delay
|
||||
select {
|
||||
case <-mediumPriorityClearance:
|
||||
case <-time.After(mediumPriorityMaxDelay):
|
||||
}
|
||||
}
|
||||
return m.runMicroTask(name, fn)
|
||||
}
|
||||
|
||||
// StartLowPriorityMicroTask starts a new MicroTask with low priority. It will wait until given a go (max 15 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
|
||||
func (m *Module) StartLowPriorityMicroTask(name *string, fn func(context.Context) error) error {
|
||||
// check if we can go immediately
|
||||
select {
|
||||
case <-lowPriorityClearance:
|
||||
default:
|
||||
// wait for go or max delay
|
||||
select {
|
||||
case <-lowPriorityClearance:
|
||||
case <-time.After(lowPriorityMaxDelay):
|
||||
}
|
||||
}
|
||||
return m.runMicroTask(name, fn)
|
||||
}
|
||||
|
||||
func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err error) {
|
||||
// start for module
|
||||
// hint: only microTasks global var is important for scheduling, others can be set here
|
||||
atomic.AddInt32(m.microTaskCnt, 1)
|
||||
m.waitGroup.Add(1)
|
||||
|
||||
// set up recovery
|
||||
defer func() {
|
||||
// recover from panic
|
||||
panicVal := recover()
|
||||
if panicVal != nil {
|
||||
me := m.NewPanicError(*name, "microtask", panicVal)
|
||||
me.Report()
|
||||
log.Errorf("%s: microtask %s panicked: %s\n%s", m.Name, *name, panicVal, me.StackTrace)
|
||||
err = me
|
||||
}
|
||||
|
||||
// finish for module
|
||||
atomic.AddInt32(m.microTaskCnt, -1)
|
||||
m.waitGroup.Done()
|
||||
|
||||
// finish and possibly trigger next task
|
||||
atomic.AddInt32(microTasks, -1)
|
||||
select {
|
||||
case microTaskFinished <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
// run
|
||||
err = fn(m.Ctx)
|
||||
return //nolint:nakedret // need to use named return val in order to change in defer
|
||||
}
|
||||
|
||||
var (
|
||||
microTaskSchedulerStarted = abool.NewBool(false)
|
||||
)
|
||||
|
||||
func microTaskScheduler() {
|
||||
// only ever start once
|
||||
if !microTaskSchedulerStarted.SetToIf(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
microTaskManageLoop:
|
||||
for {
|
||||
if shutdownSignalClosed.IsSet() {
|
||||
close(mediumPriorityClearance)
|
||||
close(lowPriorityClearance)
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(microTasks) < atomic.LoadInt32(microTasksThreshhold) { // space left for firing task
|
||||
select {
|
||||
case mediumPriorityClearance <- struct{}{}:
|
||||
default:
|
||||
select {
|
||||
case taskTimeslot <- struct{}{}:
|
||||
continue microTaskManageLoop
|
||||
case triggerLogWriting <- struct{}{}:
|
||||
continue microTaskManageLoop
|
||||
case mediumPriorityClearance <- struct{}{}:
|
||||
case lowPriorityClearance <- struct{}{}:
|
||||
}
|
||||
}
|
||||
// increase task counter
|
||||
atomic.AddInt32(microTasks, 1)
|
||||
} else {
|
||||
// wait for signal that a task was completed
|
||||
<-microTaskFinished
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -1,60 +1,97 @@
|
|||
package taskmanager
|
||||
package modules
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
mtTestName = "microtask test"
|
||||
mtModule = initNewModule("microtask test module", nil, nil, nil)
|
||||
)
|
||||
|
||||
func init() {
|
||||
go microTaskScheduler()
|
||||
}
|
||||
|
||||
// test waiting
|
||||
func TestMicroTaskWaiting(t *testing.T) {
|
||||
|
||||
// skip
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
t.Skip("skipping test in short mode, as it is not fully deterministic")
|
||||
}
|
||||
|
||||
// init
|
||||
mtwWaitGroup := new(sync.WaitGroup)
|
||||
mtwOutputChannel := make(chan string, 100)
|
||||
mtwExpectedOutput := "123456"
|
||||
mtwExpectedOutput := "1234567"
|
||||
mtwSleepDuration := 10 * time.Millisecond
|
||||
|
||||
// TEST
|
||||
mtwWaitGroup.Add(3)
|
||||
mtwWaitGroup.Add(4)
|
||||
|
||||
// ensure we only execute one microtask at once
|
||||
atomic.StoreInt32(microTasksThreshhold, 1)
|
||||
|
||||
// High Priority - slot 1-5
|
||||
go func() {
|
||||
defer mtwWaitGroup.Done()
|
||||
StartMicroTask()
|
||||
mtwOutputChannel <- "1"
|
||||
time.Sleep(mtwSleepDuration * 5)
|
||||
mtwOutputChannel <- "2"
|
||||
EndMicroTask()
|
||||
// exec at slot 1
|
||||
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
|
||||
mtwOutputChannel <- "1" // slot 1
|
||||
time.Sleep(mtwSleepDuration * 5)
|
||||
mtwOutputChannel <- "2" // slot 5
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
time.Sleep(mtwSleepDuration * 2)
|
||||
time.Sleep(mtwSleepDuration * 1)
|
||||
|
||||
// clear clearances
|
||||
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Low Priority - slot 16
|
||||
go func() {
|
||||
defer mtwWaitGroup.Done()
|
||||
// exec at slot 2
|
||||
_ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
|
||||
mtwOutputChannel <- "7" // slot 16
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
time.Sleep(mtwSleepDuration * 1)
|
||||
|
||||
// High Priority - slot 10-15
|
||||
go func() {
|
||||
defer mtwWaitGroup.Done()
|
||||
time.Sleep(mtwSleepDuration * 8)
|
||||
StartMicroTask()
|
||||
mtwOutputChannel <- "4"
|
||||
time.Sleep(mtwSleepDuration * 5)
|
||||
mtwOutputChannel <- "6"
|
||||
EndMicroTask()
|
||||
// exec at slot 10
|
||||
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
|
||||
mtwOutputChannel <- "4" // slot 10
|
||||
time.Sleep(mtwSleepDuration * 5)
|
||||
mtwOutputChannel <- "6" // slot 15
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
// Medium Priority - Waits at slot 3, should execute in slot 6-13
|
||||
// Medium Priority - slot 6-13
|
||||
go func() {
|
||||
defer mtwWaitGroup.Done()
|
||||
<-StartMediumPriorityMicroTask()
|
||||
mtwOutputChannel <- "3"
|
||||
time.Sleep(mtwSleepDuration * 7)
|
||||
mtwOutputChannel <- "5"
|
||||
EndMicroTask()
|
||||
// exec at slot 3
|
||||
_ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
|
||||
mtwOutputChannel <- "3" // slot 6
|
||||
time.Sleep(mtwSleepDuration * 7)
|
||||
mtwOutputChannel <- "5" // slot 13
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
// wait for test to finish
|
||||
|
@ -67,6 +104,7 @@ func TestMicroTaskWaiting(t *testing.T) {
|
|||
completeOutput += s
|
||||
}
|
||||
// check if test succeeded
|
||||
t.Logf("microTask wait order: %s", completeOutput)
|
||||
if completeOutput != mtwExpectedOutput {
|
||||
t.Errorf("MicroTask waiting test failed, expected sequence %s, got %s", mtwExpectedOutput, completeOutput)
|
||||
}
|
||||
|
@ -78,34 +116,27 @@ func TestMicroTaskWaiting(t *testing.T) {
|
|||
// globals
|
||||
var mtoWaitGroup sync.WaitGroup
|
||||
var mtoOutputChannel chan string
|
||||
var mtoWaitCh chan bool
|
||||
var mtoWaitCh chan struct{}
|
||||
|
||||
// functions
|
||||
func mediumPrioTaskTester() {
|
||||
defer mtoWaitGroup.Done()
|
||||
<-mtoWaitCh
|
||||
<-StartMediumPriorityMicroTask()
|
||||
mtoOutputChannel <- "1"
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
EndMicroTask()
|
||||
_ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
|
||||
mtoOutputChannel <- "1"
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func lowPrioTaskTester() {
|
||||
defer mtoWaitGroup.Done()
|
||||
<-mtoWaitCh
|
||||
<-StartLowPriorityMicroTask()
|
||||
mtoOutputChannel <- "2"
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
EndMicroTask()
|
||||
}
|
||||
|
||||
func veryLowPrioTaskTester() {
|
||||
defer mtoWaitGroup.Done()
|
||||
<-mtoWaitCh
|
||||
<-StartVeryLowPriorityMicroTask()
|
||||
mtoOutputChannel <- "3"
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
EndMicroTask()
|
||||
_ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
|
||||
mtoOutputChannel <- "2"
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// test
|
||||
|
@ -113,53 +144,51 @@ func TestMicroTaskOrdering(t *testing.T) {
|
|||
|
||||
// skip
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
t.Skip("skipping test in short mode, as it is not fully deterministic")
|
||||
}
|
||||
|
||||
// init
|
||||
mtoOutputChannel = make(chan string, 100)
|
||||
mtoWaitCh = make(chan bool, 0)
|
||||
mtoWaitCh = make(chan struct{})
|
||||
|
||||
// TEST
|
||||
mtoWaitGroup.Add(30)
|
||||
mtoWaitGroup.Add(20)
|
||||
|
||||
// ensure we only execute one microtask at once
|
||||
atomic.StoreInt32(microTasksThreshhold, 1)
|
||||
|
||||
// kick off
|
||||
go mediumPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go mediumPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
go lowPrioTaskTester()
|
||||
go veryLowPrioTaskTester()
|
||||
|
||||
// wait for all goroutines to be ready
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// sync all goroutines
|
||||
close(mtoWaitCh)
|
||||
// trigger
|
||||
select {
|
||||
case microTaskFinished <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
// wait for test to finish
|
||||
mtoWaitGroup.Wait()
|
||||
|
@ -171,7 +200,8 @@ func TestMicroTaskOrdering(t *testing.T) {
|
|||
completeOutput += s
|
||||
}
|
||||
// check if test succeeded
|
||||
if !strings.Contains(completeOutput, "11111") || !strings.Contains(completeOutput, "22222") || !strings.Contains(completeOutput, "33333") {
|
||||
t.Logf("microTask exec order: %s", completeOutput)
|
||||
if !strings.Contains(completeOutput, "11111") || !strings.Contains(completeOutput, "22222") {
|
||||
t.Errorf("MicroTask ordering test failed, output was %s. This happens occasionally, please run the test multiple times to verify", completeOutput)
|
||||
}
|
||||
|
|
@ -39,8 +39,12 @@ type Module struct {
|
|||
Ctx context.Context
|
||||
cancelCtx func()
|
||||
shutdownFlag *abool.AtomicBool
|
||||
workerGroup sync.WaitGroup
|
||||
|
||||
// workers/tasks
|
||||
workerCnt *int32
|
||||
taskCnt *int32
|
||||
microTaskCnt *int32
|
||||
waitGroup sync.WaitGroup
|
||||
|
||||
// dependency mgmt
|
||||
depNames []string
|
||||
|
@ -48,27 +52,6 @@ type Module struct {
|
|||
depReverse []*Module
|
||||
}
|
||||
|
||||
// AddWorkers adds workers to the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
|
||||
func (m *Module) AddWorkers(n uint) {
|
||||
if !m.ShutdownInProgress() {
|
||||
if atomic.AddInt32(m.workerCnt, int32(n)) > 0 {
|
||||
// only add to workgroup if cnt is positive (try to compensate wrong usage)
|
||||
m.workerGroup.Add(int(n))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FinishWorker removes a worker from the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
|
||||
func (m *Module) FinishWorker() {
|
||||
// check worker cnt
|
||||
if atomic.AddInt32(m.workerCnt, -1) < 0 {
|
||||
log.Warningf("modules: %s module tried to finish more workers than added, this may lead to undefined behavior when shutting down", m.Name)
|
||||
return
|
||||
}
|
||||
// also mark worker done in workgroup
|
||||
m.workerGroup.Done()
|
||||
}
|
||||
|
||||
// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead.
|
||||
func (m *Module) ShutdownInProgress() bool {
|
||||
return m.shutdownFlag.IsSet()
|
||||
|
@ -87,13 +70,19 @@ func (m *Module) shutdown() error {
|
|||
// wait for workers
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
m.workerGroup.Wait()
|
||||
m.waitGroup.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(3 * time.Second):
|
||||
return errors.New("timed out while waiting for module workers to finish")
|
||||
log.Warningf(
|
||||
"%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...",
|
||||
m.Name,
|
||||
atomic.LoadInt32(m.workerCnt),
|
||||
atomic.LoadInt32(m.taskCnt),
|
||||
atomic.LoadInt32(m.microTaskCnt),
|
||||
)
|
||||
}
|
||||
|
||||
// call shutdown function
|
||||
|
@ -106,8 +95,19 @@ func dummyAction() error {
|
|||
|
||||
// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished.
|
||||
func Register(name string, prep, start, stop func() error, dependencies ...string) *Module {
|
||||
newModule := initNewModule(name, prep, start, stop, dependencies...)
|
||||
|
||||
modulesLock.Lock()
|
||||
defer modulesLock.Unlock()
|
||||
modules[name] = newModule
|
||||
return newModule
|
||||
}
|
||||
|
||||
func initNewModule(name string, prep, start, stop func() error, dependencies ...string) *Module {
|
||||
ctx, cancelCtx := context.WithCancel(context.Background())
|
||||
var workerCnt int32
|
||||
var taskCnt int32
|
||||
var microTaskCnt int32
|
||||
|
||||
newModule := &Module{
|
||||
Name: name,
|
||||
|
@ -118,8 +118,10 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin
|
|||
Ctx: ctx,
|
||||
cancelCtx: cancelCtx,
|
||||
shutdownFlag: abool.NewBool(false),
|
||||
workerGroup: sync.WaitGroup{},
|
||||
waitGroup: sync.WaitGroup{},
|
||||
workerCnt: &workerCnt,
|
||||
taskCnt: &taskCnt,
|
||||
microTaskCnt: µTaskCnt,
|
||||
prep: prep,
|
||||
start: start,
|
||||
stop: stop,
|
||||
|
@ -137,9 +139,6 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin
|
|||
newModule.stop = dummyAction
|
||||
}
|
||||
|
||||
modulesLock.Lock()
|
||||
defer modulesLock.Unlock()
|
||||
modules[name] = newModule
|
||||
return newModule
|
||||
}
|
||||
|
||||
|
|
|
@ -14,28 +14,28 @@ var (
|
|||
shutdownOrder string
|
||||
)
|
||||
|
||||
func testPrep(name string) func() error {
|
||||
func testPrep(t *testing.T, name string) func() error {
|
||||
return func() error {
|
||||
// fmt.Printf("prep %s\n", name)
|
||||
t.Logf("prep %s\n", name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func testStart(name string) func() error {
|
||||
func testStart(t *testing.T, name string) func() error {
|
||||
return func() error {
|
||||
orderLock.Lock()
|
||||
defer orderLock.Unlock()
|
||||
// fmt.Printf("start %s\n", name)
|
||||
t.Logf("start %s\n", name)
|
||||
startOrder = fmt.Sprintf("%s>%s", startOrder, name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func testStop(name string) func() error {
|
||||
func testStop(t *testing.T, name string) func() error {
|
||||
return func() error {
|
||||
orderLock.Lock()
|
||||
defer orderLock.Unlock()
|
||||
// fmt.Printf("stop %s\n", name)
|
||||
t.Logf("stop %s\n", name)
|
||||
shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name)
|
||||
return nil
|
||||
}
|
||||
|
@ -49,12 +49,19 @@ func testCleanExit() error {
|
|||
return ErrCleanExit
|
||||
}
|
||||
|
||||
func TestOrdering(t *testing.T) {
|
||||
func TestModules(t *testing.T) {
|
||||
t.Parallel() // Not really, just a workaround for running these tests last.
|
||||
|
||||
Register("database", testPrep("database"), testStart("database"), testStop("database"))
|
||||
Register("stats", testPrep("stats"), testStart("stats"), testStop("stats"), "database")
|
||||
Register("service", testPrep("service"), testStart("service"), testStop("service"), "database")
|
||||
Register("analytics", testPrep("analytics"), testStart("analytics"), testStop("analytics"), "stats", "database")
|
||||
t.Run("TestModuleOrder", testModuleOrder)
|
||||
t.Run("TestModuleErrors", testModuleErrors)
|
||||
}
|
||||
|
||||
func testModuleOrder(t *testing.T) {
|
||||
|
||||
Register("database", testPrep(t, "database"), testStart(t, "database"), testStop(t, "database"))
|
||||
Register("stats", testPrep(t, "stats"), testStart(t, "stats"), testStop(t, "stats"), "database")
|
||||
Register("service", testPrep(t, "service"), testStart(t, "service"), testStop(t, "service"), "database")
|
||||
Register("analytics", testPrep(t, "analytics"), testStart(t, "analytics"), testStop(t, "analytics"), "stats", "database")
|
||||
|
||||
err := Start()
|
||||
if err != nil {
|
||||
|
@ -105,19 +112,7 @@ func printAndRemoveModules() {
|
|||
modules = make(map[string]*Module)
|
||||
}
|
||||
|
||||
func resetModules() {
|
||||
for _, module := range modules {
|
||||
module.Prepped.UnSet()
|
||||
module.Started.UnSet()
|
||||
module.Stopped.UnSet()
|
||||
module.inTransition.UnSet()
|
||||
|
||||
module.depModules = make([]*Module, 0)
|
||||
module.depModules = make([]*Module, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
func testModuleErrors(t *testing.T) {
|
||||
|
||||
// reset modules
|
||||
modules = make(map[string]*Module)
|
||||
|
@ -125,7 +120,7 @@ func TestErrors(t *testing.T) {
|
|||
startCompleteSignal = make(chan struct{})
|
||||
|
||||
// test prep error
|
||||
Register("prepfail", testFail, testStart("prepfail"), testStop("prepfail"))
|
||||
Register("prepfail", testFail, testStart(t, "prepfail"), testStop(t, "prepfail"))
|
||||
err := Start()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
|
@ -137,7 +132,7 @@ func TestErrors(t *testing.T) {
|
|||
startCompleteSignal = make(chan struct{})
|
||||
|
||||
// test prep clean exit
|
||||
Register("prepcleanexit", testCleanExit, testStart("prepcleanexit"), testStop("prepcleanexit"))
|
||||
Register("prepcleanexit", testCleanExit, testStart(t, "prepcleanexit"), testStop(t, "prepcleanexit"))
|
||||
err = Start()
|
||||
if err != ErrCleanExit {
|
||||
t.Error("should fail with clean exit")
|
||||
|
@ -149,7 +144,7 @@ func TestErrors(t *testing.T) {
|
|||
startCompleteSignal = make(chan struct{})
|
||||
|
||||
// test invalid dependency
|
||||
Register("database", nil, testStart("database"), testStop("database"), "invalid")
|
||||
Register("database", nil, testStart(t, "database"), testStop(t, "database"), "invalid")
|
||||
err = Start()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
|
@ -161,8 +156,8 @@ func TestErrors(t *testing.T) {
|
|||
startCompleteSignal = make(chan struct{})
|
||||
|
||||
// test dependency loop
|
||||
Register("database", nil, testStart("database"), testStop("database"), "helper")
|
||||
Register("helper", nil, testStart("helper"), testStop("helper"), "database")
|
||||
Register("database", nil, testStart(t, "database"), testStop(t, "database"), "helper")
|
||||
Register("helper", nil, testStart(t, "helper"), testStop(t, "helper"), "database")
|
||||
err = Start()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
|
@ -174,7 +169,7 @@ func TestErrors(t *testing.T) {
|
|||
startCompleteSignal = make(chan struct{})
|
||||
|
||||
// test failing module start
|
||||
Register("startfail", nil, testFail, testStop("startfail"))
|
||||
Register("startfail", nil, testFail, testStop(t, "startfail"))
|
||||
err = Start()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
|
@ -186,7 +181,7 @@ func TestErrors(t *testing.T) {
|
|||
startCompleteSignal = make(chan struct{})
|
||||
|
||||
// test failing module stop
|
||||
Register("stopfail", nil, testStart("stopfail"), testFail)
|
||||
Register("stopfail", nil, testStart(t, "stopfail"), testFail)
|
||||
err = Start()
|
||||
if err != nil {
|
||||
t.Error("should not fail")
|
||||
|
|
|
@ -3,6 +3,7 @@ package modules
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/tevino/abool"
|
||||
|
@ -13,13 +14,6 @@ var (
|
|||
startCompleteSignal = make(chan struct{})
|
||||
)
|
||||
|
||||
// markStartComplete marks the startup as completed.
|
||||
func markStartComplete() {
|
||||
if startComplete.SetToIf(false, true) {
|
||||
close(startCompleteSignal)
|
||||
}
|
||||
}
|
||||
|
||||
// StartCompleted returns whether starting has completed.
|
||||
func StartCompleted() bool {
|
||||
return startComplete.IsSet()
|
||||
|
@ -35,6 +29,10 @@ func Start() error {
|
|||
modulesLock.Lock()
|
||||
defer modulesLock.Unlock()
|
||||
|
||||
// start microtask scheduler
|
||||
go microTaskScheduler()
|
||||
SetMaxConcurrentMicroTasks(runtime.GOMAXPROCS(0) * 2)
|
||||
|
||||
// inter-link modules
|
||||
err := initDependencies()
|
||||
if err != nil {
|
||||
|
@ -59,6 +57,7 @@ func Start() error {
|
|||
}
|
||||
|
||||
// start logging
|
||||
log.EnableScheduling()
|
||||
err = log.Start()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to start logging: %s\n", err)
|
||||
|
@ -79,6 +78,9 @@ func Start() error {
|
|||
close(startCompleteSignal)
|
||||
}
|
||||
|
||||
go taskQueueHandler()
|
||||
go taskScheduleHandler()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
441
modules/tasks.go
Normal file
441
modules/tasks.go
Normal file
|
@ -0,0 +1,441 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
// Task is managed task bound to a module.
|
||||
type Task struct {
|
||||
name string
|
||||
module *Module
|
||||
taskFn func(context.Context, *Task)
|
||||
|
||||
queued bool
|
||||
canceled bool
|
||||
executing bool
|
||||
cancelFunc func()
|
||||
|
||||
executeAt time.Time
|
||||
repeat time.Duration
|
||||
maxDelay time.Duration
|
||||
|
||||
queueElement *list.Element
|
||||
prioritizedQueueElement *list.Element
|
||||
scheduleListElement *list.Element
|
||||
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
var (
|
||||
taskQueue = list.New()
|
||||
prioritizedTaskQueue = list.New()
|
||||
queuesLock sync.Mutex
|
||||
queueWg sync.WaitGroup
|
||||
|
||||
taskSchedule = list.New()
|
||||
scheduleLock sync.Mutex
|
||||
|
||||
waitForever chan time.Time
|
||||
|
||||
queueIsFilled = make(chan struct{}, 1) // kick off queue handler
|
||||
recalculateNextScheduledTask = make(chan struct{}, 1)
|
||||
taskTimeslot = make(chan struct{})
|
||||
)
|
||||
|
||||
const (
|
||||
maxTimeslotWait = 30 * time.Second
|
||||
minRepeatDuration = 1 * time.Minute
|
||||
maxExecutionWait = 1 * time.Minute
|
||||
defaultMaxDelay = 1 * time.Minute
|
||||
)
|
||||
|
||||
// NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed.
|
||||
func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task {
|
||||
return &Task{
|
||||
name: name,
|
||||
module: m,
|
||||
taskFn: fn,
|
||||
maxDelay: defaultMaxDelay,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task) isActive() bool {
|
||||
return !t.canceled && !t.module.ShutdownInProgress()
|
||||
}
|
||||
|
||||
func (t *Task) prepForQueueing() (ok bool) {
|
||||
if !t.isActive() {
|
||||
return false
|
||||
}
|
||||
|
||||
t.queued = true
|
||||
if t.maxDelay != 0 {
|
||||
t.executeAt = time.Now().Add(t.maxDelay)
|
||||
t.addToSchedule()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func notifyQueue() {
|
||||
select {
|
||||
case queueIsFilled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Queue queues the Task for execution.
|
||||
func (t *Task) Queue() *Task {
|
||||
t.lock.Lock()
|
||||
if !t.prepForQueueing() {
|
||||
t.lock.Unlock()
|
||||
return t
|
||||
}
|
||||
t.lock.Unlock()
|
||||
|
||||
if t.queueElement == nil {
|
||||
queuesLock.Lock()
|
||||
t.queueElement = taskQueue.PushBack(t)
|
||||
queuesLock.Unlock()
|
||||
}
|
||||
|
||||
notifyQueue()
|
||||
return t
|
||||
}
|
||||
|
||||
// Prioritize puts the task in the prioritized queue.
|
||||
func (t *Task) Prioritize() *Task {
|
||||
t.lock.Lock()
|
||||
if !t.prepForQueueing() {
|
||||
t.lock.Unlock()
|
||||
return t
|
||||
}
|
||||
t.lock.Unlock()
|
||||
|
||||
if t.prioritizedQueueElement == nil {
|
||||
queuesLock.Lock()
|
||||
t.prioritizedQueueElement = prioritizedTaskQueue.PushBack(t)
|
||||
queuesLock.Unlock()
|
||||
}
|
||||
|
||||
notifyQueue()
|
||||
return t
|
||||
}
|
||||
|
||||
// StartASAP schedules the task to be executed next.
|
||||
func (t *Task) StartASAP() *Task {
|
||||
t.lock.Lock()
|
||||
if !t.prepForQueueing() {
|
||||
t.lock.Unlock()
|
||||
return t
|
||||
}
|
||||
t.lock.Unlock()
|
||||
|
||||
queuesLock.Lock()
|
||||
if t.prioritizedQueueElement == nil {
|
||||
t.prioritizedQueueElement = prioritizedTaskQueue.PushFront(t)
|
||||
} else {
|
||||
prioritizedTaskQueue.MoveToFront(t.prioritizedQueueElement)
|
||||
}
|
||||
queuesLock.Unlock()
|
||||
|
||||
notifyQueue()
|
||||
return t
|
||||
}
|
||||
|
||||
// MaxDelay sets a maximum delay within the task should be executed from being queued. Scheduled tasks are queued when they are triggered. The default delay is 3 minutes.
|
||||
func (t *Task) MaxDelay(maxDelay time.Duration) *Task {
|
||||
t.lock.Lock()
|
||||
t.maxDelay = maxDelay
|
||||
t.lock.Unlock()
|
||||
return t
|
||||
}
|
||||
|
||||
// Schedule schedules the task for execution at the given time.
|
||||
func (t *Task) Schedule(executeAt time.Time) *Task {
|
||||
t.lock.Lock()
|
||||
t.executeAt = executeAt
|
||||
t.addToSchedule()
|
||||
t.lock.Unlock()
|
||||
return t
|
||||
}
|
||||
|
||||
// Repeat sets the task to be executed in endless repeat at the specified interval. First execution will be after interval. Minimum repeat interval is one minute.
|
||||
func (t *Task) Repeat(interval time.Duration) *Task {
|
||||
// check minimum interval duration
|
||||
if interval < minRepeatDuration {
|
||||
interval = minRepeatDuration
|
||||
}
|
||||
|
||||
t.lock.Lock()
|
||||
t.repeat = interval
|
||||
t.executeAt = time.Now().Add(t.repeat)
|
||||
t.lock.Unlock()
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Cancel cancels the current and any future execution of the Task. This is not reversible by any other functions.
|
||||
func (t *Task) Cancel() {
|
||||
t.lock.Lock()
|
||||
t.canceled = true
|
||||
if t.cancelFunc != nil {
|
||||
t.cancelFunc()
|
||||
}
|
||||
t.lock.Unlock()
|
||||
}
|
||||
|
||||
func (t *Task) runWithLocking() {
|
||||
// wait for good timeslot regarding microtasks
|
||||
select {
|
||||
case <-taskTimeslot:
|
||||
case <-time.After(maxTimeslotWait):
|
||||
}
|
||||
|
||||
t.lock.Lock()
|
||||
|
||||
// check state, return if already executing or inactive
|
||||
if t.executing || !t.isActive() {
|
||||
t.lock.Unlock()
|
||||
return
|
||||
}
|
||||
t.executing = true
|
||||
|
||||
// get list elements
|
||||
queueElement := t.queueElement
|
||||
prioritizedQueueElement := t.prioritizedQueueElement
|
||||
scheduleListElement := t.scheduleListElement
|
||||
|
||||
// create context
|
||||
var taskCtx context.Context
|
||||
taskCtx, t.cancelFunc = context.WithCancel(t.module.Ctx)
|
||||
|
||||
t.lock.Unlock()
|
||||
|
||||
// remove from lists
|
||||
if queueElement != nil {
|
||||
queuesLock.Lock()
|
||||
taskQueue.Remove(t.queueElement)
|
||||
queuesLock.Unlock()
|
||||
t.lock.Lock()
|
||||
t.queueElement = nil
|
||||
t.lock.Unlock()
|
||||
}
|
||||
if prioritizedQueueElement != nil {
|
||||
queuesLock.Lock()
|
||||
prioritizedTaskQueue.Remove(t.prioritizedQueueElement)
|
||||
queuesLock.Unlock()
|
||||
t.lock.Lock()
|
||||
t.prioritizedQueueElement = nil
|
||||
t.lock.Unlock()
|
||||
}
|
||||
if scheduleListElement != nil {
|
||||
scheduleLock.Lock()
|
||||
taskSchedule.Remove(t.scheduleListElement)
|
||||
scheduleLock.Unlock()
|
||||
t.lock.Lock()
|
||||
t.scheduleListElement = nil
|
||||
t.lock.Unlock()
|
||||
}
|
||||
|
||||
// add to queue workgroup
|
||||
queueWg.Add(1)
|
||||
|
||||
go t.executeWithLocking(taskCtx, t.cancelFunc)
|
||||
go func() {
|
||||
select {
|
||||
case <-taskCtx.Done():
|
||||
case <-time.After(maxExecutionWait):
|
||||
}
|
||||
// complete queue worker (early) to allow next worker
|
||||
queueWg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) {
|
||||
// start for module
|
||||
// hint: only queueWg global var is important for scheduling, others can be set here
|
||||
atomic.AddInt32(t.module.taskCnt, 1)
|
||||
t.module.waitGroup.Add(1)
|
||||
|
||||
defer func() {
|
||||
// recover from panic
|
||||
panicVal := recover()
|
||||
if panicVal != nil {
|
||||
me := t.module.NewPanicError(t.name, "task", panicVal)
|
||||
me.Report()
|
||||
log.Errorf("%s: task %s panicked: %s\n%s", t.module.Name, t.name, panicVal, me.StackTrace)
|
||||
}
|
||||
|
||||
// finish for module
|
||||
atomic.AddInt32(t.module.taskCnt, -1)
|
||||
t.module.waitGroup.Done()
|
||||
|
||||
// reset
|
||||
t.lock.Lock()
|
||||
// reset state
|
||||
t.executing = false
|
||||
t.queued = false
|
||||
// repeat?
|
||||
if t.isActive() && t.repeat != 0 {
|
||||
t.executeAt = time.Now().Add(t.repeat)
|
||||
t.addToSchedule()
|
||||
}
|
||||
t.lock.Unlock()
|
||||
|
||||
// notify that we finished
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
// run
|
||||
t.taskFn(ctx, t)
|
||||
}
|
||||
|
||||
func (t *Task) getExecuteAtWithLocking() time.Time {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
return t.executeAt
|
||||
}
|
||||
|
||||
func (t *Task) addToSchedule() {
|
||||
scheduleLock.Lock()
|
||||
defer scheduleLock.Unlock()
|
||||
|
||||
// notify scheduler
|
||||
defer func() {
|
||||
select {
|
||||
case recalculateNextScheduledTask <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
// insert task into schedule
|
||||
for e := taskSchedule.Front(); e != nil; e = e.Next() {
|
||||
// check for self
|
||||
eVal := e.Value.(*Task)
|
||||
if eVal == t {
|
||||
continue
|
||||
}
|
||||
// compare
|
||||
if t.executeAt.Before(eVal.getExecuteAtWithLocking()) {
|
||||
// insert/move task
|
||||
if t.scheduleListElement == nil {
|
||||
t.scheduleListElement = taskSchedule.InsertBefore(t, e)
|
||||
} else {
|
||||
taskSchedule.MoveBefore(t.scheduleListElement, e)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// add/move to end
|
||||
if t.scheduleListElement == nil {
|
||||
t.scheduleListElement = taskSchedule.PushBack(t)
|
||||
} else {
|
||||
taskSchedule.MoveToBack(t.scheduleListElement)
|
||||
}
|
||||
}
|
||||
|
||||
func waitUntilNextScheduledTask() <-chan time.Time {
|
||||
scheduleLock.Lock()
|
||||
defer scheduleLock.Unlock()
|
||||
|
||||
if taskSchedule.Len() > 0 {
|
||||
return time.After(time.Until(taskSchedule.Front().Value.(*Task).executeAt))
|
||||
}
|
||||
return waitForever
|
||||
}
|
||||
|
||||
var (
|
||||
taskQueueHandlerStarted = abool.NewBool(false)
|
||||
taskScheduleHandlerStarted = abool.NewBool(false)
|
||||
)
|
||||
|
||||
func taskQueueHandler() {
|
||||
// only ever start once
|
||||
if !taskQueueHandlerStarted.SetToIf(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
// wait
|
||||
select {
|
||||
case <-shutdownSignal:
|
||||
return
|
||||
case <-queueIsFilled:
|
||||
}
|
||||
|
||||
// execute
|
||||
execLoop:
|
||||
for {
|
||||
// wait for execution slot
|
||||
queueWg.Wait()
|
||||
|
||||
// check for shutdown
|
||||
if shutdownSignalClosed.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
// get next Task
|
||||
queuesLock.Lock()
|
||||
e := prioritizedTaskQueue.Front()
|
||||
if e != nil {
|
||||
prioritizedTaskQueue.Remove(e)
|
||||
} else {
|
||||
e = taskQueue.Front()
|
||||
if e != nil {
|
||||
taskQueue.Remove(e)
|
||||
}
|
||||
}
|
||||
queuesLock.Unlock()
|
||||
|
||||
// lists are empty
|
||||
if e == nil {
|
||||
break execLoop
|
||||
}
|
||||
|
||||
// value -> Task
|
||||
t := e.Value.(*Task)
|
||||
// run
|
||||
t.runWithLocking()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func taskScheduleHandler() {
|
||||
// only ever start once
|
||||
if !taskScheduleHandlerStarted.SetToIf(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-shutdownSignal:
|
||||
return
|
||||
case <-recalculateNextScheduledTask:
|
||||
case <-waitUntilNextScheduledTask():
|
||||
// get first task in schedule
|
||||
scheduleLock.Lock()
|
||||
e := taskSchedule.Front()
|
||||
scheduleLock.Unlock()
|
||||
t := e.Value.(*Task)
|
||||
|
||||
// process Task
|
||||
if t.queued {
|
||||
// already queued and maxDelay reached
|
||||
t.runWithLocking()
|
||||
} else {
|
||||
// place in front of prioritized queue
|
||||
t.StartASAP()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
162
modules/tasks_test.go
Normal file
162
modules/tasks_test.go
Normal file
|
@ -0,0 +1,162 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
go taskQueueHandler()
|
||||
go taskScheduleHandler()
|
||||
|
||||
go func() {
|
||||
<-time.After(10 * time.Second)
|
||||
fmt.Fprintln(os.Stderr, "taking too long")
|
||||
_ = pprof.Lookup("goroutine").WriteTo(os.Stderr, 2)
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
// always trigger task timeslot for testing
|
||||
go func() {
|
||||
for {
|
||||
taskTimeslot <- struct{}{}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// test waiting
|
||||
|
||||
// globals
|
||||
var qtWg sync.WaitGroup
|
||||
var qtOutputChannel chan string
|
||||
var qtSleepDuration time.Duration
|
||||
var qtModule = initNewModule("task test module", nil, nil, nil)
|
||||
|
||||
// functions
|
||||
func queuedTaskTester(s string) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
|
||||
time.Sleep(qtSleepDuration * 2)
|
||||
qtOutputChannel <- s
|
||||
qtWg.Done()
|
||||
}).Queue()
|
||||
}
|
||||
|
||||
func prioritizedTaskTester(s string) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
|
||||
time.Sleep(qtSleepDuration * 2)
|
||||
qtOutputChannel <- s
|
||||
qtWg.Done()
|
||||
}).Prioritize()
|
||||
}
|
||||
|
||||
// test
|
||||
func TestQueuedTask(t *testing.T) {
|
||||
// skip
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode, as it is not fully deterministic")
|
||||
}
|
||||
|
||||
// init
|
||||
expectedOutput := "0123456789"
|
||||
qtSleepDuration = 20 * time.Millisecond
|
||||
qtOutputChannel = make(chan string, 100)
|
||||
qtWg.Add(10)
|
||||
|
||||
// TEST
|
||||
queuedTaskTester("0")
|
||||
queuedTaskTester("1")
|
||||
queuedTaskTester("3")
|
||||
queuedTaskTester("4")
|
||||
queuedTaskTester("6")
|
||||
queuedTaskTester("7")
|
||||
queuedTaskTester("9")
|
||||
|
||||
time.Sleep(qtSleepDuration * 3)
|
||||
prioritizedTaskTester("2")
|
||||
time.Sleep(qtSleepDuration * 6)
|
||||
prioritizedTaskTester("5")
|
||||
time.Sleep(qtSleepDuration * 6)
|
||||
prioritizedTaskTester("8")
|
||||
|
||||
// wait for test to finish
|
||||
qtWg.Wait()
|
||||
|
||||
// collect output
|
||||
close(qtOutputChannel)
|
||||
completeOutput := ""
|
||||
for s := <-qtOutputChannel; s != ""; s = <-qtOutputChannel {
|
||||
completeOutput += s
|
||||
}
|
||||
// check if test succeeded
|
||||
if completeOutput != expectedOutput {
|
||||
t.Errorf("QueuedTask test failed, expected sequence %s, got %s", expectedOutput, completeOutput)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// test scheduled tasks
|
||||
|
||||
// globals
|
||||
var stWg sync.WaitGroup
|
||||
var stOutputChannel chan string
|
||||
var stSleepDuration time.Duration
|
||||
var stWaitCh chan bool
|
||||
|
||||
// functions
|
||||
func scheduledTaskTester(s string, sched time.Time) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
|
||||
time.Sleep(stSleepDuration)
|
||||
stOutputChannel <- s
|
||||
stWg.Done()
|
||||
}).Schedule(sched)
|
||||
}
|
||||
|
||||
// test
|
||||
func TestScheduledTaskWaiting(t *testing.T) {
|
||||
|
||||
// skip
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode, as it is not fully deterministic")
|
||||
}
|
||||
|
||||
// init
|
||||
expectedOutput := "0123456789"
|
||||
stSleepDuration = 10 * time.Millisecond
|
||||
stOutputChannel = make(chan string, 100)
|
||||
stWaitCh = make(chan bool)
|
||||
|
||||
stWg.Add(10)
|
||||
|
||||
// TEST
|
||||
scheduledTaskTester("4", time.Now().Add(stSleepDuration*8))
|
||||
scheduledTaskTester("0", time.Now().Add(stSleepDuration*0))
|
||||
scheduledTaskTester("8", time.Now().Add(stSleepDuration*16))
|
||||
scheduledTaskTester("1", time.Now().Add(stSleepDuration*2))
|
||||
scheduledTaskTester("7", time.Now().Add(stSleepDuration*14))
|
||||
scheduledTaskTester("9", time.Now().Add(stSleepDuration*18))
|
||||
scheduledTaskTester("3", time.Now().Add(stSleepDuration*6))
|
||||
scheduledTaskTester("2", time.Now().Add(stSleepDuration*4))
|
||||
scheduledTaskTester("6", time.Now().Add(stSleepDuration*12))
|
||||
scheduledTaskTester("5", time.Now().Add(stSleepDuration*10))
|
||||
|
||||
// wait for test to finish
|
||||
close(stWaitCh)
|
||||
stWg.Wait()
|
||||
|
||||
// collect output
|
||||
close(stOutputChannel)
|
||||
completeOutput := ""
|
||||
for s := <-stOutputChannel; s != ""; s = <-stOutputChannel {
|
||||
completeOutput += s
|
||||
}
|
||||
// check if test succeeded
|
||||
if completeOutput != expectedOutput {
|
||||
t.Errorf("ScheduledTask test failed, expected sequence %s, got %s", expectedOutput, completeOutput)
|
||||
}
|
||||
|
||||
}
|
79
modules/worker.go
Normal file
79
modules/worker.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
// Worker Default Configuration
|
||||
const (
|
||||
DefaultBackoffDuration = 2 * time.Second
|
||||
)
|
||||
|
||||
// RunWorker directly runs a generic worker that does not fit to be a Task or MicroTask, such as long running (and possibly mostly idle) sessions. A call to RunWorker blocks until the worker is finished.
|
||||
func (m *Module) RunWorker(name string, fn func(context.Context) error) error {
|
||||
atomic.AddInt32(m.workerCnt, 1)
|
||||
m.waitGroup.Add(1)
|
||||
defer func() {
|
||||
atomic.AddInt32(m.workerCnt, -1)
|
||||
m.waitGroup.Done()
|
||||
}()
|
||||
|
||||
return m.runWorker(name, fn)
|
||||
}
|
||||
|
||||
// StartServiceWorker starts a generic worker, which is automatically restarted in case of an error. A call to StartServiceWorker runs the service-worker in a new goroutine and returns immediately. `backoffDuration` specifies how to long to wait before restarts, multiplied by the number of failed attempts. Pass `0` for the default backoff duration. For custom error remediation functionality, build your own error handling procedure using calls to RunWorker.
|
||||
func (m *Module) StartServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) {
|
||||
go m.runServiceWorker(name, backoffDuration, fn)
|
||||
}
|
||||
|
||||
func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) {
|
||||
atomic.AddInt32(m.workerCnt, 1)
|
||||
m.waitGroup.Add(1)
|
||||
defer func() {
|
||||
atomic.AddInt32(m.workerCnt, -1)
|
||||
m.waitGroup.Done()
|
||||
}()
|
||||
|
||||
if backoffDuration == 0 {
|
||||
backoffDuration = DefaultBackoffDuration
|
||||
}
|
||||
failCnt := 0
|
||||
|
||||
for {
|
||||
if m.ShutdownInProgress() {
|
||||
return
|
||||
}
|
||||
|
||||
err := m.runWorker(name, fn)
|
||||
if err != nil {
|
||||
// log error and restart
|
||||
failCnt++
|
||||
sleepFor := time.Duration(failCnt) * backoffDuration
|
||||
log.Errorf("%s: service-worker %s failed (%d): %s - restarting in %s", m.Name, name, failCnt, err, sleepFor)
|
||||
time.Sleep(sleepFor)
|
||||
} else {
|
||||
// finish
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) runWorker(name string, fn func(context.Context) error) (err error) {
|
||||
defer func() {
|
||||
// recover from panic
|
||||
panicVal := recover()
|
||||
if panicVal != nil {
|
||||
me := m.NewPanicError(name, "worker", panicVal)
|
||||
me.Report()
|
||||
err = me
|
||||
}
|
||||
}()
|
||||
|
||||
// run
|
||||
err = fn(m.Ctx)
|
||||
return
|
||||
}
|
65
modules/worker_test.go
Normal file
65
modules/worker_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
wModule = initNewModule("worker test module", nil, nil, nil)
|
||||
errTest = errors.New("test error")
|
||||
)
|
||||
|
||||
func TestWorker(t *testing.T) {
|
||||
// test basic functionality
|
||||
err := wModule.RunWorker("test worker", func(ctx context.Context) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("worker failed (should not): %s", err)
|
||||
}
|
||||
|
||||
// test returning an error
|
||||
err = wModule.RunWorker("test worker", func(ctx context.Context) error {
|
||||
return errTest
|
||||
})
|
||||
if err != errTest {
|
||||
t.Errorf("worker failed with unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// test service functionality
|
||||
failCnt := 0
|
||||
var sWTestGroup sync.WaitGroup
|
||||
sWTestGroup.Add(1)
|
||||
wModule.StartServiceWorker("test service-worker", 2*time.Millisecond, func(ctx context.Context) error {
|
||||
failCnt++
|
||||
t.Logf("service-worker test run #%d", failCnt)
|
||||
if failCnt >= 3 {
|
||||
sWTestGroup.Done()
|
||||
return nil
|
||||
}
|
||||
return errTest
|
||||
})
|
||||
// wait for service-worker to complete test
|
||||
sWTestGroup.Wait()
|
||||
if failCnt != 3 {
|
||||
t.Errorf("service-worker failed to restart")
|
||||
}
|
||||
|
||||
// test panic recovery
|
||||
err = wModule.RunWorker("test worker", func(ctx context.Context) error {
|
||||
var a []byte
|
||||
_ = a[0] //nolint // we want to runtime panic!
|
||||
return nil
|
||||
})
|
||||
t.Logf("panic error message: %s", err)
|
||||
panicked, mErr := IsPanic(err)
|
||||
if !panicked {
|
||||
t.Errorf("failed to return *ModuleError, got %+v", err)
|
||||
} else {
|
||||
t.Logf("panic stack trace:\n%s", mErr.StackTrace)
|
||||
}
|
||||
}
|
|
@ -1,19 +1,21 @@
|
|||
package notifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
func cleaner() {
|
||||
shutdownWg.Add(1)
|
||||
select {
|
||||
case <-shutdownSignal:
|
||||
shutdownWg.Done()
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
cleanNotifications()
|
||||
//nolint:unparam // must conform to interface
|
||||
func cleaner(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
cleanNotifications()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ func UpdateNotification(n *Notification, key string) {
|
|||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
// seperate goroutine in order to correctly lock notsLock
|
||||
// separate goroutine in order to correctly lock notsLock
|
||||
notsLock.RLock()
|
||||
origN, ok := nots[key]
|
||||
notsLock.RUnlock()
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
package notifications
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
var (
|
||||
shutdownSignal = make(chan struct{})
|
||||
shutdownWg sync.WaitGroup
|
||||
module *modules.Module
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -21,12 +20,6 @@ func start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
go cleaner()
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
close(shutdownSignal)
|
||||
shutdownWg.Wait()
|
||||
go module.StartServiceWorker("cleaner", 1*time.Second, cleaner)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -32,15 +32,15 @@ type Notification struct {
|
|||
DataSubject sync.Locker
|
||||
Type uint8
|
||||
|
||||
AvailableActions []*Action
|
||||
SelectedActionID string
|
||||
|
||||
Persistent bool // this notification persists until it is handled and survives restarts
|
||||
Created int64 // creation timestamp, notification "starts"
|
||||
Expires int64 // expiry timestamp, notification is expected to be canceled at this time and may be cleaned up afterwards
|
||||
Responded int64 // response timestamp, notification "ends"
|
||||
Executed int64 // execution timestamp, notification will be deleted soon
|
||||
|
||||
AvailableActions []*Action
|
||||
SelectedActionID string
|
||||
|
||||
lock sync.Mutex
|
||||
actionFunction func(*Notification) // call function to process action
|
||||
actionTrigger chan string // and/or send to a channel
|
||||
|
@ -54,7 +54,6 @@ type Action struct {
|
|||
}
|
||||
|
||||
func noOpAction(n *Notification) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get returns the notification identifed by the given id or nil if it doesn't exist.
|
||||
|
@ -125,7 +124,7 @@ func (n *Notification) Save() *Notification {
|
|||
}
|
||||
|
||||
// SetActionFunction sets a trigger function to be executed when the user reacted on the notification.
|
||||
// The provided funtion will be started as its own goroutine and will have to lock everything it accesses, even the provided notification.
|
||||
// The provided function will be started as its own goroutine and will have to lock everything it accesses, even the provided notification.
|
||||
func (n *Notification) SetActionFunction(fn func(*Notification)) *Notification {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
@ -139,7 +138,7 @@ func (n *Notification) MakeAck() *Notification {
|
|||
defer n.lock.Unlock()
|
||||
|
||||
n.AvailableActions = []*Action{
|
||||
&Action{
|
||||
{
|
||||
ID: "ack",
|
||||
Text: "OK",
|
||||
},
|
||||
|
|
|
@ -29,7 +29,7 @@ func main() {
|
|||
|
||||
// Shutdown
|
||||
// catch interrupt for clean shutdown
|
||||
signalCh := make(chan os.Signal)
|
||||
signalCh := make(chan os.Signal, 3)
|
||||
signal.Notify(
|
||||
signalCh,
|
||||
os.Interrupt,
|
||||
|
@ -42,7 +42,7 @@ func main() {
|
|||
case <-signalCh:
|
||||
fmt.Println(" <INTERRUPT>")
|
||||
log.Warning("main: program was interrupted, shutting down.")
|
||||
modules.Shutdown()
|
||||
_ = modules.Shutdown()
|
||||
case <-modules.ShuttingDown():
|
||||
}
|
||||
|
||||
|
|
|
@ -1,176 +0,0 @@
|
|||
package taskmanager
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// TODO: getting some errors when in nanosecond precision for tests:
|
||||
// (1) panic: sync: WaitGroup is reused before previous Wait has returned - should theoretically not happen
|
||||
// (2) sometimes there seems to some kind of race condition stuff, the test hangs and does not complete
|
||||
|
||||
var (
|
||||
closedChannel chan bool
|
||||
|
||||
tasks *int32
|
||||
|
||||
mediumPriorityClearance chan bool
|
||||
lowPriorityClearance chan bool
|
||||
veryLowPriorityClearance chan bool
|
||||
|
||||
tasksDone chan bool
|
||||
tasksDoneFlag *abool.AtomicBool
|
||||
tasksWaiting chan bool
|
||||
tasksWaitingFlag *abool.AtomicBool
|
||||
|
||||
shutdownSignal = make(chan struct{}, 0)
|
||||
suttingDown = abool.NewBool(false)
|
||||
)
|
||||
|
||||
// StartMicroTask starts a new MicroTask. It will start immediately.
|
||||
func StartMicroTask() {
|
||||
atomic.AddInt32(tasks, 1)
|
||||
tasksDoneFlag.UnSet()
|
||||
}
|
||||
|
||||
// EndMicroTask MUST be always called when a MicroTask was previously started.
|
||||
func EndMicroTask() {
|
||||
c := atomic.AddInt32(tasks, -1)
|
||||
if c < 1 {
|
||||
if tasksDoneFlag.SetToIf(false, true) {
|
||||
tasksDone <- true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTaskIsWaiting() {
|
||||
tasksWaiting <- true
|
||||
}
|
||||
|
||||
// StartMediumPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
|
||||
func StartMediumPriorityMicroTask() chan bool {
|
||||
if suttingDown.IsSet() {
|
||||
return closedChannel
|
||||
}
|
||||
if tasksWaitingFlag.SetToIf(false, true) {
|
||||
defer newTaskIsWaiting()
|
||||
}
|
||||
return mediumPriorityClearance
|
||||
}
|
||||
|
||||
// StartLowPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
|
||||
func StartLowPriorityMicroTask() chan bool {
|
||||
if suttingDown.IsSet() {
|
||||
return closedChannel
|
||||
}
|
||||
if tasksWaitingFlag.SetToIf(false, true) {
|
||||
defer newTaskIsWaiting()
|
||||
}
|
||||
return lowPriorityClearance
|
||||
}
|
||||
|
||||
// StartVeryLowPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
|
||||
func StartVeryLowPriorityMicroTask() chan bool {
|
||||
if suttingDown.IsSet() {
|
||||
return closedChannel
|
||||
}
|
||||
if tasksWaitingFlag.SetToIf(false, true) {
|
||||
defer newTaskIsWaiting()
|
||||
}
|
||||
return veryLowPriorityClearance
|
||||
}
|
||||
|
||||
func start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
close(shutdownSignal)
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
closedChannel = make(chan bool, 0)
|
||||
close(closedChannel)
|
||||
|
||||
var t int32 = 0
|
||||
tasks = &t
|
||||
|
||||
mediumPriorityClearance = make(chan bool, 0)
|
||||
lowPriorityClearance = make(chan bool, 0)
|
||||
veryLowPriorityClearance = make(chan bool, 0)
|
||||
|
||||
tasksDone = make(chan bool, 1)
|
||||
tasksDoneFlag = abool.NewBool(true)
|
||||
tasksWaiting = make(chan bool, 1)
|
||||
tasksWaitingFlag = abool.NewBool(false)
|
||||
|
||||
timoutTimerDuration := 1 * time.Second
|
||||
// timoutTimer := time.NewTimer(timoutTimerDuration)
|
||||
|
||||
go func() {
|
||||
microTaskManageLoop:
|
||||
for {
|
||||
|
||||
// wait for an event to start new tasks
|
||||
if !suttingDown.IsSet() {
|
||||
|
||||
// reset timer
|
||||
// https://golang.org/pkg/time/#Timer.Reset
|
||||
// if !timoutTimer.Stop() {
|
||||
// <-timoutTimer.C
|
||||
// }
|
||||
// timoutTimer.Reset(timoutTimerDuration)
|
||||
|
||||
// wait for event to start a new task
|
||||
select {
|
||||
case <-tasksWaiting:
|
||||
if !tasksDoneFlag.IsSet() {
|
||||
continue microTaskManageLoop
|
||||
}
|
||||
case <-time.After(timoutTimerDuration):
|
||||
case <-tasksDone:
|
||||
case <-shutdownSignal:
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
// execute tasks until no tasks are waiting anymore
|
||||
if !tasksWaitingFlag.IsSet() {
|
||||
// wait until tasks are finished
|
||||
if !tasksDoneFlag.IsSet() {
|
||||
<-tasksDone
|
||||
}
|
||||
// signal module completion
|
||||
// microTasksModule.StopComplete()
|
||||
// exit
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// start new task, if none is started, check if we are shutting down
|
||||
select {
|
||||
case mediumPriorityClearance <- true:
|
||||
StartMicroTask()
|
||||
default:
|
||||
select {
|
||||
case lowPriorityClearance <- true:
|
||||
StartMicroTask()
|
||||
default:
|
||||
select {
|
||||
case veryLowPriorityClearance <- true:
|
||||
StartMicroTask()
|
||||
default:
|
||||
tasksWaitingFlag.UnSet()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
|
@ -1,152 +0,0 @@
|
|||
package taskmanager
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
name string
|
||||
start chan bool
|
||||
started *abool.AtomicBool
|
||||
schedule *time.Time
|
||||
}
|
||||
|
||||
var taskQueue *list.List
|
||||
var prioritizedTaskQueue *list.List
|
||||
var addToQueue chan *Task
|
||||
var addToPrioritizedQueue chan *Task
|
||||
var addAsNextTask chan *Task
|
||||
|
||||
var finishedQueuedTask chan bool
|
||||
var queuedTaskRunning *abool.AtomicBool
|
||||
|
||||
var getQueueLengthREQ chan bool
|
||||
var getQueueLengthREP chan int
|
||||
|
||||
func newUnqeuedTask(name string) *Task {
|
||||
t := &Task{
|
||||
name,
|
||||
make(chan bool),
|
||||
abool.NewBool(false),
|
||||
nil,
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func NewQueuedTask(name string) *Task {
|
||||
t := newUnqeuedTask(name)
|
||||
addToQueue <- t
|
||||
return t
|
||||
}
|
||||
|
||||
func NewPrioritizedQueuedTask(name string) *Task {
|
||||
t := newUnqeuedTask(name)
|
||||
addToPrioritizedQueue <- t
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Task) addToPrioritizedQueue() {
|
||||
addToPrioritizedQueue <- t
|
||||
}
|
||||
|
||||
func (t *Task) WaitForStart() chan bool {
|
||||
return t.start
|
||||
}
|
||||
|
||||
func (t *Task) StartAnyway() {
|
||||
addAsNextTask <- t
|
||||
}
|
||||
|
||||
func (t *Task) Done() {
|
||||
if !t.started.SetToIf(false, true) {
|
||||
finishedQueuedTask <- true
|
||||
}
|
||||
}
|
||||
|
||||
func TotalQueuedTasks() int {
|
||||
getQueueLengthREQ <- true
|
||||
return <-getQueueLengthREP
|
||||
}
|
||||
|
||||
func checkQueueStatus() {
|
||||
if queuedTaskRunning.SetToIf(false, true) {
|
||||
finishedQueuedTask <- true
|
||||
}
|
||||
}
|
||||
|
||||
func fireNextTask() {
|
||||
|
||||
if prioritizedTaskQueue.Len() > 0 {
|
||||
for e := prioritizedTaskQueue.Front(); prioritizedTaskQueue.Len() > 0; e.Next() {
|
||||
t := e.Value.(*Task)
|
||||
prioritizedTaskQueue.Remove(e)
|
||||
if t.started.SetToIf(false, true) {
|
||||
close(t.start)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if taskQueue.Len() > 0 {
|
||||
for e := taskQueue.Front(); taskQueue.Len() > 0; e.Next() {
|
||||
t := e.Value.(*Task)
|
||||
taskQueue.Remove(e)
|
||||
if t.started.SetToIf(false, true) {
|
||||
close(t.start)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
queuedTaskRunning.UnSet()
|
||||
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
taskQueue = list.New()
|
||||
prioritizedTaskQueue = list.New()
|
||||
addToQueue = make(chan *Task, 1)
|
||||
addToPrioritizedQueue = make(chan *Task, 1)
|
||||
addAsNextTask = make(chan *Task, 1)
|
||||
|
||||
finishedQueuedTask = make(chan bool, 1)
|
||||
queuedTaskRunning = abool.NewBool(false)
|
||||
|
||||
getQueueLengthREQ = make(chan bool, 1)
|
||||
getQueueLengthREP = make(chan int, 1)
|
||||
|
||||
go func() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-shutdownSignal:
|
||||
// TODO: work off queue?
|
||||
return
|
||||
case <-getQueueLengthREQ:
|
||||
// TODO: maybe clean queues before replying
|
||||
if queuedTaskRunning.IsSet() {
|
||||
getQueueLengthREP <- prioritizedTaskQueue.Len() + taskQueue.Len() + 1
|
||||
} else {
|
||||
getQueueLengthREP <- prioritizedTaskQueue.Len() + taskQueue.Len()
|
||||
}
|
||||
case t := <-addToQueue:
|
||||
taskQueue.PushBack(t)
|
||||
checkQueueStatus()
|
||||
case t := <-addToPrioritizedQueue:
|
||||
prioritizedTaskQueue.PushBack(t)
|
||||
checkQueueStatus()
|
||||
case t := <-addAsNextTask:
|
||||
prioritizedTaskQueue.PushFront(t)
|
||||
checkQueueStatus()
|
||||
case <-finishedQueuedTask:
|
||||
fireNextTask()
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
}
|
|
@ -1,110 +0,0 @@
|
|||
package taskmanager
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// test waiting
|
||||
|
||||
// globals
|
||||
var qtWg sync.WaitGroup
|
||||
var qtOutputChannel chan string
|
||||
var qtSleepDuration time.Duration
|
||||
|
||||
// functions
|
||||
func queuedTaskTester(s string) {
|
||||
t := NewQueuedTask(s)
|
||||
go func() {
|
||||
<-t.WaitForStart()
|
||||
time.Sleep(qtSleepDuration * 2)
|
||||
qtOutputChannel <- s
|
||||
t.Done()
|
||||
qtWg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
func prioritizedTastTester(s string) {
|
||||
t := NewPrioritizedQueuedTask(s)
|
||||
go func() {
|
||||
<-t.WaitForStart()
|
||||
time.Sleep(qtSleepDuration * 2)
|
||||
qtOutputChannel <- s
|
||||
t.Done()
|
||||
qtWg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// test
|
||||
func TestQueuedTask(t *testing.T) {
|
||||
|
||||
// skip
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
// init
|
||||
expectedOutput := "0123456789"
|
||||
qtSleepDuration = 10 * time.Millisecond
|
||||
qtOutputChannel = make(chan string, 100)
|
||||
qtWg.Add(10)
|
||||
|
||||
// test queue length
|
||||
c := TotalQueuedTasks()
|
||||
if c != 0 {
|
||||
t.Errorf("Error in calculating Task Queue, expected 0, got %d", c)
|
||||
}
|
||||
|
||||
// TEST
|
||||
queuedTaskTester("0")
|
||||
queuedTaskTester("1")
|
||||
queuedTaskTester("3")
|
||||
queuedTaskTester("4")
|
||||
queuedTaskTester("6")
|
||||
queuedTaskTester("7")
|
||||
queuedTaskTester("9")
|
||||
|
||||
// test queue length
|
||||
c = TotalQueuedTasks()
|
||||
if c != 7 {
|
||||
t.Errorf("Error in calculating Task Queue, expected 7, got %d", c)
|
||||
}
|
||||
|
||||
time.Sleep(qtSleepDuration * 3)
|
||||
prioritizedTastTester("2")
|
||||
time.Sleep(qtSleepDuration * 6)
|
||||
prioritizedTastTester("5")
|
||||
time.Sleep(qtSleepDuration * 6)
|
||||
prioritizedTastTester("8")
|
||||
|
||||
// test queue length
|
||||
c = TotalQueuedTasks()
|
||||
if c != 3 {
|
||||
t.Errorf("Error in calculating Task Queue, expected 3, got %d", c)
|
||||
}
|
||||
|
||||
// time.Sleep(qtSleepDuration * 100)
|
||||
// panic("")
|
||||
|
||||
// wait for test to finish
|
||||
qtWg.Wait()
|
||||
|
||||
// test queue length
|
||||
c = TotalQueuedTasks()
|
||||
if c != 0 {
|
||||
t.Errorf("Error in calculating Task Queue, expected 0, got %d", c)
|
||||
}
|
||||
|
||||
// collect output
|
||||
close(qtOutputChannel)
|
||||
completeOutput := ""
|
||||
for s := <-qtOutputChannel; s != ""; s = <-qtOutputChannel {
|
||||
completeOutput += s
|
||||
}
|
||||
// check if test succeeded
|
||||
if completeOutput != expectedOutput {
|
||||
t.Errorf("QueuedTask test failed, expected sequence %s, got %s", expectedOutput, completeOutput)
|
||||
}
|
||||
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
package taskmanager
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"time"
|
||||
)
|
||||
|
||||
var taskSchedule *list.List
|
||||
var addToSchedule chan *Task
|
||||
var waitForever chan time.Time
|
||||
|
||||
var getScheduleLengthREQ chan bool
|
||||
var getScheduleLengthREP chan int
|
||||
|
||||
func NewScheduledTask(name string, schedule time.Time) *Task {
|
||||
t := newUnqeuedTask(name)
|
||||
t.schedule = &schedule
|
||||
addToSchedule <- t
|
||||
return t
|
||||
}
|
||||
|
||||
func TotalScheduledTasks() int {
|
||||
getScheduleLengthREQ <- true
|
||||
return <-getScheduleLengthREP
|
||||
}
|
||||
|
||||
func (t *Task) addToSchedule() {
|
||||
for e := taskSchedule.Back(); e != nil; e = e.Prev() {
|
||||
if t.schedule.After(*e.Value.(*Task).schedule) {
|
||||
taskSchedule.InsertAfter(t, e)
|
||||
return
|
||||
}
|
||||
}
|
||||
taskSchedule.PushFront(t)
|
||||
}
|
||||
|
||||
func waitUntilNextScheduledTask() <-chan time.Time {
|
||||
if taskSchedule.Len() > 0 {
|
||||
return time.After(taskSchedule.Front().Value.(*Task).schedule.Sub(time.Now()))
|
||||
}
|
||||
return waitForever
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
taskSchedule = list.New()
|
||||
addToSchedule = make(chan *Task, 1)
|
||||
waitForever = make(chan time.Time, 1)
|
||||
|
||||
getScheduleLengthREQ = make(chan bool, 1)
|
||||
getScheduleLengthREP = make(chan int, 1)
|
||||
|
||||
go func() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-shutdownSignal:
|
||||
return
|
||||
case <-getScheduleLengthREQ:
|
||||
// TODO: maybe clean queues before replying
|
||||
getScheduleLengthREP <- prioritizedTaskQueue.Len() + taskSchedule.Len()
|
||||
case t := <-addToSchedule:
|
||||
t.addToSchedule()
|
||||
case <-waitUntilNextScheduledTask():
|
||||
e := taskSchedule.Front()
|
||||
t := e.Value.(*Task)
|
||||
t.addToPrioritizedQueue()
|
||||
taskSchedule.Remove(e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
package taskmanager
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// test waiting
|
||||
|
||||
// globals
|
||||
var stWg sync.WaitGroup
|
||||
var stOutputChannel chan string
|
||||
var stSleepDuration time.Duration
|
||||
var stWaitCh chan bool
|
||||
|
||||
// functions
|
||||
func scheduledTaskTester(s string, sched time.Time) {
|
||||
t := NewScheduledTask(s, sched)
|
||||
go func() {
|
||||
<-stWaitCh
|
||||
<-t.WaitForStart()
|
||||
time.Sleep(stSleepDuration)
|
||||
stOutputChannel <- s
|
||||
t.Done()
|
||||
stWg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// test
|
||||
func TestScheduledTaskWaiting(t *testing.T) {
|
||||
|
||||
// skip
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
// init
|
||||
expectedOutput := "0123456789"
|
||||
stSleepDuration = 10 * time.Millisecond
|
||||
stOutputChannel = make(chan string, 100)
|
||||
stWaitCh = make(chan bool, 0)
|
||||
|
||||
// test queue length
|
||||
c := TotalScheduledTasks()
|
||||
if c != 0 {
|
||||
t.Errorf("Error in calculating Task Queue, expected 0, got %d", c)
|
||||
}
|
||||
|
||||
stWg.Add(10)
|
||||
|
||||
// TEST
|
||||
scheduledTaskTester("4", time.Now().Add(stSleepDuration*4))
|
||||
scheduledTaskTester("0", time.Now().Add(stSleepDuration*1))
|
||||
scheduledTaskTester("8", time.Now().Add(stSleepDuration*8))
|
||||
scheduledTaskTester("1", time.Now().Add(stSleepDuration*2))
|
||||
scheduledTaskTester("7", time.Now().Add(stSleepDuration*7))
|
||||
|
||||
// test queue length
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
c = TotalScheduledTasks()
|
||||
if c != 5 {
|
||||
t.Errorf("Error in calculating Task Queue, expected 5, got %d", c)
|
||||
}
|
||||
|
||||
scheduledTaskTester("9", time.Now().Add(stSleepDuration*9))
|
||||
scheduledTaskTester("3", time.Now().Add(stSleepDuration*3))
|
||||
scheduledTaskTester("2", time.Now().Add(stSleepDuration*2))
|
||||
scheduledTaskTester("6", time.Now().Add(stSleepDuration*6))
|
||||
scheduledTaskTester("5", time.Now().Add(stSleepDuration*5))
|
||||
|
||||
// wait for test to finish
|
||||
close(stWaitCh)
|
||||
stWg.Wait()
|
||||
|
||||
// test queue length
|
||||
c = TotalScheduledTasks()
|
||||
if c != 0 {
|
||||
t.Errorf("Error in calculating Task Queue, expected 0, got %d", c)
|
||||
}
|
||||
|
||||
// collect output
|
||||
close(stOutputChannel)
|
||||
completeOutput := ""
|
||||
for s := <-stOutputChannel; s != ""; s = <-stOutputChannel {
|
||||
completeOutput += s
|
||||
}
|
||||
// check if test succeeded
|
||||
if completeOutput != expectedOutput {
|
||||
t.Errorf("ScheduledTask test failed, expected sequence %s, got %s", expectedOutput, completeOutput)
|
||||
}
|
||||
|
||||
}
|
134
test
134
test
|
@ -4,6 +4,23 @@ warnings=0
|
|||
errors=0
|
||||
scripted=0
|
||||
goUp="\\e[1A"
|
||||
all=0
|
||||
fullTestFlags="-short"
|
||||
install=0
|
||||
|
||||
function help {
|
||||
echo "usage: $0 [command] [options]"
|
||||
echo ""
|
||||
echo "commands:"
|
||||
echo " <none> run baseline tests"
|
||||
echo " all run all tests"
|
||||
echo " install install deps for running baseline tests"
|
||||
echo " install all install deps for running all tests"
|
||||
echo ""
|
||||
echo "options:"
|
||||
echo " --scripted dont jump console lines (still use colors)"
|
||||
echo " [package] run tests only on this package"
|
||||
}
|
||||
|
||||
function run {
|
||||
if [[ $scripted -eq 0 ]]; then
|
||||
|
@ -36,8 +53,8 @@ function run {
|
|||
if [[ $output == *"build constraints exclude all Go files"* ]]; then
|
||||
echo -e "${goUp}[ !=OS ] $*"
|
||||
else
|
||||
echo -e "${goUp}[\e[01;31m FAIL \e[00m] $*" >/dev/stderr
|
||||
cat $tmpfile >/dev/stderr
|
||||
echo -e "${goUp}[\e[01;31m FAIL \e[00m] $*"
|
||||
cat $tmpfile
|
||||
errors=$((errors+1))
|
||||
fi
|
||||
fi
|
||||
|
@ -45,27 +62,124 @@ function run {
|
|||
rm -f $tmpfile
|
||||
}
|
||||
|
||||
function checkformat {
|
||||
if [[ $scripted -eq 0 ]]; then
|
||||
echo "[......] gofmt $1"
|
||||
fi
|
||||
|
||||
output=$(gofmt -l $GOPATH/src/$1/*.go)
|
||||
if [[ $output == "" ]]; then
|
||||
echo -e "${goUp}[\e[01;32m OK \e[00m] gofmt $*"
|
||||
else
|
||||
echo -e "${goUp}[\e[01;31m FAIL \e[00m] gofmt $*"
|
||||
echo "The following files do not conform to gofmt:"
|
||||
gofmt -l $GOPATH/src/$1/*.go # keeps format
|
||||
errors=$((errors+1))
|
||||
fi
|
||||
}
|
||||
|
||||
# get and switch to script dir
|
||||
baseDir="$( cd "$(dirname "$0")" && pwd )"
|
||||
cd "$baseDir"
|
||||
|
||||
# change output format if being run in script
|
||||
if [[ $1 == "--scripted" ]]; then
|
||||
scripted=1
|
||||
goUp=""
|
||||
# args
|
||||
while true; do
|
||||
case "$1" in
|
||||
"-h"|"help"|"--help")
|
||||
help
|
||||
exit 0
|
||||
;;
|
||||
"--scripted")
|
||||
scripted=1
|
||||
goUp=""
|
||||
shift 1
|
||||
;;
|
||||
"install")
|
||||
install=1
|
||||
shift 1
|
||||
;;
|
||||
"all")
|
||||
all=1
|
||||
fullTestFlags=""
|
||||
shift 1
|
||||
;;
|
||||
*)
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# check if $GOPATH/bin is in $PATH
|
||||
if [[ $PATH != *"$GOPATH/bin"* ]]; then
|
||||
export PATH=$GOPATH/bin:$PATH
|
||||
fi
|
||||
|
||||
# install
|
||||
if [[ $install -eq 1 ]]; then
|
||||
echo "installing dependencies..."
|
||||
echo "$ go get -u golang.org/x/lint/golint"
|
||||
go get -u golang.org/x/lint/golint
|
||||
if [[ $all -eq 1 ]]; then
|
||||
echo "$ go get -u github.com/golangci/golangci-lint/cmd/golangci-lint"
|
||||
go get -u github.com/golangci/golangci-lint/cmd/golangci-lint
|
||||
fi
|
||||
fi
|
||||
|
||||
# check dependencies
|
||||
if [[ $(which go) == "" ]]; then
|
||||
echo "go command not found"
|
||||
exit 1
|
||||
fi
|
||||
if [[ $(which gofmt) == "" ]]; then
|
||||
echo "gofmt command not found"
|
||||
exit 1
|
||||
fi
|
||||
if [[ $(which golint) == "" ]]; then
|
||||
echo "golint command not found"
|
||||
echo "install with: go get -u golang.org/x/lint/golint"
|
||||
echo "or run: ./test install"
|
||||
exit 1
|
||||
fi
|
||||
if [[ $all -eq 1 ]]; then
|
||||
if [[ $(which golangci-lint) == "" ]]; then
|
||||
echo "golangci-lint command not found"
|
||||
echo "install locally with: go get -u github.com/golangci/golangci-lint/cmd/golangci-lint"
|
||||
echo "or run: ./test install all"
|
||||
echo ""
|
||||
echo "hint: install for CI with: curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin vX.Y.Z"
|
||||
echo "don't forget to specify the version you want"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# target selection
|
||||
if [[ "$1" == "" ]]; then
|
||||
# get all packages
|
||||
packages=$(go list ./...)
|
||||
else
|
||||
# single package testing
|
||||
packages=$(go list)/$1
|
||||
if [[ ! -d "$GOPATH/src/$packages" ]]; then
|
||||
echo "go package $packages does not exist"
|
||||
help
|
||||
exit 1
|
||||
fi
|
||||
echo "note: only running tests for package $packages"
|
||||
fi
|
||||
|
||||
# platform info
|
||||
platformInfo=$(go env GOOS GOARCH)
|
||||
echo "running tests for ${platformInfo//$'\n'/ }:"
|
||||
|
||||
# get all packages
|
||||
packages=$(go list ./...)
|
||||
|
||||
# run vet/test on packages
|
||||
for package in $packages; do
|
||||
checkformat $package
|
||||
run golint -set_exit_status -min_confidence 1.0 $package
|
||||
run go vet $package
|
||||
run go test -cover $package
|
||||
run go test -cover $fullTestFlags $package
|
||||
if [[ $all -eq 1 ]]; then
|
||||
run golangci-lint run $GOPATH/src/$package
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
const isWindows = runtime.GOOS == "windows"
|
||||
|
||||
// EnsureDirectory ensures that the given directoy exists and that is has the given permissions set.
|
||||
// EnsureDirectory ensures that the given directory exists and that is has the given permissions set.
|
||||
// If path is a file, it is deleted and a directory created.
|
||||
// If a directory is created, also all missing directories up to the required one are created with the given permissions.
|
||||
func EnsureDirectory(path string, perm os.FileMode) error {
|
||||
|
|
|
@ -59,7 +59,7 @@ func ExampleDirStructure() {
|
|||
fmt.Println(err)
|
||||
}
|
||||
|
||||
filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
|
||||
_ = filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err == nil {
|
||||
dir := strings.TrimPrefix(path, basePath)
|
||||
if dir == "" {
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
package testutils
|
||||
|
||||
import "runtime"
|
||||
|
||||
func GetLineNumberOfCaller(levels int) int {
|
||||
_, _, line, _ := runtime.Caller(levels + 1)
|
||||
return line
|
||||
}
|
Loading…
Add table
Reference in a new issue