Merge pull request from safing/fix/cleanup

Code cleanup
This commit is contained in:
Daniel 2019-09-23 23:55:43 +02:00 committed by GitHub
commit f311c2864d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
88 changed files with 1920 additions and 1759 deletions

6
.golangci.yml Normal file
View file

@ -0,0 +1,6 @@
linters:
enable-all: true
disable:
- lll
- gochecknoinits
- gochecknoglobals

View file

@ -101,6 +101,7 @@ func authMiddleware(next http.Handler) http.Handler {
Name: cookieName,
Value: tokenString,
HttpOnly: true,
MaxAge: int(cookieTTL.Seconds()),
})
// serve

View file

@ -6,7 +6,7 @@ import (
"time"
"github.com/safing/portbase/log"
"github.com/gorilla/websocket"
"github.com/tevino/abool"
)
@ -35,7 +35,6 @@ type Client struct {
operations map[string]*Operation
nextOpID uint64
wsConn *websocket.Conn
lastError string
}
@ -72,12 +71,12 @@ func (c *Client) Connect() error {
func (c *Client) StayConnected() {
log.Infof("client: connecting to Portmaster at %s", c.server)
c.Connect()
_ = c.Connect()
for {
select {
case <-time.After(backOffTimer):
log.Infof("client: reconnecting...")
c.Connect()
_ = c.Connect()
case <-c.shutdownSignal:
return
}

View file

@ -9,10 +9,12 @@ import (
"github.com/tevino/abool"
)
// Client errors
var (
ErrMalformedMessage = errors.New("malformed message")
)
// Message is an API message.
type Message struct {
OpID string
Type string
@ -22,6 +24,7 @@ type Message struct {
sent *abool.AtomicBool
}
// ParseMessage parses the given raw data and returns a Message.
func ParseMessage(data []byte) (*Message, error) {
parts := bytes.SplitN(data, apiSeperatorBytes, 4)
if len(parts) < 2 {
@ -68,6 +71,7 @@ func ParseMessage(data []byte) (*Message, error) {
return m, nil
}
// Pack serializes a message into a []byte slice.
func (m *Message) Pack() ([]byte, error) {
c := container.New([]byte(m.OpID), apiSeperatorBytes, []byte(m.Type))
@ -90,28 +94,3 @@ func (m *Message) Pack() ([]byte, error) {
return c.CompileData(), nil
}
func (m *Message) IsOk() bool {
return m.Type == MsgOk
}
func (m *Message) IsDone() bool {
return m.Type == MsgDone
}
func (m *Message) IsError() bool {
return m.Type == MsgError
}
func (m *Message) IsUpdate() bool {
return m.Type == MsgUpdate
}
func (m *Message) IsNew() bool {
return m.Type == MsgNew
}
func (m *Message) IsDelete() bool {
return m.Type == MsgDelete
}
func (m *Message) IsWarning() bool {
return m.Type == MsgWarning
}
func (m *Message) GetMessage() string {
return m.Key
}

View file

@ -191,7 +191,7 @@ func (api *DatabaseAPI) writer() {
select {
// prioritize direct writes
case data = <-api.sendQueue:
if data == nil || len(data) == 0 {
if len(data) == 0 {
api.shutdown()
return
}
@ -242,8 +242,9 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) {
r, err := api.db.Get(key)
if err == nil {
data, err = r.Marshal(r, record.JSON)
} else {
api.send(opID, dbMsgTypeError, err.Error(), nil)
}
if err == nil {
api.send(opID, dbMsgTypeError, err.Error(), nil) //nolint:nilness // FIXME: possibly false positive (golangci-lint govet/nilness)
return
}
api.send(opID, dbMsgTypeOk, r.Key(), data)
@ -357,7 +358,7 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
select {
case <-api.shutdownSignal:
// cancel sub and return
sub.Cancel()
_ = sub.Cancel()
return
case r := <-sub.Feed:
// process sub feed
@ -373,17 +374,19 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
// TODO: use upd, new and delete msgTypes
r.Lock()
isDeleted := r.Meta().IsDeleted()
new := r.Meta().Created == r.Meta().Modified
r.Unlock()
if isDeleted {
switch {
case isDeleted:
api.send(opID, dbMsgTypeDel, r.Key(), nil)
} else {
case new:
api.send(opID, dbMsgTypeNew, r.Key(), data)
default:
api.send(opID, dbMsgTypeUpd, r.Key(), data)
}
} else {
} else if sub.Err != nil {
// sub feed ended
if sub.Err != nil {
api.send(opID, dbMsgTypeError, sub.Err.Error(), nil)
}
api.send(opID, dbMsgTypeError, sub.Err.Error(), nil)
}
}
}
@ -489,10 +492,7 @@ func (api *DatabaseAPI) handleInsert(opID []byte, key string, data []byte) {
return false
}
insertError = acc.Set(key.String(), value.Value())
if insertError != nil {
return false
}
return true
return insertError == nil
})
if insertError != nil {

View file

@ -99,12 +99,12 @@ func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterato
}
}
go s.processQuery(q, it, opts)
go s.processQuery(it, opts)
return it, nil
}
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator, opts []*Option) {
func (s *StorageInterface) processQuery(it *iterator.Iterator, opts []*Option) {
sort.Sort(sortableOptions(opts))

View file

@ -9,12 +9,6 @@ import (
var (
validityFlag = abool.NewBool(true)
validityFlagLock sync.RWMutex
tableLock sync.RWMutex
stringTable map[string]string
intTable map[string]int
boolTable map[string]bool
)
type (

View file

@ -2,8 +2,8 @@ package config
import (
"errors"
"sync"
"fmt"
"sync"
"github.com/safing/portbase/log"
)
@ -20,7 +20,7 @@ var (
// ErrInvalidOptionType is returned by SetConfigOption and SetDefaultConfigOption if given an unsupported option type.
ErrInvalidOptionType = errors.New("invalid option value type")
changedSignal = make(chan struct{}, 0)
changedSignal = make(chan struct{})
)
// Changed signals if any config option was changed.
@ -34,7 +34,7 @@ func Changed() <-chan struct{} {
func triggerChange() {
// must be locked!
close(changedSignal)
changedSignal = make(chan struct{}, 0)
changedSignal = make(chan struct{})
}
// setConfig sets the (prioritized) user defined config.
@ -141,7 +141,7 @@ func setConfigOption(name string, value interface{}, push bool) error {
if err == nil {
resetValidityFlag()
go saveConfig()
go saveConfig() //nolint:errcheck // error is logged
triggerChange()
}

View file

@ -1,3 +1,4 @@
//nolint:goconst,errcheck
package config
import "testing"

View file

@ -39,16 +39,19 @@ func getTypeName(t uint8) string {
// Option describes a configuration option.
type Option struct {
Name string
Key string // category/sub/key
Description string
Name string
Key string // in path format: category/sub/key
Description string
ExpertiseLevel uint8
OptType uint8
RequiresRestart bool
DefaultValue interface{}
ExternalOptType string
ValidationRegex string
RequiresRestart bool
compiledRegex *regexp.Regexp
compiledRegex *regexp.Regexp
}
// Export expors an option to a Record.

View file

@ -82,7 +82,7 @@ func MapToJSON(mapData map[string]interface{}) ([]byte, error) {
for key, value := range mapData {
new[key] = value
}
expand(new)
return json.MarshalIndent(new, "", " ")
}

View file

@ -1,145 +0,0 @@
package hash
import (
"crypto/sha256"
"crypto/sha512"
"hash"
"golang.org/x/crypto/sha3"
)
type Algorithm uint8
const (
SHA2_224 Algorithm = 1 + iota
SHA2_256
SHA2_512_224
SHA2_512_256
SHA2_384
SHA2_512
SHA3_224
SHA3_256
SHA3_384
SHA3_512
BLAKE2S_256
BLAKE2B_256
BLAKE2B_384
BLAKE2B_512
)
var (
attributes = map[Algorithm][]uint8{
// block size, output size, security strength - in bytes
SHA2_224: []uint8{64, 28, 14},
SHA2_256: []uint8{64, 32, 16},
SHA2_512_224: []uint8{128, 28, 14},
SHA2_512_256: []uint8{128, 32, 16},
SHA2_384: []uint8{128, 48, 24},
SHA2_512: []uint8{128, 64, 32},
SHA3_224: []uint8{144, 28, 14},
SHA3_256: []uint8{136, 32, 16},
SHA3_384: []uint8{104, 48, 24},
SHA3_512: []uint8{72, 64, 32},
BLAKE2S_256: []uint8{64, 32, 16},
BLAKE2B_256: []uint8{128, 32, 16},
BLAKE2B_384: []uint8{128, 48, 24},
BLAKE2B_512: []uint8{128, 64, 32},
}
functions = map[Algorithm]func() hash.Hash{
SHA2_224: sha256.New224,
SHA2_256: sha256.New,
SHA2_512_224: sha512.New512_224,
SHA2_512_256: sha512.New512_256,
SHA2_384: sha512.New384,
SHA2_512: sha512.New,
SHA3_224: sha3.New224,
SHA3_256: sha3.New256,
SHA3_384: sha3.New384,
SHA3_512: sha3.New512,
BLAKE2S_256: NewBlake2s256,
BLAKE2B_256: NewBlake2b256,
BLAKE2B_384: NewBlake2b384,
BLAKE2B_512: NewBlake2b512,
}
// just ordered by strength and establishment, no research conducted yet.
orderedByRecommendation = []Algorithm{
SHA3_512, // {72, 64, 32}
SHA2_512, // {128, 64, 32}
BLAKE2B_512, // {128, 64, 32}
SHA3_384, // {104, 48, 24}
SHA2_384, // {128, 48, 24}
BLAKE2B_384, // {128, 48, 24}
SHA3_256, // {136, 32, 16}
SHA2_512_256, // {128, 32, 16}
SHA2_256, // {64, 32, 16}
BLAKE2B_256, // {128, 32, 16}
BLAKE2S_256, // {64, 32, 16}
SHA3_224, // {144, 28, 14}
SHA2_512_224, // {128, 28, 14}
SHA2_224, // {64, 28, 14}
}
// names
names = map[Algorithm]string{
SHA2_224: "SHA2-224",
SHA2_256: "SHA2-256",
SHA2_512_224: "SHA2-512/224",
SHA2_512_256: "SHA2-512/256",
SHA2_384: "SHA2-384",
SHA2_512: "SHA2-512",
SHA3_224: "SHA3-224",
SHA3_256: "SHA3-256",
SHA3_384: "SHA3-384",
SHA3_512: "SHA3-512",
BLAKE2S_256: "Blake2s-256",
BLAKE2B_256: "Blake2b-256",
BLAKE2B_384: "Blake2b-384",
BLAKE2B_512: "Blake2b-512",
}
)
func (a Algorithm) BlockSize() uint8 {
att, ok := attributes[a]
if !ok {
return 0
}
return att[0]
}
func (a Algorithm) Size() uint8 {
att, ok := attributes[a]
if !ok {
return 0
}
return att[1]
}
func (a Algorithm) SecurityStrength() uint8 {
att, ok := attributes[a]
if !ok {
return 0
}
return att[2]
}
func (a Algorithm) String() string {
return a.Name()
}
func (a Algorithm) Name() string {
name, ok := names[a]
if !ok {
return ""
}
return name
}
func (a Algorithm) New() hash.Hash {
fn, ok := functions[a]
if !ok {
return nil
}
return fn()
}

View file

@ -1,54 +0,0 @@
package hash
import "testing"
func TestAttributes(t *testing.T) {
for alg, att := range attributes {
name, ok := names[alg]
if !ok {
t.Errorf("hash test: name missing for Algorithm ID %d", alg)
}
_ = alg.String()
_, ok = functions[alg]
if !ok {
t.Errorf("hash test: function missing for Algorithm %s", name)
}
hash := alg.New()
if len(att) != 3 {
t.Errorf("hash test: Algorithm %s does not have exactly 3 attributes", name)
}
if hash.BlockSize() != int(alg.BlockSize()) {
t.Errorf("hash test: block size mismatch at Algorithm %s", name)
}
if hash.Size() != int(alg.Size()) {
t.Errorf("hash test: size mismatch at Algorithm %s", name)
}
if alg.Size()/2 != alg.SecurityStrength() {
t.Errorf("hash test: possible strength error at Algorithm %s", name)
}
}
noAlg := Algorithm(255)
if noAlg.String() != "" {
t.Error("hash test: invalid Algorithm error")
}
if noAlg.BlockSize() != 0 {
t.Error("hash test: invalid Algorithm error")
}
if noAlg.Size() != 0 {
t.Error("hash test: invalid Algorithm error")
}
if noAlg.SecurityStrength() != 0 {
t.Error("hash test: invalid Algorithm error")
}
if noAlg.New() != nil {
t.Error("hash test: invalid Algorithm error")
}
}

View file

@ -1,131 +0,0 @@
package hash
import (
"bytes"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"github.com/safing/portbase/formats/varint"
)
type Hash struct {
Algorithm Algorithm
Sum []byte
}
func FromBytes(bytes []byte) (*Hash, int, error) {
hash := &Hash{}
alg, read, err := varint.Unpack8(bytes)
hash.Algorithm = Algorithm(alg)
if err != nil {
return nil, 0, errors.New(fmt.Sprintf("hash: failed to parse: %s", err))
}
// TODO: check if length is correct
hash.Sum = bytes[read:]
return hash, 0, nil
}
func (h *Hash) Bytes() []byte {
return append(varint.Pack8(uint8(h.Algorithm)), h.Sum...)
}
func FromSafe64(s string) (*Hash, error) {
bytes, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
return nil, errors.New(fmt.Sprintf("hash: failed to parse: %s", err))
}
hash, _, err := FromBytes(bytes)
return hash, err
}
func (h *Hash) Safe64() string {
return base64.RawURLEncoding.EncodeToString(h.Bytes())
}
func FromHex(s string) (*Hash, error) {
bytes, err := hex.DecodeString(s)
if err != nil {
return nil, errors.New(fmt.Sprintf("hash: failed to parse: %s", err))
}
hash, _, err := FromBytes(bytes)
return hash, err
}
func (h *Hash) Hex() string {
return hex.EncodeToString(h.Bytes())
}
func (h *Hash) Equal(other *Hash) bool {
if h.Algorithm != other.Algorithm {
return false
}
return bytes.Equal(h.Sum, other.Sum)
}
func Sum(data []byte, alg Algorithm) *Hash {
hasher := alg.New()
hasher.Write(data)
return &Hash{
Algorithm: alg,
Sum: hasher.Sum(nil),
}
}
func SumString(data string, alg Algorithm) *Hash {
hasher := alg.New()
io.WriteString(hasher, data)
return &Hash{
Algorithm: alg,
Sum: hasher.Sum(nil),
}
}
func SumReader(reader io.Reader, alg Algorithm) (*Hash, error) {
hasher := alg.New()
_, err := io.Copy(hasher, reader)
if err != nil {
return nil, err
}
return &Hash{
Algorithm: alg,
Sum: hasher.Sum(nil),
}, nil
}
func SumAndCompare(data []byte, other Hash) (bool, *Hash) {
newHash := Sum(data, other.Algorithm)
return other.Equal(newHash), newHash
}
func SumReaderAndCompare(reader io.Reader, other Hash) (bool, *Hash, error) {
newHash, err := SumReader(reader, other.Algorithm)
if err != nil {
return false, nil, err
}
return other.Equal(newHash), newHash, nil
}
func RecommendedAlg(strengthInBits uint16) Algorithm {
strengthInBytes := uint8(strengthInBits / 8)
if strengthInBits%8 != 0 {
strengthInBytes++
}
if strengthInBytes == 0 {
strengthInBytes = uint8(0xFF)
}
chosenAlg := orderedByRecommendation[0]
for _, alg := range orderedByRecommendation {
strength := alg.SecurityStrength()
if strength < strengthInBytes {
break
}
chosenAlg = alg
if strength == strengthInBytes {
break
}
}
return chosenAlg
}

View file

@ -1,82 +0,0 @@
package hash
import (
"bytes"
"testing"
)
var (
testEmpty = []byte("")
testFox = []byte("The quick brown fox jumps over the lazy dog")
)
func testAlgorithm(t *testing.T, alg Algorithm, emptyHex, foxHex string) {
var err error
// testEmpty
hash := Sum(testEmpty, alg)
if err != nil {
t.Errorf("test Sum %s (empty): error occured: %s", alg.String(), err)
}
if hash.Hex()[2:] != emptyHex {
t.Errorf("test Sum %s (empty): hex sum mismatch, expected %s, got %s", alg.String(), emptyHex, hash.Hex())
}
// testFox
hash = Sum(testFox, alg)
if err != nil {
t.Errorf("test Sum %s (fox): error occured: %s", alg.String(), err)
}
if hash.Hex()[2:] != foxHex {
t.Errorf("test Sum %s (fox): hex sum mismatch, expected %s, got %s", alg.String(), foxHex, hash.Hex())
}
// testEmpty
hash = SumString(string(testEmpty), alg)
if err != nil {
t.Errorf("test SumString %s (empty): error occured: %s", alg.String(), err)
}
if hash.Hex()[2:] != emptyHex {
t.Errorf("test SumString %s (empty): hex sum mismatch, expected %s, got %s", alg.String(), emptyHex, hash.Hex())
}
// testFox
hash = SumString(string(testFox), alg)
if err != nil {
t.Errorf("test SumString %s (fox): error occured: %s", alg.String(), err)
}
if hash.Hex()[2:] != foxHex {
t.Errorf("test SumString %s (fox): hex sum mismatch, expected %s, got %s", alg.String(), foxHex, hash.Hex())
}
// testEmpty
hash, err = SumReader(bytes.NewReader(testEmpty), alg)
if err != nil {
t.Errorf("test SumReader %s (empty): error occured: %s", alg.String(), err)
}
if hash.Hex()[2:] != emptyHex {
t.Errorf("test SumReader %s (empty): hex sum mismatch, expected %s, got %s", alg.String(), emptyHex, hash.Hex())
}
// testFox
hash, err = SumReader(bytes.NewReader(testFox), alg)
if err != nil {
t.Errorf("test SumReader %s (fox): error occured: %s", alg.String(), err)
}
if hash.Hex()[2:] != foxHex {
t.Errorf("test SumReader %s (fox): hex sum mismatch, expected %s, got %s", alg.String(), foxHex, hash.Hex())
}
}
func TestHash(t *testing.T) {
testAlgorithm(t, SHA2_512,
"cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
"07e547d9586f6a73f73fbac0435ed76951218fb7d0c8d788a309d785436bbb642e93a252a954f23912547d1e8a3b5ed6e1bfd7097821233fa0538f3db854fee6",
)
testAlgorithm(t, SHA3_512,
"a69f73cca23a9ac5c8b567dc185a756e97c982164fe25859e0d1dcc1475c80a615b2123af1f5f94c11e3e9402c3ac558f500199d95b6d3e301758586281dcd26",
"01dedd5de4ef14642445ba5f5b97c15e47b9ad931326e4b0727cd94cefc44fff23f07bf543139939b49128caf436dc1bdee54fcb24023a08d9403f9b4bf0d450",
)
}

View file

@ -1,28 +0,0 @@
package hash
import (
"hash"
"golang.org/x/crypto/blake2b"
"golang.org/x/crypto/blake2s"
)
func NewBlake2s256() hash.Hash {
h, _ := blake2s.New256(nil)
return h
}
func NewBlake2b256() hash.Hash {
h, _ := blake2b.New256(nil)
return h
}
func NewBlake2b384() hash.Hash {
h, _ := blake2b.New384(nil)
return h
}
func NewBlake2b512() hash.Hash {
h, _ := blake2b.New512(nil)
return h
}

View file

@ -10,23 +10,10 @@ import (
)
var (
rngFeeder = make(chan []byte, 0)
rngFeeder = make(chan []byte)
minFeedEntropy config.IntOption
)
func init() {
config.Register(&config.Option{
Name: "Minimum Feed Entropy",
Key: "random/min_feed_entropy",
Description: "The minimum amount of entropy before a entropy source is feed to the RNG, in bits.",
ExpertiseLevel: config.ExpertiseLevelDeveloper,
OptType: config.OptTypeInt,
DefaultValue: 256,
ValidationRegex: "^[0-9]{3,5}$",
})
minFeedEntropy = config.Concurrent.GetAsInt("random/min_feed_entropy", 256)
}
// The Feeder is used to feed entropy to the RNG.
type Feeder struct {
input chan *entropyData
@ -43,7 +30,7 @@ type entropyData struct {
// NewFeeder returns a new entropy Feeder.
func NewFeeder() *Feeder {
new := &Feeder{
input: make(chan *entropyData, 0),
input: make(chan *entropyData),
needsEntropy: abool.NewBool(true),
buffer: container.New(),
}

View file

@ -25,28 +25,6 @@ var (
type reader struct{}
func init() {
config.Register(&config.Option{
Name: "Reseed after x seconds",
Key: "random/reseed_after_seconds",
Description: "Number of seconds until reseed",
ExpertiseLevel: config.ExpertiseLevelDeveloper,
OptType: config.OptTypeInt,
DefaultValue: 360, // ten minutes
ValidationRegex: "^[1-9][0-9]{1,5}$",
})
reseedAfterSeconds = config.Concurrent.GetAsInt("random/reseed_after_seconds", 360)
config.Register(&config.Option{
Name: "Reseed after x bytes",
Key: "random/reseed_after_bytes",
Description: "Number of fetched bytes until reseed",
ExpertiseLevel: config.ExpertiseLevelDeveloper,
OptType: config.OptTypeInt,
DefaultValue: 1000000, // one megabyte
ValidationRegex: "^[1-9][0-9]{2,9}$",
})
reseedAfterBytes = config.GetAsInt("random/reseed_after_bytes", 1000000)
Reader = reader{}
}
@ -55,7 +33,7 @@ func checkEntropy() (err error) {
return errors.New("RNG is not ready yet")
}
if rngBytesRead > reseedAfterBytes() ||
int64(time.Now().Sub(rngLastFeed).Seconds()) > reseedAfterSeconds() {
int64(time.Since(rngLastFeed).Seconds()) > reseedAfterSeconds() {
select {
case r := <-rngFeeder:
rng.Reseed(r)

View file

@ -19,13 +19,15 @@ var (
rngReady = false
rngCipherOption config.StringOption
shutdownSignal = make(chan struct{}, 0)
shutdownSignal = make(chan struct{})
)
func init() {
modules.Register("random", prep, Start, stop, "base")
modules.Register("random", prep, Start, nil, "base")
}
config.Register(&config.Option{
func prep() error {
err := config.Register(&config.Option{
Name: "RNG Cipher",
Key: "random/rng_cipher",
Description: "Cipher to use for the Fortuna RNG. Requires restart to take effect.",
@ -35,10 +37,53 @@ func init() {
DefaultValue: "aes",
ValidationRegex: "^(aes|serpent)$",
})
if err != nil {
return err
}
rngCipherOption = config.GetAsString("random/rng_cipher", "aes")
}
func prep() error {
err = config.Register(&config.Option{
Name: "Minimum Feed Entropy",
Key: "random/min_feed_entropy",
Description: "The minimum amount of entropy before a entropy source is feed to the RNG, in bits.",
ExpertiseLevel: config.ExpertiseLevelDeveloper,
OptType: config.OptTypeInt,
DefaultValue: 256,
ValidationRegex: "^[0-9]{3,5}$",
})
if err != nil {
return err
}
minFeedEntropy = config.Concurrent.GetAsInt("random/min_feed_entropy", 256)
err = config.Register(&config.Option{
Name: "Reseed after x seconds",
Key: "random/reseed_after_seconds",
Description: "Number of seconds until reseed",
ExpertiseLevel: config.ExpertiseLevelDeveloper,
OptType: config.OptTypeInt,
DefaultValue: 360, // ten minutes
ValidationRegex: "^[1-9][0-9]{1,5}$",
})
if err != nil {
return err
}
reseedAfterSeconds = config.Concurrent.GetAsInt("random/reseed_after_seconds", 360)
err = config.Register(&config.Option{
Name: "Reseed after x bytes",
Key: "random/reseed_after_bytes",
Description: "Number of fetched bytes until reseed",
ExpertiseLevel: config.ExpertiseLevelDeveloper,
OptType: config.OptTypeInt,
DefaultValue: 1000000, // one megabyte
ValidationRegex: "^[1-9][0-9]{2,9}$",
})
if err != nil {
return err
}
reseedAfterBytes = config.GetAsInt("random/reseed_after_bytes", 1000000)
return nil
}
@ -73,7 +118,3 @@ func Start() (err error) {
return nil
}
func stop() error {
return nil
}

View file

@ -7,21 +7,34 @@ import (
)
func init() {
prep()
Start()
err := prep()
if err != nil {
panic(err)
}
err = Start()
if err != nil {
panic(err)
}
}
func TestRNG(t *testing.T) {
key := make([]byte, 16)
config.SetConfigOption("random.rng_cipher", "aes")
_, err := newCipher(key)
err := config.SetConfigOption("random/rng_cipher", "aes")
if err != nil {
t.Errorf("failed to set random/rng_cipher config: %s", err)
}
_, err = newCipher(key)
if err != nil {
t.Errorf("failed to create aes cipher: %s", err)
}
rng.Reseed(key)
config.SetConfigOption("random.rng_cipher", "serpent")
err = config.SetConfigOption("random/rng_cipher", "serpent")
if err != nil {
t.Errorf("failed to set random/rng_cipher config: %s", err)
}
_, err = newCipher(key)
if err != nil {
t.Errorf("failed to create serpent cipher: %s", err)

View file

@ -55,7 +55,10 @@ func main() {
switch os.Args[1] {
case "fortuna":
random.Start()
err := random.Start()
if err != nil {
panic(err)
}
for {
b, err := random.Bytes(64)

View file

@ -1,8 +1,6 @@
package accessor
import (
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@ -23,23 +21,9 @@ func NewJSONBytesAccessor(json *[]byte) *JSONBytesAccessor {
func (ja *JSONBytesAccessor) Set(key string, value interface{}) error {
result := gjson.GetBytes(*ja.json, key)
if result.Exists() {
switch value.(type) {
case string:
if result.Type != gjson.String {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
if result.Type != gjson.Number {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case bool:
if result.Type != gjson.True && result.Type != gjson.False {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case []string:
if !result.IsArray() {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
err := checkJSONValueType(result, key, value)
if err != nil {
return err
}
}

View file

@ -23,23 +23,9 @@ func NewJSONAccessor(json *string) *JSONAccessor {
func (ja *JSONAccessor) Set(key string, value interface{}) error {
result := gjson.Get(*ja.json, key)
if result.Exists() {
switch value.(type) {
case string:
if result.Type != gjson.String {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
if result.Type != gjson.Number {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case bool:
if result.Type != gjson.True && result.Type != gjson.False {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
case []string:
if !result.IsArray() {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, result.Type.String(), value)
}
err := checkJSONValueType(result, key, value)
if err != nil {
return err
}
}
@ -51,6 +37,28 @@ func (ja *JSONAccessor) Set(key string, value interface{}) error {
return nil
}
func checkJSONValueType(jsonValue gjson.Result, key string, value interface{}) error {
switch value.(type) {
case string:
if jsonValue.Type != gjson.String {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
if jsonValue.Type != gjson.Number {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
}
case bool:
if jsonValue.Type != gjson.True && jsonValue.Type != gjson.False {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
}
case []string:
if !jsonValue.IsArray() {
return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value)
}
}
return nil
}
// Get returns the value found by the given json key and whether it could be successfully extracted.
func (ja *JSONAccessor) Get(key string) (value interface{}, ok bool) {
result := gjson.Get(*ja.json, key)

View file

@ -160,10 +160,7 @@ func (sa *StructAccessor) GetBool(key string) (value bool, ok bool) {
// Exists returns the whether the given key exists.
func (sa *StructAccessor) Exists(key string) bool {
field := sa.object.FieldByName(key)
if field.IsValid() {
return true
}
return false
return field.IsValid()
}
// Type returns the accessor type as a string.

View file

@ -1,3 +1,4 @@
//nolint:maligned,unparam
package accessor
import (

View file

@ -30,12 +30,12 @@ type Controller struct {
}
// newController creates a new controller for a storage.
func newController(storageInt storage.Interface) (*Controller, error) {
func newController(storageInt storage.Interface) *Controller {
return &Controller{
storage: storageInt,
migrating: abool.NewBool(false),
hibernating: abool.NewBool(false),
}, nil
}
}
// ReadOnly returns whether the storage is read only.
@ -221,7 +221,7 @@ func (c *Controller) MaintainThorough() error {
// Shutdown shuts down the storage.
func (c *Controller) Shutdown() error {
// aquire full locks
// acquire full locks
c.readLock.Lock()
defer c.readLock.Unlock()
c.writeLock.Lock()

View file

@ -51,12 +51,7 @@ func getController(name string) (*Controller, error) {
return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err)
}
// create controller
controller, err = newController(storageInt)
if err != nil {
return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err)
}
controller = newController(storageInt)
controllers[name] = controller
return controller, nil
}
@ -87,11 +82,7 @@ func InjectDatabase(name string, storageInt storage.Interface) (*Controller, err
return nil, fmt.Errorf(`database not of type "injected"`)
}
controller, err := newController(storageInt)
if err != nil {
return nil, fmt.Errorf(`could not create controller for database %s: %s`, name, err)
}
controller := newController(storageInt)
controllers[name] = controller
return controller, nil
}

View file

@ -97,7 +97,7 @@ func testDatabase(t *testing.T, storageType string) {
}
cnt := 0
for _ = range it.Next {
for range it.Next {
cnt++
}
if it.Err() != nil {
@ -124,7 +124,7 @@ func TestDatabaseSystem(t *testing.T) {
go func() {
time.Sleep(10 * time.Second)
fmt.Println("===== TAKING TOO LONG - PRINTING STACK TRACES =====")
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
os.Exit(1)
}()

View file

@ -38,7 +38,7 @@ func start() error {
return err
}
startMaintainer()
registerMaintenanceTasks()
return nil
}

View file

@ -1,45 +1,37 @@
package dbmodule
import (
"context"
"time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
)
var (
maintenanceShortTickDuration = 10 * time.Minute
maintenanceLongTickDuration = 1 * time.Hour
)
func startMaintainer() {
module.AddWorkers(1)
go maintenanceWorker()
func registerMaintenanceTasks() {
module.NewTask("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute)
module.NewTask("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour)
module.NewTask("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour)
}
func maintenanceWorker() {
ticker := time.NewTicker(maintenanceShortTickDuration)
longTicker := time.NewTicker(maintenanceLongTickDuration)
for {
select {
case <-ticker.C:
err := database.Maintain()
if err != nil {
log.Errorf("database: maintenance error: %s", err)
}
case <-longTicker.C:
err := database.MaintainRecordStates()
if err != nil {
log.Errorf("database: record states maintenance error: %s", err)
}
err = database.MaintainThorough()
if err != nil {
log.Errorf("database: thorough maintenance error: %s", err)
}
case <-module.ShuttingDown():
module.FinishWorker()
return
}
func maintainBasic(ctx context.Context, task *modules.Task) {
err := database.Maintain()
if err != nil {
log.Errorf("database: maintenance error: %s", err)
}
}
func maintainThorough(ctx context.Context, task *modules.Task) {
err := database.MaintainThorough()
if err != nil {
log.Errorf("database: thorough maintenance error: %s", err)
}
}
func maintainRecords(ctx context.Context, task *modules.Task) {
err := database.MaintainRecordStates()
if err != nil {
log.Errorf("database: record states maintenance error: %s", err)
}
}

View file

@ -1,10 +0,0 @@
package dbutils
type Meta struct {
Created int64 `json:"c,omitempty" bson:"c,omitempty"`
Modified int64 `json:"m,omitempty" bson:"m,omitempty"`
Expires int64 `json:"e,omitempty" bson:"e,omitempty"`
Deleted int64 `json:"d,omitempty" bson:"d,omitempty"`
Secret bool `json:"s,omitempty" bson:"s,omitempty"` // secrets must not be sent to clients, only synced between cores
Cronjewel bool `json:"j,omitempty" bson:"j,omitempty"` // crownjewels must never leave the instance
}

View file

@ -23,7 +23,7 @@ type RegisteredHook struct {
h Hook
}
// RegisterHook registeres a hook for records matching the given query in the database.
// RegisterHook registers a hook for records matching the given query in the database.
func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) {
_, err := q.Check()
if err != nil {

View file

@ -80,7 +80,7 @@ func (i *Interface) checkCache(key string) (record.Record, bool) {
func (i *Interface) updateCache(r record.Record) {
if i.cache != nil {
i.cache.Set(r.Key(), r)
_ = i.cache.Set(r.Key(), r)
}
}

View file

@ -69,11 +69,17 @@ func MaintainRecordStates() error {
}
for _, r := range toDelete {
c.storage.Delete(r.DatabaseKey())
err := c.storage.Delete(r.DatabaseKey())
if err != nil {
return err
}
}
for _, r := range toExpire {
r.Meta().Delete()
return c.Put(r)
err := c.Put(r)
if err != nil {
return err
}
}
}

View file

@ -38,7 +38,7 @@ func (c *andCond) check() (err error) {
}
func (c *andCond) string() string {
var all []string
all := make([]string, 0, len(c.conditions))
for _, cond := range c.conditions {
all = append(all, cond.string())
}

View file

@ -28,7 +28,7 @@ func newIntCondition(key string, operator uint8, value interface{}) *intConditio
case int32:
parsedValue = int64(v)
case int64:
parsedValue = int64(v)
parsedValue = v
case uint:
parsedValue = int64(v)
case uint8:

View file

@ -38,7 +38,7 @@ func (c *orCond) check() (err error) {
}
func (c *orCond) string() string {
var all []string
all := make([]string, 0, len(c.conditions))
for _, cond := range c.conditions {
all = append(all, cond.string())
}

View file

@ -205,7 +205,7 @@ func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() in
for {
if !expectingMore && rootCondition && remainingSnippets() == 0 {
// advance snippetsPos by one, as it will be set back by 1
getSnippet()
getSnippet() //nolint:errcheck
if len(conditions) == 1 {
return conditions[0], nil
}
@ -330,7 +330,7 @@ func parseCondition(firstSnippet *snippet, getSnippet func() (*snippet, error))
}
var (
escapeReplacer = regexp.MustCompile("\\\\([^\\\\])")
escapeReplacer = regexp.MustCompile(`\\([^\\])`)
)
// prepToken removes surrounding parenthesis and escape characters.

View file

@ -10,36 +10,36 @@ import (
func TestExtractSnippets(t *testing.T) {
text1 := `query test: where ( "bananas" > 100 and monkeys.# <= "12")or(coconuts < 10 "and" area > 50) or name sameas Julian or name matches ^King\ `
result1 := []*snippet{
&snippet{text: "query", globalPosition: 1},
&snippet{text: "test:", globalPosition: 7},
&snippet{text: "where", globalPosition: 13},
&snippet{text: "(", globalPosition: 19},
&snippet{text: "bananas", globalPosition: 21},
&snippet{text: ">", globalPosition: 31},
&snippet{text: "100", globalPosition: 33},
&snippet{text: "and", globalPosition: 37},
&snippet{text: "monkeys.#", globalPosition: 41},
&snippet{text: "<=", globalPosition: 51},
&snippet{text: "12", globalPosition: 54},
&snippet{text: ")", globalPosition: 58},
&snippet{text: "or", globalPosition: 59},
&snippet{text: "(", globalPosition: 61},
&snippet{text: "coconuts", globalPosition: 62},
&snippet{text: "<", globalPosition: 71},
&snippet{text: "10", globalPosition: 73},
&snippet{text: "and", globalPosition: 76},
&snippet{text: "area", globalPosition: 82},
&snippet{text: ">", globalPosition: 87},
&snippet{text: "50", globalPosition: 89},
&snippet{text: ")", globalPosition: 91},
&snippet{text: "or", globalPosition: 93},
&snippet{text: "name", globalPosition: 96},
&snippet{text: "sameas", globalPosition: 101},
&snippet{text: "Julian", globalPosition: 108},
&snippet{text: "or", globalPosition: 115},
&snippet{text: "name", globalPosition: 118},
&snippet{text: "matches", globalPosition: 123},
&snippet{text: "^King ", globalPosition: 131},
{text: "query", globalPosition: 1},
{text: "test:", globalPosition: 7},
{text: "where", globalPosition: 13},
{text: "(", globalPosition: 19},
{text: "bananas", globalPosition: 21},
{text: ">", globalPosition: 31},
{text: "100", globalPosition: 33},
{text: "and", globalPosition: 37},
{text: "monkeys.#", globalPosition: 41},
{text: "<=", globalPosition: 51},
{text: "12", globalPosition: 54},
{text: ")", globalPosition: 58},
{text: "or", globalPosition: 59},
{text: "(", globalPosition: 61},
{text: "coconuts", globalPosition: 62},
{text: "<", globalPosition: 71},
{text: "10", globalPosition: 73},
{text: "and", globalPosition: 76},
{text: "area", globalPosition: 82},
{text: ">", globalPosition: 87},
{text: "50", globalPosition: 89},
{text: ")", globalPosition: 91},
{text: "or", globalPosition: 93},
{text: "name", globalPosition: 96},
{text: "sameas", globalPosition: 101},
{text: "Julian", globalPosition: 108},
{text: "or", globalPosition: 115},
{text: "name", globalPosition: 118},
{text: "matches", globalPosition: 123},
{text: "^King ", globalPosition: 131},
}
snippets, err := extractSnippets(text1)

View file

@ -96,10 +96,7 @@ func (q *Query) IsChecked() bool {
// MatchesKey checks whether the query matches the supplied database key (key without database prefix).
func (q *Query) MatchesKey(dbKey string) bool {
if !strings.HasPrefix(dbKey, q.dbKeyPrefix) {
return false
}
return true
return strings.HasPrefix(dbKey, q.dbKeyPrefix)
}
// MatchesRecord checks whether the query matches the supplied database record (value only).

View file

@ -1,3 +1,4 @@
//nolint:unparam
package query
import (

View file

@ -148,10 +148,10 @@ func registryWriter() {
select {
case <-time.After(1 * time.Hour):
if writeRegistrySoon.SetToIf(true, false) {
saveRegistry(true)
_ = saveRegistry(true)
}
case <-shutdownSignal:
saveRegistry(true)
_ = saveRegistry(true)
return
}
}

View file

@ -21,7 +21,7 @@ type Badger struct {
}
func init() {
storage.Register("badger", NewBadger)
_ = storage.Register("badger", NewBadger)
}
// NewBadger opens/creates a badger database.
@ -190,7 +190,7 @@ func (b *Badger) Injected() bool {
// Maintain runs a light maintenance operation on the database.
func (b *Badger) Maintain() error {
b.db.RunValueLogGC(0.7)
_ = b.db.RunValueLogGC(0.7)
return nil
}

View file

@ -1,3 +1,4 @@
//nolint:unparam,maligned
package badger
import (
@ -13,7 +14,7 @@ import (
type TestRecord struct {
record.Base
lock sync.Mutex
sync.Mutex
S string
I int
I8 int8
@ -30,12 +31,6 @@ type TestRecord struct {
B bool
}
func (tr *TestRecord) Lock() {
}
func (tr *TestRecord) Unlock() {
}
func TestBadger(t *testing.T) {
testDir, err := ioutil.TempDir("", "testing-")
if err != nil {
@ -98,7 +93,7 @@ func TestBadger(t *testing.T) {
t.Fatal(err)
}
cnt := 0
for _ = range it.Next {
for range it.Next {
cnt++
}
if it.Err() != nil {

View file

@ -26,7 +26,7 @@ type BBolt struct {
}
func init() {
storage.Register("bbolt", NewBBolt)
_ = storage.Register("bbolt", NewBBolt)
}
// NewBBolt opens/creates a bbolt database.

View file

@ -1,3 +1,4 @@
//nolint:unparam,maligned
package bbolt
import (
@ -13,7 +14,7 @@ import (
type TestRecord struct {
record.Base
lock sync.Mutex
sync.Mutex
S string
I int
I8 int8
@ -30,12 +31,6 @@ type TestRecord struct {
B bool
}
func (tr *TestRecord) Lock() {
}
func (tr *TestRecord) Unlock() {
}
func TestBadger(t *testing.T) {
testDir, err := ioutil.TempDir("", "testing-")
if err != nil {
@ -126,7 +121,7 @@ func TestBadger(t *testing.T) {
t.Fatal(err)
}
cnt := 0
for _ = range it.Next {
for range it.Next {
cnt++
}
if it.Err() != nil {

View file

@ -35,7 +35,7 @@ type FSTree struct {
}
func init() {
storage.Register("fstree", NewFSTree)
_ = storage.Register("fstree", NewFSTree)
}
// NewFSTree returns a (new) FSTree database.
@ -160,15 +160,14 @@ func (fst *FSTree) Query(q *query.Query, local, internal bool) (*iterator.Iterat
}
fileInfo, err := os.Stat(walkPrefix)
var walkRoot string
if err == nil {
if fileInfo.IsDir() {
walkRoot = walkPrefix
} else {
walkRoot = filepath.Dir(walkPrefix)
}
} else if os.IsNotExist(err) {
switch {
case err == nil && fileInfo.IsDir():
walkRoot = walkPrefix
case err == nil:
walkRoot = filepath.Dir(walkPrefix)
} else {
case os.IsNotExist(err):
walkRoot = filepath.Dir(walkPrefix)
default: // err != nil
return nil, fmt.Errorf("fstree: could not stat query root %s: %s", walkPrefix, err)
}
@ -279,7 +278,7 @@ func writeFile(filename string, data []byte, perm os.FileMode) error {
if err != nil {
return err
}
defer t.Cleanup()
defer t.Cleanup() //nolint:errcheck
// Set permissions before writing data, in case the data is sensitive.
if !onWindows {

View file

@ -15,7 +15,7 @@ type Sinkhole struct {
}
func init() {
storage.Register("sinkhole", NewSinkhole)
_ = storage.Register("sinkhole", NewSinkhole)
}
// NewSinkhole creates a dummy database.

View file

@ -1 +0,0 @@
package kvops

View file

@ -25,7 +25,6 @@ const (
// define errors
var errNoMoreSpace = errors.New("dsd: no more space left after reading dsd type")
var errUnknownType = errors.New("dsd: tried to unpack unknown type")
var errNotImplemented = errors.New("dsd: this type is not yet implemented")
// Load loads an dsd structured data blob into the given interface.
@ -58,7 +57,8 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
return nil, fmt.Errorf("dsd: failed to unpack json data: %s", data)
}
return t, nil
// case BSON:
case BSON:
return nil, errNotImplemented
// err := bson.Unmarshal(data[read:], t)
// if err != nil {
// return nil, err
@ -92,7 +92,7 @@ func Dump(t interface{}, format uint8) ([]byte, error) {
}
}
f := varint.Pack8(uint8(format))
f := varint.Pack8(format)
var data []byte
var err error
switch format {
@ -106,7 +106,8 @@ func Dump(t interface{}, format uint8) ([]byte, error) {
if err != nil {
return nil, err
}
// case BSON:
case BSON:
return nil, errNotImplemented
// data, err = bson.Marshal(t)
// if err != nil {
// return nil, err

View file

@ -1,3 +1,4 @@
//nolint:maligned,unparam,gocyclo
package dsd
import (
@ -97,8 +98,7 @@ func TestConversion(t *testing.T) {
}
bString := "b"
var bBytes byte
bBytes = 0x02
var bBytes byte = 0x02
complexSubject := ComplexTestStruct{
-1,

View file

@ -1,3 +1,4 @@
//nolint:nakedret,unconvert
package dsd
import (

View file

@ -9,5 +9,5 @@ var (
func init() {
flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]")
flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,firewall=debug")
flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,notifications=debug")
}

View file

@ -7,9 +7,12 @@ import (
var counter uint16
const maxCount uint16 = 999
const (
maxCount uint16 = 999
timeFormat string = "060102 15:04:05.000"
)
func (s severity) String() string {
func (s Severity) String() string {
switch s {
case TraceLevel:
return "TRAC"
@ -41,25 +44,25 @@ func formatLine(line *logLine, duplicates uint64, useColor bool) string {
var fLine string
if line.line == 0 {
fLine = fmt.Sprintf("%s%s ? %s %s %03d%s%s %s", colorStart, line.timestamp.Format("060102 15:04:05.000"), rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg)
fLine = fmt.Sprintf("%s%s ? %s %s %03d%s%s %s", colorStart, line.timestamp.Format(timeFormat), rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg)
} else {
fLen := len(line.file)
fPartStart := fLen - 10
if fPartStart < 0 {
fPartStart = 0
}
fLine = fmt.Sprintf("%s%s %s:%03d %s %s %03d%s%s %s", colorStart, line.timestamp.Format("060102 15:04:05.000"), line.file[fPartStart:], line.line, rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg)
fLine = fmt.Sprintf("%s%s %s:%03d %s %s %03d%s%s %s", colorStart, line.timestamp.Format(timeFormat), line.file[fPartStart:], line.line, rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg)
}
if line.trace != nil {
if line.tracer != nil {
// append full trace time
if len(line.trace.actions) > 0 {
fLine += fmt.Sprintf(" Σ=%s", line.timestamp.Sub(line.trace.actions[0].timestamp))
if len(line.tracer.logs) > 0 {
fLine += fmt.Sprintf(" Σ=%s", line.timestamp.Sub(line.tracer.logs[0].timestamp))
}
// append all trace actions
var d time.Duration
for i, action := range line.trace.actions {
for i, action := range line.tracer.logs {
// set color
if useColor {
colorStart = action.level.color()
@ -71,10 +74,10 @@ func formatLine(line *logLine, duplicates uint64, useColor bool) string {
fPartStart = 0
}
// format
if i == len(line.trace.actions)-1 { // last
if i == len(line.tracer.logs)-1 { // last
d = line.timestamp.Sub(action.timestamp)
} else {
d = line.trace.actions[i+1].timestamp.Sub(action.timestamp)
d = line.tracer.logs[i+1].timestamp.Sub(action.timestamp)
}
fLine += fmt.Sprintf("\n%s%19s %s:%03d %s %s%s %s", colorStart, d, action.file[fPartStart:], action.line, rightArrow, action.level.String(), colorEnd, action.msg)
}

View file

@ -16,7 +16,7 @@ const (
// colorWhite = "\033[37m"
)
func (s severity) color() string {
func (s Severity) color() string {
switch s {
case DebugLevel:
return colorCyan

View file

@ -28,7 +28,7 @@ func init() {
colorsSupported = osdetail.EnableColorSupport()
}
func (s severity) color() string {
func (s Severity) color() string {
if colorsSupported {
switch s {
case DebugLevel:

View file

@ -8,33 +8,18 @@ import (
"time"
)
func fastcheck(level severity) bool {
if pkgLevelsActive.IsSet() {
return true
}
if uint32(level) < atomic.LoadUint32(logLevel) {
return false
}
return true
}
func log(level severity, msg string, trace *ContextTracer) {
func log(level Severity, msg string, tracer *ContextTracer) {
if !started.IsSet() {
// a bit resouce intense, but keeps logs before logging started.
// a bit resource intense, but keeps logs before logging started.
// FIXME: create option to disable logging
go func() {
<-startedSignal
log(level, msg, trace)
log(level, msg, tracer)
}()
return
}
// check if level is enabled
if !pkgLevelsActive.IsSet() && uint32(level) < atomic.LoadUint32(logLevel) {
return
}
// get time
now := time.Now()
@ -53,26 +38,33 @@ func log(level severity, msg string, trace *ContextTracer) {
// check if level is enabled for file or generally
if pkgLevelsActive.IsSet() {
fileOnly := strings.Split(file, "/")
if len(fileOnly) < 2 {
pathSegments := strings.Split(file, "/")
if len(pathSegments) < 2 {
// file too short for package levels
return
}
sev, ok := pkgLevels[fileOnly[len(fileOnly)-2]]
pkgLevelsLock.Lock()
severity, ok := pkgLevels[pathSegments[len(pathSegments)-2]]
pkgLevelsLock.Unlock()
if ok {
if level < sev {
if level < severity {
return
}
} else {
// no package level set, check against global level
if uint32(level) < atomic.LoadUint32(logLevel) {
return
}
}
} else if uint32(level) < atomic.LoadUint32(logLevel) {
// no package levels set, check against global level
return
}
// create log object
log := &logLine{
msg: msg,
trace: trace,
tracer: tracer,
level: level,
timestamp: now,
file: file,
@ -83,93 +75,107 @@ func log(level severity, msg string, trace *ContextTracer) {
select {
case logBuffer <- log:
default:
forceEmptyingOfBuffer <- true
forceEmptyingOfBuffer <- struct{}{}
logBuffer <- log
}
// wake up writer if necessary
if logsWaitingFlag.SetToIf(false, true) {
logsWaiting <- true
logsWaiting <- struct{}{}
}
}
func Tracef(things ...interface{}) {
if fastcheck(TraceLevel) {
log(TraceLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
func fastcheck(level Severity) bool {
if pkgLevelsActive.IsSet() {
return true
}
if uint32(level) >= atomic.LoadUint32(logLevel) {
return true
}
return false
}
// Trace is used to log tiny steps. Log traces to context if you can!
func Trace(msg string) {
if fastcheck(TraceLevel) {
log(TraceLevel, msg, nil)
}
}
func Debugf(things ...interface{}) {
if fastcheck(DebugLevel) {
log(DebugLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
// Tracef is used to log tiny steps. Log traces to context if you can!
func Tracef(format string, things ...interface{}) {
if fastcheck(TraceLevel) {
log(TraceLevel, fmt.Sprintf(format, things...), nil)
}
}
// Debug is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem.
func Debug(msg string) {
if fastcheck(DebugLevel) {
log(DebugLevel, msg, nil)
}
}
func Infof(things ...interface{}) {
if fastcheck(InfoLevel) {
log(InfoLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
// Debugf is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem.
func Debugf(format string, things ...interface{}) {
if fastcheck(DebugLevel) {
log(DebugLevel, fmt.Sprintf(format, things...), nil)
}
}
// Info is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen.
func Info(msg string) {
if fastcheck(InfoLevel) {
log(InfoLevel, msg, nil)
}
}
func Warningf(things ...interface{}) {
if fastcheck(WarningLevel) {
log(WarningLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
// Infof is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen.
func Infof(format string, things ...interface{}) {
if fastcheck(InfoLevel) {
log(InfoLevel, fmt.Sprintf(format, things...), nil)
}
}
// Warning is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet.
func Warning(msg string) {
if fastcheck(WarningLevel) {
log(WarningLevel, msg, nil)
}
}
func Errorf(things ...interface{}) {
if fastcheck(ErrorLevel) {
log(ErrorLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
// Warningf is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet.
func Warningf(format string, things ...interface{}) {
if fastcheck(WarningLevel) {
log(WarningLevel, fmt.Sprintf(format, things...), nil)
}
}
// Error is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational. Maybe User/Admin should be informed.
func Error(msg string) {
if fastcheck(ErrorLevel) {
log(ErrorLevel, msg, nil)
}
}
func Criticalf(things ...interface{}) {
if fastcheck(CriticalLevel) {
log(CriticalLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
// Errorf is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational.
func Errorf(format string, things ...interface{}) {
if fastcheck(ErrorLevel) {
log(ErrorLevel, fmt.Sprintf(format, things...), nil)
}
}
// Critical is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
func Critical(msg string) {
if fastcheck(CriticalLevel) {
log(CriticalLevel, msg, nil)
}
}
func Testf(things ...interface{}) {
fmt.Printf(things[0].(string), things[1:]...)
}
func Test(msg string) {
fmt.Println(msg)
// Criticalf is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
func Criticalf(format string, things ...interface{}) {
if fastcheck(CriticalLevel) {
log(CriticalLevel, fmt.Sprintf(format, things...), nil)
}
}

View file

@ -30,12 +30,13 @@ import (
- logging is configured by main module and is supplied access to configuration and taskmanager
*/
type severity uint32
// Severity describes a log level.
type Severity uint32
type logLine struct {
msg string
trace *ContextTracer
level severity
tracer *ContextTracer
level Severity
timestamp time.Time
file string
line int
@ -45,7 +46,7 @@ func (ll *logLine) Equal(ol *logLine) bool {
switch {
case ll.msg != ol.msg:
return false
case ll.trace != nil || ol.trace != nil:
case ll.tracer != nil || ol.tracer != nil:
return false
case ll.file != ol.file:
return false
@ -57,55 +58,58 @@ func (ll *logLine) Equal(ol *logLine) bool {
return true
}
// Log Levels
const (
TraceLevel severity = 1
DebugLevel severity = 2
InfoLevel severity = 3
WarningLevel severity = 4
ErrorLevel severity = 5
CriticalLevel severity = 6
TraceLevel Severity = 1
DebugLevel Severity = 2
InfoLevel Severity = 3
WarningLevel Severity = 4
ErrorLevel Severity = 5
CriticalLevel Severity = 6
)
var (
logBuffer chan *logLine
forceEmptyingOfBuffer chan bool
forceEmptyingOfBuffer chan struct{}
logLevelInt = uint32(3)
logLevel = &logLevelInt
pkgLevelsActive = abool.NewBool(false)
pkgLevels = make(map[string]severity)
pkgLevels = make(map[string]Severity)
pkgLevelsLock sync.Mutex
logsWaiting = make(chan bool, 1)
logsWaiting = make(chan struct{}, 1)
logsWaitingFlag = abool.NewBool(false)
shutdownSignal = make(chan struct{}, 0)
shutdownSignal = make(chan struct{})
shutdownWaitGroup sync.WaitGroup
initializing = abool.NewBool(false)
started = abool.NewBool(false)
startedSignal = make(chan struct{}, 0)
testErrors = abool.NewBool(false)
startedSignal = make(chan struct{})
)
func SetPkgLevels(levels map[string]severity) {
// SetPkgLevels sets individual log levels for packages.
func SetPkgLevels(levels map[string]Severity) {
pkgLevelsLock.Lock()
pkgLevels = levels
pkgLevelsLock.Unlock()
pkgLevelsActive.Set()
}
// UnSetPkgLevels removes all individual log levels for packages.
func UnSetPkgLevels() {
pkgLevelsActive.UnSet()
}
func SetLogLevel(level severity) {
// SetLogLevel sets a new log level.
func SetLogLevel(level Severity) {
atomic.StoreUint32(logLevel, uint32(level))
}
func ParseLevel(level string) severity {
// ParseLevel returns the level severity of a log level name.
func ParseLevel(level string) Severity {
switch strings.ToLower(level) {
case "trace":
return 1
@ -123,14 +127,15 @@ func ParseLevel(level string) severity {
return 0
}
// Start starts the logging system. Must be called in order to see logs.
func Start() (err error) {
if !initializing.SetToIf(false, true) {
return nil
}
logBuffer = make(chan *logLine, 8192)
forceEmptyingOfBuffer = make(chan bool, 4)
logBuffer = make(chan *logLine, 1024)
forceEmptyingOfBuffer = make(chan struct{}, 16)
initialLogLevel := ParseLevel(logLevelFlag)
if initialLogLevel > 0 {
@ -143,7 +148,7 @@ func Start() (err error) {
// get and set file loglevels
pkgLogLevels := pkgLogLevelsFlag
if len(pkgLogLevels) > 0 {
newPkgLevels := make(map[string]severity)
newPkgLevels := make(map[string]Severity)
for _, pair := range strings.Split(pkgLogLevels, ",") {
splitted := strings.Split(pair, "=")
if len(splitted) != 2 {
@ -162,6 +167,9 @@ func Start() (err error) {
SetPkgLevels(newPkgLevels)
}
if !schedulingEnabled {
close(writeTrigger)
}
startWriter()
started.Set()
@ -170,7 +178,7 @@ func Start() (err error) {
return err
}
// Shutdown writes remaining log lines and then stops the logger.
// Shutdown writes remaining log lines and then stops the log system.
func Shutdown() {
close(shutdownSignal)
shutdownWaitGroup.Wait()

View file

@ -1,17 +1,20 @@
package log
import (
"fmt"
"testing"
"time"
)
// test waiting
func TestLogging(t *testing.T) {
func init() {
err := Start()
if err != nil {
t.Errorf("start failed: %s", err)
panic(fmt.Sprintf("start failed: %s", err))
}
}
// test waiting
func TestLogging(t *testing.T) {
// skip
if testing.Short() {

View file

@ -3,10 +3,35 @@ package log
import (
"fmt"
"time"
"github.com/safing/portbase/taskmanager"
)
var (
schedulingEnabled = false
writeTrigger = make(chan struct{})
)
// EnableScheduling enables external scheduling of the logger. This will require to manually trigger writes via TriggerWrite whenevery logs should be written. Please note that full buffers will also trigger writing. Must be called before Start() to have an effect.
func EnableScheduling() {
if !initializing.IsSet() {
schedulingEnabled = true
}
}
// TriggerWriter triggers log output writing.
func TriggerWriter() {
if started.IsSet() && schedulingEnabled {
select {
case writeTrigger <- struct{}{}:
default:
}
}
}
// TriggerWriterChannel returns the channel to trigger log writing. Returned channel will close if EnableScheduling() is not called correctly.
func TriggerWriterChannel() chan struct{} {
return writeTrigger
}
func writeLine(line *logLine, duplicates uint64) {
fmt.Println(formatLine(line, duplicates, true))
// TODO: implement file logging and setting console/file logging
@ -15,7 +40,7 @@ func writeLine(line *logLine, duplicates uint64) {
func startWriter() {
shutdownWaitGroup.Add(1)
fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format("060102 15:04:05.000"), rightArrow, endColor()))
fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor()))
go writer()
}
@ -23,13 +48,12 @@ func writer() {
var line *logLine
var lastLine *logLine
var duplicates uint64
startedTask := false
defer shutdownWaitGroup.Done()
for {
// reset
line = nil
lastLine = nil
lastLine = nil //nolint:ineffassign // only ineffectual in first loop
duplicates = 0
// wait until logs need to be processed
@ -37,23 +61,17 @@ func writer() {
case <-logsWaiting:
logsWaitingFlag.UnSet()
case <-shutdownSignal:
finalizeWriting()
return
}
// wait for timeslot to log, or when buffer is full
select {
case <-taskmanager.StartVeryLowPriorityMicroTask():
startedTask = true
case <-writeTrigger:
case <-forceEmptyingOfBuffer:
case <-shutdownSignal:
for {
select {
case line = <-logBuffer:
writeLine(line, duplicates)
case <-time.After(10 * time.Millisecond):
fmt.Println(fmt.Sprintf("%s%s %s EOF%s", InfoLevel.color(), time.Now().Format("060102 15:04:05.000"), leftArrow, endColor()))
return
}
}
finalizeWriting()
return
}
// write all the logs!
@ -77,22 +95,17 @@ func writer() {
// deduplication
if !line.Equal(lastLine) {
// no duplicate
writeLine(lastLine, duplicates)
duplicates = 0
} else {
// duplicate
duplicates++
break dedupLoop
}
// duplicate
duplicates++
}
// write actual line
writeLine(line, duplicates)
duplicates = 0
default:
if startedTask {
taskmanager.EndMicroTask()
startedTask = false
}
break writeLoop
}
}
@ -101,7 +114,21 @@ func writer() {
select {
case <-time.After(10 * time.Millisecond):
case <-shutdownSignal:
finalizeWriting()
return
}
}
}
func finalizeWriting() {
for {
select {
case line := <-logBuffer:
writeLine(line, 0)
case <-time.After(10 * time.Millisecond):
fmt.Println(fmt.Sprintf("%s%s %s EOF%s", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor()))
return
}
}
}

View file

@ -4,42 +4,73 @@ import (
"context"
"fmt"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
)
// Key for context value
// ContextTracerKey is the key used for the context key/value storage.
type ContextTracerKey struct{}
// ContextTracer is attached to a context in order bind logs to a context.
type ContextTracer struct {
sync.Mutex
actions []*Action
}
type Action struct {
timestamp time.Time
level severity
msg string
file string
line int
logs []*logLine
}
var (
key = ContextTracerKey{}
nilTracer *ContextTracer
key = ContextTracerKey{}
)
func AddTracer(ctx context.Context) context.Context {
if ctx != nil && fastcheckLevel(TraceLevel) {
// AddTracer adds a ContextTracer to the returned Context. Will return a nil ContextTracer if logging level is not set to trace. Will return a nil ContextTracer if one already exists. Will return a nil ContextTracer in case of an error. Will return a nil context if nil.
func AddTracer(ctx context.Context) (context.Context, *ContextTracer) {
if ctx != nil && fastcheck(TraceLevel) {
// check pkg levels
if pkgLevelsActive.IsSet() {
// get file
_, file, _, ok := runtime.Caller(1)
if !ok {
// cannot get file, ignore
return ctx, nil
}
pathSegments := strings.Split(file, "/")
if len(pathSegments) < 2 {
// file too short for package levels
return ctx, nil
}
pkgLevelsLock.Lock()
severity, ok := pkgLevels[pathSegments[len(pathSegments)-2]]
pkgLevelsLock.Unlock()
if ok {
// check against package level
if TraceLevel < severity {
return ctx, nil
}
} else {
// no package level set, check against global level
if uint32(TraceLevel) < atomic.LoadUint32(logLevel) {
return ctx, nil
}
}
} else if uint32(TraceLevel) < atomic.LoadUint32(logLevel) {
// no package levels set, check against global level
return ctx, nil
}
// check for existing tracer
_, ok := ctx.Value(key).(*ContextTracer)
if !ok {
return context.WithValue(ctx, key, &ContextTracer{})
// add and return new tracer
tracer := &ContextTracer{}
return context.WithValue(ctx, key, tracer), tracer
}
}
return ctx
return ctx, nil
}
// Tracer returns the ContextTracer previously added to the given Context.
func Tracer(ctx context.Context) *ContextTracer {
if ctx != nil {
tracer, ok := ctx.Value(key).(*ContextTracer)
@ -47,10 +78,59 @@ func Tracer(ctx context.Context) *ContextTracer {
return tracer
}
}
return nilTracer
return nil
}
func (ct *ContextTracer) logTrace(level severity, msg string) {
// Submit collected logs on the context for further processing/outputting. Does nothing if called on a nil ContextTracer.
func (tracer *ContextTracer) Submit() {
if tracer != nil {
return
}
if !started.IsSet() {
// a bit resource intense, but keeps logs before logging started.
// FIXME: create option to disable logging
go func() {
<-startedSignal
tracer.Submit()
}()
return
}
if len(tracer.logs) == 0 {
return
}
// extract last line as main line
mainLine := tracer.logs[len(tracer.logs)-1]
tracer.logs = tracer.logs[:len(tracer.logs)-1]
// create log object
log := &logLine{
msg: mainLine.msg,
tracer: tracer,
level: mainLine.level,
timestamp: mainLine.timestamp,
file: mainLine.file,
line: mainLine.line,
}
// send log to processing
select {
case logBuffer <- log:
default:
forceEmptyingOfBuffer <- struct{}{}
logBuffer <- log
}
// wake up writer if necessary
if logsWaitingFlag.SetToIf(false, true) {
logsWaiting <- struct{}{}
}
}
func (tracer *ContextTracer) log(level Severity, msg string) {
// get file and line
_, file, line, ok := runtime.Caller(2)
if !ok {
@ -64,9 +144,9 @@ func (ct *ContextTracer) logTrace(level severity, msg string) {
}
}
ct.Lock()
defer ct.Unlock()
ct.actions = append(ct.actions, &Action{
tracer.Lock()
defer tracer.Unlock()
tracer.logs = append(tracer.logs, &logLine{
timestamp: time.Now(),
level: level,
msg: msg,
@ -75,146 +155,122 @@ func (ct *ContextTracer) logTrace(level severity, msg string) {
})
}
func (ct *ContextTracer) Tracef(things ...interface{}) (ok bool) {
if ct != nil {
if fastcheckLevel(TraceLevel) {
ct.logTrace(TraceLevel, fmt.Sprintf(things[0].(string), things[1:]...))
}
return true
// Trace is used to log tiny steps. Log traces to context if you can!
func (tracer *ContextTracer) Trace(msg string) {
switch {
case tracer != nil:
tracer.log(TraceLevel, msg)
case fastcheck(TraceLevel):
log(TraceLevel, msg, nil)
}
return false
}
func (ct *ContextTracer) Trace(msg string) (ok bool) {
if ct != nil {
if fastcheckLevel(TraceLevel) {
ct.logTrace(TraceLevel, msg)
}
return true
// Tracef is used to log tiny steps. Log traces to context if you can!
func (tracer *ContextTracer) Tracef(format string, things ...interface{}) {
switch {
case tracer != nil:
tracer.log(TraceLevel, fmt.Sprintf(format, things...))
case fastcheck(TraceLevel):
log(TraceLevel, fmt.Sprintf(format, things...), nil)
}
return false
}
func (ct *ContextTracer) Warningf(things ...interface{}) (ok bool) {
if ct != nil {
if fastcheckLevel(TraceLevel) {
ct.logTrace(WarningLevel, fmt.Sprintf(things[0].(string), things[1:]...))
}
return true
// Debug is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem.
func (tracer *ContextTracer) Debug(msg string) {
switch {
case tracer != nil:
tracer.log(DebugLevel, msg)
case fastcheck(DebugLevel):
log(DebugLevel, msg, nil)
}
return false
}
func (ct *ContextTracer) Warning(msg string) (ok bool) {
if ct != nil {
if fastcheckLevel(TraceLevel) {
ct.logTrace(WarningLevel, msg)
}
return true
// Debugf is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem.
func (tracer *ContextTracer) Debugf(format string, things ...interface{}) {
switch {
case tracer != nil:
tracer.log(DebugLevel, fmt.Sprintf(format, things...))
case fastcheck(DebugLevel):
log(DebugLevel, fmt.Sprintf(format, things...), nil)
}
return false
}
func (ct *ContextTracer) Errorf(things ...interface{}) (ok bool) {
if ct != nil {
if fastcheckLevel(TraceLevel) {
ct.logTrace(ErrorLevel, fmt.Sprintf(things[0].(string), things[1:]...))
}
return true
// Info is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen.
func (tracer *ContextTracer) Info(msg string) {
switch {
case tracer != nil:
tracer.log(InfoLevel, msg)
case fastcheck(InfoLevel):
log(InfoLevel, msg, nil)
}
return false
}
func (ct *ContextTracer) Error(msg string) (ok bool) {
if ct != nil {
if fastcheckLevel(TraceLevel) {
ct.logTrace(ErrorLevel, msg)
}
return true
// Infof is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen.
func (tracer *ContextTracer) Infof(format string, things ...interface{}) {
switch {
case tracer != nil:
tracer.log(InfoLevel, fmt.Sprintf(format, things...))
case fastcheck(InfoLevel):
log(InfoLevel, fmt.Sprintf(format, things...), nil)
}
return false
}
func DebugTrace(ctx context.Context, msg string) (ok bool) {
tracer, ok := ctx.Value(key).(*ContextTracer)
if ok && fastcheckLevel(TraceLevel) {
log(DebugLevel, msg, tracer)
return true
// Warning is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet.
func (tracer *ContextTracer) Warning(msg string) {
switch {
case tracer != nil:
tracer.log(WarningLevel, msg)
case fastcheck(WarningLevel):
log(WarningLevel, msg, nil)
}
log(DebugLevel, msg, nil)
return false
}
func DebugTracef(ctx context.Context, things ...interface{}) (ok bool) {
tracer, ok := ctx.Value(key).(*ContextTracer)
if ok && fastcheckLevel(TraceLevel) {
log(DebugLevel, fmt.Sprintf(things[0].(string), things[1:]...), tracer)
return true
// Warningf is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet.
func (tracer *ContextTracer) Warningf(format string, things ...interface{}) {
switch {
case tracer != nil:
tracer.log(WarningLevel, fmt.Sprintf(format, things...))
case fastcheck(WarningLevel):
log(WarningLevel, fmt.Sprintf(format, things...), nil)
}
log(DebugLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
return false
}
func InfoTrace(ctx context.Context, msg string) (ok bool) {
tracer, ok := ctx.Value(key).(*ContextTracer)
if ok && fastcheckLevel(TraceLevel) {
log(InfoLevel, msg, tracer)
return true
// Error is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational. Maybe User/Admin should be informed.
func (tracer *ContextTracer) Error(msg string) {
switch {
case tracer != nil:
tracer.log(ErrorLevel, msg)
case fastcheck(ErrorLevel):
log(ErrorLevel, msg, nil)
}
log(InfoLevel, msg, nil)
return false
}
func InfoTracef(ctx context.Context, things ...interface{}) (ok bool) {
tracer, ok := ctx.Value(key).(*ContextTracer)
if ok && fastcheckLevel(TraceLevel) {
log(InfoLevel, fmt.Sprintf(things[0].(string), things[1:]...), tracer)
return true
// Errorf is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational.
func (tracer *ContextTracer) Errorf(format string, things ...interface{}) {
switch {
case tracer != nil:
tracer.log(ErrorLevel, fmt.Sprintf(format, things...))
case fastcheck(ErrorLevel):
log(ErrorLevel, fmt.Sprintf(format, things...), nil)
}
log(InfoLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
return false
}
func WarningTrace(ctx context.Context, msg string) (ok bool) {
tracer, ok := ctx.Value(key).(*ContextTracer)
if ok && fastcheckLevel(TraceLevel) {
log(WarningLevel, msg, tracer)
return true
// Critical is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
func (tracer *ContextTracer) Critical(msg string) {
switch {
case tracer != nil:
tracer.log(CriticalLevel, msg)
case fastcheck(CriticalLevel):
log(CriticalLevel, msg, nil)
}
log(WarningLevel, msg, nil)
return false
}
func WarningTracef(ctx context.Context, things ...interface{}) (ok bool) {
tracer, ok := ctx.Value(key).(*ContextTracer)
if ok && fastcheckLevel(TraceLevel) {
log(WarningLevel, fmt.Sprintf(things[0].(string), things[1:]...), tracer)
return true
// Criticalf is used to log events that completely break the system. Operation connot continue. User/Admin must be informed.
func (tracer *ContextTracer) Criticalf(format string, things ...interface{}) {
switch {
case tracer != nil:
tracer.log(CriticalLevel, fmt.Sprintf(format, things...))
case fastcheck(CriticalLevel):
log(CriticalLevel, fmt.Sprintf(format, things...), nil)
}
log(WarningLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
return false
}
func ErrorTrace(ctx context.Context, msg string) (ok bool) {
tracer, ok := ctx.Value(key).(*ContextTracer)
if ok && fastcheckLevel(TraceLevel) {
log(ErrorLevel, msg, tracer)
return true
}
log(ErrorLevel, msg, nil)
return false
}
func ErrorTracef(ctx context.Context, things ...interface{}) (ok bool) {
tracer, ok := ctx.Value(key).(*ContextTracer)
if ok && fastcheckLevel(TraceLevel) {
log(ErrorLevel, fmt.Sprintf(things[0].(string), things[1:]...), tracer)
return true
}
log(ErrorLevel, fmt.Sprintf(things[0].(string), things[1:]...), nil)
return false
}
func fastcheckLevel(level severity) bool {
return uint32(level) >= atomic.LoadUint32(logLevel)
}

View file

@ -12,20 +12,22 @@ func TestContextTracer(t *testing.T) {
t.Skip()
}
ctx := AddTracer(context.Background())
ctx, tracer := AddTracer(context.Background())
_ = Tracer(ctx)
Tracer(ctx).Trace("api: request received, checking security")
tracer.Trace("api: request received, checking security")
time.Sleep(1 * time.Millisecond)
Tracer(ctx).Trace("login: logging in user")
tracer.Trace("login: logging in user")
time.Sleep(1 * time.Millisecond)
Tracer(ctx).Trace("database: fetching requested resources")
tracer.Trace("database: fetching requested resources")
time.Sleep(10 * time.Millisecond)
Tracer(ctx).Warning("database: partial failure")
tracer.Warning("database: partial failure")
time.Sleep(10 * time.Microsecond)
Tracer(ctx).Trace("renderer: rendering output")
tracer.Trace("renderer: rendering output")
time.Sleep(1 * time.Millisecond)
Tracer(ctx).Trace("api: returning request")
tracer.Trace("api: returning request")
DebugTrace(ctx, "api: completed request")
tracer.Trace("api: completed request")
tracer.Submit()
time.Sleep(100 * time.Millisecond)
}

18
modules/doc.go Normal file
View file

@ -0,0 +1,18 @@
/*
Package modules provides a full module and task management ecosystem to cleanly put all big and small moving parts of a service together.
Modules are started in a multi-stage process and may depend on other modules:
- Go's init(): register flags
- prep: check flags, register config variables
- start: start actual work, access config
- stop: gracefully shut down
Workers: A simple function that is run by the module while catching panics and reporting them. Ideal for long running (possibly) idle goroutines. Can be automatically restarted if execution ends with an error.
Tasks: Functions that take somewhere between a couple seconds and a couple minutes to execute and should be queued, scheduled or repeated.
MicroTasks: Functions that take less than a second to execute, but require lots of resources. Running such functions as MicroTasks will reduce concurrent execution and shall improve performance.
Ideally, _any_ execution by a module is done through these methods. This will not only ensure that all panics are caught, but will also give better insights into how your service performs.
*/
package modules

90
modules/error.go Normal file
View file

@ -0,0 +1,90 @@
package modules
import (
"fmt"
"runtime/debug"
)
var (
errorReportingChannel chan *ModuleError
)
// ModuleError wraps a panic, error or message into an error that can be reported.
type ModuleError struct {
Message string
ModuleName string
TaskName string
TaskType string // one of "worker", "task", "microtask" or custom
Severity string // one of "info", "error", "panic" or custom
PanicValue interface{}
StackTrace string
}
// NewInfoMessage creates a new, reportable, info message (including a stack trace).
func (m *Module) NewInfoMessage(message string) *ModuleError {
return &ModuleError{
Message: message,
ModuleName: m.Name,
Severity: "info",
StackTrace: string(debug.Stack()),
}
}
// NewErrorMessage creates a new, reportable, error message (including a stack trace).
func (m *Module) NewErrorMessage(taskName string, err error) *ModuleError {
return &ModuleError{
Message: err.Error(),
ModuleName: m.Name,
Severity: "error",
StackTrace: string(debug.Stack()),
}
}
// NewPanicError creates a new, reportable, panic error message (including a stack trace).
func (m *Module) NewPanicError(taskName, taskType string, panicValue interface{}) *ModuleError {
me := &ModuleError{
Message: fmt.Sprintf("panic: %s", panicValue),
ModuleName: m.Name,
TaskName: taskName,
TaskType: taskType,
Severity: "panic",
PanicValue: panicValue,
StackTrace: string(debug.Stack()),
}
me.Message = me.Error()
return me
}
// Error returns the string representation of the error.
func (me *ModuleError) Error() string {
return me.Message
}
// Report reports the error through the configured reporting channel.
func (me *ModuleError) Report() {
if errorReportingChannel != nil {
select {
case errorReportingChannel <- me:
default:
}
}
}
// IsPanic returns whether the given error is a wrapped panic by the modules package and additionally returns it, if true.
func IsPanic(err error) (bool, *ModuleError) {
switch val := err.(type) {
case *ModuleError:
return true, val
default:
return false, nil
}
}
// SetErrorReportingChannel sets the channel to report module errors through. By default only panics are reported, all other errors need to be manually wrapped into a *ModuleError and reported.
func SetErrorReportingChannel(reportingChannel chan *ModuleError) {
if errorReportingChannel == nil {
errorReportingChannel = reportingChannel
}
}

157
modules/microtasks.go Normal file
View file

@ -0,0 +1,157 @@
package modules
import (
"context"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
"github.com/tevino/abool"
)
// TODO: getting some errors when in nanosecond precision for tests:
// (1) panic: sync: WaitGroup is reused before previous Wait has returned - should theoretically not happen
// (2) sometimes there seems to some kind of race condition stuff, the test hangs and does not complete
var (
microTasks *int32
microTasksThreshhold *int32
microTaskFinished = make(chan struct{}, 1)
mediumPriorityClearance = make(chan struct{})
lowPriorityClearance = make(chan struct{})
triggerLogWriting = log.TriggerWriterChannel()
)
const (
mediumPriorityMaxDelay = 1 * time.Second
lowPriorityMaxDelay = 3 * time.Second
)
func init() {
var microTasksVal int32
microTasks = &microTasksVal
var microTasksThreshholdVal int32
microTasksThreshhold = &microTasksThreshholdVal
}
// SetMaxConcurrentMicroTasks sets the maximum number of microtasks that should be run concurrently.
func SetMaxConcurrentMicroTasks(n int) {
if n < 4 {
atomic.StoreInt32(microTasksThreshhold, 4)
} else {
atomic.StoreInt32(microTasksThreshhold, int32(n))
}
}
// StartMicroTask starts a new MicroTask with high priority. It will start immediately. The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
func (m *Module) StartMicroTask(name *string, fn func(context.Context) error) error {
atomic.AddInt32(microTasks, 1)
return m.runMicroTask(name, fn)
}
// StartMediumPriorityMicroTask starts a new MicroTask with medium priority. It will wait until given a go (max 3 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Context) error) error {
// check if we can go immediately
select {
case <-mediumPriorityClearance:
default:
// wait for go or max delay
select {
case <-mediumPriorityClearance:
case <-time.After(mediumPriorityMaxDelay):
}
}
return m.runMicroTask(name, fn)
}
// StartLowPriorityMicroTask starts a new MicroTask with low priority. It will wait until given a go (max 15 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied.
func (m *Module) StartLowPriorityMicroTask(name *string, fn func(context.Context) error) error {
// check if we can go immediately
select {
case <-lowPriorityClearance:
default:
// wait for go or max delay
select {
case <-lowPriorityClearance:
case <-time.After(lowPriorityMaxDelay):
}
}
return m.runMicroTask(name, fn)
}
func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err error) {
// start for module
// hint: only microTasks global var is important for scheduling, others can be set here
atomic.AddInt32(m.microTaskCnt, 1)
m.waitGroup.Add(1)
// set up recovery
defer func() {
// recover from panic
panicVal := recover()
if panicVal != nil {
me := m.NewPanicError(*name, "microtask", panicVal)
me.Report()
log.Errorf("%s: microtask %s panicked: %s\n%s", m.Name, *name, panicVal, me.StackTrace)
err = me
}
// finish for module
atomic.AddInt32(m.microTaskCnt, -1)
m.waitGroup.Done()
// finish and possibly trigger next task
atomic.AddInt32(microTasks, -1)
select {
case microTaskFinished <- struct{}{}:
default:
}
}()
// run
err = fn(m.Ctx)
return //nolint:nakedret // need to use named return val in order to change in defer
}
var (
microTaskSchedulerStarted = abool.NewBool(false)
)
func microTaskScheduler() {
// only ever start once
if !microTaskSchedulerStarted.SetToIf(false, true) {
return
}
microTaskManageLoop:
for {
if shutdownSignalClosed.IsSet() {
close(mediumPriorityClearance)
close(lowPriorityClearance)
return
}
if atomic.LoadInt32(microTasks) < atomic.LoadInt32(microTasksThreshhold) { // space left for firing task
select {
case mediumPriorityClearance <- struct{}{}:
default:
select {
case taskTimeslot <- struct{}{}:
continue microTaskManageLoop
case triggerLogWriting <- struct{}{}:
continue microTaskManageLoop
case mediumPriorityClearance <- struct{}{}:
case lowPriorityClearance <- struct{}{}:
}
}
// increase task counter
atomic.AddInt32(microTasks, 1)
} else {
// wait for signal that a task was completed
<-microTaskFinished
}
}
}

View file

@ -1,60 +1,97 @@
package taskmanager
package modules
import (
"context"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
var (
mtTestName = "microtask test"
mtModule = initNewModule("microtask test module", nil, nil, nil)
)
func init() {
go microTaskScheduler()
}
// test waiting
func TestMicroTaskWaiting(t *testing.T) {
// skip
if testing.Short() {
t.Skip("skipping test in short mode.")
t.Skip("skipping test in short mode, as it is not fully deterministic")
}
// init
mtwWaitGroup := new(sync.WaitGroup)
mtwOutputChannel := make(chan string, 100)
mtwExpectedOutput := "123456"
mtwExpectedOutput := "1234567"
mtwSleepDuration := 10 * time.Millisecond
// TEST
mtwWaitGroup.Add(3)
mtwWaitGroup.Add(4)
// ensure we only execute one microtask at once
atomic.StoreInt32(microTasksThreshhold, 1)
// High Priority - slot 1-5
go func() {
defer mtwWaitGroup.Done()
StartMicroTask()
mtwOutputChannel <- "1"
time.Sleep(mtwSleepDuration * 5)
mtwOutputChannel <- "2"
EndMicroTask()
// exec at slot 1
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "1" // slot 1
time.Sleep(mtwSleepDuration * 5)
mtwOutputChannel <- "2" // slot 5
return nil
})
}()
time.Sleep(mtwSleepDuration * 2)
time.Sleep(mtwSleepDuration * 1)
// clear clearances
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
return nil
})
// Low Priority - slot 16
go func() {
defer mtwWaitGroup.Done()
// exec at slot 2
_ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "7" // slot 16
return nil
})
}()
time.Sleep(mtwSleepDuration * 1)
// High Priority - slot 10-15
go func() {
defer mtwWaitGroup.Done()
time.Sleep(mtwSleepDuration * 8)
StartMicroTask()
mtwOutputChannel <- "4"
time.Sleep(mtwSleepDuration * 5)
mtwOutputChannel <- "6"
EndMicroTask()
// exec at slot 10
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "4" // slot 10
time.Sleep(mtwSleepDuration * 5)
mtwOutputChannel <- "6" // slot 15
return nil
})
}()
// Medium Priority - Waits at slot 3, should execute in slot 6-13
// Medium Priority - slot 6-13
go func() {
defer mtwWaitGroup.Done()
<-StartMediumPriorityMicroTask()
mtwOutputChannel <- "3"
time.Sleep(mtwSleepDuration * 7)
mtwOutputChannel <- "5"
EndMicroTask()
// exec at slot 3
_ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "3" // slot 6
time.Sleep(mtwSleepDuration * 7)
mtwOutputChannel <- "5" // slot 13
return nil
})
}()
// wait for test to finish
@ -67,6 +104,7 @@ func TestMicroTaskWaiting(t *testing.T) {
completeOutput += s
}
// check if test succeeded
t.Logf("microTask wait order: %s", completeOutput)
if completeOutput != mtwExpectedOutput {
t.Errorf("MicroTask waiting test failed, expected sequence %s, got %s", mtwExpectedOutput, completeOutput)
}
@ -78,34 +116,27 @@ func TestMicroTaskWaiting(t *testing.T) {
// globals
var mtoWaitGroup sync.WaitGroup
var mtoOutputChannel chan string
var mtoWaitCh chan bool
var mtoWaitCh chan struct{}
// functions
func mediumPrioTaskTester() {
defer mtoWaitGroup.Done()
<-mtoWaitCh
<-StartMediumPriorityMicroTask()
mtoOutputChannel <- "1"
time.Sleep(2 * time.Millisecond)
EndMicroTask()
_ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtoOutputChannel <- "1"
time.Sleep(2 * time.Millisecond)
return nil
})
}
func lowPrioTaskTester() {
defer mtoWaitGroup.Done()
<-mtoWaitCh
<-StartLowPriorityMicroTask()
mtoOutputChannel <- "2"
time.Sleep(2 * time.Millisecond)
EndMicroTask()
}
func veryLowPrioTaskTester() {
defer mtoWaitGroup.Done()
<-mtoWaitCh
<-StartVeryLowPriorityMicroTask()
mtoOutputChannel <- "3"
time.Sleep(2 * time.Millisecond)
EndMicroTask()
_ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtoOutputChannel <- "2"
time.Sleep(2 * time.Millisecond)
return nil
})
}
// test
@ -113,53 +144,51 @@ func TestMicroTaskOrdering(t *testing.T) {
// skip
if testing.Short() {
t.Skip("skipping test in short mode.")
t.Skip("skipping test in short mode, as it is not fully deterministic")
}
// init
mtoOutputChannel = make(chan string, 100)
mtoWaitCh = make(chan bool, 0)
mtoWaitCh = make(chan struct{})
// TEST
mtoWaitGroup.Add(30)
mtoWaitGroup.Add(20)
// ensure we only execute one microtask at once
atomic.StoreInt32(microTasksThreshhold, 1)
// kick off
go mediumPrioTaskTester()
go mediumPrioTaskTester()
go lowPrioTaskTester()
go lowPrioTaskTester()
go veryLowPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester()
go veryLowPrioTaskTester()
go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester()
go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester()
go mediumPrioTaskTester()
go mediumPrioTaskTester()
go mediumPrioTaskTester()
go lowPrioTaskTester()
go mediumPrioTaskTester()
go lowPrioTaskTester()
go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester()
go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester()
go lowPrioTaskTester()
go mediumPrioTaskTester()
go veryLowPrioTaskTester()
go lowPrioTaskTester()
go veryLowPrioTaskTester()
// wait for all goroutines to be ready
time.Sleep(10 * time.Millisecond)
// sync all goroutines
close(mtoWaitCh)
// trigger
select {
case microTaskFinished <- struct{}{}:
default:
}
// wait for test to finish
mtoWaitGroup.Wait()
@ -171,7 +200,8 @@ func TestMicroTaskOrdering(t *testing.T) {
completeOutput += s
}
// check if test succeeded
if !strings.Contains(completeOutput, "11111") || !strings.Contains(completeOutput, "22222") || !strings.Contains(completeOutput, "33333") {
t.Logf("microTask exec order: %s", completeOutput)
if !strings.Contains(completeOutput, "11111") || !strings.Contains(completeOutput, "22222") {
t.Errorf("MicroTask ordering test failed, output was %s. This happens occasionally, please run the test multiple times to verify", completeOutput)
}

View file

@ -39,8 +39,12 @@ type Module struct {
Ctx context.Context
cancelCtx func()
shutdownFlag *abool.AtomicBool
workerGroup sync.WaitGroup
// workers/tasks
workerCnt *int32
taskCnt *int32
microTaskCnt *int32
waitGroup sync.WaitGroup
// dependency mgmt
depNames []string
@ -48,27 +52,6 @@ type Module struct {
depReverse []*Module
}
// AddWorkers adds workers to the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
func (m *Module) AddWorkers(n uint) {
if !m.ShutdownInProgress() {
if atomic.AddInt32(m.workerCnt, int32(n)) > 0 {
// only add to workgroup if cnt is positive (try to compensate wrong usage)
m.workerGroup.Add(int(n))
}
}
}
// FinishWorker removes a worker from the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
func (m *Module) FinishWorker() {
// check worker cnt
if atomic.AddInt32(m.workerCnt, -1) < 0 {
log.Warningf("modules: %s module tried to finish more workers than added, this may lead to undefined behavior when shutting down", m.Name)
return
}
// also mark worker done in workgroup
m.workerGroup.Done()
}
// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead.
func (m *Module) ShutdownInProgress() bool {
return m.shutdownFlag.IsSet()
@ -87,13 +70,19 @@ func (m *Module) shutdown() error {
// wait for workers
done := make(chan struct{})
go func() {
m.workerGroup.Wait()
m.waitGroup.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
return errors.New("timed out while waiting for module workers to finish")
log.Warningf(
"%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...",
m.Name,
atomic.LoadInt32(m.workerCnt),
atomic.LoadInt32(m.taskCnt),
atomic.LoadInt32(m.microTaskCnt),
)
}
// call shutdown function
@ -106,8 +95,19 @@ func dummyAction() error {
// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished.
func Register(name string, prep, start, stop func() error, dependencies ...string) *Module {
newModule := initNewModule(name, prep, start, stop, dependencies...)
modulesLock.Lock()
defer modulesLock.Unlock()
modules[name] = newModule
return newModule
}
func initNewModule(name string, prep, start, stop func() error, dependencies ...string) *Module {
ctx, cancelCtx := context.WithCancel(context.Background())
var workerCnt int32
var taskCnt int32
var microTaskCnt int32
newModule := &Module{
Name: name,
@ -118,8 +118,10 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin
Ctx: ctx,
cancelCtx: cancelCtx,
shutdownFlag: abool.NewBool(false),
workerGroup: sync.WaitGroup{},
waitGroup: sync.WaitGroup{},
workerCnt: &workerCnt,
taskCnt: &taskCnt,
microTaskCnt: &microTaskCnt,
prep: prep,
start: start,
stop: stop,
@ -137,9 +139,6 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin
newModule.stop = dummyAction
}
modulesLock.Lock()
defer modulesLock.Unlock()
modules[name] = newModule
return newModule
}

View file

@ -14,28 +14,28 @@ var (
shutdownOrder string
)
func testPrep(name string) func() error {
func testPrep(t *testing.T, name string) func() error {
return func() error {
// fmt.Printf("prep %s\n", name)
t.Logf("prep %s\n", name)
return nil
}
}
func testStart(name string) func() error {
func testStart(t *testing.T, name string) func() error {
return func() error {
orderLock.Lock()
defer orderLock.Unlock()
// fmt.Printf("start %s\n", name)
t.Logf("start %s\n", name)
startOrder = fmt.Sprintf("%s>%s", startOrder, name)
return nil
}
}
func testStop(name string) func() error {
func testStop(t *testing.T, name string) func() error {
return func() error {
orderLock.Lock()
defer orderLock.Unlock()
// fmt.Printf("stop %s\n", name)
t.Logf("stop %s\n", name)
shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name)
return nil
}
@ -49,12 +49,19 @@ func testCleanExit() error {
return ErrCleanExit
}
func TestOrdering(t *testing.T) {
func TestModules(t *testing.T) {
t.Parallel() // Not really, just a workaround for running these tests last.
Register("database", testPrep("database"), testStart("database"), testStop("database"))
Register("stats", testPrep("stats"), testStart("stats"), testStop("stats"), "database")
Register("service", testPrep("service"), testStart("service"), testStop("service"), "database")
Register("analytics", testPrep("analytics"), testStart("analytics"), testStop("analytics"), "stats", "database")
t.Run("TestModuleOrder", testModuleOrder)
t.Run("TestModuleErrors", testModuleErrors)
}
func testModuleOrder(t *testing.T) {
Register("database", testPrep(t, "database"), testStart(t, "database"), testStop(t, "database"))
Register("stats", testPrep(t, "stats"), testStart(t, "stats"), testStop(t, "stats"), "database")
Register("service", testPrep(t, "service"), testStart(t, "service"), testStop(t, "service"), "database")
Register("analytics", testPrep(t, "analytics"), testStart(t, "analytics"), testStop(t, "analytics"), "stats", "database")
err := Start()
if err != nil {
@ -105,19 +112,7 @@ func printAndRemoveModules() {
modules = make(map[string]*Module)
}
func resetModules() {
for _, module := range modules {
module.Prepped.UnSet()
module.Started.UnSet()
module.Stopped.UnSet()
module.inTransition.UnSet()
module.depModules = make([]*Module, 0)
module.depModules = make([]*Module, 0)
}
}
func TestErrors(t *testing.T) {
func testModuleErrors(t *testing.T) {
// reset modules
modules = make(map[string]*Module)
@ -125,7 +120,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{})
// test prep error
Register("prepfail", testFail, testStart("prepfail"), testStop("prepfail"))
Register("prepfail", testFail, testStart(t, "prepfail"), testStop(t, "prepfail"))
err := Start()
if err == nil {
t.Error("should fail")
@ -137,7 +132,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{})
// test prep clean exit
Register("prepcleanexit", testCleanExit, testStart("prepcleanexit"), testStop("prepcleanexit"))
Register("prepcleanexit", testCleanExit, testStart(t, "prepcleanexit"), testStop(t, "prepcleanexit"))
err = Start()
if err != ErrCleanExit {
t.Error("should fail with clean exit")
@ -149,7 +144,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{})
// test invalid dependency
Register("database", nil, testStart("database"), testStop("database"), "invalid")
Register("database", nil, testStart(t, "database"), testStop(t, "database"), "invalid")
err = Start()
if err == nil {
t.Error("should fail")
@ -161,8 +156,8 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{})
// test dependency loop
Register("database", nil, testStart("database"), testStop("database"), "helper")
Register("helper", nil, testStart("helper"), testStop("helper"), "database")
Register("database", nil, testStart(t, "database"), testStop(t, "database"), "helper")
Register("helper", nil, testStart(t, "helper"), testStop(t, "helper"), "database")
err = Start()
if err == nil {
t.Error("should fail")
@ -174,7 +169,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{})
// test failing module start
Register("startfail", nil, testFail, testStop("startfail"))
Register("startfail", nil, testFail, testStop(t, "startfail"))
err = Start()
if err == nil {
t.Error("should fail")
@ -186,7 +181,7 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{})
// test failing module stop
Register("stopfail", nil, testStart("stopfail"), testFail)
Register("stopfail", nil, testStart(t, "stopfail"), testFail)
err = Start()
if err != nil {
t.Error("should not fail")

View file

@ -3,6 +3,7 @@ package modules
import (
"fmt"
"os"
"runtime"
"github.com/safing/portbase/log"
"github.com/tevino/abool"
@ -13,13 +14,6 @@ var (
startCompleteSignal = make(chan struct{})
)
// markStartComplete marks the startup as completed.
func markStartComplete() {
if startComplete.SetToIf(false, true) {
close(startCompleteSignal)
}
}
// StartCompleted returns whether starting has completed.
func StartCompleted() bool {
return startComplete.IsSet()
@ -35,6 +29,10 @@ func Start() error {
modulesLock.Lock()
defer modulesLock.Unlock()
// start microtask scheduler
go microTaskScheduler()
SetMaxConcurrentMicroTasks(runtime.GOMAXPROCS(0) * 2)
// inter-link modules
err := initDependencies()
if err != nil {
@ -59,6 +57,7 @@ func Start() error {
}
// start logging
log.EnableScheduling()
err = log.Start()
if err != nil {
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to start logging: %s\n", err)
@ -79,6 +78,9 @@ func Start() error {
close(startCompleteSignal)
}
go taskQueueHandler()
go taskScheduleHandler()
return nil
}

441
modules/tasks.go Normal file
View file

@ -0,0 +1,441 @@
package modules
import (
"container/list"
"context"
"sync"
"sync/atomic"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
)
// Task is managed task bound to a module.
type Task struct {
name string
module *Module
taskFn func(context.Context, *Task)
queued bool
canceled bool
executing bool
cancelFunc func()
executeAt time.Time
repeat time.Duration
maxDelay time.Duration
queueElement *list.Element
prioritizedQueueElement *list.Element
scheduleListElement *list.Element
lock sync.Mutex
}
var (
taskQueue = list.New()
prioritizedTaskQueue = list.New()
queuesLock sync.Mutex
queueWg sync.WaitGroup
taskSchedule = list.New()
scheduleLock sync.Mutex
waitForever chan time.Time
queueIsFilled = make(chan struct{}, 1) // kick off queue handler
recalculateNextScheduledTask = make(chan struct{}, 1)
taskTimeslot = make(chan struct{})
)
const (
maxTimeslotWait = 30 * time.Second
minRepeatDuration = 1 * time.Minute
maxExecutionWait = 1 * time.Minute
defaultMaxDelay = 1 * time.Minute
)
// NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed.
func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task {
return &Task{
name: name,
module: m,
taskFn: fn,
maxDelay: defaultMaxDelay,
}
}
func (t *Task) isActive() bool {
return !t.canceled && !t.module.ShutdownInProgress()
}
func (t *Task) prepForQueueing() (ok bool) {
if !t.isActive() {
return false
}
t.queued = true
if t.maxDelay != 0 {
t.executeAt = time.Now().Add(t.maxDelay)
t.addToSchedule()
}
return true
}
func notifyQueue() {
select {
case queueIsFilled <- struct{}{}:
default:
}
}
// Queue queues the Task for execution.
func (t *Task) Queue() *Task {
t.lock.Lock()
if !t.prepForQueueing() {
t.lock.Unlock()
return t
}
t.lock.Unlock()
if t.queueElement == nil {
queuesLock.Lock()
t.queueElement = taskQueue.PushBack(t)
queuesLock.Unlock()
}
notifyQueue()
return t
}
// Prioritize puts the task in the prioritized queue.
func (t *Task) Prioritize() *Task {
t.lock.Lock()
if !t.prepForQueueing() {
t.lock.Unlock()
return t
}
t.lock.Unlock()
if t.prioritizedQueueElement == nil {
queuesLock.Lock()
t.prioritizedQueueElement = prioritizedTaskQueue.PushBack(t)
queuesLock.Unlock()
}
notifyQueue()
return t
}
// StartASAP schedules the task to be executed next.
func (t *Task) StartASAP() *Task {
t.lock.Lock()
if !t.prepForQueueing() {
t.lock.Unlock()
return t
}
t.lock.Unlock()
queuesLock.Lock()
if t.prioritizedQueueElement == nil {
t.prioritizedQueueElement = prioritizedTaskQueue.PushFront(t)
} else {
prioritizedTaskQueue.MoveToFront(t.prioritizedQueueElement)
}
queuesLock.Unlock()
notifyQueue()
return t
}
// MaxDelay sets a maximum delay within the task should be executed from being queued. Scheduled tasks are queued when they are triggered. The default delay is 3 minutes.
func (t *Task) MaxDelay(maxDelay time.Duration) *Task {
t.lock.Lock()
t.maxDelay = maxDelay
t.lock.Unlock()
return t
}
// Schedule schedules the task for execution at the given time.
func (t *Task) Schedule(executeAt time.Time) *Task {
t.lock.Lock()
t.executeAt = executeAt
t.addToSchedule()
t.lock.Unlock()
return t
}
// Repeat sets the task to be executed in endless repeat at the specified interval. First execution will be after interval. Minimum repeat interval is one minute.
func (t *Task) Repeat(interval time.Duration) *Task {
// check minimum interval duration
if interval < minRepeatDuration {
interval = minRepeatDuration
}
t.lock.Lock()
t.repeat = interval
t.executeAt = time.Now().Add(t.repeat)
t.lock.Unlock()
return t
}
// Cancel cancels the current and any future execution of the Task. This is not reversible by any other functions.
func (t *Task) Cancel() {
t.lock.Lock()
t.canceled = true
if t.cancelFunc != nil {
t.cancelFunc()
}
t.lock.Unlock()
}
func (t *Task) runWithLocking() {
// wait for good timeslot regarding microtasks
select {
case <-taskTimeslot:
case <-time.After(maxTimeslotWait):
}
t.lock.Lock()
// check state, return if already executing or inactive
if t.executing || !t.isActive() {
t.lock.Unlock()
return
}
t.executing = true
// get list elements
queueElement := t.queueElement
prioritizedQueueElement := t.prioritizedQueueElement
scheduleListElement := t.scheduleListElement
// create context
var taskCtx context.Context
taskCtx, t.cancelFunc = context.WithCancel(t.module.Ctx)
t.lock.Unlock()
// remove from lists
if queueElement != nil {
queuesLock.Lock()
taskQueue.Remove(t.queueElement)
queuesLock.Unlock()
t.lock.Lock()
t.queueElement = nil
t.lock.Unlock()
}
if prioritizedQueueElement != nil {
queuesLock.Lock()
prioritizedTaskQueue.Remove(t.prioritizedQueueElement)
queuesLock.Unlock()
t.lock.Lock()
t.prioritizedQueueElement = nil
t.lock.Unlock()
}
if scheduleListElement != nil {
scheduleLock.Lock()
taskSchedule.Remove(t.scheduleListElement)
scheduleLock.Unlock()
t.lock.Lock()
t.scheduleListElement = nil
t.lock.Unlock()
}
// add to queue workgroup
queueWg.Add(1)
go t.executeWithLocking(taskCtx, t.cancelFunc)
go func() {
select {
case <-taskCtx.Done():
case <-time.After(maxExecutionWait):
}
// complete queue worker (early) to allow next worker
queueWg.Done()
}()
}
func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) {
// start for module
// hint: only queueWg global var is important for scheduling, others can be set here
atomic.AddInt32(t.module.taskCnt, 1)
t.module.waitGroup.Add(1)
defer func() {
// recover from panic
panicVal := recover()
if panicVal != nil {
me := t.module.NewPanicError(t.name, "task", panicVal)
me.Report()
log.Errorf("%s: task %s panicked: %s\n%s", t.module.Name, t.name, panicVal, me.StackTrace)
}
// finish for module
atomic.AddInt32(t.module.taskCnt, -1)
t.module.waitGroup.Done()
// reset
t.lock.Lock()
// reset state
t.executing = false
t.queued = false
// repeat?
if t.isActive() && t.repeat != 0 {
t.executeAt = time.Now().Add(t.repeat)
t.addToSchedule()
}
t.lock.Unlock()
// notify that we finished
cancelFunc()
}()
// run
t.taskFn(ctx, t)
}
func (t *Task) getExecuteAtWithLocking() time.Time {
t.lock.Lock()
defer t.lock.Unlock()
return t.executeAt
}
func (t *Task) addToSchedule() {
scheduleLock.Lock()
defer scheduleLock.Unlock()
// notify scheduler
defer func() {
select {
case recalculateNextScheduledTask <- struct{}{}:
default:
}
}()
// insert task into schedule
for e := taskSchedule.Front(); e != nil; e = e.Next() {
// check for self
eVal := e.Value.(*Task)
if eVal == t {
continue
}
// compare
if t.executeAt.Before(eVal.getExecuteAtWithLocking()) {
// insert/move task
if t.scheduleListElement == nil {
t.scheduleListElement = taskSchedule.InsertBefore(t, e)
} else {
taskSchedule.MoveBefore(t.scheduleListElement, e)
}
return
}
}
// add/move to end
if t.scheduleListElement == nil {
t.scheduleListElement = taskSchedule.PushBack(t)
} else {
taskSchedule.MoveToBack(t.scheduleListElement)
}
}
func waitUntilNextScheduledTask() <-chan time.Time {
scheduleLock.Lock()
defer scheduleLock.Unlock()
if taskSchedule.Len() > 0 {
return time.After(time.Until(taskSchedule.Front().Value.(*Task).executeAt))
}
return waitForever
}
var (
taskQueueHandlerStarted = abool.NewBool(false)
taskScheduleHandlerStarted = abool.NewBool(false)
)
func taskQueueHandler() {
// only ever start once
if !taskQueueHandlerStarted.SetToIf(false, true) {
return
}
for {
// wait
select {
case <-shutdownSignal:
return
case <-queueIsFilled:
}
// execute
execLoop:
for {
// wait for execution slot
queueWg.Wait()
// check for shutdown
if shutdownSignalClosed.IsSet() {
return
}
// get next Task
queuesLock.Lock()
e := prioritizedTaskQueue.Front()
if e != nil {
prioritizedTaskQueue.Remove(e)
} else {
e = taskQueue.Front()
if e != nil {
taskQueue.Remove(e)
}
}
queuesLock.Unlock()
// lists are empty
if e == nil {
break execLoop
}
// value -> Task
t := e.Value.(*Task)
// run
t.runWithLocking()
}
}
}
func taskScheduleHandler() {
// only ever start once
if !taskScheduleHandlerStarted.SetToIf(false, true) {
return
}
for {
select {
case <-shutdownSignal:
return
case <-recalculateNextScheduledTask:
case <-waitUntilNextScheduledTask():
// get first task in schedule
scheduleLock.Lock()
e := taskSchedule.Front()
scheduleLock.Unlock()
t := e.Value.(*Task)
// process Task
if t.queued {
// already queued and maxDelay reached
t.runWithLocking()
} else {
// place in front of prioritized queue
t.StartASAP()
}
}
}
}

162
modules/tasks_test.go Normal file
View file

@ -0,0 +1,162 @@
package modules
import (
"context"
"fmt"
"os"
"runtime/pprof"
"sync"
"testing"
"time"
)
func init() {
go taskQueueHandler()
go taskScheduleHandler()
go func() {
<-time.After(10 * time.Second)
fmt.Fprintln(os.Stderr, "taking too long")
_ = pprof.Lookup("goroutine").WriteTo(os.Stderr, 2)
os.Exit(1)
}()
// always trigger task timeslot for testing
go func() {
for {
taskTimeslot <- struct{}{}
}
}()
}
// test waiting
// globals
var qtWg sync.WaitGroup
var qtOutputChannel chan string
var qtSleepDuration time.Duration
var qtModule = initNewModule("task test module", nil, nil, nil)
// functions
func queuedTaskTester(s string) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
time.Sleep(qtSleepDuration * 2)
qtOutputChannel <- s
qtWg.Done()
}).Queue()
}
func prioritizedTaskTester(s string) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
time.Sleep(qtSleepDuration * 2)
qtOutputChannel <- s
qtWg.Done()
}).Prioritize()
}
// test
func TestQueuedTask(t *testing.T) {
// skip
if testing.Short() {
t.Skip("skipping test in short mode, as it is not fully deterministic")
}
// init
expectedOutput := "0123456789"
qtSleepDuration = 20 * time.Millisecond
qtOutputChannel = make(chan string, 100)
qtWg.Add(10)
// TEST
queuedTaskTester("0")
queuedTaskTester("1")
queuedTaskTester("3")
queuedTaskTester("4")
queuedTaskTester("6")
queuedTaskTester("7")
queuedTaskTester("9")
time.Sleep(qtSleepDuration * 3)
prioritizedTaskTester("2")
time.Sleep(qtSleepDuration * 6)
prioritizedTaskTester("5")
time.Sleep(qtSleepDuration * 6)
prioritizedTaskTester("8")
// wait for test to finish
qtWg.Wait()
// collect output
close(qtOutputChannel)
completeOutput := ""
for s := <-qtOutputChannel; s != ""; s = <-qtOutputChannel {
completeOutput += s
}
// check if test succeeded
if completeOutput != expectedOutput {
t.Errorf("QueuedTask test failed, expected sequence %s, got %s", expectedOutput, completeOutput)
}
}
// test scheduled tasks
// globals
var stWg sync.WaitGroup
var stOutputChannel chan string
var stSleepDuration time.Duration
var stWaitCh chan bool
// functions
func scheduledTaskTester(s string, sched time.Time) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
time.Sleep(stSleepDuration)
stOutputChannel <- s
stWg.Done()
}).Schedule(sched)
}
// test
func TestScheduledTaskWaiting(t *testing.T) {
// skip
if testing.Short() {
t.Skip("skipping test in short mode, as it is not fully deterministic")
}
// init
expectedOutput := "0123456789"
stSleepDuration = 10 * time.Millisecond
stOutputChannel = make(chan string, 100)
stWaitCh = make(chan bool)
stWg.Add(10)
// TEST
scheduledTaskTester("4", time.Now().Add(stSleepDuration*8))
scheduledTaskTester("0", time.Now().Add(stSleepDuration*0))
scheduledTaskTester("8", time.Now().Add(stSleepDuration*16))
scheduledTaskTester("1", time.Now().Add(stSleepDuration*2))
scheduledTaskTester("7", time.Now().Add(stSleepDuration*14))
scheduledTaskTester("9", time.Now().Add(stSleepDuration*18))
scheduledTaskTester("3", time.Now().Add(stSleepDuration*6))
scheduledTaskTester("2", time.Now().Add(stSleepDuration*4))
scheduledTaskTester("6", time.Now().Add(stSleepDuration*12))
scheduledTaskTester("5", time.Now().Add(stSleepDuration*10))
// wait for test to finish
close(stWaitCh)
stWg.Wait()
// collect output
close(stOutputChannel)
completeOutput := ""
for s := <-stOutputChannel; s != ""; s = <-stOutputChannel {
completeOutput += s
}
// check if test succeeded
if completeOutput != expectedOutput {
t.Errorf("ScheduledTask test failed, expected sequence %s, got %s", expectedOutput, completeOutput)
}
}

79
modules/worker.go Normal file
View file

@ -0,0 +1,79 @@
package modules
import (
"context"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
)
// Worker Default Configuration
const (
DefaultBackoffDuration = 2 * time.Second
)
// RunWorker directly runs a generic worker that does not fit to be a Task or MicroTask, such as long running (and possibly mostly idle) sessions. A call to RunWorker blocks until the worker is finished.
func (m *Module) RunWorker(name string, fn func(context.Context) error) error {
atomic.AddInt32(m.workerCnt, 1)
m.waitGroup.Add(1)
defer func() {
atomic.AddInt32(m.workerCnt, -1)
m.waitGroup.Done()
}()
return m.runWorker(name, fn)
}
// StartServiceWorker starts a generic worker, which is automatically restarted in case of an error. A call to StartServiceWorker runs the service-worker in a new goroutine and returns immediately. `backoffDuration` specifies how to long to wait before restarts, multiplied by the number of failed attempts. Pass `0` for the default backoff duration. For custom error remediation functionality, build your own error handling procedure using calls to RunWorker.
func (m *Module) StartServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) {
go m.runServiceWorker(name, backoffDuration, fn)
}
func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) {
atomic.AddInt32(m.workerCnt, 1)
m.waitGroup.Add(1)
defer func() {
atomic.AddInt32(m.workerCnt, -1)
m.waitGroup.Done()
}()
if backoffDuration == 0 {
backoffDuration = DefaultBackoffDuration
}
failCnt := 0
for {
if m.ShutdownInProgress() {
return
}
err := m.runWorker(name, fn)
if err != nil {
// log error and restart
failCnt++
sleepFor := time.Duration(failCnt) * backoffDuration
log.Errorf("%s: service-worker %s failed (%d): %s - restarting in %s", m.Name, name, failCnt, err, sleepFor)
time.Sleep(sleepFor)
} else {
// finish
return
}
}
}
func (m *Module) runWorker(name string, fn func(context.Context) error) (err error) {
defer func() {
// recover from panic
panicVal := recover()
if panicVal != nil {
me := m.NewPanicError(name, "worker", panicVal)
me.Report()
err = me
}
}()
// run
err = fn(m.Ctx)
return
}

65
modules/worker_test.go Normal file
View file

@ -0,0 +1,65 @@
package modules
import (
"context"
"errors"
"sync"
"testing"
"time"
)
var (
wModule = initNewModule("worker test module", nil, nil, nil)
errTest = errors.New("test error")
)
func TestWorker(t *testing.T) {
// test basic functionality
err := wModule.RunWorker("test worker", func(ctx context.Context) error {
return nil
})
if err != nil {
t.Errorf("worker failed (should not): %s", err)
}
// test returning an error
err = wModule.RunWorker("test worker", func(ctx context.Context) error {
return errTest
})
if err != errTest {
t.Errorf("worker failed with unexpected error: %s", err)
}
// test service functionality
failCnt := 0
var sWTestGroup sync.WaitGroup
sWTestGroup.Add(1)
wModule.StartServiceWorker("test service-worker", 2*time.Millisecond, func(ctx context.Context) error {
failCnt++
t.Logf("service-worker test run #%d", failCnt)
if failCnt >= 3 {
sWTestGroup.Done()
return nil
}
return errTest
})
// wait for service-worker to complete test
sWTestGroup.Wait()
if failCnt != 3 {
t.Errorf("service-worker failed to restart")
}
// test panic recovery
err = wModule.RunWorker("test worker", func(ctx context.Context) error {
var a []byte
_ = a[0] //nolint // we want to runtime panic!
return nil
})
t.Logf("panic error message: %s", err)
panicked, mErr := IsPanic(err)
if !panicked {
t.Errorf("failed to return *ModuleError, got %+v", err)
} else {
t.Logf("panic stack trace:\n%s", mErr.StackTrace)
}
}

View file

@ -1,19 +1,21 @@
package notifications
import (
"context"
"time"
"github.com/safing/portbase/log"
)
func cleaner() {
shutdownWg.Add(1)
select {
case <-shutdownSignal:
shutdownWg.Done()
return
case <-time.After(5 * time.Second):
cleanNotifications()
//nolint:unparam // must conform to interface
func cleaner(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
case <-time.After(5 * time.Second):
cleanNotifications()
}
}
}

View file

@ -138,7 +138,7 @@ func UpdateNotification(n *Notification, key string) {
n.Lock()
defer n.Unlock()
// seperate goroutine in order to correctly lock notsLock
// separate goroutine in order to correctly lock notsLock
notsLock.RLock()
origN, ok := nots[key]
notsLock.RUnlock()

View file

@ -1,14 +1,13 @@
package notifications
import (
"sync"
"time"
"github.com/safing/portbase/modules"
)
var (
shutdownSignal = make(chan struct{})
shutdownWg sync.WaitGroup
module *modules.Module
)
func init() {
@ -21,12 +20,6 @@ func start() error {
return err
}
go cleaner()
return nil
}
func stop() error {
close(shutdownSignal)
shutdownWg.Wait()
go module.StartServiceWorker("cleaner", 1*time.Second, cleaner)
return nil
}

View file

@ -32,15 +32,15 @@ type Notification struct {
DataSubject sync.Locker
Type uint8
AvailableActions []*Action
SelectedActionID string
Persistent bool // this notification persists until it is handled and survives restarts
Created int64 // creation timestamp, notification "starts"
Expires int64 // expiry timestamp, notification is expected to be canceled at this time and may be cleaned up afterwards
Responded int64 // response timestamp, notification "ends"
Executed int64 // execution timestamp, notification will be deleted soon
AvailableActions []*Action
SelectedActionID string
lock sync.Mutex
actionFunction func(*Notification) // call function to process action
actionTrigger chan string // and/or send to a channel
@ -54,7 +54,6 @@ type Action struct {
}
func noOpAction(n *Notification) {
return
}
// Get returns the notification identifed by the given id or nil if it doesn't exist.
@ -125,7 +124,7 @@ func (n *Notification) Save() *Notification {
}
// SetActionFunction sets a trigger function to be executed when the user reacted on the notification.
// The provided funtion will be started as its own goroutine and will have to lock everything it accesses, even the provided notification.
// The provided function will be started as its own goroutine and will have to lock everything it accesses, even the provided notification.
func (n *Notification) SetActionFunction(fn func(*Notification)) *Notification {
n.lock.Lock()
defer n.lock.Unlock()
@ -139,7 +138,7 @@ func (n *Notification) MakeAck() *Notification {
defer n.lock.Unlock()
n.AvailableActions = []*Action{
&Action{
{
ID: "ack",
Text: "OK",
},

View file

@ -29,7 +29,7 @@ func main() {
// Shutdown
// catch interrupt for clean shutdown
signalCh := make(chan os.Signal)
signalCh := make(chan os.Signal, 3)
signal.Notify(
signalCh,
os.Interrupt,
@ -42,7 +42,7 @@ func main() {
case <-signalCh:
fmt.Println(" <INTERRUPT>")
log.Warning("main: program was interrupted, shutting down.")
modules.Shutdown()
_ = modules.Shutdown()
case <-modules.ShuttingDown():
}

View file

@ -1,176 +0,0 @@
package taskmanager
import (
"sync/atomic"
"time"
"github.com/tevino/abool"
)
// TODO: getting some errors when in nanosecond precision for tests:
// (1) panic: sync: WaitGroup is reused before previous Wait has returned - should theoretically not happen
// (2) sometimes there seems to some kind of race condition stuff, the test hangs and does not complete
var (
closedChannel chan bool
tasks *int32
mediumPriorityClearance chan bool
lowPriorityClearance chan bool
veryLowPriorityClearance chan bool
tasksDone chan bool
tasksDoneFlag *abool.AtomicBool
tasksWaiting chan bool
tasksWaitingFlag *abool.AtomicBool
shutdownSignal = make(chan struct{}, 0)
suttingDown = abool.NewBool(false)
)
// StartMicroTask starts a new MicroTask. It will start immediately.
func StartMicroTask() {
atomic.AddInt32(tasks, 1)
tasksDoneFlag.UnSet()
}
// EndMicroTask MUST be always called when a MicroTask was previously started.
func EndMicroTask() {
c := atomic.AddInt32(tasks, -1)
if c < 1 {
if tasksDoneFlag.SetToIf(false, true) {
tasksDone <- true
}
}
}
func newTaskIsWaiting() {
tasksWaiting <- true
}
// StartMediumPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
func StartMediumPriorityMicroTask() chan bool {
if suttingDown.IsSet() {
return closedChannel
}
if tasksWaitingFlag.SetToIf(false, true) {
defer newTaskIsWaiting()
}
return mediumPriorityClearance
}
// StartLowPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
func StartLowPriorityMicroTask() chan bool {
if suttingDown.IsSet() {
return closedChannel
}
if tasksWaitingFlag.SetToIf(false, true) {
defer newTaskIsWaiting()
}
return lowPriorityClearance
}
// StartVeryLowPriorityMicroTask starts a new MicroTask (waiting its turn) if channel receives.
func StartVeryLowPriorityMicroTask() chan bool {
if suttingDown.IsSet() {
return closedChannel
}
if tasksWaitingFlag.SetToIf(false, true) {
defer newTaskIsWaiting()
}
return veryLowPriorityClearance
}
func start() error {
return nil
}
func stop() error {
close(shutdownSignal)
return nil
}
func init() {
closedChannel = make(chan bool, 0)
close(closedChannel)
var t int32 = 0
tasks = &t
mediumPriorityClearance = make(chan bool, 0)
lowPriorityClearance = make(chan bool, 0)
veryLowPriorityClearance = make(chan bool, 0)
tasksDone = make(chan bool, 1)
tasksDoneFlag = abool.NewBool(true)
tasksWaiting = make(chan bool, 1)
tasksWaitingFlag = abool.NewBool(false)
timoutTimerDuration := 1 * time.Second
// timoutTimer := time.NewTimer(timoutTimerDuration)
go func() {
microTaskManageLoop:
for {
// wait for an event to start new tasks
if !suttingDown.IsSet() {
// reset timer
// https://golang.org/pkg/time/#Timer.Reset
// if !timoutTimer.Stop() {
// <-timoutTimer.C
// }
// timoutTimer.Reset(timoutTimerDuration)
// wait for event to start a new task
select {
case <-tasksWaiting:
if !tasksDoneFlag.IsSet() {
continue microTaskManageLoop
}
case <-time.After(timoutTimerDuration):
case <-tasksDone:
case <-shutdownSignal:
}
} else {
// execute tasks until no tasks are waiting anymore
if !tasksWaitingFlag.IsSet() {
// wait until tasks are finished
if !tasksDoneFlag.IsSet() {
<-tasksDone
}
// signal module completion
// microTasksModule.StopComplete()
// exit
return
}
}
// start new task, if none is started, check if we are shutting down
select {
case mediumPriorityClearance <- true:
StartMicroTask()
default:
select {
case lowPriorityClearance <- true:
StartMicroTask()
default:
select {
case veryLowPriorityClearance <- true:
StartMicroTask()
default:
tasksWaitingFlag.UnSet()
}
}
}
}
}()
}

View file

@ -1,152 +0,0 @@
package taskmanager
import (
"container/list"
"time"
"github.com/tevino/abool"
)
type Task struct {
name string
start chan bool
started *abool.AtomicBool
schedule *time.Time
}
var taskQueue *list.List
var prioritizedTaskQueue *list.List
var addToQueue chan *Task
var addToPrioritizedQueue chan *Task
var addAsNextTask chan *Task
var finishedQueuedTask chan bool
var queuedTaskRunning *abool.AtomicBool
var getQueueLengthREQ chan bool
var getQueueLengthREP chan int
func newUnqeuedTask(name string) *Task {
t := &Task{
name,
make(chan bool),
abool.NewBool(false),
nil,
}
return t
}
func NewQueuedTask(name string) *Task {
t := newUnqeuedTask(name)
addToQueue <- t
return t
}
func NewPrioritizedQueuedTask(name string) *Task {
t := newUnqeuedTask(name)
addToPrioritizedQueue <- t
return t
}
func (t *Task) addToPrioritizedQueue() {
addToPrioritizedQueue <- t
}
func (t *Task) WaitForStart() chan bool {
return t.start
}
func (t *Task) StartAnyway() {
addAsNextTask <- t
}
func (t *Task) Done() {
if !t.started.SetToIf(false, true) {
finishedQueuedTask <- true
}
}
func TotalQueuedTasks() int {
getQueueLengthREQ <- true
return <-getQueueLengthREP
}
func checkQueueStatus() {
if queuedTaskRunning.SetToIf(false, true) {
finishedQueuedTask <- true
}
}
func fireNextTask() {
if prioritizedTaskQueue.Len() > 0 {
for e := prioritizedTaskQueue.Front(); prioritizedTaskQueue.Len() > 0; e.Next() {
t := e.Value.(*Task)
prioritizedTaskQueue.Remove(e)
if t.started.SetToIf(false, true) {
close(t.start)
return
}
}
}
if taskQueue.Len() > 0 {
for e := taskQueue.Front(); taskQueue.Len() > 0; e.Next() {
t := e.Value.(*Task)
taskQueue.Remove(e)
if t.started.SetToIf(false, true) {
close(t.start)
return
}
}
}
queuedTaskRunning.UnSet()
}
func init() {
taskQueue = list.New()
prioritizedTaskQueue = list.New()
addToQueue = make(chan *Task, 1)
addToPrioritizedQueue = make(chan *Task, 1)
addAsNextTask = make(chan *Task, 1)
finishedQueuedTask = make(chan bool, 1)
queuedTaskRunning = abool.NewBool(false)
getQueueLengthREQ = make(chan bool, 1)
getQueueLengthREP = make(chan int, 1)
go func() {
for {
select {
case <-shutdownSignal:
// TODO: work off queue?
return
case <-getQueueLengthREQ:
// TODO: maybe clean queues before replying
if queuedTaskRunning.IsSet() {
getQueueLengthREP <- prioritizedTaskQueue.Len() + taskQueue.Len() + 1
} else {
getQueueLengthREP <- prioritizedTaskQueue.Len() + taskQueue.Len()
}
case t := <-addToQueue:
taskQueue.PushBack(t)
checkQueueStatus()
case t := <-addToPrioritizedQueue:
prioritizedTaskQueue.PushBack(t)
checkQueueStatus()
case t := <-addAsNextTask:
prioritizedTaskQueue.PushFront(t)
checkQueueStatus()
case <-finishedQueuedTask:
fireNextTask()
}
}
}()
}

View file

@ -1,110 +0,0 @@
package taskmanager
import (
"sync"
"testing"
"time"
)
// test waiting
// globals
var qtWg sync.WaitGroup
var qtOutputChannel chan string
var qtSleepDuration time.Duration
// functions
func queuedTaskTester(s string) {
t := NewQueuedTask(s)
go func() {
<-t.WaitForStart()
time.Sleep(qtSleepDuration * 2)
qtOutputChannel <- s
t.Done()
qtWg.Done()
}()
}
func prioritizedTastTester(s string) {
t := NewPrioritizedQueuedTask(s)
go func() {
<-t.WaitForStart()
time.Sleep(qtSleepDuration * 2)
qtOutputChannel <- s
t.Done()
qtWg.Done()
}()
}
// test
func TestQueuedTask(t *testing.T) {
// skip
if testing.Short() {
t.Skip("skipping test in short mode.")
}
// init
expectedOutput := "0123456789"
qtSleepDuration = 10 * time.Millisecond
qtOutputChannel = make(chan string, 100)
qtWg.Add(10)
// test queue length
c := TotalQueuedTasks()
if c != 0 {
t.Errorf("Error in calculating Task Queue, expected 0, got %d", c)
}
// TEST
queuedTaskTester("0")
queuedTaskTester("1")
queuedTaskTester("3")
queuedTaskTester("4")
queuedTaskTester("6")
queuedTaskTester("7")
queuedTaskTester("9")
// test queue length
c = TotalQueuedTasks()
if c != 7 {
t.Errorf("Error in calculating Task Queue, expected 7, got %d", c)
}
time.Sleep(qtSleepDuration * 3)
prioritizedTastTester("2")
time.Sleep(qtSleepDuration * 6)
prioritizedTastTester("5")
time.Sleep(qtSleepDuration * 6)
prioritizedTastTester("8")
// test queue length
c = TotalQueuedTasks()
if c != 3 {
t.Errorf("Error in calculating Task Queue, expected 3, got %d", c)
}
// time.Sleep(qtSleepDuration * 100)
// panic("")
// wait for test to finish
qtWg.Wait()
// test queue length
c = TotalQueuedTasks()
if c != 0 {
t.Errorf("Error in calculating Task Queue, expected 0, got %d", c)
}
// collect output
close(qtOutputChannel)
completeOutput := ""
for s := <-qtOutputChannel; s != ""; s = <-qtOutputChannel {
completeOutput += s
}
// check if test succeeded
if completeOutput != expectedOutput {
t.Errorf("QueuedTask test failed, expected sequence %s, got %s", expectedOutput, completeOutput)
}
}

View file

@ -1,73 +0,0 @@
package taskmanager
import (
"container/list"
"time"
)
var taskSchedule *list.List
var addToSchedule chan *Task
var waitForever chan time.Time
var getScheduleLengthREQ chan bool
var getScheduleLengthREP chan int
func NewScheduledTask(name string, schedule time.Time) *Task {
t := newUnqeuedTask(name)
t.schedule = &schedule
addToSchedule <- t
return t
}
func TotalScheduledTasks() int {
getScheduleLengthREQ <- true
return <-getScheduleLengthREP
}
func (t *Task) addToSchedule() {
for e := taskSchedule.Back(); e != nil; e = e.Prev() {
if t.schedule.After(*e.Value.(*Task).schedule) {
taskSchedule.InsertAfter(t, e)
return
}
}
taskSchedule.PushFront(t)
}
func waitUntilNextScheduledTask() <-chan time.Time {
if taskSchedule.Len() > 0 {
return time.After(taskSchedule.Front().Value.(*Task).schedule.Sub(time.Now()))
}
return waitForever
}
func init() {
taskSchedule = list.New()
addToSchedule = make(chan *Task, 1)
waitForever = make(chan time.Time, 1)
getScheduleLengthREQ = make(chan bool, 1)
getScheduleLengthREP = make(chan int, 1)
go func() {
for {
select {
case <-shutdownSignal:
return
case <-getScheduleLengthREQ:
// TODO: maybe clean queues before replying
getScheduleLengthREP <- prioritizedTaskQueue.Len() + taskSchedule.Len()
case t := <-addToSchedule:
t.addToSchedule()
case <-waitUntilNextScheduledTask():
e := taskSchedule.Front()
t := e.Value.(*Task)
t.addToPrioritizedQueue()
taskSchedule.Remove(e)
}
}
}()
}

View file

@ -1,93 +0,0 @@
package taskmanager
import (
"sync"
"testing"
"time"
)
// test waiting
// globals
var stWg sync.WaitGroup
var stOutputChannel chan string
var stSleepDuration time.Duration
var stWaitCh chan bool
// functions
func scheduledTaskTester(s string, sched time.Time) {
t := NewScheduledTask(s, sched)
go func() {
<-stWaitCh
<-t.WaitForStart()
time.Sleep(stSleepDuration)
stOutputChannel <- s
t.Done()
stWg.Done()
}()
}
// test
func TestScheduledTaskWaiting(t *testing.T) {
// skip
if testing.Short() {
t.Skip("skipping test in short mode.")
}
// init
expectedOutput := "0123456789"
stSleepDuration = 10 * time.Millisecond
stOutputChannel = make(chan string, 100)
stWaitCh = make(chan bool, 0)
// test queue length
c := TotalScheduledTasks()
if c != 0 {
t.Errorf("Error in calculating Task Queue, expected 0, got %d", c)
}
stWg.Add(10)
// TEST
scheduledTaskTester("4", time.Now().Add(stSleepDuration*4))
scheduledTaskTester("0", time.Now().Add(stSleepDuration*1))
scheduledTaskTester("8", time.Now().Add(stSleepDuration*8))
scheduledTaskTester("1", time.Now().Add(stSleepDuration*2))
scheduledTaskTester("7", time.Now().Add(stSleepDuration*7))
// test queue length
time.Sleep(1 * time.Millisecond)
c = TotalScheduledTasks()
if c != 5 {
t.Errorf("Error in calculating Task Queue, expected 5, got %d", c)
}
scheduledTaskTester("9", time.Now().Add(stSleepDuration*9))
scheduledTaskTester("3", time.Now().Add(stSleepDuration*3))
scheduledTaskTester("2", time.Now().Add(stSleepDuration*2))
scheduledTaskTester("6", time.Now().Add(stSleepDuration*6))
scheduledTaskTester("5", time.Now().Add(stSleepDuration*5))
// wait for test to finish
close(stWaitCh)
stWg.Wait()
// test queue length
c = TotalScheduledTasks()
if c != 0 {
t.Errorf("Error in calculating Task Queue, expected 0, got %d", c)
}
// collect output
close(stOutputChannel)
completeOutput := ""
for s := <-stOutputChannel; s != ""; s = <-stOutputChannel {
completeOutput += s
}
// check if test succeeded
if completeOutput != expectedOutput {
t.Errorf("ScheduledTask test failed, expected sequence %s, got %s", expectedOutput, completeOutput)
}
}

134
test
View file

@ -4,6 +4,23 @@ warnings=0
errors=0
scripted=0
goUp="\\e[1A"
all=0
fullTestFlags="-short"
install=0
function help {
echo "usage: $0 [command] [options]"
echo ""
echo "commands:"
echo " <none> run baseline tests"
echo " all run all tests"
echo " install install deps for running baseline tests"
echo " install all install deps for running all tests"
echo ""
echo "options:"
echo " --scripted dont jump console lines (still use colors)"
echo " [package] run tests only on this package"
}
function run {
if [[ $scripted -eq 0 ]]; then
@ -36,8 +53,8 @@ function run {
if [[ $output == *"build constraints exclude all Go files"* ]]; then
echo -e "${goUp}[ !=OS ] $*"
else
echo -e "${goUp}[\e[01;31m FAIL \e[00m] $*" >/dev/stderr
cat $tmpfile >/dev/stderr
echo -e "${goUp}[\e[01;31m FAIL \e[00m] $*"
cat $tmpfile
errors=$((errors+1))
fi
fi
@ -45,27 +62,124 @@ function run {
rm -f $tmpfile
}
function checkformat {
if [[ $scripted -eq 0 ]]; then
echo "[......] gofmt $1"
fi
output=$(gofmt -l $GOPATH/src/$1/*.go)
if [[ $output == "" ]]; then
echo -e "${goUp}[\e[01;32m OK \e[00m] gofmt $*"
else
echo -e "${goUp}[\e[01;31m FAIL \e[00m] gofmt $*"
echo "The following files do not conform to gofmt:"
gofmt -l $GOPATH/src/$1/*.go # keeps format
errors=$((errors+1))
fi
}
# get and switch to script dir
baseDir="$( cd "$(dirname "$0")" && pwd )"
cd "$baseDir"
# change output format if being run in script
if [[ $1 == "--scripted" ]]; then
scripted=1
goUp=""
# args
while true; do
case "$1" in
"-h"|"help"|"--help")
help
exit 0
;;
"--scripted")
scripted=1
goUp=""
shift 1
;;
"install")
install=1
shift 1
;;
"all")
all=1
fullTestFlags=""
shift 1
;;
*)
break
;;
esac
done
# check if $GOPATH/bin is in $PATH
if [[ $PATH != *"$GOPATH/bin"* ]]; then
export PATH=$GOPATH/bin:$PATH
fi
# install
if [[ $install -eq 1 ]]; then
echo "installing dependencies..."
echo "$ go get -u golang.org/x/lint/golint"
go get -u golang.org/x/lint/golint
if [[ $all -eq 1 ]]; then
echo "$ go get -u github.com/golangci/golangci-lint/cmd/golangci-lint"
go get -u github.com/golangci/golangci-lint/cmd/golangci-lint
fi
fi
# check dependencies
if [[ $(which go) == "" ]]; then
echo "go command not found"
exit 1
fi
if [[ $(which gofmt) == "" ]]; then
echo "gofmt command not found"
exit 1
fi
if [[ $(which golint) == "" ]]; then
echo "golint command not found"
echo "install with: go get -u golang.org/x/lint/golint"
echo "or run: ./test install"
exit 1
fi
if [[ $all -eq 1 ]]; then
if [[ $(which golangci-lint) == "" ]]; then
echo "golangci-lint command not found"
echo "install locally with: go get -u github.com/golangci/golangci-lint/cmd/golangci-lint"
echo "or run: ./test install all"
echo ""
echo "hint: install for CI with: curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin vX.Y.Z"
echo "don't forget to specify the version you want"
exit 1
fi
fi
# target selection
if [[ "$1" == "" ]]; then
# get all packages
packages=$(go list ./...)
else
# single package testing
packages=$(go list)/$1
if [[ ! -d "$GOPATH/src/$packages" ]]; then
echo "go package $packages does not exist"
help
exit 1
fi
echo "note: only running tests for package $packages"
fi
# platform info
platformInfo=$(go env GOOS GOARCH)
echo "running tests for ${platformInfo//$'\n'/ }:"
# get all packages
packages=$(go list ./...)
# run vet/test on packages
for package in $packages; do
checkformat $package
run golint -set_exit_status -min_confidence 1.0 $package
run go vet $package
run go test -cover $package
run go test -cover $fullTestFlags $package
if [[ $all -eq 1 ]]; then
run golangci-lint run $GOPATH/src/$package
fi
done
echo ""

View file

@ -8,7 +8,7 @@ import (
const isWindows = runtime.GOOS == "windows"
// EnsureDirectory ensures that the given directoy exists and that is has the given permissions set.
// EnsureDirectory ensures that the given directory exists and that is has the given permissions set.
// If path is a file, it is deleted and a directory created.
// If a directory is created, also all missing directories up to the required one are created with the given permissions.
func EnsureDirectory(path string, perm os.FileMode) error {

View file

@ -59,7 +59,7 @@ func ExampleDirStructure() {
fmt.Println(err)
}
filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
_ = filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
if err == nil {
dir := strings.TrimPrefix(path, basePath)
if dir == "" {

View file

@ -1,8 +0,0 @@
package testutils
import "runtime"
func GetLineNumberOfCaller(levels int) int {
_, _, line, _ := runtime.Caller(levels + 1)
return line
}