Release to master

This commit is contained in:
Daniel 2020-04-15 20:55:33 +02:00
commit 4899e44f60
73 changed files with 2821 additions and 849 deletions

View file

@ -7,5 +7,11 @@ linters:
- funlen
- whitespace
- wsl
- godox
- gomnd
linters-settings:
godox:
# report any comments starting with keywords, this is useful for TODO or FIXME comments that
# might be left in the code accidentally and should be resolved before merging
keywords:
- FIXME

9
Gopkg.lock generated
View file

@ -161,14 +161,6 @@
revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4"
version = "v0.8.1"
[[projects]]
branch = "develop"
digest = "1:d88649ff4a4a0746857dd9e39915aedddce2b08e442ac131a91e573cd45bde93"
name = "github.com/safing/portmaster"
packages = ["core/structure"]
pruneopts = "UT"
revision = "26c307b7a0db78d91b35ef9020706f106ebef8b6"
[[projects]]
digest = "1:274f67cb6fed9588ea2521ecdac05a6d62a8c51c074c1fccc6a49a40ba80e925"
name = "github.com/satori/go.uuid"
@ -332,7 +324,6 @@
"github.com/gorilla/mux",
"github.com/gorilla/websocket",
"github.com/hashicorp/go-version",
"github.com/safing/portmaster/core/structure",
"github.com/satori/go.uuid",
"github.com/seehuhn/fortuna",
"github.com/shirou/gopsutil/host",

View file

