mirror of
https://github.com/safing/portbase
synced 2025-09-02 02:29:59 +00:00
Release to master
This commit is contained in:
commit
4899e44f60
73 changed files with 2821 additions and 849 deletions
|
@ -7,5 +7,11 @@ linters:
|
|||
- funlen
|
||||
- whitespace
|
||||
- wsl
|
||||
- godox
|
||||
- gomnd
|
||||
|
||||
linters-settings:
|
||||
godox:
|
||||
# report any comments starting with keywords, this is useful for TODO or FIXME comments that
|
||||
# might be left in the code accidentally and should be resolved before merging
|
||||
keywords:
|
||||
- FIXME
|
||||
|
|
9
Gopkg.lock
generated
9
Gopkg.lock
generated
|
@ -161,14 +161,6 @@
|
|||
revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4"
|
||||
version = "v0.8.1"
|
||||
|
||||
[[projects]]
|
||||
branch = "develop"
|
||||
digest = "1:d88649ff4a4a0746857dd9e39915aedddce2b08e442ac131a91e573cd45bde93"
|
||||
name = "github.com/safing/portmaster"
|
||||
packages = ["core/structure"]
|
||||
pruneopts = "UT"
|
||||
revision = "26c307b7a0db78d91b35ef9020706f106ebef8b6"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:274f67cb6fed9588ea2521ecdac05a6d62a8c51c074c1fccc6a49a40ba80e925"
|
||||
name = "github.com/satori/go.uuid"
|
||||
|
@ -332,7 +324,6 @@
|
|||
"github.com/gorilla/mux",
|
||||
"github.com/gorilla/websocket",
|
||||
"github.com/hashicorp/go-version",
|
||||
"github.com/safing/portmaster/core/structure",
|
||||
"github.com/satori/go.uuid",
|
||||
"github.com/seehuhn/fortuna",
|
||||
"github.com/shirou/gopsutil/host",
|
||||
|
|
|
@ -243,8 +243,8 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) {
|
|||
if err == nil {
|
||||
data, err = r.Marshal(r, record.JSON)
|
||||
}
|
||||
if err == nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil) //nolint:nilness // FIXME: possibly false positive (golangci-lint govet/nilness)
|
||||
if err != nil {
|
||||
api.send(opID, dbMsgTypeError, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
api.send(opID, dbMsgTypeOk, r.Key(), data)
|
||||
|
@ -384,9 +384,9 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
|
|||
default:
|
||||
api.send(opID, dbMsgTypeUpd, r.Key(), data)
|
||||
}
|
||||
} else if sub.Err != nil {
|
||||
} else {
|
||||
// sub feed ended
|
||||
api.send(opID, dbMsgTypeError, sub.Err.Error(), nil)
|
||||
api.send(opID, dbMsgTypeDone, "", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -435,13 +435,13 @@ func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create b
|
|||
return
|
||||
}
|
||||
|
||||
// FIXME: remove transition code
|
||||
if data[0] != record.JSON {
|
||||
typedData := make([]byte, len(data)+1)
|
||||
typedData[0] = record.JSON
|
||||
copy(typedData[1:], data)
|
||||
data = typedData
|
||||
}
|
||||
// TODO - staged for deletion: remove transition code
|
||||
// if data[0] != record.JSON {
|
||||
// typedData := make([]byte, len(data)+1)
|
||||
// typedData[0] = record.JSON
|
||||
// copy(typedData[1:], data)
|
||||
// data = typedData
|
||||
// }
|
||||
|
||||
r, err := record.NewWrapper(key, nil, data[0], data[1:])
|
||||
if err != nil {
|
||||
|
|
|
@ -17,7 +17,7 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("api", prep, start, stop, "base", "database", "config")
|
||||
module = modules.Register("api", prep, start, stop, "database", "config")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
|
|
|
@ -36,26 +36,30 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
|
|||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (s *StorageInterface) Put(r record.Record) error {
|
||||
func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
|
||||
if r.Meta().Deleted > 0 {
|
||||
return setConfigOption(r.DatabaseKey(), nil, false)
|
||||
return r, setConfigOption(r.DatabaseKey(), nil, false)
|
||||
}
|
||||
|
||||
acc := r.GetAccessor(r)
|
||||
if acc == nil {
|
||||
return errors.New("invalid data")
|
||||
return nil, errors.New("invalid data")
|
||||
}
|
||||
|
||||
val, ok := acc.Get("Value")
|
||||
if !ok || val == nil {
|
||||
return setConfigOption(r.DatabaseKey(), nil, false)
|
||||
err := setConfigOption(r.DatabaseKey(), nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.Get(r.DatabaseKey())
|
||||
}
|
||||
|
||||
optionsLock.RLock()
|
||||
option, ok := options[r.DatabaseKey()]
|
||||
optionsLock.RUnlock()
|
||||
if !ok {
|
||||
return errors.New("config option does not exist")
|
||||
return nil, errors.New("config option does not exist")
|
||||
}
|
||||
|
||||
var value interface{}
|
||||
|
@ -70,14 +74,14 @@ func (s *StorageInterface) Put(r record.Record) error {
|
|||
value, ok = acc.GetBool("Value")
|
||||
}
|
||||
if !ok {
|
||||
return errors.New("received invalid value in \"Value\"")
|
||||
return nil, errors.New("received invalid value in \"Value\"")
|
||||
}
|
||||
|
||||
err := setConfigOption(r.DatabaseKey(), value, false)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return nil
|
||||
return option.Export()
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
|
|
|
@ -5,6 +5,8 @@ package config
|
|||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// Expertise Level constants
|
||||
|
@ -21,7 +23,9 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
expertiseLevel *int32
|
||||
expertiseLevel *int32
|
||||
expertiseLevelOption *Option
|
||||
expertiseLevelOptionFlag = abool.New()
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -32,7 +36,7 @@ func init() {
|
|||
}
|
||||
|
||||
func registerExpertiseLevelOption() {
|
||||
err := Register(&Option{
|
||||
expertiseLevelOption = &Option{
|
||||
Name: "Expertise Level",
|
||||
Key: expertiseLevelKey,
|
||||
Description: "The Expertise Level controls the perceived complexity. Higher settings will show you more complex settings and information. This might also affect various other things relying on this setting. Modified settings in higher expertise levels stay in effect when switching back. (Unlike the Release Level)",
|
||||
|
@ -46,15 +50,31 @@ func registerExpertiseLevelOption() {
|
|||
|
||||
ExternalOptType: "string list",
|
||||
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ExpertiseLevelNameUser, ExpertiseLevelNameExpert, ExpertiseLevelNameDeveloper),
|
||||
})
|
||||
}
|
||||
|
||||
err := Register(expertiseLevelOption)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
expertiseLevelOptionFlag.Set()
|
||||
}
|
||||
|
||||
func updateExpertiseLevel() {
|
||||
new := findStringValue(expertiseLevelKey, "")
|
||||
switch new {
|
||||
// check if already registered
|
||||
if !expertiseLevelOptionFlag.IsSet() {
|
||||
return
|
||||
}
|
||||
// get value
|
||||
value := expertiseLevelOption.activeFallbackValue
|
||||
if expertiseLevelOption.activeValue != nil {
|
||||
value = expertiseLevelOption.activeValue
|
||||
}
|
||||
if expertiseLevelOption.activeDefaultValue != nil {
|
||||
value = expertiseLevelOption.activeDefaultValue
|
||||
}
|
||||
// set atomic value
|
||||
switch value.stringVal {
|
||||
case ExpertiseLevelNameUser:
|
||||
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser))
|
||||
case ExpertiseLevelNameExpert:
|
||||
|
|
|
@ -12,14 +12,24 @@ var (
|
|||
// GetAsString returns a function that returns the wanted string with high performance.
|
||||
func (cs *safe) GetAsString(name string, fallback string) StringOption {
|
||||
valid := getValidityFlag()
|
||||
value := findStringValue(name, fallback)
|
||||
option, valueCache := getValueCache(name, nil, OptTypeString)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringVal
|
||||
}
|
||||
var lock sync.Mutex
|
||||
|
||||
return func() string {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
value = findStringValue(name, fallback)
|
||||
option, valueCache = getValueCache(name, option, OptTypeString)
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
@ -28,14 +38,24 @@ func (cs *safe) GetAsString(name string, fallback string) StringOption {
|
|||
// GetAsStringArray returns a function that returns the wanted string with high performance.
|
||||
func (cs *safe) GetAsStringArray(name string, fallback []string) StringArrayOption {
|
||||
valid := getValidityFlag()
|
||||
value := findStringArrayValue(name, fallback)
|
||||
option, valueCache := getValueCache(name, nil, OptTypeStringArray)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringArrayVal
|
||||
}
|
||||
var lock sync.Mutex
|
||||
|
||||
return func() []string {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
value = findStringArrayValue(name, fallback)
|
||||
option, valueCache = getValueCache(name, option, OptTypeStringArray)
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringArrayVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
@ -44,14 +64,24 @@ func (cs *safe) GetAsStringArray(name string, fallback []string) StringArrayOpti
|
|||
// GetAsInt returns a function that returns the wanted int with high performance.
|
||||
func (cs *safe) GetAsInt(name string, fallback int64) IntOption {
|
||||
valid := getValidityFlag()
|
||||
value := findIntValue(name, fallback)
|
||||
option, valueCache := getValueCache(name, nil, OptTypeInt)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.intVal
|
||||
}
|
||||
var lock sync.Mutex
|
||||
|
||||
return func() int64 {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
value = findIntValue(name, fallback)
|
||||
option, valueCache = getValueCache(name, option, OptTypeInt)
|
||||
if valueCache != nil {
|
||||
value = valueCache.intVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
@ -60,14 +90,24 @@ func (cs *safe) GetAsInt(name string, fallback int64) IntOption {
|
|||
// GetAsBool returns a function that returns the wanted int with high performance.
|
||||
func (cs *safe) GetAsBool(name string, fallback bool) BoolOption {
|
||||
valid := getValidityFlag()
|
||||
value := findBoolValue(name, fallback)
|
||||
option, valueCache := getValueCache(name, nil, OptTypeBool)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.boolVal
|
||||
}
|
||||
var lock sync.Mutex
|
||||
|
||||
return func() bool {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
value = findBoolValue(name, fallback)
|
||||
option, valueCache = getValueCache(name, option, OptTypeBool)
|
||||
if valueCache != nil {
|
||||
value = valueCache.boolVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
|
188
config/get.go
188
config/get.go
|
@ -15,14 +15,59 @@ type (
|
|||
BoolOption func() bool
|
||||
)
|
||||
|
||||
func getValueCache(name string, option *Option, requestedType uint8) (*Option, *valueCache) {
|
||||
// get option
|
||||
if option == nil {
|
||||
var ok bool
|
||||
optionsLock.RLock()
|
||||
option, ok = options[name]
|
||||
optionsLock.RUnlock()
|
||||
if !ok {
|
||||
log.Errorf("config: request for unregistered option: %s", name)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// check type
|
||||
if requestedType != option.OptType {
|
||||
log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(option.OptType))
|
||||
return option, nil
|
||||
}
|
||||
|
||||
// lock option
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
||||
// check release level
|
||||
if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil {
|
||||
return option, option.activeValue
|
||||
}
|
||||
|
||||
if option.activeDefaultValue != nil {
|
||||
return option, option.activeDefaultValue
|
||||
}
|
||||
|
||||
return option, option.activeFallbackValue
|
||||
}
|
||||
|
||||
// GetAsString returns a function that returns the wanted string with high performance.
|
||||
func GetAsString(name string, fallback string) StringOption {
|
||||
valid := getValidityFlag()
|
||||
value := findStringValue(name, fallback)
|
||||
option, valueCache := getValueCache(name, nil, OptTypeString)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringVal
|
||||
}
|
||||
|
||||
return func() string {
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
value = findStringValue(name, fallback)
|
||||
option, valueCache = getValueCache(name, option, OptTypeString)
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
@ -31,11 +76,21 @@ func GetAsString(name string, fallback string) StringOption {
|
|||
// GetAsStringArray returns a function that returns the wanted string with high performance.
|
||||
func GetAsStringArray(name string, fallback []string) StringArrayOption {
|
||||
valid := getValidityFlag()
|
||||
value := findStringArrayValue(name, fallback)
|
||||
option, valueCache := getValueCache(name, nil, OptTypeStringArray)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringArrayVal
|
||||
}
|
||||
|
||||
return func() []string {
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
value = findStringArrayValue(name, fallback)
|
||||
option, valueCache = getValueCache(name, option, OptTypeStringArray)
|
||||
if valueCache != nil {
|
||||
value = valueCache.stringArrayVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
@ -44,11 +99,21 @@ func GetAsStringArray(name string, fallback []string) StringArrayOption {
|
|||
// GetAsInt returns a function that returns the wanted int with high performance.
|
||||
func GetAsInt(name string, fallback int64) IntOption {
|
||||
valid := getValidityFlag()
|
||||
value := findIntValue(name, fallback)
|
||||
option, valueCache := getValueCache(name, nil, OptTypeInt)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.intVal
|
||||
}
|
||||
|
||||
return func() int64 {
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
value = findIntValue(name, fallback)
|
||||
option, valueCache = getValueCache(name, option, OptTypeInt)
|
||||
if valueCache != nil {
|
||||
value = valueCache.intVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
@ -57,18 +122,28 @@ func GetAsInt(name string, fallback int64) IntOption {
|
|||
// GetAsBool returns a function that returns the wanted int with high performance.
|
||||
func GetAsBool(name string, fallback bool) BoolOption {
|
||||
valid := getValidityFlag()
|
||||
value := findBoolValue(name, fallback)
|
||||
option, valueCache := getValueCache(name, nil, OptTypeBool)
|
||||
value := fallback
|
||||
if valueCache != nil {
|
||||
value = valueCache.boolVal
|
||||
}
|
||||
|
||||
return func() bool {
|
||||
if !valid.IsSet() {
|
||||
valid = getValidityFlag()
|
||||
value = findBoolValue(name, fallback)
|
||||
option, valueCache = getValueCache(name, option, OptTypeBool)
|
||||
if valueCache != nil {
|
||||
value = valueCache.boolVal
|
||||
} else {
|
||||
value = fallback
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// findValue find the correct value in the user or default config.
|
||||
func findValue(key string) interface{} {
|
||||
/*
|
||||
func getAndFindValue(key string) interface{} {
|
||||
optionsLock.RLock()
|
||||
option, ok := options[key]
|
||||
optionsLock.RUnlock()
|
||||
|
@ -77,6 +152,13 @@ func findValue(key string) interface{} {
|
|||
return nil
|
||||
}
|
||||
|
||||
return option.findValue()
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
// findValue finds the preferred value in the user or default config.
|
||||
func (option *Option) findValue() interface{} {
|
||||
// lock option
|
||||
option.Lock()
|
||||
defer option.Unlock()
|
||||
|
@ -91,88 +173,4 @@ func findValue(key string) interface{} {
|
|||
|
||||
return option.DefaultValue
|
||||
}
|
||||
|
||||
// findStringValue validates and returns the value with the given key.
|
||||
func findStringValue(key string, fallback string) (value string) {
|
||||
result := findValue(key)
|
||||
if result == nil {
|
||||
return fallback
|
||||
}
|
||||
v, ok := result.(string)
|
||||
if ok {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// findStringArrayValue validates and returns the value with the given key.
|
||||
func findStringArrayValue(key string, fallback []string) (value []string) {
|
||||
result := findValue(key)
|
||||
if result == nil {
|
||||
return fallback
|
||||
}
|
||||
|
||||
v, ok := result.([]interface{})
|
||||
if ok {
|
||||
new := make([]string, len(v))
|
||||
for i, val := range v {
|
||||
s, ok := val.(string)
|
||||
if ok {
|
||||
new[i] = s
|
||||
} else {
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
return new
|
||||
}
|
||||
|
||||
return fallback
|
||||
}
|
||||
|
||||
// findIntValue validates and returns the value with the given key.
|
||||
func findIntValue(key string, fallback int64) (value int64) {
|
||||
result := findValue(key)
|
||||
if result == nil {
|
||||
return fallback
|
||||
}
|
||||
switch v := result.(type) {
|
||||
case int:
|
||||
return int64(v)
|
||||
case int8:
|
||||
return int64(v)
|
||||
case int16:
|
||||
return int64(v)
|
||||
case int32:
|
||||
return int64(v)
|
||||
case int64:
|
||||
return v
|
||||
case uint:
|
||||
return int64(v)
|
||||
case uint8:
|
||||
return int64(v)
|
||||
case uint16:
|
||||
return int64(v)
|
||||
case uint32:
|
||||
return int64(v)
|
||||
case uint64:
|
||||
return int64(v)
|
||||
case float32:
|
||||
return int64(v)
|
||||
case float64:
|
||||
return int64(v)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// findBoolValue validates and returns the value with the given key.
|
||||
func findBoolValue(key string, fallback bool) (value bool) {
|
||||
result := findValue(key)
|
||||
if result == nil {
|
||||
return fallback
|
||||
}
|
||||
v, ok := result.(bool)
|
||||
if ok {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
|
@ -39,7 +40,7 @@ func quickRegister(t *testing.T, key string, optType uint8, defaultValue interfa
|
|||
}
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
func TestGet(t *testing.T) { //nolint:gocognit
|
||||
// reset
|
||||
options = make(map[string]*Option)
|
||||
|
||||
|
@ -48,41 +49,41 @@ func TestGet(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
quickRegister(t, "monkey", OptTypeInt, -1)
|
||||
quickRegister(t, "monkey", OptTypeString, "c")
|
||||
quickRegister(t, "zebras/zebra", OptTypeStringArray, []string{"a", "b"})
|
||||
quickRegister(t, "elephant", OptTypeInt, -1)
|
||||
quickRegister(t, "hot", OptTypeBool, false)
|
||||
quickRegister(t, "cold", OptTypeBool, true)
|
||||
|
||||
err = parseAndSetConfig(`
|
||||
{
|
||||
"monkey": "1",
|
||||
{
|
||||
"monkey": "a",
|
||||
"zebras": {
|
||||
"zebra": ["black", "white"]
|
||||
},
|
||||
"elephant": 2,
|
||||
"elephant": 2,
|
||||
"hot": true,
|
||||
"cold": false
|
||||
}
|
||||
`)
|
||||
}
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = parseAndSetDefaultConfig(`
|
||||
{
|
||||
"monkey": "0",
|
||||
"snake": "0",
|
||||
"elephant": 0
|
||||
}
|
||||
`)
|
||||
{
|
||||
"monkey": "b",
|
||||
"snake": "0",
|
||||
"elephant": 0
|
||||
}
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
monkey := GetAsString("monkey", "none")
|
||||
if monkey() != "1" {
|
||||
t.Errorf("monkey should be 1, is %s", monkey())
|
||||
if monkey() != "a" {
|
||||
t.Errorf("monkey should be a, is %s", monkey())
|
||||
}
|
||||
|
||||
zebra := GetAsStringArray("zebras/zebra", []string{})
|
||||
|
@ -106,10 +107,10 @@ func TestGet(t *testing.T) {
|
|||
}
|
||||
|
||||
err = parseAndSetConfig(`
|
||||
{
|
||||
"monkey": "3"
|
||||
}
|
||||
`)
|
||||
{
|
||||
"monkey": "3"
|
||||
}
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -131,6 +132,53 @@ func TestGet(t *testing.T) {
|
|||
GetAsInt("elephant", -1)()
|
||||
GetAsBool("hot", false)()
|
||||
|
||||
// perspective
|
||||
|
||||
// load data
|
||||
pLoaded := make(map[string]interface{})
|
||||
err = json.Unmarshal([]byte(`{
|
||||
"monkey": "a",
|
||||
"zebras": {
|
||||
"zebra": ["black", "white"]
|
||||
},
|
||||
"elephant": 2,
|
||||
"hot": true,
|
||||
"cold": false
|
||||
}`), &pLoaded)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create
|
||||
p, err := NewPerspective(pLoaded)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
monkeyVal, ok := p.GetAsString("monkey")
|
||||
if !ok || monkeyVal != "a" {
|
||||
t.Errorf("[perspective] monkey should be a, is %+v", monkeyVal)
|
||||
}
|
||||
|
||||
zebraVal, ok := p.GetAsStringArray("zebras/zebra")
|
||||
if !ok || len(zebraVal) != 2 || zebraVal[0] != "black" || zebraVal[1] != "white" {
|
||||
t.Errorf("[perspective] zebra should be [\"black\", \"white\"], is %+v", zebraVal)
|
||||
}
|
||||
|
||||
elephantVal, ok := p.GetAsInt("elephant")
|
||||
if !ok || elephantVal != 2 {
|
||||
t.Errorf("[perspective] elephant should be 2, is %+v", elephantVal)
|
||||
}
|
||||
|
||||
hotVal, ok := p.GetAsBool("hot")
|
||||
if !ok || !hotVal {
|
||||
t.Errorf("[perspective] hot should be true, is %+v", hotVal)
|
||||
}
|
||||
|
||||
coldVal, ok := p.GetAsBool("cold")
|
||||
if !ok || coldVal {
|
||||
t.Errorf("[perspective] cold should be false, is %+v", coldVal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseLevel(t *testing.T) {
|
||||
|
@ -236,11 +284,9 @@ func BenchmarkGetAsStringCached(b *testing.B) {
|
|||
options = make(map[string]*Option)
|
||||
|
||||
// Setup
|
||||
err := parseAndSetConfig(`
|
||||
{
|
||||
"monkey": "banana"
|
||||
}
|
||||
`)
|
||||
err := parseAndSetConfig(`{
|
||||
"monkey": "banana"
|
||||
}`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -257,11 +303,9 @@ func BenchmarkGetAsStringCached(b *testing.B) {
|
|||
|
||||
func BenchmarkGetAsStringRefetch(b *testing.B) {
|
||||
// Setup
|
||||
err := parseAndSetConfig(`
|
||||
{
|
||||
"monkey": "banana"
|
||||
}
|
||||
`)
|
||||
err := parseAndSetConfig(`{
|
||||
"monkey": "banana"
|
||||
}`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -271,38 +315,34 @@ func BenchmarkGetAsStringRefetch(b *testing.B) {
|
|||
|
||||
// Start benchmark
|
||||
for i := 0; i < b.N; i++ {
|
||||
findStringValue("monkey", "no banana")
|
||||
getValueCache("monkey", nil, OptTypeString)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetAsIntCached(b *testing.B) {
|
||||
// Setup
|
||||
err := parseAndSetConfig(`
|
||||
{
|
||||
"monkey": 1
|
||||
}
|
||||
`)
|
||||
err := parseAndSetConfig(`{
|
||||
"elephant": 1
|
||||
}`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
monkey := GetAsInt("monkey", -1)
|
||||
elephant := GetAsInt("elephant", -1)
|
||||
|
||||
// Reset timer for precise results
|
||||
b.ResetTimer()
|
||||
|
||||
// Start benchmark
|
||||
for i := 0; i < b.N; i++ {
|
||||
monkey()
|
||||
elephant()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetAsIntRefetch(b *testing.B) {
|
||||
// Setup
|
||||
err := parseAndSetConfig(`
|
||||
{
|
||||
"monkey": 1
|
||||
}
|
||||
`)
|
||||
err := parseAndSetConfig(`{
|
||||
"elephant": 1
|
||||
}`)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -312,6 +352,6 @@ func BenchmarkGetAsIntRefetch(b *testing.B) {
|
|||
|
||||
// Start benchmark
|
||||
for i := 0; i < b.N; i++ {
|
||||
findIntValue("monkey", 1)
|
||||
getValueCache("elephant", nil, OptTypeInt)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,9 +5,9 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/utils"
|
||||
"github.com/safing/portmaster/core/structure"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -27,12 +27,12 @@ func SetDataRoot(root *utils.DirStructure) {
|
|||
}
|
||||
|
||||
func init() {
|
||||
module = modules.Register("config", prep, start, nil, "base", "database")
|
||||
module = modules.Register("config", prep, start, nil, "database")
|
||||
module.RegisterEvent(configChangeEvent)
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
SetDataRoot(structure.Root())
|
||||
SetDataRoot(dataroot.Root())
|
||||
if dataRoot == nil {
|
||||
return errors.New("data root is not set")
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ type Option struct {
|
|||
Name string
|
||||
Key string // in path format: category/sub/key
|
||||
Description string
|
||||
Help string
|
||||
|
||||
OptType uint8
|
||||
ExpertiseLevel uint8
|
||||
|
@ -52,9 +53,10 @@ type Option struct {
|
|||
ExternalOptType string
|
||||
ValidationRegex string
|
||||
|
||||
activeValue interface{} // runtime value (loaded from config file or set by user)
|
||||
activeDefaultValue interface{} // runtime default value (may be set internally)
|
||||
compiledRegex *regexp.Regexp
|
||||
activeValue *valueCache // runtime value (loaded from config file or set by user)
|
||||
activeDefaultValue *valueCache // runtime default value (may be set internally)
|
||||
activeFallbackValue *valueCache // default value from option registration
|
||||
compiledRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
// Export expors an option to a Record.
|
||||
|
@ -68,14 +70,14 @@ func (option *Option) Export() (record.Record, error) {
|
|||
}
|
||||
|
||||
if option.activeValue != nil {
|
||||
data, err = sjson.SetBytes(data, "Value", option.activeValue)
|
||||
data, err = sjson.SetBytes(data, "Value", option.activeValue.getData(option))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if option.activeDefaultValue != nil {
|
||||
data, err = sjson.SetBytes(data, "DefaultValue", option.activeDefaultValue)
|
||||
data, err = sjson.SetBytes(data, "DefaultValue", option.activeDefaultValue.getData(option))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ func saveConfig() error {
|
|||
for key, option := range options {
|
||||
option.Lock()
|
||||
if option.activeValue != nil {
|
||||
activeValues[key] = option.activeValue
|
||||
activeValues[key] = option.activeValue.getData(option)
|
||||
}
|
||||
option.Unlock()
|
||||
}
|
||||
|
|
128
config/perspective.go
Normal file
128
config/perspective.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
)
|
||||
|
||||
// Perspective is a view on configuration data without interfering with the configuration system.
|
||||
type Perspective struct {
|
||||
config map[string]*perspectiveOption
|
||||
}
|
||||
|
||||
type perspectiveOption struct {
|
||||
option *Option
|
||||
valueCache *valueCache
|
||||
}
|
||||
|
||||
// NewPerspective parses the given config and returns it as a new perspective.
|
||||
func NewPerspective(config map[string]interface{}) (*Perspective, error) {
|
||||
// flatten config structure
|
||||
flatten(config, config, "")
|
||||
|
||||
perspective := &Perspective{
|
||||
config: make(map[string]*perspectiveOption),
|
||||
}
|
||||
var firstErr error
|
||||
var errCnt int
|
||||
|
||||
optionsLock.Lock()
|
||||
optionsLoop:
|
||||
for key, option := range options {
|
||||
// get option key from config
|
||||
configValue, ok := config[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// validate value
|
||||
valueCache, err := validateValue(option, configValue)
|
||||
if err != nil {
|
||||
errCnt++
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue optionsLoop
|
||||
}
|
||||
|
||||
// add to perspective
|
||||
perspective.config[key] = &perspectiveOption{
|
||||
option: option,
|
||||
valueCache: valueCache,
|
||||
}
|
||||
}
|
||||
optionsLock.Unlock()
|
||||
|
||||
if firstErr != nil {
|
||||
if errCnt > 0 {
|
||||
return perspective, fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
|
||||
}
|
||||
return perspective, firstErr
|
||||
}
|
||||
|
||||
return perspective, nil
|
||||
}
|
||||
|
||||
func (p *Perspective) getPerspectiveValueCache(name string, requestedType uint8) *valueCache {
|
||||
// get option
|
||||
pOption, ok := p.config[name]
|
||||
if !ok {
|
||||
// check if option exists at all
|
||||
optionsLock.RLock()
|
||||
_, ok = options[name]
|
||||
optionsLock.RUnlock()
|
||||
if !ok {
|
||||
log.Errorf("config: request for unregistered option: %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// check type
|
||||
if requestedType != pOption.option.OptType {
|
||||
log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(pOption.option.OptType))
|
||||
return nil
|
||||
}
|
||||
|
||||
// check release level
|
||||
if pOption.option.ReleaseLevel > getReleaseLevel() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return pOption.valueCache
|
||||
}
|
||||
|
||||
// GetAsString returns a function that returns the wanted string with high performance.
|
||||
func (p *Perspective) GetAsString(name string) (value string, ok bool) {
|
||||
valueCache := p.getPerspectiveValueCache(name, OptTypeString)
|
||||
if valueCache != nil {
|
||||
return valueCache.stringVal, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetAsStringArray returns a function that returns the wanted string with high performance.
|
||||
func (p *Perspective) GetAsStringArray(name string) (value []string, ok bool) {
|
||||
valueCache := p.getPerspectiveValueCache(name, OptTypeStringArray)
|
||||
if valueCache != nil {
|
||||
return valueCache.stringArrayVal, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// GetAsInt returns a function that returns the wanted int with high performance.
|
||||
func (p *Perspective) GetAsInt(name string) (value int64, ok bool) {
|
||||
valueCache := p.getPerspectiveValueCache(name, OptTypeInt)
|
||||
if valueCache != nil {
|
||||
return valueCache.intVal, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetAsBool returns a function that returns the wanted int with high performance.
|
||||
func (p *Perspective) GetAsBool(name string) (value bool, ok bool) {
|
||||
valueCache := p.getPerspectiveValueCache(name, OptTypeBool)
|
||||
if valueCache != nil {
|
||||
return valueCache.boolVal, true
|
||||
}
|
||||
return false, false
|
||||
}
|
|
@ -26,14 +26,20 @@ func Register(option *Option) error {
|
|||
return fmt.Errorf("failed to register option: please set option.OptType")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
if option.ValidationRegex != "" {
|
||||
var err error
|
||||
option.compiledRegex, err = regexp.Compile(option.ValidationRegex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: could not compile option.ValidationRegex: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
option.activeFallbackValue, err = validateValue(option, option.DefaultValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: invalid default value: %s", err)
|
||||
}
|
||||
|
||||
optionsLock.Lock()
|
||||
defer optionsLock.Unlock()
|
||||
options[option.Key] = option
|
||||
|
|
|
@ -15,7 +15,7 @@ func TestRegistry(t *testing.T) {
|
|||
ReleaseLevel: ReleaseLevelStable,
|
||||
ExpertiseLevel: ExpertiseLevelUser,
|
||||
OptType: OptTypeString,
|
||||
DefaultValue: "default",
|
||||
DefaultValue: "water",
|
||||
ValidationRegex: "^(banana|water)$",
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
|
|
|
@ -5,6 +5,8 @@ package config
|
|||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// Release Level constants
|
||||
|
@ -21,7 +23,9 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
releaseLevel *int32
|
||||
releaseLevel *int32
|
||||
releaseLevelOption *Option
|
||||
releaseLevelOptionFlag = abool.New()
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -32,7 +36,7 @@ func init() {
|
|||
}
|
||||
|
||||
func registerReleaseLevelOption() {
|
||||
err := Register(&Option{
|
||||
releaseLevelOption = &Option{
|
||||
Name: "Release Level",
|
||||
Key: releaseLevelKey,
|
||||
Description: "The Release Level changes which features are available to you. Some beta or experimental features are also available in the stable release channel. Unavailable settings are set to the default value.",
|
||||
|
@ -46,15 +50,31 @@ func registerReleaseLevelOption() {
|
|||
|
||||
ExternalOptType: "string list",
|
||||
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ReleaseLevelNameStable, ReleaseLevelNameBeta, ReleaseLevelNameExperimental),
|
||||
})
|
||||
}
|
||||
|
||||
err := Register(releaseLevelOption)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
releaseLevelOptionFlag.Set()
|
||||
}
|
||||
|
||||
func updateReleaseLevel() {
|
||||
new := findStringValue(releaseLevelKey, "")
|
||||
switch new {
|
||||
// check if already registered
|
||||
if !releaseLevelOptionFlag.IsSet() {
|
||||
return
|
||||
}
|
||||
// get value
|
||||
value := releaseLevelOption.activeFallbackValue
|
||||
if releaseLevelOption.activeValue != nil {
|
||||
value = releaseLevelOption.activeValue
|
||||
}
|
||||
if releaseLevelOption.activeDefaultValue != nil {
|
||||
value = releaseLevelOption.activeDefaultValue
|
||||
}
|
||||
// set atomic value
|
||||
switch value.stringVal {
|
||||
case ReleaseLevelNameStable:
|
||||
atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable))
|
||||
case ReleaseLevelNameBeta:
|
||||
|
|
|
@ -19,6 +19,7 @@ var (
|
|||
validityFlagLock sync.RWMutex
|
||||
)
|
||||
|
||||
// getValidityFlag returns a flag that signifies if the configuration has been changed. This flag must not be changed, only read.
|
||||
func getValidityFlag() *abool.AtomicBool {
|
||||
validityFlagLock.RLock()
|
||||
defer validityFlagLock.RUnlock()
|
||||
|
@ -41,14 +42,24 @@ func signalChanges() {
|
|||
|
||||
// setConfig sets the (prioritized) user defined config.
|
||||
func setConfig(newValues map[string]interface{}) error {
|
||||
var firstErr error
|
||||
var errCnt int
|
||||
|
||||
optionsLock.Lock()
|
||||
for key, option := range options {
|
||||
newValue, ok := newValues[key]
|
||||
option.Lock()
|
||||
option.activeValue = nil
|
||||
if ok {
|
||||
option.activeValue = newValue
|
||||
} else {
|
||||
option.activeValue = nil
|
||||
valueCache, err := validateValue(option, newValue)
|
||||
if err == nil {
|
||||
option.activeValue = valueCache
|
||||
} else {
|
||||
errCnt++
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
option.Unlock()
|
||||
}
|
||||
|
@ -56,19 +67,37 @@ func setConfig(newValues map[string]interface{}) error {
|
|||
|
||||
signalChanges()
|
||||
go pushFullUpdate()
|
||||
|
||||
if firstErr != nil {
|
||||
if errCnt > 0 {
|
||||
return fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultConfig sets the (fallback) default config.
|
||||
func SetDefaultConfig(newValues map[string]interface{}) error {
|
||||
var firstErr error
|
||||
var errCnt int
|
||||
|
||||
optionsLock.Lock()
|
||||
for key, option := range options {
|
||||
newValue, ok := newValues[key]
|
||||
option.Lock()
|
||||
option.activeDefaultValue = nil
|
||||
if ok {
|
||||
option.activeDefaultValue = newValue
|
||||
} else {
|
||||
option.activeDefaultValue = nil
|
||||
valueCache, err := validateValue(option, newValue)
|
||||
if err == nil {
|
||||
option.activeDefaultValue = valueCache
|
||||
} else {
|
||||
errCnt++
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
option.Unlock()
|
||||
}
|
||||
|
@ -76,51 +105,15 @@ func SetDefaultConfig(newValues map[string]interface{}) error {
|
|||
|
||||
signalChanges()
|
||||
go pushFullUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateValue(option *Option, value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if option.OptType != OptTypeString {
|
||||
return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
|
||||
if firstErr != nil {
|
||||
if errCnt > 0 {
|
||||
return fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
if !option.compiledRegex.MatchString(v) {
|
||||
return fmt.Errorf("validation failed: string \"%s\" did not match regex for option %s", v, option.Key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []string:
|
||||
if option.OptType != OptTypeStringArray {
|
||||
return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
for pos, entry := range v {
|
||||
if !option.compiledRegex.MatchString(entry) {
|
||||
return fmt.Errorf("validation failed: string \"%s\" at index %d did not match regex for option %s", entry, pos, option.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
if option.OptType != OptTypeInt {
|
||||
return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
if !option.compiledRegex.MatchString(fmt.Sprintf("%d", v)) {
|
||||
return fmt.Errorf("validation failed: number \"%d\" did not match regex for option %s", v, option.Key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case bool:
|
||||
if option.OptType != OptTypeBool {
|
||||
return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("invalid option value type: %T", value)
|
||||
return firstErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConfigOption sets a single value in the (prioritized) user defined config.
|
||||
|
@ -140,9 +133,10 @@ func setConfigOption(key string, value interface{}, push bool) (err error) {
|
|||
if value == nil {
|
||||
option.activeValue = nil
|
||||
} else {
|
||||
err = validateValue(option, value)
|
||||
var valueCache *valueCache
|
||||
valueCache, err = validateValue(option, value)
|
||||
if err == nil {
|
||||
option.activeValue = value
|
||||
option.activeValue = valueCache
|
||||
}
|
||||
}
|
||||
option.Unlock()
|
||||
|
@ -175,9 +169,10 @@ func setDefaultConfigOption(key string, value interface{}, push bool) (err error
|
|||
if value == nil {
|
||||
option.activeDefaultValue = nil
|
||||
} else {
|
||||
err = validateValue(option, value)
|
||||
var valueCache *valueCache
|
||||
valueCache, err = validateValue(option, value)
|
||||
if err == nil {
|
||||
option.activeDefaultValue = value
|
||||
option.activeDefaultValue = valueCache
|
||||
}
|
||||
}
|
||||
option.Unlock()
|
||||
|
|
120
config/validate.go
Normal file
120
config/validate.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
type valueCache struct {
|
||||
stringVal string
|
||||
stringArrayVal []string
|
||||
intVal int64
|
||||
boolVal bool
|
||||
}
|
||||
|
||||
func (vc *valueCache) getData(opt *Option) interface{} {
|
||||
switch opt.OptType {
|
||||
case OptTypeBool:
|
||||
return vc.boolVal
|
||||
case OptTypeInt:
|
||||
return vc.intVal
|
||||
case OptTypeString:
|
||||
return vc.stringVal
|
||||
case OptTypeStringArray:
|
||||
return vc.stringArrayVal
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func validateValue(option *Option, value interface{}) (*valueCache, error) { //nolint:gocyclo
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if option.OptType != OptTypeString {
|
||||
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
if !option.compiledRegex.MatchString(v) {
|
||||
return nil, fmt.Errorf("validation of option %s failed: string \"%s\" did not match validation regex for option", option.Key, v)
|
||||
}
|
||||
}
|
||||
return &valueCache{stringVal: v}, nil
|
||||
case []interface{}:
|
||||
vConverted := make([]string, len(v))
|
||||
for pos, entry := range v {
|
||||
s, ok := entry.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("validation of option %s failed: element %+v at index %d is not a string", option.Key, entry, pos)
|
||||
|
||||
}
|
||||
vConverted[pos] = s
|
||||
}
|
||||
// continue to next case
|
||||
return validateValue(option, vConverted)
|
||||
case []string:
|
||||
if option.OptType != OptTypeStringArray {
|
||||
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
for pos, entry := range v {
|
||||
if !option.compiledRegex.MatchString(entry) {
|
||||
return nil, fmt.Errorf("validation of option %s failed: string \"%s\" at index %d did not match validation regex", option.Key, entry, pos)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &valueCache{stringArrayVal: v}, nil
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64:
|
||||
// uint64 is omitted, as it does not fit in a int64
|
||||
if option.OptType != OptTypeInt {
|
||||
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
|
||||
}
|
||||
if option.compiledRegex != nil {
|
||||
// we need to use %v here so we handle float and int correctly.
|
||||
if !option.compiledRegex.MatchString(fmt.Sprintf("%v", v)) {
|
||||
return nil, fmt.Errorf("validation of option %s failed: number \"%d\" did not match validation regex", option.Key, v)
|
||||
}
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
case int8:
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
case int16:
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
case int32:
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
case int64:
|
||||
return &valueCache{intVal: v}, nil
|
||||
case uint:
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
case uint8:
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
case uint16:
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
case uint32:
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
case float32:
|
||||
// convert if float has no decimals
|
||||
if math.Remainder(float64(v), 1) == 0 {
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to convert float32 to int64 for option %s, got value %+v", option.Key, v)
|
||||
case float64:
|
||||
// convert if float has no decimals
|
||||
if math.Remainder(v, 1) == 0 {
|
||||
return &valueCache{intVal: int64(v)}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to convert float64 to int64 for option %s, got value %+v", option.Key, v)
|
||||
default:
|
||||
return nil, errors.New("internal error")
|
||||
}
|
||||
case bool:
|
||||
if option.OptType != OptTypeBool {
|
||||
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
|
||||
}
|
||||
return &valueCache{boolVal: v}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid option value type for option %s: %T", option.Key, value)
|
||||
}
|
||||
}
|
30
config/validity.go
Normal file
30
config/validity.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// ValidityFlag is a flag that signifies if the configuration has been changed. It is not safe for concurrent use.
|
||||
type ValidityFlag struct {
|
||||
flag *abool.AtomicBool
|
||||
}
|
||||
|
||||
// NewValidityFlag returns a flag that signifies if the configuration has been changed.
|
||||
func NewValidityFlag() *ValidityFlag {
|
||||
vf := &ValidityFlag{}
|
||||
vf.Refresh()
|
||||
return vf
|
||||
}
|
||||
|
||||
// IsValid returns if the configuration is still valid.
|
||||
func (vf *ValidityFlag) IsValid() bool {
|
||||
return vf.flag.IsSet()
|
||||
}
|
||||
|
||||
// Refresh refreshes the flag and makes it reusable.
|
||||
func (vf *ValidityFlag) Refresh() {
|
||||
validityFlagLock.RLock()
|
||||
defer validityFlagLock.RUnlock()
|
||||
|
||||
vf.flag = validityFlag
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
@ -119,10 +120,13 @@ func (c *Controller) Put(r record.Record) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
err = c.storage.Put(r)
|
||||
r, err = c.storage.Put(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r == nil {
|
||||
return errors.New("storage returned nil record after successful put operation")
|
||||
}
|
||||
|
||||
// process subscriptions
|
||||
for _, sub := range c.subscriptions {
|
||||
|
@ -137,6 +141,32 @@ func (c *Controller) Put(r record.Record) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
func (c *Controller) PutMany() (chan<- record.Record, <-chan error) {
|
||||
c.writeLock.RLock()
|
||||
defer c.writeLock.RUnlock()
|
||||
|
||||
if shuttingDown.IsSet() {
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrShuttingDown
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
if c.ReadOnly() {
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrReadOnly
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
if batcher, ok := c.storage.(storage.Batcher); ok {
|
||||
return batcher.PutMany()
|
||||
}
|
||||
|
||||
errs := make(chan error, 1)
|
||||
errs <- ErrNotImplemented
|
||||
return make(chan record.Record), errs
|
||||
}
|
||||
|
||||
// Query executes the given query on the database.
|
||||
func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
|
||||
c.readLock.RLock()
|
||||
|
|
|
@ -10,10 +10,13 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/database/record"
|
||||
|
||||
q "github.com/safing/portbase/database/query"
|
||||
_ "github.com/safing/portbase/database/storage/badger"
|
||||
_ "github.com/safing/portbase/database/storage/bbolt"
|
||||
_ "github.com/safing/portbase/database/storage/fstree"
|
||||
_ "github.com/safing/portbase/database/storage/hashmap"
|
||||
)
|
||||
|
||||
func makeKey(dbName, key string) string {
|
||||
|
@ -39,7 +42,10 @@ func testDatabase(t *testing.T, storageType string) {
|
|||
}
|
||||
|
||||
// interface
|
||||
db := NewInterface(nil)
|
||||
db := NewInterface(&Options{
|
||||
Local: true,
|
||||
Internal: true,
|
||||
})
|
||||
|
||||
// sub
|
||||
sub, err := db.Subscribe(q.New(dbName).MustBeValid())
|
||||
|
@ -107,6 +113,18 @@ func testDatabase(t *testing.T, storageType string) {
|
|||
t.Fatalf("expected two records, got %d", cnt)
|
||||
}
|
||||
|
||||
switch storageType {
|
||||
case "bbolt", "hashmap":
|
||||
batchPut := db.PutMany(dbName)
|
||||
records := []record.Record{A, B, C, nil} // nil is to signify finish
|
||||
for _, r := range records {
|
||||
err = batchPut(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = hook.Cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -128,12 +146,12 @@ func TestDatabaseSystem(t *testing.T) {
|
|||
os.Exit(1)
|
||||
}()
|
||||
|
||||
testDir, err := ioutil.TempDir("", "testing-")
|
||||
testDir, err := ioutil.TempDir("", "portbase-database-testing-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = Initialize(testDir, nil)
|
||||
err = InitializeWithPath(testDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -142,6 +160,7 @@ func TestDatabaseSystem(t *testing.T) {
|
|||
testDatabase(t, "badger")
|
||||
testDatabase(t, "bbolt")
|
||||
testDatabase(t, "fstree")
|
||||
testDatabase(t, "hashmap")
|
||||
|
||||
err = MaintainRecordStates()
|
||||
if err != nil {
|
||||
|
|
|
@ -4,41 +4,44 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
databasePath string
|
||||
databaseStructureRoot *utils.DirStructure
|
||||
|
||||
module *modules.Module
|
||||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("database", prep, start, stop, "base")
|
||||
module = modules.Register("database", prep, start, stop)
|
||||
}
|
||||
|
||||
// SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure.
|
||||
func SetDatabaseLocation(dirPath string, dirStructureRoot *utils.DirStructure) {
|
||||
databasePath = dirPath
|
||||
databaseStructureRoot = dirStructureRoot
|
||||
func SetDatabaseLocation(dirStructureRoot *utils.DirStructure) {
|
||||
if databaseStructureRoot == nil {
|
||||
databaseStructureRoot = dirStructureRoot
|
||||
}
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
if databasePath == "" && databaseStructureRoot == nil {
|
||||
return errors.New("no database location specified")
|
||||
SetDatabaseLocation(dataroot.Root())
|
||||
if databaseStructureRoot == nil {
|
||||
return errors.New("database location not specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
err := database.Initialize(databasePath, databaseStructureRoot)
|
||||
err := database.Initialize(databaseStructureRoot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
registerMaintenanceTasks()
|
||||
startMaintenanceTasks()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -5,33 +5,23 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
func registerMaintenanceTasks() {
|
||||
func startMaintenanceTasks() {
|
||||
module.NewTask("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute)
|
||||
module.NewTask("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour)
|
||||
module.NewTask("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour)
|
||||
}
|
||||
|
||||
func maintainBasic(ctx context.Context, task *modules.Task) {
|
||||
err := database.Maintain()
|
||||
if err != nil {
|
||||
log.Errorf("database: maintenance error: %s", err)
|
||||
}
|
||||
func maintainBasic(ctx context.Context, task *modules.Task) error {
|
||||
return database.Maintain()
|
||||
}
|
||||
|
||||
func maintainThorough(ctx context.Context, task *modules.Task) {
|
||||
err := database.MaintainThorough()
|
||||
if err != nil {
|
||||
log.Errorf("database: thorough maintenance error: %s", err)
|
||||
}
|
||||
func maintainThorough(ctx context.Context, task *modules.Task) error {
|
||||
return database.MaintainThorough()
|
||||
}
|
||||
|
||||
func maintainRecords(ctx context.Context, task *modules.Task) {
|
||||
err := database.MaintainRecordStates()
|
||||
if err != nil {
|
||||
log.Errorf("database: record states maintenance error: %s", err)
|
||||
}
|
||||
func maintainRecords(ctx context.Context, task *modules.Task) error {
|
||||
return database.MaintainRecordStates()
|
||||
}
|
||||
|
|
|
@ -10,4 +10,5 @@ var (
|
|||
ErrPermissionDenied = errors.New("access to database record denied")
|
||||
ErrReadOnly = errors.New("database is read only")
|
||||
ErrShuttingDown = errors.New("database system is shutting down")
|
||||
ErrNotImplemented = errors.New("not implemented by this storage")
|
||||
)
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/bluele/gcache"
|
||||
|
||||
"github.com/safing/portbase/database/accessor"
|
||||
|
@ -170,10 +172,19 @@ func (i *Interface) InsertValue(key string, attribute string, value interface{})
|
|||
}
|
||||
|
||||
// Put saves a record to the database.
|
||||
func (i *Interface) Put(r record.Record) error {
|
||||
_, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
|
||||
if err != nil && err != ErrNotFound {
|
||||
return err
|
||||
func (i *Interface) Put(r record.Record) (err error) {
|
||||
// get record or only database
|
||||
var db *Controller
|
||||
if !i.options.Internal || !i.options.Local {
|
||||
_, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
|
||||
if err != nil && err != ErrNotFound {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
db, err = getController(r.DatabaseKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
|
@ -186,24 +197,122 @@ func (i *Interface) Put(r record.Record) error {
|
|||
}
|
||||
|
||||
// PutNew saves a record to the database as a new record (ie. with new timestamps).
|
||||
func (i *Interface) PutNew(r record.Record) error {
|
||||
_, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
|
||||
if err != nil && err != ErrNotFound {
|
||||
return err
|
||||
func (i *Interface) PutNew(r record.Record) (err error) {
|
||||
// get record or only database
|
||||
var db *Controller
|
||||
if !i.options.Internal || !i.options.Local {
|
||||
_, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
|
||||
if err != nil && err != ErrNotFound {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
db, err = getController(r.DatabaseKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
if r.Meta() == nil {
|
||||
r.CreateMeta()
|
||||
if r.Meta() != nil {
|
||||
r.Meta().Reset()
|
||||
}
|
||||
r.Meta().Reset()
|
||||
i.options.Apply(r)
|
||||
i.updateCache(r)
|
||||
return db.Put(r)
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database. Warning: This is nearly a direct database access and omits many things:
|
||||
// - Record locking
|
||||
// - Hooks
|
||||
// - Subscriptions
|
||||
// - Caching
|
||||
func (i *Interface) PutMany(dbName string) (put func(record.Record) error) {
|
||||
interfaceBatch := make(chan record.Record, 100)
|
||||
|
||||
// permission check
|
||||
if !i.options.Internal || !i.options.Local {
|
||||
return func(r record.Record) error {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
// get database
|
||||
db, err := getController(dbName)
|
||||
if err != nil {
|
||||
return func(r record.Record) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// start database access
|
||||
dbBatch, errs := db.PutMany()
|
||||
finished := abool.New()
|
||||
var internalErr error
|
||||
|
||||
// interface options proxy
|
||||
go func() {
|
||||
defer close(dbBatch) // signify that we are finished
|
||||
for {
|
||||
select {
|
||||
case r := <-interfaceBatch:
|
||||
// finished?
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
// apply options
|
||||
i.options.Apply(r)
|
||||
// pass along
|
||||
dbBatch <- r
|
||||
case <-time.After(1 * time.Second):
|
||||
// bail out
|
||||
internalErr = errors.New("timeout: putmany unused for too long")
|
||||
finished.Set()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return func(r record.Record) error {
|
||||
// finished?
|
||||
if finished.IsSet() {
|
||||
// check for internal error
|
||||
if internalErr != nil {
|
||||
return internalErr
|
||||
}
|
||||
// check for previous error
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
default:
|
||||
return errors.New("batch is closed")
|
||||
}
|
||||
}
|
||||
|
||||
// finish?
|
||||
if r == nil {
|
||||
finished.Set()
|
||||
interfaceBatch <- nil // signify that we are finished
|
||||
// do not close, as this fn could be called again with nil.
|
||||
return <-errs
|
||||
}
|
||||
|
||||
// check record scope
|
||||
if r.DatabaseName() != dbName {
|
||||
return errors.New("record out of database scope")
|
||||
}
|
||||
|
||||
// submit
|
||||
select {
|
||||
case interfaceBatch <- r:
|
||||
return nil
|
||||
case err := <-errs:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetAbsoluteExpiry sets an absolute record expiry.
|
||||
func (i *Interface) SetAbsoluteExpiry(key string, time int64) error {
|
||||
r, db, err := i.getRecord(getDBFromKey, key, true, true)
|
||||
|
|
|
@ -23,15 +23,15 @@ var (
|
|||
databasesStructure *utils.DirStructure
|
||||
)
|
||||
|
||||
// Initialize initializes the database at the specified location. Supply either a path or dir structure.
|
||||
func Initialize(dirPath string, dirStructureRoot *utils.DirStructure) error {
|
||||
if initialized.SetToIf(false, true) {
|
||||
// InitializeWithPath initializes the database at the specified location using a path.
|
||||
func InitializeWithPath(dirPath string) error {
|
||||
return Initialize(utils.NewDirStructure(dirPath, 0755))
|
||||
}
|
||||
|
||||
if dirStructureRoot != nil {
|
||||
rootStructure = dirStructureRoot
|
||||
} else {
|
||||
rootStructure = utils.NewDirStructure(dirPath, 0755)
|
||||
}
|
||||
// Initialize initializes the database at the specified location using a dir structure.
|
||||
func Initialize(dirStructureRoot *utils.DirStructure) error {
|
||||
if initialized.SetToIf(false, true) {
|
||||
rootStructure = dirStructureRoot
|
||||
|
||||
// ensure root and databases dirs
|
||||
databasesStructure = rootStructure.ChildDir(databasesSubDir, 0700)
|
||||
|
|
|
@ -1,16 +1,10 @@
|
|||
package record
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type TestRecord struct {
|
||||
Base
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (tm *TestRecord) Lock() {
|
||||
tm.lock.Lock()
|
||||
}
|
||||
|
||||
func (tm *TestRecord) Unlock() {
|
||||
tm.lock.Unlock()
|
||||
sync.Mutex
|
||||
}
|
||||
|
|
|
@ -37,27 +37,19 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) {
|
|||
offset += n
|
||||
|
||||
newMeta := &Meta{}
|
||||
if len(metaSection) == 34 && metaSection[4] == 0 {
|
||||
// TODO: remove in 2020
|
||||
// backward compatibility:
|
||||
// format would byte shift and populate metaSection[4] with value > 0 (would naturally populate >0 at 07.02.2106 07:28:15)
|
||||
// this must be gencode without format
|
||||
_, err = newMeta.GenCodeUnmarshal(metaSection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
|
||||
}
|
||||
} else {
|
||||
_, err = dsd.Load(metaSection, newMeta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
|
||||
}
|
||||
_, err = dsd.Load(metaSection, newMeta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
|
||||
}
|
||||
|
||||
format, n, err := varint.Unpack8(data[offset:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get dsd format: %s", err)
|
||||
var format uint8 = dsd.NONE
|
||||
if !newMeta.IsDeleted() {
|
||||
format, n, err = varint.Unpack8(data[offset:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get dsd format: %s", err)
|
||||
}
|
||||
offset += n
|
||||
}
|
||||
offset += n
|
||||
|
||||
return &Wrapper{
|
||||
Base{
|
||||
|
|
|
@ -2,10 +2,7 @@ package record
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portbase/container"
|
||||
)
|
||||
|
||||
func TestWrapper(t *testing.T) {
|
||||
|
@ -54,43 +51,4 @@ func TestWrapper(t *testing.T) {
|
|||
if !bytes.Equal(testData, wrapper2.Data) {
|
||||
t.Error("marshal mismatch")
|
||||
}
|
||||
|
||||
// test new format
|
||||
oldRaw, err := oldWrapperMarshalRecord(wrapper, wrapper)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrapper3, err := NewRawWrapper("test", "a", oldRaw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(testData, wrapper3.Data) {
|
||||
t.Error("marshal mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func oldWrapperMarshalRecord(w *Wrapper, r Record) ([]byte, error) {
|
||||
if w.Meta() == nil {
|
||||
return nil, errors.New("missing meta")
|
||||
}
|
||||
|
||||
// version
|
||||
c := container.New([]byte{1})
|
||||
|
||||
// meta
|
||||
metaSection, err := w.meta.GenCodeMarshal(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.AppendAsBlock(metaSection)
|
||||
|
||||
// data
|
||||
dataSection, err := w.Marshal(r, JSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Append(dataSection)
|
||||
|
||||
return c.CompileData(), nil
|
||||
}
|
||||
|
|
|
@ -139,7 +139,7 @@ func saveRegistry(lock bool) error {
|
|||
}
|
||||
|
||||
// write file
|
||||
// FIXME: write atomically (best effort)
|
||||
// TODO: write atomically (best effort)
|
||||
filePath := path.Join(rootStructure.Path, registryFileName)
|
||||
return ioutil.WriteFile(filePath, data, 0600)
|
||||
}
|
||||
|
|
|
@ -82,16 +82,19 @@ func (b *Badger) Get(key string) (record.Record, error) {
|
|||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (b *Badger) Put(r record.Record) error {
|
||||
func (b *Badger) Put(r record.Record) (record.Record, error) {
|
||||
data, err := r.MarshalRecord(r)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = b.db.Update(func(txn *badger.Txn) error {
|
||||
return txn.Set([]byte(r.DatabaseKey()), data)
|
||||
})
|
||||
return err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
|
|
|
@ -65,7 +65,7 @@ func TestBadger(t *testing.T) {
|
|||
a.SetKey("test:A")
|
||||
|
||||
// put record
|
||||
err = db.Put(a)
|
||||
_, err = db.Put(a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -86,10 +86,10 @@ func (b *BBolt) Get(key string) (record.Record, error) {
|
|||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (b *BBolt) Put(r record.Record) error {
|
||||
func (b *BBolt) Put(r record.Record) (record.Record, error) {
|
||||
data, err := r.MarshalRecord(r)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = b.db.Update(func(tx *bbolt.Tx) error {
|
||||
|
@ -100,9 +100,38 @@ func (b *BBolt) Put(r record.Record) error {
|
|||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return nil
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
func (b *BBolt) PutMany() (chan<- record.Record, <-chan error) {
|
||||
batch := make(chan record.Record, 100)
|
||||
errs := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := b.db.Batch(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(bucketName)
|
||||
for r := range batch {
|
||||
// marshal
|
||||
data, txErr := r.MarshalRecord(r)
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
|
||||
// put
|
||||
txErr = bucket.Put([]byte(r.DatabaseKey()), data)
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
errs <- err
|
||||
}()
|
||||
|
||||
return batch, errs
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
|
|
|
@ -65,7 +65,7 @@ func TestBBolt(t *testing.T) {
|
|||
a.SetKey("test:A")
|
||||
|
||||
// put record
|
||||
err = db.Put(a)
|
||||
_, err = db.Put(a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -100,15 +100,15 @@ func TestBBolt(t *testing.T) {
|
|||
qZ.SetKey("test:z")
|
||||
qZ.CreateMeta()
|
||||
// put
|
||||
err = db.Put(qA)
|
||||
_, err = db.Put(qA)
|
||||
if err == nil {
|
||||
err = db.Put(qB)
|
||||
_, err = db.Put(qB)
|
||||
}
|
||||
if err == nil {
|
||||
err = db.Put(qC)
|
||||
_, err = db.Put(qC)
|
||||
}
|
||||
if err == nil {
|
||||
err = db.Put(qZ)
|
||||
_, err = db.Put(qZ)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -104,15 +104,15 @@ func (fst *FSTree) Get(key string) (record.Record, error) {
|
|||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (fst *FSTree) Put(r record.Record) error {
|
||||
func (fst *FSTree) Put(r record.Record) (record.Record, error) {
|
||||
dstPath, err := fst.buildFilePath(r.DatabaseKey(), true)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := r.MarshalRecord(r)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = writeFile(dstPath, data, defaultFileMode)
|
||||
|
@ -120,15 +120,15 @@ func (fst *FSTree) Put(r record.Record) error {
|
|||
// create dir and try again
|
||||
err = os.MkdirAll(filepath.Dir(dstPath), defaultDirMode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fstree: failed to create directory %s: %s", filepath.Dir(dstPath), err)
|
||||
return nil, fmt.Errorf("fstree: failed to create directory %s: %s", filepath.Dir(dstPath), err)
|
||||
}
|
||||
err = writeFile(dstPath, data, defaultFileMode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fstree: could not write file %s: %s", dstPath, err)
|
||||
return nil, fmt.Errorf("fstree: could not write file %s: %s", dstPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
|
|
|
@ -44,12 +44,33 @@ func (hm *HashMap) Get(key string) (record.Record, error) {
|
|||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (hm *HashMap) Put(r record.Record) error {
|
||||
func (hm *HashMap) Put(r record.Record) (record.Record, error) {
|
||||
hm.dbLock.Lock()
|
||||
defer hm.dbLock.Unlock()
|
||||
|
||||
hm.db[r.DatabaseKey()] = r
|
||||
return nil
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
func (hm *HashMap) PutMany() (chan<- record.Record, <-chan error) {
|
||||
hm.dbLock.Lock()
|
||||
defer hm.dbLock.Unlock()
|
||||
// we could lock for every record, but we want to have the same behaviour
|
||||
// as the other storage backends, especially for testing.
|
||||
|
||||
batch := make(chan record.Record, 100)
|
||||
errs := make(chan error, 1)
|
||||
|
||||
// start handler
|
||||
go func() {
|
||||
for r := range batch {
|
||||
hm.db[r.DatabaseKey()] = r
|
||||
}
|
||||
errs <- nil
|
||||
}()
|
||||
|
||||
return batch, errs
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
|
|
|
@ -57,7 +57,7 @@ func TestHashMap(t *testing.T) {
|
|||
a.SetKey("test:A")
|
||||
|
||||
// put record
|
||||
err = db.Put(a)
|
||||
_, err = db.Put(a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -86,15 +86,15 @@ func TestHashMap(t *testing.T) {
|
|||
qZ.SetKey("test:z")
|
||||
qZ.CreateMeta()
|
||||
// put
|
||||
err = db.Put(qA)
|
||||
_, err = db.Put(qA)
|
||||
if err == nil {
|
||||
err = db.Put(qB)
|
||||
_, err = db.Put(qB)
|
||||
}
|
||||
if err == nil {
|
||||
err = db.Put(qC)
|
||||
_, err = db.Put(qC)
|
||||
}
|
||||
if err == nil {
|
||||
err = db.Put(qZ)
|
||||
_, err = db.Put(qZ)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -21,8 +21,16 @@ func (i *InjectBase) Get(key string) (record.Record, error) {
|
|||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (i *InjectBase) Put(m record.Record) error {
|
||||
return errNotImplemented
|
||||
func (i *InjectBase) Put(m record.Record) (record.Record, error) {
|
||||
return nil, errNotImplemented
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
func (i *InjectBase) PutMany() (batch chan record.Record, err chan error) {
|
||||
batch = make(chan record.Record)
|
||||
err = make(chan error, 1)
|
||||
err <- errNotImplemented
|
||||
return
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
// Interface defines the database storage API.
|
||||
type Interface interface {
|
||||
Get(key string) (record.Record, error)
|
||||
Put(m record.Record) error
|
||||
Put(m record.Record) (record.Record, error)
|
||||
Delete(key string) error
|
||||
Query(q *query.Query, local, internal bool) (*iterator.Iterator, error)
|
||||
|
||||
|
@ -19,3 +19,8 @@ type Interface interface {
|
|||
MaintainThorough() error
|
||||
Shutdown() error
|
||||
}
|
||||
|
||||
// Batcher defines the database storage API for backends that support batch operations.
|
||||
type Batcher interface {
|
||||
PutMany() (batch chan<- record.Record, errs <-chan error)
|
||||
}
|
||||
|
|
|
@ -36,8 +36,24 @@ func (s *Sinkhole) Get(key string) (record.Record, error) {
|
|||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (s *Sinkhole) Put(m record.Record) error {
|
||||
return nil
|
||||
func (s *Sinkhole) Put(r record.Record) (record.Record, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// PutMany stores many records in the database.
|
||||
func (s *Sinkhole) PutMany() (chan<- record.Record, <-chan error) {
|
||||
batch := make(chan record.Record, 100)
|
||||
errs := make(chan error, 1)
|
||||
|
||||
// start handler
|
||||
go func() {
|
||||
for range batch {
|
||||
// nom, nom, nom
|
||||
}
|
||||
errs <- nil
|
||||
}()
|
||||
|
||||
return batch, errs
|
||||
}
|
||||
|
||||
// Delete deletes a record from the database.
|
||||
|
|
|
@ -13,7 +13,6 @@ type Subscription struct {
|
|||
canceled bool
|
||||
|
||||
Feed chan record.Record
|
||||
Err error
|
||||
}
|
||||
|
||||
// Cancel cancels the subscription.
|
||||
|
|
27
dataroot/root.go
Normal file
27
dataroot/root.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package dataroot
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/safing/portbase/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
root *utils.DirStructure
|
||||
)
|
||||
|
||||
// Initialize initializes the data root directory
|
||||
func Initialize(rootDir string, perm os.FileMode) error {
|
||||
if root != nil {
|
||||
return errors.New("already initialized")
|
||||
}
|
||||
|
||||
root = utils.NewDirStructure(rootDir, perm)
|
||||
return root.Ensure()
|
||||
}
|
||||
|
||||
// Root returns the data root directory.
|
||||
func Root() *utils.DirStructure {
|
||||
return root
|
||||
}
|
|
@ -8,6 +8,6 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]")
|
||||
flag.StringVar(&logLevelFlag, "log", "", "set log level to [trace|debug|info|warning|error|critical]")
|
||||
flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,notifications=debug")
|
||||
}
|
||||
|
|
38
log/formatting_darwin.go
Normal file
38
log/formatting_darwin.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package log
|
||||
|
||||
const (
|
||||
rightArrow = "▶"
|
||||
leftArrow = "◀"
|
||||
)
|
||||
|
||||
const (
|
||||
// colorBlack = "\033[30m"
|
||||
colorRed = "\033[31m"
|
||||
// colorGreen = "\033[32m"
|
||||
colorYellow = "\033[33m"
|
||||
colorBlue = "\033[34m"
|
||||
colorMagenta = "\033[35m"
|
||||
colorCyan = "\033[36m"
|
||||
// colorWhite = "\033[37m"
|
||||
)
|
||||
|
||||
func (s Severity) color() string {
|
||||
switch s {
|
||||
case DebugLevel:
|
||||
return colorCyan
|
||||
case InfoLevel:
|
||||
return colorBlue
|
||||
case WarningLevel:
|
||||
return colorYellow
|
||||
case ErrorLevel:
|
||||
return colorRed
|
||||
case CriticalLevel:
|
||||
return colorMagenta
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func endColor() string {
|
||||
return "\033[0m"
|
||||
}
|
|
@ -12,7 +12,7 @@ func log(level Severity, msg string, tracer *ContextTracer) {
|
|||
|
||||
if !started.IsSet() {
|
||||
// a bit resource intense, but keeps logs before logging started.
|
||||
// FIXME: create option to disable logging
|
||||
// TODO: create option to disable logging
|
||||
go func() {
|
||||
<-startedSignal
|
||||
log(level, msg, tracer)
|
||||
|
|
|
@ -82,6 +82,7 @@ var (
|
|||
logsWaiting = make(chan struct{}, 4)
|
||||
logsWaitingFlag = abool.NewBool(false)
|
||||
|
||||
shutdownFlag = abool.NewBool(false)
|
||||
shutdownSignal = make(chan struct{})
|
||||
shutdownWaitGroup sync.WaitGroup
|
||||
|
||||
|
@ -136,12 +137,14 @@ func Start() (err error) {
|
|||
|
||||
logBuffer = make(chan *logLine, 1024)
|
||||
|
||||
initialLogLevel := ParseLevel(logLevelFlag)
|
||||
if initialLogLevel > 0 {
|
||||
if logLevelFlag != "" {
|
||||
initialLogLevel := ParseLevel(logLevelFlag)
|
||||
if initialLogLevel == 0 {
|
||||
fmt.Fprintf(os.Stderr, "log warning: invalid log level \"%s\", falling back to level info\n", logLevelFlag)
|
||||
initialLogLevel = InfoLevel
|
||||
}
|
||||
|
||||
SetLogLevel(initialLogLevel)
|
||||
} else {
|
||||
err = fmt.Errorf("log warning: invalid log level \"%s\", falling back to level info", logLevelFlag)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
|
||||
}
|
||||
|
||||
// get and set file loglevels
|
||||
|
@ -179,6 +182,8 @@ func Start() (err error) {
|
|||
|
||||
// Shutdown writes remaining log lines and then stops the log system.
|
||||
func Shutdown() {
|
||||
close(shutdownSignal)
|
||||
if shutdownFlag.SetToIf(false, true) {
|
||||
close(shutdownSignal)
|
||||
}
|
||||
shutdownWaitGroup.Wait()
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ func writeLine(line *logLine, duplicates uint64) {
|
|||
}
|
||||
|
||||
func startWriter() {
|
||||
fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor()))
|
||||
fmt.Printf("%s%s %s BOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor())
|
||||
|
||||
shutdownWaitGroup.Add(1)
|
||||
go writerManager()
|
||||
|
@ -168,7 +168,7 @@ func finalizeWriting() {
|
|||
case line := <-logBuffer:
|
||||
writeLine(line, 0)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
fmt.Println(fmt.Sprintf("%s%s %s EOF%s", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor()))
|
||||
fmt.Printf("%s%s %s EOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -89,7 +89,7 @@ func (tracer *ContextTracer) Submit() {
|
|||
|
||||
if !started.IsSet() {
|
||||
// a bit resource intense, but keeps logs before logging started.
|
||||
// FIXME: create option to disable logging
|
||||
// TODO: create option to disable logging
|
||||
go func() {
|
||||
<-startedSignal
|
||||
tracer.Submit()
|
||||
|
|
|
@ -17,7 +17,9 @@ type eventHook struct {
|
|||
|
||||
// TriggerEvent executes all hook functions registered to the specified event.
|
||||
func (m *Module) TriggerEvent(event string, data interface{}) {
|
||||
go m.processEventTrigger(event, data)
|
||||
if m.OnlineSoon() {
|
||||
go m.processEventTrigger(event, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) processEventTrigger(event string, data interface{}) {
|
||||
|
@ -31,18 +33,35 @@ func (m *Module) processEventTrigger(event string, data interface{}) {
|
|||
}
|
||||
|
||||
for _, hook := range hooks {
|
||||
if !hook.hookingModule.ShutdownInProgress() {
|
||||
if hook.hookingModule.OnlineSoon() {
|
||||
go m.runEventHook(hook, event, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) runEventHook(hook *eventHook, event string, data interface{}) {
|
||||
if !hook.hookingModule.Started.IsSet() {
|
||||
// check if source module is ready for handling
|
||||
if m.Status() != StatusOnline {
|
||||
// target module has not yet fully started, wait until start is complete
|
||||
select {
|
||||
case <-startCompleteSignal:
|
||||
case <-shutdownSignal:
|
||||
case <-m.StartCompleted():
|
||||
// continue with hook execution
|
||||
case <-hook.hookingModule.Stopping():
|
||||
return
|
||||
case <-m.Stopping():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// check if destionation module is ready for handling
|
||||
if hook.hookingModule.Status() != StatusOnline {
|
||||
// target module has not yet fully started, wait until start is complete
|
||||
select {
|
||||
case <-hook.hookingModule.StartCompleted():
|
||||
// continue with hook execution
|
||||
case <-hook.hookingModule.Stopping():
|
||||
return
|
||||
case <-m.Stopping():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -69,7 +88,7 @@ func (m *Module) RegisterEvent(event string) {
|
|||
}
|
||||
}
|
||||
|
||||
// RegisterEventHook registers a hook function with (another) modules' event. Whenever a hook is triggered and the receiving module has not yet fully started, hook execution will be delayed until all modules completed starting.
|
||||
// RegisterEventHook registers a hook function with (another) modules' event. Whenever a hook is triggered and the receiving module has not yet fully started, hook execution will be delayed until the modules completed starting.
|
||||
func (m *Module) RegisterEventHook(module string, event string, description string, fn func(context.Context, interface{}) error) error {
|
||||
// get target module
|
||||
var eventModule *Module
|
||||
|
@ -77,9 +96,7 @@ func (m *Module) RegisterEventHook(module string, event string, description stri
|
|||
eventModule = m
|
||||
} else {
|
||||
var ok bool
|
||||
modulesLock.RLock()
|
||||
eventModule, ok = modules[module]
|
||||
modulesLock.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf(`module "%s" does not exist`, module)
|
||||
}
|
||||
|
|
|
@ -1,18 +1,22 @@
|
|||
package modules
|
||||
|
||||
import "flag"
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
// HelpFlag triggers printing flag.Usage. It's exported for custom help handling.
|
||||
HelpFlag bool
|
||||
HelpFlag bool
|
||||
printGraphFlag bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.BoolVar(&HelpFlag, "help", false, "print help")
|
||||
flag.BoolVar(&printGraphFlag, "print-module-graph", false, "print the module dependency graph")
|
||||
}
|
||||
|
||||
func parseFlags() error {
|
||||
|
||||
// parse flags
|
||||
flag.Parse()
|
||||
|
||||
|
@ -21,5 +25,36 @@ func parseFlags() error {
|
|||
return ErrCleanExit
|
||||
}
|
||||
|
||||
if printGraphFlag {
|
||||
printGraph()
|
||||
return ErrCleanExit
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printGraph() {
|
||||
// mark roots
|
||||
for _, module := range modules {
|
||||
if len(module.depReverse) == 0 {
|
||||
// is root, dont print deps in dep tree
|
||||
module.stopFlag.Set()
|
||||
}
|
||||
}
|
||||
// print
|
||||
for _, module := range modules {
|
||||
if module.stopFlag.IsSet() {
|
||||
// print from root
|
||||
printModuleGraph("", module, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func printModuleGraph(prefix string, module *Module, root bool) {
|
||||
fmt.Printf("%s├── %s\n", prefix, module.Name)
|
||||
if root || !module.stopFlag.IsSet() {
|
||||
for _, dep := range module.Dependencies() {
|
||||
printModuleGraph(fmt.Sprintf("│ %s", prefix), dep, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
107
modules/mgmt.go
Normal file
107
modules/mgmt.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
var (
|
||||
moduleMgmtEnabled = abool.NewBool(false)
|
||||
modulesChangeNotifyFn func(*Module)
|
||||
)
|
||||
|
||||
// Enable enables the module. Only has an effect if module management is enabled.
|
||||
func (m *Module) Enable() (changed bool) {
|
||||
return m.enabled.SetToIf(false, true)
|
||||
}
|
||||
|
||||
// Disable disables the module. Only has an effect if module management is enabled.
|
||||
func (m *Module) Disable() (changed bool) {
|
||||
return m.enabled.SetToIf(true, false)
|
||||
}
|
||||
|
||||
// SetEnabled sets the module to the desired enabled state. Only has an effect if module management is enabled.
|
||||
func (m *Module) SetEnabled(enable bool) (changed bool) {
|
||||
if enable {
|
||||
return m.Enable()
|
||||
}
|
||||
return m.Disable()
|
||||
}
|
||||
|
||||
// Enabled returns wether or not the module is currently enabled.
|
||||
func (m *Module) Enabled() bool {
|
||||
return m.enabled.IsSet()
|
||||
}
|
||||
|
||||
// EnableModuleManagement enables the module management functionality within modules. The supplied notify function will be called whenever the status of a module changes. The affected module will be in the parameter. You will need to manually enable modules, else nothing will start.
|
||||
func EnableModuleManagement(changeNotifyFn func(*Module)) {
|
||||
if moduleMgmtEnabled.SetToIf(false, true) {
|
||||
modulesChangeNotifyFn = changeNotifyFn
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) notifyOfChange() {
|
||||
if moduleMgmtEnabled.IsSet() && modulesChangeNotifyFn != nil {
|
||||
m.StartWorker("notify of change", func(ctx context.Context) error {
|
||||
modulesChangeNotifyFn(m)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ManageModules triggers the module manager to react to recent changes of enabled modules.
|
||||
func ManageModules() error {
|
||||
// check if enabled
|
||||
if !moduleMgmtEnabled.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// lock mgmt
|
||||
mgmtLock.Lock()
|
||||
defer mgmtLock.Unlock()
|
||||
|
||||
log.Info("modules: managing changes")
|
||||
|
||||
// build new dependency tree
|
||||
buildEnabledTree()
|
||||
|
||||
// stop unneeded modules
|
||||
lastErr := stopModules()
|
||||
if lastErr != nil {
|
||||
log.Warning(lastErr.Error())
|
||||
}
|
||||
|
||||
// start needed modules
|
||||
err := startModules()
|
||||
if err != nil {
|
||||
log.Warning(err.Error())
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
log.Info("modules: finished managing")
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func buildEnabledTree() {
|
||||
// reset marked dependencies
|
||||
for _, m := range modules {
|
||||
m.enabledAsDependency.UnSet()
|
||||
}
|
||||
|
||||
// mark dependencies
|
||||
for _, m := range modules {
|
||||
if m.enabled.IsSet() {
|
||||
m.markDependencies()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) markDependencies() {
|
||||
for _, dep := range m.depModules {
|
||||
if dep.enabledAsDependency.SetToIf(false, true) {
|
||||
dep.markDependencies()
|
||||
}
|
||||
}
|
||||
}
|
165
modules/mgmt_test.go
Normal file
165
modules/mgmt_test.go
Normal file
|
@ -0,0 +1,165 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testModuleMgmt(t *testing.T) {
|
||||
|
||||
// enable module management
|
||||
EnableModuleManagement(nil)
|
||||
|
||||
registerTestModule(t, "base")
|
||||
registerTestModule(t, "feature1", "base")
|
||||
registerTestModule(t, "base2", "base")
|
||||
registerTestModule(t, "feature2", "base2")
|
||||
registerTestModule(t, "sub-feature", "base")
|
||||
registerTestModule(t, "feature3", "sub-feature")
|
||||
registerTestModule(t, "feature4", "sub-feature")
|
||||
|
||||
// enable core module
|
||||
core := modules["base"]
|
||||
core.Enable()
|
||||
|
||||
// start and check order
|
||||
err := Start()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if changeHistory != " on:base" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
// enable feature1
|
||||
feature1 := modules["feature1"]
|
||||
feature1.Enable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if changeHistory != " on:feature1" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
// enable feature2
|
||||
feature2 := modules["feature2"]
|
||||
feature2.Enable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if changeHistory != " on:base2 on:feature2" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
// enable feature3
|
||||
feature3 := modules["feature3"]
|
||||
feature3.Enable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if changeHistory != " on:sub-feature on:feature3" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
// enable feature4
|
||||
feature4 := modules["feature4"]
|
||||
feature4.Enable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if changeHistory != " on:feature4" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
// disable feature1
|
||||
feature1.Disable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if changeHistory != " off:feature1" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
// disable feature3
|
||||
feature3.Disable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
// disable feature4
|
||||
feature4.Disable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if changeHistory != " off:feature3 off:feature4 off:sub-feature" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
// enable feature4
|
||||
feature4.Enable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if changeHistory != " on:sub-feature on:feature4" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
// disable feature4
|
||||
feature4.Disable()
|
||||
// manage modules and check
|
||||
err = ManageModules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if changeHistory != " off:feature4 off:sub-feature" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
err = Shutdown()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if changeHistory != " off:feature2 off:base2 off:base" {
|
||||
t.Errorf("order mismatch, was %s", changeHistory)
|
||||
}
|
||||
|
||||
// reset history
|
||||
changeHistory = ""
|
||||
|
||||
// disable module management
|
||||
moduleMgmtEnabled.UnSet()
|
||||
|
||||
resetTestEnvironment()
|
||||
}
|
|
@ -172,7 +172,7 @@ func microTaskScheduler() {
|
|||
|
||||
microTaskManageLoop:
|
||||
for {
|
||||
if shutdownSignalClosed.IsSet() {
|
||||
if shutdownFlag.IsSet() {
|
||||
close(mediumPriorityClearance)
|
||||
close(lowPriorityClearance)
|
||||
return
|
||||
|
|
|
@ -13,32 +13,44 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
modulesLock sync.RWMutex
|
||||
modules = make(map[string]*Module)
|
||||
modules = make(map[string]*Module)
|
||||
mgmtLock sync.Mutex
|
||||
|
||||
// lock modules when starting
|
||||
modulesLocked = abool.New()
|
||||
|
||||
// ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag.
|
||||
ErrCleanExit = errors.New("clean exit requested")
|
||||
)
|
||||
|
||||
// Module represents a module.
|
||||
type Module struct {
|
||||
type Module struct { //nolint:maligned // not worth the effort
|
||||
sync.RWMutex
|
||||
|
||||
Name string
|
||||
|
||||
// lifecycle mgmt
|
||||
Prepped *abool.AtomicBool
|
||||
Started *abool.AtomicBool
|
||||
Stopped *abool.AtomicBool
|
||||
inTransition *abool.AtomicBool
|
||||
// status mgmt
|
||||
enabled *abool.AtomicBool
|
||||
enabledAsDependency *abool.AtomicBool
|
||||
status uint8
|
||||
|
||||
// failure status
|
||||
failureStatus uint8
|
||||
failureID string
|
||||
failureMsg string
|
||||
|
||||
// lifecycle callback functions
|
||||
prep func() error
|
||||
start func() error
|
||||
stop func() error
|
||||
prepFn func() error
|
||||
startFn func() error
|
||||
stopFn func() error
|
||||
|
||||
// shutdown mgmt
|
||||
Ctx context.Context
|
||||
cancelCtx func()
|
||||
shutdownFlag *abool.AtomicBool
|
||||
// lifecycle mgmt
|
||||
// start
|
||||
startComplete chan struct{}
|
||||
// stop
|
||||
Ctx context.Context
|
||||
cancelCtx func()
|
||||
stopFlag *abool.AtomicBool
|
||||
|
||||
// workers/tasks
|
||||
workerCnt *int32
|
||||
|
@ -56,30 +68,177 @@ type Module struct {
|
|||
depReverse []*Module
|
||||
}
|
||||
|
||||
// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead.
|
||||
func (m *Module) ShutdownInProgress() bool {
|
||||
return m.shutdownFlag.IsSet()
|
||||
// StartCompleted returns a channel read that triggers when the module has finished starting.
|
||||
func (m *Module) StartCompleted() <-chan struct{} {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return m.startComplete
|
||||
}
|
||||
|
||||
// ShuttingDown lets you listen for the shutdown signal.
|
||||
func (m *Module) ShuttingDown() <-chan struct{} {
|
||||
// Stopping returns a channel read that triggers when the module has initiated the stop procedure.
|
||||
func (m *Module) Stopping() <-chan struct{} {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return m.Ctx.Done()
|
||||
}
|
||||
|
||||
func (m *Module) shutdown() error {
|
||||
// signal shutdown
|
||||
m.shutdownFlag.Set()
|
||||
m.cancelCtx()
|
||||
// IsStopping returns whether the module has started shutting down. In most cases, you should use Stopping instead.
|
||||
func (m *Module) IsStopping() bool {
|
||||
return m.stopFlag.IsSet()
|
||||
}
|
||||
|
||||
// start shutdown function
|
||||
m.waitGroup.Add(1)
|
||||
stopFnError := make(chan error, 1)
|
||||
// Dependencies returns the module's dependencies.
|
||||
func (m *Module) Dependencies() []*Module {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return m.depModules
|
||||
}
|
||||
|
||||
func (m *Module) prep(reports chan *report) {
|
||||
// check and set intermediate status
|
||||
m.Lock()
|
||||
if m.status != StatusDead {
|
||||
m.Unlock()
|
||||
go func() {
|
||||
reports <- &report{
|
||||
module: m,
|
||||
err: fmt.Errorf("module already prepped"),
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
m.status = StatusPreparing
|
||||
m.Unlock()
|
||||
|
||||
// run prep function
|
||||
go func() {
|
||||
stopFnError <- m.runCtrlFn("stop module", m.stop)
|
||||
m.waitGroup.Done()
|
||||
var err error
|
||||
if m.prepFn != nil {
|
||||
// execute function
|
||||
err = m.runCtrlFnWithTimeout(
|
||||
"prep module",
|
||||
10*time.Second,
|
||||
m.prepFn,
|
||||
)
|
||||
}
|
||||
// set status
|
||||
if err != nil {
|
||||
m.Error(
|
||||
"module-failed-prep",
|
||||
fmt.Sprintf("failed to prep module: %s", err.Error()),
|
||||
)
|
||||
} else {
|
||||
m.Lock()
|
||||
m.status = StatusOffline
|
||||
m.Unlock()
|
||||
m.notifyOfChange()
|
||||
}
|
||||
// send report
|
||||
reports <- &report{
|
||||
module: m,
|
||||
err: err,
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// wait for workers
|
||||
func (m *Module) start(reports chan *report) {
|
||||
// check and set intermediate status
|
||||
m.Lock()
|
||||
if m.status != StatusOffline {
|
||||
m.Unlock()
|
||||
go func() {
|
||||
reports <- &report{
|
||||
module: m,
|
||||
err: fmt.Errorf("module not offline"),
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
m.status = StatusStarting
|
||||
|
||||
// reset stop management
|
||||
if m.cancelCtx != nil {
|
||||
// trigger cancel just to be sure
|
||||
m.cancelCtx()
|
||||
}
|
||||
m.Ctx, m.cancelCtx = context.WithCancel(context.Background())
|
||||
m.stopFlag.UnSet()
|
||||
|
||||
m.Unlock()
|
||||
|
||||
// run start function
|
||||
go func() {
|
||||
var err error
|
||||
if m.startFn != nil {
|
||||
// execute function
|
||||
err = m.runCtrlFnWithTimeout(
|
||||
"start module",
|
||||
10*time.Second,
|
||||
m.startFn,
|
||||
)
|
||||
}
|
||||
// set status
|
||||
if err != nil {
|
||||
m.Error(
|
||||
"module-failed-start",
|
||||
fmt.Sprintf("failed to start module: %s", err.Error()),
|
||||
)
|
||||
} else {
|
||||
m.Lock()
|
||||
m.status = StatusOnline
|
||||
// init start management
|
||||
close(m.startComplete)
|
||||
m.Unlock()
|
||||
m.notifyOfChange()
|
||||
}
|
||||
// send report
|
||||
reports <- &report{
|
||||
module: m,
|
||||
err: err,
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *Module) stop(reports chan *report) {
|
||||
// check and set intermediate status
|
||||
m.Lock()
|
||||
if m.status != StatusOnline {
|
||||
m.Unlock()
|
||||
go func() {
|
||||
reports <- &report{
|
||||
module: m,
|
||||
err: fmt.Errorf("module not online"),
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
m.status = StatusStopping
|
||||
|
||||
// reset start management
|
||||
m.startComplete = make(chan struct{})
|
||||
// init stop management
|
||||
m.cancelCtx()
|
||||
m.stopFlag.Set()
|
||||
|
||||
m.Unlock()
|
||||
|
||||
go m.stopAllTasks(reports)
|
||||
}
|
||||
|
||||
func (m *Module) stopAllTasks(reports chan *report) {
|
||||
// start shutdown function
|
||||
stopFnFinished := abool.NewBool(false)
|
||||
var stopFnError error
|
||||
if m.stopFn != nil {
|
||||
m.waitGroup.Add(1)
|
||||
go func() {
|
||||
stopFnError = m.runCtrlFn("stop module", m.stopFn)
|
||||
stopFnFinished.Set()
|
||||
m.waitGroup.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait for workers and stop fn
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
m.waitGroup.Wait()
|
||||
|
@ -91,8 +250,9 @@ func (m *Module) shutdown() error {
|
|||
case <-done:
|
||||
case <-time.After(30 * time.Second):
|
||||
log.Warningf(
|
||||
"%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...",
|
||||
"%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v workers=%d tasks=%d microtasks=%d, continuing shutdown...",
|
||||
m.Name,
|
||||
stopFnFinished.IsSet(),
|
||||
atomic.LoadInt32(m.workerCnt),
|
||||
atomic.LoadInt32(m.taskCnt),
|
||||
atomic.LoadInt32(m.microTaskCnt),
|
||||
|
@ -100,24 +260,37 @@ func (m *Module) shutdown() error {
|
|||
}
|
||||
|
||||
// collect error
|
||||
select {
|
||||
case err := <-stopFnError:
|
||||
return err
|
||||
default:
|
||||
log.Warningf(
|
||||
"%s: timed out while waiting for stop function to finish, continuing shutdown...",
|
||||
m.Name,
|
||||
var err error
|
||||
if stopFnFinished.IsSet() && stopFnError != nil {
|
||||
err = stopFnError
|
||||
}
|
||||
// set status
|
||||
if err != nil {
|
||||
m.Error(
|
||||
"module-failed-stop",
|
||||
fmt.Sprintf("failed to stop module: %s", err.Error()),
|
||||
)
|
||||
return nil
|
||||
} else {
|
||||
m.Lock()
|
||||
m.status = StatusOffline
|
||||
m.Unlock()
|
||||
m.notifyOfChange()
|
||||
}
|
||||
// send report
|
||||
reports <- &report{
|
||||
module: m,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished.
|
||||
func Register(name string, prep, start, stop func() error, dependencies ...string) *Module {
|
||||
if modulesLocked.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
newModule := initNewModule(name, prep, start, stop, dependencies...)
|
||||
|
||||
modulesLock.Lock()
|
||||
defer modulesLock.Unlock()
|
||||
// check for already existing module
|
||||
_, ok := modules[name]
|
||||
if ok {
|
||||
|
@ -136,23 +309,22 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ...
|
|||
var microTaskCnt int32
|
||||
|
||||
newModule := &Module{
|
||||
Name: name,
|
||||
Prepped: abool.NewBool(false),
|
||||
Started: abool.NewBool(false),
|
||||
Stopped: abool.NewBool(false),
|
||||
inTransition: abool.NewBool(false),
|
||||
Ctx: ctx,
|
||||
cancelCtx: cancelCtx,
|
||||
shutdownFlag: abool.NewBool(false),
|
||||
waitGroup: sync.WaitGroup{},
|
||||
workerCnt: &workerCnt,
|
||||
taskCnt: &taskCnt,
|
||||
microTaskCnt: µTaskCnt,
|
||||
prep: prep,
|
||||
start: start,
|
||||
stop: stop,
|
||||
eventHooks: make(map[string][]*eventHook),
|
||||
depNames: dependencies,
|
||||
Name: name,
|
||||
enabled: abool.NewBool(false),
|
||||
enabledAsDependency: abool.NewBool(false),
|
||||
prepFn: prep,
|
||||
startFn: start,
|
||||
stopFn: stop,
|
||||
startComplete: make(chan struct{}),
|
||||
Ctx: ctx,
|
||||
cancelCtx: cancelCtx,
|
||||
stopFlag: abool.NewBool(false),
|
||||
workerCnt: &workerCnt,
|
||||
taskCnt: &taskCnt,
|
||||
microTaskCnt: µTaskCnt,
|
||||
waitGroup: sync.WaitGroup{},
|
||||
eventHooks: make(map[string][]*eventHook),
|
||||
depNames: dependencies,
|
||||
}
|
||||
|
||||
return newModule
|
||||
|
@ -177,49 +349,3 @@ func initDependencies() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadyToPrep returns whether all dependencies are ready for this module to prep.
|
||||
func (m *Module) ReadyToPrep() bool {
|
||||
if m.inTransition.IsSet() || m.Prepped.IsSet() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, dep := range m.depModules {
|
||||
if !dep.Prepped.IsSet() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ReadyToStart returns whether all dependencies are ready for this module to start.
|
||||
func (m *Module) ReadyToStart() bool {
|
||||
if m.inTransition.IsSet() || m.Started.IsSet() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, dep := range m.depModules {
|
||||
if !dep.Started.IsSet() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ReadyToStop returns whether all dependencies are ready for this module to stop.
|
||||
func (m *Module) ReadyToStop() bool {
|
||||
if !m.Started.IsSet() || m.inTransition.IsSet() || m.Stopped.IsSet() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, revDep := range m.depReverse {
|
||||
// not ready if a reverse dependency was started, but not yet stopped
|
||||
if revDep.Started.IsSet() && !revDep.Stopped.IsSet() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -5,40 +5,36 @@ import (
|
|||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
orderLock sync.Mutex
|
||||
startOrder string
|
||||
shutdownOrder string
|
||||
changeHistoryLock sync.Mutex
|
||||
changeHistory string
|
||||
)
|
||||
|
||||
func testPrep(t *testing.T, name string) func() error {
|
||||
return func() error {
|
||||
t.Logf("prep %s\n", name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func testStart(t *testing.T, name string) func() error {
|
||||
return func() error {
|
||||
orderLock.Lock()
|
||||
defer orderLock.Unlock()
|
||||
t.Logf("start %s\n", name)
|
||||
startOrder = fmt.Sprintf("%s>%s", startOrder, name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func testStop(t *testing.T, name string) func() error {
|
||||
return func() error {
|
||||
orderLock.Lock()
|
||||
defer orderLock.Unlock()
|
||||
t.Logf("stop %s\n", name)
|
||||
shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name)
|
||||
return nil
|
||||
}
|
||||
func registerTestModule(t *testing.T, name string, dependencies ...string) {
|
||||
Register(
|
||||
name,
|
||||
func() error {
|
||||
t.Logf("prep %s\n", name)
|
||||
return nil
|
||||
},
|
||||
func() error {
|
||||
changeHistoryLock.Lock()
|
||||
defer changeHistoryLock.Unlock()
|
||||
t.Logf("start %s\n", name)
|
||||
changeHistory = fmt.Sprintf("%s on:%s", changeHistory, name)
|
||||
return nil
|
||||
},
|
||||
func() error {
|
||||
changeHistoryLock.Lock()
|
||||
defer changeHistoryLock.Unlock()
|
||||
t.Logf("stop %s\n", name)
|
||||
changeHistory = fmt.Sprintf("%s off:%s", changeHistory, name)
|
||||
return nil
|
||||
},
|
||||
dependencies...,
|
||||
)
|
||||
}
|
||||
|
||||
func testFail() error {
|
||||
|
@ -53,135 +49,94 @@ func TestModules(t *testing.T) {
|
|||
t.Parallel() // Not really, just a workaround for running these tests last.
|
||||
|
||||
t.Run("TestModuleOrder", testModuleOrder)
|
||||
t.Run("TestModuleMgmt", testModuleMgmt)
|
||||
t.Run("TestModuleErrors", testModuleErrors)
|
||||
}
|
||||
|
||||
func testModuleOrder(t *testing.T) {
|
||||
|
||||
Register("database", testPrep(t, "database"), testStart(t, "database"), testStop(t, "database"))
|
||||
Register("stats", testPrep(t, "stats"), testStart(t, "stats"), testStop(t, "stats"), "database")
|
||||
Register("service", testPrep(t, "service"), testStart(t, "service"), testStop(t, "service"), "database")
|
||||
Register("analytics", testPrep(t, "analytics"), testStart(t, "analytics"), testStop(t, "analytics"), "stats", "database")
|
||||
registerTestModule(t, "database")
|
||||
registerTestModule(t, "stats", "database")
|
||||
registerTestModule(t, "service", "database")
|
||||
registerTestModule(t, "analytics", "stats", "database")
|
||||
|
||||
err := Start()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if startOrder != ">database>service>stats>analytics" &&
|
||||
startOrder != ">database>stats>service>analytics" &&
|
||||
startOrder != ">database>stats>analytics>service" {
|
||||
t.Errorf("start order mismatch, was %s", startOrder)
|
||||
if changeHistory != " on:database on:service on:stats on:analytics" &&
|
||||
changeHistory != " on:database on:stats on:service on:analytics" &&
|
||||
changeHistory != " on:database on:stats on:analytics on:service" {
|
||||
t.Errorf("start order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
select {
|
||||
case <-ShuttingDown():
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("did not receive shutdown signal")
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
err = Shutdown()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if shutdownOrder != ">analytics>service>stats>database" &&
|
||||
shutdownOrder != ">analytics>stats>service>database" &&
|
||||
shutdownOrder != ">service>analytics>stats>database" {
|
||||
t.Errorf("shutdown order mismatch, was %s", shutdownOrder)
|
||||
if changeHistory != " off:analytics off:service off:stats off:database" &&
|
||||
changeHistory != " off:analytics off:stats off:service off:database" &&
|
||||
changeHistory != " off:service off:analytics off:stats off:database" {
|
||||
t.Errorf("shutdown order mismatch, was %s", changeHistory)
|
||||
}
|
||||
changeHistory = ""
|
||||
|
||||
wg.Wait()
|
||||
|
||||
printAndRemoveModules()
|
||||
}
|
||||
|
||||
func printAndRemoveModules() {
|
||||
modulesLock.Lock()
|
||||
defer modulesLock.Unlock()
|
||||
|
||||
fmt.Printf("All %d modules:\n", len(modules))
|
||||
for _, m := range modules {
|
||||
fmt.Printf("module %s: %+v\n", m.Name, m)
|
||||
}
|
||||
|
||||
modules = make(map[string]*Module)
|
||||
resetTestEnvironment()
|
||||
}
|
||||
|
||||
func testModuleErrors(t *testing.T) {
|
||||
|
||||
// reset modules
|
||||
modules = make(map[string]*Module)
|
||||
startComplete.UnSet()
|
||||
startCompleteSignal = make(chan struct{})
|
||||
|
||||
// test prep error
|
||||
Register("prepfail", testFail, testStart(t, "prepfail"), testStop(t, "prepfail"))
|
||||
Register("prepfail", testFail, nil, nil)
|
||||
err := Start()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
// reset modules
|
||||
modules = make(map[string]*Module)
|
||||
startComplete.UnSet()
|
||||
startCompleteSignal = make(chan struct{})
|
||||
resetTestEnvironment()
|
||||
|
||||
// test prep clean exit
|
||||
Register("prepcleanexit", testCleanExit, testStart(t, "prepcleanexit"), testStop(t, "prepcleanexit"))
|
||||
Register("prepcleanexit", testCleanExit, nil, nil)
|
||||
err = Start()
|
||||
if err != ErrCleanExit {
|
||||
t.Error("should fail with clean exit")
|
||||
}
|
||||
|
||||
// reset modules
|
||||
modules = make(map[string]*Module)
|
||||
startComplete.UnSet()
|
||||
startCompleteSignal = make(chan struct{})
|
||||
resetTestEnvironment()
|
||||
|
||||
// test invalid dependency
|
||||
Register("database", nil, testStart(t, "database"), testStop(t, "database"), "invalid")
|
||||
Register("database", nil, nil, nil, "invalid")
|
||||
err = Start()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
// reset modules
|
||||
modules = make(map[string]*Module)
|
||||
startComplete.UnSet()
|
||||
startCompleteSignal = make(chan struct{})
|
||||
resetTestEnvironment()
|
||||
|
||||
// test dependency loop
|
||||
Register("database", nil, testStart(t, "database"), testStop(t, "database"), "helper")
|
||||
Register("helper", nil, testStart(t, "helper"), testStop(t, "helper"), "database")
|
||||
registerTestModule(t, "database", "helper")
|
||||
registerTestModule(t, "helper", "database")
|
||||
err = Start()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
// reset modules
|
||||
modules = make(map[string]*Module)
|
||||
startComplete.UnSet()
|
||||
startCompleteSignal = make(chan struct{})
|
||||
resetTestEnvironment()
|
||||
|
||||
// test failing module start
|
||||
Register("startfail", nil, testFail, testStop(t, "startfail"))
|
||||
Register("startfail", nil, testFail, nil)
|
||||
err = Start()
|
||||
if err == nil {
|
||||
t.Error("should fail")
|
||||
}
|
||||
|
||||
// reset modules
|
||||
modules = make(map[string]*Module)
|
||||
startComplete.UnSet()
|
||||
startCompleteSignal = make(chan struct{})
|
||||
resetTestEnvironment()
|
||||
|
||||
// test failing module stop
|
||||
Register("stopfail", nil, testStart(t, "stopfail"), testFail)
|
||||
Register("stopfail", nil, nil, testFail)
|
||||
err = Start()
|
||||
if err != nil {
|
||||
t.Error("should not fail")
|
||||
|
@ -191,10 +146,7 @@ func testModuleErrors(t *testing.T) {
|
|||
t.Error("should fail")
|
||||
}
|
||||
|
||||
// reset modules
|
||||
modules = make(map[string]*Module)
|
||||
startComplete.UnSet()
|
||||
startCompleteSignal = make(chan struct{})
|
||||
resetTestEnvironment()
|
||||
|
||||
// test help flag
|
||||
HelpFlag = true
|
||||
|
@ -204,4 +156,20 @@ func testModuleErrors(t *testing.T) {
|
|||
}
|
||||
HelpFlag = false
|
||||
|
||||
resetTestEnvironment()
|
||||
}
|
||||
|
||||
func printModules() { //nolint:unused,deadcode
|
||||
fmt.Printf("All %d modules:\n", len(modules))
|
||||
for _, m := range modules {
|
||||
fmt.Printf("module %s: %+v\n", m.Name, m)
|
||||
}
|
||||
}
|
||||
|
||||
func resetTestEnvironment() {
|
||||
modules = make(map[string]*Module)
|
||||
shutdownSignal = make(chan struct{})
|
||||
shutdownCompleteSignal = make(chan struct{})
|
||||
shutdownFlag.UnSet()
|
||||
modulesLocked.UnSet()
|
||||
}
|
||||
|
|
168
modules/start.go
168
modules/start.go
|
@ -1,34 +1,38 @@
|
|||
package modules
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
var (
|
||||
startComplete = abool.NewBool(false)
|
||||
startCompleteSignal = make(chan struct{})
|
||||
initialStartCompleted = abool.NewBool(false)
|
||||
globalPrepFn func() error
|
||||
)
|
||||
|
||||
// StartCompleted returns whether starting has completed.
|
||||
func StartCompleted() bool {
|
||||
return startComplete.IsSet()
|
||||
}
|
||||
|
||||
// WaitForStartCompletion returns as soon as starting has completed.
|
||||
func WaitForStartCompletion() <-chan struct{} {
|
||||
return startCompleteSignal
|
||||
// SetGlobalPrepFn sets a global prep function that is run before all modules. This can be used to pre-initialize modules, such as setting the data root or database path.
|
||||
// SetGlobalPrepFn sets a global prep function that is run before all modules.
|
||||
func SetGlobalPrepFn(fn func() error) {
|
||||
if globalPrepFn == nil {
|
||||
globalPrepFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts all modules in the correct order. In case of an error, it will automatically shutdown again.
|
||||
func Start() error {
|
||||
modulesLock.RLock()
|
||||
defer modulesLock.RUnlock()
|
||||
if !modulesLocked.SetToIf(false, true) {
|
||||
return errors.New("module system already started")
|
||||
}
|
||||
|
||||
// lock mgmt
|
||||
mgmtLock.Lock()
|
||||
defer mgmtLock.Unlock()
|
||||
|
||||
// start microtask scheduler
|
||||
go microTaskScheduler()
|
||||
|
@ -44,10 +48,23 @@ func Start() error {
|
|||
// parse flags
|
||||
err = parseFlags()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err)
|
||||
if err != ErrCleanExit {
|
||||
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// execute global prep fn
|
||||
if globalPrepFn != nil {
|
||||
err = globalPrepFn()
|
||||
if err != nil {
|
||||
if err != ErrCleanExit {
|
||||
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: %s\n", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// prep modules
|
||||
err = prepareModules()
|
||||
if err != nil {
|
||||
|
@ -65,6 +82,9 @@ func Start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// build dependency tree
|
||||
buildEnabledTree()
|
||||
|
||||
// start modules
|
||||
log.Info("modules: initiating...")
|
||||
err = startModules()
|
||||
|
@ -74,14 +94,16 @@ func Start() error {
|
|||
}
|
||||
|
||||
// complete startup
|
||||
log.Infof("modules: started %d modules", len(modules))
|
||||
if startComplete.SetToIf(false, true) {
|
||||
close(startCompleteSignal)
|
||||
if moduleMgmtEnabled.IsSet() {
|
||||
log.Info("modules: initiated subsystems manager")
|
||||
} else {
|
||||
log.Infof("modules: started %d modules", len(modules))
|
||||
}
|
||||
|
||||
go taskQueueHandler()
|
||||
go taskScheduleHandler()
|
||||
|
||||
initialStartCompleted.Set()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -97,45 +119,36 @@ func prepareModules() error {
|
|||
reportCnt := 0
|
||||
|
||||
for {
|
||||
waiting := 0
|
||||
|
||||
// find modules to exec
|
||||
for _, m := range modules {
|
||||
if m.ReadyToPrep() {
|
||||
switch m.readyToPrep() {
|
||||
case statusNothingToDo:
|
||||
case statusWaiting:
|
||||
waiting++
|
||||
case statusReady:
|
||||
execCnt++
|
||||
m.inTransition.Set()
|
||||
|
||||
execM := m
|
||||
go func() {
|
||||
reports <- &report{
|
||||
module: execM,
|
||||
err: execM.runCtrlFnWithTimeout(
|
||||
"prep module",
|
||||
10*time.Second,
|
||||
execM.prep,
|
||||
),
|
||||
}
|
||||
}()
|
||||
m.prep(reports)
|
||||
}
|
||||
}
|
||||
|
||||
// check for dep loop
|
||||
if execCnt == reportCnt {
|
||||
return fmt.Errorf("modules: dependency loop detected, cannot continue")
|
||||
}
|
||||
|
||||
// wait for reports
|
||||
rep = <-reports
|
||||
rep.module.inTransition.UnSet()
|
||||
if rep.err != nil {
|
||||
if rep.err == ErrCleanExit {
|
||||
return rep.err
|
||||
if reportCnt < execCnt {
|
||||
// wait for reports
|
||||
rep = <-reports
|
||||
if rep.err != nil {
|
||||
if rep.err == ErrCleanExit {
|
||||
return rep.err
|
||||
}
|
||||
return fmt.Errorf("failed to prep module %s: %s", rep.module.Name, rep.err)
|
||||
}
|
||||
reportCnt++
|
||||
} else {
|
||||
// finished
|
||||
if waiting > 0 {
|
||||
// check for dep loop
|
||||
return fmt.Errorf("modules: dependency loop detected, cannot continue")
|
||||
}
|
||||
return fmt.Errorf("failed to prep module %s: %s", rep.module.Name, rep.err)
|
||||
}
|
||||
reportCnt++
|
||||
rep.module.Prepped.Set()
|
||||
|
||||
// exit if done
|
||||
if reportCnt == len(modules) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -149,45 +162,36 @@ func startModules() error {
|
|||
reportCnt := 0
|
||||
|
||||
for {
|
||||
waiting := 0
|
||||
|
||||
// find modules to exec
|
||||
for _, m := range modules {
|
||||
if m.ReadyToStart() {
|
||||
switch m.readyToStart() {
|
||||
case statusNothingToDo:
|
||||
case statusWaiting:
|
||||
waiting++
|
||||
case statusReady:
|
||||
execCnt++
|
||||
m.inTransition.Set()
|
||||
|
||||
execM := m
|
||||
go func() {
|
||||
reports <- &report{
|
||||
module: execM,
|
||||
err: execM.runCtrlFnWithTimeout(
|
||||
"start module",
|
||||
60*time.Second,
|
||||
execM.start,
|
||||
),
|
||||
}
|
||||
}()
|
||||
m.start(reports)
|
||||
}
|
||||
}
|
||||
|
||||
// check for dep loop
|
||||
if execCnt == reportCnt {
|
||||
return fmt.Errorf("modules: dependency loop detected, cannot continue")
|
||||
}
|
||||
|
||||
// wait for reports
|
||||
rep = <-reports
|
||||
rep.module.inTransition.UnSet()
|
||||
if rep.err != nil {
|
||||
return fmt.Errorf("modules: could not start module %s: %s", rep.module.Name, rep.err)
|
||||
}
|
||||
reportCnt++
|
||||
rep.module.Started.Set()
|
||||
log.Infof("modules: started %s", rep.module.Name)
|
||||
|
||||
// exit if done
|
||||
if reportCnt == len(modules) {
|
||||
if reportCnt < execCnt {
|
||||
// wait for reports
|
||||
rep = <-reports
|
||||
if rep.err != nil {
|
||||
return fmt.Errorf("modules: could not start module %s: %s", rep.module.Name, rep.err)
|
||||
}
|
||||
reportCnt++
|
||||
log.Infof("modules: started %s", rep.module.Name)
|
||||
} else {
|
||||
// finished
|
||||
if waiting > 0 {
|
||||
// check for dep loop
|
||||
return fmt.Errorf("modules: dependency loop detected, cannot continue")
|
||||
}
|
||||
// return last error
|
||||
return nil
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
171
modules/status.go
Normal file
171
modules/status.go
Normal file
|
@ -0,0 +1,171 @@
|
|||
package modules
|
||||
|
||||
// Module Status Values
|
||||
const (
|
||||
StatusDead uint8 = 0 // not prepared, not started
|
||||
StatusPreparing uint8 = 1
|
||||
StatusOffline uint8 = 2 // prepared, not started
|
||||
StatusStopping uint8 = 3
|
||||
StatusStarting uint8 = 4
|
||||
StatusOnline uint8 = 5 // online and running
|
||||
)
|
||||
|
||||
// Module Failure Status Values
|
||||
const (
|
||||
FailureNone uint8 = 0
|
||||
FailureHint uint8 = 1
|
||||
FailureWarning uint8 = 2
|
||||
FailureError uint8 = 3
|
||||
)
|
||||
|
||||
// ready status
|
||||
const (
|
||||
statusWaiting uint8 = iota
|
||||
statusReady
|
||||
statusNothingToDo
|
||||
)
|
||||
|
||||
// Online returns whether the module is online.
|
||||
func (m *Module) Online() bool {
|
||||
return m.Status() == StatusOnline
|
||||
}
|
||||
|
||||
// OnlineSoon returns whether the module is or is about to be online.
|
||||
func (m *Module) OnlineSoon() bool {
|
||||
if moduleMgmtEnabled.IsSet() &&
|
||||
!m.enabled.IsSet() &&
|
||||
!m.enabledAsDependency.IsSet() {
|
||||
return false
|
||||
}
|
||||
return !m.stopFlag.IsSet()
|
||||
}
|
||||
|
||||
// Status returns the current module status.
|
||||
func (m *Module) Status() uint8 {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
return m.status
|
||||
}
|
||||
|
||||
// FailureStatus returns the current failure status, ID and message.
|
||||
func (m *Module) FailureStatus() (failureStatus uint8, failureID, failureMsg string) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
return m.failureStatus, m.failureID, m.failureMsg
|
||||
}
|
||||
|
||||
// Hint sets failure status to hint. This is a somewhat special failure status, as the module is believed to be working correctly, but there is an important module specific information to convey. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans.
|
||||
func (m *Module) Hint(failureID, failureMsg string) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
m.failureStatus = FailureHint
|
||||
m.failureID = failureID
|
||||
m.failureMsg = failureMsg
|
||||
|
||||
m.notifyOfChange()
|
||||
}
|
||||
|
||||
// Warning sets failure status to warning. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans.
|
||||
func (m *Module) Warning(failureID, failureMsg string) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
m.failureStatus = FailureWarning
|
||||
m.failureID = failureID
|
||||
m.failureMsg = failureMsg
|
||||
|
||||
m.notifyOfChange()
|
||||
}
|
||||
|
||||
// Error sets failure status to error. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans.
|
||||
func (m *Module) Error(failureID, failureMsg string) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
m.failureStatus = FailureError
|
||||
m.failureID = failureID
|
||||
m.failureMsg = failureMsg
|
||||
|
||||
m.notifyOfChange()
|
||||
}
|
||||
|
||||
// Resolve removes the failure state from the module if the given failureID matches the current failure ID. If the given failureID is an empty string, Resolve removes any failure state.
|
||||
func (m *Module) Resolve(failureID string) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if failureID == "" || failureID == m.failureID {
|
||||
m.failureStatus = FailureNone
|
||||
m.failureID = ""
|
||||
m.failureMsg = ""
|
||||
}
|
||||
|
||||
m.notifyOfChange()
|
||||
}
|
||||
|
||||
// readyToPrep returns whether all dependencies are ready for this module to prep.
|
||||
func (m *Module) readyToPrep() uint8 {
|
||||
// check if valid state for prepping
|
||||
if m.Status() != StatusDead {
|
||||
return statusNothingToDo
|
||||
}
|
||||
|
||||
for _, dep := range m.depModules {
|
||||
if dep.Status() < StatusOffline {
|
||||
return statusWaiting
|
||||
}
|
||||
}
|
||||
|
||||
return statusReady
|
||||
}
|
||||
|
||||
// readyToStart returns whether all dependencies are ready for this module to start.
|
||||
func (m *Module) readyToStart() uint8 {
|
||||
// check if start is wanted
|
||||
if moduleMgmtEnabled.IsSet() {
|
||||
if !m.enabled.IsSet() && !m.enabledAsDependency.IsSet() {
|
||||
return statusNothingToDo
|
||||
}
|
||||
}
|
||||
|
||||
// check if valid state for starting
|
||||
if m.Status() != StatusOffline {
|
||||
return statusNothingToDo
|
||||
}
|
||||
|
||||
// check if all dependencies are ready
|
||||
for _, dep := range m.depModules {
|
||||
if dep.Status() < StatusOnline {
|
||||
return statusWaiting
|
||||
}
|
||||
}
|
||||
|
||||
return statusReady
|
||||
}
|
||||
|
||||
// readyToStop returns whether all dependencies are ready for this module to stop.
|
||||
func (m *Module) readyToStop() uint8 {
|
||||
// check if stop is wanted
|
||||
if moduleMgmtEnabled.IsSet() && !shutdownFlag.IsSet() {
|
||||
if m.enabled.IsSet() || m.enabledAsDependency.IsSet() {
|
||||
return statusNothingToDo
|
||||
}
|
||||
}
|
||||
|
||||
// check if valid state for stopping
|
||||
if m.Status() != StatusOnline {
|
||||
return statusNothingToDo
|
||||
}
|
||||
|
||||
for _, revDep := range m.depReverse {
|
||||
// not ready if a reverse dependency was started, but not yet stopped
|
||||
if revDep.Status() > StatusOffline {
|
||||
return statusWaiting
|
||||
}
|
||||
}
|
||||
|
||||
return statusReady
|
||||
}
|
|
@ -10,12 +10,17 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
shutdownSignal = make(chan struct{})
|
||||
shutdownSignalClosed = abool.NewBool(false)
|
||||
shutdownSignal = make(chan struct{})
|
||||
shutdownFlag = abool.NewBool(false)
|
||||
|
||||
shutdownCompleteSignal = make(chan struct{})
|
||||
)
|
||||
|
||||
// IsShuttingDown returns whether the global shutdown is in progress.
|
||||
func IsShuttingDown() bool {
|
||||
return shutdownFlag.IsSet()
|
||||
}
|
||||
|
||||
// ShuttingDown returns a channel read on the global shutdown signal.
|
||||
func ShuttingDown() <-chan struct{} {
|
||||
return shutdownSignal
|
||||
|
@ -23,18 +28,19 @@ func ShuttingDown() <-chan struct{} {
|
|||
|
||||
// Shutdown stops all modules in the correct order.
|
||||
func Shutdown() error {
|
||||
// lock mgmt
|
||||
mgmtLock.Lock()
|
||||
defer mgmtLock.Unlock()
|
||||
|
||||
if shutdownSignalClosed.SetToIf(false, true) {
|
||||
if shutdownFlag.SetToIf(false, true) {
|
||||
close(shutdownSignal)
|
||||
} else {
|
||||
// shutdown was already issued
|
||||
return errors.New("shutdown already initiated")
|
||||
}
|
||||
|
||||
if startComplete.IsSet() {
|
||||
if initialStartCompleted.IsSet() {
|
||||
log.Warning("modules: starting shutdown...")
|
||||
modulesLock.Lock()
|
||||
defer modulesLock.Unlock()
|
||||
} else {
|
||||
log.Warning("modules: aborting, shutting down...")
|
||||
}
|
||||
|
@ -61,46 +67,42 @@ func stopModules() error {
|
|||
// get number of started modules
|
||||
startedCnt := 0
|
||||
for _, m := range modules {
|
||||
if m.Started.IsSet() {
|
||||
if m.Status() >= StatusStarting {
|
||||
startedCnt++
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
waiting := 0
|
||||
|
||||
// find modules to exec
|
||||
for _, m := range modules {
|
||||
if m.ReadyToStop() {
|
||||
switch m.readyToStop() {
|
||||
case statusNothingToDo:
|
||||
case statusWaiting:
|
||||
waiting++
|
||||
case statusReady:
|
||||
execCnt++
|
||||
m.inTransition.Set()
|
||||
|
||||
execM := m
|
||||
go func() {
|
||||
reports <- &report{
|
||||
module: execM,
|
||||
err: execM.shutdown(),
|
||||
}
|
||||
}()
|
||||
m.stop(reports)
|
||||
}
|
||||
}
|
||||
|
||||
// check for dep loop
|
||||
if execCnt == reportCnt {
|
||||
return fmt.Errorf("modules: dependency loop detected, cannot continue")
|
||||
}
|
||||
|
||||
// wait for reports
|
||||
rep = <-reports
|
||||
rep.module.inTransition.UnSet()
|
||||
if rep.err != nil {
|
||||
lastErr = rep.err
|
||||
log.Warningf("modules: could not stop module %s: %s", rep.module.Name, rep.err)
|
||||
}
|
||||
reportCnt++
|
||||
rep.module.Stopped.Set()
|
||||
log.Infof("modules: stopped %s", rep.module.Name)
|
||||
|
||||
// exit if done
|
||||
if reportCnt == startedCnt {
|
||||
if reportCnt < execCnt {
|
||||
// wait for reports
|
||||
rep = <-reports
|
||||
if rep.err != nil {
|
||||
lastErr = rep.err
|
||||
log.Warningf("modules: could not stop module %s: %s", rep.module.Name, rep.err)
|
||||
}
|
||||
reportCnt++
|
||||
log.Infof("modules: stopped %s", rep.module.Name)
|
||||
} else {
|
||||
// finished
|
||||
if waiting > 0 {
|
||||
// check for dep loop
|
||||
return fmt.Errorf("modules: dependency loop detected, cannot continue")
|
||||
}
|
||||
// return last error
|
||||
return lastErr
|
||||
}
|
||||
}
|
||||
|
|
123
modules/subsystems/module.go
Normal file
123
modules/subsystems/module.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package subsystems
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/safing/portbase/database"
|
||||
_ "github.com/safing/portbase/database/dbmodule" // database module is required
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
const (
|
||||
configChangeEvent = "config change"
|
||||
subsystemsStatusChange = "status change"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
printGraphFlag bool
|
||||
|
||||
databaseKeySpace string
|
||||
db = database.NewInterface(nil)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// enable partial starting
|
||||
modules.EnableModuleManagement(handleModuleChanges)
|
||||
|
||||
// register module and enable it for starting
|
||||
module = modules.Register("subsystems", prep, start, nil, "config", "database", "base")
|
||||
module.Enable()
|
||||
|
||||
// register event for changes in the subsystem
|
||||
module.RegisterEvent(subsystemsStatusChange)
|
||||
|
||||
flag.BoolVar(&printGraphFlag, "print-subsystem-graph", false, "print the subsystem module dependency graph")
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
if printGraphFlag {
|
||||
printGraph()
|
||||
return modules.ErrCleanExit
|
||||
}
|
||||
|
||||
return module.RegisterEventHook("config", configChangeEvent, "control subsystems", handleConfigChanges)
|
||||
}
|
||||
|
||||
func start() error {
|
||||
// lock registration
|
||||
subsystemsLocked.Set()
|
||||
|
||||
// lock slice and map
|
||||
subsystemsLock.Lock()
|
||||
// go through all dependencies
|
||||
seen := make(map[string]struct{})
|
||||
for _, sub := range subsystems {
|
||||
// mark subsystem module as seen
|
||||
seen[sub.module.Name] = struct{}{}
|
||||
}
|
||||
for _, sub := range subsystems {
|
||||
// add main module
|
||||
sub.Modules = append(sub.Modules, statusFromModule(sub.module))
|
||||
// add dependencies
|
||||
sub.addDependencies(sub.module, seen)
|
||||
}
|
||||
// unlock
|
||||
subsystemsLock.Unlock()
|
||||
|
||||
// apply config
|
||||
module.StartWorker("initial subsystem configuration", func(ctx context.Context) error {
|
||||
return handleConfigChanges(module.Ctx, nil)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sub *Subsystem) addDependencies(module *modules.Module, seen map[string]struct{}) {
|
||||
for _, module := range module.Dependencies() {
|
||||
_, ok := seen[module.Name]
|
||||
if !ok {
|
||||
// add dependency to modules
|
||||
sub.Modules = append(sub.Modules, statusFromModule(module))
|
||||
// mark as seen
|
||||
seen[module.Name] = struct{}{}
|
||||
// add further dependencies
|
||||
sub.addDependencies(module, seen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetDatabaseKeySpace sets a key space where subsystem status
|
||||
func SetDatabaseKeySpace(keySpace string) {
|
||||
if databaseKeySpace == "" {
|
||||
databaseKeySpace = keySpace
|
||||
|
||||
if !strings.HasSuffix(databaseKeySpace, "/") {
|
||||
databaseKeySpace += "/"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func printGraph() {
|
||||
// unmark subsystems module
|
||||
module.Disable()
|
||||
// mark roots
|
||||
for _, sub := range subsystems {
|
||||
sub.module.Enable() // mark as tree root
|
||||
}
|
||||
// print
|
||||
for _, sub := range subsystems {
|
||||
printModuleGraph("", sub.module, true)
|
||||
}
|
||||
}
|
||||
|
||||
func printModuleGraph(prefix string, module *modules.Module, root bool) {
|
||||
fmt.Printf("%s├── %s\n", prefix, module.Name)
|
||||
if root || !module.Enabled() {
|
||||
for _, dep := range module.Dependencies() {
|
||||
printModuleGraph(fmt.Sprintf("│ %s", prefix), dep, false)
|
||||
}
|
||||
}
|
||||
}
|
116
modules/subsystems/subsystem.go
Normal file
116
modules/subsystems/subsystem.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package subsystems
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/database/record"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
// Subsystem describes a subset of modules that represent a part of a service or program to the user.
|
||||
type Subsystem struct { //nolint:maligned // not worth the effort
|
||||
record.Base
|
||||
sync.Mutex
|
||||
|
||||
ID string
|
||||
Name string
|
||||
Description string
|
||||
module *modules.Module
|
||||
|
||||
Modules []*ModuleStatus
|
||||
FailureStatus uint8 // summary: worst status
|
||||
|
||||
ToggleOptionKey string
|
||||
toggleOption *config.Option
|
||||
toggleValue func() bool
|
||||
ExpertiseLevel uint8 // copied from toggleOption
|
||||
ReleaseLevel uint8 // copied from toggleOption
|
||||
|
||||
ConfigKeySpace string
|
||||
}
|
||||
|
||||
// ModuleStatus describes the status of a module.
|
||||
type ModuleStatus struct {
|
||||
Name string
|
||||
module *modules.Module
|
||||
|
||||
// status mgmt
|
||||
Enabled bool
|
||||
Status uint8
|
||||
|
||||
// failure status
|
||||
FailureStatus uint8
|
||||
FailureID string
|
||||
FailureMsg string
|
||||
}
|
||||
|
||||
// Save saves the Subsystem Status to the database.
|
||||
func (sub *Subsystem) Save() {
|
||||
if databaseKeySpace != "" {
|
||||
if !sub.KeyIsSet() {
|
||||
sub.SetKey(databaseKeySpace + sub.ID)
|
||||
}
|
||||
err := db.Put(sub)
|
||||
if err != nil {
|
||||
log.Errorf("subsystems: could not save subsystem status to database: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func statusFromModule(module *modules.Module) *ModuleStatus {
|
||||
status := &ModuleStatus{
|
||||
Name: module.Name,
|
||||
module: module,
|
||||
Enabled: module.Enabled(),
|
||||
Status: module.Status(),
|
||||
}
|
||||
status.FailureStatus, status.FailureID, status.FailureMsg = module.FailureStatus()
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
func compareAndUpdateStatus(module *modules.Module, status *ModuleStatus) (changed bool) {
|
||||
// check if enabled
|
||||
enabled := module.Enabled()
|
||||
if status.Enabled != enabled {
|
||||
status.Enabled = enabled
|
||||
changed = true
|
||||
}
|
||||
|
||||
// check status
|
||||
statusLvl := module.Status()
|
||||
if status.Status != statusLvl {
|
||||
status.Status = statusLvl
|
||||
changed = true
|
||||
}
|
||||
|
||||
// check failure status
|
||||
failureStatus, failureID, failureMsg := module.FailureStatus()
|
||||
if status.FailureStatus != failureStatus ||
|
||||
status.FailureID != failureID {
|
||||
status.FailureStatus = failureStatus
|
||||
status.FailureID = failureID
|
||||
status.FailureMsg = failureMsg
|
||||
changed = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (sub *Subsystem) makeSummary() {
|
||||
// find worst failing module
|
||||
worstFailing := &ModuleStatus{}
|
||||
for _, depStatus := range sub.Modules {
|
||||
if depStatus.FailureStatus > worstFailing.FailureStatus {
|
||||
worstFailing = depStatus
|
||||
}
|
||||
}
|
||||
|
||||
if worstFailing != nil {
|
||||
sub.FailureStatus = worstFailing.FailureStatus
|
||||
} else {
|
||||
sub.FailureStatus = 0
|
||||
}
|
||||
}
|
161
modules/subsystems/subsystems.go
Normal file
161
modules/subsystems/subsystems.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package subsystems
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
var (
|
||||
subsystems []*Subsystem
|
||||
subsystemsMap = make(map[string]*Subsystem)
|
||||
subsystemsLock sync.Mutex
|
||||
subsystemsLocked = abool.New()
|
||||
|
||||
handlingConfigChanges = abool.New()
|
||||
)
|
||||
|
||||
// Register registers a new subsystem. The given option must be a bool option. Should be called in init() directly after the modules.Register() function. The config option must not yet be registered and will be registered for you. Pass a nil option to force enable.
|
||||
func Register(id, name, description string, module *modules.Module, configKeySpace string, option *config.Option) {
|
||||
// lock slice and map
|
||||
subsystemsLock.Lock()
|
||||
defer subsystemsLock.Unlock()
|
||||
|
||||
// check if registration is closed
|
||||
if subsystemsLocked.IsSet() {
|
||||
panic("subsystems can only be registered in prep phase or earlier")
|
||||
}
|
||||
|
||||
// check if already registered
|
||||
_, ok := subsystemsMap[name]
|
||||
if ok {
|
||||
panic(fmt.Sprintf(`subsystem "%s" already registered`, name))
|
||||
}
|
||||
|
||||
// create new
|
||||
new := &Subsystem{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Description: description,
|
||||
module: module,
|
||||
toggleOption: option,
|
||||
ConfigKeySpace: configKeySpace,
|
||||
}
|
||||
if new.toggleOption != nil {
|
||||
new.ToggleOptionKey = new.toggleOption.Key
|
||||
new.ExpertiseLevel = new.toggleOption.ExpertiseLevel
|
||||
new.ReleaseLevel = new.toggleOption.ReleaseLevel
|
||||
}
|
||||
|
||||
// register config
|
||||
if option != nil {
|
||||
err := config.Register(option)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to register config: %s", err))
|
||||
}
|
||||
new.toggleValue = config.GetAsBool(new.ToggleOptionKey, false)
|
||||
} else {
|
||||
// force enabled
|
||||
new.toggleValue = func() bool { return true }
|
||||
}
|
||||
|
||||
// add to lists
|
||||
subsystemsMap[name] = new
|
||||
subsystems = append(subsystems, new)
|
||||
}
|
||||
|
||||
func handleModuleChanges(m *modules.Module) {
|
||||
// check if ready
|
||||
if !subsystemsLocked.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
// check if shutting down
|
||||
if modules.IsShuttingDown() {
|
||||
return
|
||||
}
|
||||
|
||||
// find module status
|
||||
var moduleSubsystem *Subsystem
|
||||
var moduleStatus *ModuleStatus
|
||||
subsystemLoop:
|
||||
for _, subsystem := range subsystems {
|
||||
for _, status := range subsystem.Modules {
|
||||
if m.Name == status.Name {
|
||||
moduleSubsystem = subsystem
|
||||
moduleStatus = status
|
||||
break subsystemLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
// abort if not found
|
||||
if moduleSubsystem == nil || moduleStatus == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// update status
|
||||
moduleSubsystem.Lock()
|
||||
changed := compareAndUpdateStatus(m, moduleStatus)
|
||||
if changed {
|
||||
moduleSubsystem.makeSummary()
|
||||
}
|
||||
moduleSubsystem.Unlock()
|
||||
|
||||
// save
|
||||
if changed {
|
||||
moduleSubsystem.Save()
|
||||
}
|
||||
}
|
||||
|
||||
func handleConfigChanges(ctx context.Context, data interface{}) error {
|
||||
// check if ready
|
||||
if !subsystemsLocked.IsSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// potentially catch multiple changes
|
||||
if handlingConfigChanges.SetToIf(false, true) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
handlingConfigChanges.UnSet()
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// don't do anything if we are already shutting down globally
|
||||
if modules.IsShuttingDown() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only run one instance at any time
|
||||
subsystemsLock.Lock()
|
||||
defer subsystemsLock.Unlock()
|
||||
|
||||
var changed bool
|
||||
for _, subsystem := range subsystems {
|
||||
if subsystem.module.SetEnabled(subsystem.toggleValue()) {
|
||||
// if changed
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// trigger module management if any setting was changed
|
||||
if changed {
|
||||
err := modules.ManageModules()
|
||||
if err != nil {
|
||||
module.Error(
|
||||
"modulemgmt-failed",
|
||||
fmt.Sprintf("The subsystem framework failed to start or stop one or more modules.\nError: %s\nCheck logs for more information.", err),
|
||||
)
|
||||
} else {
|
||||
module.Resolve("modulemgmt-failed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
123
modules/subsystems/subsystems_test.go
Normal file
123
modules/subsystems/subsystems_test.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package subsystems
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
func TestSubsystems(t *testing.T) {
|
||||
// tmp dir for data root (db & config)
|
||||
tmpDir, err := ioutil.TempDir("", "portbase-testing-")
|
||||
// initialize data dir
|
||||
if err == nil {
|
||||
err = dataroot.Initialize(tmpDir, 0755)
|
||||
}
|
||||
// handle setup error
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// register
|
||||
|
||||
baseModule := modules.Register("base", nil, nil, nil)
|
||||
Register(
|
||||
"base",
|
||||
"Base",
|
||||
"Framework Groundwork",
|
||||
baseModule,
|
||||
"config:base",
|
||||
nil,
|
||||
)
|
||||
|
||||
feature1 := modules.Register("feature1", nil, nil, nil)
|
||||
Register(
|
||||
"feature-one",
|
||||
"Feature One",
|
||||
"Provides feature one",
|
||||
feature1,
|
||||
"config:feature1",
|
||||
&config.Option{
|
||||
Name: "Enable Feature One",
|
||||
Key: "config:subsystems/feature1",
|
||||
Description: "This option enables feature 1",
|
||||
OptType: config.OptTypeBool,
|
||||
DefaultValue: false,
|
||||
},
|
||||
)
|
||||
sub1 := subsystemsMap["Feature One"]
|
||||
|
||||
feature2 := modules.Register("feature2", nil, nil, nil)
|
||||
Register(
|
||||
"feature-two",
|
||||
"Feature Two",
|
||||
"Provides feature two",
|
||||
feature2,
|
||||
"config:feature2",
|
||||
&config.Option{
|
||||
Name: "Enable Feature One",
|
||||
Key: "config:subsystems/feature2",
|
||||
Description: "This option enables feature 2",
|
||||
OptType: config.OptTypeBool,
|
||||
DefaultValue: false,
|
||||
},
|
||||
)
|
||||
|
||||
// start
|
||||
err = modules.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test
|
||||
|
||||
// let module fail
|
||||
feature1.Error("test-fail", "Testing Fail")
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if sub1.FailureStatus != modules.FailureError {
|
||||
t.Fatal("error did not propagate")
|
||||
}
|
||||
|
||||
// resolve
|
||||
feature1.Resolve("test-fail")
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if sub1.FailureStatus != modules.FailureNone {
|
||||
t.Fatal("error resolving did not propagate")
|
||||
}
|
||||
|
||||
// update settings
|
||||
err = config.SetConfigOption("config:subsystems/feature2", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if !feature2.Enabled() {
|
||||
t.Fatal("failed to enable feature2")
|
||||
}
|
||||
if feature2.Status() != modules.StatusOnline {
|
||||
t.Fatal("feature2 did not start")
|
||||
}
|
||||
|
||||
// update settings
|
||||
err = config.SetConfigOption("config:subsystems/feature2", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if feature2.Enabled() {
|
||||
t.Fatal("failed to disable feature2")
|
||||
}
|
||||
if feature2.Status() != modules.StatusOffline {
|
||||
t.Fatal("feature2 did not stop")
|
||||
}
|
||||
|
||||
// clean up and exit
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
159
modules/tasks.go
159
modules/tasks.go
|
@ -8,21 +8,24 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
// Task is managed task bound to a module.
|
||||
type Task struct {
|
||||
name string
|
||||
module *Module
|
||||
taskFn func(context.Context, *Task)
|
||||
taskFn func(context.Context, *Task) error
|
||||
|
||||
queued bool
|
||||
canceled bool
|
||||
executing bool
|
||||
cancelFunc func()
|
||||
queued bool
|
||||
canceled bool
|
||||
executing bool
|
||||
|
||||
// these are populated at task creation
|
||||
// ctx is canceled when task is shutdown -> all tasks become canceled
|
||||
ctx context.Context
|
||||
cancelCtx func()
|
||||
|
||||
executeAt time.Time
|
||||
repeat time.Duration
|
||||
|
@ -59,25 +62,46 @@ const (
|
|||
)
|
||||
|
||||
// NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed.
|
||||
func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task {
|
||||
func (m *Module) NewTask(name string, fn func(context.Context, *Task) error) *Task {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if m == nil {
|
||||
log.Errorf(`modules: cannot create task "%s" with nil module`, name)
|
||||
return &Task{
|
||||
name: name,
|
||||
module: &Module{Name: "[NONE]"},
|
||||
canceled: true,
|
||||
}
|
||||
}
|
||||
if m.Ctx == nil || !m.OnlineSoon() {
|
||||
log.Errorf(`modules: tasks should only be started when the module is online or starting`)
|
||||
return &Task{
|
||||
name: name,
|
||||
module: m,
|
||||
canceled: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &Task{
|
||||
// create new task
|
||||
new := &Task{
|
||||
name: name,
|
||||
module: m,
|
||||
taskFn: fn,
|
||||
maxDelay: defaultMaxDelay,
|
||||
}
|
||||
|
||||
// create context
|
||||
new.ctx, new.cancelCtx = context.WithCancel(m.Ctx)
|
||||
|
||||
return new
|
||||
}
|
||||
|
||||
func (t *Task) isActive() bool {
|
||||
if t.module == nil {
|
||||
if t.canceled {
|
||||
return false
|
||||
}
|
||||
|
||||
return !t.canceled && !t.module.ShutdownInProgress()
|
||||
return t.module.OnlineSoon()
|
||||
}
|
||||
|
||||
func (t *Task) prepForQueueing() (ok bool) {
|
||||
|
@ -197,45 +221,15 @@ func (t *Task) Repeat(interval time.Duration) *Task {
|
|||
func (t *Task) Cancel() {
|
||||
t.lock.Lock()
|
||||
t.canceled = true
|
||||
if t.cancelFunc != nil {
|
||||
t.cancelFunc()
|
||||
if t.cancelCtx != nil {
|
||||
t.cancelCtx()
|
||||
}
|
||||
t.lock.Unlock()
|
||||
}
|
||||
|
||||
func (t *Task) runWithLocking() {
|
||||
if t.module == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for good timeslot regarding microtasks
|
||||
select {
|
||||
case <-taskTimeslot:
|
||||
case <-time.After(maxTimeslotWait):
|
||||
}
|
||||
|
||||
t.lock.Lock()
|
||||
|
||||
// check state, return if already executing or inactive
|
||||
if t.executing || !t.isActive() {
|
||||
t.lock.Unlock()
|
||||
return
|
||||
}
|
||||
t.executing = true
|
||||
|
||||
// get list elements
|
||||
queueElement := t.queueElement
|
||||
prioritizedQueueElement := t.prioritizedQueueElement
|
||||
scheduleListElement := t.scheduleListElement
|
||||
|
||||
// create context
|
||||
var taskCtx context.Context
|
||||
taskCtx, t.cancelFunc = context.WithCancel(t.module.Ctx)
|
||||
|
||||
t.lock.Unlock()
|
||||
|
||||
func (t *Task) removeFromQueues() {
|
||||
// remove from lists
|
||||
if queueElement != nil {
|
||||
if t.queueElement != nil {
|
||||
queuesLock.Lock()
|
||||
taskQueue.Remove(t.queueElement)
|
||||
queuesLock.Unlock()
|
||||
|
@ -243,7 +237,7 @@ func (t *Task) runWithLocking() {
|
|||
t.queueElement = nil
|
||||
t.lock.Unlock()
|
||||
}
|
||||
if prioritizedQueueElement != nil {
|
||||
if t.prioritizedQueueElement != nil {
|
||||
queuesLock.Lock()
|
||||
prioritizedTaskQueue.Remove(t.prioritizedQueueElement)
|
||||
queuesLock.Unlock()
|
||||
|
@ -251,7 +245,7 @@ func (t *Task) runWithLocking() {
|
|||
t.prioritizedQueueElement = nil
|
||||
t.lock.Unlock()
|
||||
}
|
||||
if scheduleListElement != nil {
|
||||
if t.scheduleListElement != nil {
|
||||
scheduleLock.Lock()
|
||||
taskSchedule.Remove(t.scheduleListElement)
|
||||
scheduleLock.Unlock()
|
||||
|
@ -259,14 +253,62 @@ func (t *Task) runWithLocking() {
|
|||
t.scheduleListElement = nil
|
||||
t.lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task) runWithLocking() {
|
||||
t.lock.Lock()
|
||||
|
||||
// check if task is already executing
|
||||
if t.executing {
|
||||
t.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// check if task is active
|
||||
if !t.isActive() {
|
||||
t.removeFromQueues()
|
||||
t.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// check if module was stopped
|
||||
select {
|
||||
case <-t.ctx.Done(): // check if module is stopped
|
||||
t.removeFromQueues()
|
||||
t.lock.Unlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
t.executing = true
|
||||
t.lock.Unlock()
|
||||
|
||||
// wait for good timeslot regarding microtasks
|
||||
select {
|
||||
case <-taskTimeslot:
|
||||
case <-time.After(maxTimeslotWait):
|
||||
}
|
||||
|
||||
// wait for module start
|
||||
if !t.module.Online() {
|
||||
if t.module.OnlineSoon() {
|
||||
// wait
|
||||
<-t.module.StartCompleted()
|
||||
} else {
|
||||
t.lock.Lock()
|
||||
t.removeFromQueues()
|
||||
t.lock.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// add to queue workgroup
|
||||
queueWg.Add(1)
|
||||
|
||||
go t.executeWithLocking(taskCtx, t.cancelFunc)
|
||||
go t.executeWithLocking()
|
||||
go func() {
|
||||
select {
|
||||
case <-taskCtx.Done():
|
||||
case <-t.ctx.Done():
|
||||
case <-time.After(maxExecutionWait):
|
||||
}
|
||||
// complete queue worker (early) to allow next worker
|
||||
|
@ -274,7 +316,7 @@ func (t *Task) runWithLocking() {
|
|||
}()
|
||||
}
|
||||
|
||||
func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) {
|
||||
func (t *Task) executeWithLocking() {
|
||||
// start for module
|
||||
// hint: only queueWg global var is important for scheduling, others can be set here
|
||||
atomic.AddInt32(t.module.taskCnt, 1)
|
||||
|
@ -306,11 +348,16 @@ func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) {
|
|||
t.lock.Unlock()
|
||||
|
||||
// notify that we finished
|
||||
cancelFunc()
|
||||
if t.cancelCtx != nil {
|
||||
t.cancelCtx()
|
||||
}
|
||||
}()
|
||||
|
||||
// run
|
||||
t.taskFn(ctx, t)
|
||||
err := t.taskFn(t.ctx, t)
|
||||
if err != nil {
|
||||
log.Errorf("%s: task %s failed: %s", t.module.Name, t.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task) getExecuteAtWithLocking() time.Time {
|
||||
|
@ -320,6 +367,10 @@ func (t *Task) getExecuteAtWithLocking() time.Time {
|
|||
}
|
||||
|
||||
func (t *Task) addToSchedule() {
|
||||
if !t.isActive() {
|
||||
return
|
||||
}
|
||||
|
||||
scheduleLock.Lock()
|
||||
defer scheduleLock.Unlock()
|
||||
// defer printTaskList(taskSchedule) // for debugging
|
||||
|
@ -395,7 +446,7 @@ func taskQueueHandler() {
|
|||
queueWg.Wait()
|
||||
|
||||
// check for shutdown
|
||||
if shutdownSignalClosed.IsSet() {
|
||||
if shutdownFlag.IsSet() {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -35,22 +35,29 @@ func init() {
|
|||
var qtWg sync.WaitGroup
|
||||
var qtOutputChannel chan string
|
||||
var qtSleepDuration time.Duration
|
||||
var qtModule = initNewModule("task test module", nil, nil, nil)
|
||||
var qtModule *Module
|
||||
|
||||
func init() {
|
||||
qtModule = initNewModule("task test module", nil, nil, nil)
|
||||
qtModule.status = StatusOnline
|
||||
}
|
||||
|
||||
// functions
|
||||
func queuedTaskTester(s string) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) error {
|
||||
time.Sleep(qtSleepDuration * 2)
|
||||
qtOutputChannel <- s
|
||||
qtWg.Done()
|
||||
return nil
|
||||
}).Queue()
|
||||
}
|
||||
|
||||
func prioritizedTaskTester(s string) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) error {
|
||||
time.Sleep(qtSleepDuration * 2)
|
||||
qtOutputChannel <- s
|
||||
qtWg.Done()
|
||||
return nil
|
||||
}).Prioritize()
|
||||
}
|
||||
|
||||
|
@ -109,10 +116,11 @@ var stWaitCh chan bool
|
|||
|
||||
// functions
|
||||
func scheduledTaskTester(s string, sched time.Time) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
|
||||
qtModule.NewTask(s, func(ctx context.Context, t *Task) error {
|
||||
time.Sleep(stSleepDuration)
|
||||
stOutputChannel <- s
|
||||
stWg.Done()
|
||||
return nil
|
||||
}).Schedule(sched)
|
||||
}
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn
|
|||
lastFail := time.Now()
|
||||
|
||||
for {
|
||||
if m.ShutdownInProgress() {
|
||||
if m.IsStopping() {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -111,26 +111,26 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
|
|||
}
|
||||
|
||||
// Put stores a record in the database.
|
||||
func (s *StorageInterface) Put(r record.Record) error {
|
||||
func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
|
||||
// record is already locked!
|
||||
key := r.DatabaseKey()
|
||||
n, err := EnsureNotification(r)
|
||||
|
||||
if err != nil {
|
||||
return ErrInvalidData
|
||||
return nil, ErrInvalidData
|
||||
}
|
||||
|
||||
// transform key
|
||||
if strings.HasPrefix(key, "all/") {
|
||||
key = strings.TrimPrefix(key, "all/")
|
||||
} else {
|
||||
return ErrInvalidPath
|
||||
return nil, ErrInvalidPath
|
||||
}
|
||||
|
||||
// continue in goroutine
|
||||
go UpdateNotification(n, key)
|
||||
|
||||
return nil
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// UpdateNotification updates a notification with input from a database action. Notification will not be saved/propagated if there is no valid change.
|
||||
|
|
|
@ -11,7 +11,7 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
module = modules.Register("notifications", nil, start, nil, "base", "database")
|
||||
module = modules.Register("notifications", nil, start, nil, "database", "base")
|
||||
}
|
||||
|
||||
func start() error {
|
||||
|
|
|
@ -93,6 +93,7 @@ func (n *Notification) Save() *Notification {
|
|||
nots[n.ID] = n
|
||||
|
||||
// push update
|
||||
log.Tracef("notifications: pushing update for %s to subscribers", n.Key())
|
||||
dbController.PushUpdate(n)
|
||||
|
||||
// persist
|
||||
|
@ -152,10 +153,11 @@ func (n *Notification) MakeAck() *Notification {
|
|||
// Response waits for the user to respond to the notification and returns the selected action.
|
||||
func (n *Notification) Response() <-chan string {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
if n.actionTrigger == nil {
|
||||
n.actionTrigger = make(chan string)
|
||||
}
|
||||
n.lock.Unlock()
|
||||
|
||||
return n.actionTrigger
|
||||
}
|
||||
|
@ -213,10 +215,11 @@ func (n *Notification) Delete() error {
|
|||
// Expired notifies the caller when the notification has expired.
|
||||
func (n *Notification) Expired() <-chan struct{} {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
if n.expiredTrigger == nil {
|
||||
n.expiredTrigger = make(chan struct{})
|
||||
}
|
||||
n.lock.Unlock()
|
||||
|
||||
return n.expiredTrigger
|
||||
}
|
||||
|
|
40
portbase.go
40
portbase.go
|
@ -1,49 +1,19 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/safing/portbase/info"
|
||||
"github.com/safing/portbase/log"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/run"
|
||||
|
||||
// include packages here
|
||||
_ "github.com/safing/portbase/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// Set Info
|
||||
info.Set("Portbase", "0.0.1", "GPLv3", false)
|
||||
|
||||
// Start
|
||||
err := modules.Start()
|
||||
if err != nil {
|
||||
if err == modules.ErrCleanExit {
|
||||
os.Exit(0)
|
||||
} else {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown
|
||||
// catch interrupt for clean shutdown
|
||||
signalCh := make(chan os.Signal, 3)
|
||||
signal.Notify(
|
||||
signalCh,
|
||||
os.Interrupt,
|
||||
syscall.SIGHUP,
|
||||
syscall.SIGINT,
|
||||
syscall.SIGTERM,
|
||||
syscall.SIGQUIT,
|
||||
)
|
||||
select {
|
||||
case <-signalCh:
|
||||
fmt.Println(" <INTERRUPT>")
|
||||
log.Warning("main: program was interrupted, shutting down.")
|
||||
_ = modules.Shutdown()
|
||||
case <-modules.ShuttingDown():
|
||||
}
|
||||
|
||||
// Run
|
||||
os.Exit(run.Run())
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
modules.Register("random", prep, Start, nil, "base")
|
||||
modules.Register("random", prep, Start, nil)
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
|
|
116
template/module.go
Normal file
116
template/module.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package template
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/safing/portbase/config"
|
||||
"github.com/safing/portbase/modules"
|
||||
"github.com/safing/portbase/modules/subsystems"
|
||||
)
|
||||
|
||||
const (
|
||||
eventStateUpdate = "state update"
|
||||
)
|
||||
|
||||
var (
|
||||
module *modules.Module
|
||||
)
|
||||
|
||||
func init() {
|
||||
// register base module, for database initialization
|
||||
modules.Register("base", nil, nil, nil)
|
||||
|
||||
// register module
|
||||
module = modules.Register("template", prep, start, stop) // add dependencies...
|
||||
subsystems.Register(
|
||||
"template-subsystem", // ID
|
||||
"Template Subsystem", // name
|
||||
"This subsystem is a template for quick setup", // description
|
||||
module,
|
||||
"config:template", // key space for configuration options registered
|
||||
&config.Option{
|
||||
Name: "Enable Template Subsystem",
|
||||
Key: "config:subsystems/template",
|
||||
Description: "This option enables the Template Subsystem [TEMPLATE]",
|
||||
OptType: config.OptTypeBool,
|
||||
DefaultValue: false,
|
||||
},
|
||||
)
|
||||
|
||||
// register events that other modules can subscribe to
|
||||
module.RegisterEvent(eventStateUpdate)
|
||||
}
|
||||
|
||||
func prep() error {
|
||||
// register options
|
||||
err := config.Register(&config.Option{
|
||||
Name: "language",
|
||||
Key: "config:template/language",
|
||||
Description: "Sets the language for the template [TEMPLATE]",
|
||||
OptType: config.OptTypeString,
|
||||
ExpertiseLevel: config.ExpertiseLevelUser, // default
|
||||
ReleaseLevel: config.ReleaseLevelStable, // default
|
||||
RequiresRestart: false, // default
|
||||
DefaultValue: "en",
|
||||
ValidationRegex: "^[a-z]{2}$",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// register event hooks
|
||||
// do this in prep() and not in start(), as we don't want to register again if module is turned off and on again
|
||||
err = module.RegisterEventHook(
|
||||
"template", // event source module name
|
||||
"state update", // event source name
|
||||
"react to state changes", // description of hook function
|
||||
eventHandler, // hook function
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// hint: event hooks and tasks will not be run if module isn't online
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() error {
|
||||
// register tasks
|
||||
module.NewTask("do something", taskFn).Queue()
|
||||
|
||||
// start service worker
|
||||
module.StartServiceWorker("do something", 0, serviceWorker)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func serviceWorker(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
err := do()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func taskFn(ctx context.Context, task *modules.Task) error {
|
||||
return do()
|
||||
}
|
||||
|
||||
func eventHandler(ctx context.Context, data interface{}) error {
|
||||
return do()
|
||||
}
|
||||
|
||||
func do() error {
|
||||
return nil
|
||||
}
|
51
template/module_test.go
Normal file
51
template/module_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package template
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/safing/portbase/dataroot"
|
||||
"github.com/safing/portbase/modules"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// enable module for testing
|
||||
module.Enable()
|
||||
|
||||
// tmp dir for data root (db & config)
|
||||
tmpDir, err := ioutil.TempDir("", "portbase-testing-")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// initialize data dir
|
||||
err = dataroot.Initialize(tmpDir, 0755)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// start modules
|
||||
var exitCode int
|
||||
err = modules.Start()
|
||||
if err != nil {
|
||||
// starting failed
|
||||
fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err)
|
||||
exitCode = 1
|
||||
} else {
|
||||
// run tests
|
||||
exitCode = m.Run()
|
||||
}
|
||||
|
||||
// shutdown
|
||||
_ = modules.Shutdown()
|
||||
if modules.GetExitStatusCode() != 0 {
|
||||
exitCode = modules.GetExitStatusCode()
|
||||
fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err)
|
||||
}
|
||||
// clean up and exit
|
||||
os.RemoveAll(tmpDir)
|
||||
os.Exit(exitCode)
|
||||
}
|
|
@ -50,6 +50,10 @@ func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("error fetching url (%s): %s", downloadURL, resp.Status)
|
||||
}
|
||||
|
||||
// download and write file
|
||||
n, err := io.Copy(atomicFile, resp.Body)
|
||||
if err != nil {
|
||||
|
@ -96,6 +100,10 @@ func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte,
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("error fetching url (%s): %s", downloadURL, resp.Status)
|
||||
}
|
||||
|
||||
// download and write file
|
||||
buf := bytes.NewBuffer(make([]byte, 0, resp.ContentLength))
|
||||
n, err := io.Copy(buf, resp.Body)
|
||||
|
|
Loading…
Add table
Reference in a new issue