diff --git a/.golangci.yml b/.golangci.yml index 3c2d6b3..de08de7 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,5 +7,11 @@ linters: - funlen - whitespace - wsl - - godox + - gomnd +linters-settings: + godox: + # report any comments starting with keywords, this is useful for TODO or FIXME comments that + # might be left in the code accidentally and should be resolved before merging + keywords: + - FIXME diff --git a/Gopkg.lock b/Gopkg.lock index 11cdb8e..164b8c6 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -161,14 +161,6 @@ revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" version = "v0.8.1" -[[projects]] - branch = "develop" - digest = "1:d88649ff4a4a0746857dd9e39915aedddce2b08e442ac131a91e573cd45bde93" - name = "github.com/safing/portmaster" - packages = ["core/structure"] - pruneopts = "UT" - revision = "26c307b7a0db78d91b35ef9020706f106ebef8b6" - [[projects]] digest = "1:274f67cb6fed9588ea2521ecdac05a6d62a8c51c074c1fccc6a49a40ba80e925" name = "github.com/satori/go.uuid" @@ -332,7 +324,6 @@ "github.com/gorilla/mux", "github.com/gorilla/websocket", "github.com/hashicorp/go-version", - "github.com/safing/portmaster/core/structure", "github.com/satori/go.uuid", "github.com/seehuhn/fortuna", "github.com/shirou/gopsutil/host", diff --git a/api/database.go b/api/database.go index 106a8d9..8a365c9 100644 --- a/api/database.go +++ b/api/database.go @@ -243,8 +243,8 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) { if err == nil { data, err = r.Marshal(r, record.JSON) } - if err == nil { - api.send(opID, dbMsgTypeError, err.Error(), nil) //nolint:nilness // FIXME: possibly false positive (golangci-lint govet/nilness) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) return } api.send(opID, dbMsgTypeOk, r.Key(), data) @@ -384,9 +384,9 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) { default: api.send(opID, dbMsgTypeUpd, r.Key(), data) } - } else if sub.Err != nil { + } else { // sub feed ended - api.send(opID, dbMsgTypeError, sub.Err.Error(), nil) + api.send(opID, dbMsgTypeDone, "", nil) } } } @@ -435,13 +435,13 @@ func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create b return } - // FIXME: remove transition code - if data[0] != record.JSON { - typedData := make([]byte, len(data)+1) - typedData[0] = record.JSON - copy(typedData[1:], data) - data = typedData - } + // TODO - staged for deletion: remove transition code + // if data[0] != record.JSON { + // typedData := make([]byte, len(data)+1) + // typedData[0] = record.JSON + // copy(typedData[1:], data) + // data = typedData + // } r, err := record.NewWrapper(key, nil, data[0], data[1:]) if err != nil { diff --git a/api/main.go b/api/main.go index f0ea824..9d015f2 100644 --- a/api/main.go +++ b/api/main.go @@ -17,7 +17,7 @@ var ( ) func init() { - module = modules.Register("api", prep, start, stop, "base", "database", "config") + module = modules.Register("api", prep, start, stop, "database", "config") } func prep() error { diff --git a/config/database.go b/config/database.go index 460aa01..109e575 100644 --- a/config/database.go +++ b/config/database.go @@ -36,26 +36,30 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { } // Put stores a record in the database. -func (s *StorageInterface) Put(r record.Record) error { +func (s *StorageInterface) Put(r record.Record) (record.Record, error) { if r.Meta().Deleted > 0 { - return setConfigOption(r.DatabaseKey(), nil, false) + return r, setConfigOption(r.DatabaseKey(), nil, false) } acc := r.GetAccessor(r) if acc == nil { - return errors.New("invalid data") + return nil, errors.New("invalid data") } val, ok := acc.Get("Value") if !ok || val == nil { - return setConfigOption(r.DatabaseKey(), nil, false) + err := setConfigOption(r.DatabaseKey(), nil, false) + if err != nil { + return nil, err + } + return s.Get(r.DatabaseKey()) } optionsLock.RLock() option, ok := options[r.DatabaseKey()] optionsLock.RUnlock() if !ok { - return errors.New("config option does not exist") + return nil, errors.New("config option does not exist") } var value interface{} @@ -70,14 +74,14 @@ func (s *StorageInterface) Put(r record.Record) error { value, ok = acc.GetBool("Value") } if !ok { - return errors.New("received invalid value in \"Value\"") + return nil, errors.New("received invalid value in \"Value\"") } err := setConfigOption(r.DatabaseKey(), value, false) if err != nil { - return err + return nil, err } - return nil + return option.Export() } // Delete deletes a record from the database. diff --git a/config/expertise.go b/config/expertise.go index 9f515d4..277758b 100644 --- a/config/expertise.go +++ b/config/expertise.go @@ -5,6 +5,8 @@ package config import ( "fmt" "sync/atomic" + + "github.com/tevino/abool" ) // Expertise Level constants @@ -21,7 +23,9 @@ const ( ) var ( - expertiseLevel *int32 + expertiseLevel *int32 + expertiseLevelOption *Option + expertiseLevelOptionFlag = abool.New() ) func init() { @@ -32,7 +36,7 @@ func init() { } func registerExpertiseLevelOption() { - err := Register(&Option{ + expertiseLevelOption = &Option{ Name: "Expertise Level", Key: expertiseLevelKey, Description: "The Expertise Level controls the perceived complexity. Higher settings will show you more complex settings and information. This might also affect various other things relying on this setting. Modified settings in higher expertise levels stay in effect when switching back. (Unlike the Release Level)", @@ -46,15 +50,31 @@ func registerExpertiseLevelOption() { ExternalOptType: "string list", ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ExpertiseLevelNameUser, ExpertiseLevelNameExpert, ExpertiseLevelNameDeveloper), - }) + } + + err := Register(expertiseLevelOption) if err != nil { panic(err) } + + expertiseLevelOptionFlag.Set() } func updateExpertiseLevel() { - new := findStringValue(expertiseLevelKey, "") - switch new { + // check if already registered + if !expertiseLevelOptionFlag.IsSet() { + return + } + // get value + value := expertiseLevelOption.activeFallbackValue + if expertiseLevelOption.activeValue != nil { + value = expertiseLevelOption.activeValue + } + if expertiseLevelOption.activeDefaultValue != nil { + value = expertiseLevelOption.activeDefaultValue + } + // set atomic value + switch value.stringVal { case ExpertiseLevelNameUser: atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser)) case ExpertiseLevelNameExpert: diff --git a/config/get-safe.go b/config/get-safe.go index a8eaff8..d48e94d 100644 --- a/config/get-safe.go +++ b/config/get-safe.go @@ -12,14 +12,24 @@ var ( // GetAsString returns a function that returns the wanted string with high performance. func (cs *safe) GetAsString(name string, fallback string) StringOption { valid := getValidityFlag() - value := findStringValue(name, fallback) + option, valueCache := getValueCache(name, nil, OptTypeString) + value := fallback + if valueCache != nil { + value = valueCache.stringVal + } var lock sync.Mutex + return func() string { lock.Lock() defer lock.Unlock() if !valid.IsSet() { valid = getValidityFlag() - value = findStringValue(name, fallback) + option, valueCache = getValueCache(name, option, OptTypeString) + if valueCache != nil { + value = valueCache.stringVal + } else { + value = fallback + } } return value } @@ -28,14 +38,24 @@ func (cs *safe) GetAsString(name string, fallback string) StringOption { // GetAsStringArray returns a function that returns the wanted string with high performance. func (cs *safe) GetAsStringArray(name string, fallback []string) StringArrayOption { valid := getValidityFlag() - value := findStringArrayValue(name, fallback) + option, valueCache := getValueCache(name, nil, OptTypeStringArray) + value := fallback + if valueCache != nil { + value = valueCache.stringArrayVal + } var lock sync.Mutex + return func() []string { lock.Lock() defer lock.Unlock() if !valid.IsSet() { valid = getValidityFlag() - value = findStringArrayValue(name, fallback) + option, valueCache = getValueCache(name, option, OptTypeStringArray) + if valueCache != nil { + value = valueCache.stringArrayVal + } else { + value = fallback + } } return value } @@ -44,14 +64,24 @@ func (cs *safe) GetAsStringArray(name string, fallback []string) StringArrayOpti // GetAsInt returns a function that returns the wanted int with high performance. func (cs *safe) GetAsInt(name string, fallback int64) IntOption { valid := getValidityFlag() - value := findIntValue(name, fallback) + option, valueCache := getValueCache(name, nil, OptTypeInt) + value := fallback + if valueCache != nil { + value = valueCache.intVal + } var lock sync.Mutex + return func() int64 { lock.Lock() defer lock.Unlock() if !valid.IsSet() { valid = getValidityFlag() - value = findIntValue(name, fallback) + option, valueCache = getValueCache(name, option, OptTypeInt) + if valueCache != nil { + value = valueCache.intVal + } else { + value = fallback + } } return value } @@ -60,14 +90,24 @@ func (cs *safe) GetAsInt(name string, fallback int64) IntOption { // GetAsBool returns a function that returns the wanted int with high performance. func (cs *safe) GetAsBool(name string, fallback bool) BoolOption { valid := getValidityFlag() - value := findBoolValue(name, fallback) + option, valueCache := getValueCache(name, nil, OptTypeBool) + value := fallback + if valueCache != nil { + value = valueCache.boolVal + } var lock sync.Mutex + return func() bool { lock.Lock() defer lock.Unlock() if !valid.IsSet() { valid = getValidityFlag() - value = findBoolValue(name, fallback) + option, valueCache = getValueCache(name, option, OptTypeBool) + if valueCache != nil { + value = valueCache.boolVal + } else { + value = fallback + } } return value } diff --git a/config/get.go b/config/get.go index 7bc0501..5e5755f 100644 --- a/config/get.go +++ b/config/get.go @@ -15,14 +15,59 @@ type ( BoolOption func() bool ) +func getValueCache(name string, option *Option, requestedType uint8) (*Option, *valueCache) { + // get option + if option == nil { + var ok bool + optionsLock.RLock() + option, ok = options[name] + optionsLock.RUnlock() + if !ok { + log.Errorf("config: request for unregistered option: %s", name) + return nil, nil + } + } + + // check type + if requestedType != option.OptType { + log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(option.OptType)) + return option, nil + } + + // lock option + option.Lock() + defer option.Unlock() + + // check release level + if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil { + return option, option.activeValue + } + + if option.activeDefaultValue != nil { + return option, option.activeDefaultValue + } + + return option, option.activeFallbackValue +} + // GetAsString returns a function that returns the wanted string with high performance. func GetAsString(name string, fallback string) StringOption { valid := getValidityFlag() - value := findStringValue(name, fallback) + option, valueCache := getValueCache(name, nil, OptTypeString) + value := fallback + if valueCache != nil { + value = valueCache.stringVal + } + return func() string { if !valid.IsSet() { valid = getValidityFlag() - value = findStringValue(name, fallback) + option, valueCache = getValueCache(name, option, OptTypeString) + if valueCache != nil { + value = valueCache.stringVal + } else { + value = fallback + } } return value } @@ -31,11 +76,21 @@ func GetAsString(name string, fallback string) StringOption { // GetAsStringArray returns a function that returns the wanted string with high performance. func GetAsStringArray(name string, fallback []string) StringArrayOption { valid := getValidityFlag() - value := findStringArrayValue(name, fallback) + option, valueCache := getValueCache(name, nil, OptTypeStringArray) + value := fallback + if valueCache != nil { + value = valueCache.stringArrayVal + } + return func() []string { if !valid.IsSet() { valid = getValidityFlag() - value = findStringArrayValue(name, fallback) + option, valueCache = getValueCache(name, option, OptTypeStringArray) + if valueCache != nil { + value = valueCache.stringArrayVal + } else { + value = fallback + } } return value } @@ -44,11 +99,21 @@ func GetAsStringArray(name string, fallback []string) StringArrayOption { // GetAsInt returns a function that returns the wanted int with high performance. func GetAsInt(name string, fallback int64) IntOption { valid := getValidityFlag() - value := findIntValue(name, fallback) + option, valueCache := getValueCache(name, nil, OptTypeInt) + value := fallback + if valueCache != nil { + value = valueCache.intVal + } + return func() int64 { if !valid.IsSet() { valid = getValidityFlag() - value = findIntValue(name, fallback) + option, valueCache = getValueCache(name, option, OptTypeInt) + if valueCache != nil { + value = valueCache.intVal + } else { + value = fallback + } } return value } @@ -57,18 +122,28 @@ func GetAsInt(name string, fallback int64) IntOption { // GetAsBool returns a function that returns the wanted int with high performance. func GetAsBool(name string, fallback bool) BoolOption { valid := getValidityFlag() - value := findBoolValue(name, fallback) + option, valueCache := getValueCache(name, nil, OptTypeBool) + value := fallback + if valueCache != nil { + value = valueCache.boolVal + } + return func() bool { if !valid.IsSet() { valid = getValidityFlag() - value = findBoolValue(name, fallback) + option, valueCache = getValueCache(name, option, OptTypeBool) + if valueCache != nil { + value = valueCache.boolVal + } else { + value = fallback + } } return value } } -// findValue find the correct value in the user or default config. -func findValue(key string) interface{} { +/* +func getAndFindValue(key string) interface{} { optionsLock.RLock() option, ok := options[key] optionsLock.RUnlock() @@ -77,6 +152,13 @@ func findValue(key string) interface{} { return nil } + return option.findValue() +} +*/ + +/* +// findValue finds the preferred value in the user or default config. +func (option *Option) findValue() interface{} { // lock option option.Lock() defer option.Unlock() @@ -91,88 +173,4 @@ func findValue(key string) interface{} { return option.DefaultValue } - -// findStringValue validates and returns the value with the given key. -func findStringValue(key string, fallback string) (value string) { - result := findValue(key) - if result == nil { - return fallback - } - v, ok := result.(string) - if ok { - return v - } - return fallback -} - -// findStringArrayValue validates and returns the value with the given key. -func findStringArrayValue(key string, fallback []string) (value []string) { - result := findValue(key) - if result == nil { - return fallback - } - - v, ok := result.([]interface{}) - if ok { - new := make([]string, len(v)) - for i, val := range v { - s, ok := val.(string) - if ok { - new[i] = s - } else { - return fallback - } - } - return new - } - - return fallback -} - -// findIntValue validates and returns the value with the given key. -func findIntValue(key string, fallback int64) (value int64) { - result := findValue(key) - if result == nil { - return fallback - } - switch v := result.(type) { - case int: - return int64(v) - case int8: - return int64(v) - case int16: - return int64(v) - case int32: - return int64(v) - case int64: - return v - case uint: - return int64(v) - case uint8: - return int64(v) - case uint16: - return int64(v) - case uint32: - return int64(v) - case uint64: - return int64(v) - case float32: - return int64(v) - case float64: - return int64(v) - } - return fallback -} - -// findBoolValue validates and returns the value with the given key. -func findBoolValue(key string, fallback bool) (value bool) { - result := findValue(key) - if result == nil { - return fallback - } - v, ok := result.(bool) - if ok { - return v - } - return fallback -} +*/ diff --git a/config/get_test.go b/config/get_test.go index 3028619..bef788a 100644 --- a/config/get_test.go +++ b/config/get_test.go @@ -1,6 +1,7 @@ package config import ( + "encoding/json" "testing" "github.com/safing/portbase/log" @@ -39,7 +40,7 @@ func quickRegister(t *testing.T, key string, optType uint8, defaultValue interfa } } -func TestGet(t *testing.T) { +func TestGet(t *testing.T) { //nolint:gocognit // reset options = make(map[string]*Option) @@ -48,41 +49,41 @@ func TestGet(t *testing.T) { t.Fatal(err) } - quickRegister(t, "monkey", OptTypeInt, -1) + quickRegister(t, "monkey", OptTypeString, "c") quickRegister(t, "zebras/zebra", OptTypeStringArray, []string{"a", "b"}) quickRegister(t, "elephant", OptTypeInt, -1) quickRegister(t, "hot", OptTypeBool, false) quickRegister(t, "cold", OptTypeBool, true) err = parseAndSetConfig(` - { - "monkey": "1", + { + "monkey": "a", "zebras": { "zebra": ["black", "white"] }, - "elephant": 2, + "elephant": 2, "hot": true, "cold": false - } - `) + } + `) if err != nil { t.Fatal(err) } err = parseAndSetDefaultConfig(` - { - "monkey": "0", - "snake": "0", - "elephant": 0 - } - `) + { + "monkey": "b", + "snake": "0", + "elephant": 0 + } + `) if err != nil { t.Fatal(err) } monkey := GetAsString("monkey", "none") - if monkey() != "1" { - t.Errorf("monkey should be 1, is %s", monkey()) + if monkey() != "a" { + t.Errorf("monkey should be a, is %s", monkey()) } zebra := GetAsStringArray("zebras/zebra", []string{}) @@ -106,10 +107,10 @@ func TestGet(t *testing.T) { } err = parseAndSetConfig(` - { - "monkey": "3" - } - `) + { + "monkey": "3" + } + `) if err != nil { t.Fatal(err) } @@ -131,6 +132,53 @@ func TestGet(t *testing.T) { GetAsInt("elephant", -1)() GetAsBool("hot", false)() + // perspective + + // load data + pLoaded := make(map[string]interface{}) + err = json.Unmarshal([]byte(`{ + "monkey": "a", + "zebras": { + "zebra": ["black", "white"] + }, + "elephant": 2, + "hot": true, + "cold": false + }`), &pLoaded) + if err != nil { + t.Fatal(err) + } + + // create + p, err := NewPerspective(pLoaded) + if err != nil { + t.Fatal(err) + } + + monkeyVal, ok := p.GetAsString("monkey") + if !ok || monkeyVal != "a" { + t.Errorf("[perspective] monkey should be a, is %+v", monkeyVal) + } + + zebraVal, ok := p.GetAsStringArray("zebras/zebra") + if !ok || len(zebraVal) != 2 || zebraVal[0] != "black" || zebraVal[1] != "white" { + t.Errorf("[perspective] zebra should be [\"black\", \"white\"], is %+v", zebraVal) + } + + elephantVal, ok := p.GetAsInt("elephant") + if !ok || elephantVal != 2 { + t.Errorf("[perspective] elephant should be 2, is %+v", elephantVal) + } + + hotVal, ok := p.GetAsBool("hot") + if !ok || !hotVal { + t.Errorf("[perspective] hot should be true, is %+v", hotVal) + } + + coldVal, ok := p.GetAsBool("cold") + if !ok || coldVal { + t.Errorf("[perspective] cold should be false, is %+v", coldVal) + } } func TestReleaseLevel(t *testing.T) { @@ -236,11 +284,9 @@ func BenchmarkGetAsStringCached(b *testing.B) { options = make(map[string]*Option) // Setup - err := parseAndSetConfig(` - { - "monkey": "banana" - } - `) + err := parseAndSetConfig(`{ + "monkey": "banana" + }`) if err != nil { b.Fatal(err) } @@ -257,11 +303,9 @@ func BenchmarkGetAsStringCached(b *testing.B) { func BenchmarkGetAsStringRefetch(b *testing.B) { // Setup - err := parseAndSetConfig(` - { - "monkey": "banana" - } - `) + err := parseAndSetConfig(`{ + "monkey": "banana" + }`) if err != nil { b.Fatal(err) } @@ -271,38 +315,34 @@ func BenchmarkGetAsStringRefetch(b *testing.B) { // Start benchmark for i := 0; i < b.N; i++ { - findStringValue("monkey", "no banana") + getValueCache("monkey", nil, OptTypeString) } } func BenchmarkGetAsIntCached(b *testing.B) { // Setup - err := parseAndSetConfig(` - { - "monkey": 1 - } - `) + err := parseAndSetConfig(`{ + "elephant": 1 + }`) if err != nil { b.Fatal(err) } - monkey := GetAsInt("monkey", -1) + elephant := GetAsInt("elephant", -1) // Reset timer for precise results b.ResetTimer() // Start benchmark for i := 0; i < b.N; i++ { - monkey() + elephant() } } func BenchmarkGetAsIntRefetch(b *testing.B) { // Setup - err := parseAndSetConfig(` - { - "monkey": 1 - } - `) + err := parseAndSetConfig(`{ + "elephant": 1 + }`) if err != nil { b.Fatal(err) } @@ -312,6 +352,6 @@ func BenchmarkGetAsIntRefetch(b *testing.B) { // Start benchmark for i := 0; i < b.N; i++ { - findIntValue("monkey", 1) + getValueCache("elephant", nil, OptTypeInt) } } diff --git a/config/main.go b/config/main.go index a37f8f6..df432c2 100644 --- a/config/main.go +++ b/config/main.go @@ -5,9 +5,9 @@ import ( "os" "path/filepath" + "github.com/safing/portbase/dataroot" "github.com/safing/portbase/modules" "github.com/safing/portbase/utils" - "github.com/safing/portmaster/core/structure" ) const ( @@ -27,12 +27,12 @@ func SetDataRoot(root *utils.DirStructure) { } func init() { - module = modules.Register("config", prep, start, nil, "base", "database") + module = modules.Register("config", prep, start, nil, "database") module.RegisterEvent(configChangeEvent) } func prep() error { - SetDataRoot(structure.Root()) + SetDataRoot(dataroot.Root()) if dataRoot == nil { return errors.New("data root is not set") } diff --git a/config/option.go b/config/option.go index db33e3a..b4e928e 100644 --- a/config/option.go +++ b/config/option.go @@ -41,6 +41,7 @@ type Option struct { Name string Key string // in path format: category/sub/key Description string + Help string OptType uint8 ExpertiseLevel uint8 @@ -52,9 +53,10 @@ type Option struct { ExternalOptType string ValidationRegex string - activeValue interface{} // runtime value (loaded from config file or set by user) - activeDefaultValue interface{} // runtime default value (may be set internally) - compiledRegex *regexp.Regexp + activeValue *valueCache // runtime value (loaded from config file or set by user) + activeDefaultValue *valueCache // runtime default value (may be set internally) + activeFallbackValue *valueCache // default value from option registration + compiledRegex *regexp.Regexp } // Export expors an option to a Record. @@ -68,14 +70,14 @@ func (option *Option) Export() (record.Record, error) { } if option.activeValue != nil { - data, err = sjson.SetBytes(data, "Value", option.activeValue) + data, err = sjson.SetBytes(data, "Value", option.activeValue.getData(option)) if err != nil { return nil, err } } if option.activeDefaultValue != nil { - data, err = sjson.SetBytes(data, "DefaultValue", option.activeDefaultValue) + data, err = sjson.SetBytes(data, "DefaultValue", option.activeDefaultValue.getData(option)) if err != nil { return nil, err } diff --git a/config/persistence.go b/config/persistence.go index dd2a182..4ffbde3 100644 --- a/config/persistence.go +++ b/config/persistence.go @@ -47,7 +47,7 @@ func saveConfig() error { for key, option := range options { option.Lock() if option.activeValue != nil { - activeValues[key] = option.activeValue + activeValues[key] = option.activeValue.getData(option) } option.Unlock() } diff --git a/config/perspective.go b/config/perspective.go new file mode 100644 index 0000000..e863449 --- /dev/null +++ b/config/perspective.go @@ -0,0 +1,128 @@ +package config + +import ( + "fmt" + + "github.com/safing/portbase/log" +) + +// Perspective is a view on configuration data without interfering with the configuration system. +type Perspective struct { + config map[string]*perspectiveOption +} + +type perspectiveOption struct { + option *Option + valueCache *valueCache +} + +// NewPerspective parses the given config and returns it as a new perspective. +func NewPerspective(config map[string]interface{}) (*Perspective, error) { + // flatten config structure + flatten(config, config, "") + + perspective := &Perspective{ + config: make(map[string]*perspectiveOption), + } + var firstErr error + var errCnt int + + optionsLock.Lock() +optionsLoop: + for key, option := range options { + // get option key from config + configValue, ok := config[key] + if !ok { + continue + } + // validate value + valueCache, err := validateValue(option, configValue) + if err != nil { + errCnt++ + if firstErr == nil { + firstErr = err + } + continue optionsLoop + } + + // add to perspective + perspective.config[key] = &perspectiveOption{ + option: option, + valueCache: valueCache, + } + } + optionsLock.Unlock() + + if firstErr != nil { + if errCnt > 0 { + return perspective, fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr) + } + return perspective, firstErr + } + + return perspective, nil +} + +func (p *Perspective) getPerspectiveValueCache(name string, requestedType uint8) *valueCache { + // get option + pOption, ok := p.config[name] + if !ok { + // check if option exists at all + optionsLock.RLock() + _, ok = options[name] + optionsLock.RUnlock() + if !ok { + log.Errorf("config: request for unregistered option: %s", name) + } + return nil + } + + // check type + if requestedType != pOption.option.OptType { + log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(pOption.option.OptType)) + return nil + } + + // check release level + if pOption.option.ReleaseLevel > getReleaseLevel() { + return nil + } + + return pOption.valueCache +} + +// GetAsString returns a function that returns the wanted string with high performance. +func (p *Perspective) GetAsString(name string) (value string, ok bool) { + valueCache := p.getPerspectiveValueCache(name, OptTypeString) + if valueCache != nil { + return valueCache.stringVal, true + } + return "", false +} + +// GetAsStringArray returns a function that returns the wanted string with high performance. +func (p *Perspective) GetAsStringArray(name string) (value []string, ok bool) { + valueCache := p.getPerspectiveValueCache(name, OptTypeStringArray) + if valueCache != nil { + return valueCache.stringArrayVal, true + } + return nil, false +} + +// GetAsInt returns a function that returns the wanted int with high performance. +func (p *Perspective) GetAsInt(name string) (value int64, ok bool) { + valueCache := p.getPerspectiveValueCache(name, OptTypeInt) + if valueCache != nil { + return valueCache.intVal, true + } + return 0, false +} + +// GetAsBool returns a function that returns the wanted int with high performance. +func (p *Perspective) GetAsBool(name string) (value bool, ok bool) { + valueCache := p.getPerspectiveValueCache(name, OptTypeBool) + if valueCache != nil { + return valueCache.boolVal, true + } + return false, false +} diff --git a/config/registry.go b/config/registry.go index 0b2c221..11bfd76 100644 --- a/config/registry.go +++ b/config/registry.go @@ -26,14 +26,20 @@ func Register(option *Option) error { return fmt.Errorf("failed to register option: please set option.OptType") } + var err error + if option.ValidationRegex != "" { - var err error option.compiledRegex, err = regexp.Compile(option.ValidationRegex) if err != nil { return fmt.Errorf("config: could not compile option.ValidationRegex: %s", err) } } + option.activeFallbackValue, err = validateValue(option, option.DefaultValue) + if err != nil { + return fmt.Errorf("config: invalid default value: %s", err) + } + optionsLock.Lock() defer optionsLock.Unlock() options[option.Key] = option diff --git a/config/registry_test.go b/config/registry_test.go index 66fcb7f..3f5593a 100644 --- a/config/registry_test.go +++ b/config/registry_test.go @@ -15,7 +15,7 @@ func TestRegistry(t *testing.T) { ReleaseLevel: ReleaseLevelStable, ExpertiseLevel: ExpertiseLevelUser, OptType: OptTypeString, - DefaultValue: "default", + DefaultValue: "water", ValidationRegex: "^(banana|water)$", }); err != nil { t.Error(err) diff --git a/config/release.go b/config/release.go index 14b239c..f72f93e 100644 --- a/config/release.go +++ b/config/release.go @@ -5,6 +5,8 @@ package config import ( "fmt" "sync/atomic" + + "github.com/tevino/abool" ) // Release Level constants @@ -21,7 +23,9 @@ const ( ) var ( - releaseLevel *int32 + releaseLevel *int32 + releaseLevelOption *Option + releaseLevelOptionFlag = abool.New() ) func init() { @@ -32,7 +36,7 @@ func init() { } func registerReleaseLevelOption() { - err := Register(&Option{ + releaseLevelOption = &Option{ Name: "Release Level", Key: releaseLevelKey, Description: "The Release Level changes which features are available to you. Some beta or experimental features are also available in the stable release channel. Unavailable settings are set to the default value.", @@ -46,15 +50,31 @@ func registerReleaseLevelOption() { ExternalOptType: "string list", ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ReleaseLevelNameStable, ReleaseLevelNameBeta, ReleaseLevelNameExperimental), - }) + } + + err := Register(releaseLevelOption) if err != nil { panic(err) } + + releaseLevelOptionFlag.Set() } func updateReleaseLevel() { - new := findStringValue(releaseLevelKey, "") - switch new { + // check if already registered + if !releaseLevelOptionFlag.IsSet() { + return + } + // get value + value := releaseLevelOption.activeFallbackValue + if releaseLevelOption.activeValue != nil { + value = releaseLevelOption.activeValue + } + if releaseLevelOption.activeDefaultValue != nil { + value = releaseLevelOption.activeDefaultValue + } + // set atomic value + switch value.stringVal { case ReleaseLevelNameStable: atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable)) case ReleaseLevelNameBeta: diff --git a/config/set.go b/config/set.go index 8e55714..0c007ac 100644 --- a/config/set.go +++ b/config/set.go @@ -19,6 +19,7 @@ var ( validityFlagLock sync.RWMutex ) +// getValidityFlag returns a flag that signifies if the configuration has been changed. This flag must not be changed, only read. func getValidityFlag() *abool.AtomicBool { validityFlagLock.RLock() defer validityFlagLock.RUnlock() @@ -41,14 +42,24 @@ func signalChanges() { // setConfig sets the (prioritized) user defined config. func setConfig(newValues map[string]interface{}) error { + var firstErr error + var errCnt int + optionsLock.Lock() for key, option := range options { newValue, ok := newValues[key] option.Lock() + option.activeValue = nil if ok { - option.activeValue = newValue - } else { - option.activeValue = nil + valueCache, err := validateValue(option, newValue) + if err == nil { + option.activeValue = valueCache + } else { + errCnt++ + if firstErr == nil { + firstErr = err + } + } } option.Unlock() } @@ -56,19 +67,37 @@ func setConfig(newValues map[string]interface{}) error { signalChanges() go pushFullUpdate() + + if firstErr != nil { + if errCnt > 0 { + return fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr) + } + return firstErr + } + return nil } // SetDefaultConfig sets the (fallback) default config. func SetDefaultConfig(newValues map[string]interface{}) error { + var firstErr error + var errCnt int + optionsLock.Lock() for key, option := range options { newValue, ok := newValues[key] option.Lock() + option.activeDefaultValue = nil if ok { - option.activeDefaultValue = newValue - } else { - option.activeDefaultValue = nil + valueCache, err := validateValue(option, newValue) + if err == nil { + option.activeDefaultValue = valueCache + } else { + errCnt++ + if firstErr == nil { + firstErr = err + } + } } option.Unlock() } @@ -76,51 +105,15 @@ func SetDefaultConfig(newValues map[string]interface{}) error { signalChanges() go pushFullUpdate() - return nil -} -func validateValue(option *Option, value interface{}) error { - switch v := value.(type) { - case string: - if option.OptType != OptTypeString { - return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v) + if firstErr != nil { + if errCnt > 0 { + return fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr) } - if option.compiledRegex != nil { - if !option.compiledRegex.MatchString(v) { - return fmt.Errorf("validation failed: string \"%s\" did not match regex for option %s", v, option.Key) - } - } - return nil - case []string: - if option.OptType != OptTypeStringArray { - return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v) - } - if option.compiledRegex != nil { - for pos, entry := range v { - if !option.compiledRegex.MatchString(entry) { - return fmt.Errorf("validation failed: string \"%s\" at index %d did not match regex for option %s", entry, pos, option.Key) - } - } - } - return nil - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - if option.OptType != OptTypeInt { - return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v) - } - if option.compiledRegex != nil { - if !option.compiledRegex.MatchString(fmt.Sprintf("%d", v)) { - return fmt.Errorf("validation failed: number \"%d\" did not match regex for option %s", v, option.Key) - } - } - return nil - case bool: - if option.OptType != OptTypeBool { - return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v) - } - return nil - default: - return fmt.Errorf("invalid option value type: %T", value) + return firstErr } + + return nil } // SetConfigOption sets a single value in the (prioritized) user defined config. @@ -140,9 +133,10 @@ func setConfigOption(key string, value interface{}, push bool) (err error) { if value == nil { option.activeValue = nil } else { - err = validateValue(option, value) + var valueCache *valueCache + valueCache, err = validateValue(option, value) if err == nil { - option.activeValue = value + option.activeValue = valueCache } } option.Unlock() @@ -175,9 +169,10 @@ func setDefaultConfigOption(key string, value interface{}, push bool) (err error if value == nil { option.activeDefaultValue = nil } else { - err = validateValue(option, value) + var valueCache *valueCache + valueCache, err = validateValue(option, value) if err == nil { - option.activeDefaultValue = value + option.activeDefaultValue = valueCache } } option.Unlock() diff --git a/config/validate.go b/config/validate.go new file mode 100644 index 0000000..3db4871 --- /dev/null +++ b/config/validate.go @@ -0,0 +1,120 @@ +package config + +import ( + "errors" + "fmt" + "math" +) + +type valueCache struct { + stringVal string + stringArrayVal []string + intVal int64 + boolVal bool +} + +func (vc *valueCache) getData(opt *Option) interface{} { + switch opt.OptType { + case OptTypeBool: + return vc.boolVal + case OptTypeInt: + return vc.intVal + case OptTypeString: + return vc.stringVal + case OptTypeStringArray: + return vc.stringArrayVal + default: + return nil + } +} + +func validateValue(option *Option, value interface{}) (*valueCache, error) { //nolint:gocyclo + switch v := value.(type) { + case string: + if option.OptType != OptTypeString { + return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v) + } + if option.compiledRegex != nil { + if !option.compiledRegex.MatchString(v) { + return nil, fmt.Errorf("validation of option %s failed: string \"%s\" did not match validation regex for option", option.Key, v) + } + } + return &valueCache{stringVal: v}, nil + case []interface{}: + vConverted := make([]string, len(v)) + for pos, entry := range v { + s, ok := entry.(string) + if !ok { + return nil, fmt.Errorf("validation of option %s failed: element %+v at index %d is not a string", option.Key, entry, pos) + + } + vConverted[pos] = s + } + // continue to next case + return validateValue(option, vConverted) + case []string: + if option.OptType != OptTypeStringArray { + return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v) + } + if option.compiledRegex != nil { + for pos, entry := range v { + if !option.compiledRegex.MatchString(entry) { + return nil, fmt.Errorf("validation of option %s failed: string \"%s\" at index %d did not match validation regex", option.Key, entry, pos) + } + } + } + return &valueCache{stringArrayVal: v}, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64: + // uint64 is omitted, as it does not fit in a int64 + if option.OptType != OptTypeInt { + return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v) + } + if option.compiledRegex != nil { + // we need to use %v here so we handle float and int correctly. + if !option.compiledRegex.MatchString(fmt.Sprintf("%v", v)) { + return nil, fmt.Errorf("validation of option %s failed: number \"%d\" did not match validation regex", option.Key, v) + } + } + switch v := value.(type) { + case int: + return &valueCache{intVal: int64(v)}, nil + case int8: + return &valueCache{intVal: int64(v)}, nil + case int16: + return &valueCache{intVal: int64(v)}, nil + case int32: + return &valueCache{intVal: int64(v)}, nil + case int64: + return &valueCache{intVal: v}, nil + case uint: + return &valueCache{intVal: int64(v)}, nil + case uint8: + return &valueCache{intVal: int64(v)}, nil + case uint16: + return &valueCache{intVal: int64(v)}, nil + case uint32: + return &valueCache{intVal: int64(v)}, nil + case float32: + // convert if float has no decimals + if math.Remainder(float64(v), 1) == 0 { + return &valueCache{intVal: int64(v)}, nil + } + return nil, fmt.Errorf("failed to convert float32 to int64 for option %s, got value %+v", option.Key, v) + case float64: + // convert if float has no decimals + if math.Remainder(v, 1) == 0 { + return &valueCache{intVal: int64(v)}, nil + } + return nil, fmt.Errorf("failed to convert float64 to int64 for option %s, got value %+v", option.Key, v) + default: + return nil, errors.New("internal error") + } + case bool: + if option.OptType != OptTypeBool { + return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v) + } + return &valueCache{boolVal: v}, nil + default: + return nil, fmt.Errorf("invalid option value type for option %s: %T", option.Key, value) + } +} diff --git a/config/validity.go b/config/validity.go new file mode 100644 index 0000000..d42e91b --- /dev/null +++ b/config/validity.go @@ -0,0 +1,30 @@ +package config + +import ( + "github.com/tevino/abool" +) + +// ValidityFlag is a flag that signifies if the configuration has been changed. It is not safe for concurrent use. +type ValidityFlag struct { + flag *abool.AtomicBool +} + +// NewValidityFlag returns a flag that signifies if the configuration has been changed. +func NewValidityFlag() *ValidityFlag { + vf := &ValidityFlag{} + vf.Refresh() + return vf +} + +// IsValid returns if the configuration is still valid. +func (vf *ValidityFlag) IsValid() bool { + return vf.flag.IsSet() +} + +// Refresh refreshes the flag and makes it reusable. +func (vf *ValidityFlag) Refresh() { + validityFlagLock.RLock() + defer validityFlagLock.RUnlock() + + vf.flag = validityFlag +} diff --git a/database/controller.go b/database/controller.go index 51699f4..e5e94fc 100644 --- a/database/controller.go +++ b/database/controller.go @@ -1,6 +1,7 @@ package database import ( + "errors" "sync" "github.com/tevino/abool" @@ -119,10 +120,13 @@ func (c *Controller) Put(r record.Record) (err error) { } } - err = c.storage.Put(r) + r, err = c.storage.Put(r) if err != nil { return err } + if r == nil { + return errors.New("storage returned nil record after successful put operation") + } // process subscriptions for _, sub := range c.subscriptions { @@ -137,6 +141,32 @@ func (c *Controller) Put(r record.Record) (err error) { return nil } +// PutMany stores many records in the database. +func (c *Controller) PutMany() (chan<- record.Record, <-chan error) { + c.writeLock.RLock() + defer c.writeLock.RUnlock() + + if shuttingDown.IsSet() { + errs := make(chan error, 1) + errs <- ErrShuttingDown + return make(chan record.Record), errs + } + + if c.ReadOnly() { + errs := make(chan error, 1) + errs <- ErrReadOnly + return make(chan record.Record), errs + } + + if batcher, ok := c.storage.(storage.Batcher); ok { + return batcher.PutMany() + } + + errs := make(chan error, 1) + errs <- ErrNotImplemented + return make(chan record.Record), errs +} + // Query executes the given query on the database. func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { c.readLock.RLock() diff --git a/database/database_test.go b/database/database_test.go index 3d59556..6367648 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -10,10 +10,13 @@ import ( "testing" "time" + "github.com/safing/portbase/database/record" + q "github.com/safing/portbase/database/query" _ "github.com/safing/portbase/database/storage/badger" _ "github.com/safing/portbase/database/storage/bbolt" _ "github.com/safing/portbase/database/storage/fstree" + _ "github.com/safing/portbase/database/storage/hashmap" ) func makeKey(dbName, key string) string { @@ -39,7 +42,10 @@ func testDatabase(t *testing.T, storageType string) { } // interface - db := NewInterface(nil) + db := NewInterface(&Options{ + Local: true, + Internal: true, + }) // sub sub, err := db.Subscribe(q.New(dbName).MustBeValid()) @@ -107,6 +113,18 @@ func testDatabase(t *testing.T, storageType string) { t.Fatalf("expected two records, got %d", cnt) } + switch storageType { + case "bbolt", "hashmap": + batchPut := db.PutMany(dbName) + records := []record.Record{A, B, C, nil} // nil is to signify finish + for _, r := range records { + err = batchPut(r) + if err != nil { + t.Fatal(err) + } + } + } + err = hook.Cancel() if err != nil { t.Fatal(err) @@ -128,12 +146,12 @@ func TestDatabaseSystem(t *testing.T) { os.Exit(1) }() - testDir, err := ioutil.TempDir("", "testing-") + testDir, err := ioutil.TempDir("", "portbase-database-testing-") if err != nil { t.Fatal(err) } - err = Initialize(testDir, nil) + err = InitializeWithPath(testDir) if err != nil { t.Fatal(err) } @@ -142,6 +160,7 @@ func TestDatabaseSystem(t *testing.T) { testDatabase(t, "badger") testDatabase(t, "bbolt") testDatabase(t, "fstree") + testDatabase(t, "hashmap") err = MaintainRecordStates() if err != nil { diff --git a/database/dbmodule/db.go b/database/dbmodule/db.go index 13903c6..9a4c430 100644 --- a/database/dbmodule/db.go +++ b/database/dbmodule/db.go @@ -4,41 +4,44 @@ import ( "errors" "github.com/safing/portbase/database" + "github.com/safing/portbase/dataroot" "github.com/safing/portbase/modules" "github.com/safing/portbase/utils" ) var ( - databasePath string databaseStructureRoot *utils.DirStructure module *modules.Module ) func init() { - module = modules.Register("database", prep, start, stop, "base") + module = modules.Register("database", prep, start, stop) } // SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure. -func SetDatabaseLocation(dirPath string, dirStructureRoot *utils.DirStructure) { - databasePath = dirPath - databaseStructureRoot = dirStructureRoot +func SetDatabaseLocation(dirStructureRoot *utils.DirStructure) { + if databaseStructureRoot == nil { + databaseStructureRoot = dirStructureRoot + } } func prep() error { - if databasePath == "" && databaseStructureRoot == nil { - return errors.New("no database location specified") + SetDatabaseLocation(dataroot.Root()) + if databaseStructureRoot == nil { + return errors.New("database location not specified") } + return nil } func start() error { - err := database.Initialize(databasePath, databaseStructureRoot) + err := database.Initialize(databaseStructureRoot) if err != nil { return err } - registerMaintenanceTasks() + startMaintenanceTasks() return nil } diff --git a/database/dbmodule/maintenance.go b/database/dbmodule/maintenance.go index 8958299..6a1755c 100644 --- a/database/dbmodule/maintenance.go +++ b/database/dbmodule/maintenance.go @@ -5,33 +5,23 @@ import ( "time" "github.com/safing/portbase/database" - "github.com/safing/portbase/log" "github.com/safing/portbase/modules" ) -func registerMaintenanceTasks() { +func startMaintenanceTasks() { module.NewTask("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute) module.NewTask("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) module.NewTask("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) } -func maintainBasic(ctx context.Context, task *modules.Task) { - err := database.Maintain() - if err != nil { - log.Errorf("database: maintenance error: %s", err) - } +func maintainBasic(ctx context.Context, task *modules.Task) error { + return database.Maintain() } -func maintainThorough(ctx context.Context, task *modules.Task) { - err := database.MaintainThorough() - if err != nil { - log.Errorf("database: thorough maintenance error: %s", err) - } +func maintainThorough(ctx context.Context, task *modules.Task) error { + return database.MaintainThorough() } -func maintainRecords(ctx context.Context, task *modules.Task) { - err := database.MaintainRecordStates() - if err != nil { - log.Errorf("database: record states maintenance error: %s", err) - } +func maintainRecords(ctx context.Context, task *modules.Task) error { + return database.MaintainRecordStates() } diff --git a/database/errors.go b/database/errors.go index cae280d..9d22b53 100644 --- a/database/errors.go +++ b/database/errors.go @@ -10,4 +10,5 @@ var ( ErrPermissionDenied = errors.New("access to database record denied") ErrReadOnly = errors.New("database is read only") ErrShuttingDown = errors.New("database system is shutting down") + ErrNotImplemented = errors.New("not implemented by this storage") ) diff --git a/database/interface.go b/database/interface.go index 9e9b9f7..a939960 100644 --- a/database/interface.go +++ b/database/interface.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/tevino/abool" + "github.com/bluele/gcache" "github.com/safing/portbase/database/accessor" @@ -170,10 +172,19 @@ func (i *Interface) InsertValue(key string, attribute string, value interface{}) } // Put saves a record to the database. -func (i *Interface) Put(r record.Record) error { - _, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true) - if err != nil && err != ErrNotFound { - return err +func (i *Interface) Put(r record.Record) (err error) { + // get record or only database + var db *Controller + if !i.options.Internal || !i.options.Local { + _, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true) + if err != nil && err != ErrNotFound { + return err + } + } else { + db, err = getController(r.DatabaseKey()) + if err != nil { + return err + } } r.Lock() @@ -186,24 +197,122 @@ func (i *Interface) Put(r record.Record) error { } // PutNew saves a record to the database as a new record (ie. with new timestamps). -func (i *Interface) PutNew(r record.Record) error { - _, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true) - if err != nil && err != ErrNotFound { - return err +func (i *Interface) PutNew(r record.Record) (err error) { + // get record or only database + var db *Controller + if !i.options.Internal || !i.options.Local { + _, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true) + if err != nil && err != ErrNotFound { + return err + } + } else { + db, err = getController(r.DatabaseKey()) + if err != nil { + return err + } } r.Lock() defer r.Unlock() - if r.Meta() == nil { - r.CreateMeta() + if r.Meta() != nil { + r.Meta().Reset() } - r.Meta().Reset() i.options.Apply(r) i.updateCache(r) return db.Put(r) } +// PutMany stores many records in the database. Warning: This is nearly a direct database access and omits many things: +// - Record locking +// - Hooks +// - Subscriptions +// - Caching +func (i *Interface) PutMany(dbName string) (put func(record.Record) error) { + interfaceBatch := make(chan record.Record, 100) + + // permission check + if !i.options.Internal || !i.options.Local { + return func(r record.Record) error { + return ErrPermissionDenied + } + } + + // get database + db, err := getController(dbName) + if err != nil { + return func(r record.Record) error { + return err + } + } + + // start database access + dbBatch, errs := db.PutMany() + finished := abool.New() + var internalErr error + + // interface options proxy + go func() { + defer close(dbBatch) // signify that we are finished + for { + select { + case r := <-interfaceBatch: + // finished? + if r == nil { + return + } + // apply options + i.options.Apply(r) + // pass along + dbBatch <- r + case <-time.After(1 * time.Second): + // bail out + internalErr = errors.New("timeout: putmany unused for too long") + finished.Set() + return + } + } + }() + + return func(r record.Record) error { + // finished? + if finished.IsSet() { + // check for internal error + if internalErr != nil { + return internalErr + } + // check for previous error + select { + case err := <-errs: + return err + default: + return errors.New("batch is closed") + } + } + + // finish? + if r == nil { + finished.Set() + interfaceBatch <- nil // signify that we are finished + // do not close, as this fn could be called again with nil. + return <-errs + } + + // check record scope + if r.DatabaseName() != dbName { + return errors.New("record out of database scope") + } + + // submit + select { + case interfaceBatch <- r: + return nil + case err := <-errs: + return err + } + } +} + // SetAbsoluteExpiry sets an absolute record expiry. func (i *Interface) SetAbsoluteExpiry(key string, time int64) error { r, db, err := i.getRecord(getDBFromKey, key, true, true) diff --git a/database/main.go b/database/main.go index d49962c..74c31a3 100644 --- a/database/main.go +++ b/database/main.go @@ -23,15 +23,15 @@ var ( databasesStructure *utils.DirStructure ) -// Initialize initializes the database at the specified location. Supply either a path or dir structure. -func Initialize(dirPath string, dirStructureRoot *utils.DirStructure) error { - if initialized.SetToIf(false, true) { +// InitializeWithPath initializes the database at the specified location using a path. +func InitializeWithPath(dirPath string) error { + return Initialize(utils.NewDirStructure(dirPath, 0755)) +} - if dirStructureRoot != nil { - rootStructure = dirStructureRoot - } else { - rootStructure = utils.NewDirStructure(dirPath, 0755) - } +// Initialize initializes the database at the specified location using a dir structure. +func Initialize(dirStructureRoot *utils.DirStructure) error { + if initialized.SetToIf(false, true) { + rootStructure = dirStructureRoot // ensure root and databases dirs databasesStructure = rootStructure.ChildDir(databasesSubDir, 0700) diff --git a/database/record/record_test.go b/database/record/record_test.go index f5e315d..5912d67 100644 --- a/database/record/record_test.go +++ b/database/record/record_test.go @@ -1,16 +1,10 @@ package record -import "sync" +import ( + "sync" +) type TestRecord struct { Base - lock sync.Mutex -} - -func (tm *TestRecord) Lock() { - tm.lock.Lock() -} - -func (tm *TestRecord) Unlock() { - tm.lock.Unlock() + sync.Mutex } diff --git a/database/record/wrapper.go b/database/record/wrapper.go index a456edd..b89a733 100644 --- a/database/record/wrapper.go +++ b/database/record/wrapper.go @@ -37,27 +37,19 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) { offset += n newMeta := &Meta{} - if len(metaSection) == 34 && metaSection[4] == 0 { - // TODO: remove in 2020 - // backward compatibility: - // format would byte shift and populate metaSection[4] with value > 0 (would naturally populate >0 at 07.02.2106 07:28:15) - // this must be gencode without format - _, err = newMeta.GenCodeUnmarshal(metaSection) - if err != nil { - return nil, fmt.Errorf("could not unmarshal meta section: %s", err) - } - } else { - _, err = dsd.Load(metaSection, newMeta) - if err != nil { - return nil, fmt.Errorf("could not unmarshal meta section: %s", err) - } + _, err = dsd.Load(metaSection, newMeta) + if err != nil { + return nil, fmt.Errorf("could not unmarshal meta section: %s", err) } - format, n, err := varint.Unpack8(data[offset:]) - if err != nil { - return nil, fmt.Errorf("could not get dsd format: %s", err) + var format uint8 = dsd.NONE + if !newMeta.IsDeleted() { + format, n, err = varint.Unpack8(data[offset:]) + if err != nil { + return nil, fmt.Errorf("could not get dsd format: %s", err) + } + offset += n } - offset += n return &Wrapper{ Base{ diff --git a/database/record/wrapper_test.go b/database/record/wrapper_test.go index 460e6e9..de4d5ae 100644 --- a/database/record/wrapper_test.go +++ b/database/record/wrapper_test.go @@ -2,10 +2,7 @@ package record import ( "bytes" - "errors" "testing" - - "github.com/safing/portbase/container" ) func TestWrapper(t *testing.T) { @@ -54,43 +51,4 @@ func TestWrapper(t *testing.T) { if !bytes.Equal(testData, wrapper2.Data) { t.Error("marshal mismatch") } - - // test new format - oldRaw, err := oldWrapperMarshalRecord(wrapper, wrapper) - if err != nil { - t.Fatal(err) - } - - wrapper3, err := NewRawWrapper("test", "a", oldRaw) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(testData, wrapper3.Data) { - t.Error("marshal mismatch") - } -} - -func oldWrapperMarshalRecord(w *Wrapper, r Record) ([]byte, error) { - if w.Meta() == nil { - return nil, errors.New("missing meta") - } - - // version - c := container.New([]byte{1}) - - // meta - metaSection, err := w.meta.GenCodeMarshal(nil) - if err != nil { - return nil, err - } - c.AppendAsBlock(metaSection) - - // data - dataSection, err := w.Marshal(r, JSON) - if err != nil { - return nil, err - } - c.Append(dataSection) - - return c.CompileData(), nil } diff --git a/database/registry.go b/database/registry.go index b6d1984..22e6bab 100644 --- a/database/registry.go +++ b/database/registry.go @@ -139,7 +139,7 @@ func saveRegistry(lock bool) error { } // write file - // FIXME: write atomically (best effort) + // TODO: write atomically (best effort) filePath := path.Join(rootStructure.Path, registryFileName) return ioutil.WriteFile(filePath, data, 0600) } diff --git a/database/storage/badger/badger.go b/database/storage/badger/badger.go index 8f473cd..82283d4 100644 --- a/database/storage/badger/badger.go +++ b/database/storage/badger/badger.go @@ -82,16 +82,19 @@ func (b *Badger) Get(key string) (record.Record, error) { } // Put stores a record in the database. -func (b *Badger) Put(r record.Record) error { +func (b *Badger) Put(r record.Record) (record.Record, error) { data, err := r.MarshalRecord(r) if err != nil { - return err + return nil, err } err = b.db.Update(func(txn *badger.Txn) error { return txn.Set([]byte(r.DatabaseKey()), data) }) - return err + if err != nil { + return nil, err + } + return r, nil } // Delete deletes a record from the database. diff --git a/database/storage/badger/badger_test.go b/database/storage/badger/badger_test.go index 47e10b6..8cfae5a 100644 --- a/database/storage/badger/badger_test.go +++ b/database/storage/badger/badger_test.go @@ -65,7 +65,7 @@ func TestBadger(t *testing.T) { a.SetKey("test:A") // put record - err = db.Put(a) + _, err = db.Put(a) if err != nil { t.Fatal(err) } diff --git a/database/storage/bbolt/bbolt.go b/database/storage/bbolt/bbolt.go index d37bf60..26666d9 100644 --- a/database/storage/bbolt/bbolt.go +++ b/database/storage/bbolt/bbolt.go @@ -86,10 +86,10 @@ func (b *BBolt) Get(key string) (record.Record, error) { } // Put stores a record in the database. -func (b *BBolt) Put(r record.Record) error { +func (b *BBolt) Put(r record.Record) (record.Record, error) { data, err := r.MarshalRecord(r) if err != nil { - return err + return nil, err } err = b.db.Update(func(tx *bbolt.Tx) error { @@ -100,9 +100,38 @@ func (b *BBolt) Put(r record.Record) error { return nil }) if err != nil { - return err + return nil, err } - return nil + return r, nil +} + +// PutMany stores many records in the database. +func (b *BBolt) PutMany() (chan<- record.Record, <-chan error) { + batch := make(chan record.Record, 100) + errs := make(chan error, 1) + + go func() { + err := b.db.Batch(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(bucketName) + for r := range batch { + // marshal + data, txErr := r.MarshalRecord(r) + if txErr != nil { + return txErr + } + + // put + txErr = bucket.Put([]byte(r.DatabaseKey()), data) + if txErr != nil { + return txErr + } + } + return nil + }) + errs <- err + }() + + return batch, errs } // Delete deletes a record from the database. diff --git a/database/storage/bbolt/bbolt_test.go b/database/storage/bbolt/bbolt_test.go index e5cdd3e..84586a3 100644 --- a/database/storage/bbolt/bbolt_test.go +++ b/database/storage/bbolt/bbolt_test.go @@ -65,7 +65,7 @@ func TestBBolt(t *testing.T) { a.SetKey("test:A") // put record - err = db.Put(a) + _, err = db.Put(a) if err != nil { t.Fatal(err) } @@ -100,15 +100,15 @@ func TestBBolt(t *testing.T) { qZ.SetKey("test:z") qZ.CreateMeta() // put - err = db.Put(qA) + _, err = db.Put(qA) if err == nil { - err = db.Put(qB) + _, err = db.Put(qB) } if err == nil { - err = db.Put(qC) + _, err = db.Put(qC) } if err == nil { - err = db.Put(qZ) + _, err = db.Put(qZ) } if err != nil { t.Fatal(err) diff --git a/database/storage/fstree/fstree.go b/database/storage/fstree/fstree.go index fa160de..34fc9f3 100644 --- a/database/storage/fstree/fstree.go +++ b/database/storage/fstree/fstree.go @@ -104,15 +104,15 @@ func (fst *FSTree) Get(key string) (record.Record, error) { } // Put stores a record in the database. -func (fst *FSTree) Put(r record.Record) error { +func (fst *FSTree) Put(r record.Record) (record.Record, error) { dstPath, err := fst.buildFilePath(r.DatabaseKey(), true) if err != nil { - return err + return nil, err } data, err := r.MarshalRecord(r) if err != nil { - return err + return nil, err } err = writeFile(dstPath, data, defaultFileMode) @@ -120,15 +120,15 @@ func (fst *FSTree) Put(r record.Record) error { // create dir and try again err = os.MkdirAll(filepath.Dir(dstPath), defaultDirMode) if err != nil { - return fmt.Errorf("fstree: failed to create directory %s: %s", filepath.Dir(dstPath), err) + return nil, fmt.Errorf("fstree: failed to create directory %s: %s", filepath.Dir(dstPath), err) } err = writeFile(dstPath, data, defaultFileMode) if err != nil { - return fmt.Errorf("fstree: could not write file %s: %s", dstPath, err) + return nil, fmt.Errorf("fstree: could not write file %s: %s", dstPath, err) } } - return nil + return r, nil } // Delete deletes a record from the database. diff --git a/database/storage/hashmap/map.go b/database/storage/hashmap/map.go index c6bb65d..ba6afd6 100644 --- a/database/storage/hashmap/map.go +++ b/database/storage/hashmap/map.go @@ -44,12 +44,33 @@ func (hm *HashMap) Get(key string) (record.Record, error) { } // Put stores a record in the database. -func (hm *HashMap) Put(r record.Record) error { +func (hm *HashMap) Put(r record.Record) (record.Record, error) { hm.dbLock.Lock() defer hm.dbLock.Unlock() hm.db[r.DatabaseKey()] = r - return nil + return r, nil +} + +// PutMany stores many records in the database. +func (hm *HashMap) PutMany() (chan<- record.Record, <-chan error) { + hm.dbLock.Lock() + defer hm.dbLock.Unlock() + // we could lock for every record, but we want to have the same behaviour + // as the other storage backends, especially for testing. + + batch := make(chan record.Record, 100) + errs := make(chan error, 1) + + // start handler + go func() { + for r := range batch { + hm.db[r.DatabaseKey()] = r + } + errs <- nil + }() + + return batch, errs } // Delete deletes a record from the database. diff --git a/database/storage/hashmap/map_test.go b/database/storage/hashmap/map_test.go index 911dc5b..a546809 100644 --- a/database/storage/hashmap/map_test.go +++ b/database/storage/hashmap/map_test.go @@ -57,7 +57,7 @@ func TestHashMap(t *testing.T) { a.SetKey("test:A") // put record - err = db.Put(a) + _, err = db.Put(a) if err != nil { t.Fatal(err) } @@ -86,15 +86,15 @@ func TestHashMap(t *testing.T) { qZ.SetKey("test:z") qZ.CreateMeta() // put - err = db.Put(qA) + _, err = db.Put(qA) if err == nil { - err = db.Put(qB) + _, err = db.Put(qB) } if err == nil { - err = db.Put(qC) + _, err = db.Put(qC) } if err == nil { - err = db.Put(qZ) + _, err = db.Put(qZ) } if err != nil { t.Fatal(err) diff --git a/database/storage/injectbase.go b/database/storage/injectbase.go index 316c8d6..182291c 100644 --- a/database/storage/injectbase.go +++ b/database/storage/injectbase.go @@ -21,8 +21,16 @@ func (i *InjectBase) Get(key string) (record.Record, error) { } // Put stores a record in the database. -func (i *InjectBase) Put(m record.Record) error { - return errNotImplemented +func (i *InjectBase) Put(m record.Record) (record.Record, error) { + return nil, errNotImplemented +} + +// PutMany stores many records in the database. +func (i *InjectBase) PutMany() (batch chan record.Record, err chan error) { + batch = make(chan record.Record) + err = make(chan error, 1) + err <- errNotImplemented + return } // Delete deletes a record from the database. diff --git a/database/storage/interface.go b/database/storage/interface.go index 4a6ce1e..bdf12e4 100644 --- a/database/storage/interface.go +++ b/database/storage/interface.go @@ -9,7 +9,7 @@ import ( // Interface defines the database storage API. type Interface interface { Get(key string) (record.Record, error) - Put(m record.Record) error + Put(m record.Record) (record.Record, error) Delete(key string) error Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) @@ -19,3 +19,8 @@ type Interface interface { MaintainThorough() error Shutdown() error } + +// Batcher defines the database storage API for backends that support batch operations. +type Batcher interface { + PutMany() (batch chan<- record.Record, errs <-chan error) +} diff --git a/database/storage/sinkhole/sinkhole.go b/database/storage/sinkhole/sinkhole.go index 9c51d29..dd3b065 100644 --- a/database/storage/sinkhole/sinkhole.go +++ b/database/storage/sinkhole/sinkhole.go @@ -36,8 +36,24 @@ func (s *Sinkhole) Get(key string) (record.Record, error) { } // Put stores a record in the database. -func (s *Sinkhole) Put(m record.Record) error { - return nil +func (s *Sinkhole) Put(r record.Record) (record.Record, error) { + return r, nil +} + +// PutMany stores many records in the database. +func (s *Sinkhole) PutMany() (chan<- record.Record, <-chan error) { + batch := make(chan record.Record, 100) + errs := make(chan error, 1) + + // start handler + go func() { + for range batch { + // nom, nom, nom + } + errs <- nil + }() + + return batch, errs } // Delete deletes a record from the database. diff --git a/database/subscription.go b/database/subscription.go index 6d29b7e..ffe516a 100644 --- a/database/subscription.go +++ b/database/subscription.go @@ -13,7 +13,6 @@ type Subscription struct { canceled bool Feed chan record.Record - Err error } // Cancel cancels the subscription. diff --git a/dataroot/root.go b/dataroot/root.go new file mode 100644 index 0000000..84c5275 --- /dev/null +++ b/dataroot/root.go @@ -0,0 +1,27 @@ +package dataroot + +import ( + "errors" + "os" + + "github.com/safing/portbase/utils" +) + +var ( + root *utils.DirStructure +) + +// Initialize initializes the data root directory +func Initialize(rootDir string, perm os.FileMode) error { + if root != nil { + return errors.New("already initialized") + } + + root = utils.NewDirStructure(rootDir, perm) + return root.Ensure() +} + +// Root returns the data root directory. +func Root() *utils.DirStructure { + return root +} diff --git a/log/flags.go b/log/flags.go index 0fcfb75..eb01929 100644 --- a/log/flags.go +++ b/log/flags.go @@ -8,6 +8,6 @@ var ( ) func init() { - flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]") + flag.StringVar(&logLevelFlag, "log", "", "set log level to [trace|debug|info|warning|error|critical]") flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,notifications=debug") } diff --git a/log/formatting_darwin.go b/log/formatting_darwin.go new file mode 100644 index 0000000..f24cf32 --- /dev/null +++ b/log/formatting_darwin.go @@ -0,0 +1,38 @@ +package log + +const ( + rightArrow = "ā–¶" + leftArrow = "ā—€" +) + +const ( + // colorBlack = "\033[30m" + colorRed = "\033[31m" + // colorGreen = "\033[32m" + colorYellow = "\033[33m" + colorBlue = "\033[34m" + colorMagenta = "\033[35m" + colorCyan = "\033[36m" + // colorWhite = "\033[37m" +) + +func (s Severity) color() string { + switch s { + case DebugLevel: + return colorCyan + case InfoLevel: + return colorBlue + case WarningLevel: + return colorYellow + case ErrorLevel: + return colorRed + case CriticalLevel: + return colorMagenta + default: + return "" + } +} + +func endColor() string { + return "\033[0m" +} diff --git a/log/input.go b/log/input.go index e8b6579..60f4af6 100644 --- a/log/input.go +++ b/log/input.go @@ -12,7 +12,7 @@ func log(level Severity, msg string, tracer *ContextTracer) { if !started.IsSet() { // a bit resource intense, but keeps logs before logging started. - // FIXME: create option to disable logging + // TODO: create option to disable logging go func() { <-startedSignal log(level, msg, tracer) diff --git a/log/logging.go b/log/logging.go index 2d605c8..7d606a3 100644 --- a/log/logging.go +++ b/log/logging.go @@ -82,6 +82,7 @@ var ( logsWaiting = make(chan struct{}, 4) logsWaitingFlag = abool.NewBool(false) + shutdownFlag = abool.NewBool(false) shutdownSignal = make(chan struct{}) shutdownWaitGroup sync.WaitGroup @@ -136,12 +137,14 @@ func Start() (err error) { logBuffer = make(chan *logLine, 1024) - initialLogLevel := ParseLevel(logLevelFlag) - if initialLogLevel > 0 { + if logLevelFlag != "" { + initialLogLevel := ParseLevel(logLevelFlag) + if initialLogLevel == 0 { + fmt.Fprintf(os.Stderr, "log warning: invalid log level \"%s\", falling back to level info\n", logLevelFlag) + initialLogLevel = InfoLevel + } + SetLogLevel(initialLogLevel) - } else { - err = fmt.Errorf("log warning: invalid log level \"%s\", falling back to level info", logLevelFlag) - fmt.Fprintf(os.Stderr, "%s\n", err.Error()) } // get and set file loglevels @@ -179,6 +182,8 @@ func Start() (err error) { // Shutdown writes remaining log lines and then stops the log system. func Shutdown() { - close(shutdownSignal) + if shutdownFlag.SetToIf(false, true) { + close(shutdownSignal) + } shutdownWaitGroup.Wait() } diff --git a/log/output.go b/log/output.go index 28bde37..0b5c799 100644 --- a/log/output.go +++ b/log/output.go @@ -41,7 +41,7 @@ func writeLine(line *logLine, duplicates uint64) { } func startWriter() { - fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor())) + fmt.Printf("%s%s %s BOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor()) shutdownWaitGroup.Add(1) go writerManager() @@ -168,7 +168,7 @@ func finalizeWriting() { case line := <-logBuffer: writeLine(line, 0) case <-time.After(10 * time.Millisecond): - fmt.Println(fmt.Sprintf("%s%s %s EOF%s", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor())) + fmt.Printf("%s%s %s EOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor()) return } } diff --git a/log/trace.go b/log/trace.go index a345f5d..f90b306 100644 --- a/log/trace.go +++ b/log/trace.go @@ -89,7 +89,7 @@ func (tracer *ContextTracer) Submit() { if !started.IsSet() { // a bit resource intense, but keeps logs before logging started. - // FIXME: create option to disable logging + // TODO: create option to disable logging go func() { <-startedSignal tracer.Submit() diff --git a/modules/events.go b/modules/events.go index 0768ebd..552a2c3 100644 --- a/modules/events.go +++ b/modules/events.go @@ -17,7 +17,9 @@ type eventHook struct { // TriggerEvent executes all hook functions registered to the specified event. func (m *Module) TriggerEvent(event string, data interface{}) { - go m.processEventTrigger(event, data) + if m.OnlineSoon() { + go m.processEventTrigger(event, data) + } } func (m *Module) processEventTrigger(event string, data interface{}) { @@ -31,18 +33,35 @@ func (m *Module) processEventTrigger(event string, data interface{}) { } for _, hook := range hooks { - if !hook.hookingModule.ShutdownInProgress() { + if hook.hookingModule.OnlineSoon() { go m.runEventHook(hook, event, data) } } } func (m *Module) runEventHook(hook *eventHook, event string, data interface{}) { - if !hook.hookingModule.Started.IsSet() { + // check if source module is ready for handling + if m.Status() != StatusOnline { // target module has not yet fully started, wait until start is complete select { - case <-startCompleteSignal: - case <-shutdownSignal: + case <-m.StartCompleted(): + // continue with hook execution + case <-hook.hookingModule.Stopping(): + return + case <-m.Stopping(): + return + } + } + + // check if destionation module is ready for handling + if hook.hookingModule.Status() != StatusOnline { + // target module has not yet fully started, wait until start is complete + select { + case <-hook.hookingModule.StartCompleted(): + // continue with hook execution + case <-hook.hookingModule.Stopping(): + return + case <-m.Stopping(): return } } @@ -69,7 +88,7 @@ func (m *Module) RegisterEvent(event string) { } } -// RegisterEventHook registers a hook function with (another) modules' event. Whenever a hook is triggered and the receiving module has not yet fully started, hook execution will be delayed until all modules completed starting. +// RegisterEventHook registers a hook function with (another) modules' event. Whenever a hook is triggered and the receiving module has not yet fully started, hook execution will be delayed until the modules completed starting. func (m *Module) RegisterEventHook(module string, event string, description string, fn func(context.Context, interface{}) error) error { // get target module var eventModule *Module @@ -77,9 +96,7 @@ func (m *Module) RegisterEventHook(module string, event string, description stri eventModule = m } else { var ok bool - modulesLock.RLock() eventModule, ok = modules[module] - modulesLock.RUnlock() if !ok { return fmt.Errorf(`module "%s" does not exist`, module) } diff --git a/modules/flags.go b/modules/flags.go index d9ad1e8..7c5da26 100644 --- a/modules/flags.go +++ b/modules/flags.go @@ -1,18 +1,22 @@ package modules -import "flag" +import ( + "flag" + "fmt" +) var ( // HelpFlag triggers printing flag.Usage. It's exported for custom help handling. - HelpFlag bool + HelpFlag bool + printGraphFlag bool ) func init() { flag.BoolVar(&HelpFlag, "help", false, "print help") + flag.BoolVar(&printGraphFlag, "print-module-graph", false, "print the module dependency graph") } func parseFlags() error { - // parse flags flag.Parse() @@ -21,5 +25,36 @@ func parseFlags() error { return ErrCleanExit } + if printGraphFlag { + printGraph() + return ErrCleanExit + } + return nil } + +func printGraph() { + // mark roots + for _, module := range modules { + if len(module.depReverse) == 0 { + // is root, dont print deps in dep tree + module.stopFlag.Set() + } + } + // print + for _, module := range modules { + if module.stopFlag.IsSet() { + // print from root + printModuleGraph("", module, true) + } + } +} + +func printModuleGraph(prefix string, module *Module, root bool) { + fmt.Printf("%sā”œā”€ā”€ %s\n", prefix, module.Name) + if root || !module.stopFlag.IsSet() { + for _, dep := range module.Dependencies() { + printModuleGraph(fmt.Sprintf("│ %s", prefix), dep, false) + } + } +} diff --git a/modules/mgmt.go b/modules/mgmt.go new file mode 100644 index 0000000..f5f082e --- /dev/null +++ b/modules/mgmt.go @@ -0,0 +1,107 @@ +package modules + +import ( + "context" + + "github.com/safing/portbase/log" + "github.com/tevino/abool" +) + +var ( + moduleMgmtEnabled = abool.NewBool(false) + modulesChangeNotifyFn func(*Module) +) + +// Enable enables the module. Only has an effect if module management is enabled. +func (m *Module) Enable() (changed bool) { + return m.enabled.SetToIf(false, true) +} + +// Disable disables the module. Only has an effect if module management is enabled. +func (m *Module) Disable() (changed bool) { + return m.enabled.SetToIf(true, false) +} + +// SetEnabled sets the module to the desired enabled state. Only has an effect if module management is enabled. +func (m *Module) SetEnabled(enable bool) (changed bool) { + if enable { + return m.Enable() + } + return m.Disable() +} + +// Enabled returns wether or not the module is currently enabled. +func (m *Module) Enabled() bool { + return m.enabled.IsSet() +} + +// EnableModuleManagement enables the module management functionality within modules. The supplied notify function will be called whenever the status of a module changes. The affected module will be in the parameter. You will need to manually enable modules, else nothing will start. +func EnableModuleManagement(changeNotifyFn func(*Module)) { + if moduleMgmtEnabled.SetToIf(false, true) { + modulesChangeNotifyFn = changeNotifyFn + } +} + +func (m *Module) notifyOfChange() { + if moduleMgmtEnabled.IsSet() && modulesChangeNotifyFn != nil { + m.StartWorker("notify of change", func(ctx context.Context) error { + modulesChangeNotifyFn(m) + return nil + }) + } +} + +// ManageModules triggers the module manager to react to recent changes of enabled modules. +func ManageModules() error { + // check if enabled + if !moduleMgmtEnabled.IsSet() { + return nil + } + + // lock mgmt + mgmtLock.Lock() + defer mgmtLock.Unlock() + + log.Info("modules: managing changes") + + // build new dependency tree + buildEnabledTree() + + // stop unneeded modules + lastErr := stopModules() + if lastErr != nil { + log.Warning(lastErr.Error()) + } + + // start needed modules + err := startModules() + if err != nil { + log.Warning(err.Error()) + lastErr = err + } + + log.Info("modules: finished managing") + return lastErr +} + +func buildEnabledTree() { + // reset marked dependencies + for _, m := range modules { + m.enabledAsDependency.UnSet() + } + + // mark dependencies + for _, m := range modules { + if m.enabled.IsSet() { + m.markDependencies() + } + } +} + +func (m *Module) markDependencies() { + for _, dep := range m.depModules { + if dep.enabledAsDependency.SetToIf(false, true) { + dep.markDependencies() + } + } +} diff --git a/modules/mgmt_test.go b/modules/mgmt_test.go new file mode 100644 index 0000000..3e6009b --- /dev/null +++ b/modules/mgmt_test.go @@ -0,0 +1,165 @@ +package modules + +import ( + "testing" +) + +func testModuleMgmt(t *testing.T) { + + // enable module management + EnableModuleManagement(nil) + + registerTestModule(t, "base") + registerTestModule(t, "feature1", "base") + registerTestModule(t, "base2", "base") + registerTestModule(t, "feature2", "base2") + registerTestModule(t, "sub-feature", "base") + registerTestModule(t, "feature3", "sub-feature") + registerTestModule(t, "feature4", "sub-feature") + + // enable core module + core := modules["base"] + core.Enable() + + // start and check order + err := Start() + if err != nil { + t.Error(err) + } + if changeHistory != " on:base" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + // enable feature1 + feature1 := modules["feature1"] + feature1.Enable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + if changeHistory != " on:feature1" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + // enable feature2 + feature2 := modules["feature2"] + feature2.Enable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + if changeHistory != " on:base2 on:feature2" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + // enable feature3 + feature3 := modules["feature3"] + feature3.Enable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + if changeHistory != " on:sub-feature on:feature3" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + // enable feature4 + feature4 := modules["feature4"] + feature4.Enable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + if changeHistory != " on:feature4" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + // disable feature1 + feature1.Disable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + if changeHistory != " off:feature1" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + // disable feature3 + feature3.Disable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + // disable feature4 + feature4.Disable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + if changeHistory != " off:feature3 off:feature4 off:sub-feature" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + // enable feature4 + feature4.Enable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + if changeHistory != " on:sub-feature on:feature4" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + // disable feature4 + feature4.Disable() + // manage modules and check + err = ManageModules() + if err != nil { + t.Fatal(err) + return + } + if changeHistory != " off:feature4 off:sub-feature" { + t.Errorf("order mismatch, was %s", changeHistory) + } + changeHistory = "" + + err = Shutdown() + if err != nil { + t.Error(err) + } + if changeHistory != " off:feature2 off:base2 off:base" { + t.Errorf("order mismatch, was %s", changeHistory) + } + + // reset history + changeHistory = "" + + // disable module management + moduleMgmtEnabled.UnSet() + + resetTestEnvironment() +} diff --git a/modules/microtasks.go b/modules/microtasks.go index a43d427..cf14832 100644 --- a/modules/microtasks.go +++ b/modules/microtasks.go @@ -172,7 +172,7 @@ func microTaskScheduler() { microTaskManageLoop: for { - if shutdownSignalClosed.IsSet() { + if shutdownFlag.IsSet() { close(mediumPriorityClearance) close(lowPriorityClearance) return diff --git a/modules/modules.go b/modules/modules.go index 66c3bd4..b330ddc 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -13,32 +13,44 @@ import ( ) var ( - modulesLock sync.RWMutex - modules = make(map[string]*Module) + modules = make(map[string]*Module) + mgmtLock sync.Mutex + + // lock modules when starting + modulesLocked = abool.New() // ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag. ErrCleanExit = errors.New("clean exit requested") ) // Module represents a module. -type Module struct { +type Module struct { //nolint:maligned // not worth the effort + sync.RWMutex + Name string - // lifecycle mgmt - Prepped *abool.AtomicBool - Started *abool.AtomicBool - Stopped *abool.AtomicBool - inTransition *abool.AtomicBool + // status mgmt + enabled *abool.AtomicBool + enabledAsDependency *abool.AtomicBool + status uint8 + + // failure status + failureStatus uint8 + failureID string + failureMsg string // lifecycle callback functions - prep func() error - start func() error - stop func() error + prepFn func() error + startFn func() error + stopFn func() error - // shutdown mgmt - Ctx context.Context - cancelCtx func() - shutdownFlag *abool.AtomicBool + // lifecycle mgmt + // start + startComplete chan struct{} + // stop + Ctx context.Context + cancelCtx func() + stopFlag *abool.AtomicBool // workers/tasks workerCnt *int32 @@ -56,30 +68,177 @@ type Module struct { depReverse []*Module } -// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead. -func (m *Module) ShutdownInProgress() bool { - return m.shutdownFlag.IsSet() +// StartCompleted returns a channel read that triggers when the module has finished starting. +func (m *Module) StartCompleted() <-chan struct{} { + m.RLock() + defer m.RUnlock() + return m.startComplete } -// ShuttingDown lets you listen for the shutdown signal. -func (m *Module) ShuttingDown() <-chan struct{} { +// Stopping returns a channel read that triggers when the module has initiated the stop procedure. +func (m *Module) Stopping() <-chan struct{} { + m.RLock() + defer m.RUnlock() return m.Ctx.Done() } -func (m *Module) shutdown() error { - // signal shutdown - m.shutdownFlag.Set() - m.cancelCtx() +// IsStopping returns whether the module has started shutting down. In most cases, you should use Stopping instead. +func (m *Module) IsStopping() bool { + return m.stopFlag.IsSet() +} - // start shutdown function - m.waitGroup.Add(1) - stopFnError := make(chan error, 1) +// Dependencies returns the module's dependencies. +func (m *Module) Dependencies() []*Module { + m.RLock() + defer m.RUnlock() + return m.depModules +} + +func (m *Module) prep(reports chan *report) { + // check and set intermediate status + m.Lock() + if m.status != StatusDead { + m.Unlock() + go func() { + reports <- &report{ + module: m, + err: fmt.Errorf("module already prepped"), + } + }() + return + } + m.status = StatusPreparing + m.Unlock() + + // run prep function go func() { - stopFnError <- m.runCtrlFn("stop module", m.stop) - m.waitGroup.Done() + var err error + if m.prepFn != nil { + // execute function + err = m.runCtrlFnWithTimeout( + "prep module", + 10*time.Second, + m.prepFn, + ) + } + // set status + if err != nil { + m.Error( + "module-failed-prep", + fmt.Sprintf("failed to prep module: %s", err.Error()), + ) + } else { + m.Lock() + m.status = StatusOffline + m.Unlock() + m.notifyOfChange() + } + // send report + reports <- &report{ + module: m, + err: err, + } }() +} - // wait for workers +func (m *Module) start(reports chan *report) { + // check and set intermediate status + m.Lock() + if m.status != StatusOffline { + m.Unlock() + go func() { + reports <- &report{ + module: m, + err: fmt.Errorf("module not offline"), + } + }() + return + } + m.status = StatusStarting + + // reset stop management + if m.cancelCtx != nil { + // trigger cancel just to be sure + m.cancelCtx() + } + m.Ctx, m.cancelCtx = context.WithCancel(context.Background()) + m.stopFlag.UnSet() + + m.Unlock() + + // run start function + go func() { + var err error + if m.startFn != nil { + // execute function + err = m.runCtrlFnWithTimeout( + "start module", + 10*time.Second, + m.startFn, + ) + } + // set status + if err != nil { + m.Error( + "module-failed-start", + fmt.Sprintf("failed to start module: %s", err.Error()), + ) + } else { + m.Lock() + m.status = StatusOnline + // init start management + close(m.startComplete) + m.Unlock() + m.notifyOfChange() + } + // send report + reports <- &report{ + module: m, + err: err, + } + }() +} + +func (m *Module) stop(reports chan *report) { + // check and set intermediate status + m.Lock() + if m.status != StatusOnline { + m.Unlock() + go func() { + reports <- &report{ + module: m, + err: fmt.Errorf("module not online"), + } + }() + return + } + m.status = StatusStopping + + // reset start management + m.startComplete = make(chan struct{}) + // init stop management + m.cancelCtx() + m.stopFlag.Set() + + m.Unlock() + + go m.stopAllTasks(reports) +} + +func (m *Module) stopAllTasks(reports chan *report) { + // start shutdown function + stopFnFinished := abool.NewBool(false) + var stopFnError error + if m.stopFn != nil { + m.waitGroup.Add(1) + go func() { + stopFnError = m.runCtrlFn("stop module", m.stopFn) + stopFnFinished.Set() + m.waitGroup.Done() + }() + } + + // wait for workers and stop fn done := make(chan struct{}) go func() { m.waitGroup.Wait() @@ -91,8 +250,9 @@ func (m *Module) shutdown() error { case <-done: case <-time.After(30 * time.Second): log.Warningf( - "%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...", + "%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v workers=%d tasks=%d microtasks=%d, continuing shutdown...", m.Name, + stopFnFinished.IsSet(), atomic.LoadInt32(m.workerCnt), atomic.LoadInt32(m.taskCnt), atomic.LoadInt32(m.microTaskCnt), @@ -100,24 +260,37 @@ func (m *Module) shutdown() error { } // collect error - select { - case err := <-stopFnError: - return err - default: - log.Warningf( - "%s: timed out while waiting for stop function to finish, continuing shutdown...", - m.Name, + var err error + if stopFnFinished.IsSet() && stopFnError != nil { + err = stopFnError + } + // set status + if err != nil { + m.Error( + "module-failed-stop", + fmt.Sprintf("failed to stop module: %s", err.Error()), ) - return nil + } else { + m.Lock() + m.status = StatusOffline + m.Unlock() + m.notifyOfChange() + } + // send report + reports <- &report{ + module: m, + err: err, } } // Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished. func Register(name string, prep, start, stop func() error, dependencies ...string) *Module { + if modulesLocked.IsSet() { + return nil + } + newModule := initNewModule(name, prep, start, stop, dependencies...) - modulesLock.Lock() - defer modulesLock.Unlock() // check for already existing module _, ok := modules[name] if ok { @@ -136,23 +309,22 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ... var microTaskCnt int32 newModule := &Module{ - Name: name, - Prepped: abool.NewBool(false), - Started: abool.NewBool(false), - Stopped: abool.NewBool(false), - inTransition: abool.NewBool(false), - Ctx: ctx, - cancelCtx: cancelCtx, - shutdownFlag: abool.NewBool(false), - waitGroup: sync.WaitGroup{}, - workerCnt: &workerCnt, - taskCnt: &taskCnt, - microTaskCnt: µTaskCnt, - prep: prep, - start: start, - stop: stop, - eventHooks: make(map[string][]*eventHook), - depNames: dependencies, + Name: name, + enabled: abool.NewBool(false), + enabledAsDependency: abool.NewBool(false), + prepFn: prep, + startFn: start, + stopFn: stop, + startComplete: make(chan struct{}), + Ctx: ctx, + cancelCtx: cancelCtx, + stopFlag: abool.NewBool(false), + workerCnt: &workerCnt, + taskCnt: &taskCnt, + microTaskCnt: µTaskCnt, + waitGroup: sync.WaitGroup{}, + eventHooks: make(map[string][]*eventHook), + depNames: dependencies, } return newModule @@ -177,49 +349,3 @@ func initDependencies() error { return nil } - -// ReadyToPrep returns whether all dependencies are ready for this module to prep. -func (m *Module) ReadyToPrep() bool { - if m.inTransition.IsSet() || m.Prepped.IsSet() { - return false - } - - for _, dep := range m.depModules { - if !dep.Prepped.IsSet() { - return false - } - } - - return true -} - -// ReadyToStart returns whether all dependencies are ready for this module to start. -func (m *Module) ReadyToStart() bool { - if m.inTransition.IsSet() || m.Started.IsSet() { - return false - } - - for _, dep := range m.depModules { - if !dep.Started.IsSet() { - return false - } - } - - return true -} - -// ReadyToStop returns whether all dependencies are ready for this module to stop. -func (m *Module) ReadyToStop() bool { - if !m.Started.IsSet() || m.inTransition.IsSet() || m.Stopped.IsSet() { - return false - } - - for _, revDep := range m.depReverse { - // not ready if a reverse dependency was started, but not yet stopped - if revDep.Started.IsSet() && !revDep.Stopped.IsSet() { - return false - } - } - - return true -} diff --git a/modules/modules_test.go b/modules/modules_test.go index 5989d7a..9d40fca 100644 --- a/modules/modules_test.go +++ b/modules/modules_test.go @@ -5,40 +5,36 @@ import ( "fmt" "sync" "testing" - "time" ) var ( - orderLock sync.Mutex - startOrder string - shutdownOrder string + changeHistoryLock sync.Mutex + changeHistory string ) -func testPrep(t *testing.T, name string) func() error { - return func() error { - t.Logf("prep %s\n", name) - return nil - } -} - -func testStart(t *testing.T, name string) func() error { - return func() error { - orderLock.Lock() - defer orderLock.Unlock() - t.Logf("start %s\n", name) - startOrder = fmt.Sprintf("%s>%s", startOrder, name) - return nil - } -} - -func testStop(t *testing.T, name string) func() error { - return func() error { - orderLock.Lock() - defer orderLock.Unlock() - t.Logf("stop %s\n", name) - shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name) - return nil - } +func registerTestModule(t *testing.T, name string, dependencies ...string) { + Register( + name, + func() error { + t.Logf("prep %s\n", name) + return nil + }, + func() error { + changeHistoryLock.Lock() + defer changeHistoryLock.Unlock() + t.Logf("start %s\n", name) + changeHistory = fmt.Sprintf("%s on:%s", changeHistory, name) + return nil + }, + func() error { + changeHistoryLock.Lock() + defer changeHistoryLock.Unlock() + t.Logf("stop %s\n", name) + changeHistory = fmt.Sprintf("%s off:%s", changeHistory, name) + return nil + }, + dependencies..., + ) } func testFail() error { @@ -53,135 +49,94 @@ func TestModules(t *testing.T) { t.Parallel() // Not really, just a workaround for running these tests last. t.Run("TestModuleOrder", testModuleOrder) + t.Run("TestModuleMgmt", testModuleMgmt) t.Run("TestModuleErrors", testModuleErrors) } func testModuleOrder(t *testing.T) { - Register("database", testPrep(t, "database"), testStart(t, "database"), testStop(t, "database")) - Register("stats", testPrep(t, "stats"), testStart(t, "stats"), testStop(t, "stats"), "database") - Register("service", testPrep(t, "service"), testStart(t, "service"), testStop(t, "service"), "database") - Register("analytics", testPrep(t, "analytics"), testStart(t, "analytics"), testStop(t, "analytics"), "stats", "database") + registerTestModule(t, "database") + registerTestModule(t, "stats", "database") + registerTestModule(t, "service", "database") + registerTestModule(t, "analytics", "stats", "database") err := Start() if err != nil { t.Error(err) } - if startOrder != ">database>service>stats>analytics" && - startOrder != ">database>stats>service>analytics" && - startOrder != ">database>stats>analytics>service" { - t.Errorf("start order mismatch, was %s", startOrder) + if changeHistory != " on:database on:service on:stats on:analytics" && + changeHistory != " on:database on:stats on:service on:analytics" && + changeHistory != " on:database on:stats on:analytics on:service" { + t.Errorf("start order mismatch, was %s", changeHistory) } + changeHistory = "" - var wg sync.WaitGroup - wg.Add(1) - go func() { - select { - case <-ShuttingDown(): - case <-time.After(1 * time.Second): - t.Error("did not receive shutdown signal") - } - wg.Done() - }() err = Shutdown() if err != nil { t.Error(err) } - if shutdownOrder != ">analytics>service>stats>database" && - shutdownOrder != ">analytics>stats>service>database" && - shutdownOrder != ">service>analytics>stats>database" { - t.Errorf("shutdown order mismatch, was %s", shutdownOrder) + if changeHistory != " off:analytics off:service off:stats off:database" && + changeHistory != " off:analytics off:stats off:service off:database" && + changeHistory != " off:service off:analytics off:stats off:database" { + t.Errorf("shutdown order mismatch, was %s", changeHistory) } + changeHistory = "" - wg.Wait() - - printAndRemoveModules() -} - -func printAndRemoveModules() { - modulesLock.Lock() - defer modulesLock.Unlock() - - fmt.Printf("All %d modules:\n", len(modules)) - for _, m := range modules { - fmt.Printf("module %s: %+v\n", m.Name, m) - } - - modules = make(map[string]*Module) + resetTestEnvironment() } func testModuleErrors(t *testing.T) { - // reset modules - modules = make(map[string]*Module) - startComplete.UnSet() - startCompleteSignal = make(chan struct{}) - // test prep error - Register("prepfail", testFail, testStart(t, "prepfail"), testStop(t, "prepfail")) + Register("prepfail", testFail, nil, nil) err := Start() if err == nil { t.Error("should fail") } - // reset modules - modules = make(map[string]*Module) - startComplete.UnSet() - startCompleteSignal = make(chan struct{}) + resetTestEnvironment() // test prep clean exit - Register("prepcleanexit", testCleanExit, testStart(t, "prepcleanexit"), testStop(t, "prepcleanexit")) + Register("prepcleanexit", testCleanExit, nil, nil) err = Start() if err != ErrCleanExit { t.Error("should fail with clean exit") } - // reset modules - modules = make(map[string]*Module) - startComplete.UnSet() - startCompleteSignal = make(chan struct{}) + resetTestEnvironment() // test invalid dependency - Register("database", nil, testStart(t, "database"), testStop(t, "database"), "invalid") + Register("database", nil, nil, nil, "invalid") err = Start() if err == nil { t.Error("should fail") } - // reset modules - modules = make(map[string]*Module) - startComplete.UnSet() - startCompleteSignal = make(chan struct{}) + resetTestEnvironment() // test dependency loop - Register("database", nil, testStart(t, "database"), testStop(t, "database"), "helper") - Register("helper", nil, testStart(t, "helper"), testStop(t, "helper"), "database") + registerTestModule(t, "database", "helper") + registerTestModule(t, "helper", "database") err = Start() if err == nil { t.Error("should fail") } - // reset modules - modules = make(map[string]*Module) - startComplete.UnSet() - startCompleteSignal = make(chan struct{}) + resetTestEnvironment() // test failing module start - Register("startfail", nil, testFail, testStop(t, "startfail")) + Register("startfail", nil, testFail, nil) err = Start() if err == nil { t.Error("should fail") } - // reset modules - modules = make(map[string]*Module) - startComplete.UnSet() - startCompleteSignal = make(chan struct{}) + resetTestEnvironment() // test failing module stop - Register("stopfail", nil, testStart(t, "stopfail"), testFail) + Register("stopfail", nil, nil, testFail) err = Start() if err != nil { t.Error("should not fail") @@ -191,10 +146,7 @@ func testModuleErrors(t *testing.T) { t.Error("should fail") } - // reset modules - modules = make(map[string]*Module) - startComplete.UnSet() - startCompleteSignal = make(chan struct{}) + resetTestEnvironment() // test help flag HelpFlag = true @@ -204,4 +156,20 @@ func testModuleErrors(t *testing.T) { } HelpFlag = false + resetTestEnvironment() +} + +func printModules() { //nolint:unused,deadcode + fmt.Printf("All %d modules:\n", len(modules)) + for _, m := range modules { + fmt.Printf("module %s: %+v\n", m.Name, m) + } +} + +func resetTestEnvironment() { + modules = make(map[string]*Module) + shutdownSignal = make(chan struct{}) + shutdownCompleteSignal = make(chan struct{}) + shutdownFlag.UnSet() + modulesLocked.UnSet() } diff --git a/modules/start.go b/modules/start.go index 6c945f3..fb2d24a 100644 --- a/modules/start.go +++ b/modules/start.go @@ -1,34 +1,38 @@ package modules import ( + "errors" "fmt" "os" "runtime" - "time" + + "github.com/tevino/abool" "github.com/safing/portbase/log" - "github.com/tevino/abool" ) var ( - startComplete = abool.NewBool(false) - startCompleteSignal = make(chan struct{}) + initialStartCompleted = abool.NewBool(false) + globalPrepFn func() error ) -// StartCompleted returns whether starting has completed. -func StartCompleted() bool { - return startComplete.IsSet() -} - -// WaitForStartCompletion returns as soon as starting has completed. -func WaitForStartCompletion() <-chan struct{} { - return startCompleteSignal +// SetGlobalPrepFn sets a global prep function that is run before all modules. This can be used to pre-initialize modules, such as setting the data root or database path. +// SetGlobalPrepFn sets a global prep function that is run before all modules. +func SetGlobalPrepFn(fn func() error) { + if globalPrepFn == nil { + globalPrepFn = fn + } } // Start starts all modules in the correct order. In case of an error, it will automatically shutdown again. func Start() error { - modulesLock.RLock() - defer modulesLock.RUnlock() + if !modulesLocked.SetToIf(false, true) { + return errors.New("module system already started") + } + + // lock mgmt + mgmtLock.Lock() + defer mgmtLock.Unlock() // start microtask scheduler go microTaskScheduler() @@ -44,10 +48,23 @@ func Start() error { // parse flags err = parseFlags() if err != nil { - fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err) + if err != ErrCleanExit { + fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err) + } return err } + // execute global prep fn + if globalPrepFn != nil { + err = globalPrepFn() + if err != nil { + if err != ErrCleanExit { + fmt.Fprintf(os.Stderr, "CRITICAL ERROR: %s\n", err) + } + return err + } + } + // prep modules err = prepareModules() if err != nil { @@ -65,6 +82,9 @@ func Start() error { return err } + // build dependency tree + buildEnabledTree() + // start modules log.Info("modules: initiating...") err = startModules() @@ -74,14 +94,16 @@ func Start() error { } // complete startup - log.Infof("modules: started %d modules", len(modules)) - if startComplete.SetToIf(false, true) { - close(startCompleteSignal) + if moduleMgmtEnabled.IsSet() { + log.Info("modules: initiated subsystems manager") + } else { + log.Infof("modules: started %d modules", len(modules)) } go taskQueueHandler() go taskScheduleHandler() + initialStartCompleted.Set() return nil } @@ -97,45 +119,36 @@ func prepareModules() error { reportCnt := 0 for { + waiting := 0 + // find modules to exec for _, m := range modules { - if m.ReadyToPrep() { + switch m.readyToPrep() { + case statusNothingToDo: + case statusWaiting: + waiting++ + case statusReady: execCnt++ - m.inTransition.Set() - - execM := m - go func() { - reports <- &report{ - module: execM, - err: execM.runCtrlFnWithTimeout( - "prep module", - 10*time.Second, - execM.prep, - ), - } - }() + m.prep(reports) } } - // check for dep loop - if execCnt == reportCnt { - return fmt.Errorf("modules: dependency loop detected, cannot continue") - } - - // wait for reports - rep = <-reports - rep.module.inTransition.UnSet() - if rep.err != nil { - if rep.err == ErrCleanExit { - return rep.err + if reportCnt < execCnt { + // wait for reports + rep = <-reports + if rep.err != nil { + if rep.err == ErrCleanExit { + return rep.err + } + return fmt.Errorf("failed to prep module %s: %s", rep.module.Name, rep.err) + } + reportCnt++ + } else { + // finished + if waiting > 0 { + // check for dep loop + return fmt.Errorf("modules: dependency loop detected, cannot continue") } - return fmt.Errorf("failed to prep module %s: %s", rep.module.Name, rep.err) - } - reportCnt++ - rep.module.Prepped.Set() - - // exit if done - if reportCnt == len(modules) { return nil } @@ -149,45 +162,36 @@ func startModules() error { reportCnt := 0 for { + waiting := 0 + // find modules to exec for _, m := range modules { - if m.ReadyToStart() { + switch m.readyToStart() { + case statusNothingToDo: + case statusWaiting: + waiting++ + case statusReady: execCnt++ - m.inTransition.Set() - - execM := m - go func() { - reports <- &report{ - module: execM, - err: execM.runCtrlFnWithTimeout( - "start module", - 60*time.Second, - execM.start, - ), - } - }() + m.start(reports) } } - // check for dep loop - if execCnt == reportCnt { - return fmt.Errorf("modules: dependency loop detected, cannot continue") - } - - // wait for reports - rep = <-reports - rep.module.inTransition.UnSet() - if rep.err != nil { - return fmt.Errorf("modules: could not start module %s: %s", rep.module.Name, rep.err) - } - reportCnt++ - rep.module.Started.Set() - log.Infof("modules: started %s", rep.module.Name) - - // exit if done - if reportCnt == len(modules) { + if reportCnt < execCnt { + // wait for reports + rep = <-reports + if rep.err != nil { + return fmt.Errorf("modules: could not start module %s: %s", rep.module.Name, rep.err) + } + reportCnt++ + log.Infof("modules: started %s", rep.module.Name) + } else { + // finished + if waiting > 0 { + // check for dep loop + return fmt.Errorf("modules: dependency loop detected, cannot continue") + } + // return last error return nil } - } } diff --git a/modules/status.go b/modules/status.go new file mode 100644 index 0000000..94ccc1b --- /dev/null +++ b/modules/status.go @@ -0,0 +1,171 @@ +package modules + +// Module Status Values +const ( + StatusDead uint8 = 0 // not prepared, not started + StatusPreparing uint8 = 1 + StatusOffline uint8 = 2 // prepared, not started + StatusStopping uint8 = 3 + StatusStarting uint8 = 4 + StatusOnline uint8 = 5 // online and running +) + +// Module Failure Status Values +const ( + FailureNone uint8 = 0 + FailureHint uint8 = 1 + FailureWarning uint8 = 2 + FailureError uint8 = 3 +) + +// ready status +const ( + statusWaiting uint8 = iota + statusReady + statusNothingToDo +) + +// Online returns whether the module is online. +func (m *Module) Online() bool { + return m.Status() == StatusOnline +} + +// OnlineSoon returns whether the module is or is about to be online. +func (m *Module) OnlineSoon() bool { + if moduleMgmtEnabled.IsSet() && + !m.enabled.IsSet() && + !m.enabledAsDependency.IsSet() { + return false + } + return !m.stopFlag.IsSet() +} + +// Status returns the current module status. +func (m *Module) Status() uint8 { + m.RLock() + defer m.RUnlock() + + return m.status +} + +// FailureStatus returns the current failure status, ID and message. +func (m *Module) FailureStatus() (failureStatus uint8, failureID, failureMsg string) { + m.RLock() + defer m.RUnlock() + + return m.failureStatus, m.failureID, m.failureMsg +} + +// Hint sets failure status to hint. This is a somewhat special failure status, as the module is believed to be working correctly, but there is an important module specific information to convey. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans. +func (m *Module) Hint(failureID, failureMsg string) { + m.Lock() + defer m.Unlock() + + m.failureStatus = FailureHint + m.failureID = failureID + m.failureMsg = failureMsg + + m.notifyOfChange() +} + +// Warning sets failure status to warning. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans. +func (m *Module) Warning(failureID, failureMsg string) { + m.Lock() + defer m.Unlock() + + m.failureStatus = FailureWarning + m.failureID = failureID + m.failureMsg = failureMsg + + m.notifyOfChange() +} + +// Error sets failure status to error. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans. +func (m *Module) Error(failureID, failureMsg string) { + m.Lock() + defer m.Unlock() + + m.failureStatus = FailureError + m.failureID = failureID + m.failureMsg = failureMsg + + m.notifyOfChange() +} + +// Resolve removes the failure state from the module if the given failureID matches the current failure ID. If the given failureID is an empty string, Resolve removes any failure state. +func (m *Module) Resolve(failureID string) { + m.Lock() + defer m.Unlock() + + if failureID == "" || failureID == m.failureID { + m.failureStatus = FailureNone + m.failureID = "" + m.failureMsg = "" + } + + m.notifyOfChange() +} + +// readyToPrep returns whether all dependencies are ready for this module to prep. +func (m *Module) readyToPrep() uint8 { + // check if valid state for prepping + if m.Status() != StatusDead { + return statusNothingToDo + } + + for _, dep := range m.depModules { + if dep.Status() < StatusOffline { + return statusWaiting + } + } + + return statusReady +} + +// readyToStart returns whether all dependencies are ready for this module to start. +func (m *Module) readyToStart() uint8 { + // check if start is wanted + if moduleMgmtEnabled.IsSet() { + if !m.enabled.IsSet() && !m.enabledAsDependency.IsSet() { + return statusNothingToDo + } + } + + // check if valid state for starting + if m.Status() != StatusOffline { + return statusNothingToDo + } + + // check if all dependencies are ready + for _, dep := range m.depModules { + if dep.Status() < StatusOnline { + return statusWaiting + } + } + + return statusReady +} + +// readyToStop returns whether all dependencies are ready for this module to stop. +func (m *Module) readyToStop() uint8 { + // check if stop is wanted + if moduleMgmtEnabled.IsSet() && !shutdownFlag.IsSet() { + if m.enabled.IsSet() || m.enabledAsDependency.IsSet() { + return statusNothingToDo + } + } + + // check if valid state for stopping + if m.Status() != StatusOnline { + return statusNothingToDo + } + + for _, revDep := range m.depReverse { + // not ready if a reverse dependency was started, but not yet stopped + if revDep.Status() > StatusOffline { + return statusWaiting + } + } + + return statusReady +} diff --git a/modules/stop.go b/modules/stop.go index 17e8b7b..cbfefad 100644 --- a/modules/stop.go +++ b/modules/stop.go @@ -10,12 +10,17 @@ import ( ) var ( - shutdownSignal = make(chan struct{}) - shutdownSignalClosed = abool.NewBool(false) + shutdownSignal = make(chan struct{}) + shutdownFlag = abool.NewBool(false) shutdownCompleteSignal = make(chan struct{}) ) +// IsShuttingDown returns whether the global shutdown is in progress. +func IsShuttingDown() bool { + return shutdownFlag.IsSet() +} + // ShuttingDown returns a channel read on the global shutdown signal. func ShuttingDown() <-chan struct{} { return shutdownSignal @@ -23,18 +28,19 @@ func ShuttingDown() <-chan struct{} { // Shutdown stops all modules in the correct order. func Shutdown() error { + // lock mgmt + mgmtLock.Lock() + defer mgmtLock.Unlock() - if shutdownSignalClosed.SetToIf(false, true) { + if shutdownFlag.SetToIf(false, true) { close(shutdownSignal) } else { // shutdown was already issued return errors.New("shutdown already initiated") } - if startComplete.IsSet() { + if initialStartCompleted.IsSet() { log.Warning("modules: starting shutdown...") - modulesLock.Lock() - defer modulesLock.Unlock() } else { log.Warning("modules: aborting, shutting down...") } @@ -61,46 +67,42 @@ func stopModules() error { // get number of started modules startedCnt := 0 for _, m := range modules { - if m.Started.IsSet() { + if m.Status() >= StatusStarting { startedCnt++ } } for { + waiting := 0 + // find modules to exec for _, m := range modules { - if m.ReadyToStop() { + switch m.readyToStop() { + case statusNothingToDo: + case statusWaiting: + waiting++ + case statusReady: execCnt++ - m.inTransition.Set() - - execM := m - go func() { - reports <- &report{ - module: execM, - err: execM.shutdown(), - } - }() + m.stop(reports) } } - // check for dep loop - if execCnt == reportCnt { - return fmt.Errorf("modules: dependency loop detected, cannot continue") - } - - // wait for reports - rep = <-reports - rep.module.inTransition.UnSet() - if rep.err != nil { - lastErr = rep.err - log.Warningf("modules: could not stop module %s: %s", rep.module.Name, rep.err) - } - reportCnt++ - rep.module.Stopped.Set() - log.Infof("modules: stopped %s", rep.module.Name) - - // exit if done - if reportCnt == startedCnt { + if reportCnt < execCnt { + // wait for reports + rep = <-reports + if rep.err != nil { + lastErr = rep.err + log.Warningf("modules: could not stop module %s: %s", rep.module.Name, rep.err) + } + reportCnt++ + log.Infof("modules: stopped %s", rep.module.Name) + } else { + // finished + if waiting > 0 { + // check for dep loop + return fmt.Errorf("modules: dependency loop detected, cannot continue") + } + // return last error return lastErr } } diff --git a/modules/subsystems/module.go b/modules/subsystems/module.go new file mode 100644 index 0000000..b7e4a61 --- /dev/null +++ b/modules/subsystems/module.go @@ -0,0 +1,123 @@ +package subsystems + +import ( + "context" + "flag" + "fmt" + "strings" + + "github.com/safing/portbase/database" + _ "github.com/safing/portbase/database/dbmodule" // database module is required + "github.com/safing/portbase/modules" +) + +const ( + configChangeEvent = "config change" + subsystemsStatusChange = "status change" +) + +var ( + module *modules.Module + printGraphFlag bool + + databaseKeySpace string + db = database.NewInterface(nil) +) + +func init() { + // enable partial starting + modules.EnableModuleManagement(handleModuleChanges) + + // register module and enable it for starting + module = modules.Register("subsystems", prep, start, nil, "config", "database", "base") + module.Enable() + + // register event for changes in the subsystem + module.RegisterEvent(subsystemsStatusChange) + + flag.BoolVar(&printGraphFlag, "print-subsystem-graph", false, "print the subsystem module dependency graph") +} + +func prep() error { + if printGraphFlag { + printGraph() + return modules.ErrCleanExit + } + + return module.RegisterEventHook("config", configChangeEvent, "control subsystems", handleConfigChanges) +} + +func start() error { + // lock registration + subsystemsLocked.Set() + + // lock slice and map + subsystemsLock.Lock() + // go through all dependencies + seen := make(map[string]struct{}) + for _, sub := range subsystems { + // mark subsystem module as seen + seen[sub.module.Name] = struct{}{} + } + for _, sub := range subsystems { + // add main module + sub.Modules = append(sub.Modules, statusFromModule(sub.module)) + // add dependencies + sub.addDependencies(sub.module, seen) + } + // unlock + subsystemsLock.Unlock() + + // apply config + module.StartWorker("initial subsystem configuration", func(ctx context.Context) error { + return handleConfigChanges(module.Ctx, nil) + }) + return nil +} + +func (sub *Subsystem) addDependencies(module *modules.Module, seen map[string]struct{}) { + for _, module := range module.Dependencies() { + _, ok := seen[module.Name] + if !ok { + // add dependency to modules + sub.Modules = append(sub.Modules, statusFromModule(module)) + // mark as seen + seen[module.Name] = struct{}{} + // add further dependencies + sub.addDependencies(module, seen) + } + } +} + +// SetDatabaseKeySpace sets a key space where subsystem status +func SetDatabaseKeySpace(keySpace string) { + if databaseKeySpace == "" { + databaseKeySpace = keySpace + + if !strings.HasSuffix(databaseKeySpace, "/") { + databaseKeySpace += "/" + } + } +} + +func printGraph() { + // unmark subsystems module + module.Disable() + // mark roots + for _, sub := range subsystems { + sub.module.Enable() // mark as tree root + } + // print + for _, sub := range subsystems { + printModuleGraph("", sub.module, true) + } +} + +func printModuleGraph(prefix string, module *modules.Module, root bool) { + fmt.Printf("%sā”œā”€ā”€ %s\n", prefix, module.Name) + if root || !module.Enabled() { + for _, dep := range module.Dependencies() { + printModuleGraph(fmt.Sprintf("│ %s", prefix), dep, false) + } + } +} diff --git a/modules/subsystems/subsystem.go b/modules/subsystems/subsystem.go new file mode 100644 index 0000000..1a8bec4 --- /dev/null +++ b/modules/subsystems/subsystem.go @@ -0,0 +1,116 @@ +package subsystems + +import ( + "sync" + + "github.com/safing/portbase/config" + "github.com/safing/portbase/database/record" + "github.com/safing/portbase/log" + "github.com/safing/portbase/modules" +) + +// Subsystem describes a subset of modules that represent a part of a service or program to the user. +type Subsystem struct { //nolint:maligned // not worth the effort + record.Base + sync.Mutex + + ID string + Name string + Description string + module *modules.Module + + Modules []*ModuleStatus + FailureStatus uint8 // summary: worst status + + ToggleOptionKey string + toggleOption *config.Option + toggleValue func() bool + ExpertiseLevel uint8 // copied from toggleOption + ReleaseLevel uint8 // copied from toggleOption + + ConfigKeySpace string +} + +// ModuleStatus describes the status of a module. +type ModuleStatus struct { + Name string + module *modules.Module + + // status mgmt + Enabled bool + Status uint8 + + // failure status + FailureStatus uint8 + FailureID string + FailureMsg string +} + +// Save saves the Subsystem Status to the database. +func (sub *Subsystem) Save() { + if databaseKeySpace != "" { + if !sub.KeyIsSet() { + sub.SetKey(databaseKeySpace + sub.ID) + } + err := db.Put(sub) + if err != nil { + log.Errorf("subsystems: could not save subsystem status to database: %s", err) + } + } +} + +func statusFromModule(module *modules.Module) *ModuleStatus { + status := &ModuleStatus{ + Name: module.Name, + module: module, + Enabled: module.Enabled(), + Status: module.Status(), + } + status.FailureStatus, status.FailureID, status.FailureMsg = module.FailureStatus() + + return status +} + +func compareAndUpdateStatus(module *modules.Module, status *ModuleStatus) (changed bool) { + // check if enabled + enabled := module.Enabled() + if status.Enabled != enabled { + status.Enabled = enabled + changed = true + } + + // check status + statusLvl := module.Status() + if status.Status != statusLvl { + status.Status = statusLvl + changed = true + } + + // check failure status + failureStatus, failureID, failureMsg := module.FailureStatus() + if status.FailureStatus != failureStatus || + status.FailureID != failureID { + status.FailureStatus = failureStatus + status.FailureID = failureID + status.FailureMsg = failureMsg + changed = true + } + + return +} + +func (sub *Subsystem) makeSummary() { + // find worst failing module + worstFailing := &ModuleStatus{} + for _, depStatus := range sub.Modules { + if depStatus.FailureStatus > worstFailing.FailureStatus { + worstFailing = depStatus + } + } + + if worstFailing != nil { + sub.FailureStatus = worstFailing.FailureStatus + } else { + sub.FailureStatus = 0 + } +} diff --git a/modules/subsystems/subsystems.go b/modules/subsystems/subsystems.go new file mode 100644 index 0000000..1726a53 --- /dev/null +++ b/modules/subsystems/subsystems.go @@ -0,0 +1,161 @@ +package subsystems + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/tevino/abool" + + "github.com/safing/portbase/config" + "github.com/safing/portbase/modules" +) + +var ( + subsystems []*Subsystem + subsystemsMap = make(map[string]*Subsystem) + subsystemsLock sync.Mutex + subsystemsLocked = abool.New() + + handlingConfigChanges = abool.New() +) + +// Register registers a new subsystem. The given option must be a bool option. Should be called in init() directly after the modules.Register() function. The config option must not yet be registered and will be registered for you. Pass a nil option to force enable. +func Register(id, name, description string, module *modules.Module, configKeySpace string, option *config.Option) { + // lock slice and map + subsystemsLock.Lock() + defer subsystemsLock.Unlock() + + // check if registration is closed + if subsystemsLocked.IsSet() { + panic("subsystems can only be registered in prep phase or earlier") + } + + // check if already registered + _, ok := subsystemsMap[name] + if ok { + panic(fmt.Sprintf(`subsystem "%s" already registered`, name)) + } + + // create new + new := &Subsystem{ + ID: id, + Name: name, + Description: description, + module: module, + toggleOption: option, + ConfigKeySpace: configKeySpace, + } + if new.toggleOption != nil { + new.ToggleOptionKey = new.toggleOption.Key + new.ExpertiseLevel = new.toggleOption.ExpertiseLevel + new.ReleaseLevel = new.toggleOption.ReleaseLevel + } + + // register config + if option != nil { + err := config.Register(option) + if err != nil { + panic(fmt.Sprintf("failed to register config: %s", err)) + } + new.toggleValue = config.GetAsBool(new.ToggleOptionKey, false) + } else { + // force enabled + new.toggleValue = func() bool { return true } + } + + // add to lists + subsystemsMap[name] = new + subsystems = append(subsystems, new) +} + +func handleModuleChanges(m *modules.Module) { + // check if ready + if !subsystemsLocked.IsSet() { + return + } + + // check if shutting down + if modules.IsShuttingDown() { + return + } + + // find module status + var moduleSubsystem *Subsystem + var moduleStatus *ModuleStatus +subsystemLoop: + for _, subsystem := range subsystems { + for _, status := range subsystem.Modules { + if m.Name == status.Name { + moduleSubsystem = subsystem + moduleStatus = status + break subsystemLoop + } + } + } + // abort if not found + if moduleSubsystem == nil || moduleStatus == nil { + return + } + + // update status + moduleSubsystem.Lock() + changed := compareAndUpdateStatus(m, moduleStatus) + if changed { + moduleSubsystem.makeSummary() + } + moduleSubsystem.Unlock() + + // save + if changed { + moduleSubsystem.Save() + } +} + +func handleConfigChanges(ctx context.Context, data interface{}) error { + // check if ready + if !subsystemsLocked.IsSet() { + return nil + } + + // potentially catch multiple changes + if handlingConfigChanges.SetToIf(false, true) { + time.Sleep(100 * time.Millisecond) + handlingConfigChanges.UnSet() + } else { + return nil + } + + // don't do anything if we are already shutting down globally + if modules.IsShuttingDown() { + return nil + } + + // only run one instance at any time + subsystemsLock.Lock() + defer subsystemsLock.Unlock() + + var changed bool + for _, subsystem := range subsystems { + if subsystem.module.SetEnabled(subsystem.toggleValue()) { + // if changed + changed = true + } + } + + // trigger module management if any setting was changed + if changed { + err := modules.ManageModules() + if err != nil { + module.Error( + "modulemgmt-failed", + fmt.Sprintf("The subsystem framework failed to start or stop one or more modules.\nError: %s\nCheck logs for more information.", err), + ) + } else { + module.Resolve("modulemgmt-failed") + } + } + + return nil +} diff --git a/modules/subsystems/subsystems_test.go b/modules/subsystems/subsystems_test.go new file mode 100644 index 0000000..601b668 --- /dev/null +++ b/modules/subsystems/subsystems_test.go @@ -0,0 +1,123 @@ +package subsystems + +import ( + "io/ioutil" + "os" + "testing" + "time" + + "github.com/safing/portbase/config" + "github.com/safing/portbase/dataroot" + "github.com/safing/portbase/modules" +) + +func TestSubsystems(t *testing.T) { + // tmp dir for data root (db & config) + tmpDir, err := ioutil.TempDir("", "portbase-testing-") + // initialize data dir + if err == nil { + err = dataroot.Initialize(tmpDir, 0755) + } + // handle setup error + if err != nil { + t.Fatal(err) + } + + // register + + baseModule := modules.Register("base", nil, nil, nil) + Register( + "base", + "Base", + "Framework Groundwork", + baseModule, + "config:base", + nil, + ) + + feature1 := modules.Register("feature1", nil, nil, nil) + Register( + "feature-one", + "Feature One", + "Provides feature one", + feature1, + "config:feature1", + &config.Option{ + Name: "Enable Feature One", + Key: "config:subsystems/feature1", + Description: "This option enables feature 1", + OptType: config.OptTypeBool, + DefaultValue: false, + }, + ) + sub1 := subsystemsMap["Feature One"] + + feature2 := modules.Register("feature2", nil, nil, nil) + Register( + "feature-two", + "Feature Two", + "Provides feature two", + feature2, + "config:feature2", + &config.Option{ + Name: "Enable Feature One", + Key: "config:subsystems/feature2", + Description: "This option enables feature 2", + OptType: config.OptTypeBool, + DefaultValue: false, + }, + ) + + // start + err = modules.Start() + if err != nil { + t.Fatal(err) + } + + // test + + // let module fail + feature1.Error("test-fail", "Testing Fail") + time.Sleep(10 * time.Millisecond) + if sub1.FailureStatus != modules.FailureError { + t.Fatal("error did not propagate") + } + + // resolve + feature1.Resolve("test-fail") + time.Sleep(10 * time.Millisecond) + if sub1.FailureStatus != modules.FailureNone { + t.Fatal("error resolving did not propagate") + } + + // update settings + err = config.SetConfigOption("config:subsystems/feature2", true) + if err != nil { + t.Fatal(err) + return + } + time.Sleep(200 * time.Millisecond) + if !feature2.Enabled() { + t.Fatal("failed to enable feature2") + } + if feature2.Status() != modules.StatusOnline { + t.Fatal("feature2 did not start") + } + + // update settings + err = config.SetConfigOption("config:subsystems/feature2", false) + if err != nil { + t.Fatal(err) + return + } + time.Sleep(200 * time.Millisecond) + if feature2.Enabled() { + t.Fatal("failed to disable feature2") + } + if feature2.Status() != modules.StatusOffline { + t.Fatal("feature2 did not stop") + } + + // clean up and exit + os.RemoveAll(tmpDir) +} diff --git a/modules/tasks.go b/modules/tasks.go index 2af4fb2..2fe08d9 100644 --- a/modules/tasks.go +++ b/modules/tasks.go @@ -8,21 +8,24 @@ import ( "sync/atomic" "time" - "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/tevino/abool" ) // Task is managed task bound to a module. type Task struct { name string module *Module - taskFn func(context.Context, *Task) + taskFn func(context.Context, *Task) error - queued bool - canceled bool - executing bool - cancelFunc func() + queued bool + canceled bool + executing bool + + // these are populated at task creation + // ctx is canceled when task is shutdown -> all tasks become canceled + ctx context.Context + cancelCtx func() executeAt time.Time repeat time.Duration @@ -59,25 +62,46 @@ const ( ) // NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed. -func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task { +func (m *Module) NewTask(name string, fn func(context.Context, *Task) error) *Task { + m.Lock() + defer m.Unlock() + if m == nil { log.Errorf(`modules: cannot create task "%s" with nil module`, name) + return &Task{ + name: name, + module: &Module{Name: "[NONE]"}, + canceled: true, + } + } + if m.Ctx == nil || !m.OnlineSoon() { + log.Errorf(`modules: tasks should only be started when the module is online or starting`) + return &Task{ + name: name, + module: m, + canceled: true, + } } - return &Task{ + // create new task + new := &Task{ name: name, module: m, taskFn: fn, maxDelay: defaultMaxDelay, } + + // create context + new.ctx, new.cancelCtx = context.WithCancel(m.Ctx) + + return new } func (t *Task) isActive() bool { - if t.module == nil { + if t.canceled { return false } - - return !t.canceled && !t.module.ShutdownInProgress() + return t.module.OnlineSoon() } func (t *Task) prepForQueueing() (ok bool) { @@ -197,45 +221,15 @@ func (t *Task) Repeat(interval time.Duration) *Task { func (t *Task) Cancel() { t.lock.Lock() t.canceled = true - if t.cancelFunc != nil { - t.cancelFunc() + if t.cancelCtx != nil { + t.cancelCtx() } t.lock.Unlock() } -func (t *Task) runWithLocking() { - if t.module == nil { - return - } - - // wait for good timeslot regarding microtasks - select { - case <-taskTimeslot: - case <-time.After(maxTimeslotWait): - } - - t.lock.Lock() - - // check state, return if already executing or inactive - if t.executing || !t.isActive() { - t.lock.Unlock() - return - } - t.executing = true - - // get list elements - queueElement := t.queueElement - prioritizedQueueElement := t.prioritizedQueueElement - scheduleListElement := t.scheduleListElement - - // create context - var taskCtx context.Context - taskCtx, t.cancelFunc = context.WithCancel(t.module.Ctx) - - t.lock.Unlock() - +func (t *Task) removeFromQueues() { // remove from lists - if queueElement != nil { + if t.queueElement != nil { queuesLock.Lock() taskQueue.Remove(t.queueElement) queuesLock.Unlock() @@ -243,7 +237,7 @@ func (t *Task) runWithLocking() { t.queueElement = nil t.lock.Unlock() } - if prioritizedQueueElement != nil { + if t.prioritizedQueueElement != nil { queuesLock.Lock() prioritizedTaskQueue.Remove(t.prioritizedQueueElement) queuesLock.Unlock() @@ -251,7 +245,7 @@ func (t *Task) runWithLocking() { t.prioritizedQueueElement = nil t.lock.Unlock() } - if scheduleListElement != nil { + if t.scheduleListElement != nil { scheduleLock.Lock() taskSchedule.Remove(t.scheduleListElement) scheduleLock.Unlock() @@ -259,14 +253,62 @@ func (t *Task) runWithLocking() { t.scheduleListElement = nil t.lock.Unlock() } +} + +func (t *Task) runWithLocking() { + t.lock.Lock() + + // check if task is already executing + if t.executing { + t.lock.Unlock() + return + } + + // check if task is active + if !t.isActive() { + t.removeFromQueues() + t.lock.Unlock() + return + } + + // check if module was stopped + select { + case <-t.ctx.Done(): // check if module is stopped + t.removeFromQueues() + t.lock.Unlock() + return + default: + } + + t.executing = true + t.lock.Unlock() + + // wait for good timeslot regarding microtasks + select { + case <-taskTimeslot: + case <-time.After(maxTimeslotWait): + } + + // wait for module start + if !t.module.Online() { + if t.module.OnlineSoon() { + // wait + <-t.module.StartCompleted() + } else { + t.lock.Lock() + t.removeFromQueues() + t.lock.Unlock() + return + } + } // add to queue workgroup queueWg.Add(1) - go t.executeWithLocking(taskCtx, t.cancelFunc) + go t.executeWithLocking() go func() { select { - case <-taskCtx.Done(): + case <-t.ctx.Done(): case <-time.After(maxExecutionWait): } // complete queue worker (early) to allow next worker @@ -274,7 +316,7 @@ func (t *Task) runWithLocking() { }() } -func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) { +func (t *Task) executeWithLocking() { // start for module // hint: only queueWg global var is important for scheduling, others can be set here atomic.AddInt32(t.module.taskCnt, 1) @@ -306,11 +348,16 @@ func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) { t.lock.Unlock() // notify that we finished - cancelFunc() + if t.cancelCtx != nil { + t.cancelCtx() + } }() // run - t.taskFn(ctx, t) + err := t.taskFn(t.ctx, t) + if err != nil { + log.Errorf("%s: task %s failed: %s", t.module.Name, t.name, err) + } } func (t *Task) getExecuteAtWithLocking() time.Time { @@ -320,6 +367,10 @@ func (t *Task) getExecuteAtWithLocking() time.Time { } func (t *Task) addToSchedule() { + if !t.isActive() { + return + } + scheduleLock.Lock() defer scheduleLock.Unlock() // defer printTaskList(taskSchedule) // for debugging @@ -395,7 +446,7 @@ func taskQueueHandler() { queueWg.Wait() // check for shutdown - if shutdownSignalClosed.IsSet() { + if shutdownFlag.IsSet() { return } diff --git a/modules/tasks_test.go b/modules/tasks_test.go index d32e770..a7ff01d 100644 --- a/modules/tasks_test.go +++ b/modules/tasks_test.go @@ -35,22 +35,29 @@ func init() { var qtWg sync.WaitGroup var qtOutputChannel chan string var qtSleepDuration time.Duration -var qtModule = initNewModule("task test module", nil, nil, nil) +var qtModule *Module + +func init() { + qtModule = initNewModule("task test module", nil, nil, nil) + qtModule.status = StatusOnline +} // functions func queuedTaskTester(s string) { - qtModule.NewTask(s, func(ctx context.Context, t *Task) { + qtModule.NewTask(s, func(ctx context.Context, t *Task) error { time.Sleep(qtSleepDuration * 2) qtOutputChannel <- s qtWg.Done() + return nil }).Queue() } func prioritizedTaskTester(s string) { - qtModule.NewTask(s, func(ctx context.Context, t *Task) { + qtModule.NewTask(s, func(ctx context.Context, t *Task) error { time.Sleep(qtSleepDuration * 2) qtOutputChannel <- s qtWg.Done() + return nil }).Prioritize() } @@ -109,10 +116,11 @@ var stWaitCh chan bool // functions func scheduledTaskTester(s string, sched time.Time) { - qtModule.NewTask(s, func(ctx context.Context, t *Task) { + qtModule.NewTask(s, func(ctx context.Context, t *Task) error { time.Sleep(stSleepDuration) stOutputChannel <- s stWg.Done() + return nil }).Schedule(sched) } diff --git a/modules/worker.go b/modules/worker.go index d0481f9..12682e6 100644 --- a/modules/worker.go +++ b/modules/worker.go @@ -73,7 +73,7 @@ func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn lastFail := time.Now() for { - if m.ShutdownInProgress() { + if m.IsStopping() { return } diff --git a/notifications/database.go b/notifications/database.go index debeffc..58b455c 100644 --- a/notifications/database.go +++ b/notifications/database.go @@ -111,26 +111,26 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { } // Put stores a record in the database. -func (s *StorageInterface) Put(r record.Record) error { +func (s *StorageInterface) Put(r record.Record) (record.Record, error) { // record is already locked! key := r.DatabaseKey() n, err := EnsureNotification(r) if err != nil { - return ErrInvalidData + return nil, ErrInvalidData } // transform key if strings.HasPrefix(key, "all/") { key = strings.TrimPrefix(key, "all/") } else { - return ErrInvalidPath + return nil, ErrInvalidPath } // continue in goroutine go UpdateNotification(n, key) - return nil + return n, nil } // UpdateNotification updates a notification with input from a database action. Notification will not be saved/propagated if there is no valid change. diff --git a/notifications/module.go b/notifications/module.go index 39985ae..b523a99 100644 --- a/notifications/module.go +++ b/notifications/module.go @@ -11,7 +11,7 @@ var ( ) func init() { - module = modules.Register("notifications", nil, start, nil, "base", "database") + module = modules.Register("notifications", nil, start, nil, "database", "base") } func start() error { diff --git a/notifications/notification.go b/notifications/notification.go index 52af0fc..635e43d 100644 --- a/notifications/notification.go +++ b/notifications/notification.go @@ -93,6 +93,7 @@ func (n *Notification) Save() *Notification { nots[n.ID] = n // push update + log.Tracef("notifications: pushing update for %s to subscribers", n.Key()) dbController.PushUpdate(n) // persist @@ -152,10 +153,11 @@ func (n *Notification) MakeAck() *Notification { // Response waits for the user to respond to the notification and returns the selected action. func (n *Notification) Response() <-chan string { n.lock.Lock() + defer n.lock.Unlock() + if n.actionTrigger == nil { n.actionTrigger = make(chan string) } - n.lock.Unlock() return n.actionTrigger } @@ -213,10 +215,11 @@ func (n *Notification) Delete() error { // Expired notifies the caller when the notification has expired. func (n *Notification) Expired() <-chan struct{} { n.lock.Lock() + defer n.lock.Unlock() + if n.expiredTrigger == nil { n.expiredTrigger = make(chan struct{}) } - n.lock.Unlock() return n.expiredTrigger } diff --git a/portbase.go b/portbase.go index 61c1662..8d9385e 100644 --- a/portbase.go +++ b/portbase.go @@ -1,49 +1,19 @@ package main import ( - "fmt" "os" - "os/signal" - "syscall" "github.com/safing/portbase/info" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portbase/run" + // include packages here + _ "github.com/safing/portbase/api" ) func main() { - // Set Info info.Set("Portbase", "0.0.1", "GPLv3", false) - // Start - err := modules.Start() - if err != nil { - if err == modules.ErrCleanExit { - os.Exit(0) - } else { - os.Exit(1) - } - } - - // Shutdown - // catch interrupt for clean shutdown - signalCh := make(chan os.Signal, 3) - signal.Notify( - signalCh, - os.Interrupt, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT, - ) - select { - case <-signalCh: - fmt.Println(" ") - log.Warning("main: program was interrupted, shutting down.") - _ = modules.Shutdown() - case <-modules.ShuttingDown(): - } - + // Run + os.Exit(run.Run()) } diff --git a/rng/rng.go b/rng/rng.go index 85528f1..d0f7578 100644 --- a/rng/rng.go +++ b/rng/rng.go @@ -23,7 +23,7 @@ var ( ) func init() { - modules.Register("random", prep, Start, nil, "base") + modules.Register("random", prep, Start, nil) } func prep() error { diff --git a/template/module.go b/template/module.go new file mode 100644 index 0000000..b921d16 --- /dev/null +++ b/template/module.go @@ -0,0 +1,116 @@ +package template + +import ( + "context" + "time" + + "github.com/safing/portbase/config" + "github.com/safing/portbase/modules" + "github.com/safing/portbase/modules/subsystems" +) + +const ( + eventStateUpdate = "state update" +) + +var ( + module *modules.Module +) + +func init() { + // register base module, for database initialization + modules.Register("base", nil, nil, nil) + + // register module + module = modules.Register("template", prep, start, stop) // add dependencies... + subsystems.Register( + "template-subsystem", // ID + "Template Subsystem", // name + "This subsystem is a template for quick setup", // description + module, + "config:template", // key space for configuration options registered + &config.Option{ + Name: "Enable Template Subsystem", + Key: "config:subsystems/template", + Description: "This option enables the Template Subsystem [TEMPLATE]", + OptType: config.OptTypeBool, + DefaultValue: false, + }, + ) + + // register events that other modules can subscribe to + module.RegisterEvent(eventStateUpdate) +} + +func prep() error { + // register options + err := config.Register(&config.Option{ + Name: "language", + Key: "config:template/language", + Description: "Sets the language for the template [TEMPLATE]", + OptType: config.OptTypeString, + ExpertiseLevel: config.ExpertiseLevelUser, // default + ReleaseLevel: config.ReleaseLevelStable, // default + RequiresRestart: false, // default + DefaultValue: "en", + ValidationRegex: "^[a-z]{2}$", + }) + if err != nil { + return err + } + + // register event hooks + // do this in prep() and not in start(), as we don't want to register again if module is turned off and on again + err = module.RegisterEventHook( + "template", // event source module name + "state update", // event source name + "react to state changes", // description of hook function + eventHandler, // hook function + ) + if err != nil { + return err + } + + // hint: event hooks and tasks will not be run if module isn't online + return nil +} + +func start() error { + // register tasks + module.NewTask("do something", taskFn).Queue() + + // start service worker + module.StartServiceWorker("do something", 0, serviceWorker) + + return nil +} + +func stop() error { + return nil +} + +func serviceWorker(ctx context.Context) error { + for { + select { + case <-time.After(1 * time.Second): + err := do() + if err != nil { + return err + } + case <-ctx.Done(): + return nil + } + } +} + +func taskFn(ctx context.Context, task *modules.Task) error { + return do() +} + +func eventHandler(ctx context.Context, data interface{}) error { + return do() +} + +func do() error { + return nil +} diff --git a/template/module_test.go b/template/module_test.go new file mode 100644 index 0000000..987ef74 --- /dev/null +++ b/template/module_test.go @@ -0,0 +1,51 @@ +package template + +import ( + "fmt" + "io/ioutil" + "os" + "testing" + + "github.com/safing/portbase/dataroot" + "github.com/safing/portbase/modules" +) + +func TestMain(m *testing.M) { + // enable module for testing + module.Enable() + + // tmp dir for data root (db & config) + tmpDir, err := ioutil.TempDir("", "portbase-testing-") + if err != nil { + fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err) + os.Exit(1) + } + // initialize data dir + err = dataroot.Initialize(tmpDir, 0755) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) + os.Exit(1) + } + + // start modules + var exitCode int + err = modules.Start() + if err != nil { + // starting failed + fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err) + exitCode = 1 + } else { + // run tests + exitCode = m.Run() + } + + // shutdown + _ = modules.Shutdown() + if modules.GetExitStatusCode() != 0 { + exitCode = modules.GetExitStatusCode() + fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err) + } + // clean up and exit + os.RemoveAll(tmpDir) + os.Exit(exitCode) +} diff --git a/updater/fetch.go b/updater/fetch.go index e7e1373..8b7853b 100644 --- a/updater/fetch.go +++ b/updater/fetch.go @@ -50,6 +50,10 @@ func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error { } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("error fetching url (%s): %s", downloadURL, resp.Status) + } + // download and write file n, err := io.Copy(atomicFile, resp.Body) if err != nil { @@ -96,6 +100,10 @@ func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte, } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("error fetching url (%s): %s", downloadURL, resp.Status) + } + // download and write file buf := bytes.NewBuffer(make([]byte, 0, resp.ContentLength)) n, err := io.Copy(buf, resp.Body)