@ -243,8 +243,8 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) {
if err == nil {
data, err = r.Marshal(r, record.JSON)
}
if err == nil {
api.send(opID, dbMsgTypeError, err.Error(), nil) //nolint:nilness // FIXME: possibly false positive (golangci-lint govet/nilness)
if err != nil {
api.send(opID, dbMsgTypeError, err.Error(), nil)
return
}
api.send(opID, dbMsgTypeOk, r.Key(), data)
@ -384,9 +384,9 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
default:
api.send(opID, dbMsgTypeUpd, r.Key(), data)
}
} else if sub.Err != nil {
} else {
// sub feed ended
api.send(opID, dbMsgTypeError, sub.Err.Error(), nil)
api.send(opID, dbMsgTypeDone, "", nil)
}
}
}
@ -435,13 +435,13 @@ func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create b
return
}
// FIXME: remove transition code
if data[0] != record.JSON {
typedData := make([]byte, len(data)+1)
typedData[0] = record.JSON
copy(typedData[1:], data)
data = typedData
}
// TODO - staged for deletion: remove transition code
// if data[0] != record.JSON {
// typedData := make([]byte, len(data)+1)
// typedData[0] = record.JSON
// copy(typedData[1:], data)
// data = typedData
// }
r, err := record.NewWrapper(key, nil, data[0], data[1:])
if err != nil {

View file

@ -17,7 +17,7 @@ var (
)
func init() {
module = modules.Register("api", prep, start, stop, "base", "database", "config")
module = modules.Register("api", prep, start, stop, "database", "config")
}
func prep() error {

View file

@ -36,26 +36,30 @@ func (s *StorageInterface) Get(key string) (record.Record, error) {
}
// Put stores a record in the database.
func (s *StorageInterface) Put(r record.Record) error {
func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
if r.Meta().Deleted > 0 {
return setConfigOption(r.DatabaseKey(), nil, false)
return r, setConfigOption(r.DatabaseKey(), nil, false)
}
acc := r.GetAccessor(r)
if acc == nil {
return errors.New("invalid data")
return nil, errors.New("invalid data")
}
val, ok := acc.Get("Value")
if !ok || val == nil {
return setConfigOption(r.DatabaseKey(), nil, false)
err := setConfigOption(r.DatabaseKey(), nil, false)
if err != nil {
return nil, err
}
return s.Get(r.DatabaseKey())
}
optionsLock.RLock()
option, ok := options[r.DatabaseKey()]
optionsLock.RUnlock()
if !ok {
return errors.New("config option does not exist")
return nil, errors.New("config option does not exist")
}
var value interface{}
@ -70,14 +74,14 @@ func (s *StorageInterface) Put(r record.Record) error {
value, ok = acc.GetBool("Value")
}
if !ok {
return errors.New("received invalid value in \"Value\"")
return nil, errors.New("received invalid value in \"Value\"")
}
err := setConfigOption(r.DatabaseKey(), value, false)
if err != nil {
return err
return nil, err
}
return nil
return option.Export()
}
// Delete deletes a record from the database.

View file

@ -5,6 +5,8 @@ package config
import (
"fmt"
"sync/atomic"
"github.com/tevino/abool"
)
// Expertise Level constants
@ -22,6 +24,8 @@ const (
var (
expertiseLevel *int32
expertiseLevelOption *Option
expertiseLevelOptionFlag = abool.New()
)
func init() {
@ -32,7 +36,7 @@ func init() {
}
func registerExpertiseLevelOption() {
err := Register(&Option{
expertiseLevelOption = &Option{
Name: "Expertise Level",
Key: expertiseLevelKey,
Description: "The Expertise Level controls the perceived complexity. Higher settings will show you more complex settings and information. This might also affect various other things relying on this setting. Modified settings in higher expertise levels stay in effect when switching back. (Unlike the Release Level)",
@ -46,15 +50,31 @@ func registerExpertiseLevelOption() {
ExternalOptType: "string list",
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ExpertiseLevelNameUser, ExpertiseLevelNameExpert, ExpertiseLevelNameDeveloper),
})
}
err := Register(expertiseLevelOption)
if err != nil {
panic(err)
}
expertiseLevelOptionFlag.Set()
}
func updateExpertiseLevel() {
new := findStringValue(expertiseLevelKey, "")
switch new {
// check if already registered
if !expertiseLevelOptionFlag.IsSet() {
return
}
// get value
value := expertiseLevelOption.activeFallbackValue
if expertiseLevelOption.activeValue != nil {
value = expertiseLevelOption.activeValue
}
if expertiseLevelOption.activeDefaultValue != nil {
value = expertiseLevelOption.activeDefaultValue
}
// set atomic value
switch value.stringVal {
case ExpertiseLevelNameUser:
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser))
case ExpertiseLevelNameExpert:

View file

@ -12,14 +12,24 @@ var (
// GetAsString returns a function that returns the wanted string with high performance.
func (cs *safe) GetAsString(name string, fallback string) StringOption {
valid := getValidityFlag()
value := findStringValue(name, fallback)
option, valueCache := getValueCache(name, nil, OptTypeString)
value := fallback
if valueCache != nil {
value = valueCache.stringVal
}
var lock sync.Mutex
return func() string {
lock.Lock()
defer lock.Unlock()
if !valid.IsSet() {
valid = getValidityFlag()
value = findStringValue(name, fallback)
option, valueCache = getValueCache(name, option, OptTypeString)
if valueCache != nil {
value = valueCache.stringVal
} else {
value = fallback
}
}
return value
}
@ -28,14 +38,24 @@ func (cs *safe) GetAsString(name string, fallback string) StringOption {
// GetAsStringArray returns a function that returns the wanted string with high performance.
func (cs *safe) GetAsStringArray(name string, fallback []string) StringArrayOption {
valid := getValidityFlag()
value := findStringArrayValue(name, fallback)
option, valueCache := getValueCache(name, nil, OptTypeStringArray)
value := fallback
if valueCache != nil {
value = valueCache.stringArrayVal
}
var lock sync.Mutex
return func() []string {
lock.Lock()
defer lock.Unlock()
if !valid.IsSet() {
valid = getValidityFlag()
value = findStringArrayValue(name, fallback)
option, valueCache = getValueCache(name, option, OptTypeStringArray)
if valueCache != nil {
value = valueCache.stringArrayVal
} else {
value = fallback
}
}
return value
}
@ -44,14 +64,24 @@ func (cs *safe) GetAsStringArray(name string, fallback []string) StringArrayOpti
// GetAsInt returns a function that returns the wanted int with high performance.
func (cs *safe) GetAsInt(name string, fallback int64) IntOption {
valid := getValidityFlag()
value := findIntValue(name, fallback)
option, valueCache := getValueCache(name, nil, OptTypeInt)
value := fallback
if valueCache != nil {
value = valueCache.intVal
}
var lock sync.Mutex
return func() int64 {
lock.Lock()
defer lock.Unlock()
if !valid.IsSet() {
valid = getValidityFlag()
value = findIntValue(name, fallback)
option, valueCache = getValueCache(name, option, OptTypeInt)
if valueCache != nil {
value = valueCache.intVal
} else {
value = fallback
}
}
return value
}
@ -60,14 +90,24 @@ func (cs *safe) GetAsInt(name string, fallback int64) IntOption {
// GetAsBool returns a function that returns the wanted int with high performance.
func (cs *safe) GetAsBool(name string, fallback bool) BoolOption {
valid := getValidityFlag()
value := findBoolValue(name, fallback)
option, valueCache := getValueCache(name, nil, OptTypeBool)
value := fallback
if valueCache != nil {
value = valueCache.boolVal
}
var lock sync.Mutex
return func() bool {
lock.Lock()
defer lock.Unlock()
if !valid.IsSet() {
valid = getValidityFlag()
value = findBoolValue(name, fallback)
option, valueCache = getValueCache(name, option, OptTypeBool)
if valueCache != nil {
value = valueCache.boolVal
} else {
value = fallback
}
}
return value
}

View file

@ -15,14 +15,59 @@ type (
BoolOption func() bool
)
func getValueCache(name string, option *Option, requestedType uint8) (*Option, *valueCache) {
// get option
if option == nil {
var ok bool
optionsLock.RLock()
option, ok = options[name]
optionsLock.RUnlock()
if !ok {
log.Errorf("config: request for unregistered option: %s", name)
return nil, nil
}
}
// check type
if requestedType != option.OptType {
log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(option.OptType))
return option, nil
}
// lock option
option.Lock()
defer option.Unlock()
// check release level
if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil {
return option, option.activeValue
}
if option.activeDefaultValue != nil {
return option, option.activeDefaultValue
}
return option, option.activeFallbackValue
}
// GetAsString returns a function that returns the wanted string with high performance.
func GetAsString(name string, fallback string) StringOption {
valid := getValidityFlag()
value := findStringValue(name, fallback)
option, valueCache := getValueCache(name, nil, OptTypeString)
value := fallback
if valueCache != nil {
value = valueCache.stringVal
}
return func() string {
if !valid.IsSet() {
valid = getValidityFlag()
value = findStringValue(name, fallback)
option, valueCache = getValueCache(name, option, OptTypeString)
if valueCache != nil {
value = valueCache.stringVal
} else {
value = fallback
}
}
return value
}
@ -31,11 +76,21 @@ func GetAsString(name string, fallback string) StringOption {
// GetAsStringArray returns a function that returns the wanted string with high performance.
func GetAsStringArray(name string, fallback []string) StringArrayOption {
valid := getValidityFlag()
value := findStringArrayValue(name, fallback)
option, valueCache := getValueCache(name, nil, OptTypeStringArray)
value := fallback
if valueCache != nil {
value = valueCache.stringArrayVal
}
return func() []string {
if !valid.IsSet() {
valid = getValidityFlag()
value = findStringArrayValue(name, fallback)
option, valueCache = getValueCache(name, option, OptTypeStringArray)
if valueCache != nil {
value = valueCache.stringArrayVal
} else {
value = fallback
}
}
return value
}
@ -44,11 +99,21 @@ func GetAsStringArray(name string, fallback []string) StringArrayOption {
// GetAsInt returns a function that returns the wanted int with high performance.
func GetAsInt(name string, fallback int64) IntOption {
valid := getValidityFlag()
value := findIntValue(name, fallback)
option, valueCache := getValueCache(name, nil, OptTypeInt)
value := fallback
if valueCache != nil {
value = valueCache.intVal
}
return func() int64 {
if !valid.IsSet() {
valid = getValidityFlag()
value = findIntValue(name, fallback)
option, valueCache = getValueCache(name, option, OptTypeInt)
if valueCache != nil {
value = valueCache.intVal
} else {
value = fallback
}
}
return value
}
@ -57,18 +122,28 @@ func GetAsInt(name string, fallback int64) IntOption {
// GetAsBool returns a function that returns the wanted int with high performance.
func GetAsBool(name string, fallback bool) BoolOption {
valid := getValidityFlag()
value := findBoolValue(name, fallback)
option, valueCache := getValueCache(name, nil, OptTypeBool)
value := fallback
if valueCache != nil {
value = valueCache.boolVal
}
return func() bool {
if !valid.IsSet() {
valid = getValidityFlag()
value = findBoolValue(name, fallback)
option, valueCache = getValueCache(name, option, OptTypeBool)
if valueCache != nil {
value = valueCache.boolVal
} else {
value = fallback
}
}
return value
}
}
// findValue find the correct value in the user or default config.
func findValue(key string) interface{} {
/*
func getAndFindValue(key string) interface{} {
optionsLock.RLock()
option, ok := options[key]
optionsLock.RUnlock()
@ -77,6 +152,13 @@ func findValue(key string) interface{} {
return nil
}
return option.findValue()
}
*/
/*
// findValue finds the preferred value in the user or default config.
func (option *Option) findValue() interface{} {
// lock option
option.Lock()
defer option.Unlock()
@ -91,88 +173,4 @@ func findValue(key string) interface{} {
return option.DefaultValue
}
// findStringValue validates and returns the value with the given key.
func findStringValue(key string, fallback string) (value string) {
result := findValue(key)
if result == nil {
return fallback
}
v, ok := result.(string)
if ok {
return v
}
return fallback
}
// findStringArrayValue validates and returns the value with the given key.
func findStringArrayValue(key string, fallback []string) (value []string) {
result := findValue(key)
if result == nil {
return fallback
}
v, ok := result.([]interface{})
if ok {
new := make([]string, len(v))
for i, val := range v {
s, ok := val.(string)
if ok {
new[i] = s
} else {
return fallback
}
}
return new
}
return fallback
}
// findIntValue validates and returns the value with the given key.
func findIntValue(key string, fallback int64) (value int64) {
result := findValue(key)
if result == nil {
return fallback
}
switch v := result.(type) {
case int:
return int64(v)
case int8:
return int64(v)
case int16:
return int64(v)
case int32:
return int64(v)
case int64:
return v
case uint:
return int64(v)
case uint8:
return int64(v)
case uint16:
return int64(v)
case uint32:
return int64(v)
case uint64:
return int64(v)
case float32:
return int64(v)
case float64:
return int64(v)
}
return fallback
}
// findBoolValue validates and returns the value with the given key.
func findBoolValue(key string, fallback bool) (value bool) {
result := findValue(key)
if result == nil {
return fallback
}
v, ok := result.(bool)
if ok {
return v
}
return fallback
}
*/

View file

@ -1,6 +1,7 @@
package config
import (
"encoding/json"
"testing"
"github.com/safing/portbase/log"
@ -39,7 +40,7 @@ func quickRegister(t *testing.T, key string, optType uint8, defaultValue interfa
}
}
func TestGet(t *testing.T) {
func TestGet(t *testing.T) { //nolint:gocognit
// reset
options = make(map[string]*Option)
@ -48,7 +49,7 @@ func TestGet(t *testing.T) {
t.Fatal(err)
}
quickRegister(t, "monkey", OptTypeInt, -1)
quickRegister(t, "monkey", OptTypeString, "c")
quickRegister(t, "zebras/zebra", OptTypeStringArray, []string{"a", "b"})
quickRegister(t, "elephant", OptTypeInt, -1)
quickRegister(t, "hot", OptTypeBool, false)
@ -56,7 +57,7 @@ func TestGet(t *testing.T) {
err = parseAndSetConfig(`
{
"monkey": "1",
"monkey": "a",
"zebras": {
"zebra": ["black", "white"]
},
@ -71,7 +72,7 @@ func TestGet(t *testing.T) {
err = parseAndSetDefaultConfig(`
{
"monkey": "0",
"monkey": "b",
"snake": "0",
"elephant": 0
}
@ -81,8 +82,8 @@ func TestGet(t *testing.T) {
}
monkey := GetAsString("monkey", "none")
if monkey() != "1" {
t.Errorf("monkey should be 1, is %s", monkey())
if monkey() != "a" {
t.Errorf("monkey should be a, is %s", monkey())
}
zebra := GetAsStringArray("zebras/zebra", []string{})
@ -131,6 +132,53 @@ func TestGet(t *testing.T) {
GetAsInt("elephant", -1)()
GetAsBool("hot", false)()
// perspective
// load data
pLoaded := make(map[string]interface{})
err = json.Unmarshal([]byte(`{
"monkey": "a",
"zebras": {
"zebra": ["black", "white"]
},
"elephant": 2,
"hot": true,
"cold": false
}`), &pLoaded)
if err != nil {
t.Fatal(err)
}
// create
p, err := NewPerspective(pLoaded)
if err != nil {
t.Fatal(err)
}
monkeyVal, ok := p.GetAsString("monkey")
if !ok || monkeyVal != "a" {
t.Errorf("[perspective] monkey should be a, is %+v", monkeyVal)
}
zebraVal, ok := p.GetAsStringArray("zebras/zebra")
if !ok || len(zebraVal) != 2 || zebraVal[0] != "black" || zebraVal[1] != "white" {
t.Errorf("[perspective] zebra should be [\"black\", \"white\"], is %+v", zebraVal)
}
elephantVal, ok := p.GetAsInt("elephant")
if !ok || elephantVal != 2 {
t.Errorf("[perspective] elephant should be 2, is %+v", elephantVal)
}
hotVal, ok := p.GetAsBool("hot")
if !ok || !hotVal {
t.Errorf("[perspective] hot should be true, is %+v", hotVal)
}
coldVal, ok := p.GetAsBool("cold")
if !ok || coldVal {
t.Errorf("[perspective] cold should be false, is %+v", coldVal)
}
}
func TestReleaseLevel(t *testing.T) {
@ -236,11 +284,9 @@ func BenchmarkGetAsStringCached(b *testing.B) {
options = make(map[string]*Option)
// Setup
err := parseAndSetConfig(`
{
err := parseAndSetConfig(`{
"monkey": "banana"
}
`)
}`)
if err != nil {
b.Fatal(err)
}
@ -257,11 +303,9 @@ func BenchmarkGetAsStringCached(b *testing.B) {
func BenchmarkGetAsStringRefetch(b *testing.B) {
// Setup
err := parseAndSetConfig(`
{
err := parseAndSetConfig(`{
"monkey": "banana"
}
`)
}`)
if err != nil {
b.Fatal(err)
}
@ -271,38 +315,34 @@ func BenchmarkGetAsStringRefetch(b *testing.B) {
// Start benchmark
for i := 0; i < b.N; i++ {
findStringValue("monkey", "no banana")
getValueCache("monkey", nil, OptTypeString)
}
}
func BenchmarkGetAsIntCached(b *testing.B) {
// Setup
err := parseAndSetConfig(`
{
"monkey": 1
}
`)
err := parseAndSetConfig(`{
"elephant": 1
}`)
if err != nil {
b.Fatal(err)
}
monkey := GetAsInt("monkey", -1)
elephant := GetAsInt("elephant", -1)
// Reset timer for precise results
b.ResetTimer()
// Start benchmark
for i := 0; i < b.N; i++ {
monkey()
elephant()
}
}
func BenchmarkGetAsIntRefetch(b *testing.B) {
// Setup
err := parseAndSetConfig(`
{
"monkey": 1
}
`)
err := parseAndSetConfig(`{
"elephant": 1
}`)
if err != nil {
b.Fatal(err)
}
@ -312,6 +352,6 @@ func BenchmarkGetAsIntRefetch(b *testing.B) {
// Start benchmark
for i := 0; i < b.N; i++ {
findIntValue("monkey", 1)
getValueCache("elephant", nil, OptTypeInt)
}
}

View file

@ -5,9 +5,9 @@ import (
"os"
"path/filepath"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/utils"
"github.com/safing/portmaster/core/structure"
)
const (
@ -27,12 +27,12 @@ func SetDataRoot(root *utils.DirStructure) {
}
func init() {
module = modules.Register("config", prep, start, nil, "base", "database")
module = modules.Register("config", prep, start, nil, "database")
module.RegisterEvent(configChangeEvent)
}
func prep() error {
SetDataRoot(structure.Root())
SetDataRoot(dataroot.Root())
if dataRoot == nil {
return errors.New("data root is not set")
}

View file

@ -41,6 +41,7 @@ type Option struct {
Name string
Key string // in path format: category/sub/key
Description string
Help string
OptType uint8
ExpertiseLevel uint8
@ -52,8 +53,9 @@ type Option struct {
ExternalOptType string
ValidationRegex string
activeValue interface{} // runtime value (loaded from config file or set by user)
activeDefaultValue interface{} // runtime default value (may be set internally)
activeValue *valueCache // runtime value (loaded from config file or set by user)
activeDefaultValue *valueCache // runtime default value (may be set internally)
activeFallbackValue *valueCache // default value from option registration
compiledRegex *regexp.Regexp
}
@ -68,14 +70,14 @@ func (option *Option) Export() (record.Record, error) {
}
if option.activeValue != nil {
data, err = sjson.SetBytes(data, "Value", option.activeValue)
data, err = sjson.SetBytes(data, "Value", option.activeValue.getData(option))
if err != nil {
return nil, err
}
}
if option.activeDefaultValue != nil {
data, err = sjson.SetBytes(data, "DefaultValue", option.activeDefaultValue)
data, err = sjson.SetBytes(data, "DefaultValue", option.activeDefaultValue.getData(option))
if err != nil {
return nil, err
}

View file

@ -47,7 +47,7 @@ func saveConfig() error {
for key, option := range options {
option.Lock()
if option.activeValue != nil {
activeValues[key] = option.activeValue
activeValues[key] = option.activeValue.getData(option)
}
option.Unlock()
}

128
config/perspective.go Normal file
View file

@ -0,0 +1,128 @@
package config
import (
"fmt"
"github.com/safing/portbase/log"
)
// Perspective is a view on configuration data without interfering with the configuration system.
type Perspective struct {
config map[string]*perspectiveOption
}
type perspectiveOption struct {
option *Option
valueCache *valueCache
}
// NewPerspective parses the given config and returns it as a new perspective.
func NewPerspective(config map[string]interface{}) (*Perspective, error) {
// flatten config structure
flatten(config, config, "")
perspective := &Perspective{
config: make(map[string]*perspectiveOption),
}
var firstErr error
var errCnt int
optionsLock.Lock()
optionsLoop:
for key, option := range options {
// get option key from config
configValue, ok := config[key]
if !ok {
continue
}
// validate value
valueCache, err := validateValue(option, configValue)
if err != nil {
errCnt++
if firstErr == nil {
firstErr = err
}
continue optionsLoop
}
// add to perspective
perspective.config[key] = &perspectiveOption{
option: option,
valueCache: valueCache,
}
}
optionsLock.Unlock()
if firstErr != nil {
if errCnt > 0 {
return perspective, fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
}
return perspective, firstErr
}
return perspective, nil
}
func (p *Perspective) getPerspectiveValueCache(name string, requestedType uint8) *valueCache {
// get option
pOption, ok := p.config[name]
if !ok {
// check if option exists at all
optionsLock.RLock()
_, ok = options[name]
optionsLock.RUnlock()
if !ok {
log.Errorf("config: request for unregistered option: %s", name)
}
return nil
}
// check type
if requestedType != pOption.option.OptType {
log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(pOption.option.OptType))
return nil
}
// check release level
if pOption.option.ReleaseLevel > getReleaseLevel() {
return nil
}
return pOption.valueCache
}
// GetAsString returns a function that returns the wanted string with high performance.
func (p *Perspective) GetAsString(name string) (value string, ok bool) {
valueCache := p.getPerspectiveValueCache(name, OptTypeString)
if valueCache != nil {
return valueCache.stringVal, true
}
return "", false
}
// GetAsStringArray returns a function that returns the wanted string with high performance.
func (p *Perspective) GetAsStringArray(name string) (value []string, ok bool) {
valueCache := p.getPerspectiveValueCache(name, OptTypeStringArray)
if valueCache != nil {
return valueCache.stringArrayVal, true
}
return nil, false
}
// GetAsInt returns a function that returns the wanted int with high performance.
func (p *Perspective) GetAsInt(name string) (value int64, ok bool) {
valueCache := p.getPerspectiveValueCache(name, OptTypeInt)
if valueCache != nil {
return valueCache.intVal, true
}
return 0, false
}
// GetAsBool returns a function that returns the wanted int with high performance.
func (p *Perspective) GetAsBool(name string) (value bool, ok bool) {
valueCache := p.getPerspectiveValueCache(name, OptTypeBool)
if valueCache != nil {
return valueCache.boolVal, true
}
return false, false
}

View file

@ -26,14 +26,20 @@ func Register(option *Option) error {
return fmt.Errorf("failed to register option: please set option.OptType")
}
if option.ValidationRegex != "" {
var err error
if option.ValidationRegex != "" {
option.compiledRegex, err = regexp.Compile(option.ValidationRegex)
if err != nil {
return fmt.Errorf("config: could not compile option.ValidationRegex: %s", err)
}
}
option.activeFallbackValue, err = validateValue(option, option.DefaultValue)
if err != nil {
return fmt.Errorf("config: invalid default value: %s", err)
}
optionsLock.Lock()
defer optionsLock.Unlock()
options[option.Key] = option

View file

@ -15,7 +15,7 @@ func TestRegistry(t *testing.T) {
ReleaseLevel: ReleaseLevelStable,
ExpertiseLevel: ExpertiseLevelUser,
OptType: OptTypeString,
DefaultValue: "default",
DefaultValue: "water",
ValidationRegex: "^(banana|water)$",
}); err != nil {
t.Error(err)

View file

@ -5,6 +5,8 @@ package config
import (
"fmt"
"sync/atomic"
"github.com/tevino/abool"
)
// Release Level constants
@ -22,6 +24,8 @@ const (
var (
releaseLevel *int32
releaseLevelOption *Option
releaseLevelOptionFlag = abool.New()
)
func init() {
@ -32,7 +36,7 @@ func init() {
}
func registerReleaseLevelOption() {
err := Register(&Option{
releaseLevelOption = &Option{
Name: "Release Level",
Key: releaseLevelKey,
Description: "The Release Level changes which features are available to you. Some beta or experimental features are also available in the stable release channel. Unavailable settings are set to the default value.",
@ -46,15 +50,31 @@ func registerReleaseLevelOption() {
ExternalOptType: "string list",
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ReleaseLevelNameStable, ReleaseLevelNameBeta, ReleaseLevelNameExperimental),
})
}
err := Register(releaseLevelOption)
if err != nil {
panic(err)
}
releaseLevelOptionFlag.Set()
}
func updateReleaseLevel() {
new := findStringValue(releaseLevelKey, "")
switch new {
// check if already registered
if !releaseLevelOptionFlag.IsSet() {
return
}
// get value
value := releaseLevelOption.activeFallbackValue
if releaseLevelOption.activeValue != nil {
value = releaseLevelOption.activeValue
}
if releaseLevelOption.activeDefaultValue != nil {
value = releaseLevelOption.activeDefaultValue
}
// set atomic value
switch value.stringVal {
case ReleaseLevelNameStable:
atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable))
case ReleaseLevelNameBeta:

View file

@ -19,6 +19,7 @@ var (
validityFlagLock sync.RWMutex
)
// getValidityFlag returns a flag that signifies if the configuration has been changed. This flag must not be changed, only read.
func getValidityFlag() *abool.AtomicBool {
validityFlagLock.RLock()
defer validityFlagLock.RUnlock()
@ -41,14 +42,24 @@ func signalChanges() {
// setConfig sets the (prioritized) user defined config.
func setConfig(newValues map[string]interface{}) error {
var firstErr error
var errCnt int
optionsLock.Lock()
for key, option := range options {
newValue, ok := newValues[key]
option.Lock()
if ok {
option.activeValue = newValue
} else {
option.activeValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeValue = valueCache
} else {
errCnt++
if firstErr == nil {
firstErr = err
}
}
}
option.Unlock()
}
@ -56,19 +67,37 @@ func setConfig(newValues map[string]interface{}) error {
signalChanges()
go pushFullUpdate()
if firstErr != nil {
if errCnt > 0 {
return fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
}
return firstErr
}
return nil
}
// SetDefaultConfig sets the (fallback) default config.
func SetDefaultConfig(newValues map[string]interface{}) error {
var firstErr error
var errCnt int
optionsLock.Lock()
for key, option := range options {
newValue, ok := newValues[key]
option.Lock()
if ok {
option.activeDefaultValue = newValue
} else {
option.activeDefaultValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeDefaultValue = valueCache
} else {
errCnt++
if firstErr == nil {
firstErr = err
}
}
}
option.Unlock()
}
@ -76,51 +105,15 @@ func SetDefaultConfig(newValues map[string]interface{}) error {
signalChanges()
go pushFullUpdate()
return nil
if firstErr != nil {
if errCnt > 0 {
return fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
}
return firstErr
}
func validateValue(option *Option, value interface{}) error {
switch v := value.(type) {
case string:
if option.OptType != OptTypeString {
return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
if !option.compiledRegex.MatchString(v) {
return fmt.Errorf("validation failed: string \"%s\" did not match regex for option %s", v, option.Key)
}
}
return nil
case []string:
if option.OptType != OptTypeStringArray {
return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
for pos, entry := range v {
if !option.compiledRegex.MatchString(entry) {
return fmt.Errorf("validation failed: string \"%s\" at index %d did not match regex for option %s", entry, pos, option.Key)
}
}
}
return nil
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
if option.OptType != OptTypeInt {
return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
if !option.compiledRegex.MatchString(fmt.Sprintf("%d", v)) {
return fmt.Errorf("validation failed: number \"%d\" did not match regex for option %s", v, option.Key)
}
}
return nil
case bool:
if option.OptType != OptTypeBool {
return fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
return nil
default:
return fmt.Errorf("invalid option value type: %T", value)
}
}
// SetConfigOption sets a single value in the (prioritized) user defined config.
@ -140,9 +133,10 @@ func setConfigOption(key string, value interface{}, push bool) (err error) {
if value == nil {
option.activeValue = nil
} else {
err = validateValue(option, value)
var valueCache *valueCache
valueCache, err = validateValue(option, value)
if err == nil {
option.activeValue = value
option.activeValue = valueCache
}
}
option.Unlock()
@ -175,9 +169,10 @@ func setDefaultConfigOption(key string, value interface{}, push bool) (err error
if value == nil {
option.activeDefaultValue = nil
} else {
err = validateValue(option, value)
var valueCache *valueCache
valueCache, err = validateValue(option, value)
if err == nil {
option.activeDefaultValue = value
option.activeDefaultValue = valueCache
}
}
option.Unlock()

120
config/validate.go Normal file
View file

@ -0,0 +1,120 @@
package config
import (
"errors"
"fmt"
"math"
)
type valueCache struct {
stringVal string
stringArrayVal []string
intVal int64
boolVal bool
}
func (vc *valueCache) getData(opt *Option) interface{} {
switch opt.OptType {
case OptTypeBool:
return vc.boolVal
case OptTypeInt:
return vc.intVal
case OptTypeString:
return vc.stringVal
case OptTypeStringArray:
return vc.stringArrayVal
default:
return nil
}
}
func validateValue(option *Option, value interface{}) (*valueCache, error) { //nolint:gocyclo
switch v := value.(type) {
case string:
if option.OptType != OptTypeString {
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
if !option.compiledRegex.MatchString(v) {
return nil, fmt.Errorf("validation of option %s failed: string \"%s\" did not match validation regex for option", option.Key, v)
}
}
return &valueCache{stringVal: v}, nil
case []interface{}:
vConverted := make([]string, len(v))
for pos, entry := range v {
s, ok := entry.(string)
if !ok {
return nil, fmt.Errorf("validation of option %s failed: element %+v at index %d is not a string", option.Key, entry, pos)
}
vConverted[pos] = s
}
// continue to next case
return validateValue(option, vConverted)
case []string:
if option.OptType != OptTypeStringArray {
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
for pos, entry := range v {
if !option.compiledRegex.MatchString(entry) {
return nil, fmt.Errorf("validation of option %s failed: string \"%s\" at index %d did not match validation regex", option.Key, entry, pos)
}
}
}
return &valueCache{stringArrayVal: v}, nil
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64:
// uint64 is omitted, as it does not fit in a int64
if option.OptType != OptTypeInt {
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
// we need to use %v here so we handle float and int correctly.
if !option.compiledRegex.MatchString(fmt.Sprintf("%v", v)) {
return nil, fmt.Errorf("validation of option %s failed: number \"%d\" did not match validation regex", option.Key, v)
}
}
switch v := value.(type) {
case int:
return &valueCache{intVal: int64(v)}, nil
case int8:
return &valueCache{intVal: int64(v)}, nil
case int16:
return &valueCache{intVal: int64(v)}, nil
case int32:
return &valueCache{intVal: int64(v)}, nil
case int64:
return &valueCache{intVal: v}, nil
case uint:
return &valueCache{intVal: int64(v)}, nil
case uint8:
return &valueCache{intVal: int64(v)}, nil
case uint16:
return &valueCache{intVal: int64(v)}, nil
case uint32:
return &valueCache{intVal: int64(v)}, nil
case float32:
// convert if float has no decimals
if math.Remainder(float64(v), 1) == 0 {
return &valueCache{intVal: int64(v)}, nil
}
return nil, fmt.Errorf("failed to convert float32 to int64 for option %s, got value %+v", option.Key, v)
case float64:
// convert if float has no decimals
if math.Remainder(v, 1) == 0 {
return &valueCache{intVal: int64(v)}, nil
}
return nil, fmt.Errorf("failed to convert float64 to int64 for option %s, got value %+v", option.Key, v)
default:
return nil, errors.New("internal error")
}
case bool:
if option.OptType != OptTypeBool {
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
return &valueCache{boolVal: v}, nil
default:
return nil, fmt.Errorf("invalid option value type for option %s: %T", option.Key, value)
}
}

30
config/validity.go Normal file
View file

@ -0,0 +1,30 @@
package config
import (
"github.com/tevino/abool"
)
// ValidityFlag is a flag that signifies if the configuration has been changed. It is not safe for concurrent use.
type ValidityFlag struct {
flag *abool.AtomicBool
}
// NewValidityFlag returns a flag that signifies if the configuration has been changed.
func NewValidityFlag() *ValidityFlag {
vf := &ValidityFlag{}
vf.Refresh()
return vf
}
// IsValid returns if the configuration is still valid.
func (vf *ValidityFlag) IsValid() bool {
return vf.flag.IsSet()
}
// Refresh refreshes the flag and makes it reusable.
func (vf *ValidityFlag) Refresh() {
validityFlagLock.RLock()
defer validityFlagLock.RUnlock()
vf.flag = validityFlag
}

View file

@ -1,6 +1,7 @@
package database
import (
"errors"
"sync"
"github.com/tevino/abool"
@ -119,10 +120,13 @@ func (c *Controller) Put(r record.Record) (err error) {
}
}
err = c.storage.Put(r)
r, err = c.storage.Put(r)
if err != nil {
return err
}
if r == nil {
return errors.New("storage returned nil record after successful put operation")
}
// process subscriptions
for _, sub := range c.subscriptions {
@ -137,6 +141,32 @@ func (c *Controller) Put(r record.Record) (err error) {
return nil
}
// PutMany stores many records in the database.
func (c *Controller) PutMany() (chan<- record.Record, <-chan error) {
c.writeLock.RLock()
defer c.writeLock.RUnlock()
if shuttingDown.IsSet() {
errs := make(chan error, 1)
errs <- ErrShuttingDown
return make(chan record.Record), errs
}
if c.ReadOnly() {
errs := make(chan error, 1)
errs <- ErrReadOnly
return make(chan record.Record), errs
}
if batcher, ok := c.storage.(storage.Batcher); ok {
return batcher.PutMany()
}
errs := make(chan error, 1)
errs <- ErrNotImplemented
return make(chan record.Record), errs
}
// Query executes the given query on the database.
func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
c.readLock.RLock()

View file

@ -10,10 +10,13 @@ import (
"testing"
"time"
"github.com/safing/portbase/database/record"
q "github.com/safing/portbase/database/query"
_ "github.com/safing/portbase/database/storage/badger"
_ "github.com/safing/portbase/database/storage/bbolt"
_ "github.com/safing/portbase/database/storage/fstree"
_ "github.com/safing/portbase/database/storage/hashmap"
)
func makeKey(dbName, key string) string {
@ -39,7 +42,10 @@ func testDatabase(t *testing.T, storageType string) {
}
// interface
db := NewInterface(nil)
db := NewInterface(&Options{
Local: true,
Internal: true,
})
// sub
sub, err := db.Subscribe(q.New(dbName).MustBeValid())
@ -107,6 +113,18 @@ func testDatabase(t *testing.T, storageType string) {
t.Fatalf("expected two records, got %d", cnt)
}
switch storageType {
case "bbolt", "hashmap":
batchPut := db.PutMany(dbName)
records := []record.Record{A, B, C, nil} // nil is to signify finish
for _, r := range records {
err = batchPut(r)
if err != nil {
t.Fatal(err)
}
}
}
err = hook.Cancel()
if err != nil {
t.Fatal(err)
@ -128,12 +146,12 @@ func TestDatabaseSystem(t *testing.T) {
os.Exit(1)
}()
testDir, err := ioutil.TempDir("", "testing-")
testDir, err := ioutil.TempDir("", "portbase-database-testing-")
if err != nil {
t.Fatal(err)
}
err = Initialize(testDir, nil)
err = InitializeWithPath(testDir)
if err != nil {
t.Fatal(err)
}
@ -142,6 +160,7 @@ func TestDatabaseSystem(t *testing.T) {
testDatabase(t, "badger")
testDatabase(t, "bbolt")
testDatabase(t, "fstree")
testDatabase(t, "hashmap")
err = MaintainRecordStates()
if err != nil {

View file

@ -4,41 +4,44 @@ import (
"errors"
"github.com/safing/portbase/database"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/utils"
)
var (
databasePath string
databaseStructureRoot *utils.DirStructure
module *modules.Module
)
func init() {
module = modules.Register("database", prep, start, stop, "base")
module = modules.Register("database", prep, start, stop)
}
// SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure.
func SetDatabaseLocation(dirPath string, dirStructureRoot *utils.DirStructure) {
databasePath = dirPath
func SetDatabaseLocation(dirStructureRoot *utils.DirStructure) {
if databaseStructureRoot == nil {
databaseStructureRoot = dirStructureRoot
}
}
func prep() error {
if databasePath == "" && databaseStructureRoot == nil {
return errors.New("no database location specified")
SetDatabaseLocation(dataroot.Root())
if databaseStructureRoot == nil {
return errors.New("database location not specified")
}
return nil
}
func start() error {
err := database.Initialize(databasePath, databaseStructureRoot)
err := database.Initialize(databaseStructureRoot)
if err != nil {
return err
}
registerMaintenanceTasks()
startMaintenanceTasks()
return nil
}

View file

@ -5,33 +5,23 @@ import (
"time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
)
func registerMaintenanceTasks() {
func startMaintenanceTasks() {
module.NewTask("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute)
module.NewTask("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour)
module.NewTask("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour)
}
func maintainBasic(ctx context.Context, task *modules.Task) {
err := database.Maintain()
if err != nil {
log.Errorf("database: maintenance error: %s", err)
}
func maintainBasic(ctx context.Context, task *modules.Task) error {
return database.Maintain()
}
func maintainThorough(ctx context.Context, task *modules.Task) {
err := database.MaintainThorough()
if err != nil {
log.Errorf("database: thorough maintenance error: %s", err)
}
func maintainThorough(ctx context.Context, task *modules.Task) error {
return database.MaintainThorough()
}
func maintainRecords(ctx context.Context, task *modules.Task) {
err := database.MaintainRecordStates()
if err != nil {
log.Errorf("database: record states maintenance error: %s", err)
}
func maintainRecords(ctx context.Context, task *modules.Task) error {
return database.MaintainRecordStates()
}

View file

@ -10,4 +10,5 @@ var (
ErrPermissionDenied = errors.New("access to database record denied")
ErrReadOnly = errors.New("database is read only")
ErrShuttingDown = errors.New("database system is shutting down")
ErrNotImplemented = errors.New("not implemented by this storage")
)

View file

@ -5,6 +5,8 @@ import (
"fmt"
"time"
"github.com/tevino/abool"
"github.com/bluele/gcache"
"github.com/safing/portbase/database/accessor"
@ -170,11 +172,20 @@ func (i *Interface) InsertValue(key string, attribute string, value interface{})
}
// Put saves a record to the database.
func (i *Interface) Put(r record.Record) error {
_, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
func (i *Interface) Put(r record.Record) (err error) {
// get record or only database
var db *Controller
if !i.options.Internal || !i.options.Local {
_, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
if err != nil && err != ErrNotFound {
return err
}
} else {
db, err = getController(r.DatabaseKey())
if err != nil {
return err
}
}
r.Lock()
defer r.Unlock()
@ -186,24 +197,122 @@ func (i *Interface) Put(r record.Record) error {
}
// PutNew saves a record to the database as a new record (ie. with new timestamps).
func (i *Interface) PutNew(r record.Record) error {
_, db, err := i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
func (i *Interface) PutNew(r record.Record) (err error) {
// get record or only database
var db *Controller
if !i.options.Internal || !i.options.Local {
_, db, err = i.getRecord(r.DatabaseName(), r.DatabaseKey(), true, true)
if err != nil && err != ErrNotFound {
return err
}
} else {
db, err = getController(r.DatabaseKey())
if err != nil {
return err
}
}
r.Lock()
defer r.Unlock()
if r.Meta() == nil {
r.CreateMeta()
}
if r.Meta() != nil {
r.Meta().Reset()
}
i.options.Apply(r)
i.updateCache(r)
return db.Put(r)
}
// PutMany stores many records in the database. Warning: This is nearly a direct database access and omits many things:
// - Record locking
// - Hooks
// - Subscriptions
// - Caching
func (i *Interface) PutMany(dbName string) (put func(record.Record) error) {
interfaceBatch := make(chan record.Record, 100)
// permission check
if !i.options.Internal || !i.options.Local {
return func(r record.Record) error {
return ErrPermissionDenied
}
}
// get database
db, err := getController(dbName)
if err != nil {
return func(r record.Record) error {
return err
}
}
// start database access
dbBatch, errs := db.PutMany()
finished := abool.New()
var internalErr error
// interface options proxy
go func() {
defer close(dbBatch) // signify that we are finished
for {
select {
case r := <-interfaceBatch:
// finished?
if r == nil {
return
}
// apply options
i.options.Apply(r)
// pass along
dbBatch <- r
case <-time.After(1 * time.Second):
// bail out
internalErr = errors.New("timeout: putmany unused for too long")
finished.Set()
return
}
}
}()
return func(r record.Record) error {
// finished?
if finished.IsSet() {
// check for internal error
if internalErr != nil {
return internalErr
}
// check for previous error
select {
case err := <-errs:
return err
default:
return errors.New("batch is closed")
}
}
// finish?
if r == nil {
finished.Set()
interfaceBatch <- nil // signify that we are finished
// do not close, as this fn could be called again with nil.
return <-errs
}
// check record scope
if r.DatabaseName() != dbName {
return errors.New("record out of database scope")
}
// submit
select {
case interfaceBatch <- r:
return nil
case err := <-errs:
return err
}
}
}
// SetAbsoluteExpiry sets an absolute record expiry.
func (i *Interface) SetAbsoluteExpiry(key string, time int64) error {
r, db, err := i.getRecord(getDBFromKey, key, true, true)

View file

@ -23,16 +23,16 @@ var (
databasesStructure *utils.DirStructure
)
// Initialize initializes the database at the specified location. Supply either a path or dir structure.
func Initialize(dirPath string, dirStructureRoot *utils.DirStructure) error {
if initialized.SetToIf(false, true) {
if dirStructureRoot != nil {
rootStructure = dirStructureRoot
} else {
rootStructure = utils.NewDirStructure(dirPath, 0755)
// InitializeWithPath initializes the database at the specified location using a path.
func InitializeWithPath(dirPath string) error {
return Initialize(utils.NewDirStructure(dirPath, 0755))
}
// Initialize initializes the database at the specified location using a dir structure.
func Initialize(dirStructureRoot *utils.DirStructure) error {
if initialized.SetToIf(false, true) {
rootStructure = dirStructureRoot
// ensure root and databases dirs
databasesStructure = rootStructure.ChildDir(databasesSubDir, 0700)
err := databasesStructure.Ensure()

View file

@ -1,16 +1,10 @@
package record
import "sync"
import (
"sync"
)
type TestRecord struct {
Base
lock sync.Mutex
}
func (tm *TestRecord) Lock() {
tm.lock.Lock()
}
func (tm *TestRecord) Unlock() {
tm.lock.Unlock()
sync.Mutex
}

View file

@ -37,27 +37,19 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) {
offset += n
newMeta := &Meta{}
if len(metaSection) == 34 && metaSection[4] == 0 {
// TODO: remove in 2020
// backward compatibility:
// format would byte shift and populate metaSection[4] with value > 0 (would naturally populate >0 at 07.02.2106 07:28:15)
// this must be gencode without format
_, err = newMeta.GenCodeUnmarshal(metaSection)
if err != nil {
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
}
} else {
_, err = dsd.Load(metaSection, newMeta)
if err != nil {
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
}
}
format, n, err := varint.Unpack8(data[offset:])
var format uint8 = dsd.NONE
if !newMeta.IsDeleted() {
format, n, err = varint.Unpack8(data[offset:])
if err != nil {
return nil, fmt.Errorf("could not get dsd format: %s", err)
}
offset += n
}
return &Wrapper{
Base{

View file

@ -2,10 +2,7 @@ package record
import (
"bytes"
"errors"
"testing"
"github.com/safing/portbase/container"
)
func TestWrapper(t *testing.T) {
@ -54,43 +51,4 @@ func TestWrapper(t *testing.T) {
if !bytes.Equal(testData, wrapper2.Data) {
t.Error("marshal mismatch")
}
// test new format
oldRaw, err := oldWrapperMarshalRecord(wrapper, wrapper)
if err != nil {
t.Fatal(err)
}
wrapper3, err := NewRawWrapper("test", "a", oldRaw)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(testData, wrapper3.Data) {
t.Error("marshal mismatch")
}
}
func oldWrapperMarshalRecord(w *Wrapper, r Record) ([]byte, error) {
if w.Meta() == nil {
return nil, errors.New("missing meta")
}
// version
c := container.New([]byte{1})
// meta
metaSection, err := w.meta.GenCodeMarshal(nil)
if err != nil {
return nil, err
}
c.AppendAsBlock(metaSection)
// data
dataSection, err := w.Marshal(r, JSON)
if err != nil {
return nil, err
}
c.Append(dataSection)
return c.CompileData(), nil
}

View file

@ -139,7 +139,7 @@ func saveRegistry(lock bool) error {
}
// write file
// FIXME: write atomically (best effort)
// TODO: write atomically (best effort)
filePath := path.Join(rootStructure.Path, registryFileName)
return ioutil.WriteFile(filePath, data, 0600)
}

View file

@ -82,16 +82,19 @@ func (b *Badger) Get(key string) (record.Record, error) {
}
// Put stores a record in the database.
func (b *Badger) Put(r record.Record) error {
func (b *Badger) Put(r record.Record) (record.Record, error) {
data, err := r.MarshalRecord(r)
if err != nil {
return err
return nil, err
}
err = b.db.Update(func(txn *badger.Txn) error {
return txn.Set([]byte(r.DatabaseKey()), data)
})
return err
if err != nil {
return nil, err
}
return r, nil
}
// Delete deletes a record from the database.

View file

@ -65,7 +65,7 @@ func TestBadger(t *testing.T) {
a.SetKey("test:A")
// put record
err = db.Put(a)
_, err = db.Put(a)
if err != nil {
t.Fatal(err)
}

View file

@ -86,10 +86,10 @@ func (b *BBolt) Get(key string) (record.Record, error) {
}
// Put stores a record in the database.
func (b *BBolt) Put(r record.Record) error {
func (b *BBolt) Put(r record.Record) (record.Record, error) {
data, err := r.MarshalRecord(r)
if err != nil {
return err
return nil, err
}
err = b.db.Update(func(tx *bbolt.Tx) error {
@ -100,9 +100,38 @@ func (b *BBolt) Put(r record.Record) error {
return nil
})
if err != nil {
return err
return nil, err
}
return r, nil
}
// PutMany stores many records in the database.
func (b *BBolt) PutMany() (chan<- record.Record, <-chan error) {
batch := make(chan record.Record, 100)
errs := make(chan error, 1)
go func() {
err := b.db.Batch(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(bucketName)
for r := range batch {
// marshal
data, txErr := r.MarshalRecord(r)
if txErr != nil {
return txErr
}
// put
txErr = bucket.Put([]byte(r.DatabaseKey()), data)
if txErr != nil {
return txErr
}
}
return nil
})
errs <- err
}()
return batch, errs
}
// Delete deletes a record from the database.

View file

@ -65,7 +65,7 @@ func TestBBolt(t *testing.T) {
a.SetKey("test:A")
// put record
err = db.Put(a)
_, err = db.Put(a)
if err != nil {
t.Fatal(err)
}
@ -100,15 +100,15 @@ func TestBBolt(t *testing.T) {
qZ.SetKey("test:z")
qZ.CreateMeta()
// put
err = db.Put(qA)
_, err = db.Put(qA)
if err == nil {
err = db.Put(qB)
_, err = db.Put(qB)
}
if err == nil {
err = db.Put(qC)
_, err = db.Put(qC)
}
if err == nil {
err = db.Put(qZ)
_, err = db.Put(qZ)
}
if err != nil {
t.Fatal(err)

View file

@ -104,15 +104,15 @@ func (fst *FSTree) Get(key string) (record.Record, error) {
}
// Put stores a record in the database.
func (fst *FSTree) Put(r record.Record) error {
func (fst *FSTree) Put(r record.Record) (record.Record, error) {
dstPath, err := fst.buildFilePath(r.DatabaseKey(), true)
if err != nil {
return err
return nil, err
}
data, err := r.MarshalRecord(r)
if err != nil {
return err
return nil, err
}
err = writeFile(dstPath, data, defaultFileMode)
@ -120,15 +120,15 @@ func (fst *FSTree) Put(r record.Record) error {
// create dir and try again
err = os.MkdirAll(filepath.Dir(dstPath), defaultDirMode)
if err != nil {
return fmt.Errorf("fstree: failed to create directory %s: %s", filepath.Dir(dstPath), err)
return nil, fmt.Errorf("fstree: failed to create directory %s: %s", filepath.Dir(dstPath), err)
}
err = writeFile(dstPath, data, defaultFileMode)
if err != nil {
return fmt.Errorf("fstree: could not write file %s: %s", dstPath, err)
return nil, fmt.Errorf("fstree: could not write file %s: %s", dstPath, err)
}
}
return nil
return r, nil
}
// Delete deletes a record from the database.

View file

@ -44,12 +44,33 @@ func (hm *HashMap) Get(key string) (record.Record, error) {
}
// Put stores a record in the database.
func (hm *HashMap) Put(r record.Record) error {
func (hm *HashMap) Put(r record.Record) (record.Record, error) {
hm.dbLock.Lock()
defer hm.dbLock.Unlock()
hm.db[r.DatabaseKey()] = r
return nil
return r, nil
}
// PutMany stores many records in the database.
func (hm *HashMap) PutMany() (chan<- record.Record, <-chan error) {
hm.dbLock.Lock()
defer hm.dbLock.Unlock()
// we could lock for every record, but we want to have the same behaviour
// as the other storage backends, especially for testing.
batch := make(chan record.Record, 100)
errs := make(chan error, 1)
// start handler
go func() {
for r := range batch {
hm.db[r.DatabaseKey()] = r
}
errs <- nil
}()
return batch, errs
}
// Delete deletes a record from the database.

View file

@ -57,7 +57,7 @@ func TestHashMap(t *testing.T) {
a.SetKey("test:A")
// put record
err = db.Put(a)
_, err = db.Put(a)
if err != nil {
t.Fatal(err)
}
@ -86,15 +86,15 @@ func TestHashMap(t *testing.T) {
qZ.SetKey("test:z")
qZ.CreateMeta()
// put
err = db.Put(qA)
_, err = db.Put(qA)
if err == nil {
err = db.Put(qB)
_, err = db.Put(qB)
}
if err == nil {
err = db.Put(qC)
_, err = db.Put(qC)
}
if err == nil {
err = db.Put(qZ)
_, err = db.Put(qZ)
}
if err != nil {
t.Fatal(err)

View file

@ -21,8 +21,16 @@ func (i *InjectBase) Get(key string) (record.Record, error) {
}
// Put stores a record in the database.
func (i *InjectBase) Put(m record.Record) error {
return errNotImplemented
func (i *InjectBase) Put(m record.Record) (record.Record, error) {
return nil, errNotImplemented
}
// PutMany stores many records in the database.
func (i *InjectBase) PutMany() (batch chan record.Record, err chan error) {
batch = make(chan record.Record)
err = make(chan error, 1)
err <- errNotImplemented
return
}
// Delete deletes a record from the database.

View file

@ -9,7 +9,7 @@ import (
// Interface defines the database storage API.
type Interface interface {
Get(key string) (record.Record, error)
Put(m record.Record) error
Put(m record.Record) (record.Record, error)
Delete(key string) error
Query(q *query.Query, local, internal bool) (*iterator.Iterator, error)
@ -19,3 +19,8 @@ type Interface interface {
MaintainThorough() error
Shutdown() error
}
// Batcher defines the database storage API for backends that support batch operations.
type Batcher interface {
PutMany() (batch chan<- record.Record, errs <-chan error)
}

View file

@ -36,8 +36,24 @@ func (s *Sinkhole) Get(key string) (record.Record, error) {
}
// Put stores a record in the database.
func (s *Sinkhole) Put(m record.Record) error {
return nil
func (s *Sinkhole) Put(r record.Record) (record.Record, error) {
return r, nil
}
// PutMany stores many records in the database.
func (s *Sinkhole) PutMany() (chan<- record.Record, <-chan error) {
batch := make(chan record.Record, 100)
errs := make(chan error, 1)
// start handler
go func() {
for range batch {
// nom, nom, nom
}
errs <- nil
}()
return batch, errs
}
// Delete deletes a record from the database.

View file

@ -13,7 +13,6 @@ type Subscription struct {
canceled bool
Feed chan record.Record
Err error
}
// Cancel cancels the subscription.

27
dataroot/root.go Normal file
View file

@ -0,0 +1,27 @@
package dataroot
import (
"errors"
"os"
"github.com/safing/portbase/utils"
)
var (
root *utils.DirStructure
)
// Initialize initializes the data root directory
func Initialize(rootDir string, perm os.FileMode) error {
if root != nil {
return errors.New("already initialized")
}
root = utils.NewDirStructure(rootDir, perm)
return root.Ensure()
}
// Root returns the data root directory.
func Root() *utils.DirStructure {
return root
}

View file

@ -8,6 +8,6 @@ var (
)
func init() {
flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]")
flag.StringVar(&logLevelFlag, "log", "", "set log level to [trace|debug|info|warning|error|critical]")
flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,notifications=debug")
}

38
log/formatting_darwin.go Normal file
View file

@ -0,0 +1,38 @@
package log
const (
rightArrow = "▶"
leftArrow = "◀"
)
const (
// colorBlack = "\033[30m"
colorRed = "\033[31m"
// colorGreen = "\033[32m"
colorYellow = "\033[33m"
colorBlue = "\033[34m"
colorMagenta = "\033[35m"
colorCyan = "\033[36m"
// colorWhite = "\033[37m"
)
func (s Severity) color() string {
switch s {
case DebugLevel:
return colorCyan
case InfoLevel:
return colorBlue
case WarningLevel:
return colorYellow
case ErrorLevel:
return colorRed
case CriticalLevel:
return colorMagenta
default:
return ""
}
}
func endColor() string {
return "\033[0m"
}

View file

@ -12,7 +12,7 @@ func log(level Severity, msg string, tracer *ContextTracer) {
if !started.IsSet() {
// a bit resource intense, but keeps logs before logging started.
// FIXME: create option to disable logging
// TODO: create option to disable logging
go func() {
<-startedSignal
log(level, msg, tracer)

View file

@ -82,6 +82,7 @@ var (
logsWaiting = make(chan struct{}, 4)
logsWaitingFlag = abool.NewBool(false)
shutdownFlag = abool.NewBool(false)
shutdownSignal = make(chan struct{})
shutdownWaitGroup sync.WaitGroup
@ -136,12 +137,14 @@ func Start() (err error) {
logBuffer = make(chan *logLine, 1024)
if logLevelFlag != "" {
initialLogLevel := ParseLevel(logLevelFlag)
if initialLogLevel > 0 {
if initialLogLevel == 0 {
fmt.Fprintf(os.Stderr, "log warning: invalid log level \"%s\", falling back to level info\n", logLevelFlag)
initialLogLevel = InfoLevel
}
SetLogLevel(initialLogLevel)
} else {
err = fmt.Errorf("log warning: invalid log level \"%s\", falling back to level info", logLevelFlag)
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
}
// get and set file loglevels
@ -179,6 +182,8 @@ func Start() (err error) {
// Shutdown writes remaining log lines and then stops the log system.
func Shutdown() {
if shutdownFlag.SetToIf(false, true) {
close(shutdownSignal)
}
shutdownWaitGroup.Wait()
}

View file

@ -41,7 +41,7 @@ func writeLine(line *logLine, duplicates uint64) {
}
func startWriter() {
fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor()))
fmt.Printf("%s%s %s BOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor())
shutdownWaitGroup.Add(1)
go writerManager()
@ -168,7 +168,7 @@ func finalizeWriting() {
case line := <-logBuffer:
writeLine(line, 0)
case <-time.After(10 * time.Millisecond):
fmt.Println(fmt.Sprintf("%s%s %s EOF%s", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor()))
fmt.Printf("%s%s %s EOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor())
return
}
}

View file

@ -89,7 +89,7 @@ func (tracer *ContextTracer) Submit() {
if !started.IsSet() {
// a bit resource intense, but keeps logs before logging started.
// FIXME: create option to disable logging
// TODO: create option to disable logging
go func() {
<-startedSignal
tracer.Submit()

View file

@ -17,8 +17,10 @@ type eventHook struct {
// TriggerEvent executes all hook functions registered to the specified event.
func (m *Module) TriggerEvent(event string, data interface{}) {
if m.OnlineSoon() {
go m.processEventTrigger(event, data)
}
}
func (m *Module) processEventTrigger(event string, data interface{}) {
m.eventHooksLock.RLock()
@ -31,18 +33,35 @@ func (m *Module) processEventTrigger(event string, data interface{}) {
}
for _, hook := range hooks {
if !hook.hookingModule.ShutdownInProgress() {
if hook.hookingModule.OnlineSoon() {
go m.runEventHook(hook, event, data)
}
}
}
func (m *Module) runEventHook(hook *eventHook, event string, data interface{}) {
if !hook.hookingModule.Started.IsSet() {
// check if source module is ready for handling
if m.Status() != StatusOnline {
// target module has not yet fully started, wait until start is complete
select {
case <-startCompleteSignal:
case <-shutdownSignal:
case <-m.StartCompleted():
// continue with hook execution
case <-hook.hookingModule.Stopping():
return
case <-m.Stopping():
return
}
}
// check if destionation module is ready for handling
if hook.hookingModule.Status() != StatusOnline {
// target module has not yet fully started, wait until start is complete
select {
case <-hook.hookingModule.StartCompleted():
// continue with hook execution
case <-hook.hookingModule.Stopping():
return
case <-m.Stopping():
return
}
}
@ -69,7 +88,7 @@ func (m *Module) RegisterEvent(event string) {
}
}
// RegisterEventHook registers a hook function with (another) modules' event. Whenever a hook is triggered and the receiving module has not yet fully started, hook execution will be delayed until all modules completed starting.
// RegisterEventHook registers a hook function with (another) modules' event. Whenever a hook is triggered and the receiving module has not yet fully started, hook execution will be delayed until the modules completed starting.
func (m *Module) RegisterEventHook(module string, event string, description string, fn func(context.Context, interface{}) error) error {
// get target module
var eventModule *Module
@ -77,9 +96,7 @@ func (m *Module) RegisterEventHook(module string, event string, description stri
eventModule = m
} else {
var ok bool
modulesLock.RLock()
eventModule, ok = modules[module]
modulesLock.RUnlock()
if !ok {
return fmt.Errorf(`module "%s" does not exist`, module)
}

View file

@ -1,18 +1,22 @@
package modules
import "flag"
import (
"flag"
"fmt"
)
var (
// HelpFlag triggers printing flag.Usage. It's exported for custom help handling.
HelpFlag bool
printGraphFlag bool
)
func init() {
flag.BoolVar(&HelpFlag, "help", false, "print help")
flag.BoolVar(&printGraphFlag, "print-module-graph", false, "print the module dependency graph")
}
func parseFlags() error {
// parse flags
flag.Parse()
@ -21,5 +25,36 @@ func parseFlags() error {
return ErrCleanExit
}
if printGraphFlag {
printGraph()
return ErrCleanExit
}
return nil
}
func printGraph() {
// mark roots
for _, module := range modules {
if len(module.depReverse) == 0 {
// is root, dont print deps in dep tree
module.stopFlag.Set()
}
}
// print
for _, module := range modules {
if module.stopFlag.IsSet() {
// print from root
printModuleGraph("", module, true)
}
}
}
func printModuleGraph(prefix string, module *Module, root bool) {
fmt.Printf("%s├── %s\n", prefix, module.Name)
if root || !module.stopFlag.IsSet() {
for _, dep := range module.Dependencies() {
printModuleGraph(fmt.Sprintf("│ %s", prefix), dep, false)
}
}
}

107
modules/mgmt.go Normal file
View file

@ -0,0 +1,107 @@
package modules
import (
"context"
"github.com/safing/portbase/log"
"github.com/tevino/abool"
)
var (
moduleMgmtEnabled = abool.NewBool(false)
modulesChangeNotifyFn func(*Module)
)
// Enable enables the module. Only has an effect if module management is enabled.
func (m *Module) Enable() (changed bool) {
return m.enabled.SetToIf(false, true)
}
// Disable disables the module. Only has an effect if module management is enabled.
func (m *Module) Disable() (changed bool) {
return m.enabled.SetToIf(true, false)
}
// SetEnabled sets the module to the desired enabled state. Only has an effect if module management is enabled.
func (m *Module) SetEnabled(enable bool) (changed bool) {
if enable {
return m.Enable()
}
return m.Disable()
}
// Enabled returns wether or not the module is currently enabled.
func (m *Module) Enabled() bool {
return m.enabled.IsSet()
}
// EnableModuleManagement enables the module management functionality within modules. The supplied notify function will be called whenever the status of a module changes. The affected module will be in the parameter. You will need to manually enable modules, else nothing will start.
func EnableModuleManagement(changeNotifyFn func(*Module)) {
if moduleMgmtEnabled.SetToIf(false, true) {
modulesChangeNotifyFn = changeNotifyFn
}
}
func (m *Module) notifyOfChange() {
if moduleMgmtEnabled.IsSet() && modulesChangeNotifyFn != nil {
m.StartWorker("notify of change", func(ctx context.Context) error {
modulesChangeNotifyFn(m)
return nil
})
}
}
// ManageModules triggers the module manager to react to recent changes of enabled modules.
func ManageModules() error {
// check if enabled
if !moduleMgmtEnabled.IsSet() {
return nil
}
// lock mgmt
mgmtLock.Lock()
defer mgmtLock.Unlock()
log.Info("modules: managing changes")
// build new dependency tree
buildEnabledTree()
// stop unneeded modules
lastErr := stopModules()
if lastErr != nil {
log.Warning(lastErr.Error())
}
// start needed modules
err := startModules()
if err != nil {
log.Warning(err.Error())
lastErr = err
}
log.Info("modules: finished managing")
return lastErr
}
func buildEnabledTree() {
// reset marked dependencies
for _, m := range modules {
m.enabledAsDependency.UnSet()
}
// mark dependencies
for _, m := range modules {
if m.enabled.IsSet() {
m.markDependencies()
}
}
}
func (m *Module) markDependencies() {
for _, dep := range m.depModules {
if dep.enabledAsDependency.SetToIf(false, true) {
dep.markDependencies()
}
}
}

165
modules/mgmt_test.go Normal file
View file

@ -0,0 +1,165 @@
package modules
import (
"testing"
)
func testModuleMgmt(t *testing.T) {
// enable module management
EnableModuleManagement(nil)
registerTestModule(t, "base")
registerTestModule(t, "feature1", "base")
registerTestModule(t, "base2", "base")
registerTestModule(t, "feature2", "base2")
registerTestModule(t, "sub-feature", "base")
registerTestModule(t, "feature3", "sub-feature")
registerTestModule(t, "feature4", "sub-feature")
// enable core module
core := modules["base"]
core.Enable()
// start and check order
err := Start()
if err != nil {
t.Error(err)
}
if changeHistory != " on:base" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
// enable feature1
feature1 := modules["feature1"]
feature1.Enable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
if changeHistory != " on:feature1" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
// enable feature2
feature2 := modules["feature2"]
feature2.Enable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
if changeHistory != " on:base2 on:feature2" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
// enable feature3
feature3 := modules["feature3"]
feature3.Enable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
if changeHistory != " on:sub-feature on:feature3" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
// enable feature4
feature4 := modules["feature4"]
feature4.Enable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
if changeHistory != " on:feature4" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
// disable feature1
feature1.Disable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
if changeHistory != " off:feature1" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
// disable feature3
feature3.Disable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
// disable feature4
feature4.Disable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
if changeHistory != " off:feature3 off:feature4 off:sub-feature" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
// enable feature4
feature4.Enable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
if changeHistory != " on:sub-feature on:feature4" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
// disable feature4
feature4.Disable()
// manage modules and check
err = ManageModules()
if err != nil {
t.Fatal(err)
return
}
if changeHistory != " off:feature4 off:sub-feature" {
t.Errorf("order mismatch, was %s", changeHistory)
}
changeHistory = ""
err = Shutdown()
if err != nil {
t.Error(err)
}
if changeHistory != " off:feature2 off:base2 off:base" {
t.Errorf("order mismatch, was %s", changeHistory)
}
// reset history
changeHistory = ""
// disable module management
moduleMgmtEnabled.UnSet()
resetTestEnvironment()
}

View file

@ -172,7 +172,7 @@ func microTaskScheduler() {
microTaskManageLoop:
for {
if shutdownSignalClosed.IsSet() {
if shutdownFlag.IsSet() {
close(mediumPriorityClearance)
close(lowPriorityClearance)
return

View file

@ -13,32 +13,44 @@ import (
)
var (
modulesLock sync.RWMutex
modules = make(map[string]*Module)
mgmtLock sync.Mutex
// lock modules when starting
modulesLocked = abool.New()
// ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag.
ErrCleanExit = errors.New("clean exit requested")
)
// Module represents a module.
type Module struct {
type Module struct { //nolint:maligned // not worth the effort
sync.RWMutex
Name string
// lifecycle mgmt
Prepped *abool.AtomicBool
Started *abool.AtomicBool
Stopped *abool.AtomicBool
inTransition *abool.AtomicBool
// status mgmt
enabled *abool.AtomicBool
enabledAsDependency *abool.AtomicBool
status uint8
// failure status
failureStatus uint8
failureID string
failureMsg string
// lifecycle callback functions
prep func() error
start func() error
stop func() error
prepFn func() error
startFn func() error
stopFn func() error
// shutdown mgmt
// lifecycle mgmt
// start
startComplete chan struct{}
// stop
Ctx context.Context
cancelCtx func()
shutdownFlag *abool.AtomicBool
stopFlag *abool.AtomicBool
// workers/tasks
workerCnt *int32
@ -56,30 +68,177 @@ type Module struct {
depReverse []*Module
}
// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead.
func (m *Module) ShutdownInProgress() bool {
return m.shutdownFlag.IsSet()
// StartCompleted returns a channel read that triggers when the module has finished starting.
func (m *Module) StartCompleted() <-chan struct{} {
m.RLock()
defer m.RUnlock()
return m.startComplete
}
// ShuttingDown lets you listen for the shutdown signal.
func (m *Module) ShuttingDown() <-chan struct{} {
// Stopping returns a channel read that triggers when the module has initiated the stop procedure.
func (m *Module) Stopping() <-chan struct{} {
m.RLock()
defer m.RUnlock()
return m.Ctx.Done()
}
func (m *Module) shutdown() error {
// signal shutdown
m.shutdownFlag.Set()
m.cancelCtx()
// IsStopping returns whether the module has started shutting down. In most cases, you should use Stopping instead.
func (m *Module) IsStopping() bool {
return m.stopFlag.IsSet()
}
// start shutdown function
m.waitGroup.Add(1)
stopFnError := make(chan error, 1)
// Dependencies returns the module's dependencies.
func (m *Module) Dependencies() []*Module {
m.RLock()
defer m.RUnlock()
return m.depModules
}
func (m *Module) prep(reports chan *report) {
// check and set intermediate status
m.Lock()
if m.status != StatusDead {
m.Unlock()
go func() {
stopFnError <- m.runCtrlFn("stop module", m.stop)
reports <- &report{
module: m,
err: fmt.Errorf("module already prepped"),
}
}()
return
}
m.status = StatusPreparing
m.Unlock()
// run prep function
go func() {
var err error
if m.prepFn != nil {
// execute function
err = m.runCtrlFnWithTimeout(
"prep module",
10*time.Second,
m.prepFn,
)
}
// set status
if err != nil {
m.Error(
"module-failed-prep",
fmt.Sprintf("failed to prep module: %s", err.Error()),
)
} else {
m.Lock()
m.status = StatusOffline
m.Unlock()
m.notifyOfChange()
}
// send report
reports <- &report{
module: m,
err: err,
}
}()
}
func (m *Module) start(reports chan *report) {
// check and set intermediate status
m.Lock()
if m.status != StatusOffline {
m.Unlock()
go func() {
reports <- &report{
module: m,
err: fmt.Errorf("module not offline"),
}
}()
return
}
m.status = StatusStarting
// reset stop management
if m.cancelCtx != nil {
// trigger cancel just to be sure
m.cancelCtx()
}
m.Ctx, m.cancelCtx = context.WithCancel(context.Background())
m.stopFlag.UnSet()
m.Unlock()
// run start function
go func() {
var err error
if m.startFn != nil {
// execute function
err = m.runCtrlFnWithTimeout(
"start module",
10*time.Second,
m.startFn,
)
}
// set status
if err != nil {
m.Error(
"module-failed-start",
fmt.Sprintf("failed to start module: %s", err.Error()),
)
} else {
m.Lock()
m.status = StatusOnline
// init start management
close(m.startComplete)
m.Unlock()
m.notifyOfChange()
}
// send report
reports <- &report{
module: m,
err: err,
}
}()
}
func (m *Module) stop(reports chan *report) {
// check and set intermediate status
m.Lock()
if m.status != StatusOnline {
m.Unlock()
go func() {
reports <- &report{
module: m,
err: fmt.Errorf("module not online"),
}
}()
return
}
m.status = StatusStopping
// reset start management
m.startComplete = make(chan struct{})
// init stop management
m.cancelCtx()
m.stopFlag.Set()
m.Unlock()
go m.stopAllTasks(reports)
}
func (m *Module) stopAllTasks(reports chan *report) {
// start shutdown function
stopFnFinished := abool.NewBool(false)
var stopFnError error
if m.stopFn != nil {
m.waitGroup.Add(1)
go func() {
stopFnError = m.runCtrlFn("stop module", m.stopFn)
stopFnFinished.Set()
m.waitGroup.Done()
}()
}
// wait for workers
// wait for workers and stop fn
done := make(chan struct{})
go func() {
m.waitGroup.Wait()
@ -91,8 +250,9 @@ func (m *Module) shutdown() error {
case <-done:
case <-time.After(30 * time.Second):
log.Warningf(
"%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...",
"%s: timed out while waiting for stopfn/workers/tasks to finish: stopFn=%v workers=%d tasks=%d microtasks=%d, continuing shutdown...",
m.Name,
stopFnFinished.IsSet(),
atomic.LoadInt32(m.workerCnt),
atomic.LoadInt32(m.taskCnt),
atomic.LoadInt32(m.microTaskCnt),
@ -100,24 +260,37 @@ func (m *Module) shutdown() error {
}
// collect error
select {
case err := <-stopFnError:
return err
default:
log.Warningf(
"%s: timed out while waiting for stop function to finish, continuing shutdown...",
m.Name,
var err error
if stopFnFinished.IsSet() && stopFnError != nil {
err = stopFnError
}
// set status
if err != nil {
m.Error(
"module-failed-stop",
fmt.Sprintf("failed to stop module: %s", err.Error()),
)
return nil
} else {
m.Lock()
m.status = StatusOffline
m.Unlock()
m.notifyOfChange()
}
// send report
reports <- &report{
module: m,
err: err,
}
}
// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished.
func Register(name string, prep, start, stop func() error, dependencies ...string) *Module {
if modulesLocked.IsSet() {
return nil
}
newModule := initNewModule(name, prep, start, stop, dependencies...)
modulesLock.Lock()
defer modulesLock.Unlock()
// check for already existing module
_, ok := modules[name]
if ok {
@ -137,20 +310,19 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ...
newModule := &Module{
Name: name,
Prepped: abool.NewBool(false),
Started: abool.NewBool(false),
Stopped: abool.NewBool(false),
inTransition: abool.NewBool(false),
enabled: abool.NewBool(false),
enabledAsDependency: abool.NewBool(false),
prepFn: prep,
startFn: start,
stopFn: stop,
startComplete: make(chan struct{}),
Ctx: ctx,
cancelCtx: cancelCtx,
shutdownFlag: abool.NewBool(false),
waitGroup: sync.WaitGroup{},
stopFlag: abool.NewBool(false),
workerCnt: &workerCnt,
taskCnt: &taskCnt,
microTaskCnt: &microTaskCnt,
prep: prep,
start: start,
stop: stop,
waitGroup: sync.WaitGroup{},
eventHooks: make(map[string][]*eventHook),
depNames: dependencies,
}
@ -177,49 +349,3 @@ func initDependencies() error {
return nil
}
// ReadyToPrep returns whether all dependencies are ready for this module to prep.
func (m *Module) ReadyToPrep() bool {
if m.inTransition.IsSet() || m.Prepped.IsSet() {
return false
}
for _, dep := range m.depModules {
if !dep.Prepped.IsSet() {
return false
}
}
return true
}
// ReadyToStart returns whether all dependencies are ready for this module to start.
func (m *Module) ReadyToStart() bool {
if m.inTransition.IsSet() || m.Started.IsSet() {
return false
}
for _, dep := range m.depModules {
if !dep.Started.IsSet() {
return false
}
}
return true
}
// ReadyToStop returns whether all dependencies are ready for this module to stop.
func (m *Module) ReadyToStop() bool {
if !m.Started.IsSet() || m.inTransition.IsSet() || m.Stopped.IsSet() {
return false
}
for _, revDep := range m.depReverse {
// not ready if a reverse dependency was started, but not yet stopped
if revDep.Started.IsSet() && !revDep.Stopped.IsSet() {
return false
}
}
return true
}

View file

@ -5,40 +5,36 @@ import (
"fmt"
"sync"
"testing"
"time"
)
var (
orderLock sync.Mutex
startOrder string
shutdownOrder string
changeHistoryLock sync.Mutex
changeHistory string
)
func testPrep(t *testing.T, name string) func() error {
return func() error {
func registerTestModule(t *testing.T, name string, dependencies ...string) {
Register(
name,
func() error {
t.Logf("prep %s\n", name)
return nil
}
}
func testStart(t *testing.T, name string) func() error {
return func() error {
orderLock.Lock()
defer orderLock.Unlock()
},
func() error {
changeHistoryLock.Lock()
defer changeHistoryLock.Unlock()
t.Logf("start %s\n", name)
startOrder = fmt.Sprintf("%s>%s", startOrder, name)
changeHistory = fmt.Sprintf("%s on:%s", changeHistory, name)
return nil
}
}
func testStop(t *testing.T, name string) func() error {
return func() error {
orderLock.Lock()
defer orderLock.Unlock()
},
func() error {
changeHistoryLock.Lock()
defer changeHistoryLock.Unlock()
t.Logf("stop %s\n", name)
shutdownOrder = fmt.Sprintf("%s>%s", shutdownOrder, name)
changeHistory = fmt.Sprintf("%s off:%s", changeHistory, name)
return nil
}
},
dependencies...,
)
}
func testFail() error {
@ -53,135 +49,94 @@ func TestModules(t *testing.T) {
t.Parallel() // Not really, just a workaround for running these tests last.
t.Run("TestModuleOrder", testModuleOrder)
t.Run("TestModuleMgmt", testModuleMgmt)
t.Run("TestModuleErrors", testModuleErrors)
}
func testModuleOrder(t *testing.T) {
Register("database", testPrep(t, "database"), testStart(t, "database"), testStop(t, "database"))
Register("stats", testPrep(t, "stats"), testStart(t, "stats"), testStop(t, "stats"), "database")
Register("service", testPrep(t, "service"), testStart(t, "service"), testStop(t, "service"), "database")
Register("analytics", testPrep(t, "analytics"), testStart(t, "analytics"), testStop(t, "analytics"), "stats", "database")
registerTestModule(t, "database")
registerTestModule(t, "stats", "database")
registerTestModule(t, "service", "database")
registerTestModule(t, "analytics", "stats", "database")
err := Start()
if err != nil {
t.Error(err)
}
if startOrder != ">database>service>stats>analytics" &&
startOrder != ">database>stats>service>analytics" &&
startOrder != ">database>stats>analytics>service" {
t.Errorf("start order mismatch, was %s", startOrder)
if changeHistory != " on:database on:service on:stats on:analytics" &&
changeHistory != " on:database on:stats on:service on:analytics" &&
changeHistory != " on:database on:stats on:analytics on:service" {
t.Errorf("start order mismatch, was %s", changeHistory)
}
changeHistory = ""
var wg sync.WaitGroup
wg.Add(1)
go func() {
select {
case <-ShuttingDown():
case <-time.After(1 * time.Second):
t.Error("did not receive shutdown signal")
}
wg.Done()
}()
err = Shutdown()
if err != nil {
t.Error(err)
}
if shutdownOrder != ">analytics>service>stats>database" &&
shutdownOrder != ">analytics>stats>service>database" &&
shutdownOrder != ">service>analytics>stats>database" {
t.Errorf("shutdown order mismatch, was %s", shutdownOrder)
if changeHistory != " off:analytics off:service off:stats off:database" &&
changeHistory != " off:analytics off:stats off:service off:database" &&
changeHistory != " off:service off:analytics off:stats off:database" {
t.Errorf("shutdown order mismatch, was %s", changeHistory)
}
changeHistory = ""
wg.Wait()
printAndRemoveModules()
}
func printAndRemoveModules() {
modulesLock.Lock()
defer modulesLock.Unlock()
fmt.Printf("All %d modules:\n", len(modules))
for _, m := range modules {
fmt.Printf("module %s: %+v\n", m.Name, m)
}
modules = make(map[string]*Module)
resetTestEnvironment()
}
func testModuleErrors(t *testing.T) {
// reset modules
modules = make(map[string]*Module)
startComplete.UnSet()
startCompleteSignal = make(chan struct{})
// test prep error
Register("prepfail", testFail, testStart(t, "prepfail"), testStop(t, "prepfail"))
Register("prepfail", testFail, nil, nil)
err := Start()
if err == nil {
t.Error("should fail")
}
// reset modules
modules = make(map[string]*Module)
startComplete.UnSet()
startCompleteSignal = make(chan struct{})
resetTestEnvironment()
// test prep clean exit
Register("prepcleanexit", testCleanExit, testStart(t, "prepcleanexit"), testStop(t, "prepcleanexit"))
Register("prepcleanexit", testCleanExit, nil, nil)
err = Start()
if err != ErrCleanExit {
t.Error("should fail with clean exit")
}
// reset modules
modules = make(map[string]*Module)
startComplete.UnSet()
startCompleteSignal = make(chan struct{})
resetTestEnvironment()
// test invalid dependency
Register("database", nil, testStart(t, "database"), testStop(t, "database"), "invalid")
Register("database", nil, nil, nil, "invalid")
err = Start()
if err == nil {
t.Error("should fail")
}
// reset modules
modules = make(map[string]*Module)
startComplete.UnSet()
startCompleteSignal = make(chan struct{})
resetTestEnvironment()
// test dependency loop
Register("database", nil, testStart(t, "database"), testStop(t, "database"), "helper")
Register("helper", nil, testStart(t, "helper"), testStop(t, "helper"), "database")
registerTestModule(t, "database", "helper")
registerTestModule(t, "helper", "database")
err = Start()
if err == nil {
t.Error("should fail")
}
// reset modules
modules = make(map[string]*Module)
startComplete.UnSet()
startCompleteSignal = make(chan struct{})
resetTestEnvironment()
// test failing module start
Register("startfail", nil, testFail, testStop(t, "startfail"))
Register("startfail", nil, testFail, nil)
err = Start()
if err == nil {
t.Error("should fail")
}
// reset modules
modules = make(map[string]*Module)
startComplete.UnSet()
startCompleteSignal = make(chan struct{})
resetTestEnvironment()
// test failing module stop
Register("stopfail", nil, testStart(t, "stopfail"), testFail)
Register("stopfail", nil, nil, testFail)
err = Start()
if err != nil {
t.Error("should not fail")
@ -191,10 +146,7 @@ func testModuleErrors(t *testing.T) {
t.Error("should fail")
}
// reset modules
modules = make(map[string]*Module)
startComplete.UnSet()
startCompleteSignal = make(chan struct{})
resetTestEnvironment()
// test help flag
HelpFlag = true
@ -204,4 +156,20 @@ func testModuleErrors(t *testing.T) {
}
HelpFlag = false
resetTestEnvironment()
}
func printModules() { //nolint:unused,deadcode
fmt.Printf("All %d modules:\n", len(modules))
for _, m := range modules {
fmt.Printf("module %s: %+v\n", m.Name, m)
}
}
func resetTestEnvironment() {
modules = make(map[string]*Module)
shutdownSignal = make(chan struct{})
shutdownCompleteSignal = make(chan struct{})
shutdownFlag.UnSet()
modulesLocked.UnSet()
}

View file

@ -1,34 +1,38 @@
package modules
import (
"errors"
"fmt"
"os"
"runtime"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/tevino/abool"
)
var (
startComplete = abool.NewBool(false)
startCompleteSignal = make(chan struct{})
initialStartCompleted = abool.NewBool(false)
globalPrepFn func() error
)
// StartCompleted returns whether starting has completed.
func StartCompleted() bool {
return startComplete.IsSet()
// SetGlobalPrepFn sets a global prep function that is run before all modules. This can be used to pre-initialize modules, such as setting the data root or database path.
// SetGlobalPrepFn sets a global prep function that is run before all modules.
func SetGlobalPrepFn(fn func() error) {
if globalPrepFn == nil {
globalPrepFn = fn
}
// WaitForStartCompletion returns as soon as starting has completed.
func WaitForStartCompletion() <-chan struct{} {
return startCompleteSignal
}
// Start starts all modules in the correct order. In case of an error, it will automatically shutdown again.
func Start() error {
modulesLock.RLock()
defer modulesLock.RUnlock()
if !modulesLocked.SetToIf(false, true) {
return errors.New("module system already started")
}
// lock mgmt
mgmtLock.Lock()
defer mgmtLock.Unlock()
// start microtask scheduler
go microTaskScheduler()
@ -44,10 +48,23 @@ func Start() error {
// parse flags
err = parseFlags()
if err != nil {
if err != ErrCleanExit {
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to parse flags: %s\n", err)
}
return err
}
// execute global prep fn
if globalPrepFn != nil {
err = globalPrepFn()
if err != nil {
if err != ErrCleanExit {
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: %s\n", err)
}
return err
}
}
// prep modules
err = prepareModules()
if err != nil {
@ -65,6 +82,9 @@ func Start() error {
return err
}
// build dependency tree
buildEnabledTree()
// start modules
log.Info("modules: initiating...")
err = startModules()
@ -74,14 +94,16 @@ func Start() error {
}
// complete startup
if moduleMgmtEnabled.IsSet() {
log.Info("modules: initiated subsystems manager")
} else {
log.Infof("modules: started %d modules", len(modules))
if startComplete.SetToIf(false, true) {
close(startCompleteSignal)
}
go taskQueueHandler()
go taskScheduleHandler()
initialStartCompleted.Set()
return nil
}
@ -97,34 +119,23 @@ func prepareModules() error {
reportCnt := 0
for {
waiting := 0
// find modules to exec
for _, m := range modules {
if m.ReadyToPrep() {
switch m.readyToPrep() {
case statusNothingToDo:
case statusWaiting:
waiting++
case statusReady:
execCnt++
m.inTransition.Set()
execM := m
go func() {
reports <- &report{
module: execM,
err: execM.runCtrlFnWithTimeout(
"prep module",
10*time.Second,
execM.prep,
),
}
}()
m.prep(reports)
}
}
// check for dep loop
if execCnt == reportCnt {
return fmt.Errorf("modules: dependency loop detected, cannot continue")
}
if reportCnt < execCnt {
// wait for reports
rep = <-reports
rep.module.inTransition.UnSet()
if rep.err != nil {
if rep.err == ErrCleanExit {
return rep.err
@ -132,10 +143,12 @@ func prepareModules() error {
return fmt.Errorf("failed to prep module %s: %s", rep.module.Name, rep.err)
}
reportCnt++
rep.module.Prepped.Set()
// exit if done
if reportCnt == len(modules) {
} else {
// finished
if waiting > 0 {
// check for dep loop
return fmt.Errorf("modules: dependency loop detected, cannot continue")
}
return nil
}
@ -149,45 +162,36 @@ func startModules() error {
reportCnt := 0
for {
waiting := 0
// find modules to exec
for _, m := range modules {
if m.ReadyToStart() {
switch m.readyToStart() {
case statusNothingToDo:
case statusWaiting:
waiting++
case statusReady:
execCnt++
m.inTransition.Set()
execM := m
go func() {
reports <- &report{
module: execM,
err: execM.runCtrlFnWithTimeout(
"start module",
60*time.Second,
execM.start,
),
}
}()
m.start(reports)
}
}
// check for dep loop
if execCnt == reportCnt {
return fmt.Errorf("modules: dependency loop detected, cannot continue")
}
if reportCnt < execCnt {
// wait for reports
rep = <-reports
rep.module.inTransition.UnSet()
if rep.err != nil {
return fmt.Errorf("modules: could not start module %s: %s", rep.module.Name, rep.err)
}
reportCnt++
rep.module.Started.Set()
log.Infof("modules: started %s", rep.module.Name)
// exit if done
if reportCnt == len(modules) {
} else {
// finished
if waiting > 0 {
// check for dep loop
return fmt.Errorf("modules: dependency loop detected, cannot continue")
}
// return last error
return nil
}
}
}

171
modules/status.go Normal file
View file

@ -0,0 +1,171 @@
package modules
// Module Status Values
const (
StatusDead uint8 = 0 // not prepared, not started
StatusPreparing uint8 = 1
StatusOffline uint8 = 2 // prepared, not started
StatusStopping uint8 = 3
StatusStarting uint8 = 4
StatusOnline uint8 = 5 // online and running
)
// Module Failure Status Values
const (
FailureNone uint8 = 0
FailureHint uint8 = 1
FailureWarning uint8 = 2
FailureError uint8 = 3
)
// ready status
const (
statusWaiting uint8 = iota
statusReady
statusNothingToDo
)
// Online returns whether the module is online.
func (m *Module) Online() bool {
return m.Status() == StatusOnline
}
// OnlineSoon returns whether the module is or is about to be online.
func (m *Module) OnlineSoon() bool {
if moduleMgmtEnabled.IsSet() &&
!m.enabled.IsSet() &&
!m.enabledAsDependency.IsSet() {
return false
}
return !m.stopFlag.IsSet()
}
// Status returns the current module status.
func (m *Module) Status() uint8 {
m.RLock()
defer m.RUnlock()
return m.status
}
// FailureStatus returns the current failure status, ID and message.
func (m *Module) FailureStatus() (failureStatus uint8, failureID, failureMsg string) {
m.RLock()
defer m.RUnlock()
return m.failureStatus, m.failureID, m.failureMsg
}
// Hint sets failure status to hint. This is a somewhat special failure status, as the module is believed to be working correctly, but there is an important module specific information to convey. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans.
func (m *Module) Hint(failureID, failureMsg string) {
m.Lock()
defer m.Unlock()
m.failureStatus = FailureHint
m.failureID = failureID
m.failureMsg = failureMsg
m.notifyOfChange()
}
// Warning sets failure status to warning. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans.
func (m *Module) Warning(failureID, failureMsg string) {
m.Lock()
defer m.Unlock()
m.failureStatus = FailureWarning
m.failureID = failureID
m.failureMsg = failureMsg
m.notifyOfChange()
}
// Error sets failure status to error. The supplied failureID is for improved automatic handling within connected systems, the failureMsg is for humans.
func (m *Module) Error(failureID, failureMsg string) {
m.Lock()
defer m.Unlock()
m.failureStatus = FailureError
m.failureID = failureID
m.failureMsg = failureMsg
m.notifyOfChange()
}
// Resolve removes the failure state from the module if the given failureID matches the current failure ID. If the given failureID is an empty string, Resolve removes any failure state.
func (m *Module) Resolve(failureID string) {
m.Lock()
defer m.Unlock()
if failureID == "" || failureID == m.failureID {
m.failureStatus = FailureNone
m.failureID = ""
m.failureMsg = ""
}
m.notifyOfChange()
}
// readyToPrep returns whether all dependencies are ready for this module to prep.
func (m *Module) readyToPrep() uint8 {
// check if valid state for prepping
if m.Status() != StatusDead {
return statusNothingToDo
}
for _, dep := range m.depModules {
if dep.Status() < StatusOffline {
return statusWaiting
}
}
return statusReady
}
// readyToStart returns whether all dependencies are ready for this module to start.
func (m *Module) readyToStart() uint8 {
// check if start is wanted
if moduleMgmtEnabled.IsSet() {
if !m.enabled.IsSet() && !m.enabledAsDependency.IsSet() {
return statusNothingToDo
}
}
// check if valid state for starting
if m.Status() != StatusOffline {
return statusNothingToDo
}
// check if all dependencies are ready
for _, dep := range m.depModules {
if dep.Status() < StatusOnline {
return statusWaiting
}
}
return statusReady
}
// readyToStop returns whether all dependencies are ready for this module to stop.
func (m *Module) readyToStop() uint8 {
// check if stop is wanted
if moduleMgmtEnabled.IsSet() && !shutdownFlag.IsSet() {
if m.enabled.IsSet() || m.enabledAsDependency.IsSet() {
return statusNothingToDo
}
}
// check if valid state for stopping
if m.Status() != StatusOnline {
return statusNothingToDo
}
for _, revDep := range m.depReverse {
// not ready if a reverse dependency was started, but not yet stopped
if revDep.Status() > StatusOffline {
return statusWaiting
}
}
return statusReady
}

View file

@ -11,11 +11,16 @@ import (
var (
shutdownSignal = make(chan struct{})
shutdownSignalClosed = abool.NewBool(false)
shutdownFlag = abool.NewBool(false)
shutdownCompleteSignal = make(chan struct{})
)
// IsShuttingDown returns whether the global shutdown is in progress.
func IsShuttingDown() bool {
return shutdownFlag.IsSet()
}
// ShuttingDown returns a channel read on the global shutdown signal.
func ShuttingDown() <-chan struct{} {
return shutdownSignal
@ -23,18 +28,19 @@ func ShuttingDown() <-chan struct{} {
// Shutdown stops all modules in the correct order.
func Shutdown() error {
// lock mgmt
mgmtLock.Lock()
defer mgmtLock.Unlock()
if shutdownSignalClosed.SetToIf(false, true) {
if shutdownFlag.SetToIf(false, true) {
close(shutdownSignal)
} else {
// shutdown was already issued
return errors.New("shutdown already initiated")
}
if startComplete.IsSet() {
if initialStartCompleted.IsSet() {
log.Warning("modules: starting shutdown...")
modulesLock.Lock()
defer modulesLock.Unlock()
} else {
log.Warning("modules: aborting, shutting down...")
}
@ -61,46 +67,42 @@ func stopModules() error {
// get number of started modules
startedCnt := 0
for _, m := range modules {
if m.Started.IsSet() {
if m.Status() >= StatusStarting {
startedCnt++
}
}
for {
waiting := 0
// find modules to exec
for _, m := range modules {
if m.ReadyToStop() {
switch m.readyToStop() {
case statusNothingToDo:
case statusWaiting:
waiting++
case statusReady:
execCnt++
m.inTransition.Set()
execM := m
go func() {
reports <- &report{
module: execM,
err: execM.shutdown(),
}
}()
m.stop(reports)
}
}
// check for dep loop
if execCnt == reportCnt {
return fmt.Errorf("modules: dependency loop detected, cannot continue")
}
if reportCnt < execCnt {
// wait for reports
rep = <-reports
rep.module.inTransition.UnSet()
if rep.err != nil {
lastErr = rep.err
log.Warningf("modules: could not stop module %s: %s", rep.module.Name, rep.err)
}
reportCnt++
rep.module.Stopped.Set()
log.Infof("modules: stopped %s", rep.module.Name)
// exit if done
if reportCnt == startedCnt {
} else {
// finished
if waiting > 0 {
// check for dep loop
return fmt.Errorf("modules: dependency loop detected, cannot continue")
}
// return last error
return lastErr
}
}

View file

@ -0,0 +1,123 @@
package subsystems
import (
"context"
"flag"
"fmt"
"strings"
"github.com/safing/portbase/database"
_ "github.com/safing/portbase/database/dbmodule" // database module is required
"github.com/safing/portbase/modules"
)
const (
configChangeEvent = "config change"
subsystemsStatusChange = "status change"
)
var (
module *modules.Module
printGraphFlag bool
databaseKeySpace string
db = database.NewInterface(nil)
)
func init() {
// enable partial starting
modules.EnableModuleManagement(handleModuleChanges)
// register module and enable it for starting
module = modules.Register("subsystems", prep, start, nil, "config", "database", "base")
module.Enable()
// register event for changes in the subsystem
module.RegisterEvent(subsystemsStatusChange)
flag.BoolVar(&printGraphFlag, "print-subsystem-graph", false, "print the subsystem module dependency graph")
}
func prep() error {
if printGraphFlag {
printGraph()
return modules.ErrCleanExit
}
return module.RegisterEventHook("config", configChangeEvent, "control subsystems", handleConfigChanges)
}
func start() error {
// lock registration
subsystemsLocked.Set()
// lock slice and map
subsystemsLock.Lock()
// go through all dependencies
seen := make(map[string]struct{})
for _, sub := range subsystems {
// mark subsystem module as seen
seen[sub.module.Name] = struct{}{}
}
for _, sub := range subsystems {
// add main module
sub.Modules = append(sub.Modules, statusFromModule(sub.module))
// add dependencies
sub.addDependencies(sub.module, seen)
}
// unlock
subsystemsLock.Unlock()
// apply config
module.StartWorker("initial subsystem configuration", func(ctx context.Context) error {
return handleConfigChanges(module.Ctx, nil)
})
return nil
}
func (sub *Subsystem) addDependencies(module *modules.Module, seen map[string]struct{}) {
for _, module := range module.Dependencies() {
_, ok := seen[module.Name]
if !ok {
// add dependency to modules
sub.Modules = append(sub.Modules, statusFromModule(module))
// mark as seen
seen[module.Name] = struct{}{}
// add further dependencies
sub.addDependencies(module, seen)
}
}
}
// SetDatabaseKeySpace sets a key space where subsystem status
func SetDatabaseKeySpace(keySpace string) {
if databaseKeySpace == "" {
databaseKeySpace = keySpace
if !strings.HasSuffix(databaseKeySpace, "/") {
databaseKeySpace += "/"
}
}
}
func printGraph() {
// unmark subsystems module
module.Disable()
// mark roots
for _, sub := range subsystems {
sub.module.Enable() // mark as tree root
}
// print
for _, sub := range subsystems {
printModuleGraph("", sub.module, true)
}
}
func printModuleGraph(prefix string, module *modules.Module, root bool) {
fmt.Printf("%s├── %s\n", prefix, module.Name)
if root || !module.Enabled() {
for _, dep := range module.Dependencies() {
printModuleGraph(fmt.Sprintf("│ %s", prefix), dep, false)
}
}
}

View file

@ -0,0 +1,116 @@
package subsystems
import (
"sync"
"github.com/safing/portbase/config"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
)
// Subsystem describes a subset of modules that represent a part of a service or program to the user.
type Subsystem struct { //nolint:maligned // not worth the effort
record.Base
sync.Mutex
ID string
Name string
Description string
module *modules.Module
Modules []*ModuleStatus
FailureStatus uint8 // summary: worst status
ToggleOptionKey string
toggleOption *config.Option
toggleValue func() bool
ExpertiseLevel uint8 // copied from toggleOption
ReleaseLevel uint8 // copied from toggleOption
ConfigKeySpace string
}
// ModuleStatus describes the status of a module.
type ModuleStatus struct {
Name string
module *modules.Module
// status mgmt
Enabled bool
Status uint8
// failure status
FailureStatus uint8
FailureID string
FailureMsg string
}
// Save saves the Subsystem Status to the database.
func (sub *Subsystem) Save() {
if databaseKeySpace != "" {
if !sub.KeyIsSet() {
sub.SetKey(databaseKeySpace + sub.ID)
}
err := db.Put(sub)
if err != nil {
log.Errorf("subsystems: could not save subsystem status to database: %s", err)
}
}
}
func statusFromModule(module *modules.Module) *ModuleStatus {
status := &ModuleStatus{
Name: module.Name,
module: module,
Enabled: module.Enabled(),
Status: module.Status(),
}
status.FailureStatus, status.FailureID, status.FailureMsg = module.FailureStatus()
return status
}
func compareAndUpdateStatus(module *modules.Module, status *ModuleStatus) (changed bool) {
// check if enabled
enabled := module.Enabled()
if status.Enabled != enabled {
status.Enabled = enabled
changed = true
}
// check status
statusLvl := module.Status()
if status.Status != statusLvl {
status.Status = statusLvl
changed = true
}
// check failure status
failureStatus, failureID, failureMsg := module.FailureStatus()
if status.FailureStatus != failureStatus ||
status.FailureID != failureID {
status.FailureStatus = failureStatus
status.FailureID = failureID
status.FailureMsg = failureMsg
changed = true
}
return
}
func (sub *Subsystem) makeSummary() {
// find worst failing module
worstFailing := &ModuleStatus{}
for _, depStatus := range sub.Modules {
if depStatus.FailureStatus > worstFailing.FailureStatus {
worstFailing = depStatus
}
}
if worstFailing != nil {
sub.FailureStatus = worstFailing.FailureStatus
} else {
sub.FailureStatus = 0
}
}

View file

@ -0,0 +1,161 @@
package subsystems
import (
"context"
"fmt"
"sync"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/config"
"github.com/safing/portbase/modules"
)
var (
subsystems []*Subsystem
subsystemsMap = make(map[string]*Subsystem)
subsystemsLock sync.Mutex
subsystemsLocked = abool.New()
handlingConfigChanges = abool.New()
)
// Register registers a new subsystem. The given option must be a bool option. Should be called in init() directly after the modules.Register() function. The config option must not yet be registered and will be registered for you. Pass a nil option to force enable.
func Register(id, name, description string, module *modules.Module, configKeySpace string, option *config.Option) {
// lock slice and map
subsystemsLock.Lock()
defer subsystemsLock.Unlock()
// check if registration is closed
if subsystemsLocked.IsSet() {
panic("subsystems can only be registered in prep phase or earlier")
}
// check if already registered
_, ok := subsystemsMap[name]
if ok {
panic(fmt.Sprintf(`subsystem "%s" already registered`, name))
}
// create new
new := &Subsystem{
ID: id,
Name: name,
Description: description,
module: module,
toggleOption: option,
ConfigKeySpace: configKeySpace,
}
if new.toggleOption != nil {
new.ToggleOptionKey = new.toggleOption.Key
new.ExpertiseLevel = new.toggleOption.ExpertiseLevel
new.ReleaseLevel = new.toggleOption.ReleaseLevel
}
// register config
if option != nil {
err := config.Register(option)
if err != nil {
panic(fmt.Sprintf("failed to register config: %s", err))
}
new.toggleValue = config.GetAsBool(new.ToggleOptionKey, false)
} else {
// force enabled
new.toggleValue = func() bool { return true }
}
// add to lists
subsystemsMap[name] = new
subsystems = append(subsystems, new)
}
func handleModuleChanges(m *modules.Module) {
// check if ready
if !subsystemsLocked.IsSet() {
return
}
// check if shutting down
if modules.IsShuttingDown() {
return
}
// find module status
var moduleSubsystem *Subsystem
var moduleStatus *ModuleStatus
subsystemLoop:
for _, subsystem := range subsystems {
for _, status := range subsystem.Modules {
if m.Name == status.Name {
moduleSubsystem = subsystem
moduleStatus = status
break subsystemLoop
}
}
}
// abort if not found
if moduleSubsystem == nil || moduleStatus == nil {
return
}
// update status
moduleSubsystem.Lock()
changed := compareAndUpdateStatus(m, moduleStatus)
if changed {
moduleSubsystem.makeSummary()
}
moduleSubsystem.Unlock()
// save
if changed {
moduleSubsystem.Save()
}
}
func handleConfigChanges(ctx context.Context, data interface{}) error {
// check if ready
if !subsystemsLocked.IsSet() {
return nil
}
// potentially catch multiple changes
if handlingConfigChanges.SetToIf(false, true) {
time.Sleep(100 * time.Millisecond)
handlingConfigChanges.UnSet()
} else {
return nil
}
// don't do anything if we are already shutting down globally
if modules.IsShuttingDown() {
return nil
}
// only run one instance at any time
subsystemsLock.Lock()
defer subsystemsLock.Unlock()
var changed bool
for _, subsystem := range subsystems {
if subsystem.module.SetEnabled(subsystem.toggleValue()) {
// if changed
changed = true
}
}
// trigger module management if any setting was changed
if changed {
err := modules.ManageModules()
if err != nil {
module.Error(
"modulemgmt-failed",
fmt.Sprintf("The subsystem framework failed to start or stop one or more modules.\nError: %s\nCheck logs for more information.", err),
)
} else {
module.Resolve("modulemgmt-failed")
}
}
return nil
}

View file

@ -0,0 +1,123 @@
package subsystems
import (
"io/ioutil"
"os"
"testing"
"time"
"github.com/safing/portbase/config"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/modules"
)
func TestSubsystems(t *testing.T) {
// tmp dir for data root (db & config)
tmpDir, err := ioutil.TempDir("", "portbase-testing-")
// initialize data dir
if err == nil {
err = dataroot.Initialize(tmpDir, 0755)
}
// handle setup error
if err != nil {
t.Fatal(err)
}
// register
baseModule := modules.Register("base", nil, nil, nil)
Register(
"base",
"Base",
"Framework Groundwork",
baseModule,
"config:base",
nil,
)
feature1 := modules.Register("feature1", nil, nil, nil)
Register(
"feature-one",
"Feature One",
"Provides feature one",
feature1,
"config:feature1",
&config.Option{
Name: "Enable Feature One",
Key: "config:subsystems/feature1",
Description: "This option enables feature 1",
OptType: config.OptTypeBool,
DefaultValue: false,
},
)
sub1 := subsystemsMap["Feature One"]
feature2 := modules.Register("feature2", nil, nil, nil)
Register(
"feature-two",
"Feature Two",
"Provides feature two",
feature2,
"config:feature2",
&config.Option{
Name: "Enable Feature One",
Key: "config:subsystems/feature2",
Description: "This option enables feature 2",
OptType: config.OptTypeBool,
DefaultValue: false,
},
)
// start
err = modules.Start()
if err != nil {
t.Fatal(err)
}
// test
// let module fail
feature1.Error("test-fail", "Testing Fail")
time.Sleep(10 * time.Millisecond)
if sub1.FailureStatus != modules.FailureError {
t.Fatal("error did not propagate")
}
// resolve
feature1.Resolve("test-fail")
time.Sleep(10 * time.Millisecond)
if sub1.FailureStatus != modules.FailureNone {
t.Fatal("error resolving did not propagate")
}
// update settings
err = config.SetConfigOption("config:subsystems/feature2", true)
if err != nil {
t.Fatal(err)
return
}
time.Sleep(200 * time.Millisecond)
if !feature2.Enabled() {
t.Fatal("failed to enable feature2")
}
if feature2.Status() != modules.StatusOnline {
t.Fatal("feature2 did not start")
}
// update settings
err = config.SetConfigOption("config:subsystems/feature2", false)
if err != nil {
t.Fatal(err)
return
}
time.Sleep(200 * time.Millisecond)
if feature2.Enabled() {
t.Fatal("failed to disable feature2")
}
if feature2.Status() != modules.StatusOffline {
t.Fatal("feature2 did not stop")
}
// clean up and exit
os.RemoveAll(tmpDir)
}

View file

@ -8,21 +8,24 @@ import (
"sync/atomic"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/tevino/abool"
)
// Task is managed task bound to a module.
type Task struct {
name string
module *Module
taskFn func(context.Context, *Task)
taskFn func(context.Context, *Task) error
queued bool
canceled bool
executing bool
cancelFunc func()
// these are populated at task creation
// ctx is canceled when task is shutdown -> all tasks become canceled
ctx context.Context
cancelCtx func()
executeAt time.Time
repeat time.Duration
@ -59,25 +62,46 @@ const (
)
// NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed.
func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task {
func (m *Module) NewTask(name string, fn func(context.Context, *Task) error) *Task {
m.Lock()
defer m.Unlock()
if m == nil {
log.Errorf(`modules: cannot create task "%s" with nil module`, name)
return &Task{
name: name,
module: &Module{Name: "[NONE]"},
canceled: true,
}
}
if m.Ctx == nil || !m.OnlineSoon() {
log.Errorf(`modules: tasks should only be started when the module is online or starting`)
return &Task{
name: name,
module: m,
canceled: true,
}
}
return &Task{
// create new task
new := &Task{
name: name,
module: m,
taskFn: fn,
maxDelay: defaultMaxDelay,
}
// create context
new.ctx, new.cancelCtx = context.WithCancel(m.Ctx)
return new
}
func (t *Task) isActive() bool {
if t.module == nil {
if t.canceled {
return false
}
return !t.canceled && !t.module.ShutdownInProgress()
return t.module.OnlineSoon()
}
func (t *Task) prepForQueueing() (ok bool) {
@ -197,45 +221,15 @@ func (t *Task) Repeat(interval time.Duration) *Task {
func (t *Task) Cancel() {
t.lock.Lock()
t.canceled = true
if t.cancelFunc != nil {
t.cancelFunc()
if t.cancelCtx != nil {
t.cancelCtx()
}
t.lock.Unlock()
}
func (t *Task) runWithLocking() {
if t.module == nil {
return
}
// wait for good timeslot regarding microtasks
select {
case <-taskTimeslot:
case <-time.After(maxTimeslotWait):
}
t.lock.Lock()
// check state, return if already executing or inactive
if t.executing || !t.isActive() {
t.lock.Unlock()
return
}
t.executing = true
// get list elements
queueElement := t.queueElement
prioritizedQueueElement := t.prioritizedQueueElement
scheduleListElement := t.scheduleListElement
// create context
var taskCtx context.Context
taskCtx, t.cancelFunc = context.WithCancel(t.module.Ctx)
t.lock.Unlock()
func (t *Task) removeFromQueues() {
// remove from lists
if queueElement != nil {
if t.queueElement != nil {
queuesLock.Lock()
taskQueue.Remove(t.queueElement)
queuesLock.Unlock()
@ -243,7 +237,7 @@ func (t *Task) runWithLocking() {
t.queueElement = nil
t.lock.Unlock()
}
if prioritizedQueueElement != nil {
if t.prioritizedQueueElement != nil {
queuesLock.Lock()
prioritizedTaskQueue.Remove(t.prioritizedQueueElement)
queuesLock.Unlock()
@ -251,7 +245,7 @@ func (t *Task) runWithLocking() {
t.prioritizedQueueElement = nil
t.lock.Unlock()
}
if scheduleListElement != nil {
if t.scheduleListElement != nil {
scheduleLock.Lock()
taskSchedule.Remove(t.scheduleListElement)
scheduleLock.Unlock()
@ -259,14 +253,62 @@ func (t *Task) runWithLocking() {
t.scheduleListElement = nil
t.lock.Unlock()
}
}
func (t *Task) runWithLocking() {
t.lock.Lock()
// check if task is already executing
if t.executing {
t.lock.Unlock()
return
}
// check if task is active
if !t.isActive() {
t.removeFromQueues()
t.lock.Unlock()
return
}
// check if module was stopped
select {
case <-t.ctx.Done(): // check if module is stopped
t.removeFromQueues()
t.lock.Unlock()
return
default:
}
t.executing = true
t.lock.Unlock()
// wait for good timeslot regarding microtasks
select {
case <-taskTimeslot:
case <-time.After(maxTimeslotWait):
}
// wait for module start
if !t.module.Online() {
if t.module.OnlineSoon() {
// wait
<-t.module.StartCompleted()
} else {
t.lock.Lock()
t.removeFromQueues()
t.lock.Unlock()
return
}
}
// add to queue workgroup
queueWg.Add(1)
go t.executeWithLocking(taskCtx, t.cancelFunc)
go t.executeWithLocking()
go func() {
select {
case <-taskCtx.Done():
case <-t.ctx.Done():
case <-time.After(maxExecutionWait):
}
// complete queue worker (early) to allow next worker
@ -274,7 +316,7 @@ func (t *Task) runWithLocking() {
}()
}
func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) {
func (t *Task) executeWithLocking() {
// start for module
// hint: only queueWg global var is important for scheduling, others can be set here
atomic.AddInt32(t.module.taskCnt, 1)
@ -306,11 +348,16 @@ func (t *Task) executeWithLocking(ctx context.Context, cancelFunc func()) {
t.lock.Unlock()
// notify that we finished
cancelFunc()
if t.cancelCtx != nil {
t.cancelCtx()
}
}()
// run
t.taskFn(ctx, t)
err := t.taskFn(t.ctx, t)
if err != nil {
log.Errorf("%s: task %s failed: %s", t.module.Name, t.name, err)
}
}
func (t *Task) getExecuteAtWithLocking() time.Time {
@ -320,6 +367,10 @@ func (t *Task) getExecuteAtWithLocking() time.Time {
}
func (t *Task) addToSchedule() {
if !t.isActive() {
return
}
scheduleLock.Lock()
defer scheduleLock.Unlock()
// defer printTaskList(taskSchedule) // for debugging
@ -395,7 +446,7 @@ func taskQueueHandler() {
queueWg.Wait()
// check for shutdown
if shutdownSignalClosed.IsSet() {
if shutdownFlag.IsSet() {
return
}

View file

@ -35,22 +35,29 @@ func init() {
var qtWg sync.WaitGroup
var qtOutputChannel chan string
var qtSleepDuration time.Duration
var qtModule = initNewModule("task test module", nil, nil, nil)
var qtModule *Module
func init() {
qtModule = initNewModule("task test module", nil, nil, nil)
qtModule.status = StatusOnline
}
// functions
func queuedTaskTester(s string) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) error {
time.Sleep(qtSleepDuration * 2)
qtOutputChannel <- s
qtWg.Done()
return nil
}).Queue()
}
func prioritizedTaskTester(s string) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) error {
time.Sleep(qtSleepDuration * 2)
qtOutputChannel <- s
qtWg.Done()
return nil
}).Prioritize()
}
@ -109,10 +116,11 @@ var stWaitCh chan bool
// functions
func scheduledTaskTester(s string, sched time.Time) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) {
qtModule.NewTask(s, func(ctx context.Context, t *Task) error {
time.Sleep(stSleepDuration)
stOutputChannel <- s
stWg.Done()
return nil
}).Schedule(sched)
}

View file

@ -73,7 +73,7 @@ func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn
lastFail := time.Now()
for {
if m.ShutdownInProgress() {
if m.IsStopping() {
return
}

View file

@ -111,26 +111,26 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
}
// Put stores a record in the database.
func (s *StorageInterface) Put(r record.Record) error {
func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
// record is already locked!
key := r.DatabaseKey()
n, err := EnsureNotification(r)
if err != nil {
return ErrInvalidData
return nil, ErrInvalidData
}
// transform key
if strings.HasPrefix(key, "all/") {
key = strings.TrimPrefix(key, "all/")
} else {
return ErrInvalidPath
return nil, ErrInvalidPath
}
// continue in goroutine
go UpdateNotification(n, key)
return nil
return n, nil
}
// UpdateNotification updates a notification with input from a database action. Notification will not be saved/propagated if there is no valid change.

View file

@ -11,7 +11,7 @@ var (
)
func init() {
module = modules.Register("notifications", nil, start, nil, "base", "database")
module = modules.Register("notifications", nil, start, nil, "database", "base")
}
func start() error {

View file

@ -93,6 +93,7 @@ func (n *Notification) Save() *Notification {
nots[n.ID] = n
// push update
log.Tracef("notifications: pushing update for %s to subscribers", n.Key())
dbController.PushUpdate(n)
// persist
@ -152,10 +153,11 @@ func (n *Notification) MakeAck() *Notification {
// Response waits for the user to respond to the notification and returns the selected action.
func (n *Notification) Response() <-chan string {
n.lock.Lock()
defer n.lock.Unlock()
if n.actionTrigger == nil {
n.actionTrigger = make(chan string)
}
n.lock.Unlock()
return n.actionTrigger
}
@ -213,10 +215,11 @@ func (n *Notification) Delete() error {
// Expired notifies the caller when the notification has expired.
func (n *Notification) Expired() <-chan struct{} {
n.lock.Lock()
defer n.lock.Unlock()
if n.expiredTrigger == nil {
n.expiredTrigger = make(chan struct{})
}
n.lock.Unlock()
return n.expiredTrigger
}

View file

@ -1,49 +1,19 @@
package main
import (
"fmt"
"os"
"os/signal"
"syscall"
"github.com/safing/portbase/info"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/run"
// include packages here
_ "github.com/safing/portbase/api"
)
func main() {
// Set Info
info.Set("Portbase", "0.0.1", "GPLv3", false)
// Start
err := modules.Start()
if err != nil {
if err == modules.ErrCleanExit {
os.Exit(0)
} else {
os.Exit(1)
}
}
// Shutdown
// catch interrupt for clean shutdown
signalCh := make(chan os.Signal, 3)
signal.Notify(
signalCh,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
)
select {
case <-signalCh:
fmt.Println(" <INTERRUPT>")
log.Warning("main: program was interrupted, shutting down.")
_ = modules.Shutdown()
case <-modules.ShuttingDown():
}
// Run
os.Exit(run.Run())
}

View file

@ -23,7 +23,7 @@ var (
)
func init() {
modules.Register("random", prep, Start, nil, "base")
modules.Register("random", prep, Start, nil)
}
func prep() error {

116
template/module.go Normal file
View file

@ -0,0 +1,116 @@
package template
import (
"context"
"time"
"github.com/safing/portbase/config"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/modules/subsystems"
)
const (
eventStateUpdate = "state update"
)
var (
module *modules.Module
)
func init() {
// register base module, for database initialization
modules.Register("base", nil, nil, nil)
// register module
module = modules.Register("template", prep, start, stop) // add dependencies...
subsystems.Register(
"template-subsystem", // ID
"Template Subsystem", // name
"This subsystem is a template for quick setup", // description
module,
"config:template", // key space for configuration options registered
&config.Option{
Name: "Enable Template Subsystem",
Key: "config:subsystems/template",
Description: "This option enables the Template Subsystem [TEMPLATE]",
OptType: config.OptTypeBool,
DefaultValue: false,
},
)
// register events that other modules can subscribe to
module.RegisterEvent(eventStateUpdate)
}
func prep() error {
// register options
err := config.Register(&config.Option{
Name: "language",
Key: "config:template/language",
Description: "Sets the language for the template [TEMPLATE]",
OptType: config.OptTypeString,
ExpertiseLevel: config.ExpertiseLevelUser, // default
ReleaseLevel: config.ReleaseLevelStable, // default
RequiresRestart: false, // default
DefaultValue: "en",
ValidationRegex: "^[a-z]{2}$",
})
if err != nil {
return err
}
// register event hooks
// do this in prep() and not in start(), as we don't want to register again if module is turned off and on again
err = module.RegisterEventHook(
"template", // event source module name
"state update", // event source name
"react to state changes", // description of hook function
eventHandler, // hook function
)
if err != nil {
return err
}
// hint: event hooks and tasks will not be run if module isn't online
return nil
}
func start() error {
// register tasks
module.NewTask("do something", taskFn).Queue()
// start service worker
module.StartServiceWorker("do something", 0, serviceWorker)
return nil
}
func stop() error {
return nil
}
func serviceWorker(ctx context.Context) error {
for {
select {
case <-time.After(1 * time.Second):
err := do()
if err != nil {
return err
}
case <-ctx.Done():
return nil
}
}
}
func taskFn(ctx context.Context, task *modules.Task) error {
return do()
}
func eventHandler(ctx context.Context, data interface{}) error {
return do()
}
func do() error {
return nil
}

51
template/module_test.go Normal file
View file

@ -0,0 +1,51 @@
package template
import (
"fmt"
"io/ioutil"
"os"
"testing"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/modules"
)
func TestMain(m *testing.M) {
// enable module for testing
module.Enable()
// tmp dir for data root (db & config)
tmpDir, err := ioutil.TempDir("", "portbase-testing-")
if err != nil {
fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err)
os.Exit(1)
}
// initialize data dir
err = dataroot.Initialize(tmpDir, 0755)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err)
os.Exit(1)
}
// start modules
var exitCode int
err = modules.Start()
if err != nil {
// starting failed
fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err)
exitCode = 1
} else {
// run tests
exitCode = m.Run()
}
// shutdown
_ = modules.Shutdown()
if modules.GetExitStatusCode() != 0 {
exitCode = modules.GetExitStatusCode()
fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err)
}
// clean up and exit
os.RemoveAll(tmpDir)
os.Exit(exitCode)
}

View file

@ -50,6 +50,10 @@ func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error {
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("error fetching url (%s): %s", downloadURL, resp.Status)
}
// download and write file
n, err := io.Copy(atomicFile, resp.Body)
if err != nil {
@ -96,6 +100,10 @@ func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte,
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("error fetching url (%s): %s", downloadURL, resp.Status)
}
// download and write file
buf := bytes.NewBuffer(make([]byte, 0, resp.ContentLength))
n, err := io.Copy(buf, resp.Body)